In [None]:
import torch
import math
import matplotlib.pyplot as plt
import sys
import os

from PIL import Image
from diffusers import FluxPipeline
from torch import Tensor
from torchvision import transforms

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

DTYPE = torch.bfloat16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

pipe = FluxPipeline.from_pretrained("/root/autodl-tmp/Flux-dev", torch_dtype=DTYPE)
pipe.to(device)

In [None]:
from models.flux_dev import FluxWrapper, encode_imgs, decode_imgs, get_time_grid, prepare_packed_latents

height = 1024
width = 1024
height = 16 * (height // 16)
width = 16 * (width // 16)

packed_latents, latent_img_ids = prepare_packed_latents(
    batch_size=1,
	height=height,
	width=width,
    dtype=DTYPE,
    device=device,
)

img_seq_len = packed_latents.shape[1]

time_grid = get_time_grid(
    num_steps=50,
    image_seq_len=img_seq_len,
)

time_grid = [1.0 - t for t in time_grid]

flux_model = FluxWrapper(
    pipeline=pipe,
    latent_image_ids=latent_img_ids,
    DTYPE=DTYPE,
    device=device,
)

In [None]:
from rectified_flow.rectified_flow import RectifiedFlow

rf_func = RectifiedFlow(
    data_shape=packed_latents.shape,
    model=flux_model,
    interp="straight",
    device=device,
	dtype=DTYPE,
)

In [None]:
from rectified_flow.samplers import EulerSampler

euler_sampler = EulerSampler(
    rectified_flow=rf_func,
    num_steps=100,
)

euler_sampler.sample_loop(
    X_0=packed_latents,
    time_grid=time_grid,
    prompt="A photo of a cat holding a camera",
    guidance_scale=3.5,
)