In [2]:
import sys
sys.path.insert(0, '/playpen-raid2/qinliu/projects/iSegFormer')
sys.path.insert(0, '/playpen-raid2/qinliu/projects/iSegFormer/maskprop')
sys.path.insert(0, '/playpen-raid2/qinliu/projects/iSegFormer/maskprop/MiVOS')


import numpy as np
import os
import SimpleITK as sitk
import torch
from maskprop.MiVOS.model.propagation.prop_net import PropagationNetwork
from maskprop.MiVOS.model.fusion_net import FusionNet
from maskprop.MiVOS.interact.interactive_utils import load_volume, images_to_torch
from inference_core import InferenceCore

model_folder = '/playpen-raid2/qinliu/projects/iSegFormer/maskprop/Med-STCN/saves'
prop_model_path = os.path.join(model_folder, 'Aug01_22.03.33_retrain_s4_ft_from_med_10000.pth')
fusion_model_path = os.path.join(model_folder, 'fusion_stcn.pth')

# Load checkpoints
prop_saved = torch.load(prop_model_path)
prop_model = PropagationNetwork().cuda().eval()
prop_model.load_state_dict(prop_saved)

fusion_saved = torch.load(fusion_model_path)
fusion_model = FusionNet().cuda().eval()
fusion_model.load_state_dict(fusion_saved)

<All keys matched successfully>

In [3]:
def get_selected_mask(label, label_frame_idx, label_idx):
    label_frame = label[label_frame_idx]

    label_frame_fc = np.zeros_like(label_frame)
    label_frame_fc[label_frame == label_idx] = 1

    label_frame_bg = np.ones_like(label_frame)
    label_frame_bg[label_frame_fc] = 0

    label_frame = np.stack([label_frame_bg, label_frame_fc], axis=0)

    return label_frame


def mask_prop(processor, label, label_frame_idxes, label_idx):
    current_mask = None
    for label_frame_idx in label_frame_idxes:
        label_frame = get_selected_mask(label, label_frame_idx, label_idx)
        label_frame = torch.from_numpy(label_frame)
        label_frame = torch.unsqueeze(label_frame, dim=1)
        current_mask = processor.interact(label_frame, label_frame_idx)

    return current_mask


def metrics(gt, pred, ignored_idx, label_idx):
    gt_ = np.copy(gt)
    pred_ = np.copy(pred)

    gt_[gt_ != label_idx] = 0
    gt_[gt_ == label_idx] = 1

    pred_[pred_ != label_idx] = 0
    pred_[pred_ == label_idx] = 1

    for idx in ignored_idx:
        pred_[idx, :, :] = gt_[idx, :, :]

    intersect = np.sum(np.logical_and(gt_, pred_))
    union = np.sum(gt_) + np.sum(pred_)
    dice = 2 * intersect / union
    sen = intersect / np.sum(gt_)
    ppv = intersect / np.sum(pred_)

    return dice, sen, ppv

In [4]:
val_file = '/playpen-raid2/qinliu/data/AbdomenCT-1K/Organ-12-Subset_frames/trainval/ImageSets/val.txt'
volume_names = []
with open(val_file, "r") as lines:
    for line in lines:
        volume_names.append(line.rstrip('\n'))

volume_folder = '/playpen-raid2/qinliu/data/AbdomenCT-1K/Organ-12-Subset/Image'
mask_folder = '/playpen-raid2/qinliu/data/AbdomenCT-1K/Organ-12-Subset/Mask'



for label_idx in range(1, 13):
    dices, sens, ppvs = [], [], []
    for idx, volume_name in enumerate(volume_names):
        # print(idx, volume_name)

        volume_path = os.path.join(volume_folder, f'{volume_name}.nii.gz')
        images = load_volume(volume_path, normalize=True)

        label_path = os.path.join(mask_folder, f'{volume_name}.nii.gz')
        label = load_volume(label_path, normalize=False)
        if len(label.shape) == 4:
            label = label[:, :, :, 0]

        num_objects = 1
        mem_freq = 5
        mem_profile = 0

        processor = InferenceCore(
            prop_model, 
            fusion_model, 
            images_to_torch(images, device='cpu'), 
            num_objects, 
            mem_freq=mem_freq, 
            mem_profile=mem_profile
        )


        # # pick up the frame with the maximum area as the memory frame
        max_area, max_area_idx = -1, label.shape[0] // 2
        for i in range(label.shape[0]):
            area = np.count_nonzero(label[i])
            if area > max_area:
                max_area = area
                max_area_idx = i
        
        label_frame_idxes = [max_area_idx]
        propagated_mask = mask_prop(processor, label, label_frame_idxes, label_idx)
        propagated_mask[propagated_mask != 0] = label_idx

        dice, sen, ppv = metrics(label, propagated_mask, label_frame_idxes, label_idx)
        # print('dice: ', dice, ' sen: ', sen, ' ppv: ', ppv)
        dices.append(dice)
        sens.append(sen)
        ppvs.append(ppv)

        # # save the propagated mask
        # mask_save_folder = '/playpen-raid2/qinliu/projects/iSegFormer/maskprop/MiVOS/test_results'
        # mask = sitk.GetImageFromArray(propagated_mask)
        # mask.CopyInformation(sitk.ReadImage(label_path))
        # mask_save_path = os.path.join(mask_save_folder,
        #     '{}_frames_{}.nii.gz'.format(volume_name, str(len(label_frame_idxes))))
        # sitk.WriteImage(mask, mask_save_path)

    print(f'label {label_idx}: DSC: {np.mean(dices)} ({np.std(dices)})')
    print(f'label {label_idx}: SEN: {np.mean(sens)} ({np.std(sens)})')
    print(f'label {label_idx}: PPV: {np.mean(ppvs)} ({np.std(ppvs)})')

label 1: DSC: 0.9545882195054233 (0.026941315941833875)
label 1: SEN: 0.9710432383842382 (0.010572612375934038)
label 1: PPV: 0.9399559451997733 (0.048535453516696685)


  ppv = intersect / np.sum(pred_)


label 2: DSC: 0.1902508018283256 (0.35551885669896016)
label 2: SEN: 0.18662704525695234 (0.35307694940167456)
label 2: PPV: nan (nan)
label 3: DSC: 0.7940970098016753 (0.18400191117402911)
label 3: SEN: 0.9583234662925456 (0.025942715654001063)
label 3: PPV: 0.718019198541557 (0.24338945633648412)
label 4: DSC: 0.007429484491687725 (0.04639711577972179)
label 4: SEN: 0.012727872798928341 (0.0794855401531769)
label 4: PPV: nan (nan)


KeyboardInterrupt: 