In [None]:
# basics
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import matplotlib.pyplot as plt
import os
from os.path import join
import sys
from sklearn.model_selection import train_test_split
# pytorch model

import torch
from torch import nn
from torch.nn import functional as F
import torchvision
from torch.utils.data import DataLoader, Dataset

# pyTorch Lightning

import pytorch_lightning as pl
from torchmetrics import Accuracy

# image processing

from skimage.io import imread
from scipy.ndimage import zoom # image resizing 3D
from skimage.transform import resize
import pydicom
from pydicom.pixel_data_handlers.util import apply_voi_lut


# Importing Data

In [None]:
! mkdir my-train-data
import zipfile
with zipfile.ZipFile('../input/rsnabraintumorclassification-64-256-256/1MS8S5qFadxAqPCrd0MtKts4ciGH5W1L-', 'r') as zip_ref:
    zip_ref.extractall('my-train-data')
    

In [None]:
! rm ./my-train-data/00109.pt ./my-train-data/00123.pt ./my-train-data/00709.pt

# **Reading filenames**

In [None]:
INPUT_FOLDER = '../input/rsna-miccai-brain-tumor-radiogenomic-classification'
INPUT_FOLDER_PNG = '../input/rsna-miccai-png'

labels_df = pd.read_csv(join(INPUT_FOLDER, 'train_labels.csv'))
labels_df = labels_df.sort_values('BraTS21ID')
patients_train = os.listdir('my-train-data')
patients_test = os.listdir(join(INPUT_FOLDER, 'test'))
# removing examples with errors mentioned in discussion
# erronous examples
error_examples = ['00109', '00123', '00709']

# remove from directory list
# for error in error_examples: patients_train.remove(error)
labels_df = labels_df[[ x not in [int(y) for y in error_examples ] for x in labels_df.BraTS21ID ]]
patients_train.sort()

print(f'Number of train data : {len(patients_train)}\nNumber of test data : {len(patients_test)}')

In [None]:
labels_df['patient_folder'] = patients_train

In [None]:
labels_df.head()

# **Split Data to train and validation**

In [None]:
train_info = pd.DataFrame({'patient_id': patients_train, 'patient_label': labels_df.MGMT_value})

train_info, val_info = train_test_split(train_info, test_size=0.18,
                                        stratify=train_info.patient_label,
                                        random_state=42
                                       )
print(train_info.head())
print('------')
print(val_info.head())

# **Constant & Enums**

In [None]:
IMAGE_DEPTH = 64
IMAGE_SIZE = [256, 256]
IMAGE_DIMS = [IMAGE_DEPTH, *IMAGE_SIZE]
BATCH_SIZE = 4
GET_ITEM_ACCESS = 0
PRETRAINED_PATH = '../input/medicalnet-pretrained-weights/resnet_34.pth'
class MRITypes:
    flair = 'FLAIR'
    tw1ce = 'T1wCE'
    t1w = 'T1w'
    t2w = 'T2w'
    
def get_types():
    return [ MRITypes.tw1ce]
    
def get_index(mri_type : MRITypes):
    return get_types().index(mri_type)

# **Load image functions**

In [None]:
# load scans in a folder
def load_scan(path):
    slices = [ pydicom.read_file(join(path, slice_file))
                    for slice_file in os.listdir(path)]
    
    slices.sort(key = lambda x: float(x.ImagePositionPatient[2]))
    
    return slices

def load_scan_png(path):
    sorted_filenames = sorted(os.listdir(path), key=lambda x: int(x[:-4][6:]))
    slices = [ imread(join(path, slice_file), as_gray=True) for slice_file in sorted_filenames]
    
    return slices

# Making the train dataset class

The following are my preprocessing for images of different types.


In [None]:
#-------------------------------------------------------------------------------------------------------------------------
# preprocessing png
def preprocess_png_image(image):
    # resize to IMAGE_SHAPE
    trans_image = np.array([ resize(x, IMAGE_SIZE) for x in image ])
    
    # set image depth to IMAGE_DEPTH
    current_depth = trans_image.shape[0]
    trans_image = zoom(trans_image, (IMAGE_DEPTH/current_depth, 1, 1))
    
    # turn image to pytorch tensor
    trans_image = torch.tensor(trans_image, dtype=torch.float32)
    
    # normalize images to values between [0, 1]
    trans_image /= 255
    
    return trans_image
#-------------------------------------------------------------------------------------------------------------------------
# preprocessing function

def preprocess_dicom_image(image):
    
    # remove all black images
    trans_image = [ x for x in image if np.any(x.pixel_array != 0)]
    
    # apply voi lut (which is a filter that makes it easier to spot things)
    trans_image = [ apply_voi_lut(x.pixel_array, x) for x in trans_image]
    
    # reverse image if monochrome (some images are inverted)
    trans_image = [ np.amax(x) - x if dicom.PhotometricInterpretation == "MONOCHROME1" else x
                       for x, dicom in zip(trans_image, image) ]
  

    # resize images to IMAGE_SIZE and discard images that are all black
    trans_image = np.array([ resize(x, IMAGE_SIZE) for x in trans_image ])
    
    
    # set image depth to IMAGE_DEPTH
    current_depth = trans_image.shape[0]
    trans_image = zoom(trans_image, (IMAGE_DEPTH/current_depth, 1, 1))
    
    
    # normalize images using the min max approach to make in range [0, 1]
    
    trans_image = [ x - np.min(x) for x in trans_image]
    trans_image = [ x / np.max(x) for x in trans_image]
    
    
    trans_image = torch.tensor(trans_image, dtype=torch.float32)
    return trans_image
#-------------------------------------------------------------------------------------------------------------------------


## **Dataset for my already preprocessed data**

As preprocessing took a long time, I preprocessed all the data, and uploaded it as a Kaggle Dataset to train faster.

In [None]:
# dataset definition
class TumorDataset(Dataset):
    def __init__(self, patient_ids, patient_labels, transform, load_function, input_folder, split):
        super().__init__()
        self.patient_ids = patient_ids
        self.patient_labels = patient_labels if not (split == 'test') else None
        self.transform = transform
        self.load_function = load_function
        self.input_folder = input_folder
        self.split = split   
    
    def __len__(self):
        return len(self.patient_ids)
    
    def __getitem__(self, idx):
        current_label = self.patient_labels[idx] if self.split == 'train' else None
        # get folder of patient
        patient_folder = join(self.input_folder, self.split, self.patient_ids[idx])
        # read each of T1, Tw1ce, T2w, FLAIR
        # add them in patient_scans.
        patient_scans = []
        
        for scan_type in get_types():
            # read image
            current_scan = self.load_function(join(patient_folder, scan_type))
            
            # apply preprocessing
            current_scan = self.transform(current_scan)
            
            # add color channel to 3D image
            current_scan = current_scan.unsqueeze(0)
            
            # add image to array
            patient_scans.append(current_scan)
        
        if self.split == 'train':
            return (
                 torch.stack(patient_scans),
                 torch.tensor(current_label, dtype=torch.float32)
                )
        else:
            return (
                torch.stack(patient_scans)
                )

#---------------------------------------------------------------------
class PTDataset(Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y
        
    def __len__(self):
        return len(self.x)
    
    def __getitem__(self, idx):
        filename = join('my-train-data', self.x[idx])
        patient_scans = torch.load(filename)
        patient_scans = patient_scans.type(torch.FloatTensor)
        patient_scans /= 255
        return (
            patient_scans,
            torch.tensor(self.y[idx], dtype=torch.float32)
        )

In [None]:
def load_dataset(patient_ids, patient_labels, image_type, split):
    transform, load_fn, input_folder = None, None, None
    if (image_type.lower() == 'png'):
        transform = preprocess_png_image
        load_fn = load_scan_png
        input_folder = INPUT_FOLDER_PNG
    if (image_type.lower() == 'dicom' or image_type.lower() == 'dcm'):
        transform = preprocess_dicom_image
        load_fn = load_scan
        input_folder = INPUT_FOLDER
        
    return TumorDataset(patient_ids, patient_labels, transform, load_fn, input_folder, split)

In [None]:
# Train Data
train_dataset = PTDataset(train_info.patient_id.values,
                             train_info.patient_label.values)

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE,
                              shuffle=True, num_workers=1)

# Validation Data
val_dataset = PTDataset(val_info.patient_id.values,
                           val_info.patient_label.values)

val_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, num_workers=1)

# Test Data
test_dataset = load_dataset(patients_test,
                           None,
                           image_type='dcm', 
                           split='test')

test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, num_workers=1)

In [None]:
def plot_slices(image):
    
    fig, axes = plt.subplots(8, 8, figsize=(50,50))
    i = 0
    for row in axes:
        for col in row:
            col.imshow(image[i, :, :], cmap='gray')
            i += 1

## **Trying out dataloaders**

In [None]:
x = next(iter(test_dataloader))
print(f'size of x : {x.size()}')
print(f'size of x[:,0] : {x[:,0].size()}')
x_show = x[:, 0][0].squeeze().numpy()
t = plot_slices(x_show)

In [None]:
x, y = next(iter(train_dataloader))
print(f'size of x : {x.size()}')
print(f'size of x[:,0] : {x[:,0].size()}')
x_show = x[:, 0][1].squeeze()
t = plot_slices(x_show.numpy())

# Adding Libraries to download 3D Models

In [None]:
input_monaipath = '../input/monai-v060-deep-learning-in-healthcare-imaging'
monaipath = '/kaggle/monai'
input_medicalnet_path = '../input/medicalnet'
medicalnet_path = '/kaggle/medicalnet'

In [None]:
! mkdir -p {monaipath}
! cp -r {input_monaipath}/* {monaipath}
! mkdir -p {medicalnet_path}
! cp -r {input_medicalnet_path}/* {medicalnet_path}

In [None]:
sys.path.append(monaipath)
# sys.path.append(medicalnet_path)


from monai.networks.nets.efficientnet import EfficientNetBN
# from models.resnet import resnet18 , resnet34

In [None]:
def remove_last_n_layers(model, n):
    # removes last 2 layer from model and
    # returns the dimension of last layer
    
    components_list = list(model.children())
    
    return nn.Sequential(*(components_list[:-n]))

def remove_last_2_layers(model):
    return remove_last_n_layers(model, 2)

In [None]:
def build_model():
#     model = resnet34(sample_input_D=1, sample_input_H=256,
#                      sample_input_W=256, num_seg_classes=1)
#     net_dict = model.state_dict()
#     pretrained_weights = torch.load(PRETRAINED_PATH)
#     pretrained_weights = { 
#                             k.replace("module.", ""): v 
#                             for k, v in pretrained_weights['state_dict'].items() 
#                             if k.replace("module.", "") in net_dict.keys()
#               }
#     net_dict.update(pretrained_weights)
#     model.load_state_dict(net_dict)
#     model.conv_seg = nn.Sequential(
#                             nn.AdaptiveAvgPool3d(output_size=1),
#                             nn.Dropout(p=0.2, inplace=False)
#                             )
    model = EfficientNetBN('efficientnet-b1', spatial_dims=3, in_channels=1,
                           num_classes=1, pretrained=False)
    
    model = remove_last_2_layers(model)
    return model

class AllTypesNet(pl.LightningModule):
    def __init__(self, num_features, lr=0.001):
        super().__init__()
        
        self.lr = lr
        self.train_acc = Accuracy()
        self.val_acc = Accuracy()
        
        self.classifiers = nn.ModuleList([build_model()
                                                for _ in get_types()])
        
        self.fc = nn.Linear(in_features=len(get_types())*num_features, out_features=1)
        
        
        
    def forward(self, x):
#         print(f'size of 1 element of x : {x[:,0].size()}')
#         print(f'size of x : {x.size()}')
        pred_list = [ classifier(x[:, i]).squeeze() for i, classifier in enumerate(self.classifiers)]
#         print(f'size of cat output : {torch.cat(pred_list, -1).size()}')
#         print(f'pred_list[0] size : {pred_list[0].size()}')
        pred = self.fc(torch.cat(pred_list, -1))
        
        
        return pred
    
    def training_step(self, batch, batch_idx):
        
        x, y = batch
        
        
        y_pred = self(x).view(-1)
        
        assert not bool(torch.any(torch.isnan(y_pred)).item()), f'Model outputs nan on Epoch {self.current_epoch}, batch {batch_idx}' # assert error msg
    
        loss = F.binary_cross_entropy_with_logits(y_pred, y)
#         print(f'Y : {y}')
#         print(f'Y pred: {y_pred}')
        self.train_acc(torch.sigmoid(y_pred), y.type(torch.cuda.LongTensor))

        return {'loss': loss, 'predictions': None }
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        
        y_pred = self(x).view(-1)
        
        loss = F.binary_cross_entropy_with_logits(y_pred, y)
        
        self.val_acc(torch.sigmoid(y_pred), y.type(torch.cuda.LongTensor))
        
        
        return {'loss': loss, 'predictions': None }
    
    def training_epoch_end(self, training_step_outputs):
        print(f'Epoch {self.current_epoch} train accuracy : {round(self.train_acc.compute().item() * 100, 2)}%')
        self.train_acc.reset()
        
    
    def validation_epoch_end(self, validation_step_outputs):
        print(f'Epoch {self.current_epoch} val accuracy : {round(self.val_acc.compute().item() * 100, 2)}%')
        self.val_acc.reset()
        
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr)
        return optimizer


## **Check Model size to make sure there's enough memory in GPU**

In [None]:
model = build_model()
mem_params = sum([param.nelement()*param.element_size() for param in model.parameters()])
mem_bufs = sum([buf.nelement()*buf.element_size() for buf in model.buffers()])
mem = (mem_params + mem_bufs) / 2**20 # in Megabytes
mem * len(get_types())

# Check that there are no nan in train dataloader

In [None]:
from tqdm import tqdm
for data in tqdm(train_dataloader):
    x , y = data
    
    assert not bool(torch.any(torch.isnan(x)).item())
    

In [None]:
num_features = EfficientNetBN('efficientnet-b1', spatial_dims=3, in_channels=1,
                           num_classes=1, pretrained=False)._fc.in_features
num_features

In [None]:
model = AllTypesNet(num_features, 0.0003)

trainer = pl.Trainer(gpus=1,
                     max_epochs=3, log_every_n_steps=75)
trainer.fit(model,
            train_dataloaders=train_dataloader,
            val_dataloaders=val_dataloader)

In [None]:
model.lr = 0.0003 / 3
trainer = pl.Trainer(gpus=1,
                     max_epochs=5, log_every_n_steps=75)
trainer.fit(model,
            train_dataloaders=train_dataloader,
            val_dataloaders=val_dataloader)

In [None]:
model.lr = 0.0003 / 5
trainer = pl.Trainer(gpus=1,
                     max_epochs=10, log_every_n_steps=75)
trainer.fit(model,
            train_dataloaders=train_dataloader,
            val_dataloaders=val_dataloader)

In [None]:
y_pred = trainer.predict(model, test_dataloader)

In [None]:
preds = []
for pred_list in y_pred:
    for element in pred_list:
        preds.append(element.item())

In [None]:
# Applying sigmoid
preds = 1 / ( 1 + np.exp(-1 * np.array(preds)))

In [None]:
! rm ./* -rf

In [None]:
submission = pd.DataFrame({'BraTS21ID': patients_test,
                           'MGMT_value': preds})
submission.to_csv('submission.csv', index=False)