# Beyond a gaussian denoiser: Residual learning of deep cnn for image denoising
https://paperswithcode.com/paper/beyond-a-gaussian-denoiser-residual-learning

In [2]:
# import liberies
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms
import torch.optim as optim

from torch.utils.tensorboard import SummaryWriter


import os
import random

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import cv2


from skimage.transform import radon, iradon
from skimage.metrics import peak_signal_noise_ratio


from data_utils import filterd_back_projection, split_dataset

## define DnCNN

In [3]:
# define neural network
class DnCNN(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, num_layers=17, num_features=64):
        super(DnCNN, self).__init__()

        layers = [
            nn.Conv2d(in_channels, num_features, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        ]

        for _ in range(num_layers-2):
            layers.extend([
                nn.Conv2d(num_features, num_features, kernel_size=3, padding=1),
                nn.BatchNorm2d(num_features),
                nn.ReLU(inplace=True)
            ])

        layers.append(nn.Conv2d(num_features, out_channels, kernel_size=3, padding=1))

        self.dncnn = nn.Sequential(*layers)

    def forward(self, x):
        return x - self.dncnn(x)

## define dataset, dataloader

In [4]:
# define custome dataset
class WaterlooPairDataset(Dataset):
    def __init__(self, clean_dir, noisy_dir, transform=None):
        self.clean_dir = clean_dir
        self.noisy_dir = noisy_dir
        self.transform = transform
        
        self.clean_files = os.listdir(clean_dir)
        self.noisy_files = os.listdir(noisy_dir)

        self.clean_files.sort()
        self.noisy_files.sort()

        assert len(self.clean_files) == len(self.noisy_files), \
            "Number of clean files and noisy files should be equal"
        
    def __len__(self):
        return len(self.clean_files)
    
    def __getitem__(self, index):
        clean_path = os.path.join(self.clean_dir, self.clean_files[index])
        noisy_path = os.path.join(self.noisy_dir, self.noisy_files[index])

        clean_img = cv2.imread(clean_path, 0)
        noisy_img = cv2.imread(noisy_path, 0)

        if self.transform is not None:
            clean_img = self.transform(clean_img)
            noisy_img = self.transform(noisy_img)
        
        return clean_sinogram, noisy_sinogram

In [31]:
# 
clean_dir = 'data/exploration_database_and_code/clean'
noisy30_dir = 'data/exploration_database_and_code/noisy30'
noisy15_dir = 'data/exploration_database_and_code/noisy15'
noisy10_dir = 'data/exploration_database_and_code/noisy10'


# define data transform
data_transfrom = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

# create dataset
noisy30_pair_dataset = WaterlooPairDataset(clean_dir, noisy30_dir, transforms=data_transfrom)

train_dataset, val_dataset, test_dataset = split_dataset(noisy30_pair_dataset)

print(f"train_dataset {len(train_dataset)}")
print(f"val_dataset {len(val_dataset)}")
print(f"test_dataset {len(test_dataset)}")

train_dataset : 3795
val_dataset : 474
test_dataset : 475


In [None]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)

In [6]:
def create_denoising_datasets(clean_dir, noisy_dir, train_ratio=0.8, val_ratio=0.1, test_ratio=0.1):
    # read local images
    image_paths = [os.path.join(root_path, img_name) for img_name in os.listdir(root_path)]
    random.shuffle(image_paths)
    
    total_images = len(image_paths)
    train_size = int(total_images * train_ratio)
    
    val_size = int(total_images * val_ratio)
    test_size = total_images - train_size - val_size
    
    # define data transform
    data_transfrom = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5])
    ])
    
    # create dataset with degisned poisson noise
    dataset = PoissonNoisyDataset(image_paths, lam=5, scale=1, transform=data_transfrom)

    train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

    return train_dataset, val_dataset, test_dataset



#
data_root = "data/exploration_database_and_code/pristine_images/"
train_dataset, val_dataset, test_dataset = create_denoising_datasets(data_root)