In [1]:
import torch
import torch.nn as nn
import numpy as np

opt = {
    "n_epochs": 10,
    "batch_size": 64,
    "lr": 0.0002,
    "b1": 0.5,
    "b2": 0.999,
    "n_cpu": 8,
    "latent_dim": 100,
    "img_size": 28,
    "channels": 1,
    "sample_interval": 400
}

img_shape = (opt['channels'], opt['img_size'], opt['img_size'])

In [2]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(opt['latent_dim'], 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *img_shape)
        return img

In [3]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)

        return validity


In [4]:
import torchvision.transforms as transforms
import torchvision.datasets as datasets

dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "./data",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(opt['img_size']), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size=opt['batch_size'],
    shuffle=True,
)

  warn(


In [5]:
import os
from torch.autograd import Variable
from torchvision.utils import save_image

os.makedirs("images", exist_ok=True)

adversarial_loss = torch.nn.BCELoss()

generator = Generator()
discriminator = Discriminator()

optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt['lr'], betas=(opt['b1'], opt['b2']))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt['lr'], betas=(opt['b1'], opt['b2']))

Tensor = torch.FloatTensor

for epoch in range(opt['n_epochs']):
    for i, (imgs, _) in enumerate(dataloader):
        valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)
        fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)

        real_imgs = Variable(imgs.type(Tensor))

        optimizer_G.zero_grad()

        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt['latent_dim']))))

        gen_imgs = generator(z)

        g_loss = adversarial_loss(discriminator(gen_imgs), valid)

        g_loss.backward()
        optimizer_G.step()

        optimizer_D.zero_grad()

        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()

        print(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
            % (epoch, opt['n_epochs'], i, len(dataloader), d_loss.item(), g_loss.item())
        )

        batches_done = epoch * len(dataloader) + i
        if batches_done % opt['sample_interval'] == 0:
            save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)

[Epoch 0/10] [Batch 0/938] [D loss: 0.663720] [G loss: 0.694880]
[Epoch 0/10] [Batch 1/938] [D loss: 0.586710] [G loss: 0.692345]
[Epoch 0/10] [Batch 2/938] [D loss: 0.525301] [G loss: 0.689616]
[Epoch 0/10] [Batch 3/938] [D loss: 0.471464] [G loss: 0.686904]
[Epoch 0/10] [Batch 4/938] [D loss: 0.433823] [G loss: 0.683377]
[Epoch 0/10] [Batch 5/938] [D loss: 0.408808] [G loss: 0.679513]
[Epoch 0/10] [Batch 6/938] [D loss: 0.393595] [G loss: 0.674044]
[Epoch 0/10] [Batch 7/938] [D loss: 0.384827] [G loss: 0.667785]
[Epoch 0/10] [Batch 8/938] [D loss: 0.382117] [G loss: 0.660891]
[Epoch 0/10] [Batch 9/938] [D loss: 0.383501] [G loss: 0.649521]
[Epoch 0/10] [Batch 10/938] [D loss: 0.388412] [G loss: 0.638167]
[Epoch 0/10] [Batch 11/938] [D loss: 0.392198] [G loss: 0.628810]
[Epoch 0/10] [Batch 12/938] [D loss: 0.395439] [G loss: 0.621445]
[Epoch 0/10] [Batch 13/938] [D loss: 0.402948] [G loss: 0.609059]
[Epoch 0/10] [Batch 14/938] [D loss: 0.406372] [G loss: 0.604721]
[Epoch 0/10] [Batch 