In [2]:
# import the libraries
import torch,pdb
from torch.utils.data import DataLoader
from torch import nn
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from tqdm.auto import tqdm

import matplotlib.pyplot as plt


In [3]:
# visualizatin function

def show(tensor, ch=1, size=(28,28), num=16):
    # tensor : 128(batch size) * 784(28*28)
    data = tensor.detach().cpu().view(-1,ch,*size) # view will reshape 128*1*28*28(-1 whatever remains after doing 1*28*28) # detach for removing gradients
    grid = make_grid(data[:num], nrows=4).permute(1,2,0) # (we are going to take 16 pictures)
    # use of permute because pytorch takes channel *w*h matplot takes as w*h*channel
    plt.imshow(grid)
    plt.show()

In [4]:
# setup of the main parameters and hyperparameters

epochs = 500
cur_step = 0
info_step = 300


mean_gen_loss = 0
mean_disc_loss = 0

z_dim = 64  # noise vector
lr = 0.00001
loss_func = nn.BCEWithLogitsLoss()

bs = 128
device = 'cuda'

dataloader = DataLoader(MNIST('.',download=True, transform=transforms.ToTensor()),shuffle = True,batch_size=bs)

# number of steps -> 60000/128 ~ 468



Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 10778284.12it/s]


Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 10422964.53it/s]


Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 3106078.28it/s]


Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 4501542.71it/s]

Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw






In [None]:
# declare our models

#Generator
def genBlock(inp, out):
    return nn.Sequential(
        nn.Linear(inp,out),
        nn.BatchNorm1d(out),
        nn.ReLU(inplace=True))


class Generator(nn.Module):
    def __init__(self,z_dim=64,i_dim=784,h_dim=128):
        super().__init__()
        self.gen = nn.Sequential(
            genBlock(z_dim,h_dim), # 64,128
            genBlock(h_dim,h_dim*2), # 128,256
            genBlock(h_dim*2,h_dim*4), # 256 * 512
            genBlock(h_dim*4,h_dim*8), # 512 1024
            nn.Linear(h_dim*8, i_dim),
            nn.Sigmoid()
            )
        
    def forward(self,noise):
        return self.gen(noise)
    

def gen_noise(number, z_dim):
    return torch.randn(number,z_dim).to(device)


# Discriminator
def discBlock(inp, out):
    return nn.Sequential(
        nn.Linear(inp,out),
        nn.LeakyReLU(0.2)
    )

class Discriminator(nn.Module):
    def __init__(self,i_dim = 784, h_dim =256):
        super().__init__()
        self.disc = nn.Sequential(
            discBlock(i_dim, h_dim*4), # 784,1024
            discBlock(h_dim*4, h_dim*2),
            discBlock(h_dim*2, h_dim),
            nn.Linear(h_dim,1) # 251 1
     
        )

    def forward(self, image):
        return self.disc(image)

