In [1]:
import os
import sys
import math
import random
import argparse
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions.normal import Normal
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
import torch.nn.functional as F
import torch.autograd as autograd

# Distribution functions (from *samplers.py*)

In [2]:
def distribution1(x, batch_size=512):
    # Distribution defined as (x, U(0,1)). Can be used for question 3
    while True:
        yield (np.array(
            [(x, random.uniform(0, 1)) for _ in range(batch_size)]))

In [3]:
def distribution2(batch_size=512):
    # High dimension uniform distribution
    while True:
        yield (np.random.uniform(0, 1, (batch_size, 2)))

# Question 1

In [4]:
def JSD_objective(D, x, y):
    return math.log(2) + 0.5 * (torch.mean(
        torch.log(D.forward(x)))) + 0.5 * (torch.mean(torch.log(1 - D.forward(y))))

# Question 2

In [5]:
# Calculates the gradient penalty loss for WGAN GP
def compute_gradient_penalty(D, x, y):
    # Random weight term for interpolation between real and fake samples
    alpha = torch.empty_like(x).uniform_(0,1)
    # Get random interpolation between real and fake samples
    interpolates = (alpha * x + ((1 - alpha) * y)).requires_grad_(True)
    interpolates_forward = D.forward(interpolates)
    # need a fake grad output
    fake = torch.ones_like(interpolates_forward)
    # Get gradient w.r.t. interpolates
    gradients = autograd.grad(
        outputs=interpolates_forward,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        only_inputs=True)[0]

    return ((torch.norm(gradients, p=2, dim=1) - 1)**2).mean()

In [6]:
def wassertein_estimate(D, x, y):
    return torch.mean(D.forward(x)) - torch.mean(D.forward(y))

In [7]:
def WD_gp_objective(D, x, y):
    lambda_gp = 10
    return torch.mean(D.forward(x)) - torch.mean(D.forward(y)) - lambda_gp * compute_gradient_penalty(D, x, y)

# Question 3

In [8]:
class Discriminator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Discriminator, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.relu(self.map1(x))
        x = self.relu(self.map2(x))
        return self.sigmoid(self.map3(x))

In [9]:
d_input_size = 2
d_hidden_size = 64
d_output_size = 1

In [10]:
def distance_estimate(dist1, dist2, n_epochs=1000, Wasserstein=False):
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    D = Discriminator(
    input_size=d_input_size,
    hidden_size=d_hidden_size,
    output_size=d_output_size).to(device)
    
    #d_optimizer = optim.SGD(D.parameters(), lr=0.001)
    d_optimizer = optim.Adam(D.parameters(), lr=0.001)
    
    objective = WD_gp_objective if Wasserstein else JSD_objective
    
    samplesx, samplesy = [], []
    for i, (d1, d2) in enumerate(zip(dist1, dist2)):
        if i >= n_epochs:
            break
        samplesx.append(d1)
        samplesy.append(d2)

    samplesx = torch.from_numpy(np.asarray(samplesx)).float().to(device)
    samplesy = torch.from_numpy(np.asarray(samplesy)).float().to(device)
    
    for x, y in zip(samplesx, samplesy):
        d_optimizer.zero_grad()
        loss_D = -objective(D, x, y)
        loss_D.backward()
        d_optimizer.step()
        #print(wassertein_estimate(D, x, y).item())
        
    if Wasserstein:
        #return -loss_D.cpu().item()
        return wassertein_estimate(D, samplesx[-1], samplesy[-1]).item()
    else:
        return -loss_D.cpu().item()

In [16]:
print(distance_estimate(distribution1(0), distribution1(1), 1000, True))

0.2082553505897522


In [12]:
n_epochs = 500
JSD, W = [], []
phi = np.arange(-1, 1, 0.1)

for p in phi:
    JSD.append(distance_estimate(distribution1(0), distribution1(p), n_epochs, False))
    W.append(distance_estimate(distribution1(0), distribution1(p), n_epochs, True))

KeyboardInterrupt: 

In [None]:
plt.plot(phi, JSD, label="Jensen-Shannon")
plt.plot(phi, W, label="Wasserstein")
plt.legend(title="Divergence")
plt.xlabel("Phi")
plt.show()