## References:
* https://www.kaggle.com/xuxu1234/efficientnet3d-for-mri

In [None]:
import os
import glob
from tqdm import tqdm_notebook as tqdm
import random
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torchvision import transforms, utils
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import pydicom
from pydicom.pixel_data_handlers.util import apply_voi_lut
import cv2


import warnings
warnings.filterwarnings("ignore")

In [None]:
def set_seed(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True


set_seed(42)

In [None]:
path = '/kaggle/input/rsna-miccai-brain-tumor-radiogenomic-classification'
train_data = pd.read_csv(os.path.join(path, 'train_labels.csv'))
print('Num of train samples:', len(train_data))
train_data.head()
img_size = 128

In [None]:
def dicom2array(path, voi_lut=True, fix_monochrome=True):
    dicom = pydicom.read_file(path)
    # VOI LUT (if available by DICOM device) is used to
    # transform raw DICOM data to "human-friendly" view
    if voi_lut:
        data = apply_voi_lut(dicom.pixel_array, dicom)
    else:
        data = dicom.pixel_array
    # depending on this value, X-ray may look inverted - fix that:
    if fix_monochrome and dicom.PhotometricInterpretation == "MONOCHROME1":
        data = np.amax(data) - data
    data = data - np.min(data)
    data = data / np.max(data)
    data = (data * 255).astype(np.uint8)
    data = cv2.resize(data, (img_size, img_size))
    return data

def load_3d_dicom_images(scan_id, split = "train"):
    """
    we will use some heuristics to choose the slices to avoid any numpy zero matrix (if possible)
    """
    flair = sorted(glob.glob(f"{path}/{split}/{scan_id}/FLAIR/*.dcm"))
    t1w = sorted(glob.glob(f"{path}/{split}/{scan_id}/T1w/*.dcm"))
    t1wce = sorted(glob.glob(f"{path}/{split}/{scan_id}/T1wCE/*.dcm"))
    t2w = sorted(glob.glob(f"{path}/{split}/{scan_id}/T2w/*.dcm"))
    
    
    flair_img = np.array([dicom2array(a) for a in flair[len(flair)//2 - 25:len(flair)//2 + 25]]).T
    
    if flair_img.shape[-1] < 50:
        n_zero = 50 - flair_img.shape[-1]
        flair_img = np.concatenate((flair_img, np.zeros((img_size, img_size, n_zero))), axis = -1)
    #print(flair_img.shape)
        
    
    
    t1w_img = np.array([dicom2array(a) for a in t1w[len(t1w)//2 - 25:len(t1w)//2 + 25]]).T
    if t1w_img.shape[-1] < 50:
        n_zero = 50 - t1w_img.shape[-1]
        t1w_img = np.concatenate((t1w_img, np.zeros((img_size, img_size, n_zero))), axis = -1)
    #print(t1w_img.shape)
    
    
    t1wce_img = np.array([dicom2array(a) for a in t1wce[len(t1wce)//2 - 25:len(t1wce)//2 + 25]]).T
    if t1wce_img.shape[-1] < 50:
        n_zero = 50 - t1wce_img.shape[-1]
        t1wce_img = np.concatenate((t1wce_img, np.zeros((img_size, img_size, n_zero))), axis = -1)
    #print(t1wce_img.shape)
    
    
    t2w_img = np.array([dicom2array(a) for a in t2w[len(t2w)//2 - 25:len(t2w)//2 + 25]]).T
    if t2w_img.shape[-1] < 50:
        n_zero = 50 - t2w_img.shape[-1]
        t2w_img = np.concatenate((t2w_img, np.zeros((img_size, img_size, n_zero))), axis = -1)
    #print(t2w_img.shape)
    
    return np.concatenate((flair_img, t1w_img, t1wce_img, t2w_img), axis = -1)

In [None]:
slices = load_3d_dicom_images("00000")
print(slices.shape)

In [None]:
class BrainTumor(Dataset):
    def __init__(self, path = '/kaggle/input/rsna-miccai-brain-tumor-radiogenomic-classification', split = "train", validation_split = 0.0):
        # labels
        train_data = pd.read_csv(os.path.join(path, 'train_labels.csv'))
        self.labels = {}
        brats = list(train_data["BraTS21ID"])
        mgmt = list(train_data["MGMT_value"])
        for b, m in zip(brats, mgmt):
            self.labels[str(b).zfill(5)] = m
            
        if split == "valid":
            self.split = split
            self.ids = [a.split("/")[-1] for a in sorted(glob.glob(path + f"/{split}/" + "/*"))]
            self.ids = self.ids[:int(len(self.ids)* validation_split)] # first 20% as validation
        elif split == "train":
            self.split = split
            self.ids = [a.split("/")[-1] for a in sorted(glob.glob(path + f"/{split}/" + "/*"))]
            self.ids = self.ids[int(len(self.ids)* validation_split):] # last 80% as train
        else:
            self.split = split
            self.ids = [a.split("/")[-1] for a in sorted(glob.glob(path + f"/{split}/" + "/*"))]
            
    
    def __len__(self):
        return len(self.ids)
    
    def __getitem__(self, idx):
        imgs = load_3d_dicom_images(self.ids[idx], self.split)
        transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,) * 200, (0.5,) * 200)])
        imgs = transform(imgs)
        
        if self.split != "test":
            label = self.labels[self.ids[idx]]
            return torch.tensor(imgs, dtype = torch.float32), torch.tensor(label, dtype = torch.long)
        else:
            return torch.tensor(imgs, dtype = torch.float32)

In [None]:
train_dataset = BrainTumor()
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=4)

In [None]:
for img, label in train_loader:
    print(img.shape)
    print(label.shape)
    break

## MODEL

In [None]:
#https://github.com/MontaEllis/Pytorch-Medical-Classification/blob/main/models/three_d/densenet3d.py

import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict


class _DenseLayer(nn.Sequential):

    def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):
        super().__init__()
        self.add_module('norm1', nn.BatchNorm3d(num_input_features))
        self.add_module('relu1', nn.ReLU(inplace=True))
        self.add_module(
            'conv1',
            nn.Conv3d(num_input_features,
                      bn_size * growth_rate,
                      kernel_size=1,
                      stride=1,
                      bias=False))
        self.add_module('norm2', nn.BatchNorm3d(bn_size * growth_rate))
        self.add_module('relu2', nn.ReLU(inplace=True))
        self.add_module(
            'conv2',
            nn.Conv3d(bn_size * growth_rate,
                      growth_rate,
                      kernel_size=3,
                      stride=1,
                      padding=1,
                      bias=False))
        self.drop_rate = drop_rate

    def forward(self, x):
        new_features = super().forward(x)
        if self.drop_rate > 0:
            new_features = F.dropout(new_features,
                                     p=self.drop_rate,
                                     training=self.training)
        return torch.cat([x, new_features], 1)


class _DenseBlock(nn.Sequential):

    def __init__(self, num_layers, num_input_features, bn_size, growth_rate,
                 drop_rate):
        super().__init__()
        for i in range(num_layers):
            layer = _DenseLayer(num_input_features + i * growth_rate,
                                growth_rate, bn_size, drop_rate)
            self.add_module('denselayer{}'.format(i + 1), layer)


class _Transition(nn.Sequential):

    def __init__(self, num_input_features, num_output_features):
        super().__init__()
        self.add_module('norm', nn.BatchNorm3d(num_input_features))
        self.add_module('relu', nn.ReLU(inplace=True))
        self.add_module(
            'conv',
            nn.Conv3d(num_input_features,
                      num_output_features,
                      kernel_size=1,
                      stride=1,
                      bias=False))
        self.add_module('pool', nn.AvgPool3d(kernel_size=2, stride=2))


class DenseNet(nn.Module):
    """Densenet-BC model class
    Args:
        growth_rate (int) - how many filters to add each layer (k in paper)
        block_config (list of 4 ints) - how many layers in each pooling block
        num_init_features (int) - the number of filters to learn in the first convolution layer
        bn_size (int) - multiplicative factor for number of bottle neck layers
          (i.e. bn_size * k features in the bottleneck layer)
        drop_rate (float) - dropout rate after each dense layer
        num_classes (int) - number of classification classes
    """

    def __init__(self,
                 n_input_channels=1,
                 conv1_t_size=7,
                 conv1_t_stride=1,
                 no_max_pool=False,
                 growth_rate=32,
                 block_config=(6, 12, 24, 16),
                 num_init_features=64,
                 bn_size=4,
                 drop_rate=0,
                 num_classes=2):

        super().__init__()

        # First convolution
        self.features = [('conv1',
                          nn.Conv3d(n_input_channels,
                                    num_init_features,
                                    kernel_size=(conv1_t_size, 7, 7),
                                    stride=(conv1_t_stride, 2, 2),
                                    padding=(conv1_t_size // 2, 3, 3),
                                    bias=False)),
                         ('norm1', nn.BatchNorm3d(num_init_features)),
                         ('relu1', nn.ReLU(inplace=True))]
        if not no_max_pool:
            self.features.append(
                ('pool1', nn.MaxPool3d(kernel_size=3, stride=2, padding=1)))
        self.features = nn.Sequential(OrderedDict(self.features))

        # Each denseblock
        num_features = num_init_features
        for i, num_layers in enumerate(block_config):
            block = _DenseBlock(num_layers=num_layers,
                                num_input_features=num_features,
                                bn_size=bn_size,
                                growth_rate=growth_rate,
                                drop_rate=drop_rate)
            self.features.add_module('denseblock{}'.format(i + 1), block)
            num_features = num_features + num_layers * growth_rate
            if i != len(block_config) - 1:
                trans = _Transition(num_input_features=num_features,
                                    num_output_features=num_features // 2)
                self.features.add_module('transition{}'.format(i + 1), trans)
                num_features = num_features // 2

        # Final batch norm
        self.features.add_module('norm5', nn.BatchNorm3d(num_features))

        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                m.weight = nn.init.kaiming_normal(m.weight, mode='fan_out')
            elif isinstance(m, nn.BatchNorm3d) or isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

        # Linear layer
        self.classifier = nn.Linear(num_features, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                nn.init.kaiming_normal_(m.weight,
                                        mode='fan_out',
                                        nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm3d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        features = self.features(x)
        out = F.relu(features, inplace=True)
        out = F.adaptive_avg_pool3d(out,
                                    output_size=(1, 1,
                                                 1)).view(features.size(0), -1)
        out = self.classifier(out)
        return out


def generate_model(model_depth, **kwargs):
    assert model_depth in [121, 169, 201, 264]

    if model_depth == 121:
        model = DenseNet(num_init_features=64,
                         growth_rate=32,
                         block_config=(6, 12, 24, 16),
                         **kwargs)
    elif model_depth == 169:
        model = DenseNet(num_init_features=64,
                         growth_rate=32,
                         block_config=(6, 12, 32, 32),
                         **kwargs)
    elif model_depth == 201:
        model = DenseNet(num_init_features=64,
                         growth_rate=32,
                         block_config=(6, 12, 48, 32),
                         **kwargs)
    elif model_depth == 264:
        model = DenseNet(num_init_features=64,
                         growth_rate=32,
                         block_config=(6, 12, 64, 48),
                         **kwargs)

    return model


if __name__ == "__main__":


    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    image_size = 128
    x = torch.Tensor(1, 1, image_size, image_size, image_size)
    x = x.to(device)
    print("x size: {}".format(x.size()))
    
    model = generate_model(201,n_input_channels=1,num_classes=2).to(device)
    

    out1 = model(x)
    print("out size: {}".format(out1.size()))

In [None]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=1., gamma=1.):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, inputs, targets, **kwargs):
        CE_loss = nn.CrossEntropyLoss(reduction='none')(inputs, targets)
        pt = torch.exp(-CE_loss)
        F_loss = self.alpha * ((1-pt)**self.gamma) * CE_loss
        return F_loss.mean()

    
def roc_score(inp, target):
    _, indices = inp.max(1)
    return torch.Tensor([roc_auc_score(target, indices)])[0]

In [None]:
model = DenseNet()
criterion = FocalLoss()
optimizer = torch.optim.Adam(model.parameters(),lr = 0.0001)
n_epochs = 10

## TRAIN

In [None]:
gpu = torch.device(f"cuda:0" if torch.cuda.is_available() else "cpu")
model.to(gpu)

for epoch in range(n_epochs):  # loop over the dataset multiple times

    train_loss = []
    best_pres = 10000
    model.train()
    for i, data in tqdm(enumerate(train_loader, 0)):
        x, y = data
        
        x = torch.unsqueeze(x, dim = 1)
        x = x.to(gpu)
        y = y.to(gpu)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(x)
        loss = criterion(outputs, y)
        loss.backward()
        optimizer.step()

        # print statistics
        train_loss.append(loss.item())
    avg_train = sum(train_loss) / len(train_loss)
    print(f"epoch {epoch+1} train: {avg_train}")

    if avg_train < best_pres:
        print('save model...')
        best_pres = avg_train
        torch.save(model.state_dict(),'best_Densenet_201_loss.pt')