## Use stacked images (3D) and simple 3D CNN model

This is a variation of https://www.kaggle.com/rluethy/efficientnet3d-with-one-mri-type using a simple 3D CNN and 5-fold cross validation. Also, a threshold is applied to the images.

Acknowledgements:

- https://www.kaggle.com/ihelon/brain-tumor-eda-with-animations-and-modeling
- https://www.kaggle.com/furcifer/torch-efficientnet3d-for-mri-no-train
- https://www.kaggle.com/davidbroberts/adjusting-contrast-on-mr-images
 
    
Use models with only one MRI type, then ensemble the 4 models 

The resulting models were used in the 26th ranked submission on the private LB: https://www.kaggle.com/rluethy/predict-only-simple-3d-cnn-with-one-mri-type  

In [None]:
import os
import sys 
import json
import glob
import random
import collections
import time
import re
import numpy as np
import pandas as pd
import pydicom
from pydicom.pixel_data_handlers.util import apply_voi_lut
import cv2
import matplotlib.pyplot as plt
import seaborn as sns

import torch
from torch import nn
from torch.utils import data as torch_data
from sklearn import model_selection as sk_model_selection
from torch.nn import functional as torch_functional
import torch.nn.functional as F

from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import roc_auc_score, log_loss

## Functions to load images

In [None]:
# Load and process images

import os
import glob
import re
import pickle

import numpy as np
import pandas as pd
import pydicom
from pydicom.pixel_data_handlers.util import apply_voi_lut
import cv2
import matplotlib.pyplot as plt
from IPython.display import HTML
from base64 import b64encode
import matplotlib.animation as animation

# Settings

if os.path.exists("../input/rsna-miccai-brain-tumor-radiogenomic-classification"):
    data_directory = '../input/rsna-miccai-brain-tumor-radiogenomic-classification'
    working_dir = "/tmp/rsna"
    modelpath = "../input/brain-tumor-models"
else:
    data_directory = '/media/roland/data/kaggle/rsna-miccai-brain-tumor-radiogenomic-classification'
    working_dir = "/tmp/rsna"

if not os.path.exists(working_dir):
    os.mkdir(working_dir)

mri_types = ['FLAIR', 'T1w', 'T1wCE', 'T2w']
SIZE = 256
NUM_IMAGES = 32
USE_IMAGE_CACHE = False
USE_LUT_CONTRAST = None  # {"window_width": 2000, "window_level": 2000}
USE_VOI_LUT = True


def find_crop_area(imgfiles):
    x1 = 1000
    x2 = 0
    y1 = 1000
    y2 = 0

    for f in imgfiles:
        dicom = pydicom.read_file(f)
        data = dicom.pixel_array
        # bb = None
        if np.max(data) > np.min(data):
            data = data - np.min(data)
            data = data / np.max(data)
            data = (data * 255).astype(np.uint8)
            bb = cv2.boundingRect(data)
            if (bb[2] > 0) and (bb[3] > 0):
                if bb[0] < x1:
                    x1 = bb[0]
                if bb[1] < y1:
                    y1 = bb[1]
                if bb[0] + bb[2] > x2:
                    x2 = bb[0] + bb[2]
                if bb[1] + bb[3] > y2:
                    y2 = bb[1] + bb[3]

        # print(bb, x1, x2, y1, y2)

    if (x2 > x1) and (y2 > y2):
        return x1, y1, x2 - x1, y2 - y1
    else:
        return 0, 0, data.shape[0], data.shape[1]


def load_dicom_image(path, img_size=SIZE, crop_area=None, voi_lut=USE_VOI_LUT, rotate=0):
    dicom = pydicom.read_file(path)
    # print(path[-10:], dicom.InstanceNumber, np.min(data), np.max(data), end=" ")
    if voi_lut:
        data = apply_voi_lut(dicom.pixel_array, dicom)
        # print("lut", np.min(data), np.max(data))
    else:
        data = dicom.pixel_array
        # print()
    if crop_area:
        cropped = data[crop_area[1]:crop_area[1] + crop_area[3], crop_area[0]:crop_area[0] + crop_area[2]]
        data = cropped
    if rotate > 0:
        rot_choices = [0, cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE, cv2.ROTATE_180]
        data = cv2.rotate(data, rot_choices[rotate])

    try:
        data = cv2.resize(data, (img_size, img_size))
    except Exception as exc:
        print(exc)
        print(path)
        print(crop_area)
        raise (exc)
    return data


def get_image_plane(img_path):
    # Ref:
    # https://www.kaggle.com/davidbroberts/determining-mr-image-planes

    dicom = pydicom.read_file(img_path)
    loc = dicom.ImageOrientationPatient

    row_x = round(loc[0])
    row_y = round(loc[1])
    row_z = round(loc[2])
    col_x = round(loc[3])
    col_y = round(loc[4])
    col_z = round(loc[5])

    if row_x == 1 and row_y == 0 and col_x == 0 and col_y == 0:
        return 0, "Coronal"

    if row_x == 0 and row_y == 1 and col_x == 0 and col_y == 0:
        return 1, "Sagittal"

    if row_x == 1 and row_y == 0 and col_x == 0 and col_y == 1:
        return 2, "Axial"

    return "Unknown"


def load_dicom_images_3d(scan_id, num_imgs=NUM_IMAGES, img_size=SIZE, mri_type="FLAIR",
                         split="train", lut_contrast=USE_LUT_CONTRAST,  # {"window_width": 1000, "window_level": 2000},
                         offset=0, voi_lut=USE_VOI_LUT, rotate=0, use_cache=USE_IMAGE_CACHE,
                         threshold=0):
    cfn = f"{working_dir}/{scan_id}_{mri_type}_{SIZE}_{num_imgs}_{offset}_{rotate}{'L' if voi_lut else ''}{'C' if lut_contrast else ''}.pkl"
    if use_cache and os.path.exists(cfn):
        img3d = pickle.load(open(cfn, "rb"))
    else:
        files = sorted(glob.glob(f"{data_directory}/{split}/{scan_id}/{mri_type}/*.dcm"),
                       key=lambda var: [int(x) if x.isdigit() else x for x in re.findall(r'[^0-9]|[0-9]+', var)])

        assert len(files) > 0, f"no image files for {data_directory}/{split}/{scan_id}/{mri_type}/*.dcm"
        middle = len(files) // 2 + offset
        # print("A ",scan_id, "n", len(files), offset, "m", middle)
        if (middle <= 5) or (middle >= (len(files) - 5)):
            middle = len(files) // 2
        num_imgs2 = num_imgs // 2
        p1 = max(0, middle - num_imgs2)
        p2 = min(len(files), middle + num_imgs2)
        # print("B ","n",len(files), offset, "m", middle, "s",p1, "e", p2)

        crop_area = find_crop_area(files[p1:p2])
        img3d = np.stack([load_dicom_image(f, SIZE, crop_area, voi_lut, rotate) for f in files[p1:p2]]).T
        # print(np.min(img3d), np.max(img3d), np.mean(img3d), np.median(img3d))
        if img3d.shape[-1] < num_imgs:
            n_zero = np.zeros((img_size, img_size, num_imgs - img3d.shape[-1]))
            img3d = np.concatenate((img3d, n_zero), axis=-1)

        # print(np.min(img3d), np.max(img3d), np.mean(img3d), np.median(img3d))
        if lut_contrast:
            img3d = img3d - np.min(img3d)
            img3d = img3d / np.max(img3d)
            img3d = lut_contrast["window_level"] * img3d
            # print(np.min(img3d), np.max(img3d), np.mean(img3d), np.median(img3d))
            lut = make_lut(img3d, windowWidth=lut_contrast["window_width"], windowLevel=lut_contrast["window_level"])
            img3d = np.reshape(apply_lut(img3d, lut), (img3d.shape[0], img3d.shape[1], img3d.shape[2]))
            # print(np.min(img3d), np.max(img3d), np.mean(img3d), np.median(img3d))
        if np.min(img3d) < np.max(img3d):
            img3d = img3d - np.min(img3d)
            img3d = img3d / np.max(img3d)
            #img3d = 2*img3d - 1

        if threshold > 0:
            idx = img3d < threshold
            img3d[idx] = 0

        if use_cache:
            pickle.dump(img3d, open(cfn, "wb"))

    # print(np.min(img3d), np.max(img3d), np.mean(img3d), np.median(img3d))
    return np.expand_dims(img3d, 0)


# Adjusting Contrast on MR Images
# https://www.kaggle.com/davidbroberts/adjusting-contrast-on-mr-images
# Make a simple linear VOI LUT from the raw (stored) pixel data
def make_lut(storedPixels, windowWidth, windowLevel, p_i="MONOCHROME2"):
    # Slope and Intercept set to 1 and 0 for MR. Get these from DICOM tags instead if using
    # on a modality that requires them (CT, PT etc)
    slope = 1.0
    intercept = 0.0
    minPixel = int(np.amin(storedPixels))
    maxPixel = int(np.amax(storedPixels))

    # Make an empty array for the LUT the size of the pixel 'width' in the raw pixel data
    lut = [0] * (maxPixel + 1)

    # Invert pixels and windowLevel for MONOCHROME1. We invert the specified windowLevel so that 
    # increasing the level value makes the images brighter regardless of photometric intrepretation
    invert = False
    if p_i == "MONOCHROME1":
        invert = True
    else:
        windowLevel = (maxPixel - minPixel) - windowLevel

    # Loop through the pixels and calculate each LUT value
    for storedValue in range(minPixel, maxPixel):
        modalityLutValue = storedValue * slope + intercept
        voiLutValue = (((modalityLutValue - windowLevel) / windowWidth + 0.5) * 255.0)
        clampedValue = min(max(voiLutValue, 0), 255)
        if invert:
            lut[storedValue] = round(255 - clampedValue)
        else:
            lut[storedValue] = round(clampedValue)

    return lut


# Apply the LUT to a pixel array
def apply_lut(pixels_in, lut):
    pixels_in = pixels_in.flatten()
    pixels_out = [0] * len(pixels_in)

    for i in range(0, len(pixels_in)):
        pixel = int(pixels_in[i])
        pixels_out[i] = int(lut[pixel])

    return pixels_out


# Save images as video
def play(filename):
    html = ''
    video = open(filename, 'rb').read()
    src = 'data:video/mp4;base64,' + b64encode(video).decode()
    html += '<video width=500 controls autoplay loop><source src="%s" type="video/mp4"></video>' % src
    return HTML(html)


def create_video(imgs, output=f'{working_dir}/vis_video.mp4', frame_delay=200):
    fig, ax = plt.subplots(figsize=(15, 10))
    ims = []
    n_imgs = imgs.shape[-1]
    for i in range(n_imgs):
        # print(i,np.min(imgs[0,:,:,i]), np.max(imgs[0,:,:,i]))
        im = ax.imshow(imgs[0, :, :, i], animated=True, cmap='gray')
        ims.append([im])
    plt.close(fig)
    # print(len(ims))
    ani = animation.ArtistAnimation(fig, ims, interval=frame_delay, blit=True, repeat_delay=1000)

    ani.save(output)
    return output


if __name__ == "__main__":
    a = load_dicom_images_3d("00000")
    print(a.shape)
    print(np.min(a), np.max(a), np.mean(a), np.median(a))

THRESHOLD = 0.5

! rm {working_dir}/*.pkl
a = load_dicom_images_3d("00144", offset=-2, threshold=THRESHOLD)
print(a.shape)
print(np.min(a), np.max(a), np.mean(a), np.median(a))

play(create_video(a))

In [None]:
def set_seed(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True

set_seed(42)

## Dataset class

In [None]:
class Dataset(torch_data.Dataset):
    def __init__(self, paths, targets=None, mri_type=None, label_smoothing=0.01, split="train", augment=False):
        self.paths = paths
        self.targets = targets
        self.mri_type = mri_type
        self.label_smoothing = label_smoothing
        self.split = split
        self.augment = augment
          
    def __len__(self):
        return len(self.paths)
    
    def __getitem__(self, index):
        scan_id = self.paths[index]
        if self.split=="test":
            data = load_dicom_images_3d(str(scan_id).zfill(5), mri_type=self.mri_type[index], 
                                        split=self.split, threshold=THRESHOLD)
        elif self.split=="valid":
            data = load_dicom_images_3d(str(scan_id).zfill(5), mri_type=self.mri_type[index], 
                                        split="train", threshold=THRESHOLD)
        else:
            offset = 0
            if self.augment:
                # offset = np.random.randint(-10,10)
                rotation = np.random.randint(0,4)
            else:
                rotation = 0
            data = load_dicom_images_3d(str(scan_id).zfill(5), mri_type=self.mri_type[index], 
                                        split="train", offset=offset, rotate=rotation,
                                        threshold=THRESHOLD)
            #if self.augment:
            #    if np.random.random() > 0.5:
            #        data = np.transpose(data, [0, 2,1,3])

        if self.targets is None:
            return {"X": torch.tensor(data).float(), "id": scan_id}
        else:
            y = torch.tensor(abs(self.targets[index]-self.label_smoothing), dtype=torch.float)
            return {"X": torch.tensor(data).float(), "y": y}




## Model class

In [None]:
class Model(nn.Module):
    
    def __init__(self, num_classes=1, num_channels=1):
        super(Model, self).__init__()
        
        self.conv_layers = nn.ModuleList([self._conv_layer_set(num_channels, 8)])
        d1 = int((SIZE-2)/2)
        d2 = int((NUM_IMAGES-2)/2)
        #print(d1,d1,d2)
        self.conv_layers.append(self._conv_layer_set(8, 16))
        d1 = int((d1-2)/2)
        d2 = int((d2-2)/2)
        #print(d1,d1,d2)
        self.conv_layers.append(self._conv_layer_set(16, 32))
        d1 = int((d1-2)/2)
        d2 = int((d2-2)/2)
        #print(d1,d1,d2)
        #self.conv_layer4 = self._conv_layer_set(32, 16)
        #print(d1,d1,d2)
        self.fc1 = nn.Linear(32*d1*d1*d2, 128)
        #self.fc1 = nn.Linear(16*14*14*2, 128)
        self.fc_final = nn.Linear(128, num_classes)
        self.activation1 = nn.LeakyReLU()
        self.batchnorm1 = nn.BatchNorm1d(128)
        #self.drop1 = nn.Dropout(p=0.15)        
        
    def _conv_layer_set(self, in_c, out_c):
        conv_layer = nn.Sequential(
        nn.Conv3d(in_c, out_c, kernel_size=(3, 3, 3), padding=0),
        nn.LeakyReLU(),
        nn.MaxPool3d((2, 2, 2)),
        )
        return conv_layer
    

    def forward(self, x):
        #print(x.shape)
        out = self.conv_layers[0](x)
        #print(out.shape)
        for l in self.conv_layers[1:]:
            out = l(out)
            #print(out.shape)
        out = out.view(out.size(0), -1)
        #print(out.shape)
        out = self.fc1(out)
        #print(out.shape)
        out = self.activation1(out)
        #print(out.shape)
        out = self.batchnorm1(out)
        #out = self.drop1(out)
        out = self.fc_final(out)
        #print(out.shape)
        
        return out

In [None]:
class Trainer:
    def __init__(
        self, 
        model, 
        device, 
        optimizer, 
        criterion
    ):
        self.model = model
        self.device = device
        self.optimizer = optimizer
        self.criterion = criterion

        self.best_valid_score = np.inf
        self.n_patience = 0
        self.lastmodel = None
        
    def fit(self, epochs, train_loader, valid_loader, save_path, patience):        
        for n_epoch in range(1, epochs + 1):
            self.info_message("EPOCH: {}", n_epoch)
            
            train_loss, train_time = self.train_epoch(train_loader)
            valid_loss, valid_auc, valid_time = self.valid_epoch(valid_loader)
            
            self.info_message(
                "[Epoch Train: {}] loss: {:.4f}, time: {:.2f} s            ",
                n_epoch, train_loss, train_time
            )
            
            self.info_message(
                "[Epoch Valid: {}] loss: {:.4f}, auc: {:.4f}, time: {:.2f} s",
                n_epoch, valid_loss, valid_auc, valid_time
            )

            # if True:
            #if self.best_valid_score < valid_auc:
            if self.best_valid_score > valid_loss:
                self.save_model(n_epoch, save_path, valid_loss, valid_auc)
                self.info_message(
                     "loss improved from {:.4f} to {:.4f}. Saved model to '{}'", 
                    self.best_valid_score, valid_loss, self.lastmodel
                )
                self.best_valid_score = valid_loss
                self.n_patience = 0
            else:
                self.n_patience += 1
            
            if self.n_patience >= patience:
                self.info_message("\nValid score didn't improve last {} epochs.", patience)
                break
            
    def train_epoch(self, train_loader):
        self.model.train()
        t = time.time()
        sum_loss = 0

        for step, batch in enumerate(train_loader, 1):
            X = batch["X"].to(self.device)
            targets = batch["y"].to(self.device)
            self.optimizer.zero_grad()
            outputs = self.model(X).squeeze(1)
            
            loss = self.criterion(outputs, targets)
            loss.backward()

            sum_loss += loss.detach().item()

            self.optimizer.step()
            
            message = 'Train Step {}/{}, train_loss: {:.4f}'
            self.info_message(message, step, len(train_loader), sum_loss/step, end="\r")
        
        return sum_loss/len(train_loader), int(time.time() - t)
    
    def valid_epoch(self, valid_loader):
        self.model.eval()
        t = time.time()
        sum_loss = 0
        y_all = []
        outputs_all = []

        for step, batch in enumerate(valid_loader, 1):
            with torch.no_grad():
                X = batch["X"].to(self.device)
                targets = batch["y"].to(self.device)

                outputs = self.model(X).squeeze(1)
                loss = self.criterion(outputs, targets)

                sum_loss += loss.detach().item()
                y_all.extend(batch["y"].tolist())
                outputs_all.extend(torch.sigmoid(outputs).tolist())

            message = 'Valid Step {}/{}, valid_loss: {:.4f}'
            self.info_message(message, step, len(valid_loader), sum_loss/step, end="\r")
            
        y_all = [1 if x > 0.5 else 0 for x in y_all]
        auc = roc_auc_score(y_all, outputs_all)
        
        return sum_loss/len(valid_loader), auc, int(time.time() - t)
    
    def save_model(self, n_epoch, save_path, loss, auc):
        self.lastmodel = f"{working_dir}/{save_path}-e{n_epoch}-loss{loss:.3f}-auc{auc:.3f}.pth"
        torch.save(
            {
                "model_state_dict": self.model.state_dict(),
                "optimizer_state_dict": self.optimizer.state_dict(),
                "best_valid_score": self.best_valid_score,
                "n_epoch": n_epoch,
            },
            self.lastmodel,
        )
    
    @staticmethod
    def info_message(message, *args, end="\n"):
        print(message.format(*args), end=end)

## train models

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def train_mri_type(trn, val, mri_type, start_model=None, fold=None):
    if mri_type=="all":
        train_list = []
        valid_list = []
        for mri_type in mri_types:
            trn.loc[:,"MRI_Type"] = mri_type
            train_list.append(trn.copy())
            val.loc[:,"MRI_Type"] = mri_type
            valid_list.append(val.copy())

        trn = pd.concat(train_list)
        val = pd.concat(valid_list)
    else:
        trn.loc[:,"MRI_Type"] = mri_type
        val.loc[:,"MRI_Type"] = mri_type

    print(fold, mri_type, trn.shape, val.shape)
    display(trn.head())
    
    train_data_retriever = Dataset(
        trn["BraTS21ID"].values, 
        trn["MGMT_value"].values, 
        trn["MRI_Type"].values,
        augment=False
    )

    valid_data_retriever = Dataset(
        val["BraTS21ID"].values, 
        val["MGMT_value"].values,
        val["MRI_Type"].values,
        split="valid"
    )

    train_loader = torch_data.DataLoader(
        train_data_retriever,
        batch_size=2,
        shuffle=True,
        num_workers=4, drop_last=True, pin_memory = True
    )

    valid_loader = torch_data.DataLoader(
        valid_data_retriever, 
        batch_size=2,
        shuffle=False,
        num_workers=4, pin_memory = True
    )

    model = Model()
    model.to(device)

    if start_model:
        print("Loading checkpoint:", start_model)
        checkpoint = torch.load(start_model)
        model.load_state_dict(checkpoint["model_state_dict"])

    print(model)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    #optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

    criterion = torch_functional.binary_cross_entropy_with_logits

    trainer = Trainer(
        model, 
        device, 
        optimizer, 
        criterion
    )
    save_path = f"{mri_type}"
    if fold:
        save_path += f"_fold{fold}"
    history = trainer.fit(
        20, 
        train_loader, 
        valid_loader, 
        save_path, 
        10,
    )
    
    return trainer.lastmodel

In [None]:
start_modelfiles = None 
useCV = True

train_df = pd.read_csv(f"{data_directory}/train_labels.csv")
print(train_df.shape)
train_df = train_df[~train_df["BraTS21ID"].isin([109,123,709])]
print(train_df.shape)

for idx, row in train_df.iterrows():
    for mri_type in mri_types:
        fpath = os.path.join(data_directory,"train", str(row["BraTS21ID"]).zfill(5), mri_type)
        imgfiles = glob.glob(os.path.join(fpath, "*"))
        train_df.loc[idx,f"{mri_type}_img_cnt"] = len(imgfiles)

display(train_df)


if useCV:
    modelfiles = []
    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=1234)
    fold=0
    for trn_index, val_index in skf.split(train_df, train_df["MGMT_value"]):
        fold+=1
        trn_df = train_df.iloc[trn_index]
        val_df = train_df.iloc[val_index]
        for mt in mri_types:
            modelfiles.append(train_mri_type(trn_df, val_df, mt, fold=fold))
else:
    df_train, df_valid = sk_model_selection.train_test_split(
        train_df, 
        test_size=0.2, 
        random_state=42, 
        stratify=train_df["MGMT_value"],
    )

    df_train.head()
    modelfiles = [train_mri_type(df_train, df_valid, mt) for mt in mri_types]

print(modelfiles)        

## Predict function

In [None]:
def predict(modelfile, df, mri_type, split):
    print("Predict:", modelfile, mri_type, df.shape)
    df.loc[:,"MRI_Type"] = mri_type
    data_retriever = Dataset(
        df.index.values, 
        mri_type=df["MRI_Type"].values,
        split=split
    )

    data_loader = torch_data.DataLoader(
        data_retriever,
        batch_size=4,
        shuffle=False,
        num_workers=4, pin_memory = True
    )
   
    model = Model()
    model.to(device)
    
    checkpoint = torch.load(modelfile)
    model.load_state_dict(checkpoint["model_state_dict"])
    model.eval()
    
    y_pred = []
    ids = []

    for e, batch in enumerate(data_loader,1):
        print(f"{e}/{len(data_loader)}", end="\r")
        with torch.no_grad():
            tmp_pred = torch.sigmoid(model(batch["X"].to(device))).cpu().numpy().squeeze()
            if tmp_pred.size == 1:
                y_pred.append(tmp_pred)
            else:
                y_pred.extend(tmp_pred.tolist())
            ids.extend(batch["id"].numpy().tolist())
            
    preddf = pd.DataFrame({"BraTS21ID": ids, "MGMT_value": y_pred}) 
    preddf = preddf.set_index("BraTS21ID")
    return preddf

## Ensemble for validation

In [None]:
if useCV:
    splits = {}
    fold = 0
    for trn_index, val_index in skf.split(train_df, train_df["MGMT_value"]):
        fold+=1
        trn_idx = trn_index
        val_idx = val_index
        splits[fold] = (trn_idx, val_idx)
    print(splits.keys())
    train_df["MGMT_pred"] = 0
    for m in modelfiles:
        mtype = m.split("_")[0].split("/")[-1]
        fold = int(m.split("-")[0][-1])
        val_df0 = train_df.iloc[splits[fold][1]]
        val_df = val_df0.set_index("BraTS21ID")
        #print(m, mtype, fold, val_df.shape)
        pred = predict(m, val_df, mtype, "valid")
        tmp = train_df.loc[val_df0.index,"MGMT_pred"] + pred["MGMT_value"].values
        train_df.loc[val_df0.index,"MGMT_pred"] = tmp
    train_df["MGMT_pred"] /= 4
    auc = roc_auc_score(train_df["MGMT_value"], train_df["MGMT_pred"])
    loss = log_loss(train_df["MGMT_value"], train_df["MGMT_pred"])
else:
    df_valid = df_valid.set_index("BraTS21ID")
    df_valid["MGMT_pred"] = 0
    for m, mtype in zip(modelfiles,  mri_types):
        pred = predict(m, df_valid, mtype, "valid")
        df_valid["MGMT_pred"] += pred["MGMT_value"]
    df_valid["MGMT_pred"] /= len(modelfiles)
    auc = roc_auc_score(df_valid["MGMT_value"], df_valid["MGMT_pred"])
    loss = log_loss(df_valid["MGMT_value"], df_valid["MGMT_pred"])
print(f"Validation ensemble loss {loss:.4f}, AUC: {auc:.4f}")

In [None]:
if useCV:
    sns.displot(train_df["MGMT_pred"])
    plt.title(f'lut_contrast={USE_LUT_CONTRAST}, voi_lut={USE_VOI_LUT}, 5 fold CV, val AUC {auc:.3f}')
else:
    sns.displot(df_valid["MGMT_pred"])
    plt.title(f'lut_contrast={USE_LUT_CONTRAST}, voi_lut={USE_VOI_LUT} val AUC {auc:.3f}')

## Ensemble for submission

In [None]:
submission = pd.read_csv(f"{data_directory}/sample_submission.csv", index_col="BraTS21ID")

submission["MGMT_value"] = 0
for m in modelfiles:
    if useCV:
        mtype = m.split("_")[0].split("/")[-1]
    else:
        mtype = m.split("-")[0].split("/")[-1]
    print(m, mtype)
    pred = predict(m, submission, mtype, split="test")
    submission["MGMT_value"] += pred["MGMT_value"]

submission["MGMT_value"] /= len(modelfiles)
submission["MGMT_value"].to_csv("submission.csv")

In [None]:
submission

In [None]:
sns.displot(submission["MGMT_value"])