In [1]:
# Generic
import typing
import os
from pathlib import Path
import typing
import json

# Plotting
import matplotlib.pyplot as plt

# Numerical Computing
import numpy as np
import torch

# muutils
from muutils.zanj.zanj import ZANJ, ZANJ_GLOBAL_DEFAULTS

# Our Code
from maze_transformer.utils.notebook_utils import configure_notebook
from maze_transformer.training.config import ConfigHolder, ZanjHookedTransformer, BaseGPTConfig, TrainConfig
from maze_transformer.dataset.maze_dataset import MazeDataset, MazeDatasetConfig
from maze_transformer.dataset.maze_dataset_configs import MAZE_DATASET_CONFIGS
from maze_transformer.generation.generators import LatticeMazeGenerators
from maze_transformer.training.train_model import TrainingResult, train_model
from maze_transformer.training.wandb_logger import WandbProject


In [10]:
# set global defaults for ZANJ
ZANJ_GLOBAL_DEFAULTS.external_array_threshold = 1024
ZANJ_GLOBAL_DEFAULTS.external_list_threshold = 1024

# paths
PATH_EXAMPLES: Path = Path("../examples/")
PATH_DATA: Path = Path("../data/")

# reproducibility and device
DEVICE = configure_notebook(seed=42, dark_mode=True)
print(f"{DEVICE = }")

DEVICE = device(type='cpu')


In [3]:
print(f"{list(MAZE_DATASET_CONFIGS.keys()) = }")

# if you want to specify a custom config, you can do so here
CFG_CUSTOM: ConfigHolder = ConfigHolder(
    name = "custom",
    dataset_cfg = MazeDatasetConfig(
		name="custom-dataset",
		grid_n=6,
		n_mazes=10000,
		maze_ctor=LatticeMazeGenerators.gen_dfs,
	),
    model_cfg = BaseGPTConfig(
        name="custom-model",
        act_fn="gelu",
        d_model=8,
        d_head=4,
        n_layers=2,
    ),
    train_cfg = TrainConfig(
        name="custom-train",
        optimizer=torch.optim.RMSprop,
        optimizer_kwargs=dict(lr=0.0001),
        batch_size=16,
        dataloader_cfg=dict(
            shuffle=True,
            num_workers=0,
            drop_last=False,
        ),
        print_loss_interval=100,
        checkpoint_interval=1000,
    ),
)

list(MAZE_DATASET_CONFIGS.keys()) = ['test-g3-n5-a_dfs-h82387', 'demo_small-g3-n100-a_dfs-h69818', 'demo-g6-n10K-a_dfs-h92077']


In [4]:
# this is for training a "real" demo model
CFG_DEMO: ConfigHolder = ConfigHolder.get_config_multisource(
    cfg_names=("demo-g6-n10K-a_dfs-h92077", "tiny-v1", "sweep-v1"),
)

# this is smaller, for testing
CFG_TEST: ConfigHolder = ConfigHolder.get_config_multisource(
    cfg_names=("demo_small-g3-n100-a_dfs-h69818", "nano-v1", "test-v1"),
)

In [5]:
# here is where to specify which config to actually use
CFG: ConfigHolder = CFG_DEMO

In [6]:
print(json.dumps(CFG.summary(), indent=2))

{
  "name": "multsrc_demo-g6-n10K-a_dfs-h92077_tiny-v1_sweep-v1",
  "dataset_cfg": {
    "name": "demo",
    "fname": "demo-g6-n10K-a_dfs-h92077",
    "sdc_hash": 65001276129750023809936129568423056416231185113961566638389670113962440292077,
    "seed": 42,
    "seq_len_min": 1,
    "seq_len_max": 512,
    "padding_token_index": 10,
    "token_arr_joined": "<ADJLIST_START> <ADJLIST_END> <TARGET_START> <TARGET_END> <ORIGIN_START> <ORIGIN_END> <PATH_START> <PATH_END> <--> ; <PADDING> (0,0) (0,1) (1,0) (1,1) (0,2) (2,0) (1,2) (2,1) (2,2) (0,3) (3,0) (3,1) (2,3) (3,2) (1,3) (3,3) (0,4) (2,4) (4,0) (1,4) (4,1) (4,2) (3,4) (4,3) (4,4) (0,5) (5,0) (5,1) (2,5) (5,2) (5,3) (4,5) (5,4) (1,5) (3,5) (5,5)",
    "applied_filters": [],
    "grid_n": 6,
    "grid_shape": [
      6,
      6
    ],
    "n_mazes": 10000,
    "maze_ctor_name": "gen_dfs",
    "maze_ctor_kwargs": {}
  },
  "model_cfg": {
    "__format__": "BaseGPTConfig(SerializableDataclass)",
    "name": "tiny-v1",
    "act_fn": "gelu",


In [11]:
# get just the dataset, generating it if needed. 
# This step can be skipped if you set `do_generate_dataset=True` when calling `train_model`
# or if the dataset in question already exists

# load the dataset from examples
DATASET: MazeDataset = MazeDataset.from_config(CFG.dataset_cfg, verbose=True, local_base_path=PATH_EXAMPLES)
# save it to the data directory where we will be saving our models
DATASET.save(PATH_DATA / DATASET.cfg.to_fname())

loading dataset from ../examples/demo-g6-n10K-a_dfs-h92077.zanj


In [8]:
result: TrainingResult = train_model(
	base_path=PATH_DATA,
    cfg=CFG,
	wandb_project=WandbProject.DEMO_NOTEBOOKS,
	do_generate_dataset=False,
	dataset_verbose=True,
)

2023-05-20 21:30:02 ERROR Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


[34m[1mwandb[0m: Currently logged in as: [33mmiv[0m. Use [1m`wandb login --relogin`[0m to force relogin


2023-05-20 21:30:06 INFO config ={'__format__': 'ConfigHolder(SerializableDataclass)', 'dataset_cfg': {'__format__': 'MazeDatasetConfig(SerializableDataclass)', 'name': 'demo', 'dtype': 'torch.int16', 'seq_len_min': 1, 'seq_len_max': 512, 'seed': 42, 'applied_filters': [], 'grid_n': 6, 'n_mazes': 10000, 'maze_ctor': {'__name__': 'gen_dfs', '__module__': 'maze_transformer.generation.generators', '__doc__': ['generate a lattice maze using depth first search, iterative', '', '        # Arguments', '        - `grid_shape: Coord`: the shape of the grid', '        - `lattice_dim: int`: the dimension of the lattice', '          (default: `2`)', '        - `n_accessible_cells: int | None`: the number of accessible cells in the maze. If `None`, defaults to the total number of cells in the grid.', '            (default: `None`)', '        - `max_tree_depth: int | None`: the maximum depth of the tree. If `None`, defaults to `2 * n_accessible_cells`.', '            (default: `None`)', '        - `