# RSNA-MICCAI Brain Tumor Radiogenomic Classificationn - **An approach with PyTorch EfficientNet 3D**

## **Problem Description**:

There are structural multi-parametric MRI (mpMRI) scans for different subjects, in DICOM format. The exact mpMRI scans included are:

* Fluid Attenuated Inversion Recovery (FLAIR)
* T1-weighted pre-contrast (T1w)
* T1-weighted post-contrast (T1Gd)
* T2-weighted (T2)

`train_labels.csv` - file contains the target **MGMT_value** for each subject in the training data **(e.g. the presence of MGMT promoter methylation)**.

So, it's a binary classification problem.

## **An EfficientNet3D solution**:

* For each patient, we consider 4 sequences (FLAIR, T1w, T1Gd, T2), and for each of those sequences we take 64 slices from the middle. We resize the slices in shape (256, 256).

* Construct an efficientnet-3d in pytorch with input shape (256, 256, 256) or (4, 256, 256, 64).

* Perform binary classification.

## **An EfficientNet2D solution**:

* For each patient, we consider 4 sequences (FLAIR, T1w, T1Gd, T2), and for each of those sequences take a slice randomly. Idea from [https://github.com/zabir-nabil/Fibro-CoSANet](https://github.com/zabir-nabil/Fibro-CoSANet)

* Construct a 4-channel image out of these 4 sequences.

* Design a 4 channel pytorch model.

* Add augmentation. Check out the augmentation notebook: https://www.kaggle.com/furcifer/mri-data-augmentation-pipeline

* Add few heuristics to avoid black/empty scans.

* Design modified efficient-net (4 channels) as model.

* Perform binary classification.


## **Check out my other kernels**

### ⚡ **Training kernel:** https://www.kaggle.com/furcifer/torch-efficientnet3d-for-mri-no-train/

### ⚡ **Inference kernel:** https://www.kaggle.com/furcifer/torch-effnet3d-for-mri-no-inference/

### ⚡ **Training kernel (EfficientNet2D):** https://www.kaggle.com/furcifer/no-baseline-pytorch-cnn-for-mri/


In [None]:
from IPython.display import Image
Image("../input/diagram/diagram.png")

### **Importing libraries**

In [None]:
import os
import glob
from tqdm import tqdm_notebook as tqdm
import math
import random
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torchvision import transforms, utils
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import pydicom
from pydicom.pixel_data_handlers.util import apply_voi_lut
import cv2
from sklearn.metrics import roc_auc_score
import imgaug as ia
import imgaug.augmenters as iaa


import warnings
warnings.filterwarnings("ignore")

### **Importing EfficientNet-(2D + 3D)**

In [None]:
import sys
sys.path.append('../input/efficientnetpyttorch3d/EfficientNet-PyTorch-3D')
sys.path.append('../input/efficientnet-pytorch')
sys.path.append('../input/efficientnet/EfficientNet-PyTorch-master/')

In [None]:
MODEL = "4C" # ["4C", "3D", "4C+3D"]

### **Inspecting Labels**

In [None]:
path = '/kaggle/input/rsna-miccai-brain-tumor-radiogenomic-classification'
train_data = pd.read_csv(os.path.join(path, 'train_labels.csv'))
print('Num of train samples:', len(train_data))
train_data.head()

## **Augmentation**

In [None]:
sometimes = lambda aug: iaa.Sometimes(0.1, aug)

seq = iaa.Sequential(
    [
        # apply the following augmenters to most images
        iaa.Fliplr(0.5), # horizontally flip 50% of all images
        iaa.Flipud(0.5), # vertically flip 20% of all images
        # crop images by -5% to 10% of their height/width
        sometimes(iaa.CropAndPad(
            percent=(-0.05, 0.05),
            pad_mode=ia.ALL,
            pad_cval=(0, 255)
        )),
        sometimes(iaa.Affine(
            scale={"x": (0.8, 1.2), "y": (0.8, 1.2)}, # scale images to 80-120% of their size, individually per axis
            translate_percent={"x": (-0.2, 0.2), "y": (-0.2, 0.2)}, # translate by -20 to +20 percent (per axis)
            rotate=(-45, 45), # rotate by -45 to +45 degrees
            shear=(-16, 16), # shear by -16 to +16 degrees
            order=[0, 1], # use nearest neighbour or bilinear interpolation (fast)
            cval=(0, 255), # if mode is constant, use a cval between 0 and 255
            mode=ia.ALL # use any of scikit-image's warping modes (see 2nd image from the top for examples)
        )),
        # execute 0 to 5 of the following (less important) augmenters per image
        # don't execute all of them, as that would often be way too strong
        iaa.SomeOf((0, 5),
            [
                sometimes(iaa.Superpixels(p_replace=(0, 1.0), n_segments=(20, 200))), # convert images into their superpixel representation
                iaa.OneOf([
                    iaa.GaussianBlur((0, 3.0)), # blur images with a sigma between 0 and 3.0
                    iaa.AverageBlur(k=(2, 7)), # blur image using local means with kernel sizes between 2 and 7
                    iaa.MedianBlur(k=(3, 11)), # blur image using local medians with kernel sizes between 2 and 7
                ]),
                iaa.Sharpen(alpha=(0, 1.0), lightness=(0.75, 1.5)), # sharpen images
                iaa.Emboss(alpha=(0, 1.0), strength=(0, 2.0)), # emboss images
                # search either for all edges or for directed edges,
                # blend the result with the original image using a blobby mask
                iaa.SimplexNoiseAlpha(iaa.OneOf([
                    iaa.EdgeDetect(alpha=(0.5, 1.0)),
                    iaa.DirectedEdgeDetect(alpha=(0.5, 1.0), direction=(0.0, 1.0)),
                ])),
                iaa.AdditiveGaussianNoise(loc=0, scale=(0.0, 0.05*255), per_channel=0.5), # add gaussian noise to images
                iaa.OneOf([
                    iaa.Dropout((0.01, 0.1), per_channel=0.5), # randomly remove up to 10% of the pixels
                    iaa.CoarseDropout((0.03, 0.15), size_percent=(0.02, 0.05), per_channel=0.2),
                ]),
                iaa.Invert(0.05, per_channel=True), # invert color channels
                iaa.Add((-10, 10), per_channel=0.5), # change brightness of images (by -10 to 10 of original value)
                
                # either change the brightness of the whole image (sometimes
                # per channel) or change the brightness of subareas
                iaa.OneOf([
                    iaa.Multiply((0.5, 1.5), per_channel=0.5),
                    iaa.FrequencyNoiseAlpha(
                        exponent=(-4, 0),
                        first=iaa.Multiply((0.5, 1.5), per_channel=True),
                        second=iaa.LinearContrast((0.5, 2.0))
                    )
                ]),
                iaa.LinearContrast((0.5, 2.0), per_channel=0.5), # improve or worsen the contrast
                sometimes(iaa.ElasticTransformation(alpha=(0.5, 3.5), sigma=0.25)), # move pixels locally around (with random strengths)
                sometimes(iaa.PiecewiseAffine(scale=(0.01, 0.05))), # sometimes move parts of the image around
                sometimes(iaa.PerspectiveTransform(scale=(0.01, 0.1)))
            ],
            random_order=True
        )
    ],
    random_order=True
)

### **MRI Slice Loading/Processing**

In [None]:
def dicom2array(paths, voi_lut=True, fix_monochrome=True, remove_black_boundary=True, aug = False):
    
    for path in paths:
        dicom = pydicom.read_file(path)
        # VOI LUT (if available by DICOM device) is used to
        # transform raw DICOM data to "human-friendly" view
        if voi_lut:
            data = apply_voi_lut(dicom.pixel_array, dicom)
        else:
            data = dicom.pixel_array
        if data.max() > 0.0: # avoiding black images (if possible)
            break
    # depending on this value, X-ray may look inverted - fix that:
    if fix_monochrome and dicom.PhotometricInterpretation == "MONOCHROME1":
        data = np.amax(data) - data
    data = data - np.min(data)
    data = data / np.max(data)
    data = (data * 255).astype(np.uint8)
    if remove_black_boundary: # we get slightly more details
        (x, y) = np.where(data > 0)
        if len(x) > 0 and len(y) > 0:
            x_mn = np.min(x)
            x_mx = np.max(x)
            y_mn = np.min(y)
            y_mx = np.max(y)
            if (x_mx - x_mn) > 10 and (y_mx - y_mn) > 10:
                data = data[:,np.min(y):np.max(y)]
    data = cv2.resize(data, (512, 512))
    if aug and random.randint(0,1) == 1: # augmenting only 50% of the time
        data = seq(images=data)
    return data


def load_3d_dicom_images(scan_id, split = "train", channel_expand = True):
    """
    we will use some heuristics to choose the slices to avoid any numpy zero matrix (if possible)
    """
    flair = sorted(glob.glob(f"{path}/{split}/{scan_id}/FLAIR/*.dcm"))
    t1w = sorted(glob.glob(f"{path}/{split}/{scan_id}/T1w/*.dcm"))
    t1wce = sorted(glob.glob(f"{path}/{split}/{scan_id}/T1wCE/*.dcm"))
    t2w = sorted(glob.glob(f"{path}/{split}/{scan_id}/T2w/*.dcm"))
    
    
    flair_img = np.array([dicom2array(a) for a in flair[len(flair)//2 - 32:len(flair)//2 + 32]]).T
    
    if len(flair_img) == 0:
        flair_img = np.zeros((256, 256, 64))
    elif flair_img.shape[-1] < 64:
        n_zero = 64 - flair_img.shape[-1]
        flair_img = np.concatenate((flair_img, np.zeros((256, 256, n_zero))), axis = -1)
    #print(flair_img.shape)
        
    
    
    t1w_img = np.array([dicom2array(a) for a in t1w[len(t1w)//2 - 32:len(t1w)//2 + 32]]).T
    
    if len(t1w_img) == 0:
        t1w_img = np.zeros((256, 256, 64))
    elif t1w_img.shape[-1] < 64:
        n_zero = 64 - t1w_img.shape[-1]
        t1w_img = np.concatenate((t1w_img, np.zeros((256, 256, n_zero))), axis = -1)
    #print(t1w_img.shape)
    
    
    t1wce_img = np.array([dicom2array(a) for a in t1wce[len(t1wce)//2 - 32:len(t1wce)//2 + 32]]).T
    
    if len(t1wce_img) == 0:
        t1wce_img = np.zeros((256, 256, 64))
    elif t1wce_img.shape[-1] < 64:
        n_zero = 64 - t1wce_img.shape[-1]
        t1wce_img = np.concatenate((t1wce_img, np.zeros((256, 256, n_zero))), axis = -1)
    #print(t1wce_img.shape)
    
    
    t2w_img = np.array([dicom2array(a) for a in t2w[len(t2w)//2 - 32:len(t2w)//2 + 32]]).T
    
    if len(t2w_img) == 0:
        t2w_img = np.zeros((256, 256, 64))
    elif t2w_img.shape[-1] < 64:
        n_zero = 64 - t2w_img.shape[-1]
        t2w_img = np.concatenate((t2w_img, np.zeros((256, 256, n_zero))), axis = -1)
    #print(t2w_img.shape)
    
    return np.concatenate((flair_img, t1w_img, t1wce_img, t2w_img), axis = -1) if not channel_expand else np.moveaxis(np.array((flair_img, t1w_img, t1wce_img, t2w_img)), 0, -1)


def load_rand_dicom_images(scan_id, split = "train", aug = False):
    """
    send 4 random slices of each modality
    """
    if split != "train" and split != "test":
        split = "train"
    flair = sorted(glob.glob(f"{path}/{split}/{scan_id}/FLAIR/*.dcm"))
    flair_img = dicom2array(random.sample(flair, max(len(flair)//2, 1)), aug = aug)
    t1w = sorted(glob.glob(f"{path}/{split}/{scan_id}/T1w/*.dcm"))
    t1w_img = dicom2array(random.sample(t1w, max(len(t1w)//2, 1)), aug = aug)
    t1wce = sorted(glob.glob(f"{path}/{split}/{scan_id}/T1wCE/*.dcm"))
    t1wce_img = dicom2array(random.sample(t1wce, max(len(t1wce)//2, 1)), aug = aug)
    t2w = sorted(glob.glob(f"{path}/{split}/{scan_id}/T2w/*.dcm"))
    t2w_img = dicom2array(random.sample(t2w, max(len(t2w)//2, 1)), aug = aug)
    
    return np.array((flair_img, t1w_img, t1wce_img, t2w_img)).T

### **Data Loader**

In [None]:
# let's write a simple pytorch dataloader

class BrainTumor4C(Dataset): # 4 channel data-loader
    def __init__(self, path = '/kaggle/input/rsna-miccai-brain-tumor-radiogenomic-classification', split = "train", validation_split = 0.2):
        # labels
        train_data = pd.read_csv(os.path.join(path, 'train_labels.csv'))
        self.labels = {}
        brats = list(train_data["BraTS21ID"])
        mgmt = list(train_data["MGMT_value"])
        for b, m in zip(brats, mgmt):
            self.labels[str(b).zfill(5)] = m
            
        if split == "valid":
            self.split = split
            self.ids = [a.split("/")[-1] for a in sorted(glob.glob(path + f"/train/" + "/*"))]
            self.ids = self.ids[:int(len(self.ids)* validation_split)] # first 20% as validation
        elif split == "train":
            self.split = split
            self.ids = [a.split("/")[-1] for a in sorted(glob.glob(path + f"/{split}/" + "/*"))]
            self.ids = self.ids[int(len(self.ids)* validation_split):] # last 80% as train
        else:
            self.split = split
            self.ids = [a.split("/")[-1] for a in sorted(glob.glob(path + f"/{split}/" + "/*"))]
            
    
    def __len__(self):
        return len(self.ids)
    
    def __getitem__(self, idx):
        imgs = load_rand_dicom_images(self.ids[idx], self.split, aug = False)
        
        transform = transforms.Compose([transforms.ToTensor()]) # transforms.Normalize((0.5, 0.5, 0.5, 0.5), (0.5, 0.5, 0.5, 0.5))
        imgs = transform(imgs)
        
        imgs = imgs - imgs.min()
        imgs = (imgs + 1e-5) / (imgs.max() - imgs.min() + 1e-5)
        
        if self.split != "test":
            label = self.labels[self.ids[idx]]
            return torch.tensor(imgs, dtype = torch.float32), torch.tensor(label, dtype = torch.long)
        else:
            return torch.tensor(imgs, dtype = torch.float32), self.ids[idx]

In [None]:
# testing the dataloader
test_dataset = BrainTumor4C(split = "test")
test_bs = 8
test_loader = DataLoader(test_dataset, batch_size = test_bs, shuffle=True)

In [None]:
for img, idx in test_loader:
    print(img.shape)
    print(img.max())
    print(img.mean())
    print(img.min())
    print(idx)
    break

### **Model: EfficientNet-3D B0 / EfficientNet B1**

In [None]:
if MODEL == "3D":
    PATH = "../input/rsna-efficientnet3db0/best_roc_0.29_loss_1826.83.pt"
    model = EfficientNet3D.from_name("efficientnet-b0", override_params={'num_classes': 2}, in_channels=4)
    model.load_state_dict(torch.load(PATH))
    model.eval()

In [None]:
if MODEL == "4C":
    from efficientnet_pytorch import EfficientNet
    from efficientnet_pytorch.utils import Conv2dStaticSamePadding

    PATH = "../input/effb14c/best_roc_0.35_loss_39.45.pt"
    model = EfficientNet.from_name('efficientnet-b1')

    # augment model with 4 channels

    model._conv_stem = Conv2dStaticSamePadding(4, 32, kernel_size = (3,3), stride = (2,2), 
                                                                 bias = False, image_size = 512)
    model._fc = torch.nn.Linear(in_features=1280, out_features=2, bias=True)
    
    model.load_state_dict(torch.load(PATH))
    model.eval()

In [None]:
# test
gpu = torch.device(f"cuda:0" if torch.cuda.is_available() else "cpu")
model.to(gpu)
n_bootstrap = 9

if MODEL == "4C":
    labels = {}

    model.eval()
    for i_b in range(n_bootstrap):
        for i, data in tqdm(enumerate(test_loader, 0)):

            x, idx = data

            x = x.to(gpu)

            # forward
            outputs = model(x)

            label = torch.argmax(outputs, dim = -1)

            # print(idx)
            # print(label)

            label = label.tolist()
            for i_, idx_ in enumerate(list(idx)):
                labels[idx_] = labels.get(idx_, []) + [label[i_]]
                
            # break

In [None]:
import collections
labels_od = collections.OrderedDict(sorted(labels.items()))
print(labels_od)

In [None]:
f_idxs = []
f_labels = []
for idx in labels_od.keys():
    f_idxs.append(int(idx))
    f_labels.append(np.array(labels_od[idx], dtype = np.float32).mean())    

In [None]:
print(f_idxs)
print(f_labels)

In [None]:
submission = pd.read_csv("../input/rsna-miccai-brain-tumor-radiogenomic-classification/sample_submission.csv")

for i, row in submission.iterrows():
    idx = int(row['BraTS21ID'])
    try:
        new_label = f_labels[f_idxs.index(idx)]
        submission.loc[i, 'MGMT_value'] = float(new_label)
    except:
        pass

In [None]:
submission.to_csv("submission.csv", index=False)
submission