## Conditional GAN (cgan)

Conditional Generative Adversarial Nets

Authors: Mehdi Mirza, Simon Osindero

Generative Adversarial Nets were recently introduced as a novel way to train generative models. In this work we introduce the conditional version of generative adversarial nets, which can be constructed by simply feeding the data, y, we wish to condition on to both the generator and discriminator. We show that this model can generate MNIST digits conditioned on class labels. We also illustrate how this model could be used to learn a multi-modal model, and provide preliminary examples of an application to image tagging in which we demonstrate how this approach can generate descriptive tags which are not part of training labels.

paper: https://arxiv.org/abs/1411.1784


In [1]:
import argparse
import os
import numpy as np
import math

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

In [2]:
os.makedirs("images", exist_ok=True)


opt_n_epochs=200           ##  help="number of epochs of training
opt_batch_size=64          ##  help="size of the batches
opt_lr=0.0002              ##  help="adam: learning rate
opt_b1=0.5                 ##  help="adam: decay of first order momentum of gradient
opt_b2=0.999               ##  help="adam: decay of first order momentum of gradient
opt_n_cpu=8                ##  help="number of cpu threads to use during batch generation
opt_latent_dim=100         ##  help="dimensionality of the latent space
opt_n_classes=10           ##  help="number of classes for dataset
opt_img_size=32            ##  help="size of each image dimension
opt_channels=1             ##  help="number of image channels
opt_sample_interval=400    ## help="interval between image sampling")



In [3]:

img_shape = (opt_channels, opt_img_size, opt_img_size)

cuda = True if torch.cuda.is_available() else False

In [4]:

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.label_emb = nn.Embedding(opt_n_classes, opt_n_classes)

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(opt_latent_dim + opt_n_classes, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, noise, labels):
        # Concatenate label embedding and image to produce input
        gen_input = torch.cat((self.label_emb(labels), noise), -1)
        img = self.model(gen_input)
        img = img.view(img.size(0), *img_shape)
        return img

In [5]:

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.label_embedding = nn.Embedding(opt_n_classes, opt_n_classes)

        self.model = nn.Sequential(
            nn.Linear(opt_n_classes + int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 512),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 512),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 1),
        )

    def forward(self, img, labels):
        # Concatenate label embedding and image to produce input
        d_in = torch.cat((img.view(img.size(0), -1), self.label_embedding(labels)), -1)
        validity = self.model(d_in)
        return validity

In [6]:

# Loss functions
adversarial_loss = torch.nn.MSELoss()

# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()

# Configure data loader
os.makedirs("../../data/mnist", exist_ok=True)

In [7]:

dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "../../data/mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(opt_img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size=opt_batch_size,
    shuffle=True,
)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../../data/mnist/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ../../data/mnist/MNIST/raw/train-images-idx3-ubyte.gz to ../../data/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../../data/mnist/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ../../data/mnist/MNIST/raw/train-labels-idx1-ubyte.gz to ../../data/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../../data/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ../../data/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to ../../data/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../../data/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ../../data/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../../data/mnist/MNIST/raw



In [8]:

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt_lr, betas=(opt_b1, opt_b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt_lr, betas=(opt_b1, opt_b2))

FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor

In [9]:

def sample_image(n_row, batches_done):
    """Saves a grid of generated digits ranging from 0 to n_classes"""
    # Sample noise
    z = Variable(FloatTensor(np.random.normal(0, 1, (n_row ** 2, opt_latent_dim))))
    # Get labels ranging from 0 to n_classes for n rows
    labels = np.array([num for _ in range(n_row) for num in range(n_row)])
    labels = Variable(LongTensor(labels))
    gen_imgs = generator(z, labels)
    save_image(gen_imgs.data, "images/%d.png" % batches_done, nrow=n_row, normalize=True)



In [13]:


# ----------
#  Training
# ----------

for epoch in range(opt_n_epochs):
    for i, (imgs, labels) in enumerate(dataloader):

        batch_size = imgs.shape[0]

        # Adversarial ground truths
        valid = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False)
        fake  = Variable(FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False)

        # Configure input
        real_imgs = Variable(imgs.type(FloatTensor))
        labels    = Variable(labels.type(LongTensor))

        # -----------------
        #  Train Generator
        # -----------------

        optimizer_G.zero_grad()

        # Sample noise and labels as generator input
        z = Variable(FloatTensor(np.random.normal(0, 1, (batch_size, opt_latent_dim))))
        gen_labels = Variable(LongTensor(np.random.randint(0, opt_n_classes, batch_size)))

        # Generate a batch of images
        gen_imgs = generator(z, gen_labels)

        # Loss measures generator's ability to fool the discriminator
        validity = discriminator(gen_imgs, gen_labels)
        g_loss   = adversarial_loss(validity, valid)

        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Loss for real images
        validity_real = discriminator(real_imgs, labels)
        d_real_loss   = adversarial_loss(validity_real, valid)

        # Loss for fake images
        validity_fake = discriminator(gen_imgs.detach(), gen_labels)
        d_fake_loss   = adversarial_loss(validity_fake, fake)

        # Total discriminator loss
        d_loss = (d_real_loss + d_fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()

        if i == 0:
            print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
                % (epoch, opt_n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
            )

        batches_done = epoch * len(dataloader) + i
        if batches_done % opt_sample_interval == 0:
            sample_image(n_row=10, batches_done=batches_done)

[Epoch 0/200] [Batch 0/938] [D loss: 0.232328] [G loss: 0.282834]
[Epoch 1/200] [Batch 0/938] [D loss: 0.250522] [G loss: 0.172114]
[Epoch 2/200] [Batch 0/938] [D loss: 0.196647] [G loss: 0.368290]
[Epoch 3/200] [Batch 0/938] [D loss: 0.224032] [G loss: 0.273425]
[Epoch 4/200] [Batch 0/938] [D loss: 0.206738] [G loss: 0.436585]
[Epoch 5/200] [Batch 0/938] [D loss: 0.191776] [G loss: 0.371070]
[Epoch 6/200] [Batch 0/938] [D loss: 0.203678] [G loss: 0.420574]
[Epoch 7/200] [Batch 0/938] [D loss: 0.203995] [G loss: 0.329888]
[Epoch 8/200] [Batch 0/938] [D loss: 0.188743] [G loss: 0.361466]
[Epoch 9/200] [Batch 0/938] [D loss: 0.220097] [G loss: 0.250613]
[Epoch 10/200] [Batch 0/938] [D loss: 0.172530] [G loss: 0.388423]
[Epoch 11/200] [Batch 0/938] [D loss: 0.180064] [G loss: 0.445653]
[Epoch 12/200] [Batch 0/938] [D loss: 0.228379] [G loss: 0.301111]
[Epoch 13/200] [Batch 0/938] [D loss: 0.178316] [G loss: 0.478501]
[Epoch 14/200] [Batch 0/938] [D loss: 0.226013] [G loss: 0.448777]
[Epoc

[Epoch 123/200] [Batch 0/938] [D loss: 0.209910] [G loss: 0.842749]
[Epoch 124/200] [Batch 0/938] [D loss: 0.065047] [G loss: 0.756783]
[Epoch 125/200] [Batch 0/938] [D loss: 0.091710] [G loss: 0.601047]
[Epoch 126/200] [Batch 0/938] [D loss: 0.108570] [G loss: 0.760149]
[Epoch 127/200] [Batch 0/938] [D loss: 0.070660] [G loss: 0.869932]
[Epoch 128/200] [Batch 0/938] [D loss: 0.175546] [G loss: 0.317292]
[Epoch 129/200] [Batch 0/938] [D loss: 0.098868] [G loss: 0.480831]
[Epoch 130/200] [Batch 0/938] [D loss: 0.162887] [G loss: 0.313130]
[Epoch 131/200] [Batch 0/938] [D loss: 0.090493] [G loss: 0.669284]
[Epoch 132/200] [Batch 0/938] [D loss: 0.059231] [G loss: 0.665911]
[Epoch 133/200] [Batch 0/938] [D loss: 0.062324] [G loss: 0.806120]
[Epoch 134/200] [Batch 0/938] [D loss: 0.092142] [G loss: 0.632135]
[Epoch 135/200] [Batch 0/938] [D loss: 0.072734] [G loss: 0.805545]
[Epoch 136/200] [Batch 0/938] [D loss: 0.107201] [G loss: 0.495470]
[Epoch 137/200] [Batch 0/938] [D loss: 0.293813]