In [1]:
import numpy as np
import os
import SimpleITK as sitk
import torch
from model.propagation.prop_net import PropagationNetwork
from model.fusion_net import FusionNet
from interact.interactive_utils import load_volume, images_to_torch
from inference_core import InferenceCore

model_folder = '/playpen-raid2/qinliu/projects/iSegFormer/maskprop/MiVOS/saves'
prop_model_path = os.path.join(model_folder, 'stcn.pth')
fusion_model_path = os.path.join(model_folder, 'fusion_stcn.pth')

# Load checkpoints
prop_saved = torch.load(prop_model_path)
prop_model = PropagationNetwork().cuda().eval()
prop_model.load_state_dict(prop_saved)

fusion_saved = torch.load(fusion_model_path)
fusion_model = FusionNet().cuda().eval()
fusion_model.load_state_dict(fusion_saved)

<All keys matched successfully>

In [29]:
def get_selected_mask(label, label_frame_idx, label_idx):
    label_frame = label[label_frame_idx]

    label_frame_fc = np.zeros_like(label_frame)
    label_frame_fc[label_frame == label_idx] = 1

    label_frame_bg = np.ones_like(label_frame)
    label_frame_bg[label_frame_fc] = 0

    label_frame = np.stack([label_frame_bg, label_frame_fc], axis=0)

    return label_frame


def mask_prop(processor, label, label_frame_idxes, label_idx):
    current_mask = None
    for label_frame_idx in label_frame_idxes:
        label_frame = get_selected_mask(label, label_frame_idx, label_idx)
        label_frame = torch.from_numpy(label_frame)
        label_frame = torch.unsqueeze(label_frame, dim=1)
        current_mask = processor.interact(label_frame, label_frame_idx)

    return current_mask

In [44]:
def metrics(gt, pred, ignored_idx, label_idx):
    gt_ = np.copy(gt)
    pred_ = np.copy(pred)

    gt_[gt_ != label_idx] = 0
    gt_[gt_ == label_idx] = 1

    pred_[pred_ != label_idx] = 0
    pred_[pred_ == label_idx] = 1

    for idx in ignored_idx:
        pred_[idx, :, :] = gt_[idx, :, :]

    intersect = np.sum(np.logical_and(gt_, pred_))
    union = np.sum(gt_) + np.sum(pred_)
    dice = 2 * intersect / union
    sen = intersect / np.sum(gt_)
    ppv = intersect / np.sum(pred_)

    return dice, sen, ppv

In [53]:
volume_folder = '/playpen-raid2/qinliu/data/OAI-ZIB/test_volumes'
volume_names = os.listdir(volume_folder)
dices, sens, ppvs = [], [], []
for idx, volume_name in enumerate(volume_names):
    if not volume_name.endswith('_image.nii.gz'):
        continue
    volume_name = volume_name.split('_')[0]
    print(idx, volume_name)
    # volume_name = '9092628'

    volume_path = os.path.join(volume_folder, '{}_image.nii.gz'.format(volume_name))
    images = load_volume(volume_path, normalize=True)

    label_path = os.path.join(volume_folder, '{}_label.nii.gz'.format(volume_name))
    label = load_volume(label_path, normalize=False)
    if len(label.shape) == 4:
        label = label[:, :, :, 0]

    num_objects = 1
    mem_freq = 5
    mem_profile = 0

    processor = InferenceCore(
        prop_model, 
        fusion_model, 
        images_to_torch(images, device='cpu'), 
        num_objects, 
        mem_freq=mem_freq, 
        mem_profile=mem_profile
    )

    mask_save_folder = '/playpen-raid2/qinliu/projects/iSegFormer/maskprop/MiVOS/test_results'

    label_frame_idxes = [16, 32, 48, 64, 80, 96, 112, 128, 144, 159]
    label_idx = 2 # 2: fc, 4: tc, 
    propagated_mask = mask_prop(processor, label, label_frame_idxes, label_idx)
    propagated_mask[propagated_mask != 0] = label_idx

    dice, sen, ppv = metrics(label, propagated_mask, label_frame_idxes, label_idx)
    print('dice: ', dice, ' sen: ', sen, ' ppv: ', ppv)
    dices.append(dice)
    sens.append(sen)
    ppvs.append(ppv)

    # save the propagated mask
    mask = sitk.GetImageFromArray(propagated_mask)
    mask.CopyInformation(sitk.ReadImage(label_path))
    mask_save_path = os.path.join(mask_save_folder,
        '{}_frames_{}.nii.gz'.format(volume_name, str(len(label_frame_idxes))))
    sitk.WriteImage(mask, mask_save_path)


3 9223980
Intensity range: [0.0, 352.0]
Intensity range: [0, 4]
dice:  0.8783520835788767  sen:  0.9067534745338592  ppv:  0.8516758379189595
4 9184790
Intensity range: [0.0, 303.0]
Intensity range: [0, 4]
dice:  0.8987379558949325  sen:  0.9203641669578084  ppv:  0.8781047337539154
6 9102858
Intensity range: [0.0, 281.0]
Intensity range: [0, 4]
dice:  0.8364469464575502  sen:  0.8597393780926974  ppv:  0.8143833207736749
8 9180558
Intensity range: [0.0, 302.0]
Intensity range: [0, 4]
dice:  0.8429150353733086  sen:  0.8904471391338167  ppv:  0.8002003108992082
9 9884591
Intensity range: [0.0, 339.0]
Intensity range: [0, 4]
dice:  0.8580722876044344  sen:  0.9172209598567789  ppv:  0.8060900927068961
12 9577982
Intensity range: [0.0, 284.0]
Intensity range: [0, 4]
dice:  0.8292071950333363  sen:  0.8369048461493924  ppv:  0.8216498556802607
13 9023193
Intensity range: [0.0, 336.0]
Intensity range: [0, 4]
dice:  0.8396927079718144  sen:  0.8774438693857687  ppv:  0.8050559651028496
14 9

In [54]:
# num_frames = 1
# 0.5506417985859106 0.13633750471782477
# 0.6165515508274226 0.22521868998136085
# 0.5533239215969692 0.11975684697964131

# num_frames = 3
# 0.7873367468636533 0.04113931741488201
# 0.8576886777596207 0.06374817063972507
# 0.731163060081232 0.04914972548031398

# num_frames = 5
# 0.8220504703204792 0.029006211762585322
# 0.8714322320500417 0.041223072935403324
# 0.7792633009666465 0.03441123489574747

# num_frames = 10
# 0.8504732751874041 0.024136538201735572
# 0.8811221432431899 0.03238036611790692
# 0.8225995738226196 0.027563815699245237

print(np.mean(dices), np.std(dices))
print(np.mean(sens), np.std(sens))
print(np.mean(ppvs), np.std(ppvs))

0.8504732751874041 0.024136538201735572
0.8811221432431899 0.03238036611790692
0.8225995738226196 0.027563815699245237
