In [2]:
import torch
import torchvision
import numpy as np
import pandas as pd
import os
import cv2
from matplotlib import pyplot as plt

In [None]:
!rm -rf *
!pip install -U gdown
!gdown 1E1poQDBsRp53Ip9qndB2nPLu1po3AbEp
!gdown 18UXIh3mGA_M0oxlapPIrUG-VVpF6PN2S
!unzip Phase2.zip

In [4]:
class DenoisingAutoEncoder(torch.nn.Module):
    def __init__(self, in_channels):
        super(DenoisingAutoEncoder, self).__init__()
        self.encoder = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=in_channels, out_channels=128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(num_features=128),
            torch.nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)),
            torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(num_features=128),
            torch.nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)),

            torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(num_features=128),
            torch.nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)),
            torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(num_features=128),
            torch.nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)),
        )
        self.decoder = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(num_features=128),
            torch.nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=(2, 2), stride=(2, 2)),
            torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(num_features=128),
            torch.nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=(2, 2), stride=(2, 2)),

            torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(num_features=128),
            torch.nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=(2, 2), stride=(2, 2)),
            torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(num_features=128),
            torch.nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=(2, 2), stride=(2, 2)),

            torch.nn.Conv2d(in_channels=128, out_channels=in_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            torch.nn.Sigmoid()
        )


    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [6]:
in_channels = 1
learning_rate = 0.001
batch_size = 8
num_epochs = 30

In [43]:
class NoisyImagesDataset(torch.utils.data.Dataset):
    def __init__(self, noise_type, root_dir='DataSet2', spreadsheet_file='Labels.xlsx', transform=None): # 'Gaussian' 'Periodic' 'Salt'
        self.annotations = pd.read_excel(spreadsheet_file)
        self.annotations = self.annotations[self.annotations['Noise Type'] == noise_type].drop(columns='Noise Type')[['Noisy Image', 'Denoise Image']]
        self.root_dir = root_dir
        self.transform = transform
        self.noise_type = noise_type

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, index):
        noisy_img_path = os.path.join(self.root_dir, self.annotations.iloc[index, 0])
        img_path = os.path.join(self.root_dir, self.annotations.iloc[index, 1])
        noisy_image = cv2.imread(filename=noisy_img_path, flags=cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255
        image = cv2.imread(filename=img_path, flags=cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255
        if self.transform:
            image = self.transform(image)
            noisy_image = self.transform(noisy_image)
        return noisy_image, image

gaussian_dataset = NoisyImagesDataset(noise_type='Gaussian', transform=torchvision.transforms.ToTensor())
gaussian_train_dataset, gaussian_test_dataset = torch.utils.data.random_split(gaussian_dataset, [319, 80])
gaussian_train_loader = torch.utils.data.DataLoader(dataset=gaussian_train_dataset, batch_size=batch_size, shuffle=True)
gaussian_test_loader = torch.utils.data.DataLoader(dataset=gaussian_test_dataset, batch_size=batch_size, shuffle=True)

periodic_dataset = NoisyImagesDataset(noise_type='Periodic', transform=torchvision.transforms.ToTensor())
periodic_train_dataset, periodic_test_dataset = torch.utils.data.random_split(periodic_dataset, [360, 91])
periodic_train_loader = torch.utils.data.DataLoader(dataset=periodic_train_dataset, batch_size=batch_size, shuffle=True)
periodic_test_loader = torch.utils.data.DataLoader(dataset=periodic_test_dataset, batch_size=batch_size, shuffle=True)

salt_dataset = NoisyImagesDataset(noise_type='Salt', transform=torchvision.transforms.ToTensor())
salt_train_dataset, salt_test_dataset = torch.utils.data.random_split(salt_dataset, [320, 80])
salt_train_loader = torch.utils.data.DataLoader(dataset=salt_train_dataset, batch_size=batch_size, shuffle=True)
salt_test_loader = torch.utils.data.DataLoader(dataset=salt_test_dataset, batch_size=batch_size, shuffle=True)

In [8]:
model = DenoisingAutoEncoder(in_channels=in_channels).to(device)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(params=model.parameters(), lr=learning_rate)

In [41]:
from torchmetrics.functional.image import learned_perceptual_image_patch_similarity
from torchmetrics.functional.image import peak_signal_noise_ratio
from torchmetrics.functional.image import structural_similarity_index_measure


def check_performance(loader, denoiser):
    model.eval()
    psnr = 0
    lpips = 0
    ssim = 0
    count = 0
    with torch.no_grad():
        for noisy_image_gray, noise_free_image_gray in loader:
            noisy_image_gray = noisy_image_gray.to(device)
            noise_free_image_gray = noise_free_image_gray.to(device)

            output = denoiser(noisy_image_gray)
            
            lpips += noisy_image_gray.shape[0] * learned_perceptual_image_patch_similarity(noise_free_image_gray.repeat(1, 3, 1, 1), output.repeat(1, 3, 1, 1), net_type='alex', normalize=True)
            psnr += noisy_image_gray.shape[0] * peak_signal_noise_ratio(output, noise_free_image_gray, data_range=(0, 1))
            ssim += noisy_image_gray.shape[0] * structural_similarity_index_measure(output, noise_free_image_gray, data_range=(0, 1))
            count += noisy_image_gray.shape[0]
    model.train()
    return psnr / count, ssim / count, lpips / count

In [9]:
for epoch in range(num_epochs):
    for batch_idx, (data, targets) in enumerate(gaussian_train_loader):
        data = data.to(device=device)
        targets = targets.to(device=device)

        scores = model(data)
        loss = criterion(scores, targets)

        optimizer.zero_grad()
        loss.backward()

        optimizer.step()
    print(f'Loss at epoch {epoch} was {loss}')

Loss at epoch 0 was 0.004186984151601791
Loss at epoch 1 was 0.003701667068526149
Loss at epoch 2 was 0.005305266939103603
Loss at epoch 3 was 0.002752698725089431
Loss at epoch 4 was 0.0019535550381988287
Loss at epoch 5 was 0.001668474287725985
Loss at epoch 6 was 0.0015025107422843575
Loss at epoch 7 was 0.0020987791940569878
Loss at epoch 8 was 0.0012610041303560138
Loss at epoch 9 was 0.0010946514084935188
Loss at epoch 10 was 0.0016950821736827493
Loss at epoch 11 was 0.0013403536286205053
Loss at epoch 12 was 0.003376695327460766
Loss at epoch 13 was 0.0011945064179599285
Loss at epoch 14 was 0.001443510758690536
Loss at epoch 15 was 0.0012131485855206847
Loss at epoch 16 was 0.002522748429328203
Loss at epoch 17 was 0.0010677635436877608
Loss at epoch 18 was 0.001199766411446035
Loss at epoch 19 was 0.0010867835953831673
Loss at epoch 20 was 0.0011253263801336288
Loss at epoch 21 was 0.0017491122707724571
Loss at epoch 22 was 0.0009218865889124572
Loss at epoch 23 was 0.0007500

In [42]:
psnr, ssim, lpips = check_performance(gaussian_train_loader, model)
print(f"PSNR mean: {psnr:.4}\nSSIM mean: {ssim:.4}\nLPIPS mean: {lpips:.4}")

PSNR mean: 31.43
SSIM mean: 0.8978
LPIPS mean: 0.1796


In [44]:
psnr, ssim, lpips = check_performance(gaussian_test_loader, model)
print(f"PSNR mean: {psnr:.4}\nSSIM mean: {ssim:.4}\nLPIPS mean: {lpips:.4}")

PSNR mean: 31.29
SSIM mean: 0.8981
LPIPS mean: 0.1823


In [45]:
for epoch in range(num_epochs):
    for batch_idx, (data, targets) in enumerate(salt_train_loader):
        data = data.to(device=device)
        targets = targets.to(device=device)

        scores = model(data)
        loss = criterion(scores, targets)

        optimizer.zero_grad()
        loss.backward()

        optimizer.step()
    print(f'Loss at epoch {epoch} was {loss}')

Loss at epoch 0 was 0.0008832307066768408
Loss at epoch 1 was 0.0009629302658140659
Loss at epoch 2 was 0.001117172185331583
Loss at epoch 3 was 0.0015351871261373162
Loss at epoch 4 was 0.0010546728735789657
Loss at epoch 5 was 0.0006712418980896473
Loss at epoch 6 was 0.0004528420395217836
Loss at epoch 7 was 0.0014227894134819508
Loss at epoch 8 was 0.0011205270420759916
Loss at epoch 9 was 0.0006742628756910563
Loss at epoch 10 was 0.0013149648439139128
Loss at epoch 11 was 0.0010858721798285842
Loss at epoch 12 was 0.000611037015914917
Loss at epoch 13 was 0.0009519215091131628
Loss at epoch 14 was 0.0008486582664772868
Loss at epoch 15 was 0.0006593860452994704
Loss at epoch 16 was 0.0010011414997279644
Loss at epoch 17 was 0.0006958458689041436
Loss at epoch 18 was 0.0008002764661796391
Loss at epoch 19 was 0.0007245963206514716
Loss at epoch 20 was 0.0006785043515264988
Loss at epoch 21 was 0.0008716938318684697
Loss at epoch 22 was 0.0007666379678994417
Loss at epoch 23 was 0.

In [46]:
psnr, ssim, lpips = check_performance(salt_train_loader, model)
print(f"PSNR mean: {psnr:.4}\nSSIM mean: {ssim:.4}\nLPIPS mean: {lpips:.4}")

PSNR mean: 31.21
SSIM mean: 0.909
LPIPS mean: 0.1747


In [47]:
psnr, ssim, lpips = check_performance(salt_test_loader, model)
print(f"PSNR mean: {psnr:.4}\nSSIM mean: {ssim:.4}\nLPIPS mean: {lpips:.4}")

PSNR mean: 30.99
SSIM mean: 0.9066
LPIPS mean: 0.1802


In [48]:
for epoch in range(num_epochs):
    for batch_idx, (data, targets) in enumerate(periodic_train_loader):
        data = data.to(device=device)
        targets = targets.to(device=device)

        scores = model(data)
        loss = criterion(scores, targets)

        optimizer.zero_grad()
        loss.backward()

        optimizer.step()
    print(f'Loss at epoch {epoch} was {loss}')

Loss at epoch 0 was 0.00978017132729292
Loss at epoch 1 was 0.00809676293283701
Loss at epoch 2 was 0.005208727903664112
Loss at epoch 3 was 0.005853255279362202
Loss at epoch 4 was 0.005594780668616295
Loss at epoch 5 was 0.0038132229819893837
Loss at epoch 6 was 0.0031190761364996433
Loss at epoch 7 was 0.002788532990962267
Loss at epoch 8 was 0.0028617081698030233
Loss at epoch 9 was 0.00289516500197351
Loss at epoch 10 was 0.0035722162574529648
Loss at epoch 11 was 0.0028106672689318657
Loss at epoch 12 was 0.003088792786002159
Loss at epoch 13 was 0.0027675097808241844
Loss at epoch 14 was 0.0020893793553113937
Loss at epoch 15 was 0.003744495566934347
Loss at epoch 16 was 0.002473799278959632
Loss at epoch 17 was 0.0032352718990296125
Loss at epoch 18 was 0.002390999346971512
Loss at epoch 19 was 0.0016637505032122135
Loss at epoch 20 was 0.0023125119041651487
Loss at epoch 21 was 0.0035929980222135782
Loss at epoch 22 was 0.003120951820164919
Loss at epoch 23 was 0.0025263347197

In [49]:
psnr, ssim, lpips = check_performance(periodic_train_loader, model)
print(f"PSNR mean: {psnr:.4}\nSSIM mean: {ssim:.4}\nLPIPS mean: {lpips:.4}")

PSNR mean: 27.06
SSIM mean: 0.882
LPIPS mean: 0.1837


In [50]:
psnr, ssim, lpips = check_performance(periodic_test_loader, model)
print(f"PSNR mean: {psnr:.4}\nSSIM mean: {ssim:.4}\nLPIPS mean: {lpips:.4}")

PSNR mean: 25.55
SSIM mean: 0.8762
LPIPS mean: 0.1942
