# PTB Benchmark Analysis

Analysis of precover decomposition benchmark results on the PTB tokenizer FST.

# Audit: Inspect Individual Automata

Run decomposition at a specific prefix and visualize Q/R.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import time
import transduction_core
from transduction.rust_bridge import to_rust_fst
from tqdm.auto import tqdm
from transduction.applications.ptb import build_ptb_fst_pynini, string_to_byte_strs, SEP
from transduction.applications.wikitext import load_wikitext, wikitext_detokenize
from transduction.fst import FST
from transduction.fsa import EPSILON
from transduction.rust_bridge import RustDecomp
from transduction.vibes import visualize_automaton
from arsenal import timeit

In [None]:
def fmt_sym(s):
    """Convert byte symbol to readable string."""
    if s == EPSILON:
        return 'ε'
    if s == SEP:
        return '<SEP>'
    assert isinstance(s, int)
    try:
        i = int(s)        
        if i == 32:
            return '␣'  # visible space
        return repr(bytes([i]))[2:-1]  # e.g. 'a', '\\n', '\\x00'
    except:
        return str(s)

In [None]:
with timeit('build'):
    ptb_fst = build_ptb_fst_pynini()

In [None]:
# Load first paragraph
for item in load_wikitext("test"):
    text = item["text"].strip()
    if text and not text.startswith("="):
        break

In [None]:
# Apply the tokenizer to the raw corpus to get a complete target string.
from arsenal import timeit

with timeit('transduce'):
    detok = wikitext_detokenize(text)
    target_full = ptb_fst.transduce(string_to_byte_strs(detok))

In [None]:
print(f"Text: {detok[:80]}...")
print('|'.join(map(fmt_sym, target_full)))
print(f"Total symbols: {len(target_full)}")

In [None]:
# Run decomposition - change PREFIX_LEN to inspect different positions
# Try: 20, 23, 46, 49 for non-empty remainder cases
PREFIX_LEN = 20

target = target_full[:PREFIX_LEN]
print((colors.dark.white % '|').join(map(fmt_sym, target)))

from arsenal import timeit, colors

with timeit('decomp'):
    result = RustDecomp(ptb_fst, target)

with timeit('post'):
   Q, R = result.quotient, result.remainder

with timeit('decomp (minimize=True)'):
    result_min = RustDecomp(ptb_fst, target, minimize=True)

with timeit('post (minimize=True)'):
    Q_min, R_min = result_min.quotient, result_min.remainder

print(f"\nRaw:       Q={len(Q.states)} states, R={len(R.states)} states")
print(f"Minimized: Q={len(Q_min.states)} states, R={len(R_min.states)} states")

In [None]:
# Decode target
with timeit('decode'):
    tokens = []
    buf = []
    for sym in target:
        if sym == SEP:
            if buf: tokens.append(bytes([int(b) for b in buf]).decode('utf-8', errors='replace'))
            buf = []
        elif sym != EPSILON:
            buf.append(sym)
    if buf: tokens.append(bytes([int(b) for b in buf]).decode('utf-8', errors='replace'))

In [None]:
print(f"Prefix {PREFIX_LEN}: {' | '.join(tokens)}")
print(f"Q: {len(Q.states)} states, {len(Q.stop)} final")
print(f"R: {len(R.states)} states, {len(R.stop)} final")

In [None]:
len(Q.states), len(Q_min.states)

In [None]:
len(list(Q.arcs())), len(list(Q_min.arcs()))

In [None]:
visualize_automaton(Q_min.map_labels(lambda x: bytes([x])))

In [None]:
visualize_automaton(R_min.map_labels(lambda x: bytes([x])))

In [None]:
t = 0
for x in Q_min.language(tuple=True):
    t += 1
    if t > 20: break
    print(''.join(fmt_sym(y) for y in x))

In [None]:
t = 0
for x in R_min.language(tuple=True):
    t += 1
    if t > 20: break
    print(''.join(fmt_sym(y) for y in x))

## Q/R Size and Timing vs. Prefix Length

In [None]:
# Cache the Rust FST conversion (avoids repeated Python→Rust conversion)
rust_fst_cached, sym_map_cached, _ = to_rust_fst(ptb_fst)

# Pre-build all decomposers
dirty_decomp = transduction_core.RustDirtyStateDecomp(rust_fst_cached)

# Collect Q/R sizes and timing as a function of prefix length (all algorithms).
# NOTE: We use wall-clock time (time.perf_counter) rather than stats.total_ms
# because stats.total_ms excludes the ip_universal_states precomputation that
# rust_decompose does on every call but the other variants cache.
data = {
    'pos': [],
    # rust_decompose (batch) — no caching
    'batch_ms': [],
    'dfa_states': [], 'Q_min': [], 'R_min': [],
    # RustDirtyStateDecomp (persists DFA structure, skips clean states)
    'dirty_ms': [],
    'dirty_compute_arcs_calls': [],  # states fully expanded (dirty/border/new)
    'dirty_intern_calls': [],        # arcs created (one intern per arc)
    'dirty_dfa_states': [],          # total arena size (grows monotonically)
    'dirty_eps_hits': [], 'dirty_eps_misses': [],
}

for pos in tqdm(range(1, len(target_full) + 1), desc="prefix sweep"):
    target_u32 = [sym_map_cached(y) for y in target_full[:pos]]

    # Batch decomposition (no caching); perf_counter returns seconds, *1000 → ms
    t0 = time.perf_counter()
    d = transduction_core.rust_decompose(rust_fst_cached, target_u32, minimize=False)
    data['batch_ms'].append((time.perf_counter() - t0) * 1000)

    # Dirty-state (persists DFA structure, only re-expands dirty/border states)
    t0 = time.perf_counter()
    dd = dirty_decomp.decompose(target_u32, False)
    data['dirty_ms'].append((time.perf_counter() - t0) * 1000)

    data['pos'].append(pos)
    data['dfa_states'].append(d.stats.dfa_states)
    data['Q_min'].append(d.quotient.num_states())
    data['R_min'].append(d.remainder.num_states())
    data['dirty_compute_arcs_calls'].append(dd.stats.compute_arcs_calls)
    data['dirty_intern_calls'].append(dd.stats.intern_calls)
    data['dirty_dfa_states'].append(dd.stats.dfa_states)
    data['dirty_eps_hits'].append(dd.stats.eps_cache_hits)
    data['dirty_eps_misses'].append(dd.stats.eps_cache_misses)

In [None]:
import numpy as np
import matplotlib.pyplot as plt

fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 9), sharex=True)

# Top panel: automaton sizes
ax1.plot(data["pos"], data["dfa_states"], label="DFA states (raw)", alpha=0.7)
ax1.plot(data["pos"], data["Q_min"], label="Q (minimized)", alpha=0.7, linestyle="--")
ax1.plot(data["pos"], data["R_min"], label="R (minimized)", alpha=0.7, linestyle="--")
ax1.set_ylabel("Number of states")
ax1.set_title("Automaton Size vs. Prefix Length")
ax1.legend()
ax1.grid(True, alpha=0.3)

# Bottom panel: wall-clock timing (log scale)
for key, label, color in [
    ('batch_ms', 'batch (no caching)', 'C0'),
    ('dirty_ms', 'dirty-state (incr BFS + univ$)', 'C3'),
]:
    arr = np.array(data[key])
    ax2.scatter(data["pos"], arr, alpha=0.2, s=1, color=color)
    # Running average
    window = 20
    if len(arr) >= window:
        avg = np.convolve(arr, np.ones(window)/window, mode='valid')
        ax2.plot(data["pos"][window-1:], avg, label=label, linewidth=2, color=color, alpha=0.8)

ax2.set_xlabel("Prefix length")
ax2.set_ylabel("Time (ms)")
ax2.set_yscale("log")
ax2.set_title("Wall-Clock Timing: Batch vs. Dirty-State")
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
total_symbols = len(target_full)
batch_total = sum(data['batch_ms'])
dirty_total = sum(data['dirty_ms'])

print(f"Total symbols: {total_symbols}")
print(f"{'Variant':<35s} {'Total ms':>10s} {'sym/sec':>10s} {'Speedup':>10s}")
print("-" * 70)
for name, total in [
    ('batch (no caching)', batch_total),
    ('dirty-state (incr BFS + univ$)', dirty_total),
]:
    throughput = total_symbols / (total / 1000)
    speedup = batch_total / total
    print(f"{name:<35s} {total:>10.0f} {throughput:>10.0f} {speedup:>10.2f}x")


## Dirty-State Analysis

The dirty-state variant persists the entire DFA structure (arena + per-state arcs + state classification) across calls. On each prefix extension, it:
1. Marks **dirty** states (NFA set contains elements at `buf_pos >= frontier`)
2. Marks **border** states (clean states with arcs to dirty states)
3. Re-expands only dirty + border states; **clean states copy cached arcs**

Additionally, it caches **universality results at the FST-state level**. For pure-frontier DFA states (all NFA elements at `buf_pos == target_len`), the universality sub-BFS only explores states at `buf_pos == target_len`, making the result purely FST-topology-dependent and target-independent. This cache never needs eviction and eliminates ~98% of the original runtime (the universality sub-BFS).

In [None]:
dirty_arr = np.array(data['dirty_ms'])
batch_arr = np.array(data['batch_ms'])
dirty_arcs = np.array(data['dirty_compute_arcs_calls'])
dfa_states = np.array(data['dfa_states'])

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Top-left: states re-expanded vs total DFA states
ax = axes[0, 0]
ax.plot(data['pos'], dfa_states, label='DFA states (batch)', alpha=0.7)
ax.plot(data['pos'], dirty_arcs, label='States re-expanded (dirty)', alpha=0.7)
ax.set_xlabel('Prefix length')
ax.set_ylabel('Count')
ax.set_title('States Re-Expanded vs. Total DFA States')
ax.legend()
ax.grid(True, alpha=0.3)

# Top-right: fraction of states re-expanded
reexpand_frac = dirty_arcs / np.maximum(dfa_states, 1) * 100
ax = axes[0, 1]
ax.plot(data['pos'], reexpand_frac, alpha=0.7, linewidth=1)
window = 20
if len(reexpand_frac) >= window:
    avg = np.convolve(reexpand_frac, np.ones(window)/window, mode='valid')
    ax.plot(data['pos'][window-1:], avg, color='red', linewidth=2,
            label=f'running avg (w={window}), mean={reexpand_frac[1:].mean():.1f}%')
ax.set_xlabel('Prefix length')
ax.set_ylabel('% of DFA states re-expanded')
ax.set_title('Fraction of DFA States Re-Expanded')
ax.legend()
ax.grid(True, alpha=0.3)

# Bottom-left: all 3 variants timing (log scale to show dirty-state separation)
ax = axes[1, 0]
for arr, label, color in [
    (batch_arr, 'batch', 'C0'),
    (dirty_arr, 'dirty-state', 'C3'),
]:
    ax.scatter(data['pos'], arr, alpha=0.2, s=1, color=color)
    if len(arr) >= window:
        avg = np.convolve(arr, np.ones(window)/window, mode='valid')
        ax.plot(data['pos'][window-1:], avg, label=label, linewidth=2, color=color, alpha=0.8)
ax.set_xlabel('Prefix length')
ax.set_ylabel('Time (ms)')
ax.set_yscale('log')
ax.set_title('Per-Position Timing (log scale)')
ax.legend()
ax.grid(True, alpha=0.3)

# Bottom-right: per-position speedup dirty vs batch
dirty_over_batch = batch_arr / np.maximum(dirty_arr, 0.01)
ax = axes[1, 1]
ax.scatter(data['pos'], dirty_over_batch, alpha=0.3, s=2)
if len(dirty_over_batch) >= window:
    avg = np.convolve(dirty_over_batch, np.ones(window)/window, mode='valid')
    ax.plot(data['pos'][window-1:], avg, color='red', linewidth=2,
            label=f'running avg (w={window})')
ax.axhline(y=1.0, color='gray', linestyle='--', alpha=0.5, label='break-even')
ax.set_xlabel('Prefix length')
ax.set_ylabel('Speedup (batch / dirty)')
ax.set_title('Per-Position Speedup: Dirty-State over Batch')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Summary stats
print(f"Mean states re-expanded per position: {dirty_arcs[1:].mean():.1f} / {dfa_states[1:].mean():.1f} "
      f"({reexpand_frac[1:].mean():.1f}%)")
print(f"Mean dirty/batch speedup: {dirty_over_batch[1:].mean():.1f}x")
print(f"Median dirty/batch speedup: {np.median(dirty_over_batch[1:]):.1f}x")

## Direct Correlation: Time vs. Arc Churn

The key claim for incremental decomposition is that per-step cost scales with the
**change** in the DFA graph, not the **total** DFA size.

We measure change as `intern_calls` — the number of arcs created per step (one
`arena.intern()` call per arc). Each re-expanded state first has its old arcs
removed, then new arcs are created via BFS, so `intern_calls` directly measures
the arc-level work done.

- **Left panel**: Scatter of per-position wall-clock time vs. arcs created
  (`intern_calls`). A strong linear fit confirms cost ∝ arc churn.
- **Right panel**: Scatter of per-position wall-clock time vs. total DFA arena size
  (`dfa_states`). No correlation confirms cost is independent of total size.

In [None]:
dirty_arr = np.array(data['dirty_ms'])
dirty_arcs_created = np.array(data['dirty_intern_calls'])
dirty_dfa = np.array(data['dirty_dfa_states'])

# Skip step 1 (cold start) for steady-state analysis
ss = slice(1, None)
t_us = dirty_arr[ss] * 1000  # ms → μs
change = dirty_arcs_created[ss]
total = dirty_dfa[ss]

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(13, 5.5))

# --- Left: time vs arcs created (should correlate) ---
ax1.scatter(change, t_us, alpha=0.25, s=8, color='tab:blue', edgecolors='none')

# Linear fit through origin: t ≈ slope * change
slope = np.dot(change.astype(float), t_us) / np.dot(change.astype(float), change.astype(float))
xs = np.linspace(0, change.max() * 1.05, 100)
ax1.plot(xs, slope * xs, 'r--', linewidth=1.5, label=f'fit: {slope:.1f} μs/arc')

ss_res = np.sum((t_us - slope * change) ** 2)
ss_tot = np.sum((t_us - t_us.mean()) ** 2)
r2 = 1 - ss_res / ss_tot

ax1.set_xlabel('arcs created (intern_calls)', fontsize=11)
ax1.set_ylabel('time per step (μs)', fontsize=11)
ax1.set_title(f'Time vs. Arcs Created (R² = {r2:.2f})', fontsize=12)
ax1.legend(fontsize=9)
ax1.set_xlim(left=0)
ax1.set_ylim(bottom=0)
ax1.grid(True, alpha=0.3)

# --- Right: time vs total DFA size (should NOT correlate) ---
ax2.scatter(total, t_us, alpha=0.25, s=8, color='tab:orange', edgecolors='none')

r_total = np.corrcoef(total.astype(float), t_us)[0, 1]

ax2.set_xlabel('total DFA arena size (dfa_states)', fontsize=11)
ax2.set_ylabel('time per step (μs)', fontsize=11)
ax2.set_title(f'Time vs. Total DFA Size (r = {r_total:.2f})', fontsize=12)
ax2.set_ylim(bottom=0)
ax2.grid(True, alpha=0.3)

# Add annotation
ax2.annotate(f'DFA grows {total.min()}→{total.max()} states\n'
             f'but time stays {np.median(t_us):.0f} μs (median)',
             xy=(0.95, 0.95), xycoords='axes fraction',
             ha='right', va='top', fontsize=9,
             bbox=dict(boxstyle='round,pad=0.3', facecolor='wheat', alpha=0.5))

plt.tight_layout()
plt.show()

print(f"Time vs arcs created:    R² = {r2:.3f}  (slope = {slope:.1f} μs/arc)")
print(f"Time vs total DFA size:  r  = {r_total:.3f}")
print(f"Median time: {np.median(t_us):.0f} μs, DFA range: {total.min()}–{total.max()} states")
print(f"Median arcs created/step: {np.median(change):.0f}")

## Conclusions

**Key finding:** The dirty-state incremental decomposition achieves **~40x speedup**
over batch decomposition on the PTB FST.

The decisive optimization is the **FST-level universality cache**. For "pure frontier"
DFA states (all NFA elements at `buf_pos == target_len`), universality depends only on
the FST state set, not the target string. Caching this permanently eliminates ~98% of
the original runtime.

**Architecture:** The dirty-state variant persists: (1) the PowersetArena, (2) per-state
cached arcs and classification (NEW/INTERIOR/QSTOP/RSTOP), (3) the UniversalityFilter,
(4) the eps_cache, and (5) a permanent `fst_univ_cache` mapping FST state sets to
universality results. On prefix extension, it marks dirty/border states for re-expansion
while clean states copy cached arcs.