# Setup

In [1]:
# Generic
import copy
import typing
from pathlib import Path
import html

# Plotting
import matplotlib.pyplot as plt

# Numerical Computing
import numpy as np
import torch
# Utilities
from muutils.statcounter import StatCounter

# Our Code
from muutils.nbutils.configure_notebook import configure_notebook
from maze_dataset.plotting import MazePlot
from maze_dataset import MazeDataset, MazeDatasetConfig
from maze_transformer.training.config import ZanjHookedTransformer
from maze_transformer.evaluation.eval_model import predict_maze_paths, evaluate_model
from maze_transformer.evaluation.baseline_models import RandomBaseline
from muutils.mlutils import get_checkpoint_paths_for_run
from maze_transformer.evaluation.plotting import plot_predicted_paths, eval_model_at_checkpoints, plot_pathdist_scores

In [None]:
# Setup (we won't be training any models)
DEVICE: torch.device = configure_notebook(seed=42, dark_mode=False)
print(f"{DEVICE = }")
torch.set_grad_enabled(False)

In [None]:
# Setup
PATH_EXAMPLES: Path = Path("../examples/")
PATH_DATA: Path = Path("../data/")

# We won't be training any models
torch.set_grad_enabled(False)

# get the default model from examples
MODEL_PATH: Path = PATH_EXAMPLES / "multsrc_demo-g6-n10K-a_dfs-h50618_tiny-v1_sweep-v1_2024-08-21-12-21-39/model.final.zanj"
# MODEL_PATH: Path = PATH_DATA / "custom_2023-05-24-05-03-04/model.final.zanj"
# MODEL_PATH: Path = PATH_EXAMPLES / "hallway-medium_2023-06-16-03-40-47.iter_26554.zanj"
print(f"will try to get model from {MODEL_PATH.as_posix()}")
MODEL: ZanjHookedTransformer = ZanjHookedTransformer.read(MODEL_PATH)
print(f"loaded model: {MODEL.config.name} with {MODEL.num_params()} parameters")

In [None]:
# generate a smaller test dataset from the same config
DATASET_TEST_CFG: MazeDatasetConfig = copy.deepcopy(MODEL.config.dataset_cfg)
DATASET_TEST_CFG.n_mazes = 100
DATASET_TEST: MazeDataset = MazeDataset.from_config(
    DATASET_TEST_CFG,
    local_base_path=PATH_DATA,
    verbose=True,
)
print(f"got test dataset: {DATASET_TEST.cfg.name} with {len(DATASET_TEST)} mazes")

# plot example mazes

In [None]:
plot_predicted_paths(MODEL, DATASET_TEST, n_mazes=5, max_new_tokens=50)

In [None]:
# now let's do the same with the baseline solver
BASELINE_SOLVER: RandomBaseline = RandomBaseline(MODEL.zanj_model_config)
plot_predicted_paths(BASELINE_SOLVER, DATASET_TEST, n_mazes=5, max_new_tokens=15)

# run path dist eval

In [None]:
PATHDIST_SCORES: dict[str, dict[int, StatCounter]] = eval_model_at_checkpoints(MODEL_PATH, DATASET_TEST, max_checkpoints=5)

In [None]:
plot_pathdist_scores(PATHDIST_SCORES)