<a href="https://colab.research.google.com/github/lqiang67/rectified-flow/blob/main/examples/train_2d_toys_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!git clone https://lqiang67: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@github.com/lqiang67/rectified-flow.git
%cd rectified-flow/

In [None]:
import torch
import os
import sys
import matplotlib.pyplot as plt

import torch.distributions as dist

from rectified_flow.utils import set_seed
from rectified_flow.datasets.toy_gmm import TwoPointGMM

from rectified_flow.rectified_flow import RectifiedFlow
from rectified_flow.models.toy_mlp import MLPVelocityConditioned

set_seed(0)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
n_samples = 50000
pi0 = dist.MultivariateNormal(torch.zeros(2), torch.eye(2))
pi1 = TwoPointGMM(x=15.0, y=7.5, std=0.5)
D0 = pi0.sample([n_samples])
D1, labels = pi1.sample_with_labels([n_samples])
labels.tolist()
print(labels.shape)

In [None]:
plt.scatter(D0[:, 0], D0[:, 1], alpha=0.5, label='D0')
plt.scatter(D1[:, 0], D1[:, 1], alpha=0.5, label='D1')
plt.legend()

In [None]:
model = MLPVelocityConditioned(2, hidden_sizes = [128, 128, 128]).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
batch_size = 1024

rf_func = RectifiedFlow(
    (2),
    model=model,
    interp="sin",
    source_distribution=pi0,
    device=device,
)

losses = []

for step in range(1500):
	optimizer.zero_grad()
	X_0 = pi0.sample([batch_size])
	X_1, cond = pi1.sample_with_labels([batch_size])

	X_0 = X_0.to(device)
	X_1 = X_1.to(device)
	cond = torch.tensor(cond).to(device)

	loss = rf_func.get_loss(X_0, X_1, labels=cond)
	loss.backward()
	optimizer.step()
	losses.append(loss.item())

	if step % 100 == 0:
		print(f"Epoch {step}, Loss: {loss.item()}")

plt.plot(losses)

In [None]:
from rectified_flow.samplers import rf_samplers_dict

euler_sampler = rf_samplers_dict["euler"](
    rectified_flow=rf_func,
    num_steps=100,
    num_samples=500,
)

labels = torch.zeros(500,).to(device)

euler_sampler.sample_loop(labels = labels)

In [None]:
from rectified_flow.utils import visualize_2d_trajectories

visualize_2d_trajectories(
    euler_sampler.trajectories,
    D1[:5000],
    num_trajectories=200,
)

In [None]:
curved_sampler = rf_samplers_dict["curved_euler"](
    rectified_flow=rf_func,
    num_steps=100,
    num_samples=500,
)

labels = torch.ones(500,).to(device)

curved_sampler.sample_loop(labels = labels)

In [None]:
visualize_2d_trajectories(
    curved_sampler.trajectories,
    D1[:5000],
    num_trajectories=200,
)