## Train STORM model on Dungeons dataset


In [1]:
import json

from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline

from storm_ml.datasets.dungeons import generate_data
from storm_ml.preprocessing import (
    DocTokenizerPipe,
    PadTruncTokensPipe,
    SchemaParserPipe,
    TargetFieldPipe,
    TokenEncoderPipe,
    docs_to_df,
)

# generate Dungeons dataset (see storm_ml/datasets/dungeons.py)
data = generate_data(
    num_instances=10_000,
    num_doors_range=(5, 10),
    num_colors=3,
    with_monsters=True,
    num_treasures=5,
)

# print example dictionary
print(json.dumps(data[0], indent=2))

# load data into dataframe and split into train/test
df = docs_to_df(data)
train_docs_df, test_docs_df = train_test_split(df, test_size=0.2, shuffle=True)

TARGET_FIELD = "treasure"

# create train and test pipelines
pipes = {
    "schema": SchemaParserPipe(),
    "target": TargetFieldPipe(TARGET_FIELD),
    "tokenizer": DocTokenizerPipe(path_in_field_tokens=True),
    "padding": PadTruncTokensPipe(length="max"),
    "encoder": TokenEncoderPipe(),
}

pipeline = Pipeline([(name, pipes[name]) for name in ("schema", "target", "tokenizer", "padding", "encoder")])

# process train, eval and test data
train_df = pipeline.fit_transform(train_docs_df)
test_df = pipeline.transform(test_docs_df)

# get stateful objects
schema = pipes["schema"].schema
encoder = pipes["encoder"].encoder
block_size = pipes["padding"].length

# print data stats
print(f"len train: {len(train_df)}, len test: {len(test_df)}")
print(f"vocab size {encoder.vocab_size}")
print(f"block size {block_size}")

{
  "door": 2,
  "key_color": "red",
  "corridor": [
    {
      "door_no": 0,
      "red_key": "spellbooks",
      "blue_key": "artifacts",
      "green_key": "diamonds"
    },
    {
      "monsters": [
        "dragon"
      ],
      "door_no": 1,
      "red_key": "gold",
      "blue_key": "diamonds",
      "green_key": "diamonds"
    },
    {
      "monsters": [
        "troll"
      ],
      "door_no": 2,
      "red_key": "gold",
      "blue_key": "artifacts",
      "green_key": "gemstones"
    },
    {
      "monsters": [
        "dragon"
      ],
      "door_no": 3,
      "red_key": "artifacts",
      "blue_key": "gold",
      "green_key": "gemstones"
    },
    {
      "monsters": [
        "wolf"
      ],
      "door_no": 4,
      "red_key": "artifacts",
      "blue_key": "spellbooks",
      "green_key": "artifacts"
    },
    {
      "monsters": [
        "troll",
        "wolf"
      ],
      "door_no": 5,
      "red_key": "diamonds",
      "blue_key": "gold",
      "green_ke

In [2]:
# create datasets, VPDA and model

from storm_ml.model import STORM
from storm_ml.model.vpda import DocumentVPDA
from storm_ml.preprocessing import DFDataset
from storm_ml.utils import ModelConfig, TrainConfig

# model and train configs
model_config = ModelConfig.from_preset("gpt-micro")
model_config.position_encoding = "NONE"
model_config.vocab_size = encoder.vocab_size
model_config.block_size = block_size

train_config = TrainConfig()
train_config.learning_rate = 1e-3
train_config.n_warmup_batches = 1000

# datasets
train_dataset = DFDataset(train_df)
test_dataset = DFDataset(test_df)

vpda = DocumentVPDA(encoder, schema)
model = STORM(model_config, train_config, vpda=vpda)


running on device mps
number of parameters: 0.81M


In [3]:
from storm_ml.inference import Predictor
from storm_ml.utils.guild import print_guild_scalars

# create a predictor
predictor = Predictor(model, encoder, TARGET_FIELD)


# model callback during training, prints training and test metrics
def progress_callback(model):
    if model.batch_num % train_config.eval_every == 0:
        print_guild_scalars(
            step=f"{int(model.batch_num / train_config.eval_every)}",
            epoch=model.epoch_num,
            batch_num=model.batch_num,
            batch_dt=f"{model.batch_dt*1000:.2f}",
            batch_loss=f"{model.loss:.4f}",
            test_loss=f"{predictor.ce_loss(test_dataset.sample(n=100)):.4f}",
            test_acc=f"{predictor.accuracy(test_dataset.sample(n=100)):.4f}",
            lr=f"{model.learning_rate:.2e}",
        )


model.set_callback("on_batch_end", progress_callback)
model.train_model(train_dataset, batches=5000)


|  step: 0  |  epoch: 0  |  batch_num: 0  |  batch_dt: 0.00  |  batch_loss: 2.6633  |  test_loss: 2.6673  |  test_acc: 0.0300  |  lr: 1.01e-06  |
|  step: 1  |  epoch: 1  |  batch_num: 100  |  batch_dt: 70.73  |  batch_loss: 1.2316  |  test_loss: 1.2135  |  test_acc: 0.1800  |  lr: 1.01e-04  |
|  step: 2  |  epoch: 2  |  batch_num: 200  |  batch_dt: 73.97  |  batch_loss: 0.7978  |  test_loss: 0.7518  |  test_acc: 0.1600  |  lr: 2.01e-04  |
|  step: 3  |  epoch: 3  |  batch_num: 300  |  batch_dt: 75.49  |  batch_loss: 0.6736  |  test_loss: 0.6504  |  test_acc: 0.2200  |  lr: 3.01e-04  |
|  step: 4  |  epoch: 5  |  batch_num: 400  |  batch_dt: 76.41  |  batch_loss: 0.6348  |  test_loss: 0.6253  |  test_acc: 0.2000  |  lr: 4.01e-04  |
|  step: 5  |  epoch: 6  |  batch_num: 500  |  batch_dt: 77.66  |  batch_loss: 0.6225  |  test_loss: 0.6184  |  test_acc: 0.1600  |  lr: 5.01e-04  |
|  step: 6  |  epoch: 7  |  batch_num: 600  |  batch_dt: 79.02  |  batch_loss: 0.6276  |  test_loss: 0.6186  



|  step: 10  |  epoch: 12  |  batch_num: 1000  |  batch_dt: 79.27  |  batch_loss: 0.6186  |  test_loss: 0.6175  |  test_acc: 0.2200  |  lr: 1.00e-03  |
|  step: 11  |  epoch: 13  |  batch_num: 1100  |  batch_dt: 80.86  |  batch_loss: 0.6177  |  test_loss: 0.6173  |  test_acc: 0.1900  |  lr: 9.83e-04  |
|  step: 12  |  epoch: 15  |  batch_num: 1200  |  batch_dt: 83.06  |  batch_loss: 0.6166  |  test_loss: 0.6162  |  test_acc: 0.2600  |  lr: 9.67e-04  |
|  step: 13  |  epoch: 16  |  batch_num: 1300  |  batch_dt: 77.07  |  batch_loss: 0.6172  |  test_loss: 0.6177  |  test_acc: 0.2900  |  lr: 9.50e-04  |
|  step: 14  |  epoch: 17  |  batch_num: 1400  |  batch_dt: 83.70  |  batch_loss: 0.6159  |  test_loss: 0.6153  |  test_acc: 0.2500  |  lr: 9.34e-04  |
|  step: 15  |  epoch: 18  |  batch_num: 1500  |  batch_dt: 84.99  |  batch_loss: 0.6140  |  test_loss: 0.6174  |  test_acc: 0.2800  |  lr: 9.17e-04  |
|  step: 16  |  epoch: 20  |  batch_num: 1600  |  batch_dt: 80.91  |  batch_loss: 0.6163

In [4]:
acc = predictor.accuracy(test_dataset, show_progress=True)
print(f"Test accuracy: {acc:.4f}")

# we can also access the predictions with the `predict()` method
predictions = predictor.predict(test_dataset)
print("Model predictions: ", predictions[:10])
print("Correct labels: ", test_dataset.df["target"].to_list()[:10])

Predicting:   0%|          | 0/83 [00:00<?, ?it/s]

Test accuracy: 0.9810
Model predictions:  ['gold', 'artifacts', 'diamonds', 'artifacts', 'diamonds', 'diamonds', 'artifacts', 'gemstones', 'artifacts', 'gemstones']
Correct labels:  ['gold', 'artifacts', 'diamonds', 'artifacts', 'diamonds', 'diamonds', 'artifacts', 'gemstones', 'artifacts', 'gemstones']
