# Setup

In [None]:
try:
    import google.colab
    !pip install diffusers==0.30.0 transformers accelerate scipy omegaconf dotenv loguru ipywidgets
except:
    pass

In [None]:
import os
import torch
import urllib.request

from IPython.display import display
from transformers import pipeline
from diffusers import StableDiffusionPipeline
from pathlib import Path

In [None]:
root_dir = Path(os.getcwd()).parent
try:
    import google.colab
    !git clone https://github.com/tweks/sae-sd.git
    root_dir = Path(os.path.join(os.getcwd(), 'sae-sd'))
except:
    pass

In [None]:
model_url = 'https://github.com/tweks/sae-sd/releases/download/model-v0.0.1/12288_768_TopKReLU_64_False_False_0.0_CC3M_15_train_target_223758458_768.pt'
model_path = root_dir/'model'/'data'/os.path.basename(model_url)
if not os.path.exists(model_path):
    urllib.request.urlretrieve(model_url, model_path)

In [None]:
pipe = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float32)
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = pipe.to(device)

In [None]:
clip = pipeline(
   task="zero-shot-image-classification",
   model="openai/clip-vit-base-patch32",
   torch_dtype=torch.float32,
   device=device,
)

# Standard SD

In [None]:
DEFAULT_SEED = 0
try:
    import google.colab
    DEFAULT_SEED = 26
except:
    pass

In [None]:
def set_seed(seed=DEFAULT_SEED):
    print(f'Setting seed to {seed}')
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

In [None]:
set_seed()

In [None]:
prompt = "a photo of an astronaut riding a horse on mars"

In [None]:
prompt_embed = pipe.encode_prompt(prompt, device=device, num_images_per_prompt=1, do_classifier_free_guidance=True)[0]
prompt_embed.shape

In [None]:
def generate_image(prompt, seed=DEFAULT_SEED):
    prompt_embed = pipe.encode_prompt(prompt, device=device, num_images_per_prompt=1, do_classifier_free_guidance=True)[0]
    generator = torch.Generator(device).manual_seed(seed)
    return pipe(prompt_embeds=prompt_embed, generator=generator).images[0]

## Other seeds

In [None]:
def generate_images(prompt, seed_start, seed_end):
    print(f'Generating images for seeds [{seed_start}, {seed_end}]')
    for seed in range(seed_start, seed_end + 1):
        print(f'Seed: {seed}')
        set_seed(seed)
        display(generate_image(prompt, seed=seed))
    set_seed()

In [None]:
#for seed in range(20, 31):
#    print(f"Seed: {seed}")
#    set_seed(seed)
#    display(generate_image(prompt, seed=seed))
#set_seed()

# Modified SD

## Setup

In [None]:
import sys
if str(root_dir) not in sys.path:
    sys.path.append(str(root_dir))
import json
import numpy as np
from model.models import load_model

In [None]:
model, dataset_normalize, dataset_target_norm, dataset_mean = load_model(str(root_dir/"model/data/12288_768_TopKReLU_64_False_False_0.0_CC3M_15_train_target_223758458_768.pt"))
model.to(device)
model.eval()
model, dataset_normalize, dataset_target_norm, dataset_mean.shape

In [None]:
ds_info = json.load(open(root_dir/"model/data/CC3M_15_train_target_sds_train_dataset_metadata.json"))
ds_info.keys()

In [None]:
dataset_scaling_factor = ds_info["scaling_factor"]
lenses = None
seq_id = 0
seq_len = 0
dataset_scaling_factor

In [None]:
def process_data(data: np.ndarray | torch.Tensor, idx: int | None=None) -> torch.Tensor:
    """Process data into the correct format."""
    X = data.to(torch.float32)
    X = X.sub(dataset_mean)
    X = X.mul(dataset_scaling_factor)

    if lenses is not None and idx is not None:
        current_seq_id = idx % seq_len
        if current_seq_id != seq_id:
            lens = lenses[current_seq_id]
            X = lens(X)

    return X

def invert_preprocess(data: torch.Tensor, idx: int | None=None) -> torch.Tensor:
    """Inverse process data."""
    if lenses is not None and idx is not None:
        current_seq_id = idx % seq_len
        if current_seq_id != seq_id:
            lens = lenses[current_seq_id]
            X = lens.invert(data)
        else:
            X = data
    else:
        X = data

    X = X.div(dataset_scaling_factor)
    X = X.add(dataset_mean)

    return X

## Image

In [None]:
prompt_embed_to_sae = prompt_embed.squeeze(0)
prompt_embed_to_sae.shape

In [None]:
prompt_embed_to_sae_pre = process_data(prompt_embed_to_sae)
prompt_embed_to_sae_post = invert_preprocess(prompt_embed_to_sae_pre)
prompt_embed_to_sae_post.shape, prompt_embed_to_sae_post.shape, torch.allclose(prompt_embed_to_sae, prompt_embed_to_sae_post, atol=1e-5)

# Concepts

In [None]:
def get_tokens(prompt, padding='max_length'):
    return pipe.tokenizer.convert_ids_to_tokens(pipe.tokenizer.encode(prompt, max_length=77, truncation=True, padding=padding))

In [None]:
concepts = []
with open(root_dir/"model/data/clip_disect_20k.txt") as f:
    concepts = [line.strip() for line in f.readlines()]

In [None]:
def get_sae_representations(prompt):
    num_tokens = len(pipe.tokenizer.convert_ids_to_tokens(pipe.tokenizer.encode(prompt)))
    tokens = pipe.tokenizer.convert_ids_to_tokens(pipe.tokenizer.encode(prompt, max_length=77, truncation=True, padding='max_length'))
    max_tok_len = max([len(t) for t in tokens])
    a = pipe.encode_prompt(prompt, device=device, num_images_per_prompt=1, do_classifier_free_guidance=True)[0]
    a = a.squeeze(0)
    a_proc = process_data(a)
    with torch.no_grad():
        _, latents, _ = model.encode(a_proc)
    return latents[[0, num_tokens-2, num_tokens-1]]

In [None]:
def get_sae_representations_batch(prompts):
    return torch.stack([get_sae_representations(prompt) for prompt in prompts])

In [None]:
get_sae_representations('horse')

In [None]:
get_sae_representations_batch(concepts[:20])

## New

In [None]:
get_tokens('horse', False)

In [None]:
get_tokens('horse')

In [None]:
def get_latent_ids(prompt, num_padding=0):
    num_tokens = len(pipe.tokenizer.convert_ids_to_tokens(pipe.tokenizer.encode(prompt)))
    tokens = pipe.tokenizer.convert_ids_to_tokens(pipe.tokenizer.encode(prompt, max_length=77, truncation=True, padding='max_length'))
    max_tok_len = max([len(t) for t in tokens])
    a = pipe.encode_prompt(prompt, device=device, num_images_per_prompt=1, do_classifier_free_guidance=True)[0]
    a = a.squeeze(0)
    a_proc = process_data(a)
    with torch.no_grad():
        _, latents, _ = model.encode(a_proc)
    num_print = min(77, num_tokens+num_padding)
    for tok, latent in zip(tokens[:num_print], latents[:num_print]):
        nonzero_values = latent[latent != 0]
        nonzero_indices = latent.nonzero().squeeze()
        sorted_values, sort_indices = torch.sort(nonzero_values, descending=True)
        sorted_indices = nonzero_indices[sort_indices]
        print(f'{tok}'.ljust(max_tok_len), {v.item():f"{i.item():.4f}" for v,i in zip(sorted_indices, sorted_values)})

In [None]:
get_latent_ids('horse')

In [None]:
get_latent_ids('horse', 77)

In [None]:
get_latent_ids(prompt)

In [None]:
def generate_modified_image(prompt, sae_latent_ids, scale, seed=DEFAULT_SEED):
    prompt_embed = pipe.encode_prompt(prompt, device=device, num_images_per_prompt=1, do_classifier_free_guidance=True)[0]
    prompt_embed_to_sae = prompt_embed.squeeze(0)
    prompt_embed_to_sae_pre = process_data(prompt_embed_to_sae)
    with torch.no_grad():
        _, sae_latents, info = model.encode(prompt_embed_to_sae_pre)
        prompt_embed_to_sae_reconstructed = model.decode(sae_latents, info)
        for sae_latent_id in sae_latent_ids:
            sae_latents[:, sae_latent_id] *= scale
        prompt_embed_to_sae_reconstructed_modified = model.decode(sae_latents, info)

    prompt_embed_to_sae_reconstructed_post = invert_preprocess(prompt_embed_to_sae_reconstructed)
    diff = prompt_embed_to_sae - prompt_embed_to_sae_reconstructed_post
    prompt_embed_to_sae_reconstructed_post_modified = invert_preprocess(prompt_embed_to_sae_reconstructed_modified)

    generator = torch.Generator(device).manual_seed(seed)
    return pipe(prompt_embeds=(prompt_embed_to_sae_reconstructed_post_modified + diff).unsqueeze(0), generator=generator).images[0]

In [None]:
def generate_modified_images(prompt, sae_latent_ids, scale, seed=DEFAULT_SEED):
    print(prompt)
    for sae_latent_id in sae_latent_ids:
        print(f'SAE latent id: {sae_latent_id}, scale: {scale}')
        display(generate_modified_image(prompt, [sae_latent_id], scale, seed))

In [None]:
original_image = generate_image(prompt)
original_image

In [None]:
#generate_modified_images(prompt, [12114, 5722, 1962, 3079, 2461, 7928, 5482, 3791], 0)

In [None]:
generate_modified_image(prompt, [9515], 0)

In [None]:
generate_modified_image(prompt, [12214], 0)

In [None]:
generate_modified_image(prompt, [1825], 0)

In [None]:
generate_modified_image(prompt, [12214], 0)

In [None]:
generate_modified_image(prompt, [12114, 9515], 0)

In [None]:
generate_modified_image(prompt, [5722], 2)

In [None]:
original_image
labels = ["an image containing a horse", "an image without a horse"]
original_predictions = clip(original_image, candidate_labels=labels)
modified_predictions = clip(image_from_embed, candidate_labels=labels)
print("Original image predictions:", original_predictions)
print("Modified image predictions:", modified_predictions)

# SD comparison

In [None]:
generate_image('a photo of an astronaut on mars')

In [None]:
generate_image('a photo of an astronaut not riding a horse on mars')

In [None]:
generate_image('a photo of an astronaut riding a on mars')

# Green

In [None]:
idx = concepts.index("green")
a = pipe.encode_prompt(concepts[idx], device=device, num_images_per_prompt=1, do_classifier_free_guidance=True)[0]
a = a.squeeze(0)
a_proc = process_data(a)
with torch.no_grad():
    _, latents, _ = model.encode(a_proc)
for latent in latents[1:10]:
    nonzero_values = latent[latent != 0]
    nonzero_indices = latent.nonzero().squeeze()
    sorted_values, sort_indices = torch.sort(nonzero_values, descending=True)
    sorted_indices = nonzero_indices[sort_indices]
    print("Sorted values:", {v.item():f"{i.item():.4f}" for v,i in zip(sorted_indices, sorted_values)})
print("----------------------------------")


In [None]:
with torch.no_grad():
    _, sae_latents, info = model.encode(prompt_embed_to_sae_pre)
    prompt_embed_to_sae_reconstructed = model.decode(sae_latents, info)
    print("Before modification", sae_latents[:,2649])
    sae_latents[:,2649] = 9
    prompt_embed_to_sae_reconstructed_modified = model.decode(sae_latents, info)

prompt_embed_to_sae_reconstructed_post = invert_preprocess(prompt_embed_to_sae_reconstructed)
diff = prompt_embed_to_sae - prompt_embed_to_sae_reconstructed_post
prompt_embed_to_sae_reconstructed_post_modified = invert_preprocess(prompt_embed_to_sae_reconstructed_modified)

generator = torch.Generator(device).manual_seed(seed)
image_from_embed = pipe(prompt_embeds=(prompt_embed_to_sae_reconstructed_post_modified + diff).unsqueeze(0), generator=generator).images[0]
image_from_embed

In [None]:
original_image
labels = ["an image containing green", "an image without green"]
original_predictions = clip(original_image, candidate_labels=labels)
modified_predictions = clip(image_from_embed, candidate_labels=labels)
print("Original image predictions:", original_predictions)
print("Modified image predictions:", modified_predictions)

SAE ids which are not usable when interpeting the model

11114 padding connected

3678 beginning of the sentence