In [None]:
# Generic
import typing
import os
from pathlib import Path
import typing
import html

# Plotting
import matplotlib.pyplot as plt

# Numerical Computing
import numpy as np
import torch
import pandas as pd
# Utilities
from muutils.statcounter import StatCounter

# Our Code
from maze_transformer.utils.notebook_utils import configure_notebook
from maze_transformer.generation.lattice_maze import LatticeMaze, SolvedMaze
from maze_transformer.evaluation.plot_maze import PathFormat, MazePlot
from maze_transformer.training.maze_dataset import MazeDataset, MazeDatasetConfig
from maze_transformer.evaluation.eval_model import (
    load_model_with_configs,
    predict_maze_paths,
    evaluate_model
)
from maze_transformer.evaluation.baseline_models import RandomBaseline
from maze_transformer.evaluation.path_evals import (
    PathEvals,
    PathEvalFunction,
)

from maze_transformer.utils.utils import set_reproducibility, chunks, get_checkpoint_paths_for_run

In [None]:
# Setup
device = configure_notebook(seed=42, dark_mode=True)

# We won't be training any models
torch.set_grad_enabled(False)

# Get latest model
# this should point towards a directory containing a run. 
# If you don't have any runs, you can create a dataset with `poetry run python scripts/create_dataset.py create ./data/maze 10 --grid_n=4`
# Then train a model with `poetry run python scripts/train_model.py ./data/maze/g4-n10`
run_path = Path("../examples/maze/g4-n10/g4-n10_tiny-v1_2023-03-28-22-44-54")

assert run_path.exists(), f"Run path {run_path.as_posix()} does not exist"
model_path = list(sorted(run_path.glob("**/model.final.pt"), key=os.path.getmtime))[
	-1
].resolve()
maze_path = run_path.parent / "maze_tokens.jsonl"

dataset = MazeDataset.disk_load(run_path.parent, do_config=True, do_tokens=True)

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01668119935008387, max=1.0)…

2023-04-21 15:32:32 INFO config ={'__format__': 'ConfigHolder(SerializableDataclass)', 'dataset_cfg': {'__format__': 'MazeDatasetConfig(SerializableDataclass)', 'name': 'test', 'device': 'cpu', 'dtype': 'torch.int16', 'seq_len_min': 1, 'seq_len_max': 512, 'seed': 42, 'applied_filters': [], 'grid_n': 3, 'n_mazes': 5, 'maze_ctor': {'__name__': 'gen_dfs', '__module__': 'maze_transformer.generation.generators', '__doc__': ['generate a lattice maze using depth first search, iterative', '', '        algorithm:', '        1. Choose the initial cell, mark it as visited and push it to the stack', '        2. While the stack is not empty', '                1. Pop a cell from the stack and make it a current cell', '                2. If the current cell has any neighbours which have not been visited', '                        1. Push the current cell to the stack', '                        2. Choose one of the unvisited neighbours', '                        3. Remove the wall between the current 

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

# plot example mazes

In [None]:
# setup consts
def testdata_plot_predicted_path(
		model,
		maze_tokens_path: Path, 
		n_mazes: int = 10,
		max_new_tokens: int = 8,
	):
	# load maze test data
	mazes_tokens: list[list[str]] = [
 	   line.split() for line in maze_tokens_path.read_text().splitlines()
	]
	mazes_tokens = mazes_tokens[:n_mazes]

	solved_mazes = [SolvedMaze.from_tokens(tokens, dataset.cfg) for tokens in mazes_tokens]

	predictions = predict_maze_paths(
		tokens_batch=mazes_tokens,
		data_cfg=dataset.cfg,
		model=model,
		max_new_tokens=max_new_tokens,
	)


	# plot
	for i, maze in enumerate(solved_mazes):
		MazePlot(maze).add_predicted_path(predictions[i]).plot()
		plt.show()


In [None]:
model, cfg_holder = load_model_with_configs(model_path)
testdata_plot_predicted_path(model, maze_path)

In [None]:
baseline_model = RandomBaseline(cfg_holder)

# Longest possible path in random DFS is NxN - 1\n
testdata_plot_predicted_path(baseline_model, maze_path, max_new_tokens=15)

# run path dist eval

In [None]:
model_checkpoints = get_checkpoint_paths_for_run(model_path.parent)
print(f"Found {len(model_checkpoints)} checkpoints:\n\t{model_checkpoints = }")

pathdist_scores_idx: dict[int, dict[str, StatCounter]] = dict()

for idx, checkpoint_path in model_checkpoints:
	print(f"# Evaluating checkpoint {idx} at {checkpoint_path}")
	model, _ = load_model_with_configs(checkpoint_path)
	pathdist_scores_idx[idx] = evaluate_model(
		model=model,
		dataset=dataset
	)


data = {
	name: {
		idx: scores[name]
		for idx, scores in pathdist_scores_idx.items()
	}
	for name in pathdist_scores_idx[0]
}

In [None]:
def plot_pathdist_scores(
		data: dict[str, dict[int, StatCounter]],
		colors: dict[str, str]|None = None,
		percentile_bounds: tuple[float, float] = (0.4, 0.6),
	):

	if colors is None:
		colors = {
			func_name: f"C{i}"
			for i, func_name in enumerate(data.keys())
		}

	fig, ax = plt.subplots(len(data), 1, figsize = (8, 4 * len(data)))
	fig.subplots_adjust(hspace = 0.5)
		
	for i, (name, scores_indexed) in enumerate(data.items()):
		x = list(scores_indexed.keys())
		y = [
			scores_indexed[i].median()
			for i in x
		]
		ax[i].plot(x, y, label=name, color=colors[name])
		# plot shaded error bars
		y_ub = [
			scores_indexed[i].percentile(percentile_bounds[1])
			for i in x
		]
		y_lb = [
			scores_indexed[i].percentile(percentile_bounds[0])
			for i in x
		]
		ax[i].fill_between(
			x, y_lb, y_ub,
	    	alpha=0.5, 
			edgecolor=colors[name], facecolor=colors[name],
		)

		ax[i].set_title(f"{name}, {percentile_bounds = }")
		ax[i].set_xlabel("Checkpoint")
		ax[i].set_ylabel("score")

	plt.show()

In [None]:
plot_pathdist_scores(data)