In [1]:
import os
import json
from pathlib import Path
import typing

import numpy as np
import torch
from transformers import OpenAIGPTLMHeadModel, OpenAIGPTConfig

from muutils.logger import Logger, TimerContext
from muutils.json_serialize import json_serialize, dataclass_serializer_factory
from muutils.tensor_utils import ATensor, NDArray
from muutils.statcounter import StatCounter
from muutils.misc import shorten_numerical_to_str

from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets

from maze_transformer.generation.latticemaze import LatticeMaze
from maze_transformer.generation.generators import LatticeMazeGenerators
from maze_transformer.training.tokenizer import MazeTokenizer, SPECIAL_TOKENS
from maze_transformer.training.mazedataset import MazeDatasetConfig, MazeDataset
from maze_transformer.evaluation.plot_maze import plot_multi_paths, PathFormat
from maze_transformer.training.dataset import GPTDatasetConfig
from maze_transformer.training.config import TrainConfig
from maze_transformer.training.training import TRAIN_SAVE_FILES
from maze_transformer.evaluation.eval_model import generate_plot_predicted_path, MazePath, load_model_with_configs, LoadedModelConfigs, predict_maze_path
from maze_transformer.evaluation.pathdist import MazeEvalFunction, ArrMazeEvalFunction, MazeEvalFuncs, ArrMazeEvalFuncs



# plot example mazes

In [None]:
# setup consts
def testdata_plot_predicted_path(
		model_path: str = "data/g4-n4K/g4-n4K_tiny-v1_2022-10-05-02-03-44/model.final.pt",
		n_mazes: int = 10,
		n_tokens_pred: int = 8,
		filename: str = "data/g4-n1K-test/maze_tokens.jsonl",
	):
	mazes_tokens: list[list[str]] = list()

	# load model and configs
	loaded_model_and_configs: LoadedModelConfigs = load_model_with_configs(model_path, MazeDatasetConfig)
	data_cfg: MazeDatasetConfig; train_cfg: TrainConfig
	model_cfg: OpenAIGPTConfig; model: OpenAIGPTLMHeadModel
	data_cfg, train_cfg, model_cfg, model = loaded_model_and_configs
	model.config.max_length = model_cfg.n_positions + 1

	# load maze test data
	with open(filename, "r") as f:
		for line in f:
			mazes_tokens.append(line.split())

	mazes_tokens = mazes_tokens[:n_mazes]

	mazes_solved: list[tuple[LatticeMaze, MazePath, MazePath]] = [
		predict_maze_path(
			tokens = tokens,
			data_cfg = data_cfg,
			model = model,
			n_tokens_pred = n_tokens_pred,
		)
		for tokens in mazes_tokens
	]


	# plot
	for maze, path, path_pred in mazes_solved:
		plot_multi_paths(
			maze = maze,
			paths = [
				PathFormat(path, "true", color = "red", quiver_kwargs = {'width': 0.015}),
				PathFormat(path_pred, "predicted", color = "blue", quiver_kwargs = {}),
			],
		)

# testdata_plot_predicted_path()

# run path dist eval

In [5]:
EvalFuncTuple = tuple[typing.Literal["arr", "list"], MazeEvalFunction|ArrMazeEvalFunction]

ALL_PATHDIST_FUNCS: dict[str, EvalFuncTuple] = {
	**{
		name: ("arr", func)
		for name, func in ArrMazeEvalFuncs.__dict__.items()
		if not name.startswith("_")
	},
	**{
		name: ("list", func)
		for name, func in MazeEvalFuncs.__dict__.items()
		if not name.startswith("_")
	},
}

print(ALL_PATHDIST_FUNCS)

{'fraction_connections_adjacent_lattice': ('arr', <staticmethod(<function ArrMazeEvalFuncs.fraction_connections_adjacent_lattice at 0x000002135AFE4F70>)>), 'fraction_connections_adjacent': ('arr', <staticmethod(<function ArrMazeEvalFuncs.fraction_connections_adjacent at 0x000002135AFE5000>)>), 'node_overlap': ('list', <staticmethod(<function MazeEvalFuncs.node_overlap at 0x000002135AFE4EE0>)>)}


In [8]:


from maze_transformer.evaluation.eval_model import ArrMazePath


def evaluate_model_pathdist_scores(
		model_path: str,
		pathdist_functions: dict[str, EvalFuncTuple]|None = None,
		n_tokens_pred: int = 8,
		test_data: str = "data/g4-n1K-test/maze_tokens.jsonl",
		n_mazes: int = 10,
	) -> dict[str, StatCounter]:

	if pathdist_functions is None:
		pathdist_functions = ALL_PATHDIST_FUNCS

	mazes_tokens: list[list[str]] = list()

	# load model and configs
	loaded_model_and_configs: LoadedModelConfigs = load_model_with_configs(model_path, MazeDatasetConfig)
	data_cfg: MazeDatasetConfig; train_cfg: TrainConfig
	model_cfg: OpenAIGPTConfig; model: OpenAIGPTLMHeadModel
	data_cfg, train_cfg, model_cfg, model = loaded_model_and_configs
	model.config.max_length = model_cfg.n_positions + 1

	# load maze test data
	with open(test_data, "r") as f:
		for line in f:
			mazes_tokens.append(line.split())

	mazes_tokens = mazes_tokens[:n_mazes]

	# predict paths
	mazes_solved: list[tuple[LatticeMaze, MazePath, MazePath]] = [
		predict_maze_path(
			tokens = tokens,
			data_cfg = data_cfg,
			model = model,
			n_tokens_pred = n_tokens_pred,
		)
		for tokens in mazes_tokens
	]

	# convert paths
	mazes_solved_arrpath: list[tuple[LatticeMaze, ArrMazePath, ArrMazePath]] = [
		(maze, np.array(p_true), np.array(p_pred))
		for maze, p_true, p_pred in mazes_solved
	]

	# evaluate
	pathdist_scores: dict[str, StatCounter] = dict()
	for name, (pathdist_type, pathdist_func) in pathdist_functions.items():
		if pathdist_type == "arr":
			pathdist_scores[name] = StatCounter(
				pathdist_func(maze, p_true, p_pred)
				for maze, p_true, p_pred in mazes_solved
			)
		elif pathdist_type == "list":
			pathdist_scores[name] = StatCounter(
				pathdist_func(maze, p_true, p_pred)
				for maze, p_true, p_pred in mazes_solved_arrpath
			)
		else:
			raise ValueError(f"Invalid pathdist_type: {pathdist_type}")

	return pathdist_scores

evaluate_model_pathdist_scores("data/g4-n4K/g4-n4K_tiny-v1_2022-10-05-02-03-44/model.final.pt")



loaded model with 28M parameters


TypeError: unsupported operand type(s) for -: 'int' and 'tuple'

In [None]:


# generate_plot_predicted_path("data/g4-n4K/g4-n4K_tiny-v1_2022-09-10-02-02-12/checkpoints/model.iter_3120.pt")
