In [1]:
# load dungeon data

import json
import random
from pathlib import Path

data_path = Path("../datasets/dungeon_10k_4_8_3_5_mkr.jsonl")
with open(data_path) as f:
    data = [json.loads(line) for line in f]

# strip _id fields
for entry in data:
    if "_id" in entry:
        del entry["_id"]

# Split into train/eval sets
random.shuffle(data)
split_idx = int(0.8 * len(data))
train_data = data[:split_idx]
eval_data = data[split_idx:]

print(f"Train set: {len(train_data)} records")
print(f"Eval set: {len(eval_data)} records")

# print first data point
print(json.dumps(data[0], indent=2))

Train set: 8000 records
Eval set: 2000 records
{
  "door": 0,
  "key_color": "blue",
  "corridor": [
    {
      "monsters": [
        "wolf",
        "troll"
      ],
      "door_no": 4,
      "red_key": "gemstones",
      "green_key": "diamonds",
      "blue_key": "gold"
    },
    {
      "door_no": 1,
      "green_key": "diamonds",
      "blue_key": "gold",
      "red_key": "gold"
    },
    {
      "monsters": [
        "dragon",
        "wolf"
      ],
      "door_no": 3,
      "blue_key": "spellbooks",
      "red_key": "diamonds",
      "green_key": "gemstones"
    },
    {
      "monsters": [
        "goblin",
        "dragon"
      ],
      "door_no": 0,
      "blue_key": "diamonds",
      "red_key": "gemstones",
      "green_key": "gemstones"
    },
    {
      "monsters": [
        "troll"
      ],
      "door_no": 2,
      "green_key": "gemstones",
      "blue_key": "spellbooks",
      "red_key": "artifacts"
    }
  ],
  "treasure": "diamonds"
}


In [None]:
from origami import ModelConfig, OrigamiConfig, OrigamiPipeline, TrainingConfig
from origami.training import TableLogCallback, accuracy

config = OrigamiConfig(
    model=ModelConfig(
        backbone="cached_transformer",
        kvpe_pooling="sum",
        d_model=128,
        n_heads=8,
        n_layers=6,
        d_ff=784,
        dropout=0.0,
        use_grammar_constraints=True,
    ),
    training=TrainingConfig(
        shuffle_keys=False,
        batch_size=100,
        warmup_steps=1000,
        learning_rate=5e-4,
        eval_strategy="epoch",
        eval_steps=100,
        eval_metrics={"acc": accuracy},
        eval_sample_size=100,
        target_key="treasure",
        target_loss_weight=1000.0,
    ),
    device="mps",
)

pipeline = OrigamiPipeline(config)
pipeline.fit(
    train_data,
    eval_data=eval_data,
    callbacks=[TableLogCallback(print_every=100)],
    epochs=50,
    verbose=True,
)

Vocabulary size: 40
Model parameters: 1,652,360
Training device: mps

Training interrupted at epoch 1, step 82


OrigamiPipeline(numeric_mode='none', fitted)

In [3]:
# pipeline.save("dungeon_pipeline.pt")

In [4]:
from origami import OrigamiPipeline

# pipeline = OrigamiPipeline.load("dungeon_pipeline.pt")

In [8]:
from origami.training import accuracy

pipeline.evaluate(eval_data, metrics={"acc": accuracy})

{'loss': 0.8326728854860578, 'acc': 0.9965}

In [10]:
doc = pipeline.generate(1)[0]

print(json.dumps(doc, indent=2))

{
  "door": 4,
  "key_color": "green",
  "corridor": [
    {
      "door_no": 4,
      "red_key": "gold",
      "green_key": "spellbooks",
      "blue_key": "spellbooks"
    },
    {
      "monsters": [
        "troll",
        "orc"
      ],
      "door_no": 1,
      "red_key": "gold",
      "green_key": "diamonds"
    },
    {
      "door_no": 4,
      "blue_key": "spellbooks",
      "red_key": "artifacts",
      "green_key": "gold"
    },
    {
      "monsters": [
        "goblin",
        "troll"
      ],
      "door_no": 3,
      "green_key": "diamonds"
    },
    {
      "monsters": [
        "goblin",
        "orc"
      ],
      "door_no": 3,
      "red_key": "spellbooks"
    },
    {
      "monsters": [
        "troll",
        "goblin"
      ],
      "door_no": 1,
      "blue_key": "artifacts",
      "green_key": "gemstones"
    },
    {
      "door_no": 0,
      "green_key": "artifacts"
    },
    {
      "door_no": 1,
      "blue_key": "gemstones"
    },
    {
      "monste