In [12]:

import os
import time
import math
import random

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torchdyn

import ot as pot

from torchdyn.core import NeuralODE
from torchdyn.datasets import generate_moons

from torchcfm.conditional_flow_matching import *
from torchcfm.models.models import *
from torchcfm.utils import *

from random_processes import *
from models import *


savedir = "savedir/MaskedCFM"
os.makedirs(savedir, exist_ok=True)

In [2]:
# Seed initialization
def init_seed(use_fixed=True, fixed_seed=42):
    if use_fixed:
        seed = fixed_seed
    else:
        seed = int.from_bytes(os.urandom(64), "little")
    print(f"Using seed: {seed}")
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    return seed


init_seed(use_fixed=True, fixed_seed=42)

Using seed: 42


42

In [None]:
# Plotting utilities
xlim = (-8, 8)
ylim = (-8, 8)

def plot_trajectories_v2(traj):
    fig, ax = plt.subplots(figsize=(6,6))
    n = min(2000, traj.shape[1])
    ax.scatter(traj[0, :n, 0], traj[0, :n, 1], s=10, alpha=0.8, c="black", label="Prior z(1)")
    ax.scatter(traj[:, :n, 0], traj[:, :n, 1], s=0.2, alpha=0.2, c="olive", label="Flow")
    ax.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=4, alpha=1, c="blue", label="z(0)")
    ax.set_xlim(*xlim)
    ax.set_ylim(*ylim)
    ax.set_aspect("equal", adjustable="box")
    ax.legend(loc="upper right")
    ax.set_xticks([]); ax.set_yticks([])
    plt.show()

Conditional Sample Paths: $X_t \mid z = (x_0,x_1)$

1) Noisy Straight Paths: $ X_t = (1-t) x_0 + t x_1 + \sigma \epsilon $ where $\epsilon \sim \mathcal{N}(0, I)$.

2) Brownian Bridge Paths: $X_t = (1-t) x_0 + t x_1 + \sigma Z_t$ where $Z_t = W_t - tW_1$ is the Brownian bridge and $W_t$ is the standard Wiener process. Alternatively, $dZ_t = \frac{-Z_t}{1-t} dt + dW_t$ with $Z_0 = 0$ so that $dX_t = \frac{x_1-X_t}{1-t} dt + \sigma dW_t$ with $X_0 = x_0$.

In [None]:
def SampleConditionalNoisyStraightPath(x0, x1, t, sigma):
    """
    Draw a sample from the probability path xt = (1 - t) * x0 + t * x1 + sigma * epsilon 
    epsilon is standard normal random variable

    Parameters
    ----------
    x0 : Tensor, shape (bs, *dim)
        represents the source minibatch
    x1 : Tensor, shape (bs, *dim)
        represents the target minibatch
    t : FloatTensor, shape (bs)

    Returns
    -------
    xt : Tensor, shape (bs, *dim)
    """
    t = t.reshape(-1, *([1] * (x0.dim() - 1)))
    mu_t = t * x1 + (1 - t) * x0
    epsilon = torch.randn_like(x0)
    return mu_t + sigma * epsilon


def SampleConditionalBrownianBridgePath(x0, x1, t, sigma):
    """
    Draw a sample from the probability path xt = (1 - t) * x0 + t * x1 + sigma * (Wt - t * W1)
    where Wt is a standard Brownian motion.

    Parameters
    ----------
    x0 : Tensor, shape (bs, *dim)
        represents the source minibatch
    x1 : Tensor, shape (bs, *dim)
        represents the target minibatch
    t : FloatTensor, shape (bs)

    Returns
    -------
    xt : Tensor, shape (bs, *dim)
    """
    t = t.reshape(-1, *([1] * (x0.dim() - 1)))
    mu_t = t * x1 + (1 - t) * x0
    sigma_t = sigma * torch.sqrt(t * (1 - t))  #FIX THIS, should be Brownian bridge stddev
    epsilon = torch.randn_like(x0)             #FIX THIS, should be Brownian bridge noise
    return mu_t + sigma_t * epsilon

Conditional veocity field: $u(x,t \mid z)$ where $z= (x_0,x_1)$

1) Noisty Straight Paths: $u(x,t \mid z) = x_1 - x_0$ 
2) Brownian Bridge Paths: 
$$
u(x,t \mid z) = \frac{x_1 - x}{1-t} - \frac{\sigma^2}{2}\nabla \log p(x,t\mid z)
$$ 
where 
$$
p(x,t\mid z) = \mathcal{N}(x\mid m_t(z),\, \sigma^2 t(1-t) I),\quad m_t(z) = (1-t) x_0 + tx_1
$$
is the marginal distribution of the Brownian bridge between $x_0$ and $x_1$ at time $t\in [0,1]$. In closed form, this leads to 
$$
u(x,t \mid z) = (x_1-x_0) + \frac{1-2t}{2t(1-t)} (x - m_t(z)).
$$ 


In [None]:
# Implement the conditional velocity field for the Brownian Bridge

def ConditionalVelocityField(x0, x1):
    """
    Compute the conditional vector field ut(x1|x0) = x1 - x0, see Eq.(15) [1].

    Parameters
    ----------
    x0 : Tensor, shape (bs, *dim)
        represents the source minibatch
    x1 : Tensor, shape (bs, *dim)
        represents the target minibatch

    Returns
    -------
    ut : conditional vector field ut(x1|x0) = x1 - x0

    References
    ----------
    [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
    """
    return x1 - x0

In [13]:
lorenz_drift, lorenz_dif = make_lorenz_system(
    sigma=10.0, rho=28.0, beta=8/3, noise_scale=0.5
)

def lorenz_init_sampler(batch_size):
    return torch.empty(batch_size, 3).uniform_(-15, 15)

lorenz_sampler = TorchSDEDiffusionSampler(
    lorenz_drift,
    lorenz_dif,
    t0=0.0,
    t1=10.0,
    steps=2000,
    init_sampler=lorenz_init_sampler,
    method="euler",
)

traj, times, mean_traj = lorenz_sampler.sample(num_paths=256)



Independent coupling:
$$ \pi(dx_0, dx_1) = q_0(dx_0) q_1(dx_1)$$

In [None]:
%%time
checkpoints = []  # reset if rerunning
sigma = 0.1
dim = 2
var = 0.01
batch_size = 256
dim_w = 32
num_layers = 3
#model = MLP(dim=dim, time_varying=True)
#model = MaskedBlockMLP(T=1, in_dim=dim, out_dim=dim, hidden_per_t=(64, 64, 64), time_varying=True)    
#model = MaskedBlockMLP(T=2, in_dim=1, out_dim=1, hidden_per_t=(5, 5, 5), causal=True, time_varying=True)    
model_causal = MaskedBlockMLP(T=2, in_dim=1, out_dim=1, hidden_per_t=(dim_w,) * num_layers, causal=True, time_varying=True) 
model_full = MaskedBlockMLP(T=2, in_dim=1, out_dim=1, hidden_per_t=(dim_w,) * num_layers, causal=False, time_varying=True) 

optimizer_causal = torch.optim.Adam(model_causal.parameters())
optimizer_full = torch.optim.Adam(model_full.parameters())

start = time.time()
for k in range(20000):
    optimizer_causal.zero_grad()
    optimizer_full.zero_grad()
    
    #m = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(dim), math.sqrt(var) * torch.eye(dim))
    #x0 = m.sample((batch_size,)).float()
    x0 = sample_8gaussians(batch_size)
    x1 = sample_moons(batch_size)

    t = torch.rand(x0.shape[0]).type_as(x0)
    xt = SampleConditionalNoisyStraightPath(x0, x1, t, sigma=0.01)
    ut = ConditionalVelocityField(x0, x1)

    #vt = model(torch.cat([xt, t[:, None]], dim=-1))
    xt_in = torch.cat([xt, t[:, None]], dim=-1)     # (B, T*in_dim + 1)
    vt_causal = model_causal(xt_in)                               # returns (B, T*out_dim)
    vt_full = model_full(xt_in) 
    #vt = vt.view(batch_size, 1, dim).squeeze(1)

    loss_causal = torch.mean((vt_causal - ut) ** 2)
    loss_full = torch.mean((vt_full - ut) ** 2)

    loss_causal.backward()
    loss_full.backward()
    optimizer_causal.step()
    optimizer_full.step()

    if (k + 1) % 5000 == 0:
        end = time.time()
        print(f"{k+1}: causal loss {loss_causal.item():0.3f} time {(end - start):0.2f}")
        print(f"{k+1}: full loss {loss_full.item():0.3f} time {(end - start):0.2f}")
        start = end
        
        node_causal = NeuralODE(torch_wrapper(model_causal), solver="dopri5", sensitivity="adjoint", atol=1e-4, rtol=1e-4)
        node_full = NeuralODE(torch_wrapper(model_full), solver="dopri5", sensitivity="adjoint", atol=1e-4, rtol=1e-4)

        with torch.no_grad():
            traj_causal = node_causal.trajectory(sample_8gaussians(1024), t_span=torch.linspace(0, 1, 100))
            traj_full = node_full.trajectory(sample_8gaussians(1024), t_span=torch.linspace(0, 1, 100))

        checkpoints.append((k + 1, loss_causal.item(), traj_causal.cpu().numpy(), loss_full.item(), traj_full.cpu().numpy()))
        plot_trajectories_v2(traj_causal.cpu().numpy())
        plot_trajectories_v2(traj_full.cpu().numpy())

In [None]:
# with torch.no_grad():
#     for idx, layer in enumerate(model.layers):
#         if isinstance(layer, MaskedLinear):
#             print(f"\nLayer {idx}")
#             print("weight =\n", layer.weight)
#             print("mask =\n", layer.mask)
#             print("effective weight =\n", layer.weight * layer.mask)
#             if layer.bias is not None:
#                 print("bias =\n", layer.bias)

Optimal transport coupling:
$$ \pi_{\mathsf{OT}}(dx_0, dx_1) \triangleq \argmin_{\pi \in \mathscr{C}(\mathbb{Q}_0, \mathbb{Q}_1)} \frac{1}{2}\int \|x_0 - x_1\|^2 \pi(dx_0, dx_1) $$

In [None]:
# %%time
# from torchcfm.optimal_transport import OTPlanSampler

# ot_sampler = OTPlanSampler(method="exact")
# sigma = 0.1
# dim = 2
# batch_size = 256
# model = MLP(dim=dim, time_varying=True)
# optimizer = torch.optim.Adam(model.parameters())
# FM = ConditionalFlowMatcher(sigma=sigma)

# start = time.time()
# for k in range(20000):
#     optimizer.zero_grad()

#     x0 = sample_8gaussians(batch_size)
#     x1 = sample_moons(batch_size)

#     # Draw samples from OT plan
#     x0, x1 = ot_sampler.sample_plan(x0, x1)

#     t = torch.rand(x0.shape[0]).type_as(x0)
#     xt = sample_conditional_pt(x0, x1, t, sigma=0.01)
#     ut = compute_conditional_vector_field(x0, x1)

#     vt = model(torch.cat([xt, t[:, None]], dim=-1))
#     loss = torch.mean((vt - ut) ** 2)

#     loss.backward()
#     optimizer.step()

#     if (k + 1) % 5000 == 0:
#         end = time.time()
#         print(f"{k+1}: loss {loss.item():0.3f} time {(end - start):0.2f}")
#         start = end
#         node = NeuralODE(
#             torch_wrapper(model), solver="dopri5", sensitivity="adjoint", atol=1e-4, rtol=1e-4
#         )
#         with torch.no_grad():
#             traj = node.trajectory(
#                 sample_8gaussians(1024),
#                 t_span=torch.linspace(0, 1, 100),
#             )
#             plot_trajectories(traj.cpu().numpy())


