In [7]:
import argparse
import os
import numpy as np
import math

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

In [8]:
os.makedirs('images', exist_ok=True)

class Opt(object):
    def __init__(self):
        self.n_epochs = 200
        self.batch_size = 64
        self.lr = 0.0002
        self.b1 = 0.5
        self.b2 = 0.999
        self.n_cpu = 8
        self.latent_dim = 100
        self.img_size = 28
        self.channels = 1
        self.sample_interval = 400
        
opt = Opt()
# parser = argparse.ArgumentParser()
# parser.add_argument('--n_epochs', type=int, default=200, help='number of epochs of training')
# parser.add_argument('--batch_size', type=int, default=64, help='size of the batches')
# parser.add_argument('--lr', type=float, default=0.0002, help='adam: learning rate')
# parser.add_argument('--b1', type=float, default=0.5, help='adam: decay of first order momentum of gradient')
# parser.add_argument('--b2', type=float, default=0.999, help='adam: decay of first order momentum of gradient')
# parser.add_argument('--n_cpu', type=int, default=8, help='number of cpu threads to use during batch generation')
# parser.add_argument('--latent_dim', type=int, default=100, help='dimensionality of the latent space')
# parser.add_argument('--img_size', type=int, default=28, help='size of each image dimension')
# parser.add_argument('--channels', type=int, default=1, help='number of image channels')
# parser.add_argument('--sample_interval', type=int, default=400, help='interval betwen image samples')
# opt = parser.parse_args()
# print(opt)

In [9]:
img_shape = (opt.channels, opt.img_size, opt.img_size)

cuda = True if torch.cuda.is_available() else False

In [11]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(opt.latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *img_shape)
        return img

In [12]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)

        return validity

In [13]:
# Loss function
adversarial_loss = torch.nn.BCELoss()

In [14]:
# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()

In [15]:
# Configure data loader
os.makedirs('../../data/mnist', exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST('../../data/mnist', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                   ])),
    batch_size=opt.batch_size, shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Processing...
Done!


In [16]:
# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

In [17]:
# ----------
#  Training
# ----------

for epoch in range(opt.n_epochs):
    for i, (imgs, _) in enumerate(dataloader):

        # Adversarial ground truths
        valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)
        fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)

        # Configure input
        real_imgs = Variable(imgs.type(Tensor))

        # -----------------
        #  Train Generator
        # -----------------

        optimizer_G.zero_grad()

        # Sample noise as generator input
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))

        # Generate a batch of images
        gen_imgs = generator(z)

        # Loss measures generator's ability to fool the discriminator
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)

        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Measure discriminator's ability to classify real from generated samples
        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()

        print ("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" % (epoch, opt.n_epochs, i, len(dataloader),
                                                            d_loss.item(), g_loss.item()))

        batches_done = epoch * len(dataloader) + i
        if batches_done % opt.sample_interval == 0:
            save_image(gen_imgs.data[:25], 'images/%d.png' % batches_done, nrow=5, normalize=True)

[Epoch 0/200] [Batch 0/938] [D loss: 0.668664] [G loss: 0.711462]
[Epoch 0/200] [Batch 1/938] [D loss: 0.584432] [G loss: 0.711248]
[Epoch 0/200] [Batch 2/938] [D loss: 0.514955] [G loss: 0.711100]
[Epoch 0/200] [Batch 3/938] [D loss: 0.459773] [G loss: 0.710947]
[Epoch 0/200] [Batch 4/938] [D loss: 0.418679] [G loss: 0.710684]
[Epoch 0/200] [Batch 5/938] [D loss: 0.391517] [G loss: 0.710402]
[Epoch 0/200] [Batch 6/938] [D loss: 0.374058] [G loss: 0.710018]
[Epoch 0/200] [Batch 7/938] [D loss: 0.364534] [G loss: 0.709588]
[Epoch 0/200] [Batch 8/938] [D loss: 0.357003] [G loss: 0.708708]
[Epoch 0/200] [Batch 9/938] [D loss: 0.353800] [G loss: 0.707378]
[Epoch 0/200] [Batch 10/938] [D loss: 0.351912] [G loss: 0.705806]
[Epoch 0/200] [Batch 11/938] [D loss: 0.350105] [G loss: 0.703808]
[Epoch 0/200] [Batch 12/938] [D loss: 0.348722] [G loss: 0.701749]
[Epoch 0/200] [Batch 13/938] [D loss: 0.349930] [G loss: 0.698055]
[Epoch 0/200] [Batch 14/938] [D loss: 0.352046] [G loss: 0.695377]
[Epoc

[Epoch 0/200] [Batch 124/938] [D loss: 0.396913] [G loss: 0.813594]
[Epoch 0/200] [Batch 125/938] [D loss: 0.322828] [G loss: 1.120462]
[Epoch 0/200] [Batch 126/938] [D loss: 0.328653] [G loss: 1.095278]
[Epoch 0/200] [Batch 127/938] [D loss: 0.339832] [G loss: 0.922403]
[Epoch 0/200] [Batch 128/938] [D loss: 0.338742] [G loss: 1.038011]
[Epoch 0/200] [Batch 129/938] [D loss: 0.371686] [G loss: 0.994977]
[Epoch 0/200] [Batch 130/938] [D loss: 0.385247] [G loss: 0.858124]
[Epoch 0/200] [Batch 131/938] [D loss: 0.402340] [G loss: 1.086600]
[Epoch 0/200] [Batch 132/938] [D loss: 0.411820] [G loss: 0.790037]
[Epoch 0/200] [Batch 133/938] [D loss: 0.363549] [G loss: 1.013990]
[Epoch 0/200] [Batch 134/938] [D loss: 0.371782] [G loss: 1.037991]
[Epoch 0/200] [Batch 135/938] [D loss: 0.369179] [G loss: 0.924047]
[Epoch 0/200] [Batch 136/938] [D loss: 0.367558] [G loss: 1.137531]
[Epoch 0/200] [Batch 137/938] [D loss: 0.363116] [G loss: 0.877165]
[Epoch 0/200] [Batch 138/938] [D loss: 0.327101]

[Epoch 0/200] [Batch 247/938] [D loss: 0.636257] [G loss: 1.019192]
[Epoch 0/200] [Batch 248/938] [D loss: 0.680685] [G loss: 0.433722]
[Epoch 0/200] [Batch 249/938] [D loss: 0.621521] [G loss: 0.922275]
[Epoch 0/200] [Batch 250/938] [D loss: 0.610695] [G loss: 0.550358]
[Epoch 0/200] [Batch 251/938] [D loss: 0.538945] [G loss: 0.900711]
[Epoch 0/200] [Batch 252/938] [D loss: 0.513345] [G loss: 0.766006]
[Epoch 0/200] [Batch 253/938] [D loss: 0.456191] [G loss: 0.914497]
[Epoch 0/200] [Batch 254/938] [D loss: 0.442335] [G loss: 0.895324]
[Epoch 0/200] [Batch 255/938] [D loss: 0.421176] [G loss: 0.866758]
[Epoch 0/200] [Batch 256/938] [D loss: 0.404367] [G loss: 1.144556]
[Epoch 0/200] [Batch 257/938] [D loss: 0.420488] [G loss: 0.756900]
[Epoch 0/200] [Batch 258/938] [D loss: 0.478866] [G loss: 1.373667]
[Epoch 0/200] [Batch 259/938] [D loss: 0.536227] [G loss: 0.511418]
[Epoch 0/200] [Batch 260/938] [D loss: 0.478676] [G loss: 1.440284]
[Epoch 0/200] [Batch 261/938] [D loss: 0.515656]

[Epoch 0/200] [Batch 370/938] [D loss: 0.337967] [G loss: 0.830118]
[Epoch 0/200] [Batch 371/938] [D loss: 0.407885] [G loss: 1.418488]
[Epoch 0/200] [Batch 372/938] [D loss: 0.479419] [G loss: 0.663531]
[Epoch 0/200] [Batch 373/938] [D loss: 0.379963] [G loss: 1.131680]
[Epoch 0/200] [Batch 374/938] [D loss: 0.395456] [G loss: 0.877015]
[Epoch 0/200] [Batch 375/938] [D loss: 0.404186] [G loss: 0.830393]
[Epoch 0/200] [Batch 376/938] [D loss: 0.448899] [G loss: 1.027930]
[Epoch 0/200] [Batch 377/938] [D loss: 0.484874] [G loss: 0.621787]
[Epoch 0/200] [Batch 378/938] [D loss: 0.469776] [G loss: 1.201019]
[Epoch 0/200] [Batch 379/938] [D loss: 0.509011] [G loss: 0.568404]
[Epoch 0/200] [Batch 380/938] [D loss: 0.400680] [G loss: 1.112690]
[Epoch 0/200] [Batch 381/938] [D loss: 0.419729] [G loss: 0.943965]
[Epoch 0/200] [Batch 382/938] [D loss: 0.400793] [G loss: 0.792278]
[Epoch 0/200] [Batch 383/938] [D loss: 0.323846] [G loss: 1.200377]
[Epoch 0/200] [Batch 384/938] [D loss: 0.300271]

[Epoch 0/200] [Batch 493/938] [D loss: 0.568862] [G loss: 0.460026]
[Epoch 0/200] [Batch 494/938] [D loss: 0.647703] [G loss: 2.094856]
[Epoch 0/200] [Batch 495/938] [D loss: 0.665535] [G loss: 0.378429]
[Epoch 0/200] [Batch 496/938] [D loss: 0.361404] [G loss: 1.413469]
[Epoch 0/200] [Batch 497/938] [D loss: 0.343862] [G loss: 1.455541]
[Epoch 0/200] [Batch 498/938] [D loss: 0.357597] [G loss: 0.863522]
[Epoch 0/200] [Batch 499/938] [D loss: 0.332743] [G loss: 1.540942]
[Epoch 0/200] [Batch 500/938] [D loss: 0.358631] [G loss: 0.994824]
[Epoch 0/200] [Batch 501/938] [D loss: 0.335342] [G loss: 1.245835]
[Epoch 0/200] [Batch 502/938] [D loss: 0.357750] [G loss: 1.152312]
[Epoch 0/200] [Batch 503/938] [D loss: 0.347102] [G loss: 1.062285]
[Epoch 0/200] [Batch 504/938] [D loss: 0.388236] [G loss: 1.296670]
[Epoch 0/200] [Batch 505/938] [D loss: 0.421370] [G loss: 0.753871]
[Epoch 0/200] [Batch 506/938] [D loss: 0.449286] [G loss: 1.721905]
[Epoch 0/200] [Batch 507/938] [D loss: 0.523782]

[Epoch 0/200] [Batch 616/938] [D loss: 0.227105] [G loss: 1.909815]
[Epoch 0/200] [Batch 617/938] [D loss: 0.304615] [G loss: 1.003639]
[Epoch 0/200] [Batch 618/938] [D loss: 0.267106] [G loss: 1.746211]
[Epoch 0/200] [Batch 619/938] [D loss: 0.302701] [G loss: 1.120270]
[Epoch 0/200] [Batch 620/938] [D loss: 0.315715] [G loss: 1.499144]
[Epoch 0/200] [Batch 621/938] [D loss: 0.373961] [G loss: 0.896575]
[Epoch 0/200] [Batch 622/938] [D loss: 0.435250] [G loss: 1.742104]
[Epoch 0/200] [Batch 623/938] [D loss: 0.669013] [G loss: 0.369772]
[Epoch 0/200] [Batch 624/938] [D loss: 0.671849] [G loss: 2.882887]
[Epoch 0/200] [Batch 625/938] [D loss: 0.614325] [G loss: 0.449616]
[Epoch 0/200] [Batch 626/938] [D loss: 0.243455] [G loss: 1.562909]
[Epoch 0/200] [Batch 627/938] [D loss: 0.310735] [G loss: 2.197899]
[Epoch 0/200] [Batch 628/938] [D loss: 0.381127] [G loss: 0.811134]
[Epoch 0/200] [Batch 629/938] [D loss: 0.256430] [G loss: 1.594639]
[Epoch 0/200] [Batch 630/938] [D loss: 0.332415]

[Epoch 0/200] [Batch 739/938] [D loss: 0.334191] [G loss: 1.302625]
[Epoch 0/200] [Batch 740/938] [D loss: 0.400838] [G loss: 1.209828]
[Epoch 0/200] [Batch 741/938] [D loss: 0.353809] [G loss: 1.070068]
[Epoch 0/200] [Batch 742/938] [D loss: 0.307939] [G loss: 1.591563]
[Epoch 0/200] [Batch 743/938] [D loss: 0.338101] [G loss: 0.979475]
[Epoch 0/200] [Batch 744/938] [D loss: 0.308457] [G loss: 1.926819]
[Epoch 0/200] [Batch 745/938] [D loss: 0.416013] [G loss: 0.716929]
[Epoch 0/200] [Batch 746/938] [D loss: 0.463838] [G loss: 2.741468]
[Epoch 0/200] [Batch 747/938] [D loss: 0.698749] [G loss: 0.332816]
[Epoch 0/200] [Batch 748/938] [D loss: 0.552742] [G loss: 3.298922]
[Epoch 0/200] [Batch 749/938] [D loss: 0.383815] [G loss: 0.766526]
[Epoch 0/200] [Batch 750/938] [D loss: 0.192032] [G loss: 1.820749]
[Epoch 0/200] [Batch 751/938] [D loss: 0.200122] [G loss: 2.133663]
[Epoch 0/200] [Batch 752/938] [D loss: 0.237918] [G loss: 1.344108]
[Epoch 0/200] [Batch 753/938] [D loss: 0.280531]

[Epoch 0/200] [Batch 862/938] [D loss: 0.307848] [G loss: 1.261026]
[Epoch 0/200] [Batch 863/938] [D loss: 0.326571] [G loss: 1.268378]
[Epoch 0/200] [Batch 864/938] [D loss: 0.330730] [G loss: 1.418857]
[Epoch 0/200] [Batch 865/938] [D loss: 0.334312] [G loss: 1.003247]
[Epoch 0/200] [Batch 866/938] [D loss: 0.528911] [G loss: 2.254957]
[Epoch 0/200] [Batch 867/938] [D loss: 1.002672] [G loss: 0.174481]
[Epoch 0/200] [Batch 868/938] [D loss: 0.673067] [G loss: 3.355190]
[Epoch 0/200] [Batch 869/938] [D loss: 0.437421] [G loss: 0.812459]
[Epoch 0/200] [Batch 870/938] [D loss: 0.217596] [G loss: 1.780616]
[Epoch 0/200] [Batch 871/938] [D loss: 0.288327] [G loss: 1.959931]
[Epoch 0/200] [Batch 872/938] [D loss: 0.392826] [G loss: 0.757249]
[Epoch 0/200] [Batch 873/938] [D loss: 0.349561] [G loss: 2.135210]
[Epoch 0/200] [Batch 874/938] [D loss: 0.386424] [G loss: 0.824310]
[Epoch 0/200] [Batch 875/938] [D loss: 0.375300] [G loss: 2.071775]
[Epoch 0/200] [Batch 876/938] [D loss: 0.431084]

KeyboardInterrupt: 