In [None]:
import matplotlib.pyplot as plt
import torch
from torch.autograd import grad, Variable
import autograd
import autograd.numpy as np
import copy
import scipy as sp
from scipy import stats
from sklearn import metrics
import sys
import ot
import gwot
from gwot import models, sim, ts, util
import gwot.bridgesampling as bs
import dcor

import importlib
import models
importlib.reload(models)
import random
import model_lenaic as model_sim

In [None]:
import os
num_threads = "8"
os.environ["OMP_NUM_THREADS"] = num_threads
os.environ["OPENBLAS_NUM_THREADS"] = num_threads
os.environ["MKL_NUM_THREADS"] = num_threads
os.environ["VECLIB_MAXIMUM_THREADS"] = num_threads
os.environ["NUMEXPR_NUM_THREADS"] = num_threads
torch.set_num_threads(8)

In [None]:
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")
torch.set_default_dtype(torch.float64)

In [None]:
PLT_CELL = 2.5

In [None]:
# set random seed
SRAND = 0
torch.manual_seed(SRAND)
np.random.seed(SRAND)

In [None]:
M = 100
N = 64
sim = gwot.sim.Simulation(V = model_sim.Psi, dV = model_sim.dPsi, birth_death = False, 
                          N = np.full(model_sim.T, N),
                          T = model_sim.T, 
                          d = model_sim.dim, 
                          D = model_sim.D, 
                          t_final = model_sim.t_final, 
                          ic_func = model_sim.ic_func, 
                          pool = None)
sim.sample(steps_scale = int(model_sim.sim_steps/sim.T));

In [None]:
plt.scatter(sim.x[:, 0], sim.x[:, 1], alpha = 0.25, c = sim.t_idx)

In [None]:
importlib.reload(models)

In [None]:
model = models.TrajLoss(torch.randn(model_sim.T, M, model_sim.dim)*1.0,
                        torch.tensor(sim.x, device = device), 
                        torch.tensor(sim.t_idx, device = device), 
                        dt = model_sim.t_final/model_sim.T, tau = model_sim.D, sigma = None, M = M,
                        lamda_reg = 0.05, lamda_cst = 1.0, sigma_cst = 5.0,
                        branching_rate_fn = model_sim.branching_rate,
                        sinkhorn_iters = 250, device = device, warm_start = True)

In [None]:
output = models.optimize(model, n_iter = 2500, eta_final = 0.25, tau_final = model_sim.D, sigma_final = 0.5, temp_init = 1.0, temp_ratio = 1.0, N = M, dim = model_sim.dim, tloss = model, print_interval = 50)

In [None]:
importlib.reload(models)

In [None]:
model_anneal = models.TrajLoss(torch.randn(model_sim.T, M, model_sim.dim)*1.0,
                        torch.tensor(sim.x, device = device), 
                        torch.tensor(sim.t_idx, device = device), 
                        dt = model_sim.t_final/model_sim.T, tau = model_sim.D, sigma = None, M = M,
                        lamda_reg = 0.05, lamda_cst = 1.0, sigma_cst = 5.0,
                        branching_rate_fn = model_sim.branching_rate,
                        sinkhorn_iters = 250, device = device, warm_start = True)

output_anneal = models.optimize(model_anneal, n_iter = 2500, eta_final = 0.25, tau_final = model_sim.D, sigma_final = 0.5, temp_init = 5, temp_ratio = (1/5)**(1/500), N = M, dim = model_sim.dim, tloss = model_anneal, print_interval = 50)

In [None]:
# modified version of optimize() that holds \tau fixed and only anneals \eps
def optimize2(model, n_iter, eps_final, eta, temp_init, temp_ratio, dim, print_interval = 50, **kwargs):
    obj = []
    obj_primal = []
    temp_curr = temp_init
    eps_t = eps_final*temp_init  # eps parameter
    optim = models.LangevinGD(model.parameters(), eta = eta, sigma2 = 2*(model.tau + eps_t)*model.lamda_reg, **kwargs) 
    # save all iterates for animations
    x_save = torch.zeros((n_iter, ) + tuple(model.x.shape))
    for i in range(n_iter):
        with torch.no_grad():
            x_save[i, :, :, :] = model.x.data.clone()
        ## set noise level
        optim.update_sigma2(2*(model.tau + eps_t)*model.lamda_reg)
        ##  optimize whole model
        loss = model()
        if torch.isnan(loss):
            break
        with torch.no_grad():
            # compute the primal objective before doing the step, since positions will be updated.
            x = model.x.cpu().numpy()
            loss_primal = model.forward_primal() + model.tau*model.lamda_reg*sum([models.entropy_est_knn(x[i, :, :], d = dim, k = 2) for i in range(x.shape[0])])
        ## langevin step
        optim.zero_grad()
        loss.backward()
        optim.step()
        
        with torch.no_grad():
            obj_primal.append(loss_primal.item())
            obj.append(loss.item())
        
        if i % print_interval == 0:
            avg_iters = np.array([l.iters_used for l in model.loss_reg.ot_losses]).mean()
            print("Iteration %d, Loss = %0.3f, Primal loss = %0.3f, Avg. iters = %0.3f, eta = %0.3f, eps = %0.3f, temp = %0.3f" % (i, loss, loss_primal, avg_iters, eta, eps_t, max(1, temp_curr)))
            
        # update noise level
        temp_curr *= temp_ratio
        eps_t = eps_final*max(1, temp_curr)
        
    return obj, obj_primal, x_save

In [None]:
# annealing \eps while holding \tau fixed 
model_anneal2 = models.TrajLoss(torch.randn(model_sim.T, M, model_sim.dim)*1.0,
                        torch.tensor(sim.x, device = device), 
                        torch.tensor(sim.t_idx, device = device), 
                        dt = model_sim.t_final/model_sim.T, tau = model_sim.D, sigma = 0.5, M = M,
                        lamda_reg = 0.05, lamda_cst = 1.0, sigma_cst = 5.0,
                        branching_rate_fn = model_sim.branching_rate,
                        sinkhorn_iters = 1_000, device = device, warm_start = True)

output_anneal2 = optimize2(model_anneal2, n_iter = 2500, eps_final = 0.001, eta = 0.25, temp_init = 250, temp_ratio = (1/250)**(1/1000), N = M, dim = model_sim.dim, tloss = model_anneal, print_interval = 50)

In [None]:
import tqdm
from tqdm import tqdm
primal_anneal = []
for i in tqdm(range(len(output_anneal[2]))):
    model_tmp = models.TrajLoss(output_anneal[2][i],
                            torch.tensor(sim.x, device = device), 
                            torch.tensor(sim.t_idx, device = device), 
                            dt = model_sim.t_final/model_sim.T, tau = model_sim.D, sigma = 0.5, M = M,
                            lamda_reg = 0.05, lamda_cst = 1.0, sigma_cst = 5.0,
                            branching_rate_fn = model_sim.branching_rate,
                            sinkhorn_iters = 1_000, device = device, warm_start = True)
    model_tmp.forward()
    with torch.no_grad():
        primal_anneal.append(model_tmp.forward_primal().item() + model_tmp.tau*model_tmp.lamda_reg*sum([models.entropy_est_knn(model_tmp.x[i, :, :], d = model_tmp.d, k = 2) for i in range(model_tmp.x.shape[0])]))

In [None]:
primal_anneal2 = []
for i in tqdm(range(len(output_anneal2[2]))):
    model_tmp = models.TrajLoss(output_anneal2[2][i],
                            torch.tensor(sim.x, device = device), 
                            torch.tensor(sim.t_idx, device = device), 
                            dt = model_sim.t_final/model_sim.T, tau = model_sim.D, sigma = 0.5, M = M,
                            lamda_reg = 0.05, lamda_cst = 1.0, sigma_cst = 5.0,
                            branching_rate_fn = model_sim.branching_rate,
                            sinkhorn_iters = 1_000, device = device, warm_start = True)
    model_tmp.forward()
    with torch.no_grad():
        primal_anneal2.append(model_tmp.forward_primal().item() + model_tmp.tau*model_tmp.lamda_reg*sum([models.entropy_est_knn(model_tmp.x[i, :, :], d = model_tmp.d, k = 2) for i in range(model_tmp.x.shape[0])]))

In [None]:
err = [np.mean([dcor.energy_distance(x, y) for (x, y) in zip(output[2][i], output[2][-1])]) for i in range(len(output[2])-1)]
err_anneal = [np.mean([dcor.energy_distance(x, y) for (x, y) in zip(output_anneal[2][i], output_anneal[2][-1])]) for i in range(len(output_anneal[2])-1)]
err_anneal2 = [np.mean([dcor.energy_distance(x, y) for (x, y) in zip(output_anneal2[2][i], output_anneal2[2][-1])]) for i in range(len(output_anneal2[2])-1)]

In [None]:
plt.figure(figsize = (3*PLT_CELL, PLT_CELL))
plt.subplot(1, 2, 1)
plt.plot(output[1], label = "MFL")
plt.plot(primal_anneal2, label = "MFL + Annealing (ε)")
plt.plot(primal_anneal, label = "MFL + Annealing (τ, σ, η)")
plt.ylim(2.2, 2.5)
plt.title("Reduced objective $F$")
plt.xlabel("Iteration")
plt.ylabel("$F$")
plt.subplot(1, 2, 2)
plt.plot(np.sqrt(np.array(err)), label = "MFL")
plt.plot(np.sqrt(np.array(err_anneal2)), label = "MFL + Annealing (ε)")
plt.plot(np.sqrt(np.array(err_anneal)), label = "MFL + Annealing (τ, σ, η)")
plt.title("Energy distance to final iterate")
plt.xlabel("Iteration")
plt.ylabel("energy distance")
plt.yscale("log")
plt.legend()
plt.tight_layout()
plt.savefig("appendix_annealing_a_final.pdf")

In [None]:
plt.figure(figsize = (3*PLT_CELL, 1*PLT_CELL))
plt.subplot(1, 3, 1)
with torch.no_grad():
    plt.scatter(model.x.reshape(-1, model_sim.dim)[:, 0], model.x.reshape(-1, model_sim.dim)[:, 1], c = np.kron(np.arange(model_sim.T), np.ones(M)), alpha = 1, marker = ".")
plt.title("MFL")
plt.xlabel("x"); plt.ylabel("y")
plt.xlim(-2.5, 2.5); plt.ylim(-1.5, 0.5)
plt.subplot(1, 3, 2)
with torch.no_grad():
    plt.scatter(model_anneal.x.reshape(-1, model_sim.dim)[:, 0], model_anneal.x.reshape(-1, model_sim.dim)[:, 1], c = np.kron(np.arange(model_sim.T), np.ones(M)), alpha = 1, marker = ".")
plt.title("MFL + Annealing (τ, σ, η)")
plt.xlabel("x"); plt.ylabel("y")
plt.xlim(-2.5, 2.5); plt.ylim(-1.5, 0.5)
plt.subplot(1, 3, 3)
with torch.no_grad():
    plt.scatter(model_anneal2.x.reshape(-1, model_sim.dim)[:, 0], model_anneal2.x.reshape(-1, model_sim.dim)[:, 1], c = np.kron(np.arange(model_sim.T), np.ones(M)), alpha = 1, marker = ".")
plt.title("MFL + Annealing (ε)")
plt.xlabel("x"); plt.ylabel("y")
plt.xlim(-2.5, 2.5); plt.ylim(-1.5, 0.5)
plt.tight_layout()
plt.savefig("appendix_annealing_b.pdf")

In [None]:
fig = plt.figure(figsize = (4*PLT_CELL, 1.125*PLT_CELL))
for (i, j) in enumerate(np.array([1, 50, 250, 500, 2500])-1):
    plt.subplot(1, 5, i+1)
    im = plt.scatter(output[2][j, :, :, :].reshape(-1, model_sim.dim)[:, 0], output[2][j, :, :, :].reshape(-1, model_sim.dim)[:, 1], c = np.kron(np.linspace(0, model_sim.t_final, model_sim.T), np.ones(M)), alpha = 0.25, marker = ".")
    plt.ylim(-1.75, 0.5); plt.xlim(-2, 2)
    # plt.text(-1.75, -1.65, "Iter %d" % (j+1))
    plt.title("Iter %d" % (j+1))
    plt.xlabel("x")
    plt.ylabel("y")
    # if i // 4 == 0:
    #     plt.gca().get_xaxis().set_visible(False)
    if i % 5 > 0:
        plt.gca().get_yaxis().set_visible(False)
plt.suptitle("MFL")
plt.tight_layout()
plt.savefig("appendix_annealing_b_new_final.pdf")

# fig.subplots_adjust(right=0.9)
# cbar_ax = fig.add_axes([0.925, 0.15, 0.025, 0.7])
# cb = fig.colorbar(im, cax=cbar_ax)
# cb.set_alpha(1)
# cb.draw_all()
# cbar_ax.set_title("$t$")

In [None]:
fig = plt.figure(figsize = (4*PLT_CELL, 1.125*PLT_CELL))
for (i, j) in enumerate(np.array([1, 50, 250, 500, 2500])-1):
    plt.subplot(1, 5, i+1)
    im = plt.scatter(output_anneal[2][j, :, :, :].reshape(-1, model_sim.dim)[:, 0], output_anneal[2][j, :, :, :].reshape(-1, model_sim.dim)[:, 1], c = np.kron(np.linspace(0, model_sim.t_final, model_sim.T), np.ones(M)), alpha = 0.25, marker = ".")
    plt.ylim(-1.75, 0.5); plt.xlim(-2, 2)
    # plt.text(-1.75, -1.65, "Iter %d" % (j+1))
    plt.title("Iter %d" % (j+1))
    plt.xlabel("x")
    plt.ylabel("y")
    # if i // 4 == 0:
    #     plt.gca().get_xaxis().set_visible(False)
    if i % 5 > 0:
        plt.gca().get_yaxis().set_visible(False)
plt.suptitle("MFL + Annealing (τ, σ, η)")
plt.tight_layout()
plt.savefig("appendix_annealing_c_new_final.pdf")

In [None]:
fig = plt.figure(figsize = (4*PLT_CELL, 1.125*PLT_CELL))
for (i, j) in enumerate(np.array([1, 50, 250, 500, 2500])-1):
    plt.subplot(1, 5, i+1)
    im = plt.scatter(output_anneal2[2][j, :, :, :].reshape(-1, model_sim.dim)[:, 0], output_anneal2[2][j, :, :, :].reshape(-1, model_sim.dim)[:, 1], c = np.kron(np.linspace(0, model_sim.t_final, model_sim.T), np.ones(M)), alpha = 0.25, marker = ".")
    plt.ylim(-1.75, 0.5); plt.xlim(-2, 2)
    # plt.text(-1.75, -1.65, "Iter %d" % (j+1))
    plt.title("Iter %d" % (j+1))
    plt.xlabel("x")
    plt.ylabel("y")
    # if i // 4 == 0:
    #     plt.gca().get_xaxis().set_visible(False)
    if i % 5 > 0:
        plt.gca().get_yaxis().set_visible(False)
plt.suptitle("MFL + Annealing (ε)")
plt.tight_layout()

plt.savefig("appendix_annealing_d_new_final.pdf")