In [23]:
import torch
import torch.nn as nn
from torchvision.datasets import MNIST
from torchvision import transforms
import torch.optim as optim 

import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

In [30]:
class MNISTGenerator(nn.Module):
    def __init__(self, latent_dimension):
        super().__init__()

        self.generator = nn.Sequential(
            nn.Linear(latent_dimension, 64), 
            nn.ReLU(),
            nn.Linear(64, 128), 
            nn.ReLU(), 
            nn.Linear(128, 256),
            nn.ReLU(), 
            nn.Linear(256, 28*28),
            nn.Tanh() # Ensures output of Generator is between -1,1
        )

    def forward(self, noise):
        return self.generator(noise)

class MNISTDiscriminator(nn.Module):
    def __init__(self):
        super().__init__()

        self.discriminator = nn.Sequential(
            nn.Linear(28*28, 256), 
            nn.ReLU(),
            nn.Linear(256, 128), 
            nn.ReLU(),
            nn.Linear(128, 64), 
            nn.ReLU(),
            nn.Linear(128, 1)            
        )

    def forward(self, x):
        batch_size = x.shape[0]
        x = x.reshape(batch_size, -1)
        return self.discriminator(x)

In [31]:
latent_dimension = 64
batch_size = 128
generator_learning_rate = 0.001
discriminator_learning_rate = 0.001
device = "cuda" if torch.cuda.is_available() else "cpu"

In [38]:
### Define Models ###
generator = MNISTGenerator(latent_dimension).to(device)
discriminator = MNISTDiscriminator().to(device)

### Define Optimizers ###
gen_optimizer = optim.Adam(generator.parameters(), generator_learning_rate)
disc_optimizer = optim.Adam(discriminator.parameters(), discriminator_learning_rate)

### Define Datasets ###
tensor2image_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda t: (t*2) - 1)
])


trainset = MNIST("../../data", transform=tensor2image_transforms)

In [13]:
generator

MNISTGenerator(
  (generator): Sequential(
    (0): Linear(in_features=64, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=128, bias=True)
    (3): ReLU()
    (4): Linear(in_features=128, out_features=256, bias=True)
    (5): ReLU()
    (6): Linear(in_features=256, out_features=784, bias=True)
    (7): Tanh()
  )
)