My notebook extend from [notebook](https://www.kaggle.com/agrica/btrc-pretrained-medicalnet-finetuning-public) and use dataset from [dataset](https://www.kaggle.com/c/rsna-miccai-brain-tumor-radiogenomic-classification/discussion/253000)
My work:
- Use Monai frame (develop for medical image and includes many architectures like ViT, efficienet, .... 2D and 3D)
- Use Monai for Augmentation ( easy transform data into 2D and 3D planes).
- Apply SAM optimizer
- More infomation [Monai](http://https://github.com/Project-MONAI/tutorials)
I hope my work can be useful for this challenge.
Many thanks [agrica](https://www.kaggle.com/agrica) and [jonathanbesomi](https://www.kaggle.com/jonathanbesomi) for your work that is good point for starting this challenge.

In [None]:
!pip install wandb --upgrade

In [None]:
!git clone https://github.com/Project-MONAI/MONAI.git
%cd MONAI/
!pip install -e '.[all]'

In [None]:
!pip install adamp

In [None]:
import monai
from monai.data import ImageDataset
from monai.transforms import Zoom,AddChannel, Compose, CenterScaleCrop,RandRotate90, Resize, ScaleIntensity, \
EnsureType,BorderPad,Zoom,RandZoom,SpatialPad,RandCoarseDropout,\
Flip,CenterScaleCrop,RandGaussianNoise,RandShiftIntensity,AdjustContrast,\
RandAdjustContrast
from monai.transforms import RandCoarseDropout
# system
import sys
import os
import random
import time

# data processing
import numpy as np
import pandas as pd

# dl librarires
## torch
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset

## transformers
from transformers import get_cosine_schedule_with_warmup
from transformers.optimization import AdamW

## sklearn
from sklearn import model_selection
from sklearn.metrics import roc_auc_score

# wandb for logging
import wandb
from kaggle_secrets import UserSecretsClient

# plotting
import cv2
import matplotlib.pyplot as plt
from IPython import display

# med
import pydicom
from pydicom.pixel_data_handlers.util import apply_voi_lut

from adamp import AdamP

In [None]:
class CFG:
    # system
    seed=42
    no_cuda=False
    # model
    use_pretrained = True
    # cohort to use
    cohort = 'FLAIR'
    optimizer_name = 'AdamW'
    SAM = True
    img_size=256
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    img_size = (128,128,128)
    roi_scale = (0.8,0.8,0.8)
    dropout_rate  = 0.3
    visualize = True
class TrainerConfig:
    num_epochs = 100
    batch_size = 8
    gradient_accumulation_steps = 1
    
    # optimizer
    lr = 3e-4
    warm_up_ratio = 0.1
    weight_decay = 0.0005

    # log every log_steps to wandb
    log_steps = 20

    # environment
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# create 5 stratified folds, optionally pass rng generator 
def get_folds(df, rng=42):
    # read df
    if isinstance(df, str):
        train_df = pd.read_csv(df)
    elif isinstance(df, pd.DataFrame):
        train_df = df
    else:
        print(f"Didn't understand data type for df: {type(df)}")

    

    # shuffle before split
    train_df = train_df.sample(frac=1, random_state=rng).reset_index(drop=True)

    # get stratified splits with sklearn
    kf = model_selection.StratifiedKFold(n_splits=5)
    # train_df.loc[:, 'fold'] = 0
    for fold_idx, (_, indices_test) in enumerate(kf.split(X=train_df.BraTS21ID, y=train_df.MGMT_value)):
        train_df.loc[indices_test, 'fold'] = fold_idx
    return train_df

In [None]:
# system setup
seed_everything(CFG.seed)

# setup some wandb params
# login wandb
user_secrets = UserSecretsClient()
wandb_api = user_secrets.get_secret("wandb_api")
wandb.login(key=wandb_api)

# config for logging
config_wandb = {
    'learning_rate': TrainerConfig.lr,
    'num_epochs': TrainerConfig.num_epochs,
    'batch_size': TrainerConfig.batch_size,
    'warm_up_ratio': TrainerConfig.warm_up_ratio,
    'weight_decay': TrainerConfig.weight_decay,
}
wandb_group_name = "exp" + wandb.util.generate_id()

In [None]:
class BTRCDataset(torch.utils.data.Dataset):
    def __init__(self, df, data_dir,  transform=None, cohort='FLAIR'):
        self.df = df
        self.data_dir = data_dir
        self.cohort = cohort
        self.transform = transform
        
    def __getitem__(self, idx):
        # get sample info
        sample_id, target = self.df.loc[idx].values
        
        # get sample path. combination of dir, padded id and cohort
        sample_dir = os.path.join(self.data_dir, f'{sample_id:05d}', self.cohort)
        sample_files = os.listdir(sample_dir)
        
        # take subset of available images if n_images > 64
        if len(sample_files) > 64:
            sample_files = np.random.choice(sample_files, size=64, replace=False)
        
        # sort samples
        sample_files = sorted(sample_files, key=lambda x: int(x[6:-4]))
        
        # load images
        imgs = [self.read_img(os.path.join(sample_dir, path)) for path in sample_files]
        imgs = np.stack(imgs)
        
        # resample images if not enough samples are available
        if len(sample_files) < 64:
            indices = sorted(np.random.choice(len(sample_files), size=64, replace=True))
            imgs = np.stack(imgs[indices])     
        imgs = np.stack(imgs)
        if self.transform:
            imgs = self.transform(imgs)
#             print(imgs.shape)
        return imgs.type(torch.float32), torch.tensor(target, dtype=torch.float32)

    def __len__(self):
        return len(self.df)
    
    def read_img(self, path):
        if path.endswith('dcm'):
            img = self.read_dicom(path)
        elif path.endswith('png'):
            img = self.read_png(path)
        else:
            print('unknown file format')
        img = cv2.resize(img, (CFG.img_size[0], CFG.img_size[0]))
        return img
    
    @staticmethod
    def read_png(path):
        img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
        return img
    
    @staticmethod
    def read_dicom(path):
        dicom = pydicom.read_file(path)
        data = apply_voi_lut(dicom.pixel_array, dicom)
        if dicom.PhotometricInterpretation == "MONOCHROME1":
            data = np.amax(data) - data
        data = data - np.min(data)
        data = data / np.max(data)
        data = (data * 255).astype(np.uint8)
        return data

In [None]:
import torch

class SAM(torch.optim.Optimizer):
    def __init__(self, params, base_optimizer, rho=0.05, **kwargs):
        assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"

        defaults = dict(rho=rho, **kwargs)
        super(SAM, self).__init__(params, defaults)

        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups

    @torch.no_grad()
    def first_step(self, zero_grad=False):
        grad_norm = self._grad_norm()
        for group in self.param_groups:
            scale = group["rho"] / (grad_norm + 1e-12)

            for p in group["params"]:
                if p.grad is None: continue
                e_w = p.grad * scale.to(p)
                p.add_(e_w)  # climb to the local maximum "w + e(w)"
                self.state[p]["e_w"] = e_w

        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad=False):
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None: continue
                p.sub_(self.state[p]["e_w"])  # get back to "w" from "w + e(w)"

        self.base_optimizer.step()  # do the actual "sharpness-aware" update

        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def step(self, closure=None):
        assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"
        closure = torch.enable_grad()(closure)  # the closure should do a full forward-backward pass

        self.first_step(zero_grad=True)
        closure()
        self.second_step()

    def _grad_norm(self):
        shared_device = self.param_groups[0]["params"][0].device  # put everything on the same device, in case of model parallelism
        norm = torch.norm(
                    torch.stack([
                        p.grad.norm(p=2).to(shared_device)
                        for group in self.param_groups for p in group["params"]
                        if p.grad is not None
                    ]),
                    p=2
               )
        return norm

In [None]:
class AverageMeter(object):
    """
    Computes and stores the average and current value
    Copied from: https://github.com/pytorch/examples/blob/master/imagenet/main.py
    """

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


class Trainer:
    def __init__(self, cfg: type(TrainerConfig),
                 model: torch.nn.Module,
                 model_path: str,
                 dataset_train: Dataset = None,
                 dataset_val: Dataset = None,
                 wandb_run: wandb.sdk.wandb_run.Run = None):
        self.cfg = cfg
        self.model = model
        self.best_model = None
        self.model_path = model_path
        self.wandb_run = wandb_run

        # datasets
        self.dataset_train = dataset_train
        self.dataset_eval = dataset_val

        # dataloaders, train/eval is optional
        kwargs_dataloader = {'batch_size': self.cfg.batch_size, 'num_workers': 2}
        if self.dataset_eval is not None:
            self.dataloader_train = DataLoader(self.dataset_train, shuffle=True, **kwargs_dataloader)
        if self.dataset_eval is not None:
            self.dataloader_eval = DataLoader(self.dataset_eval, shuffle=False, **kwargs_dataloader)

        # setup loss
        self.loss_fnc = torch.nn.BCEWithLogitsLoss()

        # if train set is provided, setup training
        if dataset_train is not None:
            # init optimizer
            if CFG.optimizer_name == 'AdamW':
                self.optimizer = AdamW(self.model.parameters(), self.cfg.lr, weight_decay=self.cfg.weight_decay)            
            elif CFG.optimizer_name == 'SGD':
                self.optimizer = torch.optim.SGD(self.model.parameters(), self.cfg.lr, weight_decay=self.cfg.weight_decay, momentum=0.9)
#             elif CFG.SAM:
#                 self.optimizer = SAM(model.parameters(), eval(CFG.optimizer_name), lr=self.cfg.lr, weight_decay=self.cfg.weight_decay)
                
            # setup lr scheduler
            _n_steps = cfg.num_epochs * len(self.dataloader_train)
            self.scheduler = get_cosine_schedule_with_warmup(
                optimizer=self.optimizer,
                num_warmup_steps=_n_steps * cfg.warm_up_ratio,
                num_training_steps=_n_steps,
            )
            # call optimizer once so we can properly init the lr in step_train()
            self.optimizer.zero_grad()
            self.optimizer.step()
            
        else:
            self.optimizer = None
            self.scheduler = None
            
        self.epoch = 0

    def train(self):
        print(f'Training model for {self.cfg.num_epochs} epochs.')
        print('Epoch | train_loss | eval_loss | train_auc | val_auc')
        train_log = pd.DataFrame(columns=['epoch', 'train_loss', 'eval_loss', 'train_auc', 'eval_auc', 'lr'])
        
        best_loss = 1e3

        for epoch in range(self.cfg.num_epochs):
            train_loss, train_auc = self.step_train()
            eval_loss, eval_auc = self.step_eval(return_predictions=False)
            
            log_item = {
                'epoch': epoch,
                'step': epoch*len(self.dataloader_train),
                'train_loss': train_loss,
                'eval_loss': eval_loss,
                'train_auc': train_auc,
                'eval_auc': eval_auc,
                'lr': self.optimizer.param_groups[0]['lr']
            }
            self.wandb_run.log(log_item)
            train_log = train_log.append(log_item, ignore_index=True)
            print(f"{epoch: <6}|{train_loss: >12.3f}|{eval_loss: >11.3f}|{train_auc: >11.3f}|{eval_auc: >8.3f}")
      
            # checkpointing
            if eval_loss < best_loss:
                torch.save(self.model, self.model_path)
                best_loss = eval_loss
            self.epoch += 1
        best_epoch = train_log.eval_loss.idxmin()
        print("Training done. Best model at epoch {} with eval_loss {:3.2f} and auc {:3.2f}".format(
            train_log.loc[best_epoch, 'epoch'],
            train_log.loc[best_epoch, 'eval_loss'],
            train_log.loc[best_epoch, 'eval_auc']
        ))
        return train_log

    def step_train(self):
        self.model.train()
        
        # setup logging
        loss_agg = AverageMeter()
        targets = []
        predictions = []
        
        # train one epoch
        for batch_idx, (x, y) in enumerate(self.dataloader_train):
            
            x = x.to(self.cfg.device)
            y = y.to(self.cfg.device)
            # forward pass  
            logits = self.model(x)
            loss = self.loss_fnc(logits[0].view(-1), y)
            # backward pass
            loss.backward()
            if (batch_idx+1) % self.cfg.gradient_accumulation_steps == 0 or (batch_idx+1) == len(self.dataloader_train):
                self.optimizer.step()
                self.optimizer.zero_grad()
                self.scheduler.step()

            # update loss meter
            loss_agg.update(loss.item(), self.cfg.batch_size)

            # save preds/targets for roc computation
            predictions.append(torch.sigmoid(logits[0].view(-1)).detach().cpu().squeeze().numpy())
            targets.append(y.detach().cpu().squeeze().numpy())

            # log every cfg.log_steps steps
            if batch_idx > 0 and batch_idx % self.cfg.log_steps == 0:
                log_item = {
                    'step': self.epoch*len(self.dataloader_train) + batch_idx,
                    'train_loss': loss_agg.avg,
                    'train_auc': roc_auc_score(np.hstack(targets), np.hstack(predictions)),
                    'lr': self.optimizer.param_groups[0]['lr']
                }
                self.wandb_run.log(log_item)
            
        # compute auc for whole epoch
        auc = roc_auc_score(np.hstack(targets), np.hstack(predictions))
        return loss_agg.avg, auc

    @torch.no_grad()
    def step_eval(self, return_predictions=False):
        self.model.eval()
        loss_agg = AverageMeter()
        predictions = []
        targets = []
        for batch_idx, (x, y) in enumerate(self.dataloader_eval):
            x = x.to(self.cfg.device)
            y = y.to(self.cfg.device)
            logits = self.model(x)
#             loss = self.loss_fnc(logits.view(-1), y)
            loss = self.loss_fnc(logits[0].view(-1), y)
            # update loss meter
            loss_agg.update(loss.item(), self.cfg.batch_size)

            # optionally return predictions
            # save preds/targets for roc computation
            predictions.append(torch.sigmoid(logits[0].view(-1)).detach().cpu().squeeze().numpy())
            targets.append(y.detach().cpu().squeeze().numpy())

        # compute auc
        targets = np.hstack(targets)
        predictions = np.hstack(predictions)
        auc = roc_auc_score(targets, predictions)

        # setup output
        if return_predictions:
            out = (loss_agg.avg, auc, predictions)
        else:
            out = (loss_agg.avg, auc)
        return out

In [None]:
cd ..

In [None]:
#SpatialPad(spatial_size=CFG.img_size),\
#                          RandCoarseDropout(holes=5,max_holes=10,prob=0.3,max_spatial_size=(28,28,28),spatial_size=(10,10,10)),\
#                          RandGaussianNoise(prob=1),\

#                          CenterScaleCrop(roi_scale=[0.3,0.4,0.5]),

#                          SpatialPad(spatial_size=CFG.img_size),

In [None]:
def get_transform(phase):
    if phase in ['train']:
        #Resize((32, 128, 128))
        trans = Compose([ScaleIntensity(),  AddChannel(), Resize((64, 128, 128)),\
                         SpatialPad(spatial_size=CFG.img_size),
#                          Flip(),
#                          RandAdjustContrast(),
#                          Zoom(zoom=0.5),
#                          RandShiftIntensity(offsets=10),\
                         EnsureType()])
    else:
        trans = Compose([ScaleIntensity(),  AddChannel(), Resize((64, 128, 128)), SpatialPad(spatial_size=CFG.img_size), EnsureType()])
        #CenterScaleCrop(roi_scale=CFG.roi_scale),
        #CenterScaleCrop(roi_scale=CFG.roi_scale),
    return trans

In [None]:
# df_train = train_df.loc[train_df.fold != 0, ['BraTS21ID', 'MGMT_value']].reset_index(drop=True)
# df_eval = train_df.loc[train_df.fold == 0, ['BraTS21ID', 'MGMT_value']].reset_index(drop=True)
# dataset_train = BTRCDataset(df_train, data_dir='../input/rsna-miccai-png/train', transform = get_transform('train'), cohort=CFG.cohort)
# dataset_eval = BTRCDataset(df_eval, data_dir='../input/rsna-miccai-png/train', transform = get_transform('valid'), cohort=CFG.cohort)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
def show_list_img(img):
    w = 10
    h = 10
    fig = plt.figure(figsize=(15, 15))
    columns = 2
    rows = 5
    for i in range(1, columns*rows +1):
        image = img[0][i+50]
        fig.add_subplot(rows, columns, i)
        plt.imshow(image)
    plt.show()

In [None]:
def visualize(dataloader_train):
    dem=0
    for img, label in dataloader_train:
        show_list_img(img)
        if dem==1:
            break
        dem+=1

In [None]:
# !wandb login --relogin

In [None]:
train_df = pd.read_csv('../input/rsna-miccai-brain-tumor-radiogenomic-classification/train_labels.csv')

# drop 3 samples, see https://www.kaggle.com/c/rsna-miccai-brain-tumor-radiogenomic-classification/discussion/262046
train_df = train_df[~train_df.BraTS21ID.isin([109, 123, 709])]

# setup cross validation
train_df = get_folds(train_df, rng=CFG.seed)
train_df.to_csv('folds.csv')

# train each fold
for fold_idx in range(int(train_df['fold'].max())+1):
    # datasets
    df_train = train_df.loc[train_df.fold != fold_idx, ['BraTS21ID', 'MGMT_value']].reset_index(drop=True)
    df_eval = train_df.loc[train_df.fold == fold_idx, ['BraTS21ID', 'MGMT_value']].reset_index(drop=True)
    dataset_train = BTRCDataset(df_train, data_dir='../input/rsna-miccai-png/train', transform = get_transform('train'), cohort=CFG.cohort)
    dataset_eval = BTRCDataset(df_eval, data_dir='../input/rsna-miccai-png/train', transform = get_transform('valid'), cohort=CFG.cohort)
    if CFG.visualize:
        visualize(dataset_train)
#     model = monai.networks.nets.DenseNet121(spatial_dims=3, in_channels=1, out_channels=1).to(CFG.device)
    model = monai.networks.nets.ViT(spatial_dims=3,in_channels=1, img_size=CFG.img_size, pos_embed='conv', \
                                    patch_size = 32, dropout_rate=CFG.dropout_rate-0.2 , \
                                    num_layers=24,num_heads=16,mlp_dim=4096,hidden_size=1024,
                                    num_classes =1,classification=True).to(CFG.device)
# #     model = monai.networks.nets.DynUNet(spatial_dims=3,in_channels=1, out_channels=1, img_size=CFG.img_size, pos_embed='conv', \
# #                                     patch_size = 32, dropout_rate=CFG.dropout_rate , \
# #                                     num_layers=24,num_heads=16,mlp_dim=4096,hidden_size=1024,
# #                                     num_classes =1,classification=True).to(CFG.device)
#     # setup wandb
    wandb_run = wandb.init(project="kaggle-BTRC", config=config_wandb, group=wandb_group_name, name=f"fold{fold_idx}", job_type="finetuning")
    wandb_run.define_metric("step")
    wandb_run.define_metric("*", step_metric="step", step_sync=True)
    
    # setup trainer
    trainer = Trainer(cfg=TrainerConfig,
                      model=model,
                      model_path=f'model_fold{fold_idx}.torch',
                      dataset_train=dataset_train,
                      dataset_val=dataset_eval,
                      wandb_run=wandb_run)
    train_log = trainer.train()
    train_log.to_csv(f"fold{fold_idx}_log.csv")
    wandb_run.finish()


In [None]:
# for img,label in dataset_train:
#     print(img)
#     break