In [1]:
# Use GAN to generate synthesized characters on MNIST dataset

In [2]:
import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image
from torchsummary import summary as PyTorchSummary

import numpy as np
import matplotlib.pyplot as plt

%matplotlib inline

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [4]:
latent_size = 64
num_epochs = 200
batch_size = 100
sample_dir = 'samples'

In [5]:
# Create a directory if not exists
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

In [6]:
transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean = (0.5,0.5,0.5),
                                     std = (0.5, 0.5, 0.5))])

In [7]:
# MNIST dataset

In [8]:
mnist = torchvision.datasets.MNIST(root='./../data',
                                          train = True,
                                          transform = transform,
                                          download = True)

In [9]:
data_loader = torch.utils.data.DataLoader(dataset = mnist,
                                          batch_size = batch_size,
                                          shuffle = True)

## Build model

In [10]:
image_size = (1,28,28)

In [11]:
# Discriminator: a simple classification model
#Input: (1,28,28) image
#output: 0 or 1 (true of fake image)

In [12]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=5, stride = 1), #1x28x28->8x24x24
            nn.BatchNorm2d(8),
            nn.ReLU(),
            nn.MaxPool2d(2, stride = 2))#8x12x12
        self.layer2 = nn.Sequential(
            nn.Conv2d(8, 16, kernel_size=5, stride = 1), #16x8x8
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(2, stride = 2)) #16x4x4
        self.layer3 = nn.Sequential(
            nn.Conv2d(16, 64, kernel_size=3, stride = 1), #64x2x2
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2, stride = 2)) #64x1x1
        self.fc = nn.Sequential(
            nn.Linear(64, 1),
            nn.Sigmoid())
    def forward(self,x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = out.reshape(out.size(0), -1)
        out = self.fc(out)
        return out

In [13]:
# Generator: an upsample model
#Input: (latent size,) vector
#output: (1,28,28) fake image

In [14]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.upsample = nn.Sequential(
            nn.ConvTranspose2d(64, 32, 5), #64x1x1 -> 32x5x5
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, 5, stride = 2), #-> 16x13x13
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 4, 3, stride = 2), #-> 4x27x27
            nn.BatchNorm2d(4),
            nn.ReLU(),
            nn.ConvTranspose2d(4, 1, 2), #-> 1x28x28
            nn.Tanh())
        
    def forward(self,x):
        out = x.reshape(x.size(0),x.size(1),1,1)
        out = self.upsample(out)
        return out

In [15]:
D = Discriminator().to(device)
G = Generator().to(device)

In [16]:
PyTorchSummary(D, input_size = (1,28,28))
PyTorchSummary(G, input_size = (latent_size,1))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 8, 24, 24]             208
       BatchNorm2d-2            [-1, 8, 24, 24]              16
              ReLU-3            [-1, 8, 24, 24]               0
         MaxPool2d-4            [-1, 8, 12, 12]               0
            Conv2d-5             [-1, 16, 8, 8]           3,216
       BatchNorm2d-6             [-1, 16, 8, 8]              32
              ReLU-7             [-1, 16, 8, 8]               0
         MaxPool2d-8             [-1, 16, 4, 4]               0
            Conv2d-9             [-1, 64, 2, 2]           9,280
      BatchNorm2d-10             [-1, 64, 2, 2]             128
             ReLU-11             [-1, 64, 2, 2]               0
        MaxPool2d-12             [-1, 64, 1, 1]               0
           Linear-13                    [-1, 1]              65
          Sigmoid-14                   

In [17]:
#Check Discriminator shape transform
x = torch.randn(100,1,28,28)
D(x).shape

torch.Size([100, 1])

In [18]:
#Check Generator shape transform
x = torch.randn(100,latent_size)
G(x).shape

torch.Size([100, 1, 28, 28])

In [19]:
# Binary cross entropy loss and optimizer
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0004)

In [20]:
def denorm(x):
    out = (x+1)/2
    return out.clamp(0, 1)

In [21]:
def reset_grad():
    d_optimizer.zero_grad()
    g_optimizer.zero_grad()

## Training

In [22]:
total_step = len(data_loader)

In [27]:
for epoch in range(10,num_epochs):
    for i, (images,_) in enumerate(data_loader):
        images = images.to(device)
        
        #create labels with are later used as input for BCE loss
        real_labels = torch.ones(batch_size).to(device)
        fake_labels = torch.zeros(batch_size).to(device)
        
        # ================================================================== #
        #                      Train the discriminator                       #
        # ================================================================== #
        
        #compute BCE_loss using real images where BCE_loss(x,y) = -y*log(D(x)) - (1-y)*log(1-D(x))
        #second term of the loss is always zero for real cases because y=1 for real case
        outputs = D(images)
        d_loss_real = criterion(outputs, real_labels)
        real_score = outputs
        
        #compute BCEloss using fake images
        #First term of the loss is always zero for fake images
        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = G(z)
        outputs = D(fake_images)
        d_loss_fake = criterion(outputs, fake_labels)
        fake_score = outputs
        
        #backprop and optimize
        d_loss = 0.5*d_loss_real + 0.5*d_loss_fake
        reset_grad()
        d_loss.backward()
        d_optimizer.step()
        
        # ================================================================== #
        #                        Train the generator                         #
        # ================================================================== #

        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = G(z)
        outputs = D(fake_images)
        
        #train G to maximize log(D(G(z))) instead of minimizing log(1-D(G(z)))
        #use real label to keep the first term so loss = -log(D(G(z)))
        g_loss = criterion(outputs, real_labels)
        
        #back_prop
        reset_grad()
        g_loss.backward()
        g_optimizer.step()
        
        if (i+1) % 50 == 0:
            print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}' 
                  .format(epoch, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(), 
                          real_score.mean().item(), fake_score.mean().item()))
            
    # Save real images
    if (epoch+1) == 1:
        images = images.reshape(images.size(0), 1, 28, 28)
        save_image(denorm(images), os.path.join(sample_dir, 'real_images.png'))
        
    # Save sampled images
    fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
    save_image(denorm(fake_images), os.path.join(sample_dir, 'fake_images-{}.png'.format(epoch+1)))
    

  "Please ensure they have the same size.".format(target.size(), input.size()))


Epoch [10/200], Step [50/600], d_loss: 0.3875, g_loss: 1.7711, D(x): 0.66, D(G(z)): 0.25
Epoch [10/200], Step [100/600], d_loss: 0.2571, g_loss: 1.5961, D(x): 0.80, D(G(z)): 0.22
Epoch [10/200], Step [150/600], d_loss: 0.4535, g_loss: 1.5719, D(x): 0.62, D(G(z)): 0.27
Epoch [10/200], Step [200/600], d_loss: 0.2656, g_loss: 1.6594, D(x): 0.80, D(G(z)): 0.23
Epoch [10/200], Step [250/600], d_loss: 0.3007, g_loss: 1.1759, D(x): 0.80, D(G(z)): 0.27
Epoch [10/200], Step [300/600], d_loss: 0.3576, g_loss: 1.4825, D(x): 0.77, D(G(z)): 0.31
Epoch [10/200], Step [350/600], d_loss: 0.4488, g_loss: 1.7201, D(x): 0.77, D(G(z)): 0.39
Epoch [10/200], Step [400/600], d_loss: 0.3832, g_loss: 1.6209, D(x): 0.74, D(G(z)): 0.31
Epoch [10/200], Step [450/600], d_loss: 0.3943, g_loss: 1.6288, D(x): 0.71, D(G(z)): 0.30
Epoch [10/200], Step [500/600], d_loss: 0.3904, g_loss: 1.4326, D(x): 0.75, D(G(z)): 0.33
Epoch [10/200], Step [550/600], d_loss: 0.3559, g_loss: 1.4385, D(x): 0.75, D(G(z)): 0.30
Epoch [10/2

Epoch [17/200], Step [450/600], d_loss: 0.3384, g_loss: 2.1257, D(x): 0.68, D(G(z)): 0.19
Epoch [17/200], Step [500/600], d_loss: 0.2695, g_loss: 1.5804, D(x): 0.78, D(G(z)): 0.21
Epoch [17/200], Step [550/600], d_loss: 0.3916, g_loss: 2.0652, D(x): 0.63, D(G(z)): 0.19
Epoch [17/200], Step [600/600], d_loss: 0.4307, g_loss: 1.6960, D(x): 0.71, D(G(z)): 0.33
Epoch [18/200], Step [50/600], d_loss: 0.3433, g_loss: 1.6100, D(x): 0.77, D(G(z)): 0.29
Epoch [18/200], Step [100/600], d_loss: 0.3669, g_loss: 1.9429, D(x): 0.72, D(G(z)): 0.28
Epoch [18/200], Step [150/600], d_loss: 0.3532, g_loss: 1.7873, D(x): 0.77, D(G(z)): 0.29
Epoch [18/200], Step [200/600], d_loss: 0.4261, g_loss: 1.5741, D(x): 0.69, D(G(z)): 0.31
Epoch [18/200], Step [250/600], d_loss: 0.3325, g_loss: 1.4567, D(x): 0.78, D(G(z)): 0.28
Epoch [18/200], Step [300/600], d_loss: 0.3245, g_loss: 1.4491, D(x): 0.77, D(G(z)): 0.25
Epoch [18/200], Step [350/600], d_loss: 0.2947, g_loss: 1.4242, D(x): 0.80, D(G(z)): 0.26
Epoch [18/2

Epoch [25/200], Step [250/600], d_loss: 0.2592, g_loss: 1.7561, D(x): 0.79, D(G(z)): 0.19
Epoch [25/200], Step [300/600], d_loss: 0.3808, g_loss: 1.9944, D(x): 0.68, D(G(z)): 0.23
Epoch [25/200], Step [350/600], d_loss: 0.3326, g_loss: 1.8937, D(x): 0.76, D(G(z)): 0.28
Epoch [25/200], Step [400/600], d_loss: 0.4577, g_loss: 2.0802, D(x): 0.65, D(G(z)): 0.29
Epoch [25/200], Step [450/600], d_loss: 0.2780, g_loss: 2.0294, D(x): 0.81, D(G(z)): 0.24
Epoch [25/200], Step [500/600], d_loss: 0.3135, g_loss: 1.7066, D(x): 0.82, D(G(z)): 0.30
Epoch [25/200], Step [550/600], d_loss: 0.3024, g_loss: 2.2239, D(x): 0.76, D(G(z)): 0.22
Epoch [25/200], Step [600/600], d_loss: 0.3369, g_loss: 1.7017, D(x): 0.69, D(G(z)): 0.19
Epoch [26/200], Step [50/600], d_loss: 0.4111, g_loss: 1.9638, D(x): 0.70, D(G(z)): 0.28
Epoch [26/200], Step [100/600], d_loss: 0.2112, g_loss: 1.5541, D(x): 0.82, D(G(z)): 0.16
Epoch [26/200], Step [150/600], d_loss: 0.2945, g_loss: 1.5472, D(x): 0.82, D(G(z)): 0.28
Epoch [26/2

Epoch [33/200], Step [50/600], d_loss: 0.3813, g_loss: 1.9439, D(x): 0.70, D(G(z)): 0.26
Epoch [33/200], Step [100/600], d_loss: 0.2721, g_loss: 2.0062, D(x): 0.85, D(G(z)): 0.26
Epoch [33/200], Step [150/600], d_loss: 0.3273, g_loss: 2.1747, D(x): 0.68, D(G(z)): 0.16
Epoch [33/200], Step [200/600], d_loss: 0.2628, g_loss: 1.8408, D(x): 0.83, D(G(z)): 0.22
Epoch [33/200], Step [250/600], d_loss: 0.2176, g_loss: 1.6243, D(x): 0.84, D(G(z)): 0.18
Epoch [33/200], Step [300/600], d_loss: 0.2637, g_loss: 2.2238, D(x): 0.82, D(G(z)): 0.23
Epoch [33/200], Step [350/600], d_loss: 0.3512, g_loss: 1.5757, D(x): 0.69, D(G(z)): 0.22
Epoch [33/200], Step [400/600], d_loss: 0.2394, g_loss: 1.8640, D(x): 0.88, D(G(z)): 0.24
Epoch [33/200], Step [450/600], d_loss: 0.4168, g_loss: 2.0478, D(x): 0.71, D(G(z)): 0.31
Epoch [33/200], Step [500/600], d_loss: 0.2347, g_loss: 2.1236, D(x): 0.83, D(G(z)): 0.20
Epoch [33/200], Step [550/600], d_loss: 0.3096, g_loss: 2.2722, D(x): 0.71, D(G(z)): 0.18
Epoch [33/2

Epoch [40/200], Step [450/600], d_loss: 0.3220, g_loss: 2.8006, D(x): 0.78, D(G(z)): 0.24
Epoch [40/200], Step [500/600], d_loss: 0.2967, g_loss: 1.8832, D(x): 0.77, D(G(z)): 0.23
Epoch [40/200], Step [550/600], d_loss: 0.3093, g_loss: 2.5733, D(x): 0.72, D(G(z)): 0.17
Epoch [40/200], Step [600/600], d_loss: 0.3289, g_loss: 1.3424, D(x): 0.80, D(G(z)): 0.27
Epoch [41/200], Step [50/600], d_loss: 0.4305, g_loss: 1.9517, D(x): 0.81, D(G(z)): 0.40
Epoch [41/200], Step [100/600], d_loss: 0.4475, g_loss: 1.9113, D(x): 0.67, D(G(z)): 0.28
Epoch [41/200], Step [150/600], d_loss: 0.2874, g_loss: 1.7895, D(x): 0.89, D(G(z)): 0.31
Epoch [41/200], Step [200/600], d_loss: 0.3776, g_loss: 1.4294, D(x): 0.79, D(G(z)): 0.30
Epoch [41/200], Step [250/600], d_loss: 0.4116, g_loss: 2.1927, D(x): 0.72, D(G(z)): 0.32
Epoch [41/200], Step [300/600], d_loss: 0.3610, g_loss: 1.9147, D(x): 0.83, D(G(z)): 0.34
Epoch [41/200], Step [350/600], d_loss: 0.2323, g_loss: 1.8365, D(x): 0.85, D(G(z)): 0.21
Epoch [41/2

Epoch [48/200], Step [250/600], d_loss: 0.1758, g_loss: 2.2862, D(x): 0.83, D(G(z)): 0.12
Epoch [48/200], Step [300/600], d_loss: 0.2618, g_loss: 2.5007, D(x): 0.78, D(G(z)): 0.18
Epoch [48/200], Step [350/600], d_loss: 0.4178, g_loss: 1.8279, D(x): 0.82, D(G(z)): 0.38
Epoch [48/200], Step [400/600], d_loss: 0.2563, g_loss: 2.3056, D(x): 0.84, D(G(z)): 0.22
Epoch [48/200], Step [450/600], d_loss: 0.2058, g_loss: 1.6744, D(x): 0.82, D(G(z)): 0.15
Epoch [48/200], Step [500/600], d_loss: 0.2911, g_loss: 2.4487, D(x): 0.85, D(G(z)): 0.29
Epoch [48/200], Step [550/600], d_loss: 0.2714, g_loss: 1.9099, D(x): 0.77, D(G(z)): 0.14
Epoch [48/200], Step [600/600], d_loss: 0.3203, g_loss: 2.0668, D(x): 0.81, D(G(z)): 0.27
Epoch [49/200], Step [50/600], d_loss: 0.5245, g_loss: 1.9724, D(x): 0.77, D(G(z)): 0.41
Epoch [49/200], Step [100/600], d_loss: 0.3287, g_loss: 2.0643, D(x): 0.76, D(G(z)): 0.25
Epoch [49/200], Step [150/600], d_loss: 0.3069, g_loss: 2.2365, D(x): 0.74, D(G(z)): 0.19
Epoch [49/2

Epoch [56/200], Step [50/600], d_loss: 0.2476, g_loss: 1.5545, D(x): 0.87, D(G(z)): 0.24
Epoch [56/200], Step [100/600], d_loss: 0.2897, g_loss: 1.9237, D(x): 0.78, D(G(z)): 0.21
Epoch [56/200], Step [150/600], d_loss: 0.3502, g_loss: 2.4611, D(x): 0.75, D(G(z)): 0.26
Epoch [56/200], Step [200/600], d_loss: 0.3156, g_loss: 2.3015, D(x): 0.84, D(G(z)): 0.29
Epoch [56/200], Step [250/600], d_loss: 0.3032, g_loss: 2.5014, D(x): 0.82, D(G(z)): 0.28
Epoch [56/200], Step [300/600], d_loss: 0.3039, g_loss: 2.6448, D(x): 0.83, D(G(z)): 0.28
Epoch [56/200], Step [350/600], d_loss: 0.2662, g_loss: 2.1514, D(x): 0.73, D(G(z)): 0.13
Epoch [56/200], Step [400/600], d_loss: 0.2994, g_loss: 1.9672, D(x): 0.80, D(G(z)): 0.23
Epoch [56/200], Step [450/600], d_loss: 0.2931, g_loss: 2.3541, D(x): 0.94, D(G(z)): 0.36
Epoch [56/200], Step [500/600], d_loss: 0.2503, g_loss: 2.4577, D(x): 0.75, D(G(z)): 0.12
Epoch [56/200], Step [550/600], d_loss: 0.2620, g_loss: 2.5812, D(x): 0.85, D(G(z)): 0.22
Epoch [56/2

Epoch [63/200], Step [450/600], d_loss: 0.3637, g_loss: 1.6952, D(x): 0.78, D(G(z)): 0.28
Epoch [63/200], Step [500/600], d_loss: 0.1913, g_loss: 1.7691, D(x): 0.82, D(G(z)): 0.12
Epoch [63/200], Step [550/600], d_loss: 0.1923, g_loss: 1.8936, D(x): 0.88, D(G(z)): 0.17
Epoch [63/200], Step [600/600], d_loss: 0.2857, g_loss: 2.4976, D(x): 0.82, D(G(z)): 0.24
Epoch [64/200], Step [50/600], d_loss: 0.2130, g_loss: 2.1264, D(x): 0.80, D(G(z)): 0.15
Epoch [64/200], Step [100/600], d_loss: 0.2527, g_loss: 2.2278, D(x): 0.87, D(G(z)): 0.25
Epoch [64/200], Step [150/600], d_loss: 0.2407, g_loss: 1.8175, D(x): 0.78, D(G(z)): 0.14
Epoch [64/200], Step [200/600], d_loss: 0.1877, g_loss: 2.9974, D(x): 0.86, D(G(z)): 0.16
Epoch [64/200], Step [250/600], d_loss: 0.4162, g_loss: 2.9728, D(x): 0.64, D(G(z)): 0.20
Epoch [64/200], Step [300/600], d_loss: 0.3511, g_loss: 1.8238, D(x): 0.71, D(G(z)): 0.20
Epoch [64/200], Step [350/600], d_loss: 0.2151, g_loss: 2.2640, D(x): 0.86, D(G(z)): 0.18
Epoch [64/2

Epoch [71/200], Step [250/600], d_loss: 0.4154, g_loss: 1.8453, D(x): 0.87, D(G(z)): 0.40
Epoch [71/200], Step [300/600], d_loss: 0.1957, g_loss: 2.4662, D(x): 0.88, D(G(z)): 0.19
Epoch [71/200], Step [350/600], d_loss: 0.2680, g_loss: 2.3476, D(x): 0.76, D(G(z)): 0.14
Epoch [71/200], Step [400/600], d_loss: 0.1063, g_loss: 2.3623, D(x): 0.92, D(G(z)): 0.10
Epoch [71/200], Step [450/600], d_loss: 0.3034, g_loss: 2.4694, D(x): 0.71, D(G(z)): 0.16
Epoch [71/200], Step [500/600], d_loss: 0.2614, g_loss: 2.3699, D(x): 0.80, D(G(z)): 0.18
Epoch [71/200], Step [550/600], d_loss: 0.4064, g_loss: 1.8180, D(x): 0.69, D(G(z)): 0.23
Epoch [71/200], Step [600/600], d_loss: 0.4471, g_loss: 2.4275, D(x): 0.62, D(G(z)): 0.21
Epoch [72/200], Step [50/600], d_loss: 0.2764, g_loss: 2.1627, D(x): 0.76, D(G(z)): 0.17
Epoch [72/200], Step [100/600], d_loss: 0.3152, g_loss: 1.2454, D(x): 0.84, D(G(z)): 0.30
Epoch [72/200], Step [150/600], d_loss: 0.2175, g_loss: 2.1448, D(x): 0.90, D(G(z)): 0.22
Epoch [72/2

Epoch [79/200], Step [50/600], d_loss: 0.2614, g_loss: 2.4370, D(x): 0.76, D(G(z)): 0.17
Epoch [79/200], Step [100/600], d_loss: 0.2718, g_loss: 2.0625, D(x): 0.78, D(G(z)): 0.19
Epoch [79/200], Step [150/600], d_loss: 0.2930, g_loss: 2.8409, D(x): 0.71, D(G(z)): 0.12
Epoch [79/200], Step [200/600], d_loss: 0.2245, g_loss: 1.7334, D(x): 0.87, D(G(z)): 0.21
Epoch [79/200], Step [250/600], d_loss: 0.3044, g_loss: 2.2151, D(x): 0.82, D(G(z)): 0.25
Epoch [79/200], Step [300/600], d_loss: 0.2955, g_loss: 2.5227, D(x): 0.70, D(G(z)): 0.11
Epoch [79/200], Step [350/600], d_loss: 0.1578, g_loss: 1.7980, D(x): 0.86, D(G(z)): 0.13
Epoch [79/200], Step [400/600], d_loss: 0.3169, g_loss: 1.9951, D(x): 0.80, D(G(z)): 0.26
Epoch [79/200], Step [450/600], d_loss: 0.2024, g_loss: 2.2514, D(x): 0.90, D(G(z)): 0.22
Epoch [79/200], Step [500/600], d_loss: 0.2419, g_loss: 1.7416, D(x): 0.80, D(G(z)): 0.18
Epoch [79/200], Step [550/600], d_loss: 0.2958, g_loss: 1.8567, D(x): 0.73, D(G(z)): 0.16
Epoch [79/2

KeyboardInterrupt: 

In [None]:
G(x)

In [24]:
z = torch.randn(batch_size, latent_size,1,1).to(device)
fake_images = G(z)

In [28]:

# Save the model checkpoints 
torch.save(G.state_dict(), 'Gcnn.ckpt')
torch.save(D.state_dict(), 'Dcnn.ckpt')