In [1]:
from transduction import LazyPrecoverNFA, FST, EPSILON
from transformers import AutoTokenizer
from transduction.lm.statelm import HfTokenizerVocab
from transduction.viz import visualize_automaton
from transduction.util import memory_limit
memory_limit(8);

In [2]:
tokenizer = AutoTokenizer.from_pretrained('gpt2', use_fast=False)
_decode = HfTokenizerVocab(tokenizer).decode

In [3]:
def bpe_fst(decode, token_ids=None, drop=frozenset()):
    """Build a BPE FST from a subset of token IDs."""
    if token_ids is None: token_ids = list(range(len(decode)))
    m = FST()
    m.add_start(())
    for i in token_ids:
        x = decode[i]
        if x in drop:
            continue
        bx = tuple(x)
        for j in range(len(bx)):
            m.add_arc(bx[:j], EPSILON, bx[j], bx[:j+1])
        m.add_arc(bx, i, EPSILON, ())
    m.add_stop(())
    return m.renumber()

In [4]:
fst = bpe_fst(_decode[:])

In [5]:
from transduction.util import sample

In [6]:
class sampler:
    def __init__(self, fst, target):
        self.nfa = LazyPrecoverNFA(fst, target)
        self.dfa = self.nfa.det().cache()

    def __iter__(self): return self

    def __next__(self):
        path = []
        i = next(self.dfa.start())
        while not self.dfa.is_final(i):
            actions = list(self.dfa.arcs(i))
            print(len(i), [_decode[z] for z, _ in actions][:5])
            a = sample(np.ones(len(actions))/len(actions))
            y, i = actions[a]
            path.append(_decode[y])
        return path

In [None]:
import numpy as np
from transduction.rust_bridge import RustLazyPrecoverDFA

class rust_sampler:
    """Same as `sampler` but traverses the Rust precover DFA directly.

    RustLazyPrecoverDFA wraps PrecoverNFA + PowersetArena in Rust and
    exposes lazy on-demand state expansion â€” no need to materialize the
    full Q(target) FSA up front.
    """
    def __init__(self, fst, target):
        self.dfa = RustLazyPrecoverDFA(fst, target)

    def __iter__(self): return self

    def __next__(self):
        path = []
        i = self.dfa.start()
        while not self.dfa.is_final(i):
            actions = self.dfa.arcs(i)
            print(self.dfa.powerset_size(i), [_decode[z] for z, _ in actions][:5])
            a = sample(np.ones(len(actions))/len(actions))
            y, i = actions[a]
            path.append(_decode[y])
        return path

In [8]:
target = "The quick brown fox jumps over the lazy dog. ".encode()

In [None]:
S = rust_sampler(fst, target)
for x in S:
    print(x)
    break

4 [b'T', b'The', b'Th']
7 [b' ', b' qui', b' q', b' qu', b' quick']
7 [b' br', b' brow', b' bro', b' ', b' brown']
2 [b'n']
5 [b' f', b' fox', b' ', b' fo']
3 [b'o', b'ox']
7 [b' jumps', b' ', b' ju', b' jump', b' j']
2 [b's']
6 [b' ov', b' over', b' ', b' o', b' ove']
5 [b'over', b'o', b'ove', b'ov']
5 [b' t', b' the', b' ', b' th']
6 [b' la', b' l', b' laz', b' lazy', b' ']
4 [b'azy', b'a', b'az']
2 [b'y']
5 [b' do', b' dog', b' ', b' d']
4 [b'dog', b'do', b'd']
2 [b'g']
2 [b'.']


In [None]:
for _ in range(50):
    #print(next(S))
    print(max(len(x) for x in S.dfa._arcs_cache.keys()), len( S.dfa._arcs_cache))