In [None]:
%matplotlib inline
import skimage.io

In [None]:
img = skimage.io.imread("/home/santiago/Downloads/celebA/img_align_celeba/000001.jpg")
print(img.shape)
skimage.io.imshow(img)

In [None]:
# img

In [None]:
import argparse
import os
import numpy as np
import math
import sys

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
import torch

from comet_ml import Experiment

In [None]:
experiment = Experiment(api_key="E3oWJUSFulpXpCUQfc5oGz0zY", project_name="pytorch-gans")

In [None]:
os.makedirs("../wgan/images", exist_ok=True)
os.makedirs("../wgan/checkpoints", exist_ok=True)
os.makedirs("../wgan/manifold_walk", exist_ok=True)

In [None]:
channels = 3
img_size = 64

In [None]:
img_shape = (channels, img_size, img_size)

In [None]:
cuda = True if torch.cuda.is_available() else False

In [None]:
latent_dim = 128

In [None]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm2d') != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.init_size = img_size // 2**4
        self.l1 = nn.Sequential(nn.Linear(latent_dim, 256*self.init_size**2))

        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(256),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(256, 256, 3, stride=1, padding=1),
            nn.Conv2d(256, 128, 3, stride=1, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.BatchNorm2d(128, 0.8),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.BatchNorm2d(64, 0.8),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(64, 64, 3, stride=1, padding=1),
            nn.Conv2d(64, 32, 3, stride=1, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.BatchNorm2d(32, 0.8),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(32, 32, 3, stride=1, padding=1),
            nn.Conv2d(32, channels, 3, stride=1, padding=1),
            nn.Tanh()
        )

    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], 256, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.conv_blocks = nn.Sequential(
            nn.Conv2d(channels, 32, 3, 1, 1),
            nn.Conv2d(32, 32, 3, 1, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.MaxPool2d(2),
            nn.Dropout2d(0.25),
            nn.BatchNorm2d(32, 0.8),
            
            nn.Conv2d(32, 64, 3, 1, 1),
            nn.Conv2d(64, 64, 3, 1, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.MaxPool2d(2),
            nn.Dropout2d(0.25),
            nn.BatchNorm2d(64, 0.8),
            
            nn.Conv2d(64, 128, 3, 1, 1),
            nn.Conv2d(128, 128, 3, 1, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.MaxPool2d(2),
            nn.Dropout2d(0.25),
            nn.BatchNorm2d(128, 0.8),
            
            nn.Conv2d(128, 256, 3, 1, 1),
            nn.Conv2d(256, 256, 3, 1, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.MaxPool2d(2),
            nn.Dropout2d(0.25),
            nn.BatchNorm2d(256, 0.8)
        )

        # The height and width of downsampled image
        ds_size = img_size // 2**4
        self.adv_layer = nn.Sequential(
            nn.Linear(256*ds_size**2, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        out = self.conv_blocks(img)
        out = out.view(out.shape[0], -1)
        validity = self.adv_layer(out)
        return validity

In [None]:
# Loss weight for gradient penalty
lambda_gp = 10

In [None]:
# Loss function
adversarial_loss = torch.nn.BCELoss()

# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()

In [None]:
print(generator)
print(discriminator)

In [None]:
# Initialize weights
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

In [None]:
dataroot = "/home/santiago/Downloads/celebA/"

In [None]:
batchSize = 128
workers = 4
dataset = datasets.ImageFolder(root=dataroot,
                           transform=transforms.Compose([
                               transforms.CenterCrop(128),
                               transforms.Resize(img_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))
assert dataset
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batchSize,
                                         shuffle=True, num_workers=int(workers))

In [None]:
# for i, (imgs, _) in enumerate(dataloader):
#     print(imgs[0])

In [None]:
b1 = 0.5
b2 = 0.999
g_lr = 0.00001
d_lr = 0.00001

In [None]:
# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=g_lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=d_lr, betas=(b1, b2))

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

In [None]:
def compute_gradient_penalty(D, real_samples, fake_samples):
    """Calculates the gradient penalty loss for WGAN GP"""
    # Random weight term for interpolation between real and fake samples
    alpha = Tensor(np.random.random((real_samples.size(0), 1, 1, 1)))
    # Get random interpolation between real and fake samples
    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
    d_interpolates = D(interpolates)
    fake = Variable(Tensor(real_samples.shape[0], 1).fill_(1.0), requires_grad=False)
    # Get gradient w.r.t. interpolates
    gradients = autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

In [None]:
n_epochs = 100
n_critic = 5
sample_interval = 100

In [None]:
discriminator.load_state_dict(torch.load("/home/santiago/Repos/pytorch-experiments/wgan/primed/netD_epoch_0.pth"))
generator.load_state_dict(torch.load("/home/santiago/Repos/pytorch-experiments/wgan/primed/netG_epoch_0.pth"))
# optimizer_D.load_state_dict(torch.load("/home/santiago/Repos/pytorch-experiments/checkpoints/opt_discriminator_13240.pth"))
# optimizer_G.load_state_dict(torch.load("/home/santiago/Repos/pytorch-experiments/checkpoints/opt_generator_13240.pth"))
# generator.load_state_dict(torch.load("/home/santiago/Repos/pytorch-experiments/vae/checkpoints/decoder_3166.pth"))
batches_done = 0

In [None]:
# ----------
#  Training
# ----------

with experiment.train():
    for epoch in range(n_epochs):
        for i, (imgs, _) in enumerate(dataloader):

            # Configure input
            real_imgs = Variable(imgs.type(Tensor))

            # ---------------------
            #  Train Discriminator
            # ---------------------

            optimizer_D.zero_grad()

            # Sample noise as generator input
            z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], latent_dim))))

            # Generate a batch of images
            fake_imgs = generator(z)

            # Real images
            real_validity = discriminator(real_imgs)
            # Fake images
            fake_validity = discriminator(fake_imgs)
            # Gradient penalty
            gradient_penalty = compute_gradient_penalty(discriminator, real_imgs.data, fake_imgs.data)
            # Adversarial loss
            d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gradient_penalty

            d_loss.backward()
            optimizer_D.step()

            optimizer_G.zero_grad()
            
            # Temporary
#             experiment.log_metric("d_loss", d_loss.item(), step=batches_done)
#             print(
#                 "[Epoch %d/%d] [Batch %d/%d] [D loss: %f]"
#                 % (epoch, n_epochs, i, len(dataloader), d_loss.item())
#                 )
#             if batches_done % sample_interval == 0:
#                 torch.save(discriminator.state_dict(), '../gan/checkpoints/discriminator_%d.pth' % batches_done)
#                 torch.save(optimizer_D.state_dict(), '../gan/checkpoints/discriminator_opt_%d.pth' % batches_done)
#             batches_done += 1
            
            # Train the generator every n_critic steps
            if i % n_critic == 0:

                # -----------------
                #  Train Generator
                # -----------------

                # Generate a batch of images
                fake_imgs = generator(z)
                # Loss measures generator's ability to fool the discriminator
                # Train on fake images
                fake_validity = discriminator(fake_imgs)
                g_loss = -torch.mean(fake_validity)

                g_loss.backward()
                optimizer_G.step()

                experiment.log_metric("d_loss", d_loss.item(), step=batches_done)
                experiment.log_metric("g_loss", g_loss.item(), step=batches_done)

                print(
                    "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
                    % (epoch, n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
                )

                if batches_done % sample_interval == 0:
                    save_image(fake_imgs.data[:25], "../wgan/images/%d.png" % batches_done, nrow=5, normalize=True)
                    torch.save(generator.state_dict(), '../wgan/checkpoints/generator_%d.pth' % batches_done)
                    torch.save(optimizer_G.state_dict(), '../wgan/checkpoints/generator_opt_%d.pth' % batches_done)
                    torch.save(discriminator.state_dict(), '../wgan/checkpoints/discriminator_%d.pth' % batches_done)
                    torch.save(optimizer_D.state_dict(), '../wgan/checkpoints/discriminator_opt_%d.pth' % batches_done)

                batches_done += n_critic

In [None]:
torch.cuda.empty_cache()
del real_imgs
del z
del fake_imgs
del real_validity
del fake_validity
del gradient_penalty
del d_loss
del g_loss

In [None]:
print(batches_done)

In [None]:
torch.save(generator.state_dict(), '../checkpoints/generator_%d.pth' % batches_done)
torch.save(discriminator.state_dict(), '../checkpoints/discriminator_%d.pth' % batches_done)

In [None]:
torch.save(optimizer_G.state_dict(), '../checkpoints/opt_generator_%d.pth' % batches_done)
torch.save(optimizer_D.state_dict(), '../checkpoints/opt_discriminator_%d.pth' % batches_done)

In [None]:
def reset_discriminator():
    discriminator.apply(weights_init_normal)
    optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

In [None]:
def reset_generator():
    generator.apply(weights_init_normal)
    optimizer_G = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

In [None]:
reset_discriminator()

In [None]:
optimizer_G = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

In [None]:
print(batches_done)

In [None]:
z = Variable(Tensor(np.random.normal(0, 1, (1, latent_dim))))
v = Variable(Tensor(0.01 * np.random.normal(0, 1, (1, latent_dim))))

In [None]:
z

In [None]:
v

In [None]:
img = generator(z)

In [None]:
img.data.shape

In [None]:
save_image(img.data[0], "../test.png")

In [None]:
steps = 200

In [None]:
for i in range(steps):
    save_image(generator(z + (i - steps / 2) * v).data[0], "../walk2/manifold_walk_%03d.png" % i)