- requirement

In [1]:
import pandas as pd
import numpy as np
from tqdm import tqdm
import os

import pydicom
from pydicom.pixel_data_handlers.util import apply_voi_lut

from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import roc_auc_score
import glob

In [2]:
import albumentations as A
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from albumentations.pytorch import ToTensorV2
from sklearn.metrics import roc_auc_score
from torch.optim import lr_scheduler
from tqdm import tqdm
import re
import cv2

In [3]:
import random
from torch.utils.data import Dataset

In [4]:
from pydicom.pixel_data_handlers.util import apply_modality_lut, apply_voi_lut

- 파이퍼 파라미터

In [10]:
NUM_IMAGES_3D = 64
TRAINING_BATCH_SIZE = 4
TEST_BATCH_SIZE = 4
IMAGE_SIZE = 256
N_EPOCHS = 10
do_valid = True
n_workers = 0
type_ = "T1wCE"

- load dicom

In [11]:
def load_dicom_image(path, img_size=IMAGE_SIZE, voi_lut=True, rotate=0):
    dicom = pydicom.read_file(path)
    data = dicom.pixel_array
    
    # voi_lut
    if voi_lut:
        data = apply_voi_lut(dicom.pixel_array, dicom)
    else:
        data = dicom.pixel_array

    # 회전 옵션
    if rotate > 0:
        rot_choices = [
            0,
            cv2.ROTATE_90_CLOCKWISE,
            cv2.ROTATE_90_COUNTERCLOCKWISE,
            cv2.ROTATE_180,
        ]
        data = cv2.rotate(data, rot_choices[rotate])

    data = cv2.resize(data, (img_size, img_size))
    return data

- BrainRSNADataset

In [17]:
class BrainRSNADataset(Dataset):
    def __init__(
        self, data, transform=None, target="MGMT_value", mri_type="FLAIR", is_train=True
    ):
        self.target = target
        self.data = data
        self.type = mri_type

        self.transform = transform
        self.is_train = is_train
        self.folder = "train" if self.is_train else "test"

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        row = self.data.loc[index]
        case_id = int(row.BraTS21ID)
        target = int(row[self.target])
        _3d_images = self.load_dicom_images_3d(case_id)
        _3d_images = torch.tensor(_3d_images).float()
        if self.is_train:
            return {"image": _3d_images, "target": target}
        else:
            return {"image": _3d_images, "case_id": case_id}

    def load_dicom_images_3d(
        self,
        case_id,
        num_imgs=NUM_IMAGES_3D,
        img_size=IMAGE_SIZE,
        rotate=0,
    ):
        case_id = str(case_id).zfill(5)

        path = f"./data/{self.folder}/{case_id}/{self.type}/*.dcm"
        files = sorted(
            glob.glob(path),
            key=lambda var: [
                int(x) if x.isdigit() else x for x in re.findall(r"[^0-9]|[0-9]+", var)
            ],
        )

        middle = len(files) // 2
        num_imgs2 = num_imgs // 2
        p1 = max(0, middle - num_imgs2)
        p2 = min(len(files), middle + num_imgs2)
        image_stack = [load_dicom_image(f, rotate=rotate) for f in files[p1:p2]]
        
        img3d = np.stack(image_stack).T
        if img3d.shape[-1] < num_imgs:
            n_zero = np.zeros((img_size, img_size, num_imgs - img3d.shape[-1]))
            img3d = np.concatenate((img3d, n_zero), axis=-1)

        if np.min(img3d) < np.max(img3d):
            img3d = img3d - np.min(img3d)
            img3d = img3d / np.max(img3d)

        return np.expand_dims(img3d, 0)

In [18]:
import monai

# model 
model = monai.networks.nets.resnet10(spatial_dims=3, n_input_channels=1, n_classes=1)
device = torch.device("cuda")
model.to(device);
all_weights = os.listdir("./data/resnet10-rsna")
fold_files = [f for f in all_weights if type_ in f]
criterion = nn.BCEWithLogitsLoss()

In [19]:
fold_files

['3d-resnet10_T1wCE_fold0_0.565.pth',
 '3d-resnet10_T1wCE_fold1_0.573.pth',
 '3d-resnet10_T1wCE_fold2_0.538.pth',
 '3d-resnet10_T1wCE_fold3_0.664.pth',
 '3d-resnet10_T1wCE_fold4_0.551.pth']

In [20]:
sample = pd.read_csv("./data/sample_submission.csv")

In [21]:
tta_true_labels = []
tta_preds = []
test_dataset = BrainRSNADataset(data=sample, mri_type=type_, is_train=False)
test_dl = torch.utils.data.DataLoader(
        test_dataset, batch_size=TEST_BATCH_SIZE, shuffle=False, num_workers=n_workers
    )

preds_f = np.zeros(len(sample))
for fold in range(5):
    image_ids = []
    model.load_state_dict(torch.load(f"./data/resnet10-rsna/{fold_files[fold]}"))
    preds = []
    epoch_iterator_test = tqdm(test_dl)
    with torch.no_grad():
        for step, batch in enumerate(epoch_iterator_test):
            model.eval()
            images = batch["image"].to(device)

            outputs = model(images)
            preds.append(outputs.sigmoid().detach().cpu().numpy())
            image_ids.append(batch["case_id"].detach().cpu().numpy())
    

    preds_f += np.vstack(preds).T[0]/5

    ids_f = np.hstack(image_ids)

100%|██████████| 22/22 [00:39<00:00,  1.78s/it]
100%|██████████| 22/22 [00:15<00:00,  1.43it/s]
100%|██████████| 22/22 [00:15<00:00,  1.46it/s]
100%|██████████| 22/22 [00:18<00:00,  1.19it/s]
100%|██████████| 22/22 [00:15<00:00,  1.46it/s]


In [22]:
sample["BraTS21ID"] = ids_f
sample["MGMT_value"] = preds_f

In [23]:
sample = sample.sort_values(by="BraTS21ID").reset_index(drop=True)

In [24]:
sample.to_csv("submission.csv", index=False)

In [25]:
sample

Unnamed: 0,BraTS21ID,MGMT_value
0,1,0.543852
1,13,0.551282
2,15,0.558330
3,27,0.552038
4,37,0.546800
...,...,...
82,826,0.457712
83,829,0.430619
84,833,0.502583
85,997,0.525578
