In [None]:
import os

os.chdir("..")

In [None]:
import torch
from diffusers.pipelines import FluxPipeline
from PIL import Image

from omini.pipeline.flux_omini import Condition, generate, seed_everything, convert_to_condition

In [None]:
pipe = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
)
pipe = pipe.to("cuda")

In [None]:
pipe.unload_lora_weights()

for condition_type in ["canny", "depth", "coloring", "deblurring"]:
    pipe.load_lora_weights(
        "Yuanshi/OminiControl",
        weight_name=f"experimental/{condition_type}.safetensors",
        adapter_name=condition_type,
    )

pipe.set_adapters(["canny", "depth", "coloring", "deblurring"])

In [None]:
image = Image.open("assets/coffee.png").convert("RGB")

w, h, min_dim = image.size + (min(image.size),)
image = image.crop(
    ((w - min_dim) // 2, (h - min_dim) // 2, (w + min_dim) // 2, (h + min_dim) // 2)
).resize((512, 512))

prompt = "In a bright room. A cup of a coffee with some beans on the side. They are placed on a dark wooden table."

In [None]:
canny_image = convert_to_condition("canny", image)
condition = Condition(canny_image, "canny")

seed_everything()

result_img = generate(
    pipe,
    prompt=prompt,
    conditions=[condition],
).images[0]

concat_image = Image.new("RGB", (1536, 512))
concat_image.paste(image, (0, 0))
concat_image.paste(condition.condition, (512, 0))
concat_image.paste(result_img, (1024, 0))
concat_image

In [None]:
depth_image = convert_to_condition("depth", image)
condition = Condition(depth_image, "depth")

seed_everything()

result_img = generate(
    pipe,
    prompt=prompt,
    conditions=[condition],
).images[0]

concat_image = Image.new("RGB", (1536, 512))
concat_image.paste(image, (0, 0))
concat_image.paste(condition.condition, (512, 0))
concat_image.paste(result_img, (1024, 0))
concat_image

In [None]:
blur_image = convert_to_condition("deblurring", image)
condition = Condition(blur_image, "deblurring")

seed_everything()

result_img = generate(
    pipe,
    prompt=prompt,
    conditions=[condition],
).images[0]

concat_image = Image.new("RGB", (1536, 512))
concat_image.paste(image, (0, 0))
concat_image.paste(condition.condition, (512, 0))
concat_image.paste(result_img, (1024, 0))
concat_image

In [None]:
condition_image = convert_to_condition("coloring", image)
condition = Condition(condition_image, "coloring")

seed_everything()

result_img = generate(
    pipe,
    prompt=prompt,
    conditions=[condition],
).images[0]

concat_image = Image.new("RGB", (1536, 512))
concat_image.paste(image, (0, 0))
concat_image.paste(condition.condition, (512, 0))
concat_image.paste(result_img, (1024, 0))
concat_image