# Token Insertion Module Demo

This notebook demonstrates how to use the `token_insertion` module to insert token sequences into training data at specific or random positions.

In [1]:
import numpy as np
from pretrain_experiments.token_insertion import (
    IntervalSet,
    wrap_sequences_in_eos_tokens,
    add_explicit_insertions,
    add_random_insertions,
)

## Setup

In [2]:
# Typical values for OLMo-2
SEQUENCE_LENGTH = 4096
EOS_TOKEN_ID = 100257

# Example token sequences (in practice, these come from tokenizer.encode())
token_sequences = [
    [1000, 2000, 3000],
    [100, 200],
    [42, 43, 44, 45],
]

## The insert_dict Structure

The core output of the insertion functions is an `insert_dict` - a dictionary that maps **global token positions** to **token sequences**:

```python
insert_dict = {
    global_position: [token_id, token_id, ...],
    global_position: [token_id, token_id, ...],
    ...
}
```

This dict tells the training framework: "at position X in the training data stream, insert these tokens".

## 1. Explicit Insertions

With explicit insertions, you specify exactly where each sequence goes.

In [3]:
positions = [1000, 5000, 10000]

insert_dict, _ = add_explicit_insertions(
    token_sequences=token_sequences,
    positions=positions,
    existing_insertions=IntervalSet()
)

print("insert_dict =")
print("{")
for pos, tokens in sorted(insert_dict.items()):
    print(f"    {pos}: {tokens},")
print("}")

insert_dict =
{
    1000: [1000, 2000, 3000],
    5000: [100, 200],
    10000: [42, 43, 44, 45],
}


### Overlapping insertions raise ValueError

If you try to insert at overlapping positions, a `ValueError` is raised:

In [4]:
interval_set = IntervalSet()

# First insertion at position 1000 (length 3, so occupies 1000-1002)
insert_dict, interval_set = add_explicit_insertions(
    token_sequences=[[1, 2, 3]],
    positions=[1000],
    existing_insertions=interval_set
)

# Try overlapping insertion at position 1002 - this will raise ValueError
try:
    add_explicit_insertions(
        token_sequences=[[4, 5, 6]],
        positions=[1002],
        existing_insertions=interval_set
    )
except ValueError as e:
    print(f"ValueError: {e}")

ValueError: Insertion at position 1002 (length 3) overlaps with existing insertion


## 2. Random Insertions

With random insertions, positions are chosen automatically within a range.

In [5]:
insert_dict, _ = add_random_insertions(
    token_sequences=token_sequences,
    start_idx=0,
    end_idx=50 * SEQUENCE_LENGTH,
    sequence_length=SEQUENCE_LENGTH,
    existing_insertions=IntervalSet(),
    rng=np.random.default_rng(seed=42)
)

print("insert_dict =")
print("{")
for pos, tokens in sorted(insert_dict.items()):
    print(f"    {pos}: {tokens},")
print("}")

100%|██████████| 3/3 [00:00<00:00, 30690.03it/s]

Total number of inserted tokens: 9
Avoided collisions while inserting sequences: 0
insert_dict =
{
    88696: [100, 200],
    156013: [1000, 2000, 3000],
    173804: [42, 43, 44, 45],
}





## 3. With EOS Wrapping

Typically you wrap sequences with EOS tokens before insertion to cleanly separate them from surrounding content.

In [6]:
# First wrap with EOS tokens
wrapped = wrap_sequences_in_eos_tokens(token_sequences, SEQUENCE_LENGTH, EOS_TOKEN_ID)

print("Before wrapping:")
for seq in token_sequences:
    print(f"  {seq}")

print("\nAfter wrapping (EOS = 100257):")
for seq in wrapped:
    print(f"  {seq}")

Minimum sequence length: 2
Maximum sequence length: 4
Second maximum sequence length: 3
Dropped 0 overly long sequences (longer than 4096 tokens).
Dropped 0 empty sequences.
Minimum sequence length after wrapping: 4
Maximum sequence length after wrapping: 6
Second maximum sequence length after wrapping: 5
Before wrapping:
  [1000, 2000, 3000]
  [100, 200]
  [42, 43, 44, 45]

After wrapping (EOS = 100257):
  [100257, 1000, 2000, 3000, 100257]
  [100257, 100, 200, 100257]
  [100257, 42, 43, 44, 45, 100257]


In [7]:
# Then insert
insert_dict, _ = add_random_insertions(
    token_sequences=wrapped,
    start_idx=0,
    end_idx=50 * SEQUENCE_LENGTH,
    sequence_length=SEQUENCE_LENGTH,
    existing_insertions=IntervalSet(),
    rng=np.random.default_rng(seed=42)
)

print("Final insert_dict (with EOS tokens):")
print("{")
for pos, tokens in sorted(insert_dict.items()):
    print(f"    {pos}: {tokens},")
print("}")

100%|██████████| 3/3 [00:00<00:00, 34007.87it/s]

Total number of inserted tokens: 15
Avoided collisions while inserting sequences: 0
Final insert_dict (with EOS tokens):
{
    88695: [100257, 100, 200, 100257],
    156013: [100257, 1000, 2000, 3000, 100257],
    173803: [100257, 42, 43, 44, 45, 100257],
}





## 4. Combining Explicit and Random

Process explicit insertions first, then random. The random insertions automatically avoid the explicit positions.

In [8]:
interval_set = IntervalSet()

# Phase 1: Explicit
explicit_dict, interval_set = add_explicit_insertions(
    token_sequences=[[1, 2, 3]],
    positions=[8192],  # Exactly at position 8192
    existing_insertions=interval_set
)

# Phase 2: Random (uses the same interval_set, so won't overlap)
random_dict, interval_set = add_random_insertions(
    token_sequences=[[10, 20], [30, 40]],
    start_idx=0,
    end_idx=10 * SEQUENCE_LENGTH,
    sequence_length=SEQUENCE_LENGTH,
    existing_insertions=interval_set,
    rng=np.random.default_rng(seed=42)
)

# Combine
insert_dict = {**explicit_dict, **random_dict}

print("Combined insert_dict:")
print("{")
for pos, tokens in sorted(insert_dict.items()):
    source = "explicit" if pos in explicit_dict else "random"
    print(f"    {pos}: {tokens},  # {source}")
print("}")

100%|██████████| 2/2 [00:00<00:00, 28149.69it/s]

Total number of inserted tokens: 4
Avoided collisions while inserting sequences: 0
Combined insert_dict:
{
    8192: [1, 2, 3],  # explicit
    19064: [30, 40],  # random
    29037: [10, 20],  # random
}





## Summary

The `insert_dict` is a simple mapping:

```
{position: tokens, position: tokens, ...}
```

This gets passed to the framework (e.g., OLMo) which handles the actual insertion into the training data stream.