In [1]:
import torch
import torch.nn as nn

# Generator

A GAN generator takes a random noise vector as input and produces a generated image. To make its architecture more reusable, you will pass both input and output shapes as parameters to the model. This way, you can use the same model with different sizes of input noise and images of varying shapes.

* Define self.generator as a sequential model.
* After the last gen_block, add a linear layer with the appropriate input size and the output size of out_dim.
* Add a sigmoid activation after the linear layer.
* In the forward() method, pass the model's input through self.generator.

In [3]:
class Generator(nn.Module):
    def __init__(self, in_dim, out_dim):
        super(Generator, self).__init__()
        # Define generator block
        self.generator = nn.Sequential(
            gen_block(in_dim, 256),
            gen_block(256, 512),
            gen_block(512, 1024),
          	# Add linear layer
            nn.Linear(in_dim, out_dim),
            # Add activation
            nn.Sigmoid(),
        )

    def forward(self, x):
      	# Pass input through generator
        return self.generator(x)

In [4]:
def gen_block(in_dim, out_dim):
    return nn.Sequential(
        nn.Linear(in_dim, out_dim),
        nn.BatchNorm1d(out_dim),
        nn.ReLU(inplace=True)
    )

That's a neat generator! Once trained, it will accept random noise of size in_dim as input, and produce the generated image of size out_dim!

# Discriminator

With the generator defined, the next step in building a GAN is to construct the discriminator. It takes the generator's output as input, and produces a binary prediction: is the input generated or real?

.

In [5]:
class Discriminator(nn.Module):
    def __init__(self, im_dim):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            disc_block(im_dim, 1024),
            disc_block(1024, 512),
            # Define last discriminator block
            disc_block(512, 256),
            # Add a linear layer
            nn.Linear(256, 1),
        )

    def forward(self, x):
        # Define the forward method
        return self.disc(x)

In [6]:
def disc_block(in_dim, out_dim):
    return nn.Sequential(
        nn.Linear(in_dim, out_dim),
        nn.LeakyReLU(0.2)
    )

Well done, a perfect discriminator! 

