I've trained model for tumor 3D segmentation based on data from [Task1](https://www.kaggle.com/dschettler8845/brats-2021-task1). [DenseVNet](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6076994/) model was used for segmentation. You can somehow use model and code provided for
MGMT classification.

## Model

In [None]:
"""Module with DenseVNet"""

import numpy as np
import torch


class DenseVNet(torch.nn.Module):
    def __init__(self, in_channels: int = 1, out_channels: int = 1):
        super().__init__()

        kernel_size = [5, 3, 3]
        num_downsample_channels = [24, 24, 24]
        num_skip_channels = [12, 24, 24]
        units = [5, 10, 10]
        growth_rate = [4, 8, 16]

        self.dfs_blocks = torch.nn.ModuleList()
        for i in range(3):
            self.dfs_blocks.append(
                DownsampleWithDfs(
                    in_channels=in_channels,
                    downsample_channels=num_downsample_channels[i],
                    skip_channels=num_skip_channels[i],
                    kernel_size=kernel_size[i],
                    units=units[i],
                    growth_rate=growth_rate[i],
                )
            )
            in_channels = num_downsample_channels[i] + units[i] * growth_rate[i]

        self.upsample_1 = torch.nn.Upsample(scale_factor=2, mode='trilinear')
        self.upsample_2 = torch.nn.Upsample(scale_factor=4, mode='trilinear')

        self.out_conv = ConvBlock(
            in_channels=sum(num_skip_channels),
            out_channels=out_channels,
            kernel_size=3,
            batch_norm=True,
            preactivation=True,
        )
        self.upsample_out = torch.nn.Upsample(scale_factor=2, mode='trilinear')

    def forward(self, x):
        x, skip_1 = self.dfs_blocks[0](x)
        x, skip_2 = self.dfs_blocks[1](x)
        _, skip_3 = self.dfs_blocks[2](x)

        skip_2 = self.upsample_1(skip_2)
        skip_3 = self.upsample_2(skip_3)

        out = self.out_conv(torch.cat([skip_1, skip_2, skip_3], 1))
        out = self.upsample_out(out)

        return out


class ConvBlock(torch.nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        dilation=1,
        stride=1,
        batch_norm=True,
        preactivation=False,
    ):
        super().__init__()

        if dilation != 1:
            raise NotImplementedError()

        padding = kernel_size - stride
        if padding % 2 != 0:
            pad = torch.nn.ConstantPad3d(
                tuple([padding % 2, padding - padding % 2] * 3), 0
            )
        else:
            pad = torch.nn.ConstantPad3d(padding // 2, 0)

        if preactivation:
            layers = [
                torch.nn.ReLU(),
                pad,
                torch.nn.Conv3d(
                    in_channels=in_channels,
                    out_channels=out_channels,
                    kernel_size=kernel_size,
                    stride=stride,
                ),
            ]
            if batch_norm:
                layers = [torch.nn.BatchNorm3d(in_channels)] + layers
        else:
            layers = [
                pad,
                torch.nn.Conv3d(
                    in_channels=in_channels,
                    out_channels=out_channels,
                    kernel_size=kernel_size,
                    stride=stride,
                ),
            ]
            if batch_norm:
                layers.append(torch.nn.BatchNorm3d(out_channels))
            layers.append(torch.nn.ReLU())

        self.conv = torch.nn.Sequential(*layers)

    def forward(self, x):
        return self.conv(x)


class DenseFeatureStack(torch.nn.Module):
    def __init__(
        self,
        in_channels,
        units,
        growth_rate,
        kernel_size,
        dilation=1,
        batch_norm=True,
        batchwise_spatial_dropout=False,
    ):
        super().__init__()

        self.units = torch.nn.ModuleList()
        for _ in range(units):
            if batchwise_spatial_dropout:
                raise NotImplementedError

            self.units.append(
                ConvBlock(
                    in_channels=in_channels,
                    out_channels=growth_rate,
                    kernel_size=kernel_size,
                    dilation=dilation,
                    stride=1,
                    batch_norm=batch_norm,
                    preactivation=True,
                )
            )
            in_channels += growth_rate

    def forward(self, x):
        feature_stack = [x]

        for unit in self.units:
            inputs = torch.cat(feature_stack, 1)
            out = unit(inputs)
            feature_stack.append(out)

        return torch.cat(feature_stack, 1)


class DownsampleWithDfs(torch.nn.Module):
    def __init__(
        self,
        in_channels,
        downsample_channels,
        skip_channels,
        kernel_size,
        units,
        growth_rate,
    ):
        super().__init__()

        self.downsample = ConvBlock(
            in_channels=in_channels,
            out_channels=downsample_channels,
            kernel_size=kernel_size,
            stride=2,
            batch_norm=True,
            preactivation=True,
        )
        self.dfs = DenseFeatureStack(
            downsample_channels, units, growth_rate, 3, batch_norm=True
        )
        self.skip = ConvBlock(
            in_channels=downsample_channels + units * growth_rate,
            out_channels=skip_channels,
            kernel_size=3,
            batch_norm=True,
            preactivation=True,
        )

    def forward(self, x):
        x = self.downsample(x)
        x = self.dfs(x)
        x_skip = self.skip(x)

        return x, x_skip


## Model Evaluator

In [None]:
"""Module with model evaluator"""

from typing import Tuple

import numpy as np
import scipy.ndimage
import torch


class ModelEvaluator:
    def __init__(
        self,
        model: DenseVNet,
        min_value: int = -200,
        max_value: int = 200,
        image_size: Tuple[int, int, int] = (184, 184, 128),
        device: torch.device = torch.device('cpu'),
    ):
        self.model = model
        self.model.eval()
        self.model.to(device)

        self.min_value = min_value
        self.max_value = max_value
        self.image_size = image_size

        self.device = device

    def evaluate(self, image: np.ndarray) -> np.ndarray:
        image = self._resize_ct(
            ct=image, 
            ct_size=self.image_size
        )
        image_tensor = self.preprocess_image(
            image=image,
            min_value=self.min_value,
            max_value=self.max_value,
        )
        image_tensor = image_tensor.to(self.device)

        image_tensor = image_tensor.unsqueeze(0).float()
        
        mask = self.model(x=image_tensor)

        mask = torch.sigmoid(mask)
        mask = mask.cpu().detach().numpy()
        mask = mask[0]

        return mask

    @staticmethod
    def preprocess_image(
        image: np.ndarray,
        min_value: int = -200,
        max_value: int = 200,
    ) -> torch.Tensor:
        image = np.clip(image, min_value, max_value)
        image = (image - min_value) / (max_value - min_value)

        image_tensor = torch.tensor(image)
        image_tensor = image_tensor.unsqueeze(0)

        return image_tensor
    
    @staticmethod
    def _resize_ct(
        ct: np.ndarray, ct_size: Tuple[int, int, int] = (512, 512, 256)
    ) -> np.ndarray:
        zoom_factor = [
            first_value / second_value
            for first_value, second_value in zip(ct_size, ct.shape)
        ]
        res_ct = scipy.ndimage.zoom(ct, zoom=zoom_factor)

        assert res_ct.shape == ct_size, f'Bad result size: {res_ct.shape}!={ct_size}'

        return res_ct


## Dicom loading functions

In [None]:
from pathlib import Path
from typing import Dict, List, Tuple, Union

import numpy as np
import nibabel as nib
import SimpleITK as sitk


def _load_and_align_ct(dicom_folder_path: str) -> np.ndarray:
    image_and_meta = _load_dicom(dicom_folder_path=dicom_folder_path)
    
    image = sitk.GetArrayFromImage(image=image_and_meta)
    affine = _make_affine(image_and_meta=image_and_meta)
    
    ct_all_info = nib.Nifti1Image(image, affine)
    image = _align_ct(ct_all_info=ct_all_info)

    return image.get_fdata(dtype=np.float64)

def _load_dicom(dicom_folder_path: Union[str, Path]) -> sitk.Image:
    sitk.ProcessObject_SetGlobalWarningDisplay(False)

    series_ids = sitk.ImageSeriesReader.GetGDCMSeriesIDs(
        directory=str(dicom_folder_path)
    )
    series_file_names = sitk.ImageSeriesReader.GetGDCMSeriesFileNames(
        str(dicom_folder_path), series_ids[0]
    )
    series_reader = sitk.ImageSeriesReader()
    series_reader.SetFileNames(series_file_names)
    series_reader.LoadPrivateTagsOn()
    image_and_meta = series_reader.Execute()
    
    return image_and_meta
    

def _align_ct(ct_all_info: nib.Nifti1Image) -> nib.Nifti1Image:
    orig_ornt = nib.io_orientation(ct_all_info.affine)
    targ_ornt = nib.orientations.axcodes2ornt(axcodes='LPS')
    transform = nib.orientations.ornt_transform(
        start_ornt=orig_ornt, end_ornt=targ_ornt
    )

    img_ornt = ct_all_info.as_reoriented(ornt=transform)

    return img_ornt
def _make_affine(image_and_meta: sitk.Image):
    # get affine transform in LPS
    c = [image_and_meta.TransformContinuousIndexToPhysicalPoint(p)
         for p in ((1, 0, 0),
                   (0, 1, 0),
                   (0, 0, 1),
                   (0, 0, 0))]
    c = np.array(c)
    affine = np.concatenate([
        np.concatenate([c[0:3] - c[3:], c[3:]], axis=0),
        [[0.], [0.], [0.], [1.]]
    ], axis=1)
    affine = np.transpose(affine)
    # convert to RAS to match nibabel
    affine = np.matmul(np.diag([-1., -1., 1., 1.]), affine)
    return affine


## Model and image loading

In [None]:
def rename_keys(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
    new_state_dict = {}

    for layer_name, layer_weights in state_dict.items():
        layer_name = layer_name.replace('model.', '')

        new_state_dict[layer_name] = layer_weights

    return new_state_dict


In [None]:
flair_dicom_path = '../input/rsna-miccai-brain-tumor-radiogenomic-classification/test/00013/FLAIR'
model_path = '../input/densevnettumorsegmentaion/DenseVNet_27epoch_best.ckpt'

in_channels = 1
out_channels = 2
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
image = _load_and_align_ct(dicom_folder_path=flair_dicom_path)
model_meta = torch.load(model_path, map_location='cpu')

model_state_dict = model_meta['state_dict']
model_state_dict = rename_keys(state_dict=model_state_dict)

model = DenseVNet(in_channels=in_channels, out_channels=out_channels)
model.load_state_dict(model_state_dict)
model.to(device=device)
model.eval()

model_evaluator = ModelEvaluator(
    model=model,
    device=device,
    min_value=-200,
    max_value=2500,
)

## Mask calculation and visualization

In [None]:
mask = model_evaluator.evaluate(image=image)

In [None]:
import matplotlib.pyplot as plt
from cv2 import cv2
from matplotlib import animation, rc
rc('animation', html='jshtml')


def create_animation(
    image,
    mask,
    min_value: int = -200,
    max_value: int = 200,
    edge: float = 0.5,
    image_size: Tuple[int, int] = (512, 512),
):
    curr_slice = combine_image_and_mask_slice(
        image=image,
        mask=mask, 
        slice_idx=0,
        min_value=min_value,
        max_value=max_value,
        edge=edge,
        image_size=image_size
    )
    
    fig = plt.figure(figsize=(6, 6))
    plt.axis('off')
    im = plt.imshow(
        curr_slice, 
        vmin=0, 
        vmax=255
    )
    
    def animate_func(i):
        curr_slice = combine_image_and_mask_slice(
            image=image,
            mask=mask, 
            slice_idx=i,
            min_value=min_value,
            max_value=max_value,
            edge=edge
        )
        
        im.set_array(curr_slice)
        return [im]

    return animation.FuncAnimation(
        fig,
        animate_func,
        frames = image.shape[2],
        interval = 1000//24
    )


def combine_image_and_mask_slice(
    image: np.ndarray,
    mask: np.ndarray,
    slice_idx: int,
    min_value: int = -200,
    max_value: int = 200,
    edge: float = 0.5, 
    image_size: Tuple[int, int] = (512, 512),
):
    image_slice = image[:, :, slice_idx]
    mask_slice = mask[:, :, slice_idx]
    
    image_slice = preprocess_image(
        image=image_slice, 
        min_value=min_value,
        max_value=max_value,
        image_size=image_size,
    )
    mask_slice = preprocess_mask(
        mask=mask_slice,
        edge=edge,
        image_size=image_size,
    )
    
    result = cv2.addWeighted(image_slice, 0.7, mask_slice, 0.3, 0.0)

    return result
    
def preprocess_image(
    image: np.ndarray,
    min_value: int = -200,
    max_value: int = 200,
    image_size: Tuple[int, int] = (512, 512),
) -> np.ndarray:
    image = np.clip(image, min_value, max_value)
    image = (image - min_value) / (max_value - min_value)
    image = cv2.resize(src=image, dsize=image_size)
    
    image = np.uint8(image * 255)
    image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)

    return image

def preprocess_mask(
    mask: np.ndarray,
    edge: float = 0.5,
    image_size: Tuple[int, int] = (512, 512),
) -> np.ndarray:
    mask = np.uint8(mask > edge)  
    mask = np.uint8(mask * 255) 
    mask = cv2.resize(src=mask, dsize=image_size)
    mask = mask[..., np.newaxis]

    mask = np.concatenate(
        [
            mask,
            np.zeros_like(mask),
            np.zeros_like(mask),
        ],
        axis=-1,
    )

    return mask
    

In [None]:
def _resize_ct(
    ct: np.ndarray, ct_size: Tuple[int, int, int] = (512, 512, 256)
) -> np.ndarray:
    zoom_factor = [
        first_value / second_value
        for first_value, second_value in zip(ct_size, ct.shape)
    ]
    res_ct = scipy.ndimage.zoom(ct, zoom=zoom_factor)

    assert res_ct.shape == ct_size, f'Bad result size: {res_ct.shape}!={ct_size}'

    return res_ct

In [None]:
image_to_draw = _resize_ct(
    ct=image,
    ct_size=(184, 184, 128)
)

create_animation(
    image=image_to_draw, 
    mask=mask[1],
    min_value=image.min(),
    max_value=image.max()
)

## Conclusion

You can use resulted masks as additional channel for your classification model. Or in some other way. Have fun!