In [1]:
from einops.layers.torch import Rearrange
import argparse
import torch
from dataset import make_data_loaders
from model import MMCFormer

from eval_utils import *


In [None]:
use_cuda = torch.cuda.is_available()
parser = argparse.ArgumentParser(description='MMCFormer')

parser.add_argument('--task_name',type=str, default='MMCFormer', 
                    help='task name')
parser.add_argument('--saved_model_path', type=str, default='./results/', 
                    help='Pre-trained model path')
parser.add_argument('--path_to_data', type=str, default='../../brats/MICCAI_BraTS_2018_Data_Training/', 
                    help='path to dataset')
parser.add_argument('--modalities', type=str, nargs='*', default=['t1ce', 't1', 'flair', 't2'], 
                    help='List of modalities needd to be used for training and evaluating the model')
parser.add_argument('--n_missing_modalities', type=int, default=1, 
                    help='number of modalities for the missing path. Sort [flair, t1, t1ce, t2] based on your desired modalities for the missing path')

parser.add_argument('--number_classes', type=int, default=4, 
                    help='number of classes in the target dataset')
parser.add_argument('--batch_size_tr', type=int, default=1, 
                    help='batch size for train')
parser.add_argument('--batch_size_va', type=int, default=1, 
                    help='batch size for validation')
parser.add_argument('--test_p', type=float, default=0.2, 
                    help='test percentage (20%)')
parser.add_argument('--progress_p', type=float, default=0.1, 
                    help='value between 0-1 shows the number of time we need to report training progress in each epoch')
parser.add_argument('--validation_p', type=float, default=0.1, 
                    help='validation percentage')
parser.add_argument('--inputshape', default=[160, 192, 128], 
                    help='input shape')

parser.add_argument('--missing_in_chans', type=int, default=1, 
                    help='missing modality input channels')
parser.add_argument('--full_in_chans', type=int, default=4, 
                    help='full modality input channels')

args = parser.parse_args(args=[])

In [None]:
# load data
args.modalities = ['t1ce', 't2', 't1', 'flair']
loaders = make_data_loaders(args)

In [None]:
def build_model(inp_shape , num_classes, full_in_chans, missing_in_chans):
    model_full    = MMCFormer(model_mode='full', img_size = inp_shape, num_classes=num_classes, in_chans=full_in_chans, 
                              head_count=1, token_mlp_mode="mix_skip").cuda()
    model_missing = MMCFormer(model_mode='missing', img_size = inp_shape, num_classes=num_classes, in_chans=missing_in_chans,
                              head_count=1, token_mlp_mode="mix_skip").cuda()
    
    return model_full, model_missing

In [None]:
def load_model(model_full, model_missing, saved_model_path):
    print("Constructing model from saved file... ")
    checkpoint = torch.load(saved_model_path)
    model_full.load_state_dict(checkpoint["model_full"])
    model_missing.load_state_dict(checkpoint["model_missing"])

    return model_full, model_missing

In [None]:
# Load Model
model_full, model_missing = build_model(inp_shape = args.inputshape, num_classes=args.number_classes,
                                        full_in_chans=args.full_in_chans, missing_in_chans=args.missing_in_chans)
model_full, model_missing= load_model(model_full, model_missing, args.saved_model_path)
model_missing.eval()

In [None]:
val_scores_miss=0
val_loss_wt=0
val_loss_et=0
val_loss_ct=0
        
for phase in ['eval']:
    loader = loaders[phase]
    total = len(loader)
    for batch_id, (batch_x, batch_y) in enumerate(loader):
        batch_x, batch_y = batch_x.cuda(non_blocking=True), batch_y.cuda(non_blocking=True)

        batch_x = Rearrange('b c h w d -> b c d h w')(batch_x)
        batch_y = Rearrange('b c h w d -> b c d h w')(batch_y)

        with torch.no_grad():
            output_missing = model_missing(batch_x[:, 0: args.n_missing_modalities])


        val_sc_miss, val_wt_miss, val_et_miss, val_ct_miss = measure_dice_score(output_missing, batch_y, 
                                                                                thresh = [0.48, 0.42, 0.31],
                                                                                wt_j=3, ct_j=2, et_j=None)


        val_scores_miss += val_sc_miss
        val_loss_wt += val_wt_miss
        val_loss_et += val_et_miss
        val_loss_ct += val_ct_miss

        # output
        dice_missing_1 = (val_scores_miss/(batch_id+1)) 
        dice_wt_1 = (val_loss_wt/(batch_id+1))
        dice_et_1 = (val_loss_et/(batch_id+1))
        dice_ct_1 = (val_loss_ct/(batch_id+1))


    print(f'### Val DSC missing: {dice_missing_1}, WT: {dice_wt_1}, CT: {dice_ct_1}, ET: {dice_et_1}')