In [21]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from torch.autograd import Variable
import os
import numpy as np
import math

In [24]:
opt={
    "n_epochs":200,
    "batch_size":128,
    "lr":2e-4,
    "b1":.5,
    "b2":.999,
    "n_cpu":16,
    "latent_dim":100,
    "img_size":28,
    "channels":1,
    "sample_interval":400
}

In [25]:
os.makedirs("images", exist_ok=True)

In [None]:
img_shape = (opt["channels"], opt["img_size"], opt["img_size"])

cuda = True if torch.cuda.is_available() else False


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(opt["latent_dim"], 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *img_shape)
        return img


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)

        return validity


# Loss function
adversarial_loss = torch.nn.BCELoss()

# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()

# Configure data loader
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "../../data/mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(opt["img_size"]), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size=opt["batch_size"],
    shuffle=True,
)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt["lr"], betas=(opt["b1"], opt["b2"]))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt["lr"], betas=(opt["b1"], opt["b2"]))

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

# ----------
#  Training
# ----------

for epoch in range(opt["n_epochs"]):
    for i, (imgs, _) in enumerate(dataloader):

        # Adversarial ground truths
        valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)
        fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)

        # Configure input
        real_imgs = Variable(imgs.type(Tensor))

        # -----------------
        #  Train Generator
        # -----------------

        optimizer_G.zero_grad()

        # Sample noise as generator input
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt["latent_dim"]))))

        # Generate a batch of images
        gen_imgs = generator(z)

        # Loss measures generator's ability to fool the discriminator
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)

        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Measure discriminator's ability to classify real from generated samples
        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()

        print(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
            % (epoch, opt["n_epochs"], i, len(dataloader), d_loss.item(), g_loss.item())
        )

        batches_done = epoch * len(dataloader) + i
        if batches_done % opt["sample_interval"] == 0:
            save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)

[Epoch 0/200] [Batch 0/469] [D loss: 0.684953] [G loss: 0.705275]
[Epoch 0/200] [Batch 1/469] [D loss: 0.595207] [G loss: 0.701892]
[Epoch 0/200] [Batch 2/469] [D loss: 0.526325] [G loss: 0.698916]
[Epoch 0/200] [Batch 3/469] [D loss: 0.473177] [G loss: 0.695554]
[Epoch 0/200] [Batch 4/469] [D loss: 0.432902] [G loss: 0.691464]
[Epoch 0/200] [Batch 5/469] [D loss: 0.405615] [G loss: 0.686657]
[Epoch 0/200] [Batch 6/469] [D loss: 0.391588] [G loss: 0.680685]
[Epoch 0/200] [Batch 7/469] [D loss: 0.383734] [G loss: 0.672799]
[Epoch 0/200] [Batch 8/469] [D loss: 0.380953] [G loss: 0.664375]
[Epoch 0/200] [Batch 9/469] [D loss: 0.381949] [G loss: 0.653376]
[Epoch 0/200] [Batch 10/469] [D loss: 0.385340] [G loss: 0.642210]
[Epoch 0/200] [Batch 11/469] [D loss: 0.390980] [G loss: 0.630252]
[Epoch 0/200] [Batch 12/469] [D loss: 0.397931] [G loss: 0.617112]
[Epoch 0/200] [Batch 13/469] [D loss: 0.403277] [G loss: 0.608690]
[Epoch 0/200] [Batch 14/469] [D loss: 0.410715] [G loss: 0.598109]
[Epoc

[Epoch 0/200] [Batch 124/469] [D loss: 0.460875] [G loss: 0.989114]
[Epoch 0/200] [Batch 125/469] [D loss: 0.461803] [G loss: 0.927798]
[Epoch 0/200] [Batch 126/469] [D loss: 0.489956] [G loss: 0.837864]
[Epoch 0/200] [Batch 127/469] [D loss: 0.518476] [G loss: 0.774376]
[Epoch 0/200] [Batch 128/469] [D loss: 0.507915] [G loss: 0.834294]
[Epoch 0/200] [Batch 129/469] [D loss: 0.506972] [G loss: 0.768824]
[Epoch 0/200] [Batch 130/469] [D loss: 0.522246] [G loss: 0.853189]
[Epoch 0/200] [Batch 131/469] [D loss: 0.556094] [G loss: 0.637283]
[Epoch 0/200] [Batch 132/469] [D loss: 0.597511] [G loss: 1.094806]
[Epoch 0/200] [Batch 133/469] [D loss: 0.785582] [G loss: 0.288855]
[Epoch 0/200] [Batch 134/469] [D loss: 0.678117] [G loss: 1.288688]
[Epoch 0/200] [Batch 135/469] [D loss: 0.661765] [G loss: 0.425797]
[Epoch 0/200] [Batch 136/469] [D loss: 0.546286] [G loss: 0.986049]
[Epoch 0/200] [Batch 137/469] [D loss: 0.552702] [G loss: 0.815282]
[Epoch 0/200] [Batch 138/469] [D loss: 0.576579]

[Epoch 0/200] [Batch 248/469] [D loss: 0.584051] [G loss: 0.588089]
[Epoch 0/200] [Batch 249/469] [D loss: 0.597600] [G loss: 0.841830]
[Epoch 0/200] [Batch 250/469] [D loss: 0.611576] [G loss: 0.567790]
[Epoch 0/200] [Batch 251/469] [D loss: 0.583429] [G loss: 0.824349]
[Epoch 0/200] [Batch 252/469] [D loss: 0.602883] [G loss: 0.623263]
[Epoch 0/200] [Batch 253/469] [D loss: 0.578115] [G loss: 0.802015]
[Epoch 0/200] [Batch 254/469] [D loss: 0.571961] [G loss: 0.715504]
[Epoch 0/200] [Batch 255/469] [D loss: 0.566488] [G loss: 0.760718]
[Epoch 0/200] [Batch 256/469] [D loss: 0.591833] [G loss: 0.721573]
[Epoch 0/200] [Batch 257/469] [D loss: 0.569360] [G loss: 0.752899]
[Epoch 0/200] [Batch 258/469] [D loss: 0.576980] [G loss: 0.744728]
[Epoch 0/200] [Batch 259/469] [D loss: 0.574560] [G loss: 0.748332]
[Epoch 0/200] [Batch 260/469] [D loss: 0.577480] [G loss: 0.758214]
[Epoch 0/200] [Batch 261/469] [D loss: 0.605686] [G loss: 0.733528]
[Epoch 0/200] [Batch 262/469] [D loss: 0.599810]

[Epoch 0/200] [Batch 372/469] [D loss: 0.447384] [G loss: 0.853996]
[Epoch 0/200] [Batch 373/469] [D loss: 0.442838] [G loss: 0.823798]
[Epoch 0/200] [Batch 374/469] [D loss: 0.452636] [G loss: 1.109671]
[Epoch 0/200] [Batch 375/469] [D loss: 0.501341] [G loss: 0.617981]
[Epoch 0/200] [Batch 376/469] [D loss: 0.479109] [G loss: 1.311001]
[Epoch 0/200] [Batch 377/469] [D loss: 0.498507] [G loss: 0.549597]
[Epoch 0/200] [Batch 378/469] [D loss: 0.427277] [G loss: 1.443653]
[Epoch 0/200] [Batch 379/469] [D loss: 0.414831] [G loss: 0.734664]
[Epoch 0/200] [Batch 380/469] [D loss: 0.380158] [G loss: 1.235152]
[Epoch 0/200] [Batch 381/469] [D loss: 0.388708] [G loss: 0.877764]
[Epoch 0/200] [Batch 382/469] [D loss: 0.378177] [G loss: 1.045287]
[Epoch 0/200] [Batch 383/469] [D loss: 0.420257] [G loss: 0.931670]
[Epoch 0/200] [Batch 384/469] [D loss: 0.429805] [G loss: 0.805881]
[Epoch 0/200] [Batch 385/469] [D loss: 0.444633] [G loss: 1.126438]
[Epoch 0/200] [Batch 386/469] [D loss: 0.515187]

[Epoch 1/200] [Batch 27/469] [D loss: 0.564573] [G loss: 1.378956]
[Epoch 1/200] [Batch 28/469] [D loss: 0.680173] [G loss: 0.364887]
[Epoch 1/200] [Batch 29/469] [D loss: 0.578143] [G loss: 1.582256]
[Epoch 1/200] [Batch 30/469] [D loss: 0.524668] [G loss: 0.546457]
[Epoch 1/200] [Batch 31/469] [D loss: 0.428112] [G loss: 1.387575]
[Epoch 1/200] [Batch 32/469] [D loss: 0.412115] [G loss: 0.955564]
[Epoch 1/200] [Batch 33/469] [D loss: 0.411605] [G loss: 0.973632]
[Epoch 1/200] [Batch 34/469] [D loss: 0.437261] [G loss: 1.159006]
[Epoch 1/200] [Batch 35/469] [D loss: 0.488479] [G loss: 0.717574]
[Epoch 1/200] [Batch 36/469] [D loss: 0.502585] [G loss: 1.392196]
[Epoch 1/200] [Batch 37/469] [D loss: 0.586291] [G loss: 0.480192]
[Epoch 1/200] [Batch 38/469] [D loss: 0.537413] [G loss: 1.704894]
[Epoch 1/200] [Batch 39/469] [D loss: 0.574701] [G loss: 0.507015]
[Epoch 1/200] [Batch 40/469] [D loss: 0.411582] [G loss: 1.445498]
[Epoch 1/200] [Batch 41/469] [D loss: 0.416414] [G loss: 1.037

[Epoch 1/200] [Batch 151/469] [D loss: 0.581541] [G loss: 0.558234]
[Epoch 1/200] [Batch 152/469] [D loss: 0.478477] [G loss: 1.666342]
[Epoch 1/200] [Batch 153/469] [D loss: 0.514942] [G loss: 0.675241]
[Epoch 1/200] [Batch 154/469] [D loss: 0.432965] [G loss: 1.417510]
[Epoch 1/200] [Batch 155/469] [D loss: 0.454810] [G loss: 0.915832]
[Epoch 1/200] [Batch 156/469] [D loss: 0.448786] [G loss: 1.104964]
[Epoch 1/200] [Batch 157/469] [D loss: 0.403346] [G loss: 0.998788]
[Epoch 1/200] [Batch 158/469] [D loss: 0.383357] [G loss: 1.322975]
[Epoch 1/200] [Batch 159/469] [D loss: 0.372325] [G loss: 1.038015]
[Epoch 1/200] [Batch 160/469] [D loss: 0.371025] [G loss: 1.453928]
[Epoch 1/200] [Batch 161/469] [D loss: 0.397480] [G loss: 0.920928]
[Epoch 1/200] [Batch 162/469] [D loss: 0.422998] [G loss: 1.757644]
[Epoch 1/200] [Batch 163/469] [D loss: 0.585417] [G loss: 0.479358]
[Epoch 1/200] [Batch 164/469] [D loss: 0.552208] [G loss: 2.262055]
[Epoch 1/200] [Batch 165/469] [D loss: 0.615898]

[Epoch 1/200] [Batch 275/469] [D loss: 0.300341] [G loss: 1.550117]
[Epoch 1/200] [Batch 276/469] [D loss: 0.307286] [G loss: 1.159968]
[Epoch 1/200] [Batch 277/469] [D loss: 0.277186] [G loss: 1.502181]
[Epoch 1/200] [Batch 278/469] [D loss: 0.310129] [G loss: 1.258792]
[Epoch 1/200] [Batch 279/469] [D loss: 0.301824] [G loss: 1.309165]
[Epoch 1/200] [Batch 280/469] [D loss: 0.318467] [G loss: 1.279074]
[Epoch 1/200] [Batch 281/469] [D loss: 0.318375] [G loss: 1.290417]
[Epoch 1/200] [Batch 282/469] [D loss: 0.278531] [G loss: 1.376339]
[Epoch 1/200] [Batch 283/469] [D loss: 0.303777] [G loss: 1.596535]
[Epoch 1/200] [Batch 284/469] [D loss: 0.371824] [G loss: 0.887641]
[Epoch 1/200] [Batch 285/469] [D loss: 0.361674] [G loss: 2.245590]
[Epoch 1/200] [Batch 286/469] [D loss: 0.543503] [G loss: 0.505490]
[Epoch 1/200] [Batch 287/469] [D loss: 0.545353] [G loss: 3.278408]
[Epoch 1/200] [Batch 288/469] [D loss: 0.378859] [G loss: 0.722869]
[Epoch 1/200] [Batch 289/469] [D loss: 0.219131]

[Epoch 1/200] [Batch 399/469] [D loss: 0.402157] [G loss: 1.091816]
[Epoch 1/200] [Batch 400/469] [D loss: 0.415010] [G loss: 1.168470]
[Epoch 1/200] [Batch 401/469] [D loss: 0.401091] [G loss: 0.994620]
[Epoch 1/200] [Batch 402/469] [D loss: 0.412892] [G loss: 1.358288]
[Epoch 1/200] [Batch 403/469] [D loss: 0.458060] [G loss: 0.705364]
[Epoch 1/200] [Batch 404/469] [D loss: 0.580746] [G loss: 2.087500]
[Epoch 1/200] [Batch 405/469] [D loss: 0.785863] [G loss: 0.279715]
[Epoch 1/200] [Batch 406/469] [D loss: 0.589682] [G loss: 2.412282]
[Epoch 1/200] [Batch 407/469] [D loss: 0.445458] [G loss: 0.733432]
[Epoch 1/200] [Batch 408/469] [D loss: 0.361296] [G loss: 1.068681]
[Epoch 1/200] [Batch 409/469] [D loss: 0.424015] [G loss: 1.560939]
[Epoch 1/200] [Batch 410/469] [D loss: 0.550027] [G loss: 0.577845]
[Epoch 1/200] [Batch 411/469] [D loss: 0.438800] [G loss: 1.535173]
[Epoch 1/200] [Batch 412/469] [D loss: 0.445279] [G loss: 0.782039]
[Epoch 1/200] [Batch 413/469] [D loss: 0.472981]

[Epoch 2/200] [Batch 54/469] [D loss: 0.449584] [G loss: 0.752058]
[Epoch 2/200] [Batch 55/469] [D loss: 0.411454] [G loss: 1.395049]
[Epoch 2/200] [Batch 56/469] [D loss: 0.420336] [G loss: 0.855107]
[Epoch 2/200] [Batch 57/469] [D loss: 0.412700] [G loss: 1.196575]
[Epoch 2/200] [Batch 58/469] [D loss: 0.401552] [G loss: 1.015605]
[Epoch 2/200] [Batch 59/469] [D loss: 0.400188] [G loss: 1.148154]
[Epoch 2/200] [Batch 60/469] [D loss: 0.369931] [G loss: 1.016548]
[Epoch 2/200] [Batch 61/469] [D loss: 0.429617] [G loss: 1.336424]
