In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
os.environ["HF_HOME"] = "/tmp/wendler/hf_cache"

import sys
import torch
from pathlib import Path

sys.path.append("./sdxl-unbox")

from SDLens import HookedStableDiffusionXLPipeline
from SAE import SparseAutoencoder

# Grounded SAM2 and Grounding DINO imports
sys.path.append("./Grounded-SAM-2")
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from grounding_dino.groundingdino.util.inference import load_model
from interventions import code_to_block,run_feature_transport
import os
import json 
from collections import defaultdict

In [None]:
k = 10
exp = 4
n_steps = 4
m1 = 2.
k_transfer = 80
use_down = True
use_up = True
use_up0 = True
use_mid = True
n_examples_per_edit = 50
prefix = './results/'
path_to_checkpoints = "./sdxl-unbox/checkpoints" # "./sdxl-turbo-saes"
dtype = "float32"
mode = "sae" 
stat = "mean"
aggregate_timesteps = "first"
keep_spatial_info = True
subtract_target_add_source = True
only_add = False
task_ids = "1,2,3,4,5,6,7,8,9"
debug = False

In [None]:
assert mode in ["sae", "neurons", "steering", "sae_adaptive"]
task_ids = task_ids.split(',')

In [None]:
if dtype == "float16":
    dtype = torch.float16
else:
    dtype = torch.float32
device = 'cuda' if torch.cuda.is_available() else 'cpu'
pipe = HookedStableDiffusionXLPipeline.from_pretrained(
    'stabilityai/sdxl-turbo',
    torch_dtype=dtype,
    device_map="balanced",
    variant=("fp16" if dtype==torch.float16 else None)
)
if dtype == torch.float32:
    pipe.text_encoder_2.to(dtype=dtype)
pipe.set_progress_bar_config(disable=True)

In [None]:
SAM2_CHECKPOINT = "./Grounded-SAM-2/checkpoints/sam2.1_hiera_large.pt"
SAM2_MODEL_CONFIG = "configs/sam2.1/sam2.1_hiera_l.yaml"
GROUNDING_DINO_CONFIG = "./Grounded-SAM-2/grounding_dino/groundingdino/config/GroundingDINO_SwinT_OGC.py"
GROUNDING_DINO_CHECKPOINT = "./Grounded-SAM-2/gdino_checkpoints/groundingdino_swint_ogc.pth"
BOX_THRESHOLD = 0.35
TEXT_THRESHOLD = 0.25
sam2_model = build_sam2(SAM2_MODEL_CONFIG, SAM2_CHECKPOINT, device=device)
sam2_predictor = SAM2ImagePredictor(sam2_model)
grounding_model = load_model(
    model_config_path=GROUNDING_DINO_CONFIG,
    model_checkpoint_path=GROUNDING_DINO_CHECKPOINT,
    device=device
)

blocks = list(code_to_block.values())
saes = {}
for shortcut in code_to_block.keys():
    block = code_to_block[shortcut]
    sae = SparseAutoencoder.load_from_disk(
        os.path.join(path_to_checkpoints, f"{block}_k{k}_hidden{exp*1280:d}_auxk256_bs4096_lr0.0001", "final")
    ).to(device, dtype=dtype)
    saes[shortcut] = sae

In [None]:
if debug:
    out = run_feature_transport("a cat in front of a grey background", "a cat with a hat in front of a grey background", "#everything", "hat", 
                    pipe, grounding_model, sam2_predictor, saes,
                    blocks_to_intervene=["down.2.1", "mid.0","up.0.0","up.0.1"], combine_blocks=True, 
                    subtract_target_add_source=True,
                    use_target_mask_in_both=True,
                    maintain_spatial_info=True,
                    n_steps=4, m1=1., k_transfer=10, stat="mean", mode="steering", 
                    aggregate_timesteps="first",
                    result_name=None)

In [None]:
if debug:
    out = run_feature_transport("a cat in front of a grey background", "a cat with a headphones in front of a grey background", "#everything", "headphones", 
                    pipe, grounding_model, sam2_predictor, saes,
                    blocks_to_intervene=["down.2.1", "mid.0","up.0.0","up.0.1"], combine_blocks=True, 
                    subtract_target_add_source=True,
                    use_target_mask_in_both=True,
                    maintain_spatial_info=True,
                    n_steps=4, m1=1., k_transfer=10, stat="mean", mode="steering", 
                    aggregate_timesteps="first",
                    result_name=None)

In [None]:
if debug:
    out = run_feature_transport("a photo of a giraffe", "a photo of a colorful model", "giraffe", "face", 
                    pipe, grounding_model, sam2_predictor, saes,
                    blocks_to_intervene=["down.2.1", "mid.0","up.0.0","up.0.1"], combine_blocks=True, 
                    subtract_target_add_source=False,
                    maintain_spatial_info=False,
                    n_steps=4, m1=1., k_transfer=200, 
                    stat="mean", 
                    mode="sae", 
                    aggregate_timesteps="first")

In [None]:
if debug:
    out = run_feature_transport("a photo of a black panther", "a photo of a colorful model", "panther", "face", 
                    pipe, grounding_model, sam2_predictor, saes,
                    blocks_to_intervene=["down.2.1", "mid.0","up.0.0","up.0.1"], combine_blocks=True, 
                    subtract_target_add_source=True,
                    maintain_spatial_info=True,
                    n_steps=4, m1=1., k_transfer=200, 
                    stat="mean", 
                    mode="sae", 
                    aggregate_timesteps="first")

In [None]:
if debug:
    out = run_feature_transport("a photo of a giraffe", "a photo of a colorful model", "giraffe", "face", 
                    pipe, grounding_model, sam2_predictor, saes,
                    blocks_to_intervene=["down.2.1", "mid.0","up.0.0","up.0.1"], combine_blocks=True, 
                    subtract_target_add_source=False,
                    maintain_spatial_info=False,
                    n_steps=4, m1=1., k_transfer=100000, 
                    stat="mean", 
                    mode="neurons", 
                    aggregate_timesteps="first")

In [None]:
with open("./dataset/riebench.json", "r") as f:
    rb = json.load(f)

rb[0] # source and target are mixed up

In [None]:

blocks_to_intervene = []
if use_down:
     blocks_to_intervene.append("down.2.1")
if use_up:
     blocks_to_intervene.append("up.0.1")
if use_up0:
     blocks_to_intervene.append("up.0.0")
if use_mid:
     blocks_to_intervene.append("mid.0")

expid2name= {"0":"random",
"1":"change object",
"2":"add object",
"3":"delete object", 
"4":"change content",
"5":"change pose",
"6":"change color",
"7":"change material",
"8":"change background",
"9":"change style"}

def remove_brakets(txt):
     return txt.replace("[","").replace("]","")

cnt = defaultdict(int)
for d in rb:
     try:
          if d["editing_type_id"] in [] or cnt[d["editing_type_id"]] >= n_examples_per_edit:
               continue
          if d["editing_type_id"] not in task_ids:
               continue
          if d["original_prompt"].replace("]", "").replace("[", "") == \
               d["editing_prompt"].replace("]", "").replace("[", ""):
               continue
          key = d["id"]
          print(key)
          path = os.path.join(prefix, f"mode{mode}_onlyadd{only_add}_select{aggregate_timesteps}_spatial{keep_spatial_info}_subtract{subtract_target_add_source}_down{use_down}_up{use_up}_up0{use_up0}_mid{use_mid}_T{n_steps}_ktrans{k_transfer}_str{m1}/{d['editing_type_id']}")
          original_prompt = remove_brakets(d["original_prompt"])
          editing_prompt = remove_brakets(d["editing_prompt"])
          os.makedirs(path, exist_ok=True)
          if d["editing_type_id"] in ['0']:
               continue
          elif d["editing_type_id"] in ['2']: # add object
               run_feature_transport(editing_prompt, original_prompt, d["editing_gsam_prompt"], "#everything", 
                    pipe, grounding_model, sam2_predictor, saes,
                    use_source_mask_in_both = True, 
                    subtract_target_add_source = True,
                    maintain_spatial_info=keep_spatial_info,
                    blocks_to_intervene=blocks_to_intervene, combine_blocks=True,
                    n_steps=n_steps, m1=m1, k_transfer=k_transfer, stat=stat, mode=mode, 
                    aggregate_timesteps=aggregate_timesteps, only_add=only_add,
                    result_name=f"{path}/{key}")
          elif d["editing_type_id"] in ['3']: # delete object
               run_feature_transport(editing_prompt, original_prompt, "#everything", d["original_gsam_prompt"], 
                    pipe, grounding_model, sam2_predictor, saes,
                    blocks_to_intervene=blocks_to_intervene, combine_blocks=True, 
                    use_target_mask_in_both=True,
                    subtract_target_add_source=True,
                    maintain_spatial_info=keep_spatial_info,
                    n_steps=n_steps, m1=m1, k_transfer=k_transfer, stat=stat, mode=mode, 
                    aggregate_timesteps=aggregate_timesteps,
                    result_name=f"{path}/{key}")
          elif d["editing_type_id"] in ['4', '5'] + ['1'] + ['6', '7', '8', '9']: # change
               run_feature_transport(editing_prompt, original_prompt, d["editing_gsam_prompt"], d["original_gsam_prompt"], 
                    pipe, grounding_model, sam2_predictor, saes,
                    blocks_to_intervene=blocks_to_intervene, combine_blocks=True, subtract_target_add_source=subtract_target_add_source,
                    maintain_spatial_info=keep_spatial_info, only_add=only_add,
                    n_steps=n_steps, m1=m1, k_transfer=k_transfer, stat=stat, mode=mode, 
                    aggregate_timesteps=aggregate_timesteps,
                    result_name=f"{path}/{key}")
          else: 
               raise ValueError("Unsupported editing type.")
          cnt[d["editing_type_id"]] += 1
     except Exception as e:
          print(e)
          continue
          