In [1]:
import torch
import numpy as np


def get_selected_mask(label, label_frame_idx, label_idx):
    label_frame = label[label_frame_idx]

    label_frame_fc = np.zeros_like(label_frame)
    label_frame_fc[label_frame == label_idx] = 1

    label_frame_bg = np.ones_like(label_frame)
    label_frame_bg[label_frame_fc] = 0

    label_frame = np.stack([label_frame_bg, label_frame_fc], axis=0)

    return label_frame


def mask_prop(processor, label, label_frame_idxes, label_idx):
    current_mask = None
    for label_frame_idx in label_frame_idxes:
        label_frame = get_selected_mask(label, label_frame_idx, label_idx)
        label_frame = torch.from_numpy(label_frame)
        label_frame = torch.unsqueeze(label_frame, dim=1)
        current_mask = processor.interact(label_frame, label_frame_idx)

    return current_mask


def metrics(gt, pred, ignored_idx, label_idx):
    gt_ = np.copy(gt)
    pred_ = np.copy(pred)

    gt_[gt_ != label_idx] = 0
    gt_[gt_ == label_idx] = 1

    pred_[pred_ != label_idx] = 0
    pred_[pred_ == label_idx] = 1

    for idx in ignored_idx:
        pred_[idx, :, :] = gt_[idx, :, :]

    intersect = np.sum(np.logical_and(gt_, pred_))
    union = np.sum(gt_) + np.sum(pred_)
    if union == 0:
        dice = 1
    else:
        dice = 2 * intersect / union

    if np.sum(gt_) == 0:
        sen = 1
    else:
        sen = intersect / np.sum(gt_)

    if np.sum(pred_) == 0:
        ppv = 1
    else:
        ppv = intersect / np.sum(pred_)

    return dice, sen, ppv

In [2]:
import sys
sys.path.insert(0, '/playpen-raid2/qinliu/projects/iSegFormer')
sys.path.insert(0, '/playpen-raid2/qinliu/projects/iSegFormer/maskprop')
sys.path.insert(0, '/playpen-raid2/qinliu/projects/iSegFormer/maskprop/MiVOS')


import numpy as np
import os
import SimpleITK as sitk
import torch
from maskprop.MiVOS.model.propagation.prop_net import PropagationNetwork
from maskprop.MiVOS.model.fusion_net import FusionNet
from maskprop.MiVOS.interact.interactive_utils import load_volume, images_to_torch
from inference_core import InferenceCore

stcn_no_ft = 'stcn.pth'
stcn_ft_no_cycle = 'Aug01_15.34.08_retrain_s4_ft_from_med_10000.pth'
stcn_ft_cycle = 'Aug01_22.03.33_retrain_s4_ft_from_med_10000.pth'

model_path = stcn_ft_cycle

model_folder = '/playpen-raid2/qinliu/projects/iSegFormer/maskprop/Med-STCN/saves'
prop_model_path = os.path.join(model_folder, model_path)
fusion_model_path = os.path.join(model_folder, 'fusion_stcn.pth')

# Load checkpoints
prop_saved = torch.load(prop_model_path)
prop_model = PropagationNetwork().cuda().eval()
prop_model.load_state_dict(prop_saved)

fusion_saved = torch.load(fusion_model_path)
fusion_model = FusionNet().cuda().eval()
fusion_model.load_state_dict(fusion_saved)



<All keys matched successfully>

In [3]:
volume_folder = '/playpen-raid2/qinliu/data/OAI-ZIB/test_volumes'
volume_names = os.listdir(volume_folder)
dices, sens, ppvs = [], [], []
for idx, volume_name in enumerate(volume_names):
    if not volume_name.endswith('_image.nii.gz'):
        continue
    volume_name = volume_name.split('_')[0]
    print(idx, volume_name)

    volume_path = os.path.join(volume_folder, '{}_image.nii.gz'.format(volume_name))
    images = load_volume(volume_path, normalize=True)

    label_path = os.path.join(volume_folder, '{}_label.nii.gz'.format(volume_name))
    label = load_volume(label_path, normalize=False)
    if len(label.shape) == 4:
        label = label[:, :, :, 0]

    num_objects = 1
    mem_freq = 5
    mem_profile = 0

    processor = InferenceCore(
        prop_model, 
        fusion_model, 
        images_to_torch(images, device='cpu'), 
        num_objects, 
        mem_freq=mem_freq, 
        mem_profile=mem_profile
    )

    # # pick up the frame with the maximum area as the memory frame
    max_area, max_area_idx = -1, label.shape[0] // 2
    for i in range(label.shape[0]):
        area = np.count_nonzero(label[i])
        if area > max_area:
            max_area = area
            max_area_idx = i
    
    label_frame_idxes = [max_area_idx]
    label_idx = 2 # 2: fc, 4: tc, 
    propagated_mask = mask_prop(processor, label, label_frame_idxes, label_idx)
    propagated_mask[propagated_mask != 0] = label_idx

    dice, sen, ppv = metrics(label, propagated_mask, label_frame_idxes, label_idx)
    print('dice: ', dice, ' sen: ', sen, ' ppv: ', ppv)
    dices.append(dice)
    sens.append(sen)
    ppvs.append(ppv)

    # # save the propagated mask
    # mask_save_folder = '/playpen-raid2/qinliu/projects/iSegFormer/maskprop/MiVOS/test_results'
    # mask = sitk.GetImageFromArray(propagated_mask)
    # mask.CopyInformation(sitk.ReadImage(label_path))
    # mask_save_path = os.path.join(mask_save_folder,
    #     '{}_frames_{}.nii.gz'.format(volume_name, str(len(label_frame_idxes))))
    # sitk.WriteImage(mask, mask_save_path)

print(np.mean(dices), np.std(dices))
print(np.mean(sens), np.std(sens))
print(np.mean(ppvs), np.std(ppvs))

3 9223980
dice:  0.7307789224659457  sen:  0.8055868025105717  ppv:  0.6686840199702733
4 9184790
dice:  0.7958945345083959  sen:  0.8456246531987915  ppv:  0.7516886723085143
6 9102858
dice:  0.735788489996852  sen:  0.7871534720417505  ppv:  0.6907164198553571
8 9180558
dice:  0.763632689961377  sen:  0.8555954792683705  ppv:  0.6895203166078796
9 9884591
dice:  0.7761836646244564  sen:  0.8755399481282563  ppv:  0.6970791181802234
12 9577982
dice:  0.5935523728820634  sen:  0.45737665896823915  ppv:  0.8451936565301125
13 9023193
dice:  0.6543255026681959  sen:  0.7995629346086534  ppv:  0.5537407029313111
14 9263504
dice:  0.7257349979863069  sen:  0.752020296936666  ppv:  0.7012251377442416
17 9065272
dice:  0.7375889259480615  sen:  0.8433772123197077  ppv:  0.6553816937737686
20 9056326
dice:  0.782266043548257  sen:  0.8441337277045801  ppv:  0.7288478171803159
21 9500225
dice:  0.7026591425704977  sen:  0.8537933899419902  ppv:  0.5969840106227005
22 9094783
dice:  0.732525687