In [None]:
# Generic
import html
import os
from pathlib import Path

# Transformers
from circuitsvis.attention import attention_heads
from circuitsvis.tokens import colored_tokens_multi
import matplotlib.pyplot as plt

# Numerical Computing
import numpy as np
import torch

# Our Code
from maze_transformer.utils.notebook_utils import configure_notebook
from maze_transformer.generation.lattice_maze import LatticeMaze, SolvedMaze
from maze_transformer.generation.generators import LatticeMazeGenerators
from maze_transformer.training.tokenizer import SPECIAL_TOKENS, HuggingMazeTokenizer
from maze_transformer.evaluation.plot_maze import MazePlot, PathFormat
from maze_transformer.evaluation.eval_model import load_model_with_configs
from maze_transformer.utils.token_utils import tokens_to_coords
from maze_transformer.generation.constants import SPECIAL_TOKENS

In [None]:
# Setup
device = configure_notebook(seed=42, dark_mode=True)
# We won't be training any models
torch.set_grad_enabled(False)


# Get latest model
# this should point towards a directory containing a run. 
# If you don't have any runs, you can create a dataset with `poetry run python scripts/create_dataset.py create ./data/maze 10 --grid_n=4`
# Then train a model with poetry run python scripts/train_model.py ./data/maze/g4-n10`
run_path = Path("../examples/maze/g4-n10")
assert run_path.exists(), f"Run path {run_path.as_posix()} does not exist"
model_path = list(sorted(run_path.glob("**/model.final.pt"), key=os.path.getmtime))[
	-1
].resolve()
model, cfg = load_model_with_configs(model_path)
maze_path = run_path / "maze_tokens.jsonl"

In [None]:
# generate a maze
grid_n: int = cfg.dataset_cfg.grid_n
maze: LatticeMaze = LatticeMazeGenerators.gen_dfs((grid_n, grid_n))
c_start = (0, 0)
c_end = (grid_n - 1, grid_n - 1)

# solve the maze explicitly
path_true = np.array(maze.find_shortest_path(
	c_start = c_start,
	c_end = c_end,
))

solved_maze: SolvedMaze = SolvedMaze.from_lattice_maze(lattice_maze=maze, solution=path_true)

# tokenize the maze
tokens = solved_maze.to_tokens(cfg.dataset_cfg.node_token_map)
path_start_index = tokens.index(SPECIAL_TOKENS["path_start"])
maze_only_tokens = tokens[:path_start_index + 1]

print("maze tokens:", maze_only_tokens)

array = model.to_tokens(" ".join(maze_only_tokens), prepend_bos=False)

In [None]:
# have the model predict some tokens
context_str: list[str] = maze_only_tokens

# escape for html
context_str = [ html.escape(t) for t in context_str ]

array_tensor = torch.tensor(array).long().to(device)
with torch.no_grad():
	logits, cache = model.run_with_cache(array_tensor)

attentions = [w for k, w in cache.items() if 'hook_pattern' in k]
print(f"{logits.shape = }\n{len(attentions) = }\n{[x.shape for x in attentions] = }")

# `output.attentions` is a tuple of tensors, where each element of the tuple corresponds to a layer. 
#  The tensor has dimensions (1, n_heads, n_positions, n_positions)


In [None]:
n_layers: int = len(attentions)
n_heads: int = attentions[0].shape[1]
n_tokens: int = attentions[0].shape[2]
attention_to_plot = torch.concatenate(attentions, dim=0).reshape(-1, n_tokens, n_tokens)
attention_head_names = [f"Layer {i} Head {j}" for i in range(n_layers) for j in range(n_heads)]
attention_heads(attention_to_plot,maze_only_tokens, attention_head_names)

In [None]:
#! ALEX note - there used to be a np.power(head_np, 1/4) here, not sure what that's about?
FROM_TOKEN = -1 # Look at attention from this token position to the rest of the sequence
attentions_from_token = torch.concatenate([w[0, :, FROM_TOKEN, :] for w in attentions], dim=0)
colored_tokens_multi(context_str, attentions_from_token.T, labels=attention_head_names)

In [None]:
def prediction_contained_a_coordinate_token(tokens: list[str], tokenizer: HuggingMazeTokenizer) -> bool:
	"""Check if the prediction contains a coordinate token"""
	for t in tokens:
		if t not in list(tokenizer.special_tokens_map.values()) + tokenizer.additional_special_tokens:
			return True
	print("FAIL: Sampled a path - No coordinate token found before EOS")
	return False

predicted_tokens = []
while not prediction_contained_a_coordinate_token(predicted_tokens, model.tokenizer):
	predictions = model.generate(array_tensor, max_new_tokens=50, stop_at_eos=True, verbose=False)
	predicted_tokens = model.to_str_tokens(predictions)[len(maze_only_tokens):]
print("SUCCESS: Model predicted the path:")
print(predicted_tokens)

path_predicted: list[tuple[int,int]] = tokens_to_coords(
	predicted_tokens,
	maze_data_cfg = cfg.dataset_cfg, 
	when_noncoord = "skip",
)

# plot the maze and both solutions
# for label, fmt, color, path in paths
MazePlot(maze).add_true_path(path_true).add_predicted_path(path_predicted).plot()
plt.show()