LSGAN with MNIST data.

In [None]:
%matplotlib inline
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torchsummary import summary
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

# Set random seed for reproducibility
manualSeed = 999
# manualSeed = random.randint(1, 10000) # use if you want new results
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)

In [None]:
# directories
data_root = os.path.join(os.getcwd(), 'data')

sample_dir = os.path.join(os.getcwd(), 'gan-04-images')
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

checkpoint_dir = os.path.join(os.getcwd(), 'gan-04-checkopints')
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)

# hyperparameters
workers = 2
ngpu = 1
batch_size = 128

image_size = 64
nc = 1
nz = 100
ngf = 64
ndf = 64

num_epochs = 100
lr = 0.0002
beta1 = 0.5

In [None]:
transform = transforms.Compose([
                                transforms.ToTensor(),
                                transforms.Normalize([0.5], [0.5])
])

dataset = dset.MNIST(root=data_root,
                     train=True,
                     transform=transform,
                     download=True
                     )

dataloader = torch.utils.data.DataLoader(dataset=dataset,
                                         batch_size=batch_size,
                                         shuffle=True,
                                         num_workers=workers)

device = torch.device("cuda:0" if torch.cuda.is_available() and ngpu > 0 else "cpu")

print(len(dataloader))

real_batch = next(iter(dataloader))
plt.figure(figsize=(8, 8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64],
                                         padding=2,
                                         normalize=True).cpu(),
                                         (1, 2, 0)))

In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [None]:
class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(  # nz
            nn.ConvTranspose2d(nz, ngf*4, 4, 1, 0, bias=False),  # ngf*4@4*4
            nn.BatchNorm2d(ngf*4),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf*4, ngf*2, 3, 2, 1, bias=False),  # ngf*2@7*7
            nn.BatchNorm2d(ngf*2),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf*2, ngf, 4, 2, 1, bias=False),  # ngf@14*14
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),  # nc@28*28
            nn.Tanh()  # [-1, 1]
        )

    def forward(self, input):
        return self.main(input)

In [None]:
netG = Generator(ngpu).to(device)

if device.type == "cuda" and ngpu > 1:
    netG = nn.DataParallel(netG, list(range(ngpu)))

netG.apply(weights_init)
print(netG)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(  # nc@28*28
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),  # ndf@14*14
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf, ndf*2, 4, 2, 1, bias=False),  # ndf*2@7*7
            nn.BatchNorm2d(ndf*2),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf*2, 1, 4, 1, 0, bias=False),  # 1@4*4
            nn.Flatten(1, -1),
            nn.Linear(16, 1),
            nn.Sigmoid()
        )
    
    def forward(self, input):
        return self.main(input)

In [None]:
netD = Discriminator(ngpu).to(device)

if device.type == "cuda" and ngpu > 1:
    netD = nn.DataParallel(netD, list(range(ngpu)))

netD.apply(weights_init)

print(netD)

In [None]:
summary(netG, (100, 1, 1))
summary(netD, (1, 28, 28))

In [None]:
# criterion = nn.BCELoss()
d_optimizer = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
g_optimizer = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

In [None]:
def denorm(x):
    out = (x+1)/2
    return out

In [None]:
img_list = []
g_losses = []
d_losses = []
total_step = len(dataloader)

print("Starting Training Loop...")
for epoch in range(num_epochs):
    for i, (images, _) in enumerate(dataloader, 0):
        # train netD
        netD.zero_grad()

        images = images.to(device)

        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        outputs = netD(images)
        # d_loss_real = criterion(outputs, real_labels[:len(outputs)])
        d_loss_real = torch.mean((outputs-real_labels[:len(outputs)])**2)
        real_score = outputs

        z = torch.randn(batch_size, nz, 1, 1).to(device)
        fake_images = netG(z)
        outputs = netD(fake_images)
        # d_loss_fake = criterion(outputs, fake_labels[:len(outputs)])
        d_loss_fake = torch.mean(outputs**2)
        fake_score = outputs

        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        d_optimizer.step()

        # train netG
        netG.zero_grad()

        z = torch.randn(batch_size, nz, 1, 1).to(device)
        fake_images = netG(z)
        outputs = netD(fake_images)

        # g_loss = criterion(outputs, real_labels[:len(outputs)])
        g_loss = torch.mean((outputs-real_labels[:len(outputs)])**2)
        g_loss.backward()
        g_optimizer.step()

        # output training data
        if (i+1) == total_step:
            print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}' 
                  .format(epoch+1, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(), 
                          real_score.mean().item(), fake_score.mean().item()))
        
        g_losses.append(g_loss.item())
        d_losses.append(d_loss.item())

    if (epoch+1) == 1:
        vutils.save_image(denorm(images), os.path.join(sample_dir, 'real_images.png'))
    
    vutils.save_image(denorm(fake_images), os.path.join(sample_dir, 'fake_images-{}.png').format(epoch+1))
    img_list.append(vutils.make_grid(fake_images, padding=2, normalize=True))

In [None]:
torch.save(netG.state_dict(), os.path.join(checkpoint_dir, 'G.ckpt'))
torch.save(netD.state_dict(), os.path.join(checkpoint_dir, 'D.ckpt'))