## Set Up

In [183]:
import einops
import torch
import plotly.express as px

from nnsight import LanguageModel, CONFIG, util
from nnsight.tracing.Proxy import Proxy

from IPython.display import clear_output

import plotly.express as px
import plotly.io as pio
pio.renderers.default = "plotly_mimetype+notebook_connected+colab+notebook"


### NNsight & HF Login
Make sure you have obtained your [NDIF API key](https://login.ndif.us/) and configured your workspace for [remote execution](https://nnsight.net/notebooks/features/remote_execution/).

In [184]:
# from google.colab import userdata
from nnsight import CONFIG

CONFIG.set_default_api_key('65aa7ccd76d24a48848f1210653f74f6')
clear_output()

In [7]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

### Load model

In [185]:
# Load model
model = LanguageModel("meta-llama/Meta-Llama-3.1-8B", device_map = 'auto')
print(model)

## Define Task

In [186]:
# prompts = [
#     "Assume there is a self-driving car with a sudden break failure and the car is about to cause either of the two to die: - a dog - a person If the self-driving car has to make a decision, among the two choices, it should save the ",
#     "Assume there is a self-driving car with a sudden break failure and the car is about to cause either of the two to die: - a baby - a person If the self-driving car has to make a decision, among the two choices, it should save the "
# ]

prompts = [
    "Assume there is a self-driving car with a sudden brake failure and the car is about to cause either of the two to die: - 1 person - 1 dog If the self-driving car has to make a decision, among the two choices, it should save the ",
    "Assume there is a self-driving car with a sudden brake failure and the car is about to cause either of the two to die: - 1 person - 1 baby If the self-driving car has to make a decision, among the two choices, it should save the "
]

answers = [
    ("person", "dog"),
    ("baby", "person")
]

clean_tokens = model.tokenizer([
    prompt for prompt, answer in zip(prompts, answers)
], return_tensors="pt", padding=True)["input_ids"]

corrupted_tokens = clean_tokens[
    [(i + 1 if i % 2 == 0 else i - 1) for i in range(len(clean_tokens))]
]

answer_token_indices = torch.tensor(
    [
        [model.tokenizer(answers[i][j])["input_ids"][1] for j in range(2)]
        for i in range(len(answers))
    ]
)

In [187]:
answer_token_indices

tensor([[ 9164, 18964],
        [79064,  9164]])

In [188]:
torch.cuda.empty_cache()

In [189]:
def get_logit_diff(logits, answer_token_indices=answer_token_indices):
    # if len(logits.shape) == 3:
    logits = logits[:, -1, :]
    correct_logits = logits.gather(1, answer_token_indices[:, 0].unsqueeze(1))
    incorrect_logits = logits.gather(1, answer_token_indices[:, 1].unsqueeze(1))
    return (correct_logits - incorrect_logits).mean()

with torch.no_grad():
    clean_logits = model.trace(clean_tokens, trace=False).logits.cpu()
    corrupted_logits = model.trace(corrupted_tokens, trace=False).logits.cpu()

CLEAN_BASELINE = get_logit_diff(clean_logits, answer_token_indices).item()
display(f"Clean logit diff: {CLEAN_BASELINE:.4f}")

CORRUPTED_BASELINE = get_logit_diff(corrupted_logits, answer_token_indices).item()
display(f"Corrupted logit diff: {CORRUPTED_BASELINE:.4f}")

You're using a PreTrainedTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

'Clean logit diff: 0.9046'

'Corrupted logit diff: -0.5086'

In [None]:
def ioi_metric(
    logits,
    answer_token_indices=answer_token_indices,
):
    # normalizes logit diff between 0 and 1
    return (get_logit_diff(logits, answer_token_indices) - CORRUPTED_BASELINE) / (
        CLEAN_BASELINE - CORRUPTED_BASELINE
    )

display(f"Clean Baseline is 1: {ioi_metric(clean_logits).item():.4f}")
display(f"Corrupted Baseline is 0: {ioi_metric(corrupted_logits).item():.4f}")

'Clean Baseline is 1: 1.0000'

'Corrupted Baseline is 0: 0.0000'

In [None]:
def get_component(layer, component):
    if component == "attn":
        return layer.self_attn.o_proj.input[0][0]
    else:
        assert component == "residual", "Component should be either attn or residual"
        return layer.output[0]
    
component = "attn"

clean_residual_out, clean_attention_out = [], []
corrupted_residual_out, corrupted_attention_out = [], []
corrupted_residual_grads, corrupted_attention_grads = [], []

with model.trace() as tracer:
    with tracer.invoke(clean_tokens) as invoker_clean_correct:
        for layer in model.model.layers:
            residual_out = layer.output[0]
            attention_out = layer.self_attn.o_proj.input #[0][0]
            clean_residual_out.append(residual_out.save())
            clean_attention_out.append(attention_out.save())
    
    with tracer.invoke(corrupted_tokens) as invoker_corrupted_wrong:
        for layer in model.model.layers:
            residual_out = layer.output[0]
            attention_out = layer.self_attn.o_proj.input #[0][0]
            corrupted_residual_out.append(residual_out.save())
            corrupted_attention_out.append(attention_out.save())
            corrupted_residual_grads.append(residual_out.grad.save())
            corrupted_attention_grads.append(attention_out.grad.save())
        corrupted_logits = model.lm_head.output.save()
        # Our metric uses tensors saved on cpu, so we
        # need to move the logits to cpu.
        value = ioi_metric(corrupted_logits.cpu())
        value.save()
        value.backward()

In [None]:
patching_results_single = []

for corrupted_grad, corrupted, clean, layer in zip(
    corrupted_attention_grads, corrupted_attention_out, clean_attention_out, range(len(clean_attention_out))
):

    residual_attr = einops.reduce(
        corrupted_grad.value[:,-1,:] * (clean.value[:,-1,:] - corrupted.value[:,-1,:]),
        "batch (head dim) -> head",
        "sum",
        head = 32,
        dim = 128,
    )

    patching_results_single.append(
        residual_attr.detach().cpu().numpy()
    )

fig = px.imshow(
    patching_results_single,
    color_continuous_scale="RdBu",
    color_continuous_midpoint=0.0,
    title="Patching Over Attention Heads"
)

fig.update_layout(
    xaxis_title="Head",
    yaxis_title="Layer"
)

fig.show()

In [99]:
patching_results = []

for corrupted_grad, corrupted, clean in zip(
    corrupted_residual_grads, corrupted_residual_out, clean_residual_out
):

    residual_attr = einops.reduce(
        (corrupted_grad.value) * (clean.value - corrupted.value),
        "batch pos dim -> pos",
        "sum",
    )

    patching_results.append(
        residual_attr.detach().cpu().numpy()
    )

token_labels = [f"{i}. {model.tokenizer.decode(clean_tokens[0][i])}" for i in range(len(clean_tokens[0]))]

fig = px.imshow(
    patching_results,
    color_continuous_scale="RdBu",
    color_continuous_midpoint=0.0,
    title="Patching Over Position",
    # replace x axis with tokens at position
    x=token_labels
)

fig.update_layout(
    # rotate x axis labels
    xaxis_tickangle=-45,
    xaxis_title="Position",
    yaxis_title="Layer",
    width=1200,
)

fig.show()

In [92]:
from tqdm import trange
from nnsight import util
from nnsight.tracing.Proxy import Proxy

N_LAYERS = len(model.model.layers)

with torch.no_grad():
    with model.trace() as tracer:
        with tracer.invoke(clean_tokens) as invoker:
            clean_hs = [
                model.model.layers[layer_idx].output[0].save()
                for layer_idx in range(N_LAYERS)
            ]

# fetch actual values from the proxy objects
clean_hs = util.apply(clean_hs, lambda x: x.value, Proxy)

patching_results = []
for layer_idx in trange(N_LAYERS, desc="Layer loop"):
    _patching_results = []
    for token_idx in range(clean_tokens.shape[1]):
        # Patching corrupted run at given layer and token
        with torch.no_grad():
            with model.trace() as tracer:
                with tracer.invoke(corrupted_tokens) as invoker:
                    # Apply the patch from the clean hidden states to the corrupted hidden states.
                    model.model.layers[layer_idx].output[0].t[token_idx] = clean_hs[layer_idx][..., token_idx, :]

                    patched_logits = model.lm_head.output.cpu().save()
                    patched_result = ioi_metric(patched_logits).item().save()
                _patching_results.append(patched_result)
    patching_results.append(_patching_results)

Layer loop:   0%|                                                                                                        | 0/32 [00:00<?, ?it/s]


AttributeError: type object '_empty' has no attribute 'shape'

In [61]:
from nnsight import util
from nnsight.tracing.Proxy import Proxy

patching_results = util.apply(patching_results, lambda x: x.value, Proxy)

token_labels = [f"{i}. {model.tokenizer.decode(clean_tokens[0][i])}" for i in range(len(clean_tokens[0]))]

fig = px.imshow(
    patching_results,
    color_continuous_scale="RdBu",
    color_continuous_midpoint=0.0,
    title="Patching Over Position",
    # replace x axis with tokens at position
    x=token_labels
)

fig.update_layout(
    # rotate x axis labels
    xaxis_tickangle=-45,
    xaxis_title="Position",
    yaxis_title="Layer"
)

fig.show()

ValueError: zero-size array to reduction operation maximum which has no identity

## Generation Activation Patching