In [1]:
import time
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt

from torchvision.io import read_image, write_png

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
sigma = 30
path_in = './test_images/cameraman.png'

img = read_image(path_in)[None, :, :, :].float().to(device)
img_noisy = img + sigma * torch.randn_like(img)

In [3]:
class conv_block(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(conv_block, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        x = self.conv(x)
        return x
        
class up_conv(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(up_conv, self).__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.up(x)
        return x

class U_Net(nn.Module):
    def __init__(self, img_ch=1, output_ch=1):
        super(U_Net, self).__init__()

        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.conv1 = conv_block(ch_in=img_ch, ch_out=64)
        self.conv2 = conv_block(ch_in=64, ch_out=128)
        self.conv3 = conv_block(ch_in=128, ch_out=256)
        self.conv4 = conv_block(ch_in=256, ch_out=512)
        self.conv5 = conv_block(ch_in=512, ch_out=1024)

        self.up5 = up_conv(ch_in=1024, ch_out=512)
        self.up_conv5 = conv_block(ch_in=1024, ch_out=512)
        self.up4 = up_conv(ch_in=512, ch_out=256)
        self.up_conv4 = conv_block(ch_in=512, ch_out=256)
        self.up3 = up_conv(ch_in=256, ch_out=128)
        self.up_conv3 = conv_block(ch_in=256, ch_out=128)
        self.up2 = up_conv(ch_in=128, ch_out=64)
        self.up_conv2 = conv_block(ch_in=128, ch_out=64)

        self.up_conv1 = nn.Conv2d(64, output_ch, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.maxpool(x1)
        x2 = self.conv2(x2)
        x3 = self.maxpool(x2)
        x3 = self.conv3(x3)
        x4 = self.maxpool(x3)
        x4 = self.conv4(x4)
        x5 = self.maxpool(x4)
        x5 = self.conv5(x5)

        d5 = self.up5(x5)
        d5 = torch.cat((x4, d5), dim=1)
        d5 = self.up_conv5(d5)

        d4 = self.up4(d5)
        d4 = torch.cat((x3, d4), dim=1)
        d4 = self.up_conv4(d4)

        d3 = self.up3(d4)
        d3 = torch.cat((x2, d3), dim=1)
        d3 = self.up_conv3(d3)

        d2 = self.up2(d3)
        d2 = torch.cat((x1, d2), dim=1)
        d2 = self.up_conv2(d2)

        d1 = self.up_conv1(d2)
        return d1

In [4]:
x = torch.randn_like(img_noisy)
model = U_Net(img_ch=1, output_ch=1)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-2)

In [5]:
t = time.time()
num_epochs = 200
for epoch in range(num_epochs):
    model.train()
    output = model(x)
    loss = criterion(output, img)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if epoch % 10 == 0:
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item():.6f}")

print("Time elapsed U-Net:", round(time.time() - t, 3), "seconds")

In [6]:
img_unet = model(x)
img_unet = img_unet.clip(0, 255)
write_png(img_unet[0, :, :, :].byte().to("cpu"), "./img_unet.png")

In [7]:
psnr = 10*torch.log10(255**2 / torch.mean((img_unet - img)**2))
print("PSNR u-net:", round(float(psnr), 2), "dB")