<a href="https://colab.research.google.com/github/lqiang67/rectified-flow/blob/main/examples/natural_euler_sampler.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!git clone https://github.com/lqiang67/rectified-flow.git
%cd rectified-flow/

In [1]:
import torch
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import warnings
import copy
import plotly.graph_objects as go

import torch.distributions as dist

from rectified_flow.utils import set_seed
from rectified_flow.utils import match_dim_with_data
from rectified_flow.datasets.toy_gmm import TwoPointGMM

from rectified_flow.rectified_flow import RectifiedFlow
from rectified_flow.models.toy_mlp import MLPVelocityConditioned, MLPVelocity

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
from rectified_flow.datasets.toy_gmm import TwoPointGMM

set_seed(0)
n_samples = 50000
pi_0 = dist.MultivariateNormal(torch.zeros(2, device=device), torch.eye(2, device=device))
pi_1 = TwoPointGMM(x=15.0, y=2, std=0.3)
D0 = pi_0.sample([n_samples])
D1, labels = pi_1.sample_with_labels([n_samples])
labels.tolist()

from rectified_flow.flow_components.interpolation_solver import AffineInterp
from rectified_flow.utils import visualize_2d_trajectories_plotly

straight_interp = AffineInterp("straight")
spherical_interp = AffineInterp("spherical")

idx = torch.randperm(n_samples)[:1000]
x_0 = D0[idx]
x_1 = D1[idx]

print(x_0.shape)

straight_interp_list = []
spherical_interp_list = []

for t in np.linspace(0, 1, 50):
	x_t_straight, dot_x_t_straight = straight_interp.forward(x_0, x_1, t)
	x_t_spherical, dot_x_t_spherical = spherical_interp.forward(x_0, x_1, t)
	straight_interp_list.append(x_t_straight)
	spherical_interp_list.append(x_t_spherical)

visualize_2d_trajectories_plotly(
	trajectories_dict={"straight interp": straight_interp_list, "spherical interp": spherical_interp_list},
	D1_gt_samples=D1[:5000],
	num_trajectories=50,
	title="Interpolated Trajectories Visualization",
)

In [None]:
from rectified_flow.flow_components.interpolation_convertor import AffineInterpConverter

def rf_trainer(rectified_flow, label = "loss", batch_size = 1024):
    model = rectified_flow.velocity_field
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

    losses = []
    for step in range(5000):
        optimizer.zero_grad()
        x_0 = pi_0.sample([batch_size]).to(device)
        x_1 = pi_1.sample([batch_size]).to(device)

        loss = rectified_flow.get_loss(x_0, x_1)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())

        if step % 1000 == 0:
            print(f"Epoch {step}, Loss: {loss.item()}")

    plt.plot(losses, label=label)
    plt.legend()

from rectified_flow.models.toy_mlp import MLPVelocity

set_seed(0)
straight_rf = RectifiedFlow(
    data_shape=(2,),
    velocity_field=MLPVelocity(2, hidden_sizes = [128, 128, 128]).to(device),
    interp=straight_interp,
    source_distribution=pi_0,
    device=device,
)

set_seed(0)
rf_trainer(rectified_flow=straight_rf, label="straight interp")

spherical_rf = AffineInterpConverter(straight_rf, AffineInterp("spherical")).transform_rectified_flow()

In [None]:
# Both vanilla Euler sampler, noticeable difference in final generated samples

from rectified_flow.samplers import EulerSampler

num_samples = 300
num_steps = 10

euler_sampler_straight = EulerSampler(straight_rf, num_steps=num_steps)
euler_sampler_straight.sample_loop(seed=0, num_samples=num_samples)

euler_sampler_spherical = EulerSampler(spherical_rf, num_steps=num_steps)
euler_sampler_spherical.sample_loop(seed=0, num_samples=num_samples)

visualize_2d_trajectories_plotly(
    trajectories_dict={
        "straight rf": euler_sampler_straight.trajectories,
        "spherical rf": euler_sampler_spherical.trajectories,
	},
    D1_gt_samples=D1[:num_samples*3],
    num_trajectories=50,
    title="Euler Sampler, straight rf vs spherical rf",
)

In [None]:
# Unmatched time grid, both natural euler samplers, nearly identical final generated samples

from rectified_flow.samplers import CurvedEulerSampler

num_samples = 300
num_steps = 10

natural_euler_sampler_straight = CurvedEulerSampler(straight_rf, num_steps=num_steps)
natural_euler_sampler_straight.sample_loop(seed=0, num_samples=num_samples)

natural_euler_sampler_spherical = CurvedEulerSampler(spherical_rf, num_steps=num_steps)
natural_euler_sampler_spherical.sample_loop(seed=0, num_samples=num_samples)

visualize_2d_trajectories_plotly(
    trajectories_dict={
        "straight rf": natural_euler_sampler_straight.trajectories,
        "spherical rf": natural_euler_sampler_spherical.trajectories,
	},
    D1_gt_samples=D1[:num_samples*3],
    num_trajectories=50,
    title="Natural Euler Sampler, straight rf vs spherical rf, unmatched time gird",
)

In [None]:
# Matched time grid, both natural euler samplers, exactly the same final generated samples

def convert_time(t, affine_interp):
    return [affine_interp.alpha(t) / (affine_interp.alpha(t) + affine_interp.beta(t)) for t in t]

converted_time_grid = convert_time(natural_euler_sampler_spherical.time_grid, spherical_interp)
print(converted_time_grid)

natural_euler_sampler_straight = CurvedEulerSampler(straight_rf, time_grid=converted_time_grid)
natural_euler_sampler_straight.sample_loop(seed=0, num_samples=num_samples)

visualize_2d_trajectories_plotly(
    trajectories_dict={
        "straight rf": natural_euler_sampler_straight.trajectories,
        "spherical rf": natural_euler_sampler_spherical.trajectories,
	},
    D1_gt_samples=D1[:num_samples*3],
    num_trajectories=50,
    title="Natural Euler Sampler, straight rf vs spherical rf, matched time gird",
)