In [None]:
import json
import random


# Generate parity data
def make_parity_data(n_samples, min_len=8, max_len=16):
    data = []
    for _ in range(n_samples):
        seq_len = random.randint(min_len, max_len)
        values = [random.choice([0, 1]) for _ in range(seq_len)]
        count = sum(values)
        result = "even" if count % 2 == 0 else "odd"
        data.append({"values": values, "result": result})
    return data


train_data = make_parity_data(1000, min_len=4, max_len=10)
test_data = make_parity_data(100, min_len=6, max_len=12)

for d in train_data[:5]:
    print(d)

In [None]:
from origami import ModelConfig, OrigamiConfig, OrigamiPipeline, TrainingConfig
from origami.training import TableLogCallback, accuracy

config = OrigamiConfig(
    model=ModelConfig(
        backbone="lstm",
        d_model=16,
        lstm_num_layers=2,
    ),
    training=TrainingConfig(
        shuffle_keys=False,
        batch_size=32,
        warmup_steps=100,
        learning_rate=1e-3,
        eval_strategy="steps",
        eval_steps=5000,
        eval_metrics={"acc": accuracy},
        eval_on_train=True,
        target_key="result",
        target_loss_weight=1000.0,
    ),
    device="cpu",
)

pipeline = OrigamiPipeline(config)
pipeline.fit(
    train_data,
    eval_data=test_data,
    epochs=1000,
    callbacks=[TableLogCallback(100)],
    verbose=True,
)

In [None]:
from origami.training import accuracy

pipeline.evaluate(test_data, metrics={"acc": accuracy})

In [None]:
doc = pipeline.generate(1)[0]

print(json.dumps(doc, indent=2))