In [8]:
import os
import matplotlib.pyplot as plt
import glob
import pandas as pd
import numpy as np
import time
import cv2
import keras
from tensorflow.keras.applications.resnet50 import ResNet50
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions
from tensorflow.keras import layers
from tensorflow.keras import Model
from tensorflow.keras import layers
from tensorflow import keras
from tensorflow.keras.preprocessing import image
import tensorflow as tf

In [15]:
df = pd.read_csv('result.csv',dtype='str')

In [16]:
df[:4]

Unnamed: 0,folder_id,FLAIR,T1w,T1wCE,T2w,BraTS21ID,MGMT_value
0,0,288,29,86,274,0,1
1,2,67,27,88,269,2,1
2,3,71,28,88,285,3,0
3,5,277,26,83,272,5,1


In [28]:
class DataGenerator(keras.utils.Sequence):
    def __init__(self,csv_path='result.csv',width=256,height=256,batch_size=16,shuffle=True,max_len=364):
        self.df = pd.read_csv(csv_path,dtype='str')
        self.batch_size = batch_size
        self.width = width
        self.height = height
        self.shuffle = shuffle
        self.max_len = max_len
        self.predictor = self.load_model()
        self.on_epoch_end()
        
    def load_model(self):
        model = ResNet50(weights='imagenet',include_top=False,input_shape=(self.height,self.width,3))
        x = layers.GlobalAveragePooling2D()(model.output)
        new_model = Model(inputs=model.inputs,outputs=x)
        return new_model
        
    def on_epoch_end(self):
        if self.shuffle:
            self.df = self.df.sample(frac=1).reset_index(drop=True)
    
    def __len__(self):
        return int(len(self.df)/self.batch_size)
    
    def __getitem__(self,index):
        start =time.time()
        batch = self.df[index*self.batch_size:(index+1)*self.batch_size]
        X1,X1_mask,X2,X2_mask,X3,X3_mask,X4,X4_mask = self.__data_generation_batch('train/'+batch['folder_id'])
        print('time taken:',(time.time()-start)*1000)
        return X1,X1_mask,X2,X2_mask,X3,X3_mask,X4,X4_mask,batch['MGMT_value']
    
    def __data_generation(self,folder_name):
        all_images = []
        all_embeds = []
        flair_path = glob.glob(folder_name+'/FLAIR/*')
        t1w_path = glob.glob(folder_name+'/T1w/*')
        t1wce_path = glob.glob(folder_name+'/T1wCE/*')
        t2w_path = glob.glob(folder_name+'/T2w/*')
        for img_path in flair_path:
            img = image.load_img(img_path,target_size=(self.height,self.width))
            img = image.img_to_array(img)
            all_images.append(img)
        for img_path in t1w_path:
            img = image.load_img(img_path,target_size=(self.height,self.width))
            img = image.img_to_array(img)
            all_images.append(img)
        for img_path in t1wce_path:
            img = image.load_img(img_path,target_size=(self.height,self.width))
            img = image.img_to_array(img)
            all_images.append(img)
        for img_path in t2w_path:
            img = image.load_img(img_path,target_size=(self.height,self.width))
            img = image.img_to_array(img)
            all_images.append(img)
        all_images = np.array(all_images)
        all_images_preprocessed = preprocess_input(all_images)
        print(all_images_preprocessed.shape)
        for index in range(0,len(all_images_preprocessed),self.batch_size):
            batch = all_images_preprocessed[index:(index+self.batch_size)]
            preds = self.predictor.predict_on_batch(batch)
            all_embeds.extend(preds)
        all_embeds =  np.array(all_embeds)
        flair_embed = all_embeds[:len(flair_path)]
        t1w_embed = all_embeds[len(flair_path):len(t1w_path)+len(flair_path)]
        t1wce_embed = all_embeds[len(t1w_path)+len(flair_path):len(t1w_path)+len(flair_path)+len(t1wce_path)]
        t2w_embed = all_embeds[len(t1w_path)+len(flair_path)+len(t1wce_path):]
        return flair_embed,t1w_embed,t1wce_embed,t2w_embed
    
    def pad_and_mask(self,embed_vec):
        embed_vec = np.array(embed_vec)
        print(embed_vec.shape)
        mask = np.ones(embed_vec.shape[0],dtype=np.uint8)
        mask = np.pad(mask,((0,self.max_len-len(embed_vec))),"constant")
        embed_vec = np.pad(embed_vec,((0,self.max_len-len(embed_vec)),(0,0)),"constant")
        mask = np.expand_dims(mask,axis=-1)
        mask_matrix = np.dot(mask,np.transpose(mask))
        return embed_vec,mask_matrix
        
    def __data_generation_batch(self,batch):
        X1 = np.empty((self.batch_size,self.max_len,2048))
        X1_mask = np.empty((self.batch_size,self.max_len,self.max_len))
        X2 = np.empty((self.batch_size,self.max_len,2048))
        X2_mask = np.empty((self.batch_size,self.max_len,self.max_len))
        X3 = np.empty((self.batch_size,self.max_len,2048))
        X3_mask = np.empty((self.batch_size,self.max_len,self.max_len))
        X4 = np.empty((self.batch_size,self.max_len,2048))
        X4_mask = np.empty((self.batch_size,self.max_len,self.max_len))
        for i,ID in enumerate(batch):
            flair_embed,t1w_embed,t1wce_embed,t2w_embed = self.__data_generation(ID)
            flair_embed_pad,flair_mask = self.pad_and_mask(flair_embed)
            t1w_embed_pad,t1w_mask = self.pad_and_mask(t1w_embed)
            t1wce_embed_pad,t1wce_mask = self.pad_and_mask(t1wce_embed)
            t2w_embed_pad,t2w_mask = self.pad_and_mask(t2w_embed)
            X1[i,] = flair_embed_pad
            X1_mask[i,] = flair_mask
            X2[i,] = t1w_embed_pad
            X2_mask[i,] = t1w_mask
            X3[i,] = t1wce_embed_pad
            X3_mask[i,] = t1wce_mask
            X4[i,] = t2w_embed_pad
            X4_mask[i,] = t2w_mask
        return X1,X1_mask,X2,X2_mask,X3,X3_mask,X4,X4_mask

In [30]:
datagen = DataGenerator()

In [31]:
datagen.predictor.summary()

Model: "model_5"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_6 (InputLayer)            [(None, 256, 256, 3) 0                                            
__________________________________________________________________________________________________
conv1_pad (ZeroPadding2D)       (None, 262, 262, 3)  0           input_6[0][0]                    
__________________________________________________________________________________________________
conv1_conv (Conv2D)             (None, 128, 128, 64) 9472        conv1_pad[0][0]                  
__________________________________________________________________________________________________
conv1_bn (BatchNormalization)   (None, 128, 128, 64) 256         conv1_conv[0][0]                 
____________________________________________________________________________________________

In [32]:
X1,X1_mask,X2,X2_mask,X3,X3_mask,X4,X4_mask,y = datagen[0]

KeyboardInterrupt: 

In [None]:
100,512,512,3 