In [4]:
# MNIST image generation using DCGAN
import torch
from torch.autograd import Variable
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import os
import imageio

# Parameters
image_size = 64
G_input_dim = 100
G_output_dim = 1
D_input_dim = 1
D_output_dim = 1
num_filters = [1024, 512, 256, 128]

learning_rate = 0.0002
betas = (0.5, 0.999)
batch_size = 128
num_epochs = 40
data_dir = 'mnist_data'
save_dir = 'MNIST_DCGAN_results/'

# MNIST dataset
transform = transforms.Compose([transforms.Resize(image_size),
                                transforms.ToTensor(),
                                transforms.Normalize(mean=(0.5, ), std=(0.5,))])

mnist_data = dsets.MNIST(root=data_dir,
                         train=True,
                         transform=transform,
                         download=False)

data_loader = torch.utils.data.DataLoader(dataset=mnist_data,
                                          batch_size=batch_size,
                                          shuffle=True)


# De-normalization
def denorm(x):
    out = (x + 1) / 2
    return out.clamp(0, 1)


# Generator model
class Generator(torch.nn.Module):
    def __init__(self, input_dim, num_filters, output_dim):
        super(Generator, self).__init__()

        # Hidden layers
        self.hidden_layer = torch.nn.Sequential()
        for i in range(len(num_filters)):
            # Deconvolutional layer
            if i == 0:
                deconv = torch.nn.ConvTranspose2d(input_dim, num_filters[i], kernel_size=4, stride=1, padding=0)
            else:
                deconv = torch.nn.ConvTranspose2d(num_filters[i-1], num_filters[i], kernel_size=4, stride=2, padding=1)

            deconv_name = 'deconv' + str(i + 1)
            self.hidden_layer.add_module(deconv_name, deconv)

            # Initializer
            torch.nn.init.normal_(deconv.weight, mean=0.0, std=0.02)
            torch.nn.init.constant_(deconv.bias, 0.0)

            # Batch normalization
            bn_name = 'bn' + str(i + 1)
            self.hidden_layer.add_module(bn_name, torch.nn.BatchNorm2d(num_filters[i]))

            # Activation
            act_name = 'act' + str(i + 1)
            self.hidden_layer.add_module(act_name, torch.nn.ReLU())

        # Output layer
        self.output_layer = torch.nn.Sequential()
        # Deconvolutional layer
        out = torch.nn.ConvTranspose2d(num_filters[i], output_dim, kernel_size=4, stride=2, padding=1)
        self.output_layer.add_module('out', out)
        # Initializer
        torch.nn.init.normal_(out.weight, mean=0.0, std=0.02)
        torch.nn.init.constant_(out.bias, 0.0)
        # Activation
        self.output_layer.add_module('act', torch.nn.Tanh())

    def forward(self, x):
        h = self.hidden_layer(x)
        out = self.output_layer(h)
        return out


# Discriminator model
class Discriminator(torch.nn.Module):
    def __init__(self, input_dim, num_filters, output_dim):
        super(Discriminator, self).__init__()

        # Hidden layers
        self.hidden_layer = torch.nn.Sequential()
        for i in range(len(num_filters)):
            # Convolutional layer
            if i == 0:
                conv = torch.nn.Conv2d(input_dim, num_filters[i], kernel_size=4, stride=2, padding=1)
            else:
                conv = torch.nn.Conv2d(num_filters[i-1], num_filters[i], kernel_size=4, stride=2, padding=1)

            conv_name = 'conv' + str(i + 1)
            self.hidden_layer.add_module(conv_name, conv)

            # Initializer
            torch.nn.init.normal_(conv.weight, mean=0.0, std=0.02)
            torch.nn.init.constant_(conv.bias, 0.0)

            # Batch normalization
            if i != 0:
                bn_name = 'bn' + str(i + 1)
                self.hidden_layer.add_module(bn_name, torch.nn.BatchNorm2d(num_filters[i]))

            # Activation
            act_name = 'act' + str(i + 1)
            self.hidden_layer.add_module(act_name, torch.nn.LeakyReLU(0.2))

        # Output layer
        self.output_layer = torch.nn.Sequential()
        # Convolutional layer
        out = torch.nn.Conv2d(num_filters[i], output_dim, kernel_size=4, stride=1, padding=0)
        self.output_layer.add_module('out', out)
        # Initializer
        torch.nn.init.normal_(out.weight, mean=0.0, std=0.02)
        torch.nn.init.constant_(out.bias, 0.0)
        # Activation
        self.output_layer.add_module('act', torch.nn.Sigmoid())

    def forward(self, x):
        h = self.hidden_layer(x)
        out = self.output_layer(h)
        return out


# Plot losses
def plot_loss(d_losses, g_losses, num_epoch, save=False, save_dir='MNIST_DCGAN_results/', show=False):
    fig, ax = plt.subplots()
    ax.set_xlim(0, num_epochs)
    ax.set_ylim(0, max(np.max(g_losses), np.max(d_losses))*1.1)
    plt.xlabel('Epoch {0}'.format(num_epoch + 1))
    plt.ylabel('Loss values')
    plt.plot(d_losses, label='Discriminator')
    plt.plot(g_losses, label='Generator')
    plt.legend()

    # save figure
    if save:
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)
        save_fn = save_dir + 'MNIST_DCGAN_losses_epoch_{:d}'.format(num_epoch + 1) + '.png'
        plt.savefig(save_fn)

    if show:
        plt.show()
    else:
        plt.close()


def plot_result(generator, noise, num_epoch, save=False, save_dir='MNIST_DCGAN_results/', show=False, fig_size=(5, 5)):
    generator.eval()

    noise = Variable(noise.cuda())
    gen_image = generator(noise)
    gen_image = denorm(gen_image)

    generator.train()

    n_rows = np.sqrt(noise.size()[0]).astype(np.int32)
    n_cols = np.sqrt(noise.size()[0]).astype(np.int32)
    fig, axes = plt.subplots(n_rows, n_cols, figsize=fig_size)
    for ax, img in zip(axes.flatten(), gen_image):
        ax.axis('off')
        ax.set_adjustable('box')
        ax.imshow(img.cpu().data.view(image_size, image_size).numpy(), cmap='gray', aspect='equal')
    plt.subplots_adjust(wspace=0, hspace=0)
    title = 'Epoch {0}'.format(num_epoch+1)
    fig.text(0.5, 0.04, title, ha='center')

    # save figure
    if save:
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)
        save_fn = save_dir + 'MNIST_DCGAN_epoch_{:d}'.format(num_epoch+1) + '.png'
        plt.savefig(save_fn)

    if show:
        plt.show()
    else:
        plt.close()


# Models
G = Generator(G_input_dim, num_filters, G_output_dim)
D = Discriminator(D_input_dim, num_filters[::-1], D_output_dim)
G.cuda()
D.cuda()

# Loss function
criterion = torch.nn.BCELoss()

# Optimizers
G_optimizer = torch.optim.Adam(G.parameters(), lr=learning_rate, betas=betas)
D_optimizer = torch.optim.Adam(D.parameters(), lr=learning_rate*0.1, betas=betas)

# Training GAN
D_avg_losses = []
G_avg_losses = []

# Fixed noise for test
num_test_samples = 5*5
fixed_noise = torch.randn(num_test_samples, G_input_dim).view(-1, G_input_dim, 1, 1)

train_D=True

for epoch in range(num_epochs):
    D_losses = []
    G_losses = []

    # minibatch training
    for i, (images, _) in enumerate(data_loader):

        # image data
        mini_batch = images.size()[0]
        x_ = Variable(images.cuda())

        # labels
        y_real_ = Variable(torch.ones(mini_batch).cuda())
        y_fake_ = Variable(torch.zeros(mini_batch).cuda())

        # Train discriminator with real data
        D_real_decision = D(x_).squeeze()
        # print(D_real_decision, y_real_)
        D_real_loss = criterion(D_real_decision, y_real_)

        # Train discriminator with fake data
        z_ = torch.randn(mini_batch, G_input_dim).view(-1, G_input_dim, 1, 1)
        z_ = Variable(z_.cuda())
        gen_image = G(z_)

        D_fake_decision = D(gen_image).squeeze()
        D_fake_loss = criterion(D_fake_decision, y_fake_)

        # Back propagation
        D_loss = D_real_loss + D_fake_loss
        D.zero_grad()
        #if train_D:
        D_loss.backward()
        D_optimizer.step()

        # Train generator
        z_ = torch.randn(mini_batch, G_input_dim).view(-1, G_input_dim, 1, 1)
        z_ = Variable(z_.cuda())
        gen_image = G(z_)

        D_fake_decision = D(gen_image).squeeze()
        G_loss = criterion(D_fake_decision, y_real_)

        # Back propagation
        D.zero_grad()
        G.zero_grad()
        G_loss.backward()
        G_optimizer.step()

        # loss values
        D_losses.append(D_loss.item())
        G_losses.append(G_loss.item())
        
        if D_loss.item()/(G_loss.item()+1e-8)<0.1:
            train_D=False
        else:
            train_D=True

        print('Epoch [%d/%d], Step [%d/%d], D_loss: %.4f, G_loss: %.4f'
              % (epoch+1, num_epochs, i+1, len(data_loader), D_loss.item(), G_loss.item()))

    D_avg_loss = torch.mean(torch.FloatTensor(D_losses))
    G_avg_loss = torch.mean(torch.FloatTensor(G_losses))

    # avg loss values for plot
    D_avg_losses.append(D_avg_loss)
    G_avg_losses.append(G_avg_loss)

    plot_loss(D_avg_losses, G_avg_losses, epoch, save=True)

    # Show result for fixed noise
    plot_result(G, fixed_noise, epoch, save=True, fig_size=(5, 5))

# Make gif
loss_plots = []
gen_image_plots = []
for epoch in range(num_epochs):
    # plot for generating gif
    save_fn1 = save_dir + 'MNIST_DCGAN_losses_epoch_{:d}'.format(epoch + 1) + '.png'
    loss_plots.append(imageio.imread(save_fn1))

    save_fn2 = save_dir + 'MNIST_DCGAN_epoch_{:d}'.format(epoch + 1) + '.png'
    gen_image_plots.append(imageio.imread(save_fn2))

imageio.mimsave(save_dir + 'MNIST_DCGAN_losses_epochs_{:d}'.format(num_epochs) + '.gif', loss_plots, fps=5)
imageio.mimsave(save_dir + 'MNIST_DCGAN_epochs_{:d}'.format(num_epochs) + '.gif', gen_image_plots, fps=5)

Epoch [1/40], Step [1/469], D_loss: 1.6848, G_loss: 1.6014
Epoch [1/40], Step [2/469], D_loss: 8.7140, G_loss: 0.0072
Epoch [1/40], Step [3/469], D_loss: 9.1787, G_loss: 0.0048
Epoch [1/40], Step [4/469], D_loss: 7.3978, G_loss: 0.0333
Epoch [1/40], Step [5/469], D_loss: 6.1375, G_loss: 0.1754
Epoch [1/40], Step [6/469], D_loss: 7.0806, G_loss: 0.0867
Epoch [1/40], Step [7/469], D_loss: 6.9031, G_loss: 0.0931
Epoch [1/40], Step [8/469], D_loss: 6.6143, G_loss: 0.1198
Epoch [1/40], Step [9/469], D_loss: 6.9993, G_loss: 0.1516
Epoch [1/40], Step [10/469], D_loss: 7.6367, G_loss: 0.0970
Epoch [1/40], Step [11/469], D_loss: 7.3100, G_loss: 0.1710
Epoch [1/40], Step [12/469], D_loss: 7.8383, G_loss: 0.1200
Epoch [1/40], Step [13/469], D_loss: 7.6054, G_loss: 0.2156
Epoch [1/40], Step [14/469], D_loss: 8.1108, G_loss: 0.1689
Epoch [1/40], Step [15/469], D_loss: 8.1218, G_loss: 0.2812
Epoch [1/40], Step [16/469], D_loss: 8.1167, G_loss: 0.2617
Epoch [1/40], Step [17/469], D_loss: 8.1437, G_lo

Epoch [1/40], Step [138/469], D_loss: 1.2605, G_loss: 4.4080
Epoch [1/40], Step [139/469], D_loss: 1.3498, G_loss: 4.0295
Epoch [1/40], Step [140/469], D_loss: 1.5634, G_loss: 4.4085
Epoch [1/40], Step [141/469], D_loss: 1.6456, G_loss: 4.7846
Epoch [1/40], Step [142/469], D_loss: 1.1488, G_loss: 4.6242
Epoch [1/40], Step [143/469], D_loss: 0.8718, G_loss: 4.1446
Epoch [1/40], Step [144/469], D_loss: 1.8600, G_loss: 4.0397
Epoch [1/40], Step [145/469], D_loss: 2.9632, G_loss: 4.3334
Epoch [1/40], Step [146/469], D_loss: 2.2362, G_loss: 5.6409
Epoch [1/40], Step [147/469], D_loss: 1.3971, G_loss: 5.5184
Epoch [1/40], Step [148/469], D_loss: 1.3367, G_loss: 4.7973
Epoch [1/40], Step [149/469], D_loss: 1.6336, G_loss: 4.3703
Epoch [1/40], Step [150/469], D_loss: 1.7321, G_loss: 4.8616
Epoch [1/40], Step [151/469], D_loss: 2.5118, G_loss: 5.1771
Epoch [1/40], Step [152/469], D_loss: 1.7493, G_loss: 5.6698
Epoch [1/40], Step [153/469], D_loss: 0.9481, G_loss: 5.3966
Epoch [1/40], Step [154/

Epoch [1/40], Step [273/469], D_loss: 0.5801, G_loss: 4.4014
Epoch [1/40], Step [274/469], D_loss: 0.5288, G_loss: 4.2096
Epoch [1/40], Step [275/469], D_loss: 0.4596, G_loss: 4.0351
Epoch [1/40], Step [276/469], D_loss: 0.4877, G_loss: 4.0815
Epoch [1/40], Step [277/469], D_loss: 0.3432, G_loss: 4.0138
Epoch [1/40], Step [278/469], D_loss: 0.6400, G_loss: 3.9511
Epoch [1/40], Step [279/469], D_loss: 0.6011, G_loss: 4.0969
Epoch [1/40], Step [280/469], D_loss: 0.4717, G_loss: 4.2265
Epoch [1/40], Step [281/469], D_loss: 0.5569, G_loss: 4.0501
Epoch [1/40], Step [282/469], D_loss: 0.4205, G_loss: 4.0000
Epoch [1/40], Step [283/469], D_loss: 0.4820, G_loss: 4.1121
Epoch [1/40], Step [284/469], D_loss: 0.4812, G_loss: 4.1952
Epoch [1/40], Step [285/469], D_loss: 0.4843, G_loss: 3.9468
Epoch [1/40], Step [286/469], D_loss: 0.5904, G_loss: 4.0976
Epoch [1/40], Step [287/469], D_loss: 0.6038, G_loss: 4.4915
Epoch [1/40], Step [288/469], D_loss: 0.4320, G_loss: 4.4453
Epoch [1/40], Step [289/

Epoch [1/40], Step [408/469], D_loss: 1.5585, G_loss: 3.7590
Epoch [1/40], Step [409/469], D_loss: 1.2965, G_loss: 4.2483
Epoch [1/40], Step [410/469], D_loss: 1.6094, G_loss: 3.6255
Epoch [1/40], Step [411/469], D_loss: 1.4084, G_loss: 4.1550
Epoch [1/40], Step [412/469], D_loss: 1.4256, G_loss: 4.0311
Epoch [1/40], Step [413/469], D_loss: 1.2014, G_loss: 4.4309
Epoch [1/40], Step [414/469], D_loss: 1.2458, G_loss: 3.3346
Epoch [1/40], Step [415/469], D_loss: 1.4048, G_loss: 4.0348
Epoch [1/40], Step [416/469], D_loss: 1.1131, G_loss: 3.4229
Epoch [1/40], Step [417/469], D_loss: 1.3749, G_loss: 5.2655
Epoch [1/40], Step [418/469], D_loss: 1.1387, G_loss: 3.8872
Epoch [1/40], Step [419/469], D_loss: 1.1037, G_loss: 3.1650
Epoch [1/40], Step [420/469], D_loss: 1.3726, G_loss: 4.5693
Epoch [1/40], Step [421/469], D_loss: 1.4042, G_loss: 3.5933
Epoch [1/40], Step [422/469], D_loss: 1.4185, G_loss: 3.6190
Epoch [1/40], Step [423/469], D_loss: 1.0793, G_loss: 3.3859
Epoch [1/40], Step [424/

KeyboardInterrupt: 