In [1]:
%load_ext autoreload
%autoreload 2

# TransducedLM vs FusedTransducedLM Benchmark

Compares two approaches to computing next-symbol log-probabilities through
an FST on the Penn Treebank tokenizer (~296 states, 257 input symbols):

- **TransducedLM**: two-phase (PeekabooState BFS decomposition, then LM-weighted search)
- **FusedTransducedLM**: single-pass (interleaved decomposition + LM search, no separate BFS)

Uses a 3-gram CharNgramLM as the inner LM, with per-call timeouts and a
process-wide memory limit.

In [2]:
import time, gc
import numpy as np

from transduction.util import Timeout, timelimit, set_memory_limit

# ---- Safety: memory limit (8 GB virtual address space) ----
import resource
_soft, _hard = resource.getrlimit(resource.RLIMIT_AS)
set_memory_limit(8)

# ---- Remap multi-char PTB symbols to single Unicode chars via map_labels ----
from transduction.applications.ptb import build_ptb_fst_pynini, string_to_byte_strs, decode_ptb_output
from transduction.fst import FST
from transduction.fsa import EPSILON

In [3]:
t0 = time.perf_counter()
raw_fst = build_ptb_fst_pynini()

# Build forward/inverse maps for all non-epsilon symbols
fwd_map, inv_map = {}, {}
code = 0xE000  # Unicode private-use area
for sym in sorted((raw_fst.A | raw_fst.B) - {EPSILON}):
    fwd_map[sym] = chr(code)
    inv_map[chr(code)] = sym
    code += 1

ptb_fst = raw_fst.map_labels(lambda a, b: (fwd_map.get(a, a), fwd_map.get(b, b)))
print(f'PTB FST built in {time.perf_counter()-t0:.1f}s: '
      f'{len(ptb_fst.states)} states, |A|={len(ptb_fst.A)}, |B|={len(ptb_fst.B)}')

Composing PTB rules...
Core PTB FST: 310 states
Final pynini FST: 296 states
Converting to native FST...
Native FST: 296 states, 23723 arcs
  eps: 108 in, 352 out
  MARKER: 0 in, 0 out
  [EOS]: 0 in, 0 out
PTB FST built in 31.1s: 296 states, |A|=257, |B|=256


In [4]:
# Generate target sequence
text = "The quick brown fox jumps over the lazy dog."
byte_strs = string_to_byte_strs(text)
remapped_input = tuple(fwd_map[s] for s in byte_strs)
input_fst_obj = FST.from_string(remapped_input)
output_fsa = (input_fst_obj @ ptb_fst).project(1)
target_seq = list(next(output_fsa.language()))
decoded = decode_ptb_output(tuple(inv_map.get(c, c) for c in target_seq))
print(f'Target: {len(target_seq)} symbols')
print(f'  {decoded!r}')

# Train inner LM for TransducedLM benchmarks
from transduction.lm.ngram import CharNgramLM
source_alpha = ptb_fst.A - {EPSILON}
train_text = (
    "The quick brown fox jumps over the lazy dog. "
    "A stitch in time saves nine. To be or not to be, that is the question. "
    "All that glitters is not gold. Actions speak louder than words. "
    "Practice makes perfect. Where there is a will, there is a way. "
) * 3
train_syms = [fwd_map[s] for s in string_to_byte_strs(train_text)]
for sym in source_alpha:
    train_syms.append(sym)
inner_lm = CharNgramLM.train(train_syms, n=3, alpha=0.5)
print(f'Inner LM: alphabet={len(inner_lm.alphabet)} symbols')

Target: 45 symbols
  'The quick brown fox jumps over the lazy dog .'
Inner LM: alphabet=257 symbols


## TransducedLM Scaling

Per-step decode time for **TransducedLM** (two-phase: PeekabooState BFS
decomposition, then LM-weighted search) vs **FusedTransducedLM** (single-pass:
interleaved decomposition + LM search, no separate BFS).

Each step includes both decomposition and LM search costs.  For TransducedLM,
the PeekabooState BFS dominates (~35s per step on PTB).  FusedTransducedLM
avoids the BFS entirely but builds the lazy DFA inline during search.

Both use `max_steps=200`, `max_beam=100`, with a 120s timeout per step.

In [6]:
from collections import defaultdict
from transduction.lm.transduced import TransducedLM
from transduction.lm.fused_transduced import FusedTransducedLM

MAX_DECODE = 10              # number of decode steps
MAX_SEARCH = 200             # max priority-queue steps per logp_next
MAX_BEAM = 10                # max items carried forward
LM_TIMEOUT = 3               # seconds per step

# Expected runtime: ~15-20 min total (TransducedLM ~65s/step, FusedTransducedLM ~35s/step)

lm_results = defaultdict(list)  # name -> [(step, time_s, logp)]

for name, cls in [('TransducedLM', TransducedLM),
                  #('FusedTransducedLM', FusedTransducedLM)
                 ]:
    print(f'\n{name} (max_steps={MAX_SEARCH}, max_beam={MAX_BEAM}):')
    tlm = cls(inner_lm, ptb_fst, max_steps=MAX_SEARCH, max_beam=MAX_BEAM)
    try:
        with timelimit(LM_TIMEOUT):
            state = tlm.initial()
    except (Timeout, MemoryError) as e:
        print(f'  initial() failed: {type(e).__name__}: {e}')
        continue
    for i in range(min(MAX_DECODE, len(target_seq))):
        y = target_seq[i]
        try:
            with timelimit(LM_TIMEOUT):
                t0 = time.perf_counter()
                lp = state.logp_next[y]
                state = state >> y
                t1 = time.perf_counter()
        except Timeout:
            print(f'  step {i+1} TIMEOUT ({LM_TIMEOUT}s)')
            break
        except MemoryError:
            print(f'  step {i+1} OOM')
            break
        elapsed = t1 - t0
        lm_results[name].append((i + 1, elapsed, lp))
        print(f'  {i+1:2d}: {elapsed*1000:8.1f} ms  logp={lp:.4f}')
    gc.collect()


TransducedLM (max_steps=200, max_beam=10):
hello
  initial() failed: Timeout: Call took longer than 3 seconds.


In [None]:
# Summary table
print(f'\n{"Algorithm":<25s} {"Total (s)":>10s} {"Avg/step (s)":>12s} {"Steps":>6s}')
print('-' * 55)
for name, data in sorted(lm_results.items()):
    total = sum(t for _, t, _ in data)
    avg = total / len(data)
    print(f'{name:<25s} {total:10.1f} {avg:12.1f} {len(data):6d}')
if len(lm_results) == 2:
    names = sorted(lm_results.keys())
    d0, d1 = lm_results[names[0]], lm_results[names[1]]
    t0 = sum(t for _, t, _ in d0)
    t1 = sum(t for _, t, _ in d1)
    if t1 > 0:
        print(f'\nFused speedup (overall): {t0/t1:.2f}x')
    # Exclude step 1 (amortization penalty for Fused)
    if len(d0) > 1 and len(d1) > 1:
        t0_skip1 = sum(t for _, t, _ in d0[1:])
        t1_skip1 = sum(t for _, t, _ in d1[1:])
        if t1_skip1 > 0:
            print(f'Fused speedup (step 2+): {t0_skip1/t1_skip1:.2f}x')

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Left: time per step
ax = axes[0]
for name, data in sorted(lm_results.items()):
    steps = [d[0] for d in data]
    times = [d[1] for d in data]
    ax.plot(steps, times, 'o-', label=name, markersize=4)
ax.set_xlabel('Target step')
ax.set_ylabel('Time per step (s)')
ax.set_title(f'TransducedLM vs Fused (PTB, max_steps={MAX_SEARCH})')
ax.legend()
ax.grid(True, alpha=0.3)

# Right: per-step speedup
ax = axes[1]
if len(lm_results) == 2:
    names = sorted(lm_results.keys())
    d0, d1 = lm_results[names[0]], lm_results[names[1]]
    n = min(len(d0), len(d1))
    steps = [d0[i][0] for i in range(n)]
    speedups = [d0[i][1] / d1[i][1] if d1[i][1] > 0 else 0 for i in range(n)]
    colors = ['#2ecc71' if s > 1 else '#e74c3c' for s in speedups]
    ax.bar(steps, speedups, color=colors, alpha=0.7, edgecolor='white')
    ax.axhline(1.0, color='black', linestyle='--', linewidth=0.8)
    ax.set_xlabel('Target step')
    ax.set_ylabel('Speedup (Original / Fused)')
    ax.set_title('Per-step speedup (>1 = Fused faster)')
    ax.grid(True, alpha=0.3, axis='y')

    logp_diffs = [abs(d0[i][2] - d1[i][2]) for i in range(n)]
    print(f'Max |logp| diff: {max(logp_diffs):.6f}')

plt.tight_layout()
plt.show()

In [None]:
# Restore original memory limit
resource.setrlimit(resource.RLIMIT_AS, (_soft, _hard))
print('Memory limit restored.')
