<a href="https://colab.research.google.com/github/lqiang67/rectified-flow/blob/main/examples/train_2d_toys.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!git clone https://github.com/lqiang67/rectified-flow.git
%cd rectified-flow/

# Rectified Flow: 2D Toy Example

This notebook provides an example illustrating the basic concept of Rectified Flow and demonstrates training on a 2D toy example. 

You can check on [this blog post](https://rectifiedflow.github.io/blog/2024/intro/) for a quick introduction.

Rectified Flow learns an ordinary differential equation (ODE), $ \dot{Z}_t = v(Z_t, t) $, to transfer data from a source distribution, $ \pi_0 $, to a target distribution, $ \pi_1 $, given limited observed data points sampled from $ \pi_1 $.

In [1]:
import torch
import numpy as np
import os
import sys
import matplotlib.pyplot as plt

import torch.distributions as dist

from rectified_flow.utils import set_seed
from rectified_flow.utils import visualize_2d_trajectories_plotly

from rectified_flow.rectified_flow import RectifiedFlow

set_seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Generate Distributions $ \pi_0 $ and $ \pi_1 $

In this section, we generate synthetic $ \pi_0 $ and $ \pi_1 $ as two Gaussian mixture models (GMM).

We sample $50,000$ data points from each distribution and store them as `D0` and `D1`. Additionally, we store the labels for $ \pi_1 $ to differentiate whether the points belong to the upper or lower part of the $\pi_1$ GMM.

In [None]:
from rectified_flow.datasets.toy_gmm import TwoPointGMM

n_samples = 50000
pi_0 = TwoPointGMM(x=0.0, y=7.5, std=0.5, device=device)
pi_1 = TwoPointGMM(x=15.0, y=7.5, std=0.5, device=device)
D0 = pi_0.sample([n_samples])
D1, labels = pi_1.sample_with_labels([n_samples])
labels.tolist()

plt.figure(figsize=(3, 3))
plt.title(r'Samples from $\pi_0$ and $\pi_1$')
plt.scatter(D0[:, 0].cpu(), D0[:, 1].cpu(), alpha=0.5, label=r'$\pi_0$')
plt.scatter(D1[:, 0].cpu(), D1[:, 1].cpu(), alpha=0.5, label=r'$\pi_1$')
plt.legend()

## 1-Rectified Flow

Given observed samples $X_0 \sim \pi_0$ from source distribution and $X_1 \sim \pi_1$ from target distribution, the *rectified flow* induced by coupling $(X_0, X_1)$ is the time-differentiable process $\mathbf{Z} = \{Z_t: t \in [0, 1]\}$ with the velocity field defined as:

$$
\mathrm{d}Z_t = v(Z_t, t) \, \mathrm{d}t, \quad t \in [0, 1], \quad \text{starting from } Z_0 = X_0.
$$

Here, $v: \mathbb{R}^d \times [0, 1] \to \mathbb{R}^d$ is set in a way that ensures that $Z_1$ follows $\pi_1$ when $Z_0 \sim \pi_0$.

Denote $X_t = \alpha_t \cdot X_1 + \beta_t \cdot X_0$ as an interpolation of samples $X_0$ and $X_1$. The velocity field is given by:

$$
v(z, t) = \mathbb{E}[ \dot X_t \mid X_t = z] = \arg \min_{v} \int_0^1 \mathbb{E}\left[\lVert  \dot \alpha_t X_1 + \dot \beta_t X_0 - v(X_t, t) \rVert^2\right] \mathrm{d}t,
$$
Here, $\alpha_t, \beta_t$ are any differentiable functions of time $t$ that satisfy $\alpha_0=\beta_1=0$ and $\alpha_1 = \beta_0 = 1$, $v(z,t)$ is the conditional expectation of all $\dot X_t$ at $X_t=z$.

We call the process $\{Z_t\}$ the **rectified flow** induced from the interpolation $\{X_t\}$.

The default choice is the straight interpolation:
$$
X_t = t X_1 + (1-t) X_0, \quad \quad \dot X_t = X_1 - X_0.
$$

Let's first visualize this straight interpolation.

In [None]:
x_0 = pi_0.sample([500])
x_0_upper = x_0.clone()
x_0_upper[:, 1] = torch.abs(x_0_upper[:, 1])
x_0_lower = x_0.clone()
x_0_lower[:, 1] = -torch.abs(x_0_lower[:, 1])

x_1_upper = pi_1.sample([500])
x_1_lower = pi_1.sample([500])

interp_upper = []
interp_lower = []

for t in np.linspace(0, 1, 100):
    x_t_uppper = (1 - t) * x_0_upper + t * x_1_upper
    x_t_lower = (1 - t) * x_0_lower + t * x_1_lower
    interp_upper.append(x_t_uppper)
    interp_lower.append(x_t_lower)
    
visualize_2d_trajectories_plotly(
    trajectories_dict={
        "upper": interp_upper,
		"lower": interp_lower
    },
    D1_gt_samples=torch.cat([x_1_upper, x_1_lower], dim=0),
    num_trajectories=100,
	title="Straight Interpolation",
)

This straight interpolation successfully constructs paths to transport $\pi_0$ to $\pi_1$. However, we cannot "simulate" these paths from $X_0$ because:

- The updates at each position $X_t$ depend on the final state $X_1$, which is inaccessible at intermediate times ($t < 1$).

In the figure above, trajectories intersect at the middle (drag `step` to $50$), indicating that there is **multiple possible diretions** to $\pi_1$. 

Such behavior makes it impossible to simulate using ODEs, as ODEs require a unique direction (or velocity) $v_t(X_t)$ for the given current state $(X_t, t)$.

### Learning a Rectified Flow Velocity with MLP

We parameterize the velocity field using a small unconditional MLP $v_\theta$.

The model is then passed to the `RectifiedFlow` class. In this 2D toy example, the data shape is `(2,)`, and we use the `"straight"` interpolation mode:

In [4]:
from rectified_flow.models.toy_mlp import MLPVelocity

model = MLPVelocity(2, hidden_sizes = [128, 128, 128]).to(device)

rectified_flow = RectifiedFlow(
    data_shape=(2,),
    velocity_field=model,
    interp="straight",
    source_distribution=pi_0,
    device=device,
)

During training, the model samples data points from the source ($\pi_0$) and target ($\pi_1$) distributions to compute the loss for optimizing the velocity field:

$$
\ell = \min_{\theta} 
\int_0^1 \mathbb{E}_{X_0 \sim \pi_0, X_1 \sim \pi_1} \left [ \left\| (X_1 - X_0) - v_\theta(X_t, t) \right\|^2 \right ] \mathrm{d}t, 
\quad \text{where} \quad
X_t = t X_1 + (1-t) X_0.
$$

The `get_loss` method computes the rectified flow loss using:
- **Inputs**:
  - `x_0`: Samples from $\pi_0$.
  - `x_1`: Samples from $\pi_1$.
  - `labels` (optional): Provides conditional information, e.g., GMM component idx.
- **Steps**:
  1. **Interpolation**: Computes intermediate states $X_t$ and derivatives $\dot{X}_t$.
  2. **Prediction**: Predicts $v_\theta(X_t, t)$ using the velocity model.
  3. **Loss**: Measures the loss between $v_\theta(X_t, t)$ and $\dot{X}_t$, with time-dependent weighting.

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
batch_size = 1024

losses = []

for step in range(5000):
	optimizer.zero_grad()
	idx = torch.randperm(n_samples)[:batch_size]
	x_0 = D0[idx].to(device)
	x_1 = D1[idx].to(device)
	
	loss = rectified_flow.get_loss(x_0, x_1)
	loss.backward()
	optimizer.step()
	losses.append(loss.item())

	if step % 1000 == 0:
		print(f"Epoch {step}, Loss: {loss.item()}")
    
plt.plot(losses)

With the rectified flows trained using straight interpolation, we can now visualize the trajectories to observe how the rectified flow effectively **'causalizes'** the interpolation process.

We split $X_0 \sim \pi_0$ into two categories: points above and below the $X$-axis. We then used these subsets of $X_0$ to perform sampling with the rectified flows.

From the visualization, we can observe that the trajectories are "rewired" at the middle - there are no intersections between trajectories, and the blue and pink dots evolve separately, meaning that they are now "simulatable".

This reflects how rectified flow learns the average direction at points of intersection with:

$$
v(z, t) = \mathbb{E}[\dot{X}_t \mid X_t = z].
$$

**Intuition of "average"**

A critical intuition here is that this average does not change the total amount of mass or the number of "particles" passing through the intersection, thereby preserving the same distribution as $\{X_t\}$ at every time $t$. 

Check the trajectories below: when $t$ is around $0.5$, the number of particles moving to the right side has not changed; they have merely swapped trajectories. The number of particles on the interpolation path is approximately the same as the number on the learned rectified flow path.

![cross](../assets/flow_in_out.png)

Note: Due to the error introduced by Euler discretization, some particles may have moved to the other side.

In [None]:
from rectified_flow.samplers import EulerSampler
from rectified_flow.utils import visualize_2d_trajectories_plotly

euler_sampler_1rf_unconditional = EulerSampler(
    rectified_flow=rectified_flow,
    num_steps=100,
)

traj_upper = euler_sampler_1rf_unconditional.sample_loop(x_0=x_0_upper).trajectories
traj_lower = euler_sampler_1rf_unconditional.sample_loop(x_0=x_0_lower).trajectories

visualize_2d_trajectories_plotly(
    trajectories_dict={"upper": traj_upper, "lower": traj_lower},
    D1_gt_samples=D1[:1000],
    num_trajectories=200,
	title="Unconditional 1-Rectified Flow",
)

The trajectories of $\{Z_t\}$ are not straight; therefore, the one-step result does not yield accurate results.

$$
\hat{X}_1 = X_0 + v(X_0, 0).
$$

In [None]:
euler_sampler_1rf_unconditional.sample_loop(num_steps=1, seed=0)

visualize_2d_trajectories_plotly(
    {"1rf uncond one-step": euler_sampler_1rf_unconditional.trajectories}, 
    D1[:1000],
    num_trajectories=200,
	title="Unconditional 1-Rectified Flow, 1-step",
)

### Learning a Conditional Rectified Flow

The rectified flow model can be extended to include class conditioning. By passing class information $c \in \{0, 1\}$ (e.g., GMM components index), the velocity field becomes class-dependent.

$$
\ell = \min_{\theta}
\int_0^1 \mathbb{E}_{X_0 \sim \pi_0, (X_1,c) \sim \pi_1} \left [ \left\| (X_1 - X_0) - v_\theta(X_t, t, c) \right\|^2 \right ] \mathrm{d}t,
\quad \text{where} \quad
X_t = t X_1 + (1-t) X_0.
$$

In this case, $(X_1, c)\sim \pi_1$ is the distribution of data-label pairs.

In [8]:
from rectified_flow.models.toy_mlp import MLPVelocityConditioned

model_cond = MLPVelocityConditioned(2, hidden_sizes = [128, 128, 128]).to(device)

rectified_flow_cond = RectifiedFlow(
    data_shape=(2,),
    velocity_field=model_cond,
    interp="straight",
    source_distribution=pi_0,
    device=device,
)

In [None]:
optimizer = torch.optim.Adam(model_cond.parameters(), lr=1e-3)
batch_size = 1024

losses = []

for step in range(5000):
	optimizer.zero_grad()
	idx = torch.randperm(n_samples)[:batch_size]
	x_0 = D0[idx]
	x_1, cond = D1[idx], labels[idx]
	
	x_0 = x_0.to(device)
	x_1 = x_1.to(device)
	cond = torch.tensor(cond).to(device)
	
	loss = rectified_flow_cond.get_loss(x_0, x_1, labels=cond)
	loss.backward()
	optimizer.step()
	losses.append(loss.item())

	if step % 1000 == 0:
		print(f"Epoch {step}, Loss: {loss.item()}")
    
plt.plot(losses)

By incorporating class information, the model can better capture the structure of conditional distributions. 

For instance, given a specific class (e.g., sampling only the upper part of the right-side distribution), the trajectories do not intersect. 

As a result, the learned velocity is very straight, and even a one-step result performs remarkably well.

In [None]:
from rectified_flow.samplers import EulerSampler
from rectified_flow.utils import visualize_2d_trajectories_plotly

euler_sampler_1rf_conditional = EulerSampler(
    rectified_flow=rectified_flow_cond,
    num_steps=1,
    num_samples=500,
)

cond = torch.zeros((500,), device=device)
euler_sampler_1rf_conditional.sample_loop(seed=0, labels=cond)

visualize_2d_trajectories_plotly(
    {"1rf cond": euler_sampler_1rf_conditional.trajectories},
    D1[:1000],
    num_trajectories=200,
    title="Conditional 1-Rectified Flow",
)

## Reflow for 2-Rectified Flow

Now let's try the *reflow* procedure to get a straightened rectified flow, 
denoted as 2-Rectified Flow, by repeating the same procedure on with $(X_0,X_1)$ replaced by  $(Z_0^1, Z_1^1)$, where $(Z_0^1, Z_1^1)$ is the coupling simulated from 1-Rectified Flow.  

We sample $50,000$ $Z_0^1$ and generate their corresponding $Z_1^1$ by simulating 1-Rectified Flow.

In [11]:
Z_0 = rectified_flow.sample_source_distribution(batch_size=50000)

Z_1 = euler_sampler_1rf_unconditional.sample_loop(x_0=Z_0, num_steps=1000).trajectories[-1]

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
batch_size = 1024

losses = []

for step in range(5000):
	optimizer.zero_grad()
	idx = torch.randperm(n_samples)[:batch_size]
	x_0 = Z_0[idx]
	x_1 = Z_1[idx]
	
	x_0 = x_0.to(device)
	x_1 = x_1.to(device)
	
	loss = rectified_flow.get_loss(x_0, x_1)
	loss.backward()
	optimizer.step()
	losses.append(loss.item())

	if step % 1000 == 0:
		print(f"Epoch {step}, Loss: {loss.item()}")
    
plt.plot(losses)

In [None]:
euler_sampler_2rf = EulerSampler(
    rectified_flow=rectified_flow,
    num_samples=1000,
)

euler_sampler_2rf.sample_loop(num_steps=100, seed=0)

visualize_2d_trajectories_plotly(
    {"2rf": euler_sampler_2rf.trajectories}, 
    D1[:1000],
    num_trajectories=200,
    title="Reflow Trajectories, 100-step",
)

In [None]:
euler_sampler_2rf.sample_loop(num_steps=1, seed=0)

visualize_2d_trajectories_plotly(
    {"2rf one-step" :euler_sampler_2rf.trajectories}, 
    D1[:1000],
    num_trajectories=200,
    title="Reflow Trajectories, 1-step",
)