# Token Insertion Module Demo

This notebook demonstrates how to use the `token_insertion` module to insert token sequences into training data at specific or random positions.

In [1]:
import numpy as np
from pretrain_experiments.token_insertion import (
    IntervalSet,
    wrap_sequences_in_eos_tokens,
    add_explicit_insertions,
    add_random_insertions,
)

## Setup

In [2]:
# Typical values for OLMo-2
SEQUENCE_LENGTH = 4096
EOS_TOKEN_ID = 100257

# Example token sequences (in practice, these come from tokenizer.encode())
token_sequences = [
    [1000, 2000, 3000],
    [100, 200],
    [42, 43, 44, 45],
]

## The insert_dict Structure

The core output of the insertion functions is an `insert_dict` - a dictionary that maps **global token positions** to **token sequences**:

```python
insert_dict = {
    global_position: [token_id, token_id, ...],
    global_position: [token_id, token_id, ...],
    ...
}
```

This dict tells the training framework: "at position X in the training data stream, insert these tokens".

## 1. Explicit Insertions

With explicit insertions, you specify exactly where each sequence goes.

In [3]:
positions = [1000, 5000, 10000]

insert_dict, _ = add_explicit_insertions(
    token_sequences=token_sequences,
    positions=positions,
    existing_insertions=IntervalSet()
)

print("insert_dict =")
print("{")
for pos, tokens in sorted(insert_dict.items()):
    print(f"    {pos}: {tokens},")
print("}")

insert_dict =
{
    1000: [1000, 2000, 3000],
    5000: [100, 200],
    10000: [42, 43, 44, 45],
}


### Overlapping insertions raise ValueError

If you try to insert at overlapping positions, a `ValueError` is raised:

In [4]:
interval_set = IntervalSet()

# First insertion at position 1000 (length 3, so occupies 1000-1002)
insert_dict, interval_set = add_explicit_insertions(
    token_sequences=[[1, 2, 3]],
    positions=[1000],
    existing_insertions=interval_set
)

# Try overlapping insertion at position 1002 - this will raise ValueError
try:
    add_explicit_insertions(
        token_sequences=[[4, 5, 6]],
        positions=[1002],
        existing_insertions=interval_set
    )
except ValueError as e:
    print(f"ValueError: {e}")

ValueError: Insertion at position 1002 (length 3) overlaps with existing insertion


## 2. Random Insertions

With random insertions, positions are chosen automatically within a range.

In [5]:
insert_dict, _ = add_random_insertions(
    token_sequences=token_sequences,
    start_idx=0,
    end_idx=50 * SEQUENCE_LENGTH,
    sequence_length=SEQUENCE_LENGTH,
    existing_insertions=IntervalSet(),
    rng=np.random.default_rng(seed=42)
)

print("insert_dict =")
print("{")
for pos, tokens in sorted(insert_dict.items()):
    print(f"    {pos}: {tokens},")
print("}")

100%|██████████| 3/3 [00:00<00:00, 30690.03it/s]

Total number of inserted tokens: 9
Avoided collisions while inserting sequences: 0
insert_dict =
{
    88696: [100, 200],
    156013: [1000, 2000, 3000],
    173804: [42, 43, 44, 45],
}





## 3. With EOS Wrapping

Typically you wrap sequences with EOS tokens before insertion to cleanly separate them from surrounding content.

In [6]:
# First wrap with EOS tokens
wrapped = wrap_sequences_in_eos_tokens(token_sequences, SEQUENCE_LENGTH, EOS_TOKEN_ID)

print("Before wrapping:")
for seq in token_sequences:
    print(f"  {seq}")

print("\nAfter wrapping (EOS = 100257):")
for seq in wrapped:
    print(f"  {seq}")

Minimum sequence length: 2
Maximum sequence length: 4
Second maximum sequence length: 3
Dropped 0 overly long sequences (longer than 4096 tokens).
Dropped 0 empty sequences.
Minimum sequence length after wrapping: 4
Maximum sequence length after wrapping: 6
Second maximum sequence length after wrapping: 5
Before wrapping:
  [1000, 2000, 3000]
  [100, 200]
  [42, 43, 44, 45]

After wrapping (EOS = 100257):
  [100257, 1000, 2000, 3000, 100257]
  [100257, 100, 200, 100257]
  [100257, 42, 43, 44, 45, 100257]


In [7]:
# Then insert
insert_dict, _ = add_random_insertions(
    token_sequences=wrapped,
    start_idx=0,
    end_idx=50 * SEQUENCE_LENGTH,
    sequence_length=SEQUENCE_LENGTH,
    existing_insertions=IntervalSet(),
    rng=np.random.default_rng(seed=42)
)

print("Final insert_dict (with EOS tokens):")
print("{")
for pos, tokens in sorted(insert_dict.items()):
    print(f"    {pos}: {tokens},")
print("}")

100%|██████████| 3/3 [00:00<00:00, 34007.87it/s]

Total number of inserted tokens: 15
Avoided collisions while inserting sequences: 0
Final insert_dict (with EOS tokens):
{
    88695: [100257, 100, 200, 100257],
    156013: [100257, 1000, 2000, 3000, 100257],
    173803: [100257, 42, 43, 44, 45, 100257],
}





## 4. Combining Explicit and Random

Process explicit insertions first, then random. The random insertions automatically avoid the explicit positions.

In [8]:
interval_set = IntervalSet()

# Phase 1: Explicit
explicit_dict, interval_set = add_explicit_insertions(
    token_sequences=[[1, 2, 3]],
    positions=[8192],  # Exactly at position 8192
    existing_insertions=interval_set
)

# Phase 2: Random (uses the same interval_set, so won't overlap)
random_dict, interval_set = add_random_insertions(
    token_sequences=[[10, 20], [30, 40]],
    start_idx=0,
    end_idx=10 * SEQUENCE_LENGTH,
    sequence_length=SEQUENCE_LENGTH,
    existing_insertions=interval_set,
    rng=np.random.default_rng(seed=42)
)

# Combine
insert_dict = {**explicit_dict, **random_dict}

print("Combined insert_dict:")
print("{")
for pos, tokens in sorted(insert_dict.items()):
    source = "explicit" if pos in explicit_dict else "random"
    print(f"    {pos}: {tokens},  # {source}")
print("}")

100%|██████████| 2/2 [00:00<00:00, 28149.69it/s]

Total number of inserted tokens: 4
Avoided collisions while inserting sequences: 0
Combined insert_dict:
{
    8192: [1, 2, 3],  # explicit
    19064: [30, 40],  # random
    29037: [10, 20],  # random
}





## Summary of Step 1

The `insert_dict` is a simple mapping from global token positions to token sequences:

```
{position: tokens, position: tokens, ...}
```

This is the **framework-agnostic** representation. Next, we convert it to an indexed format for efficient storage and framework-specific use.

# Step 2: Converting to Index-Based Format

While the `insert_dict` structure is great for specifying insertions, it is not straightforward to integrate into dataloaders in an efficient way. Dataloaders process data in chunks (sequences, batches), and looking up insertions by global position would require scanning all positions for every chunk.

Instead, we convert to an **index-based format** that groups insertions by chunk:

```python
index_map = {
    index: [(local_position, [tokens]), (local_position, [tokens]), ...],
    index: [(local_position, [tokens]), ...],
    ...
}
```

Where:
- `index` is the chunk number (e.g., sequence index, batch index)
- `local_position` is the position within that chunk
- `[tokens]` is the token sequence to insert

This allows O(1) lookup: when processing chunk N, just check if `index_map[N]` exists. The `InsertionMapWriter` stores this format in HDF5 files for efficient random access during training.

In [None]:
from pretrain_experiments.token_insertion import convert_insert_dict_to_index_map
from pretrain_experiments.insertion_map import InsertionMapWriter, InsertionMapReader
import tempfile
import os

## Converting with `convert_insert_dict_to_index_map`

The `num_index_tokens` parameter controls how positions are grouped:
- For **per-sequence** indexing: `num_index_tokens = sequence_length` (e.g., 4096)
- For **per-batch** indexing: `num_index_tokens = batch_size * sequence_length`

In [None]:
# Create an insert_dict with positions across multiple sequences
insert_dict = {
    1000: [1, 2, 3],         # Sequence 0, position 1000
    5000: [4, 5, 6, 7],      # Sequence 1, position 904 (5000 - 4096)
    10000: [8, 9],           # Sequence 2, position 1808 (10000 - 2*4096)
}

# Convert to index-based format (per-sequence)
index_map = convert_insert_dict_to_index_map(
    insert_dict,
    num_index_tokens=SEQUENCE_LENGTH  # 4096 tokens per sequence
)

print("Original insert_dict (global positions):")
print("{")
for pos, tokens in sorted(insert_dict.items()):
    print(f"    {pos}: {tokens},")
print("}")

print("\nConverted index_map (sequence-indexed):")
print("{")
for idx, insertions in sorted(index_map.items()):
    print(f"    {idx}: {insertions},")
print("}")

## Handling Boundary Crossings

When an insertion crosses an index boundary, it can be automatically split:

In [None]:
# Insertion at position 4094 with 5 tokens crosses into the next sequence
insert_dict_crossing = {
    4094: [1, 2, 3, 4, 5],  # Starts in seq 0, ends in seq 1
}

index_map_split = convert_insert_dict_to_index_map(
    insert_dict_crossing,
    num_index_tokens=SEQUENCE_LENGTH,
    split_across_boundaries=True  # Default behavior
)

print("Original: position 4094, tokens [1, 2, 3, 4, 5]")
print(f"Sequence boundary at position 4096\n")
print("After splitting:")
for idx, insertions in sorted(index_map_split.items()):
    for pos, tokens in insertions:
        print(f"  Sequence {idx}, position {pos}: {tokens}")

## Storing in HDF5 with InsertionMapWriter

For large-scale training, we store the index_map in HDF5 format. This provides:
- Efficient random access (O(1) lookup by index)
- Compression for token sequences
- Incremental writes without loading entire file

In [None]:
# Create a sample index_map
insert_dict = {
    1000: [1, 2, 3],
    5000: [4, 5, 6, 7],
    10000: [8, 9],
    12500: [10, 11, 12],
}

index_map = convert_insert_dict_to_index_map(insert_dict, num_index_tokens=SEQUENCE_LENGTH)

# Write to HDF5
with tempfile.TemporaryDirectory() as tmpdir:
    hdf5_path = os.path.join(tmpdir, "insertions.h5")
    
    # Write the index_map
    writer = InsertionMapWriter(hdf5_path)
    writer.write_dict(index_map)
    
    # Read it back
    with InsertionMapReader(hdf5_path) as reader:
        print(f"Stored {len(reader)} indices in HDF5\n")
        
        # Check which indices have insertions
        print("Indices with insertions:", reader.get_all_indices())
        
        # Load insertions for a specific index
        print("\nInsertions for sequence 1:")
        insertions = reader.load(1)
        if insertions:
            for pos, tokens in insertions:
                print(f"  Position {pos}: {tokens}")
        
        # Fast membership check (no file I/O after initial load)
        print(f"\nSequence 0 has insertions: {0 in reader}")
        print(f"Sequence 5 has insertions: {5 in reader}")

## Summary

The complete pipeline for token insertions:

1. **Create insert_dict** using `add_explicit_insertions()` or `add_random_insertions()`
   - Maps global token positions to token sequences
   - Use `IntervalSet` to prevent overlapping insertions

2. **Convert to index_map** using `convert_insert_dict_to_index_map()`
   - Groups insertions by sequence/batch index for efficient lookup
   - Handles boundary crossings with automatic splitting

3. **Store in HDF5** using `InsertionMapWriter`
   - Efficient random access during training
   - Use `InsertionMapReader` in your dataloader to lookup insertions by index