In [None]:
#!/usr/bin/env python3
# encoding: utf-8
import yaml
from data import make_data_loaders
from models import build_model
from models.discriminator import get_style_discriminator
from solver import make_optimizer_double
from losses import get_losses, bce_loss, DiceLoss
import os
import torch
import torch.optim as optim
import torch.nn.functional as F
from utils import init_env
import nibabel as nib
import numpy as np

In [None]:
def load_old_model(model_full, model_missing, saved_model_path):
    print("Constructing model from saved file... ")
    checkpoint = torch.load(saved_model_path)
    model_full.load_state_dict(checkpoint["model_full"])
    model_missing.load_state_dict(checkpoint["model_missing"])

    return model_full, model_missing
        
def to_numpy(tensor):
    if isinstance(tensor, (int, float)):
        return tensor
    else:
        return tensor.data.cpu().numpy()

def eval_metrics(gt, pred):
    loss_hw = criteria(np.where(gt>0, 1, 0), np.where(pred>0, 1, 0))
    loss_ed = criteria(np.where(gt==1, 1, 0)+np.where(gt==4, 1, 0), np.where(pred==1, 1, 0)+np.where(pred==4, 1, 0))
    loss_ct = criteria(np.where(gt==4, 1, 0), np.where(pred==4, 1, 0))
    return (loss_hw, loss_ed, loss_ct)
        
        
def evaluate_sample(batch_pred_full, batch_pred_missing, batch_y):
    def get_mask(seg_volume):
        seg_volume = seg_volume.cpu().numpy()
        seg_volume = np.squeeze(seg_volume)
        wt_pred = seg_volume[0]
        tc_pred = seg_volume[1]
        et_pred = seg_volume[2]
        mask = np.zeros_like(wt_pred)
        TH = 0.4
        mask[wt_pred >= TH] = 2
        mask[tc_pred >= TH] = 1
        mask[et_pred >= TH] = 4
        mask = mask.astype("uint8")
        #mask_nii = nib.Nifti1Image(mask, np.eye(4))
        return mask

    pred_nii_full = get_mask(batch_pred_full)
    pred_nii_miss = get_mask(batch_pred_missing)
    gt_nii = get_mask(batch_y)
    
    metric_full  = eval_metrics(gt_nii, pred_nii_full)
    metric_miss  = eval_metrics(gt_nii, pred_nii_miss)
    
    return metric_full, metric_miss   


In [None]:
## Main section
config = yaml.load(open('./config.yml'), Loader=yaml.FullLoader)
init_env('0')
loaders = make_data_loaders(config)
model_full, model_missing = build_model(inp_dim1 = 4, inp_dim2 = 1)
model_full    = model_full.cuda()
model_missing = model_missing.cuda()
d_style = get_style_discriminator(num_classes = 128).cuda()
task_name = 'brats2018_flair'
log_dir = os.path.join(config['path_to_log'], task_name)
criteria = DiceLoss() 

In [None]:
def evaluate_performance(model_full, model_missing, loaders):
    class_score_full  = np.array((0.,0.,0.))
    class_score_mono  = np.array((0.,0.,0.))
    loader = loaders['eval']
    total = len(loader)
    with torch.no_grad():
        for batch_id, (batch_x, batch_y) in enumerate(loader):
            batch_x, batch_y = batch_x.cuda(non_blocking=True), batch_y.cuda(non_blocking=True)
            seg_f, style_f, content_f = model_full(batch_x[:,:])
            seg_m, style_m, content_m = model_missing(batch_x[:,0:1])
            metric_full, metric_miss = evaluate_sample(seg_f, seg_m, batch_y)
            class_score_full += metric_full
            class_score_mono += metric_miss

            
    class_score_full  /= total   
    class_score_mono /= total
    print(f' validation Dise score full modalities class>> whole: {class_score_full[0]} core:{class_score_full[1]}  enh:{class_score_full[2]}')
    print(f' validation Dise score missing modality  class>> whole: {class_score_mono[0]} core:{class_score_mono[1]}  enh:{class_score_mono[2]}')


saved_model_path = log_dir+'/model_weights.pth'
model_full, model_missing = load_old_model(model_full, model_missing, saved_model_path)
evaluate_performance(model_full, model_missing, loaders)

