## IMPORTS

In [None]:
import torch
import torch.nn as nn
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import os
import time

## DEVICE

In [None]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

device

## GLOBAL HYPERPARAMS

In [None]:
#CUSTOM TRANSFORM
class FlattenTransform:
    def __call__(self, inputs):
        return inputs.view(inputs.shape[0], -1)

In [None]:
batch_size = 96
nz = 128 

transforms = torchvision.transforms.Compose([
    torchvision.transforms.Grayscale(num_output_channels=1),
    torchvision.transforms.ToTensor(),
#   torchvision.transforms.Normalize((0.5,),(0.5,))
])

lr = 0.0002
momentum = 0.5
dampening = 0

epochs = 200
k=1

## DATASET

In [None]:
path = '../storage/data/mnist_png/training/'

ds = torchvision.datasets.ImageFolder(path, transforms)
dl = torch.utils.data.DataLoader(ds, batch_size, shuffle=True, num_workers=4)

In [None]:
for xb, yb in dl:
    print('xb: {}, \nyb: {}'.format(xb.shape, yb.shape))
    grid = torchvision.utils.make_grid(xb[:16], nrow=8)
    plt.imshow(grid.permute(1,2,0))
    break

## DEFINE MODELS

### MODEL VARIATIONS

#### GENERATOR
1. VANILLA I
    - fc1: Linear(nz,256) LeakyReLU(0.2)
    - fc2: Linear(256,512) LeakyReLU(0.2)
    - fc3: Linear(512,1024) LeakyReLU(0.2)
    - fc4: Linear(1024,784) Tanh()
    

#### DISCRIMINATOR
1. VANILLA I
    - fc1: Linear(784,1024) LeakyReLU(0.2) Dropout(0.3)
    - fc2: Linear(1024,512) LeakyReLU(0.2) Dropout(0.3)
    - fc3: Linear(512,256) LeakyReLU(0.2) Dropout(0.3)
    - fc4: Linear(256,1) Sigmoid()

In [None]:
#CUSTOM ACTIVATION LAYER

class Maxout(torch.nn.Module):
    def __init__(self, num_pieces):
        super(Maxout, self).__init__()
        self.num_pieces = num_pieces

    def forward(self, x):
        # x.shape = (batch_size x 625)
        assert x.shape[1] % self.num_pieces == 0  # 625 % 5 = 0

        ret = x.view(
            *x.shape[:1],  # batch_size
            x.shape[1] // self.num_pieces,  # piece-wise linear
            self.num_pieces,  # num_pieces
            *x.shape[2:]  # remaining dimensions if any
        )        
        # ret.shape = (batch_size x 125 x 5)
        # https://pytorch.org/docs/stable/torch.html#torch.max        
        ret, _ = ret.max(dim=2)
        # ret.shape = (batch_size? x 125)
        return ret

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        
        self.fc1 = nn.Sequential(
            nn.Linear(nz, 256),
            nn.LeakyReLU(0.2))
        self.fc2 = nn.Sequential(
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2))
        self.fc3 = nn.Sequential(
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2))
        self.fc4 = nn.Sequential(
            nn.Linear(1024, 784),
            nn.Tanh())
        
    def forward(self, x):
        #takes in a tensor size (batch_size x nz)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        x = self.fc4(x)
        x = x.view(x.size(0), -1)
        return x

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        self.fc1 = nn.Sequential(
            nn.Linear(784, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3))
        self.fc2 = nn.Sequential(
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3))
        self.fc3 = nn.Sequential(
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3))
        self.fc4 = nn.Sequential(
            nn.Linear(256, 1),
            nn.Sigmoid())
        
    def forward(self, x):
        #takes in a tensor size (batch_size x nc x 28 x 28)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        x = self.fc4(x)
        return x

In [None]:
generator = Generator().to(device)
discriminator = Discriminator().to(device)

## DEFINE OPTIMIZERS

In [None]:
opt_d = torch.optim.Adam(discriminator.parameters(), lr)
opt_g = torch.optim.Adam(generator.parameters(), lr)

In [None]:
#CUSTOM OPTIMIZER SCHEDULER

'''
discriminator_scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer=discriminator_optimizer,
    step_size=1,
    gamma=0.99,
    last_epoch=-1
)

generator_scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer=generator_optimizer,
    step_size=1,
    gamma=0.99,
    last_epoch=-1
)
'''

#CUSTOM LAMBDA DECAY LR

'''
class DecayLR:
    
    def __init__(self, _lr, _step_size):
        
        self.lr = _lr
        self.step_size = _step_size
    
    def __call__(self, _epoch):

        if _epoch % self.step_size == 0:
            self.lr = self.lr * 0.1
        
        return self.lr


discriminator_scheduler = torch.optim.lr_scheduler.LambdaLR(
    discriminator_optimizer,
    DecayLR(
        _lr=0.9,
        _step_size=100
    )
)

generator_scheduler = torch.optim.lr_scheduler.LambdaLR(
    generator_optimizer,
    DecayLR(
        _lr=0.9,
        _step_size=100
    )
)
'''

## DEFINE LOSS

In [None]:
criterion = nn.BCELoss()

## PRE-TRAIN TESTS

In [None]:
#SAMPLE BATCH TEST
for images, _ in dl:
    images = images.to(device)
    
    sample = torch.randn(32, nz).to(device)
    print("RANDOM SAMPLE TENSOR: ", sample.shape)
    
    sample = generator(sample.detach())
    print("GENERATED IMAGES:", sample.shape)
    
    d_g_z = discriminator(sample.detach())
    print("DISCRIMINATOR OUTPUTS:", d_g_z.shape)
     
    grid = torchvision.utils.make_grid(sample.view(-1, 1, 28, 28), nrow=8, pad_value=1, normalize=False)
    plt.imshow(grid.detach().cpu().permute(1,2,0))
    break

## DEFINE VISUALIZING FOLDER

In [None]:
if not os.path.exists('visuals'):
    os.mkdir('visuals')

## TRAIN

In [None]:
fixed_noise = torch.randn(32, nz).to(device)
real_labels = torch.ones(batch_size, 1).to(device)
fake_labels = torch.zeros(batch_size, 1).to(device)
print(real_labels.shape, fake_labels.shape)

In [None]:
start_time = time.time()

losses_g = []
losses_d = []

generator.train()
discriminator.train()

for e in range(epochs):
    _loss_g = 0.0
    _loss_d = 0.0
    for i, (images, _) in enumerate(dl):
        #DISCRIMINATOR
        for _ in range(k):
            opt_d.zero_grad()
            
            real_images = images.to(device) 
            fake_images = generator(torch.randn(batch_size, nz).to(device)).detach()
            
            real_outputs = discriminator(real_images)
            fake_outputs = discriminator(fake_images)
            
            real_loss = criterion(real_outputs, real_labels)
            fake_loss = criterion(fake_outputs, fake_labels)
            
            real_loss.backward()
            fake_loss.backward()
            opt_d.step()
            
            _loss_d += real_loss + fake_loss
            
        #GENERATOR
        opt_g.zero_grad()
        
        generated_images = generator(torch.randn(batch_size, nz).to(device))
        outputs = discriminator(generated_images)
        loss = criterion(outputs, real_labels)
        
        loss.backward()
        opt_g.step()
        
        _loss_g += loss
        
    e_loss_g = _loss_g/i
    e_loss_d = _loss_d/i
    losses_g.append(e_loss_g)
    losses_d.append(e_loss_d)
    print(f"Epoch {e} of {epochs}")
    print(f"Generator loss: {e_loss_g:.8f}, Discriminator loss: {e_loss_d:.8f}")
    print(f'Duration: {time.time() - start_time:.0f} seconds') # print the time elapsed   
    
    if e%10==0:
        sample = generator(fixed_noise.detach())
        grid = torchvision.utils.make_grid(sample.view(-1, 1, 28, 28), nrow=8, pad_value=1, normalize=False)   
        torchvision.utils.save_image(grid.detach().cpu(), os.path.join('visuals', 'MNIST_VANILLA_GAN_{}.jpg'.format(e)))


print(f'\nTOTAL DURATION: {time.time() - start_time:.0f} seconds') # print the time elapsed            

## EVALUATIONS

In [None]:
#LOSS PLOT
plt.figure()
plt.plot(losses_g, label="LOSS G")
plt.plot(losses_d, label="LOSS D")
plt.legend()
plt.show()

In [None]:
fig = plt.figure(figsize=(16, 10))
plt.plot(d_lr_ls, label='D LR')
plt.plot(g_lr_ls, label='G LR')
plt.legend()
plt.show();

In [None]:
# # LOAD MODEL
# discriminator.load_state_dict(torch.load('./mnist_vanilla_gan_discriminator.pt'))
# generator.load_state_dict(torch.load('./mnist_vanilla_gan_generator.pt'))

In [None]:

#IMAGES FROM LAST EPOCH
sample = generator(torch.randn(32, nz).to(device))
grid = torchvision.utils.make_grid(sample.view(-1, 1, 28, 28), nrow=8, pad_value=1, normalize=False)   
plt.imshow(grid.detach().cpu().permute(1,2,0))

## SAVING

In [None]:
torch.save(generator.state_dict(), './mnist_vanilla_gan_advanced_generator.pt')
torch.save(discriminator.state_dict(), './mnist_vanilla_gan_advanced_discriminator.pt')