In [1]:
# Generic
# make sure jupyter notebook autoreload is on
# %load_ext autoreload

import typing
import os
from pathlib import Path
import typing
import json

# Plotting
import matplotlib.pyplot as plt

# Numerical Computing
import numpy as np
import torch

# muutils
from zanj.zanj import ZANJ, ZANJ_GLOBAL_DEFAULTS

# Our Code
from muutils.nbutils.configure_notebook import configure_notebook
from maze_transformer.training.config import ConfigHolder, ZanjHookedTransformer, BaseGPTConfig, TrainConfig
from maze_dataset import MazeDataset, MazeDatasetConfig
from maze_dataset.dataset.configs import MAZE_DATASET_CONFIGS
from maze_dataset.generation import LatticeMazeGenerators
from maze_transformer.training.train_model import TrainingResult, train_model
from maze_transformer.training.wandb_logger import WandbProject


In [2]:
# set global defaults for ZANJ
ZANJ_GLOBAL_DEFAULTS.external_array_threshold = 1024
ZANJ_GLOBAL_DEFAULTS.external_list_threshold = 1024

# paths
PATH_EXAMPLES: Path = Path("../examples/")
PATH_DATA: Path = Path("../data/")

# reproducibility and device
DEVICE = configure_notebook(seed=42, dark_mode=True)
print(f"{DEVICE = }")

set up plots with PLOT_MODE = 'inline', FIG_OUTPUT_FMT = None, FIG_BASEPATH = None
DEVICE = device(type='cuda')


In [3]:
print(f"{list(MAZE_DATASET_CONFIGS.keys()) = }")

# if you want to specify a custom config, you can do so here
CFG_CUSTOM: ConfigHolder = ConfigHolder(
    name = "custom",
    dataset_cfg = MazeDatasetConfig(
		name="custom-dataset",
		grid_n=6,
		n_mazes=10000,
		maze_ctor=LatticeMazeGenerators.gen_dfs,
	),
    model_cfg = BaseGPTConfig(
        name="custom-model",
        act_fn="gelu",
        d_model=8,
        d_head=4,
        n_layers=2,
    ),
    train_cfg = TrainConfig(
        name="custom-train",
        optimizer=torch.optim.RMSprop,
        optimizer_kwargs=dict(lr=0.0001),
        batch_size=16,
        dataloader_cfg=dict(
            shuffle=True,
            num_workers=0,
            drop_last=False,
        ),
        intervals_count=dict(
            print_loss=100,
            checkpoint=5,
            eval_fast=10,
            eval_slow=5,
        ),
        validation_dataset_cfg=100,

    ),
)


list(MAZE_DATASET_CONFIGS.keys()) = ['test-g3-n5-a_dfs-h75556', 'demo_small-g3-n100-a_dfs-h88371', 'demo-g6-n10K-a_dfs-h30615']


In [4]:

CFG_HALLWAY: ConfigHolder = ConfigHolder(
    name = "hallway_v3",
    dataset_cfg = MazeDatasetConfig(
		name="hallway",
		grid_n=7,
		n_mazes=3_000_000,
		# n_mazes=30_000,
		maze_ctor=LatticeMazeGenerators.gen_dfs,
        maze_ctor_kwargs=dict(
            do_forks=False,
        ),
        applied_filters=[{'name': 'collect_generation_meta', 'args': (), 'kwargs': {}}],
        seq_len_max=256,
	),
    model_cfg = BaseGPTConfig(
        name="custom-model",
        act_fn="gelu",
        d_model=128,
        d_head=32,
        n_layers=6,
    ),
    train_cfg = TrainConfig(
        name="custom-train",
        optimizer=torch.optim.AdamW,
        optimizer_kwargs=dict(lr=0.0001),
        batch_size=32,
        dataloader_cfg=dict(
            shuffle=False,
            num_workers=8,
            drop_last=False,
        ),
        intervals_count=dict(
            print_loss=100,
            checkpoint=20,
            eval_fast=100,
            eval_slow=50,
        ),
        validation_dataset_cfg=100,
    ),
)

In [5]:
# this is for training a "real" demo model
CFG_DEMO: ConfigHolder = ConfigHolder.get_config_multisource(
    cfg_names=("test-g3-n5-a_dfs-h75556", "tiny-v1", "sweep-v1"),
)

# this is smaller, for testing
CFG_TEST: ConfigHolder = ConfigHolder.get_config_multisource(
    cfg_names=("demo_small-g3-n100-a_dfs-h88371", "nano-v1", "test-v1"),
)

In [6]:
# here is where to specify which config to actually use
CFG: ConfigHolder = CFG_HALLWAY


In [7]:
print(json.dumps(CFG.summary(), indent=2))

{
  "name": "hallway_v3",
  "dataset_cfg": {
    "name": "hallway",
    "fname": "hallway-g7-n30K-a_dfs-h49802",
    "sdc_hash": 86367718438925181581864853467393593889412451469408090965213468443030807549802,
    "seed": 42,
    "seq_len_min": 1,
    "seq_len_max": 256,
    "applied_filters": [
      {
        "name": "collect_generation_meta",
        "args": [],
        "kwargs": {}
      }
    ],
    "grid_n": 7,
    "grid_shape": [
      7,
      7
    ],
    "n_mazes": 30000,
    "maze_ctor_name": "gen_dfs",
    "maze_ctor_kwargs": {
      "do_forks": false
    }
  },
  "model_cfg": {
    "name": "custom-model",
    "act_fn": "gelu",
    "d_model": 128,
    "d_head": 32,
    "n_layers": 6,
    "weight_processing": {
      "are_layernorms_folded": false,
      "are_weights_processed": false
    },
    "n_heads": 4
  },
  "train_cfg": {
    "name": "custom-train",
    "optimizer": "AdamW",
    "optimizer_kwargs": {
      "lr": 0.0001
    },
    "batch_size": 32,
    "dataloader_cfg":

In [8]:
# get just the dataset, generating it if needed. 
# This step can be skipped if you set `do_generate_dataset=True` when calling `train_model`
# or if the dataset in question already exists

# load the dataset
DATASET: MazeDataset = MazeDataset.from_config(CFG.dataset_cfg, verbose=True, local_base_path=PATH_DATA)

trying to get the dataset 'hallway-g7-n30K-a_dfs-h49802'
loading dataset from ../data/hallway-g7-n30K-a_dfs-h49802.zanj
Got dataset hallway with 30000 items. output.cfg.to_fname() = 'hallway-g7-n30K-a_dfs-h49802'


In [9]:
result: TrainingResult = train_model(
	base_path=PATH_DATA,
    cfg=CFG,
	wandb_project=WandbProject.DEMO_NOTEBOOKS, # change this to WandbProject.DEMO_NOTEBOOKS!
	do_generate_dataset=False,
	dataset_verbose=True,
    dataset=DATASET,
)

2024-02-04 18:04:17 ERROR Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


[34m[1mwandb[0m: Currently logged in as: [33mwatermelontology[0m. Use [1m`wandb login --relogin`[0m to force relogin


2024-02-04 18:04:19 INFO config ={'__format__': 'ConfigHolder(SerializableDataclass)', 'dataset_cfg': {'__format__': 'MazeDatasetConfig(SerializableDataclass)', 'name': 'hallway', 'seq_len_min': 1, 'seq_len_max': 256, 'seed': 42, 'applied_filters': [{'name': 'collect_generation_meta', 'args': (), 'kwargs': {}}], 'grid_n': 7, 'n_mazes': 30000, 'maze_ctor': {'__name__': 'gen_dfs', '__module__': 'maze_dataset.generation.generators', '__doc__': ['generate a lattice maze using depth first search, iterative', '', '        # Arguments', '        - `grid_shape: Coord`: the shape of the grid', '        - `lattice_dim: int`: the dimension of the lattice', '          (default: `2`)', '        - `accessible_cells: int | float |None`: the number of accessible cells in the maze. If `None`, defaults to the total number of cells in the grid. if a float, asserts it is <= 1 and treats it as a proportion of **total cells**', '            (default: `None`)', '        - `max_tree_depth: int | float | None`

  0%|          | 0/938 [00:00<?, ?it/s]

2024-02-04 18:04:20 INFO iteration 0/938: loss=5.053
2024-02-04 18:04:20 INFO Saving model checkpoint to ../data/hallway_v3_2024-02-04-18-04-17/checkpoints/model.iter_0.zanj


  1%|          | 9/938 [00:03<02:57,  5.22it/s]

2024-02-04 18:04:22 INFO iteration 9/938: loss=2.147


  2%|▏         | 18/938 [00:04<02:17,  6.69it/s]

2024-02-04 18:04:23 INFO iteration 18/938: loss=1.866


  3%|▎         | 27/938 [00:05<02:23,  6.34it/s]

2024-02-04 18:04:25 INFO iteration 27/938: loss=1.806


  4%|▍         | 36/938 [00:06<02:06,  7.14it/s]

2024-02-04 18:04:26 INFO iteration 36/938: loss=1.832


  5%|▍         | 45/938 [00:07<01:35,  9.37it/s]

2024-02-04 18:04:27 INFO iteration 45/938: loss=2.016


  5%|▍         | 46/938 [00:07<01:33,  9.49it/s]

2024-02-04 18:04:27 INFO Saving model checkpoint to ../data/hallway_v3_2024-02-04-18-04-17/checkpoints/model.iter_46.zanj


  6%|▌         | 54/938 [00:09<01:50,  7.99it/s]

2024-02-04 18:04:28 INFO iteration 54/938: loss=1.488


  7%|▋         | 63/938 [00:10<01:58,  7.41it/s]

2024-02-04 18:04:30 INFO iteration 63/938: loss=1.682


  8%|▊         | 72/938 [00:11<01:45,  8.21it/s]

2024-02-04 18:04:31 INFO iteration 72/938: loss=1.575


  9%|▊         | 81/938 [00:12<01:39,  8.65it/s]

2024-02-04 18:04:32 INFO iteration 81/938: loss=1.426


 10%|▉         | 90/938 [00:13<01:57,  7.22it/s]

2024-02-04 18:04:33 INFO iteration 90/938: loss=1.339


 10%|▉         | 92/938 [00:14<02:04,  6.81it/s]

2024-02-04 18:04:33 INFO Saving model checkpoint to ../data/hallway_v3_2024-02-04-18-04-17/checkpoints/model.iter_92.zanj


 11%|█         | 99/938 [00:15<02:43,  5.14it/s]

2024-02-04 18:04:35 INFO iteration 99/938: loss=1.423


 11%|█▏        | 107/938 [00:17<02:07,  6.54it/s]

2024-02-04 18:04:36 INFO iteration 108/938: loss=1.254


 12%|█▏        | 117/938 [00:18<01:35,  8.56it/s]

2024-02-04 18:04:37 INFO iteration 117/938: loss=1.397


 13%|█▎        | 125/938 [00:19<01:43,  7.87it/s]

2024-02-04 18:04:38 INFO iteration 126/938: loss=1.256


 14%|█▍        | 135/938 [00:20<01:29,  8.94it/s]

2024-02-04 18:04:39 INFO iteration 135/938: loss=1.293


 15%|█▍        | 137/938 [00:20<01:41,  7.87it/s]

2024-02-04 18:04:40 INFO Saving model checkpoint to ../data/hallway_v3_2024-02-04-18-04-17/checkpoints/model.iter_138.zanj


 15%|█▌        | 144/938 [00:21<02:00,  6.58it/s]

2024-02-04 18:04:41 INFO iteration 144/938: loss=1.304


 16%|█▋        | 153/938 [00:23<01:39,  7.87it/s]

2024-02-04 18:04:42 INFO iteration 153/938: loss=1.311


 17%|█▋        | 162/938 [00:24<01:44,  7.39it/s]

2024-02-04 18:04:43 INFO iteration 162/938: loss=1.244


 18%|█▊        | 171/938 [00:25<01:30,  8.43it/s]

2024-02-04 18:04:44 INFO iteration 171/938: loss=1.265


 19%|█▉        | 180/938 [00:26<01:35,  7.94it/s]

2024-02-04 18:04:45 INFO iteration 180/938: loss=1.129


 20%|█▉        | 184/938 [00:26<01:47,  7.03it/s]

2024-02-04 18:04:46 INFO Saving model checkpoint to ../data/hallway_v3_2024-02-04-18-04-17/checkpoints/model.iter_184.zanj


 20%|██        | 189/938 [00:28<02:10,  5.76it/s]

2024-02-04 18:04:47 INFO iteration 189/938: loss=1.321


 21%|██        | 198/938 [00:29<01:41,  7.28it/s]

2024-02-04 18:04:48 INFO iteration 198/938: loss=1.231


 22%|██▏       | 207/938 [00:30<01:19,  9.18it/s]

2024-02-04 18:04:49 INFO iteration 207/938: loss=1.115


 23%|██▎       | 216/938 [00:31<01:35,  7.54it/s]

2024-02-04 18:04:50 INFO iteration 216/938: loss=1.301


 24%|██▍       | 224/938 [00:32<01:34,  7.52it/s]

2024-02-04 18:04:52 INFO iteration 225/938: loss=1.066


 25%|██▍       | 230/938 [00:33<01:44,  6.75it/s]

2024-02-04 18:04:52 INFO Saving model checkpoint to ../data/hallway_v3_2024-02-04-18-04-17/checkpoints/model.iter_230.zanj


 25%|██▍       | 234/938 [00:34<02:31,  4.66it/s]

2024-02-04 18:04:54 INFO iteration 234/938: loss=1.098


 26%|██▌       | 243/938 [00:35<01:42,  6.77it/s]

2024-02-04 18:04:55 INFO iteration 243/938: loss=1.102


 27%|██▋       | 252/938 [00:37<01:34,  7.25it/s]

2024-02-04 18:04:56 INFO iteration 252/938: loss=1.257


 28%|██▊       | 261/938 [00:38<01:40,  6.77it/s]

2024-02-04 18:04:57 INFO iteration 261/938: loss=1.238


 29%|██▊       | 269/938 [00:39<01:29,  7.50it/s]

2024-02-04 18:04:59 INFO iteration 270/938: loss=1.275


 29%|██▉       | 276/938 [00:40<01:25,  7.75it/s]

2024-02-04 18:04:59 INFO Saving model checkpoint to ../data/hallway_v3_2024-02-04-18-04-17/checkpoints/model.iter_276.zanj


 30%|██▉       | 279/938 [00:41<02:06,  5.21it/s]

2024-02-04 18:05:00 INFO iteration 279/938: loss=1.219


 31%|███       | 288/938 [00:42<01:23,  7.76it/s]

2024-02-04 18:05:01 INFO iteration 288/938: loss=1.229


 32%|███▏      | 297/938 [00:43<01:18,  8.22it/s]

2024-02-04 18:05:02 INFO iteration 297/938: loss=1.036


 33%|███▎      | 305/938 [00:44<01:25,  7.40it/s]

2024-02-04 18:05:04 INFO iteration 306/938: loss=1.118


 34%|███▎      | 315/938 [00:45<01:09,  8.93it/s]

2024-02-04 18:05:05 INFO iteration 315/938: loss=1.035


 34%|███▍      | 321/938 [00:46<01:24,  7.32it/s]

2024-02-04 18:05:05 INFO Saving model checkpoint to ../data/hallway_v3_2024-02-04-18-04-17/checkpoints/model.iter_322.zanj


 35%|███▍      | 324/938 [00:47<02:10,  4.70it/s]

2024-02-04 18:05:06 INFO iteration 324/938: loss=1.115


 36%|███▌      | 333/938 [00:48<01:13,  8.21it/s]

2024-02-04 18:05:07 INFO iteration 333/938: loss=1.141


 36%|███▋      | 341/938 [00:49<01:12,  8.24it/s]

2024-02-04 18:05:08 INFO iteration 342/938: loss=1.220


 37%|███▋      | 350/938 [00:50<01:12,  8.10it/s]

2024-02-04 18:05:09 INFO iteration 351/938: loss=1.078


 38%|███▊      | 360/938 [00:51<01:13,  7.83it/s]

2024-02-04 18:05:10 INFO iteration 360/938: loss=1.079


 39%|███▉      | 368/938 [00:52<01:12,  7.82it/s]

2024-02-04 18:05:11 INFO Saving model checkpoint to ../data/hallway_v3_2024-02-04-18-04-17/checkpoints/model.iter_368.zanj


 39%|███▉      | 369/938 [00:53<02:20,  4.04it/s]

2024-02-04 18:05:12 INFO iteration 369/938: loss=1.097


 40%|████      | 378/938 [00:54<01:20,  6.95it/s]

2024-02-04 18:05:13 INFO iteration 378/938: loss=1.127


 41%|████▏     | 387/938 [00:55<01:18,  6.99it/s]

2024-02-04 18:05:14 INFO iteration 387/938: loss=1.183


 42%|████▏     | 396/938 [00:56<01:08,  7.88it/s]

2024-02-04 18:05:16 INFO iteration 396/938: loss=1.095


 43%|████▎     | 405/938 [00:57<01:05,  8.19it/s]

2024-02-04 18:05:17 INFO iteration 405/938: loss=1.115


 44%|████▍     | 414/938 [00:58<00:58,  8.96it/s]

2024-02-04 18:05:18 INFO iteration 414/938: loss=1.067
2024-02-04 18:05:18 INFO Saving model checkpoint to ../data/hallway_v3_2024-02-04-18-04-17/checkpoints/model.iter_414.zanj


 45%|████▌     | 423/938 [01:00<01:11,  7.23it/s]

2024-02-04 18:05:19 INFO iteration 423/938: loss=1.127


 46%|████▌     | 432/938 [01:01<01:01,  8.26it/s]

2024-02-04 18:05:21 INFO iteration 432/938: loss=1.114


 46%|████▋     | 435/938 [01:01<01:11,  7.02it/s]


KeyboardInterrupt: 