In [1]:
import os
os.environ["HF_HOME"] = "/tmp/wendler/hf_cache"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import argparse
import sys
import torch
import numpy as np
import cv2
from PIL import Image
import matplotlib.pyplot as plt
from pathlib import Path

# Add necessary paths for custom modules
sys.path.append("/share/u/wendler/code/my-sdxl-unbox")

from SDLens import HookedStableDiffusionXLPipeline
from SAE import SparseAutoencoder
from utils import add_feature_on_area_turbo

import supervision as sv
import pycocotools.mask as mask_util
from torchvision.ops import box_convert

# Grounded SAM2 and Grounding DINO imports
sys.path.append("/share/u/wendler/code/Grounded-SAM-2")
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from grounding_dino.groundingdino.util.inference import load_model, predict
import grounding_dino.groundingdino.datasets.transforms as T


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

In [None]:

n_steps = 4
m1 = 1.
k_transfer = 5
use_down = True
use_up = True
use_up0 = True
use_mid = True

In [None]:
# Add SDLens/src to sys.path at the top of the scrip

# --- Utility functions ---
def resize_mask(mask, size=(16, 16)):
    return cv2.resize(mask.astype(np.uint8), size, interpolation=cv2.INTER_LANCZOS4) > 0

from matplotlib import pyplot as plt
from typing import Tuple
import grounding_dino.groundingdino.datasets.transforms as T

def sam_mask(img, prompt, sam2_predictor, grounding_model, BOX_THRESHOLD, TEXT_THRESHOLD):
    def load_image(img) -> Tuple[np.array, torch.Tensor]:
        transform = T.Compose(
            [
                T.RandomResize([800], max_size=1333),
                T.ToTensor(),
                T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
            ]
        )
        image_source = img.convert("RGB")
        image = np.asarray(image_source)
        image_transformed, _ = transform(image_source, None)
        return image, image_transformed
    image_source, image = load_image(img)
    sam2_predictor.set_image(image_source)

    boxes, confidences, labels = predict(
        model=grounding_model,
        image=image,
        caption=prompt,
        box_threshold=BOX_THRESHOLD,
        text_threshold=TEXT_THRESHOLD,
    )

    # process the box prompt for SAM 2
    h, w, _ = image_source.shape
    boxes = boxes * torch.Tensor([w, h, w, h])
    input_boxes = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()


    # FIXME: figure how does this influence the G-DINO model
    # torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()

    #if torch.cuda.get_device_properties(0).major >= 8:
        # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
        #torch.backends.cuda.matmul.allow_tf32 = True
        #torch.backends.cudnn.allow_tf32 = True

    masks, scores, logits = sam2_predictor.predict(
        point_coords=None,
        point_labels=None,
        box=input_boxes,
        multimask_output=False,
    )

    """
    Post-process the output of the model to get the masks, scores, and logits for visualization
    """
    # convert the shape to (n, H, W)
    if masks.ndim == 4:
        masks = masks.squeeze(1)


    confidences = confidences.numpy().tolist()
    class_names = labels

    class_ids = np.array(list(range(len(class_names))))

    labels = [
        f"{class_name} {confidence:.2f}"
        for class_name, confidence
        in zip(class_names, confidences)
    ]

    detections = sv.Detections(
        xyxy=input_boxes,  # (n, 4)
        mask=masks.astype(bool),  # (n, h, w)
        class_id=class_ids
    )

    box_annotator = sv.BoxAnnotator()
    annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections)

    label_annotator = sv.LabelAnnotator()
    annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)

    mask_annotator = sv.MaskAnnotator()
    annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
    return detections, labels, annotated_frame
    

def best_features_saeuron(source_feats, target_feats, k=10):
    mean_cat = source_feats.mean(dim=0).mean(dim=0)
    mean_dog = target_feats.mean(dim=0).mean(dim=0)
    scores = mean_cat/mean_cat.sum() - mean_dog/mean_dog.sum()
    arg_sorted = np.argsort(scores.cpu().detach().numpy())
    return arg_sorted[::-1][:k].copy(), arg_sorted[:k].copy()
    #best_features = np.argsort(scores.cpu().detach().numpy())[::-1][:k].copy()
    #return best_features

dtype = torch.float16
device = 'cuda' if torch.cuda.is_available() else 'cpu'
pipe = HookedStableDiffusionXLPipeline.from_pretrained(
    'stabilityai/sdxl-turbo',
    torch_dtype=dtype,
    device_map="balanced",
    variant=("fp16" if dtype==torch.float16 else None)
)
pipe.set_progress_bar_config(disable=True)

In [2]:
SAM2_CHECKPOINT = "/share/u/wendler/code/Grounded-SAM-2/checkpoints/sam2.1_hiera_large.pt"
SAM2_MODEL_CONFIG = "configs/sam2.1/sam2.1_hiera_l.yaml"
GROUNDING_DINO_CONFIG = "/share/u/wendler/code/Grounded-SAM-2/grounding_dino/groundingdino/config/GroundingDINO_SwinT_OGC.py"
GROUNDING_DINO_CHECKPOINT = "/share/u/wendler/code/Grounded-SAM-2/gdino_checkpoints/groundingdino_swint_ogc.pth"
BOX_THRESHOLD = 0.35
TEXT_THRESHOLD = 0.25
sam2_model = build_sam2(SAM2_MODEL_CONFIG, SAM2_CHECKPOINT, device=device)
sam2_predictor = SAM2ImagePredictor(sam2_model)
grounding_model = load_model(
    model_config_path=GROUNDING_DINO_CONFIG,
    model_checkpoint_path=GROUNDING_DINO_CHECKPOINT,
    device=device
)

path_to_checkpoints = '/share/u/wendler/code/my-sdxl-unbox/checkpoints/'
code_to_block = {
        "down.2.1": "unet.down_blocks.2.attentions.1",
        "up.0.1": "unet.up_blocks.0.attentions.1",
        "up.0.0": "unet.up_blocks.0.attentions.0",
        "mid.0": "unet.mid_block.attentions.0",
    }
blocks = list(code_to_block.values())
saes = {}
k = 10
exp = 4
for shortcut in code_to_block.keys():
    block = code_to_block[shortcut]
    sae = SparseAutoencoder.load_from_disk(
        os.path.join(path_to_checkpoints, f"{block}_k{k}_hidden{exp*1280:d}_auxk256_bs4096_lr0.0001", "final")
    ).to(device, dtype=dtype)
    saes[shortcut] = sae



final text_encoder_type: bert-base-uncased


In [6]:
from functools import partial
# --- Main experiment ---
def add_featuremaps(sae, to_source_features, to_target_features, fmaps, target_mask, module, input, output):
        diff = output[0] - input[0]
        coefs = sae.encode(diff.permute(0, 2, 3, 1))
        mask = torch.zeros([fmaps.shape[0], fmaps.shape[1], fmaps.shape[2], sae.decoder.weight.shape[1]], device=input[0].device)
        mask[0,target_mask][..., to_target_features] -= coefs[0, target_mask][..., to_target_features]
        mask[..., to_source_features] += fmaps.to(mask.device)
        to_add = mask.to(sae.decoder.weight.dtype) @ sae.decoder.weight.T
        return (output[0] + to_add.permute(0, 3, 1, 2).to(output[0].device),)
    
def activation_patching(mean, target_mask, module, input, output):
    diff = output[0] - input[0]
    diff[0, :, target_mask] += mean[:, None]
    return (diff + input[0],)


def main(prompt1, prompt2, gsam_prompt1, gsam_prompt2, pipe=pipe, k=10, 
         blocks_to_intervene=["down.2.1", "up.0.1", "up.0.0", "mid.0"],
         n_steps=1, m1=2., k_transfer=10, stat="max", mode="sae", verbose=False,
         sam_predictor=sam2_predictor, grounding_model=grounding_model, saes=saes, 
         result_name=None):
    import logging
    logger = logging.getLogger(__name__)
    if verbose:
        logger.setLevel(logging.DEBUG)
    else:
        logger.setLevel(logging.INFO)    

    logger.debug("[4/9] Generating images and caching activations...")
    seed = 42
    base_imgs1, cache1 = pipe.run_with_cache(
        prompt1,
        positions_to_cache=blocks,
        num_inference_steps=n_steps,
        guidance_scale=0.0,
        generator=torch.Generator(device='cpu').manual_seed(seed),
        save_input=True,
    )
    base_imgs2, cache2 = pipe.run_with_cache(
        prompt2,
        positions_to_cache=blocks,
        num_inference_steps=n_steps,
        guidance_scale=0.0,
        generator=torch.Generator(device='cpu').manual_seed(seed),
        save_input=True,
    )
    img1 = base_imgs1[0][0]
    img2 = base_imgs2[0][0]

    logger.debug("[5/9] Running Grounded SAM on generated images...")
    if gsam_prompt1 == "background":
        mask1 = np.ones((16, 16), dtype=bool)
    else:
        detections1, labels1, annotated_frame1 = sam_mask(img1, gsam_prompt1, sam2_predictor, grounding_model, BOX_THRESHOLD, TEXT_THRESHOLD)
        masks = [resize_mask(bigmask).astype(np.float32) for bigmask in detections1.mask]
        mask1 = np.stack(masks, axis=0).sum(axis=0).astype(bool)
    if gsam_prompt2 == "background":
        mask2 = np.ones((16, 16), dtype=bool)
    else:
        detections2, labels2, annotated_frame2 = sam_mask(img2, gsam_prompt2, sam2_predictor, grounding_model, BOX_THRESHOLD, TEXT_THRESHOLD)
        masks = [resize_mask(bigmask).astype(np.float32) for bigmask in detections2.mask]
        mask2 = np.stack(masks, axis=0).sum(axis=0).astype(bool)
    if mask1.sum() == 0:
        mask1 = np.ones((16, 16), dtype=bool)
    if mask2.sum() == 0:
        mask2 = np.ones((16, 16), dtype=bool)
    logger.debug("[6/9] Extracting latents and encoding features...")
    interventions = {}
    for shortcut in blocks_to_intervene:
        block = code_to_block[shortcut]
        diff1 = cache1['output'][block][0] - cache1['input'][block][0]
        diff2 = cache2['output'][block][0] - cache2['input'][block][0]
        source = diff1[:, :, mask1]
        target = diff2[:, :, mask2]
        sae = saes[shortcut]
        source_feats = sae.encode(source.permute(0, 2, 1))
        target_feats = sae.encode(target.permute(0, 2, 1))

        logger.debug("[7/9] Selecting best features...")
        k = sae.k
        to_source_features, to_target_features = best_features_saeuron(source_feats, target_feats, k=k_transfer)
        logger.debug(f"Source features shape: {source_feats.shape}")
        # 1 x 39 x 5120
        if verbose:
            # make a nice plot with one histogram per feature
            fig, axs = plt.subplots(1, len(to_source_features), figsize=(15, 5))
            for i, feature in enumerate(to_source_features):
                axs[i].hist(source_feats[0][:, feature].cpu().detach().numpy(), bins="rice", edgecolor='black')
                axs[i].set_title(f"Feature {feature}")
                axs[i].set_xlabel("Value")
                axs[i].set_ylabel("Frequency")
            plt.show()

            all_source_feats = sae.encode(diff1.permute(0, 2, 3, 1))
            # visualize the feature activation maps of the to_cat_features
            plt.imshow(mask1.astype(np.float32))
            plt.show()
            fig, axs = plt.subplots(1, len(to_source_features), figsize=(15, 5))
            for i, feature in enumerate(to_source_features):
                axs[i].imshow(mask1.astype(np.float32)*all_source_feats[0, :,:, feature].cpu().detach().numpy())
                axs[i].set_title(f"Feature {feature}")
            plt.show()

            fig, axs = plt.subplots(1, len(to_source_features), figsize=(15, 5))
            for i, feature in enumerate(to_source_features):
                axs[i].imshow((1 - mask1.astype(np.float32))*all_source_feats[0, :,:, feature].cpu().detach().numpy())
                axs[i].set_title(f"Feature {feature}")
            plt.show()
            logger.debug(f"source_feats shape: {source_feats[0].shape}")
        # use max
        if stat == "max":
            stat1_val = source_feats.max(dim=0)[0].max(dim=0)[0][to_source_features]
        elif stat == "mean":
            mymeans = []
            logger.debug(f"source_feats shape: {source_feats.shape}")
            stat1_val = source_feats[source_feats[:, :] > 1e-3].mean(dim=0).mean(dim=0)
            #for fidx in to_source_features:
            #    coefs = source_feats[0][..., fidx]
            #    mymeans.append(coefs[coefs > 1e-3].mean())
            #stat1_val = torch.tensor(mymeans, device=torch.device("cuda"))
        else:
            ValueError(f"stat1 {stat} not recognized. Choose from: max, mean")
        logger.debug(f"mean_vals (max): {stat1_val}")
        logger.debug("[8/9] Preparing featuremaps for transfer...")
        fmaps = torch.zeros((1, 16, 16, len(to_source_features)), device=device)
        fmaps[:, mask2] += (m1*stat1_val).unsqueeze(0).unsqueeze(0)


        logger.debug("[9/9] Running SDXL with feature injection...")


        logger.debug(f"Decoder weight shape: {sae.decoder.weight.shape}")
        logger.debug(f"Using mode: {mode}")
        if mode == "patch_max":
            logger.debug(f"Cat shape: {source.shape}")
            f = partial(activation_patching, m1*(source[0].max(dim=1)[0] - target[0].max(dim=1)[0]), mask2)
            interventions[block] = f
        elif mode == "patch_mean":
            f = partial(activation_patching, m1*(source[0].mean(dim=1) - target[0].mean(dim=1)), mask2)
            interventions[block] = f
        else:
            f = partial(add_featuremaps, sae, to_source_features, to_target_features, fmaps, mask2)
            interventions[block] = f

    result = pipe.run_with_hooks(
        prompt2,
        position_hook_dict=interventions,
        num_inference_steps=n_steps,
        guidance_scale=0.0,
        generator=torch.Generator(device='cpu').manual_seed(seed)
    ).images[0]

    # make a result figure that shows the images with masks and the intervened image
    fig, axs = plt.subplots(1, 3, figsize=(15, 5))
    
    # Image 1 with mask from prompt 1
    if gsam_prompt1 != "background":
        axs[0].imshow(annotated_frame1)
    else:
        axs[0].imshow(img1)
    axs[0].set_title(f"{prompt1}")
    axs[0].axis('off')
    
    # Image 2 with mask from prompt 2
    if gsam_prompt2 != "background":
        axs[1].imshow(annotated_frame2)
    else:
        axs[1].imshow(img2)
    axs[1].set_title(f"{prompt2}")
    axs[1].axis('off')
    
    # Intervened result image
    axs[2].imshow(result)
    axs[2].axis('off')
    # tight
    #plt.tight_layout()
    if result_name is not None:
        plt.savefig(result_name + "_summary.png")
        plt.close()
        # save the images
        img1.save(result_name + f"_{gsam_prompt2}_img1.png")
        img2.save(result_name + f"_{gsam_prompt1}_img2.png")
        result.save(result_name + ".png")
    else:
        plt.show()


In [4]:
import json 
from collections import defaultdict
with open("/share/u/wendler/PIE-Bench_v1/mapping_file.json", "r") as f:
    pb = json.load(f)
print(list(pb.keys()))
print(list(pb["000000000001"].keys()))
print(pb["000000000001"])

cnts = defaultdict(int)
for key, val in pb.items():
    cnts[val["editing_type_id"]]+=1
    #print(val["blended_word"])
print(cnts)

['000000000000', '000000000001', '000000000002', '000000000003', '000000000004', '000000000005', '000000000006', '000000000007', '000000000008', '000000000009', '000000000010', '000000000011', '000000000012', '000000000013', '000000000014', '000000000015', '000000000016', '000000000017', '000000000018', '000000000019', '000000000020', '000000000021', '000000000022', '000000000023', '000000000024', '000000000025', '000000000026', '000000000027', '000000000028', '000000000029', '000000000030', '000000000031', '000000000032', '000000000033', '000000000034', '000000000035', '000000000036', '000000000037', '000000000038', '000000000039', '000000000040', '000000000041', '000000000042', '000000000043', '000000000044', '000000000045', '000000000046', '000000000047', '000000000048', '000000000049', '000000000050', '000000000051', '000000000052', '000000000053', '000000000054', '000000000055', '000000000056', '000000000057', '000000000058', '000000000059', '000000000060', '000000000061', '000000

In [7]:
import os
blocks_to_intervene = []
if use_down:
     blocks_to_intervene.append("down.2.1")
if use_up:
     blocks_to_intervene.append("up.0.1")
if use_up0:
     blocks_to_intervene.append("up.0.0")
if use_mid:
     blocks_to_intervene.append("mid.0")


for key, val in pb.items():
     if True or (int(val["editing_type_id"]) >= 0 and cnts[val["editing_type_id"]] < 10):
          try:
               if val["blended_word"] == "":
                    continue
               w1 = val["blended_word"].split(" ")[0]
               w2 = val["blended_word"].split(" ")[1]
               os.makedirs(f"../results/PIE-Bench/{n_steps}_{m1}_{k_transfer}_{use_down}_{use_up}_{use_up0}_{use_mid}/{val['editing_type_id']}", exist_ok=True)
               main(val["editing_prompt"], val["original_prompt"], w2, w1, 
                    blocks_to_intervene=blocks_to_intervene,
                    n_steps=n_steps, m1=m1, k_transfer=k_transfer, stat="max", k=10, mode="sae", 
                    result_name=f"../results/PIE-Bench/{n_steps}_{m1}_{k_transfer}_{use_down}_{use_up}_{use_up0}_{use_mid}/{val['editing_type_id']}/{key}")
               cnts[val["editing_type_id"]]+=1
          except Exception as e:
               print(e)
               continue
          





