In [14]:
# Generic
import typing
import os
from pathlib import Path
import typing
import json

# Plotting
import matplotlib.pyplot as plt

# Numerical Computing
import numpy as np
import torch

# muutils
from zanj.zanj import ZANJ, ZANJ_GLOBAL_DEFAULTS

# Our Code
from muutils.nbutils.configure_notebook import configure_notebook
from maze_transformer.training.config import ConfigHolder, ZanjHookedTransformer, BaseGPTConfig, TrainConfig
from maze_dataset import MazeDataset, MazeDatasetConfig
from maze_dataset.dataset.configs import MAZE_DATASET_CONFIGS
from maze_dataset.generation import LatticeMazeGenerators
from maze_transformer.training.train_model import TrainingResult, train_model
from maze_transformer.training.wandb_logger import WandbProject

import pandas as pd
from maze_dataset.tokenization import MazeTokenizer, TokenizationMode
from maze_transformer.evaluation.eval_model import predict_maze_paths
from maze_transformer.evaluation.path_evals import rollout_evals
from maze_dataset import MazeDataset, MazeDatasetConfig, SolvedMaze, LatticeMaze, SPECIAL_TOKENS, LatticeMazeGenerators, CoordArray

import typing

def print_keys(d: typing.Dict, prefix: str = " ", level: int = 1):
    indent = "  " * level  # Basic indentation for readability
    marker = "|-" if level > 0 else ""  # Add a marker for nested levels
    for k, v in d.items():
        print(f"{indent}{marker}{prefix}{k}: {v}")  # Print the current key with marker

print(torch.cuda.is_available())  # Verifies CUDA is available to PyTorch
print(torch.cuda.get_device_name(0))  # Displays the name of your CUDA device

True
NVIDIA RTX 2000 Ada Generation Laptop GPU


In [5]:
# set global defaults for ZANJ
ZANJ_GLOBAL_DEFAULTS.external_array_threshold = 1024
ZANJ_GLOBAL_DEFAULTS.external_list_threshold = 1024

# paths
PATH_EXAMPLES: Path = Path("../examples/")
PATH_DATA: Path = Path("../data/")

# reproducibility and device
DEVICE = configure_notebook(seed=42, dark_mode=True, device="cuda")
print(f"{DEVICE = }")

set up plots with PLOT_MODE = 'inline', FIG_OUTPUT_FMT = None, FIG_BASEPATH = None
DEVICE = device(type='cuda')


# 1. Train the model according to the config
- NOTE: intervals_count specifies how often to do that thing in 1 epoch. So setting it low will mean do it infrequent.

### new configs

In [10]:
eval_scale = 1/100
exp_name_data = "hallway_vLeon_v4_30k"
exp_name = "hallway_vLeon_v4_30k_PathTokenLossOnly_true_tk6"

CFG_HALLWAY_MOD1: ConfigHolder = ConfigHolder(
    name = exp_name,
    dataset_cfg = MazeDatasetConfig(
		name=exp_name_data,
		grid_n=6,
		n_mazes=30_000,
		maze_ctor=LatticeMazeGenerators.gen_dfs,
        maze_ctor_kwargs=dict(
            do_forks=True,
        ),
        applied_filters=[{'name': 'collect_generation_meta', 'args': (), 'kwargs': {}}],
        seq_len_max=256,
        seq_len_min=3,
	),
    model_cfg = BaseGPTConfig(
        name="custom-model",
        act_fn="gelu",
        d_model=256,
        d_head=32,
        n_layers=6,
    ),
    train_cfg = TrainConfig(
        name="custom-train",
        optimizer=torch.optim.AdamW,
        optimizer_kwargs=dict(lr=0.0003),
        batch_size=32,
        num_epochs=5,
        dataloader_cfg=dict(
            shuffle=True,
            num_workers=2,
            drop_last=False,
        ),

        intervals_count=dict(
            print_loss=2,
            checkpoint=2,
            eval_fast=2,
            eval_slow=2,
        ),
        validation_dataset_cfg=10,
        loss_update_from_pathstart=True,
    ),
)

In [11]:
# here is where to specify which config to actually use
CFG: ConfigHolder = CFG_HALLWAY_MOD1
print("model name;", CFG.dataset_cfg.to_fname())

model name; hallway_vLeon_v4_30k-g6-n30K-a_dfs-h48882


In [12]:
# load the dataset
DATASET: MazeDataset = MazeDataset.from_config(CFG.dataset_cfg, verbose=True, local_base_path=PATH_DATA)
print("len dataset", len(DATASET))

trying to get the dataset 'hallway_vLeon_v4_30k-g6-n30K-a_dfs-h48882'
loading dataset from ../data/hallway_vLeon_v4_30k-g6-n30K-a_dfs-h48882.zanj
Got dataset hallway_vLeon_v4_30k with 30000 items. output.cfg.to_fname() = 'hallway_vLeon_v4_30k-g6-n30K-a_dfs-h48882'
len dataset 30000


In [13]:
result: TrainingResult = train_model(
	base_path=PATH_DATA,
    cfg=CFG,
	wandb_project=WandbProject.DEMO_NOTEBOOKS, # change this to WandbProject.DEMO_NOTEBOOKS!
	do_generate_dataset=False,
	dataset_verbose=True,
    dataset=DATASET,
)

2024-03-17 17:36:22 ERROR Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


[34m[1mwandb[0m: Currently logged in as: [33mwatermelontology[0m. Use [1m`wandb login --relogin`[0m to force relogin


2024-03-17 17:36:24 INFO config ={'__format__': 'ConfigHolder(SerializableDataclass)', 'dataset_cfg': {'__format__': 'MazeDatasetConfig(SerializableDataclass)', 'name': 'hallway_vLeon_v4_30k', 'seq_len_min': 3, 'seq_len_max': 256, 'seed': 42, 'applied_filters': [{'name': 'collect_generation_meta', 'args': (), 'kwargs': {}}], 'grid_n': 6, 'n_mazes': 30000, 'maze_ctor': {'__name__': 'gen_dfs', '__module__': 'maze_dataset.generation.generators', '__doc__': ['generate a lattice maze using depth first search, iterative', '', '        # Arguments', '        - `grid_shape: Coord`: the shape of the grid', '        - `lattice_dim: int`: the dimension of the lattice', '          (default: `2`)', '        - `accessible_cells: int | float |None`: the number of accessible cells in the maze. If `None`, defaults to the total number of cells in the grid. if a float, asserts it is <= 1 and treats it as a proportion of **total cells**', '            (default: `None`)', '        - `max_tree_depth: int | 

  0%|          | 0/938 [00:00<?, ?it/s]

2024-03-17 17:36:25 INFO iteration 0/938: loss=0.292
2024-03-17 17:36:25 INFO Saving model checkpoint to ../data/hallway_vLeon_v4_30k_PathTokenLossOnly_true_tk6_2024-03-17-17-36-22/checkpoints/model.iter_0.zanj


 50%|████▉     | 468/938 [01:05<01:05,  7.16it/s]

2024-03-17 17:37:30 INFO iteration 468/938: loss=0.038
2024-03-17 17:37:30 INFO Saving model checkpoint to ../data/hallway_vLeon_v4_30k_PathTokenLossOnly_true_tk6_2024-03-17-17-36-22/checkpoints/model.iter_468.zanj


100%|█████████▉| 936/938 [02:10<00:00,  7.35it/s]

2024-03-17 17:38:35 INFO iteration 936/938: loss=0.046
2024-03-17 17:38:35 INFO Saving model checkpoint to ../data/hallway_vLeon_v4_30k_PathTokenLossOnly_true_tk6_2024-03-17-17-36-22/checkpoints/model.iter_936.zanj


100%|██████████| 938/938 [02:11<00:00,  7.11it/s]

2024-03-17 17:38:36 INFO Running evals: eval_fast





for interval_key eval_fast scores are {'node_overlap': {
  "StatCounter": {
    "0.08333333333333333": 1,
    "0.16666666666666666": 1,
    "0.2": 1,
    "0.25": 1,
    "0.26666666666666666": 1,
    "0.2777777777777778": 1,
    "0.3333333333333333": 1,
    "0.375": 1,
    "0.6": 1,
    "0.75": 1
  },
  "summary": {
    "total_items": 10,
    "n_keys": 10,
    "mode": 0.08333333333333333,
    "mean": 0.3302777777777778,
    "std": 0.1920666117088895,
    "min": 0.08333333333333333,
    "q1": 0.21250000000000002,
    "median": 0.27222222222222225,
    "q3": 0.3645833333333333,
    "max": 0.75
  }
}, 'num_connections_adjacent_lattice': {
  "StatCounter": {
    "3.0": 1,
    "5.0": 2,
    "6.0": 1,
    "7.0": 6
  },
  "summary": {
    "total_items": 10,
    "n_keys": 4,
    "mode": 7.0,
    "mean": 6.1,
    "std": 1.3,
    "min": 3.0,
    "q1": 5.25,
    "median": 7.0,
    "q3": 7.0,
    "max": 7.0
  }
}, 'fraction_connections_adjacent_lattice': {
  "StatCounter": {
    "0.75": 1,
    "0.8

  0%|          | 0/938 [00:00<?, ?it/s]

2024-03-17 17:38:39 INFO iteration 0/938: loss=0.038
2024-03-17 17:38:39 INFO Saving model checkpoint to ../data/hallway_vLeon_v4_30k_PathTokenLossOnly_true_tk6_2024-03-17-17-36-22/checkpoints/model.iter_0.zanj


 50%|████▉     | 468/938 [01:05<01:04,  7.32it/s]

2024-03-17 17:39:45 INFO iteration 468/938: loss=0.027
2024-03-17 17:39:45 INFO Saving model checkpoint to ../data/hallway_vLeon_v4_30k_PathTokenLossOnly_true_tk6_2024-03-17-17-36-22/checkpoints/model.iter_468.zanj


100%|█████████▉| 936/938 [02:11<00:00,  7.23it/s]

2024-03-17 17:40:51 INFO iteration 936/938: loss=0.033
2024-03-17 17:40:51 INFO Saving model checkpoint to ../data/hallway_vLeon_v4_30k_PathTokenLossOnly_true_tk6_2024-03-17-17-36-22/checkpoints/model.iter_936.zanj


100%|██████████| 938/938 [02:12<00:00,  7.08it/s]

2024-03-17 17:40:52 INFO Running evals: eval_fast





for interval_key eval_fast scores are {'node_overlap': {
  "StatCounter": {
    "0.0625": 1,
    "0.08333333333333333": 1,
    "0.125": 1,
    "0.16666666666666666": 2,
    "0.2": 1,
    "0.3333333333333333": 1,
    "0.5833333333333334": 1,
    "1.0": 2
  },
  "summary": {
    "total_items": 10,
    "n_keys": 8,
    "mode": 0.16666666666666666,
    "mean": 0.3720833333333333,
    "std": 0.3448130028645156,
    "min": 0.0625,
    "q1": 0.13541666666666666,
    "median": 0.18333333333333335,
    "q3": 0.5208333333333334,
    "max": 1.0
  }
}, 'num_connections_adjacent_lattice': {
  "StatCounter": {
    "2.0": 1,
    "4.0": 1,
    "7.0": 8
  },
  "summary": {
    "total_items": 10,
    "n_keys": 3,
    "mode": 7.0,
    "mean": 6.2,
    "std": 1.6613247725836149,
    "min": 2.0,
    "q1": 7.0,
    "median": 7.0,
    "q3": 7.0,
    "max": 7.0
  }
}, 'fraction_connections_adjacent_lattice': {
  "StatCounter": {
    "0.6666666666666666": 1,
    "0.8": 1,
    "0.875": 8
  },
  "summary": {
   

  0%|          | 0/938 [00:00<?, ?it/s]

2024-03-17 17:40:54 INFO iteration 0/938: loss=0.026
2024-03-17 17:40:54 INFO Saving model checkpoint to ../data/hallway_vLeon_v4_30k_PathTokenLossOnly_true_tk6_2024-03-17-17-36-22/checkpoints/model.iter_0.zanj


 50%|████▉     | 468/938 [01:05<01:07,  6.94it/s]

2024-03-17 17:42:00 INFO iteration 468/938: loss=0.032
2024-03-17 17:42:00 INFO Saving model checkpoint to ../data/hallway_vLeon_v4_30k_PathTokenLossOnly_true_tk6_2024-03-17-17-36-22/checkpoints/model.iter_468.zanj


100%|█████████▉| 936/938 [02:11<00:00,  7.15it/s]

2024-03-17 17:43:06 INFO iteration 936/938: loss=0.024
2024-03-17 17:43:06 INFO Saving model checkpoint to ../data/hallway_vLeon_v4_30k_PathTokenLossOnly_true_tk6_2024-03-17-17-36-22/checkpoints/model.iter_936.zanj


100%|██████████| 938/938 [02:12<00:00,  7.08it/s]

2024-03-17 17:43:07 INFO Running evals: eval_fast





for interval_key eval_fast scores are {'node_overlap': {
  "StatCounter": {
    "0.16666666666666666": 1,
    "0.1875": 1,
    "0.25": 1,
    "0.26666666666666666": 1,
    "0.4": 1,
    "0.4444444444444444": 1,
    "0.6666666666666666": 1,
    "0.75": 1,
    "1.0": 2
  },
  "summary": {
    "total_items": 10,
    "n_keys": 9,
    "mode": 1.0,
    "mean": 0.5131944444444445,
    "std": 0.3045217676036385,
    "min": 0.16666666666666666,
    "q1": 0.25416666666666665,
    "median": 0.4222222222222222,
    "q3": 0.7291666666666666,
    "max": 1.0
  }
}, 'num_connections_adjacent_lattice': {
  "StatCounter": {
    "2.0": 1,
    "4.0": 1,
    "5.0": 2,
    "6.0": 1,
    "7.0": 5
  },
  "summary": {
    "total_items": 10,
    "n_keys": 5,
    "mode": 7.0,
    "mean": 5.7,
    "std": 1.6155494421403513,
    "min": 2.0,
    "q1": 5.0,
    "median": 6.5,
    "q3": 7.0,
    "max": 7.0
  }
}, 'fraction_connections_adjacent_lattice': {
  "StatCounter": {
    "0.6666666666666666": 1,
    "0.8": 1,


  0%|          | 0/938 [00:00<?, ?it/s]

2024-03-17 17:43:10 INFO iteration 0/938: loss=0.028
2024-03-17 17:43:10 INFO Saving model checkpoint to ../data/hallway_vLeon_v4_30k_PathTokenLossOnly_true_tk6_2024-03-17-17-36-22/checkpoints/model.iter_0.zanj


 50%|████▉     | 468/938 [01:07<01:08,  6.88it/s]

2024-03-17 17:44:17 INFO iteration 468/938: loss=0.021
2024-03-17 17:44:17 INFO Saving model checkpoint to ../data/hallway_vLeon_v4_30k_PathTokenLossOnly_true_tk6_2024-03-17-17-36-22/checkpoints/model.iter_468.zanj


100%|█████████▉| 936/938 [02:13<00:00,  7.07it/s]

2024-03-17 17:45:23 INFO iteration 936/938: loss=0.019
2024-03-17 17:45:23 INFO Saving model checkpoint to ../data/hallway_vLeon_v4_30k_PathTokenLossOnly_true_tk6_2024-03-17-17-36-22/checkpoints/model.iter_936.zanj


100%|██████████| 938/938 [02:14<00:00,  6.95it/s]

2024-03-17 17:45:24 INFO Running evals: eval_fast





for interval_key eval_fast scores are {'node_overlap': {
  "StatCounter": {
    "0.1111111111111111": 1,
    "0.13333333333333333": 1,
    "0.3333333333333333": 1,
    "0.4": 1,
    "0.5": 1,
    "0.6666666666666666": 3,
    "0.75": 1,
    "1.0": 1
  },
  "summary": {
    "total_items": 10,
    "n_keys": 8,
    "mode": 0.6666666666666666,
    "mean": 0.5227777777777777,
    "std": 0.2672268421869053,
    "min": 0.1111111111111111,
    "q1": 0.35,
    "median": 0.5833333333333333,
    "q3": 0.6666666666666666,
    "max": 1.0
  }
}, 'num_connections_adjacent_lattice': {
  "StatCounter": {
    "4.0": 1,
    "5.0": 2,
    "6.0": 1,
    "7.0": 6
  },
  "summary": {
    "total_items": 10,
    "n_keys": 4,
    "mode": 7.0,
    "mean": 6.2,
    "std": 1.0770329614269007,
    "min": 4.0,
    "q1": 5.25,
    "median": 7.0,
    "q3": 7.0,
    "max": 7.0
  }
}, 'fraction_connections_adjacent_lattice': {
  "StatCounter": {
    "0.8": 1,
    "0.8333333333333334": 2,
    "0.8571428571428571": 1,
    

  0%|          | 0/938 [00:00<?, ?it/s]

2024-03-17 17:45:28 INFO iteration 0/938: loss=0.016
2024-03-17 17:45:28 INFO Saving model checkpoint to ../data/hallway_vLeon_v4_30k_PathTokenLossOnly_true_tk6_2024-03-17-17-36-22/checkpoints/model.iter_0.zanj


 50%|████▉     | 468/938 [01:07<01:06,  7.08it/s]

2024-03-17 17:46:35 INFO iteration 468/938: loss=0.009
2024-03-17 17:46:35 INFO Saving model checkpoint to ../data/hallway_vLeon_v4_30k_PathTokenLossOnly_true_tk6_2024-03-17-17-36-22/checkpoints/model.iter_468.zanj


100%|█████████▉| 936/938 [02:13<00:00,  7.24it/s]

2024-03-17 17:47:41 INFO iteration 936/938: loss=0.006
2024-03-17 17:47:41 INFO Saving model checkpoint to ../data/hallway_vLeon_v4_30k_PathTokenLossOnly_true_tk6_2024-03-17-17-36-22/checkpoints/model.iter_936.zanj


100%|██████████| 938/938 [02:14<00:00,  6.98it/s]

2024-03-17 17:47:42 INFO Running evals: eval_fast





for interval_key eval_fast scores are {'node_overlap': {
  "StatCounter": {
    "0.08333333333333333": 2,
    "0.1111111111111111": 1,
    "0.3333333333333333": 1,
    "0.5": 1,
    "0.6666666666666666": 1,
    "1.0": 4
  },
  "summary": {
    "total_items": 10,
    "n_keys": 6,
    "mode": 1.0,
    "mean": 0.5777777777777777,
    "std": 0.38646067309863735,
    "min": 0.08333333333333333,
    "q1": 0.16666666666666666,
    "median": 0.5833333333333333,
    "q3": 1.0,
    "max": 1.0
  }
}, 'num_connections_adjacent_lattice': {
  "StatCounter": {
    "2.0": 1,
    "4.0": 2,
    "5.0": 1,
    "7.0": 6
  },
  "summary": {
    "total_items": 10,
    "n_keys": 4,
    "mode": 7.0,
    "mean": 5.7,
    "std": 1.7349351572897471,
    "min": 2.0,
    "q1": 4.25,
    "median": 7.0,
    "q3": 7.0,
    "max": 7.0
  }
}, 'fraction_connections_adjacent_lattice': {
  "StatCounter": {
    "0.6666666666666666": 1,
    "0.8": 2,
    "0.8333333333333334": 1,
    "0.875": 6
  },
  "summary": {
    "total_

### Store model and data

In [16]:
exp_name += "_firepeople"
print(exp_name)

hallway_vLeon_v4_30k_PathTokenLossOnly_true_tk6_firepeople


In [18]:

def store_model_dataset(data_path, model, data_set, exp_name):
    # create path if does not exist
    data_path = Path(data_path)
    data_path.mkdir(parents=True, exist_ok=True)

    # save the model
    model.save(data_path / f"model_{exp_name}.zanj")

    # save the dataset
    data_set.save(data_path / f"dataset_{exp_name}.zanj")

def load_model_dataset(data_path):
    # create path if does not exist
    data_path = Path(data_path)

    # load the model
    model = ZanjHookedTransformer.read(data_path / f"model_{exp_name}.zanj")

    # load the dataset
    data_set = MazeDataset.read(data_path / f"dataset_{exp_name}.zanj")

    return model, data_set

data_path = "./data/training_exp/"
# exp_name = "hallway_6x6_3k_ArchLeon_v3"
# store_model_dataset(data_path, result.model, DATASET, exp_name)

# model2, dataset2 = load_model_dataset(data_path, exp_name)

# 2. Evaluate how good the model is

## 2.1 Get evaluation statistics:

In [19]:
def evaluate_model_rollouts(
    model,
    dataset: MazeDataset,
    rollouts_acc: int|None = None,
    rollout_temperature: float = 0.0,
    batch_size: int|None = 16,
) -> dict[str, float]:


    tokenizer: MazeTokenizer = model.zanj_model_config.maze_tokenizer

    if rollouts_acc is not None:
        if rollouts_acc == "all":
            rollouts_acc = len(dataset)
            print(f"number of rollouts = {rollouts_acc}")
        dataset_tokens: list[list[str]] = dataset.as_tokens(maze_tokenizer=tokenizer)[:rollouts_acc]
        # do rollouts
        predictions: list[list[str|tuple[int, int]]] = predict_maze_paths(
            tokens_batch=dataset_tokens,
            data_cfg=dataset.cfg,
            model=model,
            max_new_tokens=None,
            temperature=rollout_temperature,
            when_noncoord="include",
            smart_max_new_tokens=True,
            batch_size=batch_size,
        )

        mazes: list[SolvedMaze] = dataset.mazes[:rollouts_acc]
        assert len(predictions) == len(mazes)

        # evaluate rollouts
        rollout_eval_results: dict[str, float] = rollout_evals(
            predictions=predictions,
            mazes=mazes,
        )

        return rollout_eval_results

    return {}

In [20]:
RESULTS: pd.DataFrame = evaluate_model_rollouts(
	model=result.model,
	dataset=DATASET,
	# rollouts_acc="all", # number of mazes to check
	batch_size=32, 
    rollouts_acc=1000,
)

print(RESULTS)

{'correct EOS': 0.099, 'mean invalid tokens': 0.87, 'percent with invalid tokens': 0.87, 'exactly correct rollouts': 0.0, 'valid rollouts': 0.001, 'rollouts with target reached': 0.067}


In [21]:
print_keys(RESULTS)

  |- correct EOS: 0.099
  |- mean invalid tokens: 0.87
  |- percent with invalid tokens: 0.87
  |- exactly correct rollouts: 0.0
  |- valid rollouts: 0.001
  |- rollouts with target reached: 0.067


# Evaluate on Test set

In [22]:
CFG_data = CFG.dataset_cfg
# copy the object
CFG_testdata = MazeDatasetConfig(
		name="hallway_vLeon_test",
		grid_n=CFG_data.grid_n,
		n_mazes=1000,
		maze_ctor=LatticeMazeGenerators.gen_dfs,
        maze_ctor_kwargs=dict(
            do_forks=True,
        ),
        applied_filters=CFG_data.applied_filters,
        seq_len_max=CFG_data.seq_len_max,
        seq_len_min=CFG_data.seq_len_min,
        seed = 100,
	)


DATASET_test: MazeDataset = MazeDataset.from_config(CFG_testdata, verbose=True, local_base_path=PATH_DATA)


trying to get the dataset 'hallway_vLeon_test-g6-n1000-a_dfs-h75706'
loading dataset from ../data/hallway_vLeon_test-g6-n1000-a_dfs-h75706.zanj
Got dataset hallway_vLeon_test with 1000 items. output.cfg.to_fname() = 'hallway_vLeon_test-g6-n1000-a_dfs-h75706'



in GPTDatasetConfig self.name='hallway_vLeon_test', self.seed=100 is trying to override GLOBAL_SEED=42 which has already been changed elsewhere from DEFAULT_SEED=42



### Code to check the test dataset is different
- I check only the part of the maze after adjacency list

In [23]:
tokenizer=  MazeTokenizer()

dataset_tokens2: list[list[str]] = DATASET.as_tokens(maze_tokenizer=tokenizer)
dataset_tokens_test: list[list[str]] = DATASET_test.as_tokens(maze_tokenizer=tokenizer)

dataset2_str = [" ".join(tokens).split("<ORIGIN_START>")[-1] for tokens in dataset_tokens2]
dataset_test_str = [" ".join(tokens).split("<ORIGIN_START>")[-1] for tokens in dataset_tokens_test]


# check if maze in DATASET_test is also in the training data 
duplicate = 0
for maze_str in dataset_test_str:
    if maze_str in dataset2_str:
        duplicate += 1

print("number of duplicates", duplicate, "out of ", len(dataset_test_str), ", so perc duplicate", duplicate/len(dataset_test_str))

number of duplicates 307 out of  1000 , so perc duplicate 0.307


In [24]:
# get the index of path_start token for each test maze
path_start_idx = [tokens.index("<PATH_START>") for tokens in dataset_tokens_test]
print(path_start_idx[:10])
print(dataset_tokens_test[path_start_idx[0]:])

[148, 148, 148, 148, 148, 148, 148, 148, 148, 148]
[['<ADJLIST_START>', '(5,1)', '<-->', '(5,2)', ';', '(4,3)', '<-->', '(4,4)', ';', '(3,1)', '<-->', '(3,2)', ';', '(1,5)', '<-->', '(0,5)', ';', '(3,2)', '<-->', '(2,2)', ';', '(4,1)', '<-->', '(3,1)', ';', '(2,0)', '<-->', '(3,0)', ';', '(5,1)', '<-->', '(4,1)', ';', '(5,0)', '<-->', '(5,1)', ';', '(0,1)', '<-->', '(0,2)', ';', '(2,2)', '<-->', '(2,1)', ';', '(1,2)', '<-->', '(0,2)', ';', '(0,4)', '<-->', '(0,5)', ';', '(2,5)', '<-->', '(3,5)', ';', '(0,0)', '<-->', '(1,0)', ';', '(5,4)', '<-->', '(5,3)', ';', '(1,1)', '<-->', '(1,0)', ';', '(1,3)', '<-->', '(1,2)', ';', '(1,1)', '<-->', '(0,1)', ';', '(1,4)', '<-->', '(0,4)', ';', '(3,3)', '<-->', '(2,3)', ';', '(2,1)', '<-->', '(2,0)', ';', '(4,5)', '<-->', '(5,5)', ';', '(1,3)', '<-->', '(1,4)', ';', '(4,5)', '<-->', '(3,5)', ';', '(3,4)', '<-->', '(2,4)', ';', '(5,2)', '<-->', '(5,3)', ';', '(4,3)', '<-->', '(4,2)', ';', '(3,4)', '<-->', '(3,5)', ';', '(5,4)', '<-->', '(4,4)', ';'

In [25]:
# evaluate on test set
RESULTS_test: pd.DataFrame = evaluate_model_rollouts(
	model=result.model,
	dataset=DATASET_test,
	rollouts_acc="all", # number of mazes to check
	batch_size=32, 
    # rollouts_acc=3000,
)

print(RESULTS_test)
print_keys(RESULTS_test)

number of rollouts = 1000
{'correct EOS': 0.086, 'mean invalid tokens': 0.886, 'percent with invalid tokens': 0.886, 'exactly correct rollouts': 0.0, 'valid rollouts': 0.001, 'rollouts with target reached': 0.061}
  |- correct EOS: 0.086
  |- mean invalid tokens: 0.886
  |- percent with invalid tokens: 0.886
  |- exactly correct rollouts: 0.0
  |- valid rollouts: 0.001
  |- rollouts with target reached: 0.061


# Inspect predicted path vs true path:

In [29]:
from typing import List, Union, Tuple
from maze_dataset.tokenization.token_utils import (
    WhenMissing,
    get_context_tokens,
    get_path_tokens,
    remove_padding_from_token_str,
    strings_to_coords,
)

In [31]:
def get_predictions(model, dataset, rollouts_acc, batch_size, tokenizer, rollout_temperature):
    if rollouts_acc == "all":
        rollouts_acc = len(dataset)
        print(f"number of rollouts = {rollouts_acc}")
    dataset_tokens = dataset.as_tokens(maze_tokenizer=tokenizer)[:rollouts_acc]
    predictions = predict_maze_paths(
        tokens_batch=dataset_tokens,
        data_cfg=dataset.cfg,
        model=model,
        max_new_tokens=None,
        temperature=rollout_temperature,
        when_noncoord="include",
        smart_max_new_tokens=True,
        batch_size=batch_size,
    )
    return dataset_tokens, predictions

def inspect_predictions(dataset_tokens, predictions):
    for i in range(len(dataset_tokens)):
        temp = dataset_tokens[i]
        temp_pred = [str(item) for item in predictions[i]]
        print(" ".join(temp))
        print(" ".join(get_context_tokens(temp)))
        print(" ".join(get_context_tokens(temp)) + " ".join(temp_pred))
        print()
        print(" ".join(temp_pred))
        print(" ".join(dataset_tokens[i]))
        print("------------------- \n")

# Usage
model = result.model
dataset = DATASET_test
rollouts_acc = 2  # number of mazes to check
batch_size = 64
tokenizer: MazeTokenizer = model.zanj_model_config.maze_tokenizer
rollout_temperature = 0.0

dataset_tokens, predictions = get_predictions(model, dataset, rollouts_acc, batch_size, tokenizer, rollout_temperature)
inspect_predictions(dataset_tokens, predictions)

<ADJLIST_START> (3,3) <--> (3,4) ; (2,3) <--> (2,4) ; (1,1) <--> (0,1) ; (1,2) <--> (0,2) ; (4,4) <--> (5,4) ; (1,0) <--> (2,0) ; (1,1) <--> (1,0) ; (5,3) <--> (4,3) ; (3,5) <--> (3,4) ; (5,5) <--> (4,5) ; (2,2) <--> (2,1) ; (2,5) <--> (3,5) ; (0,2) <--> (0,3) ; (5,4) <--> (5,5) ; (5,2) <--> (5,1) ; (4,2) <--> (4,1) ; (3,0) <--> (2,0) ; (4,1) <--> (3,1) ; (5,0) <--> (4,0) ; (0,4) <--> (0,5) ; (1,4) <--> (1,3) ; (4,2) <--> (4,3) ; (5,3) <--> (5,2) ; (3,2) <--> (3,3) ; (2,5) <--> (1,5) ; (0,1) <--> (0,0) ; (1,5) <--> (0,5) ; (2,3) <--> (1,3) ; (0,4) <--> (1,4) ; (4,0) <--> (3,0) ; (4,5) <--> (3,5) ; (5,1) <--> (5,0) ; (0,3) <--> (1,3) ; (2,2) <--> (3,2) ; (2,1) <--> (3,1) ; <ADJLIST_END> <ORIGIN_START> (0,0) <ORIGIN_END> <TARGET_START> (3,0) <TARGET_END> <PATH_START> (0,0) (0,1) (1,1) (1,0) (2,0) (3,0) <PATH_END>
<ADJLIST_START> (3,3) <--> (3,4) ; (2,3) <--> (2,4) ; (1,1) <--> (0,1) ; (1,2) <--> (0,2) ; (4,4) <--> (5,4) ; (1,0) <--> (2,0) ; (1,1) <--> (1,0) ; (5,3) <--> (4,3) ; (3,5) <--