In [2]:
import numpy as np
import sys
import os
from pathlib import Path
import glob
import json
import re

## Run Evaluation

In [3]:
from utils import basic, eval, vis

BART toolbox not setup properly or not available


In [4]:
notebook_dir = os.getcwd()
parent_dir = os.path.abspath(os.path.join(notebook_dir, '..'))
checkpoint_path = ("results/checkpoints")

In [5]:
model_state=4999
sub = "S10"
slice = 0
acc_factors = [104] # [15, 26, 52, 104]
lambd = [0.15]
#
# ### Define all models to compare
for acc_factor, l in zip(acc_factors, lambd):
    model = {}
    # PISCO-NIK
    model["PISCO"] = {}
    model["PISCO"]["group"] = "pisco"
    model["PISCO"]["exp"] = ("lamda{}_alpha0.0001___hdr0.0_slice{}_R{}".format(l, slice, acc_factor))
    model["PISCO"]["model_state"] = model_state
    
    # PISCO-NIK-dist
    model["PISCO-dist"] = {}
    model["PISCO-dist"]["group"] = "pisco-kreg-L1dist"
    model["PISCO-dist"]["exp"] = ("lamda{}_alpha0.0001___hdr0.0_slice{}_R{}".
                                  format(l, slice, acc_factor))
    model["PISCO-dist"]["model_state"] = model_state

    plot_order = ["ref", "nufft25", "xdgrasp25", "nik", "PISCO-dist", "PISCO"]
    # plot_order = ["nufft25", "xdgrasp25", "nik", "PISCO-dist", "PISCO"]

In [7]:
results_path = os.path.join(notebook_dir, "comparison_results_{}_slice{}_R{}".format(sub, slice, acc_factor))  #
if not os.path.exists(results_path):
    os.makedirs(results_path)

## Load further reference methods
pisco_path = os.path.join(parent_dir, checkpoint_path, model["PISCO"]["group"] + "_" + sub, model["PISCO"]["exp"])
pisco_exp = os.listdir(pisco_path)[0]
config = basic.parse_config(os.path.join(pisco_path, pisco_exp, "model_checkpoints/config.yml"))

print("Loading pisco model from: ", pisco_path)

Loading pisco model from:  /home/iml/veronika.spieker/workspace/PISCO-priv/results/checkpoints/pisco_S10/lamda0.15_alpha0.0001___hdr0.0_slice0_R104


In [8]:
# ######### NUFFT reference #########
######### grasp reference #########
grasp_path = os.path.join(parent_dir, "dummy_data/grasprecon/")
grasp_paths = glob.glob(grasp_path +
                "/grasp_reference_{}_*_R{}.npz".format(slice, config["dataset"]["acc_factor"]))

print("Loading grasp recon from: ", grasp_paths)

Loading grasp recon from:  ['/home/iml/veronika.spieker/workspace/PISCO-priv/dummy_data/grasprecon/grasp_reference_0_25MS_R104.npz']


In [11]:
######### NIK reference #########
nik_group_name = config["model"]["pretrained"]["pretrain_group"] + "_S" + str(config["subject_name"])
nik_exp_name = (config["model"]["pretrained"]["pretrain_exp"]
                + "*_omega" + str(config['model']['params']['omega_0'])
                + "*_sigma" +  str(config['encoding']['sigma'])
                + "*_hdr" + str(config['hdr_ff_factor'])
                + "_slice" + str(config["slice"])
                + "_R" + str(config['dataset']['acc_factor']))
nik_group_path = basic.find_subfolder(os.path.join(parent_dir, checkpoint_path), nik_group_name)
nik_exp_path = basic.find_subfolder(nik_group_path, nik_exp_name)
nik_path = os.listdir(nik_exp_path)[0]
nik = np.load(nik_exp_path + "/" + nik_path + '/rec_test/recon_{}.npz'.format(model_state), allow_pickle=True)

In [13]:
######### LOAD IMAGES ##############################################################################################################
img = {}
eval_dict = {}

### GRASP
for g in grasp_paths:
    grasp = np.load(g, allow_pickle=True)
    nMS = int(re.search(r'(\d+)MS', g).group(1))
    img["xdgrasp{}".format(nMS)] = grasp["grasp"][:, 0, ...].transpose(1, 0, 2, 3,4)  # from batch, ms1, ms2, z, y, x to ms, ech, z, y, x
    img["xdgrasp{}".format(nMS)] = grasp["grasp"][:, 0, ...].transpose(1, 0, 2, 3,4)  # from batch, ms1, ms2, z, y, x to ms, ech, z, y, x
    img["nufft{}".format(nMS)] = grasp["nufft"][:, 0, ...].transpose(1, 0, 2, 3,4)  # from batch, ms1, ms2, z, y, x to ms, ech, z, y, x
    img["xdgrasp{}".format(nMS)] = img["xdgrasp{}".format(nMS)][:, [config["echo"]], ...]  # select echo
    img["nufft{}".format(nMS)] = img["nufft{}".format(nMS)][:, [config["echo"]], ...]  # select echo

In [14]:
### NIK
img["nik"] = nik["recon"]

### PISCO (and all comparisons)
paths = {}
for key, mod in model.items():
    temp_path = os.path.join(parent_dir, checkpoint_path, mod["group"] + "_" + sub, mod["exp"])
    temp_exp_path = os.listdir(temp_path)[0]
    paths[key] = os.path.join(temp_path, temp_exp_path, 'rec_test/recon_{}.npz'.format(mod["model_state"]))
    img[key] = np.load(paths[key], allow_pickle=True)["recon"]


#'' Save comparison files in path to traceback
paths["NIK"] = nik_exp_path
paths["GRASP"] = grasp_path
with open(os.path.join(results_path, "paths.txt"), "w") as f:
    json.dump(model, f, indent=4)
    f.write("\n Model Settings")
    json.dump(paths, f, indent=4)

In [15]:
######### EVALUATE ################################################################################################################
### Define reference
if config["data_type"] == "knee":
    ref_file = "grasp_reference_{}_1MS_R1.npz".format(config["slice"])
    img["ref"] = np.load(os.path.join(os.path.dirname(grasp_paths[0]), ref_file))["nufft"][...,0,:,:].transpose(2,0,1,3,4) # add
    img["ref"] = img["ref"][:, [config["echo"]], ...] # select echo
    # img["ref"] = grasp["R1"].item()["INUFFTnufft"].repeat(100,1,1,1,1) #  grasp["R{}".format(acc_factor)].item()["INUFFTnufft"].repeat(100,1,1,1,1)
elif config["data_type"] == "abdominal_sos":
    ref_file = "grasp_reference_{}_1MS_R1.npz".format(config["slice"]) # ToDo: decide for reference here?
    img["ref"] = np.load(os.path.join(os.path.dirname(grasp_paths[0]), ref_file))["nufft"][...,0,:,:].transpose(2,0,1,3,4)
    img["ref"] = img["ref"].repeat(100,0)[:, [config["echo"]], ...]
elif config["data_type"] == "cardiac_cine":
    ref_file = "grasp_reference_{}_25MS_R1.npz".format(config["slice"])
    img["ref"] = np.load(os.path.join(os.path.dirname(grasp_paths[0]), ref_file))["grasp"][...,0,:,:].transpose(2,0,1,3,4)
else:
    AssertionError("Invalid data type: {}".format(config["data_type"]))

In [16]:
### PostProcess
# Crop all images to reference
import utils.mri
img["ref"] = utils.mri.center_crop(img["ref"], [config["final_shape"], config["final_shape"]])
for i in img.keys():
    if i != "ref":
        img[i] = utils.mri.center_crop(img[i], img["ref"].shape[-2:])
# Rescale
for key in img.keys():
    img[key] = basic.torch2numpy(img[key])
    img[key] = np.abs(img[key])
    if config["data_type"] in ["abdominal_phantom", "abdominal_phantom_nohist"]:
        if img[key].shape[0]!= img["ref"].shape[0]:
            img[key] = eval.create_hystereses(img[key], dim_axis=0)
    elif config["data_type"] in ["abdominal_sos"]:
        img[key] = eval.create_hystereses(img[key], dim_axis=0)
    img[key] = eval.postprocess(img[key], img["ref"])

## cutoff hysteresis (if no hysteresis considered then inhale=exhale -> process/plot only half)
if config["data_type"] in ["abdominal_sos"]:
    for key in img.keys():
        cutoff = (img[key].shape[0] // 2)
        img[key] = img[key][:cutoff, ...]

In [17]:
## Compute metrics
ech = 0
eval_str_xy, eval_str_xt, eval_str_yt = {},{},{}
img_diff = {}
for key in img.keys():
    if key != "ref":
        if config["data_type"] in ["abdominal_phantom", "abdominal_phantom_nohist", "cardiac_cine", "knee"]:
            eval_dict[key] = eval.get_eval_metrics(img[key][:,ech, ...], img["ref"][:, ech,  ...])
        elif "11_R0" in sub:
            eval_dict[key] = eval.get_eval_metrics(img[key][[0], ech, ...], img["ref"][[t], ech, ...]) # calculate metric only for temproal value
        eval_str_xy[key] = eval.make_string_from_value_dict(eval_dict[key], default_keys=["psnr", "fsim"])
        eval_str_xt[key] = eval.make_string_from_value_dict(eval_dict[key], default_keys=["fsim_xt"])
        eval_str_yt[key] = eval.make_string_from_value_dict(eval_dict[key], default_keys=["fsim_yt"])

        img_diff[key] = np.abs(img[key] - img["ref"]).squeeze(1).squeeze(1)

    img_diff["ref"] = np.zeros_like(img["ref"]).squeeze(1).squeeze(1)


Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]




Loading model from: /home/iml/veronika.spieker/anaconda3/envs/pisco_nik/lib/python3.10/site-packages/lpips/weights/v0.1/alex.pth
Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off]




Loading model from: /home/iml/veronika.spieker/anaconda3/envs/pisco_nik/lib/python3.10/site-packages/lpips/weights/v0.1/vgg.pth
ssim : 0.609
ssim_std : 0.042
psnr : 19.158
psnr_std : 0.76
rmse : 0.329
rmse_std : 0.034
fsim : 0.641
fsim_std : 0.015
lpips_alex : 0.144
lpips_alex_std : 0.018
lpips_vgg : 0.309
lpips_vgg_std : 0.022
fsim_xt : 0.487
fsim_xt_std : 0.047
fsim_yt : 0.491
fsim_yt_std : 0.035
lpips_xt_alex : 0.069
lpips_xt_alex_std : 0.025
lpips_yt_alex : 0.078
lpips_yt_alex_std : 0.03
lpips_xt_vgg : 0.199
lpips_xt_vgg_std : 0.03
lpips_yt_vgg : 0.197
lpips_yt_vgg_std : 0.03
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /home/iml/veronika.spieker/anaconda3/envs/pisco_nik/lib/python3.10/site-packages/lpips/weights/v0.1/alex.pth
Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off]
Loading model from: /home/iml/veronika.spieker/anaconda3/envs/pisco_nik/lib/python3.10/site-packages/lpips/weights/v0.1/vgg.pth
ssim : 0.171


In [18]:
import medutils
img_sat, img_nosat, img_diff_plot = {}, {}, {}
for key in plot_order:
    img_nosat[key] = medutils.visualization.contrastStretching(img[key][...].squeeze(1).squeeze(1), saturated_pixel=0.00)
    img_nosat[key] /= 255.0
    img_diff_plot[key] = img_diff[key]
    img_sat[key] = medutils.visualization.contrastStretching(img[key][...].squeeze(1).squeeze(1), saturated_pixel=0.04)
    img_sat[key] /= 255.0


In [None]:
t = 10
x = img["ref"].shape[-2] // 2 + 2  # to avoid center point
y = img["ref"].shape[-1] // 2 + 2  # to avoid center point

if config["data_type"] == "cardiac_cine" and config["subject_name"] == 10:
    for key in img_nosat.keys():
        img_nosat[key] = np.flip(img_nosat[key], axis=2)
        if key != "ref":
            img_diff_plot[key] = np.flip(img_diff_plot[key], axis=2)


print("Example reconstruction for subject {} at R={} and slice {}".format(sub, acc_factor, slice))
vis.plot_3d_slices_from_dict(img_nosat, t=t, x=x, fontsize=14,
                             eval_str_dict={"t": eval_str_xy,  "x": eval_str_xt},
                             results_path=results_path + "/recon_comp_3d_slices_t{}_x{}_y{}.eps".format(t, x, y),
                             cmap="gray")
