In [2]:
import torch
import torch.nn as nn
from torchinfo import summary

In [18]:
class ResidualBlock(nn.Module):
    def __init__(self):
        super().__init__()
        self.residual = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.PReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64)
        )
    def forward(self, x):
        return x + self.residual(x)

class Generator(nn.Module):
    def __init__(self, BLOCK=16, scale=2):
        
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=9, stride=1, padding=4)
        self.prelu = nn.PReLU()
        self.res_blocks = nn.Sequential(*[ResidualBlock() for _ in range(BLOCK)])
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.bn = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 256, kernel_size=3, stride=1, padding=1)
        self.pixel_shuffle = nn.Sequential(*[nn.PixelShuffle(scale) for _ in range(2)])
        self.conv3_ = nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=4)
        self.conv4 = nn.Conv2d(16, 3, kernel_size=9, stride=1, padding=1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.prelu(x)
        res = self.res_blocks(x)
        res = self.conv2(res)
        res = self.bn(res)
        x = x + res
        x = self.conv3(x)
        x = self.pixel_shuffle(x)
        x = self.prelu(x)
        x = self.conv3_(x)
        x = self.prelu(x)
        x = self.conv4(x)       
        return x

    


In [20]:
model = Generator()

In [21]:
summary(model, input_size=(1, 3, 64, 64))

Layer (type:depth-idx)                   Output Shape              Param #
Generator                                [1, 3, 256, 256]          --
├─Conv2d: 1-1                            [1, 64, 64, 64]           15,616
├─PReLU: 1-2                             [1, 64, 64, 64]           1
├─Sequential: 1-3                        [1, 64, 64, 64]           --
│    └─ResidualBlock: 2-1                [1, 64, 64, 64]           --
│    │    └─Sequential: 3-1              [1, 64, 64, 64]           74,113
│    └─ResidualBlock: 2-2                [1, 64, 64, 64]           --
│    │    └─Sequential: 3-2              [1, 64, 64, 64]           74,113
│    └─ResidualBlock: 2-3                [1, 64, 64, 64]           --
│    │    └─Sequential: 3-3              [1, 64, 64, 64]           74,113
│    └─ResidualBlock: 2-4                [1, 64, 64, 64]           --
│    │    └─Sequential: 3-4              [1, 64, 64, 64]           74,113
│    └─ResidualBlock: 2-5                [1, 64, 64, 64]          