In [17]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import argparse
import time
import matplotlib.pyplot as plt
from torchvision.io import read_image, write_png

In [16]:
parser = argparse.ArgumentParser()
parser.add_argument("--sigma", type=float, dest="sigma",
                    help="Standard deviation of the noise (noise level). Should be between 0 and 50.", default=30)
parser.add_argument("--in", type=str, dest="path_in",
                    help="Path to the image to denoise (PNG or JPEG).", default="./test_images/cameraman.png")
parser.add_argument("--out", type=str, dest="path_out",
                    help="Path to save the denoised image.", default="./img_lichi.png")
parser.add_argument("--add_noise", action='store_true',
                    help="Add artificial Gaussian noise to the image.", default=True)

# To avoid conflicts with Jupyter's arguments, pass an empty list to parse_args()
args = parser.parse_args(args=[])
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Read image and Add noise
img = read_image(args.path_in)[None, :, :, :].float().to(device)
img_noisy = img + args.sigma * torch.randn_like(img) if args.add_noise else img

In [3]:
class conv_block(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(conv_block, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        x = self.conv(x)
        return x

In [4]:
class up_conv(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(up_conv, self).__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.up(x)
        return x

In [26]:
class U_Net(nn.Module):
    def __init__(self, img_ch=1, output_ch=1):
        super(U_Net, self).__init__()

        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.conv1 = conv_block(ch_in=img_ch, ch_out=64)
        self.conv2 = conv_block(ch_in=64, ch_out=128)
        self.conv3 = conv_block(ch_in=128, ch_out=256)
        self.conv4 = conv_block(ch_in=256, ch_out=512)
        self.conv5 = conv_block(ch_in=512, ch_out=1024)

        self.up5 = up_conv(ch_in=1024, ch_out=512)
        self.up_conv5 = conv_block(ch_in=1024, ch_out=512)
        self.up4 = up_conv(ch_in=512, ch_out=256)
        self.up_conv4 = conv_block(ch_in=512, ch_out=256)
        self.up3 = up_conv(ch_in=256, ch_out=128)
        self.up_conv3 = conv_block(ch_in=256, ch_out=128)
        self.up2 = up_conv(ch_in=128, ch_out=64)
        self.up_conv2 = conv_block(ch_in=128, ch_out=64)

        self.up_conv1 = nn.Conv2d(64, output_ch, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.maxpool(x1)
        x2 = self.conv2(x2)
        x3 = self.maxpool(x2)
        x3 = self.conv3(x3)
        x4 = self.maxpool(x3)
        x4 = self.conv4(x4)
        x5 = self.maxpool(x4)
        x5 = self.conv5(x5)

        d5 = self.up5(x5)
        d5 = torch.cat((x4, d5), dim=1)
        d5 = self.up_conv5(d5)

        d4 = self.up4(d5)
        d4 = torch.cat((x3, d4), dim=1)
        d4 = self.up_conv4(d4)

        d3 = self.up3(d4)
        d3 = torch.cat((x2, d3), dim=1)
        d3 = self.up_conv3(d3)

        d2 = self.up2(d3)
        d2 = torch.cat((x1, d2), dim=1)
        d2 = self.up_conv2(d2)

        d1 = self.up_conv1(d2)
        return d1

In [27]:
model = U_Net(img_ch=1, output_ch=1)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-2)

In [38]:
num_epochs = 200
for epoch in range(num_epochs):
    model.train()
    output = model(x)
    loss = criterion(output, img)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item():.6f}")

Epoch 1/100, Loss: 234.239731
Epoch 2/100, Loss: 232.548737
Epoch 3/100, Loss: 228.991150
Epoch 4/100, Loss: 219.061920
Epoch 5/100, Loss: 217.956680
Epoch 6/100, Loss: 213.243607
Epoch 7/100, Loss: 206.292557
Epoch 8/100, Loss: 203.736053
Epoch 9/100, Loss: 204.720245
Epoch 10/100, Loss: 201.171692
Epoch 11/100, Loss: 198.926086
Epoch 12/100, Loss: 192.928177
Epoch 13/100, Loss: 186.848633
Epoch 14/100, Loss: 187.060272
Epoch 15/100, Loss: 182.014069
Epoch 16/100, Loss: 178.458984
Epoch 17/100, Loss: 173.670807
Epoch 18/100, Loss: 174.541367
Epoch 19/100, Loss: 175.195038
Epoch 20/100, Loss: 173.221436
Epoch 21/100, Loss: 169.434143
Epoch 22/100, Loss: 166.464371
Epoch 23/100, Loss: 169.682938
Epoch 24/100, Loss: 179.470367
Epoch 25/100, Loss: 166.946503
Epoch 26/100, Loss: 157.082611
Epoch 27/100, Loss: 161.061310
Epoch 28/100, Loss: 155.353546
Epoch 29/100, Loss: 150.937485
Epoch 30/100, Loss: 151.647095
Epoch 31/100, Loss: 152.720413
Epoch 32/100, Loss: 146.687195
Epoch 33/100, Los

In [39]:
output = model(x)

psnr = 10*torch.log10(255**2 / torch.mean((output - img)**2))
print("PSNR LLR:", round(float(psnr), 2), "dB")

PSNR LLR: 29.58 dB
