In [None]:
# Generic
import html
import os
from pathlib import Path

# Transformers
from circuitsvis.attention import attention_heads
from circuitsvis.tokens import colored_tokens_multi

# Numerical Computing
import numpy as np
import torch

# Our Code
from maze_transformer.utils.notebook_utils import configure_notebook
from maze_transformer.generation.latticemaze import LatticeMaze
from maze_transformer.generation.generators import LatticeMazeGenerators
from maze_transformer.training.tokenizer import MazeTokenizer, SPECIAL_TOKENS
from maze_transformer.evaluation.plot_maze import plot_multi_paths, PathFormat
from maze_transformer.evaluation.eval_model import decode_maze_tokens_to_coords, load_model_with_configs

In [None]:
# Setup
device = configure_notebook(seed=42, dark_mode=True)
# We won't be training any models
torch.set_grad_enabled(False)

# Get latest model
# this should point towards a directory containing a run. If you don't have any runs, you can use `poetry run python scripts/create_dataset.py create ./data/maze 10 --grid_n=4`
run_path = Path("../data/maze/g4-n10")
assert run_path.exists(), f"Run path {run_path.as_posix()} does not exist"
model_path = list(sorted(run_path.glob("**/model.final.pt"), key=os.path.getmtime))[
	-1
].resolve()
model, cfg = load_model_with_configs(model_path)
maze_path = run_path / "maze_tokens.jsonl"

In [None]:
# generate a maze
grid_n: int = cfg.dataset_cfg.grid_n
maze: LatticeMaze = LatticeMazeGenerators.gen_dfs((grid_n, grid_n))
c_start = (0, 0)
c_end = (grid_n - 1, grid_n - 1)

# solve the maze explicitly
path_true = np.array(maze.find_shortest_path(
	c_start = c_start,
	c_end = c_end,
))

solved_maze: MazeTokenizer = MazeTokenizer(
	maze=maze,
	solution=np.array(maze.find_shortest_path(
		c_start=c_start,
		c_end=c_end,
	)),
)

# tokenize the maze
maze_only_tokens: list[str] = solved_maze.as_tokens(cfg.dataset_cfg.node_token_map , solution = False) + [ SPECIAL_TOKENS["start_path"] ]

print("maze tokens:", maze_only_tokens)

array_nopad = torch.tensor(
	[ cfg.dataset_cfg.tokenizer_map[t] for t in maze_only_tokens ], 
	dtype=torch.int32,
	device="cpu",
)
array = array_nopad
# print(model.to_tokens(maze_only_tokens))
# array: torch.Tensor = pad_sequence(array_nopad, cfg)

In [None]:
# have the model predict some tokens
context_str: list[str] = maze_only_tokens

# escape for html
context_str = [ html.escape(t) for t in context_str ]

array_tensor = torch.tensor(array).long().unsqueeze(0).to(device)
with torch.no_grad():
	logits, cache = model.run_with_cache(array_tensor)

attentions = [w for k, w in cache.items() if 'hook_pattern' in k]
print(f"{logits.shape = }\n{len(attentions) = }\n{[x.shape for x in attentions] = }")

# `output.attentions` is a tuple of tensors, where each element of the tuple corresponds to a layer. 
#  The tensor has dimensions (1, n_heads, n_positions, n_positions)


In [None]:
n_layers: int = len(attentions)
n_heads: int = attentions[0].shape[1]
n_tokens: int = attentions[0].shape[2]
attention_to_plot = torch.concatenate(attentions, dim=0).reshape(-1, n_tokens, n_tokens)
attention_head_names = [f"Layer {i} Head {j}" for i in range(n_layers) for j in range(n_heads)]
attention_heads(attention_to_plot,maze_only_tokens, attention_head_names)

In [None]:
#! ALEX note - there used to be a np.power(head_np, 1/4) here, not sure what that's about?
FROM_TOKEN = -1 # Look at attention from this token position to the rest of the sequence
attentions_from_token = torch.concatenate([w[0, :, FROM_TOKEN, :] for w in attentions], dim=0)
colored_tokens_multi(context_str, attentions_from_token.T, labels=attention_head_names)

In [None]:
eos_id = cfg.dataset_cfg.tokenizer_map[SPECIAL_TOKENS["end_path"]]
predictions = model.generate(array_tensor, max_new_tokens=50, 
                        	eos_token_id=eos_id, stop_at_eos=True)
#! TODO stop_eos to True once wrapped tokenizer is in the HookedTransformer

# decode the tokens
predicted_tokens = [ cfg.dataset_cfg.token_arr[t] for t in predictions[0] ]

path_predicted: list[tuple[int,int]] = decode_maze_tokens_to_coords(
	predicted_tokens[len(maze_only_tokens):],
	mazedata_cfg = cfg.dataset_cfg, 
	when_noncoord = "skip",
)

# plot the maze and both solutions
# for label, fmt, color, path in paths
plot_multi_paths(
	maze = maze,
	paths = [
		PathFormat(path_true, "true", "-", "red", {'width': 0.015}),
		PathFormat(np.array(path_predicted), "predicted", ":", "blue", {}),
	],
)