In [8]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import pickle


# Add necessary paths for custom modules
sys.path.append("./sdxl-unbox")

from SDLens import HookedStableDiffusionXLPipeline

# Grounded SAM2 and Grounding DINO imports
sys.path.append("./Grounded-SAM-2")
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from grounding_dino.groundingdino.util.inference import load_model
from gsam_utils import sam_mask, resize_mask

In [None]:
n_steps = 4
use_down = True
use_up = True
use_up0 = True
use_mid = True
n_examples_per_edit = 50
prefix = './results'
dtype = 'float32'

In [None]:
if dtype == "float16":
    dtype = torch.float16
else:
    dtype = torch.float32
device = 'cuda' if torch.cuda.is_available() else 'cpu'
pipe = HookedStableDiffusionXLPipeline.from_pretrained(
    'stabilityai/sdxl-turbo',
    torch_dtype=dtype,
    device_map="balanced",
    variant=("fp16" if dtype==torch.float16 else None)
)
if dtype == torch.float32:
    pipe.text_encoder_2.to(dtype=dtype)
pipe.set_progress_bar_config(disable=True)

In [None]:
SAM2_CHECKPOINT = "./Grounded-SAM-2/checkpoints/sam2.1_hiera_large.pt"
SAM2_MODEL_CONFIG = "configs/sam2.1/sam2.1_hiera_l.yaml"
GROUNDING_DINO_CONFIG = "./Grounded-SAM-2/grounding_dino/groundingdino/config/GroundingDINO_SwinT_OGC.py"
GROUNDING_DINO_CHECKPOINT = "./Grounded-SAM-2/gdino_checkpoints/groundingdino_swint_ogc.pth"
BOX_THRESHOLD = 0.35
TEXT_THRESHOLD = 0.25
sam2_model = build_sam2(SAM2_MODEL_CONFIG, SAM2_CHECKPOINT, device=device)
sam2_predictor = SAM2ImagePredictor(sam2_model)
grounding_model = load_model(
    model_config_path=GROUNDING_DINO_CONFIG,
    model_checkpoint_path=GROUNDING_DINO_CHECKPOINT,
    device=device
)

In [11]:
def create_reference_images(prompt1, prompt2, gsam_prompt1, gsam_prompt2, pipe=pipe, 
         n_steps=1, verbose=False,
         sam_predictor=sam2_predictor, grounding_model=grounding_model, 
         result_name=None):
    import logging
    logger = logging.getLogger(__name__)
    if verbose:
        logger.setLevel(logging.DEBUG)
    else:
        logger.setLevel(logging.INFO)    
    
    seed = 42
    base_imgs1, cache1 = pipe.run_with_cache(
        prompt1,
        positions_to_cache=[],
        num_inference_steps=n_steps,
        guidance_scale=0.0,
        generator=torch.Generator(device='cpu').manual_seed(seed),
        save_input=True,
    )
    base_imgs2, cache2 = pipe.run_with_cache(
        prompt2,
        positions_to_cache=[],
        num_inference_steps=n_steps,
        guidance_scale=0.0,
        generator=torch.Generator(device='cpu').manual_seed(seed),
        save_input=True,
    )
    img1 = base_imgs1[0][0]
    img2 = base_imgs2[0][0]

    # save to path
    img1.save(f"{result_name}_img1.png")
    img2.save(f"{result_name}_img2.png")

    if gsam_prompt1 == "#everything":
        mask1 = np.ones((16, 16), dtype=bool)
    else:
        if "background" in gsam_prompt1:
            detections1, labels1, annotated_frame1 = sam_mask(img1, "foreground", sam2_predictor, grounding_model, BOX_THRESHOLD, TEXT_THRESHOLD)
            masks = [resize_mask(bigmask).astype(np.float32) for bigmask in detections1.mask]
            mask1 = np.stack(masks, axis=0).sum(axis=0).astype(bool)
            mask1 = np.logical_not(mask1)
            if verbose:
                plt.imshow(mask1)
                plt.show()
        elif "~" in gsam_prompt1:
            detections1, labels1, annotated_frame1 = sam_mask(img1, gsam_prompt1.replace("~", ""), sam2_predictor, grounding_model, BOX_THRESHOLD, TEXT_THRESHOLD)
            masks = [resize_mask(bigmask).astype(np.float32) for bigmask in detections1.mask]
            mask1 = ~np.stack(masks, axis=0).sum(axis=0).astype(bool)
        else:
            detections1, labels1, annotated_frame1 = sam_mask(img1, gsam_prompt1, sam2_predictor, grounding_model, BOX_THRESHOLD, TEXT_THRESHOLD)
            masks = [resize_mask(bigmask).astype(np.float32) for bigmask in detections1.mask]
            mask1 = np.stack(masks, axis=0).sum(axis=0).astype(bool)
    if gsam_prompt2 == "#everything":
        mask2 = np.ones((16, 16), dtype=bool)
    else:
        if "background" in gsam_prompt2:
            detections2, labels2, annotated_frame2 = sam_mask(img2, "foreground", sam2_predictor, grounding_model, BOX_THRESHOLD, TEXT_THRESHOLD)
            masks = [resize_mask(bigmask).astype(np.float32) for bigmask in detections2.mask]
            mask2 = np.stack(masks, axis=0).sum(axis=0).astype(bool)
            mask2 = np.logical_not(mask2)
            if verbose:
                plt.imshow(mask2)
                plt.show()
        elif "~" in gsam_prompt2:
            detections2, labels2, annotated_frame2 = sam_mask(img2, gsam_prompt2.replace("~", ""), sam2_predictor, grounding_model, BOX_THRESHOLD, TEXT_THRESHOLD)
            masks = [resize_mask(bigmask).astype(np.float32) for bigmask in detections2.mask]
            mask2 = ~np.stack(masks, axis=0).sum(axis=0).astype(bool)
        else:
            detections2, labels2, annotated_frame2 = sam_mask(img2, gsam_prompt2, sam2_predictor, grounding_model, BOX_THRESHOLD, TEXT_THRESHOLD)
            masks = [resize_mask(bigmask).astype(np.float32) for bigmask in detections2.mask]
            mask2 = np.stack(masks, axis=0).sum(axis=0).astype(bool)
    if mask1.sum() == 0 or mask2.sum() == 0:
        raise ValueError("one of the masks is empty")

    # also save the detections objects and labels 2 and annotated frames using pickle
    with open(f"{result_name}_detections1.pkl", "wb") as f:
        pickle.dump(detections1, f)
    with open(f"{result_name}_detections2.pkl", "wb") as f:
        pickle.dump(detections2, f)
    with open(f"{result_name}_labels1.pkl", "wb") as f:
        pickle.dump(labels1, f)
    with open(f"{result_name}_labels2.pkl", "wb") as f:
        pickle.dump(labels2, f)
    with open(f"{result_name}_annotated_frame1.pkl", "wb") as f:
        pickle.dump(annotated_frame1, f)
    with open(f"{result_name}_annotated_frame2.pkl", "wb") as f:
        pickle.dump(annotated_frame2, f)



In [None]:
import os
import json 
from collections import defaultdict
with open("./dataset/riebench.json", "r") as f:
    rb = json.load(f)


def remove_brakets(txt):
     return txt.replace("[","").replace("]","")

cnt = defaultdict(int)

for d in rb:
     try:
          if d["editing_type_id"] in [] or cnt[d["editing_type_id"]] >= n_examples_per_edit:
               continue
          key = d["id"]
          print(key)
          path = os.path.join(prefix, f"reference/{d['editing_type_id']}")
          original_prompt = remove_brakets(d["original_prompt"])
          editing_prompt = remove_brakets(d["editing_prompt"])
          os.makedirs(path, exist_ok=True)
          if d["editing_type_id"] in ['0']:
               continue
          else:
               create_reference_images(editing_prompt, original_prompt, d["editing_gsam_prompt"], d["original_gsam_prompt"], result_name=f"{path}/{key}", n_steps=n_steps)
          cnt[d["editing_type_id"]] += 1
     except Exception as e:
          print(e)
          continue
          