In [27]:
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from IPython.display import HTML

In [28]:
n_epochs = 5
batch_size = 128
latent_dim = 100
img_size = 64
channels = 1
sample_interval = 400


In [29]:
transform = transforms.Compose([
                               transforms.Resize(img_size),
                               transforms.CenterCrop(img_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ])
trainset = dset.CIFAR10(root='./data', train=True,
                                        download=True,transform=transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

Files already downloaded and verified


In [30]:
image, label = trainset[0]

In [31]:
img_shape = image.shape

In [32]:
img_shape 

torch.Size([3, 64, 64])

In [33]:
import torch.nn as nn


class Generator(nn.Module):
    def __init__(self, latent_size, img_shape):
        super(Generator, self).__init__()
        self.latent_size = latent_size
        self.img_shape = img_shape

        self.model = nn.Sequential(
            nn.ConvTranspose2d(latent_size, 128, 4, 1, 0, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, img_shape[0], 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, z):
        z = z.view(z.size(0), self.latent_size, 1, 1)
        img = self.model(z)
        return img

# Define the discriminator architecture
class Discriminator(nn.Module):
    def __init__(self, img_shape):
        super(Discriminator, self).__init__()
        self.img_shape = img_shape

        self.model = nn.Sequential(
            nn.Conv2d(img_shape[0], 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 1, 4, 1, 0, bias=False),
        )

    def forward(self, img):
        validity = self.model(img)
        return validity.view(-1)


def wasserstein_loss(real_scores, fake_scores):
    return -torch.mean(real_scores) + torch.mean(fake_scores)



In [34]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable

def train_wgan(generator, discriminator, dataloader, num_epochs, batch_size, latent_size, lr, clip_value):
    G_losses=[]
    img_list=[]
    D_losses=[]
    cuda = True if torch.cuda.is_available() else False
    Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

 
    optimizer_G = optim.RMSprop(generator.parameters(), lr=lr)
    optimizer_D = optim.RMSprop(discriminator.parameters(), lr=lr)
    iters = 0

    for epoch in range(num_epochs):
        for i, (imgs, _) in enumerate(train_loader):
            # Sample a batch of noise vectors for generator input
            z = Variable(Tensor(batch_size, latent_size).normal_())

            # Generate a batch of fake images
            gen_imgs = generator(z)

            # Clip the discriminator weights to enforce Lipschitz continuity
            for p in discriminator.parameters():
                p.data.clamp_(-clip_value, clip_value)

            # Train the discriminator
            optimizer_D.zero_grad()

            # Compute discriminator output on real and fake images
            real_scores = discriminator(imgs.type(Tensor))
            fake_scores = discriminator(gen_imgs.detach())

            # Compute Wasserstein distance loss and backpropagate
            d_loss = wasserstein_loss(real_scores, fake_scores)
            d_loss.backward()
            optimizer_D.step()
            img_list.append(vutils.make_grid(gen_imgs, padding=2, normalize=True))
            vutils.save_image(gen_imgs,
                '%s/results_epoch_%03d.png' % ('WGAN/', epoch))
            vutils.save_image(imgs,
                '%s/real_results_epoch_%03d.png' % ('WGAN/', epoch))
            # Train the generator
            if i % 100 == 0:
                optimizer_G.zero_grad()

                # Generate a new batch of fake images
                gen_imgs = generator(z)

                # Compute discriminator output on new fake images
                fake_scores = discriminator(gen_imgs)
                
                # Compute generator loss and backpropagate
                g_loss = -torch.mean(fake_scores)
                g_loss.backward()
                optimizer_G.step()
                print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" % (epoch, num_epochs, i, len(dataloader), d_loss.item(), g_loss.item()))
                vutils.save_image(gen_imgs,
                '%s/results_epoch_%03d.png' % ('WGAN/', epoch))
                vutils.save_image(imgs,
                '%s/real_results_epoch_%03d.png' % ('WGAN/', epoch))
            G_losses.append(g_loss.item())
            D_losses.append(d_loss.item())   
           
            iters += 1

    plot_it(G_losses,D_losses)
    images(img_list)
    #fig = plt.figure(figsize=(8,8))
    #plt.axis("off")
    #ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
    #ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

    #HTML(ani.to_jshtml())


In [41]:

netG = Generator(latent_dim, img_shape).to('cuda')
netD = Discriminator(img_shape).to('cuda')

In [42]:
netG

Generator(
  (model): Sequential(
    (0): ConvTranspose2d(100, 128, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): Tanh()
  )
)

In [43]:
netD

Discriminator(
  (model): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
  )
)

In [44]:
def plot_it(g_loss,d_loss):
    plt.figure(figsize=(10,5))
    plt.title("Generator and Discriminator Loss During Training")
    plt.plot(g_loss,label="G")
    plt.plot(d_loss,label="D")
    plt.xlabel("iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.show()

In [None]:
train_wgan(netG, netD, train_loader, n_epochs, batch_size,latent_dim, 0.01, 0.001)

[Epoch 0/5] [Batch 0/391] [D loss: 0.000000] [G loss: 2.745760]
[Epoch 0/5] [Batch 100/391] [D loss: -0.000587] [G loss: 0.128273]
[Epoch 0/5] [Batch 200/391] [D loss: -0.000604] [G loss: 0.089055]
[Epoch 0/5] [Batch 300/391] [D loss: -0.000623] [G loss: 0.080351]


In [None]:
plot_it(GL,DL)

In [9]:
def plot_it(g_loss,d_loss):
    plt.figure(figsize=(10,5))
    plt.title("Generator and Discriminator Loss During Training")
    plt.plot(g_loss,label="G")
    plt.plot(d_loss,label="D")
    plt.xlabel("iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.show()