In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import cv2
import torch

import os
import glob
from tqdm.notebook import tqdm
import SimpleITK as sitk
import pydicom

import sys
sys.path.append('../input/monai-v060-deep-learning-in-healthcare-imaging/')
from monai.transforms import (
    AddChannel,
    Compose,
    RandRotate90,
    Resize,
    ScaleIntensity,
    EnsureType,
    Randomizable,
    LoadImaged,
    EnsureTyped,
    RandRotate,
    RandZoom,
    RandDeformGrid,
    RandAffine,
    Transform
)
from monai.data import CacheDataset, DataLoader, ImageDataset


# 1. Config

In [None]:
DICOM_IM_FOLDER = '../input/rsna-miccai-brain-tumor-radiogenomic-classification/test/'
IM_FOLDER = 'BraTS2021_Testing_Data'
MRI_TYPES = ['T1wCE', 'T1w', 'T2w', 'FLAIR']
SHORT_MRI_TYPES = ['t1', 't1ce', 't2', 'flair']

SEED = 67
DIM = (128, 128, 128, 1)
NUM_CLASSES = 1
NUM_SEG_CLASSES = 0 # whether to use the segment head
BATCH_SIZE = 6
DEVICE = torch.device('cuda:0')

FAST_COMMIT = True

CANDIDATES = [
    {
        'backbone_name':'resnet50',
        'model_path':'../input/brainbaselinemodels/v8/v8/t1_Fold0_resnet50_v8_data_task2_npy_norm_ValidLoss0.662_ValidAUC0.596_Ep19.pth',
        'mri_type':'t1',
    },
    {
        'backbone_name':'resnet50',
        'model_path':'../input/brainbaselinemodels/v8/v8/t1ce_Fold0_resnet50_v8_data_task2_npy_norm_ValidLoss0.626_ValidAUC0.697_Ep13.pth',
        'mri_type':'t1ce',
    },
    {
        'backbone_name':'resnet50',
        'model_path':'../input/brainbaselinemodels/v8/v8/t2_Fold0_resnet50_v8_data_task2_npy_norm_ValidLoss0.691_ValidAUC0.569_Ep02.pth',
        'mri_type':'t2',
    },
    {
        'backbone_name':'resnet50',
        'model_path':'../input/brainbaselinemodels/v8/v8/flair_Fold0_resnet50_v8_data_task2_npy_norm_ValidLoss0.668_ValidAUC0.617_Ep06.pth',
        'mri_type':'flair',
    }
]

In [None]:
def visualize_3_planes_sitk(image):
    voxels = sitk.GetArrayFromImage(image)
    plt.figure(figsize=(9,3))
    plt.subplot(1,3,1)
    plt.imshow(voxels[voxels.shape[0]//2])
    plt.subplot(1,3,2)
    plt.imshow(voxels[:, voxels.shape[1]//2, :])
    plt.subplot(1,3,3)
    plt.imshow(voxels[:,:,voxels.shape[2]//2])
    
def visualize_3_planes(voxels):
    plt.figure(figsize=(9,3))
    plt.subplot(1,3,1)
    plt.imshow(voxels[voxels.shape[0]//2])
    plt.subplot(1,3,2)
    plt.imshow(voxels[:, voxels.shape[1]//2, :])
    plt.subplot(1,3,3)
    plt.imshow(voxels[:,:,voxels.shape[2]//2])

# 2. Read Voxels

In [None]:
def get_image_plane(data):
    x1, y1, _, x2, y2, _ = [round(j) for j in data.ImageOrientationPatient]
    cords = [x1, y1, x2, y2]

    if cords == [1, 0, 0, 0]:
        return 'Coronal'
    elif cords == [1, 0, 0, 1]:
        return 'Axial'
    elif cords == [0, 1, 0, 0]:
        return 'Sagittal'
    else:
        return 'Unknown'
    
def get_voxel(study_id, scan_type):
    imgs = []
    dcm_dir = os.path.join(DICOM_IM_FOLDER, study_id, scan_type, '*.dcm')
    dcm_paths = sorted(glob.glob(dcm_dir), key=lambda x: int(x.replace('.dcm','').split("-")[-1]))
    positions = []
    
    for dcm_path in dcm_paths:
        img = pydicom.dcmread(str(dcm_path))
        imgs.append(img.pixel_array)
        positions.append(img.ImagePositionPatient)
        
    plane = get_image_plane(img)
    voxel = np.stack(imgs)
    
    # reorder planes if needed and rotate voxel
    if plane == "Coronal":
        if positions[0][1] < positions[-1][1]:
            voxel = voxel[::-1]
            print(f"{study_id} {scan_type} {plane} reordered")
        voxel = voxel.transpose((1, 0, 2))
    elif plane == "Sagittal":
        if positions[0][0] < positions[-1][0]:
            voxel = voxel[::-1]
            print(f"{study_id} {scan_type} {plane} reordered")
        voxel = voxel.transpose((1, 2, 0))
        voxel = np.rot90(voxel, 2, axes=(1, 2))
    elif plane == "Axial":
        if positions[0][2] > positions[-1][2]:
            voxel = voxel[::-1]
            print(f"{study_id} {scan_type} {plane} reordered")
        voxel = np.rot90(voxel, 2)
    else:
        raise ValueError(f"Unknown plane {plane}")
    return voxel, plane

def normalize_contrast(voxel):
    if voxel.sum() == 0:
        return voxel
    voxel = voxel - np.min(voxel)
    voxel = voxel / np.max(voxel)
    voxel = (voxel * 255).astype(np.uint8)
    return voxel

def crop_voxel(voxel):
#     try:
    if voxel.sum() == 0:
        return voxel
    keep = (voxel.mean(axis=(0, 1)) > 0)
    voxel = voxel[:, :, keep]
    keep = (voxel.mean(axis=(0, 2)) > 0)
    voxel = voxel[:, keep]
    keep = (voxel.mean(axis=(1, 2)) > 0)
    voxel = voxel[keep]
#     except Exception as ex:
#         print(ex)
    return voxel

def resize_voxel(voxel, sz=128):
    output = np.zeros((sz, sz, sz), dtype=np.uint8)

    if np.argmax(voxel.shape) == 0:
        for i, s in enumerate(np.linspace(0, voxel.shape[0] - 1, sz)):
            output[i] = cv2.resize(voxel[int(s)], (sz, sz))
    elif np.argmax(voxel.shape) == 1:
        for i, s in enumerate(np.linspace(0, voxel.shape[1] - 1, sz)):
            output[:, i] = cv2.resize(voxel[:, int(s)], (sz, sz))
    elif np.argmax(voxel.shape) == 2:
        for i, s in enumerate(np.linspace(0, voxel.shape[2] - 1, sz)):
            output[:, :, i] = cv2.resize(voxel[:, :, int(s)], (sz, sz))

    return output

In [None]:
voxel, plane = get_voxel('00001', 'T1w')
voxel = normalize_contrast(voxel)
voxel = crop_voxel(voxel)
voxel = resize_voxel(voxel)


In [None]:
visualize_3_planes(voxel)

In [None]:
writer = sitk.ImageFileWriter()

In [None]:
patient_ids = []
image_names = []
mri_types = []
metas = []

mri_type_mapping = {
    'T1w':'t1',
    'T1wCE':'t1ce',
    'T2w':'t2',
    'FLAIR':'flair'
}

if(FAST_COMMIT and len(os.listdir(DICOM_IM_FOLDER)) == 87):
    iterations = tqdm(['00001','00013', '00015'])
else:
    iterations = tqdm(os.listdir(DICOM_IM_FOLDER))

for patient_id in iterations:
    patient_dir = os.path.join(DICOM_IM_FOLDER, patient_id) 
    saved_transform = None
    for mri_type in MRI_TYPES:
        type_dir = os.path.join(patient_dir, mri_type)

        try: 
            voxel, plane = get_voxel(patient_id, mri_type)
            voxel = normalize_contrast(voxel)
            voxel = crop_voxel(voxel)
            voxel = resize_voxel(voxel)

            sitk_voxel = sitk.GetImageFromArray(voxel)
            
        except Exception as ex:
            print(ex)
            print('patient id:', patient_id)
            voxel = np.zeros(shape=(128,128,128))
            sitk_voxel = sitk.GetImageFromArray(voxel)
        
        outputImageFileName = os.path.join(IM_FOLDER, f'BraTS2021_{patient_id}', 
                                               f'BraTS2021_{patient_id}_{mri_type_mapping[mri_type]}.nii.gz')
        
        os.makedirs(os.path.dirname(outputImageFileName), exist_ok=True)
        writer.SetFileName(outputImageFileName)
        writer.Execute(sitk_voxel)

In [None]:
visualize_3_planes_sitk(sitk_voxel)

# 3. Modeling

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import math
from functools import partial

__all__ = [
    'ResNet', 'resnet10', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
    'resnet152', 'resnet200'
]


def conv3x3x3(in_planes, out_planes, stride=1, dilation=1):
    # 3x3x3 convolution with padding
    return nn.Conv3d(
        in_planes,
        out_planes,
        kernel_size=3,
        dilation=dilation,
        stride=stride,
        padding=dilation,
        bias=False)


def downsample_basic_block(x, planes, stride, no_cuda=False):
    out = F.avg_pool3d(x, kernel_size=1, stride=stride)
    zero_pads = torch.Tensor(
        out.size(0), planes - out.size(1), out.size(2), out.size(3),
        out.size(4)).zero_()
    if not no_cuda:
        if isinstance(out.data, torch.cuda.FloatTensor):
            zero_pads = zero_pads.cuda()

    out = Variable(torch.cat([out.data, zero_pads], dim=1))

    return out


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3x3(inplanes, planes, stride=stride, dilation=dilation)
        self.bn1 = nn.BatchNorm3d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3x3(planes, planes, dilation=dilation)
        self.bn2 = nn.BatchNorm3d(planes)
        self.downsample = downsample
        self.stride = stride
        self.dilation = dilation

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm3d(planes)
        self.conv2 = nn.Conv3d(
            planes, planes, kernel_size=3, stride=stride, dilation=dilation, padding=dilation, bias=False)
        self.bn2 = nn.BatchNorm3d(planes)
        self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm3d(planes * 4)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride
        self.dilation = dilation

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

class GAP3D(nn.Module):
    def __init__(self, feat_dim):
        super(GAP3D, self).__init__()
        self.feat_dim = feat_dim

    def forward(self, x):
        x = F.adaptive_avg_pool3d(x, (1, 1, 1))
        x = x.view((-1, self.feat_dim))
        return x

class ResNet(nn.Module):

    def __init__(self,
                 block,
                 layers,
                 sample_input_D,
                 sample_input_H,
                 sample_input_W,
                 num_classes,
                 num_seg_classes,
                 shortcut_type='B',
                 no_cuda = False):
        self.inplanes = 64
        self.no_cuda = no_cuda
        self.num_seg_classes = num_seg_classes
        self.num_classes = num_classes

        super(ResNet, self).__init__()
        self.conv1 = nn.Conv3d(
            1,
            64,
            kernel_size=7,
            stride=(2, 2, 2),
            padding=(3, 3, 3),
            bias=False)
            
        self.bn1 = nn.BatchNorm3d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0], shortcut_type)
        self.layer2 = self._make_layer(
            block, 128, layers[1], shortcut_type, stride=2)
        self.layer3 = self._make_layer(
            block, 256, layers[2], shortcut_type, stride=1, dilation=2)
        self.layer4 = self._make_layer(
            block, 512, layers[3], shortcut_type, stride=1, dilation=4)

        # classification head
        self.feat_dim = 512 * block.expansion
        self.clf_head = nn.Sequential(
            GAP3D(self.feat_dim),
            nn.Linear(self.feat_dim, self.num_classes)
        )

        if(num_seg_classes > 0):
            self.conv_seg = nn.Sequential(
                                            nn.ConvTranspose3d(
                                            512 * block.expansion,
                                            32,
                                            2,
                                            stride=2
                                            ),
                                            nn.BatchNorm3d(32),
                                            nn.ReLU(inplace=True),
                                            nn.Conv3d(
                                            32,
                                            32,
                                            kernel_size=3,
                                            stride=(1, 1, 1),
                                            padding=(1, 1, 1),
                                            bias=False), 
                                            nn.BatchNorm3d(32),
                                            nn.ReLU(inplace=True),
                                            nn.Conv3d(
                                            32,
                                            num_seg_classes,
                                            kernel_size=1,
                                            stride=(1, 1, 1),
                                            bias=False) 
                                            )

            for m in self.modules():
                if isinstance(m, nn.Conv3d):
                    m.weight = nn.init.kaiming_normal(m.weight, mode='fan_out')
                elif isinstance(m, nn.BatchNorm3d):
                    m.weight.data.fill_(1)
                    m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, shortcut_type, stride=1, dilation=1):
        downsample = None
        # print(planes, stride, self.inplanes, block.expansion)
        if stride != 1 or self.inplanes != planes * block.expansion:
            if shortcut_type == 'A':
                downsample = partial(
                    downsample_basic_block,
                    planes=planes * block.expansion,
                    stride=stride,
                    no_cuda=self.no_cuda)
            else:
                downsample = nn.Sequential(
                    nn.Conv3d(
                        self.inplanes,
                        planes * block.expansion,
                        kernel_size=1,
                        stride=stride,
                        bias=False), 
                    nn.BatchNorm3d(planes * block.expansion))

        layers = []
        layers.append(block(self.inplanes, planes, stride=stride, dilation=dilation, downsample=downsample))
        # print(downsample)
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, dilation=dilation))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        logits = self.clf_head(x)

#         if(self.num_seg_classes > 0):
#             seg_mask = self.conv_seg(x)
#             return logits, seg_mask
        
        return logits

def resnet10(**kwargs):
    """Constructs a ResNet-18 model.
    """
    model = ResNet(BasicBlock, [1, 1, 1, 1], **kwargs)
    return model


def resnet18(**kwargs):
    """Constructs a ResNet-18 model.
    """
    model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
    return model


def resnet34(**kwargs):
    """Constructs a ResNet-34 model.
    """
    model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
    return model


def resnet50(**kwargs):
    """Constructs a ResNet-50 model.
    """
    model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
    return model


def resnet101(**kwargs):
    """Constructs a ResNet-101 model.
    """
    model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
    return model


def resnet152(**kwargs):
    """Constructs a ResNet-101 model.
    """
    model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
    return model


def resnet200(**kwargs):
    """Constructs a ResNet-101 model.
    """
    model = ResNet(Bottleneck, [3, 24, 36, 3], **kwargs)
    return model

import torch
from torch import nn

def get_medicalnet_resnet_model(model_name, inp_w, inp_h, inp_d, short_cut_type='B', num_classes=1, num_seg_classes=1, backbone_pretrained=None):
    model_func = globals()[model_name]
    model = model_func(
                sample_input_W=inp_w,
                sample_input_H=inp_h,
                sample_input_D=inp_d,
                shortcut_type=short_cut_type,
                no_cuda=False,
                num_classes = num_classes,
                num_seg_classes=num_seg_classes)
    
    if(backbone_pretrained is not None):
        print('Load pretrained:', backbone_pretrained)
        net_dict = model.state_dict()
        pretrain = torch.load(backbone_pretrained, map_location='cpu')
        pretrain_dict = {k.replace('module.', ''): v for k, v in pretrain['state_dict'].items() if k.replace('module.', '') in net_dict.keys()}
        net_dict.update(pretrain_dict)
        model.load_state_dict(net_dict)

    return model

def get_model(candidate):
    dim = candidate.get('dim', DIM)
    if('resnet' in candidate['backbone_name']):
        model = get_medicalnet_resnet_model(candidate['backbone_name'], dim[1], dim[0], dim[2], num_classes=NUM_CLASSES,
                                                num_seg_classes=NUM_SEG_CLASSES, backbone_pretrained=candidate.get('backbone_pretrained'))
    elif('efficientnet' in candidate['backbone_name']):
        model = monai.networks.nets.efficientnet.EfficientNetBN(model_name=candidate['backbone_name'],spatial_dims=3, in_channels=1,
                                                pretrained=False, num_classes=NUM_CLASSES)
    else:
        raise ValueError('No such backbone name: '+ candidate['backbone_name'])
    return model

def predict_fn(dataloader,model,scaler, device='cuda:0'):
    model.eval()
  
    tk0 = tqdm(enumerate(dataloader), total=len(dataloader))
    all_predictions = []
    for i, batch in tk0:
        # input, gt
        voxels = batch
        voxels = voxels.to(device)

        # prediction
        with torch.cuda.amp.autocast(), torch.no_grad():
            logits = model(voxels)
            logits = logits.view(-1)
            
            logits[torch.isnan(logits)] = 0
            
            probs = logits.sigmoid()
     
        # append for metric calculation
        all_predictions.append(probs.detach().cpu().numpy())
        
        del batch, voxels, logits, probs
        torch.cuda.empty_cache()

    all_predictions = np.concatenate(all_predictions)
    return all_predictions

In [None]:
test_df = pd.DataFrame(os.listdir(IM_FOLDER), columns=['pfolder'])
test_df['BraTS21ID'] = test_df['pfolder'].map(lambda x: int(x.split('_')[-1]))

for t in SHORT_MRI_TYPES:
    test_df[f'{t}_data_path'] = test_df.pfolder.map(lambda x: os.path.join(IM_FOLDER, x, x+f'_{t}.nii.gz'))

In [None]:
test_df.head()

In [None]:
test_transforms = Compose([AddChannel(), ScaleIntensity()])
mri_type = SHORT_MRI_TYPES[0]

test_dataset = ImageDataset(image_files=test_df[f'{mri_type}_data_path'].tolist(),
                            transform=test_transforms)

In [None]:
# voxels, labels = next(iter(train_loader))
voxels  = test_dataset[0]
visualize_3_planes(voxels[0])

In [None]:
test_ensembled_prediction = 0
    
for candidate in CANDIDATES:
    
    # create data loader
    mri_type = candidate.get('mri_type')
    test_dataset = ImageDataset(image_files=test_df[f'{mri_type}_data_path'].tolist(),
                            transform=test_transforms)

    batch_size = candidate.get('batch_size', BATCH_SIZE)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False,
                    num_workers=4, pin_memory=torch.cuda.is_available())

    # Model
    model = get_model(candidate)
    print('Load trained model:', candidate['model_path'] )
    model.load_state_dict(torch.load(candidate['model_path'], map_location='cpu'))
    model = model.to(DEVICE)
    print()

    scaler = torch.cuda.amp.GradScaler()

    test_ensembled_prediction += predict_fn(test_loader, model, scaler, device=DEVICE)

test_ensembled_prediction /= len(CANDIDATES)

In [None]:
test_df['MGMT_value'] = test_ensembled_prediction

In [None]:
sub = pd.read_csv('../input/rsna-miccai-brain-tumor-radiogenomic-classification/sample_submission.csv')

In [None]:
test_df = test_df.merge(sub[['BraTS21ID']], on='BraTS21ID', how='right')

In [None]:
test_df

In [None]:
test_df[['BraTS21ID', 'MGMT_value']].head()

In [None]:
test_df[['BraTS21ID', 'MGMT_value']].to_csv('submission.csv', index=False)

In [None]:
ls

In [None]:
# clear working dir
!rm -rf BraTS2021_Testing_Data