In [None]:
from IPython.core.magic import register_cell_magic
import os
from pathlib import Path


@register_cell_magic
def write_and_run(line, cell):
    argz = line.split()
    file = argz[-1]
    mode = 'w'
    if len(argz) == 2 and argz[0] == '-a':
        mode = 'a'
    with open(file, mode) as f:
        f.write(cell)
    get_ipython().run_cell(cell)
    
Path('/kaggle/working/scripts').mkdir(exist_ok=True)
models_dir = Path('/kaggle/working/models')
models_dir.mkdir(exist_ok=True)

In [None]:

import os
import json
import glob
import random
import collections

import numpy as np
import pandas as pd
import pydicom
from pydicom.pixel_data_handlers.util import apply_voi_lut
import cv2
import matplotlib.pyplot as plt
import seaborn as sns
plt.style.use('seaborn-talk')

import torch

from sklearn.metrics import roc_auc_score, roc_curve, auc

from torchvision import transforms


import time

import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from sklearn import model_selection as sk_model_selection
from torch.nn import functional as F

from sklearn.model_selection import StratifiedKFold
import pytorch_lightning as pl
from transformers import DeiTFeatureExtractor, DeiTForImageClassification, AutoConfig
from pytorch_lightning.core.memory import ModelSummary
import sys
sys.path.append('../input/efficientnetpyttorch3d/EfficientNet-PyTorch-3D')
from efficientnet_pytorch_3d import EfficientNet3D
import joblib
import nibabel as nb
import tarfile
from pathlib import Path
!pip install --upgrade https://github.com/VincentStimper/mclahe/archive/numpy.zip

from mclahe import mclahe

In [None]:
class Config:
    seed = 42
    lr = 2e-8
    data_dir = Path('/tmp/input/nifti')
    mri_types = ['FL', 'T1', 'T1CE', 'T2']

pl.utilities.seed.seed_everything(Config.seed, workers=True)

In [None]:
# tar = tarfile.open('../input/tumorbrainnifty/input.tar')
# tar.extractall('/tmp/input')
# tar.close()

In [None]:
from tqdm import tqdm
tqdm.pandas()

kfold_df = pd.read_csv("../input/braintumor-sampling/brain_tumor_kfold.csv", dtype={'BraTS21ID': str})
print(kfold_df.shape)

def filt_func(name):
    files = glob.glob(f'/tmp/input/nifti/train/{name}/*')
    if len(files) == 0:
        return False
    
    for file in files:
        data = nb.load(file).get_fdata()
        if data.max() == 0:
            return False
    return True


kfold_df['present'] = kfold_df.BraTS21ID.apply(func=filt_func)
kfold_df = kfold_df[kfold_df.present]
print(kfold_df.shape)


In [None]:

def load_volume(dp_id, split="train"):
    mri_volumes = []
    for mri_type in Config.mri_types:
        volume = nb.load(Config.data_dir / split / dp_id / f'{mri_type}_to_SRI.nii.gz').get_fdata()
#         print(volume.max())

        volume = volume - volume.mean()
        volume = volume / volume.std()
#         print(volume.max())
        mri_volumes.append(volume)
        
    mri_volumes = np.stack(mri_volumes)
#     print(mri_volumes.shape, mri_volumes.dtype, mri_volumes.max())
    

#     print(mri_volumes.shape, mri_volumes.dtype, mri_volumes.mean(), mri_volumes.min(), mri_volumes.max())
#     print(mri_volumes.shape)
    return mri_volumes

In [None]:
%%write_and_run scripts/dataset.py

from torchvision.transforms import ToTensor

class DataRetriever(Dataset):
    def __init__(self, paths, files_dir, targets=None):
        
        self.paths = paths
        self.targets = targets
        self.files_dir = files_dir
        if self.targets is None:
            self.split = 'val'
        else:
            self.split = 'train'
        
    def __len__(self):
        return len(self.paths)
    
    def __getitem__(self, index):
        dp_id = self.paths[index]
        print(dp_id)
        
        volume = load_volume(dp_id)
        if not self.targets is None:
            y = torch.tensor(abs(self.targets[index]), dtype=torch.float)
        else:
            y = []
            
        return {"X": torch.tensor(volume).float(), 'y': y}
    
    


In [None]:
train_df = kfold_df[kfold_df.fold!=0]
train_ds = DataRetriever(
    train_df["BraTS21ID"].values, 
    '../input/rsna-miccai-brain-tumor-radiogenomic-classification/train',
    targets=train_df["MGMT_value"].values
)

X = train_ds[48]['X']


# for i in range(40, 100):
#     X = train_ds[i]['X']
#     print(i, X.mean().item(), X.sum().item())

# train_dl = DataLoader(
#     train_ds,
#     batch_size=1,
#     shuffle=True,
#     num_workers=1,
# )

# next(iter(train_dl))['X'].shape

In [None]:
for file in glob.glob('/tmp/input/nifti/train/00142/*'):
    img = nb.load(file).get_fdata()
    print(img.shape)
#     print(np.nonzero(img))
    plt.imshow(img[:,:,77])
    plt.show()

In [None]:
for mri_type in glob.glob('../input/rsna-miccai-brain-tumor-radiogenomic-classification/train/00142/*'):
    files = glob.glob(f'{mri_type}/*')
    print(len(files))
    mid_file = sorted(files, key=lambda n: int(n.split('-')[-1].split('.')[0]))[len(files) // 2+10]
    print(mid_file)
    img = pydicom.read_file(mid_file).pixel_array
#     img = nb.load(file).get_fdata()
    print(img.shape)
    plt.imshow(img)
    plt.show()

In [None]:
def calc_roc_auc(y_true, y_pred):
#     print(y_true, y_pred)
    
    return roc_auc_score(y_true, y_pred) 


In [None]:
%%write_and_run scripts/model.py


class Model(pl.LightningModule):
    def __init__(self, net, lr):
        super().__init__()
        self.net = net
        self.lr = lr
    
    def forward(self, x):
        out = self.net(x)
        return out
    
    def training_step(self, batch, batch_idx):
        X, y = batch['X'], batch['y']
        y_hat = self(X).squeeze(1)
        loss = F.binary_cross_entropy_with_logits(y_hat, y)  
        
        print('train:')
        print('X:', X.mean(), X.sum())

        print('y:', y)
        print('y_hat:', y_hat)
        print('loss:', loss.item())

        return {
            'loss': loss,
            'y': y,
            'y_hat': y_hat
        }
    
    
    def validation_step(self, batch, batch_idx):
        self.train()
        X, y = batch['X'], batch['y']
        y_hat = self(X).squeeze(1)
        loss = F.binary_cross_entropy_with_logits(y_hat, y)  

        
        print('val:')
        print('X:', X.mean(), X.sum())
        print('y:', y)
        print('y_hat:', y_hat)
        print('loss:', loss.item())
        print('\n\n')
        
        return {
            'loss': loss,
            'y': y,
            'y_hat': y_hat
        }
    
    def predict_step(self, batch, batch_idx: int, dataloader_idx: int = None):
        self.train()
        X = batch['X']
        return self(X)    
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr )
        return optimizer
    

In [None]:
from time import time
class Metrics:
    def __init__(self):
        self.losses = []
        self.reduced_losses = []
        self.y_batches = []
        self.y_hat_batches = []
        self.roc_auc_list = []
        
        self.train_epoch_start_time = None
        self.validation_epoch_start_time = None


class MetricsCallback(pl.callbacks.Callback):
    def __init__(self):
        self.train_metrics = Metrics()
        self.validation_metrics = Metrics()
        self.best_validation_roc_auc = float('-inf')
        
    def on_train_epoch_start(self, trainer, pl_module):
        self.train_epoch_start_time = time()
        
    def on_validation_epoch_start(self, trainer, pl_module):
        self.validation_epoch_start_time = time()
    
        
    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
        self.after_batch(self.train_metrics, outputs)
        
    def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
        self.after_batch(self.validation_metrics, outputs)
        
        
    def on_train_epoch_end(self, trainer, pl_module):
        print('train epoch took', time() - self.train_epoch_start_time)
        self.train_epoch_start_time = None
        
        
            
    def on_validation_epoch_end(self, trainer, pl_module):
        train_loss = self.get_avg_loss(self.train_metrics)
        validation_loss = self.get_avg_loss(self.validation_metrics)

        
        train_roc_auc = self.get_roc_auc(self.train_metrics)
        validation_roc_auc = self.get_roc_auc(self.validation_metrics)
        
        self.log('roc_auc', validation_roc_auc)
        if validation_roc_auc > self.best_validation_roc_auc:
            self.best_validation_roc_auc = validation_roc_auc
        
        print('validation epoch took', time() - self.validation_epoch_start_time)
        self.validation_epoch_start_time = None
        print('train loss:', train_loss)
        print('validation loss:', validation_loss)
        print('train roc_auc:', train_roc_auc)
        print('validation roc_auc:', validation_roc_auc)
        print()
        
        
    def get_avg_loss(self, metrics):
        avg_loss = np.array(metrics.losses).mean()
        metrics.reduced_losses.append(avg_loss)
        metrics.losses = []
        return avg_loss
        
    def after_batch(self, metrics, outputs):
        metrics.losses.append( outputs['loss'].item())
        metrics.y_batches.append(outputs['y'])
        metrics.y_hat_batches.append(outputs['y_hat'])
        
    def get_roc_auc(self, metrics):   
        if not metrics.y_batches:
            return None
        y_np = torch.hstack(metrics.y_batches).detach().cpu().numpy()
        y_hat_np = torch.hstack(metrics.y_hat_batches).detach().cpu().numpy()
        roc_auc = calc_roc_auc(y_np, y_hat_np)
        metrics.roc_auc_list.append(roc_auc)
        
        metrics.y_batches = []
        metrics.y_hat_batches = []
        return roc_auc

In [None]:
def plot_metrics(metrics_callback):
    train_losses = metrics_callback.train_metrics.reduced_losses
    validation_losses = metrics_callback.validation_metrics.reduced_losses

    train_roc_aucs = metrics_callback.train_metrics.roc_auc_list
    validation_roc_aucs = metrics_callback.validation_metrics.roc_auc_list

    plt.plot(train_losses, label='train_loss')
    plt.plot(validation_losses, label='val_loss')
    plt.legend()
    plt.show()

    plt.plot(train_roc_aucs, label='train_roc_auc')
    plt.plot(validation_roc_aucs, label='val_roc_auc')
    plt.legend()
    plt.show()

In [None]:
roc_aucs = []
for fold_n in range(5):
    print('Fold:', fold_n+1)
    train_df = kfold_df[kfold_df.fold!=fold_n]
#     train_df = kfold_df.iloc[:10]
    train_ds = DataRetriever(
    train_df["BraTS21ID"].values, 
    '../input/rsna-miccai-brain-tumor-radiogenomic-classification/train',
    targets=train_df["MGMT_value"].values
)

    val_df = kfold_df[kfold_df.fold==fold_n]
#     val_df = kfold_df.iloc[:10]
    val_ds = DataRetriever(
    val_df["BraTS21ID"].values, 
    '../input/rsna-miccai-brain-tumor-radiogenomic-classification/train',
    targets=val_df["MGMT_value"].values
)

    train_dl = DataLoader(
        train_ds,
        batch_size=2,
        shuffle=True,
        num_workers=8,
        pin_memory=True
    )

    val_dl = DataLoader(
        val_ds, 
        batch_size=2,
        shuffle=False,
        num_workers=2,
        pin_memory=True
    )

    net = EfficientNet3D.from_name("efficientnet-b3", override_params={'num_classes': 1}, in_channels=4)
    model = Model(net, Config.lr)
    checkpoint_callback = pl.callbacks.ModelCheckpoint(dirpath='models', filename=f'model_{fold_n}_' + '{epoch}_{roc_auc:.3}', monitor='roc_auc', mode='max', save_weights_only=True)
    metrics_callback = MetricsCallback()
    print(ModelSummary(model))
    trainer = pl.Trainer(fast_dev_run=False, max_epochs=10, gpus=1,
                         auto_lr_find=True, precision=16, limit_train_batches=1.0, limit_val_batches=1.0, 
                         num_sanity_val_steps=0, val_check_interval=0.333, 
                         accumulate_grad_batches=4,
                         callbacks=[metrics_callback, checkpoint_callback])



#     lr_find = trainer.tune(model, train_dl, val_dl)['lr_find']
#     fig = lr_find.plot(suggest=True)
#     plt.show()
#     new_lr = lr_find.suggestion()
#     print('new_lr', new_lr)

    trainer.fit(model, train_dl, val_dl)
    plot_metrics(metrics_callback)
    roc_aucs.append(metrics_callback.best_validation_roc_auc)


print('roc_auc_s:', roc_aucs)
print('roc_auc mean:', np.array(roc_aucs).mean())
print('roc_auc std:', np.array(roc_aucs).std())





In [None]:
!ls /tmp/input/nifti/train/0008*

In [None]:
import gc
# del model, trainer, train_dl, val_dl, train_ds, val_ds
gc.collect()
torch.cuda.empty_cache()

In [None]:
# !rm -rf models/*

In [None]:
!ls --block-size=M -l models

In [None]:
# base_model = Model(deit_model)

# for ckpt_name in glob.glob('models/*.ckpt'):
#     state_dict = torch.load(ckpt_name)['state_dict']
#     base_model.load_state_dict(state_dict)
#     fold_n = int(ckpt_name.split('_')[0][-1])
    
#     val_df = kfold_df[kfold_df.fold==fold_n]
#     val_ds = DataRetriever(
#         val_df["BraTS21ID"].values, 
#         '../input/rsna-miccai-brain-tumor-radiogenomic-classification/train',
#         targets=val_df["MGMT_value"].values
#     )

#     val_dl = DataLoader(
#         val_ds, 
#         batch_size=8,
#         shuffle=False,
#         num_workers=16,
#         collate_fn=FeatureExtractorCollate(feats_extractor)
#     )
#     metrics_callback = MetricsCallback()

#     trainer = pl.Trainer(gpus=1, num_sanity_val_steps=0, callbacks=[metrics_callback])
#     trainer.validate(base_model, val_dl)




In [None]:
# base_model = Model(deit_model)
# test_df = pd.read_csv("../input/rsna-miccai-brain-tumor-radiogenomic-classification/sample_submission.csv")

# for ckpt_name in glob.glob('models/*.ckpt'):
#     state_dict = torch.load(ckpt_name)['state_dict']
#     base_model.load_state_dict(state_dict)
#     fold_n = int(ckpt_name.split('_')[0][-1])
    
#     test_ds = DataRetriever(
#         test_df["BraTS21ID"].values, 
#         '../input/rsna-miccai-brain-tumor-radiogenomic-classification/test',
#         targets=test_df["MGMT_value"].values
#     )

#     test_dl = DataLoader(
#         test_ds, 
#         batch_size=8,
#         shuffle=False,
#         num_workers=16,
#         collate_fn=FeatureExtractorCollate(feats_extractor)
#     )

#     trainer = pl.Trainer(gpus=1, num_sanity_val_steps=0)
#     preds = trainer.predict(base_model, test_dl)
#     print([p.flatten().tolist() for p in preds])
#     print()
