In this notebook, I will try to optimize CV score of the split-by-patient strategy. The public score might be lower than my [previous notebook](https://www.kaggle.com/nanguyen/brain-tumor-2d-cnn-pytorch-split-by-patient), but I hope to get a more reliable CV result.

In [None]:
import sys
sys.path.append('../input/monai-v070')
import os
import cv2
import glob
import pydicom
import random
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import time
import datetime
from dataclasses import dataclass, field
from sklearn.model_selection import train_test_split, KFold, StratifiedKFold
from sklearn.metrics import roc_auc_score
from copy import deepcopy
import gc

from monai.data import CacheDataset, DataLoader
from monai.transforms import *
from monai.networks.nets import *

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
seed_everything(42)

class AverageMeter:
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [None]:
DATA_DIR = '../input/rsna-miccai-brain-tumor-radiogenomic-classification/'
MRI_TYPES = ["FLAIR", "T1w", "T2w", "T1wCE"]

## Dataset

In [None]:
class BrainTumorDataset(CacheDataset):
    def __init__(self, root_dir, patient_ids, mri_types, annotations, *args, **kwargs):
        self.root_dir = root_dir
        self.patient_ids = patient_ids
        self.mri_types = mri_types
        self.annotations = annotations
        data = self.get_data()
        super(BrainTumorDataset, self).__init__(data, *args, **kwargs)
    
    def get_data(self):
        data = []
        for patient_id in tqdm(self.patient_ids):
            if self.annotations is not None:
                label = self.annotations[self.annotations['BraTS21ID'] 
                                         == int(patient_id)]['MGMT_value'].item()
            else:
                label = 0 # dummy value
            for slice_path in self.get_patient_slice_paths(patient_id):
                data.append({
                    'image': slice_path,
                    'label': label,
                    'patient_id': patient_id
                })
        return data
    
    def get_patient_slice_paths(self, patient_id):
        '''
        Returns an array of all the images of a particular type for a particular patient ID
        '''
        assert(set(self.mri_types) <= set(MRI_TYPES))
        patient_path = os.path.join(self.root_dir, str(patient_id).zfill(5))
        patient_slice_paths = []
        for mri_type in self.mri_types:
            paths = sorted(
                glob.glob(os.path.join(patient_path, mri_type, "*.dcm")), 
                key=lambda x: int(x[:-4].split("-")[-1]),
            )

            num_images = len(paths)
            start = int(num_images * 0.25)
            end = int(num_images * 0.75)

            interval = 3
            if num_images < 10: 
                interval = 1
            patient_slice_paths.extend(paths[start:end:interval])
        return patient_slice_paths
    
class LoadDicomd(MapTransform):
    def __init__(self, img_size, *args, **kwargs):
        self.img_size = img_size
        super(LoadDicomd, self).__init__(*args, **kwargs)
    
    def __call__(self, data):
        d = dict(data)
        for key in self.keys:
            d[key] = self.load_dicom(d[key])
        return d

    def load_dicom(self, path):
        ''' 
        Reads a DICOM image, standardizes so that the pixel values are between 0 and 1, 
        then rescales to 0 and 255
        '''
        dicom = pydicom.read_file(path)
        data = dicom.pixel_array
        if np.max(data) != 0:
            data = data / np.max(data)
        data = (data * 255).astype(np.uint8)
        data = cv2.resize(data, (self.img_size, self.img_size)) / 255
        return np.expand_dims(data, axis=0)

## Model

In [None]:
class Simple2dCNN(nn.Module):
    def __init__(self, 
                 input_channels=1, 
                 n_classes=2, 
                 img_size=32, 
                 conv1_filters=128,
                 conv2_filters=64,
                 dropout_prob=0.1,
                 fc1_units=48):
        super(Simple2dCNN, self).__init__()
        
        self.relu = nn.ReLU()
        
        self.conv1 = nn.Conv2d(input_channels, conv1_filters, 4)
        self.maxpool1 = nn.MaxPool2d(2)
        
        self.conv2 = nn.Conv2d(conv1_filters, conv2_filters, 2)
        self.maxpool2 = nn.MaxPool2d(1)
        
        self.dropout = nn.Dropout(dropout_prob)
        last_feature_map_size = (img_size - 3) // 2 - 1
        self.fc1 = nn.Linear(conv2_filters * last_feature_map_size**2, fc1_units)
        self.fc2 = nn.Linear(fc1_units, n_classes)

    def forward(self, x):
        # (None, 1, 32, 32)
        x = self.relu(self.conv1(x)) # (None, 128, 29, 29)
        x = self.maxpool1(x) # (None, 128, 14, 14)
        
        x = self.relu(self.conv2(x)) # (None, 64, 13, 13)
        x = self.maxpool2(x) # (None, 64, 13, 13)
        
        x = self.dropout(x)
        x = x.view(x.size(0), -1) # (None, 64 * 13 * 13)
        x = self.relu(self.fc1(x)) # (None, 48)
        x = self.fc2(x) # (None, 2)
        return x

## Pipeline

In [None]:
@dataclass
class Config:
    train_dir: str = os.path.join(DATA_DIR, 'train')
    test_dir: str = os.path.join(DATA_DIR, 'test')
    annotation_path: str = os.path.join(DATA_DIR, 'train_labels.csv')
    model_name: str = 'efficientnet-b0'
    n_classes: int = 2
    img_size: int = 224
    n_workers: int = 4
    early_stopping_rounds: int = 3
    n_folds: int = 5
        
        
class Pipeline:
    def __init__(self, config):
        self.args = config
        self.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
        self.annotations = None
        self.model = None
        self.load_model()
        # transforms
        self.preaugment_transform = [
            LoadDicomd(keys="image", img_size=self.args.img_size),
        ]
        self.augment_transform = [
            RandAffined(
                keys="image",
                prob=1.0,
                rotate_range=(np.pi/9, np.pi/9),
                scale_range=(0.1, 0.1),
                shear_range=(0.1, 0.1),
            ),
            RandScaleIntensityd(keys="image", factors=0.3, prob=1.0),
            RandShiftIntensityd(keys="image", offsets=0.3, prob=1.0),
        ]
        self.postaugment_transform = [
            ToTensord(keys="image", dtype=torch.float),
            ToTensord(keys="label", dtype=torch.int64),
        ]
        
    def load_annotations(self):
        self.annotations = pd.read_csv(self.args.annotation_path)
        # exclude 3 cases
        self.annotations = self.annotations[~self.annotations['BraTS21ID'].isin([109, 123, 709])]
        self.annotations = self.annotations.reset_index(drop=True)
        skf = StratifiedKFold(n_splits=self.args.n_folds, shuffle=True, random_state=42)
        # split by patient, stratify based on target value
        folds = skf.split(self.annotations['BraTS21ID'].values, self.annotations['MGMT_value'].values)
        for i, (train_indices, val_indices) in enumerate(folds):
            self.annotations.loc[val_indices, 'fold'] = i
        self.annotations['fold'] = self.annotations['fold'].astype(int)
    
    def load_model(self, weights_path=None):
        self.model = resnet18(spatial_dims=2, 
                              n_input_channels=1, 
                              num_classes=self.args.n_classes).to(self.device)
#         self.model = EfficientNetBN(model_name=self.args.model_name,
#                                     pretrained=False,
#                                     spatial_dims=2,
#                                     in_channels=1,
#                                     num_classes=self.args.n_classes).to(self.device)

#         self.model = Simple2dCNN(input_channels=1, 
#                                  n_classes=self.args.n_classes, 
#                                  img_size=self.args.img_size, 
#                                  conv1_filters=128,
#                                  conv2_filters=64,
#                                  dropout_prob=0.1,
#                                  fc1_units=48).to(self.device)
        if weights_path:
            weights = torch.load(weights_path, map_location=self.device)
            self.model.load_state_dict(weights)
            
    def prepare_datasets(self, mri_types, fold, cache_rate):
        """
        Data format:
        {
            'image': torch tensor (batch_size, 1, img_size, img_size),
            'label': torch tensor (batch_size, )
            'patient_id'
        }
        """
        train_transform = Compose(
            self.preaugment_transform +
            self.augment_transform +
            self.postaugment_transform
        )
        val_transform = Compose(
            self.preaugment_transform +
            self.postaugment_transform
        )
        
        train_ids = self.annotations[self.annotations['fold']!=fold]['BraTS21ID'].values.tolist()
        val_ids = self.annotations[self.annotations['fold']==fold]['BraTS21ID'].values.tolist()
        
        train_ds = BrainTumorDataset(root_dir=self.args.train_dir, 
                                     patient_ids=train_ids, 
                                     mri_types=mri_types,  
                                     annotations=self.annotations,
                                     transform=train_transform,
                                     cache_rate=cache_rate,
                                     num_workers=self.args.n_workers)
        val_ds = BrainTumorDataset(root_dir=self.args.train_dir, 
                                   patient_ids=val_ids, 
                                   mri_types=mri_types, 
                                   annotations=self.annotations,
                                   transform=val_transform,
                                   cache_rate=cache_rate,
                                   num_workers=self.args.n_workers)
        return train_ds, val_ds
    
    def prepare_test_dataset(self, mri_types, cache_rate):
        test_transform = Compose(
            self.preaugment_transform +
            self.postaugment_transform
        )
        test_ids = [int(patient_id) for patient_id in os.listdir(self.args.test_dir)]
        test_ids = sorted(test_ids, key=lambda x: int(x))
        test_ds = BrainTumorDataset(root_dir=self.args.test_dir, 
                                    patient_ids=test_ids, 
                                    mri_types=mri_types, 
                                    annotations=None, 
                                    transform=test_transform,
                                    cache_rate=cache_rate,
                                    num_workers=self.args.n_workers)
        return test_ds
    
    def train_epoch(self, loader, loss_function, optimizer, verbose):
        self.model.train()
        summary_loss = AverageMeter()
        start = time.time()
        n = len(loader)
        for step, batch_data in enumerate(loader):
            inputs, labels = (
                batch_data["image"].to(self.device), # (None, 1, img_size, img_size)
                batch_data["label"].to(self.device), # (None, )
            )
            batch_size = inputs.size(0)
            # back propagation
            optimizer.zero_grad()
            outputs = self.model(inputs) # (None, 2)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            # update stats
            summary_loss.update(loss.item(), batch_size)
            if verbose:
                print('Train step {}/{}, loss: {:.5f}'.format(step + 1, n, 
                                                              summary_loss.avg), end='\r')
        elapsed_time = str(datetime.timedelta(seconds=time.time() - start))
        print('Train loss: {:.5f} - time: {}'.format(summary_loss.avg, elapsed_time))
        return summary_loss.avg
    
    def evaluate_epoch(self, loader, loss_function, verbose):
        self.model.eval()
        summary_loss = AverageMeter()
        start = time.time()
        n = len(loader)
        patient_ids_all = []
        probabilities_all = []
        labels_all = []
        with torch.no_grad():
            for step, batch_data in enumerate(loader):
                inputs, labels, patient_ids = (
                    batch_data["image"].to(self.device), # (None, 1, img_size, img_size)
                    batch_data["label"].to(self.device), # (None, )
                    batch_data["patient_id"], # (None, )
                )
                batch_size = inputs.size(0)
                # back propagation
                outputs = self.model(inputs) # (None, 2)
                loss = loss_function(outputs, labels)
                # update stats
                probabilities = F.softmax(outputs, dim=1)[:, 1].tolist()
                probabilities_all.extend(probabilities)
                labels_all.extend(labels.tolist())
                patient_ids_all.extend(patient_ids)
                
                summary_loss.update(loss.item(), batch_size)
                if verbose:
                    print('Val step {}/{}, loss: {:.5f}'.format(step + 1, n, 
                                                                summary_loss.avg), end='\r')
        elapsed_time = str(datetime.timedelta(seconds=time.time() - start))
        print('Val loss: {:.5f} - time: {}'.format(summary_loss.avg, elapsed_time))
        result = {
            'BraTS21ID': list(map(lambda x: x.item(), patient_ids_all)), 
            'probability': probabilities_all,
            'label': labels_all
        }
        result = pd.DataFrame(result)
        slice_auc = roc_auc_score(result['label'], result['probability'])
        result = result.groupby("BraTS21ID", as_index=False).mean()
        patient_auc = roc_auc_score(result['label'], result['probability'])
        print('Patient AUC: {:.5f} - Slice AUC: {:.5f}'.format(patient_auc, slice_auc))
        
        return summary_loss.avg, patient_auc, result
    
    def infer_epoch(self, loader, verbose):
        self.model.eval()
        start = time.time()
        n = len(loader)
        patient_ids_all = []
        probabilities_all = []
        with torch.no_grad():
            for step, batch_data in enumerate(loader):
                inputs, patient_ids = (
                    batch_data["image"].to(self.device), # (None, 1, img_size, img_size)
                    batch_data["patient_id"], # (None, )
                )
                batch_size = inputs.size(0)
                # forward
                outputs = self.model(inputs) # (None, 2)
                # update stats
                probabilities = F.softmax(outputs, dim=1)[:, 1].tolist()
                probabilities_all.extend(probabilities)
                patient_ids_all.extend(patient_ids)
                if verbose:
                    print('Infer step {}/{}'.format(step + 1, n), end='\r')
        
        result = {
            'BraTS21ID': list(map(lambda x: x.item(), patient_ids_all)), 
            'probability': probabilities_all,
        }
        result = pd.DataFrame(result)
        result = result.groupby("BraTS21ID", as_index=False).mean()
        
        elapsed_time = str(datetime.timedelta(seconds=time.time() - start))
        print('Elapsed time: {}'.format(elapsed_time))
        
        return result
    
    def fit(self, train_ds, val_ds, batch_size, epochs, lr, model_name, verbose):
        train_loader = DataLoader(train_ds, 
                                  batch_size=batch_size, 
                                  shuffle=True,
                                  num_workers=self.args.n_workers)
        val_loader = DataLoader(val_ds, 
                                batch_size=batch_size, 
                                shuffle=False,
                                num_workers=self.args.n_workers)
        loss_function = nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
        
        current_metric = -np.inf
        current_loss = np.inf
        current_epoch = 1
        current_state_dict = None
        save_path = '{}_imgsize{}_valloss{:.3f}_valauc{:.3f}.pth'
        for epoch in range(1, epochs + 1):
            print('\nEpoch {}/{}:'.format(epoch, epochs))
            train_loss = self.train_epoch(train_loader, loss_function, optimizer, verbose)
            val_loss, val_metric, _ = self.evaluate_epoch(val_loader, loss_function, verbose)
            
            if val_loss < current_loss:
                print('Val loss improved from {:.5f} to {:.5f}'.format(current_loss, val_loss))
#             if val_metric > current_metric:
#                 print('Val AUC improved from {:.5f} to {:.5f}'.format(current_metric, val_metric))
                current_metric = val_metric
                current_loss = val_loss
                current_epoch = epoch
                current_state_dict = deepcopy(self.model.state_dict())
                
            elif (epoch - current_epoch) > self.args.early_stopping_rounds:
                print('Early stopping. Best model is epoch {}'.format(current_epoch))
                print('Val loss: {:.5f}, Val auc: {:.5f}'.format(current_loss, current_metric))
                print('Saving model...')
                torch.save(current_state_dict, 
                           save_path.format(model_name,
                                            self.args.img_size, 
                                            current_loss, 
                                            current_metric))
                break
            if epoch == epochs:
                print('Finished training. Best model is epoch {}'.format(current_epoch))
                print('Val loss: {:.5f}, Val auc: {:.5f}'.format(current_loss, current_metric))
                print('Saving model...')
                torch.save(current_state_dict, 
                           save_path.format(model_name,
                                            self.args.img_size, 
                                            current_loss, 
                                            current_metric))
                
    def evaluate(self, val_ds, batch_size, verbose):
        val_loader = DataLoader(val_ds, 
                                batch_size=batch_size, 
                                shuffle=False,
                                num_workers=self.args.n_workers)
        loss_function = nn.CrossEntropyLoss()
        _, val_metric, val_result = self.evaluate_epoch(val_loader, loss_function, verbose)
        return val_metric, val_result
    
    def predict(self, test_ds, batch_size, verbose):
        test_loader = DataLoader(test_ds, 
                                 batch_size=batch_size, 
                                 shuffle=False,
                                 num_workers=self.args.n_workers)
        test_result = self.infer_epoch(test_loader, verbose)
        return test_result

In [None]:
# mri_types = ['FLAIR']
model_name = 'resnet18' # '2dCNN' 
img_size = 224 # 32
batch_size = 128
n_workers = 4
early_stopping_rounds = 5
n_folds = 5
epochs = 50
lr = 1e-5 

In [None]:
args = Config(model_name=model_name,
              img_size=img_size, 
              n_workers=n_workers, 
              early_stopping_rounds=early_stopping_rounds,
              n_folds=n_folds)
pipeline = Pipeline(args)

## Train

In [None]:
def train(pipeline, mri_type, n_folds, batch_size, epochs, lr, model_name):
    for fold in range(n_folds):
        print(f'### Train {mri_type} on fold {fold}: ###')
        train_ds, val_ds = pipeline.prepare_datasets(mri_types=[mri_type], 
                                                     fold=fold,
                                                     cache_rate=1.0)
        pipeline.load_model()
        pipeline.fit(train_ds, val_ds,
                     batch_size=batch_size, epochs=epochs, lr=lr, 
                     model_name=f'{model_name}_{mri_type}_fold{fold}',
                     verbose=True)
        del train_ds
        del val_ds
        gc.collect()

In [None]:
pipeline.load_annotations()
for mri_type in MRI_TYPES:
    train(pipeline, mri_type, n_folds, batch_size, epochs, lr, model_name)

## Evaluate

In [None]:
def evaluate(pipeline, mri_type, n_folds, batch_size, model_name):
    metrics = []
    results = []
    find_weight = lambda x: [w for w in os.listdir() if x in w][0]
    weights_paths = [f'{model_name}_{mri_type}_fold{fold}' for fold in range(n_folds)]
    weights_paths = [find_weight(x) for x in weights_paths]
    for fold, weights_path in enumerate(weights_paths):
        print(f'### Evaluate {mri_type} on fold {fold}: ###')
        _, val_ds = pipeline.prepare_datasets(mri_types=[mri_type], 
                                              fold=fold,
                                              cache_rate=0.0)
        pipeline.load_model(weights_path)
        val_metric, val_result = pipeline.evaluate(val_ds, batch_size=batch_size, verbose=True)
        metrics.append(val_metric)
        results.append(val_result)
    results = pd.concat(results, ignore_index=True)
    mean_auc = np.mean(metrics)
    oof_auc = roc_auc_score(results['label'], results['probability'])
    print('---')
    print(f'{mri_type} validation result:')
    print(' Mean AUC: {:.5f}'.format(mean_auc))
    print(' Out-of-fold AUC: {:.5f}'.format(oof_auc))
    print('---')
    return results, mean_auc, oof_auc

In [None]:
oof_predictions = dict()
mean_aucs = dict()
oof_aucs = dict()
for mri_type in MRI_TYPES:
    oof_prediction, mean_auc, oof_auc = evaluate(pipeline, mri_type, n_folds, batch_size, model_name)
    oof_predictions[mri_type] = oof_prediction
    mean_aucs[mri_type] = mean_auc
    oof_aucs[mri_type] = oof_auc

## Blend modalities

In this section, I will try to to find the blending weights of the 4 modalities to optimize the out-of-fold AUC of the ensemble.

In [None]:
label_df = pd.read_csv(os.path.join(DATA_DIR, 'train_labels.csv')).set_index('BraTS21ID')
oof_df = pd.concat([oof_predictions[mri_type][['BraTS21ID', 'probability']].rename(columns={'probability': mri_type}).set_index('BraTS21ID') 
                    for mri_type in MRI_TYPES] + [label_df], axis=1, join='inner').reset_index()
oof_df

In [None]:
def find_best_weight(oof_df):
    w_range = np.arange(0.0, 1.01, 0.01)
    best_auc = -np.inf
    best_w = None
    for w_flair in tqdm(w_range):
        for w_t1w in w_range:
            for w_t2w in w_range:
                if w_flair + w_t1w + w_t2w > 1:
                    continue
                w_t1wce = 1 - (w_flair + w_t1w + w_t2w)
                pred = (w_flair * oof_df['FLAIR'] + w_t1w * oof_df['T1w'] 
                        + w_t2w * oof_df['T2w'] + w_t1wce * oof_df['T1wCE'])
                auc = roc_auc_score(oof_df['MGMT_value'], pred)
                if auc > best_auc:
                    best_auc = auc
                    best_w = [w_flair, w_t1w, w_t2w, w_t1wce]
    return best_w, best_auc

In [None]:
best_w, best_auc = find_best_weight(oof_df)
print(best_w)
print(best_auc)

In [None]:
oof_df['pred'] = (best_w[0] * oof_df['FLAIR'] + best_w[1] * oof_df['T1w'] 
                  + best_w[2] * oof_df['T2w'] + best_w[3] * oof_df['T1wCE'])
oof_df[['BraTS21ID', 'pred']].to_csv('oof_prediction.csv',index=False)

## Inference

* Prediction of each modality is the average of the predictions of 5 folds.
* Final prediction for the test set is the ensemble of 4 modalities.

In [None]:
def inference(pipeline, mri_type, n_folds, batch_size, model_name):
    test_results = []
    find_weight = lambda x: [w for w in os.listdir() if x in w][0]
    weights_paths = [f'{model_name}_{mri_type}_fold{fold}' for fold in range(n_folds)]
    weights_paths = [find_weight(x) for x in weights_paths]
    for fold, weights_path in enumerate(weights_paths):
        print(f'### Inference {mri_type} on fold {fold}: ###')
        test_ds = pipeline.prepare_test_dataset(mri_types=[mri_type], cache_rate=0.0)
        pipeline.load_model(weights_path)
        test_result = pipeline.predict(test_ds, batch_size=batch_size, verbose=True)
        test_results.append(test_result)
    prediction = pd.concat([x.set_index('BraTS21ID') for x in test_results], axis=1).mean(axis=1)
    prediction = pd.DataFrame(prediction, columns=['probability']).reset_index()
    return prediction

In [None]:
test_predictions = dict()
for mri_type in MRI_TYPES:
    test_prediction = inference(pipeline, mri_type, n_folds, batch_size, model_name)
    test_predictions[mri_type] = test_prediction

In [None]:
test_df = pd.concat([test_predictions[mri_type][['BraTS21ID', 'probability']].rename(columns={'probability': mri_type}).set_index('BraTS21ID') 
                     for mri_type in MRI_TYPES], axis=1, join='inner').reset_index()
test_df['MGMT_value'] = (best_w[0] * test_df['FLAIR'] + best_w[1] * test_df['T1w'] 
                         + best_w[2] * test_df['T2w'] + best_w[3] * test_df['T1wCE'])
test_df[['BraTS21ID', 'MGMT_value']].to_csv('submission.csv',index=False)