## Part I : Global includes



In [None]:

%matplotlib inline
import warnings
warnings.filterwarnings("ignore")

import os
import pathlib
from matplotlib import pyplot as plt
import torch
import numpy as np
import cv2
import tensorflow as tf

# Uncomment to disable GPU usage.
# This is required for some models like Pridnet which has too many traininable parameters
tf.config.set_visible_devices([], 'GPU')

from tqdm.notebook import tqdm
import random

import data_importer

In [None]:
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession

config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)

## Part II : Loading test images

In [None]:
from dataloader_lodopab_ct import get_validation_dataloader
noisy_dataset = get_validation_dataloader("../../../../Dataset/LoDoPaB-CT/ground_truth_validation/")

In [None]:
def reconstruct_image_from_patches(patches, num_patches_per_row):
    patch_size = patches.shape[1]  # Assuming square patches
    num_patches = patches.shape[0]

    # Calculate the number of rows
    num_patches_per_col = num_patches // num_patches_per_row

    # Initialize an empty image to store the reconstructed result
    reconstructed_image = np.zeros((num_patches_per_col * patch_size, num_patches_per_row * patch_size))

    # Reshape the patches into a 2D array
    patches_2d = patches.reshape((num_patches_per_col, num_patches_per_row, patch_size, patch_size))

    # Reconstruct the image by placing each patch in its corresponding position
    for i in range(num_patches_per_col):
        for j in range(num_patches_per_row):
            reconstructed_image[i * patch_size:(i + 1) * patch_size, j * patch_size:(j + 1) * patch_size] = patches_2d[i, j]

    return np.expand_dims(reconstructed_image, axis=-1)

noisy_array = [None] * 28
print(len(noisy_dataset))
for i, data in enumerate(noisy_dataset):
    noisy_array[i] = reconstruct_image_from_patches(torch.squeeze(data[i], axis=0), 8)
    if i == 28:
        break
noisy_array = np.array(noisy_array)
print(noisy_array)

Visualization of the noisy / ground truth image pair

In [None]:
from data_importer import denormalize, trunc

with torch.no_grad():    
    for i, data in enumerate(noisy_array):
        plt.imshow(trunc(denormalize(data)), vmin=-160.0, vmax=240.0, cmap='gray')
        plt.show()

### Part III : Setup for Inference

In [None]:
# Inference
def inference_single_image(model, noisy_image):
    input_image = np.expand_dims(noisy_image, axis=0)
    predicted_image = model.predict(input_image)
    a = np.abs(np.min(predicted_image))
    b = np.max(predicted_image)
    
    #predicted_image = predicted_image * (b - a) + a
    return predicted_image[0]

def inference_batch_images(model, noisy_images):
    input_image = noisy_images

    predicted_image = model.predict(input_image).astype(np.float64)
    return predicted_image

In [None]:
def rgb2gray(rgb):
    return np.expand_dims(np.dot(rgb[...,:3], [0.2989, 0.5870, 0.1140]), axis=-1)

from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio
import sys
sys.path.append('../')

from metrics import compute_SSIM, compute_PSNR
from skimage.metrics import mean_squared_error  as mse

def calculate_psnr(original_image, reconstructed_image,range=400):
    return peak_signal_noise_ratio(original_image, reconstructed_image,data_range=range) 

    psnr_value = peak_signal_noise_ratio(original_image, reconstructed_image, data_range=240+160)
    return psnr_value

def calculate_ssim(original_image, reconstructed_image, range=400.0):    
    ssim_value = ssim(original_image.astype(np.int16), reconstructed_image.astype(np.int16), win_size=11, channel_axis=2, data_range=range)
    return ssim_value

def calculate_rmse(original_image, reconstructed_image):
    return mse(original_image, reconstructed_image)

def visualize_predictions(model, X_test,  n, predictions, model_name):
    random_numbers = list(range(n)) # not very random
    for i in random_numbers:
        gt_image= X_test[i].astype(np.float16)
        predicted_image = predictions[i].astype(np.float16)

        if predicted_image.shape[-1] == 3:
            predicted_image = rgb2gray(predicted_image)
                                
            
        psnr_recon =  calculate_psnr(trunc(denormalize(gt_image)), trunc(denormalize(predicted_image)))
        ssim_recon = calculate_ssim(trunc(denormalize(gt_image)),  trunc(denormalize(predicted_image)))
        rmse_recon = calculate_rmse(trunc(denormalize(gt_image)),  trunc(denormalize(predicted_image)))
        
        psnr_recon = round(psnr_recon, 4)
        ssim_recon = round(ssim_recon, 4)
        rmse_recon = round(rmse_recon, 4)
        
        f, axarr = plt.subplots(1,2, figsize=(21,21))

        axarr[0].imshow(trunc(denormalize(gt_image)), cmap='gray', vmin=-160.0, vmax=240.0)
        axarr[0].set_title("QD Image")
        axarr[0].set_axis_off()
        axarr[1].imshow(trunc(denormalize(predicted_image)),  cmap='gray', vmin=-160.0, vmax=240.0)
        axarr[1].set_title("{} Predicted Image : PSNR={}\nSSIM={}\nRMSE={}".format(model_name, psnr_recon, ssim_recon, rmse_recon))
        axarr[1].set_axis_off()
        
        plt.show()

In [None]:
from skimage.metrics import peak_signal_noise_ratio

def get_average_metrics(predicted_images,  _noisy_array):
    psnr_original_mean = 0
    psnr_prediction_mean = 0

    ssim_original_mean = 0
    ssim_prediction_mean = 0

    mse_original_mean = 0
    mse_prediction_mean = 0

    i = 0
    for  gt_img, predicted_img in zip(noisy_array, predicted_images):
        predicted_img=  predicted_images[i]
        if predicted_img.shape[-1] == 3:
            predicted_img = rgb2gray(predicted_img)
            
        psnr_recon =  calculate_psnr(trunc(denormalize(gt_img)), trunc(denormalize(predicted_img)))
        ssim_recon = calculate_ssim(trunc(denormalize(gt_img)),  trunc(denormalize(predicted_img)))
        rmse_recon = calculate_rmse(trunc(denormalize(gt_img)),  trunc(denormalize(predicted_img)))

        psnr_prediction_mean += psnr_recon
        
        ssim_prediction_mean += ssim_recon

        mse_prediction_mean += rmse_recon
        
        i = i + 1        
    
    psnr_prediction_mean/=noisy_array.shape[0]

    ssim_prediction_mean/=noisy_array.shape[0]

    mse_prediction_mean/=noisy_array.shape[0]
    
    print("Predicted average gt-predicted PSNR ->", psnr_prediction_mean)

    print("Predicted average gt-predicted SSIM ->", ssim_prediction_mean)

    print("Predicted average gt-predicted MSE->", mse_prediction_mean)
    
    return round(psnr_prediction_mean, 4), round(ssim_prediction_mean, 4), round(mse_prediction_mean, 4)


## Part IV : Evaluation of each model

## Model 1 : Hformer (for base reference)

In [None]:
sys.path.append('../denoising-models/hformer_vit/model/')
sys.path.append('../denoising-models/hformer_vit/')
from hformer_model_extended import get_hformer_model, PatchExtractor

hformer_model = get_hformer_model(num_channels_to_be_generated=64, name="hformer_model_extended")
hformer_model.build(input_shape=(None, 64, 64, 1))
hformer_model.load_weights('../denoising-models/hformer_vit/test/experiments/full_dataset/hformer_64_channel_custom_loss_epochs_48.h5')
print('Model summary : ')
print(hformer_model.summary())

In [None]:
def reconstruct_image_from_patches(patches, num_patches_per_row):
    patch_size = patches.shape[1]  # Assuming square patches
    num_patches = patches.shape[0]

    # Calculate the number of rows
    num_patches_per_col = num_patches // num_patches_per_row

    # Initialize an empty image to store the reconstructed result
    reconstructed_image = np.zeros((num_patches_per_col * patch_size, num_patches_per_row * patch_size))

    # Reshape the patches into a 2D array
    patches_2d = patches.reshape((num_patches_per_col, num_patches_per_row, patch_size, patch_size))
    # Reconstruct the image by placing each patch in its corresponding position

    for i in range(num_patches_per_col):
        for j in range(num_patches_per_row):
            reconstructed_image[i * patch_size:(i + 1) * patch_size, j * patch_size:(j + 1) * patch_size] = patches_2d[i, j]

    return np.expand_dims(reconstructed_image, axis=-1)

In [None]:
# View the predictions
patch_extractor = PatchExtractor(patch_size=64, stride=64, name="patch_extractor")
noisy_image_patches_array = patch_extractor(noisy_array)

hformer_prediction_patches = hformer_model.predict(noisy_image_patches_array)
hformer_predictions = np.expand_dims(reconstruct_image_from_patches(hformer_prediction_patches[0:64], 8), axis=0)

for i in range(1, int(hformer_prediction_patches.shape[0] / 64)): 
    reconstructed_image = reconstruct_image_from_patches(hformer_prediction_patches[i * 64 : i * 64 + 64], num_patches_per_row=8)
    reconstructed_image = np.expand_dims(reconstructed_image, axis=0)

    hformer_predictions = np.append(hformer_predictions, reconstructed_image, axis=0)
visualize_predictions(hformer_predictions, noisy_array,  len(noisy_array), hformer_predictions, "hformer")

## Model 2 : XB Model 

In [None]:
from torchinfo import summary
sys.path.append('../denoising-models/novel_model/')
from xb_denoiser_with_hf_conv import XModel 
x_model = XModel(num_channels=1)
x_model.load_state_dict(torch.load('../denoising-models/novel_model_weights/xb_hf_model_889.pth'))
x_model.eval()
print('Model summary : ')
print(summary(x_model))

In [None]:
# Get prediction images.
from pytorch_wavelets import DTCWTForward, DTCWTInverse
dwt = DTCWTForward(J=3).cuda()
idwt = DTCWTInverse().cuda()

x_model_prediction_patches = []

for img in noisy_array:
    img_tensor = torch.from_numpy(np.expand_dims(img, axis=-0)).float()

    with torch.no_grad():
        output_tensor = x_model(img_tensor)
        _ot = output_tensor
        output_tensor = output_tensor.cpu().numpy()


        _ot = torch.transpose(_ot, 1, 3)
        noisy = torch.from_numpy(np.expand_dims(img, axis=-1).astype(np.float32)).to('cuda')
        noisy = torch.transpose(noisy, 1, 3)
        noisy = torch.transpose(noisy, 0, 2)

        print('ot shape : ', _ot.shape, ' noisy shape : ', noisy.shape)


        prediction_approx, prediction_high_freq = dwt(_ot.cuda())
        prediction_high_freq_low, prediction_high_freq_mid, prediction_high_freq_coarse = prediction_high_freq[0], prediction_high_freq[1], prediction_high_freq[2]

        noisy_approx, noisy_high_freq = dwt(noisy)

        noisy_high_freq_fine, noisy_high_freq_mid, noisy_high_freq_coarse = noisy_high_freq[0], noisy_high_freq[1] , noisy_high_freq[2]

        reconstructed_prediction_image_with_high_freq_swap = idwt((noisy_approx, noisy_high_freq))
        reconstructed_prediction_image_with_high_freq_swap = torch.transpose(reconstructed_prediction_image_with_high_freq_swap, 3, 1)


        #denoise_high_freq = torch.transpose(torch.from_numpy(denoised_high_freq), 1, 3)
        #denoised_noisy_approx, denoised_high_freq = dwt(denoise_high_freq.cuda())
        #noisy_high_freq_fine, noisy_high_freq_mid, noisy_high_freq_coarse = denoised_high_freq[0], denoised_high_freq[1] , denoised_high_freq[2]

        wavelet_high_freq_swapped = [None] * 3
        wavelet_high_freq_swapped[0] =noisy_high_freq_fine
        wavelet_high_freq_swapped[1] =prediction_high_freq_mid
        wavelet_high_freq_swapped[2] =prediction_high_freq_coarse

        reconstructed_prediction_image = idwt((prediction_approx,wavelet_high_freq_swapped))
        reconstructed_prediction_image = torch.transpose(reconstructed_prediction_image, 1, 3)
        x_model_prediction_patches.append(reconstructed_prediction_image.detach().cpu().numpy())
    
visualize_predictions(x_model, noisy_array,  len(noisy_array), np.concatenate(x_model_prediction_patches,axis=0), "x_model")

## Model 3 : Y model

In [None]:
sys.path.append('../denoising-models/mwcnn/')
from yb_denoiser import YModel
y_model = YModel(num_channels=4)
y_model.load_state_dict(torch.load('../denoising-models/novel_model_weights/yb_model_203.pth'))
y_model.eval()
print('Model summary : ')
print(summary(y_model))

In [None]:
# Get prediction images.

y_model_prediction_patches = []

for img in noisy_array:
    img_tensor = torch.unsqueeze(torch.from_numpy(img).float(), dim=0)

    
    with torch.no_grad():
        output_tensor = y_model(img_tensor)
        _ot = output_tensor
        output_tensor = output_tensor.cpu().numpy()


        _ot = torch.transpose(_ot, 1, 3)
        noisy = torch.from_numpy(np.expand_dims(img, axis=-1).astype(np.float32)).to('cuda')
        noisy = torch.transpose(noisy, 1, 3)
        noisy = torch.transpose(noisy, 0, 2)

        prediction_approx, prediction_high_freq = dwt(_ot.cuda())
        prediction_high_freq_low, prediction_high_freq_mid, prediction_high_freq_coarse = prediction_high_freq[0], prediction_high_freq[1], prediction_high_freq[2]

        noisy_approx, noisy_high_freq = dwt(noisy)

        noisy_high_freq_fine, noisy_high_freq_mid, noisy_high_freq_coarse = noisy_high_freq[0], noisy_high_freq[1] , noisy_high_freq[2]

        reconstructed_prediction_image_with_high_freq_swap = idwt((noisy_approx, noisy_high_freq))
        reconstructed_prediction_image_with_high_freq_swap = torch.transpose(reconstructed_prediction_image_with_high_freq_swap, 3, 1)


        #denoise_high_freq = torch.transpose(torch.from_numpy(denoised_high_freq), 1, 3)
        #denoised_noisy_approx, denoised_high_freq = dwt(denoise_high_freq.cuda())
        #noisy_high_freq_fine, noisy_high_freq_mid, noisy_high_freq_coarse = denoised_high_freq[0], denoised_high_freq[1] , denoised_high_freq[2]

        wavelet_high_freq_swapped = [None] * 3
        wavelet_high_freq_swapped[0] =noisy_high_freq_fine
        wavelet_high_freq_swapped[1] =prediction_high_freq_mid
        wavelet_high_freq_swapped[2] =prediction_high_freq_coarse

        reconstructed_prediction_image = idwt((prediction_approx,wavelet_high_freq_swapped))
        reconstructed_prediction_image = torch.transpose(reconstructed_prediction_image, 1, 3)
        y_model_prediction_patches.append(reconstructed_prediction_image.detach().cpu().numpy())
    

visualize_predictions(y_model, noisy_array,  len(noisy_array), np.concatenate(y_model_prediction_patches,axis=0), "y_model")

# Model 3 : SA Model

In [None]:


sys.path.append('../denoising-models/hformer_self_attention')
from torchinfo import summary


from sa_b_model import SABModel

sa_model = SABModel(num_channels=4).cuda()
sa_model.load_state_dict(torch.load('../denoising-models/novel_model_weights/sab_model_333.pth'))
sa_model.eval()
print('model summary\n', summary(sa_model, input_size=(64, 64, 64, 1)))


In [None]:

sa_prediction_patches = []

with torch.no_grad():    
    for i, data in enumerate(noisy_image_patches_array):
        noisy = data
    
        predictions = sa_model(torch.unsqueeze(torch.from_numpy(noisy.numpy()), dim=0).to('cuda')).cpu()

        sa_prediction_patches.append(predictions.detach().cpu())
    
sa_prediction_patches = np.concatenate(sa_prediction_patches, axis=0)
sa_predictions = np.expand_dims(reconstruct_image_from_patches(sa_prediction_patches[0:64], 8), axis=0)


for i in range(1, int(sa_prediction_patches.shape[0] / 64)): 
    reconstructed_image = reconstruct_image_from_patches(sa_prediction_patches[i * 64 : i * 64 + 64], num_patches_per_row=8)
    reconstructed_image = torch.from_numpy(np.expand_dims(reconstructed_image, axis=0).astype(np.float32))
    _ot = torch.transpose(reconstructed_image, 1, 3)
    noisy = torch.from_numpy(np.expand_dims(noisy_array[i], axis=-1).astype(np.float32)).to('cuda')
    noisy = torch.transpose(noisy, 1, 3)
    noisy = torch.transpose(noisy, 0, 2)

    print('ot shape : ', _ot.shape, ' noisy shape : ', noisy.shape)


    prediction_approx, prediction_high_freq = dwt(_ot.cuda())
    prediction_high_freq_low, prediction_high_freq_mid, prediction_high_freq_coarse = prediction_high_freq[0], prediction_high_freq[1], prediction_high_freq[2]

    noisy_approx, noisy_high_freq = dwt(noisy)

    noisy_high_freq_fine, noisy_high_freq_mid, noisy_high_freq_coarse = noisy_high_freq[0], noisy_high_freq[1] , noisy_high_freq[2]

    reconstructed_prediction_image_with_high_freq_swap = idwt((noisy_approx, noisy_high_freq))
    reconstructed_prediction_image_with_high_freq_swap = torch.transpose(reconstructed_prediction_image_with_high_freq_swap, 3, 1)


    #denoise_high_freq = torch.transpose(torch.from_numpy(denoised_high_freq), 1, 3)
    #denoised_noisy_approx, denoised_high_freq = dwt(denoise_high_freq.cuda())
    #noisy_high_freq_fine, noisy_high_freq_mid, noisy_high_freq_coarse = denoised_high_freq[0], denoised_high_freq[1] , denoised_high_freq[2]

    wavelet_high_freq_swapped = [None] * 3
    wavelet_high_freq_swapped[0] =noisy_high_freq_fine
    wavelet_high_freq_swapped[1] =prediction_high_freq_mid
    wavelet_high_freq_swapped[2] =prediction_high_freq_coarse

    reconstructed_prediction_image = idwt((prediction_approx,wavelet_high_freq_swapped))
    reconstructed_prediction_image = torch.transpose(reconstructed_prediction_image, 1, 3)


    sa_predictions= np.append(sa_predictions, reconstructed_image, axis=0)



visualize_predictions(sa_model, noisy_array,  len(noisy_array), sa_predictions, "sa model")


# Model 4 : Z Model

In [None]:

sys.path.append('../denoising-models/hformer_pytorch')
from torchinfo import summary

from zb_model import ZModel 

z_model = ZModel(num_channels=4).cuda()
z_model.load_state_dict(torch.load('../denoising-models/novel_model_weights/zb_model_335.pth'))
z_model.eval()
print('model summary\n', summary(z_model, input_size=(64, 64, 64, 1)))

In [None]:

z_prediction_patches = []

with torch.no_grad():    
    for i, data in enumerate(noisy_image_patches_array):
        noisy = data
    
        predictions = z_model(torch.unsqueeze(torch.from_numpy(noisy.numpy()), dim=0).float().to('cuda')).cpu()

        z_prediction_patches.append(predictions.detach().cpu())
    
z_prediction_patches = np.concatenate(z_prediction_patches, axis=0)
z_predictions = np.expand_dims(reconstruct_image_from_patches(z_prediction_patches[0:64], 8), axis=0)


for i in range(1, int(z_prediction_patches.shape[0] / 64)): 
    reconstructed_image = reconstruct_image_from_patches(z_prediction_patches[i * 64 : i * 64 + 64], num_patches_per_row=8)
    reconstructed_image = torch.from_numpy(np.expand_dims(reconstructed_image, axis=0).astype(np.float32))
    
    _ot = torch.transpose(reconstructed_image, 1, 3)

    noisy = torch.from_numpy(np.expand_dims(noisy_array[i], axis=-1).astype(np.float32)).to('cuda')
    noisy = torch.transpose(noisy, 1, 3)
    noisy = torch.transpose(noisy, 0, 2)

    print('ot shape : ', _ot.shape, ' noisy shape : ', noisy.shape)


    prediction_approx, prediction_high_freq = dwt(_ot.cuda())
    prediction_high_freq_low, prediction_high_freq_mid, prediction_high_freq_coarse = prediction_high_freq[0], prediction_high_freq[1], prediction_high_freq[2]

    noisy_approx, noisy_high_freq = dwt(noisy)

    noisy_high_freq_fine, noisy_high_freq_mid, noisy_high_freq_coarse = noisy_high_freq[0], noisy_high_freq[1] , noisy_high_freq[2]

    reconstructed_prediction_image_with_high_freq_swap = idwt((noisy_approx, noisy_high_freq))
    reconstructed_prediction_image_with_high_freq_swap = torch.transpose(reconstructed_prediction_image_with_high_freq_swap, 3, 1)


    #denoise_high_freq = torch.transpose(torch.from_numpy(denoised_high_freq), 1, 3)
    #denoised_noisy_approx, denoised_high_freq = dwt(denoise_high_freq.cuda())
    #noisy_high_freq_fine, noisy_high_freq_mid, noisy_high_freq_coarse = denoised_high_freq[0], denoised_high_freq[1] , denoised_high_freq[2]

    wavelet_high_freq_swapped = [None] * 3
    wavelet_high_freq_swapped[0] =noisy_high_freq_fine
    wavelet_high_freq_swapped[1] =prediction_high_freq_mid
    wavelet_high_freq_swapped[2] =prediction_high_freq_coarse

    reconstructed_prediction_image = idwt((prediction_approx,wavelet_high_freq_swapped))
    reconstructed_prediction_image = torch.transpose(reconstructed_prediction_image, 1, 3)


    z_predictions= np.append(z_predictions, reconstructed_image, axis=0)
visualize_predictions(z_model, noisy_array,  len(noisy_array), z_predictions, "z model")


# Model 5 : W Model

In [None]:
sys.path.append('../denoising-models/hformer_pytorch')
from torchinfo import summary

from w_b_model import WModel 

w_model = WModel(num_channels=8).cuda()
w_model.load_state_dict(torch.load('../denoising-models/novel_model_weights/wb_model_662.pth'))
w_model.eval()
print('model summary\n', summary(w_model, input_size=(64, 64, 64, 1)))

In [None]:

w_prediction_patches = []

with torch.no_grad():    
    for i, data in enumerate(noisy_image_patches_array):
        noisy = data
    
        predictions = w_model(torch.unsqueeze(torch.from_numpy(noisy.numpy()), dim=0).to('cuda')).cpu()

        w_prediction_patches.append(predictions.detach().cpu())
    
w_prediction_patches = np.concatenate(w_prediction_patches, axis=0)

w_predictions = np.expand_dims(reconstruct_image_from_patches(w_prediction_patches[0:64], 8), axis=0)


for i in range(1, int(w_prediction_patches.shape[0] / 64)): 
    reconstructed_image = reconstruct_image_from_patches(w_prediction_patches[i * 64 : i * 64 + 64], num_patches_per_row=8)
    reconstructed_image = np.expand_dims(reconstructed_image, axis=0)

    w_predictions= np.append(w_predictions, reconstructed_image, axis=0)

visualize_predictions(w_model, noisy_array,  len(noisy_array), w_predictions, "w model")


In [None]:
# Wavelet with wmodel

# W model with wavelet
from pytorch_wavelets import DTCWTForward, DTCWTInverse
dwt = DTCWTForward(J=3).cuda()
idwt = DTCWTInverse().cuda()
wavelet_w_model_prediction_patches =[]


with torch.no_grad():
    for i in range(noisy_image_patches_array.shape[0] // 64):
        noisy = noisy_image_patches_array[i * 64 : i * 64 + 64]
        noisy = torch.from_numpy(noisy.cpu().numpy()).to('cuda')
        prediction = w_model(noisy)

        prediction_img = torch.transpose(prediction, 1, 3)
        transposed_noisy_image = torch.transpose(noisy, 1, 3)

        prediction_approx, prediction_high_freq = dwt(prediction_img.cuda())
        prediction_high_freq_low, prediction_high_freq_mid, prediction_high_freq_coarse = prediction_high_freq[0], prediction_high_freq[1], prediction_high_freq[2]

        noisy_approx, noisy_high_freq = dwt(transposed_noisy_image.cuda())

        noisy_high_freq_fine, noisy_high_freq_mid, noisy_high_freq_coarse = noisy_high_freq[0], noisy_high_freq[1] , noisy_high_freq[2]

        reconstructed_prediction_image_with_high_freq_swap = idwt((noisy_approx, noisy_high_freq))
        reconstructed_prediction_image_with_high_freq_swap = torch.transpose(reconstructed_prediction_image_with_high_freq_swap, 3, 1)


        #denoise_high_freq = torch.transpose(torch.from_numpy(denoised_high_freq), 1, 3)
        #denoised_noisy_approx, denoised_high_freq = dwt(denoise_high_freq.cuda())
        #noisy_high_freq_fine, noisy_high_freq_mid, noisy_high_freq_coarse = denoised_high_freq[0], denoised_high_freq[1] , denoised_high_freq[2]

        wavelet_high_freq_swapped = [None] * 3
        wavelet_high_freq_swapped[0] =noisy_high_freq_fine
        wavelet_high_freq_swapped[1] =prediction_high_freq_mid
        wavelet_high_freq_swapped[2] =prediction_high_freq_coarse

        reconstructed_prediction_image = idwt((prediction_approx,wavelet_high_freq_swapped))
        reconstructed_prediction_image = torch.transpose(reconstructed_prediction_image, 1, 3)
        wavelet_w_model_prediction_patches.append(reconstructed_prediction_image.detach().cpu().numpy())

wavelet_w_predictions = [None] * noisy_array.shape[0]
def reconstruct(patches,  num_images):
    num_patches_per_image = patches.shape[0] // num_images
    for i in range(num_images):
        
        image_patches = patches[i * num_patches_per_image:i * num_patches_per_image + num_patches_per_image]
        
        reconstruct_image = (reconstruct_image_from_patches(image_patches, 8))
        wavelet_w_predictions[i]  = reconstruct_image
wavelet_w_model_prediction_patches = np.concatenate(wavelet_w_model_prediction_patches)
reconstruct(wavelet_w_model_prediction_patches, 28)

In [None]:

visualize_predictions(w_model, noisy_array,  len(noisy_array), wavelet_w_predictions, "wavelet w model")

## Part V : Side by side comparison of all models

In [None]:
from prettytable import PrettyTable

pt = PrettyTable()
pt.field_names = ["Model", "PSNR", "SSIM", "MSE"]

hformer_metrics = get_average_metrics(hformer_predictions,  noisy_array)
x_metrics= get_average_metrics(np.concatenate(x_model_prediction_patches, axis=0), noisy_array)
y_metrics= get_average_metrics(np.concatenate(y_model_prediction_patches, axis=0),  noisy_array)
sa_metrics = get_average_metrics(sa_predictions,  noisy_array)
z_metrics = get_average_metrics(z_predictions,  noisy_array)
w_metrics = get_average_metrics(w_predictions,  noisy_array)
wavelet_w_metrics = get_average_metrics(wavelet_w_predictions, noisy_array)





In [None]:

pt.add_row(["Original X-y pairs (No Model)", '-', '-', "-"])
pt.add_row(["Hformer",str(hformer_metrics[0]), str(hformer_metrics[1]), str(round(hformer_metrics[2], 4))])
pt.add_row(["X Model",str(x_metrics[0]), str(x_metrics[1]), str(round(x_metrics[2], 4))])
pt.add_row(["Y Model",str(y_metrics[0]), str(y_metrics[1]), str(round(y_metrics[2], 4))])
pt.add_row(["SA Model",str(sa_metrics[0]), str(sa_metrics[1]), str(round(sa_metrics[2], 4))])
pt.add_row(["Z Model",str(z_metrics[0]), str(z_metrics[1]), str(round(z_metrics[2], 4))])
pt.add_row(["W Model",str(w_metrics[0]), str(w_metrics[1]), str(round(w_metrics[2], 4))])
pt.add_row(["W Wavelet Model",str(wavelet_w_metrics[0]), str(wavelet_w_metrics[1]), str(round(wavelet_w_metrics[2], 4))])

print(pt)

## Part 6 : Output of predictions of all 4 models side by side for direct visualize comparison

In [None]:
def visualize_predictions_all_models(X_test,  n, hformer_predictions, x_predictions, y_predictions, sa_predictions, z_predictions, w_predictions, wavelet_w_predictions):
    random_numbers = list(range(n))  # not very random
    for i in random_numbers:
        gt_image= X_test[i]

        hformer_pred = hformer_predictions[i]
        x_pred = x_predictions[i]
        y_pred = y_predictions[i]
        sa_pred = sa_predictions[i]
        z_pred = z_predictions[i]
        w_pred = w_predictions[i]
        wavelet_w_pred = wavelet_w_predictions[i]

        models = ["HFORMER", "X", "Y", "SA",  "Z", "W", "Wavelet W"]
        predictions = [hformer_pred, x_pred, y_pred, sa_pred, z_pred, w_pred, wavelet_w_pred]

        # Display QD and FD images
        f, axarr = plt.subplots(1, 1 + len(models), figsize=(41,41))



        axarr[0].imshow(trunc(denormalize(gt_image)), cmap='gray', vmin=-160.0, vmax=240.0)
        axarr[0].set_title("FD Image")
        axarr[0].set_axis_off()

        for j, (model_name, predicted_image) in enumerate(zip(models, predictions), start=1):
            if predicted_image.shape[-1] == 3:
                predicted_image = rgb2gray(predicted_image)

            psnr_recon = calculate_psnr(trunc(denormalize(gt_image)), trunc(denormalize(predicted_image)))
            ssim_recon = calculate_ssim(trunc(denormalize(gt_image)), trunc(denormalize(predicted_image)))
            rmse_recon = calculate_rmse(trunc(denormalize(gt_image)), trunc(denormalize(predicted_image)))

            psnr_recon = round(psnr_recon, 4)
            ssim_recon = round(ssim_recon, 4)
            mse_recon = round(rmse_recon, 4)

            axarr[j].imshow(trunc(denormalize(predicted_image)), cmap='gray', vmin=-160.0, vmax=240.0)
            axarr[j].set_title("{}\nPSNR={}\nSSIM={}\nMSE={}".format(model_name, psnr_recon, ssim_recon, mse_recon))
            axarr[j].set_axis_off()

        plt.savefig('../../output/lodopab_ct_wavelet_b_models/combined_outputs_image_index_{}.png'.format(i))
        plt.show()

In [None]:

visualize_predictions_all_models(noisy_array,  len(noisy_array), hformer_predictions, np.concatenate(x_model_prediction_patches, axis=0), np.concatenate(y_model_prediction_patches, axis=0), sa_predictions, z_predictions, w_predictions, wavelet_w_predictions)