### Import Path

In [None]:
import os
import sys

parent_dir = os.path.abspath("..")
if parent_dir not in sys.path:
    sys.path.append(parent_dir)

In [None]:
import glob
from collections import OrderedDict
from enum import Enum

import numpy as np
import torch
import torch.ao.quantization as tq
from config import Config
from model import DnCNN
from PIL import Image
from utils import batch_PSNR

### QConfig

In [None]:
class PowerOfTwoObserver(tq.MinMaxObserver):
    """
    Observer module for power-of-two quantization (dyadic quantization with b = 1).
    """

    def scale_approximate(self, scale: float, max_shift_amount=8) -> float:
        # Finding the nearest power of two by converting the scale to its binary representation
        scale_log2 = torch.ceil(torch.log2(torch.tensor(scale)))
        scale_log2 = torch.clamp(scale_log2, max=max_shift_amount)
        power_of_two_scale = 2**scale_log2

        return power_of_two_scale

    def calculate_qparams(self):
        """
        Calculates the quantization parameters with scale as power of two.
        """
        min_val, max_val = self.min_val.item(), self.max_val.item()

        """ Calculate zero_point as in the base class """
        # Compute scale
        scale = max(abs(min_val), abs(max_val)) / (
            2**7 - 1
        )  # For 8-bit symmetric quantization

        if self.dtype == torch.qint8:
            zero_point = 0
        else:
            zero_point = 128

        scale = self.scale_approximate(scale)
        zero_point = torch.tensor(zero_point, dtype=torch.int64)

        return scale, zero_point

    def extra_repr(self):
        return f"min_val={self.min_val}, max_val={self.max_val}, scale=PowerOfTwo"


class CustomQConfig(Enum):
    POWER2 = tq.QConfig(
        activation=PowerOfTwoObserver.with_args(
            dtype=torch.quint8, qscheme=torch.per_tensor_symmetric
        ),
        weight=PowerOfTwoObserver.with_args(
            dtype=torch.qint8, qscheme=torch.per_tensor_symmetric
        ),
    )
    DEFAULT = None


def normalize(data):
    return data / 255.0

### Test Quantization Model

In [None]:
channels = 1
num_of_layers = 5
input_size = (channels, 256, 256)
backend = "power2"

quantized_model = DnCNN(channels=channels, num_of_layers=num_of_layers)
quantized_model.eval()
quantized_model.cpu()
quantized_model.fuse_layers()
quantized_model = tq.QuantWrapper(quantized_model)
qconfig = CustomQConfig["POWER2"].value
quantized_model.qconfig = qconfig
tq.prepare(quantized_model, inplace=True)
tq.convert(quantized_model, inplace=True)

quantized_model.load_state_dict(torch.load(Config.quantized_model_path, map_location="cpu"))
print(quantized_model)

In [None]:
device = torch.device("cpu")
quantized_model.to(device).eval()

test_data = "Set12"
input_dir = os.path.join(Config.base_dir, "data", test_data)
num_layers = "layer5"

output_dir = os.path.join("output", "int8")
os.makedirs(output_dir, exist_ok=True)

print("Loading test images...\n")
files = sorted(glob.glob(os.path.join(input_dir, "*.png")))

psnr_sum = 0.0
test_noiseL = 25.0

for f in files:
    torch.manual_seed(0)

    name = os.path.basename(f)
    gray = Image.open(f).convert("L")
    img = normalize(np.array(gray, dtype=np.float32))
    ISource = torch.from_numpy(img).unsqueeze(0).unsqueeze(0).to(device)
    noise = torch.randn_like(ISource) * (test_noiseL / 255.0)
    INoisy = ISource + noise

    with torch.no_grad():
        denoised = quantized_model(INoisy)
        out = torch.clamp(INoisy - denoised, 0.0, 1.0)

    psnr = batch_PSNR(out, ISource, data_range=1.0)
    psnr_sum += psnr
    print(f"{name}  PSNR: {psnr:.4f}")

    # save denoised image
    out = out.squeeze().cpu().numpy() * 255.0
    out = np.clip(out, 0, 255).astype(np.uint8)
    output_path = os.path.join(output_dir, name)
    Image.fromarray(out).save(output_path)

avg_psnr = psnr_sum / len(files)
print(f"\nAverage PSNR on test data: {avg_psnr:.4f}")
print(f"Denoised images saved to {output_dir}")

### Test Full Precision Model

In [None]:
channels = 1
num_of_layers = 5

model = DnCNN(channels=channels, num_of_layers=num_of_layers)
model.eval()
model.cpu()
state_dict = torch.load(Config.model_path, map_location="cpu")
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k.replace("module.", "")
    new_state_dict[name] = v
model.load_state_dict(new_state_dict)

In [None]:
device = torch.device("cpu")
model.to(device).eval()

test_data = "Set12"
input_dir = os.path.join(Config.base_dir, "data", test_data)
output_dir = os.path.join("output", "float32")
os.makedirs(output_dir, exist_ok=True)

print("Loading test images...\n")
files = sorted(glob.glob(os.path.join(input_dir, "*.png")))

psnr_sum = 0.0
test_noiseL = 25.0

for f in files:
    torch.manual_seed(0)

    name = os.path.basename(f)
    gray = Image.open(f).convert("L")
    img = normalize(np.array(gray, dtype=np.float32))
    ISource = torch.from_numpy(img).unsqueeze(0).unsqueeze(0).to(device)
    noise = torch.randn_like(ISource) * (test_noiseL / 255.0)
    INoisy = ISource + noise

    with torch.no_grad():
        denoised = model(INoisy)
        out = torch.clamp(INoisy - denoised, 0.0, 1.0)

    psnr = batch_PSNR(out, ISource, data_range=1.0)
    psnr_sum += psnr
    print(f"{name}  PSNR: {psnr:.4f}")

    # save denoised image
    out = out.squeeze().cpu().numpy() * 255.0
    out = np.clip(out, 0, 255).astype(np.uint8)
    output_path = os.path.join(output_dir, name)
    Image.fromarray(out).save(output_path)

avg_psnr = psnr_sum / len(files)
print(f"\nAverage PSNR on test data: {avg_psnr:.4f}")
print(f"Denoised images saved to {output_dir}")