<a href="https://colab.research.google.com/github/lqiang67/rectified-flow/blob/main/examples/inference_flux_dev_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!git clone https://lqiang67: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@github.com/lqiang67/rectified-flow.git
%cd rectified-flow/

In [None]:
import torch
import math
import matplotlib.pyplot as plt
import sys
import os

from PIL import Image
from diffusers import FluxPipeline
from torch import Tensor
from torchvision import transforms
from IPython.display import clear_output

DTYPE = torch.bfloat16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=DTYPE)
pipe.to(device)
print(f"Loaded pipeline to device {device}")

In [None]:
from rectified_flow.models.flux_dev import FluxWrapper
from rectified_flow.rectified_flow import RectifiedFlow

height = 1025
width = 1025

flux_model = FluxWrapper(
    pipeline=pipe,
    height=height,
    width=width,
    dtype=DTYPE,
    device=device,
	seed=0,
)

rf_func = RectifiedFlow(
    data_shape=flux_model.data_shape,
    model=flux_model,
    interp="straight",
    source_distribution=flux_model.sample_source_distribution,
    device=device,
	dtype=DTYPE,
)

In [None]:
from rectified_flow.samplers import rf_samplers_dict

# Sample and store sampling info for several following samplers
X_0 = rf_func.sample_source_distribution(batch_size=1)
time_grid = flux_model.prepare_time_grid(num_steps=50)
print(f"X_0: {X_0.shape}")
print(f"time_grid: {time_grid}")

def print_time_callback(sampler): # demo callback function, e.g. one-step prediction
    """A callback function to print the current time t, refreshing the Jupyter Notebook output."""
    clear_output(wait=True)
    print(f"Current time: {sampler.t:.4f}")

my_callback = [print_time_callback]

In [None]:
euler_sampler = rf_samplers_dict["euler"](
    rectified_flow=rf_func,
    callbacks=my_callback,
)

euler_sampler.sample_loop(
    X_0=X_0,
    time_grid=time_grid,
    prompt="A photo of a cat holding a camera",
    guidance_scale=3.5,
)

X_1 = euler_sampler.trajectories[-1]
print(X_1.shape)

img = flux_model.decode(X_1)
plt.imshow(img)

In [None]:
noise_refresh_sampler = rf_samplers_dict["noise_refresh"](
    rectified_flow=rf_func,
    callbacks=my_callback,
)

noise_refresh_sampler.sample_loop(
    X_0=X_0,
    time_grid=time_grid,
    prompt="A photo of a cat holding a camera",
    guidance_scale=3.5,
)

X_1 = noise_refresh_sampler.trajectories[-1]
print(X_1.shape)

img = flux_model.decode(X_1)
plt.imshow(img)

In [None]:
overshoot_sampler = rf_samplers_dict["overshooting"](
    rectified_flow=rf_func,
    callbacks=my_callback,
)

overshoot_sampler.sample_loop(
    X_0=X_0,
    time_grid=time_grid,
    prompt="A photo of a cat holding a camera",
    guidance_scale=3.5,
)

X_1 = overshoot_sampler.trajectories[-1]
print(X_1.shape)

img = flux_model.unpack_and_decode(X_1)
plt.imshow(img)