# Rectified Flow: 2D Toy Example

This notebook provides an example illustrating the basic concept of Rectified Flow and demonstrates training on a 2D toy example. For more details, refer to '[Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow](https://arxiv.org/abs/2209.03003).'

Rectified Flow learns an ordinary differential equation (ODE), $ \dot{Z}_t = v(Z_t, t) $, to transfer data from a source distribution, $ \pi_0 $, to a target distribution, $ \pi_1 $, given limited observed data points sampled from $ \pi_1 $.

In [None]:
import torch
import os
import sys
import matplotlib.pyplot as plt

import torch.distributions as dist

from rectified_flow.utils import set_seed

from rectified_flow.rectified_flow import RectifiedFlow

set_seed(0)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

## Generate Distributions $ \pi_0 $ and $ \pi_1 $

In this section, we generate synthetic $ \pi_0 $ and $ \pi_1 $ as two Gaussian mixture models (GMM).

We sample $50,000$ data points from each distribution and store them as `D0` and `D1`. Additionally, we store the labels for $ \pi_1 $ to differentiate whether the points belong to the upper or lower part of the $\pi_1$ GMM.

In [None]:
from rectified_flow.datasets.toy_gmm import TwoPointGMM

n_samples = 50000
pi_0 = TwoPointGMM(x=0.0, y=7.5, std=0.5, device=device)
pi_1 = TwoPointGMM(x=15.0, y=7.5, std=0.5, device=device)
D0 = pi_0.sample([n_samples])
D1, labels = pi_1.sample_with_labels([n_samples])
labels.tolist()

plt.figure(figsize=(5, 5))
plt.title(r'Samples from $\pi_0$ and $\pi_1$')
plt.scatter(D0[:, 0].cpu(), D0[:, 1].cpu(), alpha=0.5, label=r'$\pi_0$')
plt.scatter(D1[:, 0].cpu(), D1[:, 1].cpu(), alpha=0.5, label=r'$\pi_1$')
plt.legend()

## 1-Rectified Flow

Given observed samples $X_0 \sim \pi_0$ and $X_1 \sim \pi_1$, the *rectified flow* induced by $(X_0, X_1)$ is the time-differentiable process $\mathbf{Z} = \{Z_t: t \in [0, 1]\}$ with the velocity field defined as:

$$
\mathrm{d}Z_t = v(Z_t, t) \, \mathrm{d}t, \quad t \in [0, 1], \quad \text{starting from } Z_0 = X_0.
$$

Here, $v: \mathbb{R}^d \times [0, 1] \to \mathbb{R}^d$ is set in a way that ensures that $Z_1$ follows $\pi_1$ when $Z_0 \sim \pi_0$.

Denote $X_t = \alpha_t \cdot X_1 + \beta_t \cdot X_0$ as an interpolation of samples $X_0$ and $X_1$. The velocity field is given by:

$$
v(z, t) = \mathbb{E}[ \dot X_t \mid X_t = z] = \arg \min_{v} \int_0^1 \mathbb{E}\left[\lVert  \dot \alpha_t X_1 + \dot \beta_t X_0 - v(X_t, t) \rVert^2\right] \mathrm{d}t,
$$
where $\alpha_t, \beta_t$ are any differentiable functions of time $t$ that satisfy $\alpha_0=\beta_1=0$ and $\alpha_1 = \beta_0 = 1$.

The default choice is the straight interpolation:
$$
X_t = t X_1 + (1-t) X_0, \quad \quad \dot X_t = X_1 - X_0.
$$

### Learning a unconditional Rectified Flow

We parameterize the velocity field using a small unconditional MLP $v_\theta$.

The model is then passed to the `RectifiedFlow` class. Since this is a 2D toy example, the data shape is `(2,)`, and we use the `"straight"` interpolation mode:

In [None]:
from rectified_flow.models.toy_mlp import MLPVelocity

model = MLPVelocity(2, hidden_sizes = [128, 128, 128]).to(device)

rectified_flow = RectifiedFlow(
    data_shape=(2,),
    velocity_field=model,
    interp="straight",
    source_distribution=pi_0,
    device=device,
)

During training, the model samples data points from the source ($\pi_0$) and target ($\pi_1$) distributions to compute the loss for optimizing the velocity field:

$$
\ell = \min_{\theta} 
\int_0^1 \mathbb{E}_{X_0 \sim \pi_0, X_1 \sim \pi_1} \left [ \left\| (X_1 - X_0) - v_\theta(X_t, t) \right\|^2 \right ] \mathrm{d}t, 
\quad \text{where} \quad
X_t = t X_1 + (1-t) X_0.
$$

The `get_loss` method computes the rectified flow loss using:
- **Inputs**:
  - `x_0`: Samples from $\pi_0$.
  - `x_1`: Samples from $\pi_1$.
  - `labels` (optional): Provides conditional information, e.g., GMM component idx.
- **Steps**:
  1. **Interpolation**: Computes intermediate states $X_t$ and derivatives $\dot{X}_t$.
  2. **Prediction**: Predicts $v_\theta(X_t, t)$ using the velocity model.
  3. **Loss**: Measures the loss between $v_\theta(X_t, t)$ and $\dot{X}_t$, with time-dependent weighting.

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
batch_size = 1024

losses = []

for step in range(5000):
	optimizer.zero_grad()
	idx = torch.randperm(n_samples)[:batch_size]
	x_0 = D0[idx]
	x_1 = D1[idx]
	
	x_0 = x_0.to(device)
	x_1 = x_1.to(device)
	
	loss = rectified_flow.get_loss(x_0, x_1)
	loss.backward()
	optimizer.step()
	losses.append(loss.item())

	if step % 200 == 0:
		print(f"Epoch {step}, Loss: {loss.item()}")
    
plt.plot(losses)

Then we run the Euler method to solve the ODE with $N = 100$ steps to generate samples from 1-Rectified Flow.

We can see trajectories are "rewired" at the intersection of linear interpolations.

In [None]:
from rectified_flow.samplers import EulerSampler
from rectified_flow.utils import visualize_2d_trajectories_plotly

euler_sampler_1rf_unconditional = EulerSampler(
    rectified_flow=rectified_flow,
    num_steps=100,
    num_samples=500,
)

euler_sampler_1rf_unconditional.sample_loop(seed=0)

visualize_2d_trajectories_plotly(
    trajectories_dict={"1rf uncond": euler_sampler_1rf_unconditional.trajectories},
    D1_gt_samples=D1[:1000],
    num_trajectories=200,
	title="Unconditional 1-Rectified Flow",
)

In [None]:
euler_sampler_1rf_unconditional.sample_loop(num_steps=1, seed=0)

visualize_2d_trajectories_plotly(
    {"1rf uncond one-step": euler_sampler_1rf_unconditional.trajectories}, 
    D1[:1000],
    num_trajectories=200,
	title="Unconditional 1-Rectified Flow, 1-step",
)

### Learning a Conditional Rectified Flow

The rectified flow model can be extended to include class conditioning. By passing class information $c \in \{0, 1\}$ (e.g., for GMM components), the velocity field becomes class-dependent.

$$
\ell = \min_{\theta}
\int_0^1 \mathbb{E}_{X_0 \sim \pi_0, (X_1,c) \sim \pi_1} \left [ \left\| (X_1 - X_0) - v_\theta(X_t, t, c) \right\|^2 \right ] \mathrm{d}t,
\quad \text{where} \quad
X_t = t X_1 + (1-t) X_0.
$$

In this case, $(X_1, c)\sim \pi_1$ is the distribution of data-label pairs.

In [None]:
from rectified_flow.models.toy_mlp import MLPVelocityConditioned

model_cond = MLPVelocityConditioned(2, hidden_sizes = [128, 128, 128]).to(device)

rectified_flow_cond = RectifiedFlow(
    data_shape=(2,),
    velocity_field=model_cond,
    interp="straight",
    source_distribution=pi_0,
    device=device,
)

In [None]:
optimizer = torch.optim.Adam(model_cond.parameters(), lr=1e-3)
batch_size = 1024

losses = []

for step in range(5000):
	optimizer.zero_grad()
	idx = torch.randperm(n_samples)[:batch_size]
	x_0 = D0[idx]
	x_1, cond = D1[idx], labels[idx]
	
	x_0 = x_0.to(device)
	x_1 = x_1.to(device)
	cond = torch.tensor(cond).to(device)
	
	loss = rectified_flow_cond.get_loss(x_0, x_1, labels=cond)
	loss.backward()
	optimizer.step()
	losses.append(loss.item())

	if step % 200 == 0:
		print(f"Epoch {step}, Loss: {loss.item()}")
    
plt.plot(losses)

By incorporating class information, the model can better capture the structure of conditional distributions. This ensures that the velocity fields for different classes (e.g., $c \in \{0, 1\}$) remain distinct, avoiding intersections in the middle of trajectories.

In [None]:
from rectified_flow.samplers import EulerSampler
from rectified_flow.utils import visualize_2d_trajectories_plotly

euler_sampler_1rf_conditional = EulerSampler(
    rectified_flow=rectified_flow_cond,
    num_steps=100,
    num_samples=500,
)

cond = torch.zeros((500,), device=device)
euler_sampler_1rf_conditional.sample_loop(seed=0, labels=cond)

visualize_2d_trajectories_plotly(
    {"1rf cond": euler_sampler_1rf_conditional.trajectories},
    D1[:1000],
    num_trajectories=200,
    title="Conditional 1-Rectified Flow",
)

## Reflow for 2-Rectified Flow

Now let's try the *reflow* procedure to get a straightened rectified flow, 
denoted as 2-Rectified Flow, by repeating the same procedure on with $(X_0,X_1)$ replaced by  $(Z_0^1, Z_1^1)$, where $(Z_0^1, Z_1^1)$ is the coupling simulated from 1-Rectified Flow.  

We sample $50,000$ $Z_0^1$ and generate their corresponding $Z_1^1$ by simulating 1-Rectified Flow.

In [None]:
Z_0 = rectified_flow.sample_source_distribution(batch_size=50000)

Z_1 = euler_sampler_1rf_unconditional.sample_loop(x_0=Z_0, num_steps=1000).trajectories[-1]

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
batch_size = 1024

losses = []

for step in range(5000):
	optimizer.zero_grad()
	idx = torch.randperm(n_samples)[:batch_size]
	x_0 = Z_0[idx]
	x_1 = Z_1[idx]
	
	x_0 = x_0.to(device)
	x_1 = x_1.to(device)
	
	loss = rectified_flow.get_loss(x_0, x_1)
	loss.backward()
	optimizer.step()
	losses.append(loss.item())

	if step % 200 == 0:
		print(f"Epoch {step}, Loss: {loss.item()}")
    
plt.plot(losses)

In [None]:
euler_sampler_2rf = EulerSampler(
    rectified_flow=rectified_flow,
    num_samples=1000,
)

euler_sampler_2rf.sample_loop(num_steps=100, seed=0)

visualize_2d_trajectories_plotly(
    {"2rf": euler_sampler_2rf.trajectories}, 
    D1[:1000],
    num_trajectories=200,
    title="Reflow Trajectories, 100-step",
)

In [None]:
euler_sampler_2rf.sample_loop(num_steps=1, seed=0)

visualize_2d_trajectories_plotly(
    {"2rf one-step" :euler_sampler_2rf.trajectories}, 
    D1[:1000],
    num_trajectories=200,
    title="Reflow Trajectories, 1-step",
)