# TransducedLM Demo

This notebook demonstrates the `TransducedLM` — an incremental language model that computes the **pushforward** of an inner LM through an FST.

Given:
- An inner LM $P_{\text{inner}}(\text{source})$ over source strings
- An FST mapping source → target

TransducedLM computes $P(y \mid \text{target\_so\_far})$ by marginalizing over all source strings that produce the target prefix, using the peekaboo decomposition for incremental next-symbol computation.

In [None]:
import numpy as np
from collections import defaultdict, Counter

from transduction import examples, FST, EPSILON
from transduction.lm.transduced import TransducedLM, logsumexp

## A simple char-level n-gram LM

We need an inner LM that implements the `StateLM` interface:
- `lm.initial()` → initial state
- `state << token` → advance by one token
- `state.logp_next[token]` → log P(token | context)
- `state.eos` → the EOS token

In [None]:
from transduction.lm.ngram import CharNgramLM

## Helper: display a distribution

In [None]:
def copy_fst(alphabet):
    """Identity/copy transducer: maps each symbol to itself."""
    fst = FST()
    fst.add_I(0)
    fst.add_F(0)
    for x in alphabet:
        fst.add_arc(0, x, x, 0)
    return fst


def show_dist(state, symbols, show_zeros=False):
    """Display the full next-symbol distribution including EOS."""
    lp = state.logp_next
    rows = []
    all_logps = []
    for y in sorted(symbols):
        v = lp[y]
        all_logps.append(v)
        rows.append((repr(y), v, np.exp(v) if v > -50 else 0.0))
    eos_v = lp[state.eos]
    all_logps.append(eos_v)
    rows.append(('EOS', eos_v, np.exp(eos_v) if eos_v > -50 else 0.0))

    total = np.exp(logsumexp(all_logps))
    print(f"  target = {state._peekaboo_state.target!r}")
    for name, logp, prob in rows:
        if not show_zeros and prob < 1e-6:
            continue
        bar = '#' * int(prob * 40)
        print(f"  P({name:>5s}) = {prob:.4f}  {bar}")
    print(f"  {'sum':>10s} = {total:.6f}")

---
## 1. Copy FST: TransducedLM reproduces the inner LM

With an identity transducer (each symbol maps to itself), the transduced distribution should match the inner LM exactly — including EOS.

In [None]:
inner = CharNgramLM.train("aabbaabb" * 10, n=2, alpha=0.5)
fst = copy_fst([s for s in inner.alphabet if s != '<EOS>'])
symbols = sorted(fst.B - {EPSILON})

tlm = TransducedLM(inner, fst, max_steps=2000, max_beam=200)
state = tlm.initial()

print("From empty prefix:")
show_dist(state, symbols)

print("\nInner LM comparison:")
inner_state = inner.initial()
for y in symbols:
    print(f"  {y!r}: inner={inner_state.logp_next[y]:.4f}  transduced={state.logp_next[y]:.4f}")
print(f"  EOS: inner={inner_state.logp_next['<EOS>']:.4f}  transduced={state.logp_next[state.eos]:.4f}")

In [None]:
print("After advancing by 'a':")
state_a = state << 'a'
show_dist(state_a, symbols)

print("\nAfter advancing by 'a' then 'b':")
state_ab = state_a << 'b'
show_dist(state_ab, symbols)

---
## 2. Non-trivial FST: `examples.small()`

The `small()` FST has structure:
- State 0 (initial, final): `a→x→1`, `b→x→2`
- State 1 (final): no outgoing arcs
- State 2: `a→a→3`, `b→b→3`
- State 3 (final): `a→a→3`, `b→b→3`

So possible outputs are: `ε` (empty, from state 0), `x` (from source `a`, via state 1), and `x` followed by any string of `a`s and `b`s (from source `b...`, via states 2→3).

Note that state 0 is final, so the empty source string produces the empty output — this gives non-trivial P(EOS) from the initial state.

In [None]:
fst2 = examples.small()
inner2 = CharNgramLM.train("aababba" * 10, n=2, alpha=0.5)
symbols2 = sorted(fst2.B - {EPSILON})

print(f"Input alphabet:  {sorted(fst2.A - {EPSILON})}")
print(f"Output alphabet: {symbols2}")
print(f"Initial: {sorted(fst2.I)}, Final: {sorted(fst2.F)}")
print()

tlm2 = TransducedLM(inner2, fst2, max_steps=2000, max_beam=200)

s = tlm2.initial()
print("Empty prefix — only 'x' and EOS are reachable:")
show_dist(s, symbols2)

In [None]:
s_x = s << 'x'
print("After 'x':")
print("  Source 'a' maps to just 'x' (state 1, final, no outgoing arcs)")
print("  Source 'b...' maps to 'x' + continuation, but the LM puts")
print("  most mass on 'a', so almost all probability goes to EOS:")
show_dist(s_x, symbols2)

In [None]:
# togglecase() maps a→A, b→B, A→a, B→b (and space→space)
# With an inner LM trained on lowercase, the output will be uppercase.
fst_tc = examples.togglecase()
target_alpha_tc = sorted(fst_tc.B - {EPSILON})
source_alpha_tc = sorted(fst_tc.A - {EPSILON})
print(f"togglecase: {source_alpha_tc} → {target_alpha_tc}")

inner_tc = CharNgramLM.train("ab ba ab ba" * 10, n=2, alpha=0.1)
tlm_tc = TransducedLM(inner_tc, fst_tc, max_steps=2000, max_beam=200)

s = tlm_tc.initial()
print("\nEmpty prefix — uppercase symbols should dominate:")
show_dist(s, target_alpha_tc)

s = s << 'A'
print("\nAfter 'A' (inner LM saw 'a' → next likely 'b' → output 'B'):")
show_dist(s, target_alpha_tc)

s = s << 'B'
print("\nAfter 'AB' (inner saw 'ab' → next likely ' ' → output ' '):")
show_dist(s, target_alpha_tc)

---
## 3. Autoregressive decoding

We can use `TransducedLM` for greedy (or sampled) autoregressive decoding, stopping when EOS is the most likely next token.

In [None]:
# greedy_decode and sample_decode are now methods on any LM state
# via LMState — just call state.greedy_decode() or state.sample_decode().
#
# For verbose step-by-step output in this demo, we use a small wrapper:

def show_greedy(state, max_len=15):
    """Greedy decode with per-step printing."""
    output = []
    for step in range(max_len):
        lp = state.logp_next
        best_tok = lp.argmax()
        best_lp = lp[best_tok]
        eos_lp = lp[state.eos]
        if best_tok == state.eos:
            print(f"  step {step}: EOS (logp={eos_lp:.3f})")
            break
        output.append(best_tok)
        print(f"  step {step}: {best_tok!r}  (logp={best_lp:.3f}, P(EOS)={np.exp(eos_lp):.4f})")
        state = state << best_tok
    return output, state

In [None]:
inner3 = CharNgramLM.train("abababab" * 20, n=2, alpha=0.1)
fst3 = copy_fst(['a', 'b'])
symbols3 = sorted(fst3.B - {EPSILON})

tlm3 = TransducedLM(inner3, fst3, max_steps=2000, max_beam=200)

print("Greedy decoding with copy FST (inner LM trained on 'abababab'):")
decoded, final_state = show_greedy(tlm3.initial())
print(f"  => {''.join(str(t) for t in decoded)!r}  (logp={final_state.logp:.4f})")

# Equivalently, without verbose output:
tokens = tlm3.initial().greedy_decode(max_len=15)
print(f"\n  state.greedy_decode() => {tokens!r}")

In [None]:
print("Sampling (5 draws) via state.sample_decode():")
for i in range(5):
    tokens = tlm3.initial().sample_decode(max_len=15)
    print(f"  {i}: {''.join(str(t) for t in tokens)!r}")

---
## 4. Decoding through a non-trivial FST

Decode through `togglecase()` — the inner LM is trained on lowercase text, but the output is uppercase.

In [None]:
inner4 = CharNgramLM.train("ab ba ab ba ab" * 10, n=2, alpha=0.1)
fst4 = examples.togglecase()
symbols4 = sorted(fst4.B - {EPSILON})

tlm4 = TransducedLM(inner4, fst4, max_steps=2000, max_beam=200)

print("Greedy decoding through togglecase() FST:")
decoded, _ = show_greedy(tlm4.initial())
print(f"  => {''.join(str(t) for t in decoded)!r}")

print("\nSampling (5 draws) via state.sample_decode():")
for i in range(5):
    tokens = tlm4.initial().sample_decode(max_len=15)
    print(f"  {i}: {''.join(str(t) for t in tokens)!r}")

---
## 5. Normalization check

Verify that at each step, the distribution over symbols + EOS sums to 1.

In [None]:
inner5 = CharNgramLM.train("aabbaabb" * 10, n=2, alpha=0.5)
fst5 = copy_fst([s for s in inner5.alphabet if s != '<EOS>'])
symbols5 = sorted(fst5.B - {EPSILON})

tlm5 = TransducedLM(inner5, fst5, max_steps=2000, max_beam=200)
state = tlm5.initial()

for target_sym in ['a', 'b', 'a', 'b']:
    lp = state.logp_next
    all_logps = [lp[y] for y in symbols5] + [lp[state.eos]]
    total = np.exp(logsumexp(all_logps))
    print(f"  target={state._peekaboo_state.target!r:>6s}  "
          f"sum={total:.10f}  "
          f"P(EOS)={np.exp(lp[state.eos]):.6f}")
    state = state << target_sym

# Final state
lp = state.logp_next
all_logps = [lp[y] for y in symbols5] + [lp[state.eos]]
total = np.exp(logsumexp(all_logps))
print(f"  target={state._peekaboo_state.target!r:>6s}  "
      f"sum={total:.10f}  "
      f"P(EOS)={np.exp(lp[state.eos]):.6f}")