In [None]:
import sys
sys.path.append("NewtonRaphsonInversion")

import torch
from src.config import RunConfig
from ipywidgets import Text, VBox
import PIL
from src.euler_scheduler import MyEulerAncestralDiscreteScheduler
from diffusers.pipelines.auto_pipeline import AutoPipelineForImage2Image
from src.sdxl_inversion_pipeline import SDXLDDIMPipeline
from PIL import Image
from diffusers.utils.torch_utils import randn_tensor
from IPython.display import display


In [2]:
def inversion_callback(pipe, step, timestep, callback_kwargs):
    return callback_kwargs


def inference_callback(pipe, step, timestep, callback_kwargs):
    return callback_kwargs

def center_crop(im):
    width, height = im.size  # Get dimensions
    min_dim = min(width, height)
    left = (width - min_dim) / 2
    top = (height - min_dim) / 2
    right = (width + min_dim) / 2
    bottom = (height + min_dim) / 2

    # Crop the center of the image
    im = im.crop((left, top, right, bottom))
    return im


def load_im_into_format_from_path(im_path, size=(1024, 1024)):
    return center_crop(PIL.Image.open(im_path)).resize(size)

In [3]:
model = "stabilityai/stable-diffusion-xl-base-1.0"
model = "stabilityai/sdxl-turbo"

class ImageEditorDemo:
    def __init__(self, pipe_inversion, pipe_inference, input_image, description_prompt, cfg, edit_cfg=1.2):
        self.pipe_inversion = pipe_inversion
        self.pipe_inference = pipe_inference
        self.load_image = True
        g_cpu = torch.Generator().manual_seed(7865)
        if model == "stabilityai/stable-diffusion-xl-base-1.0":
            img_size = (1024,1024)
        else:
            img_size = (512,512)
        self.original_image = load_im_into_format_from_path(input_image, img_size).convert("RGB")

        # resise input image
        VQAE_SCALE = 8
        latents_size = (1, 4, img_size[0] // VQAE_SCALE, img_size[1] // VQAE_SCALE)
        print(pipe_inversion.unet.dtype)
        noise = [randn_tensor(latents_size, dtype=pipe_inversion.unet.dtype, device=torch.device("cuda:0"), generator=g_cpu) for i
                 in range(cfg.num_inversion_steps)]
        print(noise[0].shape)
        pipe_inversion.scheduler.set_noise_list(noise)
        pipe_inference.scheduler.set_noise_list(noise)
        pipe_inversion.scheduler_inference.set_noise_list(noise)
        pipe_inversion.set_progress_bar_config(disable=True)
        pipe_inference.set_progress_bar_config(disable=True)
        self.cfg = cfg
        self.pipe_inversion.cfg = cfg
        self.pipe_inference.cfg = cfg
        self.inv_hp = [2, 0.1, 0.2] # niter, alpha, lr 2, 0.1, 0.2 is default
        self.edit_cfg = edit_cfg

        #self.pipe_inference.to("cuda")
        #self.pipe_inversion.to("cuda")

        self.last_latent = self.invert(self.original_image, description_prompt)
        self.original_latent = self.last_latent

    def invert(self, init_image, base_prompt):
        res = self.pipe_inversion(prompt=base_prompt,
                             num_inversion_steps=self.cfg.num_inversion_steps,
                             num_inference_steps=self.cfg.num_inference_steps,
                             image=init_image,
                             guidance_scale=self.cfg.guidance_scale,
                             callback_on_step_end=inversion_callback,
                             strength=self.cfg.inversion_max_step,
                             denoising_start=1.0 - self.cfg.inversion_max_step,
                             inv_hp=self.inv_hp)[0][0]
        return res

    def edit(self, target_prompt, guidance_scale=None):
        if guidance_scale is None:
            guidance_scale = self.edit_cfg
        image = self.pipe_inference(prompt=target_prompt,
                            num_inference_steps=self.cfg.num_inference_steps,
                            negative_prompt="",
                            callback_on_step_end=inference_callback,
                            image=self.last_latent,
                            strength=self.cfg.inversion_max_step,
                            denoising_start=1.0 - self.cfg.inversion_max_step,
                            guidance_scale=guidance_scale).images[0]
        return image.resize((512, 512))
    
    def edit_with_hooks(self, target_prompt, position_hook_dict, guidance_scale=None):
        if guidance_scale is None:
            guidance_scale = self.edit_cfg
        image = self.pipe_inference.run_with_hooks(prompt=target_prompt,
                            position_hook_dict=position_hook_dict,
                            num_inference_steps=self.cfg.num_inference_steps,
                            negative_prompt="",
                            callback_on_step_end=inference_callback,
                            image=self.last_latent,
                            strength=self.cfg.inversion_max_step,
                            denoising_start=1.0 - self.cfg.inversion_max_step,
                            guidance_scale=guidance_scale).images[0]
        return image

In [None]:
import os
os.environ["HF_HOME"]

In [5]:
from SDLens import HookedStableDiffusionXLImg2ImgPipeline
from utils import add_feature, add_feature_on_area_turbo, add_feature_on_area_base

In [None]:
print(os.environ["HF_HOME"])

In [None]:
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if model == "stabilityai/stable-diffusion-xl-base-1.0":
    image_size = (1024,1024)
else:
    image_size = (512,512)
dtype = torch.float32
scheduler_class = MyEulerAncestralDiscreteScheduler
if model == "stabilityai/stable-diffusion-xl-base-1.0":
    pipe_inversion = SDXLDDIMPipeline.from_pretrained(model, 
                                                      torch_dtype=dtype,
                                                      device_map="balanced",
                                                      variant=("fp16" if dtype==torch.float16 else None),
                                                      cache_dir=os.path.join(os.environ["HF_HOME"], "tmp")
    )
    pipe_inference = HookedStableDiffusionXLImg2ImgPipeline.from_pretrained(model, 
                                                                        torch_dtype=dtype,
                                                                        device_map="balanced",
                                                                        variant=("fp16" if dtype==torch.float16 else None),
                                                                        cache_dir=os.path.join(os.environ["HF_HOME"], "tmp")
                                                                    )
    if dtype == torch.float32:
        pipe_inversion.text_encoder_2.to(dtype)
        pipe_inference.text_encoder_2.to(dtype)
else:
    pipe_inversion = SDXLDDIMPipeline.from_pretrained(model, use_safetensors=True, safety_checker=None, cache_dir=os.environ["HF_HOME"]).to(device)
    pipe_inference = HookedStableDiffusionXLImg2ImgPipeline.from_pretrained(model, use_safetensors=True, safety_checker=None, cache_dir=os.environ["HF_HOME"]).to(device)

#pipe_inference = AutoPipelineForImage2Image.from_pretrained("stabilityai/sdxl-turbo", use_safetensors=True, safety_checker= None, cache_dir=os.environ["HF_HOME"]).to(device)
pipe_inference.scheduler            = scheduler_class.from_config(pipe_inference.scheduler.config)
pipe_inversion.scheduler            = scheduler_class.from_config(pipe_inversion.scheduler.config)
pipe_inversion.scheduler_inference  = scheduler_class.from_config(pipe_inference.scheduler.config)

In [8]:
if model == "stabilityai/stable-diffusion-xl-base-1.0":
    config = RunConfig(num_inference_steps=20,
                   num_inversion_steps=20,
                   guidance_scale=0.0,
                   inversion_max_step=0.6) #4,4,0,0.6 is default settings 0.6 and 0.7 look the same
else:
    config = RunConfig(num_inference_steps=4,
                   num_inversion_steps=4,
                   guidance_scale=0.0,
                   inversion_max_step=0.6) #4,4,0,0.6 is default settings 0.6 and 0.7 look the same

In [None]:
pipe_inversion.text_encoder_2.dtype

In [10]:
# set up SAE intervention
from SAE import SparseAutoencoder
dtype = torch.float32

path_to_checkpoints = './checkpoints/'

code_to_block = {
    "down.2.1": "unet.down_blocks.2.attentions.1",
    "mid.0": "unet.mid_block.attentions.0",
    "up.0.1": "unet.up_blocks.0.attentions.1",
    "up.0.0": "unet.up_blocks.0.attentions.0"
}

saes_dict = {}
means_dict = {}

for code, block in code_to_block.items():
    sae = SparseAutoencoder.load_from_disk(
        os.path.join(path_to_checkpoints, f"{block}_k10_hidden5120_auxk256_bs4096_lr0.0001", "final"),
    )
    means = torch.load(
        os.path.join(path_to_checkpoints, f"{block}_k10_hidden5120_auxk256_bs4096_lr0.0001", "final", "mean.pt"),
        weights_only=True
    )
    saes_dict[code] = sae.to('cuda', dtype=dtype)
    means_dict[code] = means.to('cuda', dtype=dtype)

In [None]:
means_dict["down.2.1"].shape

In [None]:
dtype

# back to inversion notebook

In [None]:
h = display(display_id='my-display')
input_image = "resourses/chris.png"
description_prompt = 'a guy standing on a mountain in Hokaidoo with brown hair and a blue hiking shirt, there is a city in the background'
editor = ImageEditorDemo(pipe_inversion, pipe_inference, input_image, description_prompt, config, edit_cfg=1.2) #1.2

In [14]:
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets

text_input = widgets.Text(
        value=description_prompt,
        description="Prompt:",
        style={"description_width": "initial"},
        layout=widgets.Layout(width='70%'), 
    )

def f(x):
    h.update(editor.edit(text_input.value))


In [None]:
h.display(editor.edit("a guy with a mustache standing on a mountain in Hokaidoo with brown hair and a blue hiking shirt, there is a city in the background"))

In [None]:
position_hook_dict = {"unet.down_blocks.2.attentions.1": 
                      lambda *args, **kwargs: add_feature(
                                saes_dict["down.2.1"],
                                2301,
                                10*means_dict["down.2.1"].mean(),
                                *args, **kwargs)}
h.display(editor.edit_with_hooks("a guy with a mustache standing on a mountain in Hokaidoo with brown hair and a blue hiking shirt, there is a city in the background",
                                 position_hook_dict,
                                 guidance_scale=0.0)) # this type of intervention should be adapted before using it with guidance scale

In [None]:
position_hook_dict = {"unet.down_blocks.2.attentions.1": 
                      lambda *args, **kwargs: add_feature(
                                saes_dict["down.2.1"],
                                2301,
                                -10*means_dict["down.2.1"].mean(),
                                *args, **kwargs)}
h.display(editor.edit_with_hooks("a guy with a mustache standing on a mountain in Hokaidoo with brown hair and a blue hiking shirt, there is a city in the background",
                                 position_hook_dict,
                                 guidance_scale=0.0)) # this type of intervention should be adapted before using it with guidance scale

In [None]:
position_hook_dict = {"unet.down_blocks.2.attentions.1": 
                      lambda *args, **kwargs: add_feature(
                                saes_dict["down.2.1"],
                                4998,
                                10*means_dict["down.2.1"].mean(),
                                *args, **kwargs)}
h.display(editor.edit_with_hooks("a guy with a mustache standing on a mountain in Hokaidoo with brown hair and a blue hiking shirt, there is a city in the background",
                                 position_hook_dict,
                                 guidance_scale=0.0)) # this type of intervention should be adapted before using it with guidance scale

In [None]:
position_hook_dict = {"unet.down_blocks.2.attentions.1": 
                      lambda *args, **kwargs: add_feature(
                                saes_dict["down.2.1"],
                                4998,
                                -10*means_dict["down.2.1"].mean(),
                                *args, **kwargs)}
h.display(editor.edit_with_hooks("a guy with a mustache standing on a mountain in Hokaidoo with brown hair and a blue hiking shirt, there is a city in the background",
                                 position_hook_dict,
                                 guidance_scale=0.0)) # this type of intervention should be adapted before using it with guidance scale

In [None]:
position_hook_dict = {code_to_block["up.0.1"]: 
                      lambda *args, **kwargs: add_feature(
                                saes_dict["up.0.1"],
                                90,
                                15*means_dict["up.0.1"].mean(),
                                *args, **kwargs)}
h.display(editor.edit_with_hooks("a guy with a mustache standing on a mountain in Hokaidoo with brown hair and a blue hiking shirt, there is a city in the background",
                                 position_hook_dict,
                                 guidance_scale=0.0)) # this type of intervention should be adapted before using it with guidance scale

In [None]:
position_hook_dict = {code_to_block["up.0.1"]: 
                      lambda *args, **kwargs: add_feature(
                                saes_dict["up.0.1"],
                                90,
                                0*means_dict["up.0.1"].mean(),
                                *args, **kwargs)}
h.display(editor.edit_with_hooks("a guy with a mustache standing on a mountain in Hokaidoo with brown hair and a blue hiking shirt, there is a city in the background",
                                 position_hook_dict,
                                 guidance_scale=0.0)) # this type of intervention should be adapted before using it with guidance scale

In [None]:
position_hook_dict = {code_to_block["up.0.0"]: 
                      lambda *args, **kwargs: add_feature(
                                saes_dict["up.0.0"],
                                4594,
                                -8*means_dict["up.0.0"].mean(),
                                *args, **kwargs)}
h.display(editor.edit_with_hooks("a guy with a mustache standing on a mountain in Hokaidoo with brown hair and a blue hiking shirt, there is a city in the background",
                                 position_hook_dict,
                                 guidance_scale=0.0)) # this type of intervention should be adapted before using it with guidance scale

In [None]:
position_hook_dict = {code_to_block["up.0.0"]: 
                      lambda *args, **kwargs: add_feature(
                                saes_dict["up.0.0"],
                                4594,
                                8*means_dict["up.0.0"].mean(),
                                *args, **kwargs)}
h.display(editor.edit_with_hooks("a guy with a mustache standing on a mountain in Hokaidoo with brown hair and a blue hiking shirt, there is a city in the background",
                                 position_hook_dict,
                                 guidance_scale=0.0)) # this type of intervention should be adapted before using it with guidance scale

In [None]:
# let's create a 16x16 mask
# 90 fur, 3718 giraffe

def myhook(module, inputs, outputs):
    print(inputs[0].shape)
    print(outputs[0].shape)
    return outputs

if model == "stabilityai/stable-diffusion-xl-base-1.0":
    mask = torch.zeros((1, 1, 32, 32))
    mask[:, :, 6:, 8:24] = 1
else:
    mask = torch.zeros((1, 1, 16, 16))
    mask[:, :, 3:, 4:12] = 1

block = "up.0.1"
fidx = 90
strength = 12
position_hook_dict = {code_to_block[block]: 
                      lambda *args, **kwargs: add_feature_on_area_turbo(
                                saes_dict[block],
                                fidx,
                                mask.cuda() * strength * means_dict[block][fidx],
                                *args, **kwargs),
                        code_to_block["up.0.0"]: myhook}
h.display(editor.edit_with_hooks("a guy with a mustache standing on a mountain in Hokaidoo with brown hair and a blue hiking shirt, there is a city in the background",
                                 position_hook_dict,
                                 guidance_scale=0.0)) # this type of intervention should be adapted before using it with guidance scale

In [None]:
# let's create a 16x16 mask
# up.0.1: 90 fur, 3718 giraffe, 4977 tiger, 1393 leopard, 4197 green

def myhook(module, inputs, outputs):
    print(inputs[0].shape)
    print(outputs[0].shape)
    return outputs


mask = torch.zeros((1, 1, 16, 16))
#mask[:, :, 3:, 4:12] = 1
mask[:, :, 10:13, 5:11] = 1
block = "up.0.0"
fidx = 4594 # moustache feature # 3742 beard?
strength = 10

position_hook_dict = {code_to_block[block]: 
                      lambda *args, **kwargs: add_feature_on_area_turbo(
                                saes_dict[block],
                                fidx,
                                mask.cuda() * strength * means_dict[block][fidx],
                                *args, **kwargs)}
h.display(editor.edit_with_hooks("a guy standing on a mountain in Hokaidoo with brown hair and a blue hiking shirt, there is a city in the background",
                                 position_hook_dict,
                                 guidance_scale=0.0)) # this type of intervention should be adapted before using it with guidance scale

position_hook_dict = {code_to_block[block]: 
                      lambda *args, **kwargs: add_feature_on_area_base(
                                saes_dict[block],
                                fidx,
                                mask.cuda() * strength * means_dict[block][fidx],
                                *args, **kwargs)}
h.display(editor.edit_with_hooks("a guy standing on a mountain in Hokaidoo with brown hair and a blue hiking shirt, there is a city in the background",
                                 position_hook_dict))

h.display(editor.edit("a guy with a mustache standing on a mountain in Hokaidoo with brown hair and a blue hiking shirt, there is a city in the background"))

In [None]:
# let's create a 16x16 mask
# up.0.1: 90 fur, 3718 giraffe, 4977 tiger, 1393 leopard, 4197 green

mask = torch.zeros((1, 1, 16, 16))
#mask[:, :, 3:, 4:12] = 1
mask[:, :, 7:9, 4:12] = 1
block = "up.0.0"
fidx = 2638 # sunglasses # 3742 beard?
strength = 10

position_hook_dict = {code_to_block[block]: 
                      lambda *args, **kwargs: add_feature_on_area_turbo(
                                saes_dict[block],
                                fidx,
                                mask.cuda() * strength * means_dict[block][fidx],
                                *args, **kwargs)}
h.display(editor.edit_with_hooks("a guy standing on a mountain in Hokaidoo with brown hair and a blue hiking shirt, there is a city in the background",
                                 position_hook_dict,
                                 guidance_scale=0.0)) # this type of intervention should be adapted before using it with guidance scale

position_hook_dict = {code_to_block[block]: 
                      lambda *args, **kwargs: add_feature_on_area_base(
                                saes_dict[block],
                                fidx,
                                mask.cuda() * strength * means_dict[block][fidx],
                                *args, **kwargs)}
h.display(editor.edit_with_hooks("a guy wearing standing on a mountain in Hokaidoo with brown hair and a blue hiking shirt, there is a city in the background.",
                                 position_hook_dict))

h.display(editor.edit("a guy with dark black sunglasses standing on a mountain in Hokaidoo with brown hair and a blue hiking shirt, there is a city in the background"))

In [None]:
# let's create a 16x16 mask
# up.0.1: 90 fur, 3718 giraffe, 4977 tiger, 1393 leopard, 4197 green

mask = torch.zeros((1, 1, 16, 16))
#mask[:, :, 3:, 4:12] = 1
mask[:, :, 10:14, 5:11] = 1
block = "up.0.0"
fidx = 2937 # 4161 # smile # 5048 # sad face # 2937 # shouting # 3742 beard?
strength = 10

position_hook_dict = {code_to_block[block]: 
                      lambda *args, **kwargs: add_feature_on_area_turbo(
                                saes_dict[block],
                                fidx,
                                mask.cuda() * strength * means_dict[block][fidx],
                                *args, **kwargs)}
h.display(editor.edit_with_hooks("a guy standing on a mountain in Hokaidoo with brown hair and a blue hiking shirt, there is a city in the background",
                                 position_hook_dict,
                                 guidance_scale=0.0)) # this type of intervention should be adapted before using it with guidance scale

position_hook_dict = {code_to_block[block]: 
                      lambda *args, **kwargs: add_feature_on_area_base(
                                saes_dict[block],
                                fidx,
                                mask.cuda() * strength * means_dict[block][fidx],
                                *args, **kwargs)}
h.display(editor.edit_with_hooks("a guy while standing on a mountain in Hokaidoo with brown hair and a blue hiking shirt, there is a city in the background",
                                 position_hook_dict))

h.display(editor.edit("a guy shouting at the camera while standing on a mountain in Hokaidoo with brown hair and a blue hiking shirt, there is a city in the background",
          guidance_scale=1.2))

In [None]:
# let's create a 16x16 mask
# up.0.1: 90 fur, 3718 giraffe, 4977 tiger, 1393 leopard, 4197 green

mask = torch.zeros((1, 1, 16, 16))
#mask[:, :, 3:, 4:12] = 1
mask[:, :, 6:8, 4:12] = 1
block = "up.0.1"
fidx = 90 
strength = 15

position_hook_dict = {code_to_block[block]: 
                      lambda *args, **kwargs: add_feature_on_area_turbo(
                                saes_dict[block],
                                fidx,
                                mask.cuda() * strength * means_dict[block][fidx],
                                *args, **kwargs)}
h.display(editor.edit_with_hooks("a guy standing on a mountain in Hokaidoo with brown hair and a blue hiking shirt, there is a city in the background",
                                 position_hook_dict,
                                 guidance_scale=0.0)) # this type of intervention should be adapted before using it with guidance scale

position_hook_dict = {code_to_block[block]: 
                      lambda *args, **kwargs: add_feature_on_area_base(
                                saes_dict[block],
                                fidx,
                                mask.cuda() * strength * means_dict[block][fidx],
                                *args, **kwargs)}
h.display(editor.edit_with_hooks("a guy wearing standing on a mountain in Hokaidoo with brown hair and a blue hiking shirt, there is a city in the background.",
                                 position_hook_dict))

h.display(editor.edit("a guy with dark black sunglasses standing on a mountain in Hokaidoo with brown hair and a blue hiking shirt, there is a city in the background"))

In [None]:
# let's create a 16x16 mask
# 90 fur, 3718 giraffe, 4977 tiger, 1393 leopard

mask = torch.zeros((1, 1, 16, 16))
mask[:, :, 3:, 4:12] = 1
block = "down.2.1"
fidx = 527# 2301 evil # 89 muscular # 4074 anime # 179 horse
strength = 25

position_hook_dict = {code_to_block[block]: 
                      lambda *args, **kwargs: add_feature_on_area_turbo(
                                saes_dict[block],
                                fidx,
                                mask.cuda() * strength * means_dict[block][fidx],
                                *args, **kwargs)}
h.display(editor.edit_with_hooks("a guy with a mustache standing on a mountain in Hokaidoo with brown hair and a blue hiking shirt, there is a city in the background",
                                 position_hook_dict,
                                 guidance_scale=0.0)) # this type of intervention should be adapted before using it with guidance scale

position_hook_dict = {code_to_block[block]: 
                      lambda *args, **kwargs: add_feature_on_area_base(
                                saes_dict[block],
                                fidx,
                                mask.cuda() * strength * means_dict[block][fidx],
                                *args, **kwargs)}
h.display(editor.edit_with_hooks("a guy with a mustache standing on a mountain in Hokaidoo with brown hair and a blue hiking shirt, there is a city in the background",
                                 position_hook_dict))

h.display(editor.edit("a black guy with a mustache standing on a mountain in Hokaidoo with brown hair and a blue hiking shirt, there is a city in the background"))

In [None]:
# let's create a 16x16 mask
# 90 fur, 3718 giraffe, 4977 tiger, 1393 leopard

mask = torch.zeros((1, 1, 16, 16))
mask[:, :, 3:, 4:12] = 1
block = "down.2.1"
fidx = 349 # kid # 527 # black # 2301 evil # 89 muscular # 4074 anime # 179 horse
strength = 9

position_hook_dict = {code_to_block[block]: 
                      lambda *args, **kwargs: add_feature_on_area_turbo(
                                saes_dict[block],
                                fidx,
                                mask.cuda() * strength * means_dict[block][fidx],
                                *args, **kwargs)}
h.display(editor.edit_with_hooks("a guy with a mustache standing on a mountain in Hokaidoo with brown hair and a blue hiking shirt, there is a city in the background",
                                 position_hook_dict))

position_hook_dict = {code_to_block[block]: 
                      lambda *args, **kwargs: add_feature_on_area_base(
                                saes_dict[block],
                                fidx,
                                mask.cuda() * strength * means_dict[block][fidx],
                                *args, **kwargs)}
h.display(editor.edit_with_hooks("a guy with a mustache standing on a mountain in Hokaidoo with brown hair and a blue hiking shirt, there is a city in the background",
                                 position_hook_dict))

h.display(editor.edit("a photo of a kid with a mustache standing on a mountain in Hokaidoo with brown hair and a blue hiking shirt, there is a city in the background"))