In [None]:
# !pip install torchio
# import torchio as tio

In [None]:
import os
import sys 
import json
import glob
import random
import collections
import time

import numpy as np
import pandas as pd
import pydicom
from pydicom.pixel_data_handlers.util import apply_voi_lut
import cv2
import matplotlib.pyplot as plt
import seaborn as sns

import torch
from torch import nn
from torch.utils import data as torch_data
from sklearn import model_selection as sk_model_selection
from torch.nn import functional as torch_functional
import torch.nn.functional as F
from torchvision import transforms, utils

from sklearn import model_selection
from sklearn import metrics
from skimage import exposure

from albumentations import Resize, Normalize, Compose
from albumentations.pytorch import ToTensorV2
import albumentations as album

import warnings
warnings.filterwarnings("ignore")

plt.style.use("dark_background")

from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import roc_auc_score

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device} device")

In [None]:
# pip install --upgrade batchgenerators

In [None]:
# train_data_retriever = Dataset(df_train["BraTS21ID"].values, df_train["MGMT_value"].values)

# train_loader = torch_data.DataLoader(train_data_retriever, batch_size=4, shuffle=True, num_workers=1,)
# batch = next(iter(train_loader))

In [None]:
# def plot_batch(batch):
#     batch_size = batch['X'].shape[0]
#     plt.figure(figsize=(16, 10))
#     for i in range(batch_size):
#         plt.subplot(1, batch_size, i+1)
#         plt.imshow(batch['X'][i, 0, 0, :, :], cmap="gray") # only grayscale image here
#     plt.show()

# plot_batch(batch)

In [None]:
# array = batch['X'][0, 0, :, :, :].numpy().astype(np.float32)

In [None]:
# from batchgenerators.transforms.color_transforms import ContrastAugmentationTransform
# from batchgenerators.transforms.spatial_transforms import MirrorTransform
# from batchgenerators.transforms.abstract_transforms import Compose
# from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter

# my_transforms = []
# brightness_transform = ContrastAugmentationTransform((0.3, 3.), preserve_range=True)
# my_transforms.append(brightness_transform)
# mirror_transform = MirrorTransform(axes=(0, 1))
# my_transforms.append(mirror_transform)

# all_transforms = Compose(my_transforms)

# multithreaded_generator = MultiThreadedAugmenter(array, all_transforms, 4, 1, seeds=None)
# plot_batch(next(iter(multithreaded_generator)))
# multithreaded_generator._finish()


In [None]:
if os.path.exists("../input/rsna-miccai-brain-tumor-radiogenomic-classification"):
    data_directory = '../input/rsna-miccai-brain-tumor-radiogenomic-classification'
    pytorch3dpath = "../input/efficientnetpyttorch3d/EfficientNet-PyTorch-3D"
else:
    data_directory = '/media/roland/data/kaggle/rsna-miccai-brain-tumor-radiogenomic-classification'
    pytorch3dpath = "EfficientNet-PyTorch-3D"
    
mri_types = ['FLAIR','T1w','T1wCE','T2w']
SIZE = 128
NUM_IMAGES = 64

sys.path.append(pytorch3dpath)
from efficientnet_pytorch_3d import EfficientNet3D

In [None]:
def load_dicom_image(path, img_size=SIZE):
    dicom = pydicom.read_file(path)
    data = dicom.pixel_array
    
    if np.min(data)==np.max(data):
        data = np.zeros((img_size,img_size))
        return data
    
#     data = exposure.equalize_adapthist(data, clip_limit=0.08)
    data = apply_voi_lut(dicom.pixel_array, dicom)
    if dicom.PhotometricInterpretation == "MONOCHROME1":
        data = np.amax(data) - data
        
    data = cv2.resize(data, (img_size, img_size))
    
    data = data - np.min(data)
    data = data / np.max(data)
    data = (data * 255).astype(np.uint8)

    return data.astype(np.uint8)

In [None]:
def load_dicom_images_3d(scan_id, num_imgs=NUM_IMAGES, img_size=SIZE, mri_type="FLAIR", split="train"):
    files = sorted(glob.glob(f"{data_directory}/{split}/{scan_id}/{mri_type}/*.dcm"))
    
    middle = len(files) // 2
    num_imgs2 = num_imgs // 2
    p1 = max(0, middle - num_imgs2)
    p2 = min(len(files), middle + num_imgs2)
    img3d = np.stack([load_dicom_image(f) for f in files[p1:p2]]).T 
    
    if img3d.shape[-1] < num_imgs:
        n_zero = np.zeros((img_size, img_size, num_imgs - img3d.shape[-1]))
        img3d = np.concatenate((img3d,  n_zero), axis = -1)
            
    return img3d

In [None]:
images = []
for i in mri_types:
    images.append(load_dicom_images_3d(scan_id="00000", mri_type=i))
four_channel_pack = np.stack(images)
four_channel_pack = np.transpose(four_channel_pack, (0, 3, 1, 2))

print(four_channel_pack.shape)
plt.imshow(four_channel_pack[0, 10, :, :])

In [None]:
def set_seed(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True

set_seed(42)

In [None]:
train_df = pd.read_csv(f"{data_directory}/train_labels.csv")
display(train_df)

df_train, df_valid = sk_model_selection.train_test_split(
    train_df, 
    test_size=0.3, 
    random_state=42, 
    stratify=train_df["MGMT_value"],
)

### 3D Augmenatation and Transformation

In [None]:
def get_training_augmentation():
    train_transform = [
        album.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ]
    return album.Compose(train_transform,)

In [None]:
# _transforms = {
#     tio.RandomFlip(axes=['LR', 'AP', 'IS'], p=0.8),
#     tio.RandomElasticDeformation(p=0.2),
# #     tio.RandomAffine(scales=(0.9, 1.2), degrees=10, isotropic=True, image_interpolation="nearest", p=1),
# #     tio.RandomNoise(p=0.2),
# #     tio.RandomMotion(p=0.3),
# #     tio.RandomGhosting(p=0.4),
# #     tio.RandomSpike(p=0.2),
#     tio.ZNormalization(masking_method=tio.ZNormalization.mean, p=1),
#     tio.RescaleIntensity(out_min_max=(0, 1), p=1),
    
# #     tio.OneOf([
# # #         tio.RandomMotion(p=0.2),
# #         tio.RandomBiasField(p=0.3),
# # #         tio.RandomNoise(p=0.5),
# #     ]),
    
# #     tio.OneOf([
# #         tio.ZNormalization(masking_method=tio.ZNormalization.mean, p=0.5),
# #         tio.RescaleIntensity(out_min_max=(0, 1), p=0.5),  
# #     ])
# }

# outer_transforms = _transforms
# transform = tio.Compose(outer_transforms)

### Dataset and Dataloader

In [None]:
class Dataset(torch_data.Dataset):
    def __init__(self, paths, labels=None, mri_type=None, label_smoothing=0.01, augmentation=None, transformation=None, split="train"):
        self.paths = paths
        self.labels = labels
        self.mri_type = mri_type
        self.label_smoothing = label_smoothing
        self.augmentation = augmentation
        self.transformation = transformation
        self.split = split
          
    def __len__(self):
        return len(self.paths)
    
    def __getitem__(self, index):
        scan_id = self.paths[index]
        label = self.labels[index]
        
#         inner_transform = transforms.Compose([
#             transforms.ToTensor(), 
# #             transforms.Normalize((0.5,) * NUM_IMAGES, (0.5,) * NUM_IMAGES)
#         ])
    

        images = []
        for i in mri_types:
            image_3d = load_dicom_images_3d(scan_id=str(scan_id).zfill(5), mri_type=i)
            
#             normalization
            if self.augmentation:
                for i in range(image_3d.shape[-1]):
                    temp_img = image_3d[:, :, i].astype(np.uint8)
                    temp_img = cv2.cvtColor(temp_img, cv2.COLOR_BGR2RGB)
#                     temp_img = cv2.fastNlMeansDenoisingColored(temp_img, None, 3, 3, 7, 21)
                    temp_img = self.augmentation(image=temp_img)['image'][:, :, 0]
                    image_3d[:, :, i] = temp_img
            images.append(image_3d)
        four_channel_pack = np.stack(images)
        
        # transformation
        if self.transformation:
            four_channel_pack = self.transformation(four_channel_pack)
        
        four_channel_pack = np.transpose(four_channel_pack, (0, 3, 1, 2))
        y = self.labels[index]
#         y = torch.tensor(self.labels[index], dtype=torch.float)
        
        return {"X": torch.tensor(four_channel_pack).float(), "y": y}

### Test Transformations and Normalization

In [None]:
train_data_retriever = Dataset(
    paths=df_train["BraTS21ID"].values, 
    labels=df_train["MGMT_value"].values,
    augmentation=get_training_augmentation(),
#     transformation=transform,
)

train_loader = torch_data.DataLoader(train_data_retriever, batch_size=4, shuffle=False, num_workers=8,)
a = train_data_retriever[5]["X"]

In [None]:
plt.imshow(a[0, 12, :, :])
print(a[0, 14, :, :].max(), a[0, 14, :, :].min())

### Model and Training

In [None]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = EfficientNet3D.from_name("efficientnet-b0", override_params={'num_classes': 2}, in_channels=4)
        n_features = self.net._fc.in_features
        self.net._fc = nn.Linear(in_features=n_features, out_features=2, bias=True)

#         self.net._fc = nn.Sequential(nn.Linear(n_features, 256),
#                                         nn.SELU(),
#                                         nn.Dropout(p=0.5),
#                                         nn.Linear(256, 64),
#                                         nn.SELU(),
#                                         nn.Dropout(p=0.5),
#                                         nn.Linear(64, 2),
#                                         nn.LogSigmoid()
#                                 )
    
    def forward(self, x):
        out = self.net(x)
        return out

In [None]:
class Trainer:
    def __init__(
        self, 
        model, 
        device, 
        optimizer, 
        criterion
    ):
        self.model = model
        self.device = device
        self.optimizer = optimizer
        self.criterion = criterion

        self.best_valid_score = np.inf
        self.n_patience = 0
        self.lastmodel = None
        
    def fit(self, epochs, train_loader, valid_loader, save_path, patience):        
        for n_epoch in range(1, epochs + 1):
            self.info_message("EPOCH: {}", n_epoch)
            
            train_acc, train_loss, train_time = self.train_epoch(train_loader)
            valid_acc, valid_loss, valid_time = self.valid_epoch(valid_loader)
            
            self.info_message(
                "[Epoch Train: {}] accuracy: {:.4f}, loss: {:.4f}, time: {:.2f} s            ",
                n_epoch, train_acc, train_loss, train_time
            )
            
            self.info_message(
                "[Epoch Valid: {}] accuracy: {:.4f}, loss: {:.4f}, time: {:.2f} s",
                n_epoch, valid_acc, valid_loss, valid_time
            )

            # if True:
            # if self.best_valid_score < valid_auc: 
            if self.best_valid_score > valid_loss: 
                self.save_model(n_epoch, save_path, valid_loss)
                self.info_message(
                     "loss has decresed from {:.4f} to {:.4f}. Saved model to '{}'", 
                    self.best_valid_score, valid_loss, self.lastmodel
                )
                self.best_valid_score = valid_loss
                self.n_patience = 0
            else:
                self.n_patience += 1
            
            if self.n_patience >= patience:
                self.info_message("\nValid auc didn't improve last {} epochs.", patience)
                break
            
    def train_epoch(self, train_loader):
        self.model.train()
        t = time.time()
        sum_loss = 0
        train_acc = 0.0

        for step, batch in enumerate(train_loader, 1):
            X = batch["X"].to(self.device)
            targets = batch["y"].to(self.device)
            
            self.optimizer.zero_grad()
            outputs = self.model(X).squeeze(1)
            print(torch.nn.functional.softmax(outputs, dim=1)[0] * 100)
#             print(outputs.argmax())
        
            loss = self.criterion(outputs, targets)
            loss.backward()
            sum_loss += loss.detach().item()
            self.optimizer.step()
            
            _, pred = torch.max(outputs, dim=1)
#             print(pred, targets)
            train_acc += sum((pred == targets).cpu().numpy())
            
            message = 'Train Step {}/{}, train_loss: {:.4f}'
            self.info_message(message, step, len(train_loader), sum_loss/step, end="\r")
#         print(train_acc / len(train_loader.dataset))
        return train_acc / len(train_loader.dataset), sum_loss / len(train_loader), int(time.time() - t)
    
    def valid_epoch(self, valid_loader):
        self.model.eval()
        t = time.time()
        sum_loss = 0
        y_all = []
        outputs_all = []
        
        valid_acc = 0.0

        for step, batch in enumerate(valid_loader, 1):
            with torch.no_grad():
                X = batch["X"].to(self.device)
                targets = batch["y"].to(self.device)

                outputs = self.model(X)
                print(torch.nn.functional.softmax(outputs, dim=1)[0] * 100)
#                 print(outputs.argmax())
                
                loss = self.criterion(outputs, targets)
                sum_loss += loss.detach().item()
                y_all.extend(batch["y"].tolist())
                outputs_all.extend(outputs.tolist())
                
                _, pred = torch.max(outputs, dim=1)
#                 print(pred, targets)
                valid_acc += sum((pred == targets).cpu().numpy())

            message = 'Valid Step {}/{}, valid_loss: {:.4f}'
            self.info_message(message, step, len(valid_loader), sum_loss/step, end="\r")

        return valid_acc / len(valid_loader.dataset), sum_loss/len(valid_loader), int(time.time() - t)
    
    def save_model(self, n_epoch, save_path, loss):
        self.lastmodel = f"{save_path}-e{n_epoch}-loss{loss:.3f}.pth"
        torch.save(
            {
                "model_state_dict": self.model.state_dict(),
                "optimizer_state_dict": self.optimizer.state_dict(),
                "best_valid_score": self.best_valid_score,
                "n_epoch": n_epoch,
            },
            self.lastmodel,
        )
    
    @staticmethod
    def info_message(message, *args, end="\n"):
        print(message.format(*args), end=end)

In [None]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# torch.cuda.empty_cache()

# def train_mri_type(df_train, df_valid):
    
#     train_data_retriever = Dataset(
#         paths=df_train["BraTS21ID"].values, 
#         labels=df_train["MGMT_value"].values,
#         augmentation=get_training_augmentation(),
# #         transformation=transform,
#     )
#     valid_data_retriever = Dataset(
#         paths=df_valid["BraTS21ID"].values, 
#         labels=df_valid["MGMT_value"].values,
#         augmentation=get_training_augmentation(),
# #         transformation=transform,
#     )

#     train_loader = torch_data.DataLoader(train_data_retriever, batch_size=8, shuffle=True, num_workers=12,)
#     valid_loader = torch_data.DataLoader(valid_data_retriever, batch_size=8, shuffle=True, num_workers=8,)

#     model = Model()
#     model.to(device)

#     # UPTRAIN
#     checkpoint_file = "../input/3dbrainmrimodels/classification_model-e5-loss0.683.pth"
#     if torch.cuda.is_available():
#         checkpoint = torch.load(checkpoint_file)
#     else:
#         checkpoint = torch.load(checkpoint_file, map_location=torch.device('cpu'))    
#     model.load_state_dict(checkpoint["model_state_dict"])


#     optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)
# #     optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
# #     criterion = torch_functional.binary_cross_entropy_with_logits
#     criterion = nn.CrossEntropyLoss()
# #     criterion = nn.BCEWithLogitsLoss()

#     trainer = Trainer(
#         model, 
#         device, 
#         optimizer, 
#         criterion
#     )

#     history = trainer.fit(
#         10, 
#         train_loader, 
#         valid_loader, 
#         "classification_model",
#         10,
#     )
    
#     return trainer.lastmodel

# train_mri_type(df_train, df_valid)

In [None]:
# Extra save the model

# lastmodel = f"classification_model-e10-loss0.5304.pth"
# torch.save(
#     {
#         "model_state_dict": model.state_dict(),
#         "n_epoch": 10,
#     },
#     lastmodel,
# )

### Testing the model

In [None]:
df_train = df_train.set_index("BraTS21ID")
df_train["MGMT_pred"] = 0

In [None]:
df_valid = df_valid.set_index("BraTS21ID")
df_valid["MGMT_pred"] = 0

In [None]:
# modelfile = "./classification_model-e5-loss0.683.pth"
modelfile = "./classification_model-e10-loss0.5304.pth"
model = Model()
model.to(device)

if torch.cuda.is_available():
    checkpoint = torch.load(modelfile)
else:
    checkpoint = torch.load(modelfile, map_location=torch.device('cpu'))    
model.load_state_dict(checkpoint["model_state_dict"])

In [None]:
class TestDataset(torch_data.Dataset):
    def __init__(self, paths, labels=None, mri_type=None, label_smoothing=0.01, augmentation=None, split="train"):
        self.paths = paths
        self.labels = labels
        self.mri_type = mri_type
        self.label_smoothing = label_smoothing
        self.augmentation = augmentation
        self.split = split
          
    def __len__(self):
        return len(self.paths)
    
    def __getitem__(self, index):
        scan_id = self.paths[index]
        label = self.labels[index]
        
        transform = transforms.Compose([
            transforms.ToTensor(), 
            transforms.Normalize((0.5,) * NUM_IMAGES, (0.5,) * NUM_IMAGES)
        ])
        
        images = []
        for i in mri_types:
            image_3d = load_dicom_images_3d(scan_id=str(scan_id).zfill(5), mri_type=i)
            image_3d = transform(image_3d)
            images.append(image_3d)
            
        four_channel_pack = np.stack(images)        
        
        y = self.labels[index]
        return torch.tensor(four_channel_pack).float(), y

In [None]:
test_data_retriever = TestDataset(
    df_train.index.values, 
    df_train["MGMT_value"].values,
    augmentation=get_training_augmentation(),
    split="test",
)

test_data_loader = torch_data.DataLoader(
    test_data_retriever,
    batch_size=16,
    shuffle=False,
    num_workers=8,
)

In [None]:
a, b = test_data_retriever[0]
plt.imshow(a[0, 10, :, :])

In [None]:
# fig, axis = plt.subplots(4, 4, figsize=(8, 10))
# with torch.no_grad():
#     model.eval()
#     for i, ax in enumerate(axis.flat):
#         image, label = test_fetures[i].to(device), test_labels[i].to(device)
#         ax.imshow(image[0, 20, :, :].cpu())
        
# #         image_tensor = image.unsqueeze_(0)
#         output_ = model(image).squeeze(1)
#         _, pred = torch.max(output_, dim=1)
#         print(pred, label)
        

#         print(
#             torch.nn.functional.softmax(model(image), dim=1)[0] * 100,
#         )
#         _, index = torch.max(output_, 1)
        
#         percentage = torch.nn.functional.softmax(output_, dim=1)[0] * 100
#         result = output_.argmax()
        
#         print(index, percentage[0], percentage[1])
#         print(sum([percentage[0], percentage[1]]), "\n")
#         ax.set(title = f"actual:{label}\nprediction:{result}")

### AUC Score

In [None]:
data_retriever = Dataset(
    df_valid.index.values, 
    df_valid["MGMT_value"].values,
    split="test",
)

data_loader = torch_data.DataLoader(
    data_retriever,
    batch_size=1,
    shuffle=False,
    num_workers=8,
)

y_preds = []
y = []
outputs = []

for e, batch in enumerate(data_loader):
    print(f"{e}/{len(data_loader)}", end="\r")
    with torch.no_grad():
        model.eval()
        image, label = batch["X"].to(device), batch["y"]
        
#         plt.imshow(image.cpu()[0, 0, 30, :, :])
#         plt.show()
        
        output_ = model(image).squeeze(1)
        _, pred = torch.max(output_, dim=1)

        _, index = torch.max(output_, 1)
        percentage = torch.nn.functional.softmax(output_, dim=1)[0]
        print(percentage * 100)
#         tmp_pred = torch.nn.functional.softmax(model(image), dim=1)[0] * 100
    
    
        label = label.detach().cpu().numpy()[0]
        prediction = float(percentage[1].detach().cpu().numpy())
        y.append(label)
        y_preds.append(prediction)
        outputs.append(index)

In [None]:
# score
y = np.array(y)
y_preds = np.array(y_preds)

fpr, tpr, thresholds = metrics.roc_curve(y, y_preds, pos_label=1)
roc_auc = metrics.auc(fpr, tpr)
acc = (sum([x == y for x, y in zip(outputs, y)]) / len(y)).detach()[0]

print(f"AUC score is: {roc_auc}")
print(f"Accuracy score is: {acc}")

In [None]:
lr_fpr, lr_tpr, _ = metrics.roc_curve(y, y_preds)
plt.plot(lr_fpr, lr_tpr, marker='.', label='abc')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.legend()
plt.show()

In [None]:
class SUBMISSIONDataset(torch_data.Dataset):
    def __init__(self, 
                 augmentation=None, 
                 preprocessing=None,
                ):
        self.indexes = sorted(os.listdir("../input/rsna-miccai-brain-tumor-radiogenomic-classification/test"))
        self.augmentation = augmentation
        self.preprocessing = preprocessing
          
    def __len__(self):
        return len(self.indexes)
    
    def __getitem__(self, index):
        scan_id = self.indexes[index]
        four_channel_pack = None
    
        transform = transforms.Compose([
            transforms.ToTensor(), 
            transforms.Normalize((0.5,) * NUM_IMAGES, (0.5,) * NUM_IMAGES)
        ])
        images = []
        
        try:
            for i in mri_types:
                try:
                    image_3d = load_dicom_images_3d(scan_id=str(scan_id).zfill(5), split="test", mri_type=i)
                    image_3d = transform(image_3d)
                    images.append(image_3d)
                except:
                    pass
                four_channel_pack = np.stack(images)
        except:
            pass

        return {"X": torch.tensor(four_channel_pack).float(), "id": scan_id}

In [None]:
submission = pd.read_csv("../input/rsna-miccai-brain-tumor-radiogenomic-classification/sample_submission.csv")

test_dataset = SUBMISSIONDataset(
#     augmentation=get_validation_augmentation(),
)

data_loader = torch_data.DataLoader(
    test_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=8,
)

In [None]:
ids = []
preds = []

for e, batch in enumerate(data_loader):
    print(f"{e}/{len(data_loader)}", end="\r")
    with torch.no_grad():
#         model.eval()
        image, id = batch["X"].to(device), str(batch["id"][0])
        
        try:
            output_ = model(image).squeeze(1)
            percentage = torch.nn.functional.softmax(output_, dim=1)[0]

            prediction = float(percentage[1].detach().cpu().numpy())
        except:
            prediction = 0.5
            
        preds.append(prediction)
        ids.append(id)

In [None]:
df = pd.DataFrame({"BraTS21ID": ids, "MGMT_value": preds})
# df['MGMT_value'] = preds
df[['BraTS21ID', 'MGMT_value']].to_csv("submission.csv", index=False)

In [None]:
df

In [None]:
sns.displot(df["MGMT_value"])