In [None]:
# Generic
import typing
import os
from pathlib import Path
import typing
import html

# Plotting
import matplotlib.pyplot as plt

# Numerical Computing
import numpy as np
import torch
import pandas as pd

# Utilities
from muutils.statcounter import StatCounter

# Our Code
from maze_transformer.utils.notebook_utils import configure_notebook
from maze_transformer.generation.lattice_maze import LatticeMaze, SolvedMaze
from maze_transformer.evaluation.plot_maze import PathFormat, MazePlot
from maze_transformer.training.maze_dataset import MazeDataset, MazeDatasetConfig
from maze_transformer.evaluation.eval_model import (
    load_model_with_configs,
    predict_maze_paths,
    evaluate_model
)
from maze_transformer.evaluation.path_evals import (
    PathEvals,
    PathEvalFunction,
)
from maze_transformer.utils.utils import set_reproducibility, chunks, get_checkpoint_paths_for_run

In [None]:
# Setup
device = configure_notebook(seed=42, dark_mode=True)

# We won't be training any models
torch.set_grad_enabled(False)

# Get latest model
# this should point towards a directory containing a run. 
# If you don't have any runs, you can create a dataset with `poetry run python scripts/create_dataset.py create ./data/maze 10 --grid_n=4`
# Then train a model with `poetry run python scripts/train_model.py ./data/maze/g4-n10`
run_path = Path("../examples/maze/g4-n10/g4-n10_tiny-v1_2023-03-28-22-44-54")

assert run_path.exists(), f"Run path {run_path.as_posix()} does not exist"
model_path = list(sorted(run_path.glob("**/model.final.pt"), key=os.path.getmtime))[
	-1
].resolve()
maze_path = run_path.parent / "maze_tokens.jsonl"

dataset = MazeDataset.disk_load(run_path.parent, do_config=True, do_tokens=True)

# plot example mazes

In [None]:
# setup consts
def testdata_plot_predicted_path(
		model_path: Path,
		maze_tokens_path: Path, 
		n_mazes: int = 10,
		max_new_tokens: int = 8,
	):
	# load model and configs
	model, cfg = load_model_with_configs(model_path)
	
	# load maze test data
	mazes_tokens: list[list[str]] = [
 	   line.split() for line in maze_tokens_path.read_text().splitlines()
	]
	mazes_tokens = mazes_tokens[:n_mazes]

	solved_mazes = [SolvedMaze.from_tokens(tokens, dataset.cfg) for tokens in mazes_tokens]
	mazes, solutions = zip(*solved_mazes)

	predictions = predict_maze_paths(
		tokens_batch=mazes_tokens,
		data_cfg=dataset.cfg,
		model=model,
		max_new_tokens=max_new_tokens,
	)


	# plot
	for i, maze in enumerate(mazes):
		MazePlot(maze).add_true_path(solutions[i]).add_predicted_path(predictions[i]).show()

testdata_plot_predicted_path(model_path, maze_path)

# run path dist eval

In [None]:
model_checkpoints = get_checkpoint_paths_for_run(model_path.parent)
print(f"Found {len(model_checkpoints)} checkpoints:\n\t{model_checkpoints = }")

pathdist_scores_idx: dict[int, dict[str, StatCounter]] = dict()

for idx, checkpoint_path in model_checkpoints:
	print(f"# Evaluating checkpoint {idx} at {checkpoint_path}")
	model, _ = load_model_with_configs(checkpoint_path)
	pathdist_scores_idx[idx] = evaluate_model(
		model=model,
		dataset=dataset
	)

print(f"Found {len(pathdist_scores_idx)} checkpoints:\n\t{list(pathdist_scores_idx.keys())}")


data = {
	name: {
		idx: scores[name]
		for idx, scores in pathdist_scores_idx.items()
	}
	for name in pathdist_scores_idx[0]
}

In [None]:
# sort the data:
data_old = data
data = {
	name: {
		idx: scores[name]
		for idx, scores in sorted(pathdist_scores_idx.items())
	}
	for name in pathdist_scores_idx[0]
}


In [None]:
def plot_pathdist_scores(
		data: dict[str, dict[int, StatCounter]],
		colors: dict[str, str]|None = None,
		percentile_bounds: tuple[float, float] = (0.4, 0.6),
	):

	if colors is None:
		colors = {
			func_name: f"C{i}"
			for i, func_name in enumerate(data.keys())
		}

	fig, ax = plt.subplots(len(data), 1, figsize = (8, 4 * len(data)))
	fig.subplots_adjust(hspace = 0.5)
		
	for i, (name, scores_indexed) in enumerate(data.items()):
		x = list(scores_indexed.keys())
		y = [
			scores_indexed[i].median()
			for i in x
		]
		ax[i].plot(x, y, label=name, color=colors[name])
		# plot shaded error bars
		y_ub = [
			scores_indexed[i].percentile(percentile_bounds[1])
			for i in x
		]
		y_lb = [
			scores_indexed[i].percentile(percentile_bounds[0])
			for i in x
		]
		ax[i].fill_between(
			x, y_lb, y_ub,
	    	alpha=0.5, 
			edgecolor=colors[name], facecolor=colors[name],
		)

		ax[i].set_title(f"{name}, {percentile_bounds = }")
		ax[i].set_xlabel("Checkpoint")
		ax[i].set_ylabel("score")

	plt.show()

In [None]:
plot_pathdist_scores(data)