In [1]:
# Use GAN to generate synthesized characters on MNIST dataset

In [2]:
import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image
from torchsummary import summary as PyTorchSummary

import numpy as np
import matplotlib.pyplot as plt

%matplotlib inline

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [4]:
latent_size = 16
hidden_size1 = 128
hidden_size2 = 64
image_size = 28*28
num_epochs = 200
batch_size = 100
sample_dir = 'samples'

In [5]:
# Create a directory if not exists
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

In [6]:
transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean = (0.5,0.5,0.5),
                                     std = (0.5, 0.5, 0.5))])

In [7]:
# MNIST dataset

In [8]:
mnist = torchvision.datasets.MNIST(root='./../data',
                                          train = True,
                                          transform = transform,
                                          download = True)

In [9]:
data_loader = torch.utils.data.DataLoader(dataset = mnist,
                                          batch_size = batch_size,
                                          shuffle = True)

## Build model

In [10]:
# Discriminator: a simple classification model

In [11]:
D = nn.Sequential(
    nn.Linear(image_size, hidden_size1),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size1, hidden_size2),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size2, 1),
    nn.Sigmoid())

In [12]:
PyTorchSummary(D, input_size = (1,image_size))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1               [-1, 1, 128]         100,480
         LeakyReLU-2               [-1, 1, 128]               0
            Linear-3                [-1, 1, 64]           8,256
         LeakyReLU-4                [-1, 1, 64]               0
            Linear-5                 [-1, 1, 1]              65
           Sigmoid-6                 [-1, 1, 1]               0
Total params: 108,801
Trainable params: 108,801
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.42
Estimated Total Size (MB): 0.42
----------------------------------------------------------------


In [13]:
# Generator: like an autoencoder

In [14]:
G = nn.Sequential(
    nn.Linear(latent_size, hidden_size2),
    nn.ReLU(),
    nn.Linear(hidden_size2, hidden_size1),
    nn.ReLU(),
    nn.Linear(hidden_size1, image_size),
    nn.Tanh())

In [15]:
PyTorchSummary(G, input_size = (1, latent_size))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                [-1, 1, 64]           1,088
              ReLU-2                [-1, 1, 64]               0
            Linear-3               [-1, 1, 128]           8,320
              ReLU-4               [-1, 1, 128]               0
            Linear-5               [-1, 1, 784]         101,136
              Tanh-6               [-1, 1, 784]               0
Total params: 110,544
Trainable params: 110,544
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.01
Params size (MB): 0.42
Estimated Total Size (MB): 0.44
----------------------------------------------------------------


In [16]:
# Device setting
D = D.to(device)
G = G.to(device)


In [17]:
# Binary cross entropy loss and optimizer
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0004)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0004)

In [18]:
def denorm(x):
    out = (x+1)/2
    return out.clamp(0, 1)

In [19]:
def reset_grad():
    d_optimizer.zero_grad()
    g_optimizer.zero_grad()

## Training

In [20]:
total_step = len(data_loader)

In [21]:
for epoch in range(num_epochs):
    for i, (images,_) in enumerate(data_loader):
        images = images.reshape(batch_size,-1).to(device)
        
        #create labels with are later used as input for BCE loss
        real_labels = torch.ones(batch_size,1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)
        
        # ================================================================== #
        #                      Train the discriminator                       #
        # ================================================================== #
        
        #compute BCE_loss using real images where BCE_loss(x,y) = -y*log(D(x)) - (1-y)*log(1-D(x))
        #second term of the loss is always zero for real cases because y=1 for real case
        outputs = D(images)
        d_loss_real = criterion(outputs, real_labels)
        real_score = outputs
        
        #compute BCEloss using fake images
        #First term of the loss is always zero for fake images
        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = G(z)
        outputs = D(fake_images)
        d_loss_fake = criterion(outputs, fake_labels)
        fake_score = outputs
        
        #backprop and optimize
        d_loss = 0.5*d_loss_real + 0.5*d_loss_fake
        reset_grad()
        d_loss.backward()
        d_optimizer.step()
        
        # ================================================================== #
        #                        Train the generator                         #
        # ================================================================== #

        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = G(z)
        outputs = D(fake_images)
        
        #train G to maximize log(D(G(z))) instead of minimizing log(1-D(G(z)))
        #use real label to keep the first term so loss = -log(D(G(z)))
        g_loss = criterion(outputs, real_labels)
        
        #back_prop
        reset_grad()
        g_loss.backward()
        g_optimizer.step()
        
        if (i+1) % 200 == 0:
            print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}' 
                  .format(epoch, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(), 
                          real_score.mean().item(), fake_score.mean().item()))
            
    # Save real images
    if (epoch+1) == 1:
        images = images.reshape(images.size(0), 1, 28, 28)
        save_image(denorm(images), os.path.join(sample_dir, 'real_images.png'))
        
    # Save sampled images
    fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
    save_image(denorm(fake_images), os.path.join(sample_dir, 'fake_images-{}.png'.format(epoch+1)))
    

Epoch [0/200], Step [200/600], d_loss: 0.0678, g_loss: 2.7981, D(x): 0.99, D(G(z)): 0.11
Epoch [0/200], Step [400/600], d_loss: 0.0393, g_loss: 4.4457, D(x): 1.00, D(G(z)): 0.07
Epoch [0/200], Step [600/600], d_loss: 0.0187, g_loss: 5.6974, D(x): 0.99, D(G(z)): 0.02
Epoch [1/200], Step [200/600], d_loss: 0.0134, g_loss: 5.5498, D(x): 0.98, D(G(z)): 0.01
Epoch [1/200], Step [400/600], d_loss: 0.0194, g_loss: 8.9748, D(x): 0.98, D(G(z)): 0.02
Epoch [1/200], Step [600/600], d_loss: 0.0556, g_loss: 3.4806, D(x): 0.97, D(G(z)): 0.07
Epoch [2/200], Step [200/600], d_loss: 0.1087, g_loss: 3.3210, D(x): 0.95, D(G(z)): 0.12
Epoch [2/200], Step [400/600], d_loss: 0.3090, g_loss: 5.6791, D(x): 0.92, D(G(z)): 0.27
Epoch [2/200], Step [600/600], d_loss: 0.2875, g_loss: 5.4709, D(x): 0.82, D(G(z)): 0.10
Epoch [3/200], Step [200/600], d_loss: 0.0936, g_loss: 3.3314, D(x): 0.94, D(G(z)): 0.07
Epoch [3/200], Step [400/600], d_loss: 0.2984, g_loss: 4.1858, D(x): 0.90, D(G(z)): 0.24
Epoch [3/200], Step [

Epoch [30/200], Step [600/600], d_loss: 0.2554, g_loss: 4.2133, D(x): 0.85, D(G(z)): 0.12
Epoch [31/200], Step [200/600], d_loss: 0.1566, g_loss: 4.3446, D(x): 0.92, D(G(z)): 0.13
Epoch [31/200], Step [400/600], d_loss: 0.1619, g_loss: 3.4819, D(x): 0.90, D(G(z)): 0.12
Epoch [31/200], Step [600/600], d_loss: 0.1719, g_loss: 4.1753, D(x): 0.93, D(G(z)): 0.15
Epoch [32/200], Step [200/600], d_loss: 0.1855, g_loss: 3.8644, D(x): 0.86, D(G(z)): 0.07
Epoch [32/200], Step [400/600], d_loss: 0.2484, g_loss: 3.9895, D(x): 0.84, D(G(z)): 0.09
Epoch [32/200], Step [600/600], d_loss: 0.2623, g_loss: 3.9333, D(x): 0.82, D(G(z)): 0.05
Epoch [33/200], Step [200/600], d_loss: 0.2431, g_loss: 3.6682, D(x): 0.88, D(G(z)): 0.15
Epoch [33/200], Step [400/600], d_loss: 0.1845, g_loss: 4.3754, D(x): 0.91, D(G(z)): 0.14
Epoch [33/200], Step [600/600], d_loss: 0.2494, g_loss: 3.4930, D(x): 0.84, D(G(z)): 0.11
Epoch [34/200], Step [200/600], d_loss: 0.2711, g_loss: 3.1319, D(x): 0.85, D(G(z)): 0.20
Epoch [34/

Epoch [61/200], Step [400/600], d_loss: 0.3246, g_loss: 2.8408, D(x): 0.80, D(G(z)): 0.20
Epoch [61/200], Step [600/600], d_loss: 0.3627, g_loss: 2.0886, D(x): 0.77, D(G(z)): 0.24
Epoch [62/200], Step [200/600], d_loss: 0.4890, g_loss: 1.4197, D(x): 0.75, D(G(z)): 0.32
Epoch [62/200], Step [400/600], d_loss: 0.3465, g_loss: 2.1262, D(x): 0.81, D(G(z)): 0.26
Epoch [62/200], Step [600/600], d_loss: 0.3653, g_loss: 2.4388, D(x): 0.72, D(G(z)): 0.17
Epoch [63/200], Step [200/600], d_loss: 0.3358, g_loss: 2.1180, D(x): 0.76, D(G(z)): 0.19
Epoch [63/200], Step [400/600], d_loss: 0.2913, g_loss: 2.4262, D(x): 0.76, D(G(z)): 0.14
Epoch [63/200], Step [600/600], d_loss: 0.4219, g_loss: 2.1263, D(x): 0.74, D(G(z)): 0.22
Epoch [64/200], Step [200/600], d_loss: 0.3512, g_loss: 2.2856, D(x): 0.75, D(G(z)): 0.19
Epoch [64/200], Step [400/600], d_loss: 0.3226, g_loss: 2.1820, D(x): 0.77, D(G(z)): 0.21
Epoch [64/200], Step [600/600], d_loss: 0.2946, g_loss: 1.8965, D(x): 0.84, D(G(z)): 0.25
Epoch [65/

Epoch [92/200], Step [200/600], d_loss: 0.4931, g_loss: 1.6364, D(x): 0.75, D(G(z)): 0.33
Epoch [92/200], Step [400/600], d_loss: 0.4927, g_loss: 2.0476, D(x): 0.64, D(G(z)): 0.21
Epoch [92/200], Step [600/600], d_loss: 0.3660, g_loss: 2.2112, D(x): 0.71, D(G(z)): 0.16
Epoch [93/200], Step [200/600], d_loss: 0.3033, g_loss: 2.2934, D(x): 0.79, D(G(z)): 0.21
Epoch [93/200], Step [400/600], d_loss: 0.4117, g_loss: 1.7223, D(x): 0.74, D(G(z)): 0.27
Epoch [93/200], Step [600/600], d_loss: 0.3396, g_loss: 2.4220, D(x): 0.74, D(G(z)): 0.20
Epoch [94/200], Step [200/600], d_loss: 0.4451, g_loss: 1.8214, D(x): 0.82, D(G(z)): 0.38
Epoch [94/200], Step [400/600], d_loss: 0.4386, g_loss: 2.0071, D(x): 0.74, D(G(z)): 0.31
Epoch [94/200], Step [600/600], d_loss: 0.4349, g_loss: 2.2797, D(x): 0.73, D(G(z)): 0.27
Epoch [95/200], Step [200/600], d_loss: 0.4101, g_loss: 1.7933, D(x): 0.74, D(G(z)): 0.29
Epoch [95/200], Step [400/600], d_loss: 0.3402, g_loss: 1.7571, D(x): 0.74, D(G(z)): 0.20
Epoch [95/

Epoch [122/200], Step [400/600], d_loss: 0.3974, g_loss: 2.0373, D(x): 0.70, D(G(z)): 0.21
Epoch [122/200], Step [600/600], d_loss: 0.5286, g_loss: 1.7015, D(x): 0.62, D(G(z)): 0.29
Epoch [123/200], Step [200/600], d_loss: 0.4659, g_loss: 1.4024, D(x): 0.74, D(G(z)): 0.35
Epoch [123/200], Step [400/600], d_loss: 0.5346, g_loss: 1.4817, D(x): 0.66, D(G(z)): 0.31
Epoch [123/200], Step [600/600], d_loss: 0.4771, g_loss: 1.6868, D(x): 0.71, D(G(z)): 0.29
Epoch [124/200], Step [200/600], d_loss: 0.5144, g_loss: 1.5466, D(x): 0.70, D(G(z)): 0.34
Epoch [124/200], Step [400/600], d_loss: 0.4012, g_loss: 1.3032, D(x): 0.74, D(G(z)): 0.30
Epoch [124/200], Step [600/600], d_loss: 0.4992, g_loss: 1.5007, D(x): 0.64, D(G(z)): 0.27
Epoch [125/200], Step [200/600], d_loss: 0.5994, g_loss: 1.6585, D(x): 0.59, D(G(z)): 0.33
Epoch [125/200], Step [400/600], d_loss: 0.4210, g_loss: 2.0550, D(x): 0.72, D(G(z)): 0.27
Epoch [125/200], Step [600/600], d_loss: 0.4350, g_loss: 1.8755, D(x): 0.69, D(G(z)): 0.25

Epoch [152/200], Step [600/600], d_loss: 0.4693, g_loss: 1.4709, D(x): 0.73, D(G(z)): 0.38
Epoch [153/200], Step [200/600], d_loss: 0.4699, g_loss: 1.4175, D(x): 0.70, D(G(z)): 0.33
Epoch [153/200], Step [400/600], d_loss: 0.5073, g_loss: 1.3271, D(x): 0.71, D(G(z)): 0.37
Epoch [153/200], Step [600/600], d_loss: 0.4978, g_loss: 1.8953, D(x): 0.60, D(G(z)): 0.26
Epoch [154/200], Step [200/600], d_loss: 0.5368, g_loss: 1.3654, D(x): 0.66, D(G(z)): 0.38
Epoch [154/200], Step [400/600], d_loss: 0.5139, g_loss: 1.2412, D(x): 0.70, D(G(z)): 0.38
Epoch [154/200], Step [600/600], d_loss: 0.5602, g_loss: 1.3278, D(x): 0.65, D(G(z)): 0.37
Epoch [155/200], Step [200/600], d_loss: 0.5732, g_loss: 1.1813, D(x): 0.65, D(G(z)): 0.39
Epoch [155/200], Step [400/600], d_loss: 0.5429, g_loss: 1.5184, D(x): 0.64, D(G(z)): 0.34
Epoch [155/200], Step [600/600], d_loss: 0.3765, g_loss: 1.6337, D(x): 0.77, D(G(z)): 0.31
Epoch [156/200], Step [200/600], d_loss: 0.5676, g_loss: 1.1725, D(x): 0.65, D(G(z)): 0.37

Epoch [183/200], Step [200/600], d_loss: 0.5479, g_loss: 1.5686, D(x): 0.60, D(G(z)): 0.31
Epoch [183/200], Step [400/600], d_loss: 0.4998, g_loss: 1.4488, D(x): 0.71, D(G(z)): 0.35
Epoch [183/200], Step [600/600], d_loss: 0.5366, g_loss: 1.2825, D(x): 0.65, D(G(z)): 0.36
Epoch [184/200], Step [200/600], d_loss: 0.4480, g_loss: 1.6191, D(x): 0.67, D(G(z)): 0.29
Epoch [184/200], Step [400/600], d_loss: 0.5245, g_loss: 1.2576, D(x): 0.70, D(G(z)): 0.42
Epoch [184/200], Step [600/600], d_loss: 0.4087, g_loss: 1.2268, D(x): 0.73, D(G(z)): 0.32
Epoch [185/200], Step [200/600], d_loss: 0.5076, g_loss: 1.5113, D(x): 0.63, D(G(z)): 0.33
Epoch [185/200], Step [400/600], d_loss: 0.5537, g_loss: 1.2256, D(x): 0.64, D(G(z)): 0.38
Epoch [185/200], Step [600/600], d_loss: 0.4275, g_loss: 1.8568, D(x): 0.65, D(G(z)): 0.24
Epoch [186/200], Step [200/600], d_loss: 0.5452, g_loss: 1.3895, D(x): 0.60, D(G(z)): 0.31
Epoch [186/200], Step [400/600], d_loss: 0.4433, g_loss: 1.2116, D(x): 0.67, D(G(z)): 0.29

In [22]:

# Save the model checkpoints 
torch.save(G.state_dict(), 'G.ckpt')
torch.save(D.state_dict(), 'D.ckpt')