In [None]:
INFER = False
KAGGLE = False

In [None]:
import os
import sys 
import json
import glob
import random
import collections
import time
import re
import numpy as np
import pandas as pd
import pydicom
from pydicom.pixel_data_handlers.util import apply_voi_lut
import cv2
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
import torch
from torch import nn
from torch.utils import data as torch_data
from sklearn import model_selection as sk_model_selection
from torch.nn import functional as torch_functional
import torch.nn.functional as F
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import roc_auc_score

IIID_PATH = '../input/efficientnetpyttorch3d/EfficientNet-PyTorch-3D' if KAGGLE else './EfficientNet-PyTorch-3D'
sys.path.append(IIID_PATH)
from efficientnet_pytorch_3d import EfficientNet3D

warnings.filterwarnings('ignore', category=UserWarning) 
os.environ['CUDA_VISIBLE_DEVICES'] = '0' if KAGGLE else '1'
if torch.cuda.is_available():
    DEVICE = torch.device('cuda')
    print('GPU is available')
else:
    DEVICE = torch.device('cpu')
    print('CPU is used')

# Config

In [None]:
VER = 'v19'
DATA_PATH = '../input/rsna-miccai-brain-tumor-radiogenomic-classification' if KAGGLE else './data'
MDLS_PATH = f'../input/brain-models-{VER}' if KAGGLE else f'./models_{VER}'
MRI_TYPES = ['FLAIR', 'T1w', 'T1wCE', 'T2w']
if INFER:
    with open(f'{MDLS_PATH}/config.json', 'r') as file:
        CONFIG = json.load(file)
    print('config loaded:', CONFIG)
else:
    CONFIG = {
        'device': 'cuda' if torch.cuda.is_available() else 'cpu',
        'batch_size': 8,
        'img_size': 300, # 224, 240, 260, 300, 380, 456, 528, 600
        'num_images': 64,
        'bbone': 'efficientnet-b3',
        'auc': False,
        'folds': 5,
        'epochs': 50,
        'lr': 1e-3,
        'patience': 10,
        'seed': 2020
    }
    if not os.path.exists(MDLS_PATH):
        os.mkdir(MDLS_PATH)
    with open(f'{MDLS_PATH}/config.json', 'w') as file:
        json.dump(CONFIG, file)

def seed_all(seed=0):
    np.random.seed(seed)
    random_state = np.random.RandomState(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)
    return random_state    

random_state = seed_all(CONFIG['seed'])
start_time = time.time()

# Utils

In [None]:
def load_dicom_image(path, img_size=256, voi_lut=True, rotate=0):
    dicom = pydicom.read_file(path)
    data = dicom.pixel_array
    if voi_lut:
        data = apply_voi_lut(dicom.pixel_array, dicom)
    else:
        data = dicom.pixel_array
    if rotate > 0:
        rot_choices = [0, 
                       cv2.ROTATE_90_CLOCKWISE, 
                       cv2.ROTATE_90_COUNTERCLOCKWISE, 
                       cv2.ROTATE_180]
        data = cv2.rotate(data, rot_choices[rotate])
    data = cv2.resize(data, (img_size, img_size))
    return data.astype(np.float32())

def natural_sort(l): 
    convert = lambda text: int(text) if text.isdigit() else text.lower()
    alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)]
    return sorted(l, key=alphanum_key)

def load_dicom_images_3d(scan_id, num_imgs=64, img_size=256, 
                         mri_type='FLAIR', split='train', rotate=0):
    files = natural_sort(
        glob.glob(f'{DATA_PATH}/{split}/{scan_id}/{mri_type}/*.dcm'), 
    )
    middle, num_imgs2 = len(files) // 2, num_imgs // 2
    p1 = max(0, middle - num_imgs2)
    p2 = min(len(files), middle + num_imgs2)
    img3d = np.stack([load_dicom_image(f, img_size, rotate=rotate) 
                      for f in files[p1:p2]]).T 
    if img3d.shape[-1] < num_imgs:
        n_zero = np.zeros((img_size, img_size, num_imgs - img3d.shape[-1]), 
                          dtype=np.float32())
        img3d = np.concatenate((img3d,  n_zero), axis=-1)
    if np.min(img3d) < np.max(img3d):
        img3d = img3d - np.min(img3d)
        img3d = img3d / np.max(img3d)
    return np.expand_dims(img3d, 0)

img = load_dicom_images_3d('00000')
print(img.shape)
print(np.min(img), np.max(img), np.mean(img), np.median(img))

# Data load and models

In [None]:
samples_to_exclude = [109, 123, 709]
df = pd.read_csv(f'{DATA_PATH}/train_labels.csv')
print('loaded:', df.shape, df['MGMT_value'].sum())
df = df[~df.BraTS21ID.isin(samples_to_exclude)]
df.reset_index(inplace=True)
print('cleaned:', df.shape, df['MGMT_value'].sum())
display(df.head())
skf = StratifiedKFold(CONFIG['folds'], shuffle=True, random_state=CONFIG['seed'])
df['fold'] = -1
for i, (train_idxs, val_idxs) in enumerate(skf.split(df, df['MGMT_value'])):
    df.loc[val_idxs, 'fold'] = i
display(df.head())

In [None]:
class BrainDataset(torch_data.Dataset):
    def __init__(self, paths, img_size, targets=None, mri_type=None, 
                 lbl_smth=.001, split='train', aug=False, albu=None):
        self.paths = paths
        self.img_size = img_size
        self.targets = targets
        self.mri_type = mri_type
        self.lbl_smth = lbl_smth
        self.split = split
        self.aug = aug
        self.albu = albu
          
    def __len__(self):
        return len(self.paths)
    
    def __getitem__(self, index):
        scan_id = self.paths[index]
        if self.targets is None:
            data = load_dicom_images_3d(
                str(scan_id).zfill(5), 
                img_size=self.img_size,
                mri_type=self.mri_type[index], 
                split=self.split
            )
        else:
            if self.aug:
                rotation = np.random.randint(0, 4)
            else:
                rotation = 0
            data = load_dicom_images_3d(
                str(scan_id).zfill(5), 
                img_size=self.img_size,
                mri_type=self.mri_type[index], 
                split='train', 
                rotate=rotation
            )
            if self.albu:
                frozen = albu(image=data[0, :, :, 0])
                data[0, :, :, 0] = frozen['image']
                for i in range(1, CONFIG['num_images']):
                    data[0, :, :, i] = A.ReplayCompose.replay(
                        frozen['replay'], 
                        image=data[0, :, :, i]
                    )['image']
        if self.targets is None:
            return {'X': torch.tensor(data).float(), 'id': scan_id}
        else:
            y = torch.tensor(
                abs(self.targets[index] - self.lbl_smth), 
                dtype=torch.float
            )
            return {'X': torch.tensor(data).float(), 'y': y}

In [None]:
import albumentations as A
albu = A.ReplayCompose([
    A.OneOf([
        A.RandomBrightnessContrast(
            brightness_limit=.2, 
            contrast_limit=.2, 
            p=1), 
        A.RandomGamma(p=1)
    ], p=.25),
    A.Blur(blur_limit=3, p=.25),
    A.GaussNoise(.002, p=.25),
    A.OneOf([
           A.ElasticTransform(
               alpha=120, 
               sigma=120 * .05, 
               alpha_affine=120 * .03, 
               p=.5),
           A.GridDistortion(p=.5),
       ], p=.25),
    A.ShiftScaleRotate(p=.5)
])

In [None]:
df.loc[:, 'MRI_Type'] = 'FLAIR'
dataset_show = BrainDataset(
    paths=df['BraTS21ID'].values, 
    img_size=CONFIG['img_size'],
    targets=df['MGMT_value'].values, 
    mri_type=df['MRI_Type'].values,
    aug=True,
    albu=albu
)
data_show = dataset_show.__getitem__(0)

n_imgs = 8
print('test X: ', data_show['X'].shape)
print('test y: ', data_show['y'].shape)
fig, axes = plt.subplots(figsize=(16, 4), nrows=1, ncols=n_imgs)
for j in range(n_imgs):
    axes[j].imshow(data_show['X'][0][:, :, j].numpy())
    axes[j].set_title(data_show['y'].numpy())
plt.show()

In [None]:
class BrainModel(nn.Module): 
    def __init__(self, bbone='efficientnet-b0'):
        super().__init__()
        self.net = EfficientNet3D.from_name(
            bbone, 
            override_params={'num_classes': 2}, 
            in_channels=1
        )
        n_features = self.net._fc.in_features
        self.net._fc = nn.Linear(
            in_features=n_features, 
            out_features=1, 
            bias=True
        )
    
    def forward(self, x):
        out = self.net(x)
        return out

# Training

In [None]:
class BrainTrainer:
    def __init__(self, model, device, optimizer, scheduler, 
                 criterion, auc_flag=True):
        self.model = model
        self.device = device
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.criterion = criterion
        self.auc_flag = auc_flag
        if auc_flag:
            self.best_val_auc = 0
        else:
            self.best_val_loss = np.inf
        self.val_losses = []
        self.train_losses = []
        self.val_auc = []
        
        self.lastmodel = None
        
    def fit(self, epochs, train_loader, val_loader, save_path, max_patience):     
        n_patience = 0
        for n_epoch in range(1, epochs + 1):
            self.info_message('EPOCH: {}', n_epoch)
            train_loss, train_time = self.train_epoch(train_loader)
            val_loss, val_auc, val_time = self.val_epoch(val_loader)
            self.train_losses.append(train_loss)
            self.val_losses.append(val_loss)
            self.val_auc.append(val_auc)
            self.info_message(
                'epoch train: {} | loss: {:.4f} | time: {:.2f} sec',
                n_epoch, train_loss, train_time
            )
            self.info_message(
                'epoch val: {} | loss: {:.4f} | auc: {:.4f} | time: {:.2f} sec',
                n_epoch, val_loss, val_auc, val_time
            )
            if self.auc_flag:
                if self.best_val_auc < val_auc: 
                    self.save_model(n_epoch, save_path, val_loss, val_auc)
                    self.info_message(
                        'val auc improved {:.2f} -> {:.2f} | saved model to "{}"', 
                        self.best_val_auc, val_auc, self.lastmodel
                    )
                    self.best_val_auc = val_auc
                    n_patience = 0
                else:
                    n_patience += 1
            else:
                if self.best_val_loss > val_loss: 
                    self.save_model(n_epoch, save_path, val_loss, val_auc)
                    self.info_message(
                        'val loss improved {:.4f} -> {:.4f} | saved model to "{}"', 
                        self.best_val_loss, val_loss, self.lastmodel
                    )
                    self.best_val_loss = val_loss
                    n_patience = 0
                else:
                    n_patience += 1
            if n_patience >= max_patience:
                self.info_message(
                    '\nno improvement for last {} epochs', 
                    n_patience
                )
                break
        history = {
            'train losses': self.train_losses, 
            'val losses': self.val_losses, 
            'val AUC': self.val_auc
        }
        return history
            
    def train_epoch(self, train_loader):
        self.model.train()
        scaler = torch.cuda.amp.GradScaler()
        t = time.time()
        sum_loss = 0
        for step, batch in enumerate(train_loader, 1):
            with torch.cuda.amp.autocast():
                X = batch['X'].to(self.device)
                targets = batch['y'].to(self.device)
                self.optimizer.zero_grad()
                outputs = self.model(X).squeeze(1)
                loss = self.criterion(outputs, targets)
                scaler.scale(loss).backward()
                scaler.step(self.optimizer)
                scaler.update()
                self.scheduler.step()
                sum_loss += loss.detach().item()
                self.info_message(
                    'train step {}/{} | train loss: {:.4f}           ',
                    step, len(train_loader), sum_loss / step, end='\r'
                )
        return sum_loss / len(train_loader), int(time.time() - t)
    
    def val_epoch(self, val_loader):
        self.model.eval()
        t = time.time()
        sum_loss = 0
        y_all = []
        outputs_all = []
        for step, batch in enumerate(val_loader, 1):
            with torch.no_grad():
                X = batch['X'].to(self.device)
                targets = batch['y'].to(self.device)
                outputs = self.model(X).squeeze(1)
                loss = self.criterion(outputs, targets)
                sum_loss += loss.detach().item()
                y_all.extend(batch["y"].tolist())
                outputs_all.extend(outputs.tolist())
            self.info_message(
                'val step {}/{} | val loss: {:.4f}               ', 
                step, len(val_loader), sum_loss / step, end='\r'
            )
        y_all = [1 if x > 0.5 else 0 for x in y_all]
        auc = roc_auc_score(y_all, outputs_all)
        return sum_loss / len(val_loader), auc, int(time.time() - t)
    
    def save_model(self, n_epoch, save_path, loss, auc):
        self.lastmodel = f'{MDLS_PATH}/{save_path}-e{n_epoch}-loss{loss:.3f}-auc{auc:.3f}.pth'
        dict_save = {
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'n_epoch': n_epoch,
        }
        if self.auc_flag:
            dict_save[ 'best_val_auc'] = self.best_val_auc
        else:
            dict_save[ 'best_val_loss'] = self.best_val_loss
        torch.save(dict_save, self.lastmodel)
    
    def display_plots(self, mri_type):
        fig, axes = plt.subplots(figsize=(16, 4), nrows=1, ncols=2)
        axes[0].set_title(f'{mri_type}: training and validation losses')
        axes[0].plot(self.val_losses, label='val')
        axes[0].plot(self.train_losses, label='train')
        axes[0].set_xlabel('iterations')
        axes[0].set_ylabel('loss')
        axes[0].legend()
        axes[1].set_title(f'{mri_type}: validation ROC-AUC')
        axes[1].plot(self.val_auc, label='val')
        axes[1].set_xlabel('iterations')
        axes[1].set_ylabel('AUC')
        axes[1].legend()
        plt.show()
        plt.close()
    
    @staticmethod
    def info_message(message, *args, end='\n'):
        print(message.format(*args), end=end)

In [None]:
def train_mri_type(df_train, df_val, mri_type, fold, device, 
                   epochs, patience, batch_size):
    if mri_type == 'all':
        train_list = []
        val_list = []
        for mri_type in mri_types:
            df_train.loc[:, 'MRI_Type'] = mri_type
            train_list.append(df_train.copy())
            df_val.loc[:, 'MRI_Type'] = mri_type
            val_list.append(df_val.copy())
        df_train = pd.concat(train_list)
        df_val = pd.concat(val_list)
    else:
        df_train.loc[:, 'MRI_Type'] = mri_type
        df_val.loc[:, 'MRI_Type'] = mri_type
    print('=' * 20, f'MRI_Type {mri_type} | FOLD {fold}', '=' * 20)
    print('train:', df_train.shape, '| val:', df_val.shape)
    display(df_train.head())
    train_dataset = BrainDataset(
        paths=df_train['BraTS21ID'].values, 
        img_size=CONFIG['img_size'],
        targets=df_train['MGMT_value'].values, 
        mri_type=df_train['MRI_Type'].values,
        aug=True,
        albu=albu
    )
    val_dataset = BrainDataset(
        paths=df_val['BraTS21ID'].values, 
        img_size=CONFIG['img_size'],
        targets=df_val['MGMT_value'].values,
        mri_type=df_val['MRI_Type'].values
    )
    train_loader = torch_data.DataLoader(
        train_dataset,
        batch_size=CONFIG['batch_size'],
        shuffle=True,
        num_workers=8
    )
    val_loader = torch_data.DataLoader(
        val_dataset, 
        batch_size=CONFIG['batch_size'],
        shuffle=False,
        num_workers=8
    )
    model = BrainModel(bbone=CONFIG['bbone'])
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG['lr'])
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, 
        CONFIG['epochs']
    )
    criterion = torch_functional.binary_cross_entropy_with_logits
    trainer = BrainTrainer(
        model, 
        device, 
        optimizer, 
        scheduler,
        criterion,
        auc_flag=CONFIG['auc']
    )
    history = trainer.fit(
        epochs, 
        train_loader, 
        val_loader, 
        save_path=f'{mri_type}-f{fold}', 
        max_patience=patience
    )
    trainer.display_plots(mri_type)
    with open(f'{MDLS_PATH}/history_{mri_type}_f{fold}.json', 'w') as file:
        json.dump(history, file)
    return trainer.lastmodel

if INFER:
    with open(f'{MDLS_PATH}/modelfiles.json', 'r') as file:
        modelfiles = json.load(file)
    print("model's list loaded:", modelfiles)
else:
    modelfiles = []
if not modelfiles:
    for fold_num in range(CONFIG['folds']): 
        train_idxs = np.where((df['fold'] != fold_num))[0]
        val_idxs = np.where((df['fold'] == fold_num))[0]
        df_train = df.loc[train_idxs]
        df_val = df.loc[val_idxs]
        for m in MRI_TYPES:
            modelfiles.append(train_mri_type(
                df_train, 
                df_val, 
                m, 
                fold_num,
                device=CONFIG['device'], 
                epochs=CONFIG['epochs'],
                patience=CONFIG['patience'],
                batch_size=CONFIG['batch_size']
            ))
    print(modelfiles)
    with open(f'{MDLS_PATH}/modelfiles.json', 'w') as file:
        json.dump(modelfiles, file)

In [None]:
modelfiles = [f'{MDLS_PATH}/{x.split("/")[-1]}' for x in modelfiles]
allmodelfiles = [f'{MDLS_PATH}/{x}' for x in os.listdir(MDLS_PATH) if '.pth' in x]
for file_path in allmodelfiles:
    if file_path not in modelfiles:
        os.remove(file_path)

# Inference

In [None]:
def infer(model_file, df, mri_type, split, device, batch_size):
    print('PREDICT:', model_file, mri_type, df.shape)
    df.loc[:, 'MRI_Type'] = mri_type
    pred_dataset = BrainDataset(
        paths=df.index.values, 
        img_size=CONFIG['img_size'],
        mri_type=df['MRI_Type'].values,
        split=split
    )
    pred_loader = torch_data.DataLoader(
        pred_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=8,
    )
    model = BrainModel(bbone=CONFIG['bbone'])
    model.to(device)
    checkpoint = torch.load(model_file)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    y_pred = []
    ids = []
    for i, batch in enumerate(pred_loader, start=1):
        print(f'{i}/{len(pred_loader)}', end='\r')
        with torch.no_grad():
            tmp_pred = torch.sigmoid(
                model(batch['X'].to(device))
            ).cpu().numpy().squeeze()
            if tmp_pred.size == 1:
                y_pred.append(tmp_pred)
            else:
                y_pred.extend(tmp_pred.tolist())
            ids.extend(batch['id'].numpy().tolist()) 
    df_pred = pd.DataFrame({'BraTS21ID': ids, 'MGMT_value': y_pred}) 
    df_pred = df_pred.set_index('BraTS21ID')
    return df_pred

# Validation

In [None]:
df_val = df.loc[val_idxs]

In [None]:
if INFER:
    print('infer mode, no validation')
else:
    df_val = df_val.set_index('BraTS21ID')
    df_val['MGMT_pred'] = 0
    for i, m in enumerate(modelfiles):
        mtype = MRI_TYPES[i - len(MRI_TYPES) * (i // len(MRI_TYPES))]
        preds = infer(m, df_val, mtype, 'train', 
                      CONFIG['device'], CONFIG['batch_size'])
        df_val['MGMT_pred'] += preds['MGMT_value']
    df_val['MGMT_pred'] /= len(modelfiles)
    auc = roc_auc_score(df_val['MGMT_value'], df_val['MGMT_pred'])
    print(f'validation ensemble AUC: {auc:.4f}')
    sns.displot(df_val['MGMT_pred'])
    plt.show()

# Submission

In [None]:
df_subm = pd.read_csv(f'{DATA_PATH}/sample_submission.csv', 
                      index_col='BraTS21ID')
df_subm['MGMT_value'] = 0
for i, m in enumerate(modelfiles):
    mtype = MRI_TYPES[i - len(MRI_TYPES) * (i // len(MRI_TYPES))]
    preds = infer(m, df_subm, mtype, 'test',
                  CONFIG['device'], CONFIG['batch_size'])
    df_subm['MGMT_value'] += preds['MGMT_value']
df_subm['MGMT_value'] /= len(modelfiles)
display(df_subm.head())
df_subm['MGMT_value'].to_csv('submission.csv')

In [None]:
sns.displot(df_subm['MGMT_value'])
plt.show()