In [1]:
import torch
from torch import nn
import torch.nn.functional as F


In [None]:
class ConvBlock(nn.Module):
    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 kernel_size: int,
                 stride: int,
                 padding: int) -> None:
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.norm = nn.BatchNorm2d(out_channels)
        self.relu = nn.LeakyReLU(0.2, True)
    
    def forward(self,
                x: torch.Tensor) -> torch.Tensor:
        return self.relu(self.norm(self.conv(x)))

In [None]:
class Discriminator(nn.Module):
    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 min_channels: int,
                 kernel_size: int,
                 padding: int,
                 num_layers: int) -> None:
        super(Discriminator, self).__init__()
        self.head = ConvBlock(in_channels, out_channels, kernel_size, 1, padding)
        self.body = nn.Sequential(
            ConvBlock(max(out_channels // pow(2, i), min_channels),
                      max(out_channels // pow(2, i + 1), min_channels),
                      kernel_size, 1, padding)
            for i in range(num_layers - 2)
        )
        self.tail = ConvBlock(max(out_channels // pow(2, num_layers - 2), min_channels), 1, kernel_size, 1, padding)

    def forward(self,
                x: torch.Tensor) -> torch.Tensor:
        x = self.head(x)
        x = self.body(x)
        return self.tail(x)

In [None]:
class Generator(nn.Module):
    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 min_channels: int,
                 kernel_size: int,
                 padding: int,
                 num_layers: int) -> None:
        super(Generator, self).__init__()
        self.head = ConvBlock(in_channels, out_channels, kernel_size, 1, padding)
        self.body = nn.Sequential(
            ConvBlock(max(out_channels // pow(2, i), min_channels),
                      max(out_channels // pow(2, i + 1), min_channels),
                      kernel_size, 1, padding)
            for i in range(num_layers - 2)
        )
        self.tail = nn.Sequential(
            ConvBlock(max(out_channels // pow(2, num_layers - 2), min_channels), 1, kernel_size, 1, padding),
            nn.Tanh()
        )
    
    def forward(self,
                x: torch.Tensor,
                y: torch.Tensor) -> torch.Tensor:
        x = self.head(x)
        x = self.body(x)
        x = self.tail(x)
        