In [4]:
import torch 
import pdb
import torch.nn as nn 
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST 
from torchvision.utils import make_grid
from tqdm.auto import tqdm 
import matplotlib.pyplot as plt

In [7]:
# visualization function 
def show(tensor, colour_channels = 1, size = (28, 28), num_images_to_display = 16):

    data = tensor.detach().cpu().view(-1, colour_channels, *size) # 'detach' disables gradient computation since I just want to display

    grid = make_grid(data[:num_images_to_display], nrows = 4).permute(1, 2, 0)

    plt.imshow(grid)
    plt.show()


In [8]:
# parameters / hyperparameters

epochs = 250

current_step = 0 

summary_step = 50

mean_generator_loss = 0

mean_discriminator_loss = 0

generator_input_size = 64 

learning_rate = 0.00001

loss_function = nn.BCEWithLogitsLoss()

batch_size = 128

device = 'cpu'

dataloader = DataLoader(MNIST(".", download = True, transform = transforms.ToTensor()), shuffle = True, batch_size = batch_size)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to .\MNIST\raw\train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:04<00:00, 1989623.47it/s]


Extracting .\MNIST\raw\train-images-idx3-ubyte.gz to .\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to .\MNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 137075.21it/s]


Extracting .\MNIST\raw\train-labels-idx1-ubyte.gz to .\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to .\MNIST\raw\t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:02<00:00, 680055.40it/s]


Extracting .\MNIST\raw\t10k-images-idx3-ubyte.gz to .\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to .\MNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 2192740.42it/s]

Extracting .\MNIST\raw\t10k-labels-idx1-ubyte.gz to .\MNIST\raw






In [16]:
# GENERATOR

def generator_block(input, output):

    return nn.Sequential(
        nn.Linear(input, output),
        nn.BatchNorm1d(output),
        nn.ReLU(inplace = True)
    )


def generator_noise_vector(number, generator_input_size):

    return torch.randn(number, generator_input_size).to(device)

In [17]:
class Generator(nn.Module):

    def __init__(self, generator_input_size = 64, image_dimension = 784, hidden_layer_dimension = 128):

        super().__init__()
        
        self.generator_ = nn.Sequential(

            generator_block(generator_input_size, hidden_layer_dimension),
            generator_block(hidden_layer_dimension, hidden_layer_dimension * 2),
            generator_block(hidden_layer_dimension * 2, hidden_layer_dimension * 4),
            generator_block(hidden_layer_dimension * 4, hidden_layer_dimension * 8),
            nn.Linear(hidden_layer_dimension * 8, image_dimension),               
            nn.Sigmoid()
        )

    
    def forward(self, noise_vector):

        noise_vector = generator_noise_vector()
        return self.generator_(noise_vector)
    

In [18]:
# DISCRIMINATOR 

def discriminator_block(input, output):

    return nn.Sequential(

        nn.Linear(input, output),
        nn.LeakyReLU(0.2)

    )

In [19]:
class Discriminator(nn.Module):

    def __init__(self, image_dimension = 784, hidden_layer_dimension = 256):

        super().__init__()

        self.discrimator_ = nn.Sequential(

            discriminator_block(image_dimension, hidden_layer_dimension),
            discriminator_block(hidden_layer_dimension, hidden_layer_dimension * 4),
            discriminator_block(image_dimension * 4, hidden_layer_dimension * 2),
            discriminator_block(image_dimension * 2, hidden_layer_dimension),
            nn.Linear(hidden_layer_dimension, 1)

        )

    def forward(self, image):

        return self.discrimator_(image)