# testing backup generally

# imports

In [1]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import Tensor

import numpy as np
import pandas as pd
import einops
from fancy_einsum import einsum
import tqdm.auto as tqdm
import random
from pathlib import Path
# import plotly.express as px
from torch.utils.data import DataLoader
from typing import Union, List, Optional, Callable, Tuple, Dict, Literal, Set
from jaxtyping import Float, Int
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.utils import to_numpy
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache, patching

import plotly.express as px
import circuitsvis as cv
import os, sys

In [2]:
if not os.path.exists("path_patching.py"):
        !wget https://github.com/callummcdougall/path_patching/archive/refs/heads/main.zip
        !unzip /content/main.zip 'path_patching-main/ioi_dataset.py'
        !unzip /content/main.zip 'path_patching-main/path_patching.py'
        sys.path.append("/content/path_patching-main")
        os.remove("/content/main.zip")
        os.rename("/content/path_patching-main/ioi_dataset.py", "ioi_dataset.py")
        os.rename("/content/path_patching-main/path_patching.py", "path_patching.py")
        os.rmdir("/content/path_patching-main")

from path_patching import Node, IterNode, path_patch, act_patch

In [3]:
from neel_plotly import imshow, line, scatter, histogram
import tqdm
torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"

In [4]:
device

'cuda'

In [65]:
update_layout_set = {
    "xaxis_range", "yaxis_range", "hovermode", "xaxis_title", "yaxis_title", "colorbar", "colorscale", "coloraxis", "title_x", "bargap", "bargroupgap", "xaxis_tickformat",
    "yaxis_tickformat", "title_y", "legend_title_text", "xaxis_showgrid", "xaxis_gridwidth", "xaxis_gridcolor", "yaxis_showgrid", "yaxis_gridwidth", "yaxis_gridcolor",
    "showlegend", "xaxis_tickmode", "yaxis_tickmode", "xaxis_tickangle", "yaxis_tickangle", "margin", "xaxis_visible", "yaxis_visible", "bargap", "bargroupgap"
}

def imshow(tensor, return_fig = False, renderer=None, **kwargs):
    kwargs_post = {k: v for k, v in kwargs.items() if k in update_layout_set}
    kwargs_pre = {k: v for k, v in kwargs.items() if k not in update_layout_set}
    facet_labels = kwargs_pre.pop("facet_labels", None)
    border = kwargs_pre.pop("border", False)
    if "color_continuous_scale" not in kwargs_pre:
        kwargs_pre["color_continuous_scale"] = "RdBu"
    if "margin" in kwargs_post and isinstance(kwargs_post["margin"], int):
        kwargs_post["margin"] = dict.fromkeys(list("tblr"), kwargs_post["margin"])
    if "color_continuous_midpoint" not in kwargs_pre:
        fig = px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, **kwargs_pre)
    else:
        fig = px.imshow(utils.to_numpy(tensor), **kwargs_pre)
    if facet_labels:
        for i, label in enumerate(facet_labels):
            fig.layout.annotations[i]['text'] = label
    if border:
        fig.update_xaxes(showline=True, linewidth=1, linecolor='black', mirror=True)
        fig.update_yaxes(showline=True, linewidth=1, linecolor='black', mirror=True)
    # things like `xaxis_tickmode` should be applied to all subplots. This is super janky lol but I'm under time pressure
    for setting in ["tickangle"]:
      if f"xaxis_{setting}" in kwargs_post:
          i = 2
          while f"xaxis{i}" in fig["layout"]:
            kwargs_post[f"xaxis{i}_{setting}"] = kwargs_post[f"xaxis_{setting}"]
            i += 1
    fig.update_layout(**kwargs_post)
    fig.show(renderer=renderer)
    if return_fig:
      return fig

# def hist(tensor, renderer=None, **kwargs):
#     kwargs_post = {k: v for k, v in kwargs.items() if k in update_layout_set}
#     kwargs_pre = {k: v for k, v in kwargs.items() if k not in update_layout_set}
#     names = kwargs_pre.pop("names", None)
#     if "barmode" not in kwargs_post:
#         kwargs_post["barmode"] = "overlay"
#     if "bargap" not in kwargs_post:
#         kwargs_post["bargap"] = 0.0
#     if "margin" in kwargs_post and isinstance(kwargs_post["margin"], int):
#         kwargs_post["margin"] = dict.fromkeys(list("tblr"), kwargs_post["margin"])
#     fig = px.histogram(x=tensor, **kwargs_pre).update_layout(**kwargs_post)
#     if names is not None:
#         for i in range(len(fig.data)):
#             fig.data[i]["name"] = names[i // 2]
#     fig.show(renderer)

In [6]:
from plotly import graph_objects as go
from plotly.subplots import make_subplots

In [7]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,

    device = device
)

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-small into HookedTransformer


# Set up Openwebtext and test for backup, generally

In [8]:
owt_dataset = utils.get_dataset("openwebtext")

Found cached dataset openwebtext-10k (/data/cody_rushing/.cache/huggingface/datasets/stas___openwebtext-10k/plain_text/1.0.0/3a8df094c671b4cb63ed0b41f40fb3bd855e9ce2e3765e5df50abcdfb5ec144b)


In [9]:
owt_dataset[0]["text"]

"A magazine supplement with an image of Adolf Hitler and the title 'The Unreadable Book' is pictured in Berlin. No law bans “Mein Kampf” in Germany, but the government of Bavaria, holds the copyright and guards it ferociously. (Thomas Peter/REUTERS)\n\nThe city that was the center of Adolf Hitler’s empire is littered with reminders of the Nazi past, from the bullet holes that pit the fronts of many buildings to the hulking Luftwaffe headquarters that now house the Finance Ministry.\n\nWhat it doesn’t have, nor has it since 1945, are copies of Hitler’s autobiography and political manifesto, “Mein Kampf,” in its bookstores. The latest attempt to publish excerpts fizzled this week after the Bavarian government challenged it in court, although an expurgated copy appeared at newspaper kiosks around the country.\n\nBut in Germany — where keeping a tight lid on Hitler’s writings has become a rich tradition in itself — attitudes toward his book are slowly changing, and fewer people are objecti

In [117]:
all_owt_tokens = model.to_tokens(owt_dataset[0:16]["text"])

owt_tokens = all_owt_tokens[0:8]
corrupted_owt_tokens = all_owt_tokens[8:16]

In [119]:
logits, cache = model.run_with_cache(owt_tokens)

In [120]:
# print loss on each token
utils.lm_accuracy(logits, owt_tokens)

tensor(0.2791, device='cuda:0')

In [121]:
utils.lm_cross_entropy_loss(logits, owt_tokens)

tensor(5.7209, device='cuda:0')

In [122]:
owt_tokens.shape

torch.Size([8, 1024])

# Helper Functions

In [131]:
def residual_stack_to_direct_effect(
    residual_stack: Float[Tensor, "... batch d_model"],
    cache: ActivationCache,
    effect_directions: Float[Tensor, "batch d_model"],
) -> Float[Tensor, "..."]:
    '''
    Gets the avg direct effect between the correct and incorrect answer for a given
    stack of components in the residual stream.
    '''
    batch_size = residual_stack.size(-2)
    scaled_residual_stack = cache.apply_ln_to_stack(residual_stack, layer=-1, pos_slice=-1)

    return einops.einsum(
        scaled_residual_stack, effect_directions,
        "... batch d_model, batch d_model -> ..."
    ) / batch_size

def multiple_pos_residual_stack_to_direct_effect(
    residual_stack: Float[Tensor, "... pos batch d_model"],
    cache: ActivationCache,
    effect_directions: Float[Tensor, "pos batch d_model"],
) -> Float[Tensor, "..."]:
    '''
    Gets the avg direct effect between the correct and incorrect answer for a given
    stack of components in the residual stream.
    '''
    batch_size = residual_stack.size(-2)
    scaled_residual_stack = cache.apply_ln_to_stack(residual_stack, layer=-1, pos_slice=-1)

    return einops.einsum(
        scaled_residual_stack, effect_directions,
        "... pos batch d_model, pos batch d_model -> ... pos "
    ) / batch_size

In [132]:
def collect_direct_effect(cache, correct_tokens, title = "Direct Effect of Heads", display = True):
    # break down the output of the model on a per-head basis
    clean_per_head_residual = cache.stack_head_results(layer = -1, return_labels = False, apply_ln = False) # per_head_residual.shape = heads batch seq_pos d_model
    clean_per_head_residual: Float["head seq batch d_model"] = einops.rearrange(clean_per_head_residual, "head batch seq_pos d_model -> head seq_pos batch d_model")

    # token directions in space
    token_residual_directions = model.tokens_to_residual_directions(correct_tokens)
    token_residual_directions = einops.rearrange(token_residual_directions, "batch seq_pos d_model -> seq_pos batch d_model")

    per_head_direct_effect: Float[Tensor, "heads pos"] = multiple_pos_residual_stack_to_direct_effect(clean_per_head_residual, cache, token_residual_directions)
    
    
    
    
    if display:    
        per_head_direct_effect = per_head_direct_effect[:, 0:100].mean(dim = 1)

        per_head_direct_effect = einops.rearrange(per_head_direct_effect, "(n_layer n_heads_per_layer) -> n_layer n_heads_per_layer", n_layer = model.cfg.n_layers, n_heads_per_layer = model.cfg.n_heads)
        
        fig = imshow(
            torch.stack([per_head_direct_effect]),
            return_fig = True,
            facet_col = 0,
            facet_labels = [f"Direct Effect of Heads"],
            title=title,
            labels={"x": "Head", "y": "Layer", "color": "Logit Contribution"},
            #coloraxis=dict(colorbar_ticksuffix = "%"),
            border=True,
            width=500,
            margin={"r": 100, "l": 100}
        )
    else:
        return per_head_direct_effect

In [136]:
def return_item(item):
  return item


# Clean vs Ablated Direct Effects

In [133]:
collect_direct_effect(cache, owt_tokens)

In [134]:
per_head_direct_effect = collect_direct_effect(cache, owt_tokens, display = False)
per_head_direct_effect = einops.rearrange(per_head_direct_effect, "(n_layer n_heads_per_layer) pos -> n_layer n_heads_per_layer pos", n_layer = model.cfg.n_layers, n_heads_per_layer = model.cfg.n_heads)

In [141]:
# plot a histogram of the values of the direct effect at [11, 8] 
fig = go.Figure()

# set number of bins, and make graphs overlap
fig.update_layout(barmode='overlay', )


for i in range(12):
     fig.add_trace(
          go.Histogram(
               x=per_head_direct_effect[11, i, :].cpu(),
                    name=f"(11.{i})",
                    opacity = 0.6,
                    nbinsx= int(160 * (per_head_direct_effect[11, i, :].max().item() - per_head_direct_effect[11, i, :].min().item())),

          ),
          #row=1, col=1
     )

# add marginal


fig.update_layout(
     title="Direct Effect of Heads",
     xaxis_title="Logit Contribution",
     yaxis_title="Count",
     width=1000,
     height=600,
     
     #coloraxis=dict(colorbar_ticksuffix = "%"),
     #facet_col=0,
     #facet_labels=["Direct Effect of Heads"],
     #border=True,
)
# update x-axis range
fig.update_xaxes(range=[-1, 1])
fig.show()

In [146]:
new_cache = act_patch(model, owt_tokens, [Node("z", 10, 5)], return_item, corrupted_owt_tokens, apply_metric_to_cache= True)

In [147]:
new_cache

ActivationCache with keys ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.ln1.hook_scale', 'blocks.0.ln1.hook_normalized', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.ln2.hook_scale', 'blocks.0.ln2.hook_normalized', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.ln1.hook_scale', 'blocks.1.ln1.hook_normalized', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.hook_attn_out', 'blocks.1.hook_resid_mid', 'blocks.1.ln2.hook_scale', 'blocks.1.ln2.hook_normalized', 'blocks.1.mlp.hook_pre', 'blocks.1.mlp.hook_post', 'blocks.1.hook_mlp_out', 'blocks.1.hook_resid_post', 'blocks.2.hook_re

In [148]:
ablated_per_head_direct_effect = collect_direct_effect(new_cache, owt_tokens, display = True)


Tried to stack head results when they weren't cached. Computing head results now


In [149]:
ablated_per_head_direct_effect = collect_direct_effect(new_cache, owt_tokens, display = False)
ablated_per_head_direct_effect = einops.rearrange(ablated_per_head_direct_effect, "(n_layer n_heads_per_layer) pos -> n_layer n_heads_per_layer pos", n_layer = model.cfg.n_layers, n_heads_per_layer = model.cfg.n_heads)

In [150]:
# plot a histogram of the values of the direct effect at [11, 8] 
fig = go.Figure()

# set number of bins, and make graphs overlap
fig.update_layout(barmode='overlay', )



fig.add_trace(
    go.Histogram(
        x=per_head_direct_effect[11, 8, :].cpu(),
            name=f"(11.8)",
            opacity = 0.6,
            nbinsx= int(100 * (per_head_direct_effect[11, 8, :].max().item() - per_head_direct_effect[11, 8, :].min().item())),

    ),
    #row=1, col=1
)


fig.add_trace(
    go.Histogram(
        x=ablated_per_head_direct_effect[11, 8, :].cpu(),
            name=f"Ablated (11.8)",
            opacity = 0.6,
            nbinsx= int(100 * (ablated_per_head_direct_effect[11, 8, :].max().item() - ablated_per_head_direct_effect[11, 8, :].min().item())),

    ),
    #row=1, col=1
)
    

# add marginal


fig.update_layout(
     title="Direct Effect of Heads",
     xaxis_title="Logit Contribution",
     yaxis_title="Count",
     width=1000,
     height=600,
     
     #coloraxis=dict(colorbar_ticksuffix = "%"),
     #facet_col=0,
     #facet_labels=["Direct Effect of Heads"],
     #border=True,
)
# update x-axis range
fig.update_xaxes(range=[-0.2, 1])
fig.show()