# AIPI 590 - XAI | Assignment #09
### Description
### Your Name: Wilson Tseng

#### Assignment 9 - Mechanistic Interpretability:
[GitHub Link](https://github.com/smilewilson1999/XAI/blob/9ea04d05a57738a723edc637482bd56fa2d59fc9/Assignment%209%20-%20Mechanistic%20Interpretability/transformerLens_demo_ext.ipynb)


[![Open In Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/smilewilson1999/XAI/blob/main/Assignment%209%20-%20Mechanistic%20Interpretability/transformerLens_demo_ext.ipynb)

## DO:
* Use markdown and comments effectively
* Pull out classes and functions into scripts
* Ensure cells are executed in order and avoid skipping cells to maintain reproducibility
* Choose the appropriate runtime (i.e. GPU) if needed
* If you are using a dataset that is too large to put in your GitHub repository, you must either pull it in via Hugging Face Datasets or put it in an S3 bucket and use boto3 to pull from there.
* Use versioning on all installs (ie pandas==1.3.0) to ensure consistency across versions
* Implement error handling where appropriate

## DON'T:
* Absolutely NO sending us Google Drive links or zip files with data (see above).
* Load packages throughout the notebook. Please load all packages in the first code cell in your notebook.
* Add API keys or tokens directly to your notebook!!!! EVER!!!
* Include cells that you used for testing or debugging. Delete these before submission
* Have errors rendered in your notebook. Fix errors prior to submission.

In [None]:
# Please use this to connect your GitHub repository to your Google Colab notebook
# Connects to any needed files from GitHub and Google Drive
import os

# Remove Colab default sample_data
!rm -r ./sample_data

# Clone GitHub files to colab workspace
repo_name = "XAI" # Change to your repo name
git_path = 'https://github.com/smilewilson1999/XAI.git' #Change to your path
!git clone "{git_path}"

# Install dependencies from requirements.txt file
#!pip install -r "{os.path.join(repo_name,'requirements.txt')}" #Add if using requirements.txt

# Change working directory to location of notebook
notebook_dir = 'Assignment 9 - Mechanistic Interpretability'
path_to_notebook = os.path.join(repo_name, notebook_dir)
%cd "{path_to_notebook}"
%ls

In [None]:
!pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
!pip install plotly
!pip install transformer_lens

In [None]:
import torch
import plotly.express as px
import plotly.graph_objects as go
from transformer_lens import HookedTransformer
from transformer_lens.utils import to_numpy

# Setting up the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Loading the pre-trained GPT-2 Small model
model = HookedTransformer.from_pretrained("gpt2-small", device=device)

In [4]:
# Clean tips (correct indirect objects)
clean_prompt = "After John and Mary went to the store, Mary gave a bottle of milk to"

# Bad hints (wrong indirect object)
corrupted_prompt = "After John and Mary went to the store, John gave a bottle of milk to"

In [5]:
# Convert prompts to tokens
clean_tokens = model.to_tokens(clean_prompt)
corrupted_tokens = model.to_tokens(corrupted_prompt)

In [6]:
def logits_to_logit_diff(logits, correct_answer=" John", incorrect_answer=" Mary"):
    # Get token indexes for correct and incorrect answers
    correct_index = model.to_single_token(correct_answer)
    incorrect_index = model.to_single_token(incorrect_answer)
    # Calculate logit differences
    return logits[0, -1, correct_index] - logits[0, -1, incorrect_index]

In [7]:
# Get clean tips for logits and caches
clean_logits, clean_cache = model.run_with_cache(clean_tokens)
clean_logit_diff = logits_to_logit_diff(clean_logits)

# Getting logits for bad hints
corrupted_logits = model(corrupted_tokens)
corrupted_logit_diff = logits_to_logit_diff(corrupted_logits)

print(f"Clean logit difference: {clean_logit_diff.item():.3f}")
print(f"Corrupted logit difference: {corrupted_logit_diff.item():.3f}")

Clean logit difference: 4.276
Corrupted logit difference: -2.738


In [8]:
import numpy as np

# Get the number of layers and tokens of the model
num_layers = model.cfg.n_layers
num_positions = clean_tokens.shape[1]

# Init the res matrix
ioi_patching_result = np.zeros((num_layers, num_positions))

# Define the activation patch function
def residual_stream_patching_hook(resid_pre, hook, position):
    resid_pre[:, position, :] = clean_cache[hook.name][:, position, :]
    return resid_pre

# Iterate each layer and each position
for layer in range(num_layers):
    for position in range(num_positions):
        # Passing additional paras with functools.partial
        from functools import partial
        hook_fn = partial(residual_stream_patching_hook, position=position)
        # Run the model and apply hooks
        patched_logits = model.run_with_hooks(
            corrupted_tokens,
            fwd_hooks=[(f'blocks.{layer}.hook_resid_pre', hook_fn)]
        )
        # Calculating logit differences after patching
        patched_logit_diff = logits_to_logit_diff(patched_logits)
        # normalized res
        result = (patched_logit_diff - corrupted_logit_diff) / (clean_logit_diff - corrupted_logit_diff)
        ioi_patching_result[layer, position] = result.item()

In [9]:
import pandas as pd

layer_numbers = np.arange(num_layers)
position_numbers = np.arange(num_positions)
layer_grid, position_grid = np.meshgrid(position_numbers, layer_numbers)

# Pre-process structure
df = pd.DataFrame({
    'Layer': layer_grid.flatten(),
    'Position': position_grid.flatten(),
    'Effect': ioi_patching_result.flatten()
})

# Heat maps
fig = px.imshow(
    ioi_patching_result,
    labels=dict(x="Position", y="Layer", color="Normalized Logit Difference"),
    x=position_numbers,
    y=layer_numbers,
    color_continuous_scale='RdBu',
    origin='lower'
)

fig.update_layout(
    title='Activation Patching Results',
    xaxis_nticks=num_positions,
    yaxis_nticks=num_layers,
    width=800,
    height=600
)

fig.show()

In [14]:
# 3D Map - [X-axis is position; Y-axis is layer; Z-axis is impact value.] [Cited: Partially supported by GPT-4o and reference on my previous project]
def plot_3d_surface():
    fig = go.Figure(data=[go.Surface(z=ioi_patching_result, x=np.arange(num_positions), y=np.arange(num_layers))])
    fig.update_layout(
        title='Activation Patching Effect Surface',
        scene = dict(
            xaxis_title='Position',
            yaxis_title='Layer',
            zaxis_title='Effect',
        ),
        autosize=False,
        width=800,
        height=800,
        margin=dict(l=65, r=50, b=65, t=90)
    )
    fig.show()

plot_3d_surface()

This 3D plot provides a global perspective, allowing us to get deeper insights into the model's behavior across different layers and positions

In [15]:
# Show each layer dynamically (Click play!) [Cited: Inspired by Claude-3.5 Sonnet]
def create_animation():
    frames = []
    for layer in range(num_layers):
        frame = go.Frame(
            data=[go.Bar(
                x=np.arange(num_positions),
                y=ioi_patching_result[layer],
                marker_color=ioi_patching_result[layer],
                marker_colorscale='RdBu',
                marker_reversescale=True
            )],
            name=f'Layer {layer}'
        )
        frames.append(frame)

    fig = go.Figure(
        data=[go.Bar(
            x=np.arange(num_positions),
            y=ioi_patching_result[0],
            marker_color=ioi_patching_result[0],
            marker_colorscale='RdBu',
            marker_reversescale=True
        )],
        layout=go.Layout(
            title='Activation Effect per Position',
            xaxis=dict(title='Position'),
            yaxis=dict(title='Effect'),
            updatemenus=[dict(
                type="buttons",
                buttons=[dict(label="Play",
                              method="animate",
                              args=[None, {"frame": {"duration": 500, "redraw": True},
                                           "fromcurrent": True}])])]
        ),
        frames=frames
    )
    fig.show()

create_animation()

The animation shows how the effect values change dynamically with layers, improving the understanding of this model's overall structure.