### Model

In [None]:
from DnCNN_PyTorch.models import DnCNN
channels = 1
num_of_layers = 5
input_size = (channels, 256, 256)
model = DnCNN(channels=channels, num_of_layers=num_of_layers, do_fuse=True)

### QConfig

In [None]:
from enum import Enum
import math

import torch
import torch.ao.quantization as tq


class PowerOfTwoObserver(tq.MinMaxObserver):
    """
    Observer module for power-of-two quantization (dyadic quantization with b = 1).
    """
    def scale_approximate(self, scale: float, max_shift_amount=8) -> float:
        # Finding the nearest power of two by converting the scale to its binary representation
        scale_log2 = torch.ceil(torch.log2(torch.tensor(scale)))
        scale_log2 = torch.clamp(scale_log2, max=max_shift_amount)
        power_of_two_scale = 2 ** scale_log2

        return power_of_two_scale

    def calculate_qparams(self):
        """Calculates the quantization parameters with scale as power of two."""
        min_val, max_val = self.min_val.item(), self.max_val.item()

        """ Calculate zero_point as in the base class """
        # Compute scale
        scale = max(abs(min_val), abs(max_val)) / (2**7 - 1)  # For 8-bit symmetric quantization

        if(self.dtype == torch.qint8):
          zero_point = 0
        else :
          zero_point = 128

        scale = self.scale_approximate(scale)
        scale = torch.tensor(scale, dtype=torch.float32)
        zero_point = torch.tensor(zero_point, dtype=torch.int64)

        return scale, zero_point

    def extra_repr(self):
        return f"min_val={self.min_val}, max_val={self.max_val}, scale=PowerOfTwo"
    
class CustomQConfig(Enum):
    POWER2 = tq.QConfig(
        activation=PowerOfTwoObserver.with_args(
            dtype=torch.quint8, qscheme=torch.per_tensor_symmetric
        ),
        weight=PowerOfTwoObserver.with_args(
            dtype=torch.qint8, qscheme=torch.per_tensor_symmetric
        ),
    )
    DEFAULT = None

### Quantization

In [None]:
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import glob
import os

class ImageFolderDataset(Dataset):
    def __init__(self, image_dir, transform=None, noise_std=25.0):
        self.image_paths = sorted(glob.glob(os.path.join(image_dir, "*.png")))
        self.transform = transform
        self.noise_std = noise_std / 255.0
    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert("L")
        if self.transform:
            clean_tensor = self.transform(img)
        else:
            clean_tensor = transforms.ToTensor()(img)
        noise = torch.FloatTensor(clean_tensor.size()).normal_(mean=0, std=self.noise_std)
        noisy_tensor = clean_tensor + noise
        return noisy_tensor

def get_calibration_loader(image_dir="./data/train", image_size=40, batch_size=1, noise_std=25.0):
    transform = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor()
    ])
    dataset = ImageFolderDataset(image_dir, transform=transform, noise_std=noise_std)
    return DataLoader(dataset, batch_size=batch_size, shuffle=False)

def preprocess_filename(filename: str, existed: str = "keep_both") -> str:
    if existed == "overwrite":
        pass
    elif existed == "keep_both":
        base, ext = os.path.splitext(filename)
        cnt = 1
        while os.path.exists(filename):
            filename = f"{base}-{cnt}{ext}"
            cnt += 1
    elif existed == "raise" and os.path.exists(filename):
        raise FileExistsError(f"{filename} already exists.")
    else:
        raise ValueError(f"Unknown value for 'existed': {existed}")
    return filename

def save_model(
    model, filename: str, verbose: bool = True, existed: str = "keep_both"
) -> None:
    filename = preprocess_filename(filename, existed)

    os.makedirs(os.path.dirname(filename), exist_ok=True)
    torch.save(model.state_dict(), filename)
    if verbose:
        print(f"Model saved at {filename} ({os.path.getsize(filename) / 1e6} MB)")
    else:
        print(f"Model saved at {filename}")

In [None]:
import torch
import torch.ao.quantization as tq
from DnCNN_PyTorch.models import DnCNN

""" Calibrate Method """
def calibrate(model, loader, device="cpu"):
    model.eval().to(device)   
    for x in loader:       
        model(x.to(device))  
        break     

""" Load Pretrained Model """
model_path = "./model/DnCNN_layer5.pt"
channels = 1
num_of_layers = 5

Pretrained_model = DnCNN(channels=channels, num_of_layers=num_of_layers)
Pretrained_model.eval() 
Pretrained_model.cpu()

state_dict = torch.load(model_path, map_location='cpu')

from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k.replace('module.', '')
    new_state_dict[name] = v
Pretrained_model.load_state_dict(new_state_dict)

""" Fuse Modules """
Pretrained_model.fuse_layers()

""" Configure Quantization """
fused_model = tq.QuantWrapper(Pretrained_model)
fused_model.qconfig = CustomQConfig.POWER2.value 
print(f"Quantization backend: {fused_model.qconfig}")

""" Apply Quantization Preparation """
tq.prepare(fused_model, inplace=True)

""" Calibration """
calibrate(fused_model, get_calibration_loader(image_dir="./DnCNN_PyTorch/data/train"))

""" Convert Model to Quantized Version """
tq.convert(fused_model.cpu(), inplace=True)

""" Save Quantized Model """
quantized_model_path = "./DnCNN_5_layers_int8.pt"
save_model(fused_model, quantized_model_path, existed="overwrite")


### Test

In [None]:
import torch.ao.quantization as tq
from DnCNN_PyTorch.models import DnCNN

channels = 1
num_of_layers = 5
input_size = (channels, 256, 256)
backend = "power2"

quantized_model = DnCNN(channels=channels, num_of_layers=num_of_layers)
quantized_model.eval()
quantized_model.cpu()
quantized_model.fuse_layers()
quantized_model = tq.QuantWrapper(quantized_model)
qconfig = CustomQConfig["POWER2"].value
quantized_model.qconfig = qconfig
tq.prepare(quantized_model, inplace=True)
tq.convert(quantized_model, inplace=True)

quantized_model_path = "./model/DnCNN_5_layers_int8.pt"

quantized_model.load_state_dict(torch.load(quantized_model_path, map_location="cpu"))
print(quantized_model)

In [None]:
import cv2
import os
import glob
import numpy as np
import torch
from DnCNN_PyTorch.utils import batch_PSNR

def normalize(data):
    return data / 255.0
device = torch.device('cpu')
quantized_model.to(device).eval()

test_data = "Set12"
input_dir  = os.path.join("DnCNN_PyTorch", "data", test_data) 
num_layers = "layer5"     
    
output_dir = os.path.join("results",num_layers ,test_data)
os.makedirs(output_dir, exist_ok=True)

print('Loading test images...\n')
files = sorted(glob.glob(os.path.join(input_dir, '*.png')))

psnr_sum = 0.0
test_noiseL = 25.0

for f in files:
    name = os.path.basename(f)
    gray = cv2.imread(f, cv2.IMREAD_GRAYSCALE).astype(np.float32)
    img  = normalize(gray)
    ISource = torch.from_numpy(img).unsqueeze(0).unsqueeze(0).to(device)
    torch.manual_seed(0)
    noise = torch.randn_like(ISource) * (test_noiseL / 255.0)
    INoisy = ISource + noise
    with torch.no_grad():
        denoised = quantized_model(INoisy)
        Out = torch.clamp(INoisy - denoised, 0.0, 1.0)
    psnr = batch_PSNR(Out, ISource, data_range=1.0)
    psnr_sum += psnr
    print(f"{name}  PSNR: {psnr:.4f}")

avg_psnr = psnr_sum / len(files)
print(f"\nAverage PSNR on test data: {avg_psnr:.4f}")
print(f"Denoised images saved to {output_dir}")


### Test original


In [None]:
import torch.ao.quantization as tq
from DnCNN_PyTorch.models import DnCNN

model_path = "./model/DnCNN_5_layers.pt"
channels = 1
num_of_layers = 5

model = DnCNN(channels=channels, num_of_layers=num_of_layers)
model.eval() 
model.cpu()
state_dict = torch.load(model_path, map_location='cpu')
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k.replace('module.', '') 
    new_state_dict[name] = v
model.load_state_dict(new_state_dict)


In [None]:
import cv2
import os
import glob
import numpy as np
import torch
from DnCNN_PyTorch.utils import batch_PSNR

def normalize(data):
    return data / 255.0

device = torch.device('cpu')
model.to(device).eval()

test_data = "Set12"
input_dir  = os.path.join("DnCNN_PyTorch", "data", test_data)
output_dir = os.path.join("results", test_data)
os.makedirs(output_dir, exist_ok=True)

print('Loading test images...\n')
files = sorted(glob.glob(os.path.join(input_dir, '*.png')))

psnr_sum = 0.0
test_noiseL = 25.0

for f in files:
    name = os.path.basename(f)
    gray = cv2.imread(f, cv2.IMREAD_GRAYSCALE).astype(np.float32)
    img  = normalize(gray)
    ISource = torch.from_numpy(img).unsqueeze(0).unsqueeze(0).to(device)
    torch.manual_seed(0)
    noise = torch.randn_like(ISource) * (test_noiseL / 255.0)
    INoisy = ISource + noise
    with torch.no_grad():
        denoised = model(INoisy)
        Out = torch.clamp(INoisy - denoised, 0.0, 1.0)
    psnr = batch_PSNR(Out, ISource, data_range=1.0)
    psnr_sum += psnr
    print(f"{name}  PSNR: {psnr:.4f}")

avg_psnr = psnr_sum / len(files)
print(f"\nAverage PSNR on test data: {avg_psnr:.4f}")
