In [1]:
%load_ext autoreload
%autoreload 2

# Generic
import typing
import os
from pathlib import Path
import typing
import json

# Plotting
import matplotlib.pyplot as plt

# Numerical Computing
import numpy as np
import torch

# muutils
from zanj.zanj import ZANJ, ZANJ_GLOBAL_DEFAULTS

# Our Code
from muutils.nbutils.configure_notebook import configure_notebook
from maze_transformer.training.config import ConfigHolder, ZanjHookedTransformer, BaseGPTConfig, TrainConfig
from maze_dataset import MazeDataset, MazeDatasetConfig
from maze_dataset.dataset.configs import MAZE_DATASET_CONFIGS
from maze_dataset.generation import LatticeMazeGenerators
from maze_transformer.training.train_model import TrainingResult, train_model
from maze_transformer.training.wandb_logger import WandbProject


In [2]:
# set global defaults for ZANJ
ZANJ_GLOBAL_DEFAULTS.external_array_threshold = 1024
ZANJ_GLOBAL_DEFAULTS.external_list_threshold = 1024

# paths
PATH_EXAMPLES: Path = Path("../examples/")
PATH_DATA: Path = Path("../data/")

# reproducibility and device
#DEVICE = configure_notebook(seed=42, dark_mode=True)
#print(f"{DEVICE = }")

In [3]:
print(f"{list(MAZE_DATASET_CONFIGS.keys()) = }")

# if you want to specify a custom config, you can do so here
CFG_CUSTOM: ConfigHolder = ConfigHolder(
    name = "custom",
    dataset_cfg = MazeDatasetConfig(
		name="custom-dataset",
		grid_n=6,
		n_mazes=10000,
		maze_ctor=LatticeMazeGenerators.gen_dfs,
	),
    model_cfg = BaseGPTConfig(
        name="custom-model",
        act_fn="gelu",
        d_model=8,
        d_head=4,
        n_layers=2,
    ),
    train_cfg = TrainConfig(
        name="custom-train",
        optimizer=torch.optim.RMSprop,
        optimizer_kwargs=dict(lr=0.0001),
        batch_size=16,
        dataloader_cfg=dict(
            shuffle=True,
            num_workers=0,
            drop_last=False,
        ),
        intervals_count=dict(
            print_loss=100,
            checkpoint=5,
            eval_fast=10,
            eval_slow=5,
        )
    ),
)

CFG_HALLWAY: ConfigHolder = ConfigHolder(
    name = "hallway_v3",
    dataset_cfg = MazeDatasetConfig(
		name="hallway",
		grid_n=7,
		n_mazes=3_000_000,
		maze_ctor=LatticeMazeGenerators.gen_dfs,
        maze_ctor_kwargs=dict(
            do_forks=False,
        ),
        applied_filters=[{'name': 'collect_generation_meta', 'args': (), 'kwargs': {}}],
        seq_len_max=256,
	),
    model_cfg = BaseGPTConfig(
        name="custom-model",
        act_fn="gelu",
        d_model=128,
        d_head=32,
        n_layers=6,
    ),
    train_cfg = TrainConfig(
        name="custom-train",
        optimizer=torch.optim.AdamW,
        optimizer_kwargs=dict(lr=0.0001),
        batch_size=32,
        dataloader_cfg=dict(
            shuffle=False,
            num_workers=8,
            drop_last=False,
        ),
        intervals_count=dict(
            print_loss=100,
            checkpoint=20,
            eval_fast=100,
            eval_slow=50,
        ),
        validation_dataset_cfg=100,
    ),
)

list(MAZE_DATASET_CONFIGS.keys()) = ['test-g3-n5-a_dfs-h75556', 'demo_small-g3-n100-a_dfs-h88371', 'demo-g6-n10K-a_dfs-h30615']


In [4]:
# this is for training a "real" demo model
CFG_DEMO: ConfigHolder = ConfigHolder.get_config_multisource(
    cfg_names=("test-g3-n5-a_dfs-h75556", "tiny-v1", "sweep-v1"),
)

# this is smaller, for testing
CFG_TEST: ConfigHolder = ConfigHolder.get_config_multisource(
    cfg_names=("demo_small-g3-n100-a_dfs-h88371", "nano-v1", "test-v1"),
)

In [21]:
# here is where to specify which config to actually use
CFG: ConfigHolder = CFG_TEST

In [22]:
print(json.dumps(CFG.summary(), indent=2))

{
  "name": "custom",
  "dataset_cfg": {
    "name": "custom-dataset",
    "fname": "custom-dataset-g6-n10K-a_dfs-h74621",
    "sdc_hash": 30263437081808734576530068202485318808735944160954955199116767321681194074621,
    "seed": 42,
    "seq_len_min": 1,
    "seq_len_max": 512,
    "applied_filters": [],
    "grid_n": 6,
    "grid_shape": [
      6,
      6
    ],
    "n_mazes": 10000,
    "maze_ctor_name": "gen_dfs",
    "maze_ctor_kwargs": {}
  },
  "model_cfg": {
    "name": "custom-model",
    "act_fn": "gelu",
    "d_model": 8,
    "d_head": 4,
    "n_layers": 2,
    "weight_processing": {
      "are_layernorms_folded": false,
      "are_weights_processed": false
    },
    "n_heads": 2
  },
  "train_cfg": {
    "name": "custom-train",
    "optimizer": "RMSprop",
    "optimizer_kwargs": {
      "lr": 0.0001
    },
    "batch_size": 16,
    "dataloader_cfg": {
      "shuffle": true,
      "num_workers": 0,
      "drop_last": false
    },
    "intervals": null,
    "intervals_count

In [33]:
# get just the dataset, generating it if needed. 
# This step can be skipped if you set `do_generate_dataset=True` when calling `train_model`
# or if the dataset in question already exists

# load the dataset
DATASET: MazeDataset = MazeDataset.from_config(CFG.dataset_cfg, verbose=True, local_base_path=PATH_DATA)

trying to get the dataset 'custom-dataset-g6-n10K-a_dfs-h74621'
loading dataset from ../data/custom-dataset-g6-n10K-a_dfs-h74621.zanj
Got dataset custom-dataset with 10000 items. output.cfg.to_fname() = 'custom-dataset-g6-n10K-a_dfs-h74621'


In [24]:
result: TrainingResult = train_model(
	base_path=PATH_DATA,
    cfg=CFG,
	wandb_project=None, # change this to WandbProject.DEMO_NOTEBOOKS!
	do_generate_dataset=False,
	dataset_verbose=True,
    dataset=DATASET,
)

{
  "logger_cfg": {
    "output_dir": "../data/custom_2024-01-26-17-07-30",
    "cfg.name": "custom",
    "data_cfg.name": "custom-dataset",
    "train_cfg.name": "custom-train",
    "model_cfg.name": "custom-model",
    "cfg_summary": {
      "name": "custom",
      "dataset_cfg": {
        "name": "custom-dataset",
        "fname": "custom-dataset-g6-n10K-a_dfs-h74621",
        "sdc_hash": 30263437081808734576530068202485318808735944160954955199116767321681194074621,
        "seed": 42,
        "seq_len_min": 1,
        "seq_len_max": 512,
        "applied_filters": [],
        "grid_n": 6,
        "grid_shape": [
          6,
          6
        ],
        "n_mazes": 10000,
        "maze_ctor_name": "gen_dfs",
        "maze_ctor_kwargs": {}
      },
      "model_cfg": {
        "name": "custom-model",
        "act_fn": "gelu",
        "d_model": 8,
        "d_head": 4,
        "n_layers": 2,
        "weight_processing": {
          "are_layernorms_folded": false,
          "are_weig