In [1]:
import os
os.environ["HF_HOME"] = "/tmp/wendler/hf_cache"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import argparse
import sys
import torch
import numpy as np
import cv2
from PIL import Image
import matplotlib.pyplot as plt
from pathlib import Path

# Add necessary paths for custom modules
sys.path.append("/share/u/wendler/code/my-sdxl-unbox")

from SDLens import HookedStableDiffusionXLPipeline
from SAE import SparseAutoencoder
from utils import add_feature_on_area_turbo

import supervision as sv
import pycocotools.mask as mask_util
from torchvision.ops import box_convert

# Grounded SAM2 and Grounding DINO imports
sys.path.append("/share/u/wendler/code/Grounded-SAM-2")
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from grounding_dino.groundingdino.util.inference import load_model, predict
import grounding_dino.groundingdino.datasets.transforms as T


In [2]:
k = 160
exp = 4
n_steps = 4
m1 = 1.
k_transfer = 80*4
use_down = True
use_up = True
use_up0 = True
use_mid = True
n_examples_per_edit = 1000
prefix = '../results/PIE-Bench-final'
path_to_checkpoints = "/share/u/wendler/code/my-sdxl-unbox/hparam_study/"
dtype = "float32"
mode = "sae_1"
keep_spatial_info = False

In [3]:
code_to_block = {
        "down.2.1": "unet.down_blocks.2.attentions.1",
        "up.0.1": "unet.up_blocks.0.attentions.1",
        "up.0.0": "unet.up_blocks.0.attentions.0",
        "mid.0": "unet.mid_block.attentions.0",
    }

In [4]:
# Add SDLens/src to sys.path at the top of the scrip

# --- Utility functions ---
def resize_mask(mask, size=(16, 16)):
    # consider all 32 by 32 windows in the mask
    small = cv2.resize(mask.astype(np.float32), size, interpolation=cv2.INTER_LANCZOS4) > 0
    if small.astype(np.float32).sum() == 0:
        tmp = mask.reshape(16, 32, 16, 32).astype(np.float32)
        tmp = tmp.sum(axis=1)
        tmp = tmp.sum(axis=2)
        if (tmp >= 32*32).astype(np.float32).sum() == 0:
            print("trying to fix the mask...")
            # set the maximum gridcell to 1
            amax = tmp.argmax()
            tmp[np.unravel_index(amax, tmp.shape)] = 1
            return tmp.astype(bool)
    return small

from matplotlib import pyplot as plt
from typing import Tuple
import grounding_dino.groundingdino.datasets.transforms as T

def sam_mask(img, prompt, sam2_predictor, grounding_model, BOX_THRESHOLD, TEXT_THRESHOLD):
    def load_image(img) -> Tuple[np.array, torch.Tensor]:
        transform = T.Compose(
            [
                T.RandomResize([800], max_size=1333),
                T.ToTensor(),
                T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
            ]
        )
        image_source = img.convert("RGB")
        image = np.asarray(image_source)
        image_transformed, _ = transform(image_source, None)
        return image, image_transformed
    image_source, image = load_image(img)
    sam2_predictor.set_image(image_source)

    boxes, confidences, labels = predict(
        model=grounding_model,
        image=image,
        caption=prompt,
        box_threshold=BOX_THRESHOLD,
        text_threshold=TEXT_THRESHOLD,
    )

    # process the box prompt for SAM 2
    h, w, _ = image_source.shape
    boxes = boxes * torch.Tensor([w, h, w, h])
    input_boxes = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()


    # FIXME: figure how does this influence the G-DINO model
    # torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()

    #if torch.cuda.get_device_properties(0).major >= 8:
        # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
        #torch.backends.cuda.matmul.allow_tf32 = True
        #torch.backends.cudnn.allow_tf32 = True

    masks, scores, logits = sam2_predictor.predict(
        point_coords=None,
        point_labels=None,
        box=input_boxes,
        multimask_output=False,
    )

    """
    Post-process the output of the model to get the masks, scores, and logits for visualization
    """
    # convert the shape to (n, H, W)
    if masks.ndim == 4:
        masks = masks.squeeze(1)


    confidences = confidences.numpy().tolist()
    class_names = labels

    class_ids = np.array(list(range(len(class_names))))

    labels = [
        f"{class_name} {confidence:.2f}"
        for class_name, confidence
        in zip(class_names, confidences)
    ]

    detections = sv.Detections(
        xyxy=input_boxes,  # (n, 4)
        mask=masks.astype(bool),  # (n, h, w)
        class_id=class_ids
    )

    box_annotator = sv.BoxAnnotator()
    annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections)

    label_annotator = sv.LabelAnnotator()
    annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)

    mask_annotator = sv.MaskAnnotator()
    annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
    return detections, labels, annotated_frame

In [5]:
if dtype == "float16":
    dtype = torch.float16
else:
    dtype = torch.float32
device = 'cuda' if torch.cuda.is_available() else 'cpu'
pipe = HookedStableDiffusionXLPipeline.from_pretrained(
    'stabilityai/sdxl-turbo',
    torch_dtype=dtype,
    device_map="balanced",
    variant=("fp16" if dtype==torch.float16 else None)
)
if dtype == torch.float32:
    pipe.text_encoder_2.to(dtype=dtype)
pipe.set_progress_bar_config(disable=True)

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

In [6]:
SAM2_CHECKPOINT = "/share/u/wendler/code/Grounded-SAM-2/checkpoints/sam2.1_hiera_large.pt"
SAM2_MODEL_CONFIG = "configs/sam2.1/sam2.1_hiera_l.yaml"
GROUNDING_DINO_CONFIG = "/share/u/wendler/code/Grounded-SAM-2/grounding_dino/groundingdino/config/GroundingDINO_SwinT_OGC.py"
GROUNDING_DINO_CHECKPOINT = "/share/u/wendler/code/Grounded-SAM-2/gdino_checkpoints/groundingdino_swint_ogc.pth"
BOX_THRESHOLD = 0.35
TEXT_THRESHOLD = 0.25
sam2_model = build_sam2(SAM2_MODEL_CONFIG, SAM2_CHECKPOINT, device=device)
sam2_predictor = SAM2ImagePredictor(sam2_model)
grounding_model = load_model(
    model_config_path=GROUNDING_DINO_CONFIG,
    model_checkpoint_path=GROUNDING_DINO_CHECKPOINT,
    device=device
)



final text_encoder_type: bert-base-uncased


In [7]:
from functools import partial
import json
import pickle

def create_reference_images(prompt1, prompt2, gsam_prompt1, gsam_prompt2, pipe=pipe, 
         n_steps=1, verbose=False,
         sam_predictor=sam2_predictor, grounding_model=grounding_model, 
         result_name=None):
    import logging
    logger = logging.getLogger(__name__)
    if verbose:
        logger.setLevel(logging.DEBUG)
    else:
        logger.setLevel(logging.INFO)    
    
    seed = 42
    base_imgs1, cache1 = pipe.run_with_cache(
        prompt1,
        positions_to_cache=[],
        num_inference_steps=n_steps,
        guidance_scale=0.0,
        generator=torch.Generator(device='cpu').manual_seed(seed),
        save_input=True,
    )
    base_imgs2, cache2 = pipe.run_with_cache(
        prompt2,
        positions_to_cache=[],
        num_inference_steps=n_steps,
        guidance_scale=0.0,
        generator=torch.Generator(device='cpu').manual_seed(seed),
        save_input=True,
    )
    img1 = base_imgs1[0][0]
    img2 = base_imgs2[0][0]

    # save to path
    img1.save(f"{result_name}_img1.png")
    img2.save(f"{result_name}_img2.png")

    if gsam_prompt1 == "#everything":
        mask1 = np.ones((16, 16), dtype=bool)
    else:
        if "background" in gsam_prompt1:
            detections1, labels1, annotated_frame1 = sam_mask(img1, "foreground", sam2_predictor, grounding_model, BOX_THRESHOLD, TEXT_THRESHOLD)
            masks = [resize_mask(bigmask).astype(np.float32) for bigmask in detections1.mask]
            mask1 = np.stack(masks, axis=0).sum(axis=0).astype(bool)
            mask1 = np.logical_not(mask1)
            if verbose:
                plt.imshow(mask1)
                plt.show()
        elif "~" in gsam_prompt1:
            detections1, labels1, annotated_frame1 = sam_mask(img1, gsam_prompt1.replace("~", ""), sam2_predictor, grounding_model, BOX_THRESHOLD, TEXT_THRESHOLD)
            masks = [resize_mask(bigmask).astype(np.float32) for bigmask in detections1.mask]
            mask1 = ~np.stack(masks, axis=0).sum(axis=0).astype(bool)
        else:
            detections1, labels1, annotated_frame1 = sam_mask(img1, gsam_prompt1, sam2_predictor, grounding_model, BOX_THRESHOLD, TEXT_THRESHOLD)
            masks = [resize_mask(bigmask).astype(np.float32) for bigmask in detections1.mask]
            mask1 = np.stack(masks, axis=0).sum(axis=0).astype(bool)
    if gsam_prompt2 == "#everything":
        mask2 = np.ones((16, 16), dtype=bool)
    else:
        if "background" in gsam_prompt2:
            detections2, labels2, annotated_frame2 = sam_mask(img2, "foreground", sam2_predictor, grounding_model, BOX_THRESHOLD, TEXT_THRESHOLD)
            masks = [resize_mask(bigmask).astype(np.float32) for bigmask in detections2.mask]
            mask2 = np.stack(masks, axis=0).sum(axis=0).astype(bool)
            mask2 = np.logical_not(mask2)
            if verbose:
                plt.imshow(mask2)
                plt.show()
        elif "~" in gsam_prompt2:
            detections2, labels2, annotated_frame2 = sam_mask(img2, gsam_prompt2.replace("~", ""), sam2_predictor, grounding_model, BOX_THRESHOLD, TEXT_THRESHOLD)
            masks = [resize_mask(bigmask).astype(np.float32) for bigmask in detections2.mask]
            mask2 = ~np.stack(masks, axis=0).sum(axis=0).astype(bool)
        else:
            detections2, labels2, annotated_frame2 = sam_mask(img2, gsam_prompt2, sam2_predictor, grounding_model, BOX_THRESHOLD, TEXT_THRESHOLD)
            masks = [resize_mask(bigmask).astype(np.float32) for bigmask in detections2.mask]
            mask2 = np.stack(masks, axis=0).sum(axis=0).astype(bool)
    if mask1.sum() == 0 or mask2.sum() == 0:
        raise ValueError("one of the masks is empty")

    # also save the detections objects and labels 2 and annotated frames using pickle
    with open(f"{result_name}_detections1.pkl", "wb") as f:
        pickle.dump(detections1, f)
    with open(f"{result_name}_detections2.pkl", "wb") as f:
        pickle.dump(detections2, f)
    with open(f"{result_name}_labels1.pkl", "wb") as f:
        pickle.dump(labels1, f)
    with open(f"{result_name}_labels2.pkl", "wb") as f:
        pickle.dump(labels2, f)
    with open(f"{result_name}_annotated_frame1.pkl", "wb") as f:
        pickle.dump(annotated_frame1, f)
    with open(f"{result_name}_annotated_frame2.pkl", "wb") as f:
        pickle.dump(annotated_frame2, f)



In [None]:
import os
import json 
from collections import defaultdict
with open("./generated_piebench.json", "r") as f:
    pb = json.load(f)


expid2name= {"0":"random", # -> remove
"1":"change object", # default
"2":"add object", # done 
"3":"delete object", # done 
"4":"change content", # possible -> (this should apply to all edits i guess as long as the two versions of the image are similar to each other)take edit mask in original, boost editing mask and surpress orignal mask
"5":"change pose", # somewhat possible -> switch
"6":"change color", # done
"7":"change material",# done 
"8":"change background", # done
"9":"change style" # done }
}

def remove_brakets(txt):
     return txt.replace("[","").replace("]","")

cnt = defaultdict(int)


for d in pb:
     try:
          if d["editing_type_id"] in [] or cnt[d["editing_type_id"]] > n_examples_per_edit:
               continue
          key = d["id"]
          print(key)
          path = os.path.join(prefix, f"reference/{d['editing_type_id']}")
          original_prompt = remove_brakets(d["original_prompt"])
          editing_prompt = remove_brakets(d["editing_prompt"])
          os.makedirs(path, exist_ok=True)
          if d["editing_type_id"] in ['0']:
               continue
          else:
               create_reference_images(editing_prompt, original_prompt, d["edit_target"], d["edit_source"], result_name=f"{path}/{key}", n_steps=n_steps)
          cnt[d["editing_type_id"]] += 1
     except Exception as e:
          print(e)
          continue
          

000000000000
000000000001
000000000002
000000000003
000000000004
000000000005
000000000006
000000000007
000000000008
000000000009
000000000010
000000000011
000000000012
000000000013
000000000014
000000000015
000000000016
000000000017
000000000018
000000000019
000000000020
000000000021
000000000022
000000000023
000000000024
000000000025
000000000026
000000000027
000000000028
000000000029
000000000030
000000000031
000000000032
000000000033
000000000034
000000000035
000000000036
000000000037
000000000038
000000000039
000000000040
000000000041
000000000042
000000000043
000000000044
000000000045
000000000046
000000000047
000000000048
000000000049
000000000050
000000000051
000000000052
000000000053
000000000054
000000000055
000000000056
000000000057
000000000058
000000000059
000000000060
000000000061
000000000062
000000000063
000000000064
000000000065
000000000066
000000000067
000000000068
000000000069
000000000070
000000000071
000000000072
000000000073
000000000074
000000000075
000000000076

Falling back to all available kernels for scaled_dot_product_attention (which may have a slower speed).


trying to fix the mask...
trying to fix the mask...
trying to fix the mask...
111000000001
