In [None]:
# Generic
import typing
import os
from pathlib import Path
import typing
import json

# Plotting
import matplotlib.pyplot as plt

# Numerical Computing
import numpy as np
import torch

# muutils
from zanj.zanj import ZANJ, ZANJ_GLOBAL_DEFAULTS

# Our Code
from muutils.mlutils import configure_notebook
from maze_transformer.training.config import ConfigHolder, ZanjHookedTransformer, BaseGPTConfig, TrainConfig
from maze_dataset import MazeDataset, MazeDatasetConfig
from maze_dataset.dataset.configs import MAZE_DATASET_CONFIGS
from maze_dataset.generation import LatticeMazeGenerators
from maze_transformer.training.train_model import TrainingResult, train_model
from maze_transformer.training.wandb_logger import WandbProject


In [None]:
# set global defaults for ZANJ
ZANJ_GLOBAL_DEFAULTS.external_array_threshold = 1024
ZANJ_GLOBAL_DEFAULTS.external_list_threshold = 1024

# paths
PATH_EXAMPLES: Path = Path("../examples/")
PATH_DATA: Path = Path("../data/")

# reproducibility and device
DEVICE = configure_notebook(seed=42, dark_mode=True)
print(f"{DEVICE = }")

In [None]:
print(f"{list(MAZE_DATASET_CONFIGS.keys()) = }")

# if you want to specify a custom config, you can do so here
CFG_CUSTOM: ConfigHolder = ConfigHolder(
    name = "hallway-medium",
    dataset_cfg = MazeDatasetConfig(
		name="custom-hallway",
		grid_n=8,
		n_mazes=1000000,
		maze_ctor=LatticeMazeGenerators.gen_dfs,
        maze_ctor_kwargs=dict(
            do_forks=False,
        ),
	),
    model_cfg = BaseGPTConfig(
        name="custom-model",
        act_fn="gelu",
        d_model=128,
        d_head=32,
        n_layers=6,
    ),
    train_cfg = TrainConfig(
        name="custom-train",
        optimizer=torch.optim.AdamW,
        optimizer_kwargs=dict(lr=0.00001),
        batch_size=32,
        dataloader_cfg=dict(
            shuffle=True,
            num_workers=4,
            drop_last=False,
        ),
        print_loss_interval=100,
        checkpoint_interval=1000,
    ),
)

CFG_TEST: ConfigHolder = ConfigHolder(
        name = "hallway-nano",
    dataset_cfg = MazeDatasetConfig(
		name="custom-hallway",
		grid_n=3,
		n_mazes=8,
		maze_ctor=LatticeMazeGenerators.gen_dfs,
        maze_ctor_kwargs=dict(
            do_forks=False,
        ),
	),
    model_cfg = BaseGPTConfig(
        name="custom-model",
        act_fn="gelu",
        d_model=8,
        d_head=2,
        n_layers=2,
    ),
    train_cfg = TrainConfig(
        name="custom-train",
        optimizer=torch.optim.AdamW,
        optimizer_kwargs=dict(lr=0.0001),
        batch_size=4,
        dataloader_cfg=dict(
            shuffle=True,
            drop_last=False,
        ),
        print_loss_interval=4,
        checkpoint_interval=1000,
    ),
)

In [None]:
# here is where to specify which config to actually use
CFG: ConfigHolder = CFG_TEST # change to CFG_CUSTOM to train a "real" model, the CFG_TEST is for CI testing

In [None]:
print(json.dumps(CFG.summary(), indent=2))

In [None]:
# get just the dataset, generating it if needed. 
# This step can be skipped if you set `do_generate_dataset=True` when calling `train_model`
# or if the dataset in question already exists

# load the dataset
DATASET: MazeDataset = MazeDataset.from_config(
    CFG.dataset_cfg, 
    verbose=True, 
    load_local=False,
    local_base_path=PATH_DATA,
).filter_by.collect_generation_meta()

In [None]:
DATASET.save(PATH_DATA / DATASET.cfg.to_fname())
CFG.dataset_cfg = DATASET.cfg

In [None]:
result: TrainingResult = train_model(
	base_path=PATH_DATA,
    cfg=CFG,
	wandb_project=WandbProject.INTEGRATION_TESTS, # change this to WandbProject.DEMO_NOTEBOOKS! INTEGRATION_TESTS is for CI testing
	do_generate_dataset=False,
	dataset_verbose=True,
)