In [1]:
import numpy as np
import nibabel as nib
import os
import gc
from natsort import natsorted
import matplotlib.pyplot as plt
import tensorflow as tf
from skimage.util import montage
from datasets.nifti_dataset import resample_nifti
from datasets import base_dataset
from datasets.base_dataset import _roll2center_crop
from tensorflow.keras.optimizers import Adam
from models import deep_strain_model
from scipy.ndimage import center_of_mass
import warnings
warnings.filterwarnings(action='once')

## LOADING RAW DATA

In [2]:
DATA_FOLDER = 'ACDC'
patients = [f"patient{i:03d}" for i in range(1, 151)]
patientsGT = [f"patient{i:03d}" for i in range(1, 101)]

volumesED = [ nib.load(os.path.join(os.path.join(DATA_FOLDER, patient),[ f for f in natsorted(os.listdir(os.path.join(DATA_FOLDER, patient))) 
            if f.endswith("nii.gz") and not f.endswith("gt.nii.gz") ][1])) for patient in patients ]

segsED = [ nib.load(os.path.join(os.path.join(DATA_FOLDER, patient),[ f for f in natsorted(os.listdir(os.path.join(DATA_FOLDER, patient))) 
            if f.endswith("gt.nii.gz") ][0])) for patient in patientsGT ]

volumesES = [ nib.load(os.path.join(os.path.join(DATA_FOLDER, patient),[ f for f in natsorted(os.listdir(os.path.join(DATA_FOLDER, patient))) 
            if f.endswith("nii.gz") and not f.endswith("gt.nii.gz") ][2])) for patient in patients ]

segsES = [ nib.load(os.path.join(os.path.join(DATA_FOLDER, patient),[ f for f in natsorted(os.listdir(os.path.join(DATA_FOLDER, patient))) 
            if f.endswith("gt.nii.gz") ][1])) for patient in patientsGT ]

## RESAMPLING IMAGES

In [3]:
VED_nifti_resampled = [resample_nifti(
    V_nifti, order=1, in_plane_resolution_mm=1.25, number_of_slices=16
) for V_nifti in volumesED]

VES_nifti_resampled = [resample_nifti(
    V_nifti, order=1, in_plane_resolution_mm=1.25, number_of_slices=16
) for V_nifti in volumesES]


MED_nifti_resampled = [resample_nifti(
    M_nifti, order=1, in_plane_resolution_mm=1.25, number_of_slices=16
) for M_nifti in segsED]

MES_nifti_resampled = [resample_nifti(
    M_nifti, order=1, in_plane_resolution_mm=1.25, number_of_slices=16
) for M_nifti in segsES]



## SEGMENTATIONS FOR PATIENTS 100-150

In [4]:
########################## Mask and normalization ######################################


def normalize(x, axis=(0, 1, 2)):
    # normalize per volume (x,y,z) frame
    mu = x.mean(axis=axis, keepdims=True)
    sd = x.std(axis=axis, keepdims=True)
    return (x - mu) / (sd + 1e-8)


def get_mask(V, netS):
    nx, ny, nz, nt = V.shape

    M = np.zeros((nx, ny, nz, nt))
    v = V.transpose((2, 3, 0, 1)).reshape((-1, nx, ny))  # (nz*nt,nx,ny)
    v = normalize(v)
    for t in range(nt):
        for z in range(nz):
            m = netS(
                v[z * nt + t, nx // 2 - 64 : nx // 2 + 64, ny // 2 - 64 : ny // 2 + 64][
                    None, ..., None
                ]
            )
            M[nx // 2 - 64 : nx // 2 + 64, ny // 2 - 64 : ny // 2 + 64, z, t] += (
                np.argmax(m, -1).transpose((1, 2, 0)).reshape((128, 128))
            )
    return M

######################### Constants and arguments ######################################

DATASET_FOLDER = "ACDC"

class CarSON_options:
    def __init__(self):
        self.isTrain = False
        self.image_shape = (128, 128, 1)
        self.nlabels = 4
        self.pretrained_models_netS = "models/carson_Jan2021.h5"

In [5]:
opt = CarSON_options()
model = deep_strain_model.DeepStrain(Adam, opt=opt)
netS = model.get_netS()

for i, (VED_resampled, VES_resampled) in enumerate(zip(VED_nifti_resampled[100:150], VES_nifti_resampled[100:150])):
    
    print(f"Processing patient {i+1}...")

    VED = VED_resampled.get_fdata()
    VES = VES_resampled.get_fdata()
    V = np.stack((VED, VES), axis=-1)
    V_unnormalized = V.copy()
    V = normalize(V, axis=(0, 1))
    M = get_mask(V, netS)

    center_resampled = center_of_mass(M[:, :, :, 0] == 2)
    V = base_dataset.roll2center(x=V, center=center_resampled)
    V_unnormalized = base_dataset.roll2center(x=V_unnormalized, center=center_resampled)
    M = base_dataset.roll2center(x=M, center=center_resampled)

    M = get_mask(V, netS)

    VED_nifti_resampled[100+i] = nib.Nifti1Image(V_unnormalized[...,0], VED_resampled.affine)
    VES_nifti_resampled[100+i] = nib.Nifti1Image(V_unnormalized[...,1], VES_resampled.affine)
    MED_nifti_resampled.append(nib.Nifti1Image(M[...,0], VED_resampled.affine))
    MES_nifti_resampled.append(nib.Nifti1Image(M[...,1], VES_resampled.affine))


Processing patient 1...
Processing patient 2...
Processing patient 3...
Processing patient 4...
Processing patient 5...
Processing patient 6...
Processing patient 7...
Processing patient 8...
Processing patient 9...
Processing patient 10...
Processing patient 11...
Processing patient 12...
Processing patient 13...
Processing patient 14...
Processing patient 15...
Processing patient 16...
Processing patient 17...
Processing patient 18...
Processing patient 19...
Processing patient 20...
Processing patient 21...
Processing patient 22...
Processing patient 23...
Processing patient 24...
Processing patient 25...
Processing patient 26...
Processing patient 27...
Processing patient 28...
Processing patient 29...
Processing patient 30...
Processing patient 31...
Processing patient 32...
Processing patient 33...
Processing patient 34...
Processing patient 35...
Processing patient 36...
Processing patient 37...
Processing patient 38...
Processing patient 39...
Processing patient 40...
Processin

## GET DATA, CROP AND CENTER

In [6]:
V0_list = []
Vt_list = []
M0_list = []
Mt_list = []
resx_list = []
resy_list = []
resz_list = []
for i in range(150):
    center = center_of_mass(MED_nifti_resampled[i].get_fdata() == 2)
    V0 = _roll2center_crop(VED_nifti_resampled[i].get_fdata().astype('float32'), center)[...,None]
    Vt = _roll2center_crop(VES_nifti_resampled[i].get_fdata().astype('float32'), center)[...,None]
    M0 = _roll2center_crop(MED_nifti_resampled[i].get_fdata().astype('float32'), center)[...,None]
    Mt = _roll2center_crop(MES_nifti_resampled[i].get_fdata().astype('float32'), center)[...,None]
    res = np.array(VED_nifti_resampled[i].header.get_zooms(), dtype=np.float32)
    resx = np.ones_like(M0)*res[0]
    resy = np.ones_like(M0)*res[1]
    resz = np.ones_like(M0)*res[2]
    V0_list.append(V0)
    V0_list.append(V0)
    Vt_list.append(V0)
    Vt_list.append(Vt)
    M0_list.append(M0)
    M0_list.append(M0)
    Mt_list.append(M0)
    Mt_list.append(Mt)
    resx_list.append(resx)
    resx_list.append(resx)
    resy_list.append(resy)
    resy_list.append(resy)
    resz_list.append(resz)
    resz_list.append(resz)

## SERIALIZE DATA AND SAVE TO DISK

In [7]:
def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def serialize_example(feature0, feature1, feature2, feature3, feature4, feature5, feature6):
    feature = {
        'V0': _bytes_feature(feature0),
        'Vt': _bytes_feature(feature1),
        'M0': _bytes_feature(feature2),
        'Mt': _bytes_feature(feature3),
        'resx': _bytes_feature(feature4),
        'resy': _bytes_feature(feature5),
        'resz': _bytes_feature(feature6)
    }
    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    return example_proto.SerializeToString()

def write_tfrecord(data, filename):
    with tf.io.TFRecordWriter(filename) as writer:
        for i in range(len(data)):
            example = serialize_example(data[i][0].tobytes(), data[i][1].tobytes(), data[i][2].tobytes(), data[i][3].tobytes(), data[i][4].tobytes(), data[i][5].tobytes(), data[i][6].tobytes())
            writer.write(example)

In [8]:
data = list(zip(V0_list, Vt_list, M0_list, Mt_list, resx_list, resy_list, resz_list))
write_tfrecord(data, 'data/training/trainingEDES_con_res.tfrecord')

: 

In [None]:
# montage imshow plot of V0
