# VANILLA GAN - MNIST

## IMPORTS

In [1]:
import torch
import torchvision

import numpy as np
import matplotlib.pyplot as plt

import os
import time

In [2]:
device = ('cuda' if torch.cuda.is_available() else 'cpu')
device

'cuda'

## HYPERPARAMETERS

In [3]:
dataroot = '../storage/data/mnist_png/'
saveroot = '../storage/GAN_Images/'
test_number = 'Test_1'

batch_size = 32
workers = 4
transforms = torchvision.transforms.Compose([
    torchvision.transforms.Grayscale(1),
    torchvision.transforms.ToTensor()
])

nz = 100

lr = 0.0002
betas = (0.1, 0.999)

num_epochs = 10
num_steps = 1000
k = 1

## DATASET

In [4]:
ds = torchvision.datasets.ImageFolder(dataroot+'training', transforms)

In [5]:
len(ds)

60000

In [6]:
dl = torch.utils.data.DataLoader(ds, batch_size, shuffle=True, num_workers=workers, pin_memory=True)

## DATA ANALYSIS

In [7]:
for data in (dl):
    print(data[0].shape, data[1].shape)
    break

torch.Size([32, 1, 28, 28]) torch.Size([32])


## MODEL

In [8]:
class Generator(torch.nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        
        self.fc1 = torch.nn.Sequential(
            torch.nn.Linear(nz, 256),
            torch.nn.LeakyReLU(0.2)
        )
        self.fc2 = torch.nn.Sequential(
            torch.nn.Linear(256, 512),
            torch.nn.LeakyReLU(0.2)
        )
        self.fc3 = torch.nn.Sequential(
            torch.nn.Linear(512, 1024),
            torch.nn.LeakyReLU(0.2)
        )
        self.fc4 = torch.nn.Sequential(
            torch.nn.Linear(1024, 784),
            torch.nn.Tanh()
        )
        
    def forward(self, x):
        # x.shape = batch_size x nz
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        x = self.fc4(x)
        return x

In [9]:
class Discriminator(torch.nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        self.fc1 = torch.nn.Sequential(
            torch.nn.Linear(784 , 1024),
            torch.nn.LeakyReLU(0.2),
            torch.nn.Dropout(0.3)
        )
        self.fc2 = torch.nn.Sequential(
            torch.nn.Linear(1024 , 512),
            torch.nn.LeakyReLU(0.2),
            torch.nn.Dropout(0.3)
        )
        self.fc3 = torch.nn.Sequential(
            torch.nn.Linear(512 , 256),
            torch.nn.LeakyReLU(0.2),
            torch.nn.Dropout(0.3)
        )
        self.fc4 = torch.nn.Sequential(
            torch.nn.Linear(256 , 1),
            torch.nn.Sigmoid()
        )
        
    def forward(self, x):
        # (batch_size, 784) or (batch_size, 1, 28, 28)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        x = self.fc4(x)
        return x

In [10]:
discriminator = Discriminator().to(device)
generator = Generator().to(device)

In [11]:
for params in discriminator.parameters():
    print(params.shape)

torch.Size([1024, 784])
torch.Size([1024])
torch.Size([512, 1024])
torch.Size([512])
torch.Size([256, 512])
torch.Size([256])
torch.Size([1, 256])
torch.Size([1])


## LOSS AND OPTIMIZATION

In [12]:
opt_d = torch.optim.Adam(discriminator.parameters(), lr, betas)
opt_g = torch.optim.Adam(generator.parameters(), lr, betas)
criterion = torch.nn.BCELoss()

In [13]:
real_labels = torch.ones((batch_size, 1)).to(device)
fake_labels = torch.zeros((batch_size, 1)).to(device)
fixed_noise = torch.randn(batch_size, nz).to(device)

## PRE-TRAIN RUN

In [14]:
with torch.no_grad():
    #REAL BATCH
    for xb, _ in dl:
        sample = xb.to(device)
        break
    print(sample.shape)
    out = discriminator(sample)
    print(out.shape)
    loss = criterion(out, real_labels)
    print(loss)
    #FAKE BATCH
    sample = generator(fixed_noise)
    print(sample.shape)
    out = discriminator(sample)
    print(out.shape)
    loss = criterion(out, real_labels)
    print(loss)

torch.Size([32, 1, 28, 28])
torch.Size([32, 1])
tensor(0.7078, device='cuda:0')
torch.Size([32, 784])
torch.Size([32, 1])
tensor(0.7029, device='cuda:0')


## TRAIN

In [15]:
def train_discriminator(images):
    opt_d.zero_grad()
    
    real_images = images.to(device)
    fake_images = generator(torch.randn(batch_size, nz).to(device))
    
    real_out = discriminator(real_images)
    fake_out = discriminator(fake_images)
    
    real_loss = criterion(real_out, real_labels)
    fake_loss = criterion(fake_out, fake_labels)
    
    real_loss.backward()
    fake_loss.backward()
    opt_d.step()
    
    return real_loss + fake_loss

In [16]:
def train_generator():
    opt_g.zero_grad()

    out = discriminator(generator(torch.randn(batch_size, nz).to(device)))
    
    loss = criterion(out, real_labels)
    
    loss.backward()
    opt_g.step()
    
    return loss

In [None]:
start_time = time.time()
loss_D = []
loss_G = []

for epoch in range(num_epochs):
    epoch += 1
    loss_d = 0.0
    loss_g = 0.0
    
    for i,(images,_) in enumerate(dl):
        if i == num_steps:
            break
            
        for _ in range(k):
            loss_d += train_discriminator(images)
        loss_g += train_generator()
    
    loss_D.append(loss_d / i)
    loss_G.append(loss_g / i)
    
    print(f'{epoch}/{num_epochs} | loss_d: {loss_d/i:.8f} | loss_g: {loss_g/i:.8f} | Time: {time.time() - start_time:.0f} sec')
    if epoch%10==0:
        sample = generator(fixed_noise).detach()
        grid = torchvision.utils.make_grid(sample.view(-1, 1, 28, 28), nrow=8, pad_value=1, normalize=False)   
        torchvision.utils.save_image(grid.detach().cpu(), os.path.join(saveroot, '20210126_MNIST_VANILLA_GAN_{}_{}.jpg'.format(str(test_number), str(epoch).zfill(3))))

1/10 | loss_d: 2.38316011 | loss_g: 2.38316011 | Time: 16 sec
2/10 | loss_d: 2.33142662 | loss_g: 2.33142662 | Time: 28 sec
3/10 | loss_d: 1.95939195 | loss_g: 1.95939195 | Time: 38 sec
4/10 | loss_d: 1.59533298 | loss_g: 1.59533298 | Time: 47 sec
5/10 | loss_d: 1.70839918 | loss_g: 1.70839918 | Time: 57 sec
6/10 | loss_d: 1.89603281 | loss_g: 1.89603281 | Time: 67 sec
7/10 | loss_d: 2.02509904 | loss_g: 2.02509904 | Time: 77 sec
8/10 | loss_d: 2.14785838 | loss_g: 2.14785838 | Time: 86 sec
9/10 | loss_d: 2.27471638 | loss_g: 2.27471638 | Time: 95 sec


## TEST

## EVALUATE

In [None]:
plt.figure()

plt.plot(loss_D, label='LOSS D')
plt.plot(loss_G, label='LOSS G')
plt.legend()
plt.savefig('./MNIST_VANILLA_GAN_{}'.format(test_number))
plt.show()

## SAVE/ LOAD

In [None]:
torch.save(generator.state_dict(), './20210126_MNIST_VANILLA_GAN_G.pt')
torch.save(discriminator.state_dict(), './20210126_MNIST_VANILLA_GAN_D.pt')