This Jupyter notebook implements the "Population transportation on the sphere" example which uses $L^2$ cost discussed in Appendix D of the paper "Neural Monge Map estimation and its applications". The code runs on Google Colab.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
pip install rasterio

In [None]:
############################
# import necessary modules #
############################


import os
import argparse

import math
import numpy as np
from numpy import *
import scipy

import time

import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import random
import numpy.random as npr

import matplotlib
matplotlib.use('agg')
from matplotlib.pyplot import figure
from mpl_toolkits import mplot3d
import matplotlib.pyplot as plt

import rasterio


We will run the program on CUDA if it is available

In [None]:
torch.cuda.is_available()


In [None]:
#########################################
#            Processing Data            #
#########################################


# Reference: Amos, B., Cohen, S., Luise, G., and Redko, I. (2022). Meta optimal transport.
# Reference code: https://github.com/facebookresearch/meta-ot#spherical-transport.
# Download "2020 Tiff data at 15-minute resolution" data from https://sedac.ciesin.columbia.edu/data/set/gpw-v4-population-density-adjusted-to-2015-unwpp-country-totals-rev11/data-download#.
# Save the file in the current directory and run the codes in this box.


save_path = os.getcwd()
population_fname = save_path + '/pop-15min.tif'
src = rasterio.open(population_fname)
P = src.read(1)  # P is a 720 by 1440 array.
P[P < 0] = 0.


# Non-uniform population distribution over landmass
Pflat = P.ravel()
Pflat = Pflat / Pflat.max()
Pflat = Pflat / Pflat.sum()

# Uniform population distribution over landmass
Uflat = Pflat.copy()
Uflat[Uflat > 0] = 1.
Uflat /= Uflat.sum()

# Plot uniform population map
Umap = P
Umap[Umap > 0] = 1.0
fig = plt.figure(figsize=(80, 40))
plt.imshow(Umap, cmap='tab20c', interpolation='nearest')
plt.savefig('world uniform population map'+'.jpg')


def read_uniform_map():

    save_path = os.getcwd()
    population_fname = save_path + '/pop-15min.tif'
    src = rasterio.open(population_fname)

    P = src.read(1)
    P[P < 0] = 0.

    Umap = P
    Umap[Umap > 0] = 1.0

    return Umap



In [None]:
##################################
#            Samplers            #
##################################


# Transform the spherical coordinates (\theta, \phi, \rho=1) to Cartesian coordinates (x, y, z).
def spherical_to_cartesian(spherical_coord):
    theta = spherical_coord[:, 0]
    phi = spherical_coord[:, 1]
    cartesian_coord = torch.zeros(spherical_coord.size()[0], 3)
    cartesian_coord[:, 0] = torch.cos(theta) * torch.sin(phi)
    cartesian_coord[:, 1] = torch.sin(theta) * torch.sin(phi)
    cartesian_coord[:, 2] = torch.cos(phi)

    return cartesian_coord


# Define the sampler that samples from the distribution p.
def sample(p, num_samples, seed=0):
    npr.seed(seed)
    sample_Is = npr.choice(len(p), p=p, size=num_samples)  # ATTENTION! this sampling will take around 0.02s for 800 samples
    # computing phi coordinate
    samples_phi = sample_Is / P.shape[1]
    samples_phi = (samples_phi / P.shape[0]) * np.pi
    # computing theta coordinate
    samples_theta = sample_Is % P.shape[1]
    samples_theta = (samples_theta / P.shape[1]) * 2 * np.pi
    # obtain both spherical and Cartesian coordinates
    if torch.cuda.is_available():
       samples_spherical = torch.tensor(np.vstack((samples_theta, samples_phi)).T).cuda()
    else:
       samples_spherical = torch.tensor(np.vstack((samples_theta, samples_phi)).T)
    samples_euclidean = spherical_to_cartesian(samples_spherical)

    return samples_spherical, samples_euclidean


# Generate n samples (in spherical coordinates) from the source distribution \rho_a
def Sampler_a_sph(n):
    Source_locs_sperical, _ = sample(Pflat, n, seed=0)
    return Source_locs_sperical

# Generate n samples (in both spherical and Cartesian coordinates) from the source distribution \rho_a
def Sampler_a(n):
    Source_locs_sperical, Source_locs_cartesian = sample(Pflat, n, seed=0)
    return Source_locs_sperical, Source_locs_cartesian

# Generate n samples (in spherical coordinates) from the target distribution \rho_b
def Sampler_b_sph(n):
    Target_locs_spherical, _ = sample(Uflat, n, seed=1)
    return Target_locs_spherical

# Generate n samples (in both spherical and Cartesian coordinates) from the target distribution \rho_b
def Sampler_b(n):
    Target_locs_spherical, Target_locs_cartesian = sample(Uflat, n, seed=1)
    return Target_locs_spherical, Target_locs_cartesian



In [None]:
##########################################################
#  Functions that process the coordinates after mapping  #
##########################################################


# Fit sperical coordinates \theta, \phi into intervals [0, 2\pi) and [0, \pi]
def fit_coordinate(samples):
    N = samples.size()[0]
    fit_samples = samples
    for i in range(N):
      theta = samples[i, 0] - math.floor(samples[i, 0] / (2*math.pi)) * 2 * math.pi
      phi = samples[i, 1] - math.floor(samples[i, 1] / (2*math.pi)) * 2 * math.pi
      if phi > math.pi:
         phi = 2 * math.pi - phi
         if theta < math.pi:
           theta = theta + math.pi
         else:
           theta = theta - math.pi

      fit_samples[i, 0] = theta
      fit_samples[i, 1] = phi

    return fit_samples


# Pick k random positions uniformly over the land, return the closest position to the given location int_coordinate
# We use this method in the example presented in our paper.
def find_random(int_coordinate, k, P):
    candidate_points = Sampler_b_sph(k).cpu().detach().numpy()
    candidate_coordinate = np.zeros((k, 2))
    candidate_coordinate[:, 0] = (candidate_points[:, 1] / math.pi * P.shape[0]).astype(int)
    candidate_coordinate[:, 1] = (candidate_points[:, 0] / (2 * math.pi) * P.shape[1]).astype(int)

    flag = 0
    Min_dist = 1000 * 1000 + 1500 * 1500
    for ii in range(k):
      dist = (candidate_coordinate[ii, 0] - int_coordinate[0]) * (candidate_coordinate[ii, 0] - int_coordinate[0]) + (candidate_coordinate[ii, 1] - int_coordinate[1]) * (candidate_coordinate[ii, 1] - int_coordinate[1])
      if dist < Min_dist:
        flag = ii
        Min_dist = dist

    return candidate_coordinate[flag]


# Map the samples that are not located on land to the nearest point among k uniformly generated locations on landmass.
# If the sample is already on land, it will not be modified.
# We use this method in the example presented in our paper.
def map_to_land_randomk(samples, num_rand_loc, seed=1):

    random.seed(seed)

    map2land_samples = np.zeros(samples.shape)
    for i in range(samples.shape[0]):
       map2land_samples[i] = samples[i]

    N = samples.shape[0]
    for i in range(N):
      int_coord_sample_0 = int(samples[i, 1] / math.pi * P.shape[0])
      int_coord_sample_1 = int(samples[i, 0] / (2 * math.pi) * P.shape[1])

      # if the point is map to sea, one should map it onto land
      if P[int_coord_sample_0, int_coord_sample_1] == 0.0:
        rand_land_loc = find_random(np.array([int_coord_sample_0, int_coord_sample_1]), num_rand_loc, P)

        # transform back to sperical coordinate
        sample_phi = (rand_land_loc[0] / P.shape[0]) * np.pi
        sample_theta = (rand_land_loc[1] / P.shape[1]) * 2 * np.pi

        # replace the original coordinate by the new coordinate that is located on land
        map2land_samples[i, 0] = sample_theta
        map2land_samples[i, 1] = sample_phi

    return samples, map2land_samples




In [None]:
##############################################################
#    Plotting samples and geodesics that join the samples    #
##############################################################


# Plot the samples under sperical coordinates and join each pair of samples (original one and the mapped one) with geodesics on sphere.
def plot_samples_spherical_coord_geodesics_continuousplot(batchsize, iteration, num_traj, num_rand_loc, num_join_arc):

    scale = 720 / math.pi

    samples_source = Sampler_a_sph(batchsize)
    samples_target = Sampler_b_sph(batchsize)
    samples_source = samples_source.float()
    samples_transported = samples_source + netf(samples_source)

    scaled_samples_source = scale * samples_source.cpu()
    scaled_samples_target = scale * samples_target.cpu()
    scaled_samples_transported = scale * samples_transported.cpu()

    fit_samples_transported = fit_coordinate(samples_transported).cpu().detach().numpy()
    fit_trans_sample, map2landsample = map_to_land_randomk(fit_samples_transported, num_rand_loc, seed=1)

    scaled_fit_samples_transported = scale * fit_samples_transported
    scaled_map2landsample = scale * map2landsample

    ############################## plotting samples and geodesics ################################
    fig = plt.figure(figsize=(80, 40))

    # plot world uniform population map
    U = read_uniform_map()
    plt.imshow(U, cmap='tab20c', interpolation='nearest', alpha = 0.5)

    for indx in range(min(batchsize, num_traj)):

        theta1 = samples_source[indx, 0].cpu().numpy()
        aphi1 = samples_source[indx, 1].cpu().numpy()
        theta2 = map2landsample[indx, 0]
        aphi2 = map2landsample[indx, 1]

        # plot geodesics under the geographical coordinate system (spherical coordinate system)
        # return the spherical coordinate at time l / num_join_arc (0<=l<=num_join_arc), assume geodesic time from 0 to 1
        # radius is set to be 1 as default
        angle = np.arccos(np.sin(aphi1) * np.sin(aphi2) * np.cos(theta1 - theta2) + np.cos(aphi1) * np.cos(aphi2))
        u = R * np.array([np.sin(aphi1) * np.cos(theta1), np.sin(aphi1) * np.sin(theta1), np.cos(aphi1)])
        v = R * np.array([np.sin(aphi2) * np.cos(theta2), np.sin(aphi2) * np.sin(theta2), np.cos(aphi2)])
        u_perp_unnorm = v - np.cos(angle) * u
        u_perp = R * u_perp_unnorm / linalg.norm(u_perp_unnorm)

        angle_list = np.linspace(0, angle, num_join_arc + 1)
        y = np.cos(angle_list)
        z = np.sin(angle_list)

        x_coord = u[0] * y + u_perp[0] * z
        y_coord = u[1] * y + u_perp[1] * z
        z_coord = u[2] * y + u_perp[2] * z

        theta_coord = np.arccos(x_coord / np.sqrt(x_coord * x_coord + y_coord * y_coord))
        phi_coord = np.arccos(z_coord / np.sqrt(x_coord * x_coord + y_coord * y_coord + z_coord * z_coord))

        bool_y = y_coord > 0.0
        theta_coord = theta_coord * bool_y + (2 * math.pi - theta_coord) * (1 - bool_y)

        scaled_theta_coord = scale * theta_coord
        scaled_phi_coord = scale * phi_coord

        color='orange'
        for c in range(num_join_arc):
            if np.abs(theta_coord[c] - theta_coord[c+1]) < math.pi:
                plt.plot([scaled_theta_coord[c], scaled_theta_coord[c+1]], [scaled_phi_coord[c], scaled_phi_coord[c+1]], c=color, alpha=0.9, linewidth=1.5)
            elif theta_coord[c] < math.pi:
                plt.plot([scaled_theta_coord[c], scale * 0], [scaled_phi_coord[c], scaled_phi_coord[c]], c=color, alpha=0.9, linewidth=1.5)
                plt.plot([scale * 2 * math.pi, scaled_theta_coord[c+1]], [scaled_phi_coord[c+1], scaled_phi_coord[c+1]], c=color, alpha=0.9, linewidth=1.5)
            else:
                plt.plot([scaled_theta_coord[c], scale * 2 * math.pi], [scaled_phi_coord[c], scaled_phi_coord[c]], c=color, alpha=0.9, linewidth=1.5)
                plt.plot([scale * 0, scaled_theta_coord[c+1]], [scaled_phi_coord[c+1], scaled_phi_coord[c+1]], c=color, alpha=0.9, linewidth=1.5)

    num_plot_sample = min(2000, batchsize)
    # Plot source samples
    plt.scatter(scaled_samples_source.detach().numpy()[:num_plot_sample, 0], scaled_samples_source.detach().numpy()[:num_plot_sample, 1], c='blue', s=20)
    # Plot the pushforwarded (transported) samples
    plt.scatter(scaled_map2landsample[:num_plot_sample, 0], scaled_map2landsample[:num_plot_sample, 1], c='green', s=20)

    plt.xlim([0, scale * 2 * math.pi])
    plt.ylim([scale * math.pi, 0])

    plt.xlabel('theta -axis')
    plt.ylabel('phi -axis')

    plt.savefig('Plot of samples under spherical coordinates (map locations on sea to land, with transport geodesic trajectory), at iteration {}'.format(iteration) + '.jpg')

    plt.close()

    #################### plotting samples ###########################
    fig = plt.figure(figsize=(80, 40))
    # plot world uniform population map
    U = read_uniform_map()
    plt.imshow(U, cmap='tab20c', interpolation='nearest', alpha = 0.5)

    # Plot source samples
    plt.scatter(scaled_samples_source.detach().numpy()[:, 0], scaled_samples_source.detach().numpy()[:, 1], c='blue', s=1)
    # Plot the pushforwarded (transported) samples
    plt.scatter(scaled_map2landsample[:, 0], scaled_map2landsample[:, 1], c='green', s=1)

    plt.xlim([0, scale * 2 * math.pi])
    plt.ylim([scale * math.pi, 0])

    plt.xlabel('theta -axis')
    plt.ylabel('phi -axis')

    plt.savefig('Plot of samples under spherical coordinates (map locations on sea to land), at iteration {}'.format(iteration) + '.jpg')

    plt.close()



In [None]:
###############################
#   Define Residue Network    #
###############################



Resnet_stepsize = 0.5

class rnet(nn.Module):
    def __init__(self, network_length, input_dimension, hidden_dimension, output_dimension):
        super(rnet, self).__init__()
        # super().__init__()

        self.linears = nn.ModuleList([nn.Linear(input_dimension, hidden_dimension)])
        self.linears.extend([nn.Linear(hidden_dimension, hidden_dimension) for _ in range(1, network_length-1)])
        self.linears.extend([nn.Linear(hidden_dimension, output_dimension, bias=False)])
        self.network_length = network_length
        self.dropout = nn.Dropout(0.24)
        self.prelu = nn.PReLU()

    def initialization(self):
        for l in self.linears:
            # if the length of network is large, then the value of wights should be near 1.0,
            # otherwise the value of derivative of the network will be small (<< 1.0)
            l.weight.data.uniform_(-0.1, 0.1)
            l.bias.data.uniform_(-0.1, 0.1)

    def forward(self, x):

        ll = self.linears[0]
        x = self.dropout(ll(x))

        for l in self.linears[1: self.network_length - 1]:
            x = x + Resnet_stepsize * l(x)
            x = self.dropout(self.prelu(x))

        ll = self.linears[-1]
        x = self.dropout(ll(x))

        return x



In [None]:
################################
#     Hyper parameters         #
################################


# random seed
random.seed(8)
torch.manual_seed(8)
torch.cuda.manual_seed(8)

# Default radius of sphere = 1
R = 1.

# Set up networks netf and netphi
if torch.cuda.is_available():
   netf = rnet(4, 2, 32, 2).cuda()
   netphi = rnet(4, 2, 32, 1).cuda()
else:
   netf = rnet(4, 2, 32, 2)
   netphi = rnet(4, 2, 32, 1)


# Set up optimizers
optimizerf = optim.Adam(netf.parameters(), lr = 0.5 * 1e-4, betas=(0.9, 0.99))
optimizerphi = optim.Adam(netphi.parameters(), lr = 0.5 * 1e-4, betas=(0.9, 0.99))

iter_Num = 500000
iter_phi = 1
iter_F =  8
batch_size_src = 400
batch_size_tgt = 400


# Plotting period
period = 100000
period_rec = 100000
priod_renew_pool = 40000



This is the main function. Figures will be produced by running this function. Samples from source and transported samples will also be saved in the desired folder.

In [None]:
###########################################
#  Main Algorithm (max_netphi min_netf)   #
###########################################

# sample pool
sample_a_pool = Sampler_a_sph(40000)
sample_b_pool = Sampler_b_sph(40000)

# record outer loss:
loss_list = []

######### Outer loop (netphi-optimization) ########
for i in range(iter_Num):

    if (i % priod_renew_pool == 0):
       # renew the sample pool
       sample_a_pool = Sampler_a_sph(40000)
       sample_b_pool = Sampler_b_sph(40000)

    inner_loss_list = []

    for j in range(iter_F):
        ######## Inner loop (netf-optimization) #######
        # start_time = time.time()
        indices_a = torch.randperm(len(sample_a_pool))[:batch_size_src]
        indices_b = torch.randperm(len(sample_b_pool))[:batch_size_tgt]
        sample_a = autograd.Variable(sample_a_pool[indices_a].float(), requires_grad=True)
        sample_b = autograd.Variable(sample_b_pool[indices_b].float(), requires_grad=True)
        # end_time = time.time()
       # print(end_time - start_time)

        x_k = sample_a + netf(sample_a)

        netf.zero_grad()
        netphi.zero_grad()

        # ###### Compute modified geodesic distance on sphere ######
        # # Under spherical coordinate, arcos is to sensitive when x is approaching 1, as the gradient will approach to infinity, causing the value increasing to NaN.
        # # We modify the arccos function by using its linear approximation: arccos(x) ~~ \pi/2 - x.
        # geodesic_dist = - (torch.sin(sample_a[:, 1]) * torch.sin(x_k[:, 1]) * torch.cos(sample_a[:, 0] - x_k[:, 0]) + torch.cos(sample_a[:, 1]) * torch.cos(x_k[:, 1]))
        ###### Compute l2 distance on spherical coordinate plane (choose Prime meridian as the longitude angle \theta = 0) ######
        # d((theta_1, phi_1), (theta_2, phi_2)) = ((theta_a - theta_b)^2 + (phi_1-phi_2)^2)/2 (suppose theta_a=(theta_1-pi)mod(2pi), similar for theta_b)
        l2_dist = (sample_a - x_k) * (sample_a - x_k)
        l2_dist = torch.sum(l2_dist, 1)/2

        netphi_xk = netphi(x_k)
        Lagrange_Multiplier_part = netphi_xk.mean()

        # in_loss = geodesic_dist.mean() - Lagrange_Multiplier_part
        in_loss = l2_dist.mean() - Lagrange_Multiplier_part
        in_loss.backward(retain_graph=True)

        optimizerf.step()

        # record inner loss
        inner_loss_list.append(in_loss.cpu().data.numpy())

    # Plot inner loss every period of iterations
    if i % period == 0:
        fig_inner_loss_plot = plt.figure()
        ax = fig_inner_loss_plot.add_subplot(111)
        ax.scatter(range(0, len(inner_loss_list)), inner_loss_list, alpha=0.7, s=25)
        ax.set_title("plot for inner losses at {} outer iteration".format(i) + '.jpg')
        fig_inner_loss_plot.savefig("Plot of inner losses at {} outer iteration".format(i) + '.jpg')
        plt.show()
        plt.close()

    ######### phi-optimization ########
    for l in range(iter_phi):

        indices_a = torch.randperm(len(sample_a_pool))[:batch_size_src]
        indices_b = torch.randperm(len(sample_b_pool))[:batch_size_tgt]
        sample_a = autograd.Variable(sample_a_pool[indices_a].float(), requires_grad=True)
        sample_b = autograd.Variable(sample_b_pool[indices_b].float(), requires_grad=True)

        x_k = sample_a + netf(sample_a)

        netf.zero_grad()
        netphi.zero_grad()

        outer_netphi_xk = netphi(x_k)
        outer_netphi_samp_b = netphi(sample_b)
        outer_loss = outer_netphi_xk.mean() - outer_netphi_samp_b.mean()

        outer_loss.backward(retain_graph=True)
        optimizerphi.step()

        loss_list.append(outer_loss.cpu().data.numpy())

    # Plot loss data every 100 iterations
    if i % 100 == 0:
        print("Current iter = {}".format(i))
        print("innerloss:")
        print(in_loss.cpu().data.numpy())
        print("outerloss:")
        print(outer_loss.cpu().data.numpy())

    # Plots every period iterations
    if (i+1) % period == 0:
        plot_samples_spherical_coord_geodesics_continuousplot(20000, i, 500, 2000, 200)

    # Save the network parameters and the sample points
    if (i+1) % period_rec == 0:
       samples = Sampler_a_sph(24000).float()
       pushed_samples = samples + netf(samples)
       samples = samples.cpu()
       pushed_samples = pushed_samples.cpu()

       savepath = os.getcwd()
       #  torch.save(netf, 'netf param iter={}.pt'.format(i))
       #  torch.save(netphi, 'netphi param iter={}.pt'.format(i))
       torch.save(samples, 'sample_from_source iter={}.pt'.format(i))
       torch.save(pushed_samples, 'transport_samples iter={}.pt'.format(i))




This is the main function. Figures will be produced by running this function. Samples from source and transported samples will also be saved in the desired folder.

In [None]:
######################################
#      Sampler (CPU version)         #
######################################

def sample_CPU(p, num_samples, seed=0):
    npr.seed(seed)
    sample_Is = npr.choice(len(p), p=p, size=num_samples)  # ATTENTION! this sampling will take around 0.02s for 800 samples
    # computing phi coordinate
    samples_phi = sample_Is / P.shape[1]
    samples_phi = (samples_phi / P.shape[0]) * np.pi
    # computing theta coordinate
    samples_theta = sample_Is % P.shape[1]
    samples_theta = (samples_theta / P.shape[1]) * 2 * np.pi
    # obtain both spherical and Cartesian coordinates
    samples_spherical = torch.tensor(np.vstack((samples_theta, samples_phi)).T)
    samples_euclidean = spherical_to_cartesian(samples_spherical)

    return samples_spherical, samples_euclidean


def Sampler_b_sph_CPU(n):
    Target_locs_spherical, _ = sample_CPU(Uflat, n, seed=1)
    return Target_locs_spherical



#######################################################################
#  Functions that process the coordinates after mapping (CPU version) #
#######################################################################


# Fit sperical coordinates \theta, \phi into intervals [0, 2\pi) and [0, \pi]
def fit_coordinate_cpu(samples):
    N = samples.size()[0]
    fit_samples = samples
    for i in range(N):
      theta = samples[i, 0] - math.floor(samples[i, 0] / (2*math.pi)) * 2 * math.pi
      phi = samples[i, 1] - math.floor(samples[i, 1] / (2*math.pi)) * 2 * math.pi
      if phi > math.pi:
         phi = 2 * math.pi - phi
         if theta < math.pi:
           theta = theta + math.pi
         else:
           theta = theta - math.pi

      fit_samples[i, 0] = theta
      fit_samples[i, 1] = phi

    return fit_samples


# Pick k random positions uniformly over the land, return the closest position to the given location int_coordinate
# We use this method in the example presented in our paper.
def find_random_cpu(int_coordinate, k, P):
    candidate_points = Sampler_b_sph_CPU(k).cpu().detach().numpy()
    candidate_coordinate = np.zeros((k, 2))
    candidate_coordinate[:, 0] = (candidate_points[:, 1] / math.pi * P.shape[0]).astype(int)
    candidate_coordinate[:, 1] = (candidate_points[:, 0] / (2 * math.pi) * P.shape[1]).astype(int)

    flag = 0
    Min_dist = 1000 * 1000 + 1500 * 1500
    for ii in range(k):
      dist = (candidate_coordinate[ii, 0] - int_coordinate[0]) * (candidate_coordinate[ii, 0] - int_coordinate[0]) + (candidate_coordinate[ii, 1] - int_coordinate[1]) * (candidate_coordinate[ii, 1] - int_coordinate[1])
      if dist < Min_dist:
        flag = ii
        Min_dist = dist

    return candidate_coordinate[flag]


# Map the samples that are not located on land to the nearest point among k uniformly generated locations on landmass.
# If the sample is already on land, it will not be modified.
# We use this method in the example presented in our paper.
def map_to_land_randomk_cpu(samples, num_rand_loc, seed=1):

    random.seed(seed)

    map2land_samples = np.zeros(samples.shape)
    for i in range(samples.shape[0]):
       map2land_samples[i] = samples[i]

    N = samples.shape[0]
    for i in range(N):
      int_coord_sample_0 = int(samples[i, 1] / math.pi * P.shape[0])
      int_coord_sample_1 = int(samples[i, 0] / (2 * math.pi) * P.shape[1])

      # if the point is map to sea, one should map it onto land
      if P[int_coord_sample_0, int_coord_sample_1] == 0.0:
        rand_land_loc = find_random_cpu(np.array([int_coord_sample_0, int_coord_sample_1]), num_rand_loc, P)

        # transform back to sperical coordinate
        sample_phi = (rand_land_loc[0] / P.shape[0]) * np.pi
        sample_theta = (rand_land_loc[1] / P.shape[1]) * 2 * np.pi

        # replace the original coordinate by the new coordinate that is located on land
        map2land_samples[i, 0] = sample_theta
        map2land_samples[i, 1] = sample_phi

    return samples, map2land_samples




###############################################################################
#    Plotting samples and geodesics that join the samples (from saved data)   #
###############################################################################


R = 1


# Plot the samples under sperical coordinates and join each pair of samples (original one and the mapped one) with geodesics on sphere.
def plot_samples_spherical_coord(samples_source,samples_transported, batchsize, iteration, num_traj, num_rand_loc, num_join_arc):

    scale = 720 / math.pi

    # samples_source = Sampler_a_sph(batchsize)
    samples_target = Sampler_b_sph_CPU(batchsize)
    # samples_source = samples_source.float()
    # samples_transported = samples_source + netf(samples_source)

    scaled_samples_source = scale * samples_source
    scaled_samples_target = scale * samples_target
    scaled_samples_transported = scale * samples_transported

    fit_samples_transported = fit_coordinate_cpu(samples_transported).detach().numpy()
    fit_trans_sample, map2landsample = map_to_land_randomk_cpu(fit_samples_transported, num_rand_loc, seed=1)

    scaled_fit_samples_transported = scale * fit_samples_transported
    scaled_map2landsample = scale * map2landsample

    ############################## plotting samples and geodesics ################################
    fig = plt.figure(figsize=(80, 40))

    # plot world uniform population map
    U = read_uniform_map()
    plt.imshow(U, cmap='tab20c', interpolation='nearest', alpha = 0.5)

    for indx in range(min(batchsize, num_traj)):

        theta1 = samples_source[indx, 0].numpy()
        aphi1 = samples_source[indx, 1].numpy()
        theta2 = map2landsample[indx, 0]
        aphi2 = map2landsample[indx, 1]

        # plot geodesics under the geographical coordinate system (spherical coordinate system)
        # return the spherical coordinate at time l / num_join_arc (0<=l<=num_join_arc), assume geodesic time from 0 to 1
        # radius is set to be 1 as default
        angle = np.arccos(np.sin(aphi1) * np.sin(aphi2) * np.cos(theta1 - theta2) + np.cos(aphi1) * np.cos(aphi2))
        u = R * np.array([np.sin(aphi1) * np.cos(theta1), np.sin(aphi1) * np.sin(theta1), np.cos(aphi1)])
        v = R * np.array([np.sin(aphi2) * np.cos(theta2), np.sin(aphi2) * np.sin(theta2), np.cos(aphi2)])
        u_perp_unnorm = v - np.cos(angle) * u
        u_perp = R * u_perp_unnorm / linalg.norm(u_perp_unnorm)

        angle_list = np.linspace(0, angle, num_join_arc + 1)
        y = np.cos(angle_list)
        z = np.sin(angle_list)

        x_coord = u[0] * y + u_perp[0] * z
        y_coord = u[1] * y + u_perp[1] * z
        z_coord = u[2] * y + u_perp[2] * z

        theta_coord = np.arccos(x_coord / np.sqrt(x_coord * x_coord + y_coord * y_coord))
        phi_coord = np.arccos(z_coord / np.sqrt(x_coord * x_coord + y_coord * y_coord + z_coord * z_coord))

        bool_y = y_coord > 0.0
        theta_coord = theta_coord * bool_y + (2 * math.pi - theta_coord) * (1 - bool_y)

        scaled_theta_coord = scale * theta_coord
        scaled_phi_coord = scale * phi_coord

        color='orange'
        for c in range(num_join_arc):
            if np.abs(theta_coord[c] - theta_coord[c+1]) < math.pi:
                plt.plot([scaled_theta_coord[c], scaled_theta_coord[c+1]], [scaled_phi_coord[c], scaled_phi_coord[c+1]], c=color, alpha=0.9, linewidth=1.5)
            elif theta_coord[c] < math.pi:
                plt.plot([scaled_theta_coord[c], scale * 0], [scaled_phi_coord[c], scaled_phi_coord[c]], c=color, alpha=0.9, linewidth=1.5)
                plt.plot([scale * 2 * math.pi, scaled_theta_coord[c+1]], [scaled_phi_coord[c+1], scaled_phi_coord[c+1]], c=color, alpha=0.9, linewidth=1.5)
            else:
                plt.plot([scaled_theta_coord[c], scale * 2 * math.pi], [scaled_phi_coord[c], scaled_phi_coord[c]], c=color, alpha=0.9, linewidth=1.5)
                plt.plot([scale * 0, scaled_theta_coord[c+1]], [scaled_phi_coord[c+1], scaled_phi_coord[c+1]], c=color, alpha=0.9, linewidth=1.5)

    num_plot_sample = min(2000, batchsize)
    # Plot source samples
    plt.scatter(scaled_samples_source.detach().numpy()[:num_plot_sample, 0], scaled_samples_source.detach().numpy()[:num_plot_sample, 1], c='blue', s=20)
    # Plot the pushforwarded (transported) samples
    plt.scatter(scaled_map2landsample[:num_plot_sample, 0], scaled_map2landsample[:num_plot_sample, 1], c='green', s=20)

    plt.xlim([0, scale * 2 * math.pi])
    plt.ylim([scale * math.pi, 0])

    plt.xlabel('theta -axis')
    plt.ylabel('phi -axis')

    plt.savefig('Plot of samples under spherical coordinates (map locations on sea to land, with transport geodesic trajectory), at iteration {}'.format(iteration) + '.jpg')

    plt.close()

    #################### plotting samples ###########################
    fig = plt.figure(figsize=(80, 40))
    # plot world uniform population map
    U = read_uniform_map()
    plt.imshow(U, cmap='tab20c', interpolation='nearest', alpha = 0.5)

    # Plot source samples
    plt.scatter(scaled_samples_source.detach().numpy()[:, 0], scaled_samples_source.detach().numpy()[:, 1], c='blue', s=1)
    # Plot the pushforwarded (transported) samples
    plt.scatter(scaled_map2landsample[:, 0], scaled_map2landsample[:, 1], c='green', s=1)

    plt.xlim([0, scale * 2 * math.pi])
    plt.ylim([scale * math.pi, 0])

    plt.xlabel('theta -axis')
    plt.ylabel('phi -axis')

    plt.savefig('Plot of samples under spherical coordinates (map locations on sea to land), at iteration {}'.format(iteration) + '.jpg')

    plt.close()



# Plot the samples under sperical coordinates and join each pair of samples (original one and the mapped one) with geodesics on sphere.
def plot_samples_spherical_coord_L2_geodesic(samples_source,samples_transported, batchsize, iteration, num_traj, num_rand_loc, num_join_arc):

    scale = 720 / math.pi

    # samples_source = Sampler_a_sph(batchsize)
    samples_target = Sampler_b_sph_CPU(batchsize)
    # samples_source = samples_source.float()
    # samples_transported = samples_source + netf(samples_source)

    scaled_samples_source = scale * samples_source
    scaled_samples_target = scale * samples_target
    scaled_samples_transported = scale * samples_transported

    fit_samples_transported = fit_coordinate_cpu(samples_transported).detach().numpy()
    fit_trans_sample, map2landsample = map_to_land_randomk_cpu(fit_samples_transported, num_rand_loc, seed=1)

    scaled_fit_samples_transported = scale * fit_samples_transported
    scaled_map2landsample = scale * map2landsample

    ############################## plotting samples and geodesics ################################
    fig = plt.figure(figsize=(80, 40))

    # plot world uniform population map
    U = read_uniform_map()
    plt.imshow(U, cmap='tab20c', interpolation='nearest', alpha = 0.5)

    for indx in range(min(batchsize, num_traj)):

        theta1 = samples_source[indx, 0].numpy()
        aphi1 = samples_source[indx, 1].numpy()
        theta2 = map2landsample[indx, 0]
        aphi2 = map2landsample[indx, 1]

        # plot geodesics (w.r.t. to L2 metric, i.e. straight line) under the geographical coordinate system (spherical coordinate system)
        # return the spherical coordinate at time l / num_join_arc (0<=l<=num_join_arc), assume geodesic time from 0 to 1
        # radius is set to be 1 as default
        scaled_theta_coord = scale * (np.linspace(0, theta2 - theta1, num_join_arc + 1) + theta1)
        scaled_phi_coord = scale * (np.linspace(0, aphi2 - aphi1, num_join_arc + 1) + aphi1)

        color='orange'
        for c in range(num_join_arc):
            plt.plot([scaled_theta_coord[c], scaled_theta_coord[c+1]], [scaled_phi_coord[c], scaled_phi_coord[c+1]], c=color, alpha=0.9, linewidth=1.5)

    num_plot_sample = min(2000, batchsize)
    # Plot source samples
    plt.scatter(scaled_samples_source.detach().numpy()[:num_plot_sample, 0], scaled_samples_source.detach().numpy()[:num_plot_sample, 1], c='blue', s=20)
    # Plot the pushforwarded (transported) samples
    plt.scatter(scaled_map2landsample[:num_plot_sample, 0], scaled_map2landsample[:num_plot_sample, 1], c='green', s=20)

    plt.xlim([0, scale * 2 * math.pi])
    plt.ylim([scale * math.pi, 0])

    plt.xlabel('theta -axis')
    plt.ylabel('phi -axis')

    plt.savefig('Plot of samples under spherical coordinates (map locations on sea to land, with transport L2 geodesic trajectory), at iteration {}'.format(iteration) + '.jpg')

    plt.close()



sample = torch.load('sample_from_source iter=499999.pt').detach()
tranpted_sample = torch.load('transport_samples iter=499999.pt').detach()

# plot_samples_spherical_coord(sample ,tranpted_sample, 20000, 499999, 500, 2000, 200)
plot_samples_spherical_coord_L2_geodesic(sample ,tranpted_sample, 20000, 499999, 500, 2000, 200)



