In [2]:
import numpy as np
import nibabel as nib
import os
import gc
from natsort import natsorted
import matplotlib.pyplot as plt
import tensorflow as tf
from skimage.util import montage
from datasets.nifti_dataset import resample_nifti
from datasets import base_dataset
from datasets.base_dataset import _roll2center_crop
from tensorflow.keras.optimizers import Adam
from models import deep_strain_model
from scipy.ndimage import center_of_mass
import warnings
warnings.filterwarnings(action='once')

## LOADING RAW DATA

In [3]:
DATA_FOLDER = 'ACDC'
patients = [f"patient{i:03d}" for i in range(1, 151)]
patientsGT = [f"patient{i:03d}" for i in range(1, 101)]

volumesED = [ nib.load(os.path.join(os.path.join(DATA_FOLDER, patient),[ f for f in natsorted(os.listdir(os.path.join(DATA_FOLDER, patient))) 
            if f.endswith("nii.gz") and not f.endswith("gt.nii.gz") ][1])) for patient in patients ]

segsED = [ nib.load(os.path.join(os.path.join(DATA_FOLDER, patient),[ f for f in natsorted(os.listdir(os.path.join(DATA_FOLDER, patient))) 
            if f.endswith("gt.nii.gz") ][0])) for patient in patientsGT ]

volumesES = [ nib.load(os.path.join(os.path.join(DATA_FOLDER, patient),[ f for f in natsorted(os.listdir(os.path.join(DATA_FOLDER, patient))) 
            if f.endswith("nii.gz") and not f.endswith("gt.nii.gz") ][2])) for patient in patients ]

segsES = [ nib.load(os.path.join(os.path.join(DATA_FOLDER, patient),[ f for f in natsorted(os.listdir(os.path.join(DATA_FOLDER, patient))) 
            if f.endswith("gt.nii.gz") ][1])) for patient in patientsGT ]

## RESAMPLING IMAGES

In [4]:
VED_nifti_resampled = [resample_nifti(
    V_nifti, order=1, in_plane_resolution_mm=1.25, number_of_slices=16
) for V_nifti in volumesED]

VES_nifti_resampled = [resample_nifti(
    V_nifti, order=1, in_plane_resolution_mm=1.25, number_of_slices=16
) for V_nifti in volumesES]


MED_nifti_resampled = [resample_nifti(
    M_nifti, order=1, in_plane_resolution_mm=1.25, number_of_slices=16
) for M_nifti in segsED]

MES_nifti_resampled = [resample_nifti(
    M_nifti, order=1, in_plane_resolution_mm=1.25, number_of_slices=16
) for M_nifti in segsES]



## SEGMENTATIONS FOR PATIENTS 100-150

In [5]:
########################## Mask and normalization ######################################


def normalize(x, axis=(0, 1, 2)):
    # normalize per volume (x,y,z) frame
    mu = x.mean(axis=axis, keepdims=True)
    sd = x.std(axis=axis, keepdims=True)
    return (x - mu) / (sd + 1e-8)


def get_mask(V, netS):
    nx, ny, nz, nt = V.shape

    M = np.zeros((nx, ny, nz, nt))
    v = V.transpose((2, 3, 0, 1)).reshape((-1, nx, ny))  # (nz*nt,nx,ny)
    v = normalize(v)
    for t in range(nt):
        for z in range(nz):
            m = netS(
                v[z * nt + t, nx // 2 - 64 : nx // 2 + 64, ny // 2 - 64 : ny // 2 + 64][
                    None, ..., None
                ]
            )
            M[nx // 2 - 64 : nx // 2 + 64, ny // 2 - 64 : ny // 2 + 64, z, t] += (
                np.argmax(m, -1).transpose((1, 2, 0)).reshape((128, 128))
            )
    return M

######################### Constants and arguments ######################################

DATASET_FOLDER = "ACDC"

class CarSON_options:
    def __init__(self):
        self.isTrain = False
        self.image_shape = (128, 128, 1)
        self.nlabels = 4
        self.pretrained_models_netS = "models/carson_Jan2021.h5"

In [7]:
opt = CarSON_options()
model = deep_strain_model.DeepStrain(Adam, opt=opt)
netS = model.get_netS()

for i, (VED_resampled, VES_resampled) in enumerate(zip(VED_nifti_resampled[100:150], VES_nifti_resampled[100:150])):

    VED = VED_resampled.get_fdata()
    VES = VES_resampled.get_fdata()
    V = np.stack((VED, VES), axis=-1)
    V = normalize(V, axis=(0, 1))
    M = get_mask(V, netS)

    center_resampled = center_of_mass(M[:, :, :, 0] == 2)
    V = base_dataset.roll2center(x=V, center=center_resampled)
    M = base_dataset.roll2center(x=M, center=center_resampled)

    M = get_mask(V, netS)

    VED_nifti_resampled[100+i] = nib.Nifti1Image(V[...,0], VED_resampled.affine)
    VES_nifti_resampled[100+i] = nib.Nifti1Image(V[...,1], VES_resampled.affine)
    MED_nifti_resampled.append(nib.Nifti1Image(M[...,0], VED_resampled.affine))
    MES_nifti_resampled.append(nib.Nifti1Image(M[...,1], VES_resampled.affine))


## GET DATA, CROP AND CENTER

In [14]:
V0_list = []
Vt_list = []
M0_list = []
Mt_list = []
res_list = []
for i in range(150):
    center = center_of_mass(MED_nifti_resampled[i].get_fdata() == 2)
    V0 = _roll2center_crop(VED_nifti_resampled[i].get_fdata().astype('float32'), center)[...,None]
    Vt = _roll2center_crop(VES_nifti_resampled[i].get_fdata().astype('float32'), center)[...,None]
    M0 = _roll2center_crop(MED_nifti_resampled[i].get_fdata().astype('float32'), center)[...,None]
    Mt = _roll2center_crop(MES_nifti_resampled[i].get_fdata().astype('float32'), center)[...,None]
    res = np.array(VED_nifti_resampled[i].header.get_zooms(), dtype=np.float32).reshape(3,1,1)
    res = np.pad(res, ((0,125),(0,127),(0,15)), constant_values=(1,))[...,None]
    V0_list.append(V0)
    V0_list.append(V0)
    Vt_list.append(V0)
    Vt_list.append(Vt)
    M0_list.append(M0)
    M0_list.append(M0)
    Mt_list.append(M0)
    Mt_list.append(Mt)
    res_list.append(res)
    res_list.append(res)

## SERIALIZE DATA AND SAVE TO DISK

In [16]:
def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def serialize_example(feature0, feature1, feature2, feature3, feature4):
    feature = {
        'V0': _bytes_feature(feature0),
        'Vt': _bytes_feature(feature1),
        'M0': _bytes_feature(feature2),
        'Mt': _bytes_feature(feature3),
        'res': _bytes_feature(feature4)
    }
    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    return example_proto.SerializeToString()

def write_tfrecord(data, filename):
    with tf.io.TFRecordWriter(filename) as writer:
        for i in range(len(data)):
            example = serialize_example(data[i][0].tobytes(), data[i][1].tobytes(), data[i][2].tobytes(), data[i][3].tobytes(), data[i][4].tobytes())
            writer.write(example)

In [17]:
data = list(zip(V0_list, Vt_list, M0_list, Mt_list, res_list))
write_tfrecord(data, 'data/training/trainingEDES.tfrecord')

In [82]:

class Augment(tf.keras.layers.Layer):
  def __init__(self, seed=42):
    super().__init__()
    # both use the same seed, so they'll make the same random changes.
    self.augment_batch = tf.keras.Sequential([
      tf.keras.layers.RandomFlip("horizontal_and_vertical", seed=seed),
      tf.keras.layers.RandomRotation(1, seed=seed),
      tf.keras.layers.RandomContrast(0.5, seed=seed)
    ])

  def call(self, inputs, labels):
    ## unstack 2 batched inputs shape (batch, 128, 128, 16, 1) each in axis -2, and stack them back after augmentation
    ## unstack batched labels shape (batch, 5, 128, 128, 16, 1) each in axis -2 and -5, and stack them back after augmentation
    ## concat after unstacking to generate one big batch
    batchsize = labels.shape[0]
    nz = labels.shape[-2]
    
    bigbatch = tf.concat(
      [tf.concat(tf.unstack(inputs[0], axis=-2), axis=-1),
      tf.concat(tf.unstack(inputs[1], axis=-2), axis=-1),
      tf.concat(tf.unstack(labels[:,0], axis=-2), axis=-1),
      tf.concat(tf.unstack(labels[:,1], axis=-2), axis=-1),
      tf.concat(tf.unstack(labels[:,2], axis=-2), axis=-1),
      tf.concat(tf.unstack(labels[:,3], axis=-2), axis=-1),
      tf.concat(tf.unstack(labels[:,4], axis=-2), axis=-1)],
      axis=-1
    )
      
    augmented = self.augment_batch(bigbatch)


    inputs = [
      tf.expand_dims(tf.stack(tf.unstack(augmented[...,:nz], axis=-1), axis=-1), axis=-1),
      tf.expand_dims(tf.stack(tf.unstack(augmented[...,nz:2*nz], axis=-1), axis=-1), axis=-1),
    ]

    labels = tf.stack([
      tf.expand_dims(tf.stack(tf.unstack(augmented[...,2*nz:3*nz], axis=-1), axis=-1), axis=-1),
      tf.expand_dims(tf.stack(tf.unstack(augmented[...,3*nz:4*nz], axis=-1), axis=-1), axis=-1),
      tf.expand_dims(tf.stack(tf.unstack(augmented[...,4*nz:5*nz], axis=-1), axis=-1), axis=-1),
      tf.expand_dims(tf.stack(tf.unstack(augmented[...,5*nz:6*nz], axis=-1), axis=-1), axis=-1),
      tf.expand_dims(tf.stack(tf.unstack(augmented[...,6*nz:7*nz], axis=-1), axis=-1), axis=-1),
    ], axis=1)
                        
    return inputs, labels

In [83]:
aug = Augment()
inputs = [np.repeat(volumesED[0].get_fdata()[np.newaxis,..., np.newaxis], 7, axis=0), 
          np.repeat(volumesES[0].get_fdata()[np.newaxis,..., np.newaxis], 7, axis=0),]
labels = tf.convert_to_tensor(np.array([
        np.repeat(volumesED[0].get_fdata()[np.newaxis,..., np.newaxis], 7, axis=0), 
        np.repeat(volumesES[0].get_fdata()[np.newaxis,..., np.newaxis], 7, axis=0),
        np.repeat(segsED[0].get_fdata()[np.newaxis,..., np.newaxis], 7, axis=0),
        np.repeat(segsES[0].get_fdata()[np.newaxis,..., np.newaxis], 7, axis=0),
        np.repeat(segsED[0].get_fdata()[np.newaxis,..., np.newaxis], 7, axis=0),]).transpose((1,0,2,3,4,5)).astype(np.float32))



In [None]:
# montage imshow plot of V0
