In [None]:
import torch 
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from skimage import io
import numpy as np
from skimage.measure import block_reduce

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
num_iter = 5000
learning_rate = 0.01
img_path = ''
sigma = 1.0 / 10.0 
noise_ratio = 0.5 #for denoising
scale_factor = 2 #for super resulotion

In [None]:
img = io.imread(img_path) / 255.0
h, w, c = img.shape
z = (torch.randn(1, c, h, w) * sigma).to(device)
to_tensor = transforms.ToTensor()

In [None]:
#denoising

img = to_tensor(img).to(device)
mask = np.where(np.random.rand(h, w) > noise_ratio, 1.0, 0.0)
mask = to_tensor(mask).to(device)
y = (img * mask).reshape((1, c, h, w)).to(device)


In [None]:
#super resulotion
'''
resize = transforms.Resize((h // scale_factor, w // scale_factor))
y = to_tensor(img)
y = resize(y)
y = y.reshape((1, c, h // scale_factor, w // scale_factor)).to(device)
'''

In [None]:
plt.imshow(y[0].permute(1, 2, 0).cpu().detach().numpy())
plt.show()

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size=3):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size, padding='same'),
            nn.ReLU(),
            nn.Conv2d(out_ch, out_ch, kernel_size, padding='same'),
            nn.BatchNorm2d(out_ch),
            nn.ReLU())
    
    def forward(self, x):
        return self.block(x)

class UNET(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = ConvBlock(3, 8)
        self.conv2 = ConvBlock(8, 16)
        self.conv3 = ConvBlock(16, 32)
        self.conv4 = ConvBlock(32, 64)
        self.conv5 = ConvBlock(64, 128)
        self.conv6 = ConvBlock(128, 256)
        self.conv7 = ConvBlock(256, 128)
        self.conv8 = ConvBlock(128, 64)
        self.conv9 = ConvBlock(64, 32)
        self.conv10 = ConvBlock(32, 16)
        self.conv11 = ConvBlock(16, 8)
        self.conv12 = ConvBlock(8, 3, kernel_size=1)
        self.pool = nn.MaxPool2d(2)
        self.up_conv1 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.up_conv2 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.up_conv3 = nn.ConvTranspose2d(64, 32, 2, stride=2)
        self.up_conv4 = nn.ConvTranspose2d(32, 16, 2, stride=2)
        self.up_conv5 = nn.ConvTranspose2d(16, 8, 2, stride=2)

    def forward(self, x):
        conv1 = self.conv1(x)
        down1 = self.pool(conv1)
        conv2 = self.conv2(down1)
        down2 = self.pool(conv2)
        conv3 = self.conv3(down2)
        down3 = self.pool(conv3)
        conv4 = self.conv4(down3)
        down4 = self.pool(conv4)
        conv5 = self.conv5(down4)
        down5 = self.pool(conv5)
        conv6 = self.conv6(down5)
        up1 = torch.cat((self.up_conv1(conv6), conv5), 1)
        up2 = torch.cat((self.up_conv2(self.conv7(up1)), conv4), 1)
        up3 = torch.cat((self.up_conv3(self.conv8(up2)), conv3), 1)
        up4 = torch.cat((self.up_conv4(self.conv9(up3)), conv2), 1)
        up5 = torch.cat((self.up_conv5(self.conv10(up4)), conv1), 1)
        out = self.conv12(self.conv11(up5))
        return out

In [None]:
model = UNET().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
model.train()
for i in range(num_iter):
    x = model(z)
    out = x * mask #for denoising
    #out = resize(x) #for super resulotion
    loss = criterion(out.float(), y.float())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if i % 100 == 0:
        print('i =', i, 'loss =', loss.item())
        plt.imshow(x[0].permute(1, 2, 0).cpu().detach().numpy())
        plt.show()