# Logit Lens and Direct Logit Attribution

In this notebook, we will apply direct logit attribution to our maze-transformer model. 

This is mostly just getting the data in an appropriate format and using techniques from Neel's explanatory analysis.

## Setup

In [1]:
# Generic
import os
from pathlib import Path
from copy import deepcopy
import typing

from jaxtyping import Float
from fancy_einsum import einsum
import torch
import torch.nn.functional as F

# Numerical Computing
import numpy as np
import torch
import einops

from muutils.misc import shorten_numerical_to_str
from muutils.nbutils.configure_notebook import configure_notebook
# TransformerLens imports
from transformer_lens import ActivationCache

# Our Code
from maze_dataset import MazeDataset, MazeDatasetConfig, SolvedMaze, LatticeMaze
from maze_transformer.tokenizer import HuggingMazeTokenizer
from maze_transformer.training.config import ConfigHolder, ZanjHookedTransformer, BaseGPTConfig
from maze_transformer.evaluation.eval_model import  load_model_with_configs



In [2]:
# Setup
DEVICE: torch.device = configure_notebook(seed=42, dark_mode=True)

# We won't be training any models
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x1d9106be650>

## Loading the model in

Loading in Alex's successfully trained model, it uses an embedding dimension of 384 with 6 heads and 12 layers

In [3]:
MODEL: ZanjHookedTransformer = ZanjHookedTransformer.read("../examples/hallway-medium_2023-06-16-03-40-47.iter_26554.zanj")
num_params: int = MODEL.num_params()
print(f"loaded model with {shorten_numerical_to_str(num_params)} params ({num_params = })")
GPT_CONFIG: BaseGPTConfig = MODEL.zanj_model_config.model_cfg

loaded model with 1.3M params (num_params = 1274699)


## Dataset Creation

Creating a collection of mazes to have the model predict on

In [4]:
# get 100 mazes and pass into model, storing logits and cache
n_examples: int = 100
TEST_DATASET_CFG: MazeDatasetConfig = deepcopy(MODEL.zanj_model_config.dataset_cfg)
TEST_DATASET_CFG.n_mazes = n_examples
# hacky: adjust things to work
for filter in TEST_DATASET_CFG.applied_filters:
	if len(filter["args"]) == 0:
		filter["args"] = tuple()

DATASET: MazeDataset = MazeDataset.from_config(TEST_DATASET_CFG)
DATASET_TOKENS_UNJOINED: list[list[str]] = DATASET.as_tokens(join_tokens_individual_maze=False)

get the data up to the path start token, as well as the path start tokens themselves

In [None]:
get_token_first_index: typing.Callable[[str, list[str]], int] = lambda search_token, token_list: token_list.index(search_token)

DATASET_TOKENS_UNJOINED: list[str] = [

]

In [7]:
print(DATASET_TOKENS)

['<ADJLIST_START> (6,4) <--> (6,3) ; (4,7) <--> (3,7) ; (3,4) <--> (3,5) ; (3,3) <--> (4,3) ; (4,4) <--> (4,5) ; (3,7) <--> (3,6) ; (3,3) <--> (3,4) ; (4,7) <--> (4,6) ; (4,2) <--> (4,1) ; (6,5) <--> (5,5) ; (3,5) <--> (3,6) ; (4,6) <--> (4,5) ; (5,1) <--> (4,1) ; (6,4) <--> (6,5) ; (5,1) <--> (5,2) ; (5,2) <--> (5,3) ; (5,3) <--> (5,4) ; (4,3) <--> (4,2) ; (5,4) <--> (5,5) ; <ADJLIST_END> <ORIGIN_START> (3,4) <ORIGIN_END> <TARGET_START> (4,1) <TARGET_END> <PATH_START> (3,4) (3,3) (4,3) (4,2) (4,1) <PATH_END>', '<ADJLIST_START> (3,2) <--> (2,2) ; (1,7) <--> (0,7) ; (4,7) <--> (5,7) ; (5,6) <--> (4,6) ; (4,4) <--> (4,3) ; (6,6) <--> (5,6) ; (6,2) <--> (6,3) ; (6,0) <--> (5,0) ; (3,2) <--> (3,1) ; (4,0) <--> (3,0) ; (1,5) <--> (2,5) ; (6,5) <--> (7,5) ; (6,7) <--> (5,7) ; (6,6) <--> (7,6) ; (4,6) <--> (3,6) ; (7,6) <--> (7,7) ; (2,7) <--> (3,7) ; (6,5) <--> (5,5) ; (7,2) <--> (7,1) ; (7,5) <--> (7,4) ; (7,1) <--> (7,0) ; (3,7) <--> (3,6) ; (3,4) <--> (3,5) ; (0,6) <--> (0,7) ; (1,6) <-->

In [5]:
LOGITS, CACHE = MODEL.run_with_cache(DATASET_TOKENS)

In [6]:
def token_id_find(mazes_tokens, specific_tok):
    '''
    Returns a list of token indexes for a specific token over a batch of mazes.
    '''

    path_start = []
    for maze in mazes_tokens:
        for idx, tok in enumerate(maze):
            if tok == specific_tok:
                path_start.append(idx)
    
    if len(path_start) > mazes_tokens.shape[0]:
        return f'More than one intance of token in one or more mazes'
    else:
        return path_start

def pad_tensor_to_shape(tensor, target_shape, pad_value=10):
    '''
    Used to pad the front of a tokenized maze so examples are all the same length
    '''
    padding = []
    for i in range(len(tensor.shape)-1, -1, -1):
        total_padding = target_shape[i] - tensor.shape[i]
        padding.extend([total_padding, 0])

    return F.pad(tensor, padding, value=pad_value)

# Get a list of the <PATH_START> index for each maze example
path_start = token_id_find(mazes_tokens=mazes_tokens, specific_tok=6)

# Using the <PATH_START> index, strip the tokenized maze of everything after <PATH_START> (so this is the last token)
maze_only_tokens = []
for idx, tok in enumerate(path_start):
    maze_only_tokens.append(mazes_tokens[idx][:tok+1])

# Pad the front of this stripped tokenized maze so that they are all the same length
padded_tokens = []
for maze in maze_only_tokens:
    padded_tokens.append(pad_tensor_to_shape(maze, (166,), 10))

# Make this list of examples into a stacked tensor on the correct device and appropriate dtype
padded_tokens = torch.stack(padded_tokens).long().to(device)

NameError: name 'mazes_tokens' is not defined

## Model Predictions
Using the model to make predictions and caching associated activations

In [None]:
# Have the model predict on the maze examples, storing logits and activations in cache
with torch.no_grad():
	logits, cache = MODEL.run_with_cache(padded_tokens)

In [None]:
# For this architecture, this should of length 208
len(cache.cache_dict.keys())

Logits are in shape [100, 166, 47] corresponding to batch = 100 (100 maze examples), sequence_length = 166 and vocab size = 47.

We want to predict on the next token (the first path coordinate), thus logits associated with final token.

In [None]:
# Get the last token prediction from the model
last_token_logits = logits[:, -1, :]
predictions = []
for sample in last_token_logits:
    last_token_pred = torch.argmax(sample).item()
    predictions.append(last_token_pred)
predictions = torch.tensor(predictions)
print(f'Prediction from first maze: {predictions[0]}, shape of predictions: {predictions.shape}')

In [None]:
# Are the maze tokens the same?
print(f'Are all mazes the same in the maze tokens dataset? {padded_tokens.all()}')

## Measuring model performance

**The rest of this notebook is mainly from https://colab.research.google.com/github/neelnanda-io/Easy-Transformer/blob/main/Exploratory_Analysis_Demo.ipynb just applying it to our mazes**

`answer_tokens` is just a list of [correct, incorrect] tokens.

In [None]:
# Lets create an list of for these predictions in the [correct, incorrect] format.
# The correct token is taken from the maze definition as first token after <PATH_START>, the incorrect token is set to the <TARGET>.

answer_tokens = []
for maze in mazes_tokens:
    for idx, tok in enumerate(maze):
        if tok == 6:
            answer_tokens.append([maze[idx+1], maze[len(maze)-2]])
    
answer_tokens = torch.tensor(answer_tokens).to(device=device)
answer_tokens.shape

It shows us the models performance by comparing the logits associated with a correct response minus those with an incorrect response in this case (1, 0).

In [None]:
# From Neels explanatory notebook: https://colab.research.google.com/github/neelnanda-io/Easy-Transformer/blob/main/Exploratory_Analysis_Demo.ipynb
def logits_to_ave_logit_diff(logits, answer_tokens, per_prompt=False):
    # Only the final logits are relevant for the answer
    final_logits = logits[:, -1, :]
    answer_logits = final_logits.gather(dim=-1, index=answer_tokens)
    answer_logit_diff = answer_logits[:, 0] - answer_logits[:, 1]
    if per_prompt:
        return answer_logit_diff
    else:
        return answer_logit_diff.mean()

print("Per prompt logit difference:", logits_to_ave_logit_diff(logits, answer_tokens, per_prompt=True))
original_average_logit_diff = logits_to_ave_logit_diff(logits, answer_tokens)
print("Average logit difference:", logits_to_ave_logit_diff(logits, answer_tokens).item())

## Mapping tokens into the model's residual stream

In [None]:
answer_residual_directions = model.tokens_to_residual_directions(answer_tokens)
print(f"Answer residual directions shape: {answer_residual_directions.shape}")

In [None]:
logit_diff_directions = answer_residual_directions[:, 0] - answer_residual_directions[:, 1]
print(f"Logit diff directions shape: {logit_diff_directions.shape}")

In [None]:
# cache the values at the end of the residual stream
final_residual_stream = cache["resid_post", -1]
print(f"Final reisudal stream shape: {final_residual_stream.shape}")

In [None]:
# Get the final token resid stream values (like we did above with last_pred_token)
final_token_residual_stream = final_residual_stream[:, -1, :]
print(f'Final token residual stream value shape: {final_token_residual_stream.shape}')

In [None]:
# Scaling the values in residual stream with layer norm
scaled_final_token_residual_stream = cache.apply_ln_to_stack(final_token_residual_stream, layer = -1, pos_slice=-1)

In [None]:
# Average logit diff from residual stream method
average_logit_diff = einsum("batch d_model, batch d_model -> ", scaled_final_token_residual_stream, logit_diff_directions)/len(answer_tokens)
print("Calculated average logit diff:", average_logit_diff.item())
print("Original logit difference:",original_average_logit_diff.item())

These match quite closely meaning that the residual stream has been correctly scaled.

## Logit Lens

This implementation is directly from Neel's Exploratory Analysis notebook, found here:
https://colab.research.google.com/github/neelnanda-io/Easy-Transformer/blob/main/Exploratory_Analysis_Demo.ipynb

In [None]:
def residual_stack_to_logit_diff(residual_stack: Float[torch.Tensor, "components batch d_model"], cache: ActivationCache) -> float:
    scaled_residual_stack = cache.apply_ln_to_stack(residual_stack, layer = -1, pos_slice=-1)
    return einsum("... batch d_model, batch d_model -> ...", scaled_residual_stack, logit_diff_directions)/len(answer_tokens)

In [None]:
accumulated_residual, labels = cache.accumulated_resid(layer=-1, incl_mid=True, pos_slice=-1, return_labels=True)
logit_lens_logit_diffs = residual_stack_to_logit_diff(accumulated_residual, cache)
logit_lens_cpu = logit_lens_logit_diffs.to("cpu")
y = logit_lens_cpu.numpy()
x = np.arange(GPT_CONFIG.n_layers*2+1)/2
print(type(x), type(y))

In [None]:
import matplotlib.pyplot as plt
plt.plot(x, y)
plt.title("Logit Difference from Accumulated Residual Stream")
plt.show()

In [None]:
# Layer Attribution
per_layer_residual, labels = cache.decompose_resid(layer=-1, pos_slice=-1, return_labels=True)
per_layer_logit_diffs = residual_stack_to_logit_diff(per_layer_residual, cache)
y = per_layer_logit_diffs.to("cpu").numpy()
x = np.arange(len(y))
print(type(x), type(y))

In [None]:
import matplotlib.pyplot as plt
plt.plot(x, y)
plt.title("Logit Difference for each layer")
plt.show()

## Direct Logit Attribution
Again from Neel's exploratory analysis notebook.

In [None]:
per_head_residual, labels = cache.stack_head_results(layer=-1, pos_slice=-1, return_labels=True)
per_head_logit_diffs = residual_stack_to_logit_diff(per_head_residual, cache)
per_head_logit_diffs = einops.rearrange(per_head_logit_diffs, "(layer head_index) -> layer head_index", layer=GPT_CONFIG.n_layers, head_index=GPT_CONFIG.n_heads)
data = per_head_logit_diffs.to("cpu").numpy()
plt.imshow(data, cmap = "RdBu")
plt.colorbar()
plt.title("Logit Difference from each head")
plt.show()