## Training an ORiGAMi model on the Dungeons dataset

### The Dungeons Dataset

The Dungeons dataset is a (dungeons-themed) challenging synthetic dataset for supervised classification on 
semi-structured data. 

Each instance constains a corridor array with several rooms. Each room has a door number and contains multiple 
treasure chests with different-colored keys. All but one of the treasures are fake though.

The goal is to find the correct room number and key color in each dungeon based on some clues and return the 
only real treasure. The clues are given at the top-level of the object in the fields `door` and `key_color`. 

To make it even harder, the `corridor` array may be shuffled (`shuffle_rooms=True`), and room objects may 
have a number of monsters as their first field (`with_monsters=True`), shifting the token positions of the 
serialized object by a variable amount. 

The following dictionary represents one example JSON instance:

```json
{
    "door": 1,                              // clue which door is the correct one
    "key_color": "blue",                    // clue which key is the correct one
    "corridor": [                           // a corridor with many doors
        {
            "monsters": ["troll", "wolf"],  // optional monsters in front of the door
            "door_no": 1,                   // door number in the corridor
            "red_key": "gemstones",         // different keys return different treasures,
            "blue_key": "spellbooks",       // but only one is real, the others are fake
            "green_key": "artifacts"
        },
        {                                   // another room, here without monsters
            "door_no": 0,                   // rooms can be shuffled, here room 0 comes after 1        
            "red_key": "diamonds",          
            "blue_key": "gold",           
            "green_key": "gemstones"
        },
        // ... more rooms ...
    ],
    "treasure": "spellbooks"                // correct treasure (target label)
}
```

The correct answer for this instance is "spellbooks", because the `door` is 1 and the `key_color` is "blue".


### Preprocessing

The JSON objects are tokenized by recursively walking through them depth-first and extracting key and value tokens. 
Additionally, when encountering arrays or nested objects, special grammar tokens are included in the sequence. 
This diagram illustrates tokenization.

<img src="../assets/preprocessing-diagram.png" width="600px" />


In [None]:
import json

from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline

from origami.utils import set_seed
from origami.datasets.dungeons import generate_data
from origami.preprocessing import (
    DocTokenizerPipe,
    PadTruncTokensPipe,
    SchemaParserPipe,
    TargetFieldPipe,
    TokenEncoderPipe,
    docs_to_df,
)

# for reproducibility
# set_seed(123)

# generate Dungeons dataset (see origami/datasets/dungeons.py)
data = generate_data(
    num_instances=100_000,
    num_doors_range=(5, 10),
    num_colors=3,
    num_treasures=5,
    with_monsters=True,    # makes it harder as token positions get shifted by variable amount
    shuffle_rooms=True     # makes it harder because rooms are in random order
)

# print example dictionary
print(json.dumps(data[0], indent=2))

# load data into dataframe and split into train/test
df = docs_to_df(data)
train_docs_df, test_docs_df = train_test_split(df, test_size=0.2, shuffle=True)

TARGET_FIELD = "treasure"

# create train and test pipelines
pipes = {
    "schema": SchemaParserPipe(),
    "target": TargetFieldPipe(TARGET_FIELD),
    "tokenizer": DocTokenizerPipe(path_in_field_tokens=True),
    "padding": PadTruncTokensPipe(length="max"),
    "encoder": TokenEncoderPipe(),
}

pipeline = Pipeline([(name, pipes[name]) for name in ("schema", "target", "tokenizer", "padding", "encoder")])

# process train, eval and test data
train_df = pipeline.fit_transform(train_docs_df)
test_df = pipeline.transform(test_docs_df)

# get stateful objects
schema = pipes["schema"].schema
encoder = pipes["encoder"].encoder
block_size = pipes["padding"].length

# print data stats
print(f"len train: {len(train_df)}, len test: {len(test_df)}")
print(f"vocab size {encoder.vocab_size}")
print(f"block size {block_size}")

### ORiGAMi Model

Here we instantiate an ORiGAMi model, a modified transformer trained on the token sequences created above.
We use a standard "medium" configuration. ORiGAMi models are relatively robust to the choice of hyper-parameter
and default configurations often work well for mid-sized datasets. 

In [None]:
from origami.model import ORIGAMI
from origami.model.vpda import ObjectVPDA
from origami.preprocessing import DFDataset
from origami.utils import ModelConfig, TrainConfig, count_parameters

# model and train configs
model_config = ModelConfig.from_preset("medium")   # see origami/utils/config.py for different presets
model_config.vocab_size = encoder.vocab_size
model_config.block_size = block_size

train_config = TrainConfig()
train_config.learning_rate = 1e-3
train_config.print_every = 10
train_config.eval_every = 500

# wrap dataframes in datasets
train_dataset = DFDataset(train_df)
test_dataset = DFDataset(test_df)

# create PDA and pass it to the model 
vpda = ObjectVPDA(encoder, schema)
model = ORIGAMI(model_config, train_config, vpda=vpda)

n_params = count_parameters(model)
print(f"Number of parameters: {n_params/1e6:.2f}M")

In [None]:
from origami.inference import Predictor
from origami.utils import make_progress_callback

# create a predictor
predictor = Predictor(model, encoder, TARGET_FIELD)

# create and register progress callback
progress_callback = make_progress_callback(
    train_config, train_dataset=train_dataset, test_dataset=test_dataset, predictor=predictor
)
model.set_callback("on_batch_end", progress_callback)

# train model (train and test accuracy should start to go towards 1.0 after ~3000 batches as loss drops below 0.7)
model.train_model(train_dataset, batches=5000)

In [None]:
# calculate test accuracy
acc = predictor.accuracy(test_dataset, show_progress=True)
print(f"Test accuracy: {acc:.4f}")

# we can also access the predictions with the `predict()` method
predictions = predictor.predict(test_dataset)
print("Model predictions: ", predictions[:10])
print("Correct labels: ", test_dataset.df["target"].to_list()[:10])