In [None]:
from IPython.display import clear_output
!pip install git+https://github.com/shijianjian/EfficientNet-PyTorch-3D
clear_output()

In [None]:
import os
import sys 
import json
from glob import glob
import random
import collections
import time
import re
import warnings

import numpy as np
import pandas as pd
import cv2
import matplotlib.pyplot as plt
import seaborn as sns

import torch
from torch import nn
from torch.utils import data as torch_data
from sklearn import model_selection as sk_model_selection
from torch.nn import functional as torch_functional
import torch.nn.functional as F
from torchvision import transforms
import torchvision.models as models
import pydicom
from pydicom.pixel_data_handlers.util import apply_voi_lut
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import roc_auc_score
from efficientnet_pytorch_3d import EfficientNet3D
import joblib
from tqdm.notebook import tqdm

def set_seed(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True

set_seed(1)
warnings.filterwarnings("ignore") 
warnings.filterwarnings("ignore", category=DeprecationWarning) 

torch.hub._validate_not_a_forked_repo=lambda a,b,c: True

In [None]:
data_directory = './'
mri_type = 'FLAIR'
train = pd.read_csv("../input/rsna-miccai-brain-tumor-radiogenomic-classification/train_labels.csv")
l = [x.split('/')[-1] for x in glob('../input/rsna-miccai-brain-tumor-radiogenomic-classification/train/*')]
train.BraTS21ID = l

if mri_type=='all':
  train_df = train.iloc[50::]  
  SIZE = 256
  NUM_IMAGES = 64
  batch_size = 4
  in_channels = 3
  lr = 0.000001

if mri_type=='FLAIR':
  train_df = train.iloc[0:50] 
  SIZE = 256
  NUM_IMAGES = 64
  batch_size = 4
  in_channels = 3
  lr = 0.00002

elif mri_type=='T1w':
  SIZE = 256
  NUM_IMAGES = 64
  batch_size = 4
  in_channels = 3
  lr = 0.00002

elif mri_type=='T1wCE':
  SIZE = 256
  NUM_IMAGES = 64
  batch_size = 4
  in_channels = 3
  lr = 0.00002

elif mri_type=='T2w':
  SIZE = 256
  NUM_IMAGES = 64
  batch_size = 4
  in_channels = 3
  lr = 0.00002

In [None]:
SIZE = 256

def read_mri(path, voi_lut=False, fix_monochrome=True):
    dicom = pydicom.read_file(path)    
    data = dicom.pixel_array.astype(float)
    
    if voi_lut:
        data = apply_voi_lut(data, dicom)
    
    if fix_monochrome:
        if dicom.PhotometricInterpretation == "MONOCHROME1":
            data = np.amax(data) - data
        
    data = data - np.min(data)
    if np.max(data)!=0:
        data = data / np.max(data)
    data = (data * 255).astype(np.uint8)
            
    data = cv2.resize(data, (SIZE, SIZE))
    
    return data

def id2path(img_id, is_test):
    if is_test: 
        return glob(f'../input/rsna-miccai-brain-tumor-radiogenomic-classification/test/{img_id}/*/*.dcm')
    else: 
        return glob(f'../input/rsna-miccai-brain-tumor-radiogenomic-classification/train/{img_id}/*/*.dcm')
    
def save_train_img(_id):
    fnames = id2path(_id, False)
    os.mkdir('train/'+str(_id))
    os.mkdir('train/'+str(_id)+'/FLAIR')
    os.mkdir('train/'+str(_id)+'/T1w')
    os.mkdir('train/'+str(_id)+'/T1wCE')
    os.mkdir('train/'+str(_id)+'/T2w')
    for fname in fnames:
        im = read_mri(fname)
        n = fname.replace('../input/rsna-miccai-brain-tumor-radiogenomic-classification/','')
        n = n.replace('dcm','png')
        cv2.imwrite(n, im)    

!mkdir train
    
_ = joblib.Parallel(n_jobs=-1)(joblib.delayed(save_train_img)(_id) for _id in tqdm(train_df.BraTS21ID))

In [None]:
def get_transforms():
    return transforms.Compose([
        transforms.ToPILImage(),                               
        transforms.RandomVerticalFlip(p=0.05),
        transforms.RandomAutocontrast(),
        transforms.RandomGrayscale(),
    ]) 

def load_dicom_image(path, img_size=SIZE, agument=False):
    img = cv2.imread(path)
    img = cv2.resize(img, (SIZE,SIZE))

    if agument:
        img = get_transforms()(img)

    return img

def load_dicom_images_3d(scan_id, num_imgs=NUM_IMAGES, img_size=SIZE, mri_type="FLAIR", split="train", agument=False):

    files = sorted(glob(f"{data_directory}/{split}/{scan_id}/{mri_type}/*.png"), 
               key=lambda var:[int(x) if x.isdigit() else x for x in re.findall(r'[^0-9]|[0-9]+', var)])

    middle = len(files)//2
    num_imgs2 = num_imgs//2
    p1 = max(0, middle - num_imgs2)
    p2 = min(len(files), middle + num_imgs2)
    img3d = np.stack([load_dicom_image(f, agument=agument) for f in files[p1:p2]]).T 
    if img3d.shape[-1] < num_imgs:
        n_zero = np.zeros((img3d.shape[:-1] + (num_imgs - img3d.shape[-1],)))
        img3d = np.concatenate((img3d,  n_zero), axis=-1)
            
    return img3d

In [None]:
a = load_dicom_images_3d("00452")
z = np.zeros((a.shape[:-1]+(4,)))
print(a.shape)
print(np.min(a), np.max(a), np.mean(a), np.median(a))

In [None]:
df_train, df_valid = sk_model_selection.train_test_split(
    train_df, 
    test_size=0.1, 
    random_state=1, 
    stratify=train_df["MGMT_value"],
)

df_train.shape, df_valid.shape

In [None]:
def onehot(size, target):
    vec = torch.zeros(size, dtype=torch.float64)
    vec[target] = 1.
    return vec  

class Dataset(torch_data.Dataset):
    def __init__(self, paths, targets=None, mri_type=None, split="train", augment=False):
        self.paths = paths
        self.targets = targets
        self.mri_type = mri_type
        self.split = split
        self.augment = augment
          
    def __len__(self):
        return len(self.paths)
    
    def __getitem__(self, index):
        scan_id = self.paths[index]
        if self.targets is None:
            data = load_dicom_images_3d(str(scan_id).zfill(5), mri_type=self.mri_type[index], split=self.split)
        else:
            if self.augment:
                data = load_dicom_images_3d(str(scan_id).zfill(5), mri_type=self.mri_type[index], split="train", agument=True)
            else:
                data = load_dicom_images_3d(str(scan_id).zfill(5), mri_type=self.mri_type[index], split="train")

        if self.targets is None:
            return {"X": torch.tensor(data).float(), "id": scan_id}
        else:
            y = torch.tensor(abs(self.targets[index]- 0.01), dtype=torch.float)
            return {"X": torch.tensor(data).float(), "y": y}

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = EfficientNet3D.from_name("efficientnet-b7", override_params={'num_classes': 2}, in_channels=in_channels)
        n_features = self.net._fc.in_features
        self.net._fc = nn.Linear(in_features=n_features, out_features=1, bias=True)

    def forward(self, x):
        out = self.net(x)
        return out
    
class Trainer:
    def __init__(
        self, 
        model, 
        device, 
        optimizer, 
        criterion,
        scheduler
    ):
        self.model = model
        self.device = device
        self.optimizer = optimizer
        self.criterion = criterion
        self.scheduler = scheduler

        self.best_valid_auc = 0
        self.n_patience = 0
        self.lastmodel = None
        
    def fit(self, epochs, train_loader, valid_loader, save_path, patience):        
        for n_epoch in range(1, epochs + 1):
            self.info_message("EPOCH: {}, \nLR: {}", n_epoch, self.optimizer.param_groups[0]['lr'])
            
            train_loss, train_time = self.train_epoch(train_loader)

            valid_loss, valid_auc, valid_time = self.valid_epoch(valid_loader)
            
            self.scheduler.step(valid_loss)
            
            self.info_message(
                "[Epoch Train: {}] loss: {:.4f}, time: {:.2f} s            ",
                n_epoch, train_loss, train_time
            )
            
            self.info_message(
                "[Epoch Valid: {}] loss: {:.4f}, auc: {:.4f}, time: {:.2f} s",
                n_epoch, valid_loss, valid_auc, valid_time
            )

            if self.best_valid_auc < valid_auc: 
                self.save_model(n_epoch, save_path, valid_loss, valid_auc)
                self.info_message(
                     "auc improved from {:.4f} to {:.4f}. Saved model to '{}'", 
                    self.best_valid_auc, valid_auc, self.lastmodel
                )
                self.best_valid_auc = valid_auc
                self.n_patience = 0
            else:
                self.n_patience += 1

            if self.n_patience >= patience:
                self.info_message("\nValid auc didn't improve last {} epochs.", patience)
                break
            
    def train_epoch(self, train_loader):
        self.model.train()
        t = time.time()
        sum_loss = 0

        for step, batch in enumerate(train_loader, 1):
            X = batch["X"].to(self.device)
            targets = batch["y"].to(self.device)
            self.optimizer.zero_grad()
            outputs = self.model(X).squeeze(1)
            
            loss = self.criterion(outputs, targets)
            loss.backward()

            sum_loss += loss.detach().item()

            self.optimizer.step()
            
            message = 'Train Step {}/{}, train_loss: {:.4f}'
            self.info_message(message, step, len(train_loader), sum_loss/step, end="\r")
        
        return sum_loss/len(train_loader), int(time.time() - t)
    
    def valid_epoch(self, valid_loader):
        self.model.eval()
        t = time.time()
        sum_loss = 0
        y_all = []
        outputs_all = []

        for step, batch in enumerate(valid_loader, 1):
            with torch.no_grad():
                X = batch["X"].to(self.device)
                targets = batch["y"].to(self.device)

                outputs = self.model(X).squeeze(1)
                loss = self.criterion(outputs, targets)

                sum_loss += loss.detach().item()
                y_all.extend(batch["y"].tolist())
                outputs_all.extend(torch.sigmoid(outputs).tolist())

            message = 'Valid Step {}/{}, valid_loss: {:.4f}'
            self.info_message(message, step, len(valid_loader), sum_loss/step, end="\r")
            
        y_all = [1 if x > 0.5 else 0 for x in y_all]
        auc = roc_auc_score(y_all, outputs_all)
        
        return sum_loss/len(valid_loader), auc, int(time.time() - t)
    
    def save_model(self, n_epoch, save_path, loss, auc):
        self.lastmodel = f"{save_path}-e{n_epoch}-loss{loss:.3f}-auc{auc:.3f}.pth"
        torch.save(
            {
                "model_state_dict": self.model.state_dict(),
                "optimizer_state_dict": self.optimizer.state_dict(),
                "best_valid_score": self.best_valid_auc,
                "n_epoch": n_epoch,
            },
            self.lastmodel,
        )
    
    @staticmethod
    def info_message(message, *args, end="\n"):
        print(message.format(*args), end=end)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def train_mri_type(df_train, df_valid, mri_type):
    if mri_type=="all":
        train_list = []
        valid_list = []
        for mri_type in mri_types:
          if mri_type=='FLAIR':
            train_df = pd.read_csv("train_labels.csv")
            train_df = train_df.iloc[0:50]
            df_train, df_valid = sk_model_selection.train_test_split(train_df, test_size=0.1, random_state=1, stratify=train_df["MGMT_value"])
        
          if mri_type=='T1w':
            train_df = pd.read_csv("train_labels.csv")
            train_df = pd.concat([train_df.iloc[0:50], train_df.iloc[200::]]).iloc[50:100]
            df_train, df_valid = sk_model_selection.train_test_split(train_df, test_size=0.1, random_state=1, stratify=train_df["MGMT_value"])
          
          if mri_type=='T1wCE':
            train_df = pd.read_csv("train_labels.csv")
            train_df = pd.concat([train_df.iloc[0:250], train_df.iloc[300:350], train_df.iloc[400::]]).iloc[0:50]
            df_train, df_valid = sk_model_selection.train_test_split(train_df, test_size=0.1, random_state=1, stratify=train_df["MGMT_value"])

          if mri_type=='T2w':
            train_df = pd.read_csv("train_labels.csv")
            train_df = pd.concat([train_df.iloc[50:550]]).iloc[0:50]
            df_train, df_valid = sk_model_selection.train_test_split(train_df, test_size=0.1, random_state=1, stratify=train_df["MGMT_value"])
          
          df_train.loc[:,"MRI_Type"] = mri_type
          train_list.append(df_train.copy())
          df_valid.loc[:,"MRI_Type"] = mri_type
          valid_list.append(df_valid.copy())

        df_train = pd.concat(train_list)
        df_valid = pd.concat(valid_list)
    else:
        df_train.loc[:,"MRI_Type"] = mri_type
        df_valid.loc[:,"MRI_Type"] = mri_type

    print(df_train.shape, df_valid.shape)
    display(df_train.head())
    
    train_data_retriever = Dataset(
        df_train["BraTS21ID"].values, 
        df_train["MGMT_value"].values, 
        df_train["MRI_Type"].values,
        augment=True
    )

    valid_data_retriever = Dataset(
        df_valid["BraTS21ID"].values, 
        df_valid["MGMT_value"].values,
        df_valid["MRI_Type"].values
    )

    train_loader = torch_data.DataLoader(
        train_data_retriever,
        batch_size=batch_size,
        shuffle=True,
        num_workers=2,pin_memory = True,
    )

    valid_loader = torch_data.DataLoader(
        valid_data_retriever, 
        batch_size=batch_size,
        shuffle=False,
        num_workers=2,pin_memory = True
    )

    model = Model()
    model.to(device)

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.001},
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
    optimizer = torch.optim.Adam(optimizer_grouped_parameters, lr=lr)

    criterion = torch_functional.binary_cross_entropy_with_logits
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=5, 
                                                           threshold=0.0001, threshold_mode='abs', cooldown=0, 
                                                           min_lr=0, eps=1e-08, verbose=False)
    trainer = Trainer(
        model, 
        device, 
        optimizer, 
        criterion,
        scheduler
    )

    history = trainer.fit( 
        1,  
        train_loader, 
        valid_loader, 
        f"./{mri_type}",  
        1,
    )
    
    return trainer.lastmodel


In [None]:
train_mri_type(df_train, df_valid, mri_type)