# Origami Data Flow

This notebook traces the data transformations step-by-step through the Origami pipeline,
from raw input data to model forward pass.

## Step 1: Data Preprocessing with NumericDiscretizer

The first step in the pipeline is preprocessing. When `numeric_mode="discretize"`,
high-cardinality numeric fields are binned into categories.

**Input:** `list[dict]` - raw JSON objects

**Output:** `list[dict]` - objects with high-cardinality numerics replaced by bin labels

In [1]:
from origami.preprocessing import NumericDiscretizer

# Sample input data - a list of dictionaries with varied structure:
# - price: numeric field (high-cardinality, will be discretized)
# - category: categorical field (low-cardinality, stays as-is)
# - tags: array with 0-3 values
# - seller: subdocument with 2 subfields (name, rating)
# - some fields are missing in different objects

raw_data = [
    {
        "price": 15000,
        "category": "sedan",
        "tags": ["reliable", "fuel-efficient"],
        "seller": {"name": "Alice", "rating": 4.5},
    },
    {
        "price": 25000,
        "category": "suv",
        "tags": ["spacious"],
        "seller": {"name": "Bob", "rating": 4.8},
    },
    {"price": 18000, "category": "sedan", "seller": {"name": "Alice", "rating": 4.5}},  # no tags
    {
        "price": 32000,
        "category": "truck",
        "tags": [],
        "seller": {"name": "Carol", "rating": 4.2},
    },  # empty tags
    {
        "price": 28000,
        "tags": ["luxury", "fast", "new"],
        "seller": {"name": "Bob", "rating": 4.8},
    },  # no category
    {"price": 12000, "category": "sedan", "tags": ["budget"]},  # no seller
    {
        "price": 45000,
        "category": "suv",
        "tags": ["luxury", "spacious"],
        "seller": {"name": "Carol", "rating": 4.2},
    },
    {
        "price": 22000,
        "category": "sedan",
        "tags": ["reliable"],
        "seller": {"name": "Dave"},
    },  # seller missing rating
]

print("Raw input data:")
print(f"Length: {len(raw_data)}")
print()


def print_data(data: list[dict], title: str):
    print(title)
    for i, obj in enumerate(data):
        print(f"[{i}] {obj}")
    print()


print_data(raw_data, "Input Data:")

Raw input data:
Length: 8

Input Data:
[0] {'price': 15000, 'category': 'sedan', 'tags': ['reliable', 'fuel-efficient'], 'seller': {'name': 'Alice', 'rating': 4.5}}
[1] {'price': 25000, 'category': 'suv', 'tags': ['spacious'], 'seller': {'name': 'Bob', 'rating': 4.8}}
[2] {'price': 18000, 'category': 'sedan', 'seller': {'name': 'Alice', 'rating': 4.5}}
[3] {'price': 32000, 'category': 'truck', 'tags': [], 'seller': {'name': 'Carol', 'rating': 4.2}}
[4] {'price': 28000, 'tags': ['luxury', 'fast', 'new'], 'seller': {'name': 'Bob', 'rating': 4.8}}
[5] {'price': 12000, 'category': 'sedan', 'tags': ['budget']}
[6] {'price': 45000, 'category': 'suv', 'tags': ['luxury', 'spacious'], 'seller': {'name': 'Carol', 'rating': 4.2}}
[7] {'price': 22000, 'category': 'sedan', 'tags': ['reliable'], 'seller': {'name': 'Dave'}}



In [2]:
# Create NumericDiscretizer
# - cat_threshold: fields with more unique values than this are considered high-cardinality
# - n_bins: number of bins to discretize into
# - strategy: binning strategy ("quantile" or "uniform")

discretizer = NumericDiscretizer(
    cat_threshold=3,  # price has 8 unique values -> discretized; rating has 3 -> kept as-is
    n_bins=4,
    strategy="quantile",
)

print(f"NumericDiscretizer config:")
print(f"  cat_threshold: {discretizer.cat_threshold}")
print(f"  n_bins: {discretizer.n_bins}")
print(f"  strategy: {discretizer.strategy}")

NumericDiscretizer config:
  cat_threshold: 3
  n_bins: 4
  strategy: quantile


In [3]:
# fit_transform: analyze the data and transform it
preprocessed_data = discretizer.fit_transform(raw_data)

print("After fit_transform:")
print(f"Type: {type(preprocessed_data)}")
print(f"Length: {len(preprocessed_data)}")
print()

# Show which fields were discretized
print(f"Fields identified as high-cardinality: {discretizer.discretized_fields}")
print()

print_data(discretizer.fit_transform(raw_data), "Transformed Data:")

After fit_transform:
Type: <class 'list'>
Length: 8

Fields identified as high-cardinality: {'price'}

Transformed Data:
[0] {'price': 14250.0, 'category': 'sedan', 'tags': ['reliable', 'fuel-efficient'], 'seller': {'name': 'Alice', 'rating': 4.5}}
[1] {'price': 26750.0, 'category': 'suv', 'tags': ['spacious'], 'seller': {'name': 'Bob', 'rating': 4.8}}
[2] {'price': 20000.0, 'category': 'sedan', 'seller': {'name': 'Alice', 'rating': 4.5}}
[3] {'price': 37500.0, 'category': 'truck', 'tags': [], 'seller': {'name': 'Carol', 'rating': 4.2}}
[4] {'price': 26750.0, 'tags': ['luxury', 'fast', 'new'], 'seller': {'name': 'Bob', 'rating': 4.8}}
[5] {'price': 14250.0, 'category': 'sedan', 'tags': ['budget']}
[6] {'price': 37500.0, 'category': 'suv', 'tags': ['luxury', 'spacious'], 'seller': {'name': 'Carol', 'rating': 4.2}}
[7] {'price': 20000.0, 'category': 'sedan', 'tags': ['reliable'], 'seller': {'name': 'Dave'}}



In [4]:
# The discretizer stores bin edges for each field
print("Bin edges learned by discretizer:")
for field, discretizer_obj in discretizer.discretizers.items():
    print(f"\n{field}:")
    print(f"  Bin edges: {discretizer_obj.bin_edges_[0]}")

Bin edges learned by discretizer:

price:
  Bin edges: [12000. 16500. 23500. 30000. 45000.]


### Summary of Step 1

**What happened:**
- `NumericDiscretizer.fit_transform()` analyzed the data
- Fields with more than `cat_threshold` unique values were identified as high-cardinality
- Only `price` was discretized (8 unique values > threshold of 3)
- `seller.rating` was not discretized (only 3 unique values ≤ threshold)
- Arrays, subdocuments, and categorical fields pass through unchanged

**Data type transformation:**
- Input: `list[dict]` with nested structure (arrays, subdocs, missing fields)
- Output: `list[dict]` with same structure, but `price` values replaced by bin labels

**Next step:** Tokenization

## Step 2: Tokenization

The tokenizer converts JSON objects into token sequences with path information.

**Input:** `list[dict]` - preprocessed JSON objects

**Output:** `TokenizedInstance` per object, containing:
- `tokens`: List of Token objects (grammar tokens, keys, values)
- `paths`: Path for each token (location in JSON hierarchy)
- `numeric_values`: Scaled float values for NUM tokens (used with continuous head)

In [5]:
from origami.tokenizer import JSONTokenizer

# Create and fit the tokenizer on preprocessed data
tokenizer = JSONTokenizer(
    max_depth=32,  # Maximum nesting depth for paths
    max_array_index=256,  # Maximum array index supported
)

# fit() builds vocabulary from all keys and values in the data
tokenizer.fit(preprocessed_data)

print(f"Tokenizer fitted on {len(preprocessed_data)} objects")
print(f"Vocabulary frozen: {tokenizer.vocab.frozen}")

Tokenizer fitted on 8 objects
Vocabulary frozen: True


In [6]:
# Explore the vocabulary
vocab = tokenizer.vocab

print(f"Vocabulary size: {vocab.size}")
print(f"  - Grammar tokens (fixed): 10 (IDs 0-9)")
print(f"  - Dynamic tokens (keys + values): {vocab.size - 10}")
print()

# Show all tokens in the vocabulary
print("All tokens in vocabulary:")
print("-" * 50)
for token_id in range(vocab.size):
    token = vocab.decode(token_id)
    print(f"  ID {token_id:2d}: {token}")

Vocabulary size: 37
  - Grammar tokens (fixed): 10 (IDs 0-9)
  - Dynamic tokens (keys + values): 27

All tokens in vocabulary:
--------------------------------------------------
  ID  0: GrammarToken('START')
  ID  1: GrammarToken('END')
  ID  2: GrammarToken('OBJ_START')
  ID  3: GrammarToken('OBJ_END')
  ID  4: GrammarToken('ARRAY_START')
  ID  5: GrammarToken('ARRAY_END')
  ID  6: GrammarToken('PAD')
  ID  7: GrammarToken('UNK_KEY')
  ID  8: GrammarToken('UNK_VALUE')
  ID  9: GrammarToken('NUM')
  ID 10: KeyToken('price')
  ID 11: ValueToken(14250.0)
  ID 12: KeyToken('category')
  ID 13: ValueToken('sedan')
  ID 14: KeyToken('tags')
  ID 15: ValueToken('reliable')
  ID 16: ValueToken('fuel-efficient')
  ID 17: KeyToken('seller')
  ID 18: KeyToken('name')
  ID 19: ValueToken('Alice')
  ID 20: KeyToken('rating')
  ID 21: ValueToken(4.5)
  ID 22: ValueToken(26750.0)
  ID 23: ValueToken('suv')
  ID 24: ValueToken('spacious')
  ID 25: ValueToken('Bob')
  ID 26: ValueToken(4.8)
  ID 27: 

In [7]:
# Tokenize a single object
# tokenize() returns a TokenizedInstance with tokens, paths, and numeric_values

obj = preprocessed_data[0]
print(f"Object to tokenize:")
print(f"  {obj}")
print()

instance = tokenizer.tokenize(obj, shuffle=False)

print(f"TokenizedInstance:")
print(f"  Type: {type(instance)}")
print(f"  Length: {len(instance)} tokens")

print(f"  Tokens: {instance.tokens}")
print(f"  Paths: {instance.paths}")
print(f"  Numeric Values: {instance.numeric_values}")

Object to tokenize:
  {'price': 14250.0, 'category': 'sedan', 'tags': ['reliable', 'fuel-efficient'], 'seller': {'name': 'Alice', 'rating': 4.5}}

TokenizedInstance:
  Type: <class 'origami.tokenizer.json_tokenizer.TokenizedInstance'>
  Length: 20 tokens
  Tokens: [GrammarToken('START'), GrammarToken('OBJ_START'), KeyToken('price'), ValueToken(14250.0), KeyToken('category'), ValueToken('sedan'), KeyToken('tags'), GrammarToken('ARRAY_START'), ValueToken('reliable'), ValueToken('fuel-efficient'), GrammarToken('ARRAY_END'), KeyToken('seller'), GrammarToken('OBJ_START'), KeyToken('name'), ValueToken('Alice'), KeyToken('rating'), ValueToken(4.5), GrammarToken('OBJ_END'), GrammarToken('OBJ_END'), GrammarToken('END')]
  Paths: [(), (), (), (KeyElement('price'),), (), (KeyElement('category'),), (), (KeyElement('tags'),), (KeyElement('tags'), IndexElement(0)), (KeyElement('tags'), IndexElement(1)), (KeyElement('tags'),), (), (KeyElement('seller'),), (KeyElement('seller'),), (KeyElement('seller'

In [8]:
# Show the token sequence with paths
# Each token has an associated path showing its location in the JSON hierarchy


def format_path(path):
    """Format a path tuple as a readable string."""
    if not path:
        return "(root)"
    parts = []
    for elem in path:
        if hasattr(elem, "key"):  # KeyElement
            parts.append(f".{elem.key}")
        else:  # IndexElement
            parts.append(f"[{elem.index}]")
    return "".join(parts)


print("Token sequence with paths:")
print("-" * 70)
print(f"{'Pos':>3}  {'Token':<35} {'Path':<20}")
print("-" * 70)
for i, (token, path) in enumerate(zip(instance.tokens, instance.paths, strict=True)):
    print(f"{i:3d}  {str(token):<35} {format_path(path):<20}")

Token sequence with paths:
----------------------------------------------------------------------
Pos  Token                               Path                
----------------------------------------------------------------------
  0  GrammarToken('START')               (root)              
  1  GrammarToken('OBJ_START')           (root)              
  2  KeyToken('price')                   (root)              
  3  ValueToken(14250.0)                 .price              
  4  KeyToken('category')                (root)              
  5  ValueToken('sedan')                 .category           
  6  KeyToken('tags')                    (root)              
  7  GrammarToken('ARRAY_START')         .tags               
  8  ValueToken('reliable')              .tags[0]            
  9  ValueToken('fuel-efficient')        .tags[1]            
 10  GrammarToken('ARRAY_END')           .tags               
 11  KeyToken('seller')                  (root)              
 12  GrammarToken('OBJ_ST

### Summary of Step 2

**What happened:**
- `JSONTokenizer.fit()` built a vocabulary from all keys and values in the data
- Grammar tokens have fixed IDs (0-9): START, END, OBJ_START, OBJ_END, ARRAY_START, ARRAY_END, PAD, UNK_KEY, UNK_VALUE, NUM
- Dynamic tokens (keys and values) are assigned IDs starting from 10
- `tokenize()` converts a JSON object to a `TokenizedInstance`

**Token sequence structure:**
- Starts with `START`, ends with `END`
- Objects: `OBJ_START` → key-value pairs → `OBJ_END`
- Arrays: `ARRAY_START` → elements → `ARRAY_END`
- Each token has an associated **path** showing its location in the JSON hierarchy

**Paths are used for Key-Value Position Encoding (KVPE):**
- Instead of standard positional encoding (1, 2, 3...), KVPE encodes the path through the JSON structure
- This allows the model to understand hierarchical relationships

**Next step:** Batch encoding (converting to tensors)

## Step 3a: Dataset Wrappers (OrigamiDataset)

During training, preprocessed data is wrapped in a dataset class that handles:
- **OrigamiDataset** with `shuffle=True` (training): Key-order shuffling for data augmentation
- **OrigamiDataset** with `shuffle=False` (validation): Deterministic tokenization for reproducible evaluation

**Input:** `list[dict]` - preprocessed JSON objects

**Output:** `TokenizedInstance` when indexed (calls `tokenizer.tokenize()` on demand)

In [9]:
from origami.training.dataset import OrigamiDataset

# Create an OrigamiDataset for training with key shuffling enabled
train_dataset = OrigamiDataset(
    data=preprocessed_data,
    tokenizer=tokenizer,
    shuffle=True,  # Shuffle key order each time
)

print(f"OrigamiDataset (training):")
print(f"  Size: {len(train_dataset)} objects")
print(f"  Shuffle: {train_dataset.shuffle}")

OrigamiDataset (training):
  Size: 8 objects
  Shuffle: True


In [10]:
# Demonstrate key shuffling: access the same base object multiple times
# Each access returns a TokenizedInstance with potentially different key order


def extract_key_order(instance):
    """Extract the order of keys from token sequence."""
    from origami.tokenizer.vocabulary import KeyToken

    keys = [t.key for t in instance.tokens if isinstance(t, KeyToken)]
    # Only root-level keys (first 4 keys before nested ones)
    return keys[:4]  # price, category, tags, seller (in some order)


print("Key shuffling demonstration:")
print("Accessing the same object (index 0) multiple times:\n")

for i in range(4):
    inst = train_dataset[0]  # Always maps to base object 0
    keys = extract_key_order(inst)
    print(f"  Access {i + 1}: {keys}")

Key shuffling demonstration:
Accessing the same object (index 0) multiple times:

  Access 1: ['price', 'tags', 'category', 'seller']
  Access 2: ['category', 'seller', 'name', 'rating']
  Access 3: ['tags', 'category', 'price', 'seller']
  Access 4: ['price', 'tags', 'category', 'seller']


### Summary of Step 3a

**What happened:**
- `OrigamiDataset` wraps the preprocessed data with a tokenizer reference
- With `shuffle=True`, `__getitem__` calls `tokenizer.tokenize(obj, shuffle=True)` on demand
- Each access produces a fresh `TokenizedInstance` with randomized key order

**Why key shuffling matters:**
- JSON object keys have no inherent order
- Without shuffling, model memorizes key positions instead of key semantics
- Shuffling forces model to learn from key names, not positions

**Data type:** Still `TokenizedInstance` (no tensors yet)

**Next step:** Collation (batching + converting to tensors)

## Step 3b: Collation (OrigamiDataCollator)

The collator converts a batch of `TokenizedInstance` objects into padded tensors ready for model input.

**Input:** `list[TokenizedInstance]` - batch from DataLoader

**Output:** `dict[str, Tensor]` with:
- `input_ids`: Token IDs `(batch, seq_len)`
- `path_types`: Path element types `(batch, seq_len, max_depth)`
- `path_ids`: Path element IDs `(batch, seq_len, max_depth)`
- `path_lengths`: Path depths `(batch, seq_len)`
- `attention_mask`: Valid positions `(batch, seq_len)`
- `labels`: Same as input_ids for autoregressive training
- `numeric_values`, `numeric_mask`: For continuous head
- `lengths`: Original sequence lengths

**Key feature:** Uses **LEFT-PADDING** so all sequences end at the same position.

In [11]:
from origami.training.collator import OrigamiDataCollator
from origami.training.dataset import OrigamiDataset

# Create a collator
collator = OrigamiDataCollator(
    tokenizer=tokenizer,
    max_length=None,  # No truncation
)

# Use OrigamiDataset with shuffle=False so we get deterministic sequences
# This lets us see the varying lengths due to missing fields
eval_dataset = OrigamiDataset(data=preprocessed_data, tokenizer=tokenizer, shuffle=False)

# Get ALL 8 objects as a single batch
batch_instances = [eval_dataset[i] for i in range(len(eval_dataset))]

print(f"Batch of {len(batch_instances)} TokenizedInstances:")
for i, inst in enumerate(batch_instances):
    print(f"  [{i}] length={len(inst):2d} tokens  <- {list(preprocessed_data[i].keys())}")

Batch of 8 TokenizedInstances:
  [0] length=20 tokens  <- ['price', 'category', 'tags', 'seller']
  [1] length=19 tokens  <- ['price', 'category', 'tags', 'seller']
  [2] length=15 tokens  <- ['price', 'category', 'seller']
  [3] length=18 tokens  <- ['price', 'category', 'tags', 'seller']
  [4] length=19 tokens  <- ['price', 'tags', 'seller']
  [5] length=12 tokens  <- ['price', 'category', 'tags']
  [6] length=20 tokens  <- ['price', 'category', 'tags', 'seller']
  [7] length=17 tokens  <- ['price', 'category', 'tags', 'seller']


In [None]:
# Collate the batch into tensors
batch = collator(batch_instances)

print("Collated batch (EncodedBatch dataclass):")
print("-" * 50)
print(f"  input_ids: shape={tuple(batch.input_ids.shape)}, dtype={batch.input_ids.dtype}")
print(f"  path_types: shape={tuple(batch.path_types.shape)}, dtype={batch.path_types.dtype}")
print(f"  path_ids: shape={tuple(batch.path_ids.shape)}, dtype={batch.path_ids.dtype}")
print(f"  path_lengths: shape={tuple(batch.path_lengths.shape)}, dtype={batch.path_lengths.dtype}")
print(
    f"  attention_mask: shape={tuple(batch.attention_mask.shape)}, dtype={batch.attention_mask.dtype}"
)
print(
    f"  numeric_values: shape={tuple(batch.numeric_values.shape)}, dtype={batch.numeric_values.dtype}"
)
print(f"  numeric_mask: shape={tuple(batch.numeric_mask.shape)}, dtype={batch.numeric_mask.dtype}")
print(f"  lengths: shape={tuple(batch.lengths.shape)}, dtype={batch.lengths.dtype}")
print(f"  labels: shape={tuple(batch.labels.shape)}, dtype={batch.labels.dtype}")

Collated batch (EncodedBatch dataclass):
--------------------------------------------------
  input_ids: shape=(8, 20), dtype=torch.int64
  path_types: shape=(8, 20, 32), dtype=torch.int64
  path_ids: shape=(8, 20, 32), dtype=torch.int64
  path_lengths: shape=(8, 20), dtype=torch.int64
  attention_mask: shape=(8, 20), dtype=torch.bool
  numeric_values: shape=(8, 20), dtype=torch.float32
  numeric_mask: shape=(8, 20), dtype=torch.bool
  lengths: shape=(8,), dtype=torch.int64
  labels: shape=(8, 20), dtype=torch.int64


In [13]:
# Demonstrate LEFT-PADDING
# Shorter sequences have PAD tokens at the START, content at the END
# This allows batched prediction: logits[:, -1, :] gives next token for all sequences

print("LEFT-PADDING demonstration:")
print("=" * 80)
print(f"PAD=6, START=0, END=1, OBJ_START=2, OBJ_END=3, ARRAY_START=4, ARRAY_END=5")
print()

input_ids = batch.input_ids

for i in range(len(batch_instances)):
    ids = input_ids[i].tolist()
    print(f"[{i}] {ids}")

print()
print("All sequences end with token ID 1 (END) at the same position!")

LEFT-PADDING demonstration:
PAD=6, START=0, END=1, OBJ_START=2, OBJ_END=3, ARRAY_START=4, ARRAY_END=5

[0] [0, 2, 10, 11, 12, 13, 14, 4, 15, 16, 5, 17, 2, 18, 19, 20, 21, 3, 3, 1]
[1] [6, 0, 2, 10, 22, 12, 23, 14, 4, 24, 5, 17, 2, 18, 25, 20, 26, 3, 3, 1]
[2] [6, 6, 6, 6, 6, 0, 2, 10, 27, 12, 13, 17, 2, 18, 19, 20, 21, 3, 3, 1]
[3] [6, 6, 0, 2, 10, 28, 12, 29, 14, 4, 5, 17, 2, 18, 30, 20, 31, 3, 3, 1]
[4] [6, 0, 2, 10, 22, 14, 4, 32, 33, 34, 5, 17, 2, 18, 25, 20, 26, 3, 3, 1]
[5] [6, 6, 6, 6, 6, 6, 6, 6, 0, 2, 10, 11, 12, 13, 14, 4, 35, 5, 3, 1]
[6] [0, 2, 10, 28, 12, 23, 14, 4, 32, 24, 5, 17, 2, 18, 30, 20, 31, 3, 3, 1]
[7] [6, 6, 6, 0, 2, 10, 27, 12, 13, 14, 4, 15, 5, 17, 2, 18, 36, 3, 3, 1]

All sequences end with token ID 1 (END) at the same position!


In [None]:
# Path encoding for sample 0
# path_types: 0=pad, 1=key, 2=index
# path_ids: vocab ID for keys, array index for indices
# path_lengths: depth of path at each position

sample_idx = 0
print(f"Path encoding for sample {sample_idx}:")
print(f"Object: {preprocessed_data[sample_idx]}")
print()

path_types = batch.path_types[sample_idx]
path_ids = batch.path_ids[sample_idx]
path_lengths = batch.path_lengths[sample_idx]
ids = batch.input_ids[sample_idx]

print(
    f"{'Pos':>3}  {'TokID':>5}  {'Depth':>5}  {'Path Types (trimmed)':20}  {'Path IDs (trimmed)'}"
)
print("-" * 75)

for pos in range(len(ids)):
    tok_id = ids[pos].item()
    depth = path_lengths[pos].item()
    types = path_types[pos, :3].tolist()  # Show first 3 elements
    pids = path_ids[pos, :3].tolist()
    print(f"{pos:3d}  {tok_id:5d}  {depth:5d}  {str(types):20}  {pids}")

Path encoding for sample 0:
Object: {'price': 14250.0, 'category': 'sedan', 'tags': ['reliable', 'fuel-efficient'], 'seller': {'name': 'Alice', 'rating': 4.5}}

Pos  TokID  Depth  Path Types (trimmed)  Path IDs (trimmed)
---------------------------------------------------------------------------
  0      0      0  [0, 0, 0]             [0, 0, 0]
  1      2      0  [0, 0, 0]             [0, 0, 0]
  2     10      0  [0, 0, 0]             [0, 0, 0]
  3     11      1  [1, 0, 0]             [10, 0, 0]
  4     12      0  [0, 0, 0]             [0, 0, 0]
  5     13      1  [1, 0, 0]             [12, 0, 0]
  6     14      0  [0, 0, 0]             [0, 0, 0]
  7      4      1  [1, 0, 0]             [14, 0, 0]
  8     15      2  [1, 2, 0]             [14, 0, 0]
  9     16      2  [1, 2, 0]             [14, 1, 0]
 10      5      1  [1, 0, 0]             [14, 0, 0]
 11     17      0  [0, 0, 0]             [0, 0, 0]
 12      2      1  [1, 0, 0]             [17, 0, 0]
 13     18      1  [1, 0, 0]     

### Summary of Step 3b

**What happened:**
- `OrigamiDataCollator` takes a list of `TokenizedInstance` objects
- Converts tokens to integer IDs using the vocabulary
- Encodes paths into `path_types`, `path_ids`, `path_lengths` tensors
- Applies **LEFT-PADDING**: PAD tokens at start, content at end
- Returns an `EncodedBatch` dataclass ready for `model.forward()`

**Why LEFT-PADDING?**
```
Right-padding (standard):     Left-padding (Origami):
[START, a, END, PAD, PAD]     [PAD, PAD, START, a, END]
[START, a, b, c, END]         [START, a, b, c, END]
       ↑                                         ↑
  Ends differ                           All end at same position
```
With left-padding, `logits[:, -1, :]` gives the next-token prediction for ALL sequences in the batch simultaneously.

**Data type transformation:**
- Input: `list[TokenizedInstance]` (Python objects)
- Output: `EncodedBatch` dataclass (PyTorch tensors with attribute access)

**Next step:** Model forward pass

## Step 3c: DataLoader

In training, we use a PyTorch `DataLoader` to iterate over batches. The DataLoader:
- Samples indices from the dataset
- Calls `dataset[i]` for each index to get `TokenizedInstance` objects
- Passes the list to the collator to produce the batch tensor dict

Let's verify this produces the same output as calling the collator directly.

In [15]:
from torch.utils.data import DataLoader

# Create a DataLoader with batch_size=4
# This gives us 2 batches from our 8 samples
data_loader = DataLoader(
    eval_dataset,
    batch_size=4,
    shuffle=False,  # Keep order deterministic for comparison
    collate_fn=collator,
)

print(f"DataLoader created:")
print(f"  Dataset size: {len(eval_dataset)}")
print(f"  Batch size: 4")
print(f"  Number of batches: {len(data_loader)}")

DataLoader created:
  Dataset size: 8
  Batch size: 4
  Number of batches: 2


In [16]:
# Get the first batch from the DataLoader
first_batch = next(iter(data_loader))

print("First batch from DataLoader:")
print("-" * 50)
print(f"  input_ids: shape={tuple(first_batch.input_ids.shape)}")
print(f"  path_types: shape={tuple(first_batch.path_types.shape)}")
print(f"  path_ids: shape={tuple(first_batch.path_ids.shape)}")
print(f"  path_lengths: shape={tuple(first_batch.path_lengths.shape)}")
print(f"  attention_mask: shape={tuple(first_batch.attention_mask.shape)}")
print(f"  labels: shape={tuple(first_batch.labels.shape)}")

print()
print("input_ids (first 4 samples, indices 0-3):")
for i in range(4):
    print(f"[{i}] {first_batch.input_ids[i].tolist()}")

First batch from DataLoader:
--------------------------------------------------
  input_ids: shape=(4, 20)
  path_types: shape=(4, 20, 32)
  path_ids: shape=(4, 20, 32)
  path_lengths: shape=(4, 20)
  attention_mask: shape=(4, 20)
  labels: shape=(4, 20)

input_ids (first 4 samples, indices 0-3):
[0] [0, 2, 10, 11, 12, 13, 14, 4, 15, 16, 5, 17, 2, 18, 19, 20, 21, 3, 3, 1]
[1] [6, 0, 2, 10, 22, 12, 23, 14, 4, 24, 5, 17, 2, 18, 25, 20, 26, 3, 3, 1]
[2] [6, 6, 6, 6, 6, 0, 2, 10, 27, 12, 13, 17, 2, 18, 19, 20, 21, 3, 3, 1]
[3] [6, 6, 0, 2, 10, 28, 12, 29, 14, 4, 5, 17, 2, 18, 30, 20, 31, 3, 3, 1]


In [17]:
# Compare with manually collated batch (first 4 samples)
# The DataLoader batch should match samples 0-3 from our earlier batch of 8

from dataclasses import fields

import torch

# Manually collate just the first 4 samples for comparison
manual_batch = collator([eval_dataset[i] for i in range(4)])

print("Comparison: DataLoader batch vs manual collation")
print("=" * 50)

all_match = True
for field in fields(first_batch):
    key = field.name
    first_val = getattr(first_batch, key)
    manual_val = getattr(manual_batch, key)
    if first_val is None and manual_val is None:
        match = True
    elif first_val is None or manual_val is None:
        match = False
    else:
        match = torch.equal(first_val, manual_val)
    status = "✓" if match else "✗"
    print(f"  {key}: {status}")
    if not match:
        all_match = False

print()
if all_match:
    print("All tensors match! DataLoader + collator produces identical output.")

Comparison: DataLoader batch vs manual collation
  input_ids: ✓
  path_types: ✓
  path_ids: ✓
  path_lengths: ✓
  attention_mask: ✓
  numeric_values: ✓
  numeric_mask: ✓
  lengths: ✓
  labels: ✓

All tensors match! DataLoader + collator produces identical output.


## Step 4: Grammar Constraints (JSONGrammarPDA)

During training, the model applies grammar constraints to mask out invalid tokens at each position. This ensures the model only learns to predict syntactically valid JSON.

**Input:** `input_ids` tensor `(batch, seq_len)`

**Output:** Boolean mask `(batch, seq_len, vocab_size)` where `True` = valid token for next position

**Complexity:** O(n) for sequence length - single pass, no nested loops over positions.

The grammar rules enforced:
- After START: OBJ_START or ARRAY_START only
- After OBJ_START: any key or OBJ_END
- After key: value (primitive, OBJ_START, ARRAY_START)
- After value in object: key or OBJ_END
- After ARRAY_START: value or ARRAY_END
- After root closes: END only
- After END: PAD only

In [18]:
from origami.constraints.json_grammar import JSONGrammarPDA

# Create a grammar PDA with our vocabulary
pda = JSONGrammarPDA(vocab=tokenizer.vocab, max_depth=32)

print(f"JSONGrammarPDA created:")
print(f"  Vocabulary size: {tokenizer.vocab.size}")
print(f"  Max depth: {pda.max_depth}")
print(f"  Number of keys: {len(pda._key_ids)}")
print(f"  Number of values: {len(pda._value_ids)}")

JSONGrammarPDA created:
  Vocabulary size: 37
  Max depth: 32
  Number of keys: 7
  Number of values: 23


In [None]:
# Compute grammar mask for our batch
# Use sample 0 (the full object with all fields)
input_ids = batch.input_ids

grammar_mask = pda.compute_valid_mask(input_ids)

print(f"Grammar mask shape: {grammar_mask.shape}")
print(
    f"  (batch={grammar_mask.shape[0]}, seq_len={grammar_mask.shape[1]}, vocab_size={grammar_mask.shape[2]})"
)
print()
print("mask[b, t, v] = True means token v is valid at position t+1 for sequence b")

Grammar mask shape: torch.Size([8, 20, 37])
  (batch=8, seq_len=20, vocab_size=37)

mask[b, t, v] = True means token v is valid at position t+1 for sequence b


In [20]:
# Visualize grammar mask for sample 5 (has left-padding)
# Show: current token at position t, attention mask, and what tokens are valid for position t+1

sample_idx = 5
sample_ids = batch.input_ids[sample_idx]
sample_mask = grammar_mask[sample_idx]
sample_attn = batch.attention_mask[sample_idx]


def get_valid_tokens(mask_row):
    """Get list of valid token IDs from a mask row."""
    return mask_row.nonzero(as_tuple=True)[0].tolist()


def format_valid_tokens(valid_ids, vocab):
    """Format valid token IDs as readable string."""
    if len(valid_ids) == 0:
        return "(none)"
    return ", ".join(str(vocab.decode(tid)) for tid in valid_ids)


print(f"Grammar mask for sample {sample_idx}:")
print(f"Sequence: {preprocessed_data[sample_idx]}")
print()
print(f"{'Pos':>3}  {'Current Token':<30} {'Attn':>5}  {'# Valid':>7}  {'Valid Next Tokens'}")
print("-" * 130)

for t in range(len(sample_ids)):
    current_tok = vocab.decode(sample_ids[t].item())
    attn = sample_attn[t].item()
    valid_ids = get_valid_tokens(sample_mask[t])
    valid_str = format_valid_tokens(valid_ids, vocab)
    attn_str = "True" if attn else "False"
    print(f"{t:3d}  {str(current_tok):<30} {attn_str:>5}  {len(valid_ids):7d}  {valid_str}")

Grammar mask for sample 5:
Sequence: {'price': 14250.0, 'category': 'sedan', 'tags': ['budget']}

Pos  Current Token                   Attn  # Valid  Valid Next Tokens
----------------------------------------------------------------------------------------------------------------------------------
  0  GrammarToken('PAD')            False        1  GrammarToken('START')
  1  GrammarToken('PAD')            False        1  GrammarToken('START')
  2  GrammarToken('PAD')            False        1  GrammarToken('START')
  3  GrammarToken('PAD')            False        1  GrammarToken('START')
  4  GrammarToken('PAD')            False        1  GrammarToken('START')
  5  GrammarToken('PAD')            False        1  GrammarToken('START')
  6  GrammarToken('PAD')            False        1  GrammarToken('START')
  7  GrammarToken('PAD')            False        1  GrammarToken('START')
  8  GrammarToken('START')           True        2  GrammarToken('OBJ_START'), GrammarToken('ARRAY_START')
  

### Summary of Step 4

**What happened:**
- `JSONGrammarPDA.compute_valid_mask()` computes valid next-tokens for each position
- Single loop over positions: O(n) complexity, vectorized over batch
- Pre-computed mask patterns enable O(1) operations per position (no O(vocab_size) loops)

**Grammar in action:**
- After `START`: only `OBJ_START` or `ARRAY_START` valid (2 options)
- After `OBJ_START`: any key or `OBJ_END` valid (~6 keys + 1 = 7 options)
- After a key: any value or nested container valid (~21 values + 2 containers = 23 options)
- After `END`: only `PAD` valid (1 option)

**How it's used in training:**
```python
# In model._apply_grammar_mask():
logits = logits.masked_fill(~grammar_mask, float("-inf"))
```
Invalid tokens get `-inf` logits → 0 probability after softmax.

**Training vs Inference:**
- Training: Full mask computed for all positions (this O(n) method)
- Inference: Grammar state updated incrementally, O(1) per generated token

**Next step:** Model forward pass (embeddings → backbone → heads → loss)