In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
# from torchcfm.optimal_transport import OTPlanSampler

from typing import List
import time
from torchdyn.core import NeuralODE

from tqdm import tqdm
from torch.distributions.multivariate_normal import MultivariateNormal
import ot
import ot.plot
import pickle
from copy import deepcopy

# import warnings
# warnings.filterwarnings('ignore')

# 0. Data Generation

In [None]:
N = 100
yend = 5

np.random.seed(0)
# start q(x0)
p = .5
x0 = np.random.multivariate_normal([0, 0], [[1, 0], [0, 1]], N)

# end q(x1)
z_id1 = np.random.binomial(1, p, N)[:,None]
x1 = z_id1*np.random.multivariate_normal([-1.5, yend], [[.05, 0], [0, .05]], N) +\
(1-z_id1)*np.random.multivariate_normal([1.5, yend], [[.05, 0], [0, .05]], N)

x0 = torch.from_numpy(x0).to(torch.float32)
x1 = torch.from_numpy(x1).to(torch.float32)

# 1. Functions

## 1.1 Common Functions

In [None]:
class MLP(torch.nn.Module):
    def __init__(self, dim, out_dim=None, w=64, time_varying=False):
        super().__init__()
        self.time_varying = time_varying
        if out_dim is None:
            out_dim = dim
        self.net = torch.nn.Sequential(
            torch.nn.Linear(dim + (1 if time_varying else 0), w),
            torch.nn.SELU(),
            torch.nn.Linear(w, w),
            torch.nn.SELU(),
            torch.nn.Linear(w, w),
            torch.nn.SELU(),
            torch.nn.Linear(w, out_dim),
        )

    def forward(self, x):
        return self.net(x)

In [None]:
class torch_wrapper(torch.nn.Module):
    """Wraps model to torchdyn compatible format."""

    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, t, x, *args, **kwargs):
        return self.model(torch.cat([x, t.repeat(x.shape[0])[:, None]], 1))

In [None]:
def gen_traj(model, n_samp, nt_gen, seed, x_start = None):
    
    node = NeuralODE(torch_wrapper(model), solver="dopri5",
                 sensitivity="adjoint", atol=1e-4, rtol=1e-4)
    if x_start is None:
        torch.manual_seed(seed)
        x_start = torch.randn(n_samp, dim)

    with torch.no_grad():
        traj = node.trajectory(x_start, t_span=torch.linspace(0, 1, nt_gen))
        
    return traj

In [None]:
def plot_traj(traj, nt_gen, mid_pts = True, start_color = "black", end_color = "orange"):
    plt.scatter(traj[0, :, 0], traj[0, :, 1], s=4, alpha=1, c=start_color)
    if mid_pts:
        plt.scatter(traj[int(nt_gen/2), :, 0], traj[int(nt_gen/2), :, 1], s=4, alpha=1, c="red")
    plt.scatter(traj[:, :, 0], traj[:, :, 1], s=1, alpha=0.1, c="blue")
    plt.scatter(traj[-1, :, 0], traj[-1, :, 1], s=4, alpha=1, c=end_color)
    
    if mid_pts:
        plt.legend(["x0", "x_05", "Flow", "x1"])
    else:
        plt.legend(["x0", "Flow", "x1"])

    plt.xlabel("x")
    plt.ylabel("y")

## 1.2 GP-ICFM

In [None]:
def calc_r(ti, tj):
    r = ti[...,None] - tj[...,None,:]
    r[r == 0] = 1e-15
    return r
def k11(r, alpha, l):
    return (alpha**2)*torch.exp(-0.5 * ((r/l)**2))
def k12(r, alpha, l):
    return (alpha**2/l**2)*r*torch.exp(-0.5*((r/l)**2))
def k22(r, alpha, l):
    return (alpha**2/l**4)*(l**2 - r**2)*torch.exp(-0.5*((r/l)**2))

In [None]:
def cov_mat2(ti, tj, alpha, l, beta = 1e-3, decrease = None):
    
    r = calc_r(ti, tj)
    nB = r.shape[0]
    nt = r.shape[1]
    
    if decrease is None:
        Sig11 = k11(r, alpha, l) + (torch.eye(nt)*beta).repeat(nB,1,1)
        Sig12 = k12(r, alpha, l)
    elif decrease:
        Sig11 = k11(r, alpha, l) + beta*torch.bmm((ti-1).unsqueeze(2), (tj-1).unsqueeze(1))
        Sig12 = k12(r, alpha, l) + beta*(ti-1)[:,:,None].repeat(1,1,nt) + torch.diag_embed(beta*(ti-1))
    else:
        Sig11 = k11(r, alpha, l) + beta*torch.bmm(ti.unsqueeze(2), tj.unsqueeze(1))
        Sig12 = k12(r, alpha, l) + beta*ti[:,:,None].repeat(1,1,nt) + torch.diag_embed(beta*ti)
    
    Sig21 = Sig12.permute(0, 2, 1)
    Sig22 = k22(r, alpha, l)
    
    block_row1 = torch.cat([Sig11, Sig12], dim=2)
    block_row2 = torch.cat([Sig21, Sig22], dim=2)
    Sig = torch.cat([block_row1, block_row2], dim = 1)
    Sig = (Sig + Sig.permute(0, 2, 1))/2
    
    return Sig

In [None]:
def samp_x_dx2(t_mat, alpha, l, x_obs, t_obs, beta = 1e-3, decrease = None):
    
    nB = x_obs.shape[0]
    dim = x_obs.shape[2]
    nt = t_mat.shape[1]
    nt_obs = t_obs.shape[0]
    
    r_obs_x = calc_r(t_obs, t_mat)
    r_obs_obs = calc_r(t_obs, t_obs)
    
    Sig_11 = cov_mat2(t_mat, t_mat, alpha, l, beta, decrease)
    if decrease is None:
        k_obs_x = k11(r_obs_x, alpha, l)
        k_obs_dx = k12(r_obs_x, alpha, l)
        Sig_22_sing = k11(r_obs_obs, alpha, l) + torch.eye(nt_obs)*beta
    elif decrease:
        k_obs_x = k11(r_obs_x, alpha, l) + beta*torch.bmm((t_obs.repeat(nB,1)-1).unsqueeze(2),
                                                      (t_mat-1).unsqueeze(1))
        k_obs_dx = k12(r_obs_x, alpha, l) + beta*(t_obs.repeat(nB,1)-1)[:,:,None].repeat(1,1,nt)
        Sig_22_sing = k11(r_obs_obs, alpha, l) + beta*torch.outer((t_obs-1), (t_obs-1))
    else:
        k_obs_x = k11(r_obs_x, alpha, l) + beta*torch.bmm(t_obs.repeat(nB,1).unsqueeze(2),
                                                      t_mat.unsqueeze(1))
        k_obs_dx = k12(r_obs_x, alpha, l) + beta*t_obs.repeat(nB,1)[:,:,None].repeat(1,1,nt)
        Sig_22_sing = k11(r_obs_obs, alpha, l) + beta*torch.outer(t_obs, t_obs)
        
    
    Sig_21 = torch.cat([k_obs_x, k_obs_dx], dim=2)
    Sig_12 = Sig_21.permute(0, 2, 1)
    
    Sig_22_inv_sing = torch.linalg.inv(Sig_22_sing)
    Sig_22 = Sig_22_sing.repeat(nB,1,1)
    Sig_22_inv = Sig_22_inv_sing.repeat(nB,1,1)
    
    Sig_cond = Sig_11 - torch.bmm(torch.bmm(Sig_12, Sig_22_inv), Sig_21)
    Sig_cond = (Sig_cond + Sig_cond.permute(0, 2, 1))/2
    
    svd_add_idx = np.where(sum((torch.linalg.eigvals(Sig_cond).real>=0).T) != Sig_cond.shape[1])[0]
    U, S, Vh = torch.linalg.svd(Sig_cond[svd_add_idx,:,:])
    Sig_cond_add = torch.bmm(torch.bmm(Vh.permute(0, 2, 1), torch.diag_embed(S + 1e-8)), Vh)
    Sig_cond[svd_add_idx,:,:] = (Sig_cond_add + Sig_cond_add.permute(0, 2, 1))/2
    
    mu_A = torch.bmm(Sig_12, Sig_22_inv)
    
    x_samps = torch.zeros((nB, nt, dim))
    dx_samps = torch.zeros((nB, nt, dim))
    
    for dd in range(dim):
        x_obs_tmp = x_obs[:,:,dd]
        x_obs_tmp_batch = torch.reshape(x_obs_tmp, (nB, nt_obs, 1))
        mu_new = torch.bmm(mu_A, x_obs_tmp_batch).reshape((nB, 2*nt))
        try:
            x_dx_samps_tmp = MultivariateNormal(loc=mu_new, covariance_matrix=Sig_cond).rsample()
        except:
            x_dx_samps_tmp = np.zeros((nB, 2*nt))
            for bb in range(nB):
                x_dx_samps_tmp[bb,:] = np.random.multivariate_normal(mu_new[bb,:], Sig_cond[bb,:,:])
            x_dx_samps_tmp = torch.from_numpy(x_dx_samps_tmp)
            
        x_samps[:,:,dd] = x_dx_samps_tmp[:,0:nt]
        dx_samps[:,:,dd] = x_dx_samps_tmp[:,nt:(2*nt)]
    
    return x_samps, dx_samps

In [None]:
def GP_FM2(x_data, model, optimizer, alpha, l,
          nt, batch_size, t_obs, n_epochs, beta = 0, decrease = None,
          ImpSamp = False, beta_a = 1.0, beta_b = 0.5, storeCheck = False, epoch_check_step = 100):
    
    N = x_data.shape[0]
    dim = x_data.shape[2]
    
    if ImpSamp:
        m = torch.distributions.beta.Beta(torch.tensor([beta_a]), torch.tensor([beta_b])) # put more weight on t = 1
    
    nbatch = int(N/batch_size)
    batch_idx = np.reshape(np.arange(0,N),[nbatch, batch_size])
    
    losses: List[float] = []
    if storeCheck:
        check_pts = []
        check_steps = []
        
    model.train()
    for k in tqdm(range(n_epochs)):
        for bb in range(nbatch):
            x0 = torch.randn((batch_size,dim))
            x_obs = x_data[batch_idx[bb,:],:,:]
            x_obs[:,0,:] = x0
            
            if ImpSamp:
                t_mat = m.sample((batch_size,nt))[:,:,0]
            else:
                t_mat = torch.rand((batch_size,nt))
            
            try:
                xt_batch, ut_batch = samp_x_dx2(t_mat, alpha, l, x_obs, t_obs, beta, decrease)
            except:
                pass
            
            t = torch.reshape(t_mat, (-1, 1))
            xt = torch.reshape(xt_batch, (-1,dim))
            ut = torch.reshape(ut_batch, (-1,dim))
            
            vt = model(torch.cat([xt, t], dim=-1))
            if ImpSamp:
                loss = torch.mean((1/torch.exp(m.log_prob(t))[:,None])*((vt - ut) ** 2))
            else:
                loss = torch.mean((vt - ut) ** 2)
            
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            
            # Logging
            losses.append(loss.item())
            
            if storeCheck:
                if k % epoch_check_step == 0:
                    check_pts.append(deepcopy(model.state_dict()))
                    check_steps.append(k)
            
    if storeCheck:       
        return model, losses, check_pts, check_steps
    else:
        return model, losses

## 1.3 W2

In [None]:
def w_mat_dist(x1_test, x1_gen, p = 2, ot_mat = False):
    n_test = x1_test.shape[0]
    n_gen = x1_gen.shape[0]
    
    a, b = np.ones((n_test,)) / n_test, np.ones((n_gen,)) / n_gen  # uniform distribution on samples
    if p == 1:
        M = ot.dist(x1_test, x1_gen, metric='euclidean')
    elif p == 2:
        M = ot.dist(x1_test, x1_gen)
    G0 = None
    if ot_mat:
        G0 = ot.emd(a, b, M)
    
    d = ot.emd2(a, b, M)
    
    return G0, d

# 2. Example Paths

In [None]:
t_mat = torch.rand((2,100))
t_obs = torch.tensor([0., 1.])
x_obs = torch.zeros((2, 2, 1))
x_obs[0,0,0] = x0[0,0]
x_obs[0,1,0] = x1[0,0]
x_obs[1,0,0] = x0[1,0]
x_obs[1,1,0] = x1[1,0]

alpha = 1
l = 1
plt.rcParams['figure.figsize'] = [10, 3]
fig, axs = plt.subplots(1, 4)
beta_all = [0, 1e-2, 1e-2, 1e-2]
decrease_all = [None, None, True, False]
for tt, ax in enumerate(axs.flatten()):
    for ii in range(50):
        x_samp, dx_samp =  samp_x_dx2(t_mat, alpha, l, x_obs,
                                      t_obs, beta = beta_all[tt], decrease = decrease_all[tt])
        ax.scatter(t_mat[0,:], x_samp[0,:,:], s = 2)
plt.tight_layout()
plt.show()
plt.rcParams['figure.figsize'] = [6, 4]

# 3. Fitting

In [None]:
dim = x1.shape[1]
sigma = 1e-2
n_samp = 1000
nt_gen = 100

x_data = torch.zeros(N, 2, dim)
x_data[:,1,:] = x1

alpha = 1
l = 1
nt = 10
batch_size = 100
t_obs = torch.tensor([0., 1.])

n_epochs = 5000
lr_GP = 2e-3

In [None]:
model_GP0 = MLP(dim=dim, time_varying=True)
optimizer_GP0 = torch.optim.Adam(model_GP0.parameters(), lr=lr_GP)
model_GP0, losses_GP0 = GP_FM2(x_data, model_GP0, optimizer_GP0, alpha,
                               l,nt, batch_size, t_obs, n_epochs, beta = 0, decrease = None)
model_GP_noise = MLP(dim=dim, time_varying=True)
optimizer_GP_noise = torch.optim.Adam(model_GP_noise.parameters(), lr=lr_GP)
model_GP_noise, losses_GP_noise = GP_FM2(x_data, model_GP_noise, optimizer_GP_noise, alpha,
                                         l,nt, batch_size, t_obs, n_epochs, beta = 1e-2, decrease = None)
model_GP_dec = MLP(dim=dim, time_varying=True)
optimizer_GP_dec = torch.optim.Adam(model_GP_dec.parameters(), lr=lr_GP)
model_GP_dec, losses_GP_dec = GP_FM2(x_data, model_GP_dec, optimizer_GP_dec, alpha,
                               l,nt, batch_size, t_obs, n_epochs, beta = 1e-2, decrease = True)
model_GP_inc = MLP(dim=dim, time_varying=True)
optimizer_GP_inc = torch.optim.Adam(model_GP_inc.parameters(), lr=lr_GP)
model_GP_inc, losses_GP_inc = GP_FM2(x_data, model_GP_inc, optimizer_GP_inc, alpha,
                               l,nt, batch_size, t_obs, n_epochs, beta = 1e-2, decrease = False)

In [None]:
saveFolder = "/hpc/group/mastatlab/gw74/fm_var_change/"
rep_saveFolder = "/hpc/group/mastatlab/gw74/fm_var_change/100_seeds/"

# torch.save(model_GP0.state_dict(), saveFolder + "model_GP0.pt")
# torch.save(model_GP_noise.state_dict(), saveFolder + "model_GP_noise.pt")
# torch.save(model_GP_dec.state_dict(), saveFolder + "model_GP_dec.pt")
# torch.save(model_GP_inc.state_dict(), saveFolder + "model_GP_inc.pt")

In [None]:
%%capture output
nSeeds = 100
n_epochs = 10000

for ll in range(0, nSeeds):
    
    model_GP0 = MLP(dim=dim, time_varying=True)
    optimizer_GP0 = torch.optim.Adam(model_GP0.parameters(), lr=lr_GP)
    model_GP0.load_state_dict(torch.load(saveFolder + "model_GP0.pt"))
    model_GP0, losses_GP0 = GP_FM2(x_data, model_GP0, optimizer_GP0, alpha,
                               l,nt, batch_size, t_obs, n_epochs, beta = 0, decrease = None)
    torch.save(model_GP0.state_dict(), rep_saveFolder + "model_GP0_" + str(ll) + ".pt")
    
    model_GP_noise = MLP(dim=dim, time_varying=True)
    optimizer_GP_noise = torch.optim.Adam(model_GP_noise.parameters(), lr=lr_GP)
    model_GP_noise.load_state_dict(torch.load(saveFolder + "model_GP_noise.pt"))
    model_GP_noise, losses_GP_noise = GP_FM2(x_data, model_GP_noise, optimizer_GP_noise, alpha,
                                         l,nt, batch_size, t_obs, n_epochs,
                                             beta = 1e-2, decrease = None)
    torch.save(model_GP_noise.state_dict(), rep_saveFolder + "model_GP_noise_" + str(ll) + ".pt")
    
    model_GP_dec = MLP(dim=dim, time_varying=True)
    optimizer_GP_dec = torch.optim.Adam(model_GP_dec.parameters(), lr=lr_GP)
    model_GP_dec.load_state_dict(torch.load(saveFolder + "model_GP_dec.pt"))
    model_GP_dec, losses_GP_dec = GP_FM2(x_data, model_GP_dec, optimizer_GP_dec, alpha,
                                   l,nt, batch_size, t_obs, n_epochs, beta = 1e-2, decrease = True)
    torch.save(model_GP_dec.state_dict(), rep_saveFolder + "model_GP_dec_" + str(ll) + ".pt")
    
    model_GP_inc = MLP(dim=dim, time_varying=True)
    optimizer_GP_inc = torch.optim.Adam(model_GP_inc.parameters(), lr=lr_GP)
    model_GP_inc.load_state_dict(torch.load(saveFolder + "model_GP_inc.pt"))
    model_GP_inc, losses_GP_inc = GP_FM2(x_data, model_GP_inc, optimizer_GP_inc, alpha,
                                   l,nt, batch_size, t_obs, n_epochs, beta = 1e-2, decrease = False)
    torch.save(model_GP_inc.state_dict(), rep_saveFolder + "model_GP_inc_" + str(ll) + ".pt")

# 4. W2

In [None]:
N_test = 1000
np.random.seed(1)
z_id1_test = np.random.binomial(1, p, N_test)[:,None]
x1_test = z_id1_test*np.random.multivariate_normal([-1.5, yend], [[.05, 0], [0, .05]], N_test) +\
(1-z_id1_test)*np.random.multivariate_normal([1.5, yend], [[.05, 0], [0, .05]], N_test)
x1_test = torch.from_numpy(x1_test).to(torch.float32)

In [None]:
dAll_GP0 = np.zeros((nSeeds))
dAll_GP_noise = np.zeros((nSeeds))
dAll_GP_dec = np.zeros((nSeeds))
dAll_GP_inc = np.zeros((nSeeds))

for ss in range(nSeeds):
    
    model_GP0 = MLP(dim=dim, time_varying=True)
    model_GP0.load_state_dict(torch.load(rep_saveFolder + "model_GP0_" + str(ss) + ".pt"))
    traj_GP0 = gen_traj(model_GP0, x1_test.shape[0], 2, ss)
    
    model_GP_noise = MLP(dim=dim, time_varying=True)
    model_GP_noise.load_state_dict(torch.load(rep_saveFolder + "model_GP_noise_" + str(ss) + ".pt"))
    traj_GP_noise = gen_traj(model_GP_noise, x1_test.shape[0], 2, ss)
    
    model_GP_dec = MLP(dim=dim, time_varying=True)
    model_GP_dec.load_state_dict(torch.load(rep_saveFolder + "model_GP_dec_" + str(ss) + ".pt"))
    traj_GP_dec = gen_traj(model_GP_dec, x1_test.shape[0], 2, ss)
    
    model_GP_inc = MLP(dim=dim, time_varying=True)
    model_GP_inc.load_state_dict(torch.load(rep_saveFolder + "model_GP_inc_" + str(ss) + ".pt"))
    traj_GP_inc = gen_traj(model_GP_inc, x1_test.shape[0], 2, ss)
    
    _, dAll_GP0[ss] = w_mat_dist(x1_test.numpy(), traj_GP0[-1,:,:].numpy(), p = 2)
    _, dAll_GP_noise[ss] = w_mat_dist(x1_test.numpy(), traj_GP_noise[-1,:,:].numpy(), p = 2)
    _, dAll_GP_dec[ss] = w_mat_dist(x1_test.numpy(), traj_GP_dec[-1,:,:].numpy(), p = 2)
    _, dAll_GP_inc[ss] = w_mat_dist(x1_test.numpy(), traj_GP_inc[-1,:,:].numpy(), p = 2)

In [None]:
print('I-GP-CFM, 0: {:.3f} +- {:.3f}'.format(np.mean(dAll_GP0), np.std(dAll_GP0)))
print('I-GP-CFM, noise: {:.3f} +- {:.3f}'.format(np.mean(dAll_GP_noise), np.std(dAll_GP_noise)))
print('I-GP-CFM, decrease: {:.3f} +- {:.3f}'.format(np.mean(dAll_GP_dec), np.std(dAll_GP_dec)))
print('I-GP-CFM, increase: {:.3f} +- {:.3f}'.format(np.mean(dAll_GP_inc), np.std(dAll_GP_inc)))