In [0]:
import numpy as np
import torch
from matplotlib import pyplot as plt
from torch.autograd import Variable
from torch import  optim
import ot
import ot.gpu
from scipy.stats import sem, t
from sklearn.datasets import make_moons, make_swiss_roll,make_circles
from itertools import combinations
import time

In [0]:
def generate_swiss_roll(num_samples):
    swiss_roll=make_swiss_roll(n_samples=num_samples, noise=0.0, random_state=None)
    swiss_roll=swiss_roll[0][:,[0,2]]
    swiss_roll[:,0]=swiss_roll[:,0]/np.abs(swiss_roll[:,0]).max()
    swiss_roll[:,1]=swiss_roll[:,1]/np.abs(swiss_roll[:,1]).max()
    samples=torch.from_numpy(swiss_roll).type(torch.FloatTensor)
    return samples

def generate_moons(num_samples):
    swiss_roll=make_moons(n_samples=num_samples, noise=0.0, random_state=None)[0]
    swiss_roll=swiss_roll/np.abs(swiss_roll).max()
    samples=torch.from_numpy(swiss_roll).type(torch.FloatTensor)
    return samples
def generate_circle(batch_size):

    circles = make_circles(2 * batch_size, noise=.01)
    z = np.squeeze(circles[0][np.argwhere(circles[1] == 0), :])
    return torch.from_numpy(z).type(torch.FloatTensor)
def generate_rectangle(batch_size):
    z = 2 * (np.random.uniform(size=(batch_size, 2)) - 0.5)
    return torch.from_numpy(z).type(torch.FloatTensor)
def generate_8gaussian(num_samples):
    theta=np.arange(0,2*np.pi,2*np.pi/8,)
    samples=np.zeros([num_samples,2])
    cov= [[1e-4, 0], [0, 1e-4]]
    n,i,k=0,0,0
    while n<num_samples:
      l=min(60,num_samples-n)
      mean=[2**0.5*np.sin(theta[i%8]),2**0.5*np.cos(theta[i%8])]
      samples[i*60:i*60+l]=np.random.multivariate_normal(mean, cov, l)
      n+=l
      i+=1
    return torch.from_numpy(samples).type(torch.FloatTensor)
def generate_25gaussian(num_samples):
  x=np.linspace(-2**0.5,2**0.5,5)
  y=np.linspace(-2**0.5,2**0.5,5)
  n=num_samples//25
  cov= [[1e-4, 0], [0, 1e-4]]
  samples=np.zeros([num_samples,2])
  index=0
  for i in x:
    for j in y:
      mean=[i,j]
      samples[index*n:(index+1)*n]=np.random.multivariate_normal(mean, cov, n)
      index+=1
  return torch.from_numpy(samples).type(torch.FloatTensor)
def generate_knot(num_samples):

    l=np.arange(0,2*np.pi,2*np.pi/num_samples,)
    x=(np.sin(np.pi*l)*np.cos(l)).reshape((num_samples,1))
    y=(np.sin(np.pi*l)*np.sin(l)).reshape((num_samples,-1))
    z=np.concatenate((x,y),axis=1)
    return torch.from_numpy(z).type(torch.FloatTensor)
def generate_heart(num_samples):
    t=np.arange(0,2*np.pi,2*np.pi/num_samples,)
    a=16*np.sin(t)**3
    b=13*np.cos(t)-5*np.cos(2*t)-2*np.cos(3*t)-np.cos(4*t)
    samples=np.zeros([num_samples,2])
    samples[:,0],samples[:,1]=a,b
    samples[:,0]/=np.abs(samples[:,0]).max()
    samples[:,1]/=np.abs(samples[:,1]).max(),
    return torch.from_numpy(samples).type(torch.FloatTensor)
def rand_projections(dim, num_projections=1000):
    projections = torch.randn((num_projections, dim))
    projections = projections / torch.sqrt(torch.sum(projections ** 2, dim=1, keepdim=True))
    return projections
def sliced_wasserstein_distance(first_samples,
                                second_samples,
                                num_projection=1000,
                                p=2,
                                device='cuda'):
    dim = second_samples.size(1)
    projections = rand_projections(dim, num_projections).to(device)
    first_projections = first_samples.matmul(projections.transpose(0, 1))
    second_projections = (second_samples.matmul(projections.transpose(0, 1)))
    wasserstein_distance = torch.abs((torch.sort(first_projections.transpose(0, 1), dim=1)[0] -
                            torch.sort(second_projections.transpose(0, 1), dim=1)[0]))
    wasserstein_distance = torch.pow(torch.sum(torch.pow(wasserstein_distance, p), dim=1),1./p)
    return torch.pow(torch.pow(wasserstein_distance, p).mean(),1./p)
def max_sliced_wasserstein_distance(first_samples,
                                second_samples,
                                num_projection=1000,
                                p=2,max_iter=10,
                                device='cuda'):
    dim = second_samples.size(1)
    num_projection=1
    first_samples_detach = first_samples.detach()
    second_samples_detach = second_samples.detach()
    projections = rand_projections(dim, 1).to(device)
    projections.requires_grad_()
    optimizer=optim.Adam([projections], lr=0.005, betas=(0.999, 0.999))
    for i in range(max_iter):
      first_projections = first_samples_detach.matmul(projections.transpose(0, 1))
      second_projections = second_samples_detach.matmul(projections.transpose(0, 1))
      wasserstein_distance = torch.abs((torch.sort(first_projections.transpose(0, 1), dim=1)[0] -
                              torch.sort(second_projections.transpose(0, 1), dim=1)[0]))
      wasserstein_distance = torch.sum(torch.pow(wasserstein_distance, p), dim=1)
      wasserstein_distance = torch.pow(wasserstein_distance.mean(),1./p)
      optimizer.zero_grad()
      loss=-wasserstein_distance
      loss.backward(retain_graph=True)
      optimizer.step()
      projections.data = projections.data / torch.sqrt(torch.sum(projections.data ** 2, dim=1))
    first_projections = first_samples.matmul(projections.transpose(0, 1))
    second_projections = second_samples.matmul(projections.transpose(0, 1))
    wasserstein_distance = torch.abs((torch.sort(first_projections.transpose(0, 1), dim=1)[0] -
                                      torch.sort(second_projections.transpose(0, 1), dim=1)[0]))
    wasserstein_distance = torch.sum(torch.pow(wasserstein_distance, p), dim=1)
    wasserstein_distance = torch.pow(wasserstein_distance.mean(),1./p)
    return wasserstein_distance  
def poly_degree(degree,dim):
    comb=combinations(np.arange(1,degree+dim),dim-1)
    comb=list(comb)
    x_degree=np.zeros([len(comb),dim])
    for (i,c) in enumerate(comb):
        c=list(c)
        c.append(degree+dim)
        for (j,index) in enumerate(c):
            if j == 0:
                x_degree[i,j]=index-1
            else:
                x_degree[i,j]=index-c[j-1]-1
    return torch.from_numpy(x_degree).type(torch.FloatTensor)
def polynomial_function(samples, parameters, device='cuda'):
    theta, degree_matrix = parameters[0], parameters[1]
    prod_value_poly = torch.zeros(samples.shape[0], degree_matrix.shape[0]).to(device)
    degree_matrix = degree_matrix.unsqueeze(0).expand(samples.shape[0], degree_matrix.shape[0],
                                                      degree_matrix.shape[1])
    samples = samples.unsqueeze(1).expand(samples.shape[0], degree_matrix.shape[1], samples.shape[1])
    exp_matrix = samples ** degree_matrix
    polynomial_function_value = torch.matmul(torch.prod(exp_matrix, -1), theta)
    return polynomial_function_value

def max_GSWD_polynomial_3(encoded_samples,distribution_samples,p=2,num_projection=1000,degree=3 ,device='cuda',max_iter=10):
    num_projection=1
    mat_1 = poly_degree(degree, encoded_samples.shape[-1]).to(device)
    coefficient_1 = torch.randn((mat_1.shape[0], num_projection), device=device)
    coefficient_1.data = coefficient_1.data / torch.sqrt(torch.sum(coefficient_1.data ** 2, dim=0))
    coefficient_1.requires_grad_()
    optimizer=optim.Adam([coefficient_1], lr=0.005, betas=(0.999, 0.999))
    for i in range(max_iter):
      encoded_projections=polynomial_function(encoded_samples.detach(),[coefficient_1,mat_1.detach()])
      distribution_projections=polynomial_function(distribution_samples.detach(),[coefficient_1,mat_1.detach()])
      wasserstein_distance=torch.abs(torch.sort(encoded_projections.transpose(0,1),dim=1)[0]-
                                    torch.sort(distribution_projections.transpose(0,1),dim=1)[0])
      wasserstein_distance=torch.sum(torch.pow(wasserstein_distance,p),dim=-1)
      wasserstein_distance=torch.pow(torch.mean(wasserstein_distance),1./p)
      optimizer.zero_grad()
      loss=-wasserstein_distance
      loss.backward(retain_graph=True)
      optimizer.step()
      coefficient_1.data = coefficient_1.data / torch.sqrt(torch.sum(coefficient_1.data ** 2, dim=0))
    encoded_projections=polynomial_function(encoded_samples,[coefficient_1.detach(),mat_1.detach()])
    distribution_projections=polynomial_function(distribution_samples,[coefficient_1.detach(),mat_1.detach()])
    wasserstein_distance=torch.abs(torch.sort(encoded_projections.transpose(0,1),dim=1)[0]-
                                   torch.sort(distribution_projections.transpose(0,1),dim=1)[0])
    wasserstein_distance=torch.sum(torch.pow(wasserstein_distance,p),dim=-1)
    wasserstein_distance=torch.pow(torch.mean(wasserstein_distance),1./p)
    return wasserstein_distance

def max_GSWD_polynomial_5(encoded_samples,distribution_samples,p=2,num_projection=1000,degree=5 ,device='cuda',max_iter=10):
    num_projection=1
    mat_1 = poly_degree(degree, encoded_samples.shape[-1]).to(device)
    coefficient_1 = torch.randn((mat_1.shape[0], num_projection), device=device, requires_grad=True)
    coefficient_1.data = coefficient_1.data / torch.sqrt(torch.sum(coefficient_1.data ** 2, dim=0))
    optimizer=optim.Adam([coefficient_1], lr=0.005, betas=(0.999, 0.999))
    for i in range(max_iter):
      encoded_projections=polynomial_function(encoded_samples.detach(),[coefficient_1,mat_1.detach()])
      distribution_projections=polynomial_function(distribution_samples.detach(),[coefficient_1,mat_1.detach()])
      wasserstein_distance=torch.abs(torch.sort(encoded_projections.transpose(0,1),dim=1)[0]-
                                    torch.sort(distribution_projections.transpose(0,1),dim=1)[0])
      wasserstein_distance=torch.sum(torch.pow(wasserstein_distance,p),dim=-1)
      wasserstein_distance=torch.pow(torch.mean(wasserstein_distance),1./p)
      optimizer.zero_grad()
      loss=-wasserstein_distance
      loss.backward(retain_graph=True)
      optimizer.step()
      coefficient_1.data = coefficient_1.data / torch.sqrt(torch.sum(coefficient_1.data ** 2, dim=0))
    encoded_projections=polynomial_function(encoded_samples,[coefficient_1.detach(),mat_1.detach()])
    distribution_projections=polynomial_function(distribution_samples,[coefficient_1.detach(),mat_1.detach()])
    wasserstein_distance=torch.abs(torch.sort(encoded_projections.transpose(0,1),dim=1)[0]-
                                   torch.sort(distribution_projections.transpose(0,1),dim=1)[0])
    wasserstein_distance=torch.sum(torch.pow(wasserstein_distance,p),dim=-1)
    wasserstein_distance=torch.pow(torch.mean(wasserstein_distance),1./p)
    return wasserstein_distance

def GSWD_polynomial(encoded_samples,distribution_samples,p=2,num_projection=1000,degree=5 ,device='cuda'):
    mat_1 = poly_degree(degree, encoded_samples.shape[-1]).to(device)
    coefficient_1 = torch.randn((mat_1.shape[0], num_projection), device=device, requires_grad=True)
    coefficient_1.data = coefficient_1.data / torch.sqrt(torch.sum(coefficient_1.data ** 2, dim=0))
    encoded_projections=polynomial_function(encoded_samples,[coefficient_1.detach(),mat_1.detach()])
    distribution_projections=polynomial_function(distribution_samples,[coefficient_1.detach(),mat_1.detach()])
    wasserstein_distance=torch.abs(torch.sort(encoded_projections.transpose(0,1),dim=1)[0]-
                                   torch.sort(distribution_projections.transpose(0,1),dim=1)[0])
    wasserstein_distance=torch.sum(torch.pow(wasserstein_distance,p),dim=-1)
    wasserstein_distance=torch.pow(torch.mean(wasserstein_distance),1./p)
    return wasserstein_distance

def GSWD_circular(encoded_samples,distribution_samples,p=2,num_projection=1000,r=1,device='cuda'):
    theta = torch.randn((num_projection, encoded_samples.shape[1])).to(device)
    theta = theta / torch.sqrt(torch.sum(theta ** 2, dim=1, keepdim=True))
    cost_matrix_1 = torch.sqrt(cost_matrix(encoded_samples, theta * r))
    cost_matrix_2 = torch.sqrt(cost_matrix(distribution_samples, theta * r))
    wasserstein_distance = torch.abs(torch.sort(cost_matrix_1.transpose(0, 1), dim=1)[0] -
                                      torch.sort(cost_matrix_2.transpose(0, 1), dim=1)[0])
    wasserstein_distance =torch.sum(torch.pow(wasserstein_distance, p), dim=1)
    return torch.pow(wasserstein_distance.mean(), 1. / p)

def max_GSWD_circular(encoded_samples,distribution_samples,p=2,num_projection=1,r=1,device='cuda',max_iter=10):
    num_projection=1
    theta = torch.randn((num_projection, encoded_samples.shape[1])).to(device)
    theta.data = theta.data / torch.sqrt(torch.sum(theta.data ** 2, dim=1, keepdim=True))
    theta.requires_grad_()
    optimizer=optim.Adam([theta], lr=0.005, betas=(0.999, 0.999))

    for i in range(max_iter):
      cost_matrix_1 = torch.sqrt(cost_matrix(encoded_samples.detach(), theta * r))
      cost_matrix_2 = torch.sqrt(cost_matrix(distribution_samples.detach(), theta * r))
      wasserstein_distance = torch.abs(torch.sort(cost_matrix_1.transpose(0, 1), dim=1)[0] -
                                        torch.sort(cost_matrix_2.transpose(0, 1), dim=1)[0])
      wasserstein_distance =torch.sum(torch.pow(wasserstein_distance, p), dim=1)
      wasserstein_distance=torch.pow(wasserstein_distance.mean(), 1. / p)
      optimizer.zero_grad()
      loss=-wasserstein_distance
      loss.backward(retain_graph=True)
      optimizer.step()
      theta.data = theta.data / torch.sqrt(torch.sum(theta.data ** 2, dim=1))
    cost_matrix_1 = torch.sqrt(cost_matrix(encoded_samples, theta * r))
    cost_matrix_2 = torch.sqrt(cost_matrix(distribution_samples, theta * r))
    wasserstein_distance = torch.abs(torch.sort(cost_matrix_1.transpose(0, 1), dim=1)[0] -
                                      torch.sort(cost_matrix_2.transpose(0, 1), dim=1)[0])
    wasserstein_distance =torch.sum(torch.pow(wasserstein_distance, p), dim=1)
    return torch.pow(wasserstein_distance.mean(), 1. / p)

def circular_function(samples, parameters):
    radial, theta = parameters[0], parameters[1]
    cost_matrix_1 = torch.sqrt(cost_matrix(samples, theta * radial))
    return cost_matrix_1

def cost_matrix(encoded_samples, distribution_samples, p=2):
    n = encoded_samples.size(0)
    m = distribution_samples.size(0)
    d = encoded_samples.size(1)
    return ((encoded_samples.reshape(n,d,1)-distribution_samples.transpose(1,0).unsqueeze(0))**2).sum(1)

def GSWD_polynomial3(encoded_samples,distribution_samples,p=2,num_projection=1000,device='cuda'):
    theta = torch.randn((num_projection, 4)).to(device)
    theta = theta / torch.sqrt(torch.sum(theta ** 2, dim=1, keepdim=True))
    encoded_samples_=torch.ones(encoded_samples.shape[0],4).to(device)
    distribution_samples_=torch.ones(distribution_samples.shape[0],4).to(device)
    encoded_samples_[:,0],encoded_samples_[:,1],encoded_samples_[:,2],encoded_samples_[:,3]=encoded_samples[:,1]**3,encoded_samples[:,0]*encoded_samples[:,1]**2,encoded_samples[:,0]**2*encoded_samples[:,1],encoded_samples[:,0]**3
    distribution_samples_[:,0],distribution_samples_[:,1],distribution_samples_[:,2],distribution_samples_[:,3]=distribution_samples[:,1]**3,distribution_samples[:,0]*distribution_samples[:,1]**2,distribution_samples[:,0]**2*distribution_samples[:,1]**1,distribution_samples[:,0]**3
    encoded_projections=encoded_samples_.matmul(theta.transpose(0,1))
    distribution_projections=distribution_samples_.matmul(theta.transpose(0,1))
    wasserstein_distance=torch.abs(torch.sort(encoded_projections.transpose(0,1),dim=1)[0]-
                                   torch.sort(distribution_projections.transpose(0,1),dim=1)[0])
    wasserstein_distance=torch.sum(torch.pow(wasserstein_distance,p),dim=-1)
    wasserstein_distance=torch.pow(torch.mean(wasserstein_distance),1./p)
    return wasserstein_distance

from torch import nn
class Mapping(nn.Module):
    def __init__(self, size):
        super(Mapping, self).__init__()
        self.size = size
        self.net = nn.Sequential(nn.Linear(self.size, self.size))
    def forward(self, inputs):
        outputs =self.net(inputs)
        return torch.cat((inputs,outputs),dim=-1)

def augmented_sliced_wassersten_distance(first_samples,second_samples,num_projections,phi,
                                                       phi_op,p=2,max_iter=10,lam=20,device='cuda',net_type='fc'):
    embedding_dim = first_samples.size(1)
    first_samples_detach = first_samples.detach()
    second_samples_detach = second_samples.detach()
    for _ in range(max_iter):
        first_samples_transform=phi(first_samples_detach)
        second_samples_transform = phi(second_samples_detach)
        reg=lam*(torch.norm(first_samples_transform,p=2,dim=1)+torch.norm(second_samples_transform,p=2,dim=1)).mean()
        projections = rand_projections(first_samples_transform.shape[-1], num_projections).to(device)
        encoded_projections = first_samples_transform.matmul(projections.transpose(0, 1))
        distribution_projections = (second_samples_transform.matmul(projections.transpose(0, 1)))
        wasserstein_distance = torch.abs((torch.sort(encoded_projections.transpose(0, 1), dim=1)[0] -
                                torch.sort(distribution_projections.transpose(0, 1), dim=1)[0]))
        wasserstein_distance = torch.sum(torch.pow(wasserstein_distance, p), dim=1)*512/first_samples_detach.shape[0]
        wasserstein_distance = torch.pow(wasserstein_distance.mean(), 1. / p)
        loss=reg-wasserstein_distance
        phi_op.zero_grad()
        loss.backward(retain_graph=True)
        phi_op.step()
    first_samples_transform = phi(first_samples)
    second_samples_transform = phi(second_samples)
    projections=rand_projections(first_samples_transform.shape[-1], num_projections).to(device)
    encoded_projections = first_samples_transform.matmul(projections.transpose(0, 1))
    distribution_projections = second_samples_transform.matmul(projections.transpose(0, 1))
    wasserstein_distance = torch.abs((torch.sort(encoded_projections.transpose(0, 1), dim=1)[0] -
                                      torch.sort(distribution_projections.transpose(0, 1), dim=1)[0]))
    wasserstein_distance = torch.sum(torch.pow(wasserstein_distance, p), dim=1)
    wasserstein_distance = torch.pow(wasserstein_distance.mean(), 1. / p)
    return  wasserstein_distance

class TransformNet(nn.Module):
    def __init__(self, size):
        super(TransformNet, self).__init__()
        self.size = size
        self.net = nn.Sequential(nn.Linear(self.size,self.size))
    def forward(self, input):
        out =self.net(input)
        return out/torch.sqrt(torch.sum(out**2,dim=1,keepdim=True))

def cosine_distance_torch(x1, x2=None, eps=1e-8):
    x2 = x1 if x2 is None else x2
    w1 = x1.norm(p=2, dim=1, keepdim=True)
    w2 = w1 if x2 is x1 else x2.norm(p=2, dim=1, keepdim=True)
    return torch.mean(torch.abs(torch.mm(x1, x2.t()) / (w1 * w2.t()).clamp(min=eps)))

def distributional_sliced_wasserstein_distance(first_samples, second_samples, num_projections, f, f_op,
                                               p=2, max_iter=10, lam=1, device='cuda'):
    embedding_dim = first_samples.size(1)
    pro = rand_projections(embedding_dim, num_projections).to(device)
    first_samples_detach = first_samples.detach()
    second_samples_detach = second_samples.detach()
    for _ in range(max_iter):
        
        pro = rand_projections(embedding_dim, num_projections).to(device)
        projections = f(pro)
        reg = lam * cosine_distance_torch(projections, projections)
        encoded_projections = first_samples_detach.matmul(projections.transpose(0, 1))
        distribution_projections = (second_samples_detach.matmul(projections.transpose(0, 1)))
        wasserstein_distance = torch.abs((torch.sort(encoded_projections.transpose(0, 1), dim=1)[0] -
                                torch.sort(distribution_projections.transpose(0, 1), dim=1)[0]))
        wasserstein_distance = torch.pow(torch.sum(torch.pow(wasserstein_distance, p), dim=1),1./p)
        wasserstein_distance = torch.pow(torch.pow(wasserstein_distance, p).mean(),1./p)
        loss = reg - wasserstein_distance
        f_op.zero_grad()
        loss.backward(retain_graph=True)
        f_op.step()
    pro = rand_projections(embedding_dim, num_projections).to(device)
    projections = f(pro)
    encoded_projections = first_samples.matmul(projections.transpose(0, 1))
    distribution_projections = (second_samples.matmul(projections.transpose(0, 1)))
    wasserstein_distance = torch.abs((torch.sort(encoded_projections.transpose(0, 1), dim=1)[0] -
                            torch.sort(distribution_projections.transpose(0, 1), dim=1)[0]))
    wasserstein_distance = torch.sum(torch.pow(wasserstein_distance, p), dim=1)
    wasserstein_distance = torch.pow(wasserstein_distance.mean(), 1. / p)
    return wasserstein_distance
class MLP(nn.Module):
    def __init__(self, din=2,dout=10, num_filters=32, depth=3):
        super(MLP, self).__init__()
        self.din=din
        self.dout=dout
        self.init_num_filters = num_filters
        self.depth=depth

        self.features = nn.Sequential()
        
        for i in range(self.depth):
            if i==0:
                self.features.add_module('linear%02d'%(i+1),nn.Linear(self.din,self.init_num_filters))        
            else:
                self.features.add_module('linear%02d'%(i+1),nn.Linear(self.init_num_filters,self.init_num_filters))
            self.features.add_module('activation%02d'%(i+1),nn.LeakyReLU(inplace=True))

        self.features.add_module('linear%02d'%(i+2),nn.Linear(self.init_num_filters,self.dout))
    
    def forward(self, x):        
        return self.features(x)
    
    def init_weights(self,m):
        if type(m) == nn.Linear:
            nn.init.xavier_uniform_(m.weight)
            m.bias.data.fill_(0.01)
    
    def reset(self):
        self.features.apply(self.init_weights)

def gsw_nn_1(first_samples, second_samples, net,net_op,max_iter=10,p=2):
  encoded_projections=net(first_samples)
  distribution_projections=net(second_samples)
  wasserstein_distance = torch.abs((torch.sort(encoded_projections.transpose(0, 1), dim=1)[0] -
                          torch.sort(distribution_projections.transpose(0, 1), dim=1)[0]))
  wasserstein_distance = torch.sum(torch.pow(wasserstein_distance, p), dim=1)
  wasserstein_distance = torch.pow(wasserstein_distance.mean(), 1. / p)
  return wasserstein_distance

def gsw_nn_3(first_samples, second_samples, net,net_op,max_iter=10,p=2):
  encoded_projections=net(first_samples)
  distribution_projections=net(second_samples)
  wasserstein_distance = torch.abs((torch.sort(encoded_projections.transpose(0, 1), dim=1)[0] -
                          torch.sort(distribution_projections.transpose(0, 1), dim=1)[0]))
  wasserstein_distance = torch.sum(torch.pow(wasserstein_distance, p), dim=1)
  wasserstein_distance = torch.pow(wasserstein_distance.mean(), 1. / p)
  return wasserstein_distance

def max_gsw_nn_1(first_samples, second_samples, net, net_op, max_iter=10,p=2):

  for i in range(max_iter):
    encoded_projections=net(first_samples.detach())
    distribution_projections=net(second_samples.detach())
    wasserstein_distance = torch.abs((torch.sort(encoded_projections.transpose(0, 1), dim=1)[0] -
                            torch.sort(distribution_projections.transpose(0, 1), dim=1)[0]))
    wasserstein_distance = torch.sum(torch.pow(wasserstein_distance, p), dim=1)
    wasserstein_distance = torch.pow(wasserstein_distance.mean(), 1. / p)
    loss=-wasserstein_distance
    net_op.zero_grad()
    loss.backward(retain_graph=True)
    net_op.step()
  encoded_projections=net(first_samples)
  distribution_projections=net(second_samples)
  wasserstein_distance = torch.abs((torch.sort(encoded_projections.transpose(0, 1), dim=1)[0] -
                          torch.sort(distribution_projections.transpose(0, 1), dim=1)[0]))
  wasserstein_distance = torch.sum(torch.pow(wasserstein_distance, p), dim=1)
  wasserstein_distance = torch.pow(wasserstein_distance.mean(), 1. / p)
  return wasserstein_distance

def max_gsw_nn_3(first_samples, second_samples, net, net_op, max_iter=10,p=2):

  for i in range(max_iter):
    encoded_projections=net(first_samples.detach())
    distribution_projections=net(second_samples.detach())
    wasserstein_distance = torch.abs((torch.sort(encoded_projections.transpose(0, 1), dim=1)[0] -
                            torch.sort(distribution_projections.transpose(0, 1), dim=1)[0]))
    wasserstein_distance = torch.sum(torch.pow(wasserstein_distance, p), dim=1)
    wasserstein_distance = torch.pow(wasserstein_distance.mean(), 1. / p)
    loss=-wasserstein_distance
    net_op.zero_grad()
    loss.backward(retain_graph=True)
    net_op.step()
  encoded_projections=net(first_samples)
  distribution_projections=net(second_samples)
  wasserstein_distance = torch.abs((torch.sort(encoded_projections.transpose(0, 1), dim=1)[0] -
                          torch.sort(distribution_projections.transpose(0, 1), dim=1)[0]))
  wasserstein_distance = torch.sum(torch.pow(wasserstein_distance, p), dim=1)
  wasserstein_distance = torch.pow(wasserstein_distance.mean(), 1. / p)
  return wasserstein_distance


In [0]:
n=1
num_projections=1000
num_iteration=2000
num_experiments=50
interval=5
lr=0.002
device='cuda'
functions=[augmented_sliced_wassersten_distance,sliced_wasserstein_distance,GSWD_polynomial3,GSWD_polynomial,GSWD_circular,
           distributional_sliced_wasserstein_distance,gsw_nn_1,max_gsw_nn_1,gsw_nn_3,max_gsw_nn_3]
generators=[generate_swiss_roll,generate_circle,generate_8gaussian,generate_moons,generate_knot,generate_heart,generate_rectangle,generate_25gaussian]

dataset=['Swiss','Circle','8gaussian','Moon','Knot','Heart','Rectangle','25gaussian']

for d,generator in enumerate(generators):

  W2_recorder=np.zeros([len(functions),num_experiments,num_iteration//interval+1])
  for k in range(len(functions)):
    function=functions[k]
    if function==augmented_sliced_wassersten_distance:
      for j in range(num_experiments):
        target_distribution=generator(500).to('cuda')
        lam=0.05/target_distribution.abs().mean()
        evolving_distribution=torch.randn_like(target_distribution)
        evolving_distribution=Variable(evolving_distribution,requires_grad=True).to('cuda')
        optimizer=optim.Adam([evolving_distribution], lr=lr ,betas=(0.9, 0.999))
        phi=Mapping(2).to(device)
        phi_op = optim.Adam(phi.parameters(), lr=0.005, betas=(0.999, 0.999))
        for i in range(num_iteration):
          optimizer.zero_grad()

          loss=augmented_sliced_wassersten_distance(evolving_distribution,target_distribution,num_projections,phi,
                                                                  phi_op,p=2,max_iter=10,lam=lam,device='cuda',net_type='fc')
          loss.backward(retain_graph=True)
          optimizer.step()
          if (i+1)%interval==0 or i==0:
            M=(((evolving_distribution.unsqueeze(2)-target_distribution.transpose(1,0).unsqueeze(0))**2).sum(1)).cpu().detach().numpy()
            ed1,ed2=np.ones((evolving_distribution.shape[0],))/evolving_distribution.shape[0],np.ones((evolving_distribution.shape[0],))/evolving_distribution.shape[0]
            W2=ot.emd2(ed1,ed2,M)**0.5
            if i==0:
              W2_recorder[k,j,i]=W2
            if (i+1)%interval==0:
              W2_recorder[k,j,(i+1)//interval]=W2
        if j%10==0:
          plt.scatter(evolving_distribution.cpu().detach().numpy()[:,0],evolving_distribution.cpu().detach().numpy()[:,1])
          plt.show()
        print(functions[k].__name__,j,W2)
    elif function==distributional_sliced_wasserstein_distance:
      for j in range(num_experiments):
        target_distribution=generator(500).to('cuda')
        lam=10
        evolving_distribution=torch.randn_like(target_distribution)
        evolving_distribution=Variable(evolving_distribution,requires_grad=True).to('cuda')
        optimizer=optim.Adam([evolving_distribution], lr=lr ,betas=(0.9, 0.999))
        transform_net = TransformNet(2).to(device)
        op_trannet = optim.Adam(transform_net.parameters(), lr=0.005, betas=(0.999, 0.999))
        for i in range(num_iteration):
          optimizer.zero_grad()
          loss=function(evolving_distribution,target_distribution, num_projections, transform_net,
                                                           op_trannet,p=2, max_iter=10, lam=lam,device='cuda')
          loss.backward(retain_graph=True)
          optimizer.step()
          if (i+1)%interval==0 or i==0:
            M=(((evolving_distribution.unsqueeze(2)-target_distribution.transpose(1,0).unsqueeze(0))**2).sum(1)).cpu().detach().numpy()
            ed1,ed2=np.ones((evolving_distribution.shape[0],))/evolving_distribution.shape[0],np.ones((evolving_distribution.shape[0],))/evolving_distribution.shape[0]
            W2=ot.emd2(ed1,ed2,M)**0.5
            if i==0:
              W2_recorder[k,j,i]=W2
            if (i+1)%interval==0:
              W2_recorder[k,j,(i+1)//interval]=W2
        if j%10==0:
          plt.scatter(evolving_distribution.cpu().detach().numpy()[:,0],evolving_distribution.cpu().detach().numpy()[:,1])
          plt.show()
        print(functions[k].__name__,j,W2)
    elif function==max_gsw_nn_3 or function==gsw_nn_3 or function==max_gsw_nn_1 or function==gsw_nn_1:
      for j in range(num_experiments):
        target_distribution=generator(500).to('cuda')
        evolving_distribution=torch.randn_like(target_distribution)
        evolving_distribution=Variable(evolving_distribution,requires_grad=True).to('cuda')
        optimizer=optim.Adam([evolving_distribution], lr=lr ,betas=(0.9, 0.999))
        if function==max_gsw_nn_1 or function==gsw_nn_1:
          net = MLP(din=2,dout=num_projections,num_filters=32,depth=1).to(device)
        elif function==gsw_nn_3 or function==max_gsw_nn_3:
          net = MLP(din=2,dout=num_projections,num_filters=32,depth=3).to(device)
        net_op = optim.Adam(net.parameters(), lr=0.005,betas=(0.999, 0.999))
        for i in range(num_iteration):
          optimizer.zero_grad()
          loss=function(evolving_distribution,target_distribution, net, net_op,max_iter=10)
          loss.backward(retain_graph=True)
          optimizer.step()
          if (i+1)%interval==0 or i==0:
            M=(((evolving_distribution.unsqueeze(2)-target_distribution.transpose(1,0).unsqueeze(0))**2).sum(1)).cpu().detach().numpy()
            ed1,ed2=np.ones((evolving_distribution.shape[0],))/evolving_distribution.shape[0],np.ones((evolving_distribution.shape[0],))/evolving_distribution.shape[0]
            W2=ot.emd2(ed1,ed2,M)**0.5
            if i==0:
              W2_recorder[k,j,i]=W2
            if (i+1)%interval==0:
              W2_recorder[k,j,(i+1)//interval]=W2
        if j%10==0:
          plt.scatter(evolving_distribution.cpu().detach().numpy()[:,0],evolving_distribution.cpu().detach().numpy()[:,1])
          plt.show()
        print(functions[k].__name__,j,W2)
    else:
      for j in range(num_experiments):
        target_distribution=generator(500).to('cuda')
        evolving_distribution=torch.randn_like(target_distribution)
        evolving_distribution=Variable(evolving_distribution,requires_grad=True).to('cuda')
        optimizer=optim.Adam([evolving_distribution], lr=lr, betas=(0.9, 0.999))
        for i in range(num_iteration):
          optimizer.zero_grad()
          loss=function(evolving_distribution,target_distribution,num_projection=num_projections)
          loss.backward(retain_graph=True)
          optimizer.step()
          if (i+1)%interval==0 or i==0:
            M=(((evolving_distribution.unsqueeze(2)-target_distribution.transpose(1,0).unsqueeze(0))**2).sum(1)).cpu().detach().numpy()
            ed1,ed2=np.ones((evolving_distribution.shape[0],))/evolving_distribution.shape[0],np.ones((evolving_distribution.shape[0],))/evolving_distribution.shape[0]
            W2=ot.emd2(ed1,ed2,M)**0.5
            if i==0:
              W2_recorder[k,j,i]=W2
            if (i+1)%interval==0:
              W2_recorder[k,j,(i+1)//interval]=W2
        print(functions[k].__name__,j,W2)
        if j%10==0:
          plt.scatter(evolving_distribution.cpu().detach().numpy()[:,0],evolving_distribution.cpu().detach().numpy()[:,1])
          plt.show()