This Jupyter notebook implements the "Decreasing function as the cost" example with quadratic cost in Appendix D of the paper "Neural Monge Map estimation and its applications". The code runs on Google Colab.

In [None]:
############################
# import necessary modules #
############################


import os
import argparse

import math
import random
import scipy
import numpy as np
from numpy import *

import matplotlib
matplotlib.use('agg')
from matplotlib.pyplot import figure
from mpl_toolkits import mplot3d
import matplotlib.pyplot as plt

import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import Tensor
from torch.nn import Parameter
from torch.optim.lr_scheduler import ExponentialLR





In [None]:
##################################
#            Samplers            #
##################################


########### Generate n samples from initial distribution \rho_a ######
# rho_a = uniform distribution on [4, 6]\times[0, 2\pi] (under polar coordinate)
def Sampler_a(n, dim):
    R2 = 6
    R1 = 4

    unif_samp = torch.rand(n, dim)

    r = (R2 - R1) * unif_samp[:, 0] + R1
    theta = (2 * math.pi) * unif_samp[:, 1]
    samples = torch.zeros(n, dim)
    samples[:, 0] = r * torch.cos(theta)
    samples[:, 1] = r * torch.sin(theta)

    return samples


########### Generate n samples from destination distribution \rho_b ######
def Sampler_b(n, dim):
    r2 = 2
    r1 = 1

    unif_samp = torch.rand(n, dim)

    r = (r2 - r1) * unif_samp[:, 0] + r1
    theta = (2 * math.pi) * unif_samp[:, 1]
    samples = torch.zeros(n, dim)
    samples[:, 0] = r * torch.cos(theta)
    samples[:, 1] = r * torch.sin(theta)

    return samples



In [None]:
###############################
#   Define Neural Networks    #
###############################


################# define F #############
class networkf(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(networkf, self).__init__()

        main = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.PReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.PReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.PReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.PReLU(),
            nn.Linear(hidden_size, output_size),
        )
        self.main = main

    def forward(self, inputs):
        output = self.main(inputs)
        return output



################# define phi #############
class networkphi(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(networkphi, self).__init__()

        main = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, output_size),
        )
        self.main = main

    def forward(self, inputs):
        output = self.main(inputs)

        return output


def weights_init(m):
    if type(m) == nn.Linear:
        torch.nn.init.xavier_normal_(m.weight, gain=1.0)
        m.bias.data.fill_(2.0)

    elif type(m) == nn.BatchNorm1d:
        m.weight.data.fill_(1.0)
        m.bias.data.fill_(2.0)



In [None]:
############################
#    Plotting Functions    #
############################


# xlim of plot
min_x = - 6.5
max_x = 6.5

# ylim of plot
min_y = - 6.5
max_y = 6.5

# xspace and yspace
xspace = 1
yspace = 1


# This function plots pushforwarded samples, and the details of transport mapping.
# M: size of plot samples; iter_numer: current iteration; netf: current trained netf;
# flag = 1 indicates cost = quadratic reciprocal; flag = 0 indicates cost = quadratic.
def plot_function(M, iter_number, netf, flag):
    R = 6

    re_s = torch.rand(M, 2)
    r = 2 * re_s[:, 1] + 4
    theta = (2 * math.pi) * re_s[:, 0]
    samples = torch.zeros(M, 2)
    samples[:, 0] = r * torch.cos(theta)
    samples[:, 1] = r * torch.sin(theta)

    transported_samples = samples.data.numpy() + 1. * netf(samples).data.numpy()

    figure, axes = plt.subplots(figsize=(15, 15))

    plt.xlim(min_x, max_x)
    plt.ylim(min_y, max_y)
    plt.gca().set_aspect('equal', adjustable='box')

    circle_3 = plt.Circle((0, 0), 2, color='navajowhite', fill=False, linestyle='--')
    circle_4 = plt.Circle((0, 0), 1, color='navajowhite', fill=False, linestyle='--')
    axes.add_artist(circle_3)
    axes.add_artist(circle_4)

    plt.scatter(transported_samples[:, 0], transported_samples[:, 1], c='orange', s=1, marker='o')
    if (flag == 1):
       plt.savefig('[Quadratic Reciprocal cost] pushforward samples  at iteration {} '.format(iter_number) + '.jpg')
    else:
       plt.savefig('[Quadratic cost] pushforward samples  at iteration {} '.format(iter_number) + '.jpg')
    plt.close()

    ###### large circle ######
    R = 6

    re_s = torch.rand(M, 2)

    theta = (2 * math.pi) * re_s[:, 0]
    samples = torch.zeros(M, 2)
    samples[:, 0] = R * torch.cos(theta)
    samples[:, 1] = R * torch.sin(theta)

    transported_samples = samples.data.numpy() + 1. * netf(samples).data.numpy()

    figure, axes = plt.subplots(figsize=(15, 15))

    plt.xlim(min_x, max_x)
    plt.ylim(min_y, max_y)
    plt.gca().set_aspect('equal', adjustable='box')

    circle_3 = plt.Circle((0, 0), 2, color='navajowhite', fill=False, linestyle='--')
    circle_4 = plt.Circle((0, 0), 1, color='navajowhite', fill=False, linestyle='--')

    axes.add_artist(circle_3)
    axes.add_artist(circle_4)

    for kk in range(M):
        plt.scatter(samples[kk][0], samples[kk][1], c='blue', s=0.5, marker='o')
        plt.scatter(transported_samples[kk][0], transported_samples[kk][1], c='orange', s=0.1, marker='o')
        x_val = [samples[kk][0], transported_samples[kk][0]]
        y_val = [samples[kk][1], transported_samples[kk][1]]
        plt.plot(x_val, y_val, 'lightgrey', linewidth=0.1, alpha=0.3)
    if (flag == 1):
       plt.savefig('[Quadratic Reciprocal cost] pushforward samples large circle at iteration {} '.format(iter_number) + '.jpg')
    else:
       plt.savefig('[Quadratic cost] pushforward samples large circle at iteration {} '.format(iter_number) + '.jpg')
    plt.close()

    #### small circle #####
    R = 4

    re_s = torch.rand(M, 2)

    theta = (2 * math.pi) * re_s[:, 0]
    samples_2 = torch.zeros(M, 2)
    samples_2[:, 0] = R * torch.cos(theta)
    samples_2[:, 1] = R * torch.sin(theta)

    transported_samples_2 = samples_2.data.numpy() + 1. * netf(samples_2).data.numpy()

    figure, axes = plt.subplots(figsize=(15, 15))

    plt.xlim(min_x, max_x)
    plt.ylim(min_y, max_y)
    plt.gca().set_aspect('equal', adjustable='box')

    circle_3 = plt.Circle((0, 0), 2, color='navajowhite', fill=False, linestyle='--')
    circle_4 = plt.Circle((0, 0), 1, color='navajowhite', fill=False, linestyle='--')

    axes.add_artist(circle_3)
    axes.add_artist(circle_4)

    for kk in range(M):
        plt.scatter(samples_2[kk][0], samples_2[kk][1], c='blue', s=0.1, marker='o')
        plt.scatter(transported_samples_2[kk][0], transported_samples_2[kk][1], c='orange', s=0.1, marker='o')
        x_val = [samples_2[kk][0], transported_samples_2[kk][0]]
        y_val = [samples_2[kk][1], transported_samples_2[kk][1]]
        plt.plot(x_val, y_val, 'lightgrey', linewidth=0.1, alpha=0.3)
    if (flag == 1):
       plt.savefig('[Quadratic Reciprocal cost] pushforward samples small circle at iteration {} '.format(iter_number) + '.jpg')
    else:
       plt.savefig('[Quadratic cost] pushforward samples small circle at iteration {} '.format(iter_number) + '.jpg')
    plt.close()

    ####### quater circle ############
    # large quarter circle
    R = 6

    re_s = torch.rand(M, 2)

    theta = (0.75 * math.pi - 0.25 * math.pi) * re_s[:, 0] + 0.25 * math.pi
    samples = torch.zeros(M, 2)
    samples[:, 0] = R * torch.cos(theta)
    samples[:, 1] = R * torch.sin(theta)

    transported_samples = samples.data.numpy() + 1. * netf(samples).data.numpy()

    figure, axes = plt.subplots(figsize=(15, 15))

    plt.xlim(min_x, max_x)
    plt.ylim(min_y, max_y)
    plt.gca().set_aspect('equal', adjustable='box')

    # circle_1 = plt.Circle((0, 0), 6, color='skyblue', fill=False)
    # circle_2 = plt.Circle((0, 0), 4, color='skyblue', fill=False)
    circle_3 = plt.Circle((0, 0), 2, color='navajowhite', fill=False, linestyle='--')
    circle_4 = plt.Circle((0, 0), 1, color='navajowhite', fill=False, linestyle='--')

    # axes.add_artist(circle_1)
    # axes.add_artist(circle_2)
    axes.add_artist(circle_3)
    axes.add_artist(circle_4)

    for kk in range(M):
        plt.scatter(samples[kk][0], samples[kk][1], c='blue', s=0.1, marker='o')
        plt.scatter(transported_samples[kk][0], transported_samples[kk][1], c='orange', s=0.1, marker='o')
        x_val = [samples[kk][0], transported_samples[kk][0]]
        y_val = [samples[kk][1], transported_samples[kk][1]]
        plt.plot(x_val, y_val, 'lightgrey', linewidth=0.1, alpha=0.3)
    if (flag == 1):
       plt.savefig('[Quadratic Reciprocal cost] pushforward the large quarter circle at iteration {} '.format(iter_number) + '.jpg')
    else:
       plt.savefig('[Quadratic cost] pushforward the large quarter circle at iteration {} '.format(iter_number) + '.jpg')
    plt.close()

    # small quarter circle
    R = 4

    re_s = torch.rand(M, 2)

    theta = (0.75 * math.pi - 0.25 * math.pi) * re_s[:, 0] + 0.25 * math.pi
    samples = torch.zeros(M, 2)
    samples[:, 0] = R * torch.cos(theta)
    samples[:, 1] = R * torch.sin(theta)

    transported_samples = samples.data.numpy() + 1. * netf(samples).data.numpy()

    figure, axes = plt.subplots(figsize=(15, 15))

    plt.xlim(min_x, max_x)
    plt.ylim(min_y, max_y)
    plt.gca().set_aspect('equal', adjustable='box')

    # circle_1 = plt.Circle((0, 0), 6, color='skyblue', fill=False)
    # circle_2 = plt.Circle((0, 0), 4, color='skyblue', fill=False)
    circle_3 = plt.Circle((0, 0), 2, color='navajowhite', fill=False, linestyle='--')
    circle_4 = plt.Circle((0, 0), 1, color='navajowhite', fill=False, linestyle='--')

    # axes.add_artist(circle_1)
    # axes.add_artist(circle_2)
    axes.add_artist(circle_3)
    axes.add_artist(circle_4)

    for kk in range(M):
        plt.scatter(samples[kk][0], samples[kk][1], c='blue', s=0.1, marker='o')
        plt.scatter(transported_samples[kk][0], transported_samples[kk][1], c='orange', s=0.1, marker='o')
        x_val = [samples[kk][0], transported_samples[kk][0]]
        y_val = [samples[kk][1], transported_samples[kk][1]]
        plt.plot(x_val, y_val, 'lightgrey', linewidth=0.1, alpha=0.3)

    if (flag == 1):
       plt.savefig('[Quadratic Reciprocal cost pushforward the smaller quarter circle at iteration {} '.format(iter_number) + '.jpg')
    else:
       plt.savefig('[Quadratic cost pushforward the smaller quarter circle at iteration {} '.format(iter_number) + '.jpg')
    plt.close()



In [None]:
################################
#     Hyper parameters         #
################################

# fix the random seed
random.seed(12)
torch.manual_seed(12)

# set up networks F and phi and initialize their parameters
netf = networkf(2, 36, 2)
netphi = networkphi(2, 36, 1)
netf.apply(weights_init)
netphi.apply(weights_init)

# set up the optimizers for network F and phi
optimizerf = optim.Adam(netf.parameters(), lr = 1e-4)
optimizerphi = optim.Adam(netphi.parameters(), lr = 1e-4)

# some other hyper parameters
iter = 8000
iter_1 = 8
iter_2 = 6
batch_size = 2000
plot_batch_size = 800
plot_period = 400



In [None]:
################################################################################
#  Main Algorithm (max_netphi min_netf)  cost= quadratic i.e, c(x,y)=|x-y|^2   #
################################################################################
# In the implementation, we actually consider c(x,y) = 1/2 * |x-y|^2.
# Notice that the 1/2 ahead of the quadratic cost remains the constrained optimization problem equivalent to the original problem theoretically.
# But this 1/2 proves to be helpful in our training. The training process is not very effective if the number is close to 1.


# record outer loss:
loss_list = []

######### Outer loop (netphi-optimization) ########
for i in range(iter):

    inner_loss_list = []
    for j in range(iter_1):
        ######## Inner loop (netf-optimization) #######
        sample_a = autograd.Variable(Sampler_a(batch_size, 2))
        sample_b = autograd.Variable(Sampler_b(batch_size, 2))

        x_k = sample_a + netf(sample_a)

        for para in netf.parameters():
            para.requires_grad = True
        for para in netphi.parameters():
            para.requires_grad = False

        netf.zero_grad()
        netphi.zero_grad()


        sq_netf = torch.mul(netf(sample_a), netf(sample_a))
        sq_netf = torch.sum(sq_netf, 1)
        E_kinetic = 0.5 * sq_netf.mean()
        netphi_xk = netphi(x_k)
        Lagrange_Multiplier_part = netphi_xk.mean()
        in_loss = E_kinetic + Lagrange_Multiplier_part
        in_loss.backward(retain_graph=True)

        optimizerf.step()

        # record inner loss data
        inner_loss_list.append(in_loss.data.numpy())

    # Plot inner loss every 400 iterations
    if i % 400 == 0:
        fig_inner_loss_plot = plt.figure()
        ax = fig_inner_loss_plot.add_subplot(111)
        ax.scatter(range(0, len(inner_loss_list)), inner_loss_list, alpha=0.7, s=25)
        ax.set_title("[Quadratic cost] plot for inner losses at {} outer iteration".format(i) + '.jpg')
        fig_inner_loss_plot.savefig("[Quadratic cost] Plot of inner losses at {} outer iteration".format(i) + '.jpg')

        plt.close()

    for l in range(iter_2):
        ###### Outer optimization (netphi-optimization) #######
        sample_a = autograd.Variable(Sampler_a(batch_size, 2))
        sample_b = autograd.Variable(Sampler_b(batch_size, 2))

        x_k = sample_a + netf(sample_a)

        for para in netf.parameters():
            para.requires_grad = False
        for para in netphi.parameters():
            para.requires_grad = True

        netf.zero_grad()
        netphi.zero_grad()

        outer_netphi_xk = netphi(x_k)
        outer_netphi_samp_b = netphi(sample_b)
        outer_loss = outer_netphi_samp_b.mean() - outer_netphi_xk.mean()

        outer_loss.backward(retain_graph=True)
        optimizerphi.step()

        # record loss data
        loss_list.append(outer_loss.data.numpy())

    #  plot samples every period of iterations
    if (i+1) % plot_period == 0:
       flag = 0
       plot_function(plot_batch_size, i+1, netf, flag)


######### plot loss curve #########
fig_entire_loss_plot = plt.figure()
ax = fig_entire_loss_plot.add_subplot(111)
ax.scatter(range(0, len(loss_list)), loss_list, alpha=0.7, s=25)
ax.set_title("plot for losses")
fig_entire_loss_plot.savefig("Plot of losses")
plt.show()
plt.close()

