<a href="https://colab.research.google.com/github/lqiang67/rectified-flow/blob/main/examples/samplers_2d_toys.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!git clone https://github.com/lqiang67/rectified-flow.git
%cd rectified-flow/

This notebook explores the samplers available in this repository using a 2D toy example. 

- illustrates the concepts and usage of both deterministic and stochastic samplers. 
- demonstrates how to customize a sampler by inheriting from the Sampler base class.
- discusses the effects of employing stochastic samplers.

In [1]:
import torch
import os
import sys
import numpy as np
import matplotlib.pyplot as plt

import torch.distributions as dist

from rectified_flow.utils import set_seed
from rectified_flow.utils import visualize_2d_trajectories_plotly
from rectified_flow.datasets.toy_gmm import TwoPointGMM

from rectified_flow.rectified_flow import RectifiedFlow
from rectified_flow.models.toy_mlp import MLPVelocityConditioned, MLPVelocity

set_seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In this example, we use a 2D toy dataset consisting of:

- $\pi_0$: A standard 2D Gaussian distribution with mean of zeros and identity covariance matrix.
- $\pi_1$: A custom two-point Gaussian mixture model with centers located at $(15, 2)$ and a standard deviation of $0.3$

In [None]:
n_samples = 50000
pi_0 = dist.MultivariateNormal(torch.zeros(2, device=device), torch.eye(2, device=device))
pi_1 = TwoPointGMM(x=15.0, y=2, std=0.3)
D0 = pi_0.sample([n_samples])
D1, labels = pi_1.sample_with_labels([n_samples])
labels.tolist()

plt.scatter(D0[:, 0].cpu().numpy(), D0[:, 1].cpu().numpy(), alpha=0.5, label='D0')
plt.scatter(D1[:, 0].cpu().numpy(), D1[:, 1].cpu().numpy(), alpha=0.5, label='D1')
plt.legend()
plt.xlim(-5, 18)
plt.ylim(-5, 5)
plt.gca().set_aspect('equal', adjustable='box')
plt.show()

In [None]:
def rf_trainer(rectified_flow, label = "loss", batch_size = 1024):
    model = rectified_flow.velocity_field
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    losses = []

    for step in range(5000):
        optimizer.zero_grad()
        idx = torch.randperm(n_samples)[:batch_size]
        x_0 = D0[idx].to(device)
        x_1 = D1[idx].to(device)

        loss = rectified_flow.get_loss(x_0, x_1)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())

        if step % 1000 == 0:
            print(f"Epoch {step}, Loss: {loss.item()}")

    plt.plot(losses, label=label)
    plt.legend()
    
from rectified_flow.models.toy_mlp import MLPVelocity

straight_rf = RectifiedFlow(
    data_shape=(2,),
    velocity_field=MLPVelocity(2, hidden_sizes = [128, 128, 128]).to(device),
    interp="straight",
    source_distribution=pi_0,
    device=device,
)

rf_trainer(straight_rf, "straight interp")

# Samplers

`Rectified Flow` also offers several off-the-shelf `Sampler`s for exploration and study. These prebuilt Samplers are straightforward to use. 

In [None]:
from rectified_flow.samplers import rf_samplers_dict

for key, rf_sampler in rf_samplers_dict.items():
	print(f"{key}: {rf_sampler}")

#### Euler Sampler

The **Euler Sampler** is a simple, deterministic method. It updates each sample $X_t$ by moving along the direction of the velocity field $v(X_t, t)$:

$$
X_{t + \Delta t} = X_t + \Delta t \cdot v(X_t, t)
$$

To implement the **Euler Sampler**, we inherit from the `Sampler` base class and override its `step` method. The base class handles time-stepping and trajectory recording, so we only need to define how the state $x_t$ is updated in each step:

In [None]:
from rectified_flow.samplers.base_sampler import Sampler

class MyEulerSampler(Sampler):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def step(self, **model_kwargs):
        # Extract the current time, next time point, and current state
        t, t_next, x_t = self.t, self.t_next, self.x_t
        
        # Compute the velocity field at the current state and time
        v_t = self.rectified_flow.get_velocity(x_t=x_t, t=t, **model_kwargs)
        
        # Update the state using the Euler formula
        self.x_t = x_t + (t_next - t) * v_t
        
euler_sampler = MyEulerSampler(
    rectified_flow=straight_rf,
    num_steps=50,
    num_samples=100,
)

# Sample method 1)
# Will use the default num_steps and num_samples previously set in the Sampler class
traj1 = euler_sampler.sample_loop(seed=233)

# Sample method 2)
# We can pass in a custom x_0 to sample from
set_seed(233)
x_0 = straight_rf.sample_source_distribution(batch_size=100)
traj2 = euler_sampler.sample_loop(x_0=x_0)

# Sample method 3)
# If we pass in num_steps and num_samples, it will override the default values
traj3 = euler_sampler.sample_loop(seed=233, num_steps=50, num_samples=100)

# three trajectories are the same
visualize_2d_trajectories_plotly(
    trajectories_dict={
        "traj1": traj1.trajectories,
        "traj2": traj2.trajectories,
        "traj3": traj3.trajectories,
    },
    D1_gt_samples=D1[:1000],
    num_trajectories=100,
    title="My Euler Sampler",
)

#### AffineInterp Solver
The interpolation process is governed by two equations:

$$
\begin{align}
    X_t &= \alpha_t \cdot X_1 + \beta_t \cdot X_0, \\
    \dot{X}_t &= \dot{\alpha}_t \cdot X_1 + \dot{\beta}_t \cdot X_0,
\end{align}
$$

where $X_t$ represents the interpolated state at time $t$, and $\dot{X}_t$ is its time derivative.

Given any two of the variables $X_0, X_1, X_t, \dot{X}_t$, the remaining two variables can be uniquely determined for a specific time $t$. 

The `AffineInterp` class provide a `solve` function, which incorporates precomputed symbolic solvers for all possible combinations of known and unknown variables, making it straightforward to compute the missing variables as needed.

In [6]:
x_0 = straight_rf.sample_source_distribution(batch_size=500)
x_1 = D1[:500].to(device)
t = straight_rf.sample_train_time(batch_size=500)
x_t, dot_x_t = straight_rf.get_interpolation(x_0, x_1, t)

# Given x_t and dot_x_t, solve for x_0 and x_1
result = straight_rf.interp.solve(t, x_t=x_t, dot_x_t=dot_x_t)
x_1_pred = result.x_1
x_0_pred = result.x_0

assert torch.allclose(x_0, x_0_pred, atol=1e-4)
assert torch.allclose(x_1, x_1_pred, atol=1e-4)

### Curved Euler Sampler

The **Curved Euler Sampler** uses interpolation to trace a curved path rather than a straight line. It works as follows:

* Starting from the current state $(X_t, t)$
* Use the velocity model to predict the next velocities and generate two reference points $\hat{X}_0$ and $\hat{X}_1$.
* Interpolate between $\hat{X}_0$ and $\hat{X}_1$ using functions $\alpha(t)$ and $\beta(t)$ to get the next state:

$$
X_{t + \Delta_t} = \alpha(t + \Delta t) \cdot \hat{X}_1 + \beta(t + \Delta t) \cdot \hat{X}_0
$$

There's a interesting observation on Cuverd Euler Sampler with different interpolation schemes - they are all equivalent. Check the [natural euler samplers](https://rectifiedflow.github.io/blog/2024/discretization/) for more details.

In [None]:
from rectified_flow.samplers import CurvedEulerSampler

curved_euler_sampler_straight = CurvedEulerSampler(
    rectified_flow=straight_rf,
	num_steps=10,
	num_samples=100,
)

visualize_2d_trajectories_plotly(
    trajectories_dict={
		"straight": curved_euler_sampler_straight.sample_loop(seed=0).trajectories, 
	},
	D1_gt_samples=D1[:1000],
	num_trajectories=100,
	title="Curved Euler Sampler",
)

### Noise Refresh Sampler

This example demonstrates how to create a custom sampler that refreshes the noise component at each step. 

At each point, the sampler predicts the noise component $\hat{X}_0$ and refreshes it by blending it with new noise, using the formula:

$$
X_0' = \sqrt{1 - \eta^2} \cdot \hat{X}_0 + \eta \cdot \epsilon,
$$

where $\eta$ is the noise replacement rate, and $\epsilon$ is a random noise sample. 

In [None]:
class MyNoiseRefreshSampler(Sampler):
    def __init__(self, noise_replacement_rate = lambda t: 0.5, **kwargs):
        super().__init__(**kwargs)
        self.noise_replacement_rate = noise_replacement_rate
        assert (self.rectified_flow.independent_coupling and self.rectified_flow.is_pi_0_zero_mean_gaussian), \
            'pi0 must be a zero mean gaussian and must use indepdent coupling'

    def step(self, **model_kwargs):
        t, t_next, x_t = self.t, self.t_next, self.x_t
        v_t = self.rectified_flow.get_velocity(x_t=x_t, t=t, **model_kwargs)

        # Given x_t and dot_x_t = vt, find the corresponding endpoints x_0 and x_1
        self.rectified_flow.interp.solve(t, x_t=x_t, dot_x_t=v_t)
        x_1_pred = self.rectified_flow.interp.x_1
        x_0_pred = self.rectified_flow.interp.x_0

        # Randomize x_0_pred by replacing part of it with new noise
        noise = self.rectified_flow.sample_source_distribution(self.num_samples)

        noise_replacement_factor = self.noise_replacement_rate(t)
        x_0_pred_refreshed = (
            (1 - noise_replacement_factor**2)**0.5 * x_0_pred +
            noise * noise_replacement_factor
        )

        # Interpolate to find x_t at t_next
        self.rectified_flow.interp.solve(t_next, x_0=x_0_pred_refreshed, x_1=x_1_pred)
        self.x_t = self.rectified_flow.interp.x_t
        
noise_refresh_sampler = MyNoiseRefreshSampler(
    rectified_flow=straight_rf,
    num_steps=50,
    num_samples=500,
)

visualize_2d_trajectories_plotly(
    trajectories_dict={
        "spherical": noise_refresh_sampler.sample_loop(seed=0).trajectories,
    },
    D1_gt_samples=D1[:1000],
    num_trajectories=100,
    title="Noise Refresh Sampler",
)

### SDESampler

The **SDESampler** introduces stochasticity (randomness) into the sampling process. We control the noise at time $(t)$ using the following hyperparameters:

* **noise_scale**: Controls the amount of noise added at each step.

* **noise_decay_rate**: Controls how the noise changes over time. A decay rate of 0 means the noise level remains constant, while a decay rate of 1.0 means the noise decreases over time.

Mathematically, the effective noise at time $(t)$ is given by:

$$
\text{Effective Noise at time } t = \text{step\_size} \times \text{noise\_scale} \times \beta_t^{\text{noise\_decay\_rate}}
$$

Check out this [blog post](https://rectifiedflow.github.io/blog/2024/diffusion/) for more details.

In [None]:
from rectified_flow.samplers import SDESampler

sde_sampler_sphere = SDESampler(
    rectified_flow=straight_rf,
	num_steps=50,
	num_samples=500,
)

visualize_2d_trajectories_plotly(
    trajectories_dict={
		"straight rf sde": sde_sampler_sphere.sample_loop(seed=0).trajectories, 
	},
	D1_gt_samples=D1[:1000],
	num_trajectories=100,
	title="SDE Sampler",
)

We can easily implement a sotchastic sampler that matched the **DDPM** sampling schem:

In [None]:
from rectified_flow.samplers.stochastic_curved_euler_sampler import StochasticCurvedEulerSampler

stochastic_curved_euler_sampler = StochasticCurvedEulerSampler(
    rectified_flow=straight_rf,
	num_steps=100,
	num_samples=1000,
    noise_replacement_rate="ddpm",
)

visualize_2d_trajectories_plotly(
    trajectories_dict={
		"straight": stochastic_curved_euler_sampler.sample_loop(seed=0).trajectories, 
	},
	D1_gt_samples=D1[:1000],
	num_trajectories=100,
	title="Stochastic Curved Euler Sampler",
)

### OvershootingSampler

The OverShootingSampler, introduced in [our AMO Sampler paper]((https://arxiv.org/abs/2411.19415)), adds an extra "overshoot" step during sampling. This means that at each step, it doesn't just move forward along the trajectory, but goes a bit beyond the next point and then comes back, adding more stochasticity and potentially finding better paths. We can control the amount of noise added by:

* c: A parameter controlling how far we overshoot.
* overshooting_method: Determines the exact method used to overshoot (e.g., "t+dt").

In [None]:
from rectified_flow.samplers import OverShootingSampler


sde_sampler = OverShootingSampler(
    rectified_flow=straight_rf,
    num_steps=10,
    num_samples=1000,
    c=15.0,
    overshooting_method="t+dt"
)

sde_sampler.sample_loop()

plt.figure(figsize=(5,3))

# Plot CurvedEulerSampler results
visualize_2d_trajectories_plotly(
    {"overshooting": sde_sampler.trajectories},
    D1[:1000], # D1 defined previously
    num_trajectories=100,
    show_legend=True
)
plt.title("OverShootingSampler")