This Jupyter notebook implements the "Uniform distribution on sphere" example in Appendix D of the paper "Neural Monge Map estimation and its applications". The code runs on Google Colab.

In [None]:
############################
# import necessary modules #
############################


import os
import argparse

import math
import random
import scipy
import numpy as np
from numpy import linalg

import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
import matplotlib as mpl
from mpl_toolkits import mplot3d
from mpl_toolkits.mplot3d import Axes3D

import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from numpy import *



In [None]:
##################################
#            Samplers            #
##################################


########### Generate n samples from initial distribution rho_a ######
# rho_a is the uniform distribution on [0, 2\pi)\times[0, \pi/4] (under spherical coordinate (\theta, \phi, \rho=1))
def Sampler_a(n, dim):
    unif_samp = torch.rand(n, dim)
    samples = torch.zeros(n, dim)
    samples[:, 0] = 2 * np.pi * unif_samp[:, 0]
    samples[:, 1] = np.pi / 4 * unif_samp[:, 1]

    return samples


########### Generate n samples from traget distribution rho_b ######
# rho_b is the uniform distribution on [0, 2\pi)\times[3\pi/4, \pi] (under spherical coordinate (\theta, \phi, \rho=1))
def Sampler_b(n, dim):
    unif_samp = torch.rand(n, dim)
    samples = torch.zeros(n, dim)
    samples[:, 0] = 2 * np.pi * unif_samp[:, 0]
    samples[:, 1] = 3 * np.pi / 4 + np.pi / 4 * unif_samp[:, 1]

    return samples



In [None]:
########################################################
#    Plotting Functions (under spherical coordinate)   #
########################################################


batch_size = 200

# xlim of plot
min_x = 0.
max_x = 2 * np.pi

# ylim of plot
min_y = 0.
max_y = np.pi

# xspace and yspace
xspace = 1
yspace = 1


def plot_pushforwarded_samples(M, iter_number, netf):
    samples = Sampler_a(M, 2)
    samples = torch.Tensor(samples)
    samples = autograd.Variable(samples, requires_grad=True)

    i = 10
    t = 1
    transported_sample = samples.data.numpy() + t * netf(samples).data.numpy()

    figure(num=None, figsize=(10, 10), dpi=80, facecolor='w', edgecolor='k')
    plt.xticks(np.arange(min_x, max_x, xspace))
    plt.yticks(np.arange(min_y, max_y, yspace))
    plt.scatter(transported_sample[:, 0], transported_sample[:, 1], c='blue', marker='.', label='test' + str(i))
    plt.legend()
    plt.title('outer_iteration = {}, samples at t={}'.format(iter_number, 0.1 * i) + '.jpg')
    plt.savefig('outer_iteration = {}, samples at t={}'.format(iter_number, 0.1 * i) + '.jpg')
    plt.close()


def plot_map(M, iter_number, netf):
    samples = Sampler_a(M, 2)
    transported_samples = samples.data.numpy() + 1. * netf(samples).data.numpy()

    figure, axes = plt.subplots(figsize=(10, 10))
    plt.gca().set_aspect('equal', adjustable='box')
    for kk in range(M):
        plt.scatter(samples[kk][0], samples[kk][1], c='blue', s=0.5, marker='o')
        plt.scatter(transported_samples[kk][0], transported_samples[kk][1], c='orange', s=0.1, marker='o')
        x_val = [samples[kk][0], transported_samples[kk][0]]
        y_val = [samples[kk][1], transported_samples[kk][1]]
        plt.plot(x_val, y_val, 'lightgrey', linewidth=0.4, alpha=0.8)

    plt.savefig('non restriction map and pushforward samples at iteration {} '.format(iter_number) + '.jpg')
    plt.close()

    samples = Sampler_a(M, 2)
    transported_samples = samples.data.numpy() + 1. * netf(samples).data.numpy()

    figure, axes = plt.subplots(figsize=(10, 10))
    plt.xlim(min_x, max_x)
    plt.ylim(min_y, max_y)
    plt.gca().set_aspect('equal', adjustable='box')
    for kk in range(M):
        plt.scatter(samples[kk][0], samples[kk][1], c='blue', s=0.5, marker='o')
        plt.scatter(transported_samples[kk][0], transported_samples[kk][1], c='orange', s=0.1, marker='o')
        x_val = [samples[kk][0], transported_samples[kk][0]]
        y_val = [samples[kk][1], transported_samples[kk][1]]
        plt.plot(x_val, y_val, 'lightgrey', linewidth=0.4, alpha=0.8)

    plt.savefig('map and pushforward samples at iteration {} '.format(iter_number) + '.jpg')
    plt.close()

    samples = Sampler_a(M, 2)
    samples[:, 1] = torch.zeros(M)
    transported_samples = samples.data.numpy() + 1. * netf(samples).data.numpy()

    figure, axes = plt.subplots()
    plt.xlim(min_x, max_x)
    plt.ylim(min_y, max_y)
    plt.gca().set_aspect('equal', adjustable='box')
    for kk in range(M):
        plt.scatter(samples[kk][0], samples[kk][1], c='blue', s=0.5, marker='o')
        plt.scatter(transported_samples[kk][0], transported_samples[kk][1], c='orange', s=0.1, marker='o')
        x_val = [samples[kk][0], transported_samples[kk][0]]
        y_val = [samples[kk][1], transported_samples[kk][1]]
        plt.plot(x_val, y_val, 'lightgrey', linewidth=0.4, alpha=0.8)

    plt.savefig('0map and pushforward line 0 at iteration {} '.format(iter_number) + '.jpg')
    plt.close()

    samples = Sampler_a(M, 2)
    samples[:, 1] = torch.zeros(M) + torch.tensor([np.pi / 8])
    transported_samples = samples.data.numpy() + 1. * netf(samples).data.numpy()

    figure, axes = plt.subplots(figsize=(10, 10))
    plt.xlim(min_x, max_x)
    plt.ylim(min_y, max_y)
    plt.gca().set_aspect('equal', adjustable='box')
    for kk in range(M):
        plt.scatter(samples[kk][0], samples[kk][1], c='blue', s=0.5, marker='o')
        plt.scatter(transported_samples[kk][0], transported_samples[kk][1], c='orange', s=0.1, marker='o')
        x_val = [samples[kk][0], transported_samples[kk][0]]
        y_val = [samples[kk][1], transported_samples[kk][1]]
        plt.plot(x_val, y_val, 'lightgrey', linewidth=0.4, alpha=0.8)

    plt.savefig('1map and pushforward line 1 at iteration {} '.format(iter_number) + '.jpg')
    plt.close()

    samples = Sampler_a(M, 2)
    samples[:, 1] = torch.zeros(M) + torch.tensor([np.pi / 4])
    transported_samples = samples.data.numpy() + 1. * netf(samples).data.numpy()

    figure, axes = plt.subplots(figsize=(10, 10))
    plt.xlim(min_x, max_x)
    plt.ylim(min_y, max_y)
    plt.gca().set_aspect('equal', adjustable='box')
    for kk in range(M):
        plt.scatter(samples[kk][0], samples[kk][1], c='blue', s=0.5, marker='o')
        plt.scatter(transported_samples[kk][0], transported_samples[kk][1], c='orange', s=0.1, marker='o')
        x_val = [samples[kk][0], transported_samples[kk][0]]
        y_val = [samples[kk][1], transported_samples[kk][1]]
        plt.plot(x_val, y_val, 'lightgrey', linewidth=0.4, alpha=0.8)

    plt.savefig('2map and pushforward line 2 at iteration {} '.format(iter_number) + '.jpg')
    plt.close()



In [None]:
########################################################
#    Plotting Functions (under Cartesian coordinate)   #
########################################################


# plot initial samples and the pushforwarded samples on sphere
def scatter_sphere(Angle_list_a, Angle_list_b, R, iter_num):

    fig = plt.figure(figsize=(20, 20), dpi=180)
    ax = fig.add_subplot(111, projection='3d')
    # ax.set_aspect("equal")

    num_of_particles = Angle_list_a.shape[0]
    for l in range(num_of_particles):

        theta1 = Angle_list_a[l, 0]
        aphi1 = Angle_list_a[l, 1]

        ax.scatter(R * np.sin(aphi1) * np.cos(theta1), R * np.sin(aphi1) * np.sin(theta1), R * np.cos(aphi1), c='blue')

        theta2 = Angle_list_b[l, 0]
        aphi2 = Angle_list_b[l, 1]

        ax.scatter(R * np.sin(aphi2) * np.cos(theta2), R * np.sin(aphi2) * np.sin(theta2), R * np.cos(aphi2), c='orange')

    #### draw shpere ####
    pi = np.pi
    cos = np.cos
    sin = np.sin
    phi, theta = np.mgrid[0.0:pi:100j, 0.0:2.0 * pi:100j]
    x = R * sin(phi) * cos(theta)
    y = R * sin(phi) * sin(theta)
    z = R * cos(phi)

    ax.plot_surface( x, y, z, rstride=1, cstride=1, color='lightgrey', alpha=0.2, linewidth=0)

    # draw center
    ax.scatter(0., 0., 0., c='r')

    ### meridians & latitudes ###
    # draw meridians
    for k in range(8):
        beta = k * 2 * np.pi / 8
        gamma = np.linspace(0, 2 * np.pi, 200)
        cos_gamma = np.cos(gamma)
        sin_gamma = np.sin(gamma)
        ax.plot(R * (np.cos(beta) * sin_gamma),
                R * (np.sin(beta) * sin_gamma),
                R * (cos_gamma), c='grey', alpha=0.4)

    # circles of latitude
    for k in range(3):
        gamma = (k + 1) * np.pi / 4
        beta = np.linspace(0, 2 * np.pi, 200)
        cos_beta = np.cos(beta)
        sin_beta = np.sin(beta)
        ax.plot(R * np.sin(gamma) * cos_beta,
                R * np.sin(gamma) * sin_beta,
                R * np.cos(gamma), c='grey', alpha=0.4)

    ax.set_xlim([-R - 1, R + 1])
    ax.set_ylim([-R - 1, R + 1])
    ax.set_zlim([-R - 1, R + 1])

    # Set the x axis label of the current axis.
    plt.xlabel('x - axis')
    # Set the y axis label of the current axis.
    plt.ylabel('y - axis')

    plt.savefig('Scatter original and transported samples on sphere iteration = {}'.format(iter_num) + '.jpg')
    plt.close()


# plot pushforwarded samples and the corresponding geodesics on sphere
def plot_on_sphere_with_geodesics(Angle_list_1, Angle_list_2, R, N, iter_num, flag):

    fig = plt.figure(figsize=(20, 20), dpi=180)
    ax = fig.add_subplot(111, projection='3d')
    # ax.set_aspect("equal")
    num_of_particles = Angle_list_1.shape[0]

    for l in range(num_of_particles):
        theta1 = Angle_list_1[l, 0]
        aphi1 = Angle_list_1[l, 1]
        theta2 = Angle_list_2[l, 0]
        aphi2 = Angle_list_2[l, 1]

        ### draw arcs that join each pair of points on shpere ###
        angle = np.arccos(np.sin(aphi1) * np.sin(aphi2) * np.cos(theta1 - theta2) + np.cos(aphi1) * np.cos(aphi2))

        u = R * np.array([np.sin(aphi1) * np.cos(theta1), np.sin(aphi1) * np.sin(theta1), np.cos(aphi1)])
        v = R * np.array([np.sin(aphi2) * np.cos(theta2), np.sin(aphi2) * np.sin(theta2), np.cos(aphi2)])

        u_perp_unnorm = v - np.cos(angle) * u
        u_perp = R * u_perp_unnorm / linalg.norm(u_perp_unnorm)

        ax.scatter(u[0], u[1], u[2], c='blue')
        ax.scatter(v[0], v[1], v[2], c='orange')

        theta = np.linspace(0, angle, N + 1)
        y = np.cos(theta)
        z = np.sin(theta)

        ax.plot(R * np.sin(aphi1) * np.cos(theta1) * y + u_perp[0] * z,
                R * (np.sin(aphi1) * np.sin(theta1) * y) + u_perp[1] * z,
                R * (np.cos(aphi1) * y) + u_perp[2] * z, c='dimgrey')

    ### draw shpere ###
    pi = np.pi
    cos = np.cos
    sin = np.sin
    phi, theta = np.mgrid[0.0:pi:100j, 0.0:2.0 * pi:100j]
    x = R * sin(phi) * cos(theta)
    y = R * sin(phi) * sin(theta)
    z = R * cos(phi)
    ax.plot_surface( x, y, z, rstride=1, cstride=1, color='lightgrey', alpha=0.2, linewidth=0)

    ### draw center of the sphere ###
    ax.scatter(0., 0., 0., c='r')

    ### meridians & latitudes ###
    # draw meridians
    for k in range(8):
        beta = k * 2 * np.pi / 8
        gamma = np.linspace(0, 2 * np.pi, 200)
        cos_gamma = np.cos(gamma)
        sin_gamma = np.sin(gamma)
        ax.plot(R * (np.cos(beta) * sin_gamma),
                R * (np.sin(beta) * sin_gamma),
                R * (cos_gamma), c='grey', alpha=0.4)

    # draw circles of latitude
    for k in range(3):
        gamma = (k + 1) * np.pi / 4
        beta = np.linspace(0, 2 * np.pi, 200)
        cos_beta = np.cos(beta)
        sin_beta = np.sin(beta)
        ax.plot(R * np.sin(gamma) * cos_beta,
                R * np.sin(gamma) * sin_beta,
                R * np.cos(gamma), c='grey', alpha=0.4)

    ax.set_xlim([-R - 1, R + 1])
    ax.set_ylim([-R - 1, R + 1])
    ax.set_zlim([-R - 1, R + 1])

    # Set the x axis label of the current axis.
    plt.xlabel('x - axis')
    # Set the y axis label of the current axis.
    plt.ylabel('y - axis')

    plt.savefig('transport map on sphere [{} iteration] ( {} )'.format(iter_num, flag) + '.jpg')
    plt.close()





In [None]:
###############################
#   Define Neural Networks    #
###############################


################# define network F #################
class networkf(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(networkf, self).__init__()

        main = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.PReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.PReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.PReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.PReLU(),
            nn.Linear(hidden_size, output_size),
        )
        self.main = main

    def forward(self, inputs):
        output = self.main(inputs)
        return output


################# define network phi ################
class networkphi(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(networkphi, self).__init__()

        main = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.PReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.PReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.PReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.PReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.PReLU(),
            nn.Linear(hidden_size, output_size),
        )
        self.main = main

    def forward(self, inputs):
        output = self.main(inputs)

        return output


def weights_init(m):
    if type(m) == nn.Linear:
        torch.nn.init.xavier_normal_(m.weight, gain=1.0)
        m.bias.data.fill_(2.0)

    elif type(m) == nn.BatchNorm1d:
        m.weight.data.fill_(1.0)
        m.bias.data.fill_(2.0)



In [None]:
################################
#     Hyper parameters         #
################################

# fix the random seed
random.seed(8)
torch.manual_seed(8)

# radius of plotted sphere
R = 3.

# set up netf , netphi and initialize their parameters
netf = networkf(2, 8, 2)
netphi = networkphi(2, 8, 1)

# set up the optimizer for netf and netphi
optimizerf = optim.Adam(netf.parameters(), lr=0.9*1e-4)
optimizerphi = optim.Adam(netphi.parameters(), lr=0.9*1e-4)

# some other hyper parameters
iter_Num = 10000
iter_phi = 8
iter_F = 4
batch_size = 4000
plot_period = 5000



In [None]:
############################################
#  main algorithm (max_netphi min_netf)    #
############################################


# record outer loss:
OT_distance = []

######### Outer loop (netphi-optimization) ########
for i in range(iter_Num):

    inner_loss_list = []
    for j in range(iter_F):
        ######## Inner loop (netf-optimization) #######
        sample_a = autograd.Variable(Sampler_a(batch_size, 2), requires_grad=True)
        sample_b = autograd.Variable(Sampler_b(batch_size, 2), requires_grad=True)

        x_k = sample_a + netf(sample_a)

        for para in netf.parameters():
            para.requires_grad = True
        for para in netphi.parameters():
            para.requires_grad = False

        netf.zero_grad()
        netphi.zero_grad()

        # geodesic distance on sphere
        # truncate constant (to avoid gradient blow up of arccos)
        truncate_const = 0.99
        geodesic_dist = 0.1 * torch.acos(truncate_const * (torch.sin(sample_a[:, 1]) * torch.sin(x_k[:, 1]) * torch.cos(sample_a[:, 0] - x_k[:, 0])  + torch.cos(sample_a[:, 1]) * torch.cos(x_k[:, 1])))

        netphi_xk = netphi(x_k)
        Lagrange_Multiplier_part = netphi_xk.mean()
        in_loss = geodesic_dist.mean() + Lagrange_Multiplier_part

        in_loss.backward(retain_graph=True)

        optimizerf.step()


    for l in range(iter_phi):
        ###### Outer optimization (netphi-optimization) #######
        sample_a = autograd.Variable(Sampler_a(batch_size, 2))
        sample_b = autograd.Variable(Sampler_b(batch_size, 2))

        x_k = sample_a + netf(sample_a)

        for para in netf.parameters():
            para.requires_grad = False
        for para in netphi.parameters():
            para.requires_grad = True

        netf.zero_grad()
        netphi.zero_grad()

        outer_netphi_xk = netphi(x_k)
        outer_netphi_samp_b = netphi(sample_b)
        outer_loss = outer_netphi_samp_b.mean() - outer_netphi_xk.mean()

        outer_loss.backward(retain_graph=True)
        optimizerphi.step()

    if (i+1) % 100 == 0:
      print(i)
      ########## compute the OT distance ##########
      sample_a = Sampler_a(50000, 2) # we use 50000 samples to estimate the OT distance
      x_k = sample_a + netf(sample_a)
      geodesic_dist = R * torch.acos(torch.sin(sample_a[:, 1]) * torch.sin(x_k[:, 1]) * torch.cos(sample_a[:, 0] - x_k[:, 0])  + torch.cos(sample_a[:, 1]) * torch.cos(x_k[:, 1]))
      Approximated_OT_dist = geodesic_dist.mean()
      print("The OT distance estimated equals: {}".format(Approximated_OT_dist))
      Approximated_OT_dist_divided_by_pi = Approximated_OT_dist / math.pi
      print("The OT distance estimated  over pi equals: {}".format(Approximated_OT_dist_divided_by_pi))
      OT_distance.append(Approximated_OT_dist_divided_by_pi)


    #  plot samples every period of iterations
    if (i+1) % plot_period == 0:

        ### plot samples under spherical coordinate ##
        plot_pushforwarded_samples(1000, i+1, netf)
        plot_map(1000, i+1, netf)


        ### plot samples under Cartesian coordinate on shpere ##
        # scatter the initial samples and the pushforwarded samples on sphere
        Angle_list_1 = Sampler_a(200, 2)
        Angle_list_2 = Angle_list_1 + netf(Angle_list_1)
        Angle_list_1 = Angle_list_1.numpy()
        Angle_list_2 = Angle_list_2.detach().numpy()
        scatter_sphere(Angle_list_1, Angle_list_2, 3., i+1)

        # map from samples of rhoa to samples of rhob
        Angle_list_1 = Sampler_a(200, 2)
        Angle_list_2 = Angle_list_1 + netf(Angle_list_1)
        Angle_list_2 = Angle_list_2.detach().numpy()
        Angle_list_1 = Angle_list_1.numpy()
        plot_on_sphere_with_geodesics(Angle_list_1, Angle_list_2, 3., 200, i+1, 0)

        # map from ring (phi=pi/8)
        Angle_list_1 = np.zeros([200, 2])
        thetas = np.linspace(0, 2 * np.pi, 200)
        phi0 = np.pi / 8
        Angle_list_1[:, 0] = thetas
        Angle_list_1[:, 1] = phi0 * np.ones(200)
        Angle_list_1_ts = torch.tensor(Angle_list_1)
        Angle_list_1_ts = Angle_list_1_ts.float()
        Angle_list_2_ts = Angle_list_1_ts + netf(Angle_list_1_ts)
        Angle_list_2 = Angle_list_2_ts.detach().numpy()
        plot_on_sphere_with_geodesics(Angle_list_1, Angle_list_2, 3., 200, i+1, 1)

        # map from ring (phi=pi/4)
        Angle_list_1 = np.zeros([200, 2])
        thetas = np.linspace(0, 2 * np.pi, 200)
        phi0 = np.pi / 4
        Angle_list_1[:, 0] = thetas
        Angle_list_1[:, 1] = phi0 * np.ones(200)
        Angle_list_1_ts = torch.tensor(Angle_list_1)
        Angle_list_1_ts = Angle_list_1_ts.float()
        Angle_list_2_ts = Angle_list_1_ts + netf(Angle_list_1_ts)
        Angle_list_2 = Angle_list_2_ts.detach().numpy()
        plot_on_sphere_with_geodesics(Angle_list_1, Angle_list_2, 3., 200, i+1, 2)



######### plot estimated OT distance #########
fig_entire_loss_plot = plt.figure(figsize=(10, 10), dpi=180)
ax = fig_entire_loss_plot.add_subplot(111)
ax.plot(OT_distance )
ax.set_title("plot for OT distance vs iteration number")
fig_entire_loss_plot.savefig("Plot of OT distances") # we do observe rather strong oscillations in the estimated OT distances vs iterations
plt.show()
plt.close()


########## compute the OT distance ##########
sample_a = Sampler_a(50000, 2) # we use 50000 samples to estimate the OT distance
x_k = sample_a + netf(sample_a)
geodesic_dist = R * torch.acos(torch.sin(sample_a[:, 1]) * torch.sin(x_k[:, 1]) * torch.cos(sample_a[:, 0] - x_k[:, 0])  + torch.cos(sample_a[:, 1]) * torch.cos(x_k[:, 1]))
Approximated_OT_dist = geodesic_dist.mean()
print("The OT distance estimated equals: {}".format(Approximated_OT_dist))
Approximated_OT_dist_divided_by_pi = Approximated_OT_dist / math.pi
print("The OT distance estimated  over pi equals: {}".format(Approximated_OT_dist_divided_by_pi))




