In [None]:
import torch
import pandas as pd
import numpy as np
import math

import textwrap

from typing import Optional
from diffusers import StableDiffusionPipeline
from huggingface_hub import notebook_login

import matplotlib.pyplot as plt
from PIL import Image

notebook_login()

In [24]:
pipe = StableDiffusionPipeline.from_pretrained(
    "Manojb/stable-diffusion-2-1-base", # "CompVis/stable-diffusion-v1-4",
    safety_checker=None,
    torch_dtype=torch.float16
).to("cuda")

Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

In [25]:
def collect_residual_streams(
    pipe: StableDiffusionPipeline,
    forget_set: list[str],
    retain_set: list[str],
    steps: int = 30,
    guidance: float = 7.5,
    from_timestamp: int = 25
):
    
    forget_acts = []
    retain_acts = []

    for idx, (forget_prompt, retain_prompt) in enumerate(zip(forget_set, retain_set)):
        print(f'[{idx+1}] Extracting acts for forget prompt: {forget_prompt}')
        forget_act = get_unet_residual_stream(pipe, forget_prompt, steps, guidance, from_timestamp)

        print(f'[{idx+1}] Extracting acts for retain prompt: {retain_prompt}')
        retain_act = get_unet_residual_stream(pipe, retain_prompt, steps, guidance, from_timestamp)
        
        forget_acts.append(forget_act)
        retain_acts.append(retain_act)

    return forget_acts, retain_acts


def get_unet_residual_stream(
    pipe: StableDiffusionPipeline,
    prompt: str,
    steps: int = 30,
    guidance: float = 7.5,
    from_timestamp: int = 25
):
    # designed to be, using batches would cause coherence issues when collecting acts.    

    assert 0 < from_timestamp < steps
    
    residuals_dict = {}
    handles = []

    def save_residuals(name):
        def hook(module, input, output):
            # UNet calculates noise prediction for both conditioned and unconditioned input, so we take the second
            residual = output[1] if isinstance(output, tuple) else output
            residuals_dict.setdefault(name, []).append(residual[1].detach().cpu())
        
        return hook

    for i, block in enumerate(pipe.unet.down_blocks):
        for j, resnet in enumerate(block.resnets):
            handles.append(resnet.register_forward_hook(save_residuals(f"down_block_{i}_resnet_{j}")))

    for j, resnet in enumerate(pipe.unet.mid_block.resnets):
        handles.append(resnet.register_forward_hook(save_residuals(f"mid_block_resnet_{j}")))

    for i, block in enumerate(pipe.unet.up_blocks):
        for j, resnet in enumerate(block.resnets):
            handles.append(resnet.register_forward_hook(save_residuals(f"up_block_{i}_resnet_{j}")))

    pipe(
        prompt,
        num_inference_steps=steps,
        guidance_scale=guidance
    )

    for h in handles:
        h.remove()

    residuals_by_timestep = {
        layer: torch.stack(tensors, dim=0)[from_timestamp:]
        for layer, tensors in residuals_dict.items()
    }

    return residuals_by_timestep # [T, C, H, W]

def show_images(images: list[Image.Image], prompts: list[str], cols: int = 2, width: int = 40) -> None:
    assert len(images) == len(prompts)

    rows = math.ceil(len(images) / cols)
    fig, axes = plt.subplots(rows, cols, figsize=(3 * cols, 3 * rows))

    if isinstance(axes, np.ndarray):
        axes = axes.flatten()
    else:
        axes = [axes]

    for ax in axes[len(images):]:
        ax.axis('off')

    for ax, img, prompt in zip(axes, images, prompts):
        ax.imshow(img)
        ax.axis('off')
        wrapped_prompt = "\n".join(textwrap.wrap(prompt, width=width))
        ax.text(0.5, -0.05, wrapped_prompt, fontsize=10, ha='center', va='top', transform=ax.transAxes)
        for spine in ax.spines.values():
            spine.set_visible(True)
            spine.set_color('black')
            spine.set_linewidth(1)

    plt.tight_layout()
    plt.show()


def get_layers_activations_at_timestep(forget_acts, retain_acts, layers: list[str], ts_idx: int):
    result = {}

    forget_layers = {}
    retain_layers = {}
    
    for l in layers:
        forget_layers[l] = torch.stack([f[l][ts_idx] for f in forget_acts], dim=0)
        retain_layers[l] = torch.stack([r[l][ts_idx] for r in retain_acts], dim=0)
        
    return forget_layers, retain_layers

In [None]:
dogs_dataset = pd.read_csv('/kaggle/input/prompts-steering/dogs.csv')

dog_prompts = dogs_dataset['positive'].tolist()
non_dog_prompts = dogs_dataset['negative'].tolist()

**Generate a bunch of images just to visualize them**

In [None]:
batch_size = 2
images_no = 6 # len(dog_prompts)

all_images = []
all_prompts = []

for i in range(0, images_no, batch_size):
    pos_batch = dog_prompts[i:i+batch_size]
    neg_batch = non_dog_prompts[i:i+batch_size]

    pos_images = pipe(pos_batch, num_inference_steps=30, guidance_scale=8).images
    neg_images = pipe(neg_batch, num_inference_steps=30, guidance_scale=8).images

    for p_img, n_img, p_prompt, n_prompt in zip(pos_images, neg_images, pos_batch, neg_batch):
        all_images.append(p_img)
        all_images.append(n_img)
        all_prompts.append(p_prompt)
        all_prompts.append(n_prompt)

show_images(all_images, all_prompts, cols=2)

**Extract raw activations and print layer names**

In [5]:
#print(pipe.unet)
forget_acts, retain_acts = collect_residual_streams(pipe, dog_prompts[:5], non_dog_prompts[:5], 30, 10, 29)

layer_names = set(forget_acts[0].keys()) # take first prompt, same layer names...

layer_names

[1] Extracting acts for forget prompt: A playful golden retriever running in a sunny park, photorealistic


  0%|          | 0/30 [00:00<?, ?it/s]

[1] Extracting acts for retain prompt: A playful child running in a sunny park, photorealistic


  0%|          | 0/30 [00:00<?, ?it/s]

[2] Extracting acts for forget prompt: A group of dogs playing in the snow, winter scene, cinematic lighting


  0%|          | 0/30 [00:00<?, ?it/s]

[2] Extracting acts for retain prompt: A group of children playing in the snow, winter scene, cinematic lighting


  0%|          | 0/30 [00:00<?, ?it/s]

[3] Extracting acts for forget prompt: Close-up of a dog’s face, detailed fur and expressive eyes, portrait


  0%|          | 0/30 [00:00<?, ?it/s]

[3] Extracting acts for retain prompt: Close-up of a cat’s face, detailed fur and expressive eyes, portrait


  0%|          | 0/30 [00:00<?, ?it/s]

[4] Extracting acts for forget prompt: A dog catching a frisbee in mid-air, dynamic sports shot


  0%|          | 0/30 [00:00<?, ?it/s]

[4] Extracting acts for retain prompt: A boy catching a frisbee in mid-air, dynamic sports shot


  0%|          | 0/30 [00:00<?, ?it/s]

[5] Extracting acts for forget prompt: Watercolor painting of a happy dog in a garden, soft pastel colors


  0%|          | 0/30 [00:00<?, ?it/s]

[5] Extracting acts for retain prompt: Watercolor painting of a flower in a garden, soft pastel colors


  0%|          | 0/30 [00:00<?, ?it/s]

{'down_block_0_resnet_0',
 'down_block_0_resnet_1',
 'down_block_1_resnet_0',
 'down_block_1_resnet_1',
 'down_block_2_resnet_0',
 'down_block_2_resnet_1',
 'down_block_3_resnet_0',
 'down_block_3_resnet_1',
 'mid_block_resnet_0',
 'mid_block_resnet_1',
 'up_block_0_resnet_0',
 'up_block_0_resnet_1',
 'up_block_0_resnet_2',
 'up_block_1_resnet_0',
 'up_block_1_resnet_1',
 'up_block_1_resnet_2',
 'up_block_2_resnet_0',
 'up_block_2_resnet_1',
 'up_block_2_resnet_2',
 'up_block_3_resnet_0',
 'up_block_3_resnet_1',
 'up_block_3_resnet_2'}

**Daniele insert your code here -> Layer Navigator**
Layer navigator should return a subset of the layers just printed, possibly a dict with scores for report/debug purposes

example: layer_navigator(...) -> ['down_block_2_resnet_1', 'mid_block_resnet_0', 'mid_block_resnet_1', 'up_block_2_resnet_2']

In [None]:
# daniele code

In [27]:
# Get timestep specific activations for each specified layer

forget_layers_act, retain_layers_act = get_layers_activations_at_timestep(
    forget_acts, 
    retain_acts, 
    ['down_block_2_resnet_1', 'mid_block_resnet_0', 'mid_block_resnet_1', 'up_block_2_resnet_2'],
    -1 # we take the last timestep
)

In [28]:
def mean_difference(X, Y, normalize=True):
    X_mean = X.mean(dim=0) # [C, H, W]
    Y_mean = Y.mean(dim=0) # [C, H, W]

    v = X_mean - Y_mean

    if normalize:
        v = v / v.norm()

    return v


def compute_mean_differences(forget_layers_act, retain_layers_act, normalize=True):
    result = {}
    for (layer, f), (_, r) in zip(forget_layers_act.items(), retain_layers_act.items()):
        result[layer] = mean_difference(f, r, normalize)
        
    return result


def contrastive_pca(X, Y, n_components=300, alpha=0.8):
    B, C, H, W = X.shape
    D = C * H * W
    
    X_flat = X.float().reshape(X.shape[0], -1).to('cuda') # (Bx, D)
    Y_flat = Y.float().reshape(Y.shape[0], -1).to('cuda') # (By, D)
    
    mean_X = X_flat.mean(dim=0, keepdim=True)
    mean_Y = Y_flat.mean(dim=0, keepdim=True)
    
    X_centered = X_flat - mean_X
    Y_centered = Y_flat - mean_Y

    U, S, Vh = torch.linalg.svd(X_centered, full_matrices=False)
    V = Vh.T
    
    X_proj = U @ torch.diag(S)
    Y_proj = Y_centered @ V
    
    Cx_small = (X_proj.T @ X_proj) / (X.shape[0] - 1)
    Cy_small = (Y_proj.T @ Y_proj) / (Y.shape[0] - 1)
    
    C_dual = Cx_small - alpha * Cy_small
    
    eigvals, eigvecs_small = torch.linalg.eigh(C_dual)
    
    idx = torch.argsort(eigvals, descending=True)[:n_components]
    top_vecs_small = eigvecs_small[:, idx]
    
    components = V @ top_vecs_small
    
    components = components / components.norm(dim=0, keepdim=True)
    
    return components.T

def compute_principal_componets(forget_layers_act, retain_layers_act, n_components=10, alpha=1e-3, whiten=False):
    result = {}
    for (layer, f), (_, r) in zip(forget_layers_act.items(), retain_layers_act.items()):
        result[layer] = contrastive_pca(f, r, n_components, alpha, whiten)
        
    return result

In [18]:
# THE PROBLEM IS THAT WITH PCA we have N <<< D
# Specifying 300 components to extract but passing a batch of 5-10-30 samples does not make many sense

# components = contrastive_pca(
#     forget_layers_act['down_block_2_resnet_1'],
#     retain_layers_act['down_block_2_resnet_1'],
#     300,
#     0.8
# )

# components.shape

In [36]:
mean_differences = compute_mean_differences(forget_layers_act, retain_layers_act, True)

print([(layer, steering_vector.shape) for (layer, steering_vector) in mean_differences.items()])

[('down_block_2_resnet_1', torch.Size([1280, 16, 16])), ('mid_block_resnet_0', torch.Size([1280, 8, 8])), ('mid_block_resnet_1', torch.Size([1280, 8, 8])), ('up_block_2_resnet_2', torch.Size([640, 32, 32]))]


**OLJA, inject you inference code here, mean differences contain the steering vector for each layer specified above**
First try simple mean difference, then we can try to use PCA

In [None]:
# olja, code