In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import pickle
from functools import reduce
import random

from keras.preprocessing.image import ImageDataGenerator
import keras 

from keras import regularizers
import keras.backend as K

import tensorflow as tf
from keras.models import Model
from keras.layers import Lambda, UpSampling2D, \
            Flatten, Dense, Conv2D, MaxPooling2D, ZeroPadding2D

from keras.optimizers import Adam
from keras.applications import ResNet50, densenet

In [None]:
class Image_class:
    def __init__(self, pat_n, image, roi, n_serial):
        
        self.name = pat_n
        self.image = image
        self.roi = roi
        self.mal = False
        self.n_serial = n_serial
        
        if np.sum(roi) > 0:
            self.mal = True
        
    def print_roi(self, cmap = 'plasma', linewidths = 0.5,  **args):
        plt.contour(self.roi, cmap = cmap, linewidths = linewidths, **args)
            
    def print_image(self, cmap = 'gray', **args):
        plt.imshow(self.image, cmap = 'gray', **args)
 

In [None]:
def create_train_test_sets(poss, negs, admitted_pos, num_test = 50 ):
    negs.pop('Brats18_CBICA_AWI_1')
    poss.pop('Brats18_CBICA_AWI_1')
    
    poss_shuffled = list(poss.keys())
    random.shuffle(poss_shuffled)
    
    test_pts = poss_shuffled[ :num_test]
    train_pts = poss_shuffled[num_test:]
    
    files_pos_train = reduce(lambda x,y:x+y, [poss[x] for x in train_pts])
    files_neg_train = reduce(lambda x,y:x+y, [negs[x] for x in train_pts])

    files_pos_train = [x for x in files_pos_train if x in admitted_pos]
    random.shuffle(files_neg_train)
    
    files_train = files_pos_train + files_neg_train[:len(files_pos_train)]
    
    files_pos_test = reduce(lambda x,y:x+y, [poss[x] for x in test_pts])
    files_neg_test = reduce(lambda x,y:x+y, [negs[x] for x in test_pts])
    
    files_pos_test = [x for x in files_pos_test if x in admitted_pos]
    random.shuffle(files_neg_train)
    
    files_test = files_pos_test + files_neg_test[:len(files_pos_test)]
    
    
    random.shuffle(files_train)
    random.shuffle(files_test)
    
    return files_train, files_test
    

In [None]:
class DataGenerator(keras.utils.Sequence):
    'Generates data for Keras'
    def __init__(self, list_IDs, batch_size=16, dim=(170, 140), n_channels=3,
                  shuffle=True, OUT_SIZE = 2, TOP_K = 1, mode = 'train'):
        'Initialization'
        
        self.dim = dim # ?-1
        self.OUT_SIZE = OUT_SIZE
        self.batch_size = batch_size
        self.list_IDs = list_IDs
        self.n_channels = n_channels
        self.shuffle = shuffle
        self.TOP_K = TOP_K
        
        self.on_epoch_end()

        self.aug = ImageDataGenerator(rotation_range=45, width_shift_range=0.1,\
                                      height_shift_range=0.1, horizontal_flip=True)
        self.mode = mode

    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.floor(len(self.list_IDs) / self.batch_size))

    def __getitem__(self, index):
        'Generate one batch of data'
        # Generate indexes of the batch
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]

        # Find list of IDs
        list_IDs_temp = [self.list_IDs[k] for k in indexes]

        # Generate data
        X, y = self.__data_generation(list_IDs_temp)

        return X, y

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.indexes = np.arange(len(self.list_IDs))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)

    def __data_generation(self, list_IDs_temp):
        'Generates data containing batch_size samples' 
        # X : (n_samples, *dim, n_channels)

        X = np.empty((self.batch_size, *self.dim, self.n_channels))
        y = np.empty((self.batch_size, self.OUT_SIZE), dtype='float32')
        

        for i, ID in enumerate(list_IDs_temp):
            # Store sample
            
            with open('BRATS_preprocessed_/' + ID, 'rb') as filehandler:
                photo = pickle.load(filehandler)
                
            img = photo.image
            
            XX = np.repeat(img.reshape(*self.dim, 1), 3, axis = 2)   
            
            if self.mode == 'train':
                params = self.aug.get_random_transform(XX.shape)
                X[i,] = self.aug.apply_transform(XX, params)
            else:
                X[i,] = XX
            

            if photo.mal:
                y[i,:] = np.array([1.0, 0.0])

            else:
                y[i,:] = np.array([0.0, 1.0])  

        return X, y

In [None]:
def custom_acc(y_true, y_pred):
    return K.mean(K.equal(K.max(K.round(y_pred), axis = -1), K.max(y_true, axis = -1)))

def custom_loss(y_true, y_pred):

    c1 = tf.clip_by_value(y_pred, 1e-7, 1.0 - 1e-7)
    v1 = tf.math.multiply(1.0 - y_true, tf.math.log(1.0 - c1))
    v2 = tf.math.multiply(y_true, tf.math.log(c1))

    return -tf.reduce_mean(tf.add(v1, v2))#keras.losses.binary_crossentropy

def create_model(mm_, lambd_reg = 5e-6, learning_rate = 5e-5, mu_reg = 1e-6):
    
    
    with tf.variable_scope("weights", reuse=tf.AUTO_REUSE):
        w1 = tf.get_variable("w1", shape=(256,1), trainable=True)
        w2 = tf.get_variable("w2", shape=(512,1), trainable=True)
        w3 = tf.get_variable("w3", shape=(1024,1), trainable=True)
    #    w4 = tf.get_variable("w4", shape=(32,1), trainable=True)
    
    
    inp = mm_.input
    
    c_6 = mm_.get_layer("res2c_branch2c")
    #c_6.activity_regularizer = regularizers.l1(lambd_reg)
    c_6 = ZeroPadding2D(padding=((1,0),(1,0)))(c_6.output)
    
    
    c_18 = mm_.get_layer("res3d_branch2c")
    #c_18.activity_regularizer = regularizers.l1(lambd_reg)
    c_18 = c_18.output
    
    c_42 = mm_.get_layer("res4f_branch2c")
    #c_42.activity_regularizer = regularizers.l1(lambd_reg)
    c_42 = c_42.output
    
    '''
    c_58 = mm_.get_layer("activation_49")
    #c_58.activity_regularizer = regularizers.l1(lambd_reg)
    c_58 = c_58.output
    '''
    
    def multy(a):
        return lambda x: tf.tensordot(x, a, axes = [[-1],[0]])
    
    def add(a):
        return tf.math.add(a[0], a[1])
    
    upsampling1 = UpSampling2D(size=(2, 2), data_format="channels_last", interpolation='bilinear')
    upsampling2 = UpSampling2D(size=(2, 2), data_format="channels_last", interpolation='bilinear')
    upsampling3 = UpSampling2D(size=(2, 2), data_format="channels_last", interpolation='bilinear')
    
    layer_4 = Lambda(multy(w3))(c_42)
    layer_3 = Lambda(add)([upsampling1(layer_4), Lambda(multy(w2))(c_18)])
    layer_2 = Lambda(add)([upsampling2(layer_3), Lambda(multy(w1))(c_6)])

    conv = Conv2D(1, (1,1), name = 'Conv_Sigmoid', 
                  activity_regularizer = regularizers.l1(mu_reg),\
                  activation = 'sigmoid')(layer_2)
    
    out = Dense(2, activation = 'softmax')(Flatten()(conv))
    
    
    model_train = Model(inputs = inp, outputs = out)
    
    #for layer in model_train.layers:
    #    layer.kernel_regularizer = regularizers.l2(lambd_reg)
        
    model_image = Model(inputs = inp, outputs = conv)
    
    adam = Adam(learning_rate)
    
    #model_train.compile(optimizer = adam, loss = custom_loss, metrics = [custom_acc])
    model_train.compile(optimizer = adam, loss = 'binary_crossentropy', metrics = ['accuracy'])
    
    return model_train, model_image

In [None]:
with open('patients_pos', 'rb') as filehandler:
    patients_pos = pickle.load(filehandler)
    
with open('patients_neg', 'rb') as filehandler:
    patients_neg = pickle.load(filehandler)
    
with open('10_percent_lesions', 'rb') as filehandler:
    admitted_pos = pickle.load(filehandler)   
    
train_namespace, val_namespace = create_train_test_sets(\
                patients_pos, patients_neg, admitted_pos)

In [None]:
#model = densenet.DenseNet121(include_top = False, weights = 'imagenet', input_shape = (170, 140, 3))

model = ResNet50(include_top=False, weights='imagenet', input_shape = (170, 140, 3))

In [None]:
1 + 1

In [None]:
model.summary()

In [None]:
mm, mm_image = create_model(model)

training_generator = DataGenerator(train_namespace, mode = 'train')
validation_generator = DataGenerator(val_namespace, mode = 'val')


In [None]:
mm.fit_generator(generator=training_generator,
                validation_data=validation_generator,
                use_multiprocessing=True,
                 #callbacks=callbacks_list,
                workers=4, verbose = 1, epochs = 500)

In [None]:
def downsample_map(image, dims_tr = (44, 34)):
    h, w = image.shape
    
    image_tr = np.zeros(dims_tr)
    for i in range(h):
        for j in range(w):
            if image[i,j] > 0:
                image_tr[int((i / h) // (1.0 / dims_tr[0])),
                     int((j / w) // (1.0 / dims_tr[1]))] = 1.0
            
    return image_tr

for ID in val_namespace:
#ID = namespace[400]


    with open('BRATS_preprocessed_/' + ID, 'rb') as filehandler:
        photo = pickle.load(filehandler)
        
    photo.print_image()
    photo.print_roi()
    plt.show()
    
    img = downsample_map(photo.roi, dims_tr = (44, 34))
    
    plt.imshow(img)
    plt.show()

In [None]:
for ID in val_namespace:
#ID = namespace[400]


    with open('BRATS_preprocessed_/' + ID, 'rb') as filehandler:
        photo = pickle.load(filehandler)

    print(photo.mal)
    #with open('Data_original/Original/' + ID, 'rb') as filehandler:
    #    photo = pickle.load(filehandler, encoding = 'latin1')


    pred = mm_image.predict(np.repeat(photo.image.reshape(1, 170, 140, 1), 3, axis=3))
    photo.print_image()
    photo.print_roi()
    plt.show()

    plt.imshow(pred.reshape((44,36)), cmap = 'Blues')
    plt.colorbar()
    plt.show()

1584 / 

In [None]:
all_files_pos = reduce(lambda x,y:x+y, patients_pos.values())

In [None]:
admitted_positives = []

for i, name in enumerate(all_files_pos):
    
    if i % 100 == 0:
        print(i)
        
    with open('BRATS_preprocessed/' + name, 'rb') as filehandler:
        img = pickle.load(filehandler)
        

In [None]:


with open('BRATS_preprocessed/Brats18_2013_4_1_119', 'rb') as filehandler:
    img = pickle.load(filehandler)

img.print_roi(cmap = 'Reds')
img.print_image()
plt.show()

In [None]:
roi = np.sum(img.roi > 0)

In [None]:
roi

In [None]:
img.image.shape

In [None]:
'Brats18_2013_4_1_120' in patients_pos['Brats18_2013_4_1']

In [None]:
img.image.shape

In [None]:
l = sum([], patients_pos.values())

In [None]:
sum([[1,2,3], [4,5]])

In [None]:
np.tile(np.array([[1,2,3],[4,5,6]]), 3)

In [None]:
np.min(f)

In [None]:
np.var([1,2,3])

In [None]:
img.image

In [None]:
from skimage.transform import resize

In [None]:
resize(img.image, (140, 170))

In [None]:
for file in [patients_pos[x][0] for x in patients_pos if len(patients_pos[x]) > 0]:
    with open('BRATS_preprocessed/' + file ,'rb') as filehandler:
        img = pickle.load(filehandler)
    print(file)    
    img.print_roi(cmap = 'Reds')
    img.print_image()
    plt.show()

In [None]:
import random
random.shuffle(all_files)

In [None]:
len([patients_pos[x][0] for x in patients_pos if '2013' in x and len(patients_pos[x]) > 0])

In [None]:
all_files

In [None]:
len(patients_pos.keys())

In [None]:
mm.summary()