In [None]:
import os
import sys 
import json
import glob
import random
import re
import collections
import time

import numpy as np
import pandas as pd
import pydicom
import cv2
import matplotlib.pyplot as plt
import seaborn as sns

import torch
from torch import nn
from torch.utils import data as torch_data
from sklearn import model_selection as sk_model_selection
from torch.nn import functional as torch_functional

from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import roc_auc_score

In [None]:
!pip install /kaggle/input/monai-private/monai-0.7.0-202109240007-py3-none-any.whl

In [None]:
from monai.networks.nets.resnet import ResNet, resnet34, resnet50, resnet101, resnet152, resnet200

In [None]:
data_directory = '/kaggle/input/rsna-miccai-brain-tumor-radiogenomic-classification'
input_monaipath = "/kaggle/input/monai-v060-deep-learning-in-healthcare-imaging/"
monaipath = "/kaggle/tmp/monai/"

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

paths_flair = [
    "../input/resnet50-weights/resnet34_fold0_best_loss_FLAIR_mri_type.pth",
    "../input/resnet50-weights/resnet34_fold1_best_loss_FLAIR_mri_type.pth",
    "../input/resnet50-weights/resnet34_fold2_best_loss_FLAIR_mri_type.pth",
    "../input/resnet50-weights/resnet34_fold3_best_loss_FLAIR_mri_type.pth",
    "../input/resnet50-weights/resnet34_fold4_best_loss_FLAIR_mri_type.pth",
]

paths_t1w = [
    "../input/resnet50-weights/resnet34_fold0_best_loss_T1w_mri_type.pth",
    "../input/resnet50-weights/resnet34_fold1_best_loss_T1w_mri_type.pth",
    "../input/resnet50-weights/resnet34_fold2_best_loss_T1w_mri_type.pth",
    "../input/resnet50-weights/resnet34_fold3_best_loss_T1w_mri_type.pth",
    "../input/resnet50-weights/resnet34_fold4_best_loss_T1w_mri_type.pth",
]

paths_t1wCE = [
    "../input/resnet50-weights/resnet34_fold0_best_loss_T1wCE_mri_type.pth",
    "../input/resnet50-weights/resnet34_fold1_best_loss_T1wCE_mri_type.pth",
    "../input/resnet50-weights/resnet34_fold2_best_loss_T1wCE_mri_type.pth",
    "../input/resnet50-weights/resnet34_fold3_best_loss_T1wCE_mri_type.pth",
    "../input/resnet50-weights/resnet34_fold4_best_loss_T1wCE_mri_type.pth",
]

paths_t2w = [
    "../input/resnet50-weights/resnet34_fold0_best_loss_T2w_mri_type.pth",
    "../input/resnet50-weights/resnet34_fold1_best_loss_T2w_mri_type.pth",
    "../input/resnet50-weights/resnet34_fold2_best_loss_T2w_mri_type.pth",
    "../input/resnet50-weights/resnet34_fold3_best_loss_T2w_mri_type.pth",
    "../input/resnet50-weights/resnet34_fold4_best_loss_T2w_mri_type.pth",
]

In [None]:
!mkdir -p {monaipath}
!cp -r {input_monaipath}/* {monaipath}

In [None]:
mri_types = ['FLAIR']
SIZE = 256
NUM_IMAGES = 64
BATCH_SIZE = 4
N_EPOCHS = 16
SEED = 12345
LEARNING_RATE = 0.0005
LR_DECAY = 0.9

sys.path.append(monaipath)

from monai.networks.nets.resnet import resnet34

## Functions to load images

In [None]:
def load_dicom_image(path, img_size=SIZE):
    dicom = pydicom.read_file(path)
    data = dicom.pixel_array
    if np.min(data)==np.max(data):
        data = np.zeros((img_size,img_size))
        return data
    
    data = cv2.resize(data, (img_size, img_size))
    return data


def natural_sort(l): 
    convert = lambda text: int(text) if text.isdigit() else text.lower()
    alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)]
    return sorted(l, key=alphanum_key)


def load_dicom_images_3d(scan_id, num_imgs=NUM_IMAGES, img_size=SIZE, mri_type="FLAIR", split="train"):
    files = natural_sort(glob.glob(f"{data_directory}/{split}/{scan_id}/{mri_type}/*.dcm"))
    
    every_nth = len(files) / num_imgs
    indexes = [min(int(round(i*every_nth)), len(files)-1) for i in range(0,num_imgs)]
    
    files_to_load = [files[i] for i in indexes]
    
    img3d = np.stack([load_dicom_image(f) for f in files_to_load]).T 
    
    img3d = img3d - np.min(img3d)
    if np.max(img3d) != 0:
        img3d = img3d / np.max(img3d)
    
    return np.expand_dims(img3d,0)


load_dicom_images_3d("00000", mri_type=mri_types[0]).shape

In [None]:
def set_seed(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True

set_seed(SEED)

## train / test splits

In [None]:
samples_to_exclude = [109, 123, 709]

train_df = pd.read_csv(f"{data_directory}/train_labels.csv")
print("original shape", train_df.shape)
train_df = train_df[~train_df.BraTS21ID.isin(samples_to_exclude)]
print("new shape", train_df.shape)
display(train_df)

df_train, df_valid = sk_model_selection.train_test_split(
    train_df, 
    test_size=0.2, 
    random_state=SEED, 
    stratify=train_df["MGMT_value"],
)


In [None]:
df_train.tail()

## Model and training classes

In [None]:
class Dataset(torch_data.Dataset):
    def __init__(self, paths, targets=None, mri_type=None, split="train"):
        self.paths = paths
        self.targets = targets
        self.mri_type = mri_type
        self.split = split
          
    def __len__(self):
        return len(self.paths)
    
    def __getitem__(self, index):
        scan_id = self.paths[index]
        if self.targets is None:
            data = load_dicom_images_3d(str(scan_id).zfill(5), mri_type=self.mri_type[index], split=self.split)
        else:
            data = load_dicom_images_3d(str(scan_id).zfill(5), mri_type=self.mri_type[index], split="train")
            
        if self.targets is None:
            return {"X": data, "id": scan_id}
        else:
            return {"X": data, "y": torch.tensor(self.targets[index], dtype=torch.float)}


In [None]:
def build_model():
    model = resnet34(spatial_dims=3, n_input_channels=1, num_classes=1)
    return model    

# Prediction

In [None]:
def predict(path, df, mri_type, split):
    print("Predict:", path, mri_type, df.shape)
    df.loc[:,"MRI_Type"] = mri_type
    data_retriever = Dataset(
        df.index.values, 
        mri_type=df["MRI_Type"].values,
        split=split
    )

    data_loader = torch_data.DataLoader(
        data_retriever,
        batch_size=4,
        shuffle=False,
        num_workers=8,
    )
   
    model = build_model()
    model.to(device)
    
    state_dict = torch.load(path, map_location=torch.device('cpu'))
    model.load_state_dict(state_dict)
    
    model.eval()
    
    y_pred = []
    ids = []

    for e, batch in enumerate(data_loader,1):
        print(f"{e}/{len(data_loader)}", end="\r")
        with torch.no_grad():
            tmp_pred = torch.sigmoid(model(torch.tensor(batch["X"]).float().to(device)).squeeze(1)).cpu().numpy().squeeze()
            if tmp_pred.size == 1:
                y_pred.append(tmp_pred)
            else:
                y_pred.extend(tmp_pred.tolist())
            ids.extend(batch["id"].numpy().tolist())
            
    preddf = pd.DataFrame({"BraTS21ID": ids, "MGMT_value": y_pred}) 
    preddf = preddf.set_index("BraTS21ID")
    return preddf

**Submission**

In [None]:
submission_flair = pd.read_csv(f"{data_directory}/sample_submission.csv", index_col="BraTS21ID")

submission_flair["MGMT_value"] = 0
for m in paths_flair:
    print(m)
    pred = predict(m, submission_flair, "FLAIR", split="test")
    submission_flair["MGMT_value"] += pred["MGMT_value"]

submission_flair["MGMT_value"] /= len(paths_flair)
submission_flair.head()


In [None]:
submission_t1w = pd.read_csv(f"{data_directory}/sample_submission.csv", index_col="BraTS21ID")

submission_t1w["MGMT_value"] = 0
for m in paths_t1w:
    print(m)
    pred = predict(m, submission_t1w, "T1w", split="test")
    submission_t1w["MGMT_value"] += pred["MGMT_value"]

submission_t1w["MGMT_value"] /= len(paths_t1w)
submission_t1w.head()

In [None]:
submission_t1wCE = pd.read_csv(f"{data_directory}/sample_submission.csv", index_col="BraTS21ID")

submission_t1wCE["MGMT_value"] = 0
for m in paths_t1wCE:
    print(m)
    pred = predict(m, submission_t1wCE, "T1wCE", split="test")
    submission_t1wCE["MGMT_value"] += pred["MGMT_value"]

submission_t1wCE["MGMT_value"] /= len(paths_t1wCE)
submission_t1wCE.head()

In [None]:
submission_t2w = pd.read_csv(f"{data_directory}/sample_submission.csv", index_col="BraTS21ID")

submission_t2w["MGMT_value"] = 0
for m in paths_t2w:
    print(m)
    pred = predict(m, submission_t2w, "T2w", split="test")
    submission_t2w["MGMT_value"] += pred["MGMT_value"]

submission_t2w["MGMT_value"] /= len(paths_t2w)
submission_t2w.head()

In [None]:
submission = pd.read_csv(f"{data_directory}/sample_submission.csv", index_col="BraTS21ID")

In [None]:
submission['MGMT_value'] = 0
submission['MGMT_value'] = (submission_flair['MGMT_value'] + submission_t1w["MGMT_value"] + submission_t1wCE["MGMT_value"] + submission_t2w["MGMT_value"]) / 4.0
submission["MGMT_value"].to_csv("submission.csv")