In [1]:
# Generic
import typing
import os
from pathlib import Path
import typing
import json

# Plotting
import matplotlib.pyplot as plt

# Numerical Computing
import numpy as np
import torch

# Our Code
from maze_transformer.utils.notebook_utils import configure_notebook
from maze_transformer.training.config import ConfigHolder, ZanjHookedTransformer, BaseGPTConfig, TrainConfig
from maze_transformer.dataset.maze_dataset import MazeDataset, MazeDatasetConfig
from maze_transformer.dataset.maze_dataset_configs import MAZE_DATASET_CONFIGS
from maze_transformer.generation.generators import LatticeMazeGenerators
from maze_transformer.training.train_model import TrainingResult, train_model
from maze_transformer.training.wandb_logger import WandbProject
# from maze_transformer.evaluation.baseline_models import RandomBaseline
# from maze_transformer.evaluation.path_evals import PathEvals, PathEvalFunction
# from maze_transformer.evaluation.eval_model import evaluate_model

# from maze_transformer.utils.utils import set_reproducibility, chunks, get_checkpoint_paths_for_run

In [2]:
DEVICE = configure_notebook(seed=42, dark_mode=True)
print(f"{DEVICE = }")

DEVICE = device(type='cpu')


In [3]:
print(f"{list(MAZE_DATASET_CONFIGS.keys()) = }")

# if you want to specify a custom config, you can do so here
CFG_CUSTOM: ConfigHolder = ConfigHolder(
    name = "custom",
    dataset_cfg = MazeDatasetConfig(
		name="custom-dataset",
		grid_n=6,
		n_mazes=10000,
		maze_ctor=LatticeMazeGenerators.gen_dfs,
	),
    model_cfg = BaseGPTConfig(
        name="custom-model",
        act_fn="gelu",
        d_model=8,
        d_head=4,
        n_layers=2,
    ),
    train_cfg = TrainConfig(
        name="custom-train",
        optimizer=torch.optim.RMSprop,
        optimizer_kwargs=dict(lr=0.0001),
        batch_size=16,
        dataloader_cfg=dict(
            shuffle=True,
            num_workers=0,
            drop_last=False,
        ),
        print_loss_interval=100,
        checkpoint_interval=1000,
    ),
)

list(MAZE_DATASET_CONFIGS.keys()) = ['test-g3-n5-a_dfs-h82387', 'demo_small-g3-n100-a_dfs-h69818', 'demo-g6-n10K-a_dfs-h92077']


In [4]:
# this is for training a "real" demo model
CFG_DEMO: ConfigHolder = ConfigHolder.get_config_multisource(
    cfg_names=("demo-g6-n10K-a_dfs-h92077", "tuned-v1", "sweep-v1"),
)

# this is smaller, for testing
CFG_TEST: ConfigHolder = ConfigHolder.get_config_multisource(
    cfg_names=("demo_small-g3-n100-a_dfs-h69818", "nano-v1", "test-v1"),
)

In [5]:
# here is where to specify which config to actually use
CFG: ConfigHolder = CFG_TEST

In [6]:
print(json.dumps(CFG.summary(), indent=2))

{
  "name": "multsrc_demo_small-g3-n100-a_dfs-h69818_nano-v1_test-v1",
  "dataset_cfg": {
    "name": "demo_small",
    "fname": "demo_small-g3-n100-a_dfs-h69818",
    "sdc_hash": 112638380443659042591091270151231815740307954105906328170842675288284902869818,
    "seed": 42,
    "seq_len_min": 1,
    "seq_len_max": 512,
    "padding_token_index": 10,
    "token_arr_joined": "<ADJLIST_START> <ADJLIST_END> <TARGET_START> <TARGET_END> <ORIGIN_START> <ORIGIN_END> <PATH_START> <PATH_END> <--> ; <PADDING> (0,0) (0,1) (1,0) (1,1) (0,2) (2,0) (1,2) (2,1) (2,2)",
    "applied_filters": [],
    "grid_n": 3,
    "grid_shape": [
      3,
      3
    ],
    "n_mazes": 100,
    "maze_ctor_name": "gen_dfs",
    "maze_ctor_kwargs": {}
  },
  "model_cfg": {
    "__format__": "BaseGPTConfig(SerializableDataclass)",
    "name": "nano-v1",
    "act_fn": "gelu",
    "d_model": 8,
    "d_head": 4,
    "n_layers": 2,
    "weight_processing": {
      "are_layernorms_folded": false,
      "are_weights_processe

In [8]:
result: TrainingResult = train_model(
	base_path="../data/",
    cfg=CFG,
	wandb_project=WandbProject.DEMO_NOTEBOOKS,
	do_generate_dataset=True,
)

VBox(children=(Label(value='0.125 MB of 0.125 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
loss,██▆▄▄▂▁

0,1
device,cpu
loss,3.21348
model.device,cpu
model_n_params,1536
n_batches,7


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

2023-05-20 02:07:24 INFO config ={'__format__': 'ConfigHolder(SerializableDataclass)', 'dataset_cfg': {'__format__': 'MazeDatasetConfig(SerializableDataclass)', 'name': 'demo_small', 'dtype': 'torch.int16', 'seq_len_min': 1, 'seq_len_max': 512, 'seed': 42, 'applied_filters': [], 'grid_n': 3, 'n_mazes': 100, 'maze_ctor': {'__name__': 'gen_dfs', '__module__': 'maze_transformer.generation.generators', '__doc__': ['generate a lattice maze using depth first search, iterative', '', '        # Arguments', '        - `grid_shape: Coord`: the shape of the grid', '        - `lattice_dim: int`: the dimension of the lattice', '          (default: `2`)', '        - `n_accessible_cells: int | None`: the number of accessible cells in the maze. If `None`, defaults to the total number of cells in the grid.', '            (default: `None`)', '        - `max_tree_depth: int | None`: the maximum depth of the tree. If `None`, defaults to `2 * n_accessible_cells`.', '            (default: `None`)', '       

generating & solving mazes: 100%|██████████| 100/100 [00:00<00:00, 1265.15maze/s]

2023-05-20 02:07:25 INFO finished getting dataset
2023-05-20 02:07:25 INFO Loaded 100 sequences
2023-05-20 02:07:25 INFO Creating dataloader
2023-05-20 02:07:25 INFO finished dataloader, passing to train()
2023-05-20 02:07:25 INFO Initializing model
2023-05-20 02:07:25 INFO Initializing optimizer





2023-05-20 02:07:25 INFO Starting training
2023-05-20 02:07:25 INFO will train for 7 batches, checkpoint_interval_iters = 6, loss_interval_iters = 1
2023-05-20 02:07:25 INFO iteration 0/7: loss=3.258
2023-05-20 02:07:25 INFO Saving model to ../data/multsrc_demo_small-g3-n100-a_dfs-h69818_nano-v1_test-v1_2023-05-20-02-07-18/checkpoints/model.iter_0.zanj
2023-05-20 02:07:25 INFO iteration 1/7: loss=3.269
2023-05-20 02:07:25 INFO iteration 2/7: loss=3.271
2023-05-20 02:07:25 INFO iteration 3/7: loss=3.227
2023-05-20 02:07:25 INFO iteration 4/7: loss=3.212
2023-05-20 02:07:25 INFO iteration 5/7: loss=3.186
2023-05-20 02:07:25 INFO iteration 6/7: loss=3.152
2023-05-20 02:07:25 INFO Saving model to ../data/multsrc_demo_small-g3-n100-a_dfs-h69818_nano-v1_test-v1_2023-05-20-02-07-18/checkpoints/model.iter_6.zanj
2023-05-20 02:07:25 INFO Saving final model to ../data/multsrc_demo_small-g3-n100-a_dfs-h69818_nano-v1_test-v1_2023-05-20-02-07-18/model.final.zanj
2023-05-20 02:07:26 INFO Done!
