In [2]:
import os
import numpy as np
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr

# Define the function to compute MAE
def compute_mae(image1, image2):
    return np.abs(image1 - image2).mean()

# Paths to the folders
gt_folder = "/workdir/carrot/splited_data_15k/test/B"
condition_folder = "/workdir/carrot/splited_data_15k/test/A"
# pre_folder = "results/108_CT2PET_UncerBBDM3c/LBBDM-f4/sample_to_eval/200"

dataset_name = 'LBBDMxVq13_15k'

pre_folder = "/workdir/ssd2/nguyent_petct/tiennh/BBDM_folk/results/" + dataset_name + "/LBBDM-f4/sample_to_eval/400"




# Lists to store the computed metrics for each pair
# ssim_scores = []
# psnr_scores = []

mae_scores = []
high_mae, high_mae_gts, high_mae_pds, high_mae_conditions = [], [], [], []

# Iterate through the files in the ground truth folder
for filename in os.listdir(gt_folder):
    # Make sure the file is a numpy array
    if filename.endswith(".npy"):
        # Construct the paths for the corresponding ground truth and predicted files
        try:
            gt_path = os.path.join(gt_folder, filename) 
            pre_path = os.path.join(pre_folder, filename)
        
            # Load the images as numpy arrays
            gt_img = np.load(gt_path, allow_pickle=True)
            pre_img = np.load(pre_path, allow_pickle=True)
        except:
            continue   
        # Preprocess the predicted image
        pre_img1 = pre_img.mean(axis=-1) / 32767.0
        
        # Normalize the ground truth image
        gt_img1 = gt_img / 32767.0
        # Calculate the SSIM, PSNR, and MAE for this pair
        # ssim_score = ssim(pre_img, gt_img, data_range=1)
        # psnr_score = psnr(pre_img, gt_img, data_range=1)
        mae = compute_mae(pre_img1, gt_img1)
    
        # Append the scores to the corresponding lists
        # ssim_scores.append(ssim_score)
        # psnr_scores.append(psnr_score)
        mae_scores.append(mae * 32767)

        if mae * 32767 > 900  : 
            high_mae_gts.append(gt_img) 
            high_mae_pds.append(pre_img)
            high_mae_conditions.append(np.load(os.path.join(condition_folder, filename), allow_pickle=True))
            high_mae.append(mae * 32767)

# Calculate the mean scores over all pairs
# mean_ssim = np.mean(ssim_scores)
# mean_psnr = np.mean(psnr_scores)
mean_mae = np.mean(mae_scores)

# Print the mean metrics
# print("Mean SSIM: {}".format(mean_ssim))
# print("Mean PSNR: {}".format(mean_psnr))
print("Mean MAE: {}".format(mean_mae))


Mean MAE: 351.2983712748704


In [9]:
import matplotlib.pyplot as plt

save_fig_dir = './fig/' + dataset_name + '/400'
# save_fig_dir = './fig/small_mae/' + dataset_name

if not os.path.exists(save_fig_dir):
    os.makedirs(save_fig_dir)


# Define a function to visualize images
def visualize_and_save(gt, pre, condition , mae, filename):
    plt.figure(figsize=(18, 5))

    plt.subplot(1, 3, 1)
    plt.imshow(condition, cmap='gray')
    plt.title('CT')
    
    plt.subplot(1, 3, 2)
    plt.imshow(gt, cmap='gray')
    plt.title('Ground Truth')

    plt.subplot(1, 3, 3)
    plt.imshow(pre, cmap='gray')
    plt.title('Predicted')

    
    plt.suptitle(f'MAE: {mae:.2f}')
    plt.savefig(filename)
    plt.close()

# Visualize images with high MAE



for i in range(len(high_mae)):
    gt_image = high_mae_gts[i]
    pre_image = high_mae_pds[i]
    mae_value = high_mae[i]
    condition_img = high_mae_conditions[i]
    filename = os.path.join(save_fig_dir, str(i) + '.png' )
    # title = f"MAE: {mae_value:.2f}"
    visualize_and_save(gt_image, pre_image, condition_img, mae_value, filename)


In [3]:
np.min(mae_scores), np.max(mae_scores)

(57.192235168588546, 1096.5176792527561)

In [1]:
import torch 

torch.arange(100-1, -1, -1)

  from .autonotebook import tqdm as notebook_tqdm


tensor([99, 98, 97, 96, 95, 94, 93, 92, 91, 90, 89, 88, 87, 86, 85, 84, 83, 82,
        81, 80, 79, 78, 77, 76, 75, 74, 73, 72, 71, 70, 69, 68, 67, 66, 65, 64,
        63, 62, 61, 60, 59, 58, 57, 56, 55, 54, 53, 52, 51, 50, 49, 48, 47, 46,
        45, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 33, 32, 31, 30, 29, 28,
        27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10,
         9,  8,  7,  6,  5,  4,  3,  2,  1,  0])

In [21]:
midsteps = torch.arange(1000 - 1, 1,
                                        step=-((1000 - 1) / (200 - 2))).long()
steps = torch.cat((midsteps, torch.Tensor([1, 0]).long()), dim=0)
np.flip(steps.numpy())

array([  0,   1,   5,  10,  15,  20,  25,  30,  35,  40,  45,  50,  55,
        60,  65,  70,  75,  80,  85,  90,  95, 100, 105, 111, 116, 121,
       126, 131, 136, 141, 146, 151, 156, 161, 166, 171, 176, 181, 186,
       191, 196, 201, 206, 211, 216, 222, 227, 232, 237, 242, 247, 252,
       257, 262, 267, 272, 277, 282, 287, 292, 297, 302, 307, 312, 317,
       322, 327, 333, 338, 343, 348, 353, 358, 363, 368, 373, 378, 383,
       388, 393, 398, 403, 408, 413, 418, 423, 428, 433, 438, 444, 449,
       454, 459, 464, 469, 474, 479, 484, 489, 494, 499, 504, 509, 514,
       519, 524, 529, 534, 539, 544, 549, 555, 560, 565, 570, 575, 580,
       585, 590, 595, 600, 605, 610, 615, 620, 625, 630, 635, 640, 645,
       650, 655, 660, 666, 671, 676, 681, 686, 691, 696, 701, 706, 711,
       716, 721, 726, 731, 736, 741, 746, 751, 756, 761, 766, 771, 777,
       782, 787, 792, 797, 802, 807, 812, 817, 822, 827, 832, 837, 842,
       847, 852, 857, 862, 867, 872, 877, 882, 888, 893, 898, 90

In [9]:
c = 5
time_steps = np.asarray(list(range(0, 1000, c))) + 1

time_steps

array([  1,   6,  11,  16,  21,  26,  31,  36,  41,  46,  51,  56,  61,
        66,  71,  76,  81,  86,  91,  96, 101, 106, 111, 116, 121, 126,
       131, 136, 141, 146, 151, 156, 161, 166, 171, 176, 181, 186, 191,
       196, 201, 206, 211, 216, 221, 226, 231, 236, 241, 246, 251, 256,
       261, 266, 271, 276, 281, 286, 291, 296, 301, 306, 311, 316, 321,
       326, 331, 336, 341, 346, 351, 356, 361, 366, 371, 376, 381, 386,
       391, 396, 401, 406, 411, 416, 421, 426, 431, 436, 441, 446, 451,
       456, 461, 466, 471, 476, 481, 486, 491, 496, 501, 506, 511, 516,
       521, 526, 531, 536, 541, 546, 551, 556, 561, 566, 571, 576, 581,
       586, 591, 596, 601, 606, 611, 616, 621, 626, 631, 636, 641, 646,
       651, 656, 661, 666, 671, 676, 681, 686, 691, 696, 701, 706, 711,
       716, 721, 726, 731, 736, 741, 746, 751, 756, 761, 766, 771, 776,
       781, 786, 791, 796, 801, 806, 811, 816, 821, 826, 831, 836, 841,
       846, 851, 856, 861, 866, 871, 876, 881, 886, 891, 896, 90

In [10]:
type(time_steps)

numpy.ndarray

array([  1,   6,  11,  16,  21,  26,  31,  36,  41,  46,  51,  56,  61,
        66,  71,  76,  81,  86,  91,  96, 101, 106, 112, 117, 122, 127,
       132, 137, 142, 147, 152, 157, 162, 167, 172, 177, 182, 187, 192,
       197, 202, 207, 212, 217, 223, 228, 233, 238, 243, 248, 253, 258,
       263, 268, 273, 278, 283, 288, 293, 298, 303, 308, 313, 318, 323,
       328, 334, 339, 344, 349, 354, 359, 364, 369, 374, 379, 384, 389,
       394, 399, 404, 409, 414, 419, 424, 429, 434, 439, 445, 450, 455,
       460, 465, 470, 475, 480, 485, 490, 495, 500, 505, 510, 515, 520,
       525, 530, 535, 540, 545, 550, 556, 561, 566, 571, 576, 581, 586,
       591, 596, 601, 606, 611, 616, 621, 626, 631, 636, 641, 646, 651,
       656, 661, 667, 672, 677, 682, 687, 692, 697, 702, 707, 712, 717,
       722, 727, 732, 737, 742, 747, 752, 757, 762, 767, 772, 778, 783,
       788, 793, 798, 803, 808, 813, 818, 823, 828, 833, 838, 843, 848,
       853, 858, 863, 868, 873, 878, 883, 889, 894, 899, 904, 90