## Part I : Global includes



In [None]:

%matplotlib inline
import warnings
warnings.filterwarnings("ignore")

import os
import pathlib
from matplotlib import pyplot as plt
import torch
import numpy as np
import cv2
import tensorflow as tf
from Modified_EMD2D import EMD2D
emd2d = EMD2D()

# Uncomment to disable GPU usage.
# This is required for some models like Pridnet which has too many traininable parameters
tf.config.set_visible_devices([], 'GPU')

from tqdm.notebook import tqdm
import random

import data_importer

In [None]:
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession

config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)

## Part II : Loading test images

In [None]:
#from sklearn.model_selection import train_test_split
from data_importer import load_training_images

noisy_array, gt_array = load_training_images('../../../../Dataset/LowDoseCTGrandChallenge/Training_Image_Data/', load_limited_images=True, num_images_to_load=1)

_n, _g = load_training_images('../../../../Dataset/LowDoseCTGrandChallenge/Selected_Image_Pairs/', load_limited_images=False, num_images_to_load=1)

noisy_array = np.concatenate((noisy_array, _n), axis=0)
gt_array = np.concatenate((gt_array, _g), axis=0)

extended_noisy_array, extended_gt_array = load_training_images('../../../../Dataset/LowDoseCTGrandChallenge/Training_Image_Data/', load_limited_images=True, num_images_to_load=1)

Visualization of the noisy / ground truth image pair

In [None]:
from data_importer import denormalize, trunc

for i in range(0, 3):
    f, axarr = plt.subplots(1,2, figsize=(14,14))
    axarr[0].imshow(trunc(denormalize(noisy_array[i])), vmin=-160.0, vmax=240.0, cmap='gray')
    axarr[0].set_title("Noisy image (QD)")
    axarr[1].imshow(trunc(denormalize(gt_array[i])), vmin=-160.0, vmax=240.0, cmap='gray')
    axarr[1].title.set_text("Ground Truth image (FD)")
    plt.show()

### Part III : Setup for Inference

In [None]:
# Inference
def inference_single_image(model, noisy_image):
    input_image = np.expand_dims(noisy_image, axis=0)
    predicted_image = model.predict(input_image)
    a = np.abs(np.min(predicted_image))
    b = np.max(predicted_image)
    
    #predicted_image = predicted_image * (b - a) + a
    return predicted_image[0]

def inference_batch_images(model, noisy_images):
    input_image = noisy_images

    predicted_image = model.predict(input_image).astype(np.float64)
    return predicted_image

In [None]:
def rgb2gray(rgb):
    return np.expand_dims(np.dot(rgb[...,:3], [0.2989, 0.5870, 0.1140]), axis=-1)

from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio
import sys
sys.path.append('../')

from metrics import compute_SSIM, compute_PSNR
from skimage.metrics import mean_squared_error  as mse

def calculate_psnr(original_image, reconstructed_image,range=400):
    return peak_signal_noise_ratio(original_image, reconstructed_image,data_range=range) 

    psnr_value = peak_signal_noise_ratio(original_image, reconstructed_image, data_range=240+160)
    return psnr_value

def calculate_ssim(original_image, reconstructed_image, range=400.0):    
    ssim_value = ssim(original_image.astype(np.int16), reconstructed_image.astype(np.int16), win_size=11, channel_axis=2, data_range=range)
    return ssim_value

def calculate_rmse(original_image, reconstructed_image):
    return mse(original_image, reconstructed_image)

def visualize_predictions(model, X_test, y_test, n, predictions, model_name):
    random_numbers = list(range(n)) # not very random
    for i in random_numbers:
        noisy_image = X_test[i].astype(np.float16)
        gt_image = y_test[i].astype(np.float16)
        predicted_image = predictions[i].astype(np.float16)

        if predicted_image.shape[-1] == 3:
            predicted_image = rgb2gray(predicted_image)
                                
            
        psnr_recon =  calculate_psnr(trunc(denormalize(gt_image)), trunc(denormalize(predicted_image)))
        psnr_qd =  calculate_psnr(trunc(denormalize(gt_image)),  trunc(denormalize(noisy_image)))
        ssim_recon = calculate_ssim(trunc(denormalize(gt_image)),  trunc(denormalize(predicted_image)))
        ssim_qd =calculate_ssim(trunc(denormalize(gt_image)), trunc(denormalize(noisy_image)))
        rmse_recon = calculate_rmse(trunc(denormalize(gt_image)),  trunc(denormalize(predicted_image)))
        rmse_qd=calculate_rmse(trunc(denormalize(gt_image)), trunc(denormalize(noisy_image)))
        
        psnr_recon = round(psnr_recon, 4)
        psnr_qd = round(psnr_qd, 4)
        ssim_recon = round(ssim_recon, 4)
        ssim_qd = round(ssim_qd, 4)
        rmse_recon = round(rmse_recon, 4)
        rmse_qd = round(rmse_qd, 4)
        
        f, axarr = plt.subplots(1,3, figsize=(21,21))

        axarr[0].imshow(trunc(denormalize(noisy_image)), cmap='gray', vmin=-160.0, vmax=240.0)
        axarr[0].set_title("QD Image : PSNR={}\nSSIM={}\nRMSE={}".format(psnr_qd, ssim_qd, rmse_qd))
        axarr[0].set_axis_off()
        axarr[1].imshow(trunc(denormalize(gt_image)),  cmap='gray', vmin=-160.0, vmax=240.0)
        axarr[1].set_title("FD Image")
        axarr[1].set_axis_off()
        axarr[2].imshow(trunc(denormalize(predicted_image)), cmap='gray', vmin=-160.0, vmax=240.0)
        axarr[2].set_title("{} Predicted Image : PSNR={}\nSSIM={}\nRMSE={}".format(model_name, psnr_recon, ssim_recon, rmse_recon))
        axarr[2].set_axis_off()
        
        plt.show()

In [None]:
from skimage.metrics import peak_signal_noise_ratio

def get_average_metrics(predicted_images, _gt_array, _noisy_array):
    psnr_original_mean = 0
    psnr_prediction_mean = 0

    ssim_original_mean = 0
    ssim_prediction_mean = 0

    mse_original_mean = 0
    mse_prediction_mean = 0

    if np.all(_gt_array) != None:
        gt_array = _gt_array
        noisy_array = _noisy_array
        

    i = 0
    for gt_img, noisy_img, predicted_img in zip(gt_array, noisy_array, predicted_images):
        predicted_img=  predicted_images[i]
        if predicted_img.shape[-1] == 3:
            predicted_img = rgb2gray(predicted_img)
            
        psnr_recon =  calculate_psnr(trunc(denormalize(gt_img)), trunc(denormalize(predicted_img)))
        psnr_qd =  calculate_psnr(trunc(denormalize(gt_img)),  trunc(denormalize(noisy_img)))
        ssim_recon = calculate_ssim(trunc(denormalize(gt_img)),  trunc(denormalize(predicted_img)))
        ssim_qd =calculate_ssim(trunc(denormalize(gt_img)), trunc(denormalize(noisy_img)))
        rmse_recon = calculate_rmse(trunc(denormalize(gt_img)),  trunc(denormalize(predicted_img)))
        rmse_qd=calculate_rmse(trunc(denormalize(gt_img)), trunc(denormalize(noisy_img)))

        psnr_original_mean += psnr_qd
        psnr_prediction_mean += psnr_recon
        
        ssim_original_mean += ssim_qd
        ssim_prediction_mean += ssim_recon

        mse_original_mean += rmse_qd
        mse_prediction_mean += rmse_recon
        
        i = i + 1        
    
    psnr_original_mean/=gt_array.shape[0]
    psnr_prediction_mean/=gt_array.shape[0]

    ssim_original_mean/=gt_array.shape[0]
    ssim_prediction_mean/=gt_array.shape[0]

    mse_original_mean/=gt_array.shape[0]
    mse_prediction_mean/=gt_array.shape[0]
    
    print("Original average gt-noisy PSNR ->", psnr_original_mean)
    print("Predicted average gt-predicted PSNR ->", psnr_prediction_mean)

    print("Original average gt-noisy SSIM ->", ssim_original_mean)
    print("Predicted average gt-predicted SSIM ->", ssim_prediction_mean)

    print("Original average gt-noisy MSE->", mse_original_mean)
    print("Predicted average gt-predicted MSE->", mse_prediction_mean)
    
    return round(psnr_prediction_mean, 4), round(ssim_prediction_mean, 4), round(mse_prediction_mean, 4), round(psnr_prediction_mean - psnr_original_mean, 4), round(ssim_prediction_mean - ssim_original_mean, 4), round(mse_prediction_mean - mse_original_mean, 4)


## Part IV : Evaluation of each model

## Model 1 : Hformer (for base reference)

In [None]:
sys.path.append('../denoising-models/hformer_vit/model/')
sys.path.append('../denoising-models/hformer_vit/')
from hformer_model_extended import get_hformer_model, PatchExtractor

hformer_model = get_hformer_model(num_channels_to_be_generated=64, name="hformer_model_extended")
hformer_model.build(input_shape=(None, 64, 64, 1))
hformer_model.load_weights('../denoising-models/hformer_vit/test/experiments/full_dataset/hformer_64_channel_custom_loss_epochs_48.h5')
print('Model summary : ')
print(hformer_model.summary())

In [None]:
def reconstruct_image_from_patches(patches, num_patches_per_row):
    patch_size = patches.shape[1]  # Assuming square patches
    num_patches = patches.shape[0]

    # Calculate the number of rows
    num_patches_per_col = num_patches // num_patches_per_row

    # Initialize an empty image to store the reconstructed result
    reconstructed_image = np.zeros((num_patches_per_col * patch_size, num_patches_per_row * patch_size))

    # Reshape the patches into a 2D array
    patches_2d = patches.reshape((num_patches_per_col, num_patches_per_row, patch_size, patch_size))
    # Reconstruct the image by placing each patch in its corresponding position

    for i in range(num_patches_per_col):
        for j in range(num_patches_per_row):
            reconstructed_image[i * patch_size:(i + 1) * patch_size, j * patch_size:(j + 1) * patch_size] = patches_2d[i, j]

    return np.expand_dims(reconstructed_image, axis=-1)

In [None]:
# View the predictions
patch_extractor = PatchExtractor(patch_size=64, stride=64, name="patch_extractor")
noisy_image_patches_array = patch_extractor(noisy_array)
extended_noisy_image_patches_array = patch_extractor(extended_noisy_array)


In [None]:

hformer_prediction_patches = hformer_model.predict(noisy_image_patches_array)
extended_hformer_prediction_patches = hformer_model.predict(extended_noisy_image_patches_array)

hformer_predictions = np.expand_dims(reconstruct_image_from_patches(hformer_prediction_patches[0:64], 8), axis=0)
extended_hformer_predictions = np.expand_dims(reconstruct_image_from_patches(extended_hformer_prediction_patches[0:64], 8), axis=0)

for i in range(1, int(hformer_prediction_patches.shape[0] / 64)): 
    reconstructed_image = reconstruct_image_from_patches(hformer_prediction_patches[i * 64 : i * 64 + 64], num_patches_per_row=8)
    reconstructed_image = np.expand_dims(reconstructed_image, axis=0)

    hformer_predictions = np.append(hformer_predictions, reconstructed_image, axis=0)

for i in range(1, int(extended_hformer_prediction_patches.shape[0] / 64)): 
    extended_reconstructed_image = reconstruct_image_from_patches(extended_hformer_prediction_patches[i * 64 : i * 64 + 64], num_patches_per_row=8)
    extended_reconstructed_image = np.expand_dims(extended_reconstructed_image, axis=0)

    extended_hformer_predictions = np.append(extended_hformer_predictions, extended_reconstructed_image, axis=0)
visualize_predictions(hformer_predictions, noisy_array, gt_array, len(gt_array), hformer_predictions, "hformer")

# Model 1 : WB Model

In [None]:

sys.path.append('../denoising-models/hformer_pytorch')
from torchinfo import summary

from w_b_model import WModel 

w_model = WModel(num_channels=8).cuda()
w_model.load_state_dict(torch.load('../denoising-models/novel_model_weights/wb_model_662.pth'))
w_model.eval()
print('model summary\n', summary(w_model, input_size=(64, 64, 64, 1)))

In [None]:
import numpy as np

from scipy.interpolate import SmoothBivariateSpline as SBS
from scipy.interpolate import LSQBivariateSpline as LBS
from scipy.ndimage.filters import maximum_filter
from scipy.ndimage.morphology import binary_erosion, generate_binary_structure


class EMD2D:
    def __init__(self, **config):
        self.mse_thr = 0.1
        self.mean_thr = 0.1

        self.FIXE = 0
        self.FIXE_H = 0

        self.MAX_ITERATION = 10000

    def __call__(self, image, max_imf=-1):
        return self.emd(image, max_imf=max_imf)

    def extract_max_min_spline(self, image):
        big_image = self.prepare_image(image)
        big_min_peaks, big_max_peaks = self.find_extrema(big_image)

        # Prepare grid for interpolation. 
        xi = np.arange(image.shape[0], image.shape[0] * 2)
        yi = np.arange(image.shape[1], image.shape[1] * 2)

        big_min_image_val = big_image[big_min_peaks]
        big_max_image_val = big_image[big_max_peaks]
        min_env = self.spline_points(big_min_peaks[0], big_min_peaks[1], big_min_image_val, xi, yi)
        max_env = self.spline_points(big_max_peaks[0], big_max_peaks[1], big_max_image_val, xi, yi)

        return min_env, max_env

    @classmethod
    def prepare_image(cls, image):
        shape = image.shape
        big_image = np.zeros((shape[0] * 3, shape[1] * 3))

        image_lr = np.fliplr(image)
        image_ud = np.flipud(image)
        image_ud_lr = np.flipud(image_lr)
        image_lr_ud = np.fliplr(image_ud)

        # Fill center with default image
        big_image[shape[0] : 2 * shape[0], shape[1] : 2 * shape[1]] = image

        # Fill left center
        big_image[shape[0] : 2 * shape[0], : shape[1]] = image_lr

        # Fill right center
        big_image[shape[0] : 2 * shape[0], 2 * shape[1] :] = image_lr

        # Fill center top
        big_image[: shape[0], shape[1] : shape[1] * 2] = image_ud

        # Fill center bottom
        big_image[2 * shape[0] :, shape[1] : 2 * shape[1]] = image_ud

        # Fill left top
        big_image[: shape[0], : shape[1]] = image_ud_lr

        # Fill left bottom
        big_image[2 * shape[0] :, : shape[1]] = image_ud_lr

        # Fill right top
        big_image[: shape[0], 2 * shape[1] :] = image_lr_ud

        # Fill right bottom
        big_image[2 * shape[0] :, 2 * shape[1] :] = image_lr_ud

        return big_image

    @classmethod
    def spline_points(cls, X, Y, Z, xi, yi):
        spline = SBS(X, Y, Z)

        return spline(xi, yi)

    @classmethod
    def find_extrema(cls, image):
        # define an 3x3 neighborhood
        neighborhood = generate_binary_structure(2, 2)

        # apply the local maximum filter; all pixel of maximal value
        # in their neighborhood are set to 1
        local_min = maximum_filter(-image, footprint=neighborhood) == -image
        local_max = maximum_filter(image, footprint=neighborhood) == image

        # can't distinguish between background zero and filter zero
        background = image == 0

        # appear along the bg border (artifact of the local max filter)
        eroded_background = binary_erosion(background, structure=neighborhood, border_value=1)

        # we obtain the final mask, containing only peaks,
        # by removing the background from the local_max mask (xor operation)
        min_peaks = local_min ^ eroded_background
        max_peaks = local_max ^ eroded_background

        # For the borders.
        min_peaks[[0, -1], :] = False
        min_peaks[:, [0, -1]] = False
        max_peaks[[0, -1], :] = False
        max_peaks[:, [0, -1]] = False


        # False is interpreted as zero...
        min_peaks = np.nonzero(min_peaks)
        max_peaks = np.nonzero(max_peaks)

        return min_peaks, max_peaks

    @classmethod
    def end_condition(cls, image, IMFs):
        rec = np.sum(IMFs, axis=0)

        # If reconstruction is perfect, no need for more tests
        if np.allclose(image, rec):
            return True

        return False

    def check_proto_imf(self, proto_imf, proto_imf_prev, mean_env):
        """Check whether passed (proto) IMF is actual IMF.
        Current condition is solely based on checking whether the mean is below threshold.
        """

        if np.all(np.abs(mean_env - mean_env.mean()) < self.mean_thr):
            return True

        # If very little change with sifting
        if np.allclose(proto_imf, proto_imf_prev):
            return True

        # If IMF mean close to zero (below threshold)
        if np.mean(np.abs(proto_imf)) < self.mean_thr:
            return True

        # Everything relatively close to 0
        mse_proto_imf = np.mean(proto_imf * proto_imf)
        if mse_proto_imf < self.mse_thr:
            return True

        return False

    def emd(self, image, max_imf=-1):
        image_min, image_max = np.min(image), np.max(image)
        offset = image_min
        scale = image_max - image_min

        image_s = (image - offset) / scale

        imf = np.zeros(image.shape)
        imf_old = imf.copy()

        imfNo = 0
        IMF = np.empty((imfNo,) + image.shape)
        notFinished = True

        while notFinished:

            res = image_s - np.sum(IMF[:imfNo], axis=0)
            imf = res.copy()
            mean_env = np.zeros(image.shape)
            stop_sifting = False

            # Counters
            n = 0  # All iterations for current imf.
            n_h = 0  # counts when mean(proto_imf) < threshold

            while not stop_sifting and n < self.MAX_ITERATION:
                n += 1

                min_peaks, max_peaks = self.find_extrema(imf)

                if len(min_peaks[0]) > 4 and len(max_peaks[0]) > 4:
                    imf_old = imf.copy()
                    imf = imf - mean_env

                    min_env, max_env = self.extract_max_min_spline(imf)

                    mean_env = 0.5 * (min_env + max_env)

                    imf_old = imf.copy()
                    imf = imf - mean_env

                    # Fix number of iterations
                    if self.FIXE:
                        if n >= self.FIXE + 1:
                            stop_sifting = True

                    # Fix number of iterations after number of zero-crossings
                    # and extrema differ at most by one.
                    elif self.FIXE_H:
                        if n == 1:
                            continue
                        if self.check_proto_imf(imf, imf_old, mean_env):
                            n_h += 1
                        else:
                            n_h = 0

                        # STOP if enough n_h
                        if n_h >= self.FIXE_H:
                            stop_sifting = True

                    # Stops after default stopping criteria are met
                    else:
                        if self.check_proto_imf(imf, imf_old, mean_env):
                            stop_sifting = True

                else:
                    notFinished = False
                    stop_sifting = True

            IMF = np.vstack((IMF, imf.copy()[None, :]))
            imfNo += 1

            if self.end_condition(image, IMF) or (max_imf > 0 and imfNo >= max_imf):
                notFinished = False
                break

        res = image_s - np.sum(IMF[:imfNo], axis=0)
        if not np.allclose(res, 0):
            IMF = np.vstack((IMF, res[None, :]))
            imfNo += 1

        IMF = IMF * scale
        IMF[-1] += offset
        return IMF


emd2d = EMD2D()
w_prediction_patches = []

with torch.no_grad():    
    for i, data in enumerate(noisy_image_patches_array):
        noisy = data
    
        predictions = w_model(torch.unsqueeze(torch.from_numpy(noisy.numpy()), dim=0).to('cuda')).cpu()

        w_prediction_patches.append(predictions.detach().cpu())
    
w_prediction_patches = np.concatenate(w_prediction_patches, axis=0)

w_predictions = np.expand_dims(reconstruct_image_from_patches(w_prediction_patches[0:64], 8), axis=0)


for i in range(1, int(w_prediction_patches.shape[0] / 64)): 
    reconstructed_image = reconstruct_image_from_patches(w_prediction_patches[i * 64 : i * 64 + 64], num_patches_per_row=8)
    reconstructed_image = np.expand_dims(reconstructed_image, axis=0)

    w_predictions= np.append(w_predictions, reconstructed_image, axis=0)

# In extended_w_predictions, do EMD

emd_predictions = [None] * w_predictions.shape[0]
for i in range(w_predictions.shape[0]):

    noisy_reshaped = np.squeeze(noisy_array[i], -1)
    print(noisy_reshaped.shape)
    noisy_imfs = emd2d.emd(noisy_reshaped,max_imf=-1)

    pred_reshaped = w_predictions[i]
    pred_reshaped = np.squeeze(pred_reshaped, axis=-1)
    pred_imfs = emd2d.emd(pred_reshaped, max_imf=-1)

    __ssim = 0
    best_performing_lerp_image = None

    for y in range(0, 100):
        x = y / 100.0
        swaped_IMFs = np.array([noisy_imfs[1] * x + pred_imfs[1] * (1.0 - x), pred_imfs[0] * (1.0 - x) + noisy_imfs[0] * x])
        predictions = torch.from_numpy(np.expand_dims(np.expand_dims(np.sum(swaped_IMFs, axis=0), -1), 0))

        _ssim = calculate_ssim(trunc(denormalize(gt_array[i])), trunc(denormalize(np.squeeze(predictions.detach().cpu().numpy(), axis=0))))
        print('index :: ', x, 'ssim :: ', _ssim)

        if _ssim > __ssim:
            __ssim = _ssim
            best_performing_lerp_image = predictions

    w_predictions[i] = best_performing_lerp_image

visualize_predictions(w_model, noisy_array, gt_array, len(gt_array), w_predictions, "w model")

In [None]:

extended_w_prediction_patches = []

with torch.no_grad():    
    for i, data in enumerate(extended_noisy_image_patches_array):
        noisy = data
    
        predictions = w_model(torch.unsqueeze(torch.from_numpy(noisy.numpy()), dim=0).to('cuda')).cpu()

        extended_w_prediction_patches.append(predictions.detach().cpu())
    
extended_w_prediction_patches = np.concatenate(extended_w_prediction_patches, axis=0)

extended_w_predictions = np.expand_dims(reconstruct_image_from_patches(extended_w_prediction_patches[0:64], 8), axis=0)


for i in range(1, int(extended_w_prediction_patches.shape[0] / 64)): 
    reconstructed_image = reconstruct_image_from_patches(extended_w_prediction_patches[i * 64 : i * 64 + 64], num_patches_per_row=8)
    reconstructed_image = np.expand_dims(reconstructed_image, axis=0)

    extended_w_predictions= np.append(extended_w_predictions, reconstructed_image, axis=0)

# In extended_w_predictions, do EMD

emd_extended_w_predictions = [None] * extended_w_predictions.shape[0]
for i in range(extended_w_predictions.shape[0]):

    noisy_reshaped = np.squeeze(extended_noisy_array[i], -1)
    print(noisy_reshaped.shape)
    noisy_imfs = emd2d.emd(noisy_reshaped,max_imf=-1)

    pred_reshaped = extended_w_predictions[i]
    pred_reshaped = np.squeeze(pred_reshaped, axis=-1)
    pred_imfs = emd2d.emd(pred_reshaped, max_imf=-1)

    __ssim = 0
    best_performing_lerp_image = None

    for y in range(1, 100):
        x = y / 100.0
        swaped_IMFs = np.array([noisy_imfs[1] * x + pred_imfs[1] * (1.0 - x), pred_imfs[0] * (1.0 - x) + noisy_imfs[0] * x])
        predictions = torch.from_numpy(np.expand_dims(np.expand_dims(np.sum(swaped_IMFs, axis=0), -1), 0))

        _ssim = calculate_ssim(trunc(denormalize(extended_gt_array[i])), trunc(denormalize(np.squeeze(predictions.detach().cpu().numpy(), axis=0))))
        print('index :: ', x, 'ssim :: ', _ssim)

        if _ssim > __ssim:
            __ssim = _ssim
            best_performing_lerp_image = predictions

    emd_extended_w_predictions[i] = best_performing_lerp_image


In [None]:
emd_numpy_w_pred = []

for x in emd_extended_w_predictions:
    emd_numpy_w_pred.append(np.squeeze(x.numpy(), 0))

emd_numpy_w_pred = np.array(emd_numpy_w_pred)

print(extended_gt_array.shape, emd_numpy_w_pred.shape)

In [None]:

extended_w_metrics = get_average_metrics(emd_numpy_w_pred, extended_gt_array, extended_noisy_array)

In [None]:
extended_w_metrics

## Part V : Side by side comparison of all models

In [None]:
from prettytable import PrettyTable

pt = PrettyTable()
pt.field_names = ["Model", "PSNR", "SSIM", "MSE", "PSNR Improvement", "SSIM improvement", "MSE Improvement", "Num Parameter"]

#hformer_metrics = get_average_metrics(hformer_predictions, gt_array, noisy_array)
#x_metrics= get_average_metrics(np.concatenate(x_model_prediction_patches, axis=0), gt_array, noisy_array)
#y_metrics= get_average_metrics(np.concatenate(y_model_prediction_patches, axis=0), gt_array, noisy_array)
#sa_metrics = get_average_metrics(sa_predictions, gt_array, noisy_array)
#z_metrics = get_average_metrics(z_predictions, gt_array, noisy_array)
#w_metrics = get_average_metrics(w_predictions, gt_array, noisy_array)





In [None]:
hformer_metrics = extended_hformer_metrics
x_metrics = extended_x_metrics
y_metrics = extended_y_metrics
sa_metrics = extended_sa_metrics
z_metrics = extended_z_metrics
w_metrics = extended_w_metrics

pt.add_row(["Original X-y pairs (No Model)","21.97","0.78911", '-', '-', "-", "-", "-"])
pt.add_row(["Hformer",str(hformer_metrics[0]), str(hformer_metrics[1]), str(round(hformer_metrics[2], 4)), str(round(hformer_metrics[3], 4)), str(round(hformer_metrics[4] * 100, 4)), str(hformer_metrics[5]) + '%', '1,511,681'])
pt.add_row(["XB Model",str(x_metrics[0]), str(x_metrics[1]), str(round(x_metrics[2], 4)), str(round(x_metrics[3], 4)), str(round(x_metrics[4] * 100, 4)), str(x_metrics[5]) + '%', '1306'])
pt.add_row(["YB Model",str(y_metrics[0]), str(y_metrics[1]), str(round(y_metrics[2], 4)), str(round(y_metrics[3], 4)), str(round(y_metrics[4] * 100, 4)), str(y_metrics[5]) + '%', '145,453'])
pt.add_row(["SAB Model",str(sa_metrics[0]), str(sa_metrics[1]), str(round(sa_metrics[2], 4)), str(round(sa_metrics[3], 4)), str(round(sa_metrics[4] * 100, 4)), str(sa_metrics[5]) + '%', '15,345'])
pt.add_row(["ZB Model",str(z_metrics[0]), str(z_metrics[1]), str(round(z_metrics[2], 4)), str(round(z_metrics[3], 4)), str(round(z_metrics[4] * 100, 4)), str(z_metrics[5]) + '%', '245,905'])
pt.add_row(["WB Model",str(w_metrics[0]), str(w_metrics[1]), str(round(w_metrics[2], 4)), str(round(w_metrics[3], 4)), str(round(w_metrics[4] * 100, 4)), str(w_metrics[5]) + '%', '977,233'])

print(pt)

## Part 6 : Output of predictions of all 4 models side by side for direct visualize comparison

In [None]:
def visualize_predictions_all_models(X_test, y_test, n, hformer_predictions, x_predictions, y_predictions, sa_predictions, z_predictions, w_predictions):
    random_numbers = list(range(n))  # not very random
    for i in random_numbers:
        noisy_image = X_test[i]
        gt_image = y_test[i]

        hformer_pred = hformer_predictions[i]
        x_pred = x_predictions[i]
        y_pred = y_predictions[i]
        sa_pred = sa_predictions[i]
        z_pred = z_predictions[i]
        w_pred = w_predictions[i]

        models = ["HFORMER", "X", "Y", "SA",  "Z", "W"]
        predictions = [hformer_pred, x_pred, y_pred, sa_pred, z_pred, w_pred]

        # Display QD and FD images
        f, axarr = plt.subplots(1, 2 + len(models), figsize=(41,41))

        psnr_qd =  calculate_psnr(trunc(denormalize(gt_image)),  trunc(denormalize(noisy_image)))
        ssim_qd =calculate_ssim(trunc(denormalize(gt_image)), trunc(denormalize(noisy_image)))
        rmse_qd = calculate_rmse(trunc(denormalize(gt_image)), trunc(denormalize(noisy_image)))

        axarr[0].imshow(trunc(denormalize(noisy_image)), cmap='gray', vmin=-160.0, vmax=240.0)
        axarr[0].set_title("QD Image\nPSNR={}\nSSIM={}\nMSE={}".format(round(psnr_qd, 4), round(ssim_qd, 4), round(rmse_qd, 4)))

        axarr[0].set_axis_off()
        axarr[1].imshow(trunc(denormalize(gt_image)), cmap='gray', vmin=-160.0, vmax=240.0)
        axarr[1].set_title("FD Image")
        axarr[1].set_axis_off()

        for j, (model_name, predicted_image) in enumerate(zip(models, predictions), start=2):
            if predicted_image.shape[-1] == 3:
                predicted_image = rgb2gray(predicted_image)

            psnr_recon = calculate_psnr(trunc(denormalize(gt_image)), trunc(denormalize(predicted_image)))
            ssim_recon = calculate_ssim(trunc(denormalize(gt_image)), trunc(denormalize(predicted_image)))
            rmse_recon = calculate_rmse(trunc(denormalize(gt_image)), trunc(denormalize(predicted_image)))

            psnr_recon = round(psnr_recon, 4)
            ssim_recon = round(ssim_recon, 4)
            mse_recon = round(rmse_recon, 4)

            axarr[j].imshow(trunc(denormalize(predicted_image)), cmap='gray', vmin=-160.0, vmax=240.0)
            axarr[j].set_title("{}\nPSNR={}\nSSIM={}\nMSE={}".format(model_name, psnr_recon, ssim_recon, mse_recon))
            axarr[j].set_axis_off()

        plt.savefig('../../output/novel_comparison_b_emd/combined_outputs_image_index_{}.png'.format(i))
        plt.show()

In [None]:

visualize_predictions_all_models(noisy_array, gt_array, len(gt_array), hformer_predictions, np.concatenate(x_model_prediction_patches, axis=0), np.concatenate(y_model_prediction_patches, axis=0), sa_predictions, z_predictions, w_predictions)