# Wasserstein GAN in Pytorch

In [2]:
%matplotlib inline
from importlib import reload
import utils2; reload(utils2)
from utils2 import *

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.
  return f(*args, **kwds)


In [3]:
import torch_utils; reload(torch_utils)
from torch_utils import *

The good news is that in the last month the GAN training problem has been solved! [This paper](https://arxiv.org/abs/1701.07875) shows a minor change to the loss function and constraining the weights allows a GAN to reliably learn following a consistent loss schedule.

First, we, set up batch size, image size, and size of noise vector:

In [4]:
bs, sz, nz = 64, 64, 100

Pytorch has the handy [torch-vision](https://github.com/pytorch/vision) library which makes handling images fast and easy.

In [27]:
PATH = '../data2/cifar10/'
transform = transforms.Compose([
    transforms.Scale(sz),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

data = datasets.CIFAR10(root=PATH, download=True,
                       transform=transform)



Files already downloaded and verified


In [34]:
PATH = '../data2/lsun'
#os.makedirs(PATH+'bedroom_train_lmdb', exist_ok=True)
transform=transforms.Compose([
        transforms.Scale(sz),
        transforms.CenterCrop(sz),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

data = datasets.LSUN(db_path=PATH, classes=['bedroom_train'], transform=transform)



Error: ../data2/lsun/bedroom_train_lmdb: No such file or directory

Even parallel processing is handling automatically by torch-vision.

In [28]:
dataloader = torch.utils.data.DataLoader(data, batch_size=bs,
                                          shuffle=True, num_workers=8)
n = len(dataloader); n

782

Our activation function will be `tanh`, so we need to do some processing to view the generated images.

In [35]:
def show(img, fs=(6,6)):
    plt.figure(figsize = fs)
    plt.imshow(np.transpose((img/2+0.5).clamp(0,1).numpy(), (1, 2, 0)), interpolation='nearest')

## Create model

The CNN definitions are a little big for a notebook, so we import them.

In [36]:
import dcgan_prac1; reload(dcgan_prac1)
from dcgan_prac1 import DCGAN_D, DCGAN_G

Pytorch uses `module.apply()` for picking an initializer.

In [38]:
def weights_init(m):
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
        m.weight.data.normal_(mean=0.0, std=0.02)
    elif isinstance(m, nn.BatchNorm2d):
        m.weight.data.normal_(mean=1.0, std=0.02)
        m.bias.data.fill_(0)

In [39]:
# DCGAN_G(isize, nz, nc, ngf, ngpu, n_extra_layers)
netG = DCGAN_G(sz, nz, 3, 64, 1, 1).cuda()
netG.apply(weights_init)

DCGAN_G(
  (main): Sequential(
    (initial-100.512.convt): ConvTranspose2d (100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (initial-512.batchnorm): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
    (initial-512.relu): ReLU(inplace)
    (pyramid-512.256.convt): ConvTranspose2d (512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (pyramid-256.batchnorm): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
    (pyramid-256.relu): ReLU(inplace)
    (pyramid-256.128.convt): ConvTranspose2d (256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (pyramid-128.batchnorm): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
    (pyramid-128.relu): ReLU(inplace)
    (pyramid-128.64.convt): ConvTranspose2d (128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (pyramid-64.batchnorm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
    (pyramid-64.relu): ReLU(inplace)
    (extra-0-64.64.convt):

In [40]:
# DCGAN_D(isize, nc, ndf, ngpu, n_extra_layers)
netD = DCGAN_D(sz, 3, 64, 1, 1).cuda()
netD.apply(weights_init)

DCGAN_D(
  (main): Sequential(
    (initial-3.64.conv): Conv2d (3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (initial-64.batchnorm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
    (initial-64.relu): LeakyReLU(0.2, inplace)
    (extra-0-64.64.conv): Conv2d (64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (extra-0-64.batchnorm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
    (extra-0-64.relu): LeakyReLU(0.2, inplace)
    (pyramid-64.128.conv): Conv2d (64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (pyramid-128.batchnorm): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
    (pyramid-128.relu): LeakyReLU(0.2, inplace)
    (pyramid-128.256.conv): Conv2d (128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (pyramid-256.batchnorm): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
    (pyramid-256.relu): LeakyReLU(0.2, inplace)
    (pyramid-256.512.co

Just some shortcuts to create tensors and variables.

In [44]:
from torch import FloatTensor as FT
def Var(*params): return Variable(FT(*params).cuda())

In [45]:
def create_noise(b):
    return Variable(FT(b, nz, 1, 1).cuda().normal_(0,1))

In [46]:
# Input placeholder
input = Var(bs, 3, sz, nz)
# Fixed noise used just for visualizing images when done
fixed_noise = create_noise(bs)
# The numbers 0 and -1
one = torch.FloatTensor([1]).cuda()
mone = one * -1

An optimizer needs to be told what variables to optimize. A module automatically keeps track of its variables.

In [48]:
optimizerD = optim.RMSprop(netD.parameters(), lr = 1e-4)
optimizerG = optim.RMSprop(netG.parameters(), lr = 1e-4)

One forward step and one backward step for D

In [55]:
def step_D(v, init_grad):
    err = netD(v)
    err.backward(init_grad)
    return err

In [50]:
def make_trainable(net, val):
    for p in net.parameters(): p.requires_grad = True

In [56]:
def train(niter, first=True):
    gen_iterations = 0
    for epoch in range(niter):
        data_iter = iter(dataloader)
        i = 0
        while i < n:
            make_trainable(netD, True)
            d_iters = (100 if first and (gen_iterations<25) or (gen_iterations % 500 == 0) 
                      else 5) 
            
            j = 0
            while j < d_iters and i < n:
                j += 1; i += 1
                for p in netD.parameters(): p.data.clamp_(-0.01, 0.01)
                real = Variable(next(data_iter)[0].cuda())
                netD.zero_grad()
                errD_real = step_D(real, one)
                
                fake = netG(create_noise(real.size()[0]))
                input.data.resize_(real.size()).copy_(fake.data)
                errD_fake = step_D(input, mone)
                errD = errD_real - errD_fake
                optimizerD.step()
            
            make_trainable(netD, False)
            netG.zero_grad()
            errG = step_D(netG(create_noise(bs)), one)
            optimizerG.step()
            gen_iterations += 1
        
        print('[%d/%d][%d/%d] Loss_D: %f Loss_G: %f Loss_D_real: %f Loss_D_fake %f' % (
            epoch, niter, gen_iterations, n,
            errD.data[0], errG.data[0], errD_real.data[0], errD_fake.data[0]))

In [None]:
%time train(200, True)

[0/200][8/782] Loss_D: -1.556155 Loss_G: 0.755735 Loss_D_real: -0.814974 Loss_D_fake 0.741182
[1/200][16/782] Loss_D: -1.563625 Loss_G: 0.758643 Loss_D_real: -0.819034 Loss_D_fake 0.744591
[2/200][24/782] Loss_D: -1.567824 Loss_G: 0.758920 Loss_D_real: -0.823326 Loss_D_fake 0.744498
[3/200][162/782] Loss_D: -1.441777 Loss_G: 0.676901 Loss_D_real: -0.780393 Loss_D_fake 0.661383
[4/200][319/782] Loss_D: -1.505537 Loss_G: 0.700469 Loss_D_real: -0.819340 Loss_D_fake 0.686197
[5/200][476/782] Loss_D: -1.484864 Loss_G: 0.707040 Loss_D_real: -0.796190 Loss_D_fake 0.688674
[6/200][614/782] Loss_D: -1.455112 Loss_G: 0.716488 Loss_D_real: -0.795500 Loss_D_fake 0.659612
[7/200][771/782] Loss_D: -1.510170 Loss_G: 0.724779 Loss_D_real: -0.793192 Loss_D_fake 0.716978
[8/200][928/782] Loss_D: -0.027120 Loss_G: -0.540400 Loss_D_real: -0.668299 Loss_D_fake -0.641179
[9/200][1066/782] Loss_D: -0.466443 Loss_G: -0.284900 Loss_D_real: 0.172127 Loss_D_fake 0.638570
[10/200][1223/782] Loss_D: -1.166974 Loss

## View

In [None]:
fake = netG(fixed_noise).data.cpu()

In [None]:
show(vutils.make_grid(fake))

In [None]:
show(vutils.make_grid(iter(dataloader).next()[0]))

In [None]:
show(vutils.make_grid(fake))

In [None]:
show(vutils.make_grid(iter(dataloader).next()[0]))