## Use stacked images (3D) and MedicalNet model + volumentations


In this notebook you will find:

* Use of volumentations with PyTorch
* Training with pretrained MedicalNet model
* k-fold cross validation

Acknowledgements:

- https://www.kaggle.com/ihelon/brain-tumor-eda-with-animations-and-modeling
- https://www.kaggle.com/furcifer/torch-efficientnet3d-for-mri-no-train
- https://github.com/shijianjian/EfficientNet-PyTorch-3D
- https://www.kaggle.com/rluethy/efficientnet3d-with-one-mri-type
    
    
Use models with only one MRI type, then ensemble the 4 models 

V14: add image rotation augmentation

TODO: 

* measure time volumentations -> DONE
* use efficietnet3d with in_channels=4 -> UNDONE
* natural sorting, sigmoid -> DONE
* Review configuration: batch_norm, architecture, etc.
* Add augmentation + external data
* Do EDA to see if possible improve data quality.
* Unfreeze layers in medical net


Results



In [None]:
# thanks to https://www.kaggle.com/ipythonx
!pip install ../input/keras-3d-model-and-3d-augmentation/volumentations_3D-1.0.3-py3-none-any.whl -q

In [None]:
import os
import sys 
import json
import glob
import random
import collections
import time
import re
import torchvision
import numpy as np
import pandas as pd
import pydicom
from pydicom.pixel_data_handlers.util import apply_voi_lut
import cv2
import matplotlib.pyplot as plt
import seaborn as sns

import torch
from torch import nn
from torch.utils import data as torch_data
from sklearn import model_selection as sk_model_selection
from torch.nn import functional as torch_functional
import torch.nn.functional as F
import glob

from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import roc_auc_score

In [None]:
data_directory = '/kaggle/input/rsna-miccai-brain-tumor-radiogenomic-classification'
input_monaipath = "/kaggle/input/monai-v060-deep-learning-in-healthcare-imaging/"


LEARNING_RATE=0.0005


In [None]:
if os.path.exists("../input/rsna-miccai-brain-tumor-radiogenomic-classification"):
    data_directory = '../input/rsna-miccai-brain-tumor-radiogenomic-classification'
    pytorch3dpath = "../input/efficientnetpyttorch3d/EfficientNet-PyTorch-3D"
    medicalpath = "../input/medicalnet" 
    sys.path.append(medicalpath)
else:
    data_directory = '/media/roland/data/kaggle/rsna-miccai-brain-tumor-radiogenomic-classification'
    pytorch3dpath = "EfficientNet-PyTorch-3D"
    
mri_types = ['FLAIR','T1w','T1wCE','T2w']
PRETRAINED_WEIGHTS='../input/medicalnet-pretrained-weights/resnet_34.pth'
SIZE = 256
NUM_IMAGES = 64
VOLUMENTATIONS = True # quite slow
EPOCHS = 6
FOLDS = 3
PATCH_SIZE = (SIZE, SIZE, NUM_IMAGES)
RESIZING_VOLUMENTATIONS = True
DATA_PATH = '/kaggle/input/rsna-processed-voxels-64x256x256-clahe'
sys.path.append(pytorch3dpath)


from efficientnet_pytorch_3d import EfficientNet3D
from models import resnet

## Functions to load images

In [None]:
split = 'train'
train_voxels = sorted(glob.glob(f"{DATA_PATH}/voxels/*/*.npy"))

df_train = pd.DataFrame(train_voxels, columns=['voxel_paths'])
df_train['BraTS21ID'] = df_train.voxel_paths.map(lambda path:path.split('/')[-1].strip('.npy'))
df_train['MRI_Type'] = df_train.voxel_paths.map(lambda path:path.split('/')[-2])
df_train_labels = pd.read_csv('../input/rsna-miccai-brain-tumor-radiogenomic-classification/train_labels.csv',
                             dtype={'BraTS21ID':np.object,
                                   'MGMT_value':np.int32})
df_train = df_train.set_index('BraTS21ID').join(df_train_labels.set_index('BraTS21ID'), on='BraTS21ID', how='left')
df_train = df_train.reset_index()
df_train.to_csv("/kaggle/working/df_train_meta.csv")
df_train




### Augmentations

In [None]:
from functools import partial
from volumentations import *

def get_augmentation(patch_size):
    return Compose([
        #Rotate((-5, 5), (0, 0), (0, 0), p=0.5), # slow
        #GridDropout(ratio=0., holes_number_y=10, p=1.0),
        RandomCropFromBorders(crop_value=0.25, p=0.3), # p=0.3
        #ElasticTransform((0, 0.15), interpolation=2, p=0.5), # slow
        #Resize(patch_size, interpolation=1, always_apply=True, p=1.0),
        Flip(0, p=0.5),
        Flip(1, p=0.5),
        RandomRotate90((0, 1), p=0.6),
        #GaussianNoise(var_limit=(0, 5), p=1.0), # slow
        RandomGamma(gamma_limit=(0.5, 1.5), p=0.7),
    ], p=1.0)

volume3D = get_augmentation((SIZE, SIZE, NUM_IMAGES))

def load_dicom_image(path, img_size=SIZE, voi_lut=True, rotate=0):
    dicom = pydicom.read_file(path)
    data = dicom.pixel_array
    if voi_lut:
        data = apply_voi_lut(dicom.pixel_array, dicom)
    else:
        data = dicom.pixel_array
        
    if rotate > 0:
        rot_choices = [0, cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE, cv2.ROTATE_180]
        data = cv2.rotate(data, rot_choices[rotate])
        
    data = cv2.resize(data, (img_size, img_size))
    return data


def natural_sort(l): 
    convert = lambda text: int(text) if text.isdigit() else text.lower()
    alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)]
    return sorted(l, key=alphanum_key)


def load_normalized_images_3d(path, target_size=(64, 256, 256)):
    
    voxel = np.load(path).astype(np.float32) / 255.0
    voxel = np.reshape(voxel, target_size)
    voxel = np.transpose(voxel, (1, 2, 0))
    voxel = np.expand_dims(voxel,0)
    return voxel

def load_dicom_images_3d(scan_id, num_imgs=NUM_IMAGES, img_size=SIZE, 
                         mri_type="FLAIR", split="train", 
                         rotate=0, volumentations=True):
    
    files = natural_sort(glob.glob(f"{data_directory}/{split}/{scan_id}/{mri_type}/*.dcm"))

    middle = len(files)//2
    num_imgs2 = num_imgs//2
    p1 = max(0, middle - num_imgs2)
    p2 = min(len(files), middle + num_imgs2)
    img3d = np.stack([load_dicom_image(f, rotate=rotate) for f in files[p1:p2]]).T 
    
    if VOLUMENTATIONS and (rotate != 0): # rotate = 0 means test set
        data = volume3D(**{"image":img3d})
        data = data['image']

        if RESIZING_VOLUMENTATIONS:
            new_data = []
            for z in range(data.shape[2]):
                new_data.append(cv2.resize(data[:, :, z], (SIZE, SIZE)))
            img3d = np.transpose(np.array(new_data), (1, 2, 0))
        
    if img3d.shape[-1] < num_imgs: # Fill gaps
        n_zero = np.zeros((img_size, img_size, num_imgs - img3d.shape[-1]))
        img3d = np.concatenate((img3d,  n_zero), axis = -1)
        
    if np.min(img3d) < np.max(img3d): # normalize
        img3d = img3d - np.min(img3d)
        img3d = img3d / np.max(img3d)
    
    img3d = np.expand_dims(img3d,0)
        
    return img3d

initial = time.time()
a = load_normalized_images_3d(df_train.loc[0].voxel_paths)
#a = load_dicom_images_3d("00000", mri_type="T1wCE", rotate=1)

print(f'Time for reading a voxel: {time.time() - initial}')

print(a.shape)
print(np.min(a), np.max(a), np.mean(a), np.median(a))

In [None]:
plt.figure(figsize=(16, 5))
dataplot = a[0][:, :, :4]
for i in range(dataplot.shape[2]):
    plt.subplot(1, 4, i + 1)
    plt.imshow(dataplot[:, :, i], cmap="gray")
    plt.title("Sample", fontsize=16)
    plt.axis("off")
plt.show()

In [None]:
def plot_slices(num_rows, num_columns, width, height, data):
    """Plot a montage of 20 CT slices"""
    data = np.rot90(np.array(data))  
    data = np.transpose(data)
    data = np.reshape(data, (num_rows, num_columns, width, height))
    rows_data, columns_data = data.shape[0], data.shape[1]
    heights = [slc[0].shape[0] for slc in data]
    widths = [slc.shape[1] for slc in data[0]]
    fig_width = 12.0
    fig_height = fig_width * sum(heights) / sum(widths)
    f, axarr = plt.subplots(
        rows_data,
        columns_data,
        figsize=(fig_width, fig_height),
        gridspec_kw={"height_ratios": heights},
    )
    for i in range(rows_data):
        for j in range(columns_data):
            axarr[i, j].imshow(data[i][j], cmap="gray")
            axarr[i, j].axis("off")
    plt.subplots_adjust(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)
    plt.show()
# Visualize montage of slices.
# 5 rows and 10 columns for 100 slices of the CT scan.
plot_slices(5, 10, SIZE, SIZE, a[0][:, :, :50])

In [None]:
def set_seed(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True

SEED = 12
set_seed(SEED)

## train / test splits

In [None]:
display(df_train)

In [None]:
skf = StratifiedKFold(n_splits=FOLDS, random_state=SEED, shuffle=True)

patient_df = df_train.groupby('BraTS21ID').MGMT_value.max().reset_index()

print('Class Ratio:',sum(patient_df['MGMT_value'])/len(patient_df['MGMT_value']))

target = patient_df.loc[:,'MGMT_value']

fold_no = 1
train_fold_dict = {}
val_fold_dict = {}
train_indices = []
val_indices = [] 
for train_index, val_index in skf.split(patient_df, target):
    train = patient_df.loc[train_index,:]
    val = patient_df.loc[val_index,:]
    train_indices.append(df_train.reset_index().merge(train, on='BraTS21ID', how="right").set_index('index').index)
    val_indices.append(df_train.reset_index().merge(val, on='BraTS21ID', how="right").set_index('index').index) # add indices from the general dataframe with all MRI types
    print('Fold',str(fold_no),'Class Ratio:',sum(patient_df.iloc[val_index]['MGMT_value'])/len(patient_df.iloc[val_index]),
          ',\t len train, val, sum:',len(train), len(val), len(train)+len(val))
    fold_no += 1



In [None]:
assert(df_train.iloc[val_indices[0],:].groupby('BraTS21ID').voxel_paths.count().nunique() == 1)

In [None]:
assert sum([len(x) for x in val_indices]) == df_train.shape[0]

In [None]:
train_indices

## Model and training classes

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(DEVICE)

In [None]:
class Dataset(torch_data.Dataset):
    def __init__(self, paths, targets=None, mri_type=None, scan_ids=None, label_smoothing=0.01, split="train", augment=False):
        self.paths = paths
        self.targets = targets
        self.mri_type = mri_type
        self.scan_ids = scan_ids
        self.label_smoothing = label_smoothing
        self.split = split
        self.augment = augment
          
    def __len__(self):
        return len(self.paths)
    
    def __getitem__(self, index):
        path = self.paths[index]
        scan_id = self.scan_ids[index]
        data = load_normalized_images_3d(path)        
        if self.targets is None:
            return {"X": torch.tensor(data).float(), "id": scan_id}
        else:
            y = torch.tensor(abs(self.targets[index]-self.label_smoothing), dtype=torch.float)
            return {"X": torch.tensor(data).float(), "y": y}


In [None]:
from models import resnet
from collections import OrderedDict

class MedicalNet(nn.Module):

    def __init__(self, path_to_weights=None, device='cuda'):
        super(MedicalNet, self).__init__()
        self.model = resnet.resnet34(sample_input_D=1, sample_input_H=256, sample_input_W=256, num_seg_classes=2)        
        if path_to_weights:
            net_dict = self.model.state_dict()
            pretrained_weights = torch.load(path_to_weights, map_location=torch.device(device))
            pretrain_dict = {
                k.replace("module.", ""): v for k, v in pretrained_weights['state_dict'].items() if k.replace("module.", "") in net_dict.keys()
              }
            net_dict.update(pretrain_dict)
            self.model.load_state_dict(net_dict)
        self.model.conv_seg = nn.Sequential(OrderedDict([ 
            ('adapt1', nn.AdaptiveMaxPool3d(output_size=(1, 1, 1))),
            ('flatten1', nn.Flatten(start_dim=1)),
            ('dropout1', nn.Dropout(0.1))
        ]))
        self.model = self.model.to(device)
        self.fc = nn.Linear(512, 1).to(device)

    def forward(self, x):
        features = self.model(x)
        return self.fc(features)

class EfficientNetModel(nn.Module):
    def __init__(self, device='cuda'):
        super().__init__()
        self.net = EfficientNet3D.from_name("efficientnet-b0", override_params={'num_classes': 2}, in_channels=1)
        n_features = self.net._fc.in_features
        self.net._fc = nn.Linear(in_features=n_features, out_features=1, bias=True)
        self.net = self.net.to(device)
    
    def forward(self, x):
        out = self.net(x)
        return out
    

    

In [None]:
class Trainer:
    def __init__(
        self, 
        model, 
        device, 
        optimizer, 
        criterion
    ):
        self.model = model
        self.device = device
        self.optimizer = optimizer
        self.criterion = criterion
        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5)
        self.best_valid_score = np.inf
        self.n_patience = 0
        self.lastmodel = None
        
    def fit(self, epochs, train_loader, valid_loader, save_path, patience, kfold):        
        for n_epoch in range(1, epochs + 1):
            self.info_message("EPOCH: {}", n_epoch)
            
            train_loss, train_time = self.train_epoch(train_loader)
            valid_loss, valid_auc, valid_time = self.valid_epoch(valid_loader)
            
            self.info_message(
                "[Epoch Train: {}] loss: {:.4f}, time: {:.2f} s            ",
                n_epoch, train_loss, train_time
            )
            
            self.info_message(
                "[Epoch Valid: {}] loss: {:.4f}, auc: {:.4f}, time: {:.2f} s",
                n_epoch, valid_loss, valid_auc, valid_time
            )

            self.scheduler.step(valid_loss)
            # if True:
            # if self.best_valid_score < valid_auc: 
            if self.best_valid_score > valid_loss: 
                self.save_model(n_epoch, save_path, valid_loss, valid_auc, kfold)
                self.info_message(
                     "auc improved from {:.4f} to {:.4f}. Saved model to '{}'", 
                    self.best_valid_score, valid_loss, self.lastmodel
                )
                self.best_valid_score = valid_loss
                self.n_patience = 0
            else:
                self.n_patience += 1
            
            
    def train_epoch(self, train_loader):
        self.model.train()
        t = time.time()
        sum_loss = 0

        for step, batch in enumerate(train_loader, 1):
            X = batch["X"].to(self.device)
            targets = batch["y"].to(self.device)
            self.optimizer.zero_grad()
            outputs = self.model(X).squeeze(1)
            
            loss = self.criterion(outputs, targets)
            loss.backward()

            sum_loss += loss.detach().item()

            self.optimizer.step()
            message = 'Train Step {}/{}, train_loss: {:.4f}'
            self.info_message(message, step, len(train_loader), sum_loss/step, end="\r")
        
        return sum_loss/len(train_loader), int(time.time() - t)
    
    def valid_epoch(self, valid_loader):
        self.model.eval()
        t = time.time()
        sum_loss = 0
        y_all = []
        outputs_all = []

        for step, batch in enumerate(valid_loader, 1):
            with torch.no_grad():
                X = batch["X"].to(self.device)
                targets = batch["y"].to(self.device)

                outputs = self.model(X).squeeze(1)
                loss = self.criterion(outputs, targets)

                sum_loss += loss.detach().item()
                y_all.extend(batch["y"].tolist())
                outputs_all.extend(outputs.tolist())

            message = 'Valid Step {}/{}, valid_loss: {:.4f}'
            self.info_message(message, step, len(valid_loader), sum_loss/step, end="\r")
            
        y_all = [1 if x > 0.5 else 0 for x in y_all]
        auc = roc_auc_score(y_all, outputs_all)
        
        return sum_loss/len(valid_loader), auc, int(time.time() - t)
    
    def save_model(self, n_epoch, save_path, loss, auc, kfold):
        self.lastmodel = f"{save_path}-k{kfold}-e{n_epoch}-loss{loss:.3f}-auc{auc:.3f}.pth"
        torch.save(
            {
                "model_state_dict": self.model.state_dict(),
                "optimizer_state_dict": self.optimizer.state_dict(),
                "best_valid_score": self.best_valid_score,
                "n_epoch": n_epoch,
            },
            self.lastmodel,
        )
    
    @staticmethod
    def info_message(message, *args, end="\n"):
        print(message.format(*args), end=end)

## train models

In [None]:
# Select model

MODEL_TYPE = 'medicalnet' 
FEATURES_EXTRACTION = False

def create_model(device):
    if MODEL_TYPE == 'medicalnet':
        model = MedicalNet(path_to_weights=PRETRAINED_WEIGHTS, device=device)
        for name,param in model.named_parameters():
            if name.startswith("fc") or not FEATURES_EXTRACTION:
                param.requires_grad = True
            else:
                param.requires_grad = False

        print("Params to learn:")
        params_to_update = []
        for name,param in model.named_parameters():
            if param.requires_grad == True:
                params_to_update.append(param)

    elif MODEL_TYPE == 'efficientnet':
        model = EfficientNetModel(device=device)
        params_to_update = model.parameters()
    else:
        model = DenseNet121(spatial_dims=3, in_channels=1, out_channels=1)
        model = model.to(device)
        params_to_update = model.parameters()
    return model, params_to_update

In [None]:
def train_mri_type(df_train, df_valid, mri_type, kfold, device):
    print(f'kfold: {kfold}')    
    train_data_retriever = Dataset(
        paths=df_train.voxel_paths.values, 
        targets=df_train["MGMT_value"].values, 
        mri_type=df_train["MRI_Type"].values,
        scan_ids=df_train["BraTS21ID"].values,
        augment=True
    )

    valid_data_retriever = Dataset(
        paths=df_valid.voxel_paths.values, 
        targets=df_valid["MGMT_value"].values, 
        mri_type=df_valid["MRI_Type"].values,
        scan_ids=df_valid["BraTS21ID"].values
    )

    train_loader = torch_data.DataLoader(
        train_data_retriever,
        batch_size=4,
        shuffle=True,
        num_workers=8,
    )

    valid_loader = torch_data.DataLoader(
        valid_data_retriever, 
        batch_size=4,
        shuffle=False,
        num_workers=8,
    )

# Observe that all parameters are being optimized
    model, params_to_update = create_model(device=device)
    optimizer = torch.optim.Adam(params_to_update, lr=LEARNING_RATE)
    #optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

    criterion = torch_functional.binary_cross_entropy_with_logits

    trainer = Trainer(
        model, 
        device, 
        optimizer, 
        criterion
    )

    print(f'Fitting model for mri_type: {mri_type}')
    history = trainer.fit(
        EPOCHS, 
        train_loader, 
        valid_loader, 
        f"{mri_type}", 
        8,
        kfold=kfold
    )
    
    return trainer.lastmodel

modelfiles = None

# all mri types
if not modelfiles:
    modelfiles = [train_mri_type(df_train.loc[train, :], df_train.loc[val, :], mri_type='all', kfold=i, device=DEVICE) 
                  for i, (train, val) in enumerate(zip(train_indices, val_indices))]
    print(modelfiles)

## Predict function

In [None]:
def predict(modelfile, df, mri_type, split, device):
    print("Predict:", modelfile, mri_type, df.shape)   
    data_retriever = Dataset(
        paths=df.voxel_paths.values, 
        mri_type=df["MRI_Type"].values,
        scan_ids=df.BraTS21ID.values,
        split=split
    )
        
    data_loader = torch_data.DataLoader(
        data_retriever,
        batch_size=4,
        shuffle=False,
        num_workers=8,
    )
    
    model, _ = create_model(device=device)
    checkpoint = torch.load(modelfile, map_location=device)
    model.load_state_dict(checkpoint["model_state_dict"])
    model.eval()
    
    y_pred = []
    ids = []

    for e, batch in enumerate(data_loader,1):
        print(f"{e}/{len(data_loader)}", end="\r")
        with torch.no_grad():
            tmp_pred = torch.sigmoid(model(batch["X"].float().to(device))).cpu().numpy().squeeze()
            if tmp_pred.size == 1:
                y_pred.append(tmp_pred)
            else:
                y_pred.extend(tmp_pred.tolist())
            ids.extend(batch["id"])
            
    preddf = pd.DataFrame({"BraTS21ID": ids, "MGMT_pred": y_pred}) 
    return preddf

## Ensemble for validation

In [None]:
#modelfiles = ['../input/output-medical-volumentations/all-k0-e3-loss0.689-auc0.550.pth', '../input/output-medical-volumentations/all-k1-e5-loss0.697-auc0.532.pth', '../input/output-medical-volumentations/all-k2-e1-loss0.697-auc0.531.pth'] # TODO: remove it!

In [None]:
# K-FOLD CV
predictions_by_fold = []
for m, indices in zip(modelfiles, val_indices):
    df_valid = df_train.loc[indices, :].copy()
    pred = predict(m, df_valid, mri_type='all', split="train", device=DEVICE)
    pred = pred.groupby(pred.BraTS21ID).mean() # give a single vote for the same patient study using all mri types
    df_valid = df_valid.merge(pred, on='BraTS21ID', how="left")
    predictions_by_fold.append(df_valid)

In [None]:
df_valid = pd.concat(predictions_by_fold)
assert(all(df_valid.groupby('BraTS21ID').MGMT_pred.nunique() == 1))
auc = roc_auc_score(df_valid["MGMT_value"], df_valid["MGMT_pred"])
print(f"Validation ensemble AUC: {auc:.4f}")
sns.displot(df_valid["MGMT_pred"])

## Ensemble for submission

In [None]:
# TODO: train the model in the whole dataset
#modelfiles = train_mri_type(df_train, df_train, mri_type='all') 
#                  for (train, val) in zip(train_indices, val_indices)

In [None]:
submission = pd.read_csv(f"{data_directory}/sample_submission.csv", index_col="BraTS21ID")
submission["MGMT_value"] = 0

for m in modelfiles:
    pred = predict(m, submission, 'all', split="test", device=DEVICE)
    pred = pred.groupby(pred.index).mean()
    submission["MGMT_value"] += pred["MGMT_value"]
    
submission['MGMT_value'] /= len(modelfiles)
submission["MGMT_value"].to_csv("submission.csv")

In [None]:
submission

In [None]:
sns.displot(submission["MGMT_value"])