# Is the pynini PTB FST a function?

An FST is a **function** if every input string maps to at most one output string.
If some input can produce multiple distinct outputs, it's a **relation**.

We test this by composing `FST.from_string(input) @ ptb_fst`, projecting onto the output,
and checking whether the resulting FSA has more than one string in its language.

In [None]:
import sys; sys.path.insert(0, '..')

from collections import deque
from benchmark.fsts.ptb_pynini import build_ptb_fst_pynini, string_to_byte_strs, SEP
from transduction.fst import FST
from transduction.fsa import EPSILON

fst = build_ptb_fst_pynini()

In [None]:
def decode_output(output_tuple):
    """Decode FST output tuple back to readable token string."""
    tokens = []
    current = []
    for sym in output_tuple:
        if sym == SEP:
            if current:
                tokens.append(bytes(int(b) for b in current).decode('utf-8', errors='replace'))
                current = []
        elif sym != EPSILON and int(sym) < 256:
            current.append(sym)
    if current:
        tokens.append(bytes(int(b) for b in current).decode('utf-8', errors='replace'))
    return ' '.join(tokens)


def enumerate_outputs(fst, text, max_outputs=20):
    """
    Enumerate up to `max_outputs` distinct output strings for a given input text.
    
    Returns a list of distinct output tuples.
    """
    byte_strs = string_to_byte_strs(text)
    input_fst = FST.from_string(byte_strs)
    output_fsa = (input_fst @ fst).project(1)
    
    # BFS enumeration with visited-state tracking to avoid infinite loops
    seen_outputs = set()
    results = []
    worklist = deque()
    for s in output_fsa.start:
        worklist.append((s, (), set()))
    
    iterations = 0
    max_iterations = 500_000  # safety limit
    
    while worklist and len(results) < max_outputs and iterations < max_iterations:
        iterations += 1
        state, path, visited_with_path = worklist.popleft()
        
        if state in output_fsa.stop:
            if path not in seen_outputs:
                seen_outputs.add(path)
                results.append(path)
                if len(results) >= max_outputs:
                    break
        
        for a, j in output_fsa.arcs(state):
            if a == EPSILON:
                # For epsilon transitions, track visited states to avoid cycles
                key = (j, path)
                if key not in visited_with_path:
                    new_visited = visited_with_path | {key}
                    worklist.append((j, path, new_visited))
            else:
                new_path = path + (a,)
                worklist.append((j, new_path, set()))
    
    return results


def check_functional(fst, text, max_outputs=10):
    """
    Check if the FST is functional on a given input.
    Returns (is_functional, outputs_list).
    """
    outputs = enumerate_outputs(fst, text, max_outputs=max_outputs)
    return len(outputs) <= 1, outputs

## Quick sanity check

A simple input should produce exactly one output.

In [None]:
is_fn, outputs = check_functional(fst, "Hello, world!")
print(f"Functional: {is_fn}")
print(f"Number of outputs: {len(outputs)}")
for o in outputs:
    print(f"  -> {decode_output(o)!r}")

## Curated test cases

These target areas most likely to cause ambiguity:
quotes, contractions, clitics, punctuation edge cases.

In [None]:
test_cases = [
    # Basic
    "Hello world",
    "Hello, world!",
    
    # Contractions and clitics
    "I can't do it.",
    "I'll be there.",
    "We've been here.",
    "Don't you think?",
    "It's a test.",
    "She'd gone home.",
    
    # Contractions that split
    "cannot stop",
    "gonna be great",
    "gotta run",
    "lemme see",
    "wanna go",
    "gimme that",
    "CANNOT STOP",
    "Cannot stop",
    
    # Contractions3
    "'tis nothing",
    "'twas long ago",
    "'Tis the season",
    
    # Quotes — potential ambiguity source
    '"Hello," she said.',
    'She said "hello" to me.',
    'He said "don\'t" loudly.',
    '"Can\'t stop," he said.',
    "''Hello,'' she replied",
    'She said ''hello'' there',
    '""double quotes""',
    
    # Nested/adjacent quotes
    '"She said \'hello\'," he replied.',
    '"It\'s fine," she said.',
    
    # Punctuation
    "Hello...",
    "What?!",
    "a -- b",
    "(hello) [world]",
    "1,000 people",
    "at 3:00 PM",
    "items: none",
    "$100 & more",
    "50% off @ store",
    
    # Period edge cases
    "Dr. Smith went home.",
    "U.S.A.",
    "end.",
    "end. ",
    
    # Apostrophe edge cases
    "the dog's bone",
    "the dogs' bones",
    "rock 'n' roll",
    "'twas brillig",
    
    # Multiple clitics in one sentence
    "I can't believe she'd've done that.",
    "They'll've gone by then.",
    
    # Mixed contractions and quotes
    '"I can\'t," she said.',
    'He replied, "We\'ll see."',
    
    # Edge cases that might cause overlapping rule matches
    "''",
    '"',
    "'",
    "a'b",
    "a''b",
    'a"b',
    "d'ye know",
    "more'n enough",
    
    # Longer text
    'The company reported $1,000,000 in revenue (a 50% increase), which "exceeded expectations."',
]

non_functional = []

for text in test_cases:
    is_fn, outputs = check_functional(fst, text)
    n = len(outputs)
    status = '✓' if is_fn else f'✗ ({n} outputs)'
    print(f"{status} {text!r}")
    if not is_fn:
        non_functional.append((text, outputs))
        for o in outputs[:5]:
            print(f"    -> {decode_output(o)!r}")

print(f"\n{'='*60}")
print(f"Functional: {len(test_cases) - len(non_functional)}/{len(test_cases)}")
print(f"Non-functional: {len(non_functional)}")

## Adversarial inputs

Target overlapping rule contexts: contractions inside quotes, mixed punctuation,
boundary conditions, adjacent special characters.

In [None]:
adversarial = [
    # Contraction + clitic overlap
    "cannot's", "gimme's", "gonna's",
    # Contraction at sentence boundary
    "cannot.", "cannot!", "cannot?",
    # Contraction inside quotes
    '"cannot"', '"gimme"', '"gonna"',
    # Contraction adjacent to other contractions
    "cannot cannot", "I can't cannot go",
    # Multiple apostrophe patterns
    "it's 'cause", "she's 'n' he's", "'twas it's",
    # Quotes with contractions
    "\"I can't,\" she can't.", "''can't''", '"cannot stop"',
    # Period with quotes
    'said "hello."', "said ''hello.''", 'He said "Dr. Smith."',
    # Multiple punctuation
    "hello!!", "hello??", "hello?!", "hello!?", "a...b...c",
    # Brackets with quotes
    '("hello")', '["world"]', '(cannot)',
    # Edge: minimal inputs
    "a", "ab", " ", "  ",
    # Double-dash with quotes
    '"hello" -- "world"', "a--b--c",
    # Hash with context
    "#hello", "a#b", "100#",
    # Single character punctuation
    ".", ",", ":", ";", "!", "?", "@", "#", "%", "&",
    # Adjacent special chars
    "@#$%", ".,;:!?",
    # Whitespace-sensitive
    " hello", "hello ", " hello ", "  hello  world  ",
]

adv_non_functional = []
for text in adversarial:
    is_fn, outputs = check_functional(fst, text)
    n = len(outputs)
    status = '✓' if is_fn else f'✗ ({n} outputs)'
    print(f"{status} {text!r}")
    if not is_fn:
        adv_non_functional.append((text, outputs))
        for o in outputs[:5]:
            print(f"    -> {decode_output(o)!r}")

print(f"\n{'='*60}")
print(f"Adversarial: {len(adversarial) - len(adv_non_functional)}/{len(adversarial)} functional")

## WikiText paragraphs

Test on 200 real-world paragraphs from WikiText-103.

In [None]:
from benchmark.data import load_wikitext, wikitext_detokenize

dataset = load_wikitext('test')

paragraphs = []
for item in dataset:
    text = item['text'].strip()
    if text and not text.startswith('='):
        detokenized = wikitext_detokenize(text)[:200]
        if len(detokenized) > 10:
            paragraphs.append(detokenized)
    if len(paragraphs) >= 200:
        break

print(f"Loaded {len(paragraphs)} paragraphs from WikiText")

wiki_non_functional = []
for i, text in enumerate(paragraphs):
    is_fn, outputs = check_functional(fst, text, max_outputs=5)
    if not is_fn:
        wiki_non_functional.append((text, outputs))
        print(f"✗ [{i}] ({len(outputs)} outputs) {text[:80]!r}")
        for o in outputs[:3]:
            print(f"    -> {decode_output(o)!r}")
    elif (i + 1) % 50 == 0:
        print(f"  ... checked {i + 1}/{len(paragraphs)}, {len(wiki_non_functional)} non-functional so far")

print(f"\n{'='*60}")
print(f"WikiText: {len(paragraphs) - len(wiki_non_functional)}/{len(paragraphs)} functional")

## Random stress test

500 random strings from a punctuation-biased alphabet.

In [None]:
import random, string

random.seed(42)
alphabet = string.ascii_letters + string.digits + "  .,;:!?'\"()[]{}--@#$%&"

rand_non_functional = []
n_tests = 500
for i in range(n_tests):
    length = random.randint(1, 30)
    text = ''.join(random.choice(alphabet) for _ in range(length))
    try:
        is_fn, outputs = check_functional(fst, text, max_outputs=5)
        if not is_fn:
            rand_non_functional.append((text, outputs))
            print(f"✗ ({len(outputs)} outputs) {text!r}")
            for o in outputs[:3]:
                print(f"    -> {decode_output(o)!r}")
    except Exception:
        pass  # skip errors from unusual byte sequences
    if (i + 1) % 100 == 0:
        print(f"  ... tested {i+1}/{n_tests}, {len(rand_non_functional)} non-functional")

print(f"\n{'='*60}")
print(f"Random: {n_tests - len(rand_non_functional)}/{n_tests} functional")

## Conclusion

In [None]:
all_non_functional = non_functional + adv_non_functional + wiki_non_functional + rand_non_functional

if not all_non_functional:
    print("RESULT: The PTB pynini FST is a FUNCTION.")
    print()
    print("No input (out of ~790 tested) produced multiple distinct outputs.")
    print()
    print("This is expected: each pynini cdrewrite rule is functional by construction,")
    print("and composition of functional FSTs preserves functionality.")
else:
    print(f"RESULT: The PTB FST is a RELATION — {len(all_non_functional)} inputs produced multiple outputs.")
    print()
    for text, outputs in all_non_functional:
        print(f"Input: {text!r}")
        for o in outputs:
            print(f"  -> {decode_output(o)!r}")
        print()