In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions.normal import Normal
import argparse
import os
import numpy as np
import math
import sys
import random
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
import torch
import matplotlib.pyplot as plt


In [10]:
def distribution2(batch_size=512):
    # High dimension uniform distribution
    while True:
        yield(np.random.uniform(0, 1, (batch_size, 2)))
        

def distribution1(x, batch_size=512):
    # Distribution defined as (x, U(0,1)). Can be used for question 3
    while True:
        yield(np.array([(x, random.uniform(0, 1)) for _ in range(batch_size)]))

In [11]:
def JSD_objective(D,x,y):
    D_x = D.forward(x)
    D_y = D.forward(y)
    return math.log(2) + 0.5 * (torch.mean(torch.log(D_x))) +  0.5 * (torch.mean(torch.log(1 - D_y)))

cuda = True if torch.cuda.is_available() else False

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

In [12]:
def compute_gradient_penalty(D, x, y):
    """Calculates the gradient penalty loss for WGAN GP"""
    # Random weight term for interpolation between real and fake samples
    alpha = Tensor(np.random.random((x.size(0), 1)))
    # Get random interpolation between real and fake samples
    interpolates = (alpha * x + ((1 - alpha) * y)).requires_grad_(True)
    d_interpolates = D(interpolates)
    fake = Variable(Tensor(x.shape[0], 1).fill_(1.0), requires_grad=False)
    # Get gradient w.r.t. interpolates
    gradients = autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

In [13]:
def WD_gp_objective(D, x , y):
       
    lambda_gp = 10
    D_x = D.forward(x)
    D_y = D.forward(y)
    return torch.mean(D_x) - torch.mean(D_y) - lambda_gp * compute_gradient_penalty(D, x , y)

In [14]:
d_input_size = 2
d_hidden_size = 75   
d_output_size = 1

class Discriminator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Discriminator, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)
        self.elu = torch.nn.ELU()
    def forward(self, x):
        x = self.elu(self.map1(x))
        x = self.elu(self.map2(x))
        return torch.sigmoid( self.map3(x) )

In [15]:
D = Discriminator(input_size=d_input_size, hidden_size=d_hidden_size, output_size=d_output_size)
d_optimizer = optim.SGD(D.parameters(), lr=0.001 ) 
def distance_estimate(dist1, dist2, n_epochs = 1000, Wasserstein = False):
#    losses = []
    if Wasserstein:
        objective = WD_gp_objective
    else:
        objective = JSD_objective
        
    for epoch in range(n_epochs):
        samplesx = next(dist1)
        samplesy = next(dist2)
        
        x =torch.from_numpy(samplesx).float() #realdata
        y = torch.from_numpy(samplesy).float()
        d_optimizer.zero_grad()
        loss_D = -objective(D, x, y)
        loss_D.backward()
        d_optimizer.step()
        
#        losses.append(-loss_W.item())

    return -loss_D.item()



In [16]:
dist1 = iter(distribution1(0))

n_epochs = 400

y_axis_JSD = []
y_axis_W = []
phi = np.arange(-1,1,0.1)

for i in range(len(phi)):

    dist1 = iter(distribution1(0))
    dist2 = iter(distribution1(phi[i]))

    
    e = distance_estimate(dist1, dist2, n_epochs, False)
    print(e)
    
    y_axis_JSD.append(e)
    
#    y_axis_W.append(distance_estimate(dist1, dist2, n_epochs, True))

plt.plot(phi, y_axis_JSD)

0.0727398693561554
0.14052468538284302
0.1943548619747162
0.23627428710460663
0.26424404978752136
0.2766450047492981
0.2694860100746155
0.24174611270427704
0.18804693222045898
0.10565412044525146
-0.0002110004425048828
-0.10906785726547241
-0.18836838006973267
-0.2187577188014984
-0.20818832516670227
-0.16920223832130432
-0.11783340573310852
-0.0604877769947052
0.0035831034183502197
0.08092963695526123


[<matplotlib.lines.Line2D at 0x7f19e80db588>]