# 实验7 WGAN

Author: 高鹏昺

Email: nbgao@126.com

Reference: https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html?highlight=gan

In [32]:
import torch
import torch.nn as nn
import torch.nn.parallel
from torch.autograd import Variable
import torch.optim as optim
import torch.utils.data as data
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.utils as utils
from torchvision.datasets import CIFAR10

import time
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

In [14]:
torch.cuda.is_available()

True

In [53]:
workers = 2
batch_size = 128
image_size = 64
nc = 3
nz = 100
ngf = 64
ndf = 64
num_epochs = 8
lr = 0.00005
beta1 = 0.5
clip = 0.01
ngpu = 1

device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu>0) else "cpu")
print(device)

cuda:0


In [16]:
transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

In [17]:
dataset = CIFAR10(root='../data/cifar-10-python/', transform=transform, download=False)
print(dataset)
dataloader = data.DataLoader(dataset, batch_size, shuffle=True,  num_workers=workers)

Dataset CIFAR10
    Number of datapoints: 50000
    Root location: ../data/cifar-10-python/
    Split: Train


In [18]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [54]:
# Generator
class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.generator = nn.Sequential(
            nn.ConvTranspose2d(nz, ngf*8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf*8),
            nn.ReLU(True),
            # [nfg*8, 4, 4]
            nn.ConvTranspose2d(ngf*8, ngf*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf*4),
            nn.ReLU(True),
            # [ngf*4, 8, 8]
            nn.ConvTranspose2d(ngf*4, ngf*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf*2),
            nn.ReLU(True),
            # [ngf*2, 16, 16]
            nn.ConvTranspose2d(ngf*2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            # [ngf, 32, 32]
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # [nc, 64, 64]
        )
    
    def forward(self, x):
        return self.generator(x)

In [55]:
netG = Generator(ngpu).to(device)

if (device.type=='cuda' and ngpu>1):
    netG = nn.DataParallel(netG, list(range(ngpu)))

netG.apply(weights_init)

print(netG)

Generator(
  (generator): Sequential(
    (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace)
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (12): Tanh()
  )
)


In [56]:
# Discriminator
class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.discriminator = nn.Sequential(
            # [nc, 64, 64]
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # [ndf, 32, 32]
            nn.Conv2d(ndf, ndf*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*2),
            nn.LeakyReLU(0.2, inplace=True),
            # [ndf*2, 16, 16]
            nn.Conv2d(ndf*2, ndf*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*4),
            nn.LeakyReLU(0.2, inplace=True),
            # [ndf*4, 8, 8]
            nn.Conv2d(ndf*4, ndf*8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*8),
            nn.LeakyReLU(0.2, inplace=True),
            # [ndf*8 ,4, 4]
            nn.Conv2d(ndf*8, 1, 4, 1, 0, bias=True),
            # [WGAN] Modification 1: remove sigmoid
            # nn.Sigmoid()
        )
    
    def forward(self, x):
        return self.discriminator(x)

In [57]:
netD = Discriminator(ngpu).to(device)

if device.type=='cuda' and ngpu>1:
    netD = nn.DataParallel(netD, list(range(ngpu)))

netD.apply(weights_init)

print(netD)

Discriminator(
  (discriminator): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1))
  )
)


In [58]:
# [WGAN] Modification 2: remove log
#criterion = nn.BCELoss()

fixed_noise = torch.randn(64, nz, 1, 1, device=device)

real_label = 1
fake_label = 0

# [WGAN] Modification 3: Use RMSProp optimizzer
optimizerD = optim.RMSprop(netD.parameters(), lr=lr)
optimizerG = optim.RMSprop(netG.parameters(), lr=lr)

In [None]:
img_list = []
G_losses = []
D_losses = []
iters = 0

time_start = time.time()
print('Starting Training Loop...')

one = torch.FloatTensor([1])
mone = -1 * one
one = one.cuda()
mone = mone.cuda()

for epoch in range(num_epochs):
    for i, data in enumerate(dataloader, 0):
        # (1) Update D network: maximize log(D(x)) + log(1-D(G(z)))
        
        # [WGAN] Modification 4: Clip parameters for Discriminator
        for parm in netD.parameters():
            parm.data.clamp_(-clip, clip)
        
        real, _ = data
        real = Variable(real).cuda()
        noise = torch.randn(real.size(0), nz, 1, 1)
        noise = Variable(noise).cuda()

        netD.zero_grad()
        lossD_real = netD(real)
        #lossD_real = criterion(output, label)
        lossD_real.backward(torch.ones_like(lossD_real))
        D_x = lossD_real.mean().item()

    
        # Generate fake image batch with G
        fake = netG(noise).detach()
        lossD_fake = netD(fake)
        lossD_fake.backward(-torch.ones_like(lossD_fake))
        lossD = lossD_real + lossD_fake
        optimizerD.step()
#         output = netD(fake.detach()).view(-1)
#         lossD_fake = criterion(output, label)
#         lossD_fake.backward()
        D_G_z1 = lossD_fake.view(-1).mean().item()


        # (2) Update G network: maxmize log(D(G(z)))
        if i%5 == 0:
            netG.zero_grad()
            noise.data.normal_(0,1)
            fake = netG(noise)
            lossG = netD(fake)
            lossG.backward(torch.ones_like(lossG))
            D_G_z2 = lossG.view(-1).mean().item()
            optimizerG.step()
        
        if i%50==0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f' \
                % (epoch, num_epochs, i, len(dataloader), lossD.mean().item(), lossG.mean().item(), D_x, D_G_z1, D_G_z2))
        
        G_losses.append(lossG.mean().item())
        D_losses.append(lossD.mean().item())

        if (iters%500==0) or((epoch==num_epochs-1) and (i==len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(utils.make_grid(fake, padding=2, normalize=True))
        
        iters += 1
        
time_end = time.time()
print('Finished! Time: %.2fs' % (time_end-time_start))

Starting Training Loop...
[0/8][0/391]	Loss_D: -0.0130	Loss_G: 0.0436	D(x): -0.0067	D(G(z)): -0.0063 / 0.0436
[0/8][50/391]	Loss_D: -0.0301	Loss_G: 0.4533	D(x): -0.4651	D(G(z)): 0.4350 / 0.4533


In [None]:
plt.figure(figsize=(12, 6))
plt.plot(G_losses, label='G')
plt.plot(D_losses, label='D')
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.grid()
plt.show()

In [None]:
fig = plt.figure(figsize=(8,8))
plt.axis('off')
ims = [[plt.imshow(np.transpose(i, (1,2,0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=500, repeat_delay=1000, blit=True)

HTML(ani.to_jshtml())