## RSNA-MICCAI Brain Tumor Radiogenomic Classification
This notebook can be used to finetune a pretrained model on the RSNA-MICCAI dataset. It uses MedicalNet[[1]](https://github.com/Tencent/MedicalNet) as a starting point.

## Table of Contents
1. [Setup](#setup) 
    1. [Configuration](#config) 
    2. [Imports](#imports) 
    3. [Helpers](#helpers) 
    4. [System setup](#env) 
2. [Model Definition](#model)
3. [Dataset](#data)
4. [Trainer](#trainer)
5. [Training](#training)

### Trusting CV vs LB
In short: Don't trust CV. It is still unclear if our models can even find a meaningful signal to predict the MGMT value. I have not found a valid cross validation scheme as of yet. This is refelected in the subpar learning curve.

### Notes
The competition data consists of MRI scans in 4 different expressions (FLAIR, T1w, T1Gd, T2), each of which can have different orientations and resolutions. If we want to train a model that leverages all of this information, we would most likely have to superimpose these scans to preserve the spacial information. A promising approach for this was already introduced by @boojum[[2]](https://www.kaggle.com/boojum/connecting-voxel-spaces). However we are only interested in a baseline for now, so we will only use the FLAIR scans and also discard additional information from the DICOM files. This allows us to use the converted png dataset by MICCAI [[3]](https://www.kaggle.com/jonathanbesomi/rsna-miccai-png), giving us a simple and fast preprocessing pipeline.

**Thank you to all the domain experts for your valuable insights.**


### EDA
There are several excellent kernels out there. I mainly referred to these kernels and discussions:
 - https://www.kaggle.com/ayuraj/brain-tumor-eda-and-interactive-viz-with-w-b
 - https://www.kaggle.com/ihelon/brain-tumor-eda-with-animations-and-modeling
 - https://www.kaggle.com/c/rsna-miccai-brain-tumor-radiogenomic-classification/discussion/253736
 - https://www.kaggle.com/c/rsna-miccai-brain-tumor-radiogenomic-classification/discussion/253488
 

## TODO
There is a lot of room for improvement, for example:
 - data augmentation
 - preprocessing, e.g. center cropping
 - adding T1w, T1Gd, T2 channels
 - architecture choices
 - hyperparameter optimization (this notebook uses default params)

## References
[1]
Chen, Sihong and Ma, Kai and Zheng, Yefeng (2019). 
Med3D: Transfer Learning for 3D Medical Image Analysis.
arXiv preprint arXiv:1904.00625

[2] https://www.kaggle.com/boojum/connecting-voxel-spaces

[3] https://www.kaggle.com/jonathanbesomi/rsna-miccai-png


## Setup <a class="anchor" id="setup"></a>
### Configuration <a class="anchor" id="config"></a>

In [None]:
class CFG:
    # system
    seed=42
    no_cuda=False
    
    # model
    model_name = 'resnet_34_23dataset'
    use_pretrained = True
    
    # cohort to use
    cohort = 'FLAIR'
    
    img_size=256
    
class TrainerConfig:
    num_epochs = 15
    batch_size = 8
    gradient_accumulation_steps = 1

    # optimizer
    lr = 3e-2
    warm_up_ratio = 0.1
    weight_decay = 0.0

    # log every log_steps to wandb
    log_steps = 20

    # environment
    device = 'cuda'

### Imports <a class="anchor" id="imports"></a>

Add MedicalNet to sys path and import libraries

In [None]:
# add medicalnet to path
import sys
import os
os.system('cp -r ../input/medicalnet-with-weights/MedicalNet MedicalNet')
os.system('touch MedicalNet/__init__.py')
sys.path.append('/kaggle/working/MedicalNet')

In [None]:
# system
import sys
import os
import random
import time

# data processing
import numpy as np
import pandas as pd

# dl librarires
## torch
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset

## transformers
from transformers import get_cosine_schedule_with_warmup
from transformers.optimization import AdamW

## sklearn
from sklearn import model_selection
from sklearn.metrics import roc_auc_score

# wandb for logging
import wandb
from kaggle_secrets import UserSecretsClient

# plotting
import cv2
import matplotlib.pyplot as plt
from IPython import display

# med
import pydicom
from pydicom.pixel_data_handlers.util import apply_voi_lut

import MedicalNet.model

### Helpers <a class="anchor" id="helpers"></a>
Some helper functions. `Struct` is only needed for MedicalNet  

Functions: `seed_everything`, `get_folds`

Classes: `Struct`

In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
# MedicalNet expect a class to hold hyperparameters
class Struct:
    def __init__(self, entries):
        self.__dict__.update(entries)


# create 5 stratified folds, optionally pass rng generator 
def get_folds(df, rng=42):
    # read df
    if isinstance(df, str):
        train_df = pd.read_csv(df)
    elif isinstance(df, pd.DataFrame):
        train_df = df
    else:
        print(f"Didn't understand data type for df: {type(df)}")

    

    # shuffle before split
    train_df = train_df.sample(frac=1, random_state=rng).reset_index(drop=True)

    # get stratified splits with sklearn
    kf = model_selection.StratifiedKFold(n_splits=5)
    # train_df.loc[:, 'fold'] = 0
    for fold_idx, (_, indices_test) in enumerate(kf.split(X=train_df.BraTS21ID, y=train_df.MGMT_value)):
        train_df.loc[indices_test, 'fold'] = fold_idx
    return train_df

### System setup <a class="anchor" id="env"></a>
Initialize random seed, setup cuda and set environment for wandb.

Variables: `DEVICE`, `config_wandb`, `wandb_group_name`

In [None]:
# system setup
seed_everything(CFG.seed)

DEVICE = torch.device("cuda") if torch.cuda.is_available() and not CFG.no_cuda else torch.device("cpu")

# setup some wandb params
# login wandb
user_secrets = UserSecretsClient()
wandb_api = user_secrets.get_secret("wandb_api")
wandb.login(key=wandb_api)

# config for logging
config_wandb = {
    'learning_rate': TrainerConfig.lr,
    'num_epochs': TrainerConfig.num_epochs,
    'batch_size': TrainerConfig.batch_size,
    'warm_up_ratio': TrainerConfig.warm_up_ratio,
    'weight_decay': TrainerConfig.weight_decay,
    'architecture': CFG.model_name,
}
wandb_group_name = "exp" + wandb.util.generate_id()

## Model Definition <a class="anchor" id="model"></a>
Take a pretrained segemntation model from MedicalNet, add a global pooling layer and a linear layer and *et voila*, we have a classifier.

Dict: `model_pretrained_params`

Class: `MedicalNetWithHead`

In [None]:
# model specific args
model_pretrained_params = {
    'resnet_10': {'model_depth': 10, 'resnet_shortcut': 'B'},
    'resnet_10_23dataset': {'model_depth': 10, 'resnet_shortcut': 'B'},
    'resnet_18': {'model_depth': 18, 'resnet_shortcut': 'A'},
    'resnet_18_23dataset': {'model_depth': 18, 'resnet_shortcut': 'A'},
    'resnet_34': {'model_depth': 34, 'resnet_shortcut': 'A'},
    'resnet_34_23dataset': {'model_depth': 34, 'resnet_shortcut': 'A'},
    'resnet_50': {'model_depth': 50, 'resnet_shortcut': 'B'},
    'resnet_50_23dataset': {'model_depth': 50, 'resnet_shortcut': 'B'},
    'resnet_101': {'model_depth': 101, 'resnet_shortcut': 'B'},
    'resnet_152': {'model_depth': 152, 'resnet_shortcut': 'B'},
    'resnet_200': {'model_depth': 200, 'resnet_shortcut': 'B'},
}
# consistent args
opts = {
    'model': 'resnet',
    'input_W': 256,
    'input_H': 256,
    'input_D': 64,
    'no_cuda': CFG.no_cuda,
    'n_seg_classes': 1,
    'phase': 'train',
    'pretrain_path': None,
    'gpu_id': [1],
}

# merge modelspecific args and global args
for model_name, model_dict in model_pretrained_params.items():
    model_pretrained_params[model_name] = Struct({**model_dict, **opts})
    
    
# MedicalNet with a global pooling head
class MedicalNetWithHead(nn.Module):
    def __init__(self, model_name, pretrain_path=None):
        super().__init__()
        self.model_name = model_name
        model, parameters = MedicalNet.model.generate_model(model_pretrained_params[model_name])
        self.medical_net = model
        self.drop_in = nn.Dropout(p=0.1)
        
        # init model with pretrained weights
        if not pretrain_path and CFG.use_pretrained:
            self.init_model()
            
        # use simple pooling for now
        self.pool = nn.AdaptiveAvgPool3d(1)
        
    def forward(self, x):
        x = self.medical_net(self.drop_in(x))
        out = self.pool(x)
        return out
            
    def init_model(self):
        net_dict = self.medical_net.state_dict()
        # load pretrain
        pretrain = torch.load(f'../input/medicalnet-with-weights/MedicalNet_pytorch_files2/pretrain/{self.model_name}.pth', map_location=DEVICE)
        pretrain_dict = {k: v for k, v in pretrain['state_dict'].items() if k in net_dict.keys()}
        net_dict.update(pretrain_dict)
        self.medical_net.load_state_dict(net_dict)
        print("loaded pretrained weights")

## Dataset <a class="anchor" id="data"></a>
Simple dataset without augmenations. We would like to have the same depth for each MRI scan with dimensions DxWxH, however the sampling rate varies. If there are less than the required number of depth channels, the dataset randomly repeats layers. 

Classes: `BTRCDataset`

In [None]:
class BTRCDataset(torch.utils.data.Dataset):
    def __init__(self, df, data_dir,  cohort='FLAIR'):
        self.df = df
        self.data_dir = data_dir
        self.cohort = cohort
        
        
    def __getitem__(self, idx):
        # get sample info
        sample_id, target = self.df.loc[idx].values
        
        # get sample path. combination of dir, padded id and cohort
        sample_dir = os.path.join(self.data_dir, f'{sample_id:05d}', self.cohort)
        sample_files = os.listdir(sample_dir)
        
        # take subset of available images if n_images > 64
        if len(sample_files) > 64:
            sample_files = np.random.choice(sample_files, size=64, replace=False)
        
        # sort samples
        sample_files = sorted(sample_files, key=lambda x: int(x[6:-4]))
        
        # load images
        imgs = [self.read_img(os.path.join(sample_dir, path)) for path in sample_files]
        imgs = np.stack(imgs)
        
        # resample images if not enough samples are available
        if len(sample_files) < 64:
            indices = sorted(np.random.choice(len(sample_files), size=64, replace=True))
            
            imgs = np.stack(imgs[indices])
        
        imgs = np.stack(imgs)
            
        return torch.tensor(imgs, dtype=torch.float32).unsqueeze(0), torch.tensor(target, dtype=torch.float32)
        
    def __len__(self):
        return len(self.df)
    
    def read_img(self, path):
        if path.endswith('dcm'):
            img = self.read_dicom(path)
        elif path.endswith('png'):
            img = self.read_png(path)
        else:
            print('unknown file format')
        img = cv2.resize(img, (CFG.img_size, CFG.img_size))
        return img
    
    @staticmethod
    def read_png(path):
        img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
        return img
    
    @staticmethod
    def read_dicom(path):
        dicom = pydicom.read_file(path)
        data = apply_voi_lut(dicom.pixel_array, dicom)
        if dicom.PhotometricInterpretation == "MONOCHROME1":
            data = np.amax(data) - data
        data = data - np.min(data)
        data = data / np.max(data)
        data = (data * 255).astype(np.uint8)
        return data


## Trainer <a class="anchor" id="trainer"></a>
Does the heavy lifting for training.

Classes: `AverageMeter`, `Trainer`

In [None]:
class AverageMeter(object):
    """
    Computes and stores the average and current value
    Copied from: https://github.com/pytorch/examples/blob/master/imagenet/main.py
    """

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


class Trainer:
    def __init__(self, cfg: type(TrainerConfig),
                 model: torch.nn.Module,
                 model_path: str,
                 dataset_train: Dataset = None,
                 dataset_val: Dataset = None,
                 wandb_run: wandb.sdk.wandb_run.Run = None):
        self.cfg = cfg
        self.model = model
        self.best_model = None
        self.model_path = model_path
        self.wandb_run = wandb_run

        # datasets
        self.dataset_train = dataset_train
        self.dataset_eval = dataset_val

        # dataloaders, train/eval is optional
        kwargs_dataloader = {'batch_size': self.cfg.batch_size, 'num_workers': 2}
        if self.dataset_eval is not None:
            self.dataloader_train = DataLoader(self.dataset_train, shuffle=True, **kwargs_dataloader)
        if self.dataset_eval is not None:
            self.dataloader_eval = DataLoader(self.dataset_eval, shuffle=False, **kwargs_dataloader)

        # setup loss
        self.loss_fnc = torch.nn.BCEWithLogitsLoss()

        # if train set is provided, setup training
        if dataset_train is not None:
            # init optimizer
            #self.optimizer = AdamW(self.model.parameters(), self.cfg.lr, weight_decay=self.cfg.weight_decay)
            self.optimizer = torch.optim.SGD(self.model.parameters(), self.cfg.lr, weight_decay=self.cfg.weight_decay, momentum=0.9)

            # setup lr scheduler
            _n_steps = cfg.num_epochs * len(self.dataloader_train)
            self.scheduler = get_cosine_schedule_with_warmup(
                optimizer=self.optimizer,
                num_warmup_steps=_n_steps * cfg.warm_up_ratio,
                num_training_steps=_n_steps,
            )
            # call optimizer once so we can properly init the lr in step_train()
            self.optimizer.zero_grad()
            self.optimizer.step()
            
        else:
            self.optimizer = None
            self.scheduler = None
            
        self.epoch = 0

    def train(self):
        print(f'Training model for {self.cfg.num_epochs} epochs.')
        print('Epoch | train_loss | eval_loss | train_auc | val_auc')
        train_log = pd.DataFrame(columns=['epoch', 'train_loss', 'eval_loss', 'train_auc', 'eval_auc', 'lr'])
        
        best_loss = 1e3

        for epoch in range(self.cfg.num_epochs):
            train_loss, train_auc = self.step_train()
            eval_loss, eval_auc = self.step_eval(return_predictions=False)
            
            log_item = {
                'epoch': epoch,
                'step': epoch*len(self.dataloader_train),
                'train_loss': train_loss,
                'eval_loss': eval_loss,
                'train_auc': train_auc,
                'eval_auc': eval_auc,
                'lr': self.optimizer.param_groups[0]['lr']
            }
            self.wandb_run.log(log_item)
            train_log = train_log.append(log_item, ignore_index=True)
            print(f"{epoch: <6}|{train_loss: >12.3f}|{eval_loss: >11.3f}|{train_auc: >11.3f}|{eval_auc: >8.3f}")
      
            # checkpointing
            if eval_loss < best_loss:
                torch.save(self.model, self.model_path)
                best_loss = eval_loss
            self.epoch += 1
        best_epoch = train_log.eval_loss.idxmin()
        print("Training done. Best model at epoch {} with eval_loss {:3.2f} and auc {:3.2f}".format(
            train_log.loc[best_epoch, 'epoch'],
            train_log.loc[best_epoch, 'eval_loss'],
            train_log.loc[best_epoch, 'eval_auc']
        ))
        return train_log

    def step_train(self):
        self.model.train()
        
        # setup logging
        loss_agg = AverageMeter()
        targets = []
        predictions = []
        
        # train one epoch
        for batch_idx, (x, y) in enumerate(self.dataloader_train):
            
            x = x.to(self.cfg.device)
            y = y.to(self.cfg.device)
            
            # forward pass  
            logits = self.model(x)
            loss = self.loss_fnc(logits.view(-1), y)

            # backward pass
            loss.backward()
            if (batch_idx+1) % self.cfg.gradient_accumulation_steps == 0 or (batch_idx+1) == len(self.dataloader_train):
                self.optimizer.step()
                self.optimizer.zero_grad()
                self.scheduler.step()

            # update loss meter
            loss_agg.update(loss.item(), self.cfg.batch_size)

            # save preds/targets for roc computation
            predictions.append(torch.sigmoid(logits.view(-1)).detach().cpu().squeeze().numpy())
            targets.append(y.detach().cpu().squeeze().numpy())

            # log every cfg.log_steps steps
            if batch_idx > 0 and batch_idx % self.cfg.log_steps == 0:
                log_item = {
                    'step': self.epoch*len(self.dataloader_train) + batch_idx,
                    'train_loss': loss_agg.avg,
                    'train_auc': roc_auc_score(np.hstack(targets), np.hstack(predictions)),
                    'lr': self.optimizer.param_groups[0]['lr']
                }
                self.wandb_run.log(log_item)
            
        # compute auc for whole epoch
        auc = roc_auc_score(np.hstack(targets), np.hstack(predictions))
        return loss_agg.avg, auc

    @torch.no_grad()
    def step_eval(self, return_predictions=False):
        self.model.eval()
        loss_agg = AverageMeter()
        predictions = []
        targets = []
        for batch_idx, (x, y) in enumerate(self.dataloader_eval):
            x = x.to(self.cfg.device)
            y = y.to(self.cfg.device)
            logits = self.model(x)
            loss = self.loss_fnc(logits.view(-1), y)

            # update loss meter
            loss_agg.update(loss.item(), self.cfg.batch_size)

            # optionally return predictions
            # save preds/targets for roc computation
            predictions.append(torch.sigmoid(logits.view(-1)).detach().cpu().squeeze().numpy())
            targets.append(y.detach().cpu().squeeze().numpy())

        # compute auc
        targets = np.hstack(targets)
        predictions = np.hstack(predictions)
        auc = roc_auc_score(targets, predictions)

        # setup output
        if return_predictions:
            out = (loss_agg.avg, auc, predictions)
        else:
            out = (loss_agg.avg, auc)
        return out
    

## Training <a class="anchor" id="training"></a>
Train k models and save the best epochs for each. Logging is done in console and with wandb.

In [None]:
# load training dataframe
train_df = pd.read_csv('../input/rsna-miccai-brain-tumor-radiogenomic-classification/train_labels.csv')

# drop 3 samples, see https://www.kaggle.com/c/rsna-miccai-brain-tumor-radiogenomic-classification/discussion/262046
train_df = train_df[~train_df.BraTS21ID.isin([109, 123, 709])]

# setup cross validation
train_df = get_folds(train_df, rng=CFG.seed)
train_df.to_csv('folds.csv')

# train each fold
for fold_idx in range(int(train_df['fold'].max())+1):
    # datasets
    df_train = train_df.loc[train_df.fold != fold_idx, ['BraTS21ID', 'MGMT_value']].reset_index(drop=True)
    df_eval = train_df.loc[train_df.fold == fold_idx, ['BraTS21ID', 'MGMT_value']].reset_index(drop=True)
    dataset_train = BTRCDataset(df_train, data_dir='../input/rsna-miccai-png/train', cohort=CFG.cohort)
    dataset_eval = BTRCDataset(df_eval, data_dir='../input/rsna-miccai-png/train', cohort=CFG.cohort)

    # init new model
    model = MedicalNetWithHead(model_name=CFG.model_name).to(DEVICE)
    
    # setup wandb
    wandb_run = wandb.init(project="kaggle-BTRC", config=config_wandb, group=wandb_group_name, name=f"fold{fold_idx}", job_type="finetuning")
    wandb_run.define_metric("step")
    wandb_run.define_metric("*", step_metric="step", step_sync=True)
    
    # setup trainer
    trainer = Trainer(cfg=TrainerConfig,
                      model=model,
                      model_path=f'model_fold{fold_idx}.torch',
                      dataset_train=dataset_train,
                      dataset_val=dataset_eval,
                      wandb_run=wandb_run)

    # train the model
    train_log = trainer.train()
    train_log.to_csv(f"fold{fold_idx}_log.csv")
    wandb_run.finish()
    

In [None]:
trainer = Trainer(cfg=TrainerConfig,
                      model=nn.Linear(10,10),
                      model_path=f'model_fold{fold_idx}.torch',
                      dataset_train=dataset_train,
                      dataset_val=dataset_eval,
                      wandb_run=None)