# Load and Prompt a Checkpoint

This notebook demonstrates how to reconstruct a model from a file checkpoint. 
To load the vocabulary and tokenizer used to train this model correctly, this notebook assumes that the corresponding maze datasets are loaded into MongoDB correctly, as outlined in [TokenDatasets.ipynb](TokenDatasets.ipynb).

A single checkpoint holds data about
* Hyper-parameters used for training (excluding the token vocabulary, which is stored together with the corresponding token dataset).
* All model parameters
* Optimizer state
* Number of gradient steps at which the model was solved

This notebook only outlines how a model can be reconstructed from a checkpoint file.
To use any of the other workflows included in this code base, the checkpoint files must be be first imported into MongoDB.
In this notebook we focus on the checkpoint resulting from the run `maze-sweep-rep-nondet-small-0` and assume that the file `maze-sweep-rep-nondet-small-0.ckpt` is present at the project's root directory.
This checkpoint file can be downloaded [here](https://dl.fbaipublicfiles.com/searchformer/ckptDB/maze-sweep-rep-nondet-small-0.ckpt). 
The file [`checkpoint_index.csv`](../doc.checkpoint_index.csv) lists all released checkpoints and their corresponding download link.

First different modules are imported.

In [1]:
import sys 
sys.path.append("..")

import logging
import torch
from searchformer.train import Checkpoint
from searchformer.transformer import EncoderDecoderConfig, sample_probability
from searchformer.trace import DictTokenizer, TokenizedDataset


logging.basicConfig(
    level=logging.DEBUG,
    format="%(levelname)s - %(asctime)s - %(name)s - %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)

Loading the checkpoint file `../maze-sweep-rep-nondet-small-0.ckpt` and printing the training configuration.

In [2]:
ckpt = Checkpoint.from_file("/home/rahim/sokoban-7722-m-trace-plan-100k-2-step-3.ckpt")
ckpt.config

{'_id': 'sokoban-7722-m-trace-plan-100k-2-step-3',
 'data': {'train_name': '65ca57d67f455f390d05bf33.improved',
  'test_name': 'sokoban.7-by-7-walls-2-boxes-2.with-box-40k',
  'batch_size': 1,
  'plan_only': False,
  'num_train_sequences': 100000,
  'num_test_sequences': 1000,
  'load_batch_size': 10000,
  'num_workers': 2,
  'min_reasoning_len': 0,
  'max_reasoning_len': 10000},
 'encoder': 'enc-m-s',
 'decoder': 'dec-m-s',
 'optimizer': {'lr': 7.5e-05,
  'lr_schedule': 'cosine',
  'train_steps': 10000,
  'warmup': 2000,
  'beta_0': 0.9,
  'beta_1': 0.99,
  'cycle_length': 1.0,
  'cosine_theta': 1.0,
  'lr_min_ratio': 0.1},
 'log_interval': 200,
 'eval_interval': 5000,
 'start_checkpoint': 'sokoban-7722-m-trace-plan-100k-2-step-2'}

First loading the tokenized dataset and constructing a `DictTokenizer` object. 
This object is used to map word token sequences to integer lists.
Subsequently, an `EncoderDecoderConfig` object is constructed which holds all network architecture model parameters.
From this object the actual encoder-decoder Transformer is constructed and the model parameters (state dictionary) are loaded in.
The example below runs inference on CPU for the smallest model and shortest sequences to reduce compute requirements.

In [14]:
ckpt.checkpoint_id

'sokoban-7722-m-trace-plan-100k-2-step-3'

In [None]:
# Load vocabulary from tokenized dataset. This is needed to load the training token vocabulary and a test prompt.
tok_dataset = TokenizedDataset(ckpt.config_obj.data.train_name)


INFO - 2025-05-06 06:07:13 - root - Connecting to mongodb://localhost:27017/mongo
DEBUG - 2025-05-06 06:07:13 - pymongo.topology - {"topologyId": {"$oid": "6819a711da53a29b9107c1de"}, "message": "Starting topology monitoring"}
DEBUG - 2025-05-06 06:07:13 - pymongo.topology - {"topologyId": {"$oid": "6819a711da53a29b9107c1de"}, "previousDescription": "<TopologyDescription id: 6819a711da53a29b9107c1de, topology_type: Unknown, servers: []>", "newDescription": "<TopologyDescription id: 6819a711da53a29b9107c1de, topology_type: Unknown, servers: [<ServerDescription ('localhost', 27017) server_type: Unknown, rtt: None>]>", "message": "Topology description changed"}
DEBUG - 2025-05-06 06:07:13 - pymongo.topology - {"topologyId": {"$oid": "6819a711da53a29b9107c1de"}, "serverHost": "localhost", "serverPort": 27017, "message": "Starting server monitoring"}
DEBUG - 2025-05-06 06:07:13 - pymongo.connection - {"clientId": {"$oid": "6819a711da53a29b9107c1de"}, "message": "Connection pool created", "s

DEBUG - 2025-05-06 06:07:13 - pymongo.topology - {"topologyId": {"$oid": "6819a711da53a29b9107c1de"}, "driverConnectionId": 1, "serverConnectionId": 10, "serverHost": "localhost", "serverPort": 27017, "awaited": true, "message": "Server heartbeat started"}
DEBUG - 2025-05-06 06:07:23 - pymongo.topology - {"topologyId": {"$oid": "6819a711da53a29b9107c1de"}, "driverConnectionId": 1, "serverConnectionId": 10, "serverHost": "localhost", "serverPort": 27017, "awaited": true, "durationMS": 10006.402913999409, "reply": "{\"isWritablePrimary\": true, \"topologyVersion\": {\"processId\": {\"$oid\": \"68199a333ff792416d1f8d34\"}}, \"maxBsonObjectSize\": 16777216, \"maxMessageSizeBytes\": 48000000, \"maxWriteBatchSize\": 100000, \"localTime\": {\"$date\": \"2025-05-06T06:07:23.410Z\"}, \"logicalSessionTimeoutMinutes\": 30, \"connectionId\": 10, \"maxWireVersion\": 25, \"ok\": 1.0}", "message": "Server heartbeat succeeded"}
DEBUG - 2025-05-06 06:07:23 - pymongo.topology - {"topologyId": {"$oid": "

In [4]:
tok_dataset.vocabulary

DEBUG - 2025-05-06 06:07:17 - pymongo.serverSelection - {"message": "Server selection started", "selector": "Primary()", "operation": "find", "topologyDescription": "<TopologyDescription id: 6819a711da53a29b9107c1de, topology_type: Single, servers: [<ServerDescription ('localhost', 27017) server_type: Standalone, rtt: 0.0010107979996973881>]>", "clientId": {"$oid": "6819a711da53a29b9107c1de"}}
DEBUG - 2025-05-06 06:07:17 - pymongo.serverSelection - {"message": "Server selection succeeded", "selector": "Primary()", "operation": "find", "topologyDescription": "<TopologyDescription id: 6819a711da53a29b9107c1de, topology_type: Single, servers: [<ServerDescription ('localhost', 27017) server_type: Standalone, rtt: 0.0010107979996973881>]>", "clientId": {"$oid": "6819a711da53a29b9107c1de"}, "serverHost": "localhost", "serverPort": 27017}
DEBUG - 2025-05-06 06:07:17 - pymongo.connection - {"clientId": {"$oid": "6819a711da53a29b9107c1de"}, "message": "Connection checkout started", "serverHost"

['c27750',
 'c7041',
 'c3091',
 'c19843',
 'c23232',
 'c27000',
 'c26372',
 'c38151',
 'c889',
 'c9052',
 'c19086',
 'c24251',
 'c15645',
 'c29417',
 'c33091',
 'c38177',
 'c10393',
 'c574',
 'c12667',
 'c27631',
 'c31151',
 'c8021',
 'c788',
 'c17837',
 'c7407',
 'c26041',
 'c15377',
 'c16136',
 'c16354',
 'c11811',
 'c20833',
 'c1521',
 'c22551',
 'c20414',
 'c23650',
 'c3962',
 'c37828',
 'c25792',
 'c36849',
 'c31302',
 'c32992',
 'c22168',
 'c28627',
 'c2429',
 'c10824',
 'c3006',
 'c10663',
 'c15168',
 'c21917',
 'c17964',
 'c3035',
 'c20816',
 'c26665',
 'c14446',
 'c16451',
 'c35639',
 'c35211',
 'c25841',
 'c28779',
 'c6976',
 'c8227',
 'c30480',
 'c18714',
 'c33544',
 'c13813',
 'c3268',
 'c8614',
 'c15566',
 'c37641',
 'c13122',
 'c30282',
 'c15517',
 'c30426',
 'c37415',
 'c25150',
 'c35393',
 'c39169',
 'c13094',
 'c20755',
 'c34005',
 'c33484',
 'c23914',
 'c38427',
 'c23436',
 'c3234',
 'c36837',
 'c3120',
 'c37252',
 'c8206',
 'c3201',
 'c11418',
 'c10850',
 'c6205',
 '

In [5]:
# Load tokenizer mapping tokens to indices.
tokenizer = DictTokenizer(tok_dataset.vocabulary)
# Construct model config object.


DEBUG - 2025-05-06 06:07:23 - pymongo.serverSelection - {"message": "Server selection started", "selector": "Primary()", "operation": "find", "topologyDescription": "<TopologyDescription id: 6819a711da53a29b9107c1de, topology_type: Single, servers: [<ServerDescription ('localhost', 27017) server_type: Standalone, rtt: 0.00083982883959834>]>", "clientId": {"$oid": "6819a711da53a29b9107c1de"}}
DEBUG - 2025-05-06 06:07:23 - pymongo.serverSelection - {"message": "Server selection succeeded", "selector": "Primary()", "operation": "find", "topologyDescription": "<TopologyDescription id: 6819a711da53a29b9107c1de, topology_type: Single, servers: [<ServerDescription ('localhost', 27017) server_type: Standalone, rtt: 0.00083982883959834>]>", "clientId": {"$oid": "6819a711da53a29b9107c1de"}, "serverHost": "localhost", "serverPort": 27017}
DEBUG - 2025-05-06 06:07:23 - pymongo.connection - {"clientId": {"$oid": "6819a711da53a29b9107c1de"}, "message": "Connection checkout started", "serverHost": "l

In [6]:
enc_dec_config = EncoderDecoderConfig.from_name(
    enc_name=ckpt.config_obj.encoder,
    dec_name=ckpt.config_obj.decoder,
    vocab_size=tokenizer.vocab_size,
)
# Construct model from config.
model = enc_dec_config.construct_model()
# Loading trained weights into model.
model.load_state_dict(ckpt.model_only_state_dict)

DEBUG - 2025-05-06 06:07:27 - root - Creating block: n_heads=4, dim=384.
DEBUG - 2025-05-06 06:07:27 - root - Creating block: n_heads=4, dim=384.
DEBUG - 2025-05-06 06:07:27 - root - Creating block: n_heads=4, dim=384.
DEBUG - 2025-05-06 06:07:27 - root - Creating block: n_heads=4, dim=384.
DEBUG - 2025-05-06 06:07:27 - root - Creating block: n_heads=4, dim=384.
DEBUG - 2025-05-06 06:07:27 - root - Creating block: n_heads=4, dim=384.
DEBUG - 2025-05-06 06:07:27 - root - Creating block: n_heads=4, dim=384.
DEBUG - 2025-05-06 06:07:27 - root - Creating block: n_heads=4, dim=384.
DEBUG - 2025-05-06 06:07:27 - root - Creating block: n_heads=4, dim=384.
DEBUG - 2025-05-06 06:07:27 - root - Creating block: n_heads=4, dim=384.
DEBUG - 2025-05-06 06:07:27 - root - Creating block: n_heads=4, dim=384.
DEBUG - 2025-05-06 06:07:27 - root - Creating block: n_heads=4, dim=384.
DEBUG - 2025-05-06 06:07:27 - root - Creating block: n_heads=4, dim=384.
DEBUG - 2025-05-06 06:07:27 - root - Creating block

<All keys matched successfully>

The following code segment loads the first test prompt and prints it.

In [23]:
print(model)

EncoderDecoder(
  (encoder): Encoder(
    (embedding): Embedding(40016, 384)
    (rope): RoPE()
    (layers): ModuleList(
      (0-7): 8 x SelfAttentionBlock(
        (attention): Attention(
          (rope): RoPE()
          (wq): Linear(in_features=384, out_features=384, bias=False)
          (wk): Linear(in_features=384, out_features=384, bias=False)
          (wv): Linear(in_features=384, out_features=384, bias=False)
          (wo): Linear(in_features=384, out_features=384, bias=False)
        )
        (feed_forward): FeedForward(
          (w1): Linear(in_features=384, out_features=1024, bias=False)
          (w2): Linear(in_features=1024, out_features=384, bias=False)
          (w3): Linear(in_features=384, out_features=1024, bias=False)
        )
        (attention_norm): RMSLayerNorm(normalized_shape=torch.Size([384])eps=1e-05elementwise_affine=True)
        (ffn_norm): RMSLayerNorm(normalized_shape=torch.Size([384])eps=1e-05elementwise_affine=True)
      )
    )
    (output)

In [10]:
test_trace_id_list = tok_dataset.test_ids


DEBUG - 2025-05-06 06:08:24 - root - Loading all ids from Collection(Database(MongoClient(host=['localhost:27017'], document_class=dict, tz_aware=False, connect=True, sockettimeoutms=1800000, connecttimeoutms=1800000), 'tokenSeqDB'), '65ca57d67f455f390d05bf33.improved.meta.test') ...
DEBUG - 2025-05-06 06:08:24 - pymongo.serverSelection - {"message": "Server selection started", "selector": "Primary()", "operation": "find", "topologyDescription": "<TopologyDescription id: 6819a711da53a29b9107c1de, topology_type: Single, servers: [<ServerDescription ('localhost', 27017) server_type: Standalone, rtt: 0.0005582851399910637>]>", "clientId": {"$oid": "6819a711da53a29b9107c1de"}}
DEBUG - 2025-05-06 06:08:24 - pymongo.serverSelection - {"message": "Server selection succeeded", "selector": "Primary()", "operation": "find", "topologyDescription": "<TopologyDescription id: 6819a711da53a29b9107c1de, topology_type: Single, servers: [<ServerDescription ('localhost', 27017) server_type: Standalone, r

In [None]:
ckpt.config_obj

{'train_name': '65ca57d67f455f390d05bf33.improved',
 'test_name': 'sokoban.7-by-7-walls-2-boxes-2.with-box-40k',
 'batch_size': 1,
 'plan_only': False,
 'num_train_sequences': 100000,
 'num_test_sequences': 1000,
 'load_batch_size': 10000,
 'num_workers': 2,
 'min_reasoning_len': 0,
 'max_reasoning_len': 10000}

In [17]:
tok_dataset.train_ids


DEBUG - 2025-05-06 06:13:49 - root - Loading all ids from Collection(Database(MongoClient(host=['localhost:27017'], document_class=dict, tz_aware=False, connect=True, sockettimeoutms=1800000, connecttimeoutms=1800000), 'tokenSeqDB'), '65ca57d67f455f390d05bf33.improved.meta.train') ...
DEBUG - 2025-05-06 06:13:49 - pymongo.serverSelection - {"message": "Server selection started", "selector": "Primary()", "operation": "find", "topologyDescription": "<TopologyDescription id: 6819a711da53a29b9107c1de, topology_type: Single, servers: [<ServerDescription ('localhost', 27017) server_type: Standalone, rtt: 0.00044676434307125326>]>", "clientId": {"$oid": "6819a711da53a29b9107c1de"}}
DEBUG - 2025-05-06 06:13:49 - pymongo.serverSelection - {"message": "Server selection succeeded", "selector": "Primary()", "operation": "find", "topologyDescription": "<TopologyDescription id: 6819a711da53a29b9107c1de, topology_type: Single, servers: [<ServerDescription ('localhost', 27017) server_type: Standalone,

[]

In [13]:
test_trace_id_list

[]

In [12]:
test_trace = next(iter(tok_dataset.test_it(test_trace_id_list[:1])))[0]


DEBUG - 2025-05-06 06:08:37 - root - Iterating over 0 ids.


StopIteration: 

In [None]:

prompt_str = " ".join(test_trace.prompt)
prompt_str = prompt_str.replace("start", "\n\tstart")
prompt_str = prompt_str.replace("wall", "\n\twall ")
prompt_str = prompt_str.replace("goal", "\n\tgoal ")
print("Prompt: " + prompt_str)

The following code segment maps the prompt to an integer tensor and then generates a response sequence. 
This response sequence (a integer tensor) is then decoded into a token sequence and printed.

In [5]:
prompt_tokens = tokenizer.encode(test_trace.prompt)
prompt_tokens_tensor = torch.Tensor(prompt_tokens).long()
response = model.rollout(
    prompt=prompt_tokens_tensor,
    bos_idx=tokenizer.bos,
    eos_idx=tokenizer.eos,
    max_rollout_len=2000,
    sample_fn=sample_probability,
)
response_token_list = tokenizer.decode(response[0].tolist())
print("Response:" + " ".join(response_token_list).replace("bos ", "\n\tbos").replace("eos", "\n\teos").replace("create", "\n\tcreate").replace("close", "\n\tclose ").replace("plan ", "\n\tplan   "))

NameError: name 'tokenizer' is not defined