In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
import wandb
import glob
import re 
import pydicom
from pydicom.pixel_data_handlers.util import apply_voi_lut
import tensorflow as tf
import imageio
from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns
import cv2
import collections
import json
from tqdm import tqdm
%matplotlib inline
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        os.path.join(dirname, filename)

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
wandb.login()

In [None]:
CONFIG = {
    'IMG_SIZE':224,
    'competition':'rsna-miccai-brain',
    '_wandb_kernel':'rooneyy'
}

In [None]:
filename = '/kaggle/input/rsna-miccai-brain-tumor-radiogenomic-classification/train_labels.csv'
train_df = pd.read_csv(filename)
train_df.head()

In [None]:
print(f'Number of rows: {len(train_df)}')

In [None]:
fig, ax = plt.subplots(figsize=(10,6))
sns.countplot(y='MGMT_value', data=train_df);
ax.set_title('Distribution of labels', fontsize=15, weight='heavy')

In [None]:
filenames = glob.glob('/kaggle/input/rsna-miccai-brain-tumor-radiogenomic-classification/train/*/*/*')
print(f'Total number of files: {len(filenames)}')

In [None]:
label_dict = {'FLAIR':[],
              'T1w':[],
              'T1wCE':[],
              'T2w':[]
             }

for filename in tqdm(filenames):
    scan = filename.split('/')[-2]
    if scan == 'FLAIR':
        label_dict['FLAIR'].append(filename)
    elif scan == "T1w":
        label_dict['T1w'].append(filename)
    elif scan == 'T1wCE':
        label_dict['T1wCE'].append(filename)
    elif scan == 'T2w':
        label_dict['T2w'].append(filename)
        
print('Size of FLAIR scan: {}\nT1w scan: {}\nT1wCE scan: {}\nT2w scan: {}'.format(len(label_dict['FLAIR']),
                                                                                 len(label_dict['T1w']),
                                                                                 len(label_dict['T1wCE']),
                                                                                 len(label_dict['T2w'])))

In [None]:
run = wandb.init(project='brain-tumor-wizz', config=CONFIG)
data = [['FLAIR',74248],['T1w',77627],['T1wCE',96766],['T2w',100000]]
table = wandb.Table(data=data, columns=['Scan type','Size of Files'])
wandb.log({'my_bar_chart_id':wandb.plot.bar(table, 'Scan type', 'Size of Files', title='Scan types vs Number of Dicom Files')})
run.finish()

## **Read Dicom Files**

In [None]:
def ReadMRI(path, voi_lut=True, fix_monochrome=True):
    dicom = pydicom.read_file(path)
    
    if voi_lut:
        data = apply_voi_lut(dicom.pixel_array, dicom)
    else:
        data = dicom.pixel_array
        
    if fix_monochrome and dicom.PhotometricInterpretation == 'MONOCHROME1':
        data = data - np.min(data)
        if np.max(data) != 0:
            data = data / np.max(data)
        data = (data * 255).astype(np.uint8)
    
    return data

In [None]:
path = filenames[32346]
data = ReadMRI(path)
plt.imshow(data, cmap='gray')

In [None]:
def sorted_nicely(l):
    """ Sort the given iterable in the way that humans expect """
    convert = lambda text: int(text) if text.isdigit() else text
    alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)]
    return sorted(l, key=alphanum_key)

In [None]:
def get_patient_id(patient_id):
    if patient_id < 10:
        return '0000'+str(patient_id)
    elif patient_id >= 10 and patient_id < 100:
        return '000'+str(patient_id)
    elif patient_id >= 100 and patient_id < 1000:
        return '00'+str(patient_id)
    else:
        return '0'+str(patient_id)

In [None]:
train_df_1 = train_df[train_df.MGMT_value == 1].reset_index(drop=True)
print(f'Number of patients with brain tumor: {len(train_df_1)}')

IMG_2_log = 20
train_df_1_sampled = train_df_1.sample(n=IMG_2_log).reset_index(drop=True)
print(f'Number of sampled patients: {len(train_df_1_sampled)}') 

sampled_data_at = wandb.Table(dataframe=train_df_1_sampled)
run = wandb.init(project='brain-tumor-viz(Sampled Patients)', config=CONFIG)
wandb.log({f'Sampled DataFrame': sampled_data_at})
run.finish()

In [None]:
for i in tqdm(range(len(train_df_1_sampled))):
    ID = train_df_1_sampled.BraTS21ID[i]
    patient_id = get_patient_id(ID)
    
    run = wandb.init(project='brain-tumor-viz(Animate MRI)', config=CONFIG, name=f'{patient_id}')
    
    for key in label_dict.keys():
        if os.path.isdir(f'/kaggle/input/rsna-miccai-brain-tumor-radiogenomic-classification/train/{patient_id}/{key}'):
            _filenames = os.listdir(f'/kaggle/input/rsna-miccai-brain-tumor-radiogenomic-classification/train/{patient_id}/{key}')
            _filenames = sorted_nicely(_filenames)
            for filename in _filenames:
                mri_data = ReadMRI(f'/kaggle/input/rsna-miccai-brain-tumor-radiogenomic-classification/train/{patient_id}/{key}/{filename}')
                wandb.log({f'{key}': [wandb.Image(mri_data)]})
    
    run.finish()

In [None]:
def load_dicom(path):
    dicom = pydicom.read_file(path)
    data = dicom.pixel_array
    data = data - np.min(data)
    if np.amax(data) != 0:
        data = data / np.amax(data)
    data = (data * 255).astype(np.uint8)
    
    return data

In [None]:
def visualise_sample(brats21id, slice_i, mgmt_value, types=('FLAIR','T1w','T1wCE','T2w')):
    plt.figure(figsize=(10,6))
    patient_path = os.path.join('/kaggle/input/rsna-miccai-brain-tumor-radiogenomic-classification/train/', str(brats21id).zfill(5))
    
    for i, t in enumerate(types, 1):
        t_paths = sorted(glob.glob(os.path.join(patient_path, t, '*')), key = lambda x: int(x[:-4].split('-')[-1]))
        data = load_dicom(t_paths[int(len(t_paths) * slice_i)])
        plt.subplot(1, 4, i)
        plt.imshow(data, cmap='gray');
        plt.title(f'{t}', fontsize=16)
        plt.axis('off')
        
    plt.suptitle(f'MGMT value: {mgmt_value}', fontsize=14)
    plt.show()

In [None]:
for i in np.random.choice(range(len(train_df)), 10):
    _brats21id = train_df.iloc[i].BraTS21ID
    _mgmt_value = train_df.iloc[i].MGMT_value
    visualise_sample(brats21id=_brats21id, mgmt_value=_mgmt_value, slice_i=0.5)

In [None]:
from matplotlib import animation, rc
rc('animation', html='jshtml')

def create_animations(ims):
    fig = plt.figure(figsize=(10,6))
    plt.axis('off')
    im = plt.imshow(ims[0], cmap='gray')
    
    def animate_func(i):
        im.set_array(ims[i])
        return [im]
    
    return animation.FuncAnimation(fig, animate_func, frames=len(ims), interval = 1000//20)

In [None]:
def load_dicom_line(path):
    t_paths = sorted(glob.glob(os.path.join(path, '*')), key = lambda x: int(x[:-4].split('-')[-1]))
    
    images=[]
    for filename in t_paths:
        data = load_dicom(filename)
        if data.max() == 0:
            continue
        images.append(data)
        
    return images

In [None]:
path = '/kaggle/input/rsna-miccai-brain-tumor-radiogenomic-classification/train/00234/T2w'

In [None]:
train_df[train_df.BraTS21ID == 234].MGMT_value.values

In [None]:
print('MGMT Value of patient:', train_df[train_df.BraTS21ID == 234].MGMT_value.values)
images = load_dicom_line(path)
create_animations(images)

In [None]:
print('MGMT value of patient:', train_df[train_df.BraTS21ID == 510].MGMT_value.values)
path = '/kaggle/input/rsna-miccai-brain-tumor-radiogenomic-classification/train/00510/T2w'
images = load_dicom_line(path)
create_animations(images)

In [None]:
import time 

import torch 
from torch import nn
from torch.utils import data as torch_data
from sklearn import model_selection as sk_model_selection
from torch.nn import functional as torch_functional 
from efficientnet_pytorch import EfficientNet
import cv2

from sklearn.model_selection import StratifiedKFold


In [None]:
def set_seed(seed):
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        
        
set_seed(42)        

In [None]:
df = pd.read_csv('/kaggle/input/rsna-miccai-brain-tumor-radiogenomic-classification/train_labels.csv')
df_train, df_valid = sk_model_selection.train_test_split(df, test_size=0.2, random_state=42, stratify = df['MGMT_value'])

In [None]:
class DataRetriever(torch_data.Dataset):
    def __init__(self, paths, targets):
        self.paths = paths
        self.targets = targets
        
    def __len__(self):
        return len(self.paths)
        
    def __getitem__(self, index):
        _id = self.paths[index]
        patient_path = f'/kaggle/input/rsna-miccai-brain-tumor-radiogenomic-classification/train/{str(_id).zfill(5)}/'
        channels = []
        for t in ('FLAIR','T1w','T1wCE','T2w'):
            t_paths = sorted(glob.glob(os.path.join(patient_path, t, '*')), key = lambda x: int(x[:-4].split('-')[-1]))
            x = len(t_paths)
            if x < 10:
                r = range(x)
            else:
                d = x // 10
                r = range(d, x - d, d)
            
            channel = []
            for i in r:
                channel.append(cv2.resize(load_dicom(t_paths[i]), (256,256)) / 255)
            channel = np.mean(channel, axis=0)
            channels.append(channel)
        
        y = torch.tensor(self.targets[index], dtype = torch.float)
        
        
        return {'X': torch.tensor(channels).float(), 'y':y}

In [None]:
train_data_retriever = DataRetriever(df_train['BraTS21ID'].values, df_train['MGMT_value'].values)

valid_data_retriever = DataRetriever(df_valid['BraTS21ID'].values, df_valid['MGMT_value'].values)

In [None]:
plt.figure(figsize=(10,10))
for i in range(3):
    plt.subplot(1, 3, i+1)
    plt.imshow(train_data_retriever[100]['X'].numpy()[i], cmap='gray')

In [None]:
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        os.path.join(dirname, filename)

In [None]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = EfficientNet.from_name('efficientnet-b0')
        checkpoint = torch.load('/kaggle/input/nfnets/pytorch-image-models-master/')# to be completed
        self.net.load_state_dict(checkpoint)
        n_features = self.net._fc.in_features
        self.net._fc = nn.Linear(in_features=n_features, out_features=1, bias = True)
        
    def forward(self, x):
        out = self.net(x)
        return(out)

In [None]:
class LossMeter:
    def __init__(self):
        self.avg = 0
        self.n = 0
        
    def update(self, val):
        self.n += 1
        #incremental update
        self.avg = val / self.n + (self.n - 1) / self.n * self.avg
        

class AccMeter:
    def __init__(self):
        self.avg = 0
        self.n = 0
        
    def update(self, y_true, y_pred):
        y_true = y_true.cpu().numpy().astype(int)
        y_pred = y_pred.cpu().numpy() >= 0
        last_n = self.n
        self.n += len(y_true)
        true_count = np.sum(y_true == y_pred)
        #incremental update
        self.avg = true_count / self.n + last_n / self.n * self.avg

In [None]:
class Trainer:
    def __init__(
        self,
        model,
        device,
        optimizer,
        criterion,
        loss_meter,
        score_meter):
        self.model = model
        self.device = device
        self.optimizer = optimizer
        self.criterion = criterion
        self.loss_meter = loss_meter
        self.score_meter = score_meter
        
        self.best_valid_score = -np.inf
        self.n_patience = 0
        
        self.messages = {
            'epoch': '[Epoch {}: {}] loss: {:.5f}, score: {:.5f}, time: {} s',
            'checkpoint': 'The score improved from {:.5f} to {:.5f}. Save model to {}',
            'patience': "\nValid score didn't improve last {} epochs"
        }
        
    def fit(self, epochs, train_loader, valid_loader, save_path, patience):
        for n_epoch in range(1, epochs + 1):
            self.info_message(f'EPOCH: {n_epoch}')
            
            train_loss, train_score, train_time = self.train_epoch(train_loader)
            valid_loss, valid_score, valid_time = self.valid_epoch(valid_loader)
            
            self.info_message(self.messages['epoch'], 'Train', n_epoch, train_loss, train_score, train_time)
            
            self.info_message(self.messages['epoch'], 'Valid', n_epoch, valid_loss, valid_score, valid_time)
            
            if True:
#                 if self.best_valid_score < valid_score:
                self.info_message(self.messages['checkpoint'], self.best_valid_score, valid_score, save_path)
                self.best_valid_score = valid_score
                self.save_model(n_epoch, save_path)
                self.n_patience = 0
            else:
                self.n_patience += 1
                
            if self.n_patience >= patience:
                self.info_message(self.messages['patience'], patience)
                break
                
    def train_epoch (self, train_loader):
        self.model.train()
        t = time.time()
        train_loss = self.loss_meter()
        train_score = self.score_meter()
        
        for step, batch in enumerate(train_loader, 1):
            X = batch['X'].to(self.device)
            target = batch['y'].to(self.device)
            self.optimizer.zero_grad()
            outputs = self.model(X).squeeze(1)
            
            loss = self.criterion(outputs, targets)
            loss.backward()
            
            train_loss.update(loss.detach().item())
            train_score.update(targets, outputs.detach())
            
            self.optimizer.step()
            
            _loss, _score = train_loss.avg, train_score.avg
            message = 'Train Step {}/{}, train_loss: {:.5f}, train_score: {:.5f}'
            self.info_message(message, step, len(train_loader), _loss, _score, end='\r')
            
        return train_loss.avg, train_score.avg, int(time.time() - t)
    
    
    def valid_epoch(self, valid_loader):
        self.model.eval()
        t = time.time()
        valid_loss = self.loss_meter()
        valid_score = self.score_meter()
        
        for step, batch in enumerate(valid_loader, 1):
            with torch.no_grad():
                X = batch['X'].to(self.device)
                target = batch['y'].to(self.device)
                outputs = self.model(X).squeeze(1)
                loss = self.criterion(outputs, targets)
                
                valid_loss.update(loss.detach().item())
                valid_score.update(targets, outputs)
                
            _loss, _score = valid_loss.avg, valid_score.avg
            message = 'Valid Step {}/{}, valid_loss: {:.5f}, valid_score: {:.5f}'
            self.info_message(message, step, len(valid_loader), _loss, _score, end='\r')
            
        return valid_loss.avg, valid_score.avg, int(time.time() - t)
    
    def save_model(self, n_epoch, save_path):
        torch.save(
        {
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'best_valid_score': self.best_valid_score,
            'n_epoch': n_epoch
        },
        save_path
        )
        
    @staticmethod
    def info_message(message, *args, end='\n'):
        print(message.format(*args), end=end)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_data_retreiver = DataRetriever(
    df_train['BraTS21ID'].values,
    df_train['MGMT_value'].values)

valid_data_retreiver = DataRetriever(
    df_valid['BraTS21ID'].values,
    df_valid['MGMT_value'].values)

train_loader = torch_data.DataLoader(
    train_data_retreiver,
    batch_size=8,
    shuffle=True,
    num_workers=8)

valid_loader = torch_data.DataLoader(
    valid_data_retreiver,
    batch_size=8,
    shuffle=False,
    num_workers=8)

model = Model()
model.to(device)
# 
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch_functional.binary_cross_entropy_with_logits

trainer = Trainer(
    model,
    device,
    optimizer,
    criterion,
    LossMeter,
    AccMeter)

history = trainer.fit(
    2,
    train_loader,
    valid_loader,
    f'best_model-0.path',
    100)

# Work in Progress . .