# RSNA MICCAI Brain Tumor classifier -- Ordered MRI voxel data

## Acknowledgement:

Voxel Data Ordering
* https://www.kaggle.com/davidbroberts/determining-dicom-image-order

MobileNetV3 for 2D images
* https://arxiv.org/abs/1905.02244
* https://github.com/pytorch/vision/blob/v0.9.0/torchvision/models/mobilenetv3.py

Troubleshooting
* https://fullstackdeeplearning.com/spring2021/lecture-7/

Other Works on 3D MRI voxel data
* http://www.ajnr.org/content/ajnr/42/5/845.full.pdf

In [None]:
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

import pydicom
import cv2 as cv

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset, Subset
# from torchvision import models
import torchvision
import kornia as K  # batch image augmentations with torch.Tensor
from kornia.augmentation import AugmentationSequential
from kornia.augmentation.base import AugmentationBase3D  # Subclassing this is too complicated.
from kornia.enhance import invert

from tqdm.notebook import tqdm

from pathlib import Path
from typing import Union, Tuple, List, Optional, Type, Dict, Iterable
import time

DEBUG = False
REPRODUCTIVE = True
INFERENCE_ONLY = True
USE_CROSS_VALIDATION = True

random_state = 42
model_name = "Net-3D"
data_dir = Path("../input/rsna-miccai-brain-tumor-radiogenomic-classification")
models_dir = Path("../input/model-weights-for-rsna-miccai-brain-tumor-dataset")
# models_dir = Path(".")  # If train model with local machine

device = "cuda" if torch.cuda.is_available() else "cpu"

time_begin = time.time()

if REPRODUCTIVE:
    np.random.seed(random_state)
    torch.random.manual_seed(random_state)
display(list(data_dir.iterdir()), torch.__version__, torchvision.__version__)

## Data manipulation

In [None]:
mri_series = {0: "FLAIR", 1: "T1w", 2: "T1wCE", 3: "T2w"}
mri_series_map = {v: k for k, v in mri_series.items()}
planes = {0: "Unknown", 1: "Coronal", 2: "Sagittal", 3: "Axial"}
planes_map = {v: k for k, v in planes.items()}

In [None]:
labels_train = pd.read_csv(data_dir / "train_labels.csv", dtype={"BraTS21ID": str})
labels_train

In [None]:
def look_one_dcm(instance_id: str, img_dir: Path, mri_series="FLAIR", verbose=False):
    dcm_paths = list(img_dir.glob("./{}/{}/*.dcm".format(instance_id.zfill(5), mri_series)))
    print("Containing {} dicom files(including blank).".format(len(dcm_paths)))
    if dcm_paths:
        dcm_mid = dcm_paths[(len(dcm_paths) - 1) // 2]
        dcm_ds = pydicom.read_file(str(dcm_mid))
        if verbose:
            print(dir(dcm_ds))
            print(dcm_ds)
            print(type(dcm_ds[("0010", "0010")].value))
            print(dcm_ds[("0020", "0032")].name, eval(str(dcm_ds[("0020", "0032")].value)))
            print(dir(dcm_ds[("0020", "0032")]))
            print(dcm_ds.pixel_array.dtype)
        plt.imshow(dcm_ds.pixel_array, cmap=plt.cm.gray)
        plt.show()


look_one_dcm("00000", data_dir / "train", verbose=True)

In [None]:
def get_image_plane(loc):
    row_x, row_y, row_z, col_x, col_y, col_z = [round(v) for v in loc]
    if (row_x, row_y, col_x, col_y) == (1, 0, 0, 0): return planes[1]
    if (row_x, row_y, col_x, col_y) == (0, 1, 0, 0): return planes[2]
    if (row_x, row_y, col_x, col_y) == (1, 0, 0, 1): return planes[3]
    return planes[0]


class DICOMMetaLoader(Dataset):
    
    def __init__(self, img_dir: Path, glob=None):
        super(DICOMMetaLoader, self).__init__()
        if glob is None:
            glob = "./*/*/*.dcm"
        self.dcm_paths = list(img_dir.glob(glob))
    
    def __len__(self): return len(self.dcm_paths)
    
    def __getitem__(self, idx) -> dict:
        dcm_path = str(self.dcm_paths[idx])
        dcm_obj = pydicom.read_file(dcm_path)
        photometric = str(dcm_obj[0x28, 0x04])
        array = dcm_obj.pixel_array
        if photometric == "MONOCHROME1":
            info_func = np.iinfo if np.issubdtype(array.dtype, np.integer) else np.finfo
            array = info_func(array.dtype).max - array
        image_mean, image_std = np.mean(array), np.std(array)
        
        impo_x, impo_y, impo_z = [float(v) for v in dcm_obj[0x20, 0x32]]
        plane = get_image_plane(dcm_obj[0x20, 0x37])
        
        patient_id = str(dcm_obj[0x0010, 0x0020].value).strip().zfill(5)
        series_desc = str(dcm_obj[0x0008, 0x103e].value).strip()
        row = dict(dcm_path=dcm_path, BraTS21ID=patient_id, series_description=series_desc,
                   image_mean=image_mean, image_std=image_std,
                   plane=plane,
                   image_position_x=impo_x, image_position_y=impo_y, image_position_z=impo_z)
        return row


def get_meta_from_glob(img_dir: Path, glob=None) -> pd.DataFrame:
    dcm_ds = DICOMMetaLoader(img_dir, glob)
    dcm_dl = DataLoader(dcm_ds, batch_size=256, num_workers=6)
    df = pd.DataFrame()
    for item in tqdm(dcm_dl):
        chunks = pd.DataFrame.from_dict({k:np.asarray(v) for k, v in item.items()})
        df = pd.concat([df, chunks], ignore_index=True)
    return df


df_train = get_meta_from_glob(data_dir / "train")

In [None]:
# To categorical data by mapping, 

df_train.loc[:, "plane"] = df_train.loc[:, "plane"].map(planes_map)
df_train.loc[:, "series_description"] = df_train.loc[:, "series_description"].map(mri_series_map)

In [None]:
def keep_non_blank(df: pd.DataFrame):
    """
    Keep data containing non blank image.
    :params:
        df: pd.DataFrame, requires "image_std" and "image_mean" in df.columns.
    :returns:
        pd.DataFrame: filtered DataFrame
    """
    df = df.loc[(df["image_std"] > 0) & (df["image_mean"] > 0)]
    return df


display(len(df_train))
df_train = keep_non_blank(df_train)
display(len(df_train))

In [None]:
def drop_by_id(df: pd.DataFrame, ids: List[Union[int, str]]):
    ids = [str(s).zfill(5) for s in ids]
    df = df.loc[~(df["BraTS21ID"].isin(ids))].reset_index(drop=True)
    return df


drop_ids = "00109, 00123, 00709".split(", ")
df_train = drop_by_id(df_train, drop_ids)
labels_train = drop_by_id(labels_train, drop_ids)

In [None]:
def count_values(df: pd.DataFrame):
    groupby = df.groupby(["BraTS21ID", "series_description"])
    count = groupby.count()
    display(count["dcm_path"].describe())
    display(count.loc[count["dcm_path"] == count["dcm_path"].min(), "dcm_path"])
    display(count.loc[count["dcm_path"] == count["dcm_path"].max(), "dcm_path"])


display(df_train.describe())
count_values(df_train)
look_one_dcm("00571", data_dir / "train", mri_series[0])
look_one_dcm("00818", data_dir / "train", mri_series[0])
look_one_dcm("00012", data_dir / "train", mri_series[3])

In [None]:
def is_retrievable(df: pd.DataFrame,
                   patient_id: str,
                   series_desc_idx: int):
    retrieved_idx = (df["BraTS21ID"].eq(patient_id)) & (df["series_description"].eq(series_desc_idx))
    return True if retrieved_idx.sum() > 0 else False


def get_voxel_by_id_series(df: pd.DataFrame,
                           patient_id: str,
                           series_desc_idx: int = 0,
                           size: Union[int, Tuple[int, int]] = 256) -> Tuple[np.ndarray, int]:
    """
    :params:
        :df: required columns: [dcm_path, BraTS21ID, series_description, plane,
                                image_position_x, image_position_y, image_position_z]
    """
    size = (int(size), int(size)) if isinstance(size, (int, float)) else size
    retrieved_idx = (df["BraTS21ID"].eq(patient_id)) & (df["series_description"].eq(series_desc_idx))
    assert retrieved_idx.sum() > 0, "Nothing retrived."
    retrieved_df = df.loc[retrieved_idx].copy()
    plane = retrieved_df["plane"].unique()
    assert len(plane) == 1, "Different plane in a folder."
    img_pos_cols = [c for c in retrieved_df.columns if c.startswith("image_position_")]
    img_pos_stds = np.array([retrieved_df[c].std() for c in img_pos_cols])
    img_pos_argsort = np.argsort(img_pos_stds)[::-1]
    sorted_df = retrieved_df.sort_values([img_pos_cols[i] for i in img_pos_argsort], ascending=True, ignore_index=True)
    voxel_stack = list()
    for row in sorted_df.itertuples():
        dcm_obj = pydicom.read_file(row.dcm_path)
        array = dcm_obj.pixel_array
        array = cv.resize(array, size)
        dinfo = np.iinfo(array.dtype) if np.issubdtype(array.dtype, np.integer) else np.finfo(array.dtype)
        array = (array / dinfo.max).astype(np.float32)  # like (a / 255) if a.dtype is uint8
        if dcm_obj[0x0028, 0x0004] == "MONOCHROME1":
            array = dinfo.max - array
        voxel_stack.append(array)
    voxel = np.stack(voxel_stack)
    voxel = (voxel - np.min(voxel)) / max(np.max(voxel), 1e-8)  # min-max normalization
    return voxel, plane[0]


def plot_voxel(voxel, max_n_plots=10, cols=10):
    actual_n_plots = min(max_n_plots, len(voxel))
    rows = int(np.ceil(actual_n_plots / cols))
    fig, axes = plt.subplots(rows, cols, figsize=(4 * cols, 4 * rows), tight_layout=True)
    for i in range(actual_n_plots):
        axes[i // cols, i % cols].imshow(voxel[i, :, :], cmap=plt.cm.gray)
        axes[i // cols, i % cols].set_axis_off()
    plt.show()


vox, plane = get_voxel_by_id_series(df_train, "00571", size=256)
plot_voxel(vox, 14, 7)
vox, plane = get_voxel_by_id_series(df_train, "00571", size=128)
plot_voxel(vox, 14, 7)
vox, plane = get_voxel_by_id_series(df_train, "00571", size=64)
plot_voxel(vox, 14, 7)
vox, plane = get_voxel_by_id_series(df_train, "00571", size=32)
plot_voxel(vox, 14, 7)
display(plane)

In [None]:
class MRIVoxelDataset(Dataset):
    
    def __init__(self, meta_df: pd.DataFrame, label_df: Optional[pd.DataFrame] = None,
                 voxel_size: Union[int, Tuple[int, int], Tuple[int, int, int]] = (64, 256, 256),
                 including_series: np.ndarray = np.array(list(mri_series.keys()), dtype=np.int64)):
        """
        :params:
            :meta_df: required columns: [dcm_path, BraTS21ID, series_description, plane,
                                         image_position_x, image_position_y, image_position_z]
            :label_df(Optional): required columns: [BraTS21ID, MGMT_value]
            :voxel_size: if int, the D, H, W will be set to the same;
                         if (int, int), D by voxel_size[0], H, W by voxel_size[1];
                         if (int, int, int), D, H, W will be set respectively.
        """
        super(MRIVoxelDataset, self).__init__()
        self.meta_df,self.label_df,self.voxel_size = meta_df,label_df,voxel_size
        self.including_series = including_series
        if isinstance(self.voxel_size, int):
            self.voxel_size = tuple(self.voxel_size for _ in range(3))
        elif isinstance(self.voxel_size, tuple):
            if len(self.voxel_size) == 2:
                self.voxel_size = (self.voxel_size[0], self.voxel_size[1], self.voxel_size[1])
        self.meta_df = self.meta_df.loc[self.meta_df["series_description"].isin(self.including_series)].copy()
        if self.label_df is None:
            self.label_df = pd.concat([pd.DataFrame.from_dict(
                dict(BraTS21ID=self.meta_df["BraTS21ID"].unique())
            )], axis=1)
            self.label_df.loc[:, "BraTS21ID"] = self.label_df["BraTS21ID"].map(lambda i: str(i).zfill(5))
            labels = np.full_like(self.label_df["BraTS21ID"].values, np.nan, dtype=np.float64)
            self.label_df.loc[:, "MGMT_value"] = labels

        new_label_df = pd.DataFrame()
        for v in self.meta_df["series_description"].unique():
            series_desc = pd.DataFrame({
                    "series_description": np.full((len(self.label_df)), v, dtype=np.int64)
                 })
            df = self.label_df.reset_index(drop=True)
            df = pd.concat([df, series_desc], axis=1)
            new_label_df = pd.concat([new_label_df, df], axis=0)
        self.label_df = new_label_df.reset_index(drop=True)

        retrievables = list()
        for i in range(len(self.label_df)):
            row = self.label_df.iloc[i]
            flag = is_retrievable(self.meta_df, row.BraTS21ID, row.series_description)
            if not flag:
                print(row.BraTS21ID, row.series_description)
            retrievables.append(flag)
        retrievables = np.asarray(retrievables)
        self.label_df = self.label_df.iloc[retrievables]
        print(f"Got {len(self)} samples in dataset.")

    def __len__(self): return len(self.label_df)
    
    def __getitem__(self, idx):
        row = self.label_df.iloc[idx]
        voxel, plane = get_voxel_by_id_series(self.meta_df, row["BraTS21ID"], row["series_description"], self.voxel_size[1:])
        voxel = torch.tensor(voxel, dtype=torch.float32).unsqueeze(0).unsqueeze(0)  # [N, C, D, H, W]
        voxel = F.interpolate(voxel, self.voxel_size, mode="trilinear", align_corners=False)
        voxel = voxel.squeeze(0)
        label = torch.tensor([row["MGMT_value"]], dtype=torch.float32)
        plane = torch.tensor(plane, dtype=torch.int64)
        series_desc = torch.tensor(row["series_description"], dtype=torch.int64)
        return voxel, label, (series_desc, plane)


if DEBUG:
    ds_ = MRIVoxelDataset(df_train, labels_train, (64, 128), np.array([0], dtype=np.int64))
    dl_ = DataLoader(ds_, batch_size=4, num_workers=4)
    for voxel, label, (series_desc, plane) in dl_:
        print(voxel.shape, label.shape, plane.shape, series_desc.shape)
        print(voxel.dtype, label.dtype, plane.dtype, series_desc.dtype)
        break

## The architecture of model

In [None]:
NormLayerClass = Type
ActivationLayerClass = Type


class SqueezeExcitation(nn.Module):
    
    def __init__(self, in_channels):
        super(SqueezeExcitation, self).__init__()
        self.in_channels = in_channels
        self.squeeze_channels = self.in_channels // 4
        
        self.seq = nn.Sequential(
            nn.AdaptiveAvgPool3d(1),
            nn.Conv3d(self.in_channels, self.squeeze_channels, 1),
            nn.ReLU(inplace=True),
            nn.Conv3d(self.squeeze_channels, self.in_channels, 1),
            nn.Hardsigmoid(inplace=True),
        )
    
    def forward(self, x):
        scale = self.seq(x)
        out = scale * x
        return out


class ConvBNActivation(nn.Module):
    
    def __init__(self, conv_config: dict,
                 norm_layer_cls: NormLayerClass = nn.BatchNorm3d,
                 activation_layer_cls: ActivationLayerClass = nn.ReLU,
                 use_se: bool = False,
        ) -> None:
        super(ConvBNActivation, self).__init__()
        layers = list()
        layers.append(nn.Conv3d(**conv_config))
        layers.append(norm_layer_cls(conv_config["out_channels"]))
        layers.append(activation_layer_cls(inplace=True))
        if use_se:
            layers.append(SqueezeExcitation(conv_config["out_channels"]))
        self.seq = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.seq(x)
    
    @staticmethod
    def config(in_channels: int,
               out_channels: int,
               kernel_size: Union[int, Tuple[int, int, int]],
               stride: Union[int, Tuple[int, int, int]] = 1,
               padding: Union[int, Tuple[int, int, int]] = 0,
               dilation: Union[int, Tuple[int, int, int]] = 1,
               groups: int = 1,
               bias: bool = True,
               padding_mode: str = 'zeros',
        ) -> dict:
        return locals()


class BottleNeck(nn.Module):
    
    def __init__(self, residual_config: dict):
        super(BottleNeck, self).__init__()
        self.residual_config = residual_config
        layers = list()
        layers.append(ConvBNActivation(ConvBNActivation.config(
            self.residual_config["in_channels"],
            self.residual_config["expand_channels"],
            1,
            1,
            0,
        ), self.residual_config["norm_layer_cls"], self.residual_config["activation_layer_cls"]))
        layers.append(ConvBNActivation(ConvBNActivation.config(
            self.residual_config["expand_channels"],
            self.residual_config["expand_channels"],
            self.residual_config["kernel_size"],
            self.residual_config["stride"],
            self.residual_config["padding"],
            groups=self.residual_config["expand_channels"],
        ), self.residual_config["norm_layer_cls"],
           self.residual_config["activation_layer_cls"],
           self.residual_config["use_se"]))
        layers.append(ConvBNActivation(ConvBNActivation.config(
            self.residual_config["expand_channels"],
            self.residual_config["out_channels"],
            1,
            1,
            0,
        ), self.residual_config["norm_layer_cls"], nn.Identity))
        self.seq = nn.Sequential(*layers)
        # The shortcut: Same as nn.Linear if channels at last dim.
        self.shortcut = nn.Conv3d(self.residual_config["in_channels"], self.residual_config["out_channels"], 1)
    
    def forward(self, x):
        post_seq = self.seq(x)
        x = self.shortcut(x)
        x = F.interpolate(x, post_seq.shape[-3:], mode="trilinear", align_corners=False)
        return x + post_seq
    
    @staticmethod
    def config(in_channels: int,
               out_channels: int,
               expand_channels: int,
               kernel_size: Union[int, Tuple[int, int, int]],
               stride: Union[int, Tuple[int, int, int]] = 1,
               padding: Union[int, Tuple[int, int, int]] = 0,
               norm_layer_cls: NormLayerClass = nn.BatchNorm3d,
               activation_layer_cls: ActivationLayerClass = nn.Hardswish,
               use_se: bool = False,
    ) -> dict:
        return locals()


class NetFeatures(nn.Module):
    
    def __init__(self, in_channels, out_channels, residual_config_list: List[dict]):
        super(NetFeatures, self).__init__()
        self.in_channels,self.out_channels = in_channels,out_channels
        self.residual_config_list = residual_config_list

        first_conv_out_channels = self.residual_config_list[0]["in_channels"]
        self.first_conv = ConvBNActivation(ConvBNActivation.config(
            self.in_channels, first_conv_out_channels, 3, 2, 1), activation_layer_cls=nn.ReLU)
        residual_layers = list()
        for conf in self.residual_config_list:
            residual_layers.append(BottleNeck(conf))
        self.residual_block = nn.Sequential(*residual_layers)
        last_conv_in_channels = self.residual_config_list[-1]["out_channels"]
        self.last_conv = ConvBNActivation(ConvBNActivation.config(last_conv_in_channels, self.out_channels, 1),
                                          nn.BatchNorm3d,
                                          nn.Hardswish,
                                          use_se=True)
    
    def forward(self, x):
        x = self.first_conv(x)
        x = self.residual_block(x)
        x = self.last_conv(x)
        return x


class ConcatEmbeddingLinear(nn.Module):
    
    def __init__(self, in_features: int, out_features: int, n_embeddings: int, embed_dim: Optional[int] = None):
        super(ConcatEmbeddingLinear, self).__init__()
        self.in_features,self.out_features = in_features,out_features
        self.n_embeddings,self.embed_dim = n_embeddings,embed_dim
        if self.embed_dim is None: self.embed_dim = self.in_features
        
        self.emb = nn.Embedding(self.n_embeddings, self.embed_dim)
        self.fc = nn.Linear(self.in_features + self.embed_dim, self.out_features)
    
    def forward(self, x, idx_emb):
        emb_out = self.emb(idx_emb)
        concatenated = torch.cat([emb_out, x], dim=-1)
        out = self.fc(concatenated)
        return out


class Net(nn.Module):
    
    def __init__(self, in_channels, feature_out_channels, hidden_features, n_classes, n_series, n_planes,
                 residual_config_list: List[dict]) -> None:
        super(Net, self).__init__()
        self.in_channels,self.feature_out_channels,self.n_classes = in_channels,feature_out_channels,n_classes
        self.hidden_features = hidden_features
        self.n_planes,self.n_series = n_planes,n_series
        self.residual_config_list = residual_config_list
        
        self.features = NetFeatures(self.in_channels, self.feature_out_channels, self.residual_config_list)
        self.pool_flat_linear = nn.Sequential(nn.AdaptiveAvgPool3d(1),
            nn.Flatten(),
            nn.Linear(self.features.out_channels, self.hidden_features),
        )
        self.emb_series = ConcatEmbeddingLinear(self.hidden_features, self.hidden_features, self.n_series)
        self.emb_planes = ConcatEmbeddingLinear(self.hidden_features, self.hidden_features, self.n_planes)
        self.classifier = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(self.hidden_features, self.n_classes)
        )

    def forward(self, x, idx_series, idx_planes):
        x = self.features(x)
        x = self.pool_flat_linear(x)
        x = self.emb_series(x, idx_series)
        x = self.emb_planes(x, idx_planes)
        out = self.classifier(x)
        return out


def get_residual_config_backup():
    # Like MobileNetV3 small, although it may be too deep.
    # in_channels, out_channels, expand_channels, kernel_size, stride, padding, norm, activation, use_se
    conf = list()
    conf.append(BottleNeck.config(16, 16, 16, 3, 2, 1, nn.BatchNorm3d, nn.Hardswish, True))
    conf.append(BottleNeck.config(16, 24, 72, 3, 2, 1, nn.BatchNorm3d, nn.ReLU, False))
    conf.append(BottleNeck.config(24, 24, 88, 3, 1, 1, nn.BatchNorm3d, nn.ReLU, False))
    conf.append(BottleNeck.config(24, 40, 96, 5, 2, 2, nn.BatchNorm3d, nn.ReLU, True))
    conf.append(BottleNeck.config(40, 40, 240, 5, 1, 2, nn.BatchNorm3d, nn.Hardswish, True))
    conf.append(BottleNeck.config(40, 40, 240, 5, 1, 2, nn.BatchNorm3d, nn.Hardswish, True))
    conf.append(BottleNeck.config(40, 48, 120, 5, 1, 2, nn.BatchNorm3d, nn.Hardswish, True))
    conf.append(BottleNeck.config(48, 48, 144, 5, 1, 2, nn.BatchNorm3d, nn.Hardswish, True))
    conf.append(BottleNeck.config(48, 96, 288, 5, 2, 2, nn.BatchNorm3d, nn.Hardswish, True))
    conf.append(BottleNeck.config(96, 96, 576, 5, 1, 2, nn.BatchNorm3d, nn.Hardswish, True))
    conf.append(BottleNeck.config(96, 96, 576, 5, 1, 2, nn.BatchNorm3d, nn.Hardswish, True))
    return conf


def get_residual_config():
    # in_channels, out_channels, expand_channels, kernel_size, stride, padding, norm, activation, use_se
    conf = list()
    conf.append(BottleNeck.config(16, 16, 16, 3, 2, 1, nn.BatchNorm3d, nn.Hardswish, True))
    conf.append(BottleNeck.config(16, 24, 72, 3, 2, 1, nn.BatchNorm3d, nn.ReLU, False))
    conf.append(BottleNeck.config(24, 24, 88, 3, 1, 1, nn.BatchNorm3d, nn.ReLU, False))
    conf.append(BottleNeck.config(24, 40, 96, 5, 2, 2, nn.BatchNorm3d, nn.ReLU, True))
    conf.append(BottleNeck.config(40, 40, 240, 5, 1, 2, nn.BatchNorm3d, nn.Hardswish, True))
    conf.append(BottleNeck.config(40, 80, 288, 5, 2, 2, nn.BatchNorm3d, nn.Hardswish, True))
    conf.append(BottleNeck.config(80, 96, 576, 5, 1, 2, nn.BatchNorm3d, nn.Hardswish, True))
    return conf


if DEBUG:
    t_ = torch.ones(4, 1, 64, 256, 256, dtype=torch.float32)
    l_ = torch.ones(4, 1, dtype=torch.float32)
    s_ = torch.ones(4, dtype=torch.int64)
    p_ = torch.ones(4, dtype=torch.int64)
    config_ = get_residual_config()
    net_ = Net(1, 512, 512, 1, 4, 4, config_).to(dtype=torch.float32)
    print(net_)
    with torch.no_grad():
        o_ = net_(t_, s_, p_)
        loss_ = F.binary_cross_entropy_with_logits(o_, l_)
        print(loss_.item())

## Voxel MRI augmentations

In [None]:
class RandomInvert3D(AugmentationBase3D):
    
    def __init__(
        self,
        max_val: Union[float, torch.Tensor] = torch.tensor(1.0),
        return_transform: bool = False,
        same_on_batch: bool = False,
        p: float = 0.5,
    ) -> None:
        super(RandomInvert3D, self).__init__(
            p=p, return_transform=return_transform, same_on_batch=same_on_batch, p_batch=1.0
        )
        self.max_val = max_val

    def __repr__(self) -> str:
        return self.__class__.__name__ + f"({super().__repr__()})"
    
    def generate_parameters(self, batch_shape: torch.Size):
        return dict(max_val=torch.as_tensor(self.max_val), batch_shape=torch.as_tensor(batch_shape))
    
    def compute_transformation(self, input, params: Dict[str, torch.Tensor]):
        return self.identity_matrix(input)

    def apply_transform(
        self, input: torch.Tensor,
        params: Dict[str, torch.Tensor],
        transform: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        max_val = params["max_val"]
        return invert(input, max_val)

    
Numeric = Union[int, float]


class RandomShift3D(nn.Module):
    
    def __init__(self,
                 shift_limit: Union[Numeric, List[Numeric], Tuple[Numeric, Numeric]] = 0.125,
                 p: float = 0.5):
        super(RandomShift3D, self).__init__()
        self.shift_limit,self.p = shift_limit,p
        if isinstance(self.shift_limit, (float, int)):
            self.shift_limit = np.array(((-abs(self.shift_limit), abs(self.shift_limit)),
                                         (-abs(self.shift_limit), abs(self.shift_limit)),
                                         (-abs(self.shift_limit), abs(self.shift_limit)),), dtype=np.float64)
        elif isinstance(self.shift_limit, (tuple, list)):
            self.shift_limit = np.array(self.shift_limit, dtype=np.float64)
        else:
            raise TypeError("shift_limit expects ")
        self.shift_limit = np.clip(self.shift_limit, -1., 1.)
        if self.shift_limit.shape[0] == 1:
            self.shift_limit = np.concatenate([self.shift_limit, self.shift_limit, self.shift_limit])
        assert self.shift_limit.shape == (3, 2), f""
    
    def forward(self, tensor):
        assert len(tensor.shape) == 5, f"Requires 5 dims torch.Tensor[N, C, D, H, W], got {tensor.shape}"
        n, c, d, h, w = tensor.shape
        apply_proba = np.random.uniform(size=(n,))
        shift_size = np.random.uniform(low=self.shift_limit[:, 0], high=self.shift_limit[:, 1], size=(n, 3))
        shift_d, shift_h, shift_w = (np.array(tensor.shape[2:])[np.newaxis, :] * shift_size).astype(np.int64).T
        out = torch.zeros_like(tensor)
        for i in range(n):
            if apply_proba[i] <= self.p:
                out[i, :,
                    max(0, 0+shift_d[i]):min(d, d+shift_d[i]),
                    max(0, 0+shift_h[i]):min(h, h+shift_h[i]),
                    max(0, 0+shift_w[i]):min(w, w+shift_w[i]),
                ] = tensor[i, :,
                    max(0, 0-shift_d[i]):min(d, d-shift_d[i]),
                    max(0, 0-shift_h[i]):min(h, h-shift_h[i]),
                    max(0, 0-shift_w[i]):min(w, w-shift_w[i]),
                ]
            else:
                out[i] = tensor[i]  # Unchanged.
        return out


def get_augmentation(split="train") -> nn.Sequential:
    """
    Get Sequence of augmentations.
    :return: nn.Sequential: requires input: torch.FloatTensor[N, C, D, H, W] in range[0., 1.]
    """
    if split in ("test", "val"):
        aug_list = nn.Sequential()
    elif split == "train":
        aug_list = nn.Sequential(
            K.augmentation.RandomAffine3D(degrees=(5., 5., 90.), translate=(.05, .05, .05), scale=(.98, 1.02), p=.3),
            K.augmentation.RandomHorizontalFlip3D(p=.3),
#             K.augmentation.RandomVerticalFlip3D(p=.1),
#             K.augmentation.RandomRotation3D((0., 0., 90.), p=1.0)
            RandomShift3D(shift_limit=0.2, p=.3),
            RandomInvert3D(p=.1),
        )
    else:
        raise ValueError(f"Argument `split` must in {{'train', 'val', 'test'}}, got {split}")
    aug_list.requires_grad_(False)
    return aug_list


def plot_grid(t: torch.tensor) -> None:
    """
    Plot image by middle index
    :argument: t: torch.Tensor[N, C, D, H, W]
    """
    from itertools import product
    a = int(np.ceil(np.sqrt(len(t))))
    fig, axes = plt.subplots(a, a, figsize=(14, 14))
    for nth, (i, j) in zip(range(len(t)), product(range(a), range(a))):
        nth_img = t[nth].squeeze(0).numpy()
        nth_img_mid = nth_img[len(nth_img) // 2]
        mean, std = np.mean(nth_img), np.std(nth_img)
        axes[i, j].imshow(nth_img_mid, cmap=plt.cm.gray)
        axes[i, j].set_title(f"mean: {mean:.4f}, std: {std:.4f}")
        axes[i, j].set_axis_off()
    plt.show()


# Check the effect of augmentation.
ds_ = MRIVoxelDataset(df_train, labels_train, (64, 128))
dl_ = DataLoader(ds_, batch_size=16, shuffle=True, num_workers=6)
aug_ = get_augmentation(split="train")
for voxel, label, (_, _) in dl_:
    voxel = aug_(voxel)
    plot_grid(voxel)
    break

## Training and evaluating functions

In [None]:
from sklearn.metrics import roc_auc_score


@torch.no_grad()
def evaluation(loader: DataLoader, model, aug_list, device):
    if isinstance(loader.dataset, Subset):
        df_copy: pd.DataFrame = (loader.dataset.dataset.label_df.iloc[loader.dataset.indices].copy())
    else:
        df_copy: pd.DataFrame = loader.dataset.label_df.copy()
    df_copy["MGMT_value_pred"] = np.full((len(df_copy),), fill_value=np.nan, dtype=np.float64)
    df_copy = df_copy.reset_index(drop=True)
    batch_size = loader.batch_size
    model.to(device)
    model.eval()
    loss_val = 0.
    for n, (voxel, label, (series_desc, plane)) in tqdm(enumerate(loader), desc="Evaluating with AUC", total=len(loader)):
        voxel, label, series_desc, plane = (aug_list(voxel).to(device), label.to(device),
                                            series_desc.to(device), plane.to(device))
        out = model(voxel, series_desc, plane)
        loss = nn.functional.binary_cross_entropy_with_logits(out, label)
        loss_val += loss.item()
        pred_proba = torch.sigmoid(out.detach())[:, 0].cpu().numpy()
        df_copy.iloc[n*batch_size:n*batch_size+len(voxel), df_copy.columns.get_loc("MGMT_value_pred")] = pred_proba
    loss_val /= max(1, len(loader))
    auc = roc_auc_score(df_copy["MGMT_value"].values.astype("int64"), df_copy["MGMT_value_pred"].values)
    return loss_val, auc


def train_an_epoch(loader, model, aug_list, device, optimizer, epoch_idx=0):
    model.train()
    model.to(device)
    loss_epoch = 0.
    for voxel, label, (series_desc, plane) in tqdm(loader, desc=f"Training epoch: {epoch_idx}"):
        voxel, label = aug_list(voxel).to(device), label.to(device)
        series_desc, plane = series_desc.to(device), plane.to(device)
        out = model(voxel, series_desc, plane)
        loss = nn.functional.binary_cross_entropy_with_logits(out, label)
        if np.isnan(loss.item()):
            raise ValueError("loss is nan")
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss_epoch += loss.item()
    loss_epoch /= max(1, len(loader))
    print(f"Epoch: {e}, train loss: {loss_epoch}")


if DEBUG:
    ds_ = MRIVoxelDataset(df_train, labels_train, (64, 128))
    sub_ = Subset(ds_, list(range(len(ds_)))[:101])
    dl_ = DataLoader(sub_, batch_size=4, num_workers=4)
    aug_list_ = get_augmentation(split="val")
    conf_ = get_residual_config()
    net_ = Net(1, 512, 512, 1, 4, 4, conf_).to(torch.float32)
    optim_ = torch.optim.Adam(net_.parameters(), lr=3e-4)
    for e in range(1):
        train_an_epoch(dl_, net_, aug_list_, device, optim_, e)
        loss_val, auc = evaluation(dl_, net_, aug_list_, device)
        print(loss_val, auc)

## Train the model with data pipeline

In [None]:
voxel_size = (64, 64, 64)
including_series = np.array([
    mri_series_map["FLAIR"],
    mri_series_map["T1w"],
    mri_series_map["T1wCE"],
    mri_series_map["T2w"],
], dtype=np.int64)
NumpyNDArray = Iterable

def get_dataset_in_pipeline(img_dir: Path, including_series: NumpyNDArray[np.int64],
                            voxel_size: Union[int, Tuple[int, int], Tuple[int, int, int]] = (64, 256, 256),
                            glob: str = None,
                            df_labels: pd.DataFrame = None, drop_ids: List[str] = None):
    df_meta = get_meta_from_glob(img_dir, glob)
    df_meta.loc[:, "plane"] = df_meta.loc[:, "plane"].map(planes_map)
    df_meta.loc[:, "series_description"] = df_meta.loc[:, "series_description"].map(mri_series_map)
    df_meta = keep_non_blank(df_meta)
    if df_labels is not None:
        df_labels = drop_by_id(df_labels, drop_ids)
        df_meta = drop_by_id(df_meta, drop_ids)
    ds = MRIVoxelDataset(df_meta, df_labels, voxel_size, including_series)
    return ds


labels_train = pd.read_csv(data_dir / "train_labels.csv", dtype={"BraTS21ID": str})
ds_train = get_dataset_in_pipeline(data_dir / "train", including_series, voxel_size,
                                   None, labels_train, "00109, 00123, 00709".split(", "))
ds_test = get_dataset_in_pipeline(data_dir / "test", including_series, voxel_size)

### Train model by cross validation

In [None]:
from sklearn.model_selection import StratifiedKFold

# Parameters to construct Net
in_channels = 1
feature_out_channels = 576
hidden_features = 512
n_classes = 1
n_series = len(mri_series)
n_planes = len(planes)
residual_config = get_residual_config()

# Training Parameters
n_splits = 5
batch_size = 16
epochs = 15
lr = 3e-4
num_workers = 6
weight_decay = 1e-5

save_filename_template = "{}-model-fold_{:03d}-best-state_dict.pt"
if USE_CROSS_VALIDATION and (not INFERENCE_ONLY):
    splitter = StratifiedKFold(n_splits, shuffle=True, random_state=42)
    df_label_train = ds_train.label_df
    for nth_fold, (idx_trn, idx_val) in enumerate(splitter.split(df_label_train, df_label_train["series_description"].values)):
        ds_trn,ds_val = (Subset(ds_train, idx_trn), Subset(ds_train, idx_val))
        dl_trn,dl_val = (DataLoader(ds_trn, batch_size=batch_size, shuffle=True, num_workers=num_workers),
                         DataLoader(ds_val, batch_size=batch_size, num_workers=num_workers))
        aug_trn,aug_val = get_augmentation(split="train"),get_augmentation(split="val")
        model = Net(
            in_channels,
            feature_out_channels,
            hidden_features,
            n_classes,
            n_series,
            n_planes,
            residual_config,
        ).to(torch.float32)
        optim = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optim, "min", factor=0.5, patience=3, cooldown=0, verbose=True
        )
        best_val_loss = np.inf
        for e in range(epochs):
            train_an_epoch(dl_trn, model, aug_trn, device, optim, e)
            loss_val, auc_val = evaluation(dl_val, model, aug_val, device)
            print(f"fold {nth_fold}, epoch {e}, val loss: {loss_val:.6f}, val AUC: {auc_val:.6f}")
            if loss_val < best_val_loss:
                print(f"Best val loss changed from {best_val_loss:.6f} to {loss_val:.6f}")
                best_val_loss = loss_val
                torch.save(model.state_dict(), save_filename_template.format(model_name, nth_fold))
            scheduler.step(loss_val)

### Train model by full data

Train with augmented data, as well as on validation. (Overfit the random Augmented validation set)

In [None]:
# Parameters to construct Net
in_channels = 1
feature_out_channels = 576
hidden_features = 512
n_classes = 1
n_series = len(mri_series)
n_planes = len(planes)
residual_config = get_residual_config()

# Training Parameters
batch_size = 16
epochs = 18
lr = 3e-4
num_workers = 6
weight_decay = 1e-5

save_filename_template = "{}-model-whole-dataset-best-state_dict.pt"

if not INFERENCE_ONLY:
    dl_trn,dl_val = (DataLoader(ds_train, batch_size=batch_size, shuffle=True, num_workers=num_workers),
                     DataLoader(ds_train, batch_size=batch_size, num_workers=num_workers))
    aug_trn,aug_val = get_augmentation(split="train"),get_augmentation(split="train")
    model = Net(
        in_channels,
        feature_out_channels,
        hidden_features,
        n_classes,
        n_series,
        n_planes,
        residual_config,
    ).to(torch.float32)
    optim = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optim, "min", factor=0.3, patience=3, cooldown=0, verbose=True
    )
    best_val_loss = np.inf
    for e in range(epochs):
        train_an_epoch(dl_trn, model, aug_trn, device, optim, e)
        loss_val, auc_val = evaluation(dl_val, model, aug_val, device)
        print(f"epoch {e}, val loss: {loss_val:.6f}, val AUC: {auc_val:.6f}")
        if loss_val < best_val_loss:
            print(f"Best val loss changed from {best_val_loss:.6f} to {loss_val:.6f}")
            best_val_loss = loss_val
            torch.save(model.state_dict(), save_filename_template.format(model_name))
        scheduler.step(loss_val)

## Summary in train set

In [None]:
from sklearn.metrics import confusion_matrix


def load_model(path, *net_args, **net_kwargs):
    net = Net(*net_args, **net_kwargs)
    state_dict = torch.load(path)
    net.load_state_dict(state_dict)
    return net


@torch.no_grad()
def train_performance(loader: DataLoader, models_list: List[Net], aug_list, device, n_pics=10):
    
    if isinstance(loader.dataset, Subset):
        df_copy: pd.DataFrame = loader.dataset.dataset.label_df.iloc[loader.dataset.indices].copy()
        meta_df = loader.dataset.dataset.meta_df.copy()
        voxel_size = loader.dataset.dataset.voxel_size
    else:
        df_copy: pd.DataFrame = loader.dataset.label_df.copy()
        meta_df = loader.dataset.meta_df.copy()
        voxel_size = loader.dataset.voxel_size
    print(df_copy["MGMT_value"].value_counts())
    batch_size = loader.batch_size
    for i, model in enumerate(models_list):
        df_copy.loc[:, f"MGMT_value_{i}"] = np.zeros_like(df_copy["BraTS21ID"], dtype=np.float64)
        model.to(device)
        model.eval()
        for n, (voxel, _, (series_desc, plane)) in tqdm(enumerate(loader),
                                                        desc=f"Inferencing with model idx {i}", total=len(loader)):
            voxel, series_desc, plane = aug_list(voxel).to(device), series_desc.to(device), plane.to(device)
            out = model(voxel, series_desc, plane)
            pred_proba = torch.sigmoid(out.detach())[:, 0].cpu().numpy()
            df_copy.iloc[n*batch_size:n*batch_size+len(voxel), df_copy.columns.get_loc(f"MGMT_value_{i}")] = pred_proba
    use_cols = [s for s in df_copy.columns if s.startswith("MGMT_value_")]
    df_copy["MGMT_value_pred"] = df_copy.loc[:, use_cols].mean(axis=1)
    
    
    auc = roc_auc_score(df_copy["MGMT_value"].values.astype("int64"), df_copy["MGMT_value_pred"].values)
    print(f"AUC: {auc}")

    cm = confusion_matrix(df_copy["MGMT_value"].values.astype("int64"),
                          df_copy["MGMT_value_pred"].map(lambda v: 1 if v>=0.5 else 0).values.astype("int64"))
    fig, ax = plt.subplots(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt="d", ax=ax)
    ax.set_xlabel("Pred")
    ax.set_ylabel("True")
    plt.show()
    
    tp = (df_copy.loc[(df_copy["MGMT_value"] == 1) & (df_copy["MGMT_value_pred"] >= 0.5)]
          .sort_values("MGMT_value_pred", ascending=False))
    fp = (df_copy.loc[(df_copy["MGMT_value"] == 0) & (df_copy["MGMT_value_pred"] >= 0.5)]
          .sort_values("MGMT_value_pred", ascending=False))
    tn = (df_copy.loc[(df_copy["MGMT_value"] == 0) & (df_copy["MGMT_value_pred"] < 0.5)]
          .sort_values("MGMT_value_pred", ascending=True))
    fn = (df_copy.loc[(df_copy["MGMT_value"] == 1) & (df_copy["MGMT_value_pred"] < 0.5)]
          .sort_values("MGMT_value_pred", ascending=True))
    titles = ["TP", "FP", "TN", "FN"]
    fig, axes = plt.subplots(n_pics, 4, figsize=(16, int(5*n_pics)))
    for j, df in enumerate([tp, fp, tn, fn]):
        for n, (_, row) in zip(range(n_pics), df.iterrows()):
            voxel, plane = get_voxel_by_id_series(meta_df, row["BraTS21ID"], row["series_description"], voxel_size[1:])
            img = voxel[(len(voxel)-1) // 2]
            axes[n, j].set_title(f"{titles[j]}, likelihood of positive: {row['MGMT_value_pred']:.4f}\n"
                                 f"type: {row['series_description']}, plane: {plane}.")
            axes[n, j].imshow(img, cmap=plt.cm.gray)
            axes[n, j].set_axis_off()
    plt.show()


aug_list = get_augmentation(split="val")
models_path = list(sorted(models_dir.glob(f"{model_name}*whole*best-state_dict.pt")))
if USE_CROSS_VALIDATION:
    models_path.extend(list(sorted(models_dir.glob(f"{model_name}*fold*best-state_dict.pt"))))
print(models_path)
models_list = [load_model(path,
                          in_channels,
                          feature_out_channels,
                          hidden_features,
                          n_classes,
                          n_series,
                          n_planes,
                          residual_config,
) for path in models_path]
dl_train = DataLoader(ds_train, batch_size=batch_size, num_workers=num_workers)
train_performance(dl_train, models_list, aug_list, device)

In [None]:
@torch.no_grad()
def inference_by_models(loader: DataLoader, models_list: List[Net], aug_list, device):

    df_copy: pd.DataFrame = loader.dataset.label_df.copy()
    batch_size = loader.batch_size
    for i, model in enumerate(models_list):
        df_copy.loc[:, f"MGMT_value_{i}"] = np.full_like(df_copy["BraTS21ID"], np.nan, dtype=np.float64)
        model.to(device)
        model.eval()
        for n, (voxel, _, (series_desc, plane))  in tqdm(enumerate(loader),
                                                         desc=f"Inferencing with model idx {i}", total=len(loader)):
            voxel, series_desc, plane = aug_list(voxel).to(device), series_desc.to(device), plane.to(device)
            out = model(voxel, series_desc, plane)
            pred_proba = torch.sigmoid(out.detach())[:, 0]
            df_copy.iloc[n*batch_size:n*batch_size+len(voxel),
                         df_copy.columns.get_loc(f"MGMT_value_{i}")] = pred_proba.cpu().numpy()
    df = df_copy.groupby("BraTS21ID").mean()
    use_cols = [s for s in df.columns if s.startswith("MGMT_value_")]
    df["MGMT_value"] = df.loc[:, use_cols].mean(axis=1)
    df = df.reset_index()
    submission = df.loc[:, ["BraTS21ID", "MGMT_value"]].copy()
    return submission


aug_list = get_augmentation(split="test")
models_path = list(sorted(models_dir.glob(f"{model_name}*whole*best-state_dict.pt")))
if USE_CROSS_VALIDATION:
    models_path.extend(list(sorted(models_dir.glob(f"{model_name}*fold*best-state_dict.pt"))))
print(models_path)
models_list = [load_model(path,
                          in_channels,
                          feature_out_channels,
                          hidden_features,
                          n_classes,
                          n_series,
                          n_planes,
                          residual_config,
) for path in models_path]
dl_test = DataLoader(ds_test, batch_size=batch_size, num_workers=num_workers)
submission = inference_by_models(dl_test, models_list, aug_list, device)
submission.to_csv("submission.csv", index=False)

In [None]:
sns.displot(submission["MGMT_value"], kde=True, bins=50)
plt.show()

In [None]:
time_end = time.time()
duration = time_end - time_begin
print(f"Kernel duration: {duration:.4f} seconds ({(duration / 60.):.4f} Min, {(duration / 3600.):.4f} Hour).")