In [1]:
%load_ext autoreload
%autoreload 2

# TransducedLM vs FusedTransducedLM Benchmark

Compares two approaches to computing next-symbol log-probabilities through
an FST on the Penn Treebank tokenizer (~296 states, 257 input symbols):

- **TransducedLM**: two-phase (PeekabooState BFS decomposition, then LM-weighted search)
- **FusedTransducedLM**: single-pass (interleaved decomposition + LM search, no separate BFS)

Uses a 3-gram CharNgramLM as the inner LM, with per-call timeouts and a
process-wide memory limit.

In [None]:
import signal, resource, time, gc
import numpy as np

# ---- Safety: memory limit (8 GB virtual address space) ----
_GB = 1024 ** 3
_soft, _hard = resource.getrlimit(resource.RLIMIT_AS)
resource.setrlimit(resource.RLIMIT_AS, (8 * _GB, _hard))

class _Timeout(Exception):
    pass

def _alarm(signum, frame):
    raise _Timeout()

def timed(fn, timeout_s=30, label=''):
    """Run fn() with wall-clock timeout. Returns (result, elapsed_s) or (None, None)."""
    prev = signal.signal(signal.SIGALRM, _alarm)
    signal.alarm(timeout_s)
    try:
        t0 = time.perf_counter()
        out = fn()
        return out, time.perf_counter() - t0
    except _Timeout:
        print(f'  {label} TIMEOUT ({timeout_s}s)')
        return None, None
    except MemoryError:
        print(f'  {label} OOM')
        return None, None
    except Exception as e:
        print(f'  {label} ERROR: {type(e).__name__}: {e}')
        return None, None
    finally:
        signal.alarm(0)
        signal.signal(signal.SIGALRM, prev)

# ---- Remap multi-char PTB symbols to single Unicode chars ----
# PeekabooPrecover NFA uses string concatenation for buffers and indexes by
# character position, which breaks for multi-character symbol names like '84'.
def remap_fst_to_single_chars(fst):
    from transduction.fst import FST as FSTClass
    from transduction.fsa import EPSILON as _EPS
    fwd, inv = {}, {}
    code = 0xE000  # Unicode private-use area
    for sym in sorted((fst.A | fst.B) - {_EPS}):
        fwd[sym] = chr(code)
        inv[chr(code)] = sym
        code += 1
    new_fst = FSTClass()
    for s in fst.start: new_fst.add_start(s)
    for s in fst.stop:  new_fst.add_stop(s)
    for s in fst.states:
        for x, y, j in fst.arcs(s):
            new_x = fwd.get(x, x) if x != _EPS else _EPS
            new_y = fwd.get(y, y) if y != _EPS else _EPS
            new_fst.add_arc(s, new_x, new_y, j)
    return new_fst, fwd, inv

# ---- Build PTB FST (requires pynini) ----
from transduction.applications.ptb import build_ptb_fst_pynini, string_to_byte_strs, decode_ptb_output
from transduction.fst import FST
from transduction.fsa import EPSILON

t0 = time.perf_counter()
raw_fst = build_ptb_fst_pynini()
ptb_fst, fwd_map, inv_map = remap_fst_to_single_chars(raw_fst)
print(f'PTB FST built in {time.perf_counter()-t0:.1f}s: '
      f'{len(ptb_fst.states)} states, |A|={len(ptb_fst.A)}, |B|={len(ptb_fst.B)}')

# Generate target sequence
text = "The quick brown fox jumps over the lazy dog."
byte_strs = string_to_byte_strs(text)
remapped_input = tuple(fwd_map[s] for s in byte_strs)
input_fst_obj = FST.from_string(remapped_input)
output_fsa = (input_fst_obj @ ptb_fst).project(1)
target_seq = list(next(output_fsa.language()))
decoded = decode_ptb_output(tuple(inv_map.get(c, c) for c in target_seq))
print(f'Target: {len(target_seq)} symbols')
print(f'  {decoded!r}')

# Train inner LM for TransducedLM benchmarks
from transduction.lm.ngram import CharNgramLM
source_alpha = ptb_fst.A - {EPSILON}
train_text = (
    "The quick brown fox jumps over the lazy dog. "
    "A stitch in time saves nine. To be or not to be, that is the question. "
    "All that glitters is not gold. Actions speak louder than words. "
    "Practice makes perfect. Where there is a will, there is a way. "
) * 3
train_syms = [fwd_map[s] for s in string_to_byte_strs(train_text)]
for sym in source_alpha:
    train_syms.append(sym)
inner_lm = CharNgramLM.train(train_syms, n=3, alpha=0.5)
print(f'Inner LM: alphabet={len(inner_lm.alphabet)} symbols')

## TransducedLM Scaling

Per-step decode time for **TransducedLM** (two-phase: PeekabooState BFS
decomposition, then LM-weighted search) vs **FusedTransducedLM** (single-pass:
interleaved decomposition + LM search, no separate BFS).

Each step includes both decomposition and LM search costs.  For TransducedLM,
the PeekabooState BFS dominates (~35s per step on PTB).  FusedTransducedLM
avoids the BFS entirely but builds the lazy DFA inline during search.

Both use `max_steps=200`, `max_beam=100`, with a 120s timeout per step.

In [None]:
from collections import defaultdict
from transduction.lm.transduced import TransducedLM
from transduction.lm.fused_transduced import FusedTransducedLM

MAX_DECODE = 10              # number of decode steps
MAX_SEARCH = 200             # max priority-queue steps per logp_next
MAX_BEAM = 100               # max items carried forward
LM_TIMEOUT = 120             # seconds per step

# Expected runtime: ~15-20 min total (TransducedLM ~65s/step, FusedTransducedLM ~35s/step)

lm_results = defaultdict(list)  # name -> [(step, time_s, logp)]

for name, cls in [('TransducedLM', TransducedLM),
                  ('FusedTransducedLM', FusedTransducedLM)]:
    print(f'\n{name} (max_steps={MAX_SEARCH}, max_beam={MAX_BEAM}):')
    tlm = cls(inner_lm, ptb_fst, max_steps=MAX_SEARCH, max_beam=MAX_BEAM)
    state = tlm.initial()
    for i in range(min(MAX_DECODE, len(target_seq))):
        y = target_seq[i]
        def step(s=state, y=y):
            lp = s.logp_next[y]
            return s >> y, lp
        out, t = timed(step, timeout_s=LM_TIMEOUT, label=f'step {i+1}')
        if t is None:
            break
        state, lp = out
        lm_results[name].append((i + 1, t, lp))
        print(f'  {i+1:2d}: {t*1000:8.1f} ms  logp={lp:.4f}')
    gc.collect()

# Summary table
print(f'\n{"Algorithm":<25s} {"Total (s)":>10s} {"Avg/step (s)":>12s} {"Steps":>6s}')
print('-' * 55)
for name, data in sorted(lm_results.items()):
    total = sum(t for _, t, _ in data)
    avg = total / len(data)
    print(f'{name:<25s} {total:10.1f} {avg:12.1f} {len(data):6d}')
if len(lm_results) == 2:
    names = sorted(lm_results.keys())
    d0, d1 = lm_results[names[0]], lm_results[names[1]]
    t0 = sum(t for _, t, _ in d0)
    t1 = sum(t for _, t, _ in d1)
    if t1 > 0:
        print(f'\nFused speedup (overall): {t0/t1:.2f}x')
    # Exclude step 1 (amortization penalty for Fused)
    if len(d0) > 1 and len(d1) > 1:
        t0_skip1 = sum(t for _, t, _ in d0[1:])
        t1_skip1 = sum(t for _, t, _ in d1[1:])
        if t1_skip1 > 0:
            print(f'Fused speedup (step 2+): {t0_skip1/t1_skip1:.2f}x')


TransducedLM (max_steps=200, max_beam=100):


In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Left: time per step
ax = axes[0]
for name, data in sorted(lm_results.items()):
    steps = [d[0] for d in data]
    times = [d[1] for d in data]
    ax.plot(steps, times, 'o-', label=name, markersize=4)
ax.set_xlabel('Target step')
ax.set_ylabel('Time per step (s)')
ax.set_title(f'TransducedLM vs Fused (PTB, max_steps={MAX_SEARCH})')
ax.legend()
ax.grid(True, alpha=0.3)

# Right: per-step speedup
ax = axes[1]
if len(lm_results) == 2:
    names = sorted(lm_results.keys())
    d0, d1 = lm_results[names[0]], lm_results[names[1]]
    n = min(len(d0), len(d1))
    steps = [d0[i][0] for i in range(n)]
    speedups = [d0[i][1] / d1[i][1] if d1[i][1] > 0 else 0 for i in range(n)]
    colors = ['#2ecc71' if s > 1 else '#e74c3c' for s in speedups]
    ax.bar(steps, speedups, color=colors, alpha=0.7, edgecolor='white')
    ax.axhline(1.0, color='black', linestyle='--', linewidth=0.8)
    ax.set_xlabel('Target step')
    ax.set_ylabel('Speedup (Original / Fused)')
    ax.set_title('Per-step speedup (>1 = Fused faster)')
    ax.grid(True, alpha=0.3, axis='y')

    logp_diffs = [abs(d0[i][2] - d1[i][2]) for i in range(n)]
    print(f'Max |logp| diff: {max(logp_diffs):.6f}')

plt.tight_layout()
plt.show()

In [None]:
# Restore original memory limit
resource.setrlimit(resource.RLIMIT_AS, (_soft, _hard))
print('Memory limit restored.')
