In [None]:
%matplotlib inline

In [None]:
import argparse
import os
import numpy as np
import math
import sys


import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
import torch

from torchvision import datasets
import torchvision.transforms as transforms
from torchvision.utils import save_image, make_grid
from torch.utils.data import DataLoader

from torch.autograd import Variable

import matplotlib.pyplot as plt

from torch.utils.tensorboard import SummaryWriter

from tqdm import tqdm

from ipywidgets import IntProgress

In [None]:
parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=100, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0001, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=128, help="dimensionality of the latent space")
parser.add_argument("--img_size", type=int, default=32, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=3, help="number of image channels")
parser.add_argument("--n_critic", type=int, default=5, help="number of training steps for discriminator per iter")
parser.add_argument("--sample_interval", type=int, default=800, help="interval betwen image samples")
opt = parser.parse_args([])
print(opt)

In [None]:
img_shape = (opt.channels, opt.img_size, opt.img_size)
cuda = True if torch.cuda.is_available() else False

# Loss weight for gradient penalty
lambda_gp = 10

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

In [None]:
save_dir = "./images/wgan-gp_celeba"
os.makedirs(save_dir, exist_ok=True)

In [None]:
class Generator(nn.Module):
    def __init__(self, DIM=128):
        super(Generator, self).__init__()
        self.DIM = DIM
        
        preprocess = nn.Sequential(
            nn.Linear(128, 4 * 4 * 4 * DIM),
            nn.BatchNorm1d(4 * 4 * 4 * DIM),
            nn.ReLU(True),
        )

        block1 = nn.Sequential(
            nn.ConvTranspose2d(4 * DIM, 2 * DIM, 2, stride=2),
            nn.BatchNorm2d(2 * DIM),
            nn.ReLU(True),
        )
        block2 = nn.Sequential(
            nn.ConvTranspose2d(2 * DIM, DIM, 2, stride=2),
            nn.BatchNorm2d(DIM),
            nn.ReLU(True),
        )
        deconv_out = nn.ConvTranspose2d(DIM, 3, 2, stride=2)

        self.preprocess = preprocess
        self.block1 = block1
        self.block2 = block2
        self.deconv_out = deconv_out
        self.tanh = nn.Tanh()

    def forward(self, z):
        DIM = self.DIM
        output = self.preprocess(z)
        output = output.view(-1, 4 * DIM, 4, 4)
        output = self.block1(output)
        output = self.block2(output)
        output = self.deconv_out(output)
        output = self.tanh(output)
        return output.view(-1, 3, 32, 32)


class Discriminator(nn.Module):
    def __init__(self, DIM=128):
        super(Discriminator, self).__init__()
        
        self.DIM = DIM
        
        main = nn.Sequential(
            nn.Conv2d(3, DIM, 3, 2, padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(DIM, 2 * DIM, 3, 2, padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(2 * DIM, 4 * DIM, 3, 2, padding=1),
            nn.LeakyReLU(),
        )

        self.main = main
        self.linear = nn.Linear(4*4*4*DIM, 1)

    def forward(self, img):
        DIM = self.DIM
        
        output = self.main(img)
        output = output.view(-1, 4*4*4*DIM)
        output = self.linear(output)
        return output

In [None]:
generator = Generator(opt.latent_dim)
discriminator = Discriminator(opt.latent_dim)

if cuda:
    generator.cuda()
    discriminator.cuda()

In [None]:
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

In [None]:
def compute_gradient_penalty(D, real_samples, fake_samples):
    """Calculates the gradient penalty loss for WGAN GP"""
    # Random weight term for interpolation between real and fake samples
    alpha = Tensor(np.random.random((real_samples.size(0), 1, 1, 1)))
    # Get random interpolation between real and fake samples
    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
    d_interpolates = D(interpolates)
    fake = Variable(Tensor(real_samples.shape[0], 1).fill_(1.0), requires_grad=False)
    # Get gradient w.r.t. interpolates
    gradients = autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

In [None]:
dataloader = torch.utils.data.DataLoader(
    datasets.CelebA(
        "/data",
        split='train',
        download=False,
        transform=transforms.Compose([
                               transforms.Resize(32),
                               transforms.CenterCrop(32),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ])
    ),
    batch_size=opt.batch_size,
    shuffle=True,
)

In [None]:
writer = SummaryWriter()

In [None]:
dm = torch.as_tensor([0.5, 0.5, 0.5])[:, None, None].cuda()
ds = torch.as_tensor([0.5, 0.5, 0.5])[:, None, None].cuda()

z_static = Variable(Tensor(np.random.normal(0, 1, (64, opt.latent_dim))))

In [None]:
# ----------
#  Training
# ----------

batches_done = 0
for epoch in tqdm(range(opt.n_epochs)):
    for i, (imgs, _) in enumerate(dataloader):

        # Configure input
        real_imgs = Variable(imgs.type(Tensor))

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Sample noise as generator input
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))

        # Generate a batch of images
        fake_imgs = generator(z)

        # Real images
        real_validity = discriminator(real_imgs)
        # Fake images
        fake_validity = discriminator(fake_imgs)
        # Gradient penalty
        gradient_penalty = compute_gradient_penalty(discriminator, real_imgs.data, fake_imgs.data)
        # Adversarial loss
        d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gradient_penalty

        d_loss.backward()
        optimizer_D.step()

        optimizer_G.zero_grad()

        # Train the generator every n_critic steps
        if i % opt.n_critic == 0:

            # -----------------
            #  Train Generator
            # -----------------

            # Generate a batch of images
            fake_imgs = generator(z)
            # Loss measures generator's ability to fool the discriminator
            # Train on fake images
            fake_validity = discriminator(fake_imgs)
            g_loss = -torch.mean(fake_validity)

            g_loss.backward()
            optimizer_G.step()

#             print(
#                 "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
#                 % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
#             )

#             if batches_done % opt.sample_interval == 0:
#                 save_image(fake_imgs.data[:25], (save_dir + "{}.png").format(batches_done), nrow=5, normalize=True)

            if batches_done % opt.sample_interval == 0:
                fake_imgs_static = generator(z_static)
                fake_imgs_static.mul_(ds).add_(dm).clamp_(0, 1)
                grid = make_grid(fake_imgs_static)
                writer.add_image('images', grid, batches_done)
                writer.add_scalar('Loss/d_loss', d_loss.item(), batches_done)
                writer.add_scalar('Loss/g_loss', g_loss.item(), batches_done)
            
            batches_done += opt.n_critic
        
writer.close()

In [None]:
z = Variable(Tensor(np.random.normal(0, 1, (64, opt.latent_dim))))
fake_imgs = generator(z)

In [None]:
dm = torch.as_tensor([0.5, 0.5, 0.5])[:, None, None].cuda()
ds = torch.as_tensor([0.5, 0.5, 0.5])[:, None, None].cuda()
fake_imgs.mul_(ds).add_(dm).clamp_(0, 1)
img = make_grid(fake_imgs)

In [None]:
plt.figure(figsize = (14,14))
plt.imshow(img.clone().detach().cpu().numpy().swapaxes(0, 2).swapaxes(0, 1).clip(0, 1), interpolation='nearest')

In [None]:
save_dir = './models/celeba_wgan-gp_generator_32.pth.tar'
torch.save({'state_dict': generator.state_dict()}, save_dir)

In [None]:
"""
Loading
"""
model = Generator()
checkpoint = torch.load(save_dir)
model.load_state_dict(checkpoint['state_dict'])