In [1]:
import math
from pathlib import Path

import wandb
import torch
from muutils.misc import shorten_numerical_to_str

from maze_transformer.training.config import BaseGPTConfig, TrainConfig, ConfigHolder
from maze_dataset import MazeDataset, MazeDatasetConfig
from maze_transformer.training.config import ZanjHookedTransformer



In [2]:
# Download wandb artifact


def get_step(artifact):
    # Find the alias beginning with "step="
    step_alias = [alias for alias in artifact.aliases if alias.startswith("step=")]
    if len(step_alias) != 1: # if we have multiple, skip as well
        return -1
    return int(step_alias[0].split('=')[-1])

def load_model(config_holder, model_path, fold_ln=True):
    model = config_holder.create_model()
    state_dict = torch.load(model_path, map_location=model.cfg.device)
    model.load_and_process_state_dict(
        state_dict,
        fold_ln=False,
        center_writing_weights=True,
        center_unembed=True,
        refactor_factored_attn_matrices=True,
    )
    model.process_weights_(fold_ln=fold_ln)
    model.setup()  # Re-attach layernorm hooks by calling setup
    model.eval()
    return model 

def load_wandb_run(
        project="aisc-search/alex", 
		run_id="sa973hyn", 
		output_path='./downloaded_models',
		checkpoint=None,
	):
    api = wandb.Api()

    artifact_name = f"{project.rstrip('/')}/{run_id}"

    run = api.run(artifact_name)
    wandb_cfg = run.config  # Get run configuration

    # -- Get / Match checkpoint --
    if checkpoint is not None:
        # Match checkpoint 
        available_checkpoints = [artifact for artifact in run.logged_artifacts() if artifact.type == 'model']
        available_checkpoints = list(run.logged_artifacts())
        artifact = [artifact for artifact in available_checkpoints if get_step(artifact) == checkpoint]
        if len(artifact) != 1:
            print(f"Could not find checkpoint {checkpoint} in {artifact_name}")
            print("Available checkpoints:")
            [print(artifact.name, '| Steps: ', get_step(artifact)) for artifact in available_checkpoints]
            return

        artifact = artifact[0]
        print('Loading checkpoint', checkpoint)
    else:
        # Get latest checkpoint
        print('Loading latest checkpoint')
        artifact_name = f"{artifact_name}:latest"
        artifact = api.artifact(artifact_name)
        checkpoint = get_step(artifact)    

    # -- Initalize configurations --
    # Model cfg
    model_properties = {k:wandb_cfg[k] for k in ['act_fn', 'd_model', 'd_head', 'n_layers']}
    model_cfg = BaseGPTConfig(name=f"model {run_id}", weight_processing={'are_layernorms_folded': True, 'are_weights_processed': True}, **model_properties)

    # Dataset cfg
    grid_n = math.sqrt(wandb_cfg['d_vocab'] - 11) #! Jank 
    assert grid_n == int(grid_n), "grid_n must be a perfect square + 11"  # check integer
    ds_cfg = MazeDatasetConfig(name=wandb_cfg.get('dataset_name','no_name'), grid_n=int(grid_n), n_mazes=1)


    cfg = ConfigHolder(model_cfg=model_cfg, dataset_cfg=ds_cfg, train_cfg=TrainConfig(name=f"artifact '{artifact_name}', checkpoint '{checkpoint}'"))

    download_path = Path(output_path)/f'{artifact.name.split(":")[0]}'/f'model.iter_{checkpoint}.pt'
    #! Account for final checkpoint
    if not download_path.exists():
        artifact.download(root=download_path.parent)
        print(f'Downloaded model to {download_path}')
    else:
        print(f'Model already downloaded to {download_path}')

    print('Loading model')
    model = load_model(cfg, download_path, fold_ln=True)
    return  model, cfg


In [3]:
MODEL_KWARGS: dict = dict(
    project="aisc-search/alex", 
	run_id="jerpkipj", 
	checkpoint=None,
)
MODEL, CFG = load_wandb_run(**MODEL_KWARGS)

print(f"{type(MODEL) = } {type(CFG) = }")

Loading latest checkpoint


[34m[1mwandb[0m:   1 of 1 files downloaded.  


Downloaded model to downloaded_models\sa973hyn\model.iter_29000000.pt
Loading model
type(MODEL) = <class 'transformer_lens.HookedTransformer.HookedTransformer'> type(CFG) = <class 'maze_transformer.training.config.ConfigHolder'>


In [4]:
MODEL_ZANJ: ZanjHookedTransformer = ZanjHookedTransformer(CFG)
MODEL_ZANJ.load_state_dict(MODEL.state_dict())
MODEL_ZANJ.training_records = {
    "load_wandb_run_kwargs": MODEL_KWARGS,
    "train_cfg.name": CFG.train_cfg.name,
}
print(f"loaded model with {shorten_numerical_to_str(MODEL_ZANJ.num_params())} parameters")
print(MODEL_ZANJ.training_records)

loaded model with 9.6M parameters
{'load_wandb_run_kwargs': {'project': 'aisc-search/alex', 'run_id': 'sa973hyn', 'checkpoint': None}, 'train_cfg.name': "artifact 'aisc-search/alex/sa973hyn:latest', checkpoint '29000000'"}


In [5]:
MODEL_ZANJ.save(f"examples/wandb.{MODEL_KWARGS['run_id']}.zanj")