In [None]:
# Target google colab notebook
DEVELOPMENT_MODE = False
import google.colab
IN_COLAB = True
print("Running as a Colab notebook")
%pip install git+https://github.com/neelnanda-io/TransformerLens.git


In [None]:
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

In [None]:
def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

In [None]:
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
from transformers import GPTNeoXForCausalLM, AutoTokenizer
hfmodel = GPTNeoXForCausalLM.from_pretrained("EleutherAI/pythia-410m-deduped")

In [None]:
model = HookedTransformer.from_pretrained("EleutherAI/pythia-410m-deduped", device=device, hf_model=hfmodel)

In [None]:
plaintext = "The capital of Ireland is called"
tokens = model.to_tokens(plaintext)
logits_out, cache = model.run_with_cache(tokens, remove_batch_dim=False) # leave batch dim so we can run layers manually
logits_out.shape

In [None]:
#define a function to run a single transformer block
def run_block(model, tokens, block_idx, cache):
    # get the block
    block = model.hf_model.transformer.h[block_idx]
    # get the block's hook point
    hook_point = model.hook_points[block_idx]
    # run the block
    out, cache = hook_point.run(tokens, cache)
    return out, cache



In [None]:
def make_logit_less_mean_embeds(target_position, target_tokenid):
    
  def logit_less_mean_embeds(last_residual):
      normed_last_residual = model.ln_final(last_residual)
      lm_logits = model.unembed.forward(normed_last_residual)
      logits_mean = torch.mean(lm_logits[0,target_position])
      return lm_logits[0,target_position,target_tokenid:target_tokenid+1]-logits_mean
  return logit_less_mean_embeds

# Approach

The idea is for each predicted token to use attribution to identify a relevant subspace of the final layer residual. Combine this with the actual residual to give relevant residual for that prediction. This residual is after all attention blocks, so only the position of the predicted token is relevant. Call the subspace layername, and the relevant residual instancename. (maybe see how many low salience elements can be attributed while maintaining the predicted token)

Make a function mapping from the next layer down (blocks.nlayer-1.hook_resid_post) to that layer, and return a loss based on distance from the relevant part of the top layer.

In [None]:
def make_transformer_block_runner(block_idx):
    def transformer_block_runner(input_residuals):
        out, cache = run_block(model, input_residuals, block_idx, cache)
        return out
    
    return transformer_block_runner

def run_block(model, tokens, block_idx, cache):
    # get the block
    block = model.hf_model.transformer.h[block_idx]
    # get the block's hook point
    hook_point = model.hook_points[block_idx]
    # run the block
    out, cache = hook_point.run(tokens, cache)
    return out, cache