In [1]:
# Generic
import typing
import os
from pathlib import Path
import typing
import json

# Plotting
import matplotlib.pyplot as plt

# Numerical Computing
import numpy as np
import torch

# muutils
from zanj.zanj import ZANJ, ZANJ_GLOBAL_DEFAULTS

# Our Code
from muutils.nbutils.configure_notebook import configure_notebook
from maze_transformer.training.config import ConfigHolder, ZanjHookedTransformer, BaseGPTConfig, TrainConfig
from maze_dataset import MazeDataset, MazeDatasetConfig
from maze_dataset.dataset.configs import MAZE_DATASET_CONFIGS
from maze_dataset.generation import LatticeMazeGenerators
from maze_transformer.training.train_model import TrainingResult, train_model
from maze_transformer.training.wandb_logger import WandbProject


In [2]:
# set global defaults for ZANJ
ZANJ_GLOBAL_DEFAULTS.external_array_threshold = 1024
ZANJ_GLOBAL_DEFAULTS.external_list_threshold = 1024

# paths
PATH_EXAMPLES: Path = Path("../examples/")
PATH_DATA: Path = Path("../data/")

# reproducibility and device
DEVICE = configure_notebook(seed=42, dark_mode=True)
print(f"{DEVICE = }")

DEVICE = device(type='cuda')


In [3]:
print(f"{list(MAZE_DATASET_CONFIGS.keys()) = }")

# if you want to specify a custom config, you can do so here
CFG_CUSTOM: ConfigHolder = ConfigHolder(
    name = "hallway-medium",
    dataset_cfg = MazeDatasetConfig(
		name="custom-hallway",
		grid_n=8,
		n_mazes=100_000,
		maze_ctor=LatticeMazeGenerators.gen_dfs,
        maze_ctor_kwargs=dict(
            do_forks=False,
        ),
	),
    model_cfg = BaseGPTConfig(
        name="custom-model",
        act_fn="gelu",
        d_model=128,
        d_head=32,
        n_layers=6,
    ),
    train_cfg = TrainConfig(
        name="custom-train",
        optimizer=torch.optim.AdamW,
        optimizer_kwargs=dict(lr=0.00001),
        batch_size=32,
        dataloader_cfg=dict(
            shuffle=True,
            num_workers=4,
            drop_last=False,
        ),
        intervals_count=dict(
            print_loss=1000,
            checkpoint=20,
            eval_fast=100,
            eval_slow=50,
        ),
        validation_dataset_cfg=10,
    ),
)

CFG_TEST: ConfigHolder = ConfigHolder(
        name = "hallway-nano",
    dataset_cfg = MazeDatasetConfig(
		name="custom-hallway",
		grid_n=3,
		n_mazes=50,
		maze_ctor=LatticeMazeGenerators.gen_dfs,
        maze_ctor_kwargs=dict(
            do_forks=False,
        ),
	),
    model_cfg = BaseGPTConfig(
        name="custom-model",
        act_fn="gelu",
        d_model=8,
        d_head=2,
        n_layers=2,
    ),
    train_cfg = TrainConfig(
        name="custom-train",
        optimizer=torch.optim.AdamW,
        optimizer_kwargs=dict(lr=0.0001),
        batch_size=4,
        dataloader_cfg=dict(
            shuffle=True,
            drop_last=False,
        ),
       validation_dataset_cfg=5,
    ),
)

list(MAZE_DATASET_CONFIGS.keys()) = ['test-g3-n5-a_dfs-h89001', 'demo_small-g3-n100-a_dfs-h58410', 'demo-g6-n10K-a_dfs-h86254']


In [4]:
# here is where to specify which config to actually use
CFG: ConfigHolder = CFG_CUSTOM # change to CFG_CUSTOM to train a "real" model, the CFG_TEST is for CI testing

In [5]:
print(json.dumps(CFG.summary(), indent=2))

{
  "name": "hallway-medium",
  "dataset_cfg": {
    "name": "custom-hallway",
    "fname": "custom-hallway-g8-n100K-a_dfs-h31024",
    "sdc_hash": 43825844172657897425957296941560401936207458692340074698180715204853252031024,
    "seed": 42,
    "seq_len_min": 1,
    "seq_len_max": 512,
    "padding_token_index": 10,
    "token_arr_joined": "<ADJLIST_START> <ADJLIST_END> <TARGET_START> <TARGET_END> <ORIGIN_START> <ORIGIN_END> <PATH_START> <PATH_END> <--> ; <PADDING> (0,0) (0,1) (1,0) (1,1) (0,2) (2,0) (1,2) (2,1) (2,2) (0,3) (3,0) (3,1) (2,3) (3,2) (1,3) (3,3) (0,4) (2,4) (4,0) (1,4) (4,1) (4,2) (3,4) (4,3) (4,4) (0,5) (5,0) (5,1) (2,5) (5,2) (5,3) (4,5) (5,4) (1,5) (3,5) (5,5) (0,6) (2,6) (4,6) (6,0) (1,6) (6,1) (6,2) (3,6) (6,3) (6,4) (5,6) (6,5) (6,6) (0,7) (7,0) (7,1) (2,7) (7,2) (7,3) (4,7) (7,4) (7,5) (6,7) (7,6) (1,7) (3,7) (5,7) (7,7)",
    "applied_filters": [],
    "grid_n": 8,
    "grid_shape": [
      8,
      8
    ],
    "n_mazes": 100000,
    "maze_ctor_name": "gen_dfs"

In [6]:
# get just the dataset, generating it if needed. 
# This step can be skipped if you set `do_generate_dataset=True` when calling `train_model`
# or if the dataset in question already exists

# load the dataset
DATASET: MazeDataset = MazeDataset.from_config(
    CFG.dataset_cfg, 
    verbose=True, 
    load_local=False,
    local_base_path=PATH_DATA,
).filter_by.collect_generation_meta()

seeing if we can download the dataset...
no download found, or download failed
generating dataset...


generating & solving mazes: 100%|██████████| 100000/100000 [02:51<00:00, 581.55maze/s]


saving dataset to ..\data\custom-hallway-g8-n100K-a_dfs-h31024.zanj
Got dataset custom-hallway with 100000 items. output.cfg.to_fname() = 'custom-hallway-g8-n100K-a_dfs-h31024'


In [7]:
DATASET.save(PATH_DATA / DATASET.cfg.to_fname())
CFG.dataset_cfg = DATASET.cfg

In [8]:
result: TrainingResult = train_model(
	base_path=PATH_DATA,
    cfg=CFG,
	wandb_project=WandbProject.DEMO_NOTEBOOKS, # change this to WandbProject.DEMO_NOTEBOOKS! INTEGRATION_TESTS is for CI testing
	do_generate_dataset=False,
	dataset_verbose=True,
)

2023-06-15 03:51:26 ERROR Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


[34m[1mwandb[0m: Currently logged in as: [33mmiv[0m. Use [1m`wandb login --relogin`[0m to force relogin


2023-06-15 03:51:29 INFO config ={'__format__': 'ConfigHolder(SerializableDataclass)', 'dataset_cfg': {'__format__': 'MazeDatasetConfig(SerializableDataclass)', 'name': 'custom-hallway', 'seq_len_min': 1, 'seq_len_max': 512, 'seed': 42, 'applied_filters': [{'name': 'collect_generation_meta', 'args': (), 'kwargs': {}}], 'grid_n': 8, 'n_mazes': 100000, 'maze_ctor': {'__name__': 'gen_dfs', '__module__': 'maze_dataset.generation.generators', '__doc__': ['generate a lattice maze using depth first search, iterative', '', '        # Arguments', '        - `grid_shape: Coord`: the shape of the grid', '        - `lattice_dim: int`: the dimension of the lattice', '          (default: `2`)', '        - `n_accessible_cells: int | None`: the number of accessible cells in the maze. If `None`, defaults to the total number of cells in the grid.', '            (default: `None`)', '        - `max_tree_depth: int | None`: the maximum depth of the tree. If `None`, defaults to `2 * n_accessible_cells`.', '


fraction_connections_adjacent_lattice called on path of length less than 2, retuning NaN
prediction = array([[3, 1]])


fraction_connections_adjacent_lattice called on path of length less than 2, retuning NaN
prediction = array([[5, 6]])


fraction_connections_adjacent_lattice called on path of length less than 2, retuning NaN
prediction = array([[7, 0]])



2023-06-15 03:52:39 INFO iteration 96/3125: loss=2.114
2023-06-15 03:52:39 INFO iteration 99/3125: loss=2.200
2023-06-15 03:52:40 INFO iteration 102/3125: loss=2.350
2023-06-15 03:52:40 INFO iteration 105/3125: loss=2.254
2023-06-15 03:52:41 INFO iteration 108/3125: loss=2.067
2023-06-15 03:52:41 INFO iteration 111/3125: loss=2.306
2023-06-15 03:52:42 INFO iteration 114/3125: loss=1.934
2023-06-15 03:52:43 INFO iteration 117/3125: loss=2.222
2023-06-15 03:52:43 INFO iteration 120/3125: loss=2.202
2023-06-15 03:52:44 INFO iteration 123/3125: loss=2.323
2023-06-15 03:52:44 INFO Running evals: eval_fast
2023-06-15 03:52:45 INFO Running evals: eval_slow
2023-06-15 03:52:47 INFO iteration 126/3125: loss=2.056
2023-06-15 03:52:48 INFO iteration 129/3125: loss=2.239
2023-06-15 03:52:48 INFO iteration 132/3125: loss=2.183
2023-06-15 03:52:49 INFO iteration 135/3125: loss=2.259
2023-06-15 03:52:49 INFO iteration 138/3125: loss=1.942
2023-06-15 03:52:50 INFO iteration 141/3125: loss=2.238
2023-0


fraction_connections_adjacent_lattice called on path of length less than 2, retuning NaN
prediction = array([[2, 3]])



2023-06-15 03:53:53 INFO iteration 372/3125: loss=1.362
2023-06-15 03:53:54 INFO iteration 375/3125: loss=1.427
2023-06-15 03:53:54 INFO iteration 378/3125: loss=1.687
2023-06-15 03:53:55 INFO iteration 381/3125: loss=1.659
2023-06-15 03:53:55 INFO iteration 384/3125: loss=1.505
2023-06-15 03:53:56 INFO iteration 387/3125: loss=1.603
2023-06-15 03:53:57 INFO iteration 390/3125: loss=1.798
2023-06-15 03:53:57 INFO iteration 393/3125: loss=1.650
2023-06-15 03:53:58 INFO iteration 396/3125: loss=1.591
2023-06-15 03:53:58 INFO iteration 399/3125: loss=1.922
2023-06-15 03:53:59 INFO iteration 402/3125: loss=1.772
2023-06-15 03:53:59 INFO Running evals: eval_fast



fraction_connections_adjacent_lattice called on path of length less than 2, retuning NaN
prediction = array([[1, 7]])



2023-06-15 03:54:01 INFO iteration 405/3125: loss=1.671
2023-06-15 03:54:01 INFO iteration 408/3125: loss=1.619
2023-06-15 03:54:02 INFO iteration 411/3125: loss=1.607
2023-06-15 03:54:03 INFO iteration 414/3125: loss=1.910
2023-06-15 03:54:03 INFO iteration 417/3125: loss=1.596
2023-06-15 03:54:04 INFO iteration 420/3125: loss=1.796
2023-06-15 03:54:04 INFO iteration 423/3125: loss=1.733
2023-06-15 03:54:05 INFO iteration 426/3125: loss=1.561
2023-06-15 03:54:05 INFO iteration 429/3125: loss=1.784
2023-06-15 03:54:06 INFO iteration 432/3125: loss=1.731
2023-06-15 03:54:06 INFO Running evals: eval_fast
2023-06-15 03:54:08 INFO Running evals: eval_slow
2023-06-15 03:54:10 INFO iteration 435/3125: loss=1.788
2023-06-15 03:54:10 INFO iteration 438/3125: loss=1.669
2023-06-15 03:54:11 INFO iteration 441/3125: loss=1.593
2023-06-15 03:54:11 INFO iteration 444/3125: loss=1.630
2023-06-15 03:54:12 INFO iteration 447/3125: loss=1.704
2023-06-15 03:54:13 INFO iteration 450/3125: loss=1.863
2023


fraction_connections_adjacent_lattice called on path of length less than 2, retuning NaN
prediction = array([[2, 6]])



2023-06-15 03:54:43 INFO iteration 558/3125: loss=1.488
2023-06-15 03:54:44 INFO iteration 561/3125: loss=1.411
2023-06-15 03:54:44 INFO iteration 564/3125: loss=1.527
2023-06-15 03:54:45 INFO iteration 567/3125: loss=1.703
2023-06-15 03:54:46 INFO iteration 570/3125: loss=1.965
2023-06-15 03:54:46 INFO iteration 573/3125: loss=1.832
2023-06-15 03:54:47 INFO iteration 576/3125: loss=1.567
2023-06-15 03:54:47 INFO iteration 579/3125: loss=1.714
2023-06-15 03:54:48 INFO iteration 582/3125: loss=1.573
2023-06-15 03:54:48 INFO iteration 585/3125: loss=1.546
2023-06-15 03:54:49 INFO iteration 588/3125: loss=1.604
2023-06-15 03:54:49 INFO Running evals: eval_fast
2023-06-15 03:54:51 INFO iteration 591/3125: loss=1.703
2023-06-15 03:54:52 INFO iteration 594/3125: loss=1.419
2023-06-15 03:54:52 INFO iteration 597/3125: loss=1.491
2023-06-15 03:54:53 INFO iteration 600/3125: loss=1.616
2023-06-15 03:54:54 INFO iteration 603/3125: loss=1.414
2023-06-15 03:54:54 INFO iteration 606/3125: loss=1.55


fraction_connections_adjacent_lattice called on path of length less than 2, retuning NaN
prediction = array([[4, 7]])



2023-06-15 03:55:00 INFO iteration 621/3125: loss=1.401
2023-06-15 03:55:01 INFO iteration 624/3125: loss=1.865
2023-06-15 03:55:01 INFO Saving model checkpoint to ../data/hallway-medium_2023-06-15-03-51-25/checkpoints/model.iter_624.zanj
2023-06-15 03:55:02 INFO iteration 627/3125: loss=1.515
2023-06-15 03:55:03 INFO iteration 630/3125: loss=1.533
2023-06-15 03:55:03 INFO iteration 633/3125: loss=1.736
2023-06-15 03:55:04 INFO iteration 636/3125: loss=1.680
2023-06-15 03:55:04 INFO iteration 639/3125: loss=1.473
2023-06-15 03:55:05 INFO iteration 642/3125: loss=1.554
2023-06-15 03:55:05 INFO iteration 645/3125: loss=1.734
2023-06-15 03:55:06 INFO iteration 648/3125: loss=1.515
2023-06-15 03:55:07 INFO Running evals: eval_fast
2023-06-15 03:55:08 INFO iteration 651/3125: loss=1.416



fraction_connections_adjacent_lattice called on path of length less than 2, retuning NaN
prediction = array([[5, 5]])



2023-06-15 03:55:09 INFO iteration 654/3125: loss=1.815
2023-06-15 03:55:09 INFO iteration 657/3125: loss=1.551
2023-06-15 03:55:10 INFO iteration 660/3125: loss=1.556
2023-06-15 03:55:10 INFO iteration 663/3125: loss=1.453
2023-06-15 03:55:11 INFO iteration 666/3125: loss=1.482
2023-06-15 03:55:12 INFO iteration 669/3125: loss=1.584
2023-06-15 03:55:12 INFO iteration 672/3125: loss=1.512
2023-06-15 03:55:13 INFO iteration 675/3125: loss=1.475
2023-06-15 03:55:13 INFO iteration 678/3125: loss=1.744
2023-06-15 03:55:14 INFO iteration 681/3125: loss=1.845
2023-06-15 03:55:14 INFO Running evals: eval_fast
2023-06-15 03:55:15 INFO Running evals: eval_slow
2023-06-15 03:55:17 INFO iteration 684/3125: loss=1.744
2023-06-15 03:55:18 INFO iteration 687/3125: loss=1.411
2023-06-15 03:55:18 INFO iteration 690/3125: loss=1.736
2023-06-15 03:55:19 INFO iteration 693/3125: loss=1.321
2023-06-15 03:55:19 INFO iteration 696/3125: loss=1.649
2023-06-15 03:55:20 INFO iteration 699/3125: loss=1.510
2023

KeyboardInterrupt: 