# Beyond a gaussian denoiser: Residual learning of deep cnn for image denoising
This is a implementation of post-reconstruction denosing using the exact network proposed by 
https://paperswithcode.com/paper/beyond-a-gaussian-denoiser-residual-learning

In [None]:
# import liberies
import os 
import time

import numpy as np
import matplotlib.pyplot as plt

import pandas as pd

# import cv2
from PIL import Image

#
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR

# about dataset 
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms

# 
from torch.utils.tensorboard import SummaryWriter
from torchsummary import summary

# test metrics
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim

# import external liberaries
import importlib
import data_utils
importlib.reload(data_utils)
from data_utils import show_images_grid, show_error_profile, filterd_back_projection

# set seed
np.random.seed(42)
torch.manual_seed(42)

# from experiments.pytorchtools import EarlyStopping

## define DnCNN

In [None]:
# define neural network
class DnCNN(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, num_layers=17, num_features=64):
        super(DnCNN, self).__init__()

        layers = [
            nn.Conv2d(in_channels, num_features, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        ]

        for _ in range(num_layers-2):
            layers.extend([
                nn.Conv2d(num_features, num_features, kernel_size=3, padding=1),
                nn.BatchNorm2d(num_features),
                nn.ReLU(inplace=True)
            ])

        layers.append(nn.Conv2d(num_features, out_channels, kernel_size=3, padding=1))

        self.dncnn = nn.Sequential(*layers)

    def forward(self, x):
        return x - self.dncnn(x)

## define custome dataset

In [None]:
# define custome dataset
class WaterlooPairDataset(Dataset):
    def __init__(self, clean_dir, noisy_dir, transform=None):
        self.transform = transform
        self.clean_dir = clean_dir
        self.noisy_dir = noisy_dir
        
        # scan image file
        clean_sinograms = sorted([f for f in os.listdir(self.clean_dir) if f.endswith(('.png', '.jpg', '.jpeg', '.bmp'))])
        noisy_sinograms = sorted([f for f in os.listdir(self.noisy_dir) if f.endswith(('.png', '.jpg', '.jpeg', '.bmp'))])

        assert len(clean_sinograms) == len(noisy_sinograms), \
            "Number of clean sinograms and noisy sinograms should be equal"
        
        self.sinogram_pairs = [(os.path.join(clean_dir, c), os.path.join(noisy_dir, n)) for c, n in zip(clean_sinograms, noisy_sinograms)]
        
    def __len__(self):
        return len(self.sinogram_pairs)
    
    def __getitem__(self, index):
        clean_path, noisy_path = self.sinogram_pairs[index]

        # clean_sinogram = cv2.imread(clean_path, cv2.IMREAD_GRAYSCALE) 
        # noisy_sinogram = cv2.imread(noisy_path, cv2.IMREAD_GRAYSCALE)
        
        clean_sinogram = Image.open(clean_path).convert('L')
        noisy_sinogram = Image.open(noisy_path).convert('L')
        
        ## apply reconstruction algorithm
        clean_recon = filterd_back_projection(np.array(clean_sinogram))
        noisy_recon = filterd_back_projection(np.array(noisy_sinogram))
            
        # Convert recon images to float tensors
        #clean_recon = torch.from_numpy(clean_recon).float() / 255.0
        #noisy_recon = torch.from_numpy(noisy_recon).float() / 255.0
        
        # Normalize the recon images
        clean_recon = (clean_recon - np.min(clean_recon)) / (np.max(clean_recon) - np.min(clean_recon))
        noisy_recon = (noisy_recon - np.min(noisy_recon)) / (np.max(noisy_recon) - np.min(noisy_recon))
        
        
        if self.transform is not None:
            clean_recon = self.transform(Image.fromarray(clean_recon))
            noisy_recon = self.transform(Image.fromarray(noisy_recon))

        return clean_recon, noisy_recon

## create dataloaders

In [None]:
def create_dataloaders(clean_dir, noisy_dir, transform=None, batch_size=32, num_workers=4):
    dataset = WaterlooPairDataset(clean_dir, noisy_dir, transform)

    # calculate dataset length
    total_len = len(dataset)
    train_len = int(0.8 * total_len)
    val_len = int(0.1 * total_len)
    test_len = total_len - train_len - val_len

    # random_split dataset
    train_dataset, val_dataset, test_dataset = random_split(dataset, [train_len, val_len, test_len])

    # create dataloader
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    
    print(f"train:{len(train_loader.dataset)}, val:{len(val_loader.dataset)}, test:{len(test_loader.dataset)}")
    
    return train_loader, val_loader, test_loader

In [None]:
# class AddChannelDimension:
#     def __call__(self, img):
#         return np.expand_dims(img, axis=0)
    

# define data transform
data_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])


# define input dir
clean_dir = 'data/exploration_database_and_code/clean'
noisy30_dir = 'data/exploration_database_and_code/noisy30'
noisy25_dir = 'data/exploration_database_and_code/noisy25'
noisy20_dir = 'data/exploration_database_and_code/noisy20'
noisy15_dir = 'data/exploration_database_and_code/noisy15'
noisy10_dir = 'data/exploration_database_and_code/noisy10'


# define dataloader setting
batch_size = 32
num_workers = 12

# create niosy30 datasetloader 
train_loader, val_loader, test_loader = create_dataloaders(clean_dir, noisy10_dir, transform=data_transform, batch_size=batch_size, num_workers=num_workers)

In [None]:
# show training batch
dataiter = iter(train_loader)

batch_clean_tensor, batch_noisy_tensor = dataiter.next()
batch_clean_recons, batch_noisy_recons = batch_clean_tensor.numpy(), batch_noisy_tensor.numpy()

#
print(f"batch shape : {batch_clean_recons.shape}")
print(f"feed data, range: {np.min(batch_clean_recons[0])} {np.max(batch_clean_recons[0])}")

#
show_images_grid(np.squeeze(batch_clean_recons, axis=1), cmap='gray', figsize=(15, 15), suptitle='clean recon batch')
show_images_grid(np.squeeze(batch_noisy_recons, axis=1), cmap='gray', figsize=(15, 15), suptitle='noisy recon batch')

In [None]:
# make sure clean and noisy recon at the similar brightness
idx = 0
clean_recon = np.squeeze(batch_clean_recons, axis=1)[idx]
noisy_recon = np.squeeze(batch_noisy_recons, axis=1)[idx]

#
print(f"clean_recon, mean: {np.mean(clean_recon):.5f}")
print(f"noisy_recon, mean: {np.mean(noisy_recon):.5f}")

# 
show_error_profile(clean_recon, noisy_recon, suptitle="clean_recon vs noisy_recon")

## model training

In [None]:
# traiing function
def train(train_loader, model, criterion, optimizer, device):
    model.train()
    running_loss = 0.0

    for batch_idx, (clean_img, noisy_img) in enumerate(train_loader):
        clean_img, noisy_img = clean_img.to(device), noisy_img.to(device)
        optimizer.zero_grad()
        outputs = model(noisy_img)
        loss = criterion(outputs, clean_img)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    return running_loss / len(train_loader)

# validation function
def validate(val_loader, model, criterion, device):
    model.eval()
    running_loss = 0.0

    with torch.no_grad():
        for batch_idx, (clean_img, noisy_img) in enumerate(val_loader):
            clean_img, noisy_img = clean_img.to(device), noisy_img.to(device)
            outputs = model(noisy_img)
            loss = criterion(outputs, clean_img)
            running_loss += loss.item()

    return running_loss / len(val_loader)

In [None]:
# set hyperparameters
learning_rate = 0.01 # SGD
momentum = 0.9 # SGD
step_size = 10 # StepLR
gamma = 0.1 # StepLR
epochs = 50 # training epochs
# patience = 10 # earlystop patience

In [None]:
# create model 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = DnCNN().to(device)
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum)
scheduler = StepLR(optimizer, step_size=step_size, gamma=gamma)
# early_stopping = EarlyStopping(patience=patience, verbose=True)

print(f"device: {device}")

In [None]:
# summary model
summary(model, (1, 256, 256))

In [None]:
# # writer = SummaryWriter("tf-logs/runs/DnCNN-noisy10-MSE-50epoch")

# # training loop
# start_time = time.time()
# for epoch in range(epochs):
#     train_loss = train(train_loader, model, criterion, optimizer, device)
#     val_loss = validate(val_loader, model, criterion, device)
#     print(f'Epoch: {epoch+1}/{epochs}, Train Loss: {train_loss:.6f}, Val Loss: {val_loss:.6f}')
    
#     # Log the losses to TensorBoard
#     writer.add_scalars("Losses", {"Train": train_loss, "Val": val_loss}, epoch)

#     # early_stopping(val_loss, model)
#     # if early_stopping.early_stop:
#     #     print("Early stopping")
#     #     break

#     scheduler.step()
    
    
# end_time = time.time()
# total_second = int(end_time - start_time)

In [None]:
# Save model's weight
# torch.save(model.state_dict(), 'checkpoints/DnCNN-noisy10-MSE-50epoch.pt')

In [None]:
# send result to wechat 

# import requests
# headers = {"Authorization": "eyJhbGciOiJFUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1aWQiOjM5MjI4LCJ1dWlkIjoiNTAyZTcyM2ItZjY2Mi00YTk4LWJkZmEtMzc1ZjdlOWM5NmFlIiwiaXNfYWRtaW4iOmZhbHNlLCJpc19zdXBlcl9hZG1pbiI6ZmFsc2UsInN1Yl9uYW1lIjoiIiwidGVuYW50IjoiYXV0b2RsIiwidXBrIjoiIn0.0IybMXdA-3z6KDYJDDGCj-_qqw6o4kya5usOFcLUtFL-ewBe35RnN8COQn4lO3umL-rWJ3er2PsIWZBjIl5XJw"}
# resp = requests.post("https://www.autodl.com/api/v1/wechat/message/send",
#                      json={
#                          "title": "训练结束",
#                          "name": "DnCNN-noisy10-MSE-50epoch",
#                          "content": f"training time: {total_second // 3600}h {(total_second % 3600) // 60}m"
#                      }, headers = headers)

# print(resp.content.decode())

## model testing

In [None]:
# define evaluating function
def calculate_psnr(pred, target):
    """calculate PSNR between tensors
    """
    pred = pred.squeeze(0).clamp(0, 1).cpu().numpy()
    target = target.squeeze(0).clamp(0, 1).cpu().numpy()
    return psnr(target, pred, data_range=1)


def calculate_ssim(pred, target):
    """calculate SSIM between tensors
    """
    pred = pred.squeeze(0).clamp(0, 1).cpu().numpy()
    target = target.squeeze(0).clamp(0, 1).cpu().numpy()
    return ssim(target, pred, data_range=1)

def evaluate(model, test_loader, device):
    """evaluate from test_loaders
    """
    model.eval()
    psnr_list = []
    ssim_list = []
    clean_imgs, noisy_imgs, denoised_imgs = [], [], []

    with torch.no_grad():
        for batch_idx, (clean_img, noisy_img) in enumerate(test_loader):
            clean_img, noisy_img = clean_img.to(device), noisy_img.to(device)
            denoised = model(noisy_img)

            for i in range(clean_img.size(0)):
                psnr_list.append(calculate_psnr(denoised[i], clean_img[i]))
                ssim_list.append(calculate_ssim(denoised[i], clean_img[i]))

                # Convert tensors to NumPy arrays and reshape to grayscale images
                clean_np = clean_img[i].cpu().numpy().squeeze()
                noisy_np = noisy_img[i].cpu().numpy().squeeze()
                denoised_np = denoised[i].cpu().numpy().squeeze()

                # Append the grayscale images as a tuple to the results list
                clean_imgs.append(clean_np)
                noisy_imgs.append(noisy_np)
                denoised_imgs.append(denoised_np)
                

    avg_psnr = np.mean(psnr_list)
    avg_ssim = np.mean(ssim_list)

    return avg_psnr, avg_ssim, (clean_imgs, noisy_imgs, denoised_imgs)

def get_evaluation_results(model, test_loaders, device):
    "get evulation result"
    avg_psnr_results = []
    avg_ssim_results = []
    clean_imgs_results, noisy_imgs_results, denoised_imgs_results = [], [], [] 

    for _, test_loader in test_loaders.items():
        avg_psnr, avg_ssim, (clean_imgs, noisy_imgs, denoised_imgs) = evaluate(model, test_loader, device)
        avg_psnr_results.append(avg_psnr)
        avg_ssim_results.append(avg_ssim)
        clean_imgs_results.append(clean_imgs)
        noisy_imgs_results.append(noisy_imgs)
        denoised_imgs_results.append(denoised_imgs)

    clean_imgs_results, noisy_imgs_results, denoised_imgs_results = np.array(clean_imgs_results), np.array(noisy_imgs_results), np.array(denoised_imgs_results)
    
    return avg_psnr_results, avg_ssim_results, clean_imgs_results, noisy_imgs_results, denoised_imgs_results

def plot_evaluation_results(x, avg_psnr_results, avg_ssim_results, title="PSNR & SSIM plot", save_path=None):
    """plot evaluation results
    """
    fig, axs = plt.subplots(2, 1, figsize=(10, 5))
    axs[0].plot(x, avg_psnr_results, color='blue')
    axs[0].set_ylabel("PSNR(db)")
    axs[1].plot(x, avg_ssim_results, color='red')
    axs[1].set_ylabel("SSIM(%)")

    for i, j in zip(x, avg_psnr_results):
        axs[0].annotate(f"{j:.4f}", xy=(i, j), xycoords='data', xytext=(0, 10),
                        textcoords='offset points', ha='center', va='bottom')

    for i, j in zip(x, avg_ssim_results):
        axs[1].annotate(f"{j:.4f}", xy=(i, j), xycoords='data', xytext=(0, 10),
                        textcoords='offset points', ha='center', va='bottom')

    fig.subplots_adjust(hspace=0.5)
    fig.suptitle(title)
    
    if save_path is not None:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    
    plt.show()
    

def plot_noise_level_images(x, clean_imgs_results, noisy_imgs_results, denoised_imgs_results, idx=0):
    """plot_noise_level_images
    """
    num_noise_levels = clean_imgs_results.shape[0]

    # Create a grid of subplots with num_noise_levels rows and 3 columns for clean, noisy, and denoised images
    fig, axs = plt.subplots(num_noise_levels, 3, figsize=(20, num_noise_levels * 5))

    for i in range(num_noise_levels):
        # Display clean image
        axs[i, 0].imshow(clean_imgs_results[i, idx], cmap='gray')
        axs[i, 0].set_title(f"Clean ({x[i]})")

        # Display noisy image
        axs[i, 1].imshow(noisy_imgs_results[i, idx], cmap='gray')
        axs[i, 1].set_title(f"Noisy ({x[i]})")

        # Display denoised image
        axs[i, 2].imshow(denoised_imgs_results[i, idx], cmap='gray')
        axs[i, 2].set_title(f"Denoised ({x[i]})")

        # Remove axis ticks and labels
        axs[i, 0].set_xticks([])
        axs[i, 0].set_yticks([])
        axs[i, 1].set_xticks([])
        axs[i, 1].set_yticks([])
        axs[i, 2].set_xticks([])
        axs[i, 2].set_yticks([])

    # Adjust the spacing between subplots
    fig.subplots_adjust(hspace=0.2, wspace=0.3)



In [None]:
# load model from cheack point
# np.random.seed(10)
# torch.manual_seed(10)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = DnCNN().to(device)

In [None]:
# create test dataloaders 
test_loaders = {
    "noisy30": create_dataloaders(clean_dir, noisy30_dir, transform=data_transform, batch_size=batch_size, num_workers=num_workers)[-1],
    "noisy25": create_dataloaders(clean_dir, noisy25_dir, transform=data_transform, batch_size=batch_size, num_workers=num_workers)[-1],
    "noisy20": create_dataloaders(clean_dir, noisy20_dir, transform=data_transform, batch_size=batch_size, num_workers=num_workers)[-1],
    "noisy15": create_dataloaders(clean_dir, noisy15_dir, transform=data_transform, batch_size=batch_size, num_workers=num_workers)[-1],
    "noisy10": create_dataloaders(clean_dir, noisy10_dir, transform=data_transform, batch_size=batch_size, num_workers=num_workers)[-1]
}

x = list(test_loaders.keys())

### 1. evaluate result that train on `noisy30` 

In [None]:
# load trained network weight
model.load_state_dict(torch.load('checkpoints/DnCNN-noisy30-MSE-50epoch.pt'))

In [None]:
# evaluate trained noisy30 network on differnt noise level testset
avg_psnr_results, avg_ssim_results, clean_imgs_results, noisy_imgs_results, denoised_imgs_results = get_evaluation_results(model, test_loaders, device)

In [None]:
# plot PSNR and SSIM
plot_evaluation_results(x, avg_psnr_results, avg_ssim_results, title="PSNR & SSIM plot(noisy30)", save_path="media/PSNR & SSIM plot(noisy30).png")

In [None]:
# visualize clean, noisy, and denoised image 
plot_noise_level_images(x, clean_imgs_results, noisy_imgs_results, denoised_imgs_results, idx=0)

### 2. evaluate result that train on `noisy25` 

In [None]:
# load trained network weight
model.load_state_dict(torch.load('checkpoints/DnCNN-noisy25-MSE-50epoch.pt'))

In [None]:
# evaluate trained noisy25 network on differnt noise level testset
avg_psnr_results, avg_ssim_results, clean_imgs_results, noisy_imgs_results, denoised_imgs_results = get_evaluation_results(model, test_loaders, device)

In [None]:
# plot PSNR and SSIM
plot_evaluation_results(x, avg_psnr_results, avg_ssim_results, title="PSNR & SSIM plot(noisy25)", save_path="media/PSNR & SSIM plot(noisy25).png")

In [None]:
# visualize clean, noisy, and denoised image 
plot_noise_level_images(x, clean_imgs_results, noisy_imgs_results, denoised_imgs_results, idx=0)

### 3. evaluate result that train on `noisy20` 

In [None]:
# load trained network weight
model.load_state_dict(torch.load('checkpoints/DnCNN-noisy20-MSE-50epoch.pt'))

In [None]:
# evaluate trained noisy20 network on differnt noise level testset
avg_psnr_results, avg_ssim_results, clean_imgs_results, noisy_imgs_results, denoised_imgs_results = get_evaluation_results(model, test_loaders, device)

In [None]:
# plot PSNR and SSIM
plot_evaluation_results(x, avg_psnr_results, avg_ssim_results, title="PSNR & SSIM plot(noisy20)", save_path="media/PSNR & SSIM plot(noisy20).png")

In [None]:
# visualize clean, noisy, and denoised image 
plot_noise_level_images(x, clean_imgs_results, noisy_imgs_results, denoised_imgs_results, idx=0)

### 4. evaluate result that train on `noisy15` 

In [None]:
# load trained network weight
model.load_state_dict(torch.load('checkpoints/DnCNN-noisy15-MSE-50epoch.pt'))

In [None]:
# evaluate trained noisy15 network on differnt noise level testset
avg_psnr_results, avg_ssim_results, clean_imgs_results, noisy_imgs_results, denoised_imgs_results = get_evaluation_results(model, test_loaders, device)

In [None]:
# plot PSNR and SSIM
plot_evaluation_results(x, avg_psnr_results, avg_ssim_results, title="PSNR & SSIM plot(noisy15)", save_path="media/PSNR & SSIM plot(noisy15).png")

In [None]:
# visualize clean, noisy, and denoised image 
plot_noise_level_images(x, clean_imgs_results, noisy_imgs_results, denoised_imgs_results, idx=0)

### 5. evaluate result that train on `noisy10` 

In [None]:
# load trained network weight
model.load_state_dict(torch.load('checkpoints/DnCNN-noisy15-MSE-50epoch.pt'))

In [None]:
# evaluate trained noisy10 network on differnt noise level testset
avg_psnr_results, avg_ssim_results, clean_imgs_results, noisy_imgs_results, denoised_imgs_results = get_evaluation_results(model, test_loaders, device)

In [None]:
# plot PSNR and SSIM
plot_evaluation_results(x, avg_psnr_results, avg_ssim_results, title="PSNR & SSIM plot(noisy10)", save_path="media/PSNR & SSIM plot(noisy10).png")

In [None]:
# visualize clean, noisy, and denoised image 
plot_noise_level_images(x, clean_imgs_results, noisy_imgs_results, denoised_imgs_results, idx=0)