In [1]:
import os, sys
sys.path.append(os.getcwd())

import time

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import sklearn.datasets

import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.utils.data import DataLoader
from torchvision import datasets
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torchvision.utils import make_grid
import torchvision

import scipy.misc
from scipy.misc import imsave
from IPython.display import display, clear_output
from pyro.distributions.relaxed_straight_through import RelaxedBernoulliStraightThrough
from torch.distributions.bernoulli import Bernoulli
from torch.utils.data import RandomSampler, BatchSampler
from skimage.measure import compare_ssim as ssim

In [2]:
torch.manual_seed(1)
use_cuda = torch.cuda.is_available()
if use_cuda:
    gpu = 0

In [3]:
DIM = 64 # Model dimensionality
BATCH_SIZE = 50 # Batch size
LAMBDA = 10 # Gradient penalty lambda hyperparameter
OUTPUT_DIM = 784 # Number of pixels in MNIST (28*28)
DOWNLOAD_MNIST = False

FEATURE_LENGTH = 256 # How many binary values the encoded stage has
LOAD_MODEL = False

In [4]:
root = './datasets/mnist/'

train_data = datasets.MNIST(
    root=root,
    transform=torchvision.transforms.ToTensor(),
    download=DOWNLOAD_MNIST,
    train=True,
)
test_data = datasets.MNIST(
    root=root,
    transform=torchvision.transforms.ToTensor(),
    download=DOWNLOAD_MNIST,
    train=False,
)

# Data Loader for easy mini-batch return in training, the image batch shape will be (50, 1, 28, 28)
trainloader = torch.utils.data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)

testloader = torch.utils.data.DataLoader(dataset=test_data, batch_size=100)


for display_data,y in testloader:
    break
save_image(display_data, "images/_original.png", nrow=10, normalize=False)

display_data = display_data.view(-1, 1, 28*28)
if use_cuda:
    display_data = display_data.cuda(gpu)

In [6]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        
        self.encoder = nn.Sequential(
            nn.Linear(28*28, 256),
            nn.Tanh(),
            nn.Linear(256, 256),
        )

        self.decoder = nn.Sequential(
            nn.Linear(256, 256),
            nn.Tanh(),
            nn.Linear(256, 28*28),
            nn.Sigmoid(),       # compress to a range (0, 1)
        )

    def forward(self, image_batch):
        
        encoded = self.encoder(image_batch)
        bernoulli_encoded = RelaxedBernoulliStraightThrough(0.7, logits=encoded).rsample()
        decoded = self.decoder(bernoulli_encoded)
        return decoded
    

In [7]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        main = nn.Sequential(
            nn.Conv2d(1, DIM, 5, stride=2, padding=2),
            # nn.Linear(OUTPUT_DIM, 4*4*4*DIM),
            nn.ReLU(True),
            nn.Conv2d(DIM, 2*DIM, 5, stride=2, padding=2),
            # nn.Linear(4*4*4*DIM, 4*4*4*DIM),
            nn.ReLU(True),
            nn.Conv2d(2*DIM, 4*DIM, 5, stride=2, padding=2),
            # nn.Linear(4*4*4*DIM, 4*4*4*DIM),
            nn.ReLU(True),
            # nn.Linear(4*4*4*DIM, 4*4*4*DIM),
            # nn.LeakyReLU(True),
            # nn.Linear(4*4*4*DIM, 4*4*4*DIM),
            # nn.LeakyReLU(True),
        )
        self.main = main
        self.output = nn.Linear(4*4*4*DIM, 1)

    def forward(self, input):
        input = input.view(-1, 1, 28, 28)
        out = self.main(input)
        out = out.view(-1, 4*4*4*DIM)
        out = self.output(out)
        return out.view(-1)


In [8]:
def calc_gradient_penalty(netD, real_data, fake_data):
    #print real_data.size()
    alpha = torch.rand(BATCH_SIZE, 1)
    alpha = alpha.expand(real_data.size())
    alpha = alpha.cuda(gpu) if use_cuda else alpha

    interpolates = alpha * real_data + ((1 - alpha) * fake_data)

    if use_cuda:
        interpolates = interpolates.cuda(gpu)
    interpolates = autograd.Variable(interpolates, requires_grad=True)

    disc_interpolates = netD(interpolates)

    gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
                              grad_outputs=torch.ones(disc_interpolates.size()).cuda(gpu) if use_cuda else torch.ones(
                                  disc_interpolates.size()),
                              create_graph=True, retain_graph=True, only_inputs=True)[0]

    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA
    return gradient_penalty


In [10]:
def myplot(x,y, name):
    plt.clf()
    plt.plot(x, y, 'ro')
    plt.xlabel('iteration')
    plt.ylabel(name)
    plt.savefig('images/'+name+'.png')

In [13]:
# will generate pics for lambda 0.5 to 2**-6
for increase in range(0, 7): 
    LOAD_MODEL = False
    #-----------------------------------------------------------
    netG = Generator()
    netD = Discriminator()

    if use_cuda:
        netD = netD.cuda(gpu)
        netG = netG.cuda(gpu)

    optimizerD = optim.Adam(netD.parameters(), lr=1e-4, betas=(0.5, 0.9))
    optimizerG = optim.Adam(netG.parameters(), lr=1e-4, betas=(0.5, 0.9))
    
    #-----------------------------------------------------------
    if LOAD_MODEL:
        checkpoint = torch.load('./models/lamb0epoch20.pt')
        netG.load_state_dict(checkpoint['netG_state_dict'])
        netD.load_state_dict(checkpoint['netD_state_dict'])
        optimizerG.load_state_dict(checkpoint['optimizerG_state_dict'])
        optimizerD.load_state_dict(checkpoint['optimizerD_state_dict'])
        loaded_epoch = checkpoint['epoch']+2
        iteration = checkpoint['iteration']

        netG.eval()
        netD.eval()
    #-----------------------------------------------------------
    
    diterats = []
    giterats = []
    wdistarr = []
    dcostarr = []
    gcostarr = []
    trainloader_length = int(len(trainloader))
    if increase == 0:
        DISC_LAMBDA = 0
    else:
        DISC_LAMBDA = 2**(-increase)# Higher -> more discriminator
        
    DISC_GEN_TRAIN_RATIO = 5 # How many times the discriminator should be trained for one generator train
    EPOCHS = 50 # How many times the whole trainingset should be iterated over

    if not LOAD_MODEL:
        loaded_epoch = 0
        iteration = 0

    #-----------------------------------------------------------
    for epoch in range(loaded_epoch, EPOCHS):
        for x,y in trainloader:

            x = x.view(-1, 28*28)

            if use_cuda:
                x = x.cuda(gpu)


            if iteration %(DISC_GEN_TRAIN_RATIO+1) != DISC_GEN_TRAIN_RATIO:

            ############################
            # (1) Update D network
            ###########################

                # train with real
                D_real = netD(x).mean()

                # train with fake
                fake = netG(x).detach()
                D_fake = netD(fake).mean()

                # train with gradient penalty
                gradient_penalty = calc_gradient_penalty(netD, x, fake)

                D_cost = D_fake - D_real + gradient_penalty

                D_cost.backward()

                Wasserstein_D = D_real - D_fake
                optimizerD.step()
                optimizerD.zero_grad()

                diterats += [iteration]
                dcostarr += [D_cost.item()]
                wdistarr += [Wasserstein_D.item()]


            if iteration %(DISC_GEN_TRAIN_RATIO+1) == DISC_GEN_TRAIN_RATIO:

                ############################
                # (2) Update G network
                ###########################

                fake = netG(x)
                G = netD(fake).mean()

                rec_x = ((fake-x)**2).sum(1).mean(0)
    #             G_cost = rec_x - DISC_LAMBDA * G
                G_cost = (1 - DISC_LAMBDA) * rec_x - DISC_LAMBDA * G

                G_cost.backward()

                optimizerG.step()
                optimizerG.zero_grad()

                # Write logs and save samples
                giterats += [iteration]
                gcostarr += [G_cost.item()]


            if iteration % 100 == 99:
                fake = netG(display_data).view(-1,1,28,28)
                if use_cuda:
                    fake = fake.cpu()
                save_image(fake, "images/%d%d.png" % (iteration, increase), nrow=10, normalize=False)

            clear_output(wait=True)
            display('Iteration '+str(iteration)+' , epoch '+str(epoch)+' ,step '+str(increase)+'/7 ,increase '+str(DISC_LAMBDA))
            iteration += 1

        myplot(diterats, dcostarr, "dcost%d" % (increase))
        myplot(diterats, wdistarr, "wdist%d" % (increase))
        myplot(giterats, gcostarr, "gcost%d" % (increase))
        torch.save({
            'netG_state_dict': netG.state_dict(),
            'netD_state_dict': netD.state_dict(),
            'optimizerG_state_dict': optimizerG.state_dict(),
            'optimizerD_state_dict': optimizerD.state_dict(),
            'epoch': epoch,
            'iteration': iteration,
            }, './models/lamb%depoch50mnist.pt' % (increase))

'Iteration 59999 , epoch 49 ,step 6/7 ,increase 0.015625'