In [1]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [2]:
import torch
import numpy as np
import pandas as pd
from diffusers.utils import make_image_grid
from transformers import (
    #     SamModel,
    #     SamProcessor,
    #     Blip2Processor,
    #     Blip2ForConditionalGeneration,
    AutoProcessor,
    LlavaForConditionalGeneration,
)
from PIL import Image

from src.eunms import Model_Type, Scheduler_Type
from src.utils.enums_utils import get_pipes
from src.config import RunConfig
from main import run as invert

from attention_maps_utils_by_timesteps import (
    get_attn_maps,
    cross_attn_init,
    register_cross_attention_hook,
    set_layer_with_name_and_path,
    preprocess,
    visualize_and_save_attn_map,
)

  from .autonotebook import tqdm as notebook_tqdm
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  torch.utils._pytree._register_pytree_node(


In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model_type = Model_Type.SDXL_Turbo
scheduler_type = Scheduler_Type.EULER
pipe_inversion, pipe_inference = get_pipes(
    model_type, scheduler_type, device=device, is_optimize_z=True
)
_, pipe_extract_attn_maps = get_pipes(model_type, scheduler_type, device=device)

Keyword arguments {'safety_checker': None} are not expected by StableDiffusionXLImg2ImgOptimizeZPipeline and will be ignored.
  torch.utils._pytree._register_pytree_node(
Loading pipeline components...: 100%|██████████| 7/7 [00:01<00:00,  5.88it/s]
Keyword arguments {'safety_checker': None} are not expected by StableDiffusionXLImg2ImgPipeline and will be ignored.
Loading pipeline components...: 100%|██████████| 7/7 [00:00<00:00,  8.35it/s]


In [4]:
config = RunConfig(
    model_type=model_type,
    scheduler_type=scheduler_type,
    noise_regularization_lambda_kl=0.08,
    noise_regularization_lambda_ac=40,
    num_inversion_steps=4,
    num_inference_steps=4,
)

In [5]:
data_path = (
    "/home/lab/yairshp/projects/insert_object/benchmark/object_placement_images.csv"
)
data = pd.read_csv(data_path, dtype={"bg_img_id": str, "ref_img_id": str})
open_images_path = "/cortex/data/images/OpenImagesV6/images"
cocoee_path = "/cortex/data/images/COCOEE/test_bench/Ref_3500"

# open_images dataset is used for bg_images. add a column of the path to the image (the path is the <open_images_path>/<image_id>.jpg)
data["bg_img_path"] = data["bg_img_id"].apply(lambda x: f"{open_images_path}/{x}.jpg")

# cocoee dataset is used for ref_images. add a column of the path to the image (the path is the <cocoee_path>/<image_id>_ref.png)
data["ref_img_path"] = data["ref_img_id"].apply(lambda x: f"{cocoee_path}/{x}_ref.png")

In [6]:
my_bg_images = [
    "/home/lab/yairshp/projects/insert_object/benchmark/bed.jpeg",
    "/home/lab/yairshp/projects/insert_object/benchmark/desk.jpeg",
    "/home/lab/yairshp/projects/insert_object/benchmark/cabinet.jpeg",
    "/home/lab/yairshp/projects/insert_object/benchmark/face.jpg",
]

my_object_images = [
    "/home/lab/yairshp/projects/insert_object/benchmark/objects/pillow/pillow.jpeg",
    "/home/lab/yairshp/projects/insert_object/benchmark/objects/plant/plant.jpg",
    "/home/lab/yairshp/projects/insert_object/benchmark/objects/vase/vase.jpeg",
    "/home/lab/yairshp/projects/insert_object/benchmark/objects/hat/hat.png",
]

In [7]:
bg_images = list(data["bg_img_path"])
bg_images.extend(my_bg_images)
ref_images = list(data["ref_img_path"])
ref_images.extend(my_object_images)

In [8]:
def preprocess_image(image_path):
    image = Image.open(image_path)
    return image.convert("RGB").resize((512, 512))


bg_images = [preprocess_image(image) for image in bg_images]
ref_images = [preprocess_image(image) for image in ref_images]

In [9]:
# blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
# blip_model = Blip2ForConditionalGeneration.from_pretrained(
#     "Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16
# ).to(device)
llava_model = LlavaForConditionalGeneration.from_pretrained(
    "llava-hf/llava-1.5-7b-hf"
).to(device)
llava_processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
# nlp = spacy.load("en_core_web_sm")

Loading checkpoint shards: 100%|██████████| 3/3 [00:02<00:00,  1.06it/s]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [10]:
def get_prompts(processor, model, images, object_names=None):
    if object_names is None:
        propmt = "USER: <image>\nWhat's in the image (answer in shortest way possible)? ASSISTANT:"
        prompts = [propmt for _ in range(len(images))]
    else:
        prompts = [
            f"USER: <image>\nWhat's in the image that a {ref} can be on (answer in shortest way possible)? ASSISTANT:"
            for ref in object_names
        ]
    answers = []
    for prompt, image in zip(prompts, images):
        inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
        generate_ids = model.generate(**inputs, max_new_tokens=15)
        answer = processor.batch_decode(
            generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )[0]
        # the answer should be the string after the last colon
        answer = answer.split(":")[-1].strip()
        answers.append(answer)
    return answers


def get_edit_prompts(inversion_prompts, objects):
    edit_prompts = []
    for inversion_prompt, object in zip(inversion_prompts, objects):
        edit_prompts.append(f"a {inversion_prompt} and a {object}".lower())
    return edit_prompts

In [11]:
object_names = get_prompts(llava_processor, llava_model, ref_images)
inversion_prompts = ["an image" for _ in range(len(bg_images))]
edit_prompts = [f"an image of a {o.lower()}" for o in object_names]

In [12]:
for inversion_prompt, object_name, edit_prompt in zip(
    inversion_prompts, object_names, edit_prompts
):
    print(f"Prompt: {inversion_prompt} + {object_name} -> {edit_prompt}")

Prompt: an image + Glass -> an image of a glass
Prompt: an image + Glass -> an image of a glass
Prompt: an image + Pepsi -> an image of a pepsi
Prompt: an image + Cat -> an image of a cat
Prompt: an image + Vase -> an image of a vase
Prompt: an image + Orange -> an image of a orange
Prompt: an image + Glass -> an image of a glass
Prompt: an image + Plant -> an image of a plant
Prompt: an image + Suitcase -> an image of a suitcase
Prompt: an image + Teddy bear -> an image of a teddy bear
Prompt: an image + Dog -> an image of a dog
Prompt: an image + Glass -> an image of a glass
Prompt: an image + Pillow -> an image of a pillow
Prompt: an image + Plant -> an image of a plant
Prompt: an image + Vase -> an image of a vase
Prompt: an image + Hat -> an image of a hat


In [13]:
def invert_images(images, prompts):
    inv_latents = []
    noises = []
    for image, prompt in zip(images, prompts):
        _, inv_latent, noise, _ = invert(
            image,
            prompt,
            config,
            pipe_inversion=pipe_inversion,
            pipe_inference=pipe_inference,
            do_reconstruction=False,
        )
        inv_latents.append(inv_latent)
        noises.append(noise)
    return inv_latents, noises

In [14]:
inv_latents, noises = invert_images(bg_images, inversion_prompts)

Inverting...


100%|██████████| 4/4 [00:05<00:00,  1.36s/it]


Inverting...


100%|██████████| 4/4 [00:05<00:00,  1.28s/it]


Inverting...


100%|██████████| 4/4 [00:04<00:00,  1.22s/it]


Inverting...


100%|██████████| 4/4 [00:04<00:00,  1.22s/it]


Inverting...


100%|██████████| 4/4 [00:04<00:00,  1.22s/it]


Inverting...


100%|██████████| 4/4 [00:04<00:00,  1.22s/it]


Inverting...


100%|██████████| 4/4 [00:04<00:00,  1.22s/it]


Inverting...


100%|██████████| 4/4 [00:04<00:00,  1.23s/it]


Inverting...


100%|██████████| 4/4 [00:04<00:00,  1.24s/it]


Inverting...


100%|██████████| 4/4 [00:04<00:00,  1.23s/it]


Inverting...


100%|██████████| 4/4 [00:04<00:00,  1.24s/it]


Inverting...


100%|██████████| 4/4 [00:04<00:00,  1.23s/it]


Inverting...


100%|██████████| 4/4 [00:04<00:00,  1.23s/it]


Inverting...


100%|██████████| 4/4 [00:04<00:00,  1.24s/it]


Inverting...


100%|██████████| 4/4 [00:04<00:00,  1.24s/it]


Inverting...


100%|██████████| 4/4 [00:05<00:00,  1.29s/it]


In [15]:
cross_attn_init()
pipe_inference.unet = set_layer_with_name_and_path(pipe_inference.unet)
pipe_inference.unet = register_cross_attention_hook(pipe_inference.unet)
pipe_extract_attn_maps.unet = set_layer_with_name_and_path(pipe_extract_attn_maps.unet)
pipe_extract_attn_maps.unet = register_cross_attention_hook(pipe_extract_attn_maps.unet)

In [16]:
def prompt2tokens(tokenizer, prompt):
    text_inputs = tokenizer(
        prompt,
        padding="max_length",
        max_length=tokenizer.model_max_length,
        truncation=True,
        return_tensors="pt",
    )
    text_input_ids = text_inputs.input_ids
    tokens = []
    for text_input_id in text_input_ids[0]:
        token = tokenizer.decoder[text_input_id.item()]
        tokens.append(token)
    return tokens


def process_attn_map(timestep_attn_maps, tokenizer, prompt):
    preprocessed_attn_maps = {}
    for k, v in timestep_attn_maps.items():
        v = torch.mean(v.cpu(), axis=0).squeeze(0)
        if v.shape[-1] != 16:
            continue
        preprocessed_attn_maps[k] = v
    attn_map = torch.stack(list(preprocessed_attn_maps.values()), axis=0)
    attn_map = torch.mean(attn_map, axis=0)

    tokens = prompt2tokens(tokenizer, prompt)
    eos_token = tokenizer.eos_token
    for token, token_attn_map in zip(tokens, attn_map):
        if token != eos_token:
            continue
        return token_attn_map


def get_attn_map(edit_prompt, tokenizer):
    attn_maps = get_attn_maps()
    attn_map = process_attn_map(attn_maps[-1], tokenizer, edit_prompt)
    return attn_map


def reset_attn_maps():
    attn_maps = get_attn_maps()
    attn_maps.clear()


def extract_attn_maps(inv_latents, noises, inversion_prompts, edit_prompts):
    edit_images = []
    last_timestep_attn_maps = []
    for inv_latent, noise, inversion_prompt, edit_prompt in zip(
        inv_latents, noises, inversion_prompts, edit_prompts
    ):
        pipe_extract_attn_maps.scheduler.set_noise_list(noise)
        pipe_extract_attn_maps.cfg = config
        edit_image = pipe_extract_attn_maps(
            prompt=edit_prompt,
            num_inference_steps=config.num_inference_steps,
            negative_prompt=inversion_prompt,
            image=inv_latent,
            strength=config.inversion_max_step,
            denoising_start=1.0 - config.inversion_max_step,
            guidance_scale=config.guidance_scale,
        ).images[0]
        edit_images.append(edit_image)
        last_timestep_attn_maps.append(
            get_attn_map(edit_prompt, pipe_extract_attn_maps.tokenizer)
        )
        reset_attn_maps()

    return last_timestep_attn_maps

In [17]:
# def get_attn_map(edit_prompt):
#     attn_maps = get_attn_maps()
#     attn_map = preprocess(attn_maps[-1], 512, 512)
#     attn_map_img = visualize_and_save_attn_map(
#         attn_map, pipe_inference.tokenizer, edit_prompt, edit_prompt.split()[-1].lower()
#     )
#     return attn_map_img


# def reset_attn_maps():
#     attn_maps = get_attn_maps()
#     attn_maps.clear()


def get_edit_images(
    inv_latents, noises, inversion_prompts, edit_prompts, all_attn_maps
):
    edit_images = []
    last_timestep_attn_maps = []
    for inv_latent, noise, inversion_prompt, edit_prompt, sample_attn_maps in zip(
        inv_latents, noises, inversion_prompts, edit_prompts, all_attn_maps
    ):
        pipe_inference.scheduler.set_noise_list(noise)
        pipe_inference.cfg = config
        edit_image = pipe_inference(
            prompt=edit_prompt,
            general_attn_map=sample_attn_maps,
            num_inference_steps=config.num_inference_steps,
            negative_prompt=inversion_prompt,
            image=inv_latent,
            strength=config.inversion_max_step,
            denoising_start=1.0 - config.inversion_max_step,
            guidance_scale=config.guidance_scale,
        ).images[0]
        edit_images.append(edit_image)
        # last_timestep_attn_maps.append(get_attn_map(edit_prompt))

    return edit_images, last_timestep_attn_maps

In [19]:
extracted_attn_maps = extract_attn_maps(
    inv_latents, noises, inversion_prompts, edit_prompts
)

  0%|          | 0/4 [00:00<?, ?it/s]

100%|██████████| 4/4 [00:03<00:00,  1.24it/s]
100%|██████████| 4/4 [00:00<00:00, 10.86it/s]
100%|██████████| 4/4 [00:00<00:00, 11.01it/s]
100%|██████████| 4/4 [00:00<00:00, 10.70it/s]
100%|██████████| 4/4 [00:00<00:00, 10.94it/s]
100%|██████████| 4/4 [00:00<00:00, 11.13it/s]
100%|██████████| 4/4 [00:00<00:00, 10.93it/s]
100%|██████████| 4/4 [00:00<00:00, 11.15it/s]
100%|██████████| 4/4 [00:00<00:00, 11.12it/s]
100%|██████████| 4/4 [00:00<00:00, 11.15it/s]
100%|██████████| 4/4 [00:00<00:00, 11.20it/s]
100%|██████████| 4/4 [00:00<00:00, 11.19it/s]
100%|██████████| 4/4 [00:00<00:00, 11.06it/s]
100%|██████████| 4/4 [00:00<00:00, 10.88it/s]
100%|██████████| 4/4 [00:00<00:00, 10.94it/s]
100%|██████████| 4/4 [00:00<00:00, 10.91it/s]


In [22]:
extracted_attn_maps[0].shape

torch.Size([16, 16])

In [None]:
edit_images, last_timestep_attn_maps = get_edit_images(
    inv_latents, noises, inversion_prompts, edit_prompts, extracted_attn_maps
)

In [None]:
sam_model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")

In [None]:
def get_max_x_y(attn_map):
    max_index = np.argmax(attn_map)
    max_x = max_index % attn_map.shape[1]
    max_y = max_index // attn_map.shape[1]
    return max_x, max_y


# def get_placement_mask(edit_image):
def get_placement_mask(attn_map):
    # max_x, max_y = get_max_x_y(np.array(edit_image))
    max_x, max_y = get_max_x_y(attn_map)
    sam_input_points = [[[max_x, max_y]]]
    sam_inputs = sam_processor(
        Image.fromarray(attn_map).convert("RGB"),
        input_points=sam_input_points,
        return_tensors="pt",
    ).to(device)
    # sam_inputs = sam_processor(edit_image, input_points=sam_input_points, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = sam_model(**sam_inputs)

    masks = (
        sam_processor.image_processor.post_process_masks(
            outputs.pred_masks.cpu(),
            sam_inputs["original_sizes"].cpu(),
            sam_inputs["reshaped_input_sizes"].cpu(),
        )[0]
        .squeeze()
        .numpy()
    )

    mask = Image.fromarray(masks[0].astype(np.uint8) * 255)
    return mask


def get_placement_mask_on_bg_image(bg_image, placement_mask):
    image_np = np.array(bg_image)
    mask_np = np.array(placement_mask)

    masked_image_np = np.copy(image_np)
    masked_image_np[mask_np == 255] = [
        255,
        0,
        0,
    ]  # Set RGB values to [255, 0, 0] for red

    masked_image = Image.fromarray(masked_image_np)
    return masked_image

In [None]:
masked_images = []
masks = []
for bg_image, edit_image, attn_map in zip(
    bg_images, edit_images, last_timestep_attn_maps
):
    # placement_mask = get_placement_mask(np.array(attn_map))
    placement_mask = get_placement_mask(edit_image)
    masked_image = get_placement_mask_on_bg_image(bg_image, placement_mask)
    masked_images.append(masked_image)
    masks.append(placement_mask)

In [None]:
anydoor_images_path = "/home/lab/yairshp/projects/third_party/AnyDoor/results/seg"
anydoor_images_paths = [
    f"{anydoor_images_path}/{f}" for f in os.listdir(anydoor_images_path)
]
anydoor_images = [Image.open(image_path) for image_path in anydoor_images_paths]

In [None]:
images_grid_arr = []
for bg_image, ref_image, edit_image, attn_map, masked_image, anydoor_image in zip(
    bg_images,
    ref_images,
    edit_images,
    last_timestep_attn_maps,
    masked_images,
    anydoor_images,
):
    images_grid_arr.append(bg_image)
    images_grid_arr.append(ref_image)
    images_grid_arr.append(edit_image)
    images_grid_arr.append(attn_map)
    images_grid_arr.append(masked_image)
    images_grid_arr.append(anydoor_image)
make_image_grid(images_grid_arr, len(bg_images), 6)

In [None]:
attn_map = last_timestep_attn_maps[-1]

In [None]:
a = np.array(attn_map)

In [None]:
bounding_boxes = []
for mask in masks:
    mask_arr = np.array(mask)
    non_zero_rows, non_zero_cols = np.nonzero(mask_arr == 255)
    top_left_row = np.min(non_zero_rows)
    top_left_col = np.min(non_zero_cols)
    bottom_right_row = np.max(non_zero_rows)
    bottom_right_col = np.max(non_zero_cols)
    bbox = np.zeros_like(mask)
    # bbox[top_left_row, top_left_col:bottom_right_col+1] = 255
    # bbox[bottom_right_row, top_left_col:bottom_right_col+1] = 255
    # bbox[top_left_row:bottom_right_row+1, top_left_col] = 255
    # bbox[top_left_row:bottom_right_row+1, bottom_right_col] = 255
    bbox[top_left_row : bottom_right_row + 1, top_left_col : bottom_right_col + 1] = 255
    bounding_boxes.append(Image.fromarray(bbox))

In [None]:
xxx = []
for bbox, mask in zip(bounding_boxes, masks):
    xxx.append(bbox)
    xxx.append(mask)
make_image_grid(xxx, len(bounding_boxes), 2)

In [None]:
for i, (object_name, mask, bbox) in enumerate(zip(object_names, masks, bounding_boxes)):
    mask.save(
        f"/home/lab/yairshp/projects/insert_object/data/seg_mask_{object_name}_{i}.png"
    )
    bbox.save(
        f"/home/lab/yairshp/projects/insert_object/data/bbox_{object_name}_{i}.png"
    )