In [None]:
from matplotlib import image
import torch
import torch.nn as nn

import numpy as np

import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
import torchvision.utils as vutils

class Discriminator(nn.Module):
    def __init__(self, img_channels, features_d,num_classes,image_size):
        super(Discriminator, self).__init__()
        self.image_size = image_size
        self.disc = nn.Sequential(
            # input: N x img_channels x 64 x 64
            nn.Conv2d(
                img_channels+1, features_d, kernel_size=4, stride=2, padding=1
                #image_channels+1 for embedding layer of labels
            ), #32 x 32
            nn.LeakyReLU(0.2),

            self._block(features_d, features_d * 2, 4, 2, 1), #16 x16
            self._block(features_d * 2, features_d * 4, 4, 2, 1), #8 x 8
            self._block(features_d * 4, features_d * 8, 4, 2, 1), #4 x 4
           
            nn.Conv2d(features_d * 8, 1, kernel_size=4, stride=2, padding=0),
            nn.Sigmoid(),
        )

        self.embed = nn.Embedding(num_classes, image_size*image_size)

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2),
        )

    def forward(self, x, labels):
        embedding = (self.embed(labels)).view(labels.shape[0], 1, self.image_size, self.image_size)
        x = torch.cat([x, embedding], dim=1)
        return self.disc(x)


class Generator(nn.Module):
    def __init__(self, z_dim, img_channels, features_g, image_size, num_classes, embedding_size):
        super(Generator, self).__init__()
        self.image_size = image_size
        self.net = nn.Sequential(
            # Input: N x z_dim x 1 x 1
            self._block(z_dim+embedding_size, features_g * 16, 4, 1, 0),  # img: 4x4
            self._block(features_g * 16, features_g * 8, 4, 2, 1),  # img: 8x8
            self._block(features_g * 8, features_g * 4, 4, 2, 1),  # img: 16x16
            self._block(features_g * 4, features_g * 2, 4, 2, 1),  # img: 32x32
            nn.ConvTranspose2d(
                features_g * 2, img_channels, kernel_size=4, stride=2, padding=1
            ),
            # Output: N x img_channels x 64 x 64
            nn.Tanh(),
        )

        self.embed = nn.Embedding(num_classes, embedding_size)

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def forward(self, x,labels):
        #latent vector z: N x z_dim x 1 x 1
        embedding = self.embed(labels).unsqueeze(2).unsqueeze(3)
        x = torch.cat([x, embedding], dim=1)
        return self.net(x)


def initialize_weights(model):
    # Initializes weights according to the DCGAN paper
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02) # mean 0 std 0.02

def test():
    N, in_channels, H, W = 8, 3, 64, 64
    noise_dim = 100
    x = torch.randn((N, in_channels, H, W))
    disc = Discriminator(in_channels, 8)
    assert disc(x).shape == (N, 1, 1, 1), "Discriminator test failed"
    gen = Generator(noise_dim, in_channels, 8)
    z = torch.randn((N, noise_dim, 1, 1))
    assert gen(z).shape == (N, in_channels, H, W), "Generator test failed"
    print("All tests passed")


In [None]:
"""
    Train a DCGAN on MNIST.
"""




device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

LEARNING_RATE = 2e-4
BATCH_SIZE = 128
IMAGE_SIZE = 64
CHANNELS_IMG = 1
NUM_CLASSES =10
GEN_EMBEDDING = 100

Z_DIM = 100
NUM_EPOCHS = 100
FEATURES_DISC = 64
FEATURES_GEN = 64


transforms = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(
        [0.5 for _ in range(CHANNELS_IMG)],[ 0.5 for _ in range(CHANNELS_IMG)]),

]
)

dataset = datasets.MNIST(root="./data", train=True, transform=transforms, download=True)
loader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

disc = Discriminator(CHANNELS_IMG, FEATURES_DISC, NUM_CLASSES,IMAGE_SIZE).to(device)
gen = Generator(Z_DIM, CHANNELS_IMG, FEATURES_GEN, IMAGE_SIZE, NUM_CLASSES,GEN_EMBEDDING).to(device)

initialize_weights(gen)
initialize_weights(disc)

opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
criterion = nn.BCELoss()

fixed_noise = torch.randn(32, Z_DIM, 1, 1).to(device)

img_list = []

gen.train()
disc.train()
iters = 0

for epoch in range(NUM_EPOCHS):
    for i, (imgs, labels) in enumerate(loader):
        iters += 1
        imgs = imgs.to(device)
        labels = labels.to(device)
        noise = torch.randn(BATCH_SIZE, Z_DIM, 1, 1).to(device)
        # Train Discriminator
        fake = gen(noise, labels)

        real_labels = torch.ones(BATCH_SIZE, 1, 1, 1).to(device)
        fake_labels = torch.zeros(BATCH_SIZE, 1, 1, 1).to(device)
        real_outputs = disc(imgs,labels)
        fake_outputs = disc(fake.detach(), labels) # detach to avoid backprop through generator
        real_loss = criterion(real_outputs, torch.ones_like(real_outputs))
        fake_loss = criterion(fake_outputs, torch.zeros_like(fake_outputs))
        d_loss = (real_loss + fake_loss) / 2
        disc.zero_grad()
        d_loss.backward()
        opt_disc.step()
        # Train Generator
        
        
        gen_outputs = disc(fake, labels)
        g_loss = criterion(gen_outputs, torch.ones_like(gen_outputs))
        gen.zero_grad()
        g_loss.backward()
        opt_gen.step()


        if (iters % 200 == 0) or ((epoch == NUM_EPOCHS-1) and (i == len(loader)-1)):
            with torch.no_grad():
                faker = gen(noise, labels)
            img_list.append(vutils.make_grid(faker, padding=2, normalize=True))



        if i == 0:
            print("Epoch: {}/{}".format(epoch, NUM_EPOCHS))
            print("Discriminator loss: {}".format(d_loss))
            print("Generator loss: {}".format(g_loss))
            # print("Real outputs: {}".format(real_outputs))
            # print("Fake outputs: {}".format(fake_outputs))
            # print("Real labels: {}".format(real_labels))
            # print("Fake labels: {}".format(fake_labels))
            # print("Generator output: {}".format(gen(noise)))
            # print("Discriminator output: {}".format(disc(imgs)))
            print("\n")






In [None]:
fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

HTML(ani.to_jshtml())