In [None]:
import torch
import matplotlib.pyplot as plt

from rectified_flow.flow_components import utils
from rectified_flow.models.dit import DiT
from rectified_flow.rectified_flow import RectifiedFlow

utils.set_seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
dit = DiT.from_pretrained("PATH_TO_MODEL", "dit", True).to(device)

In [None]:
rf_func = RectifiedFlow(
    data_shape=(3, 32, 32),
    model=dit,
    device=device,
)

In [None]:
from rectified_flow.samplers import rf_samplers_dict

X_0 = torch.randn(130, 3, 32, 32, device=device)

euler_sampler = rf_samplers_dict["euler"](
    rectified_flow=rf_func,
    num_steps=100,
    num_samples=130,
)

euler_sampler.sample_loop(X_0=X_0)

In [None]:
traj = euler_sampler.trajectories
print(len(traj))
X_1 = traj[-1]
print(X_1.shape)

utils.plot_cifar_results(X_1)

In [None]:
from scipy.integrate import solve_ivp

@torch.inference_mode()
def rk45(f, z0, startT=0., endT=1.0):
    def f_np(t, z, func, shape):
        z_tensor = torch.tensor(z, dtype=torch.float32, device=device).reshape(shape)
        t_tensor = t * torch.ones(shape[0], device=device)
        dz_tensor = func(z_tensor, t_tensor)   # Dit expects X_t of (B, C, H, W), and t of (B,)
        return dz_tensor.detach().cpu().numpy().reshape((-1,))
    
    shape = z0.shape
    z0_np = z0.cpu().numpy().flatten()
    t_span = (startT, endT)
    sol = solve_ivp(f_np, t_span, z0_np, args=(f, shape), method='RK45', t_eval=None, rtol=1e-3, atol=1e-3)
    z_final_np = sol.y[:, -1]
    nfe = sol.nfev
    print(f"Number of function evaluations: {nfe}")
    z_final = torch.tensor(z_final_np, dtype=torch.float32, device=device).reshape(shape)
    
    return z_final

X_0 = euler_sampler.X_0.clone()
X_1 = rk45(dit, X_0)
utils.plot_cifar_results(X_1)