#### Image Denoising using ContourletSD: Demo

In [None]:
# Import libraries.
import sys
sys.path.append('..')

import matplotlib.pyplot as plt
import torch
from torchvision.io import read_image

from contourletSD_pytorch.contourlet_ops import (get_norm_scaling,
                                                 hard_thresholding)
from contourletSD_pytorch.contourlet_sd import ContourletSD
from contourletSD_pytorch.contourlet_sd_dec import VALID_INPUT_PRECISION
from utils.smooth_functions import VALID_SMOOTH_FUNC
import piqa

In [None]:
# Set configuration.
nlev_SD = [2, 2, 3, 4, 5]
Pyr_mode = 1
smooth_func = 'rcos'
dfilt = 'pkva'
sigma = 30.0
input_precision = 'single'
SDmm_dir = '../extras/sdmm_matlab'

# Color input.
color_images = True
image_path = '../datasets/test_images/peppers_color.png'

In [None]:
# Get the contourlets operator.
contourlet_sd = ContourletSD(
    nlevs=nlev_SD,
    Pyr_mode=Pyr_mode,
    smooth_func=VALID_SMOOTH_FUNC[smooth_func],
    dfilt=dfilt,
    color_mode='rgb' if color_images else 'gray',
)

In [None]:
# Read a test image.
X = read_image(image_path).unsqueeze(0)
X = X.to(VALID_INPUT_PRECISION[input_precision])

# Add Gaussian noise.
Xn = X + sigma * torch.randn_like(X)

In [None]:
# Load pre-computed norm scaling factors for each subband (for thresholding purposes).
E = get_norm_scaling(
    image_size=Xn.shape[-1],
    SDmm_dir=SDmm_dir,
    Pyr_mode=Pyr_mode,
    device=Xn.device,
)

In [None]:
# Get contourlet coefficients.
Y = contourlet_sd(x=Xn)

# Apply hard thresholding on coefficients.
Yth = hard_thresholding(
    y=Y,
    sigma=sigma,
    E=E,
)

# Reconstruct image.
Xd = contourlet_sd(
    x=Yth,
    reconstruct=True,
)

In [None]:
# Compute PSNR and SSIM.
psnr = piqa.PSNR(value_range=255)
psnr_n = psnr(Xn.clamp(0, 255), X.clamp(0, 255))
psnr_d = psnr(Xd.clamp(0, 255), X.clamp(0, 255))

ssim = piqa.SSIM(value_range=255)
ssim_n = ssim(Xn.clamp(0, 255), X.clamp(0, 255))
ssim_d = ssim(Xd.clamp(0, 255), X.clamp(0, 255))

In [None]:
# Show results.
f01 = plt.figure(figsize=(30, 10))
for idx, (tag, img) in enumerate(zip([
  'Original', f'Observation\n(PSNR: {psnr_n:.2f} dB, SSIM: {ssim_n:.2f})',
  f'ContourletSD Reconstruction\n(PSNR: {psnr_d:.2f} dB, SSIM: {ssim_d:.2f})',
  ], [X, Xn, Xd])):
  if color_images:
    img_ = img[0, ...].permute(1, 2, 0).clamp(0, 255) / 255 
  else:
    img_ = img[0, 0, ...].clamp(0, 255)
  plt.subplot(1, 3, idx + 1)
  plt.imshow(img_, cmap='gray')
  plt.axis('off')
  plt.title(tag, fontsize=24)