In [None]:
import torch
import numpy as np

np.random.seed(43)  # for NumPy
torch.manual_seed(43)  # for PyTorch

from torch.utils.data import DataLoader
from icnnet import ICNNet #PICCN architecture
from toy_data_dataloader_gaussian import generate_dataset #simulated data
from gaussian_transport import get_gaussian_transport, get_mean, get_covariance #closed form gaussian transport
from mydataset import MyDataset, get_gaussian_dataset, get_gaussian_transport_dataset #dataset
from train_picnn import PICNNtrain #initial training
from train_makkuva import train_makkuva_epoch #Makkuva's training
from visualization import plot_transport #load transport map vizualization
import matplotlib.pyplot as plt

import ot
import ot.plot

In [None]:
%load_ext autoreload
%autoreload 2

__Settings__

In [None]:
r = 500 #sample size for a given distribution
N = 50 #number of distributions

init_z_f = 'empiric_gaussian_map' #chose from 'Amos', 'gaussian01_map', 'empiric_gaussian_map'
init_z_g = 'empiric_gaussian_map' #chose from 'Amos', 'gaussian01_map', 'empiric_gaussian_map'

input_size = 2
layer_sizes = [input_size,64, 64, 64, 1]
n_layers = len(layer_sizes)
context_layer_sizes = [12, 12, 12, 12, 1]

init_bunne = True #to approximate initialization at instanciation
n_epoch_training_init = 0
lr_training_init = 0.001


filepath_pth_f = 'trained_models/training/models_f_'
filepath_pth_g = 'trained_models/training/models_g_'
filepath_plt_f = 'trained_models/training/plots_f_'
filepath_plt_g = 'trained_models/training/plots_g_'
import os
os.makedirs(filepath_pth_f, exist_ok=True)
os.makedirs(filepath_pth_g, exist_ok=True)
os.makedirs(filepath_plt_f, exist_ok=True)
os.makedirs(filepath_plt_g, exist_ok=True)

batch_number = 0

train_freq_f = 1
train_freq_g = 5
lr_makkuva = 0.0001
regularize_f = False
regularize_g = True
lambda_proximal = 0.0001

__Simulate transport data__

In [None]:
dataset = generate_dataset(r=500, N=50)
gaussian_dataset = get_gaussian_dataset(dataset)
gaussian_transport_dataset = get_gaussian_transport_dataset(gaussian_dataset)

__Define initialization__

In [None]:
if init_z_f == 'Amos':
    def init_z_f(x):
        return x
if init_z_g == 'Amos':
    def init_z_g(x) :
        return x

if init_z_f == 'gaussian01_map':
    def init_z_f(x):
        return (1/2) * torch.norm(x, dim=-1, keepdim=True)**2
if init_z_g == 'gaussian01_map':
    def init_z_g(x) :
        return (1/2) * torch.norm(x, dim=-1, keepdim=True)**2

mean1 = get_mean(dataset.X)
cov1 = get_covariance(dataset.X)
mean2 = get_mean(dataset.Y)
cov2 = get_covariance(dataset.Y)
    
if init_z_f == 'empiric_gaussian_map':
    def init_z_f(x):
        return(get_gaussian_transport(u=x, cov1 = cov1, cov2 = cov2, m1=mean1, m2=mean2))
if init_z_g == 'empiric_gaussian_map':
    def init_z_g(x) :
        return(get_gaussian_transport(u=x, cov1 = cov2, cov2 = cov1, m1=mean2, m2=mean1))


__Build model__

In [None]:
ICNNf = ICNNet(layer_sizes = layer_sizes, context_layer_sizes=context_layer_sizes, init_bunne = init_bunne)
ICNNg = ICNNet(layer_sizes = layer_sizes, context_layer_sizes=context_layer_sizes, init_bunne = init_bunne)

__PICNN initial training__

In [None]:
print('training f')
gaussian_transport_dataloader = DataLoader(gaussian_transport_dataset, batch_size=250, shuffle=True)
PICNNtrain(ICNNf, gaussian_transport_dataloader, init_z_f, lr=lr_training_init, epochs=n_epoch_training_init)

reversed_gaussian_dataset = MyDataset(gaussian_dataset.Y, gaussian_dataset.C, gaussian_dataset.X)
gaussian_transport_dataset_reversed = get_gaussian_transport_dataset(reversed_gaussian_dataset)
gaussian_transport_dataloader_reversed = DataLoader(gaussian_transport_dataset_reversed, batch_size=250, shuffle=True)
PICNNtrain(ICNNg, gaussian_transport_dataloader_reversed, init_z_g, lr=lr_training_init, epochs=n_epoch_training_init)

__Save and plot init__

In [None]:
filename_pth_f = filepath_pth_f + '/' + str(0) + '_init.pth'
filename_pth_g = filepath_pth_g + '/' + str(0) + '_init.pth'
torch.save(ICNNf.state_dict(), filename_pth_f)
torch.save(ICNNg.state_dict(), filename_pth_g)

filename_plt_f = filepath_plt_f + '/' + str(0) + '_init.png'
filename_plt_g = filepath_plt_g + '/' + str(0) + '_init.png'
plot_transport(dataset, batch_number, ICNNf, ICNNg, init_z_f = init_z_f, init_z_g = init_z_g, filename_f = filename_plt_f, filename_g = filename_plt_g)

__Train using Makkuva's method__

In [None]:
dataloader = DataLoader(dataset, batch_size=500, shuffle=True)

for epoch in range(1, 101) :
    print('epoch :', epoch, end=' - ')
    train_makkuva_epoch(ICNNf=ICNNf, ICNNg=ICNNg, dataloader = dataloader, init_z_f = init_z_f, init_z_g = init_z_g, lr=lr_makkuva, train_freq_g=train_freq_g, train_freq_f=train_freq_f, regularize_f = regularize_f, regularize_g = regularize_g, lambda_proximal = lambda_proximal)

    filename_pth_f = filepath_pth_f + '/' + str(epoch) + '.pth'
    filename_pth_g = filepath_pth_g + '/' + str(epoch) + '.pth'
    torch.save(ICNNf.state_dict(), filename_pth_f)
    torch.save(ICNNg.state_dict(), filename_pth_g)

    filename_plt_f = filepath_plt_f + '/' + str(epoch) + '.png'
    filename_plt_g = filepath_plt_g + '/' + str(epoch) + '.png'
        
    plot_transport(dataset, batch_number, ICNNf, ICNNg, init_z_f=init_z_f, init_z_g=init_z_g, filename_f = filename_plt_f, filename_g = filename_plt_g)  

__Calculating the transport cost__

In [None]:

test = 0 #number of the distribution to plot

x_i = dataset.X[test, :, :]
y_i = dataset.Y[test, :, :]
c_i = dataset.C[test, :, :]

locs = c_i[:,0]
scales = c_i[:,1]


y_i.requires_grad_(True)
x_i.requires_grad_(True)



def calculating_OT(mu, nu, f_mu_c):
    a, b = np.ones((mu.shape[0],)) / mu.shape[0], np.ones((nu.shape[0],)) / nu.shape[0]

    M = ot.dist(nu, f_mu_c)
    G0 = ot.emd2(a, b, M)

    M2 = ot.dist(mu, f_mu_c)
    G1 = ot.emd2(a, b, M2)

    G2 = 1/mu.shape[0] * np.sum(np.linalg.norm(mu - f_mu_c, axis = 1)**2)

    return G0, G2-G1

output_model_f = ICNNf(x_i, c_i, init_z = init_z_f)
grad_model_f = torch.autograd.grad(outputs=output_model_f, inputs=x_i, grad_outputs=torch.ones_like(output_model_f), create_graph=True)[0].detach().numpy()

output_model_g = ICNNg(y_i, c_i, init_z = init_z_g)
grad_model_g = torch.autograd.grad(outputs=output_model_g, inputs=y_i, grad_outputs=torch.ones_like(output_model_g), create_graph=True)[0].detach().numpy()

mu = x_i.detach().numpy()
nu = y_i.detach().numpy()

G0, delta = calculating_OT(mu, nu, grad_model_f)

print('Cost of transport between the target and the predicted distribution :', G0)
print('Difference between the cost of transport and the Wasserstein distance:', delta)