In [1]:
# Generic
import typing
import os
from pathlib import Path
import typing
import json

# Plotting
import matplotlib.pyplot as plt

# Numerical Computing
import numpy as np
import torch

# muutils
from muutils.nbutils.configure_notebook import configure_notebook
from zanj.zanj import ZANJ, ZANJ_GLOBAL_DEFAULTS

# maze-dataset
from maze_dataset.generation import LatticeMazeGenerators
from maze_dataset import MazeDataset, MazeDatasetConfig
from maze_dataset.dataset.configs import MAZE_DATASET_CONFIGS

# maze-transformer
from maze_transformer.training.config import ConfigHolder, ZanjHookedTransformer, BaseGPTConfig, TrainConfig
from maze_transformer.training.train_model import TrainingResult, train_model
from maze_transformer.training.wandb_logger import WandbProject


In [2]:
# set global defaults for ZANJ
ZANJ_GLOBAL_DEFAULTS.external_array_threshold = 1024
ZANJ_GLOBAL_DEFAULTS.external_list_threshold = 1024

# paths
PATH_EXAMPLES: Path = Path("../examples/")
PATH_DATA: Path = Path("../data/")

# reproducibility and device
DEVICE = configure_notebook(seed=42, dark_mode=True)
print(f"{DEVICE = }")

set up plots with PLOT_MODE = 'inline', FIG_OUTPUT_FMT = None, FIG_BASEPATH = None
DEVICE = device(type='cpu')


In [3]:
print(f"{list(MAZE_DATASET_CONFIGS.keys()) = }")

# if you want to specify a custom config, you can do so here
CFG_CUSTOM: ConfigHolder = ConfigHolder(
    name = "custom",
    dataset_cfg = MazeDatasetConfig(
		name="custom-dataset",
		grid_n=6,
		n_mazes=10000,
		maze_ctor=LatticeMazeGenerators.gen_dfs,
	),
    model_cfg = BaseGPTConfig(
        name="custom-model",
        act_fn="gelu",
        d_model=8,
        d_head=4,
        n_layers=2,
    ),
    train_cfg = TrainConfig(
        name="custom-train",
        optimizer=torch.optim.RMSprop,
        optimizer_kwargs=dict(lr=0.0001),
        batch_size=16,
        dataloader_cfg=dict(
            shuffle=True,
            num_workers=0,
            drop_last=False,
        ),
        intervals_count=dict(
            print_loss=100,
            checkpoint=5,
            eval_fast=10,
            eval_slow=5,
        )
    ),
)

CFG_HALLWAY: ConfigHolder = ConfigHolder(
    name = "hallway_v3",
    dataset_cfg = MazeDatasetConfig(
		name="hallway",
		grid_n=7,
		n_mazes=3_000_000,
		maze_ctor=LatticeMazeGenerators.gen_dfs,
        maze_ctor_kwargs=dict(
            do_forks=False,
        ),
        applied_filters=[{'name': 'collect_generation_meta', 'args': (), 'kwargs': {}}],
        seq_len_max=256,
	),
    model_cfg = BaseGPTConfig(
        name="custom-model",
        act_fn="gelu",
        d_model=128,
        d_head=32,
        n_layers=6,
    ),
    train_cfg = TrainConfig(
        name="custom-train",
        optimizer=torch.optim.AdamW,
        optimizer_kwargs=dict(lr=0.0001),
        batch_size=32,
        dataloader_cfg=dict(
            shuffle=False,
            num_workers=8,
            drop_last=False,
        ),
        intervals_count=dict(
            print_loss=100,
            checkpoint=20,
            eval_fast=100,
            eval_slow=50,
        ),
        validation_dataset_cfg=100,
    ),
)

list(MAZE_DATASET_CONFIGS.keys()) = ['test-g3-n5-a_dfs-h73257', 'demo_small-g3-n100-a_dfs-h44636', 'demo-g6-n10K-a_dfs-h50618']


In [4]:
# this is for training a "real" demo model
CFG_DEMO: ConfigHolder = ConfigHolder.get_config_multisource(
    cfg_names=("demo-g6-n10K-a_dfs-h50618", "tiny-v1", "sweep-v1"),
)

# this is smaller, for testing
CFG_TEST: ConfigHolder = ConfigHolder.get_config_multisource(
    cfg_names=("demo_small-g3-n100-a_dfs-h44636", "nano-v1", "test-v1"),
)

In [5]:
# here is where to specify which config to actually use
CFG: ConfigHolder = CFG_DEMO

In [6]:
print(json.dumps(CFG.summary(), indent=2))

{
  "name": "multsrc_demo-g6-n10K-a_dfs-h50618_tiny-v1_sweep-v1",
  "dataset_cfg": {
    "name": "demo",
    "fname": "demo-g6-n10K-a_dfs-h50618",
    "sdc_hash": 109846810483272090382151445746698648148887661707374334566910004327535649450618,
    "seed": 42,
    "seq_len_min": 1,
    "seq_len_max": 512,
    "applied_filters": [],
    "grid_n": 6,
    "n_mazes": 10000,
    "maze_ctor_name": "gen_dfs",
    "maze_ctor_kwargs": {},
    "endpoint_kwargs": {}
  },
  "model_cfg": {
    "name": "tiny-v1",
    "act_fn": "gelu",
    "d_model": 32,
    "d_head": 16,
    "n_layers": 4,
    "positional_embedding_type": "standard",
    "weight_processing": {
      "are_layernorms_folded": false,
      "are_weights_processed": false
    },
    "n_heads": 2
  },
  "train_cfg": {
    "name": "sweep-v1",
    "optimizer": "AdamW",
    "optimizer_kwargs": {
      "lr": 0.0001
    },
    "batch_size": 64,
    "dataloader_cfg": {
      "shuffle": true,
      "num_workers": 8,
      "persistent_workers": tru

In [7]:
# get just the dataset, generating it if needed. 
# This step can be skipped if you set `do_generate_dataset=True` when calling `train_model`
# or if the dataset in question already exists

# load the dataset
DATASET: MazeDataset = MazeDataset.from_config(CFG.dataset_cfg, verbose=True, local_base_path=PATH_DATA)

trying to get the dataset 'demo-g6-n10K-a_dfs-h50618'
loading dataset from ../data/demo-g6-n10K-a_dfs-h50618.zanj
load successful!
Got dataset demo with 10000 items. output.cfg.to_fname() = 'demo-g6-n10K-a_dfs-h53138'


In [8]:
result: TrainingResult = train_model(
	base_path=PATH_DATA,
    cfg=CFG,
	wandb_project=WandbProject.UNDERSTANDING_SEARCH, # change this to WandbProject.DEMO_NOTEBOOKS!
	do_generate_dataset=False,
	dataset_verbose=True,
    dataset=DATASET,
)

2024-08-21 12:21:40 ERROR Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


wandb: Currently logged in as: miv. Use `wandb login --relogin` to force relogin


2024-08-21 12:21:46 INFO config ={'__format__': 'ConfigHolder(SerializableDataclass)', 'dataset_cfg': {'__format__': 'MazeDatasetConfig(SerializableDataclass)', 'name': 'demo', 'seq_len_min': 1, 'seq_len_max': 512, 'seed': 42, 'applied_filters': [], 'grid_n': 6, 'n_mazes': 10000, 'maze_ctor': {'__name__': 'gen_dfs', '__module__': 'maze_dataset.generation.generators', '__doc__': ['generate a lattice maze using depth first search, iterative', '', '        # Arguments', '        - `grid_shape: Coord`: the shape of the grid', '        - `lattice_dim: int`: the dimension of the lattice', '          (default: `2`)', '        - `accessible_cells: int | float |None`: the number of accessible cells in the maze. If `None`, defaults to the total number of cells in the grid. if a float, asserts it is <= 1 and treats it as a proportion of **total cells**', '            (default: `None`)', '        - `max_tree_depth: int | float | None`: the maximum depth of the tree. If `None`, defaults to `2 * acc


dataset has different config than cfg.dataset_cfg, but the only difference is in applied_filters, so using passed dataset. This is due to fast dataset loading collecting generation metadata for performance reasons



2024-08-21 12:21:46 INFO finished getting training dataset with 10000 samples
2024-08-21 12:21:46 INFO got validation dataset by splitting training dataset into 9950 train and 50 validation samples
2024-08-21 12:21:46 INFO Loaded 9950 sequences
2024-08-21 12:21:46 INFO Creating dataloader
2024-08-21 12:21:46 INFO finished dataloader, passing to train()
2024-08-21 12:21:46 INFO Initializing model



`clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884



Moving model to device:  cpu
2024-08-21 12:21:47 INFO Initializing optimizer
2024-08-21 12:21:50 INFO will train for 155 batches, evals_enabled=True, with intervals: {'print_loss': 1, 'checkpoint': 15, 'eval_fast': 7, 'eval_slow': 15}
2024-08-21 12:21:50 INFO Starting training
2024-08-21 12:22:36 INFO Running evals: eval_fast
2024-08-21 12:22:40 INFO Running evals: eval_slow
2024-08-21 12:22:43 INFO iteration 0/155: loss=8.466
2024-08-21 12:22:43 INFO Saving model checkpoint to ../data/multsrc_demo-g6-n10K-a_dfs-h50618_tiny-v1_sweep-v1_2024-08-21-12-21-39/checkpoints/model.iter_0.zanj
2024-08-21 12:22:46 INFO iteration 1/155: loss=8.426
2024-08-21 12:22:47 INFO iteration 2/155: loss=8.392
2024-08-21 12:22:49 INFO iteration 3/155: loss=8.333
2024-08-21 12:22:51 INFO iteration 4/155: loss=8.290
2024-08-21 12:22:53 INFO iteration 5/155: loss=8.255
2024-08-21 12:22:55 INFO iteration 6/155: loss=8.237
2024-08-21 12:22:57 INFO Running evals: eval_fast
2024-08-21 12:23:00 INFO iteration 7/155


fraction_connections_adjacent_lattice called on path of length less than 2, retuning NaN
prediction = array([[24, 25]])



2024-08-21 12:24:34 INFO iteration 49/155: loss=7.220
2024-08-21 12:24:36 INFO iteration 50/155: loss=7.215
2024-08-21 12:24:37 INFO iteration 51/155: loss=7.158
2024-08-21 12:24:39 INFO iteration 52/155: loss=7.115
2024-08-21 12:24:42 INFO iteration 53/155: loss=7.113
2024-08-21 12:24:44 INFO iteration 54/155: loss=7.084
2024-08-21 12:24:46 INFO iteration 55/155: loss=7.102
2024-08-21 12:24:49 INFO Running evals: eval_fast
2024-08-21 12:24:54 INFO iteration 56/155: loss=7.065
2024-08-21 12:24:56 INFO iteration 57/155: loss=7.040
2024-08-21 12:24:59 INFO iteration 58/155: loss=7.064
2024-08-21 12:25:01 INFO iteration 59/155: loss=7.026
2024-08-21 12:25:03 INFO Running evals: eval_slow
2024-08-21 12:25:07 INFO iteration 60/155: loss=7.031
2024-08-21 12:25:07 INFO Saving model checkpoint to ../data/multsrc_demo-g6-n10K-a_dfs-h50618_tiny-v1_sweep-v1_2024-08-21-12-21-39/checkpoints/model.iter_60.zanj
2024-08-21 12:25:12 INFO iteration 61/155: loss=6.986
2024-08-21 12:25:16 INFO iteration 6