<a href="https://colab.research.google.com/github/qinliuliuqin/GenerativeModels/blob/main/vanilla_gan.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np
import os
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F


# build datasets
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, 
                                transform=transforms.ToTensor())
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, 
                               transform=transforms.ToTensor())

mb_size = 64
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=mb_size, 
                                           shuffle=True, num_workers=2)

test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=mb_size, 
                                          shuffle=False, num_workers=2)

Z_dim = 100
X_dim = 28
Y_dim=28
h_dim = 128
c = 0
lr = 1e-3

def xavier_init(size):
    in_dim = size[0]
    xavier_stddev = 1. / np.sqrt(in_dim / 2.)
    return torch.randn(*size) * xavier_stddev


""" ==================== GENERATOR ======================== """

Wzh = xavier_init(size=[Z_dim, h_dim])
bzh = torch.zeros(h_dim)

Whx = xavier_init(size=[h_dim, X_dim * Y_dim])
bhx = torch.zeros(X_dim * Y_dim)

def G(z):
    h = nn.ReLU()(z @ Wzh + bzh.repeat(z.size(0), 1))
    X = nn.Sigmoid()(h @ Whx + bhx.repeat(h.size(0), 1))
    return X


""" ==================== DISCRIMINATOR ======================== """

Wxh = xavier_init(size=[X_dim * Y_dim, h_dim])
bxh = torch.zeros(h_dim)

Why = xavier_init(size=[h_dim, 1])
bhy = torch.zeros(1)

def D(X):
    X = X.flatten(start_dim=1)
    h = nn.ReLU()(X @ Wxh + bxh.repeat(X.size(0), 1))
    y = nn.Sigmoid()(h @ Why + bhy.repeat(h.size(0), 1))
    return y

G_params = [Wzh, bzh, Whx, bhx]
D_params = [Wxh, bxh, Why, bhy]
params = G_params + D_params

for param in params:
  param.requires_grad = True


""" ===================== TRAINING ======================== """


def reset_grad():
    for p in params:
        if p.grad is not None:
            data = p.grad.data
            p.grad = data.new().resize_as_(data).zero_()


G_solver = optim.Adam(G_params, lr=1e-3)
D_solver = optim.Adam(D_params, lr=1e-3)

ones_label = torch.ones(mb_size, 1)
zeros_label = torch.zeros(mb_size, 1)


for it in range(100000):
    # Sample data
    z = torch.randn(mb_size, Z_dim)
    X, _ = next(iter(train_loader))
    X = X.squeeze(dim=1)

    # Dicriminator forward-loss-backward-update
    for param in G_params:
      param.requires_grad = False
    for param in D_params:
      param.requires_grad = True

    G_sample = G(z)
    D_real = D(X)
    D_fake = D(G_sample)

    D_loss_real = F.binary_cross_entropy(D_real, ones_label)
    D_loss_fake = F.binary_cross_entropy(D_fake, zeros_label)
    D_loss = D_loss_real + D_loss_fake

    D_loss.backward()
    D_solver.step()

    # Housekeeping - reset gradient
    reset_grad()

    # Generator forward-loss-backward-update
    for param in G_params:
      param.requires_grad = True
    for param in D_params:
      param.requires_grad = False

    z = torch.randn(mb_size, Z_dim)
    G_sample = G(z)
    D_fake = D(G_sample)

    G_loss = F.binary_cross_entropy(D_fake, ones_label)

    G_loss.backward()
    G_solver.step()

    # Housekeeping - reset gradient
    reset_grad()

    # Print and plot every now and then
    if it % 100 == 0:
        print('Iter-{}; D_loss: {}; G_loss: {}'.format(it, D_loss.data.numpy(), 
                                                       G_loss.data.numpy()))

        samples = G(z).data.numpy()[:16]

        fig = plt.figure(figsize=(4, 4))
        gs = gridspec.GridSpec(4, 4)
        gs.update(wspace=0.05, hspace=0.05)

        for i, sample in enumerate(samples):
            ax = plt.subplot(gs[i])
            plt.axis('off')
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.set_aspect('equal')
            plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

        if not os.path.exists('out/'):
            os.makedirs('out/')

        plt.savefig('out/{}.png'.format(str(c).zfill(3)), bbox_inches='tight')
        c += 1
        plt.close(fig)

Iter-0; D_loss: 1.7111154794692993; G_loss: 2.3246514797210693
Iter-100; D_loss: 0.24809274077415466; G_loss: 3.9886956214904785
Iter-200; D_loss: 0.05050894245505333; G_loss: 5.8243608474731445
Iter-300; D_loss: 0.07821844518184662; G_loss: 5.957636833190918
Iter-400; D_loss: 0.02358751744031906; G_loss: 5.9996185302734375
Iter-500; D_loss: 0.011351916939020157; G_loss: 6.457283973693848
Iter-600; D_loss: 0.004365156404674053; G_loss: 7.674680709838867
Iter-700; D_loss: 0.020695002749562263; G_loss: 9.446849822998047
Iter-800; D_loss: 0.008205423131585121; G_loss: 7.852683067321777
Iter-900; D_loss: 0.0062996563501656055; G_loss: 7.8452630043029785
Iter-1000; D_loss: 0.0038198246620595455; G_loss: 8.367925643920898
Iter-1100; D_loss: 0.004194531124085188; G_loss: 8.40261173248291
Iter-1200; D_loss: 0.0043449099175632; G_loss: 7.920822620391846
Iter-1300; D_loss: 0.0031598899513483047; G_loss: 7.968366622924805
Iter-1400; D_loss: 0.006255622021853924; G_loss: 8.467569351196289
Iter-150