In [None]:
import torch
import pandas as pd
import numpy as np
import math

import textwrap

from typing import Optional
from diffusers import StableDiffusionPipeline
from huggingface_hub import notebook_login

import matplotlib.pyplot as plt
from PIL import Image
from collections import defaultdict

notebook_login()

In [None]:
pipe = StableDiffusionPipeline.from_pretrained(
    "stable-diffusion-v1-5/stable-diffusion-v1-5", # "Manojb/stable-diffusion-2-1-base" # "CompVis/stable-diffusion-v1-4",
    safety_checker=None,
    torch_dtype=torch.float16
).to("cuda")

In [None]:
def get_unet_resnets(pipe):
    resnets = {}
    
    for i, block in enumerate(pipe.unet.down_blocks):
        for j, resnet in enumerate(block.resnets):
            resnets[f"down_block_{i}_resnet_{j}"] = resnet

    for j, resnet in enumerate(pipe.unet.mid_block.resnets):
        resnets[f"mid_block_resnet_{j}"] = resnet

    for i, block in enumerate(pipe.unet.up_blocks):
        for j, resnet in enumerate(block.resnets):
            resnets[f"up_block_{i}_resnet_{j}"] = resnet

    return resnets

def get_unet_trans_ff(pipe):
    ff_nets = {}

    for i, block in enumerate(pipe.unet.down_blocks):
        if hasattr(block, "attentions"):
            for j, attn in enumerate(block.attentions):
                for k, transformer in enumerate(attn.transformer_blocks):
                    name = f"down_block_{i}_attn_{j}_trans_{k}_ff"
                    ff_nets[name] = transformer.ff
                    
    if hasattr(pipe.unet.mid_block, "attentions"):
        for j, attn in enumerate(pipe.unet.mid_block.attentions):
            for k, transformer in enumerate(attn.transformer_blocks):
                name = f"mid_block_attn_{j}_trans_{k}_ff"
                ff_nets[name] = transformer.ff
                
    for i, block in enumerate(pipe.unet.up_blocks):
        if hasattr(block, "attentions"):
            for j, attn in enumerate(block.attentions):
                for k, transformer in enumerate(attn.transformer_blocks):
                    name = f"up_block_{i}_attn_{j}_trans_{k}_ff"
                    ff_nets[name] = transformer.ff

    return ff_nets

resnets = get_unet_resnets(pipe)
ff_nets = get_unet_trans_ff(pipe)
print(list(resnets.keys()))
print(list(ff_nets.keys()))

In [None]:
dogs_dataset = pd.read_csv('/kaggle/input/new-nudity-steering/nudity.csv')

dog_prompts = dogs_dataset['positive'].tolist()
non_dog_prompts = dogs_dataset['negative'].tolist()

In [None]:

def collect_residual_streams(
    pipe: StableDiffusionPipeline,
    forget_set: list[str],
    retain_set: list[str],
    guidance: float,
    resnets: dict,
    layers: list[str],
    timesteps: list[int]
):
    forget_acts = []
    retain_acts = []

    for idx, (forget_prompt, retain_prompt) in enumerate(zip(forget_set, retain_set)):
        print(f'[{idx+1}] Extracting acts for forget prompt: {forget_prompt}')
        forget_act = get_unet_residual_stream(pipe, forget_prompt, guidance, resnets, layers, timesteps)

        print(f'[{idx+1}] Extracting acts for retain prompt: {retain_prompt}')
        retain_act = get_unet_residual_stream(pipe, retain_prompt, guidance, resnets, layers, timesteps)
        
        forget_acts.append(forget_act)
        retain_acts.append(retain_act)

    forget_layers = {}
    retain_layers = {}
    
    for l in layers:
        forget_layers[l] = torch.stack([f[l] for f in forget_acts], dim=0)
        retain_layers[l] = torch.stack([r[l] for r in retain_acts], dim=0)
        
    return forget_layers, retain_layers



def get_unet_residual_stream(
    pipe: StableDiffusionPipeline,
    prompt: str,
    guidance: float,
    resnets: dict,
    layers: list[str],
    timesteps: list[int]
):
    # designed to be simple, using batches would cause coherence issues when collecting acts.
    residuals_dict = {}
    handles = []

    current_step = 0

    def save_residuals(name):
        def hook(module, input, output):           
            if current_step in timesteps:
                # UNet calculates noise prediction for both conditioned and unconditioned input, so we take the second
                residual = output[1] if isinstance(output, tuple) else output
                residuals_dict.setdefault(name, []).append(residual[1].detach().cpu())

        return hook

    for l in layers:
        handles.append(
            resnets[l].register_forward_hook(save_residuals(l))
        )

    def callback(pipeline, step_index, timestep, callback_kwargs):
        nonlocal current_step
        current_step = step_index

        return callback_kwargs
    
    try:
        images = pipe(
            prompt,
            num_inference_steps=timesteps[-1],
            guidance_scale=guidance,
            callback_on_step_end=callback
        )
        
        return {
            layer: torch.stack(tensors, dim=0)
            for layer, tensors in residuals_dict.items()
        } # [T, C, H, W]
    except Exception as e:
        raise e
    finally:
        for h in handles:
            h.remove()

def show_images(images: list[Image.Image], prompts: list[str], cols: int = 2, width: int = 40) -> None:
    assert len(images) == len(prompts)

    rows = math.ceil(len(images) / cols)
    fig, axes = plt.subplots(rows, cols, figsize=(3 * cols, 3 * rows))

    if isinstance(axes, np.ndarray):
        axes = axes.flatten()
    else:
        axes = [axes]

    for ax in axes[len(images):]:
        ax.axis('off')

    for ax, img, prompt in zip(axes, images, prompts):
        ax.imshow(img)
        ax.axis('off')
        wrapped_prompt = "\n".join(textwrap.wrap(prompt, width=width))
        ax.text(0.5, -0.05, wrapped_prompt, fontsize=10, ha='center', va='top', transform=ax.transAxes)
        for spine in ax.spines.values():
            spine.set_visible(True)
            spine.set_color('black')
            spine.set_linewidth(1)

    plt.tight_layout()
    plt.show()

**Extract raw activations and print layer names**

In [None]:
#print(pipe.unet)
ALL_LAYERS = list((resnets | ff_nets).keys()) # list(resnets.keys())
GUIDANCE = 7.5
LAYERS = ALL_LAYERS
STEPS = 30
TIMESTEPS = list(range(1, STEPS+1))
forget_acts, retain_acts = collect_residual_streams(
    pipe,
    dog_prompts[:5],
    non_dog_prompts[:5],
    guidance=GUIDANCE,
    resnets=(resnets | ff_nets),
    layers=LAYERS,
    timesteps=TIMESTEPS
)

for layer, act in forget_acts.items():
    print(f'Layer {layer}: {act.shape}')

In [None]:
def compute_mean_differences(forget_layers_act, retain_layers_act, layer_nav=False):
    if layer_nav:
        return (forget_layers_act - retain_layers_act).mean(dim=0)
        
    result = {}
    for (layer, X), (_, Y) in zip(forget_layers_act.items(), retain_layers_act.items()):
        if X.ndim == 5: # Resnet activation
            result[layer] = (X - Y).mean(dim=(0, 3, 4)) # [T, C]
        else: # FF net activation
            result[layer] = (X - Y).mean(dim=(0, 2))
    return result

In [None]:
def process_acts(acts_dict):
    temp_dict = {}
    for layer_name, act in acts_dict.items():
        # act is a  Tensor(N, steps, C, H, W) for each item in the batch

        act = act.to('cuda') # (N, Steps, C, H, W)

        if act.ndim == 5: # Resnet activation
            spatial_averaged = act.mean(dim=(3,4)) # (N, steps, C)
        else: # Feed forward attention activation
            spatial_averaged = act.mean(dim=2) # (N, steps, C)
            
        temp_dict[layer_name] = spatial_averaged.float() 

    
    return temp_dict


def compute_scores(retain_acts, forget_acts, timesteps, top_k):
    results = {}
    retain_acts, forget_acts = process_acts(retain_acts), process_acts(forget_acts)

    for timestep in timesteps:
        timestep_dict = {}
        for layer in retain_acts:
            P = retain_acts[layer][:, timestep-timesteps[0], :]  # Positive (N, D)
            N = forget_acts[layer][:, timestep-timesteps[0], :]  # Negative (N, D)
    
            if P.shape != N.shape:
                print(f'P shape and N shape differs in {layer}')
    
            n_samples = P.shape[0]
    
            all_acts = torch.cat([P, N], dim=0) # (2N, D)
            mu_l = all_acts.mean(dim=0, keepdim=True)  # (1, D)
            sigma_l = all_acts.std(dim=0, keepdim=True) + 1e-8 # (1, D)
    
            P_tilde = (P - mu_l) / sigma_l
            N_tilde = (N - mu_l) / sigma_l
            
            v_l = compute_mean_differences(N, P, layer_nav=True) # (D)
    
            
            # Calculate means of normalized data
            mu_pos = P_tilde.mean(dim=0) # (D)
            mu_neg = N_tilde.mean(dim=0) # (D)
    
            # Instead of creating (D, D) matrix, project means onto v_l
            proj_pos = torch.dot(mu_pos, v_l)
            proj_neg = torch.dot(mu_neg, v_l)
            
            # v^T Sb v = N * (proj_pos^2 + proj_neg^2)
            sb_val = n_samples * (proj_pos**2 + proj_neg**2)
    
            # Center the data class-wise
            P_centered = P_tilde - mu_pos.unsqueeze(0) # (N, D)
            N_centered = N_tilde - mu_neg.unsqueeze(0) # (N, D)
    
            # Instead of creating (D, D) covariance, project data onto v_l
            # This calculates the variance of the data along the direction of v_l
            p_proj = torch.mv(P_centered, v_l) # (N)
            n_proj = torch.mv(N_centered, v_l) # (N)
    
            sw_pos_val = torch.sum(p_proj**2)
            sw_neg_val = torch.sum(n_proj**2)
            sw_val = sw_pos_val + sw_neg_val
    
            
            D_l = (sb_val / (sb_val + sw_val + 1e-8)).item()
    
            pair_diffs = N_tilde - P_tilde # (N, D)
            dot_products = torch.mv(pair_diffs, v_l) # (N)
            pair_norms = torch.norm(pair_diffs, dim=1) # (N,)
            v_norm = torch.norm(v_l)
            
            cosine_sims = dot_products / (pair_norms * v_norm + 1e-8)
            C_l = cosine_sims.mean().item()
    
            S_l = D_l + C_l
    
            timestep_dict[layer] = {
                "score": S_l,
                "discriminability": D_l,
                "consistency": C_l
            }
            
            del P_tilde, N_tilde, all_acts, P_centered, N_centered
            torch.cuda.empty_cache()
        sorted_layers = sorted(timestep_dict.items(), key=lambda x: x[1]['score'], reverse=True)
        results[timestep] = [x for x in sorted_layers[:top_k]]
    return results

def get_top_k_layers(results):
    res = {}
    for timestep, top in results.items():
        res[timestep] = [x[0] for x in top]
    return res

def print_report(results):
    for timestep, top in results.items():
        print(f"Timestep: {timestep}")
        for layer, score_dict in top:
            score, disc, cons = score_dict['score'], score_dict['discriminability'], score_dict['consistency']
            print(f'\tLayer: {layer} | Score: {score} | Disc: {disc} | Cons: {cons}')

In [None]:
results = compute_scores(retain_acts, forget_acts, timesteps=TIMESTEPS[:-1], top_k=5)

top_k_per_timestep = get_top_k_layers(results)
print_report(results)

In [None]:
# Calculate steering vectors
steering_vectors = compute_mean_differences(forget_acts, retain_acts)

print([(layer, steering_vector.shape) for (layer, steering_vector) in steering_vectors.items()])

In [None]:
def steer_activations(x, r, lam=-1.0):
    #print(f'Steering on {x.shape} with {r.shape}')
    
    r = r.to(x.device, x.dtype)
    r = r/r.norm()
    if x.ndim == 3:  # [C, H, W]
        r = r[:, None, None]      # shape [C, 1, 1]
        channel_dim = 0
    elif x.ndim == 4:  # [1, C, H, W]
        r = r[None, :, None, None] # shape [1, C, 1, 1]
        channel_dim = 1
    else: # [S, C] (ff layers)
        r = r[None, :] # shape [1, C]
        channel_dim = 1
        
    
    dot_product = (x * r).sum(dim=channel_dim, keepdim=True)
    
    return x + lam * (dot_product * r)


def generate_with_steering(
    pipe: StableDiffusionPipeline,
    prompt: str,
    guidance: float,
    resnets: dict,
    steering_vectors: dict[str, torch.Tensor],
    timesteps: list[int],
    lam: float,
    top_layers_per_timestep: dict[int, list[str]]
):
    # designed to be simple, using batches would cause coherence issues when collecting acts.
    residuals_dict = {}
    handles = []

    current_step = 0

    def steering_hook(layer: str, steering_vector: torch.Tensor):
        ts_index = 0
        
        def hook(module, inp, out):
            nonlocal ts_index
            #print(f"[STEERING] layer={layer_name} step={current_step}")

            # out can be tensor or (hidden, tensor)
            if isinstance(out, tuple):
                hidden, residual = out
            else:
                hidden, residual = None, out  # residual: [B, C, H, W]
                
            if current_step in timesteps:
                # If a dict with top layers for each timestep has been passed use that
                if isinstance(top_layers_per_timestep, dict): 
                    current_top_layers = top_layers_per_timestep[current_step]
                else: # Otherwise assume ALL_LAYERS has been passed
                    current_top_layers = top_layers_per_timestep
                    
                if layer in current_top_layers:
                    #print(f'[{layer}] -> Step {current_step}, ts_index {ts_index}')
                    
                    x = residual[1]
                    x_steered = steer_activations(x, steering_vector[ts_index], lam)
                    residual[1] = x_steered
                    
                ts_index += 1

            if hidden is None:
                return residual
            else:
                return (hidden, residual)

        return hook

    for layer, steering_vector in steering_vectors.items():
        handles.append(
            resnets[layer].register_forward_hook(steering_hook(layer, steering_vector))
        )

    def callback(pipeline, step_index, timestep, callback_kwargs):
        nonlocal current_step
        current_step = step_index

        return callback_kwargs
    
    try:
        return pipe(
            prompt,
            num_inference_steps=timesteps[-1],
            guidance_scale=guidance,
            callback_on_step_end=callback,
            generator=torch.Generator(device="cuda").manual_seed(362)
        ).images
    except Exception as e:
        raise e
    finally:
        for h in handles:
            h.remove()


# Run generation with steering
prompt = dog_prompts[1]

all_images = []
lambdas = []
print(prompt)
for lam in torch.arange(-3, 3, 0.5):
    steered_images = generate_with_steering(
        pipe,
        prompt,
        GUIDANCE,
        (resnets | ff_nets),
        steering_vectors,
        timesteps=TIMESTEPS,
        lam=lam,
        top_layers_per_timestep= top_k_per_timestep # ALL_LAYERS
    )
    all_images.extend(steered_images)
    lambdas.append(str(lam.item()))

# Visualize
show_images(all_images, lambdas, cols=3)