In [18]:
import torch
from torch import nn
import torchvision
import math
import numpy as np
import matplotlib.pyplot as plt

In [20]:
pixel_size = 50e-6
z = 50e-3
nx = ny = 1024

x = np.linspace(-nx//2**2, nx//2**2-1, nx)*pixel_size
y = np.linspace(-ny//2**2, ny//2**2-1, ny)*pixel_size
X, Y = np.meshgrid(x, y)

# Input Field
sigma = 20*pixel_size
field = np.exp(-(X**2 + Y**2)/(2 * sigma**2))

In [32]:
class ConvolutionalBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride = 1, batch_norm = False, activation = None):
        super(ConvolutionalBlock, self).__init__()

        if activation is not None:
            activation = activation.lower()
            assert activation in {"prelu", "leakyrelu", "tanh"}

        # This list holds all the layers for this block. In this case, a single 2d convolution
        layers = list()
        layers.append(nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size = kernel_size, stride = stride, 
                                padding = kernel_size//2))
        
        # Cases where batchnorm is used
        if activation == "prelu":
            layers.append(nn.PReLU())
        elif activation == "leakyrelu":
            layers.append(nn.LeakyReLU(0.2))
        elif activation == "tanh":
            layers.append(nn.Tanh())

        # Assemble the convolutional block
        self.conv_block = nn.Sequential(*layers)

    def forward(self, x):
        output = self.conv_block(x)
        return output

In [33]:
class SubPixelConvolutionBlock(nn.Module):
    def __init__(self, kernel_size = 3, n_channels = 64, scaling_factor = 2):
        super(SubPixelConvolutionBlock, self).__init__()

        self.conv = nn.Conv2d(in_channels = n_channels, out_channels = n_channels * (scaling_factor **2),
                              kernel_size = kernel_size, padding = kernel_size // 2)
        self.pixel_suffle = nn.PixelShuffle(upscale_factor = scaling_factor)
        self.prelu = nn.PReLU()

    def forward(self, x):
        output = self.conv(x)
        output = self.pixel_suffle(output)
        output = self.prelu(output)

        return output

In [36]:
class ResidualBlock(nn.Module):
    def __init__(self, kernel_size = 3, n_channels = 64):
        super(ResidualBlock, self).__init__()

        self.conv_block1 = ConvolutionalBlock(in_channels = n_channels, out_channels = n_channels, kernel_size = kernel_size,
                                               batch_norm = False, activation = "PReLU")
        self.conv_block2 = ConvolutionalBlock(in_channels = n_channels, out_channels = n_channels, kernel_size = kernel_size,
                                              batch_norm = False, activation = None)
        
    def forward(self, x):
        residual = x
        output = self.conv_block1(x)
        output = self.conv_block2(output)
        output = output + residual

        return output

In [None]:
class SRResNet(nn.Module):
    def __init__(self, large_kernel_size = 9, small_kernel_size = 3, n_channels = 64, n_blocks = 16, scaling_factor = 4):
        super(SRResNet, self).__init__()

        scaling_factor = int(scaling_factor)
        assert scaling_factor in {2, 4, 8}, "The Scaling Factor must be 2, 4, or 8!"

        self.conv_block1 = ConvolutionalBlock(in_channels = 3, out_channels = n_channels, kernel_size = large_kernel_size,
                                              batch_norm = False, activation = "PReLU")
        self.residual_blocks = nn.Sequential(
            *[ResidualBlock(kernel_size = small_kernel_size, n_channels = n_channels) for i in range(n_blocks)])
        self.conv_block2 = ConvolutionalBlock(in_channels = n_channels, out_channels = n_channels, kernel_size = small_kernel_size,
                                              batch_norm = False, activation = "Tanh")
        
        n_subpixel_convolution_blocks = int(math.log2(scaling_factor))
        self.subpixel_convolutional_blocks = nn.Sequential(
            *[SubPixelConvolutionBlock(kernel_size = small_kernel_size, n_channels = n_channels, scaling_factor = 2) for i 
              in range(n_subpixel_convolution_blocks)])
        self.conv_block3 = ConvolutionalBlock(in_channels = n_channels, out_channels = 3, kernel_size = large_kernel_size,
                                              batch_norm = False, activation = 'Tanh')
        
    def forward(self, lr_imgs):

        output = self.conv_block1(lr_imgs) 
        residual = output
        output = self.residual_blocks(output) 
        output = self.conv_block2(output) 
        output = output + residual 
        output = self.subpixel_convolutional_blocks(output) 
        sr_imgs = self.conv_block3(output)

        return sr_imgs