<a href="https://colab.research.google.com/github/lqiang67/rectified-flow/blob/main/examples/inference_flux_dev_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!git clone https://lqiang67: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@github.com/lqiang67/rectified-flow.git
%cd rectified-flow/

In [None]:
from huggingface_hub import login
login(token='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') # Need to anonymize this token when published

In [None]:
import torch
import math
import matplotlib.pyplot as plt
import sys
import os

from PIL import Image
from diffusers import FluxPipeline
from torch import Tensor
from torchvision import transforms
from IPython.display import clear_output

DTYPE = torch.bfloat16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=DTYPE)
pipe.to(device)
print(f"Loaded pipeline to device {device}")

In [None]:
from rectified_flow.models.flux_dev import FluxWrapper
from rectified_flow.rectified_flow import RectifiedFlow
from rectified_flow.utils import set_seed

set_seed(0)

height = 1025
width = 1025

flux_model = FluxWrapper(
    pipeline=pipe,
    height=height,
    width=width,
    dtype=DTYPE,
    device=device,
)

rectified_flow = RectifiedFlow(
    data_shape=flux_model.dit_latent_shape,
    model=flux_model,
    interp="straight",
    device=device,
	dtype=DTYPE,
)

In [None]:
# Sample and store sampling info for several following samplers
x_0 = rectified_flow.sample_source_distribution(batch_size=1)
time_grid = flux_model.prepare_time_grid(num_steps=50)
print(f"x_0: {x_0.shape}")
print(f"time_grid: {time_grid}")

def print_time_callback(sampler): # demo callback function, e.g. one-step prediction
    """A callback function to print the current time t, refreshing the Jupyter Notebook output."""
    clear_output(wait=True)
    print(f"Current time: {sampler.t:.4f}")

my_callback = [print_time_callback]

In [None]:
from rectified_flow.samplers import rf_samplers_dict

for sampler_name, sampler_class in rf_samplers_dict.items():
	print(f"Sampler: {sampler_name}, class: {sampler_class}")

## Euler Sampler

$$
X_{t+1} = X_{t} + \Delta t \cdot v_{\theta}(X_t, t)
$$

In [None]:
from rectified_flow.samplers import EulerSampler

euler_sampler = EulerSampler(
    rectified_flow=rectified_flow,
    callbacks=my_callback,
)

euler_sampler.sample_loop(
    x_0=x_0,
    time_grid=time_grid,
    prompt="A photo of a cat holding a camera",
    guidance_scale=3.5,
)

x_1 = euler_sampler.trajectories[-1]
print(x_1.shape)

img = flux_model.decode(x_1)
plt.imshow(img)

## SDE Sampler

In [None]:
from rectified_flow.samplers import SDESampler

sde_sampler = SDESampler(
    rectified_flow=rectified_flow,
	callbacks=my_callback,
    noise_scale=5.0,
    noise_decay_rate=1.0,
    noise_method="stable",
    ode_method="curved",
)

sde_sampler.sample_loop(
    x_0=x_0,
	time_grid=time_grid,
	prompt="A photo of a cat holding a sign, say 'Rectified Flow'",
	guidance_scale=3.5,
)

x_1 = sde_sampler.trajectories[-1]

img = flux_model.decode(x_1)

plt.imshow(img)

## Noise Refresh Sampler

In [None]:
from rectified_flow.samplers import NoiseRefreshSampler

noise_refresh_sampler = NoiseRefreshSampler(
    rectified_flow=rectified_flow,
    callbacks=my_callback,
    noise_replacement_rate=lambda t: 0.3,
)

noise_refresh_sampler.sample_loop(
    x_0=x_0,
    time_grid=time_grid,
    prompt="A photo of a cat holding a camera",
    guidance_scale=3.5,
)

x_1 = noise_refresh_sampler.trajectories[-1]
print(x_1.shape)

img = flux_model.decode(x_1)
plt.imshow(img)

In [None]:
from rectified_flow.samplers import OverShootingSampler

overshoot_sampler = OverShootingSampler(
    rectified_flow=rectified_flow,
    callbacks=my_callback,
)

overshoot_sampler.sample_loop(
    x_0=x_0,
    time_grid=time_grid,
    prompt="A photo of a cat holding a camera",
    guidance_scale=3.5,
)

x_1 = overshoot_sampler.trajectories[-1]
print(x_1.shape)

img = flux_model.decode(x_1)
plt.imshow(img)