In [4]:
import os
import json
import torch
import pickle
import numpy as np
import SimpleITK as sitk
from os.path import join
from skimage import morphology
from scipy.ndimage import zoom
import medpy.metric.binary as metric

def load_json(file: str):
    with open(file, 'r') as f:
        a = json.load(f)
    return a

def load_pickle(file: str, mode: str = 'rb'):
    with open(file, mode) as f:
        a = pickle.load(f)
    return a

def get_attribute(case_dic):
    sample_slices = 3
    name = case_dic['name']
    data = np.load(case_dic['preprocess_npy']) # shape: (3, all_slices, consecutive slices, h, w)
    properties = load_pickle(case_dic['preprocess_pkl'])
    GT = sitk.GetArrayFromImage(sitk.ReadImage(case_dic['label']))
    # 计算取样的个数: 3/5/7
    consecutive_slices = data.shape[2]
    mid_slice = consecutive_slices // 2
    image = data[0][:, mid_slice-sample_slices//2:mid_slice+sample_slices//2+1,...].copy()
    return name, image, GT, properties

def pad2origin(arr, box, origin_shape, extend_slices=0):

    pad_arr = np.pad(arr, ((box[0][0]+extend_slices, origin_shape[0]-box[0][1]+extend_slices), 
                            (box[1][0], origin_shape[1]-box[1][1]), 
                            (box[2][0], origin_shape[2]-box[2][1])), 'constant', constant_values=(0, 0))
    assert pad_arr.shape[0] == origin_shape[0] and pad_arr.shape[1] == origin_shape[1] and pad_arr.shape[2] == origin_shape[2]
    return pad_arr
def reduction_for2D(predict, properties):
    
    restore_label = lambda label, x, y: zoom(label, (1.0, x / 256, y / 256), order=0)
    predict = restore_label(predict, properties['crop_shape'][1], properties['crop_shape'][2])
    predict = pad2origin(predict, properties['liver_box'], properties['origin_shape'], 0)

    return predict

def cal_metric(PR, GT):
    assert len(PR.shape) == 3 and len(PR.shape) == len(GT.shape)
    res = {}
    res['dice'] = metric.dc(PR, GT)
    res['jc'] = metric.jc(PR, GT)
    res['precition'] = metric.precision(PR, GT)
    res['recall'] = metric.recall(PR, GT)
    res['specificity'] = metric.specificity(PR, GT)
    return res

def print_metrics(metric_res):
    avg = {}        
    for name, case_metric in metric_res.items():
        print('Case-name: ', name)
        for key, value in case_metric.items():
            print('{}: {}'.format(key, value))
            if key not in avg.keys():
                avg[key] = value
            else:
                avg[key] += value
        print('\n')
    print('AVG:')
    for key, value in avg.items():
        print('{}: {}'.format(key, value / len(metric_res.keys())))
    print('\n')

def postprocess(pre, threshhold=200):
    post_arr = morphology.remove_small_objects(pre.astype(np.bool8), threshhold, connectivity=3)
    return post_arr.astype(np.float32)

def get_image_from_array(arr, properties):
        
    image = sitk.GetImageFromArray(arr)
    image.SetOrigin(properties['itk_origin'])
    image.SetDirection(properties['itk_direction'])
    image.SetSpacing(properties['itk_spacing'])

    return image

def save_predict(arr, case_name, properties, save_dir='/data1/zfx/code/latentAugmentation/predict'):

    label = get_image_from_array(arr, properties)
    sitk.WriteImage(label, join(save_dir, case_name+'.nii.gz'))

    

In [5]:
pred_path = '/data1/zfx/code/latentAugmentation/medseg/saved/train_BileDuct_keep-origin-false_n_cls_2/BileDuct/cooperative_training/6/model/best/checkpoints/report/pred_npy'
test_json = load_json('/data1/zfx/data/BileDuct/preprocessed_data/preprocess_dataset.json')

In [6]:
metric_res = {}
for case_dic in test_json['test']:
    case_name, image, GT, properties = get_attribute(case_dic)
    case_id = case_name.split('_')[-1]
    soft_pred = np.load(join(pred_path, case_id + "_soft_pred.npy"))
    PR = torch.from_numpy(soft_pred).max(1)[1].numpy()
    restore_PR = reduction_for2D(PR, properties)
    postprocess_PR = postprocess(pre=restore_PR)
    metric_res[case_name] = cal_metric(postprocess_PR, GT)
    save_predict(postprocess_PR, case_name, properties, save_dir='/data1/zfx/code/latentAugmentation/predict/cop-6/')
    print('save case: ', case_name, ' done...')

print_metrics(metric_res)

save case:  BileDuct_002  done...
save case:  BileDuct_010  done...
save case:  BileDuct_012  done...
Case-name:  BileDuct_002
dice: 0.8088799627012875
jc: 0.6790918944959653
precition: 0.8228143651407475
recall: 0.7954096603092202
specificity: 0.9998067904116106


Case-name:  BileDuct_010
dice: 0.7453131427899778
jc: 0.5940232325755683
precition: 0.8054108980058428
recall: 0.6935613241942541
specificity: 0.9999609433516458


Case-name:  BileDuct_012
dice: 0.7960017145714762
jc: 0.6611319336623187
precition: 0.9175542135602526
recall: 0.7028871225502565
specificity: 0.9999376698333615


AVG:
dice: 0.7833982733542472
jc: 0.6447490202446174
precition: 0.8485931589022809
recall: 0.7306193690179104
specificity: 0.9999018011988726




In [2]:
pred_path = '/data1/zfx/code/latentAugmentation/medseg/saved/train_BileDuct_keep-origin-false_n_cls_2/BileDuct/cooperative_training/5/model/best/checkpoints/report/pred_npy'
test_json = load_json('/data1/zfx/data/BileDuct/preprocessed_data/preprocess_dataset.json')

In [14]:
for case_dic in test_json['test']:
    case_name, image, GT, properties = get_attribute(case_dic)
    case_id = case_name.split('_')[-1]
    print(case_id)
    if case_id == '010':
        print(properties.keys())
        print(properties['size'])
        print(properties['liver_box'])
        print(properties['origin_shape'])
        print(properties['crop_shape'])

        crop_properties = load_pickle(case_dic['crop_pkl'])
        print(crop_properties.keys())
        print(crop_properties['size'])
        print(crop_properties['liver_box'])
        print(crop_properties['origin_shape'])
        print(crop_properties['crop_shape'])


        break
        

002
010
odict_keys(['liver_box', 'size', 'itk_origin', 'itk_spacing', 'itk_direction', 'origin_shape', 'crop_shape'])
[449 512 512]
[[291, 425], [120, 359], [68, 353]]
(449, 512, 512)
(134, 239, 285)
odict_keys(['liver_box', 'size', 'itk_origin', 'itk_spacing', 'itk_direction', 'origin_shape', 'crop_shape'])
[449 512 512]
[[291, 425], [120, 359], [68, 353]]
(449, 512, 512)
(134, 239, 285)
