In [None]:
import torch
import pandas as pd
import numpy as np
import math

import textwrap

from typing import Optional
from diffusers import StableDiffusionPipeline
from huggingface_hub import notebook_login

import matplotlib.pyplot as plt
from PIL import Image
from collections import defaultdict
import os
notebook_login()

In [None]:
pipe = StableDiffusionPipeline.from_pretrained(
    "stable-diffusion-v1-5/stable-diffusion-v1-5", # "Manojb/stable-diffusion-2-1-base" # "CompVis/stable-diffusion-v1-4",
    safety_checker=None,
    torch_dtype=torch.float16
).to("cuda")

In [None]:
def get_unet_layers(pipe, extract_resnet=False, extract_attentions=True):
    assert extract_resnet or extract_attentions
    
    nets = {}
    
    for i, block in enumerate(pipe.unet.down_blocks):
        if extract_resnet:
            for j, resnet in enumerate(block.resnets):
               nets[f"down_block_{i}_resnet_{j}"] = resnet

        if hasattr(block, "attentions") and extract_attentions:
            for j, attn in enumerate(block.attentions):
                for k, transformer in enumerate(attn.transformer_blocks):
                    name = f"down_block_{i}_attn_{j}_trans_{k}"
                    nets[name] = transformer.attn2

    if extract_resnet:
        for j, resnet in enumerate(pipe.unet.mid_block.resnets):
            nets[f"mid_block_resnet_{j}"] = resnet

    if hasattr(pipe.unet.mid_block, "attentions") and extract_attentions:
        for j, attn in enumerate(pipe.unet.mid_block.attentions):
            for k, transformer in enumerate(attn.transformer_blocks):
                name = f"mid_block_attn_{j}_trans_{k}"
                nets[name] = transformer.attn2
                
    
    for i, block in enumerate(pipe.unet.up_blocks):
        if extract_resnet:
            for j, resnet in enumerate(block.resnets):
                nets[f"up_block_{i}_resnet_{j}"] = resnet

        if hasattr(block, "attentions") and extract_attentions:
            for j, attn in enumerate(block.attentions):
                for k, transformer in enumerate(attn.transformer_blocks):
                    name = f"up_block_{i}_attn_{j}_trans_{k}"
                    nets[name] = transformer.attn2

    return nets


nets = get_unet_layers(pipe, True, True)

ALL_LAYERS = list(nets.keys())
RESNET_LAYERS = [l for l in ALL_LAYERS if 'resnet_' in l]
ATTENTION_LAYERS = [l for l in ALL_LAYERS if 'attn_' in l]

print('Resnet layers:', RESNET_LAYERS, end='\n\n')
print('Attention layers:', ATTENTION_LAYERS)

In [4]:
dogs_dataset = pd.read_csv('/kaggle/input/prompts-steering/dogs.csv')

dog_prompts = dogs_dataset['positive'].tolist()
non_dog_prompts = dogs_dataset['negative'].tolist()

In [None]:
def collect_average_activations(
    pipe: StableDiffusionPipeline,
    forget_set: list[str],
    retain_set: list[str],
    total_steps: int,
    guidance: float,
    nets: dict,
    layers: list[str],
    timesteps: list[int]
):
    forget_acts = []
    retain_acts = []

    for idx, (forget_prompt, retain_prompt) in enumerate(zip(forget_set, retain_set)):
        print(f'[{idx+1}] Extracting acts for forget prompt: {forget_prompt}')
        forget_act = get_average_activations(pipe, forget_prompt, total_steps, guidance, nets, layers, timesteps)

        print(f'[{idx+1}] Extracting acts for retain prompt: {retain_prompt}')
        retain_act = get_average_activations(pipe, retain_prompt, total_steps, guidance, nets, layers, timesteps)
        
        forget_acts.append(forget_act)
        retain_acts.append(retain_act)

    forget_layers = {}
    retain_layers = {}
    
    for l in layers:
        forget_layers[l] = torch.stack([f[l] for f in forget_acts], dim=0)
        retain_layers[l] = torch.stack([r[l] for r in retain_acts], dim=0)
        
    return forget_layers, retain_layers



def get_average_activations(
    pipe: StableDiffusionPipeline,
    prompt: str,
    total_steps: int,
    guidance: float,
    nets: dict,
    layers: list[str],
    timesteps: list[int]
):
    # designed to be simple, using batches would cause coherence issues when collecting acts.
    result = {}
    handles = []

    current_step = 0

    def save_act(name):
        def hook(module, input, output):           
            if current_step in timesteps:
                # UNet calculates noise prediction for both conditioned and unconditioned input, so we take the second
                residual = output[1] if isinstance(output, tuple) else output

                if residual[1].ndim == 3: # Channels x Width x Height
                    act = residual[1].mean(dim=(1, 2)).detach().cpu()
                elif residual[1].ndim == 2: # Tokens x Context 
                    act = residual[1].mean(dim=0).detach().cpu()
                else:
                    raise Exception(f'Unexpected activation shape {residual[1].shape} for {name}') 
                
                result.setdefault(name, []).append(act)
                
        return hook

    for l in layers:
        handles.append(
            nets[l].register_forward_hook(save_act(l))
        )

    def callback(pipeline, step_index, timestep, callback_kwargs):
        nonlocal current_step
        current_step = step_index

        return callback_kwargs
    
    try:
        images = pipe(
            prompt,
            num_inference_steps=total_steps,
            guidance_scale=guidance,
            callback_on_step_end=callback
        )
        
        return {
            layer: torch.stack(tensors, dim=0)
            for layer, tensors in result.items()
        } # [T, C, H, W]
    except Exception as e:
        raise e
    finally:
        for h in handles:
            h.remove()

def show_images(images: list[Image.Image], prompts: list[str], cols: int = 2, width: int = 40) -> None:
    assert len(images) == len(prompts)

    rows = math.ceil(len(images) / cols)
    fig, axes = plt.subplots(rows, cols, figsize=(3 * cols, 3 * rows))

    if isinstance(axes, np.ndarray):
        axes = axes.flatten()
    else:
        axes = [axes]

    for ax in axes[len(images):]:
        ax.axis('off')

    for ax, img, prompt in zip(axes, images, prompts):
        ax.imshow(img)
        ax.axis('off')
        wrapped_prompt = "\n".join(textwrap.wrap(prompt, width=width))
        ax.text(0.5, -0.05, wrapped_prompt, fontsize=10, ha='center', va='top', transform=ax.transAxes)
        for spine in ax.spines.values():
            spine.set_visible(True)
            spine.set_color('black')
            spine.set_linewidth(1)

    plt.tight_layout()
    plt.show()

def get_timesteps_by_name(steps, name: str):
    if name == 'f':
        return list(range(0, steps // 2 + 1))
    elif name == 'l':
        return list(range(steps // 2, steps + 1))
    elif name == 'all':
        return list(range(0, steps + 1))
    else:
        raise ValueError(f'{name}')

In [None]:
GUIDANCE = 7.5
LAYERS = ATTENTION_LAYERS
STEPS = 30
TIMESTEPS_NAME = 'all'
TIMESTEPS = get_timesteps_by_name(STEPS, TIMESTEPS_NAME)
LAYER_NAV_K = len(LAYERS)

In [None]:
forget_acts, retain_acts = collect_average_activations(
    pipe,
    dog_prompts,
    non_dog_prompts,
    total_steps=STEPS,
    guidance=GUIDANCE,
    nets=nets,
    layers=LAYERS,
    timesteps=TIMESTEPS
)

#for layer, act in forget_acts.items():
#    print(f'Layer {layer}: {act.shape}')

In [None]:
def compute_scores(retain_acts, forget_acts, timesteps, top_k):
    results = {}

    for timestep in timesteps :
        timestep_dict = {}
        for layer in retain_acts:
            P = retain_acts[layer][:, timestep-timesteps[0], :].float()  # Positive (N, D)
            N = forget_acts[layer][:, timestep-timesteps[0], :].float()  # Negative (N, D)
    
            if P.shape != N.shape:
                print(f'P shape and N shape differs in {layer}')
    
            n_samples = P.shape[0]
    
            all_acts = torch.cat([P, N], dim=0) # (2N, D)
            mu_l = all_acts.mean(dim=0, keepdim=True)  # (1, D)
            sigma_l = all_acts.std(dim=0, keepdim=True) + 1e-8 # (1, D)
    
            P_tilde = (P - mu_l) / sigma_l
            N_tilde = (N - mu_l) / sigma_l

            v_l = (N - P).mean(dim=0)
    
            
            # Calculate means of normalized data
            mu_pos = P_tilde.mean(dim=0) # (D)
            mu_neg = N_tilde.mean(dim=0) # (D)
    
            # Instead of creating (D, D) matrix, project means onto v_l
            proj_pos = torch.dot(mu_pos, v_l)
            proj_neg = torch.dot(mu_neg, v_l)
            
            # v^T Sb v = N * (proj_pos^2 + proj_neg^2)
            sb_val = n_samples * (proj_pos**2 + proj_neg**2)
    
            # Center the data class-wise
            P_centered = P_tilde - mu_pos.unsqueeze(0) # (N, D)
            N_centered = N_tilde - mu_neg.unsqueeze(0) # (N, D)
    
            # Instead of creating (D, D) covariance, project data onto v_l
            # This calculates the variance of the data along the direction of v_l
            p_proj = torch.mv(P_centered, v_l) # (N)
            n_proj = torch.mv(N_centered, v_l) # (N)
    
            sw_pos_val = torch.sum(p_proj**2)
            sw_neg_val = torch.sum(n_proj**2)
            sw_val = sw_pos_val + sw_neg_val
    
            
            D_l = (sb_val / (sb_val + sw_val + 1e-8)).item()
    
            pair_diffs = N_tilde - P_tilde # (N, D)
            dot_products = torch.mv(pair_diffs, v_l) # (N)
            pair_norms = torch.norm(pair_diffs, dim=1) # (N,)
            v_norm = torch.norm(v_l)
            
            cosine_sims = dot_products / (pair_norms * v_norm + 1e-8)
            C_l = cosine_sims.mean().item()
    
            S_l = D_l + C_l
    
            timestep_dict[layer] = {
                "score": S_l,
                "discriminability": D_l,
                "consistency": C_l
            }
            
            del P_tilde, N_tilde, all_acts, P_centered, N_centered
            torch.cuda.empty_cache()
            
        sorted_layers = sorted(timestep_dict.items(), key=lambda x: x[1]['score'], reverse=True)
        results[timestep] = [x for x in sorted_layers[:top_k]]
        
    return results

def get_top_k_layers(results, k):
    res = {}
    for timestep, top in results.items():
        res[timestep] = [x[0] for x in top[:k]]
    return res

def print_report(results):
    for timestep, top in results.items():
        print(f"Timestep: {timestep}")
        for layer, score_dict in top:
            score, disc, cons = score_dict['score'], score_dict['discriminability'], score_dict['consistency']
            print(f'\tLayer: {layer} | Score: {score} | Disc: {disc} | Cons: {cons}')

def mask_vectors_by_top_k(steering_vectors, timesteps, top_k_per_timestep):
    masked_vectors = {l: v.clone() for l, v in steering_vectors.items()}
    
    for ts_index, step in enumerate(timesteps):
        active_layers = top_k_per_timestep.get(step, [])
        
        for layer_name, vector_tensor in masked_vectors.items():
            # If this layer is NOT in the active list for this step, zero it out
            if layer_name not in active_layers:
                vector_tensor[ts_index] = 0.0

    return masked_vectors

In [None]:
results = compute_scores(retain_acts, forget_acts, timesteps=TIMESTEPS, top_k=LAYER_NAV_K)

#top_k_per_timestep = get_top_k_layers(results, k=LAYER_NAV_K)
#print_report(results)

In [None]:
BEST_LAYERS = list(set([l[0] for value in results.values() for l in value]))

#BEST_LAYERS

In [None]:
def contrastive_pca(X, Y, alpha: float = 1.0, threshold: float = 0.95, min_eigen: float = 0.5):
    X = X.float()
    Y = Y.float()

    X_center = X.mean(dim=0, keepdim=True)
    Xc = X - X_center
    Yc = Y - Y.mean(dim=0, keepdim=True)
    
    N, D = Xc.shape
    M, _ = Yc.shape

    S_X = (Xc.T @ Xc) / (N - 1)
    S_Y = (Yc.T @ Yc) / (M - 1)

    S_c = S_X - alpha * S_Y

    eigvals, eigvecs = torch.linalg.eigh(S_c)

    eigvals, idx = torch.sort(eigvals, descending=True)
    eigvecs = eigvecs[:, idx]

    mask = eigvals > min_eigen
    eigvals_pos = eigvals[mask]
    eigvecs_pos = eigvecs[:, mask]

    cum_var = torch.cumsum(eigvals_pos, dim=0) / eigvals_pos.sum()
    num_components = (cum_var < threshold).sum() + 1

    eigvals_selected = eigvals_pos[:num_components]
    eigvecs_selected = eigvecs_pos[:, :num_components]

    return eigvecs_selected.cpu(), eigvals_selected.cpu(), X_center.cpu()


def compute_principal_components(forget_acts, retain_acts, layers: list[str], min_subspace_dim: int = 5, alpha: float = 1.0, threshold: float = 0.95, min_eigen: float = 0.5):
    result = defaultdict(lambda: defaultdict(int))
    for (layer, X), (_, Y) in zip(forget_acts.items(), retain_acts.items()): 
        X = forget_acts[layer]
        Y = retain_acts[layer]

        assert X.shape == Y.shape

        for ts in range(X.size(1)):
            eigvecs, eigvals, mean = contrastive_pca(X[:, ts, :], Y[:, ts, :], alpha, threshold, min_eigen)

            if len(eigvals) >= min_subspace_dim: # we require at least 5 PC  
                #print(f'Layer {layer}: shape: {X.shape} ts: {ts}, eigvecs: {eigvecs.shape}, eigvals: {eigvals.tolist()}')
                result[layer][ts] = (eigvecs, mean)

    return result

layer_principal_components = compute_principal_components(
    forget_acts,
    retain_acts,
    layers=BEST_LAYERS,
    min_subspace_dim=5,
    alpha=10,
    threshold=0.90,
    min_eigen=0.01
)

In [None]:
#layer_principal_components['down_block_2_attn_0_trans_0'][1][1].shape

In [None]:
# Calculate steering vectors

def compute_mean_differences(forget_layers_act, retain_layers_act):
    result = {}
    for (layer, X), (_, Y) in zip(forget_layers_act.items(), retain_layers_act.items()):
        result[layer] = (X - Y).mean(dim=0)

    return result

steering_vectors = compute_mean_differences(forget_acts, retain_acts)

#filtered_vectors = mask_vectors_by_top_k(
#    steering_vectors, 
#    timesteps=TIMESTEPS, 
#    top_k_per_timestep=top_k_per_timestep
#)
#print([(layer, steering_vector.shape) for (layer, steering_vector) in steering_vectors.items()])

In [None]:
def steer_activations(x, r, lam=-1.0):
    if torch.all(r == 0).item():
        return x

    r = r.to(x.device, x.dtype)
    r /= r.norm()
        
    if x.ndim == 4:  # [B, C, H, W]
        r = r[None, :, None, None] # shape [B, C, 1, 1]
        channel_dim = 1
    elif x.ndim == 3: # [B, T, C] (ff layers)
        r = r[None, None, :] # shape [B, C]
        channel_dim = 2
        
    dot_product = (x * r).sum(dim=channel_dim, keepdim=True)
    
    return x + (lam * dot_product * r)


def steer_activations_pca(x, r, pca_data: tuple[torch.Tensor, torch.Tensor], lam: float = -1.0):
    #print(f'Steering on {x.shape} with {r.shape}')
    if torch.all(r == 0).item():
        return x

    r = r.to(x.device, x.dtype)

    pcs = pca_data[0].half().to(x.device)
    mean = pca_data[1].half().to(x.device)

    #print(f'{x.shape}, {mean.shape}')
    #x -= mean

    x_compressed = (x - mean) @ pcs
    x_compressed = x @ pcs
    r_compressed = r @ pcs
    r_constructed = (r_compressed @ pcs.T) + mean
    
    r_compressed /= r_compressed.norm()
    r_constructed /= r_constructed.norm()
        
    if x.ndim == 4:  # [B, C, H, W]
        r = r[None, :, None, None] # shape [B, C, 1, 1]
        channel_dim = 1
    elif x.ndim == 3: # [B, T, C] (attention layers)
        r = r[None, None, :] # shape [B, C]
        channel_dim = 2
        
    dot_product = (x_compressed * r_compressed).sum(dim=channel_dim, keepdim=True)
    
    return x + (lam * dot_product * r_constructed)


def generate_with_steering(
    pipe: StableDiffusionPipeline,
    prompt: str,
    guidance: float,
    nets: dict,
    steering_vectors: dict[str, torch.Tensor],
    timesteps: list[int],
    inference_steps: int,
    lam: float,
    pca_dict: dict[str, dict[int, tuple[torch.Tensor, torch.Tensor]]]
):
    handles = []

    current_step = 0

    def steering_hook(layer: str, steering_vector: torch.Tensor):
        ts_index = 0
        
        def hook(module, inp, out):
            nonlocal ts_index
            #print(f"[STEERING] layer={layer_name} step={current_step}")

            # out can be tensor or (hidden, tensor)
            if isinstance(out, tuple):
                hidden, activation = out
            else:
                hidden, activation = None, out  # activation: [B, C, H, W]

            if current_step in timesteps: 
                #print(f'[{layer}] -> Step {current_step}, ts_index {ts_index}')
                B = activation.size(0)

                assert B % 2 == 0
                
                x = activation[B // 2 :, :, :]

                pca_data = pca_dict.get(layer, {}).get(current_step, None) if pca_dict is not None else None

                if pca_data is not None:
                    x_steered = steer_activations_pca(x, steering_vector[ts_index], pca_data, lam)
                else:
                    x_steered = steer_activations(x, steering_vector[ts_index], lam)
                
                activation[B // 2 :] = x_steered
                    
                ts_index += 1

            if hidden is None:
                return activation
            else:
                return (hidden, activation)

        return hook

    for layer, steering_vector in steering_vectors.items():
        handles.append(
            nets[layer].register_forward_hook(steering_hook(layer, steering_vector))
        )

    def callback(pipeline, step_index, timestep, callback_kwargs):
        nonlocal current_step
        current_step = step_index

        return callback_kwargs
    
    try:
        return pipe(
            prompt,
            num_inference_steps=inference_steps,
            guidance_scale=guidance,
            callback_on_step_end=callback,
            generator=torch.Generator(device="cuda").manual_seed(362)
        ).images
    except Exception as e:
        raise e
    finally:
        for h in handles:
            h.remove()


# Run generation with steering

all_images = []
lambdas = []


BATCH_SIZE = 5

LAMBDA_PARAMS = [-4, -3.5, -3, -2.5, -2, -1.5, -1, 0]
K_PARAMS = [1, 3, 5, len(LAYERS)]
TIMESTEP_PARAMS = ['f', 'l', 'all']
OUTPUT_DIR = './'
os.makedirs(f'{OUTPUT_DIR}/base_imgs', exist_ok=True)
os.makedirs(f'{OUTPUT_DIR}/steered_imgs', exist_ok=True)


for i in range(0, len(dog_prompts), BATCH_SIZE):
    prompt_batch = dog_prompts[i:i + BATCH_SIZE]
    print(prompt_batch)
    for lam in LAMBDA_PARAMS:
        if lam == 0:
            images = generate_with_steering(
                        pipe,
                        prompt_batch,
                        GUIDANCE,
                        nets,
                        steering_vectors,
                        timesteps=get_timesteps_by_name(STEPS, 'all'),
                        inference_steps=STEPS,
                        lam=lam,
                        pca_dict=None
                    )   
            for idx, image in enumerate(images):
                filename = f'{OUTPUT_DIR}/base_imgs/{i + idx}.png'
                image.save(filename)
        else:
            for k in K_PARAMS:
                for t in TIMESTEP_PARAMS:
                    top_k_per_timestep = get_top_k_layers(results, k=k)
                    filtered_vectors = mask_vectors_by_top_k(
                                        steering_vectors, 
                                        timesteps=get_timesteps_by_name(STEPS, t), 
                                        top_k_per_timestep=top_k_per_timestep
                                    )
            
                    steered_images = generate_with_steering(
                        pipe,
                        prompt_batch,
                        GUIDANCE,
                        nets,
                        filtered_vectors, # steering_vectors,
                        timesteps=get_timesteps_by_name(STEPS, t),
                        inference_steps=STEPS,
                        lam=lam,
                        pca_dict=None
                    )
    
                    for idx, image in enumerate(steered_images):
                        filename = f'{OUTPUT_DIR}/steered_imgs/{i + idx}_lambda={lam}_k={k}_t={t}.png'
                        image.save(filename)


In [None]:
!zip -r output_images.zip base_imgs steered_imgs