In [None]:
import os
import sys 
import json
import glob
import random
import collections
import time
import re
import math
import numpy as np
import pandas as pd
import cv2
from tqdm import tqdm

import matplotlib.pyplot as plt
import seaborn as sns
import pydicom
from pydicom.pixel_data_handlers.util import apply_voi_lut

from random import shuffle
from sklearn import model_selection as sk_model_selection
from sklearn.metrics import roc_auc_score

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader

In [None]:
data_directory = '../input/rsna-miccai-brain-tumor-radiogenomic-classification'
 
mri_types = ['FLAIR','T1w','T1wCE','T2w']
mri_types_id=0 # 0,1,2,3

IMAGE_SIZE = 256
NUM_IMAGES = 100
BATCH_SIZE= 16
NUM_EPOCHS = 50
DEVICE = "cuda:0"

train_df = pd.read_csv("../input/rsna-miccai-brain-tumor-radiogenomic-classification/train_labels.csv")
train_df['BraTS21ID5'] = [format(x, '05d') for x in train_df.BraTS21ID]
train_df.head(5)

In [None]:
indexs = [not val in ["00109", "00123", "00709"] for val in train_df["BraTS21ID5"]]
train_df = train_df[indexs]

In [None]:
def load_dicom_image(path, img_size=IMAGE_SIZE, voi_lut=True, rotate=0):
    dicom = pydicom.read_file(path)
    data = dicom.pixel_array
    if voi_lut:
        data = apply_voi_lut(dicom.pixel_array, dicom)
    else:
        data = dicom.pixel_array
        
    if rotate > 0:
        rot_choices = [0, cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE, cv2.ROTATE_180]
        data = cv2.rotate(data, rot_choices[rotate])
        
    data = cv2.resize(data, (img_size, img_size))
    return data


def load_dicom_images_3d(scan_id, num_imgs=NUM_IMAGES, img_size=IMAGE_SIZE, mri_type=mri_types[mri_types_id], split="train", rotate=0):
#     rotate = np.random.randint(4)

    files = sorted(glob.glob(f"{data_directory}/{split}/{scan_id}/{mri_type}/*.dcm"), 
               key=lambda var:[int(x) if x.isdigit() else x for x in re.findall(r'[^0-9]|[0-9]+', var)])

    middle = len(files)//2
    num_imgs2 = num_imgs//2
    p1 = max(0, middle - num_imgs2)
    p2 = min(len(files), middle + num_imgs2)
    img3d = np.stack([load_dicom_image(f, rotate=rotate) for f in files[p1:p2]]).T
    
    if img3d.shape[-1] < num_imgs:
        s1 = (num_imgs - img3d.shape[-1]) // 2
        s2 = (num_imgs - img3d.shape[-1] + 1) // 2
        n_zero1 = np.zeros((img_size, img_size, s1))
        n_zero2 = np.zeros((img_size, img_size, s2))
        img3d = np.concatenate((n_zero1, img3d,  n_zero2), axis = -1)
        
    if np.min(img3d) < np.max(img3d):
        img3d = img3d - np.min(img3d)
        img3d = img3d / np.max(img3d)
            
    return np.expand_dims(img3d,0)

# a = load_dicom_images_3d("00003")
# print(a.shape)
# print(np.min(a), np.max(a), np.mean(a), np.median(a))
# image = a[0]
# print("Dimension of the CT scan is:", image.shape)
# plt.imshow(np.squeeze(image[:, :, 100]), cmap="gray")

In [None]:
class Dataset(object):
    def __init__(self,df,is_train=True, device="cpu"):
#         self.idx = df["BraTS21ID"].values
        self.paths = df["BraTS21ID5"].values
        self.y =  df["MGMT_value"].values
        self.is_train = is_train
        self.len = df.shape[0]
        self.device = device
    def __len__(self):
        return self.len
   
    def __getitem__(self,ids):
        
        if self.is_train:
            X = load_dicom_images_3d(self.paths[ids],split="train")
            Y = self.y[ids]
            X = torch.tensor(X, dtype=torch.float32, device=self.device)
            Y = torch.tensor(Y, dtype=torch.float32, device=self.device)
            return X, Y
        else:
            X = load_dicom_images_3d(self.paths[ids],split="train")
            X = torch.tensor(X, dtype=torch.float32, device=self.device)
            return X

In [None]:
class View(nn.Module):
    def __init__(self, size):
        super(View, self).__init__()
        self.size = size

    def forward(self, tensor):
        return tensor.view(self.size)
    
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        h_dim = [32, 64, 128]
        self.net = nn.Sequential(
            nn.Conv3d(1, h_dim[0], 4, 2, 1), # 128, 128, 50
            nn.BatchNorm3d(h_dim[0]),
            nn.Dropout(0.05),
            nn.MaxPool3d((2,2,2)),  # 64, 64,50
            nn.Conv3d(h_dim[0], h_dim[1], 4, 2, 1), # 32, 32,12
            nn.BatchNorm3d(h_dim[1]),
            nn.Dropout(0.05),
            nn.MaxPool3d((2,2,2)),   # 16, 16,12
            nn.Conv3d(h_dim[1], h_dim[2], 4, 2, 1),  # 8, 8, 3
            nn.BatchNorm3d(h_dim[2]),
            nn.Dropout(0.05),
            nn.MaxPool3d((4,4,1)), # 2, 2, 3
            View((-1, h_dim[-1]*2*2*3)),
            nn.Linear(h_dim[-1]*2*2*3, 128),
            nn.ReLU(True),
            nn.Linear(128, 1))
    
    def forward(self, x):
        #x: [bz, 1, 256, 256, 400]
        x = self.net(x)
        x = torch.sigmoid(x)
        eps = 1e-6
        x = 0.5*eps + (1-eps)*x
        return x

In [None]:
def get_loss(y_pred, y):
    y = y.unsqueeze(-1)
    loss = - torch.log(y_pred) * y - torch.log(1 - y_pred) * (1 - y)
#     print(loss.shape)
    acc = (y_pred > 0.5) == y
    acc = acc.to(torch.float32)
#     print(y_pred)
    auc = roc_auc_score(y.detach().cpu().numpy(), y_pred.detach().cpu().numpy())
#     print(acc.shape)
    return loss.mean(), acc.mean(), auc

In [None]:
train_dataset = Dataset(train_df[:int(train_df.shape[0]*0.85)], device=DEVICE)
val_dataset = Dataset(train_df[int(train_df.shape[0]*0.85):], device=DEVICE)

train_loader = DataLoader(train_dataset, BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, BATCH_SIZE, shuffle=False)

model = Model()
model = model.to(DEVICE)

In [None]:
def train(train_dataloader, val_dataloader, device, num_epochs):
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    scheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.75)
    # record training
    best_model_path = "best_model.pt"
    best_val_loss = np.inf
    torch.save(model, best_model_path)
    
    for epoch in tqdm(range(num_epochs)):
        model.train()
        train_loss = torch.tensor(0, device=device).float()

        for X, y in train_dataloader:
            y_pred = model(X)
            loss, acc, auc = get_loss(y_pred, y)
#                 print(y.shape)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss * X.shape[0]

            print("loss: %.4f" % loss.item(), "acc: %.4f" % acc.item(), "auc : %.2f" % auc)
            

        model.eval()
        val_loss = torch.tensor(0, device=device).float()
        val_acc = torch.tensor(0, device=device).float()
        val_auc = torch.tensor(0, device=device).float()
        
        with torch.no_grad():
            for X, y in val_dataloader:
                y_pred = model(X)
                loss, acc, auc = get_loss(y_pred, y)
                val_loss += loss * X.shape[0]
                val_acc += acc * X.shape[0]
                val_auc += auc * X.shape[0]
            
            val_data_num = len(val_dataloader.dataset)
            print("Current val loss: %.4f" % (val_loss/val_data_num))
            print("Current acc: %.4f" % (val_acc/val_data_num))
            print("Current auc: %.4f" % (val_auc/val_data_num))

            if best_val_loss > val_loss/val_data_num:
                best_val_loss = val_loss/val_data_num
                print("Impove! saving the model. Current best val loss: %.4f" % best_val_loss)
                torch.save(model, best_model_path)
        scheduler.step()

In [None]:
# 用以下命令跑后台将输出存log
# import sys
# stdout_backup = sys.stdout
# log_file = open("message.log", "w")
# sys.stdout = log_file
train(train_loader, val_loader, DEVICE, num_epochs=NUM_EPOCHS)