# Lazy Precover DFA: Bitset & Hash-Consing Optimizations

This notebook walks through the Python port of the three key optimizations from the Rust `transduction-core` crate:

1. **Integer packing** — Pack `(fst_state, buf_pos)` into a single `int`
2. **Hash-consing (PowersetArena)** — Deduplicate DFA states, with a singleton fast path
3. **Epsilon closure caching with productivity filtering** — Compute each closure once, keep only productive states

These three optimizations combine in `LazyPrecoverDFA`, a drop-in replacement for `PrecoverNFA(fst, target).det()` that can be dramatically faster on BPE-style FSTs.

In [1]:
from transduction.lazy_precover_dfa import PowersetArena, PackedPrecoverNFA, LazyPrecoverDFA
from transduction.precover_nfa import PrecoverNFA
from transduction.fst import FST, EPSILON
from transduction import examples

## 1. Integer Packing

The original `PrecoverNFA` uses tuple states `(fst_state, buffer_string)`:
```python
state = (0, ('x', 'a'))   # FST state 0, buffer has matched 'x','a'
```

The key insight: since the buffer is always a prefix of the target, we only need its **length**:
```python
packed = fst_state * stride + buf_pos   # single int!
```

This makes hashing O(1) instead of O(len(buffer)).

In [2]:
fst = examples.samuel_example()   # FST with epsilon arcs
target = ('c', 'x')

nfa = PackedPrecoverNFA(fst, target)
print(f"target_len = {nfa.target_len}")
print(f"stride     = {nfa.stride}  (= target_len + 1)")
print(f"num FST states = {len(nfa._state_map)}")
print()

# Pack and unpack a state
packed = nfa.pack(2, 1)   # FST state 2, buf_pos 1
print(f"pack(2, 1)   = {packed}")
print(f"unpack({packed})  = {nfa.unpack(packed)}")
print()

# Compare with the tuple-based representation:
print("Tuple state:  (2, ('c',))   — hashes a 2-tuple of (int, tuple)")
print(f"Packed state: {packed}              — hashes a single int")

target_len = 2
stride     = 3  (= target_len + 1)
num FST states = 5

pack(2, 1)   = 7
unpack(7)  = (2, 1)

Tuple state:  (2, ('c',))   — hashes a 2-tuple of (int, tuple)
Packed state: 7              — hashes a single int


## 2. Hash-Consing with PowersetArena

In powerset construction, DFA states are **sets of NFA states**. The original code uses `frozenset`, which:
- Hashes the entire set on every lookup
- Creates a new object for every distinct set
- Is O(n) to hash where n = set size

`PowersetArena` replaces this with **hash-consing**: each unique sorted tuple of NFA states is assigned an integer ID. Subsequent lookups return the same ID.

### Singleton fast path

For BPE FSTs, ~99% of powerset states are singletons (exactly one NFA state). The arena has a special `_single_map` that hashes just the bare int, avoiding tuple construction entirely.

In [3]:
arena = PowersetArena()

# Intern some NFA state sets
id0 = arena.intern((100,), any_final=False)       # singleton → fast path
id1 = arena.intern((200, 300), any_final=True)     # multi-element → general path
id2 = arena.intern((100,), any_final=True)         # same set → SAME ID, finality updated

print(f"id0 = {id0}  (singleton {{100}})")
print(f"id1 = {id1}  (multi-element {{200, 300}})")
print(f"id2 = {id2}  (same singleton {{100}} → same ID!)")
print()
print(f"is_final[id0] = {arena.is_final[id0]}  (was False, updated to True on re-intern)")
print(f"is_final[id1] = {arena.is_final[id1]}")
print()
print(f"Singletons stored in _single_map: {arena._single_map}")
print(f"Multi-element stored in _map:      {arena._map}")
print(f"Total interned: {len(arena)}")

id0 = 0  (singleton {100})
id1 = 1  (multi-element {200, 300})
id2 = 0  (same singleton {100} → same ID!)

is_final[id0] = True  (was False, updated to True on re-intern)
is_final[id1] = True

Singletons stored in _single_map: {100: 0}
Multi-element stored in _map:      {(200, 300): 1}
Total interned: 2


### Why finality is updated on cache hits

In incremental decomposition, the same NFA state set can be final at one target length but not another. By always overwriting `is_final` on cache hits, the arena stays correct when reused across steps.

## 3. Epsilon Closure Caching & Productivity Filtering

For FSTs with epsilon (input) arcs, the powerset construction must epsilon-close each NFA state. This is done via BFS.

### Productivity filtering

Not all states reachable via epsilon arcs matter. A state is **productive** if:
- Its FST state has at least one non-epsilon input arc (contributes DFA transitions), OR
- It's NFA-final (affects powerset finality)

**Transit-only** states (epsilon-input-only, non-final) are filtered out. For BPE FSTs where each token is a chain of epsilon arcs producing one byte each, this collapses the chain to just the endpoints.

In [4]:
fst = examples.samuel_example()
target = ('c', 'x')
nfa = PackedPrecoverNFA(fst, target)

# Show which FST states are productive
print("FST state productivity:")
for orig_state in sorted(fst.states):
    int_s = nfa._state_map(orig_state)
    has_neps = int_s in nfa._has_non_eps_input
    is_fin = fst.is_final(orig_state)
    print(f"  state {orig_state}: has_non_eps_input={has_neps}, is_final={is_fin}")

print()

# Demonstrate epsilon closure with filtering
for s in nfa.start_states():
    closure = nfa.eps_closure_single(s)
    fst_state, buf_pos = nfa.unpack(s)
    orig = nfa._inv_state_map[fst_state]
    print(f"eps_closure of ({orig}, buf={buf_pos}):")
    for cs in closure:
        cs_fst, cs_buf = nfa.unpack(cs)
        cs_orig = nfa._inv_state_map[cs_fst]
        prod = nfa.is_productive(cs)
        print(f"  ({cs_orig}, buf={cs_buf})  productive={prod}")

FST state productivity:
  state 0: has_non_eps_input=True, is_final=True
  state 1: has_non_eps_input=True, is_final=False
  state 2: has_non_eps_input=True, is_final=True
  state 3: has_non_eps_input=False, is_final=True
  state 4: has_non_eps_input=True, is_final=True

eps_closure of (0, buf=0):
  (0, buf=0)  productive=True


In [5]:
# Show cache hits vs misses
nfa2 = PackedPrecoverNFA(fst, target)

# First pass: all misses
for s in nfa2.start_states():
    nfa2.eps_closure_single(s)
hits1, misses1 = nfa2.eps_cache_stats()
print(f"After first pass:  hits={hits1}, misses={misses1}")

# Second pass: all hits
for s in nfa2.start_states():
    nfa2.eps_closure_single(s)
hits2, misses2 = nfa2.eps_cache_stats()
print(f"After second pass: hits={hits2}, misses={misses2}  (no new misses!)")

After first pass:  hits=0, misses=1
After second pass: hits=1, misses=1  (no new misses!)


## 4. LazyPrecoverDFA: Putting It All Together

The `LazyPrecoverDFA` wraps all three optimizations into a single `Lazy` automaton:

```
PackedPrecoverNFA  ─→  PowersetArena  ─→  LazyPrecoverDFA
  (packed ints)         (hash-consing)      (Lazy interface)
  (cached closures)     (singleton fast)    (arc caching)
```

It's a drop-in replacement for `PrecoverNFA(fst, target).det()`.

In [6]:
fst = examples.samuel_example()
target = ('c', 'x')

# Build the optimized DFA
dfa = LazyPrecoverDFA(fst, target)

# Explore from the start state
[start] = list(dfa.start())
print(f"Start state: {start} (is_final={dfa.is_final(start)})")
print(f"Powerset size: {dfa.powerset_size(start)} NFA states")
print()

# Show arcs (lazily computed on first access)
print("Arcs from start:")
for sym, dest in dfa.arcs(start):
    print(f"  {sym!r} → state {dest}  (final={dfa.is_final(dest)}, powerset_size={dfa.powerset_size(dest)})")

print(f"\nDFA states created so far: {dfa.num_states()}")

Start state: 0 (is_final=False)
Powerset size: 1 NFA states

Arcs from start:
  'a' → state 1  (final=False, powerset_size=2)

DFA states created so far: 2


In [7]:
# Materialize the full DFA to see all states
materialized = dfa.materialize()
print(f"Materialized DFA: {len(materialized.states)} states")
print(f"Stats after full expansion:")
for k, v in dfa.stats().items():
    print(f"  {k}: {v}")

Materialized DFA: 3 states
Stats after full expansion:
  num_dfa_states: 3
  num_expanded: 3
  avg_powerset_size: 1.3333333333333333
  max_powerset_size: 2
  singleton_fraction: 0.6666666666666666
  eps_cache_size: 5
  eps_cache_hits: 2
  eps_cache_misses: 5


## 5. Inspecting DFA Internals

Each DFA state is backed by a set of packed NFA states. We can unpack them to see what's inside.

In [8]:
fst = examples.samuel_example()
target = ('c', 'x', 'x')
dfa = LazyPrecoverDFA(fst, target)
dfa.materialize()   # expand all states

print(f"DFA has {dfa.num_states()} states\n")

for sid in range(dfa.num_states()):
    arcs = dfa.arcs(sid)
    nfa_set = dfa.nfa_states(sid)
    final_str = " [FINAL]" if dfa.is_final(sid) else ""
    print(f"DFA state {sid}{final_str}  ({len(nfa_set)} NFA states):")
    for packed in nfa_set:
        int_s, buf_pos = dfa.unpack_nfa_state(packed)
        orig = dfa._nfa._inv_state_map[int_s]
        print(f"    (fst={orig}, buf_pos={buf_pos})")
    if arcs:
        for sym, dest in arcs:
            print(f"  {sym!r} → {dest}")
    else:
        print(f"  (no arcs)")
    print()

DFA has 4 states

DFA state 0  (1 NFA states):
    (fst=0, buf_pos=0)
  'a' → 1

DFA state 1  (2 NFA states):
    (fst=1, buf_pos=0)
    (fst=2, buf_pos=1)
  'a' → 2
  'b' → 2

DFA state 2  (1 NFA states):
    (fst=4, buf_pos=2)
  'a' → 3
  'b' → 3

DFA state 3 [FINAL]  (1 NFA states):
    (fst=4, buf_pos=3)
  'a' → 3
  'b' → 3



## 6. Equivalence with the Reference

Let's verify that `LazyPrecoverDFA` accepts exactly the same language as `PrecoverNFA(...).det()`.

In [9]:
from collections import deque

def accepted_strings(dfa, alphabet, max_len=5):
    """Enumerate all accepted strings up to max_len."""
    accepted = set()
    [start] = list(dfa.start())
    worklist = deque([(start, ())])
    visited = {(start, ())}
    while worklist:
        state, path = worklist.popleft()
        if dfa.is_final(state):
            accepted.add(path)
        if len(path) >= max_len:
            continue
        for x in alphabet:
            for dest in dfa.arcs_x(state, x):
                key = (dest, path + (x,))
                if key not in visited:
                    visited.add(key)
                    worklist.append((dest, path + (x,)))
    return accepted


test_cases = [
    (examples.small(), ('x',), "small / target=('x',)"),
    (examples.samuel_example(), ('c', 'x'), "samuel / target=('c','x')"),
    (examples.delete_b(), ('A',), "delete_b / target=('A',)"),
    (examples.triplets_of_doom(), ('a', 'a', 'a'), "triplets / target=('a','a','a')"),
    (examples.lookahead(), ('x', 'a'), "lookahead / target=('x','a')"),
]

for fst, target, name in test_cases:
    alpha = fst.A - {EPSILON}
    ref = PrecoverNFA(fst, target).det()
    opt = LazyPrecoverDFA(fst, target)
    ref_s = accepted_strings(ref, alpha)
    opt_s = accepted_strings(opt, alpha)
    match = "✓" if ref_s == opt_s else "✗"
    print(f"  {match}  {name:40s}  |L| = {len(ref_s)}")

  ✓  small / target=('x',)                     |L| = 31
  ✓  samuel / target=('c','x')                 |L| = 30
  ✓  delete_b / target=('A',)                  |L| = 57
  ✓  triplets / target=('a','a','a')           |L| = 1
  ✓  lookahead / target=('x','a')              |L| = 14


## 7. Performance on BPE-like FSTs

The optimizations really shine on BPE-like FSTs, which have:
- Long epsilon chains (one arc per byte in a token)
- Many singleton powerset states (~99%)
- Large state spaces that benefit from hash-consing

In [10]:
import time

fst = examples.bpe_like(vocab_size=50, alphabet=tuple("abc"), max_len=4)
target = tuple("abcabc")

# Time the optimized version
t0 = time.perf_counter()
dfa = LazyPrecoverDFA(fst, target)
dfa.materialize()
t_opt = time.perf_counter() - t0

print(f"LazyPrecoverDFA: {t_opt:.4f}s")
print(f"\nStats:")
for k, v in dfa.stats().items():
    if isinstance(v, float):
        print(f"  {k}: {v:.3f}")
    else:
        print(f"  {k}: {v}")

print(f"\nKey takeaways:")
stats = dfa.stats()
print(f"  - {stats['singleton_fraction']*100:.0f}% of DFA states are singletons (fast path)")
total = stats['eps_cache_hits'] + stats['eps_cache_misses']
if total > 0:
    print(f"  - Eps cache hit rate: {stats['eps_cache_hits']/total*100:.0f}% ({stats['eps_cache_hits']}/{total})")
print(f"  - {stats['num_dfa_states']} DFA states, avg powerset size {stats['avg_powerset_size']:.1f}")

LazyPrecoverDFA: 0.0036s

Stats:
  num_dfa_states: 7
  num_expanded: 7
  avg_powerset_size: 11.571
  max_powerset_size: 51
  singleton_fraction: 0.000
  eps_cache_size: 7
  eps_cache_hits: 74
  eps_cache_misses: 7

Key takeaways:
  - 0% of DFA states are singletons (fast path)
  - Eps cache hit rate: 91% (74/81)
  - 7 DFA states, avg powerset size 11.6


## 8. How It Maps to the Rust Code

| Python | Rust | File |
|--------|------|------|
| `PowersetArena` | `PowersetArena` | `powerset.rs` |
| `PowersetArena._single_map` | `single_map: FxHashMap<u64, u32>` | singleton fast path |
| `PowersetArena._map` | `map: FxHashMap<Vec<u64>, u32>` | general path |
| `PackedPrecoverNFA` | `PrecoverNFA` | `precover.rs` |
| `pack(s, p) = s * stride + p` | `pack(s, p, tl) = s * (tl+1) + p` | same formula |
| `eps_closure_single` | `eps_closure_single_cached` | BFS + filter + cache |
| `is_productive` | `is_productive` | non-eps input OR final |
| `compute_all_arcs` | `compute_all_arcs_into` | batch arc computation |
| `LazyPrecoverDFA` | `LazyPrecoverDFA` | `lazy_precover.rs` |
| `_ensure_arcs` | `ensure_arcs_for` | lazy expansion + caching |
| `_arcs_buf` (reusable dict) | `arcs_buf: FxHashMap` | buffer reuse pattern |

### Key difference

In Rust, the NFA is re-created as a temporary for each DFA expansion step, and the epsilon cache is transferred in/out via `take_eps_cache()`. This is due to Rust's borrow checker. In Python, we just keep one NFA instance alive — same effect since the FST doesn't change.