In [None]:
%matplotlib inline
import skimage.io
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
from torchvision import datasets
from torchvision import transforms
import torchvision.utils as vutils
import tqdm
from comet_ml import Experiment

In [None]:
experiment = Experiment(api_key="E3oWJUSFulpXpCUQfc5oGz0zY", project_name="pytorch-gans")

In [None]:
cuda = True
cudnn.benchmark = True

if torch.cuda.is_available() and not cuda:
    print("WARNING: You have a CUDA device, so you should probably run with --cuda")

In [None]:
os.makedirs("../gan/images", exist_ok=True)
os.makedirs("../gan/checkpoints", exist_ok=True)
os.makedirs("../gan/manifold_walk", exist_ok=True)

In [None]:
channels = 3
img_size = 64

In [None]:
img_shape = (channels, img_size, img_size)

In [None]:
latent_dim = 128

In [None]:
dataroot = "/home/santiago/Downloads/celebA/"

batchSize = 256
workers = 4
dataset = datasets.ImageFolder(root=dataroot,
                           transform=transforms.Compose([
                               transforms.CenterCrop(128),
                               transforms.Resize(img_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))
assert dataset
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batchSize,
                                         shuffle=True, num_workers=int(workers))

In [None]:
device = torch.device("cuda:0" if cuda else "cpu")

In [None]:
# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.init_size = img_size // 2**4
        self.l1 = nn.Sequential(nn.Linear(latent_dim, 256*self.init_size**2))

        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(256),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(256, 256, 3, stride=1, padding=1),
            nn.Conv2d(256, 128, 3, stride=1, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.BatchNorm2d(128, 0.8),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.BatchNorm2d(64, 0.8),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(64, 64, 3, stride=1, padding=1),
            nn.Conv2d(64, 32, 3, stride=1, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.BatchNorm2d(32, 0.8),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(32, 32, 3, stride=1, padding=1),
            nn.Conv2d(32, channels, 3, stride=1, padding=1),
            nn.Tanh()
        )

    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], 256, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.conv_blocks = nn.Sequential(
            nn.Conv2d(channels, 32, 3, 1, 1),
            nn.Conv2d(32, 32, 3, 1, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.MaxPool2d(2),
            nn.Dropout2d(0.25),
            nn.BatchNorm2d(32, 0.8),
            
            nn.Conv2d(32, 64, 3, 1, 1),
            nn.Conv2d(64, 64, 3, 1, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.MaxPool2d(2),
            nn.Dropout2d(0.25),
            nn.BatchNorm2d(64, 0.8),
            
            nn.Conv2d(64, 128, 3, 1, 1),
            nn.Conv2d(128, 128, 3, 1, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.MaxPool2d(2),
            nn.Dropout2d(0.25),
            nn.BatchNorm2d(128, 0.8),
            
            nn.Conv2d(128, 256, 3, 1, 1),
            nn.Conv2d(256, 256, 3, 1, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.MaxPool2d(2),
            nn.Dropout2d(0.25),
            nn.BatchNorm2d(256, 0.8)
        )

        # The height and width of downsampled image
        ds_size = img_size // 2**4
        self.adv_layer = nn.Sequential(
            nn.Linear(256*ds_size**2, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        out = self.conv_blocks(img)
        out = out.view(out.shape[0], -1)
        validity = self.adv_layer(out)
        return validity

In [None]:
netG = Generator().cuda()
netG.apply(weights_init)
print(netG)
netD = Discriminator().cuda()
netD.apply(weights_init)
print(netD)

In [None]:
lr = 0.00001
b1 = 0.5
b2 = 0.999

In [None]:
criterion = nn.BCELoss().cuda()

fixed_noise = torch.randn(batchSize, latent_dim, device=device)
real_label = 1
fake_label = 0

# setup optimizer
optD = optim.Adam(netD.parameters(), lr=lr, betas=(b1, b2))
optG = optim.Adam(netG.parameters(), lr=lr, betas=(b1, b2))

In [None]:
# netG.load_state_dict(torch.load("/home/santiago/Repos/pytorch-experiments/vae/checkpoints/decoder_3166.pth"))
# netG.load_state_dict(torch.load("/home/santiago/Repos/pytorch-experiments/gan/primed/netG_epoch_0.pth"))
# netD.load_state_dict(torch.load("/home/santiago/Repos/pytorch-experiments/gan/primed/netD_epoch_0.pth"))
netD.load_state_dict(torch.load("/home/santiago/Repos/pytorch-experiments/gan/checkpoints/netD_step_83600.pth"))
netG.load_state_dict(torch.load("/home/santiago/Repos/pytorch-experiments/gan/checkpoints/netG_step_83600.pth"))
optD.load_state_dict(torch.load("/home/santiago/Repos/pytorch-experiments/gan/checkpoints/optD_step_83600.pth"))
optG.load_state_dict(torch.load("/home/santiago/Repos/pytorch-experiments/gan/checkpoints/optG_step_83600.pth"))
batches_done = 83600

In [None]:
epochs = 100

In [None]:
with experiment.train():
    for epoch in range(epochs):
        for i, data in enumerate(dataloader, 0):
            ############################
            # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            ###########################
            # train with real
            netD.zero_grad()
            real_cpu = data[0].to(device)
            batch_size = real_cpu.size(0)
            label = torch.full((batch_size,), real_label, device=device)

            output = netD(real_cpu)
    #         print(real_cpu.shape)
    #         print(output.shape)
    #         print(label.shape)
            errD_real = criterion(output, label)
            errD_real.backward()
            D_x = output.mean().item()

            # train with fake
            noise = torch.randn(batch_size, latent_dim, device=device)
            fake = netG(noise)
            label.fill_(fake_label)
            output = netD(fake.detach())
            errD_fake = criterion(output, label)
            errD_fake.backward()
            D_G_z1 = output.mean().item()
            errD = errD_real + errD_fake
            optD.step()

            ############################
            # (2) Update G network: maximize log(D(G(z)))
            ###########################
            netG.zero_grad()
            label.fill_(real_label)  # fake labels are real for generator cost
            output = netD(fake)
            errG = criterion(output, label)
            errG.backward()
            D_G_z2 = output.mean().item()
            optG.step()

            print("[Epoch: {}/{}] [Batch: {}/{}] [Global step: {}] [D loss: {}] [G loss: {}]".format(
                epoch, epochs, i, len(dataloader), batches_done, errD.item(), errG.item()
            ))
            
            experiment.log_metric("d_loss", errD.item(), step=batches_done)
            experiment.log_metric("g_loss", errG.item(), step=batches_done)
            
    #         print('[%d/%d][%d/%d] Loss_D: %.4f D(x): %.4f'
    #               % (epoch, epochs, i, len(dataloader),
    #                  errD.item(), D_x, D_G_z1))
            if batches_done % 400 == 0:
                vutils.save_image(real_cpu,
                        '../gan/images/real_samples.png',
                        normalize=True)
                fake = netG(fixed_noise)
                vutils.save_image(fake.detach(),
                        '../gan/images/fake_samples_step_%03d.png' % batches_done,
                        normalize=True)
                # do checkpointing
                torch.save(netG.state_dict(), '../gan/checkpoints/netG_step_%d.pth' % batches_done)
                torch.save(netD.state_dict(), '../gan/checkpoints/netD_step_%d.pth' % batches_done)
                torch.save(optG.state_dict(), '../gan/checkpoints/optG_step_%d.pth' % batches_done)
                torch.save(optD.state_dict(), '../gan/checkpoints/optD_step_%d.pth' % batches_done)
            
            batches_done += 1

In [None]:
noise = torch.randn(128, latent_dim, device=device)

In [None]:
direction = torch.randn(128, latent_dim, device=device)

In [None]:
for i in range(400):
    img = netG(noise + (i - 200) * 0.01 * direction)
    vutils.save_image(img.detach(), "../gan/walk_grid/walk%03d.png" % i, normalize=True)

In [None]:
noise