In [None]:
import sys
sys.path.append('../input/monai-v070')

In [None]:
import os
import cv2
import glob
import pydicom
import random
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import seaborn as sns
import time
import datetime
from dataclasses import dataclass, field
from sklearn.model_selection import train_test_split, KFold, StratifiedKFold
from sklearn.metrics import roc_auc_score
from sklearn.metrics import roc_curve
from sklearn.metrics import f1_score
import matplotlib.pyplot as plt
from copy import deepcopy
from monai.data import CacheDataset, DataLoader
from monai.transforms import *

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
seed_everything(42)

class AverageMeter:
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()
        self.history = []
    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

class History:
    def __init__(self):
        self.history = []
    
    def add(self, val):
        self.history.append(val)

In [None]:
DATA_DIR = '../input/rsna-miccai-brain-tumor-radiogenomic-classification/'
MRI_TYPES = ["FLAIR", "T1w", "T2w", "T1wCE"]

In [None]:
class BrainTumorDataset(CacheDataset):
    def __init__(self, root_dir, patient_ids, mri_types, annotations, section, *args, **kwargs):
        self.root_dir = root_dir
        self.patient_ids = patient_ids
        self.mri_types = mri_types
        self.annotations = annotations
        data = self.get_data()
        if section is not None:
            train_data, val_data = train_test_split(data, test_size=0.2, random_state=42)
            data = train_data if section=='train' else val_data
        super(BrainTumorDataset, self).__init__(data, *args, **kwargs)
    
    def get_data(self):
        data = []
        for patient_id in tqdm(self.patient_ids):
            if self.annotations is not None:
                label = self.annotations[self.annotations['BraTS21ID'] 
                                         == int(patient_id)]['MGMT_value'].item()
            else:
                label = 0 
            for slice_path in self.get_patient_slice_paths(patient_id):
                data.append({
                    'image': slice_path,
                    'label': label,
                    'patient_id': patient_id
                })
        return data
    
    def get_patient_slice_paths(self, patient_id):
        '''
        Returns an array of all the images of a particular type for a particular patient ID
        '''
        assert(set(self.mri_types) <= set(MRI_TYPES))
        patient_path = os.path.join(self.root_dir, str(patient_id).zfill(5))
        patient_slice_paths = []
        for mri_type in self.mri_types:
            paths = sorted(
                glob.glob(os.path.join(patient_path, mri_type, "*.dcm")), 
                key=lambda x: int(x[:-4].split("-")[-1]),
            )

            num_images = len(paths)
            start = int(num_images * 0.25)
            end = int(num_images * 0.75)

            interval = 3
            if num_images < 10: 
                interval = 1
            patient_slice_paths.extend(paths[start:end:interval])
        return patient_slice_paths
    
class LoadDicomd(MapTransform):
    def __init__(self, img_size, *args, **kwargs):
        self.img_size = img_size
        super(LoadDicomd, self).__init__(*args, **kwargs)
    
    def __call__(self, data):
        d = dict(data)
        for key in self.keys:
            d[key] = self.load_dicom(d[key])
        return d

    def load_dicom(self, path):
        dicom = pydicom.read_file(path)
        data = dicom.pixel_array
        if np.max(data) != 0:
            data = data / np.max(data)
        data = (data * 255).astype(np.uint8)
        data = cv2.resize(data, (self.img_size, self.img_size)) / 255
        return np.expand_dims(data, axis=0)

In [None]:
class Simple2dCNN(nn.Module):
    def __init__(self, 
                 input_channels=1, 
                 n_classes=2, 
                 img_size=32, 
                 conv1_filters=128,
                 conv2_filters=64,
                 dropout_prob=0.1,
                 fc1_units=48):
        super(Simple2dCNN, self).__init__()
        
        self.relu = nn.ReLU()
        
        self.conv1 = nn.Conv2d(input_channels, conv1_filters, 4)
        self.maxpool1 = nn.MaxPool2d(2)
        self.bn1 = nn.BatchNorm2d(conv1_filters)
        
        self.conv2 = nn.Conv2d(conv1_filters, conv2_filters, 2)
        self.maxpool2 = nn.MaxPool2d(1)
        self.bn2 = nn.BatchNorm2d(conv2_filters)
        
        self.dropout = nn.Dropout(dropout_prob)
        last_feature_map_size = (img_size - 3) // 2 - 1
        self.fc1 = nn.Linear(conv2_filters * last_feature_map_size**2, fc1_units)
        self.fc2 = nn.Linear(fc1_units, n_classes)

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.bn1(x)
        x = self.maxpool1(x)
        x = self.dropout(x)
        
        x = self.relu(self.conv2(x))
        x = self.bn2(x)
        x = self.maxpool2(x) 
        
        x = self.dropout(x)
        x = x.view(x.size(0), -1) 
        x = self.relu(self.fc1(x)) 
        x = self.fc2(x) 
        return x

In [None]:
@dataclass
class Config:
    train_dir: str = os.path.join(DATA_DIR, 'train')
    test_dir: str = os.path.join(DATA_DIR, 'test')
    annotation_path: str = os.path.join(DATA_DIR, 'train_labels.csv')
    n_classes: int = 2
    img_size: int = 32
    n_workers: int = 4
    early_stopping_rounds: int = 3
    n_folds: int = 5
        
        
class Pipeline:
    def __init__(self, config):
        self.args = config
        self.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
        self.annotations = None
        self.model = None
        self.load_model()
        self.preaugment_transform = [
            LoadDicomd(keys="image", img_size=self.args.img_size),
        ]
        self.augment_transform = [] 
        self.postaugment_transform = [
            ToTensord(keys="image", dtype=torch.float),
            ToTensord(keys="label", dtype=torch.int64),
        ]
        
    def load_annotations(self):
        self.annotations = pd.read_csv(self.args.annotation_path)
        # exclude 3 cases
        self.annotations = self.annotations[~self.annotations['BraTS21ID'].isin([109, 123, 709])]
        self.annotations = self.annotations.reset_index(drop=True)
        skf = StratifiedKFold(n_splits=self.args.n_folds, shuffle=True, random_state=42)
        folds = skf.split(self.annotations['BraTS21ID'].values, self.annotations['MGMT_value'].values)
        for i, (train_indices, val_indices) in enumerate(folds):
            self.annotations.loc[val_indices, 'fold'] = i
        self.annotations['fold'] = self.annotations['fold'].astype(int)
    
    def load_model(self, weights_path=None):
        self.model = Simple2dCNN(input_channels=1, 
                                 n_classes=self.args.n_classes,
                                 img_size=self.args.img_size).to(self.device)
        if weights_path:
            weights = torch.load(weights_path, map_location=self.device)
            self.model.load_state_dict(weights)
        return self.model
        
    def prepare_datasets(self, mri_types, fold, cache_rate):
        
        train_ids = self.annotations[self.annotations['fold']!=fold]['BraTS21ID'].values.tolist()
        test_ids = self.annotations[self.annotations['fold']!=fold]['BraTS21ID'].values.tolist()
        
        train_transform = Compose(
            self.preaugment_transform +
            self.augment_transform +
            self.postaugment_transform
        )
        val_transform = Compose(
            self.preaugment_transform +
            self.postaugment_transform
        )
        
        
        train_ds = BrainTumorDataset(root_dir=self.args.train_dir, 
                                     patient_ids=train_ids, 
                                     mri_types=mri_types,  
                                     annotations=self.annotations,
                                     transform=train_transform,
                                     section='train',
                                     cache_rate=cache_rate,
                                     num_workers=self.args.n_workers)
        val_ds = BrainTumorDataset(root_dir=self.args.train_dir, 
                                   patient_ids=train_ids, 
                                   mri_types=mri_types,  
                                   annotations=self.annotations,
                                   transform=val_transform,
                                   section='val',
                                   cache_rate=cache_rate,
                                   num_workers=self.args.n_workers)
        test_ds = BrainTumorDataset(root_dir=self.args.train_dir, 
                                           patient_ids=test_ids, 
                                           mri_types=mri_types, 
                                           annotations=self.annotations, 
                                           transform=val_transform,
                                           section='val',
                                           cache_rate=cache_rate,
                                           num_workers=self.args.n_workers)
        return train_ds, val_ds, test_ds
    
    def prepare_test_dataset(self, mri_types, cache_rate):
        test_transform = Compose(
            self.preaugment_transform +
            self.postaugment_transform
        )
        test_ids = [int(patient_id) for patient_id in os.listdir(self.args.test_dir)]
        test_ids = sorted(test_ids, key=lambda x: int(x))
        test_ds = BrainTumorDataset(root_dir=self.args.test_dir, 
                                    patient_ids=test_ids, 
                                    mri_types=mri_types, 
                                    annotations=None, 
                                    transform=test_transform,
                                    section=None,
                                    cache_rate=cache_rate,
                                    num_workers=self.args.n_workers)
        return test_ds
    
    def train_epoch(self, loader, loss_function, optimizer, verbose):
        self.model.train()
        summary_loss = AverageMeter()
        start = time.time()
        n = len(loader)
        patient_ids_all = []
        probabilities_all = []
        labels_all = []
        for step, batch_data in enumerate(loader):
            inputs, labels, patient_ids = (
                batch_data["image"].to(self.device), # (None, 1, 32, 32)
                batch_data["label"].to(self.device), # (None, )
                batch_data["patient_id"]
            )
            patient_ids_all.extend(patient_ids)
            
            labels_all.extend(labels.tolist())
            batch_size = inputs.size(0)
            
            optimizer.zero_grad()
            outputs = self.model(inputs) # (None, 2)
            loss = loss_function(outputs, labels)
            probabilities = F.softmax(outputs, dim=1)[:, 1].tolist()
            probabilities_all.extend(probabilities)
            loss.backward()
            optimizer.step()
            
            summary_loss.update(loss.item(), batch_size)
            if verbose:
                print('Train step {}/{}, loss: {:.5f}'.format(step + 1, n, 
                                                              summary_loss.avg), end='\r')
        elapsed_time = str(datetime.timedelta(seconds=time.time() - start))
        print('Train loss: {:.5f} - time: {}'.format(summary_loss.avg, elapsed_time))
        result = {
            'BraTS21ID': list(map(lambda x: x.item(), patient_ids_all)), 
            'probability': probabilities_all,
            'label': labels_all
        }
        result = pd.DataFrame(result)
        slice_auc = roc_auc_score(result['label'], result['probability'])
        result = result.groupby("BraTS21ID", as_index=False).mean()
        patient_auc = roc_auc_score(result['label'], result['probability'])
        print('Patient AUC: {:.5f} - Slice AUC: {:.5f}'.format(patient_auc, slice_auc))
        return summary_loss.avg, patient_auc
    
    def evaluate_epoch(self, loader, loss_function, verbose):
        self.model.eval()
        summary_loss = AverageMeter()
        start = time.time()
        n = len(loader)
        patient_ids_all = []
        probabilities_all = []
        labels_all = []
        with torch.no_grad():
            for step, batch_data in enumerate(loader):
                inputs, labels, patient_ids = (
                    batch_data["image"].to(self.device),
                    batch_data["label"].to(self.device), 
                    batch_data["patient_id"], 
                )
                batch_size = inputs.size(0)
                
                outputs = self.model(inputs) 
                loss = loss_function(outputs, labels)
                
                probabilities = F.softmax(outputs, dim=1)[:, 1].tolist()
                probabilities_all.extend(probabilities)
                labels_all.extend(labels.tolist())
                patient_ids_all.extend(patient_ids)
                
                summary_loss.update(loss.item(), batch_size)
                if verbose:
                    print('Val step {}/{}, loss: {:.5f}'.format(step + 1, n, 
                                                                summary_loss.avg), end='\r')
        elapsed_time = str(datetime.timedelta(seconds=time.time() - start))
        #print('Val loss: {:.5f} - time: {}'.format(summary_loss.avg, elapsed_time))
        result = {
            'BraTS21ID': list(map(lambda x: x.item(), patient_ids_all)), 
            'probability': probabilities_all,
            'label': labels_all
        }
        result = pd.DataFrame(result)
        slice_auc = roc_auc_score(result['label'], result['probability'])
        result = result.groupby("BraTS21ID", as_index=False).mean()
        patient_auc = roc_auc_score(result['label'], result['probability'])
        #print('Patient AUC: {:.5f} - Slice AUC: {:.5f}'.format(patient_auc, slice_auc))
        
        return summary_loss.avg, patient_auc, result
    
    def infer_epoch(self, loader, verbose):
        self.model.eval()
        start = time.time()
        n = len(loader)
        patient_ids_all = []
        probabilities_all = []
        with torch.no_grad():
            for step, batch_data in enumerate(loader):
                inputs, patient_ids = (
                    batch_data["image"].to(self.device), 
                    batch_data["patient_id"], 
                )
                batch_size = inputs.size(0)
                
                outputs = self.model(inputs) 
                
                probabilities = F.softmax(outputs, dim=1)[:, 1].tolist()
                probabilities_all.extend(probabilities)
                patient_ids_all.extend(patient_ids)
                if verbose:
                    print('Infer step {}/{}'.format(step + 1, n), end='\r')
        
        result = {
            'BraTS21ID': list(map(lambda x: x.item(), patient_ids_all)), 
            'probability': probabilities_all,
        }
        result = pd.DataFrame(result)
        result = result.groupby("BraTS21ID", as_index=False).mean()
        
        elapsed_time = str(datetime.timedelta(seconds=time.time() - start))
        print('Elapsed time: {}'.format(elapsed_time))
        
        return result
    
    def fit(self, train_ds, val_ds, test_ds, batch_size, epochs, lr, model_name, verbose):
        auc = History()
        loss = History()
        train_auc = History()
        train_loss = History()
        train_loader = DataLoader(train_ds, 
                                  batch_size=batch_size, 
                                  shuffle=True,
                                  num_workers=self.args.n_workers)
        val_loader = DataLoader(val_ds, 
                                batch_size=batch_size, 
                                shuffle=False,
                                num_workers=self.args.n_workers)
        test_loader = DataLoader(test_ds, 
                                        batch_size=batch_size, 
                                        shuffle=False,
                                        num_workers=self.args.n_workers)
        loss_function = nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
        
        current_metric = -np.inf
        current_loss = np.inf
        current_epoch = 1
        current_state_dict = None
        save_path = '{}.pth'
        for epoch in range(1, epochs + 1):
            print('\nEpoch {}/{}:'.format(epoch, epochs))
            trai_loss, trai_auc = self.train_epoch(train_loader, loss_function, optimizer, verbose)
            train_loss.add(trai_loss)
            train_auc.add(trai_auc)
            print(' Validation:')
            val_loss, val_metric, _ = self.evaluate_epoch(val_loader, loss_function, verbose)
            auc.add(val_metric)
            loss.add(val_loss)
            
            
#             if val_loss < current_loss:
            if val_metric > current_metric:
                print('Val AUC improved from {:.5f} to {:.5f}'.format(current_metric, val_metric))
                current_metric = val_metric
                current_loss = val_loss
                current_epoch = epoch
                current_state_dict = deepcopy(self.model.state_dict())
                
            elif (epoch - current_epoch) > self.args.early_stopping_rounds:
                print('Early stopping. Best model is epoch {}'.format(current_epoch))
                print('Val loss: {:.5f}, Val auc: {:.5f}'.format(current_loss, current_metric))
                print('Saving model...')
                torch.save(current_state_dict, 
                           save_path.format(model_name))
                break
            if epoch == epochs:
                print('Finished training. Best model is epoch {}'.format(current_epoch))
                print('Val loss: {:.5f}, Val auc: {:.5f}'.format(current_loss, current_metric))
                print('Saving model...')
                torch.save(current_state_dict, 
                           save_path.format(model_name))
        return auc.history, loss.history, train_loss.history, train_auc.history
                
    def evaluate(self, test_ds, batch_size, verbose):
        test_loader = DataLoader(test_ds, 
                                        batch_size=batch_size, 
                                        shuffle=False,
                                        num_workers=self.args.n_workers)
        loss_function = nn.CrossEntropyLoss()
        #print(' Test:')
        _, test_metric, test_result = self.evaluate_epoch(test_loader, 
                                                                        loss_function, 
                                                                        verbose)
        return test_metric, test_result
    
    def predict(self, test_ds, batch_size, verbose):
        test_loader = DataLoader(test_ds, 
                                 batch_size=batch_size, 
                                 shuffle=False,
                                 num_workers=self.args.n_workers)
        test_result = self.infer_epoch(test_loader, verbose)
        return test_result

In [None]:
img_size = 128
batch_size = 32
n_workers = 8
early_stopping_rounds = 3
n_folds = 4
epochs = 50
lr = 1e-4

In [None]:
args = Config(img_size=img_size, 
              n_workers=n_workers, 
              early_stopping_rounds=early_stopping_rounds,
              n_folds=n_folds)
pipeline = Pipeline(args)
model = pipeline.load_model()
print(model)

In [None]:
pipeline.load_annotations()
auc, loss, t_loss, t_auc = [], [], [], []
for mri in MRI_TYPES:
    for fold in range(n_folds):
        print(f'### Train {mri} on fold {fold}: ###')
        train_ds, val_ds, test_ds = pipeline.prepare_datasets(mri_types=[mri], 
                                                                     fold=fold,
                                                                     cache_rate=1.0)
        pipeline.load_model()
        auc_, loss_, t_loss_, t_auc_ = pipeline.fit(train_ds, val_ds, test_ds,
                     batch_size=batch_size, epochs=epochs, lr=lr, 
                     model_name=f'{"_".join(mri)}_fold{fold}',
                     verbose=True)
        auc.append(auc_)
        t_auc.append(t_auc_)
        loss.append(loss_)
        t_loss.append(t_loss_)

In [None]:
def roc(fpr, tpr, roc_auc, mri):
    plt.figure()
    lw = 2
    plt.plot(
        fpr,
        tpr,
        color="darkorange",
        lw=lw,
        label="ROC curve (area = %0.2f)" % roc_auc,
    )
    plt.plot([0, 1], [0, 1], color="navy", lw=lw, linestyle="--")
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("ROC Curve for {} test dataset".format(mri))
    plt.legend(loc="lower right")
    plt.show()

In [None]:
sns.set(rc={'figure.figsize':(20,13)})
plot = sns.lineplot(data = auc[0], label = 'val AUC').set_title(MRI_TYPES[0])
plot = sns.lineplot(data = loss[0], label = 'val Loss')
plot = sns.lineplot(data = t_auc[0], label = 'train AUC')
plot = sns.lineplot(data = t_loss[0], label = 'train Loss')
a = plot.set(xlabel='Epochs', ylabel='Value')

In [None]:
i=4
sns.set(rc={'figure.figsize':(20,13)})
plot = sns.lineplot(data = auc[i], label = 'val AUC').set_title(MRI_TYPES[1])
plot = sns.lineplot(data = loss[i], label = 'val Loss')
plot = sns.lineplot(data = t_auc[i], label = 'train AUC')
plot = sns.lineplot(data = t_loss[i], label = 'train Loss')
a = plot.set(xlabel='Epochs', ylabel='Value')

In [None]:
i=8
sns.set(rc={'figure.figsize':(20,13)})
plot = sns.lineplot(data = auc[i], label = 'val AUC').set_title(MRI_TYPES[2])
plot = sns.lineplot(data = loss[i], label = 'val Loss')
plot = sns.lineplot(data = t_auc[i], label = 'train AUC')
plot = sns.lineplot(data = t_loss[i], label = 'train Loss')
a = plot.set(xlabel='Epochs', ylabel='Value')

In [None]:
i=12
sns.set(rc={'figure.figsize':(20,13)})
plot = sns.lineplot(data = auc[i], label = 'val AUC').set_title(MRI_TYPES[3])
plot = sns.lineplot(data = loss[i], label = 'val Loss')
plot = sns.lineplot(data = t_auc[i], label = 'train AUC')
plot = sns.lineplot(data = t_loss[i], label = 'train Loss')
a = plot.set(xlabel='Epochs', ylabel='Value')

In [None]:
sns.set(rc={'figure.figsize':(20,13)})
plot = sns.lineplot(data = auc, label = 'AUC')
plot = sns.lineplot(data = loss, label = 'Loss')
a = plot.set(xlabel='Epochs', ylabel='Value')

In [None]:
mri_types = MRI_TYPES
for mri in mri_types:
    metrics = []
    results = []
    find_weight = lambda x: [w for w in os.listdir() if x in w][0]
    weights_paths = [f'{"_".join(mri)}_fold{fold}.pth' for fold in range(n_folds)]
    weights_paths = [find_weight(x) for x in weights_paths]
    for fold, weights_path in enumerate(weights_paths):
        #print(f'### Evaluate {mri_types} on fold {fold}: ###')
        _, _, test_ds = pipeline.prepare_datasets(mri_types=mri_types, 
                                                         fold=fold,
                                                         cache_rate=0.0)
        pipeline.load_model(weights_path)
        val_metric, val_result = pipeline.evaluate(test_ds, batch_size=batch_size, verbose=False)
        metrics.append(val_metric)
        results.append(val_result)
    results = pd.concat(results, ignore_index=True)
    mean_auc = np.mean(metrics)
    oof_auc = roc_auc_score(results['label'], results['probability'])
    f1 = f1_score(results['label'].astype('float'), results['probability'].apply(lambda x: 0 if x <0.5 else 1),average='binary')
    print('---')
    print(f'{mri} test result:')
    print(' Mean AUC: {:.5f}'.format(mean_auc))
    print(' Mean F1_score: {:.5f}'.format(f1))
    fpr, tpr, _ = roc_curve(results['label'], results['probability'])
    roc(fpr, tpr, mean_auc, mri)
    print('---')