In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
# from torchcfm.optimal_transport import OTPlanSampler

from typing import List
import time
from torchdyn.core import NeuralODE

from tqdm import tqdm
from torch.distributions.multivariate_normal import MultivariateNormal

# 0. Data Generation

In [None]:
N = 100
yend = 20

np.random.seed(0)
# start q(x0)
p = .5
z_id0 = np.random.binomial(1, p, N)[:,None]
x0 = z_id0*np.random.multivariate_normal([-4, 0], [[1, 0], [0, 1]], N) +\
(1-z_id0)*np.random.multivariate_normal([4, 0], [[1, 0], [0, 1]], N)

# intermediate points
x_05 = z_id0*np.random.multivariate_normal([3, yend/2], [[1, 0], [0, 1]], N) +\
(1-z_id0)*np.random.multivariate_normal([-3, yend/2], [[1, 0], [0, 1]], N)

# z_id1 = np.random.binomial(1, p, N)[:,None]
x1 = z_id0*np.random.multivariate_normal([-4, yend], [[1, 0], [0, 1]], N) +\
(1-z_id0)*np.random.multivariate_normal([4, yend], [[1, 0], [0, 1]], N)

x0 = torch.from_numpy(x0).to(torch.float32)
x1 = torch.from_numpy(x1).to(torch.float32)
x_05 = torch.from_numpy(x_05).to(torch.float32)

In [None]:
plt.rcParams['svg.fonttype'] = 'none'
plt.rcParams['text.usetex'] = False
plt.rcParams.update({'font.size': 12})

plt.rcParams['figure.figsize'] = [4, 3]

plt.scatter(x0[:,0], x0[:,1], s = 4, c = "black");
plt.scatter(x_05[:,0], x_05[:,1], s = 4, c = "red");
plt.scatter(x1[:,0], x1[:,1], s= 4, c = "orange");
for ii in range(N):
    xx_tmp = torch.stack((x0[ii,0], x_05[ii,0]))
    yy_tmp = torch.stack((x0[ii,1], x_05[ii,1]))
    plt.plot(xx_tmp, yy_tmp, c = 'black', alpha = 0.2, linestyle='dashed')
    
    xx_tmp = torch.stack((x_05[ii,0], x1[ii,0]))
    yy_tmp = torch.stack((x_05[ii,1], x1[ii,1]))
    plt.plot(xx_tmp, yy_tmp, c = 'black', alpha = 0.1, linestyle='dashed')
    
plt.plot()
plt.xlabel("x")
plt.ylabel("y")
plt.xlim([-8, 8]);
plt.ylim([-6, 26]);
# plt.savefig("1_sim_samp.svg")

# 1. Functions

## 1.1 Common Functions

In [None]:
class MLP(torch.nn.Module):
    def __init__(self, dim, out_dim=None, w=64, time_varying=False):
        super().__init__()
        self.time_varying = time_varying
        if out_dim is None:
            out_dim = dim
        self.net = torch.nn.Sequential(
            torch.nn.Linear(dim + (1 if time_varying else 0), w),
            torch.nn.SELU(),
            torch.nn.Linear(w, w),
            torch.nn.SELU(),
            torch.nn.Linear(w, w),
            torch.nn.SELU(),
            torch.nn.Linear(w, out_dim),
        )

    def forward(self, x):
        return self.net(x)

In [None]:
class torch_wrapper(torch.nn.Module):
    """Wraps model to torchdyn compatible format."""

    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, t, x, *args, **kwargs):
        return self.model(torch.cat([x, t.repeat(x.shape[0])[:, None]], 1))

In [None]:
def gen_traj(model, n_samp, nt_gen, seed, x_start = None):
    
    node = NeuralODE(torch_wrapper(model), solver="dopri5",
                 sensitivity="adjoint", atol=1e-4, rtol=1e-4)
    if x_start is None:
        torch.manual_seed(seed)
        x_start = torch.randn(n_samp, dim)

    with torch.no_grad():
        traj = node.trajectory(x_start, t_span=torch.linspace(0, 1, nt_gen))
        
    return traj

In [None]:
def plot_traj(traj, nt_gen, mid_pts = True, start_color = "black", end_color = "orange"):
    plt.scatter(traj[0, :, 0], traj[0, :, 1], s=4, alpha=1, c=start_color)
    if mid_pts:
        plt.scatter(traj[int(nt_gen/2), :, 0], traj[int(nt_gen/2), :, 1], s=4, alpha=1, c="red")
    plt.scatter(traj[:, :, 0], traj[:, :, 1], s=1, alpha=0.1, c="blue")
    plt.scatter(traj[-1, :, 0], traj[-1, :, 1], s=4, alpha=1, c=end_color)
    
    if mid_pts:
        plt.legend(["x0", "x_05", "Flow", "x1"])
    else:
        plt.legend(["x0", "Flow", "x1"])
        
    plt.xlabel("x")
    plt.ylabel("y")

## 1.2 GP-ICFM

In [None]:
def calc_r(ti, tj):
    r = ti[...,None] - tj[...,None,:]
    r[r == 0] = 1e-15
    return r
def k11(r, alpha, l):
    return (alpha**2)*torch.exp(-0.5 * ((r/l)**2))
def k12(r, alpha, l):
    return (alpha**2/l**2)*r*torch.exp(-0.5*((r/l)**2))
def k22(r, alpha, l):
    return (alpha**2/l**4)*(l**2 - r**2)*torch.exp(-0.5*((r/l)**2))

In [None]:
def cov_mat(ti, tj, alpha, l, sig2_diag = 1e-8):
    r = calc_r(ti, tj)
    nt = r.shape[0]
    
    Sig11 = k11(r, alpha, l) + torch.eye(nt)*sig2_diag
    Sig12 = k12(r, alpha, l)
    Sig21 = Sig12.T
    Sig22 = k22(r, alpha, l)
    
    block_row1 = torch.cat([Sig11, Sig12], dim=1)
    block_row2 = torch.cat([Sig21, Sig22], dim=1)
    Sig = torch.cat([block_row1, block_row2], dim=0)
    Sig = (Sig + Sig.T)/2
    
    return Sig

In [None]:
def samp_x_dx(t, alpha, l, x_obs, t_obs, sig2_diag = 1e-8):
    
    nB = x_obs.shape[0]
    dim = x_obs.shape[2]
    nt = t.shape[0]
    nt_obs = t_obs.shape[0]
    
    r_obs_x = calc_r(t_obs, t)
    r_obs_obs = calc_r(t_obs, t_obs)
    
    Sig_11 = cov_mat(t, t, alpha, l, sig2_diag)
    k_obs_x = k11(r_obs_x, alpha, l)
    k_obs_dx = k12(r_obs_x, alpha, l)
    Sig_21 = torch.cat([k_obs_x, k_obs_dx], dim=1)
    Sig_12 = Sig_21.T
    
    Sig_22 = k11(r_obs_obs, alpha, l) + torch.eye(nt_obs)*sig2_diag
    Sig_22_inv = torch.linalg.inv(Sig_22)
    
    Sig_cond = Sig_11 - Sig_12 @ Sig_22_inv @ Sig_21
    Sig_cond = (Sig_cond + Sig_cond.T)/2
    if not bool((torch.linalg.eigvals(Sig_cond).real>=0).all()):
        U, S, Vh = torch.linalg.svd(Sig_cond)
        Sig_cond  = Vh.T @ torch.diag(S + 1e-6) @ Vh
        Sig_cond = (Sig_cond + Sig_cond.T)/2
    
    mu_A = Sig_12 @ Sig_22_inv
    mu_A_expand = mu_A.repeat(nB,1,1)
    
    x_samps = torch.zeros((nB, nt, dim))
    dx_samps = torch.zeros((nB, nt, dim))
    for dd in range(dim):
        x_obs_tmp = x_obs[:,:,dd]
        x_obs_tmp_batch = torch.reshape(x_obs_tmp, (nB, nt_obs, 1))
        mu_new = torch.bmm(mu_A_expand, x_obs_tmp_batch).reshape((nB, 2*nt))
        try:
            x_dx_samps_tmp = MultivariateNormal(loc=mu_new, covariance_matrix=Sig_cond).rsample()
        except:
            x_dx_samps_tmp = np.zeros((nB, 2*nt))
            for bb in range(nB):
                x_dx_samps_tmp[bb,:] = np.random.multivariate_normal(mu_new[bb,:], Sig_cond[bb,:,:])
            x_dx_samps_tmp = torch.from_numpy(x_dx_samps_tmp)
        
        x_samps[:,:,dd] = x_dx_samps_tmp[:,0:nt]
        dx_samps[:,:,dd] = x_dx_samps_tmp[:,nt:(2*nt)]
    
    return x_samps, dx_samps

In [None]:
def GP_FM(model, optimizer, x_data, alpha, l, nt, batch_size, t_obs, n_epochs, sig2_diag = 1e-8):
    
    N = x_data.shape[0]
    dim = x_data.shape[2]
    
    nbatch = int(N/batch_size)
    batch_idx = np.reshape(np.arange(0,N),[nbatch, batch_size])
    
    losses: List[float] = []
    model.train()
    for k in tqdm(range(n_epochs)):

        for bb in range(nbatch):
#             x0 = torch.randn((batch_size,dim))
            x_obs = x_data[batch_idx[bb,:],:,:]
#             x_obs[:,0,:] = x0

            t_batch = torch.rand(nt)
            try:
                xt_batch, ut_batch = samp_x_dx(t_batch, alpha, l, x_obs, t_obs, sig2_diag)
            except:
                pass

            t = t_batch.repeat(1,batch_size).T
            xt = torch.reshape(xt_batch, (-1,dim))
            ut = torch.reshape(ut_batch, (-1,dim))

            vt = model(torch.cat([xt, t], dim=-1))
            loss = torch.mean((vt - ut) ** 2)

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            # Logging
            losses.append(loss.item())
    return model, losses

## 1.3 ICFM

In [None]:
def sample_conditional_pt(x0, x1, t, sigma):
    t = t.reshape(-1, *([1] * (x0.dim() - 1)))
    mu_t = t * x1 + (1 - t) * x0
    epsilon = torch.randn_like(x0)
    return mu_t + sigma * epsilon

def compute_conditional_vector_field(x0, x1):
    return x1 - x0

In [None]:
def I_FM(x1, model, optimizer, sigma = 1e-1, n_epochs = 10000, x0 = None):
    
    losses: List[float] = []
    
    model.train()
    for k in tqdm(range(n_epochs)):
        if x0 is None:
            x0 = torch.randn_like(x1)
            
        # x0, x1 = ot_sampler.sample_plan(x0, y_train)
        # x1 = y_train
        # x0_ot, x1_ot = ot_sampler.sample_plan(x0, x1)

        t = torch.rand(x0.shape[0]).type_as(x0)
        xt = sample_conditional_pt(x0, x1, t, sigma=sigma)
        ut = compute_conditional_vector_field(x0, x1)
        vt = model(torch.cat([xt, t[:, None]], dim=-1))
        loss = torch.mean((vt - ut) ** 2)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        # Logging
        losses.append(loss.item())
    return model, losses

# 2. Fitting

In [None]:
dim = x1.shape[1]
x_data = torch.zeros(N, 3, dim)
x_data[:,0,:] = x0
x_data[:,1,:] = x_05
x_data[:,2,:] = x1

alpha = 2
l = 1
nt = 10
batch_size = 20
t_obs = torch.tensor([0, 0.5, 1])

n_samp = 100
nt_gen = 100
seed = 0

In [None]:
model_icfm0 = MLP(dim = dim, out_dim = dim, time_varying=True)
optimizer_icfm0 = torch.optim.Adam(model_icfm0.parameters(), lr=1e-3)
model_icfm0,_ = I_FM(x0, model_icfm0, optimizer_icfm0, 0, n_epochs = 10000)

In [None]:
traj_start = gen_traj(model_icfm0, n_samp, 2, 6)
x0_gen = traj_start[-1,:,:]

# import pickle
# with open("x0_gen", "wb") as fp: pickle.dump(x0_gen, fp);

## 2.1 Unconditional GP-ICFM

In [None]:
model_1_10000 = MLP(dim=dim, time_varying=True)
optimizer = torch.optim.Adam(model_1_10000.parameters(), lr=1e-3)
model_1_10000, losses_1_10000 = GP_FM(model_1_10000, optimizer, x_data,
                                      1, 5, nt, batch_size, t_obs, 10000, sig2_diag = 1e-6)

In [None]:
traj_1_10000 = gen_traj(model_1_10000, n_samp, nt_gen, 1, x0_gen)

In [None]:
plot_traj(traj_1_10000, nt_gen)
plt.xlim([-8, 8]);
plt.ylim([-6, 26]);
plt.savefig("2_GP_path_un.svg")

## 2.2 Conditional GP-ICFM

In [None]:
model_1_10000 = MLP(dim=dim + dim, out_dim = dim, time_varying=True)
optimizer = torch.optim.Adam(model_1_10000.parameters(), lr=1e-3)

model_1_10000, losses_1_10000 = GP_FM(model_1_10000, optimizer, x_data, 2, 1, nt,
                                      batch_size, t_obs, 10000, sig2_diag = 1e-4)

In [None]:
traj_1_10000 = gen_traj(model_1_10000, n_samp, nt_gen, 1, x0_gen)

In [None]:
plt.rcParams['svg.fonttype'] = 'none'
plt.rcParams['text.usetex'] = False
plt.rcParams.update({'font.size': 12})
plt.rcParams['figure.figsize'] = [4, 3]

plot_traj(traj_1_10000, nt_gen)