In [2]:
import matplotlib.pyplot as plt
import os
import zipfile
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard
import nibabel as nib
from scipy import ndimage
import glob
import sys,os
from natsort import natsorted
import tensorflow_probability as tfp
import random
from aifnet_utils.preprocess import read_nifti_file, normalize, normalize_aif, process_scan
from aifnet_utils.losses import MaxCorrelation
from aifnet_utils.isles_loaders import read_isles_annotations, read_isles_volumes

%matplotlib inline

In [3]:
keras.backend.set_image_data_format('channels_last')

In [4]:
root_dir     = '/media/sebastian/data/ASAP/ISLES2018_Training'
#At insel: /media/sebastian/data/ASAP/ISLES2018_Training
#Local: '/Users/sebastianotalora/work/postdoc/data/ISLES/'
aif_annotations_path = '/home/sebastian/experiments/aifnet_replication/annotated_aif_vof_validation.csv'
min_num_volumes_ctp = 43
ROOT_EXP = '/home/sebastian/experiments/aifnet_replication'#'/Users/sebastianotalora/work/postdoc/ctp/aifnet_replication'


In [5]:
aif_annotations, vof_annotations = read_isles_annotations(aif_annotations_path, root_dir, 
                                         min_num_volumes_ctp, return_aif_only = False)

In [138]:
dataset_dir = os.path.join(root_dir, "TRAINING")

filenames_4D = natsorted(glob.glob(dataset_dir + "/case_*/*4D*/*nii*"))

cases_paths = {}
cases_paths = {path.split('.')[-2]: path for path in filenames_4D}
filenames_4D

['/media/sebastian/data/ASAP/ISLES2018_Training/TRAINING/case_1/SMIR.Brain.XX.O.CT_4DPWI.345561/SMIR.Brain.XX.O.CT_4DPWI.345561.nii',
 '/media/sebastian/data/ASAP/ISLES2018_Training/TRAINING/case_2/SMIR.Brain.XX.O.CT_4DPWI.345568/SMIR.Brain.XX.O.CT_4DPWI.345568.nii',
 '/media/sebastian/data/ASAP/ISLES2018_Training/TRAINING/case_3/SMIR.Brain.XX.O.CT_4DPWI.345575/SMIR.Brain.XX.O.CT_4DPWI.345575.nii',
 '/media/sebastian/data/ASAP/ISLES2018_Training/TRAINING/case_4/SMIR.Brain.XX.O.CT_4DPWI.345582/SMIR.Brain.XX.O.CT_4DPWI.345582.nii',
 '/media/sebastian/data/ASAP/ISLES2018_Training/TRAINING/case_5/SMIR.Brain.XX.O.CT_4DPWI.345589/SMIR.Brain.XX.O.CT_4DPWI.345589.nii',
 '/media/sebastian/data/ASAP/ISLES2018_Training/TRAINING/case_6/SMIR.Brain.XX.O.CT_4DPWI.345596/SMIR.Brain.XX.O.CT_4DPWI.345596.nii',
 '/media/sebastian/data/ASAP/ISLES2018_Training/TRAINING/case_7/SMIR.Brain.XX.O.CT_4DPWI.345603/SMIR.Brain.XX.O.CT_4DPWI.345603.nii',
 '/media/sebastian/data/ASAP/ISLES2018_Training/TRAINING/case_

In [6]:
def read_isles_volumes(root_dir, aif_annotations_path, min_num_volumes_ctp, take_two_slices_only=False):
    dataset_dir = os.path.join(root_dir, "TRAINING")
    filenames_4D = natsorted(glob.glob(dataset_dir + "/case_*/*4D*/*nii*"))
    filenames_precontrast = 
    cases_paths = {}
    cases_paths = {path.split('.')[-2]: path for path in filenames_4D}
    #This is a little bit awful, but we need to get the coordinates from the annotations file to 
    #know which two slices get exactly
    cases_annotations = {}
    aif_annotations_file = open(aif_annotations_path,'r')
    aif_annotations_file.readline()
    for line in aif_annotations_file: #Here we substract one to account for 0-indexing in python
        cases_annotations[line.split(',')[0]] = [np.array([int(line.split(',')[1]),int(line.split(',')[2]),int(line.split(',')[3])])-1,
                                             np.array([int(line.split(',')[4]),int(line.split(',')[5]),int(line.split(',')[6])])-1]
    
    datalist = []
    
    for cur_case in cases_annotations.keys():
        fname = cases_paths[cur_case]
        cur_nib = nib.load(fname)    
        ctp_vals = cur_nib.get_fdata()
        AIFx,AIFy,AIFz = cases_annotations[cur_case][0][0],cases_annotations[cur_case][0][1], cases_annotations[cur_case][0][2]
        VOFx,VOFy,VOFz = cases_annotations[cur_case][1][0],cases_annotations[cur_case][1][1], cases_annotations[cur_case][1][2]
        #Four cases either is it possible to have the slice up, or the one down, or any of them
        if take_two_slices_only: 
            if ctp_vals.shape[2] != 2 and AIFz+1 <=  ctp_vals.shape[2] and AIFz>0:            
                    ctp_vals = ctp_vals[:,:,AIFz-1:AIFz+1,:]
                    AIFz = 1
                    print("After " + str(ctp_vals.shape))
            if ctp_vals.shape[2] != 2 and AIFz+1 <  ctp_vals.shape[2] and AIFz==0:
                    AIFx,AIFy,AIFz = cases_annotations[cur_case][0][0],cases_annotations[cur_case][0][1], cases_annotations[cur_case][0][2]
                    #print("Adding it in the other direction " + str(ctp_vals.shape))
                    ctp_vals = ctp_vals[:,:,AIFz:AIFz+2,:]
                    AIFz = 0
                    print("After " + str(ctp_vals.shape))
            else:
                ("Not processed")
        datalist.append({"image": fname, "ctpvals": ctp_vals[:,:,:,0:min_num_volumes_ctp]})
    return datalist

In [7]:
ctp_volumes = read_isles_volumes(root_dir, aif_annotations_path, min_num_volumes_ctp, take_two_slices_only=False)

In [8]:
print(len(ctp_volumes), len(aif_annotations))

1 1


In [9]:
#Dataset generator
class ISLES18DataGen_aifvof(tf.keras.utils.Sequence):
  
    def __init__(self, 
                 ctp_volumes,
                 annotations_aif,
                 annotations_vof,
                 minimum_number_volumes_ctp,
                 batch_size=1,
                 input_size=(256, 256, None,43),
                 shuffle=True):
        self.ctp_volumes = ctp_volumes 
        self.labels_aif = annotations_aif,
        self.labels_vof = annotations_vof,
        self.minimum_number_volumes_ctp = minimum_number_volumes_ctp
        self.annotations = [annotations_aif],[annotations_vof]
        self.batch_size = batch_size
        self.input_size = input_size
        self.shuffle = shuffle        
        self.n = len(self.ctp_volumes)
        self.indices = np.arange(len(self.ctp_volumes))

    def on_epoch_end(self):
        if self.shuffle:
            np.random.shuffle(self.indices)

    def __get_input(self, img_idx):
        #Get the volume
        ctp_vals = self.ctp_volumes[img_idx]['ctpvals']
        volume = normalize(ctp_vals)
        #Get the labels
        case_id = ctp_volumes[img_idx]['image'].split('.')[-2]

        label_aif = normalize_aif(self.labels_aif[0][case_id])
        label_vof = self.labels_vof[0][case_id]
        #labels = np.array([label_aif,label_vof])
        return volume,[label_aif,label_vof]
        #return volume,label_aif
    
    def pct_augment():
        pass
    
    def __getitem__(self, idx): #This function returns the batch 
        inds = self.indices[idx * self.batch_size:(idx + 1) * self.batch_size]
        #print(inds)
        #batch_x = [self.ctp_volumes[index] for index in inds]
        #batch_y = self.annotations[inds]
        batch_x, batch_y = [], []
        for index in inds:            
            x, y = self.__get_input(index)
            batch_x.append(x)
            batch_y.append(y)
        return np.array(batch_x), np.array(batch_y).squeeze()

    def __len__(self):
        return self.n // self.batch_size

In [10]:
pp = ctp_volumes[0]['ctpvals']
np.zeros(pp.shape)
pp.shape
np.sum(np.sqrt(normalize(pp).flatten()))

3256993.0

In [11]:
aif_annotations['345582'].shape[-1]

43

In [118]:
def late_bolus(volume_sequence, labels, delay_t=None):
    delayed_volume, delayed_intensity = np.zeros(volume_sequence.shape), np.zeros(labels.shape)
    nb_timepoints = volume_sequence.shape[-1]

    first_volume = volume_sequence[:,:,:,0]    
    if delay_t == None:
        delay_t = random.randint(1,int(nb_timepoints/3))
    if delay_t == 0:
        return volume_sequence, labels
    
    for i in range(0,delay_t+1):
        delayed_volume[:,:,:,i] = first_volume
    #print(i)
    delayed_volume[:,:,:,i:] = volume_sequence[:,:,:,1:nb_timepoints-i+1]
    
    delayed_intensity[0:delay_t] = labels[0] #Repeating the first time point
    delayed_intensity[delay_t:] = labels[1:nb_timepoints-delay_t+1]
    
    return delayed_volume, delayed_intensity

def early_bolus(volume_sequence, labels, delay_t=None):
    early_volume, early_intensity = np.zeros(volume_sequence.shape), np.zeros(labels.shape)
    nb_timepoints = volume_sequence.shape[-1]

    last_volume = volume_sequence[:,:,:,-1]    
    if delay_t == None:
        delay_t = random.randint(1,int(nb_timepoints/3))
    if delay_t == 0:
        return volume_sequence, labels
    
    early_volume[:,:,:,0:nb_timepoints-delay_t] = volume_sequence[:,:,:,delay_t:nb_timepoints]
    for i in range(nb_timepoints-delay_t,nb_timepoints):   
        early_volume[:,:,:,i] = volume_sequence[:,:,:,-1]
    
    early_intensity[0:nb_timepoints-delay_t] = labels[delay_t:nb_timepoints] #Shifting the first time points
    early_intensity[nb_timepoints-delay_t:] = labels[-1]
    
    return early_volume, early_intensity


In [119]:
pp = ctp_volumes[0]['ctpvals']
np.sum(np.sqrt(normalize(pp).flatten()))


3256993.0

In [120]:
pp.shape

(256, 256, 8, 43)

In [121]:
43-10

33

In [122]:
sum(pp[:,:,:,42].flatten())

18082584.740959097

In [127]:
pp_late, pl_late   = late_bolus(pp,aif_annotations['345582'],20)
pp_early, pl_early = early_bolus(pp,aif_annotations['345582'],20)

In [128]:
for i in range(43):
    print(str(i) + " <=> " + str(int(np.sum(pp[:,:,:,i].flatten())))  + ", " + str(int(np.sum(pp_early[:,:,:,i].flatten())))
         + "   ==  " +str((aif_annotations['345582'][i])) +", " + str((pl_early[i])))

0 <=> 17446955, 18274742   ==  46.00687026977539, 64.83368682861328
1 <=> 17491673, 18187071   ==  45.46455383300781, 68.8190689086914
2 <=> 17497241, 18154510   ==  44.42026138305664, 63.13859558105469
3 <=> 17516392, 18138570   ==  51.08286666870117, 68.21202850341797
4 <=> 17535508, 18099539   ==  72.60696411132812, 59.45867156982422
5 <=> 17697656, 18098316   ==  100.85197448730469, 92.30133819580078
6 <=> 17921554, 18347820   ==  122.21617889404297, 79.43942260742188
7 <=> 18159512, 18081879   ==  137.44642639160156, 79.08687591552734
8 <=> 18370063, 18094479   ==  119.1580810546875, 53.11215591430664
9 <=> 18637204, 18108415   ==  172.7520751953125, 69.41795349121094
10 <=> 18880359, 18087250   ==  163.13604736328125, 65.49166107177734
11 <=> 19055675, 18068658   ==  147.68309020996094, 73.1170883178711
12 <=> 19117529, 18054814   ==  137.92901611328125, 77.52625274658203
13 <=> 19151811, 18051509   ==  117.84191131591797, 71.06597137451172
14 <=> 19108125, 18077758   ==  103.170

In [126]:
for i in range(43):
    print(str(i) + " <=> " + str(int(np.sum(pp[:,:,:,i].flatten())))  + ", " + str(int(np.sum(pp_late[:,:,:,i].flatten())))
         + "   ==  " +str((aif_annotations['345582'][i])) +", " + str((pl_late[i])))

0 <=> 17446955, 17446955   ==  46.00687026977539, 46.00687026977539
1 <=> 17491673, 17446955   ==  45.46455383300781, 46.00687026977539
2 <=> 17497241, 17446955   ==  44.42026138305664, 46.00687026977539
3 <=> 17516392, 17446955   ==  51.08286666870117, 46.00687026977539
4 <=> 17535508, 17446955   ==  72.60696411132812, 46.00687026977539
5 <=> 17697656, 17446955   ==  100.85197448730469, 46.00687026977539
6 <=> 17921554, 17446955   ==  122.21617889404297, 46.00687026977539
7 <=> 18159512, 17446955   ==  137.44642639160156, 46.00687026977539
8 <=> 18370063, 17446955   ==  119.1580810546875, 46.00687026977539
9 <=> 18637204, 17446955   ==  172.7520751953125, 46.00687026977539
10 <=> 18880359, 17491673   ==  163.13604736328125, 45.46455383300781
11 <=> 19055675, 17497241   ==  147.68309020996094, 44.42026138305664
12 <=> 19117529, 17516392   ==  137.92901611328125, 51.08286666870117
13 <=> 19151811, 17535508   ==  117.84191131591797, 72.60696411132812
14 <=> 19108125, 17697656   ==  103.1

In [85]:
43-10

33

In [51]:
pp_late
np.sum(np.sqrt(normalize(pp_late[0]).flatten()))

3256993.0

In [130]:
def scale_pct_values(pre_contrast,pct_sequence):
    scaled_sequence = np.zeros(pct_sequence.shape)
    #Substracting the pre-contrast to the sequence
    for i in range(nb)
    
    return scaled_sequence

In [135]:
scale_pct_values(pre_contrast=None,pct_sequence=pp)

array([[[[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]],

        [[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]],

        [[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]],

        ...,

        [[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
    

In [132]:

pp.shape

(256, 256, 8, 43)