For affine interpolations, 

$$
X_t = \alpha_t X_1 + \beta_t X_0 \\
\dot X_t = \dot{\alpha}_t X_1 + \dot \beta_t X_0
$$

we have:

$$
\begin{aligned}
v(X_t, t) &= \mathbb E_{\substack{X_0 \sim \pi_0 \\ X_1 \sim \pi_1}} \left[\dot X_t \mid X_t= \alpha_t X_1 + \beta_t X_0\right]\\

&= \mathbb E_{X_1 \sim \pi_1} \left[
\mathbb E_{X_0 \sim \pi_0} \left[\dot X_t \mid X_t, X_1\right]
\right] \\

&= \mathbb E_{X_1 \sim \pi_1} \left[
\mathbb E_{X_0 \sim \pi_0} \left[\dot{\alpha}_t X_1 + \dot \beta_t X_0 \mid X_0 =\frac{X_t - \alpha_t X_1}{\beta_t}\right]
\right] \\

&= \mathbb E_{X_1\sim \pi_1} \left[
\frac{\pi_0(\frac{X_t - \alpha_t X_1}{\beta_t})}{Z} \cdot \left(\dot \alpha_t X_1 + \dot \beta_t \frac{(X_t - \alpha_t X_1)}{\beta_t}\right)
\right] \\

\end{aligned}
$$

Where $Z = \mathbb E_{X_1 \sim \pi_1}\left[\pi_0\left(\dfrac{X_t-\alpha_tX_1}{\beta_t}\right)\right]$. 

If $\pi_0$ is a standard Gaussian distribution and replace $\pi_1$ with the empirical Dirac distribution, we have:
$$
\begin{aligned}
\tilde v(X_t,t)&= \frac 1 N \sum_{i=1}^{N}  \frac{\pi_0(\frac{X_t - \alpha_t X_1^{(i)}}{\beta_t})}{Z} \cdot \left(\dot \alpha_t X_1^{(i)} + \dot \beta_t \frac{(X_t - \alpha_t X_1^{(i)})}{\beta_t}\right) \\

\end{aligned}
$$
Denote $X_0^{(i)}=\dfrac{X_t - \alpha_t X_1^{(i)}}{\beta_t}$, $\dot X_t^{(i)} = \dot \alpha_t X_1^{(i)} + \dot \beta_t X_0^{(i)}$
$$
\begin{aligned}
\tilde v(X_t,t)&= \frac 1 N \sum_{i=1}^{N}  \frac{\pi_0(X_0^{(i)})}{Z} \cdot \left(\dot \alpha_t X_1^{(i)} + \dot \beta_t X_0^{(i)}\right) \\

&= \frac 1 N \sum_{i=1}^{N} \frac{\exp(-\frac{\|X_0^{(i)}\|^2}{2})}{\frac{1}{N}\sum_{j=1}^N\exp(-\frac{\|X_0^{(j)}\|^2}{2})} \left(\dot \alpha_t X_1^{(i)} + \dot \beta_t X_0^{(i)}\right) \\
\end{aligned}
$$

In [1]:
import torch
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import warnings
import copy
import plotly.graph_objects as go

import torch.distributions as dist

from rectified_flow.utils import set_seed
from rectified_flow.utils import match_dim_with_data

from rectified_flow.rectified_flow import RectifiedFlow
from rectified_flow.flow_components.interpolation_solver import AffineInterp

from rectified_flow.utils import visualize_2d_trajectories_plotly

device = torch.device('cpu')

In [2]:
from rectified_flow.datasets.toy_gmm import TwoPointGMM

set_seed(0)
n_samples = 50000
pi_0 = dist.MultivariateNormal(torch.zeros(2, device=device), torch.eye(2, device=device) * 0.08)
D0 = pi_0.sample([n_samples])

In [None]:
import math
import matplotlib.pyplot as plt

N_categories = 8

# Uniformly distributed D1 on a circle
theta = torch.linspace(0.0, 2.0 * math.pi, steps=N_categories+1)[:-1]
x = torch.cos(theta) * 5
y = torch.sin(theta) * 5

# D1 dataset
data = torch.stack([x, y], dim=1)

plt.figure(figsize=(5, 5))
plt.scatter(data[:, 0].numpy(), data[:, 1].numpy(), color='red')

plt.gca().set_aspect('equal', adjustable='box')
plt.title("Data on a Circle")
plt.xlabel("x")
plt.ylabel("y")
plt.grid(True)
plt.show()

print(data.shape)

In [4]:
from rectified_flow.models.gauss_analytic import AnalyticGaussianVelocity

velocity = AnalyticGaussianVelocity(dataset=data, interp=AffineInterp("spherical"))

rectfied_flow = RectifiedFlow(
    data_shape=(2,),
    velocity_field=velocity,
    interp=velocity.interp,
    source_distribution=pi_0,
    device=device,
)

In [None]:
from rectified_flow.samplers.euler_sampler import EulerSampler

sampler = EulerSampler(rectfied_flow, num_steps=100, num_samples=1000)

sampler.sample_loop(seed=0)

traj = sampler.trajectories

print(traj[-1].shape)

In [6]:
ref_points = data  # shape [N_categories, 2]

points_last = traj[-1]  # shape [B, 2]

distances = (points_last.unsqueeze(1) - ref_points).norm(dim=2)

categories = distances.argmin(dim=1)

type_list = [[] for _ in range(N_categories)]

for i in range(len(traj)):
    points_i = traj[i]

    step_data = [[] for _ in range(N_categories)]
    
    for j in range(points_i.shape[0]):
        cat_id = categories[j].item()
        step_data[cat_id].append(points_i[j])
    
    for c in range(N_categories):
        if step_data[c]:
            step_data_c = torch.stack(step_data[c], dim=0)  # shape [Nc, 2]
        else:
            step_data_c = torch.empty(0, 2)
        type_list[c].append(step_data_c)

In [None]:
visualize_2d_trajectories_plotly(
    trajectories_dict={
        f"{i}": _type_list for i, _type_list in enumerate(type_list)
    },
    D1_gt_samples=data,
    num_trajectories=300,
    title="Decision Boundaris",
    show_legend=False,
)