"""
RUN For the first time
"""
!pip install git+https://www.github.com/keras-team/keras-contrib.git
! pip install pydicom

In [None]:
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
import keras
import math
import pydicom
import os

from keras.optimizers import Adam
from keras.utils import Sequence
import keras.backend as K
from keras.callbacks import ModelCheckpoint, CSVLogger, LearningRateScheduler
from sklearn.metrics import confusion_matrix, classification_report

# Consecutive Slices module

In [None]:
def getMaxLength(arr):
 
    # intitialize count
    count = 0
     
    # initialize max
    result = 0
 
    for i in range(0, len(arr)):
     
        # Reset count when 0 is found
        if (arr[i] == 0):
            count = 0
 
        # If 1 is found, increment count
        # and update result if count  
        # becomes more.
        else:
             
            # increase count
            count+= 1
            result = max(result, count)  
         
    return result

# 3D Model Declaration

In [None]:
"""
The code block for 3D CNN Architecture
Returns: V-Net Model
"""

import keras
import tensorflow as tf
import keras_contrib


# Building blocks
def adding_conv(x, a, filters, kernel_size, padding, strides, data_format, groups):
    channel_axis = -1 if data_format=='channels_last' else 1
    c = keras.layers.Conv3D(filters, kernel_size, padding=padding, strides=strides, 
            activation=None, data_format=data_format)(x)
    c = keras.layers.add([c, a])
    c = keras_contrib.layers.GroupNormalization(groups=groups, axis=channel_axis)(c)
    c = keras.layers.advanced_activations.PReLU()(c)
    return c

def conv(x, filters, kernel_size, padding, strides, data_format, groups):
    channel_axis = -1 if data_format=='channels_last' else 1
    c = keras.layers.Conv3D(filters, kernel_size, padding=padding, strides=strides, 
            activation=None, data_format=data_format)(x)
    c = keras_contrib.layers.GroupNormalization(groups=groups, axis=channel_axis)(c)
    c = keras.layers.advanced_activations.PReLU()(c)
    return c

def down_conv(x, filters, kernel_size, padding, data_format, groups):
    channel_axis = -1 if data_format=='channels_last' else 1
    c = keras.layers.Conv3D(filters, kernel_size, padding=padding, strides=2, 
                            activation=None, data_format=data_format)(x)
    c = keras_contrib.layers.GroupNormalization(groups=groups, axis=channel_axis)(c)
    c = keras.layers.advanced_activations.PReLU()(c)
    return c

def up_conv_concat_conv(x, skip, filters, kernel_size, padding, strides, data_format, groups):
    channel_axis = -1 if data_format=='channels_last' else 1
    c = keras.layers.Conv3DTranspose(filters, kernel_size=(2,2,2), strides=(2,2,2), 
                                    data_format=data_format)(x) # up dim(x) by x2
    c = keras_contrib.layers.GroupNormalization(groups=groups, axis=channel_axis)(c)
    c = keras.layers.Conv3D(filters, kernel_size, padding=padding, strides=strides, 
                            activation=None, data_format=data_format)(c)
    concat = keras.layers.Concatenate(axis=channel_axis)([c, skip]) # concat after Up; dim(skip) == 2*dim(x)
    c = keras_contrib.layers.GroupNormalization(groups=groups, axis=channel_axis)(concat)
    c = keras.layers.advanced_activations.PReLU()(c)
    return c


# Encoders
def encoder1(x, filters, kernel_size, padding, strides, data_format, groups):
    with tf.variable_scope('encoder1'):
        with tf.variable_scope('conv'):
            conv1 = conv(x, filters, kernel_size, padding, strides, data_format, groups)
        with tf.variable_scope('addconv'):
            addconv = adding_conv(conv1, conv1, filters, kernel_size, padding, strides, data_format, groups) # N
        with tf.variable_scope('downconv'):
            downconv = down_conv(addconv, filters*2, kernel_size, padding, data_format, groups) # N/2
        return (addconv, downconv)

def encoder2(x, filters, kernel_size, padding, strides, data_format, groups):
    with tf.variable_scope('encoder2'):
        with tf.variable_scope('conv'):
            conv1 = conv(x, filters, kernel_size, padding, strides, data_format, groups) 
        with tf.variable_scope('addconv'):
            addconv = adding_conv(conv1, x, filters, kernel_size, padding, strides, data_format, groups) # N/2
        with tf.variable_scope('downconv'):
            downconv = down_conv(addconv, filters*2, kernel_size, padding, data_format, groups) # N/4
        return (addconv, downconv)

def encoder3(x, filters, kernel_size, padding, strides, data_format, groups):
    with tf.variable_scope('encoder3'):
        with tf.variable_scope('conv1'):
            conv1 = conv(x, filters, kernel_size, padding, strides, data_format, groups) # N/4
        with tf.variable_scope('conv2'):
            conv2 = conv(conv1, filters, kernel_size, padding, strides, data_format, groups) # N/4
        with tf.variable_scope('addconv'):
            addconv = adding_conv(conv2, x, filters, kernel_size, padding, strides, data_format, groups) # N/4
        with tf.variable_scope('downconv'):
            downconv = down_conv(addconv, filters*2, kernel_size, padding, data_format, groups) # N/8
        return (addconv, downconv)

def encoder4(x, filters, kernel_size, padding, strides, data_format, groups):
    with tf.variable_scope('encoder4'):
        with tf.variable_scope('conv1'):
            conv1 = conv(x, filters, kernel_size, padding, strides, data_format, groups) # N/8
        with tf.variable_scope('conv2'):
            conv2 = conv(conv1, filters, kernel_size, padding, strides, data_format, groups) # N/8
        with tf.variable_scope('addconv'):
            addconv = adding_conv(conv2, x, filters, kernel_size, padding, strides, data_format, groups) # N/8
        with tf.variable_scope('downconv'):
            downconv = down_conv(addconv, filters*2, kernel_size, padding, data_format, groups) # N/16
        return (addconv, downconv)


# Bottom
def bottom(x, filters, kernel_size, padding, strides, data_format, groups):
    with tf.variable_scope('bottom'):
        with tf.variable_scope('conv1'):
            conv1 = conv(x, filters, kernel_size, padding, strides, data_format, groups)
        with tf.variable_scope('conv2'):
            conv2 = conv(conv1, filters, kernel_size, padding, strides, data_format, groups)
        with tf.variable_scope('addconv'):
            addconv = adding_conv(conv2, x, filters, kernel_size, padding, strides, data_format, groups) # N/16
        return addconv # N/16


# Decoders
def decoder4(x, skip, filters, kernel_size, padding, strides, data_format, groups):
    with tf.variable_scope('decoder4'):
        with tf.variable_scope('upconv'):
            upconv = up_conv_concat_conv(x, skip, filters, kernel_size, padding, strides, data_format, groups) # N/8
        with tf.variable_scope('conv1'):
            conv1 = conv(upconv, filters, kernel_size, padding, strides, data_format, groups)
        with tf.variable_scope('conv2'):
            conv2 = conv(conv1, filters, kernel_size, padding, strides, data_format, groups)
        return conv2 # N/8

def decoder3(x, skip, filters, kernel_size, padding, strides, data_format, groups):
    with tf.variable_scope('decoder3'):
        with tf.variable_scope('upconv'):
            upconv = up_conv_concat_conv(x, skip, filters, kernel_size, padding, strides, data_format, groups) # N/4
        with tf.variable_scope('conv1'):
            conv1 = conv(upconv, filters, kernel_size, padding, strides, data_format, groups)
        with tf.variable_scope('conv2'):
            conv2 = conv(conv1, filters, kernel_size, padding, strides, data_format, groups)
        return conv2 # N/4

def decoder2(x, skip, filters, kernel_size, padding, strides, data_format, groups):
    with tf.variable_scope('decoder2'):
        with tf.variable_scope('upconv'):
            upconv = up_conv_concat_conv(x, skip, filters, kernel_size, padding, strides, data_format, groups) # N/2
        with tf.variable_scope('conv'):
            conv1 = conv(upconv, filters, kernel_size, padding, strides, data_format, groups)
        return conv1 # N/2

def decoder1(x, skip, filters, kernel_size, padding, strides, data_format, groups):
    with tf.variable_scope('decoder1'):
        with tf.variable_scope('upconv'):
            upconv = up_conv_concat_conv(x, skip, filters, kernel_size, padding, strides, data_format, groups) # N
        return upconv # N


# Attention gate
def attention_gate(inp, g, intra_filters):
    with tf.variable_scope('attention_gate'):
        data_format = 'channels_first'##@##
        groups = 8 ##@##

        # Gating signal processing
        g = keras.layers.Conv3D(intra_filters, kernel_size=1, data_format=data_format)(g) # N/2
        g = keras_contrib.layers.GroupNormalization(groups=groups, axis=1)(g) # N/2

        # Skip signal processing: 
        x = keras.layers.Conv3D(intra_filters, kernel_size=2, strides=2, padding='same', data_format=data_format)(inp) # N-->N/2
        x = keras_contrib.layers.GroupNormalization(groups=groups, axis=1)(x) # N

        # Add and proc
        g_x = keras.layers.Add()([g, x]) # N/2
        psi = keras.layers.Activation('relu')(g_x) # N/2
        psi = keras.layers.Conv3D(1, kernel_size = 1, padding='same', data_format=data_format)(psi) # N/2
        psi = keras_contrib.layers.GroupNormalization(groups=1, axis=1)(psi) # N/2
        psi = keras.layers.Activation('sigmoid')(psi) # N/2
        alpha = keras.layers.UpSampling3D(size=2, data_format=data_format)(psi) # N/2-->N


        x_hat = keras.layers.Multiply()([inp, alpha])
        return x_hat


# Model
def VNet(n_in, n_out, image_shape, filters, kernel_size, padding, strides, data_format, groups, inter_filters):
    with tf.variable_scope('VNet'):
        input_dim = image_shape+(n_in,) if data_format=='channels_last' \
            else (n_in,)+image_shape
        
        inputs = keras.layers.Input(input_dim)

        (encoder1_addconv, encoder1_downconv) = encoder1(inputs, filters*2**0, kernel_size, padding, strides, data_format, groups) # N, N/2
        (encoder2_addconv, encoder2_downconv) = encoder2(encoder1_downconv, filters*2**1, kernel_size, padding, strides, data_format, groups) # N/2, N/4
        (encoder3_addconv, encoder3_downconv) = encoder3(encoder2_downconv, filters*2**2, kernel_size, padding, strides, data_format, groups) # N/4, N/8
        (encoder4_addconv, encoder4_downconv) = encoder4(encoder3_downconv, filters*2**3, kernel_size, padding, strides, data_format, groups) # N/8, N/16

        bottom_addconv = bottom(encoder4_downconv, filters*2**4, kernel_size, padding, strides, data_format, groups) # N/16

        encoder4_ag = attention_gate(encoder4_addconv, bottom_addconv, inter_filters) # (N/8, N/16) --> N/8
        decoder4_conv = decoder4(bottom_addconv, encoder4_ag, filters*2**3, kernel_size, padding, strides, data_format, groups) # N/8
        encoder3_ag = attention_gate(encoder3_addconv, decoder4_conv, inter_filters) # (N/4, N/8) --> N/4
        decoder3_conv = decoder3(decoder4_conv, encoder3_ag, filters*2**2, kernel_size, padding, strides, data_format, groups) # N/4
        encoder2_ag = attention_gate(encoder2_addconv, decoder3_conv, inter_filters) # (N/2, N/4) --> N/2
        decoder2_conv = decoder2(decoder3_conv, encoder2_ag, filters*2**1, kernel_size, padding, strides, data_format, groups) # N/2
        encoder1_ag = attention_gate(encoder1_addconv, decoder2_conv, inter_filters) # (N, N/2) --> N
        decoder1_conv = decoder1(decoder2_conv, encoder1_ag, filters*2**0, kernel_size, padding, strides, data_format, groups) # N
       
        with tf.variable_scope("output"):
            outputs = keras.layers.Conv3D(n_out,
                (1,1,1),
                padding='same',
                activation='sigmoid',
                data_format=data_format)(decoder1_conv)

            model = keras.models.Model(inputs, outputs)

        return model

# ImageSlice Viewer

In [None]:
import ipywidgets as ipyw
import matplotlib.pyplot as plt
%matplotlib inline

class ImageSliceViewer3D:
    """ 
    ImageSliceViewer3D is for viewing volumetric image slices in jupyter or
    ipython notebooks. 
    
    User can interactively change the slice plane selection for the image and 
    the slice plane being viewed. 

    Argumentss:
    Volume = 3D input image
    figsize = default(8,8), to set the size of the figure
    cmap = default('plasma'), string for the matplotlib colormap. You can find 
    more matplotlib colormaps on the following link:
    https://matplotlib.org/users/colormaps.html
    
    """
    
    def __init__(self, volume, figsize=(8,8), cmap='plasma'):
        self.volume = volume
        self.figsize = figsize
        self.cmap = cmap
        self.v = [np.min(volume), np.max(volume)]
        
        # Call to select slice plane
        ipyw.interact(self.view_selection, view=ipyw.RadioButtons(
            options=['x-y','y-z', 'z-x'], value='x-y', 
            description='Slice plane selection:', disabled=False,
            style={'description_width': 'initial'}))
    
    def view_selection(self, view):
        # Transpose the volume to orient according to the slice plane selection
        orient = {"y-z":[1,2,0], "z-x":[2,0,1], "x-y": [0,1,2]}
        self.vol = np.transpose(self.volume, orient[view])
        maxZ = self.vol.shape[2] - 1
        
        # Call to view a slice within the selected slice plane
        ipyw.interact(self.plot_slice, 
            z=ipyw.IntSlider(min=0, max=maxZ, step=1, continuous_update=False, 
            description='Image Slice:'))
        
    def plot_slice(self, z):
        # Plot slice for the given plane and slice
        self.fig = plt.figure(figsize=self.figsize)
        plt.imshow(self.vol[:,:,z], cmap=plt.get_cmap(self.cmap))

In [None]:
"""
Custom ImageSliceViewer3D
Arguments:
    volume : 3D Volume of CT scan
    volume_1 : 3D Volume of Label/True masks
    volume_2 : 3D Volume of Predicted masks
"""


import ipywidgets as ipyw
import matplotlib.pyplot as plt
%matplotlib inline

class ImageSliceViewer3DMultipleColour:
    """ 
    ImageSliceViewer3D is for viewing volumetric image slices in jupyter or
    ipython notebooks. 
    
    User can interactively change the slice plane selection for the image and 
    the slice plane being viewed. 

    Argumentss:
    Volume = 3D input image
    figsize = default(8,8), to set the size of the figure
    cmap = default('plasma'), string for the matplotlib colormap. You can find 
    more matplotlib colormaps on the following link:
    https://matplotlib.org/users/colormaps.html
    
    """
    
    def __init__(self, volume, volume_1, volume_2, figsize=(8,8), cmap='plasma'):
        self.volume = volume/np.max(volume)
        self.volume_1 = volume_1/np.max(volume_1)
        self.volume_2 = volume_2/np.max(volume_2)
        self.figsize = figsize
        self.cmap = cmap
        self.v = [np.min(volume), np.max(volume)]
        
        # Call to select slice plane
        ipyw.interact(self.view_selection, view=ipyw.RadioButtons(
            options=['x-y','y-z', 'z-x'], value='x-y', 
            description='Slice plane selection:', disabled=False,
            style={'description_width': 'initial'}))
    
    def view_selection(self, view):
        # Transpose the volume to orient according to the slice plane selection
        orient = {"y-z":[1,2,0], "z-x":[2,0,1], "x-y": [0,1,2]}
        self.vol = np.transpose(self.volume, orient[view])
        self.vol_1 = np.transpose(self.volume_1, orient[view])
        self.vol_2 = np.transpose(self.volume_2, orient[view])
        maxZ = self.vol.shape[2] - 1
        
        # Call to view a slice within the selected slice plane
        ipyw.interact(self.plot_slice, 
            z=ipyw.IntSlider(min=0, max=maxZ, step=1, continuous_update=False, 
            description='Image Slice:'))
        
    def plot_slice(self, z):
        # Plot slice for the given plane and slice
        self.fig = plt.figure(figsize=self.figsize)
        
        img_true_mask = cv2.addWeighted(self.vol_1[:,:,z], 0.7, self.vol[:,:,z], 1, 0)
        img_pred_mask = cv2.addWeighted(self.vol_2[:,:,z], 0.7, self.vol[:,:,z], 1, 0)
#         final_image = np.hstack([self.vol[:,:,z], self.vol_1[:,:,z], self.vol_2[:,:,z]])
        final_image = np.hstack([self.vol[:,:,z], img_true_mask, img_pred_mask])
        
        plt.imshow(final_image, cmap=plt.get_cmap(self.cmap))

# 3D Batch Generator

In [None]:
"""
Data generator for 3D Volumes:
Arguments:
    ct_dataframe = Final prepared dataframe
    unique_id_list = list of names of unique identifyers (list of strings)
    list_num_slices_list = list of number of slices in the scans (list of numbers)
    batch_size = batch size (int)
    image_size = image size (tuple), eg (512,512)
    stack_size = depth of smaller volumes (int)
    overlap_size = overlap factor (int)
    
NOTE: Change the names of the columns according to the dataframe/pathology. 
      For eg: Change the names wherever there is "row."
      Here it is consolidation, predominant_consolidation, ground_glass_opacity, predominant_ground_glass_opacity
"""

from scipy import ndimage
import cv2
import random
class VolDataGenerator3D(Sequence):
    
    def __init__(self, ct_dataframe, unique_id_list, list_num_slices_list, batch_size, image_size, stack_size, overlap_size):
        self.unique_id_list = unique_id_list
        self.batch_size = batch_size
        self.image_size = image_size
        self.ct_dataframe = ct_dataframe
        self.stack_size = stack_size
        self.overlap_size = overlap_size
        self.batch_size = batch_size
        self.list_num_slices_list = list_num_slices_list
        
        assert(self.overlap_size < self.stack_size)
        random.seed(7)
        self.list_of_stacks = [(self.unique_id_list[i], current_stack) for i in range(len(self.unique_id_list)) for current_stack in self.make_overlapping_stacks(self.list_num_slices_list[i], self.stack_size, self.overlap_size) ]
#         random.shuffle(self.list_of_stacks)
        
        
    def __len__(self):
        return int(math.ceil(len(self.list_of_stacks)/self.batch_size))

    
    def rle_decode(self, mask_rle, shape):
        """
        Return an image array from run-length encoded string `mask_rle` with `shape`.
        """
        img = np.zeros(shape[0] * shape[1], dtype=np.uint)
        if mask_rle==[]:
            return np.zeros((shape[0], shape[1]), dtype=np.uint)
        else:
            s = mask_rle[0].split()
            starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
            starts -= 1
            ends = starts + lengths

            for low, up in zip(starts, ends): img[low:up] = 1
        return img.reshape(shape)

        
    def make_overlapping_stacks(self, num_slices, stack_size, overlap_size):
        """
        Makes the stacks of 3D Volumes for a single CT scan 
        
        Arguments: 
            num_slices = total number of slcies in a scan (int) 
            stack_size = stack size of the smaller volume
            overlap_size = overlap size
        
        Returns:
            list of indices = eg: ((0,80), (80,160), (160,240), (240,320))
        """
        list_of_indices = [(start, min(start + stack_size, num_slices - 1)) for start in range(0, num_slices, stack_size - overlap_size)]
        all_endings = list(map(lambda x: x[1], list_of_indices))
        first_index_of_max_end = all_endings.index(max(all_endings))
        list_of_indices = list_of_indices[: first_index_of_max_end + 1]
        self.list_of_indices = list_of_indices
        return list_of_indices


    def create_3d_vol(self, name_vs_num_slices_element):
        """
        Creates volumes of X and Y of requried size (image.shape[0], image.shape[1], stack_size)
        
        Arguments:
            name_vs_num_slices_element = For eg: ('ct_name', (0, 100))
            
        Returns:
            img_vol = smaller stack volume of ct scan
            label_vol = smaller stack volume of labels/masks
        """
        self.name_vs_num_slices_element = name_vs_num_slices_element
        temp_df = self.ct_dataframe.loc[self.ct_dataframe["unique_identifier"] == name_vs_num_slices_element[0]]
        temp_df = temp_df.iloc[name_vs_num_slices_element[1][0] : name_vs_num_slices_element[1][1]]
        temp_df = temp_df.reset_index(drop=True)
        
        img_size = (pydicom.read_file(temp_df.iloc[0].dicom_path).pixel_array).shape  #SHAPE OF IMAGE
        
        img_vol = np.zeros((self.image_size[0], self.image_size[1],  self.stack_size))
        label_vol = np.zeros((self.image_size[0], self.image_size[1],  self.stack_size))
        
        
        for i, row in temp_df.iterrows():
            img = (pydicom.read_file(row.dicom_path).pixel_array)
            img = cv2.resize(img, (self.image_size[0], self.image_size[1]))
            img_vol[:,:,i] = img
            
        unique_unique = (temp_df.combine_pathos.unique()).all()
        
        for i, row in temp_df.iterrows():

            label_con = np.maximum(self.rle_decode(eval(row.bleed), img_size), self.rle_decode(eval(row.bleed), img_size))
            label_ggo = np.maximum(self.rle_decode(eval(row.bleed), img_size), self.rle_decode(eval(row.bleed), img_size))
            label_whole = np.maximum(label_con, label_ggo)
            label = cv2.resize(label_whole.astype(float), (self.image_size[0], self.image_size[1]))
            label_vol[:,:,i] = label
                
        self.temp_df = temp_df
        return img_vol, label_vol
            

    def create_volume_image(self, name_vs_num_slices_element):
        """
        Creates final volume for X and Y of the expanded dimensions to fit the model requirmenets
        Arguments:
            name_vs_num_slices_element = For eg: ('ct_name', (0, 100))
        """
        img_vol, label_vol = self.create_3d_vol(name_vs_num_slices_element)        
        
        img_vol = np.expand_dims(img_vol, 0)
        label_vol = np.expand_dims(label_vol, 0)
        return img_vol, label_vol


    def __getitem__(self, index):
        
        studies_for_this_batch = self.list_of_stacks[(self.batch_size*index):self.batch_size*(index+1)]
        self.studies_for_this_batch = studies_for_this_batch
        
        X = np.zeros((len(studies_for_this_batch), 1, self.image_size[0], self.image_size[1], self.stack_size))
        Y = np.zeros((len(studies_for_this_batch), 1, self.image_size[0], self.image_size[1], self.stack_size))
        
   
        for j,study in enumerate(studies_for_this_batch):
            X[j], Y[j] = self.create_volume_image(study)
        return X, Y

In [None]:
def find_num_slices_list(dataframe, scan_name_list):
    """
    Returns the list of the number of slices in each of the unique identifyer
    Arguemtns:
        dataframe: the retrived dataframe from Sushrut CT
        scan_name_list = list of names of unique identifyers"""
    num_slices_list = []
    for name in scan_name_list:
        temp_df = dataframe.loc[dataframe.unique_identifier==name]
        num_slices_list.append(len(temp_df))
    return num_slices_list

In [None]:
"""
Change train_csv accordignly
"""
final_train_df = pd.read_csv('./CSV/brain_bleeds/train_platform_ycm.csv')
# final_train_df = get_modified_final_csv("", train_list)

final_train_df_unique_series_id_list = (final_train_df.unique_identifier.unique())
final_train_df_unique_series_id_num_list = find_num_slices_list(final_train_df, final_train_df_unique_series_id_list)
assert(len(final_train_df_unique_series_id_list)==len(final_train_df_unique_series_id_num_list))

In [None]:
"""
Change val_csv accordignly
"""
final_val_df = pd.read_csv('./CSV/brain_bleeds/val_platform_ycm.csv')
# final_val_df = get_modified_final_csv("",val_list)

final_val_df_unique_series_id_list = (final_val_df.unique_identifier.unique())
final_val_df_unique_series_id_num_list = find_num_slices_list(final_val_df, final_val_df_unique_series_id_list)
assert(len(final_val_df_unique_series_id_list)==len(final_val_df_unique_series_id_num_list))

In [None]:
"""
Change test_csv accordignly
"""
final_test_df = pd.read_csv('./CSV/brain_bleeds/test_platform_ycm.csv')
# final_test_df = get_modified_final_csv("",test_list)

final_test_df_unique_series_id_list = (final_test_df.unique_identifier.unique())
final_test_df_unique_series_id_num_list = find_num_slices_list(final_test_df, final_test_df_unique_series_id_list)
assert(len(final_test_df_unique_series_id_list)==len(final_test_df_unique_series_id_num_list))

# END*****
print(len(final_train_df.unique_identifier.unique()),len(final_val_df.unique_identifier.unique()),len(final_test_df.unique_identifier.unique()),len(final_out_test_df.unique_identifier.unique()))

In [None]:
"""
Parameters for the Data Generator
"""
image_size = (512, 512)
stack_size = 32
overlap_size = 20
batch_size = 1
assert(overlap_size < stack_size)

In [None]:
final_train_df_gen = VolDataGenerator3D(ct_dataframe = final_train_df,
                               unique_id_list = final_train_df_unique_series_id_list,
                               list_num_slices_list = final_train_df_unique_series_id_num_list,
                               batch_size = batch_size,
                               image_size = image_size,
                               stack_size = stack_size,
                               overlap_size = overlap_size)

In [None]:
final_val_df_gen = VolDataGenerator3D(ct_dataframe = final_val_df,
                               unique_id_list = final_val_df_unique_series_id_list,
                               list_num_slices_list = final_val_df_unique_series_id_num_list,
                               batch_size = batch_size,
                               image_size = image_size,
                               stack_size = stack_size,
                               overlap_size = overlap_size)

In [None]:
final_test_df_gen = VolDataGenerator3D(ct_dataframe = final_test_df,
                               unique_id_list = final_test_df_unique_series_id_list,
                               list_num_slices_list = final_test_df_unique_series_id_num_list,
                               batch_size = batch_size,
                               image_size = image_size,
                               stack_size = stack_size,
                               overlap_size = overlap_size)

In [None]:
x=80
ImageSliceViewer3DMultipleColour(np.squeeze(final_train_df_gen[x][0]), np.squeeze(final_train_df_gen[x][1]), np.squeeze(final_train_df_gen[x][1]), cmap='gray')

In [None]:
plt.imshow(np.squeeze(final_train_df_gen[80][0][0])[:,:,23])

# MODEL

In [None]:
"""
Importing V-Net
"""

image_shape = (image_size[0], image_size[1], stack_size)
group_size = 2
f_root = 8
filters = 4
model = VNet(image_shape=image_shape, n_in=1, n_out=1, 
        strides=1, padding='same', kernel_size=3,
        groups=group_size, data_format='channels_first',
        inter_filters=f_root, filters = filters)

In [None]:
model.summary(line_length=150)

In [None]:
"""
Model compile
"""
def dice_coef(y_true, y_pred):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.reshape(y_pred, (-1, 2))
    intersection = K.mean(y_true_f * y_pred_f[:,0]) + K.mean((1.0 - y_true_f) * y_pred_f[:,1])
    
    return 2. * intersection;

def dice_coef_loss(y_true, y_pred):
    return 1-dice_coef(y_true, y_pred)

smooth = 1
def dice_coef(y_true, y_pred):
    y_true_f = K.flatten(y_true)
    y_pred_f1 = K.flatten(K.round(y_pred))
    intersection = K.sum(y_true_f * y_pred_f1)
    return (2. * intersection) / (K.sum(y_true_f) + K.sum(y_pred_f1) + smooth)

model.compile(optimizer=Adam(lr=1e-3), loss='binary_crossentropy', metrics=['accuracy', dice_coef])

In [None]:
"""
Define Model checkpoints and lr scheduler
"""

model_checkpoint_path = '/opt/bucketdata/Users/Rohit/3D_CT_Model/'

model_checkpoint_1 = ModelCheckpoint(os.path.join(model_checkpoint_path, 'Weights', 'brain_bleeds_loss_platform_ycm_pretrained_covid.hdf5'),
                                     monitor='val_loss', mode='min',save_best_only=True, verbose=1)

# model_checkpoint_2 = ModelCheckpoint(os.path.join(model_checkpoint_path, 'Weights', 'brain_bleeds_1_dice.hdf5'),
#                                      monitor='val_dice_coef', mode='max',save_best_only=True, verbose=1)

csvlogger = CSVLogger(os.path.join(model_checkpoint_path, 'CSV_Log', 'brain_bleeds_platform_ycm_pretrained.csv'))


def scheduler(epoch, lr):
    if epoch  % 4 == 0 and epoch != 0:
        return lr / 3
    else:
        return lr

lr_scheduler = LearningRateScheduler(scheduler, verbose=1)

callbacks = [model_checkpoint_1, csvlogger, lr_scheduler]

In [None]:
"""
Load previous weights if any
"""

model_checkpoint_path_abhishek = '/opt/bucketdata/Users/Abhishek/3D_SEGMENTATION_COMPARE_23_NOV/stats_&_weights'
model.load_weights(os.path.join(model_checkpoint_path_abhishek, 'Weights', 'minvalloss_seg_3d_UNIQUE_INDENTIFIER_BATCH_APPROACH_trial_6_512X32_20overlap.hdf5'))

In [None]:
"""
Start the training
"""
model.fit_generator(final_train_df_gen,
                    validation_data=final_val_df_gen,
                    epochs=30,
                    callbacks=callbacks,
                    verbose=1,
                    shuffle=False)

In [None]:
plt.plot(pd.read_csv(os.path.join(model_checkpoint_path, 'CSV_Log', 'seg_3d_UNIQUE_INDENTIFIER_BATCH_APPROACH_trial_4_352X80_0overlap.csv')).val_loss)

# INFERENCE

In [None]:
final_test_df_preds = model.predict_generator(final_test_df_gen, verbose=1)

In [None]:
ImageSliceViewer3DMultipleColour(np.squeeze(final_test_df_gen[4][0]), np.squeeze(final_test_df_gen[4][1]), (np.squeeze(final_test_df_preds[4]>0.08)), cmap='gray', figsize=(8,8))