In [1]:
# Generic
import typing
import os
from pathlib import Path
import typing
import json

# Plotting
import matplotlib.pyplot as plt

# Numerical Computing
import numpy as np
import torch

# muutils
from zanj.zanj import ZANJ, ZANJ_GLOBAL_DEFAULTS

# Our Code
from muutils.nbutils.configure_notebook import configure_notebook
from maze_transformer.training.config import ConfigHolder, ZanjHookedTransformer, BaseGPTConfig, TrainConfig
from maze_dataset import MazeDataset, MazeDatasetConfig, MazeDatasetCollectionConfig
from maze_dataset.dataset.configs import MAZE_DATASET_CONFIGS
from maze_dataset.generation import LatticeMazeGenerators
from maze_transformer.training.train_model import TrainingResult, train_model
from maze_transformer.training.wandb_logger import WandbProject


In [2]:
# set global defaults for ZANJ
ZANJ_GLOBAL_DEFAULTS.external_array_threshold = 1024
ZANJ_GLOBAL_DEFAULTS.external_list_threshold = 1024

# paths
PATH_EXAMPLES: Path = Path("../examples/")
PATH_DATA: Path = Path("../data/")

# reproducibility and device
DEVICE = configure_notebook(seed=42, dark_mode=True)
print(f"{DEVICE = }")

DEVICE = device(type='cuda')


In [3]:
print(f"{list(MAZE_DATASET_CONFIGS.keys()) = }")

# if you want to specify a custom config, you can do so here
CFG_CUSTOM: ConfigHolder = ConfigHolder(
    name = "hallway-medium",
    dataset_cfg = MazeDatasetConfig(
		name="custom-hallway",
		grid_n=8,
		n_mazes=1_000_000,
		maze_ctor=LatticeMazeGenerators.gen_dfs,
        maze_ctor_kwargs=dict(
            do_forks=False,
        ),
        applied_filters=[{'name': 'collect_generation_meta', 'args': (), 'kwargs': {}}],
	),
    model_cfg = BaseGPTConfig(
        name="custom-model",
        act_fn="gelu",
        d_model=128,
        d_head=32,
        n_layers=6,
    ),
    train_cfg = TrainConfig(
        name="custom-train",
        optimizer=torch.optim.AdamW,
        optimizer_kwargs=dict(lr=0.001),
        batch_size=32,
        dataloader_cfg=dict(
            shuffle=False,
            num_workers=8,
            drop_last=False,
        ),
        intervals_count=dict(
            print_loss=100,
            checkpoint=20,
            eval_fast=100,
            eval_slow=50,
        ),
        validation_dataset_cfg=100,
    ),
)

CFG_TEST: ConfigHolder = ConfigHolder(
    name = "hallway-nano",
    dataset_cfg = MazeDatasetConfig(
        name=f"custom-hallway",
        grid_n=3,
        n_mazes=50,
        maze_ctor=LatticeMazeGenerators.gen_dfs,
        maze_ctor_kwargs=dict(
            do_forks=False,
        ),
        applied_filters=[{'name': 'collect_generation_meta', 'args': (), 'kwargs': {}}],
    ),
    # TODO: fix this
    # dataset_cfg = MazeDatasetCollectionConfig(
    #     name="custom-hallway-collection",
    #     maze_dataset_configs=[
    #         MazeDatasetConfig(
    #             name=f"custom-hallway-{n}",
    #             grid_n=n,
    #             n_mazes=50,
    #             maze_ctor=LatticeMazeGenerators.gen_dfs,
    #             maze_ctor_kwargs=dict(
    #                 do_forks=False,
    #             ),
    #             applied_filters=[{'name': 'collect_generation_meta', 'args': (), 'kwargs': {}}],
    #         ) 
    #         for n in [3, 4]
    #     ]
    # ),
    model_cfg = BaseGPTConfig(
        name="custom-model",
        act_fn="gelu",
        d_model=8,
        d_head=2,
        n_layers=2,
    ),
    train_cfg = TrainConfig(
        name="custom-train",
        optimizer=torch.optim.AdamW,
        optimizer_kwargs=dict(lr=0.0001),
        batch_size=4,
        dataloader_cfg=dict(
            shuffle=True,
            drop_last=False,
        ),
       validation_dataset_cfg=5,
    ),
)

list(MAZE_DATASET_CONFIGS.keys()) = ['test-g3-n5-a_dfs-h89001', 'demo_small-g3-n100-a_dfs-h58410', 'demo-g6-n10K-a_dfs-h86254']


In [4]:
# here is where to specify which config to actually use
CFG: ConfigHolder = CFG_CUSTOM # change to CFG_CUSTOM to train a "real" model, the CFG_TEST is for CI testing

In [5]:
print(json.dumps(CFG.summary(), indent=2))

{
  "name": "hallway-medium",
  "dataset_cfg": {
    "name": "custom-hallway",
    "fname": "custom-hallway-g8-n1000000-a_dfs-h67201",
    "sdc_hash": 17702762527462501104362081349751668865096610287696345990139527306009298467201,
    "seed": 42,
    "seq_len_min": 1,
    "seq_len_max": 512,
    "padding_token_index": 10,
    "token_arr_joined": "<ADJLIST_START> <ADJLIST_END> <TARGET_START> <TARGET_END> <ORIGIN_START> <ORIGIN_END> <PATH_START> <PATH_END> <--> ; <PADDING> (0,0) (0,1) (1,0) (1,1) (0,2) (2,0) (1,2) (2,1) (2,2) (0,3) (3,0) (3,1) (2,3) (3,2) (1,3) (3,3) (0,4) (2,4) (4,0) (1,4) (4,1) (4,2) (3,4) (4,3) (4,4) (0,5) (5,0) (5,1) (2,5) (5,2) (5,3) (4,5) (5,4) (1,5) (3,5) (5,5) (0,6) (2,6) (4,6) (6,0) (1,6) (6,1) (6,2) (3,6) (6,3) (6,4) (5,6) (6,5) (6,6) (0,7) (7,0) (7,1) (2,7) (7,2) (7,3) (4,7) (7,4) (7,5) (6,7) (7,6) (1,7) (3,7) (5,7) (7,7)",
    "applied_filters": [
      {
        "name": "collect_generation_meta",
        "args": [],
        "kwargs": {}
      }
    ],
    "gr

In [6]:
# get just the dataset, generating it if needed. 
# This step can be skipped if you set `do_generate_dataset=True` when calling `train_model`
# or if the dataset in question already exists

# load the dataset
DATASET: MazeDataset = MazeDataset.from_config(
    CFG.dataset_cfg, 
    verbose=True, 
    load_local=False,
    save_local=True,
    local_base_path=PATH_DATA,
)
CFG.dataset_cfg = DATASET.cfg

seeing if we can download the dataset...
no download found, or download failed
generating dataset...


generating & solving mazes: 100%|██████████| 1000000/1000000 [21:12<00:00, 786.11maze/s]


saving dataset to ..\data\custom-hallway-g8-n1000000-a_dfs-h67201.zanj
Got dataset custom-hallway with 1000000 items. output.cfg.to_fname() = 'custom-hallway-g8-n1000000-a_dfs-h67201'


In [7]:
result: TrainingResult = train_model(
	base_path=PATH_DATA,
    cfg=CFG,
	wandb_project=WandbProject.DEMO_NOTEBOOKS, # change this to WandbProject.DEMO_NOTEBOOKS! INTEGRATION_TESTS is for CI testing
	do_generate_dataset=False,
	dataset_verbose=True,
)

2023-06-16 03:40:47 ERROR Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


[34m[1mwandb[0m: Currently logged in as: [33mmiv[0m. Use [1m`wandb login --relogin`[0m to force relogin


2023-06-16 03:40:50 INFO config ={'__format__': 'ConfigHolder(SerializableDataclass)', 'dataset_cfg': {'__format__': 'MazeDatasetConfig(SerializableDataclass)', 'name': 'custom-hallway', 'seq_len_min': 1, 'seq_len_max': 512, 'seed': 42, 'applied_filters': [{'name': 'collect_generation_meta', 'args': (), 'kwargs': {}}], 'grid_n': 8, 'n_mazes': 1000000, 'maze_ctor': {'__name__': 'gen_dfs', '__module__': 'maze_dataset.generation.generators', '__doc__': ['generate a lattice maze using depth first search, iterative', '', '        # Arguments', '        - `grid_shape: Coord`: the shape of the grid', '        - `lattice_dim: int`: the dimension of the lattice', '          (default: `2`)', '        - `n_accessible_cells: int | None`: the number of accessible cells in the maze. If `None`, defaults to the total number of cells in the grid.', '            (default: `None`)', '        - `max_tree_depth: int | None`: the maximum depth of the tree. If `None`, defaults to `2 * n_accessible_cells`.', 


fraction_connections_adjacent_lattice called on path of length less than 2, retuning NaN
prediction = array([[4, 1]])



2023-06-16 03:47:34 INFO Running evals: eval_slow
2023-06-16 03:47:48 INFO iteration 0/31247: loss=4.777
2023-06-16 03:47:48 INFO Saving model checkpoint to ../data/hallway-medium_2023-06-16-03-40-47/checkpoints/model.iter_0.zanj
2023-06-16 03:48:38 INFO Running evals: eval_fast
2023-06-16 03:48:52 INFO iteration 312/31247: loss=0.988
2023-06-16 03:49:42 INFO Running evals: eval_fast
2023-06-16 03:49:55 INFO Running evals: eval_slow
2023-06-16 03:50:09 INFO iteration 624/31247: loss=0.698
2023-06-16 03:50:59 INFO Running evals: eval_fast



fraction_connections_adjacent_lattice called on path of length less than 2, retuning NaN
prediction = array([[4, 7]])


fraction_connections_adjacent_lattice called on path of length less than 2, retuning NaN
prediction = array([[4, 4]])



2023-06-16 03:51:12 INFO iteration 936/31247: loss=0.623
2023-06-16 03:52:02 INFO Running evals: eval_fast



fraction_connections_adjacent_lattice called on path of length less than 2, retuning NaN
prediction = array([[7, 1]])


fraction_connections_adjacent_lattice called on path of length less than 2, retuning NaN
prediction = array([[0, 2]])



2023-06-16 03:52:15 INFO Running evals: eval_slow
2023-06-16 03:52:28 INFO iteration 1248/31247: loss=0.677
2023-06-16 03:53:19 INFO Running evals: eval_fast



fraction_connections_adjacent_lattice called on path of length less than 2, retuning NaN
prediction = array([[3, 7]])


fraction_connections_adjacent_lattice called on path of length less than 2, retuning NaN
prediction = array([[2, 4]])


fraction_connections_adjacent_lattice called on path of length less than 2, retuning NaN
prediction = array([[2, 0]])


fraction_connections_adjacent_lattice called on path of length less than 2, retuning NaN
prediction = array([[6, 7]])


fraction_connections_adjacent_lattice called on path of length less than 2, retuning NaN
prediction = array([[2, 5]])


fraction_connections_adjacent_lattice called on path of length less than 2, retuning NaN
prediction = array([[0, 5]])



2023-06-16 03:53:32 INFO iteration 1560/31247: loss=0.697
2023-06-16 03:53:32 INFO Saving model checkpoint to ../data/hallway-medium_2023-06-16-03-40-47/checkpoints/model.iter_1562.zanj
2023-06-16 03:54:24 INFO Running evals: eval_fast



fraction_connections_adjacent_lattice called on path of length less than 2, retuning NaN
prediction = array([[1, 4]])


fraction_connections_adjacent_lattice called on path of length less than 2, retuning NaN
prediction = array([[6, 4]])



2023-06-16 03:54:38 INFO Running evals: eval_slow
2023-06-16 03:54:51 INFO iteration 1872/31247: loss=0.775
2023-06-16 03:55:42 INFO Running evals: eval_fast
2023-06-16 03:55:54 INFO iteration 2184/31247: loss=0.721
2023-06-16 03:56:48 INFO Running evals: eval_fast
2023-06-16 03:57:01 INFO Running evals: eval_slow
2023-06-16 03:57:14 INFO iteration 2496/31247: loss=0.773
2023-06-16 03:58:08 INFO Running evals: eval_fast



fraction_connections_adjacent_lattice called on path of length less than 2, retuning NaN
prediction = array([[1, 6]])


fraction_connections_adjacent_lattice called on path of length less than 2, retuning NaN
prediction = array([[5, 3]])



2023-06-16 03:58:20 INFO iteration 2808/31247: loss=0.701
2023-06-16 03:59:17 INFO Running evals: eval_fast
2023-06-16 03:59:30 INFO Running evals: eval_slow
2023-06-16 03:59:43 INFO iteration 3120/31247: loss=0.560
2023-06-16 03:59:44 INFO Saving model checkpoint to ../data/hallway-medium_2023-06-16-03-40-47/checkpoints/model.iter_3124.zanj
2023-06-16 04:00:40 INFO Running evals: eval_fast



fraction_connections_adjacent_lattice called on path of length less than 2, retuning NaN
prediction = array([[7, 2]])


fraction_connections_adjacent_lattice called on path of length less than 2, retuning NaN
prediction = array([[5, 5]])



2023-06-16 04:00:52 INFO iteration 3432/31247: loss=0.651
2023-06-16 04:01:55 INFO Running evals: eval_fast
2023-06-16 04:02:08 INFO Running evals: eval_slow
2023-06-16 04:02:20 INFO iteration 3744/31247: loss=0.634
2023-06-16 04:03:17 INFO Running evals: eval_fast
2023-06-16 04:03:30 INFO iteration 4056/31247: loss=0.574
2023-06-16 04:04:30 INFO Running evals: eval_fast
2023-06-16 04:04:43 INFO Running evals: eval_slow
2023-06-16 04:04:55 INFO iteration 4368/31247: loss=0.644
2023-06-16 04:05:50 INFO Running evals: eval_fast



fraction_connections_adjacent_lattice called on path of length less than 2, retuning NaN
prediction = array([[7, 7]])



2023-06-16 04:06:03 INFO iteration 4680/31247: loss=0.606
2023-06-16 04:06:04 INFO Saving model checkpoint to ../data/hallway-medium_2023-06-16-03-40-47/checkpoints/model.iter_4686.zanj
2023-06-16 04:07:07 INFO Running evals: eval_fast
2023-06-16 04:07:21 INFO Running evals: eval_slow



fraction_connections_adjacent_lattice called on path of length less than 2, retuning NaN
prediction = array([[2, 3]])



2023-06-16 04:07:34 INFO iteration 4992/31247: loss=0.695
2023-06-16 04:08:37 INFO Running evals: eval_fast



fraction_connections_adjacent_lattice called on path of length less than 2, retuning NaN
prediction = array([[5, 0]])



2023-06-16 04:08:50 INFO iteration 5304/31247: loss=0.536
2023-06-16 04:09:53 INFO Running evals: eval_fast
2023-06-16 04:10:05 INFO Running evals: eval_slow
2023-06-16 04:10:18 INFO iteration 5616/31247: loss=0.436
2023-06-16 04:11:18 INFO Running evals: eval_fast



fraction_connections_adjacent_lattice called on path of length less than 2, retuning NaN
prediction = array([[4, 2]])



2023-06-16 04:11:31 INFO iteration 5928/31247: loss=0.545
2023-06-16 04:12:35 INFO Running evals: eval_fast
2023-06-16 04:12:48 INFO Running evals: eval_slow
2023-06-16 04:13:00 INFO iteration 6240/31247: loss=0.493
2023-06-16 04:13:01 INFO Saving model checkpoint to ../data/hallway-medium_2023-06-16-03-40-47/checkpoints/model.iter_6248.zanj
2023-06-16 04:14:01 INFO Running evals: eval_fast
2023-06-16 04:14:13 INFO iteration 6552/31247: loss=0.612
2023-06-16 04:15:20 INFO Running evals: eval_fast
2023-06-16 04:15:32 INFO Running evals: eval_slow
2023-06-16 04:15:45 INFO iteration 6864/31247: loss=0.595
2023-06-16 04:16:48 INFO Running evals: eval_fast
2023-06-16 04:17:01 INFO iteration 7176/31247: loss=0.509
2023-06-16 04:18:05 INFO Running evals: eval_fast
2023-06-16 04:18:18 INFO Running evals: eval_slow
2023-06-16 04:18:30 INFO iteration 7488/31247: loss=0.548
2023-06-16 04:19:33 INFO Running evals: eval_fast
2023-06-16 04:19:45 INFO iteration 7800/31247: loss=0.509
2023-06-16 04:19


fraction_connections_adjacent_lattice called on path of length less than 2, retuning NaN
prediction = array([[2, 7]])



2023-06-16 04:26:48 INFO Running evals: eval_slow
2023-06-16 04:27:00 INFO iteration 9360/31247: loss=0.532
2023-06-16 04:27:02 INFO Saving model checkpoint to ../data/hallway-medium_2023-06-16-03-40-47/checkpoints/model.iter_9372.zanj




2023-06-16 04:28:11 INFO Running evals: eval_fast
2023-06-16 04:28:23 INFO iteration 9672/31247: loss=0.588
2023-06-16 04:29:30 INFO Running evals: eval_fast
2023-06-16 04:29:43 INFO Running evals: eval_slow
2023-06-16 04:29:55 INFO iteration 9984/31247: loss=0.465
2023-06-16 04:31:03 INFO Running evals: eval_fast
2023-06-16 04:31:15 INFO iteration 10296/31247: loss=0.469
2023-06-16 04:32:23 INFO Running evals: eval_fast
2023-06-16 04:32:36 INFO Running evals: eval_slow
2023-06-16 04:32:49 INFO iteration 10608/31247: loss=0.539
2023-06-16 04:33:53 INFO Running evals: eval_fast



fraction_connections_adjacent_lattice called on path of length less than 2, retuning NaN
prediction = array([[7, 0]])



2023-06-16 04:34:05 INFO iteration 10920/31247: loss=0.547
2023-06-16 04:34:08 INFO Saving model checkpoint to ../data/hallway-medium_2023-06-16-03-40-47/checkpoints/model.iter_10934.zanj
2023-06-16 04:35:13 INFO Running evals: eval_fast



fraction_connections_adjacent_lattice called on path of length less than 2, retuning NaN
prediction = array([[3, 3]])



2023-06-16 04:35:26 INFO Running evals: eval_slow
2023-06-16 04:35:38 INFO iteration 11232/31247: loss=0.536
2023-06-16 04:36:45 INFO Running evals: eval_fast
2023-06-16 04:36:57 INFO iteration 11544/31247: loss=0.549
2023-06-16 04:38:04 INFO Running evals: eval_fast
2023-06-16 04:38:16 INFO Running evals: eval_slow
2023-06-16 04:38:28 INFO iteration 11856/31247: loss=0.577
2023-06-16 04:39:31 INFO Running evals: eval_fast
2023-06-16 04:39:43 INFO iteration 12168/31247: loss=0.507
2023-06-16 04:40:52 INFO Running evals: eval_fast
2023-06-16 04:41:04 INFO Running evals: eval_slow
2023-06-16 04:41:17 INFO iteration 12480/31247: loss=0.605
2023-06-16 04:41:19 INFO Saving model checkpoint to ../data/hallway-medium_2023-06-16-03-40-47/checkpoints/model.iter_12496.zanj
2023-06-16 04:42:22 INFO Running evals: eval_fast
2023-06-16 04:42:34 INFO iteration 12792/31247: loss=0.529
2023-06-16 04:43:45 INFO Running evals: eval_fast
2023-06-16 04:43:57 INFO Running evals: eval_slow
2023-06-16 04:44:


fraction_connections_adjacent_lattice called on path of length less than 2, retuning NaN
prediction = array([[0, 4]])



2023-06-16 04:46:46 INFO Running evals: eval_slow
2023-06-16 04:46:57 INFO iteration 13728/31247: loss=0.651
2023-06-16 04:48:02 INFO Running evals: eval_fast
2023-06-16 04:48:14 INFO iteration 14040/31247: loss=0.502
2023-06-16 04:48:17 INFO Saving model checkpoint to ../data/hallway-medium_2023-06-16-03-40-47/checkpoints/model.iter_14058.zanj
2023-06-16 04:49:21 INFO Running evals: eval_fast
2023-06-16 04:49:33 INFO Running evals: eval_slow
2023-06-16 04:49:45 INFO iteration 14352/31247: loss=0.521
2023-06-16 04:50:52 INFO Running evals: eval_fast
2023-06-16 04:51:04 INFO iteration 14664/31247: loss=0.610
2023-06-16 04:52:16 INFO Running evals: eval_fast



fraction_connections_adjacent_lattice called on path of length less than 2, retuning NaN
prediction = array([[5, 1]])



2023-06-16 04:52:28 INFO Running evals: eval_slow
2023-06-16 04:52:41 INFO iteration 14976/31247: loss=0.539
2023-06-16 04:53:51 INFO Running evals: eval_fast
2023-06-16 04:54:04 INFO iteration 15288/31247: loss=0.497
2023-06-16 04:55:13 INFO Running evals: eval_fast
2023-06-16 04:55:26 INFO Running evals: eval_slow
2023-06-16 04:55:38 INFO iteration 15600/31247: loss=0.396
2023-06-16 04:55:42 INFO Saving model checkpoint to ../data/hallway-medium_2023-06-16-03-40-47/checkpoints/model.iter_15620.zanj
2023-06-16 04:56:44 INFO Running evals: eval_fast
2023-06-16 04:56:56 INFO iteration 15912/31247: loss=0.449
2023-06-16 04:58:06 INFO Running evals: eval_fast



fraction_connections_adjacent_lattice called on path of length less than 2, retuning NaN
prediction = array([[7, 4]])



2023-06-16 04:58:19 INFO Running evals: eval_slow
2023-06-16 04:58:31 INFO iteration 16224/31247: loss=0.529
2023-06-16 04:59:36 INFO Running evals: eval_fast



fraction_connections_adjacent_lattice called on path of length less than 2, retuning NaN
prediction = array([[1, 3]])



2023-06-16 04:59:48 INFO iteration 16536/31247: loss=0.585
2023-06-16 05:00:57 INFO Running evals: eval_fast
2023-06-16 05:01:09 INFO Running evals: eval_slow
2023-06-16 05:01:22 INFO iteration 16848/31247: loss=0.558
2023-06-16 05:02:33 INFO Running evals: eval_fast
2023-06-16 05:02:45 INFO iteration 17160/31247: loss=0.445
2023-06-16 05:02:48 INFO Saving model checkpoint to ../data/hallway-medium_2023-06-16-03-40-47/checkpoints/model.iter_17182.zanj
2023-06-16 05:03:55 INFO Running evals: eval_fast
2023-06-16 05:04:07 INFO Running evals: eval_slow
2023-06-16 05:04:19 INFO iteration 17472/31247: loss=0.389
2023-06-16 05:05:25 INFO Running evals: eval_fast
2023-06-16 05:05:38 INFO iteration 17784/31247: loss=0.481
2023-06-16 05:06:50 INFO Running evals: eval_fast
2023-06-16 05:07:03 INFO Running evals: eval_slow
2023-06-16 05:07:15 INFO iteration 18096/31247: loss=0.474
2023-06-16 05:08:24 INFO Running evals: eval_fast
2023-06-16 05:08:36 INFO iteration 18408/31247: loss=0.592
2023-06-


fraction_connections_adjacent_lattice called on path of length less than 2, retuning NaN
prediction = array([[2, 1]])



2023-06-16 05:21:30 INFO Running evals: eval_slow
2023-06-16 05:21:42 INFO iteration 21216/31247: loss=0.564
2023-06-16 05:22:50 INFO Running evals: eval_fast
2023-06-16 05:23:02 INFO iteration 21528/31247: loss=0.590
2023-06-16 05:24:15 INFO Running evals: eval_fast
2023-06-16 05:24:27 INFO Running evals: eval_slow
2023-06-16 05:24:39 INFO iteration 21840/31247: loss=0.522
2023-06-16 05:24:44 INFO Saving model checkpoint to ../data/hallway-medium_2023-06-16-03-40-47/checkpoints/model.iter_21868.zanj
2023-06-16 05:25:49 INFO Running evals: eval_fast
2023-06-16 05:26:02 INFO iteration 22152/31247: loss=0.661
2023-06-16 05:27:09 INFO Running evals: eval_fast
2023-06-16 05:27:22 INFO Running evals: eval_slow
2023-06-16 05:27:35 INFO iteration 22464/31247: loss=0.534
2023-06-16 05:28:45 INFO Running evals: eval_fast
2023-06-16 05:28:58 INFO iteration 22776/31247: loss=0.526
2023-06-16 05:30:06 INFO Running evals: eval_fast
2023-06-16 05:30:18 INFO Running evals: eval_slow
2023-06-16 05:30: