# Models

In [1]:
import torch
import torch.nn as nn

In [None]:
IMG_SHAPE = (3, 256, 256)


# Vanilla GAN

In [2]:
class LinearBlock(nn.Module):
    """
    Linear block for MLP.
    """
    def __init__(self, in_features, out_features, activation=True, batch_norm=True):
        """
        Args:
            in_features (int): Number of input features.
            out_features (int): Number of output features.
            activation (bool): If True, ReLU activation is used.
            batch_norm (bool): If True, batch normalization is used.
        """
        super(LinearBlock, self).__init__()
        self.linear = nn.Linear(in_features, out_features)
        self.activation = nn.LeakyReLU(0.2) if activation else None
        self.batch_norm = nn.BatchNorm1d(out_features) if batch_norm else None
        
    def forward(self, x):
        x = self.linear(x)
        if self.activation:
            x = self.activation(x)
        if self.batch_norm:
            x = self.batch_norm(x)
        if self.dropout:
            x = self.dropout(x)
        return x


class Generator(nn.Module):
    def __init__(self, latent_dim, img_shape):
        super(Generator, self).__init__()
        self.latent_dim = latent_dim
        self.img_shape = img_shape
        self.model = nn.Sequential(
            LinearBlock(latent_dim, 256, batch_norm=False),
            LinearBlock(256, 512),
            LinearBlock(512, 1024),
            LinearBlock(1024, 2048),
            nn.Linear(2048, int(torch.prod(torch.tensor(img_shape)))),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.model(x)
        x = x.view(x.size(0), *self.img_shape)
        return x


class Discriminator(nn.Module):
    def __init__(self, img_shape):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            LinearBlock(int(torch.prod(torch.tensor(img_shape))), 512, batch_norm=False),
            LinearBlock(512, 256),
            LinearBlock(256, 1, activation=False, batch_norm=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.model(x)
        return x

# DC-GAN

In [None]:
class ConvTransposedBlock(nn.Module):
    """
    Convolutional transpose block for DCGAN.
    """
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, batch_norm=True):
        """
        Args:
            in_channels (int): Number of input channels.
            out_channels (int): Number of output channels.
            kernel_size (int): Kernel size.
            stride (int): Stride.
            padding (int): Padding.
            batch_norm (bool): If True, batch normalization is used.
        """
        super(ConvTransposedBlock, self).__init__()
        self.conv_transpose = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=False)
        self.batch_norm = nn.BatchNorm2d(out_channels) if batch_norm else None
        self.activation = nn.ReLU(True)
        
    def forward(self, x):
        x = self.conv_transpose(x)
        if self.batch_norm:
            x = self.batch_norm(x)
        x = self.activation(x)
        return x


class ConvBlock(nn.module):
    """
    Convolutional block for DCGAN.
    """
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, batch_norm=True):
        """
        Args:
            in_channels (int): Number of input channels.
            out_channels (int): Number of output channels.
            kernel_size (int): Kernel size.
            stride (int): Stride.
            padding (int): Padding.
            batch_norm (bool): If True, batch normalization is used.
        """
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False)
        self.batch_norm = nn.BatchNorm2d(out_channels) if batch_norm else None
        self.activation = nn.LeakyReLU(0.2, inplace=True)
        
    def forward(self, x):
        x = self.conv(x)
        if self.batch_norm:
            x = self.batch_norm(x)
        x = self.activation(x)
        return x


class Generator(nn.Module):
    def __init__(self, z_latent_vector_size, feature_maps_size, num_channels=3):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            # input: z_latent_vector_size x 1 x 1
            ConvTransposedBlock(z_latent_vector_size, feature_maps_size * 32, 4, 1, 0),
            # (feature_maps_size*32) x 4 x 4
            ConvTransposedBlock(feature_maps_size * 32, feature_maps_size * 16, 4, 1, 0),
            # (feature_maps_size*16) x 8 x 8
            ConvTransposedBlock(feature_maps_size * 16, feature_maps_size * 8, 4, 2, 1),
            # (feature_maps_size*8) x 16 x 16
            ConvTransposedBlock(feature_maps_size * 8, feature_maps_size * 4, 4, 2, 1),
            # (feature_maps_size*4) x 32 x 32
            ConvTransposedBlock(feature_maps_size * 4, feature_maps_size * 2, 4, 2, 1),
            # (feature_maps_size*2) x 64 x 64
            ConvTransposedBlock(feature_maps_size * 2, feature_maps_size, 4, 2, 1),
            # (feature_maps_size) x 128 x 128
            nn.ConvTranspose2d(feature_maps_size, num_channels, 4, 2, 1),
            nn.Tanh()
            # (num_channels) x 256 x 256
        )

    def forward(self, input):
        input = input.view(input.size(0), input.size(1), 1, 1)
        return self.model(input)
    

class Discriminator(nn.Module):
    def __init__(self, num_channels, feature_maps_size):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            # input: 3x256x256
            ConvBlock(num_channels, feature_maps_size, 4, 2, 1, batch_norm=False),
            # (feature_maps_size) x 128 x 128
            ConvBlock(feature_maps_size, feature_maps_size * 2, 4, 2, 1),
            # (feature_maps_size*2) x 64 x 64
            ConvBlock(feature_maps_size * 2, feature_maps_size * 4, 4, 2, 1),
            # (feature_maps_size*4) x 32 x 32
            ConvBlock(feature_maps_size * 4, feature_maps_size * 8, 4, 2, 1),
            # (feature_maps_size*8) x 16 x 16
            nn.Conv2d(feature_maps_size * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)
