In [None]:
# Generic
from pathlib import Path
import typing
import json
from pathlib import Path
import typing
import html
import copy

# Transformers
import circuitsvis # keep this import since we need to disable circuitsvis in CI
import matplotlib.pyplot as plt

# Numerical Computing
import numpy as np
import torch

# Our Code
from muutils.nbutils.configure_notebook import configure_notebook
from muutils.mlutils import pprint_summary
from maze_transformer.training.config import ConfigHolder, ZanjHookedTransformer
from maze_dataset import MazeDataset, MazeDatasetConfig
from maze_transformer.mechinterp.plot_attention import ProcessedMazeAttention



In [None]:
# Setup
DEVICE = configure_notebook(seed=42, dark_mode=True)
PATH_EXAMPLES: Path = Path("../examples/")
PATH_DATA: Path = Path("../data/")

# We won't be training any models
torch.set_grad_enabled(False)

# MODEL_PATH: Path = PATH_EXAMPLES / "multsrc_demo-g6-n10K-a_dfs-h92077_tiny-v1_sweep-v1_2023-05-20-21-30-02/model.final.zanj"
MODEL_PATH: Path = PATH_EXAMPLES / "model.hallway-jvq.final.zanj"
print(f"will try to get model from {MODEL_PATH.as_posix()}")

# get the default model from examples
MODEL: ZanjHookedTransformer = ZanjHookedTransformer.read(MODEL_PATH)
print(f"loaded model: {MODEL.zanj_model_config.name} with {MODEL.num_params()} parameters")

# generate a smaller test dataset
DATASET_TEST_CFG: MazeDatasetConfig = copy.deepcopy(MODEL.zanj_model_config.dataset_cfg)
DATASET_TEST_CFG.n_mazes = 100
DATASET_TEST: MazeDataset = MazeDataset.from_config(
    DATASET_TEST_CFG,
    local_base_path=PATH_DATA,
    verbose=True,
)
print(f"got test dataset: {DATASET_TEST.cfg.name} with {len(DATASET_TEST)} mazes")

# print a summary of the model config
pprint_summary(MODEL.zanj_model_config.model_cfg.summary())

In [None]:
# process the attention data for a single maze
ATTENTION_DATA: ProcessedMazeAttention = ProcessedMazeAttention.from_model_and_dataset(
	model=MODEL, dataset=DATASET_TEST, n_mazes=1,
)[0]

# print the summary of the attention data
pprint_summary(ATTENTION_DATA.summary())

# show the actual maze we are looking at
plt.imshow(ATTENTION_DATA.input_maze.as_pixels())

In [None]:
# plot a highlight of which tokens are being attended to in the sequence
ATTENTION_DATA.plot_colored_tokens_multi()

In [None]:
ATTENTION_DATA.plot_attentions_on_maze(
    predict_path_len=10,
    model=MODEL,
    dataset_cfg=DATASET_TEST.cfg,
)

In [None]:
# this function plots the attention head values for all the token pairs, for each head on each layer
ATTENTION_DATA.plot_attentions()