In [2]:
# Generic
import typing
import os
from pathlib import Path
import typing
import json

# Plotting
import matplotlib.pyplot as plt

# Numerical Computing
import numpy as np
import torch

# muutils
from zanj.zanj import ZANJ, ZANJ_GLOBAL_DEFAULTS

# Our Code
from muutils.nbutils.configure_notebook import configure_notebook
from maze_transformer.training.config import ConfigHolder, ZanjHookedTransformer, BaseGPTConfig, TrainConfig
from maze_dataset import MazeDataset, MazeDatasetConfig
from maze_dataset.dataset.configs import MAZE_DATASET_CONFIGS
from maze_dataset.generation import LatticeMazeGenerators
from maze_transformer.training.train_model import TrainingResult, train_model
from maze_transformer.training.wandb_logger import WandbProject


  torch.utils._pytree._register_pytree_node(


In [1]:
# import os
# os.environ['CUDA_LAUNCH_BLOCKING'] = "1"\
# torch.cuda.is_available() 
# print(torch.cuda.get_device_name(0))  # Should return the name of your GPU
# ! pip install --upgrade pip

import torch
print(torch.cuda.is_available())  # Verifies CUDA is available to PyTorch
print(torch.cuda.get_device_name(0))  # Displays the name of your CUDA device


True
NVIDIA RTX 2000 Ada Generation Laptop GPU


In [3]:
# !pip install torch torchvision torchaudio


In [4]:
# install the torch version that is compatible with the cuda version
# !pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html	
# !pip install torch==2.2.1+cu111 torchvision==0.13.1+cu111 torchaudio==0.12.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html
# !pip install torch==2.2.1+cu118 torchvision==0.13.1+cu118 torchaudio==0.12.1+cu118 -f https://download.pytorch.org/whl/torch_stable.html

# !pip install torch==2.2.1+cu118




In [5]:
# !pip install torchvision==0.15.0+cu118 torchaudio==0.12.1+cu118 -f https://download.pytorch.org/whl/torch_stable.html


In [3]:
# set global defaults for ZANJ
ZANJ_GLOBAL_DEFAULTS.external_array_threshold = 1024
ZANJ_GLOBAL_DEFAULTS.external_list_threshold = 1024

# paths
PATH_EXAMPLES: Path = Path("../examples/")
PATH_DATA: Path = Path("../data/")

# reproducibility and device
DEVICE = configure_notebook(seed=42, dark_mode=True, device="cuda")
print(f"{DEVICE = }")

set up plots with PLOT_MODE = 'inline', FIG_OUTPUT_FMT = None, FIG_BASEPATH = None
DEVICE = device(type='cuda')


In [4]:
# set DEVICE to cuda
# DEVICE = torch.device("cuda")

# 1. Train the model according to the config
- NOTE: intervals_count specifies how often to do that thing in 1 epoch. So setting it low will mean do it infrequent.

### new configs

In [24]:
eval_scale = 1/100
CFG_HALLWAY_MOD1: ConfigHolder = ConfigHolder(
    name = "hallway_vLeon",
    dataset_cfg = MazeDatasetConfig(
		name="hallway_vLeon",
		grid_n=6,
		n_mazes=10_000,
		maze_ctor=LatticeMazeGenerators.gen_dfs,
        maze_ctor_kwargs=dict(
            do_forks=False,
        ),
        applied_filters=[{'name': 'collect_generation_meta', 'args': (), 'kwargs': {}}],
        seq_len_max=256,
        seq_len_min=3,
	),
    model_cfg = BaseGPTConfig(
        name="custom-model",
        act_fn="gelu",
        d_model=512,
        d_head=64,
        n_layers=3,
    ),
    train_cfg = TrainConfig(
        name="custom-train",
        optimizer=torch.optim.AdamW,
        optimizer_kwargs=dict(lr=0.0003),
        batch_size=32,
        num_epochs=12,
        dataloader_cfg=dict(
            shuffle=True,
            num_workers=2,
            drop_last=False,
        ),

        intervals_count=dict(
            print_loss=2,
            checkpoint=2,
            eval_fast=2,
            eval_slow=2,
        ),
        validation_dataset_cfg=10,
    ),
)

In [25]:
# here is where to specify which config to actually use
CFG: ConfigHolder = CFG_HALLWAY_MOD1
print("model name;", CFG.dataset_cfg.to_fname())

model name; hallway_vLeon-g6-n10K-a_dfs-h67461


In [26]:
# load the dataset
DATASET: MazeDataset = MazeDataset.from_config(CFG.dataset_cfg, verbose=True, local_base_path=PATH_DATA)
print("len dataset", len(DATASET))

trying to get the dataset 'hallway_vLeon-g6-n10K-a_dfs-h67461'
loading dataset from ../data/hallway_vLeon-g6-n10K-a_dfs-h67461.zanj
Got dataset hallway_vLeon with 10000 items. output.cfg.to_fname() = 'hallway_vLeon-g6-n10K-a_dfs-h67461'
len dataset 10000


In [27]:
result: TrainingResult = train_model(
	base_path=PATH_DATA,
    cfg=CFG,
	wandb_project=WandbProject.DEMO_NOTEBOOKS, # change this to WandbProject.DEMO_NOTEBOOKS!
	do_generate_dataset=False,
	dataset_verbose=True,
    dataset=DATASET,
)

0,1
average_predicted_step_size-mean,▁▁▁▁▁
corner_jumps-mean,▁▁▁▁▁
distance_between_end_nodes-mean,█▂▅▂▁
epoch,▁▃▅▆█
exact_path_predicted-mean,▁▃▃▆█
fraction_connections_adjacent-mean,▁▅▆▇█
fraction_connections_adjacent_lattice-mean,▁▅█▅▃
loss,█▅▅▄▃▃▃▃▃▃▃▂▃▃▃▃▂▂▂▂▁▂▂▁▂▁▂▂▂▂▁▂▂▁▁▂▂▁▁▁
n_samples_seen,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
node_overlap-mean,▁▆▅▆█

0,1
average_predicted_step_size-mean,1.0
corner_jumps-mean,0.0
device,cuda
distance_between_end_nodes-mean,0.87287
epoch,4
exact_path_predicted-mean,0.6
fraction_connections_adjacent-mean,0.94286
fraction_connections_adjacent_lattice-mean,0.75833
loss,0.69056
model.device,cuda


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016672908750001623, max=1.0…

2024-03-16 19:32:18 INFO config ={'__format__': 'ConfigHolder(SerializableDataclass)', 'dataset_cfg': {'__format__': 'MazeDatasetConfig(SerializableDataclass)', 'name': 'hallway_vLeon', 'seq_len_min': 3, 'seq_len_max': 256, 'seed': 42, 'applied_filters': [{'name': 'collect_generation_meta', 'args': (), 'kwargs': {}}], 'grid_n': 6, 'n_mazes': 10000, 'maze_ctor': {'__name__': 'gen_dfs', '__module__': 'maze_dataset.generation.generators', '__doc__': ['generate a lattice maze using depth first search, iterative', '', '        # Arguments', '        - `grid_shape: Coord`: the shape of the grid', '        - `lattice_dim: int`: the dimension of the lattice', '          (default: `2`)', '        - `accessible_cells: int | float |None`: the number of accessible cells in the maze. If `None`, defaults to the total number of cells in the grid. if a float, asserts it is <= 1 and treats it as a proportion of **total cells**', '            (default: `None`)', '        - `max_tree_depth: int | float |

  0%|          | 0/313 [00:00<?, ?it/s]

2024-03-16 19:32:19 INFO iteration 0/313: loss=4.915
2024-03-16 19:32:19 INFO Saving model checkpoint to ../data/hallway_vLeon_2024-03-16-19-32-09/checkpoints/model.iter_0.zanj


 50%|████▉     | 156/313 [00:22<00:20,  7.80it/s]

2024-03-16 19:32:41 INFO iteration 156/313: loss=0.733
2024-03-16 19:32:41 INFO Saving model checkpoint to ../data/hallway_vLeon_2024-03-16-19-32-09/checkpoints/model.iter_156.zanj


100%|█████████▉| 312/313 [00:44<00:00,  7.40it/s]

2024-03-16 19:33:03 INFO iteration 312/313: loss=1.136
2024-03-16 19:33:03 INFO Saving model checkpoint to ../data/hallway_vLeon_2024-03-16-19-32-09/checkpoints/model.iter_312.zanj


100%|██████████| 313/313 [00:45<00:00,  6.85it/s]

2024-03-16 19:33:04 INFO Running evals: eval_fast





for interval_key eval_fast scores are {'node_overlap': {
  "StatCounter": {
    "0.0": 3,
    "0.13043478260869565": 1,
    "0.2": 1,
    "0.3333333333333333": 1,
    "0.5": 3,
    "0.75": 1
  },
  "summary": {
    "total_items": 10,
    "n_keys": 6,
    "mode": 0.5,
    "mean": 0.2913768115942029,
    "std": 0.25132048867477014,
    "min": 0.0,
    "q1": 0.03260869565217391,
    "median": 0.26666666666666666,
    "q3": 0.5,
    "max": 0.75
  }
}, 'num_connections_adjacent_lattice': {
  "StatCounter": {
    "1.0": 3,
    "3.0": 1,
    "4.0": 1,
    "5.0": 1,
    "7.0": 4
  },
  "summary": {
    "total_items": 10,
    "n_keys": 5,
    "mode": 7.0,
    "mean": 4.3,
    "std": 2.5317977802344322,
    "min": 1.0,
    "q1": 1.5,
    "median": 4.5,
    "q3": 7.0,
    "max": 7.0
  }
}, 'fraction_connections_adjacent_lattice': {
  "StatCounter": {
    "0.5": 3,
    "0.75": 1,
    "0.8": 1,
    "0.8333333333333334": 1,
    "0.875": 4
  },
  "summary": {
    "total_items": 10,
    "n_keys": 5,
 

  0%|          | 0/313 [00:00<?, ?it/s]

2024-03-16 19:33:06 INFO iteration 0/313: loss=0.915
2024-03-16 19:33:06 INFO Saving model checkpoint to ../data/hallway_vLeon_2024-03-16-19-32-09/checkpoints/model.iter_0.zanj


 50%|████▉     | 156/313 [00:22<00:20,  7.53it/s]

2024-03-16 19:33:28 INFO iteration 156/313: loss=0.864
2024-03-16 19:33:28 INFO Saving model checkpoint to ../data/hallway_vLeon_2024-03-16-19-32-09/checkpoints/model.iter_156.zanj


100%|█████████▉| 312/313 [00:44<00:00,  7.38it/s]

2024-03-16 19:33:50 INFO iteration 312/313: loss=0.720
2024-03-16 19:33:50 INFO Saving model checkpoint to ../data/hallway_vLeon_2024-03-16-19-32-09/checkpoints/model.iter_312.zanj


100%|██████████| 313/313 [00:50<00:00,  6.24it/s]

2024-03-16 19:33:56 INFO Running evals: eval_fast





for interval_key eval_fast scores are {'node_overlap': {
  "StatCounter": {
    "0.21739130434782608": 1,
    "0.3333333333333333": 1,
    "0.4": 1,
    "0.5": 2,
    "0.6666666666666666": 1,
    "1.0": 4
  },
  "summary": {
    "total_items": 10,
    "n_keys": 6,
    "mode": 1.0,
    "mean": 0.6617391304347826,
    "std": 0.2972924094684865,
    "min": 0.21739130434782608,
    "q1": 0.42500000000000004,
    "median": 0.5833333333333333,
    "q3": 1.0,
    "max": 1.0
  }
}, 'num_connections_adjacent_lattice': {
  "StatCounter": {
    "1.0": 1,
    "2.0": 1,
    "3.0": 2,
    "4.0": 1,
    "5.0": 2,
    "7.0": 3
  },
  "summary": {
    "total_items": 10,
    "n_keys": 6,
    "mode": 7.0,
    "mean": 4.4,
    "std": 2.0591260281974,
    "min": 1.0,
    "q1": 3.0,
    "median": 4.5,
    "q3": 6.5,
    "max": 7.0
  }
}, 'fraction_connections_adjacent_lattice': {
  "StatCounter": {
    "0.5": 1,
    "0.6666666666666666": 1,
    "0.75": 2,
    "0.8": 1,
    "0.8333333333333334": 2,
    "0.87

  0%|          | 0/313 [00:00<?, ?it/s]

2024-03-16 19:33:57 INFO iteration 0/313: loss=0.745
2024-03-16 19:33:57 INFO Saving model checkpoint to ../data/hallway_vLeon_2024-03-16-19-32-09/checkpoints/model.iter_0.zanj


 50%|████▉     | 156/313 [00:22<00:21,  7.45it/s]

2024-03-16 19:34:20 INFO iteration 156/313: loss=0.755
2024-03-16 19:34:20 INFO Saving model checkpoint to ../data/hallway_vLeon_2024-03-16-19-32-09/checkpoints/model.iter_156.zanj


100%|█████████▉| 312/313 [00:44<00:00,  7.59it/s]

2024-03-16 19:34:42 INFO iteration 312/313: loss=0.728
2024-03-16 19:34:42 INFO Saving model checkpoint to ../data/hallway_vLeon_2024-03-16-19-32-09/checkpoints/model.iter_312.zanj


100%|██████████| 313/313 [00:46<00:00,  6.76it/s]

2024-03-16 19:34:43 INFO Running evals: eval_fast





for interval_key eval_fast scores are {'node_overlap': {
  "StatCounter": {
    "0.16666666666666666": 1,
    "0.25": 1,
    "0.3333333333333333": 1,
    "0.34782608695652173": 1,
    "0.4666666666666667": 1,
    "0.5": 1,
    "0.6666666666666666": 1,
    "1.0": 3
  },
  "summary": {
    "total_items": 10,
    "n_keys": 8,
    "mode": 1.0,
    "mean": 0.5731159420289856,
    "std": 0.3082166880298473,
    "min": 0.16666666666666666,
    "q1": 0.33695652173913043,
    "median": 0.48333333333333334,
    "q3": 0.9166666666666666,
    "max": 1.0
  }
}, 'num_connections_adjacent_lattice': {
  "StatCounter": {
    "1.0": 1,
    "2.0": 1,
    "3.0": 1,
    "7.0": 7
  },
  "summary": {
    "total_items": 10,
    "n_keys": 4,
    "mode": 7.0,
    "mean": 5.5,
    "std": 2.3345235059857505,
    "min": 1.0,
    "q1": 4.0,
    "median": 7.0,
    "q3": 7.0,
    "max": 7.0
  }
}, 'fraction_connections_adjacent_lattice': {
  "StatCounter": {
    "0.5": 1,
    "0.6666666666666666": 1,
    "0.75": 1,
 

  0%|          | 0/313 [00:00<?, ?it/s]

2024-03-16 19:34:45 INFO iteration 0/313: loss=0.677
2024-03-16 19:34:45 INFO Saving model checkpoint to ../data/hallway_vLeon_2024-03-16-19-32-09/checkpoints/model.iter_0.zanj


 50%|████▉     | 156/313 [00:22<00:22,  7.08it/s]

2024-03-16 19:35:07 INFO iteration 156/313: loss=0.688
2024-03-16 19:35:07 INFO Saving model checkpoint to ../data/hallway_vLeon_2024-03-16-19-32-09/checkpoints/model.iter_156.zanj


100%|█████████▉| 312/313 [00:44<00:00,  7.35it/s]

2024-03-16 19:35:29 INFO iteration 312/313: loss=0.675
2024-03-16 19:35:29 INFO Saving model checkpoint to ../data/hallway_vLeon_2024-03-16-19-32-09/checkpoints/model.iter_312.zanj


100%|██████████| 313/313 [00:46<00:00,  6.78it/s]

2024-03-16 19:35:31 INFO Running evals: eval_fast





for interval_key eval_fast scores are {'node_overlap': {
  "StatCounter": {
    "0.16666666666666666": 1,
    "0.2": 1,
    "0.3333333333333333": 1,
    "0.34782608695652173": 1,
    "1.0": 6
  },
  "summary": {
    "total_items": 10,
    "n_keys": 5,
    "mode": 1.0,
    "mean": 0.7047826086956521,
    "std": 0.3650598061219404,
    "min": 0.16666666666666666,
    "q1": 0.33695652173913043,
    "median": 1.0,
    "q3": 1.0,
    "max": 1.0
  }
}, 'num_connections_adjacent_lattice': {
  "StatCounter": {
    "1.0": 2,
    "2.0": 1,
    "5.0": 1,
    "7.0": 6
  },
  "summary": {
    "total_items": 10,
    "n_keys": 4,
    "mode": 7.0,
    "mean": 5.1,
    "std": 2.5475478405713994,
    "min": 1.0,
    "q1": 2.75,
    "median": 7.0,
    "q3": 7.0,
    "max": 7.0
  }
}, 'fraction_connections_adjacent_lattice': {
  "StatCounter": {
    "0.5": 2,
    "0.6666666666666666": 1,
    "0.8333333333333334": 1,
    "0.875": 6
  },
  "summary": {
    "total_items": 10,
    "n_keys": 4,
    "mode": 0.8

  0%|          | 0/313 [00:00<?, ?it/s]

2024-03-16 19:35:32 INFO iteration 0/313: loss=0.662
2024-03-16 19:35:32 INFO Saving model checkpoint to ../data/hallway_vLeon_2024-03-16-19-32-09/checkpoints/model.iter_0.zanj


 50%|████▉     | 156/313 [00:22<00:22,  6.98it/s]

2024-03-16 19:35:55 INFO iteration 156/313: loss=0.543
2024-03-16 19:35:55 INFO Saving model checkpoint to ../data/hallway_vLeon_2024-03-16-19-32-09/checkpoints/model.iter_156.zanj


100%|█████████▉| 312/313 [00:44<00:00,  7.48it/s]

2024-03-16 19:36:17 INFO iteration 312/313: loss=0.691
2024-03-16 19:36:17 INFO Saving model checkpoint to ../data/hallway_vLeon_2024-03-16-19-32-09/checkpoints/model.iter_312.zanj


100%|██████████| 313/313 [00:46<00:00,  6.77it/s]

2024-03-16 19:36:18 INFO Running evals: eval_fast





for interval_key eval_fast scores are {'node_overlap': {
  "StatCounter": {
    "0.3333333333333333": 1,
    "0.34782608695652173": 1,
    "0.5333333333333333": 1,
    "1.0": 7
  },
  "summary": {
    "total_items": 10,
    "n_keys": 4,
    "mode": 1.0,
    "mean": 0.8214492753623188,
    "std": 0.2772633084975822,
    "min": 0.3333333333333333,
    "q1": 0.65,
    "median": 1.0,
    "q3": 1.0,
    "max": 1.0
  }
}, 'num_connections_adjacent_lattice': {
  "StatCounter": {
    "1.0": 2,
    "2.0": 1,
    "3.0": 1,
    "5.0": 2,
    "7.0": 4
  },
  "summary": {
    "total_items": 10,
    "n_keys": 5,
    "mode": 7.0,
    "mean": 4.5,
    "std": 2.4186773244895647,
    "min": 1.0,
    "q1": 2.25,
    "median": 5.0,
    "q3": 7.0,
    "max": 7.0
  }
}, 'fraction_connections_adjacent_lattice': {
  "StatCounter": {
    "0.5": 2,
    "0.6666666666666666": 1,
    "0.75": 1,
    "0.8333333333333334": 2,
    "0.875": 4
  },
  "summary": {
    "total_items": 10,
    "n_keys": 5,
    "mode": 0.875

  0%|          | 0/313 [00:00<?, ?it/s]

2024-03-16 19:36:20 INFO iteration 0/313: loss=0.558
2024-03-16 19:36:20 INFO Saving model checkpoint to ../data/hallway_vLeon_2024-03-16-19-32-09/checkpoints/model.iter_0.zanj


 50%|████▉     | 156/313 [00:22<00:21,  7.35it/s]

2024-03-16 19:36:42 INFO iteration 156/313: loss=0.568
2024-03-16 19:36:42 INFO Saving model checkpoint to ../data/hallway_vLeon_2024-03-16-19-32-09/checkpoints/model.iter_156.zanj


100%|█████████▉| 312/313 [00:44<00:00,  7.31it/s]

2024-03-16 19:37:04 INFO iteration 312/313: loss=0.566
2024-03-16 19:37:04 INFO Saving model checkpoint to ../data/hallway_vLeon_2024-03-16-19-32-09/checkpoints/model.iter_312.zanj


100%|██████████| 313/313 [00:46<00:00,  6.79it/s]

2024-03-16 19:37:05 INFO Running evals: eval_fast





for interval_key eval_fast scores are {'node_overlap': {
  "StatCounter": {
    "0.17391304347826086": 1,
    "0.2": 1,
    "0.3333333333333333": 1,
    "0.5": 1,
    "0.6666666666666666": 2,
    "1.0": 4
  },
  "summary": {
    "total_items": 10,
    "n_keys": 6,
    "mode": 1.0,
    "mean": 0.6540579710144927,
    "std": 0.3228509644117423,
    "min": 0.17391304347826086,
    "q1": 0.375,
    "median": 0.6666666666666666,
    "q3": 1.0,
    "max": 1.0
  }
}, 'num_connections_adjacent_lattice': {
  "StatCounter": {
    "3.0": 2,
    "5.0": 1,
    "7.0": 7
  },
  "summary": {
    "total_items": 10,
    "n_keys": 3,
    "mode": 7.0,
    "mean": 6.0,
    "std": 1.61245154965971,
    "min": 3.0,
    "q1": 5.5,
    "median": 7.0,
    "q3": 7.0,
    "max": 7.0
  }
}, 'fraction_connections_adjacent_lattice': {
  "StatCounter": {
    "0.75": 2,
    "0.8333333333333334": 1,
    "0.875": 7
  },
  "summary": {
    "total_items": 10,
    "n_keys": 3,
    "mode": 0.875,
    "mean": 0.8458333333333

  0%|          | 0/313 [00:00<?, ?it/s]

2024-03-16 19:37:07 INFO iteration 0/313: loss=0.548
2024-03-16 19:37:07 INFO Saving model checkpoint to ../data/hallway_vLeon_2024-03-16-19-32-09/checkpoints/model.iter_0.zanj


 50%|████▉     | 156/313 [00:22<00:21,  7.28it/s]

2024-03-16 19:37:30 INFO iteration 156/313: loss=0.660
2024-03-16 19:37:30 INFO Saving model checkpoint to ../data/hallway_vLeon_2024-03-16-19-32-09/checkpoints/model.iter_156.zanj


100%|█████████▉| 312/313 [00:44<00:00,  7.33it/s]

2024-03-16 19:37:52 INFO iteration 312/313: loss=0.663
2024-03-16 19:37:52 INFO Saving model checkpoint to ../data/hallway_vLeon_2024-03-16-19-32-09/checkpoints/model.iter_312.zanj


100%|██████████| 313/313 [00:46<00:00,  6.77it/s]

2024-03-16 19:37:53 INFO Running evals: eval_fast





for interval_key eval_fast scores are {'node_overlap': {
  "StatCounter": {
    "0.2": 1,
    "0.21739130434782608": 1,
    "0.5": 1,
    "0.6666666666666666": 1,
    "1.0": 6
  },
  "summary": {
    "total_items": 10,
    "n_keys": 5,
    "mode": 1.0,
    "mean": 0.7584057971014493,
    "std": 0.3209220922404139,
    "min": 0.2,
    "q1": 0.5416666666666666,
    "median": 1.0,
    "q3": 1.0,
    "max": 1.0
  }
}, 'num_connections_adjacent_lattice': {
  "StatCounter": {
    "1.0": 2,
    "2.0": 2,
    "3.0": 1,
    "5.0": 1,
    "7.0": 4
  },
  "summary": {
    "total_items": 10,
    "n_keys": 5,
    "mode": 7.0,
    "mean": 4.2,
    "std": 2.5219040425836985,
    "min": 1.0,
    "q1": 2.0,
    "median": 4.0,
    "q3": 7.0,
    "max": 7.0
  }
}, 'fraction_connections_adjacent_lattice': {
  "StatCounter": {
    "0.5": 2,
    "0.6666666666666666": 2,
    "0.75": 1,
    "0.8333333333333334": 1,
    "0.875": 4
  },
  "summary": {
    "total_items": 10,
    "n_keys": 5,
    "mode": 0.875,
 

  0%|          | 0/313 [00:00<?, ?it/s]

2024-03-16 19:37:54 INFO iteration 0/313: loss=0.544
2024-03-16 19:37:54 INFO Saving model checkpoint to ../data/hallway_vLeon_2024-03-16-19-32-09/checkpoints/model.iter_0.zanj


 50%|████▉     | 156/313 [00:22<00:20,  7.53it/s]

2024-03-16 19:38:17 INFO iteration 156/313: loss=0.524
2024-03-16 19:38:17 INFO Saving model checkpoint to ../data/hallway_vLeon_2024-03-16-19-32-09/checkpoints/model.iter_156.zanj


100%|█████████▉| 312/313 [00:44<00:00,  7.52it/s]

2024-03-16 19:38:39 INFO iteration 312/313: loss=0.541
2024-03-16 19:38:39 INFO Saving model checkpoint to ../data/hallway_vLeon_2024-03-16-19-32-09/checkpoints/model.iter_312.zanj


100%|██████████| 313/313 [00:46<00:00,  6.75it/s]

2024-03-16 19:38:41 INFO Running evals: eval_fast





for interval_key eval_fast scores are {'node_overlap': {
  "StatCounter": {
    "0.16666666666666666": 1,
    "0.25": 1,
    "0.34782608695652173": 1,
    "0.5": 1,
    "0.5333333333333333": 1,
    "0.6666666666666666": 1,
    "1.0": 4
  },
  "summary": {
    "total_items": 10,
    "n_keys": 7,
    "mode": 1.0,
    "mean": 0.6464492753623189,
    "std": 0.3179910371802138,
    "min": 0.16666666666666666,
    "q1": 0.3858695652173913,
    "median": 0.6,
    "q3": 1.0,
    "max": 1.0
  }
}, 'num_connections_adjacent_lattice': {
  "StatCounter": {
    "1.0": 1,
    "2.0": 2,
    "3.0": 2,
    "6.0": 1,
    "7.0": 4
  },
  "summary": {
    "total_items": 10,
    "n_keys": 5,
    "mode": 7.0,
    "mean": 4.5,
    "std": 2.3769728648009427,
    "min": 1.0,
    "q1": 2.25,
    "median": 4.5,
    "q3": 7.0,
    "max": 7.0
  }
}, 'fraction_connections_adjacent_lattice': {
  "StatCounter": {
    "0.5": 1,
    "0.6666666666666666": 2,
    "0.75": 3,
    "0.875": 4
  },
  "summary": {
    "total_i

  0%|          | 0/313 [00:00<?, ?it/s]

2024-03-16 19:38:42 INFO iteration 0/313: loss=0.593
2024-03-16 19:38:42 INFO Saving model checkpoint to ../data/hallway_vLeon_2024-03-16-19-32-09/checkpoints/model.iter_0.zanj


 50%|████▉     | 156/313 [00:22<00:21,  7.44it/s]

2024-03-16 19:39:04 INFO iteration 156/313: loss=0.599
2024-03-16 19:39:04 INFO Saving model checkpoint to ../data/hallway_vLeon_2024-03-16-19-32-09/checkpoints/model.iter_156.zanj


100%|█████████▉| 312/313 [00:44<00:00,  7.43it/s]

2024-03-16 19:39:27 INFO iteration 312/313: loss=0.690
2024-03-16 19:39:27 INFO Saving model checkpoint to ../data/hallway_vLeon_2024-03-16-19-32-09/checkpoints/model.iter_312.zanj


100%|██████████| 313/313 [00:46<00:00,  6.77it/s]

2024-03-16 19:39:28 INFO Running evals: eval_fast





for interval_key eval_fast scores are {'node_overlap': {
  "StatCounter": {
    "0.16666666666666666": 1,
    "0.2": 1,
    "0.34782608695652173": 1,
    "1.0": 7
  },
  "summary": {
    "total_items": 10,
    "n_keys": 4,
    "mode": 1.0,
    "mean": 0.7714492753623189,
    "std": 0.351769941861836,
    "min": 0.16666666666666666,
    "q1": 0.5108695652173914,
    "median": 1.0,
    "q3": 1.0,
    "max": 1.0
  }
}, 'num_connections_adjacent_lattice': {
  "StatCounter": {
    "1.0": 2,
    "2.0": 2,
    "3.0": 1,
    "5.0": 1,
    "6.0": 1,
    "7.0": 3
  },
  "summary": {
    "total_items": 10,
    "n_keys": 6,
    "mode": 7.0,
    "mean": 4.1,
    "std": 2.4269322199023193,
    "min": 1.0,
    "q1": 2.0,
    "median": 4.0,
    "q3": 6.75,
    "max": 7.0
  }
}, 'fraction_connections_adjacent_lattice': {
  "StatCounter": {
    "0.5": 2,
    "0.6666666666666666": 2,
    "0.75": 2,
    "0.8333333333333334": 1,
    "0.875": 3
  },
  "summary": {
    "total_items": 10,
    "n_keys": 5,
   

  0%|          | 0/313 [00:00<?, ?it/s]

2024-03-16 19:39:29 INFO iteration 0/313: loss=0.549
2024-03-16 19:39:29 INFO Saving model checkpoint to ../data/hallway_vLeon_2024-03-16-19-32-09/checkpoints/model.iter_0.zanj


 50%|████▉     | 156/313 [00:22<00:20,  7.51it/s]

2024-03-16 19:39:52 INFO iteration 156/313: loss=0.597
2024-03-16 19:39:52 INFO Saving model checkpoint to ../data/hallway_vLeon_2024-03-16-19-32-09/checkpoints/model.iter_156.zanj


100%|█████████▉| 312/313 [00:44<00:00,  7.29it/s]

2024-03-16 19:40:14 INFO iteration 312/313: loss=0.791
2024-03-16 19:40:14 INFO Saving model checkpoint to ../data/hallway_vLeon_2024-03-16-19-32-09/checkpoints/model.iter_312.zanj


100%|██████████| 313/313 [00:46<00:00,  6.78it/s]

2024-03-16 19:40:15 INFO Running evals: eval_fast





for interval_key eval_fast scores are {'node_overlap': {
  "StatCounter": {
    "0.3333333333333333": 1,
    "0.34782608695652173": 1,
    "0.5": 1,
    "0.5333333333333333": 1,
    "0.6666666666666666": 1,
    "1.0": 5
  },
  "summary": {
    "total_items": 10,
    "n_keys": 6,
    "mode": 1.0,
    "mean": 0.7381159420289855,
    "std": 0.27619405279515324,
    "min": 0.3333333333333333,
    "q1": 0.5083333333333333,
    "median": 0.8333333333333333,
    "q3": 1.0,
    "max": 1.0
  }
}, 'num_connections_adjacent_lattice': {
  "StatCounter": {
    "1.0": 2,
    "2.0": 1,
    "3.0": 1,
    "5.0": 1,
    "6.0": 2,
    "7.0": 3
  },
  "summary": {
    "total_items": 10,
    "n_keys": 6,
    "mode": 7.0,
    "mean": 4.5,
    "std": 2.3769728648009427,
    "min": 1.0,
    "q1": 2.25,
    "median": 5.5,
    "q3": 6.75,
    "max": 7.0
  }
}, 'fraction_connections_adjacent_lattice': {
  "StatCounter": {
    "0.5": 2,
    "0.6666666666666666": 1,
    "0.75": 3,
    "0.8333333333333334": 1,
    

  0%|          | 0/313 [00:00<?, ?it/s]

2024-03-16 19:40:17 INFO iteration 0/313: loss=0.563
2024-03-16 19:40:17 INFO Saving model checkpoint to ../data/hallway_vLeon_2024-03-16-19-32-09/checkpoints/model.iter_0.zanj


 50%|████▉     | 156/313 [00:22<00:20,  7.50it/s]

2024-03-16 19:40:39 INFO iteration 156/313: loss=0.583
2024-03-16 19:40:39 INFO Saving model checkpoint to ../data/hallway_vLeon_2024-03-16-19-32-09/checkpoints/model.iter_156.zanj


100%|█████████▉| 312/313 [00:44<00:00,  7.45it/s]

2024-03-16 19:41:01 INFO iteration 312/313: loss=0.681
2024-03-16 19:41:01 INFO Saving model checkpoint to ../data/hallway_vLeon_2024-03-16-19-32-09/checkpoints/model.iter_312.zanj


100%|██████████| 313/313 [00:46<00:00,  6.77it/s]

2024-03-16 19:41:03 INFO Running evals: eval_fast





for interval_key eval_fast scores are {'node_overlap': {
  "StatCounter": {
    "0.16666666666666666": 1,
    "0.2": 1,
    "0.25": 1,
    "0.3333333333333333": 1,
    "0.34782608695652173": 1,
    "1.0": 5
  },
  "summary": {
    "total_items": 10,
    "n_keys": 6,
    "mode": 1.0,
    "mean": 0.6297826086956522,
    "std": 0.3736456253592922,
    "min": 0.16666666666666666,
    "q1": 0.2708333333333333,
    "median": 0.6739130434782609,
    "q3": 1.0,
    "max": 1.0
  }
}, 'num_connections_adjacent_lattice': {
  "StatCounter": {
    "1.0": 2,
    "2.0": 1,
    "5.0": 1,
    "6.0": 1,
    "7.0": 5
  },
  "summary": {
    "total_items": 10,
    "n_keys": 5,
    "mode": 7.0,
    "mean": 5.0,
    "std": 2.4899799195977463,
    "min": 1.0,
    "q1": 2.75,
    "median": 6.5,
    "q3": 7.0,
    "max": 7.0
  }
}, 'fraction_connections_adjacent_lattice': {
  "StatCounter": {
    "0.5": 2,
    "0.6666666666666666": 1,
    "0.75": 1,
    "0.8333333333333334": 1,
    "0.875": 5
  },
  "summary":

  0%|          | 0/313 [00:00<?, ?it/s]

2024-03-16 19:41:04 INFO iteration 0/313: loss=0.576
2024-03-16 19:41:04 INFO Saving model checkpoint to ../data/hallway_vLeon_2024-03-16-19-32-09/checkpoints/model.iter_0.zanj


 50%|████▉     | 156/313 [00:22<00:21,  7.40it/s]

2024-03-16 19:41:27 INFO iteration 156/313: loss=0.544
2024-03-16 19:41:27 INFO Saving model checkpoint to ../data/hallway_vLeon_2024-03-16-19-32-09/checkpoints/model.iter_156.zanj


100%|█████████▉| 312/313 [00:44<00:00,  7.93it/s]

2024-03-16 19:41:49 INFO iteration 312/313: loss=0.807
2024-03-16 19:41:49 INFO Saving model checkpoint to ../data/hallway_vLeon_2024-03-16-19-32-09/checkpoints/model.iter_312.zanj


100%|██████████| 313/313 [00:46<00:00,  6.79it/s]

2024-03-16 19:41:50 INFO Running evals: eval_fast





for interval_key eval_fast scores are {'node_overlap': {
  "StatCounter": {
    "0.13333333333333333": 1,
    "0.16666666666666666": 1,
    "0.21739130434782608": 1,
    "0.3333333333333333": 1,
    "1.0": 6
  },
  "summary": {
    "total_items": 10,
    "n_keys": 5,
    "mode": 1.0,
    "mean": 0.685072463768116,
    "std": 0.3886750364687126,
    "min": 0.13333333333333333,
    "q1": 0.24637681159420288,
    "median": 1.0,
    "q3": 1.0,
    "max": 1.0
  }
}, 'num_connections_adjacent_lattice': {
  "StatCounter": {
    "1.0": 2,
    "2.0": 1,
    "3.0": 1,
    "5.0": 1,
    "7.0": 5
  },
  "summary": {
    "total_items": 10,
    "n_keys": 5,
    "mode": 7.0,
    "mean": 4.7,
    "std": 2.5317977802344322,
    "min": 1.0,
    "q1": 2.25,
    "median": 6.0,
    "q3": 7.0,
    "max": 7.0
  }
}, 'fraction_connections_adjacent_lattice': {
  "StatCounter": {
    "0.5": 2,
    "0.6666666666666666": 1,
    "0.75": 1,
    "0.8333333333333334": 1,
    "0.875": 5
  },
  "summary": {
    "total_

In [13]:

def store_model_dataset(data_path, model, data_set, exp_name):
    # create path if does not exist
    data_path = Path(data_path)
    data_path.mkdir(parents=True, exist_ok=True)

    # save the model
    model.save(data_path / f"model_{exp_name}.zanj")

    # save the dataset
    data_set.save(data_path / f"dataset_{exp_name}.zanj")

def load_model_dataset(data_path):
    # create path if does not exist
    data_path = Path(data_path)

    # load the model
    model = ZanjHookedTransformer.read(data_path / f"model_{exp_name}.zanj")

    # load the dataset
    data_set = MazeDataset.read(data_path / f"dataset_{exp_name}.zanj")

    return model, data_set

data_path = "./data/training_exp/"
exp_name = "hallway_6x6_3k_ArchLeon"
# store_model_dataset(data_path, result.model, DATASET, exp_name)

# model2, dataset2 = load_model_dataset(data_path, exp_name)
# print(peniz)

NameError: name 'peniz' is not defined

# 2. Evaluate how good the model is

In [28]:
import pandas as pd
from maze_dataset.tokenization import MazeTokenizer, TokenizationMode
from maze_transformer.evaluation.eval_model import predict_maze_paths
from maze_transformer.evaluation.path_evals import rollout_evals
from maze_dataset import MazeDataset, MazeDatasetConfig, SolvedMaze, LatticeMaze, SPECIAL_TOKENS, LatticeMazeGenerators, CoordArray


## 2.1 Get evaluation statistics:

In [16]:
def evaluate_model_rollouts(
    model,
    dataset: MazeDataset,
    rollouts_acc: int|None = None,
    rollout_temperature: float = 0.0,
    batch_size: int|None = 16,
) -> dict[str, float]:


    tokenizer: MazeTokenizer = model.zanj_model_config.maze_tokenizer

    if rollouts_acc is not None:
        if rollouts_acc == "all":
            rollouts_acc = len(dataset)
            print(f"number of rollouts = {rollouts_acc}")
        dataset_tokens: list[list[str]] = dataset.as_tokens(maze_tokenizer=tokenizer)[:rollouts_acc]
        # do rollouts
        predictions: list[list[str|tuple[int, int]]] = predict_maze_paths(
            tokens_batch=dataset_tokens,
            data_cfg=dataset.cfg,
            model=model,
            max_new_tokens=None,
            temperature=rollout_temperature,
            when_noncoord="include",
            smart_max_new_tokens=True,
            batch_size=batch_size,
        )

        mazes: list[SolvedMaze] = dataset.mazes[:rollouts_acc]
        assert len(predictions) == len(mazes)

        # evaluate rollouts
        rollout_eval_results: dict[str, float] = rollout_evals(
            predictions=predictions,
            mazes=mazes,
        )

        return rollout_eval_results

    return {}

In [17]:
import typing

def print_keys(d: typing.Dict, prefix: str = " ", level: int = 1):
    indent = "  " * level  # Basic indentation for readability
    marker = "|-" if level > 0 else ""  # Add a marker for nested levels
    for k, v in d.items():
        print(f"{indent}{marker}{prefix}{k}: {v}")  # Print the current key with marker



In [18]:
RESULTS: pd.DataFrame = evaluate_model_rollouts(
	model=result.model,
	dataset=DATASET,
	# rollouts_acc="all", # number of mazes to check
	batch_size=64, 
    rollouts_acc=3000,
)

print(RESULTS)

{'correct EOS': 0.969, 'mean invalid tokens': 3.768666666666667, 'percent with invalid tokens': 0.5489999999999999, 'exactly correct rollouts': 0.018666666666666668, 'valid rollouts': 0.22633333333333333, 'rollouts with target reached': 0.841}


In [20]:
print_keys(RESULTS)

  |- correct EOS: 0.969
  |- mean invalid tokens: 3.768666666666667
  |- percent with invalid tokens: 0.5489999999999999
  |- exactly correct rollouts: 0.018666666666666668
  |- valid rollouts: 0.22633333333333333
  |- rollouts with target reached: 0.841


# Evaluate on Test set

In [29]:
CFG_data = CFG.dataset_cfg
# copy the object
CFG_testdata = MazeDatasetConfig(
		name="hallway_vLeon_test",
		grid_n=CFG_data.grid_n,
		n_mazes=1000,
		maze_ctor=LatticeMazeGenerators.gen_dfs,
        maze_ctor_kwargs=dict(
            do_forks=False,
        ),
        applied_filters=CFG_data.applied_filters,
        seq_len_max=CFG_data.seq_len_max,
        seq_len_min=CFG_data.seq_len_min,
        seed = 100,
	)


DATASET_test: MazeDataset = MazeDataset.from_config(CFG_testdata, verbose=True, local_base_path=PATH_DATA)


trying to get the dataset 'hallway_vLeon_test-g6-n1000-a_dfs-h55104'
loading dataset from ../data/hallway_vLeon_test-g6-n1000-a_dfs-h55104.zanj
Got dataset hallway_vLeon_test with 1000 items. output.cfg.to_fname() = 'hallway_vLeon_test-g6-n1000-a_dfs-h55104'



in GPTDatasetConfig self.name='hallway_vLeon_test', self.seed=100 is trying to override GLOBAL_SEED=42 which has already been changed elsewhere from DEFAULT_SEED=42



### Code to check the test dataset is different
- I check only the part of the maze after adjacency list

In [30]:
# check that DATASET2 and DATASET_test are different
tokenizer: MazeTokenizer = result.model.zanj_model_config.maze_tokenizer

dataset_tokens2: list[list[str]] = DATASET.as_tokens(maze_tokenizer=tokenizer)
dataset_tokens_test: list[list[str]] = DATASET_test.as_tokens(maze_tokenizer=tokenizer)

dataset2_str = [" ".join(tokens).split("<ORIGIN_START>")[-1] for tokens in dataset_tokens2]
dataset_test_str = [" ".join(tokens).split("<ORIGIN_START>")[-1] for tokens in dataset_tokens_test]

print(f"{len(dataset_tokens2) = }")

# check that no maze in  DATASET2 is also in  DATASET_test 
duplicate = 0
for maze_str in dataset_test_str:
    if maze_str in dataset2_str:
        duplicate += 1

print("number of duplicates", duplicate, "out of ", len(dataset_test_str), ", so perc duplicate", duplicate/len(dataset_test_str))

len(dataset_tokens2) = 9990
number of duplicates 398 out of  1000 , so perc duplicate 0.398


In [31]:
# evaluate on test set
RESULTS_test: pd.DataFrame = evaluate_model_rollouts(
	model=result.model,
	dataset=DATASET_test,
	rollouts_acc="all", # number of mazes to check
	batch_size=64, 
    # rollouts_acc=3000,
)

print(RESULTS_test)
print_keys(RESULTS_test)

number of rollouts = 1000
{'correct EOS': 0.994, 'mean invalid tokens': 1.616, 'percent with invalid tokens': 0.523, 'exactly correct rollouts': 0.043, 'valid rollouts': 0.26, 'rollouts with target reached': 0.882}
  |- correct EOS: 0.994
  |- mean invalid tokens: 1.616
  |- percent with invalid tokens: 0.523
  |- exactly correct rollouts: 0.043
  |- valid rollouts: 0.26
  |- rollouts with target reached: 0.882


In [None]:
model = result.model
dataset =DATASET
tokenizer: MazeTokenizer = model.zanj_model_config.maze_tokenizer
dataset_tokens: list[list[str]] = dataset.as_tokens(maze_tokenizer=tokenizer)
print(" ".join(dataset_tokens[0]))


<ADJLIST_START> (1,0) <--> (1,1) ; (1,1) <--> (1,2) ; (0,1) <--> (0,0) ; (0,2) <--> (0,1) ; (1,2) <--> (0,2) ; (1,0) <--> (2,0) ; <ADJLIST_END> <ORIGIN_START> (1,1) <ORIGIN_END> <TARGET_START> (1,2) <TARGET_END> <PATH_START> (1,1) (1,2) <PATH_END>


<ADJLIST_START> (1,0) <--> (0,0) ; (0,2) <--> (1,2) ; (1,2) <--> (2,2) ; (1,1) <--> (1,0) ; (1,4) <--> (1,3) ; (4,0) <--> (3,0) ; (2,0) <--> (2,1) ; (0,2) <--> (0,3) ; (3,1) <--> (3,2) ; (0,3) <--> (1,3) ; (4,4) <--> (3,4) ; (4,3) <--> (4,4) ; (4,2) <--> (4,3) ; (3,0) <--> (2,0) ; (1,4) <--> (0,4) ; (3,4) <--> (2,4) ; (3,3) <--> (3,2) ; (2,3) <--> (2,4) ; (2,2) <--> (2,1) ; (2,2) <--> (3,2) ; (0,1) <--> (0,0) ; (0,1) <--> (0,2) ; (4,0) <--> (4,1) ; (4,2) <--> (4,1) ; <ADJLIST_END> <ORIGIN_START> (1,3) <ORIGIN_END> <TARGET_START> (2,3) <TARGET_END> <PATH_START> (1,3) (0,3) (0,2) (1,2) (2,2) (2,1) (2,0) (3,0) (4,0) (4,1) (4,2) (4,3) (4,4) (3,4) (2,4) (2,3) <PATH_END>