In [1]:
import os

os.chdir("..")

In [2]:
import torch
from diffusers import FluxKontextPipeline
from PIL import Image
import helpers.drawing as drawing
import numpy as np
import cv2
import shutil
from rembg import remove
import io

OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.


In [3]:
# !export HF_HUB_CACHE=./cache

## Note
Here, we start with character images, then add a background to it. This makes up the target image.
Then we remove the character from the target image. This makes the input base image.
We use a canvas with white background, and place the character on it. This makes the reference image.

## Normal FluxKontext Generation

In [5]:
pipe = FluxKontextPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16
)
pipe = pipe.to("cuda")


Loading pipeline components...:   0%|                                                 | 0/7 [00:00<?, ?it/s]
Loading checkpoint shards:   0%|                                                      | 0/3 [00:00<?, ?it/s][A
Loading checkpoint shards: 100%|██████████████████████████████████████████████| 3/3 [00:00<00:00, 29.36it/s][A
Loading pipeline components...:  29%|███████████▋                             | 2/7 [00:00<00:00,  5.39it/s]
Loading checkpoint shards:   0%|                                                      | 0/2 [00:00<?, ?it/s][A
Loading checkpoint shards:  50%|███████████████████████                       | 1/2 [00:00<00:00,  6.99it/s][A
Loading checkpoint shards: 100%|██████████████████████████████████████████████| 2/2 [00:00<00:00,  7.01it/s][A
Loading pipeline components...:  57%|███████████████████████▍                 | 4/7 [00:00<00:00,  5.46it/s]You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
Loading pipeline componen

## Extract the subject/character

In [6]:

def generate_image(base_scene_path, save_path, prompt = "Add character to the image.", width=None, height=None):
    if isinstance(base_scene_path, str):
        base_scene = Image.open(base_scene_path).convert("RGB")
    else:
        base_scene = base_scene_path.convert("RGB")

    if width is None:
        width, height = base_scene.size

    seed = torch.Generator().manual_seed(42)

    result_img_base = pipe(
        prompt=prompt,
        image=base_scene,
        num_inference_steps=28,
        height=height,
        width=width,
        # generator=seed,
        _auto_resize=False,
        max_area=width*height
    ).images[0]
    result_img_base.save(save_path)
    return result_img_base

100%|███████████████████████████████████████████████████████████████████████| 14/14 [00:17<00:00,  1.23s/it]


In [3]:
def remove_background(
    image: Image.Image,
    cropped: bool = True,
    padding: int = 10
) -> Image.Image:
    """Remove background and optionally crop to content."""
    # Convert the PIL image to bytes
    buffer = io.BytesIO()
    image.save(buffer, format="PNG")
    
    # Remove background
    output_bytes = remove(buffer.getvalue())
    output_image = Image.open(io.BytesIO(output_bytes)).convert("RGBA")
    
    if cropped:
        # Crop to content with padding
        alpha = output_image.split()[3]
        bbox = alpha.getbbox()
        if bbox:
            left, upper, right, lower = bbox
            width, height = output_image.size

            # Apply padding within bounds
            left = max(0, left - padding)
            upper = max(0, upper - padding)
            right = min(width, right + padding)
            lower = min(height, lower + padding)

            output_image = output_image.crop((left, upper, right, lower))
    
    return output_image


In [None]:
# Generate character from the prompt
prompt_src = "custom_data/character-spatial/character.txt"
all_prompts = []



with open(prompt_src, "r") as f:
    prompts = f.readlines()

for prompt in prompts:
    prompt = prompt.strip()
    all_prompts.append(prompt)

all_prompts = " ".join(all_prompts).split("---")
all_prompts = [prompt.strip() for prompt in all_prompts]

print(len(all_prompts))

144 105


In [None]:
src = "custom_data/character-spatial"

end_dir = os.path.join(src, "end")
image_extensions = (".png", ".jpg", ".jpeg", ".webp")
image_paths = [
    os.path.join(end_dir, fname)
    for fname in os.listdir(end_dir)
    if fname.lower().endswith(image_extensions)
]
image_paths.sort()

# Create end, reference and start folders, if it doesnot exists
os.makedirs(os.path.join(src, "end"), exist_ok=True)
os.makedirs(os.path.join(src, "character"), exist_ok=True)
os.makedirs(os.path.join(src, "reference"), exist_ok=True)
os.makedirs(os.path.join(src, "start"), exist_ok=True)

len(image_paths)

45

In [None]:

# Do this only once. Ignore if your dataset is already named properly
# Rename all images in end_dir as 000.<ext>, 001.<ext>, etc.
for idx, old_path in enumerate(sorted(image_paths)):
    ext = os.path.splitext(old_path)[1].lower()
    new_name = f"{idx:03d}{ext}"
    new_path = os.path.join(end_dir, new_name)
    shutil.move(old_path, new_path)

# Update image_paths to reflect new names
image_paths = [
    os.path.join(end_dir, fname)
    for fname in sorted(os.listdir(end_dir))
    if fname.lower().endswith(image_extensions)
]

In [None]:
i=0

In [None]:
# Create target from the prompt
for i in range(len(all_prompts)):
  w = 1392
  h = 756
  if random.random() < 0.5:
    w = 768
    h = 768

  end_prompt = all_prompts[i]
  end_dir = os.path.join(src, "end")
  end_save_path = os.path.join(end_dir, f"boy_{i:03d}.png")
  end = generate_image(end_dir, end_save_path, end_prompt, w, h)

In [None]:
# Create character image
for i in range(len(image_paths)):
  image_path = image_paths[i]
  file_name = os.path.basename(image_path).split(".")[0]
  names = file_name.split("_")[:-1]
  names = " ".join(names)
  end_prompt = f"Extract the {names} from the image, in a white background."

  w = 768
  h = 768

  end_save_path = image_path.replace("end", "end")
  end = generate_image(image_path, end_save_path, end_prompt, w, h)

In [None]:
# For more editing
reference_prompt = "rotate the cup"

generate_image(end, end_save_path, end_prompt)

In [None]:
start_prompt = "Remove the boy from the image."
for i in range(len(image_paths)):
  image_path = image_paths[i]
  start_save_path = image_path.replace("end", "start")
  start = generate_image(image_path, start_save_path, start_prompt)


In [None]:
print(end.size, start.size)

In [None]:

# Ensure both images are the same size and mode
if start.size != end.size or start.mode != end.mode:
    end_resized = end.resize(start.size).convert(start.mode)
else:
    end_resized = end

# Convert to numpy arrays
start_np = np.array(start)
end_np = np.array(end_resized)

# Compute absolute pixel-wise difference
diff_np = np.abs(start_np.astype(np.int16) - end_np.astype(np.int16)).astype(np.uint8)

# Convert back to PIL Image
diff_img = Image.fromarray(diff_np)

# Optionally display or save the difference image
diff_img.resize((diff_img.size[0]//3, diff_img.size[1]//3))


In [None]:
import zipfile
import os

def zip_src_folder(src, zip_path):
    """
    Create a zip archive from the folder specified by src.

    Args:
        src (str): Path to the source directory to zip.
        zip_path (str): Path to the output zip file.
    """
    with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
        for root, dirs, files in os.walk(src):
            for file in files:
                file_path = os.path.join(root, file)
                # Write file to zip, preserving folder structure relative to src
                arcname = os.path.relpath(file_path, start=src)
                zipf.write(file_path, arcname=arcname)

# Example usage:
# src = "path/to/source_folder"
# zip_path = "output.zip"
# zip_src_folder(src, zip_path)



In [None]:
i = 0

In [None]:
image_path = image_paths[i]
start_path = image_path.replace("end", "start")
character_path = image_path.replace("end", "character")

reference = Image.open(character_path)
start = Image.open(start_path)
end = Image.open(image_path)

s = 2
reference = reference.resize((reference.width//s, reference.height//s))
start = start.resize((start.width//s, start.height//s))
end = end.resize((end.width//s, end.height//s))



merged_image = Image.new("RGB", (reference.width + start.width + end.width, start.height))
merged_image.paste(reference, (0, 0))
merged_image.paste(start, (reference.width, 0))
merged_image.paste(end, (reference.width + start.width, 0))

merged_image

NameError: name 'image_paths' is not defined