In [1]:
import torch
from torch import nn

In [4]:
class ConvBlock(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        discriminator=False,
        use_act=True, #use activation 
        use_bn=True, #batch norm
        **kwargs
    ):
        super().__init__()
        self.use_act = use_act
        self.cnn = nn.Conv2d(in_channels, out_channels, **kwargs, bias=not use_bn)
        self.bn = nn.BatchNorm2d(out_channels) if use_bn else nn.Identity()
        self.act = (
            nn.LeakyReLU(0.2, inplace=True)
            if discriminator
            else nn.PReLU(num_parameters=out_channels)
        )
    
    def forward(self, x):
      return self.act(self.bn(self.cnn(x))) if self.use_act else self.bn(self.cnn(x))

In [5]:
class UpsampleBlock(nn.Module):
  def __init__(self, in_c, scale_factor):
    super().__init__()
    self.conv = nn.Conv2d(in_c, in_c * scale_factor ** 2, 3, 1, 1)
    self.ps = nn.PixelShuffle(scale_factor) #in_c * 4, H, W -> in_c, H*2, W*2
    self.act = nn.PReLU(num_parameters=in_c)

  def forward(self, x):
    return self.act(self.ps(self.conv(x)))

In [6]:
class ResidualBlock(nn.Module):
  def __init__(self, in_channels):
    super().__init__()
    self.block1 = ConvBlock(
        in_channels,
        in_channels,
        kernel_size = 3,
        stride = 1,
        padding = 1
    )

    self.block2 = ConvBlock(
        in_channels,
        in_channels,
        kernel_size = 3,
        stride = 1,
        padding = 1,
        use_act = False
    )

  def forward(self, x):
    out = self.block2(self.block1(x))
    return out + x

In [7]:
class Generator(nn.Module):
  

SyntaxError: ignored