In [None]:
import torch
import math
import matplotlib.pyplot as plt
import sys
import os

from PIL import Image
from diffusers import FluxPipeline
from torch import Tensor
from torchvision import transforms
from IPython.display import clear_output

DTYPE = torch.bfloat16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=DTYPE)
pipe.to(device)
print(f"Loaded pipeline to device {device}")

In [None]:
from rectified_flow.models.flux_dev import FluxWrapper
from rectified_flow.rectified_flow import RectifiedFlow
from rectified_flow.utils import set_seed

set_seed(0)

height = 1024
width = 1024

flux_model = FluxWrapper(
    pipeline=pipe,
    height=height,
    width=width,
    dtype=DTYPE,
    device=device,
)

rectified_flow = RectifiedFlow(
    data_shape=flux_model.dit_latent_shape,
    velocity_field=flux_model,
    interp="straight",
    device=device,
	dtype=DTYPE,
)

In [None]:
img = Image.open("./example/cat.png")

train_transforms = transforms.Compose(
            [
                transforms.Resize(1024, interpolation=transforms.InterpolationMode.BILINEAR),
                transforms.CenterCrop(1024),
                transforms.ToTensor(),
                transforms.Normalize([0.5], [0.5]),
            ]
        )

img = train_transforms(img).unsqueeze(0).to("cuda").to(DTYPE)

img_latent = flux_model.encode(img)
img_rec = flux_model.decode(img_latent)

plt.imshow(img_rec)

# A simple noisy interpolation editing

In [None]:
x_1 = img_latent.clone()

noise = rectified_flow.sample_source_distribution(1)

noisy_time = 0.4
t = torch.full((1,), noisy_time)
t = rectified_flow.match_dim_with_data(t, x_1.shape, expand_dim=True)

x_t = t * x_1 + (1 - t) * noise

time_grid = torch.linspace(noisy_time, 1, 50).tolist()

x_t_rec = flux_model.decode(x_t)
plt.imshow(x_t_rec)

In [None]:
from rectified_flow.samplers import EulerSampler

sampler = EulerSampler(
    rectified_flow=rectified_flow,
    time_grid=time_grid,
)

prompt = "a photo of a sitting tiger"

x_1 = sampler.sample_loop(x_0=x_t, prompt=prompt, guidance_scale=3.5).trajectories[-1]

x_1_rec = flux_model.decode(x_1)
plt.imshow(x_1_rec)

# A more advanced soft interpolation editing

In [None]:
from rectified_flow.samplers import EulerSampler

import math

class EditingEulerSampler(EulerSampler):
    def __init__(self, start_t, end_t, eta_base, eta_trend, analytic_velocity, alpha=None, **kwargs):
        super().__init__(**kwargs)
        self.start_t = start_t
        self.end_t = end_t
        self.eta_base = eta_base
        self.eta_trend = eta_trend
        self.alpha = alpha
        self.analytic_velocity = analytic_velocity
        
    @staticmethod
    def get_eta_value(t, start_t, end_t, eta, eta_trend, alpha=2):
        assert 0 <= start_t < end_t <= 1.0, "start_t and end_t must be in [0, 1] range, and start_t < end_t"
        assert 0 <= eta <= 1.0, "eta must be in [0, 1] range"
        if t < start_t or t > end_t:
            return 0.
        
        tau = (t - start_t) / (end_t - start_t)
        if eta_trend == 'constant':
            return eta
        elif eta_trend == 'linear_decrease':
            return eta * (1 - tau)
        elif eta_trend == 'exponential_decrease':
            if abs(alpha) < 1e-5:
                return eta * (1 - tau)
            else:
                numerator = math.exp(-alpha * tau) - math.exp(-alpha)
                denominator = 1 - math.exp(-alpha)
                return eta * (numerator / denominator)
        else:
            raise NotImplementedError(f"Unsupported eta_trend: {eta_trend}")
        
    def get_velocity(self, **kwargs):
        x_t, t = self.x_t, self.t
        t = self.rectified_flow.match_dim_with_data(t, x_t.shape, expand_dim=False)
        
        target_velocity = self.analytic_velocity(x_t, t)
        original_velocity = self.rectified_flow.get_velocity(x_t, t, **kwargs)
        
        eta = self.get_eta_value(t, self.start_t, self.end_t, self.eta_base, self.eta_trend, self.alpha)
        
        return eta * target_velocity + (1 - eta) * original_velocity

In [None]:
from rectified_flow.models.gauss_analytic import AnalyticGaussianVelocity

analytic_velocity = AnalyticGaussianVelocity(dataset=x_1, interp=rectified_flow.interp)

In [None]:
sampler = EditingEulerSampler(
    rectified_flow=rectified_flow,
    time_grid=time_grid,
    num_samples=1,
    start_t=0.,
    end_t=0.7,
    eta_base=0.95,
    eta_trend='exponential_decrease',
)

prompt = "a photo of a sitting tiger"

x_1 = sampler.sample_loop(seed=0, prompt=prompt, guidance_scale=3.5).trajectories[-1]

x_1_rec = flux_model.decode(x_1)
plt.imshow(x_1_rec)