In [None]:
import os
import sys 
import gc
import glob
import random 
import cv2
import numpy as np 
import pandas as pd 
from sklearn import metrics
import torch
from torch import nn
import torch.nn.functional as F
from tqdm import tqdm_notebook as tqdm
import pydicom
from pydicom.pixel_data_handlers.util import apply_voi_lut

gc.enable()

In [None]:
package_path = "../input/efficientnet-pytorch/EfficientNet-PyTorch/EfficientNet-PyTorch-master/"
sys.path.append(package_path)

import efficientnet_pytorch

In [None]:
BATCH_SIZE = 48 
NUM_FOLDS = 5
NUM_EPOCHS = 3
DEVICE = 'cuda'
LEARNING_RATE = 1e-5

train_df = pd.read_csv('../input/rsna-brain-folds/train_folds.csv')

In [None]:
def set_random_seed(random_seed):
    random.seed(random_seed)
    np.random.seed(random_seed)
    os.environ["PYTHONHASHSEED"] = str(random_seed)

    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)

    torch.backends.cudnn.deterministic = True
    
set_random_seed(1729)

In [None]:
def load_dicom(path):
    dicom = pydicom.read_file(path)
    data = dicom.pixel_array
    data = data - np.min(data)
    if np.max(data) != 0:
        data = data / np.max(data)
    data = (data * 255).astype(np.uint8)
    return data

In [None]:
class Dataset:
    def __init__(self, paths, targets):
        self.paths = paths
        self.targets = targets
    
    def __len__(self):
        return len(self.paths)
    
    def __getitem__(self, index, inference_only=False):
        _id = self.paths[index]
        patient_path = f"../input/rsna-miccai-brain-tumor-radiogenomic-classification/train/{str(_id).zfill(5)}/"
        channels = []
        for t in ("FLAIR", "T1w", "T1wCE"): # "T2w"
            t_paths = sorted(
                glob.glob(os.path.join(patient_path, t, "*")), 
                key=lambda x: int(x[:-4].split("-")[-1]),
            )
            # start, end = int(len(t_paths) * 0.475), int(len(t_paths) * 0.525)
            x = len(t_paths)
            if x < 10:
                r = range(x)
            else:
                d = x // 10
                r = range(d, x - d, d)
                
            channel = []
            # for i in range(start, end + 1):
            for i in r:
                channel.append(cv2.resize(load_dicom(t_paths[i]), (256, 256)) / 255)
            channel = np.mean(channel, axis=0)
            channels.append(channel)
        
        if inference_only:
            return {
                'X': torch.tensor(channels).float()
            }
        
        return {
            "X": torch.tensor(channels).float(), 
            "y": torch.tensor(self.targets[index], dtype=torch.float),
        }

In [None]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = efficientnet_pytorch.EfficientNet.from_name("efficientnet-b0")
        checkpoint = torch.load("../input/efficientnet-pytorch/efficientnet-b0-08094119.pth")
        self.net.load_state_dict(checkpoint)
        n_features = self.net._fc.in_features
        self.net._fc = nn.Linear(in_features=n_features, out_features=1, bias=True)
    
    def forward(self, x):
        out = self.net(x)
        return out

In [None]:
def train_fn(data_loader, model, optimizer, device, valid_data, best_score, fold):
    
    for iteration, data in enumerate(data_loader):
        # print('chk1')
        features = data['X']
        target = data['y']
        features = features.to(device, dtype=torch.float)
        target = target.to(device, dtype=torch.float)
        
        optimizer.zero_grad()
        
        # print('chk2')
        predictions = model(features)
        loss = F.binary_cross_entropy_with_logits(predictions.flatten(), target)
        loss.backward()
        optimizer.step()
        
        # print('chk3')
        if len(data_loader) == iteration + 1:
            current_score = eval_fn(valid_data, model, device)
            if current_score > best_score: 
                best_score = current_score
                torch.save(model.state_dict(), f'model_{fold}.pth')
            print(f'Step: {iteration}, Current Score: {current_score}, Best Score: {best_score}')
                
    return best_score
        
def eval_fn(data_loader, model, device):
    final_predictions = []
    final_targets = []
    
    model.eval()
    
    with torch.no_grad():
        for data in data_loader:
            features = data['X']
            target = data['y']

            features = features.to(device, dtype=torch.float)
            target = target.to(device, dtype=torch.float)
            
            predictions = model(features).squeeze()
            predictions = torch.sigmoid(predictions).cpu().detach().numpy().tolist()
            final_predictions.extend(predictions)
            
            target = target.cpu().detach().numpy().tolist()
            final_targets.extend(target)
        
        score = metrics.roc_auc_score(final_targets, final_predictions)
    
        return score

In [None]:
def run(data, fold):
    
    print(f'Fold: {fold}')
    
    train_data = data[data['kfold'] != fold].reset_index(drop=True)
    val_data = data[data['kfold'] == fold].reset_index(drop=True)
    
    train_dataset = Dataset(
        train_data.BraTS21ID.values,
        train_data.MGMT_value.values
    )
    val_dataset = Dataset(
        val_data.BraTS21ID.values,
        val_data.MGMT_value.values
    )
    
    train_loader = torch.utils.data.DataLoader(
        train_dataset, 
        batch_size=BATCH_SIZE
    )
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=BATCH_SIZE
    )
    
    
    DEVICE = torch.device('cuda')
    model = Model()
    model.to(DEVICE)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    
    best_score = -1
    for epoch in tqdm(range(NUM_EPOCHS)):
        print(f'Epoch: {epoch + 1}/{NUM_EPOCHS}')
        best_score = train_fn(train_loader, model, optimizer, DEVICE, val_loader, best_score, fold)
        print(f'Best Score for epoch {epoch + 1}: {best_score}')
        
    del model
    gc.collect()

In [None]:
for fold in range(NUM_FOLDS):
    run(train_df, fold)