In [2]:
from torch import nn

def conv_layer(in_channels, out_channels, kernel_size=3, stride=1, padding=1):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
        nn.ReLU(inplace=True),
    )

In [3]:
import torch.nn.functional as F

class BicubicRefiner(nn.Module):

    def __init__(self, num_channels, num_blocks, clamp=False):
        super().__init__()
        self.clamp = clamp
        layers = []

        # Feature extraction on HR images
        layers.append(conv_layer(3, num_channels, kernel_size=3, stride=1, padding=1))

        for _ in range(num_blocks):
            layers.append(conv_layer(num_channels, num_channels, kernel_size=3, stride=1, padding=1))

        # Reconstruction to RGB image layer
        layers.append(nn.Conv2d(num_channels, 3, 3, 1, 1))
        self.head = nn.Sequential(*layers[:-1])  # split head and conv_out to solve an error of pytorch_model_summary
        self.conv_out = layers[-1]

    def forward(self, lr):

        base = F.interpolate(lr, scale_factor=4, mode="bicubic", align_corners=False)

        # predict residual correction on HR img
        res = self.conv_out(self.head(base))
        out = base + res

        return out

In [4]:
import torch
from pytorch_model_summary import summary

batch_size = 16
model = BicubicRefiner(32, 2).cuda()

print(summary(model, torch.rand(size=(batch_size, 3, 32, 32)).cuda(), show_input=True))

--------------------------------------------------------------------------
      Layer (type)            Input Shape         Param #     Tr. Param #
          Conv2d-1      [16, 3, 128, 128]             896             896
            ReLU-2     [16, 32, 128, 128]               0               0
          Conv2d-3     [16, 32, 128, 128]           9,248           9,248
            ReLU-4     [16, 32, 128, 128]               0               0
          Conv2d-5     [16, 32, 128, 128]           9,248           9,248
            ReLU-6     [16, 32, 128, 128]               0               0
          Conv2d-7     [16, 32, 128, 128]             867             867
Total params: 20,259
Trainable params: 20,259
Non-trainable params: 0
--------------------------------------------------------------------------


In [5]:
from torch import nn

class SrNet(nn.Module):
    def __init__(
        self,
        blocks: int = 12,
        channels: int = 64,
    ):
        super().__init__()
        self.blocks = blocks
        self.channels = channels
        self.residual_scale = 0.1

        # 1. Feature extraction (b, 3, 32, 32) -> (b, channels, 32, 32)
        self.feature_extraction = nn.Conv2d(3, channels, kernel_size=3, stride=1, padding=1)
        self.activation = nn.LeakyReLU(0.2, inplace=True)

        # 2. Residual Blocks -> keep same dimension ( b, channels, 32, 32 )
        self.residual_blocks = nn.ModuleList([self._residual_block() for _ in range(blocks)])

        # 3. Upsample Blocks -> (b, channels, 32, 32) -> (b, channels, 64, 64) -> (b, channels, 128, 128)
        self.upsample1 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1)
        self.upsample2 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1)

        # 4. SR Img reconstruction
        self.sr_img_reconstruction = nn.Conv2d(channels, 3, kernel_size=3, stride=1, padding=1)

    def _residual_block(self):
        return nn.Sequential(
            nn.Conv2d(self.channels, self.channels, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(self.channels, self.channels, kernel_size=3, stride=1, padding=1)
        )

    def forward(self, lr_img):
        # 1. Feature Extraction
        x = self.feature_extraction(lr_img)
        x = self.activation(x)

        # 2. Residual Blocks
        for block in self.residual_blocks:
            x = x + self.residual_scale * block(x)

        # 3. Upsample blocks
        x = F.interpolate(x, scale_factor=2, mode="nearest")
        x = self.upsample1(x)
        x = self.activation(x)

        x = F.interpolate(x, scale_factor=2, mode="nearest")
        x = self.upsample2(x)
        x = self.activation(x)

        # 4. Sr Img Reconstruction
        sr_img = self.sr_img_reconstruction(x)
        return sr_img

In [6]:
batch_size = 16
model = SrNet(32, 2).cuda()

print(summary(model, torch.rand(size=(batch_size, 3, 32, 32)).cuda(), show_input=True))

-------------------------------------------------------------------------
      Layer (type)           Input Shape         Param #     Tr. Param #
          Conv2d-1       [16, 3, 32, 32]              56              56
       LeakyReLU-2       [16, 2, 32, 32]               0               0
          Conv2d-3       [16, 2, 32, 32]              38              38
       LeakyReLU-4       [16, 2, 32, 32]               0               0
          Conv2d-5       [16, 2, 32, 32]              38              38
          Conv2d-6       [16, 2, 32, 32]              38              38
       LeakyReLU-7       [16, 2, 32, 32]               0               0
          Conv2d-8       [16, 2, 32, 32]              38              38
          Conv2d-9       [16, 2, 32, 32]              38              38
      LeakyReLU-10       [16, 2, 32, 32]               0               0
         Conv2d-11       [16, 2, 32, 32]              38              38
         Conv2d-12       [16, 2, 32, 32]          

In [8]:
x = torch.randn(2, 3, 32, 32).cuda()
y = model(x)
print(y.shape)

torch.Size([2, 3, 128, 128])
