In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import h5py
import glob
from BraTSdataset import GBMset, GBMValidset, GBMValidset2
import SimpleITK as sitk
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from itertools import chain, combinations

from model import RobustMseg
from transform import transforms
from evaluation import eval_overlap_save, eval_overlap, eval_overlap_recon
from utils import seed_everything

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   
os.environ["CUDA_VISIBLE_DEVICES"]="2"

In [None]:
MODALITIES = [0,1,2,3]
def all_subsets(l):
    #Does not include the empty set
    subsets_modalities = list(chain(*map(lambda x: combinations(l, x), range(1, len(l)+1))))
    return np.array([[True if k in subset else False for k in range(4)] for subset in subsets_modalities])

SUBSETS_MODALITIES = all_subsets(MODALITIES)

In [None]:
%%time
pat_num = 285
x_p = np.zeros(pat_num,)
# target value
y_p = np.zeros(pat_num,)
indices = np.arange(pat_num)
x_train_p, x_valid_p, y_train_p, y_valid_p, idx_train, idx_valid = train_test_split(x_p, y_p, indices, test_size=0.2, random_state=20)

ov_validset = GBMset(sorted(idx_valid), transform=transforms(), lazy=True)
ov_validloader = torch.utils.data.DataLoader(ov_validset, batch_size=1,
                                          shuffle=False, num_workers=4)

In [None]:
model = RobustMseg()
model_name = 'dice_norm2_real_missing_adain'
epoch = '360'
model.load_state_dict(torch.load(f'{model_name}/{epoch}.pth')) 
model = nn.DataParallel(model)
model.eval()
model.cuda()

In [None]:
%%time
# 112 - 1 : 12 /  draw 10 : 
seed = 20
seed_everything(seed)
crop_size = 112
valid_batch = 10
tot_eval = np.zeros((2, 3)) # dice hd95 - wt tc et
for idx, subset in enumerate(SUBSETS_MODALITIES):
#     if idx != 1:
#         continue
    result_text = ''
    if subset[0]:
        result_text += 'T1c '
    else:
        result_text += '    '
    if subset[1]:
        result_text += 'T1 '
    else:
        result_text += '   '
    if subset[2]:
        result_text += 'T2 '
    else:
        result_text += '   '
    if subset[3]:
        result_text += 'FLAIR |'
    else:
        result_text += '      |'
    va_eval = eval_overlap(ov_validloader, model, idx, draw=None, patch_size=crop_size, overlap_stepsize=crop_size//2, batch_size=valid_batch,
                           num_classes=4, verbose=False, save=False, dir_name=f'{model_name}_{epoch}')

    tot_eval += va_eval
    print(f'{result_text} {va_eval[0][0]*100:.2f} {va_eval[0][1]*100:.2f} {va_eval[0][2]*100:.2f} {va_eval[1][0]:.2f} {va_eval[1][1]:.2f} {va_eval[1][2]:.2f}')
print(f'{"Average":16s}| {tot_eval[0][0]/15*100:.2f} {tot_eval[0][1]/15*100:.2f} {tot_eval[0][2]/15*100:.2f} {tot_eval[1][0]/15:.2f} {tot_eval[1][1]/15:.2f} {tot_eval[1][2]/15:.2f}')