## Use stacked images (3D) and Efficientnet3D model

Acknowledgements:

- https://www.kaggle.com/ihelon/brain-tumor-eda-with-animations-and-modeling
- https://www.kaggle.com/furcifer/torch-efficientnet3d-for-mri-no-train
- https://github.com/shijianjian/EfficientNet-PyTorch-3D
    
    
Use models with only one MRI type, then ensemble the 4 models 


In [None]:
import os
import sys 
import json
import glob
import random
import collections
import time

import numpy as np
import pandas as pd
import pydicom
import cv2
import matplotlib.pyplot as plt
import seaborn as sns

import torch
from torch import nn
from torch.utils import data as torch_data
from sklearn import model_selection as sk_model_selection
from torch.nn import functional as torch_functional
import torch.nn.functional as F
from torch.optim import lr_scheduler

from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import roc_auc_score

In [None]:
if os.path.exists("../input/rsna-miccai-brain-tumor-radiogenomic-classification"):
    data_directory = '../input/rsna-miccai-brain-tumor-radiogenomic-classification'
    pytorch3dpath = "../input/efficientnetpyttorch3d/EfficientNet-PyTorch-3D"
else:
    data_directory = '/media/roland/data/kaggle/rsna-miccai-brain-tumor-radiogenomic-classification'
    pytorch3dpath = "EfficientNet-PyTorch-3D"
    
#mri_types = ['FLAIR','T1w','T1wCE','T2w']
#mri_types = ['FLAIR','T2w','T1w','T1wCE']
mri_types = ['T1wCE']
SIZE = 256
NUM_IMAGES = 64

# sys.path.append(pytorch3dpath)
# from efficientnet_pytorch_3d import EfficientNet3D

## Functions to load images

In [None]:
from fractions import Fraction
import shutil
def image_aug(n,kijyun = NUM_IMAGES):
    #print('★n before:', n.shape)
    im_x = n.shape[1]
    im_y = n.shape[2]
    while kijyun != n.shape[0]:
        n_cnt = n.shape[0]
        n_tmpo = kijyun/n_cnt
        bunkatsu = Fraction(n_tmpo).limit_denominator(kijyun)
        bunshi = bunkatsu.numerator
        bunnbo = bunkatsu.denominator
        #print('bunshi:',bunshi,' bunnbo:',bunnbo)
        n_n = np.array([])
        if bunshi == bunnbo:
            #そのまま
            break
        elif bunshi > bunnbo:
            #拡張
            if (bunshi / bunnbo) >= 2:
                baisu = int(bunshi / bunnbo) - 1
                for w,n_x in enumerate(n):
                    #print('w:',w)
                    n_x = n_x.reshape(-1,im_x,im_y)
                    if n.shape[0] == w + 1:
                        n_y = n_x
                    else:
                        n_y = n[w+1].reshape(-1,im_x,im_y)
                    
                    if(n_n.shape[0] == 0):
                        n_n = n_x
                    else:
                        n_n = np.concatenate([n_n,n_x])
                    #print('baisu:',baisu)
                    for baisu_cnt in range(baisu):
                        baisu_cnt += 1
                        n_x = n_x * (baisu_cnt / baisu)
                        n_y = n_y * ((baisu - baisu_cnt) / baisu)
                        n_z = np.concatenate([n_x,n_y])
                        n_z = np.mean(n_z,axis=0).reshape(-1,im_x,im_y)
                        n_n = np.concatenate([n_n,n_z])
            else:
                #ins_cnt = int(bunnbo / 2)
                ins_cnt = int((kijyun) / (kijyun - n_cnt))
                if ins_cnt == kijyun:
                    ins_cnt -= 1
                #print('n_cnt:',n_cnt)
                #print('kijyun:',kijyun)
                #print('ins_cnt:',ins_cnt)
                #print('int(n.shape[0]/ins_cnt):',int(n.shape[0]/ins_cnt))
                for w in range(int(n_cnt/ins_cnt)):
                    #print('w:',w)
                    w += 1
                    w = (w * ins_cnt)
                    #print('w:',w)
                    #print('w-ins_cnt:',w-ins_cnt)
                    
                    n_x = n[w-ins_cnt:w]
                    
                    #print('n_cnt:',n_cnt)
                    #print('w:',w)
                    if n_cnt <= w + 1:
                        #一番後ろの列を
                        #print('n_x:',n_x.shape)
                        #print('n_x[-1:]:',n_x[-1:].shape)
                        n_x = np.concatenate([n_x,n_x[-1:]])
                    else:
                        n_y = n[w+1].reshape(-1,im_x,im_y)
                        n_y = np.concatenate([n_x[-1:],n_y])
                        n_y = np.mean(n_y,axis=0).reshape(-1,im_x,im_y)
                        n_x = np.concatenate([n_x,n_y])
                    if(n_n.shape[0] == 0):
                        n_n = n_x
                    else:
                        n_n = np.concatenate([n_n,n_x])
                n_n = np.concatenate([n_n,n[w:]])
        elif bunshi < bunnbo:
            #縮小
            #print('n:',n.shape)
            img_size = n.shape[0]
            if int(bunnbo / bunshi) >= 2:
                small_size = int(bunnbo / bunshi)
                
                img_size_current = 0
                while img_size > img_size_current:
                    #small_sizeに指定されている画像を１つにまとめていく
                    if img_size >= (img_size_current + small_size):
                        n_x = np.expand_dims(n[img_size_current:img_size_current + small_size].mean(axis=0),0)
                    else:
                        n_x = np.expand_dims(n[img_size_current:img_size].mean(axis=0),0)
                    
                    if(n_n.shape[0] == 0):
                        n_n = n_x
                    else:
                        n_n = np.concatenate([n_n,n_x])
                    img_size_current = img_size_current + small_size
            else:
                while bunnbo >= kijyun: 
                    bunnbo = int(bunnbo / 2)
                    bunshi = int(bunshi / 2)
                del_cnt = int(bunnbo / 2)
                i = 0
                #print(del_cnt)
                for n_x in n:
                    i += 1
                    if i % del_cnt != 0:
                        n_x = n_x.reshape(-1,n_x.shape[0],n_x.shape[1])
                        if(n_n.shape[0] == 0):
                            n_n = n_x
                        else:
                            n_n = np.concatenate([n_n,n_x])
                    else:
                        del_cnt = del_cnt + bunnbo
        n = n_n
        #print('n.shape:',n.shape)
    return n

In [None]:
# def load_dicom_image(path, img_size=SIZE):
#     dicom = pydicom.read_file(path)
#     data = dicom.pixel_array
#     if np.min(data)==np.max(data):
#         data = np.zeros((img_size,img_size))
#         return data
#     data = data - np.min(data)
#     if np.max(data) != 0:
#         data = data / np.max(data)
    
#     #data = (data * 255).astype(np.uint8)
#     data = cv2.resize(data, (img_size, img_size))
#     return data

# def load_dicom_images_3d(scan_id, num_imgs=NUM_IMAGES, img_size=SIZE, mri_type="FLAIR", split="train"):

#     files = sorted(glob.glob(f"{data_directory}/{split}/{scan_id}/{mri_type}/*.dcm"))
    
#     middle = len(files)//2
#     num_imgs2 = num_imgs//2
#     p1 = max(0, middle - num_imgs2)
#     p2 = min(len(files), middle + num_imgs2)
#     img3d = np.stack([load_dicom_image(f) for f in files[p1:p2]]).T 
#     if img3d.shape[-1] < num_imgs:
#         n_zero = np.zeros((img_size, img_size, num_imgs - img3d.shape[-1]))
#         img3d = np.concatenate((img3d,  n_zero), axis = -1)
            
#     return np.expand_dims(img3d,0)

# load_dicom_images_3d("00000").shape
def load_dicom_image(path, img_size=SIZE):
    #print(path)
    dicom = pydicom.read_file(path)
    data = dicom.pixel_array
    #print(np.array(data).shape)
    
#     if np.min(data)==np.max(data):
#         data = np.zeros((img_size,img_size))
#         return data
#     data = data - np.min(data)
#     if np.max(data) != 0:
#         data = data / np.max(data)
    
    #data = (data * 255).astype(np.uint8)
    data = cv2.resize(data, (img_size, img_size))
    #print('(np.array(data).reshape(-1)!=0).sum():',(np.array(data).reshape(-1)!=0).sum())
    if (np.array(data).reshape(-1)!=0).sum() >= 1000:
        return data
    else:
        return False

def load_dicom_images_3d(scan_id, num_imgs=NUM_IMAGES, img_size=SIZE, mri_type="FLAIR", split="train"):
    #print('scan_id:',scan_id)
    if os.path.exists(f"/tmp_np/{split}/{scan_id}/{mri_type}"):
        img3d = np.load(f"/tmp_np/{split}/{scan_id}/{mri_type}/np_sa.npy")
        return img3d
    else:
        files = sorted(glob.glob(f"{data_directory}/{split}/{scan_id}/{mri_type}/*.dcm"))

    #     middle = len(files)//2
    #     num_imgs2 = num_imgs//2
    #     p1 = max(0, middle - num_imgs2)
    #     p2 = min(len(files), middle + num_imgs2)
        #img3d = np.stack([load_dicom_image(f) for f in files[p1:p2]]).T
        np_stack = []
        for f in files:
            f_np = load_dicom_image(f)
            if type(f_np) == np.ndarray:
                np_stack.append(f_np)
        if len(np_stack) == 0:
            img3d = np.zeros((1,img_size, img_size))
        else:
            img3d = np.stack(np_stack)
        #img3d = np.stack([load_dicom_image(f) for f in files])
        img3d = image_aug(img3d)

        if img3d.shape[-1] < num_imgs:
            n_zero = np.zeros((img_size, img_size, num_imgs - img3d.shape[-1]))
            img3d = np.concatenate((img3d,  n_zero), axis = -1)

        img3d = img3d - np.min(img3d)
        if np.max(img3d) != 0:
            img3d = img3d / np.max(img3d)
        img3d = img3d.transpose(1,2,0)
        img3d = np.expand_dims(img3d,0)
        os.makedirs(f"/tmp_np/{split}/{scan_id}/{mri_type}")
        np.save(f"/tmp_np/{split}/{scan_id}/{mri_type}/np_sa.npy",img3d)
        return img3d

load_dicom_images_3d("00000").shape

In [None]:
def set_seed(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True

set_seed(42)

## train / test splits

In [None]:
train_df = pd.read_csv(f"{data_directory}/train_labels.csv")
display(train_df)

df_train, df_valid = sk_model_selection.train_test_split(
    train_df, 
    test_size=0.2, 
    random_state=42, 
    stratify=train_df["MGMT_value"],
)


In [None]:
df_train.tail()

## Model and training classes

In [None]:
# class Dataset(torch_data.Dataset):
#     def __init__(self, paths, targets=None, mri_type=None, label_smoothing=0.01, split="train"):
#         self.paths = paths
#         self.targets = targets
#         self.mri_type = mri_type
#         self.label_smoothing = label_smoothing
#         self.split = split
          
#     def __len__(self):
#         return len(self.paths)
    
#     def __getitem__(self, index):
#         scan_id = self.paths[index]
#         if self.targets is None:
#             data = load_dicom_images_3d(str(scan_id).zfill(5), mri_type=self.mri_type[index], split=self.split)
#         else:
#             data = load_dicom_images_3d(str(scan_id).zfill(5), mri_type=self.mri_type[index], split="train")

#         if self.targets is None:
#             return {"X": torch.tensor(data).float(), "id": scan_id}
#         else:
#             y = torch.tensor(abs(self.targets[index]-self.label_smoothing), dtype=torch.float)
#             return {"X": torch.tensor(data).float(), "y": y}


In [None]:
class Dataset_LSTM(torch_data.Dataset):
    def __init__(self, paths, targets=None, mri_type=None, label_smoothing=0.01, split="train"):
        self.paths = paths
        self.targets = targets
        self.mri_type = mri_type
        self.label_smoothing = label_smoothing
        self.split = split
          
    def __len__(self):
        return len(self.paths)
    
    def __getitem__(self, index):
        scan_id = self.paths[index]
        if self.targets is None:
            data = load_dicom_images_3d(str(scan_id).zfill(5), mri_type=self.mri_type[index], split=self.split)
        else:
            data = load_dicom_images_3d(str(scan_id).zfill(5), mri_type=self.mri_type[index], split="train")
        
        
        data = torch.tensor(data).permute(3,0,1,2).squeeze()
        data = data.reshape([data.shape[0],-1])
        
        
        if self.targets is None:
            return {"X": data.float(), "id": scan_id}
        else:
            y = torch.tensor(abs(self.targets[index]-self.label_smoothing), dtype=torch.float)
            return {"X": data.float(), "y": y}

In [None]:
# class Model(nn.Module):
#     def __init__(self):
#         super().__init__()
#         #モデルを修正 efficientnet-b0→efficientnet-b3
#         #self.net = EfficientNet3D.from_name("efficientnet-b0", override_params={'num_classes': 2}, in_channels=1)
#         self.net = EfficientNet3D.from_name("efficientnet-b3", override_params={'num_classes': 2}, in_channels=1)
#         n_features = self.net._fc.in_features
#         self.net._fc = nn.Linear(in_features=n_features, out_features=1, bias=True)
    
#     def forward(self, x):
#         out = self.net(x)
#         return out

In [None]:
class SwishImplementation2(torch.autograd.Function):
    @staticmethod
    def forward(ctx, i):
        result = i * torch.sigmoid(i)
        ctx.save_for_backward(i)
        return result

    @staticmethod
    def backward(ctx, grad_output):
        #i = ctx.saved_variables[0]
        i = ctx.saved_tensors[0]
        sigmoid_i = torch.sigmoid(i)
        return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))

class MemoryEfficientSwish2(nn.Module):
    def forward(self, x):
        return SwishImplementation2.apply(x)

In [None]:
class Model_LSTM(nn.Module):
    def __init__(self):
        super().__init__()
        
#         self.grubat=nn.BatchNorm2d(64)
        self.xh   = torch.nn.GRU(256*256, 1024)
#         self.xh1  = torch.nn.GRU(256*256, 1024)
        #self.xh1  = torch.nn.GRU(2048, 1024)
        
        self.drop = nn.Dropout(0.1)
        self.hy   = torch.nn.Linear(1024, 1024)
        self.bat  = nn.BatchNorm1d(1024)
        #self.swift= MemoryEfficientSwish2()
        self.swift= nn.ELU()
        
        self.drop1= nn.Dropout(0.1)
        self.hy1  = torch.nn.Linear(1024, 1024)
        self.bat1 = nn.BatchNorm1d(1024)
        #self.swift1= MemoryEfficientSwish2()
        self.swift1= nn.ELU()
        
        self.drop3= nn.Dropout(0.1)
        self.hy3  = torch.nn.Linear(1024, 1024)
        self.bat3 = nn.BatchNorm1d(1024)
        #self.swift1= MemoryEfficientSwish2()
        self.swift3= nn.ELU()
        
        self.drop4= nn.Dropout(0.1)
        self.hy4  = torch.nn.Linear(1024, 1024)
        self.bat4 = nn.BatchNorm1d(1024)
        #self.swift1= MemoryEfficientSwish2()
        self.swift4= nn.ELU()
        
        self.drop2= nn.Dropout(0.1)
        self.bat2 = nn.BatchNorm1d(4096)
        self.hy2  = torch.nn.Linear(4096, 1024)
        self.elu  = nn.ELU()
        
        self.drop5= nn.Dropout(0.1)
        self.bat5 = nn.BatchNorm1d(1024)
        self.hy5  = torch.nn.Linear(1024, 1)
        self.elu5 = nn.ELU()
        
        
        #self.sig  = nn.Sigmoid()
        #self.swish= MemoryEfficientSwish2()
        
#         self._init_weights(self.hy)
#         self._init_weights(self.hy1)
#         self._init_weights(self.hy3)
#         self._init_weights(self.hy4)
#         self._init_weights(self.hy2)
        
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            self._std = np.sqrt(2/(module.in_features + module.out_features))
            print('self._std:',self._std)
            #module.weight.data.normal_(mean=0.0, std=0.02)
            module.weight.data.normal_(mean=0.0, std=self._std)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
    
    def forward(self, x):
# #         print('x:',x.shape)
#         x = x.reshape([-1,64,256,256])
# #         print('x:',x.shape)
#         x = self.grubat(x)
# #         print('x:',x.shape)
#         x = x.reshape([-1,64,256*256])

        x1,x2 = self.xh(x)
        #x1,x2 = self.xh1(x1)
        
        out = self.drop(x1[:,-1,:])
        out = self.hy(out)
        #out = self.swift(self.bat(out))
        out = self.bat(out)
        out = self.swift(out)
        
        out1= self.drop1(x1[:,32,:])
        out1= self.hy1(out1)
        #out1= self.swift1(self.bat1(out1))
        out1= self.bat1(out1)
        out1= self.swift1(out1)
        
        out2= self.drop3(x1[:,45,:])
        out2= self.hy3(out2)
        #out1= self.swift1(self.bat1(out1))
        out2= self.bat3(out2)
        out2= self.swift3(out2)
        
        out3= self.drop4(x1[:,15,:])
        out3= self.hy4(out3)
        #out1= self.swift1(self.bat1(out1))
        out3= self.bat4(out3)
        out3= self.swift4(out3)
        
#         x1,x2 = self.xh(x)

#         out2= self.drop3(x1[:,15,:])
#         out2= self.hy3(out2)
#         #out1= self.swift1(self.bat1(out1))
#         out2= self.bat3(out2)
#         out2= self.swift3(out2)
        
#         out3= self.drop4(x1[:,-1,:])
#         out3= self.hy4(out3)
#         #out1= self.swift1(self.bat1(out1))
#         out3= self.bat4(out3)
#         out3= self.swift4(out3)
        
#         x3,x4 = self.xh1(x[:,32:64,:])
        
#         out = self.drop(x3[:,-1,:])
#         out = self.hy(out)
#         #out = self.swift(self.bat(out))
#         out = self.bat(out)
#         out = self.swift(out)
        
#         out1= self.drop1(x3[:,15,:])
#         out1= self.hy1(out1)
#         #out1= self.swift1(self.bat1(out1))
#         out1= self.bat1(out1)
#         out1= self.swift1(out1)

        out =  torch.cat([out,out1,out2,out3], axis=1)
        
        out = self.drop2(out)
        out = self.bat2(out)
        out = self.hy2(out)
        out = self.elu(out)
        
        out = self.drop5(out)
        out = self.bat5(out)
        out = self.hy5(out)

        out = self.elu5(out)
        
        
        #out = self.sig(out)
        #out = self.nn_softmax(out)
        return out

In [None]:
class Trainer:
    def __init__(
        self, 
        model, 
        device, 
        optimizer, 
        criterion
    ):
        self.model = model
        self.device = device
        self.optimizer = optimizer
        self.criterion = criterion

        self.best_valid_score = np.inf
        self.n_patience = 0
        self.n_scheduler = 0
        self.lastmodel = None
        
    def fit(self, epochs, train_loader, valid_loader, save_path, patience): 
        train_count = train_loader.__len__()
        self.scheduler = lr_scheduler.OneCycleLR(optimizer=self.optimizer
                                            , pct_start=0.3
                                            , div_factor=100
                                            , max_lr=0.0001
                                            , epochs=epochs
                                            , steps_per_epoch=train_count)
        for n_epoch in range(1, epochs + 1):
            self.info_message("EPOCH: {}", n_epoch)
            
            train_loss, train_time = self.train_epoch(train_loader)
            valid_loss, valid_auc, valid_time = self.valid_epoch(valid_loader)
            
            self.info_message(
                "[Epoch Train: {}] loss: {:.4f}, time: {:.2f} s            ",
                n_epoch, train_loss, train_time
            )
            
            self.info_message(
                "[Epoch Valid: {}] loss: {:.4f}, auc: {:.4f}, time: {:.2f} s",
                n_epoch, valid_loss, valid_auc, valid_time
            )

            # if True:
            # if self.best_valid_score < valid_auc: 
            if self.best_valid_score > valid_loss: 
                self.save_model(n_epoch, save_path, valid_loss, valid_auc)
                self.info_message(
                     "auc improved from {:.4f} to {:.4f}. Saved model to '{}'", 
                    self.best_valid_score, valid_loss, self.lastmodel
                )
                self.best_valid_score = valid_loss
                self.n_patience = 0
                self.n_scheduler= 0
            else:
                self.n_patience += 1
                self.n_scheduler+= 1
                if self.n_scheduler >= 3:
                    self.scheduler = lr_scheduler.OneCycleLR(optimizer=self.optimizer
                                                        , pct_start=0.3
                                                        , div_factor=1000
                                                        , max_lr=0.001
                                                        , epochs=epochs
                                                        , steps_per_epoch=train_count)
                    self.n_scheduler= 0
            
            if self.n_patience >= patience:
                self.info_message("\nValid auc didn't improve last {} epochs.", patience)
                break
            
    def train_epoch(self, train_loader):
        self.model.train()
        t = time.time()
        sum_loss = 0

        for step, batch in enumerate(train_loader, 1):
            X = batch["X"].to(self.device)
            targets = batch["y"].to(self.device)
            self.optimizer.zero_grad()
            outputs = self.model(X).squeeze(1)
            
            loss = self.criterion(outputs, targets)
            loss.backward()

            sum_loss += loss.detach().item()

            self.optimizer.step()
            self.scheduler.step()
            _lr = self.scheduler.get_last_lr()[0]
            message = 'Train Step {}/{}, train_loss: {:.4f}, get_lr: {:.7f}'
            self.info_message(message, step, len(train_loader), sum_loss/step,_lr, end="\r")
        
        return sum_loss/len(train_loader), int(time.time() - t)
    
    def valid_epoch(self, valid_loader):
        self.model.eval()
        t = time.time()
        sum_loss = 0
        y_all = []
        outputs_all = []

        for step, batch in enumerate(valid_loader, 1):
            with torch.no_grad():
                X = batch["X"].to(self.device)
                targets = batch["y"].to(self.device)

                outputs = self.model(X).squeeze(1)
                loss = self.criterion(outputs, targets)

                sum_loss += loss.detach().item()
                y_all.extend(batch["y"].tolist())
                outputs_all.extend(outputs.tolist())

            message = 'Valid Step {}/{}, valid_loss: {:.4f}'
            self.info_message(message, step, len(valid_loader), sum_loss/step, end="\r")
            
        y_all = [1 if x > 0.5 else 0 for x in y_all]
        auc = roc_auc_score(y_all, outputs_all)
        
        return sum_loss/len(valid_loader), auc, int(time.time() - t)
    
    def save_model(self, n_epoch, save_path, loss, auc):
        if self.lastmodel != None:
            os.remove(self.lastmodel)
        self.lastmodel = f"{save_path}-e{n_epoch}-loss{loss:.3f}-auc{auc:.3f}.pth"
        torch.save(
            {
                "model_state_dict": self.model.state_dict(),
                "optimizer_state_dict": self.optimizer.state_dict(),
                "best_valid_score": self.best_valid_score,
                "n_epoch": n_epoch,
            },
            self.lastmodel,
        )
    
    @staticmethod
    def info_message(message, *args, end="\n"):
        print(message.format(*args), end=end)

## train models

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def train_mri_type(df_train, df_valid, mri_type):
    if mri_type=="all":
        train_list = []
        valid_list = []
        for mri_type in mri_types:
            df_train.loc[:,"MRI_Type"] = mri_type
            train_list.append(df_train.copy())
            df_valid.loc[:,"MRI_Type"] = mri_type
            valid_list.append(df_valid.copy())

        df_train = pd.concat(train_list)
        df_valid = pd.concat(valid_list)
    else:
        df_train.loc[:,"MRI_Type"] = mri_type
        df_valid.loc[:,"MRI_Type"] = mri_type

    print(df_train.shape, df_valid.shape)
    display(df_train.head())
    
    train_data_retriever = Dataset_LSTM(
        df_train["BraTS21ID"].values, 
        df_train["MGMT_value"].values, 
        df_train["MRI_Type"].values
    )

    valid_data_retriever = Dataset_LSTM(
        df_valid["BraTS21ID"].values, 
        df_valid["MGMT_value"].values,
        df_valid["MRI_Type"].values
    )

    train_loader = torch_data.DataLoader(
        train_data_retriever,
        batch_size=4,
        shuffle=True,
        num_workers=8,
    )

    valid_loader = torch_data.DataLoader(
        valid_data_retriever, 
        batch_size=4,
        shuffle=False,
        num_workers=8,
    )

    model = Model_LSTM()
    model.to(device)

    #checkpoint = torch.load("best-model-all-auc0.555.pth")
    #model.load_state_dict(checkpoint["model_state_dict"])

    #print(model)

    #★学習率を修正0.001→0.0001
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
    #optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

    criterion = torch_functional.binary_cross_entropy_with_logits

    trainer = Trainer(
        model, 
        device, 
        optimizer, 
        criterion
    )

    history = trainer.fit(
        15, 
        train_loader, 
        valid_loader, 
        f"{mri_type}", 
        10,
    )
    
    #★numpyをファイル保存する対応の際に追加
    shutil.rmtree(f'/tmp_np/')
    
    return trainer.lastmodel

modelfiles = None

if not modelfiles:
    modelfiles = [train_mri_type(df_train, df_valid, m) for m in mri_types]
    print(modelfiles)

## Predict function

In [None]:
# def predict(modelfile, df, mri_type, split):
#     print("Predict:", modelfile, mri_type, df.shape)
#     df.loc[:,"MRI_Type"] = mri_type
#     data_retriever = Dataset_LSTM(
#         df.index.values, 
#         mri_type=df["MRI_Type"].values,
#         split=split
#     )

#     data_loader = torch_data.DataLoader(
#         data_retriever,
#         batch_size=4,
#         shuffle=False,
#         num_workers=8,
#     )
   
#     model = Model_LSTM()
#     model.to(device)
    
#     checkpoint = torch.load(modelfile)
#     model.load_state_dict(checkpoint["model_state_dict"])
#     model.eval()
    
#     y_pred = []
#     ids = []

#     for e, batch in enumerate(data_loader,1):
#         print(f"{e}/{len(data_loader)}", end="\r")
#         with torch.no_grad():
#             tmp_pred = torch.sigmoid(model(batch["X"].to(device))).cpu().numpy().squeeze()
#             #tmp_pred = model(batch["X"].to(device)).cpu().numpy().astype(np.float64).squeeze()
#             #tmp_pred = np.where(tmp_pred>1.0,1.0,np.where(tmp_pred<0.0,0.0,tmp_pred))
#             #print("tmp_pred:",tmp_pred)
#             if tmp_pred.size == 1:
#                 y_pred.append(tmp_pred)
#             else:
#                 y_pred.extend(tmp_pred.tolist())
#             ids.extend(batch["id"].numpy().tolist())
            
#     preddf = pd.DataFrame({"BraTS21ID": ids, "MGMT_value": y_pred}) 
#     preddf = preddf.set_index("BraTS21ID")
#     return preddf

## Ensemble for validation

In [None]:
# modelfiles = ["../input/gru-net-training-data/FLAIR-e12-loss0.654-auc0.692.pth"
#              ,"../input/gru-net-training-data/T1w-e9-loss0.685-auc0.578.pth"
#              ,"../input/gru-net-training-data/T1wCE-e15-loss0.679-auc0.606.pth"
#              ,"../input/gru-net-training-data/T2w-e6-loss0.657-auc0.698.pth"]

In [None]:
# df_valid = df_valid.set_index("BraTS21ID")
# df_valid["MGMT_pred"] = 0
# for m, mtype in zip(modelfiles,  mri_types):
#     pred = predict(m, df_valid, mtype, "train")
#     df_valid["MGMT_pred"] += pred["MGMT_value"]
# df_valid["MGMT_pred"] /= len(modelfiles)
# auc = roc_auc_score(df_valid["MGMT_value"], df_valid["MGMT_pred"])
# print(f"Validation ensemble AUC: {auc:.4f}")
# sns.displot(df_valid["MGMT_pred"])

In [None]:
# df_valid = df_valid.set_index("BraTS21ID")
# df_valid["MGMT_pred"] = 0
# for m, mtype in zip(modelfiles,  mri_types):
#     pred = predict(m, df_valid, mtype, "train")
#     df_valid["MGMT_pred" + mtype] = pred["MGMT_value"].astype(np.float64)
# # df_valid["MGMT_pred"] /= len(modelfiles)
# for mtype in mri_types:
#     auc = roc_auc_score(df_valid["MGMT_value"], df_valid["MGMT_pred" + mtype])
#     print(f"Validation ensemble AUC: {auc:.4f}")
#     sns.displot(df_valid["MGMT_pred" + mtype].astype(np.float64))

## Ensemble for submission

In [None]:
# submission = pd.read_csv(f"{data_directory}/sample_submission.csv", index_col="BraTS21ID")

# submission["MGMT_value"] = 0
# for m, mtype in zip(modelfiles, mri_types):
#     pred = predict(m, submission, mtype, split="test")
#     submission["MGMT_value"] += pred["MGMT_value"]

# submission["MGMT_value"] /= len(modelfiles)
# submission["MGMT_value"].to_csv("submission.csv")

In [None]:
# submission = pd.read_csv(f"{data_directory}/sample_submission.csv", index_col="BraTS21ID")

# submission["MGMT_value"] = 0
# for m, mtype in zip(modelfiles, mri_types):
#     pred = predict(m, submission, mtype, split="test")
#     submission["MGMT_value"+mtype] = pred["MGMT_value"]

# submission["MGMT_value"] += submission["MGMT_valueFLAIR"] * 0.5
# submission["MGMT_value"] += submission["MGMT_valueT1w"] * 0.00
# submission["MGMT_value"] += submission["MGMT_valueT1wCE"] * 0.00
# submission["MGMT_value"] += submission["MGMT_valueT2w"] * 0.5

# submission["MGMT_value"].to_csv("submission.csv")

In [None]:
# submission

In [None]:
# sns.displot(submission["MGMT_value"])