In [0]:
hf_token = "xxx" # HuggingFace token to load the `mistralai/Mistral-7B-v0.1`tokenizer

In [0]:
import copy
import io
import pickle
import os
import time
import json
import functools
from collections import defaultdict, Counter

import numpy as np
import torch
import transformers
import pygtrie

from pyformlang.cfg.variable import Variable
from pyformlang.cfg.terminal import Terminal
from pyformlang.cfg.production import Production
from pyformlang.cfg.cfg import CFG
from pyformlang.regular_expression import Regex
from pyformlang.finite_automaton.deterministic_finite_automaton import DeterministicFiniteAutomaton as DFA
from pyformlang.finite_automaton import EpsilonNFA

import networkx as nx
from networkx.drawing.nx_agraph import to_agraph
import matplotlib.pyplot as plt

# Tokenizer

In [0]:
tokenizer = transformers.AutoTokenizer.from_pretrained(
    "mistralai/Mistral-7B-v0.1",
    token=hf_token
)

def get_token_id(token, tokenizer):
    """
    Return the token id of a given token.
    """
    return tokenizer.backend_tokenizer.model.token_to_id(token)

def get_token(token_id, tokenizer):
    """
    Return the token corresponding to a token id.
    """
    token = tokenizer.backend_tokenizer.model.id_to_token(token_id)
    if token == "<0x0A>":
        return "\n"
    return token

def encode(s, tokenizer):
    """
    Encode a string with a tokenizer.
    """
    # The difference with the standard encode function is that we remove the
    # added blank space at the beginning of the string
    normalized = tokenizer.backend_tokenizer.normalizer.normalize_str(s)[1:]
    return [
        x.as_tuple()[0]
        for x in tokenizer.backend_tokenizer.model.tokenize(normalized)
    ]

In [0]:
# We restrict ourselves to ASCII characters

characters = {chr(i) for i in range(32, 127)}
characters.add(" ")
characters.add("\n")

vocabulary = {}
tokenizer_vocab = tokenizer.vocab
for w in tokenizer_vocab:
    if w == "<0x0A>":
        vocabulary["\n"] = tokenizer_vocab[w]
        continue
    for c in w:
        if c not in characters and c != "▁":
            break
    if w.startswith("<0x"):
        continue
    else:
        vocabulary[w.replace("▁", " ")] = tokenizer_vocab[w]

# We create a trie corresponding to the available tokens to accelerate the creation of the token-level NFA.
        
trie = pygtrie.CharTrie()
for t in vocabulary:
    trie[t] = vocabulary[t]

# Grammar

In [0]:
import lark
from lark.indenter import PythonIndenter
import interegular
import re

# In the context of this notebook, we focus on the Python grammar, as specified in `lark`.

parser = lark.Lark.open(
    'python.lark',
    parser='lalr',
    lexer='basic',
    postlex=PythonIndenter(),
    start='file_input'
)

terminals = set([x.name for x in parser.lexer_conf.terminals])
non_terminals = set()

def get_name(non_terminal):
    try:
        return non_terminal.fullrepr.split("'")[3]
    except IndexError:
        return non_terminal.name

for rule in parser.rules:
    name = get_name(rule.origin)
    non_terminals.add(name)
    for x in rule.expansion:
        if type(x) == lark.grammar.NonTerminal:
            non_terminals.add(x.name)
        elif type(x) == lark.grammar.Terminal:
            terminals.add(x.name)

print(len(parser.rules), "rules,", len(terminals), "terminals,", len(non_terminals), "non-terminals")

In [0]:
# Compute the immediate terminal or non-terminal successors of a symbol

successors = defaultdict(set)

def get_name(non_terminal):
    try:
        return non_terminal.fullrepr.split("'")[3]
    except IndexError:
        return non_terminal.name

def is_terminal(node):
    return node != node.lower()

for r in parser.rules:
    name = get_name(r.origin)
    current = f"{name}_start"
    for node in r.expansion:
        if isinstance(node, lark.grammar.NonTerminal):
            new_name = get_name(node)
            next = f"{new_name}_start"
            new_current = f"{new_name}_end"
        else:
            next, new_current = node.name, node.name
        successors[current].add(next)
        current = new_current
    successors[current].add(f"{name}_end")

# Compute the immediate terminal successors of a symbol

terminal_successors = defaultdict(dict)

for t in successors:
    already_seen = {}
    queue = [(str(t), ())]
    while len(queue) > 0:
        (node, path) = queue.pop()
        if node in successors:
            for next_node in successors[node]:
                if next_node not in already_seen:
                    if is_terminal(next_node):
                        terminal_successors[t][next_node] = path + (node,)
                    else:
                        queue.append((next_node, path + (node,)))
                        already_seen[next_node] = True

# Store the priorities of the terminals

terminal2priority = {
    v.name: k
    for k, v in enumerate(parser.parser.lexer.lexer.terminals)
}
terminal2priority[""] = -1
terminal2priority["UNDERSCORE"] = 1
terminal2priority["UNDERSCORE"] = 1
terminal2priority["_DEDENT"] = 1
terminal2priority["_INDENT"] = 1

# Slightly adjust the regex of the terminals so that they are compatible with interegular

terminal2regex = {t.name: t.pattern.to_regexp() for t in parser.lexer_conf.terminals}
terminal2regex["STRING"] = r"""([ubfr]|ur|br|rb|fr|rf)?(?:'[^'\\\n]*(?:\\.[^'\\\n]*)*'|"[^"\\\n]*(?:\\.[^"\\\n]*)*")"""
terminal2regex["LONG_STRING"] =  r'''(?:(?:[ubf]?[rR]?|[rR]?[ubf]?)(?:(?:"""(?:[^\\]|\\.|\\\n)*?"""''' + r"""|'''(?:[^\\]|\\.|\\\n)*?''')))"""
terminal2regex["DEC_NUMBER"] = r"0|(?:[1-9](?:(?:[0-9])|(?:[0-9]\_[0-9]))*)"
terminal2regex["_DEDENT"] = ""
terminal2regex["_INDENT"] = ""

# Identify and display sets of terminals that are mutually interchangeable

terminal2rules = defaultdict(set)

for r in parser.rules:
    name = r.origin.name if type(r.origin.name) == str else r.origin.name.value
    symbols = [x.name for x in r.expansion]
    for i in range(len(symbols)):
        symbol = symbols[i]
        if symbol == symbol.upper():
            terminal2rules[symbol].add(name + ":" + "-".join(symbols[:i] + ["X"] + symbols[i+1:]))

equivalent_terminals = {}
for t in terminal2rules:
    equivalent_terminals[t] = []
    for t2 in terminal2rules:
        if terminal2rules[t] == terminal2rules[t2]:
            equivalent_terminals[t].append(t2)
terminal_replacement = {t: sorted(equivalent_terminals[t])[0] for t in equivalent_terminals}
terminals_to_replace = {t for t in terminal_replacement if terminal_replacement[t] != t}

def get_regex(t):
    if t.startswith("__ANON_"):
        return terminal2regex[t].replace("\\", "")
    return t

equivalence_classes = defaultdict(set)
for t in terminal_replacement:
    equivalence_classes[terminal_replacement[t]].add(t)
print("Sets of interchangeable terminals:")
for c in equivalence_classes:
    if len(equivalence_classes[c]) > 1:
        print("  "+", ".join([get_regex(t) for t in equivalence_classes[c]]))

# Character-level NFA

In [0]:
class CharacterNFA():
    """
    Class to define an NFA at the character level.
    """
    def __init__(self, states, alphabet, map, initial_state, final_states):
        self.states = states
        self.alphabet = alphabet
        self.map = map
        self.initial_state = initial_state
        self.final_states = final_states

    def add_states(self, new_states):
        self.states = self.states.union(new_states)
        for state in new_states:
            self.map[state] = {}

    def add_transition(self, state, symbol, new_state):
        if state not in self.map:
            self.map[state] = {}
        if symbol not in self.map[state]:
            self.map[state][symbol] = {new_state}
        else:
            self.map[state][symbol].add(new_state)

    def add_final_states(self, new_final_states):
        self.final_states = self.final_states.union(new_final_states)

In [0]:
# We create a character-level NFA to identify potential sequences of terminals.
# For this, we assemble the DFA corresponding to the regex of each terminals.

character_nfa = CharacterNFA(
    {0},
    characters,
    {0: {}},
    0,
    set()
)

state2terminal = {0: ""}
terminal2initial_state = {}
terminal2final_states = defaultdict(set)

def get_fsm(terminal):
    return interegular.parse_pattern(terminal2regex[terminal]).to_fsm().reduce()

for t in terminal2regex:
    # We create the DFA associated with the regex of the terminal.
    fsm = get_fsm(t)

    # We map the indices used in the FSM to the corresponding characters.
    symbol2character = defaultdict(set)
    characters_in_map = set()
    for k in fsm.alphabet:
        symbol2character[fsm.alphabet[k]].add(k)
        if type(k) == str:
            characters_in_map.add(k)

    # We add one state to the NFA for each state of the DFA.
    delta = len(character_nfa.states)
    new_states = {k + delta for k in fsm.states}
    for s in new_states:
        state2terminal[s] = t
    character_nfa.add_states(new_states)
    terminal2initial_state[t] = fsm.initial + delta
    for k in fsm.finals:
        terminal2final_states[t].add(k + delta)

    # We add the edges corresponding to the edges of the DFA.
    for node in fsm.map:
        for symbol in fsm.map[node]:
            origin = node + delta
            destination = fsm.map[node][symbol] + delta
            for character in symbol2character[symbol]:
                if type(character) == str:
                    if character not in characters:
                        continue
                    character_nfa.add_transition(origin, character, destination)
                else:
                    for c in characters:
                        if c not in characters_in_map:
                            character_nfa.add_transition(origin, c, destination)

for t in terminal2final_states:
    character_nfa.add_final_states(terminal2final_states[t])

new_transitions = []

# We connect the start state of the NFA with the states of the sub-DFAs.
for t in terminal2regex:
    initial_state = terminal2initial_state[t]
    if initial_state in character_nfa.map:
        for c in character_nfa.map[initial_state]:
            for destination in character_nfa.map[initial_state][c]:
                new_transitions.append((0, c, destination))

for new_transition in new_transitions:
    character_nfa.add_transition(*new_transition)
del new_transitions

print(len(character_nfa.states), "states,", sum([len(character_nfa.map[s]) for s in character_nfa.map]), "transitions")

In [0]:
class PartialPythonIndenter(PythonIndenter):
    """
    A subclass of the Lark Python indenter to process partial strings.
    """
    def handle_newline(self, terminal):

        if self.paren_level > 0:
            return []

        result = [terminal]

        indent_str = terminal.rsplit('\n', 1)[1] # Tabs and spaces
        indent = indent_str.count(' ') + indent_str.count('\t') * self.tab_len

        if indent > self.indent_level[-1]:
            self.indent_level.append(indent)
            result.append(lark.Token.new_borrow_pos(self.INDENT_type, indent_str, terminal))
        else:
            while indent < self.indent_level[-1]:
                self.indent_level.pop()
                result.append(lark.Token.new_borrow_pos(self.DEDENT_type, indent_str, terminal))

            if indent != self.indent_level[-1]:
                raise lark.DedentError('Unexpected dedent to column %s. Expected dedent to %s' % (indent, self.indent_level[-1]))
        return result

    def consume(self, terminals, final=False):
        result = []
        for terminal in terminals:
            if terminal.type == self.NL_type:
                result += self.handle_newline(terminal)
            else:
                result.append(terminal)

            if terminal.type in self.OPEN_PAREN_types:
                self.paren_level += 1
            elif terminal.type in self.CLOSE_PAREN_types:
                self.paren_level -= 1
                assert self.paren_level >= 0

        if final:
            while len(self.indent_level) > 1:
                self.indent_level.pop()
                result.append(lark.Token(self.DEDENT_type, ''))
            assert self.indent_level == [0], self.indent_level

        return result

In [0]:
# We check the syntactic validity of a Python script with the character NFA and the interactive parser.

prompt = """def fun(n):
    if n > 0:
        if n % 2 == 0.1: # Blabla
            return True # Blalbla
    # Blabla
    return False
"""

terminals_to_ignore = set(parser.ignore_tokens)
interactive_parser = parser.parse_interactive("")
acceptable_terminals = interactive_parser.accepts()
indenter = PartialPythonIndenter()
partial_string = ""
path = []

states = [0]

def next_states(state, c):
    if c in character_nfa.map[state]:
        if len(character_nfa.map[state][c]) > 0:
            return character_nfa.map[state][c], False
    elif state in character_nfa.final_states and c in character_nfa.map[0]:
        return character_nfa.map[0][c], True
    return [], False

for c in prompt:
    with_new_terminals, without_new_terminals = [], []
    highest_priority = -1
    for state in states:
        new_states, new_terminal = next_states(state, c)
        for new_state in new_states:
            if not new_terminal:
                without_new_terminals.append(new_state)
            else:
                priority = terminal2priority[state2terminal[state]]
                if priority < highest_priority:
                    continue
                if priority > highest_priority:
                    highest_priority = priority
                    with_new_terminals = []
                with_new_terminals.append((state, new_state))
    if len(without_new_terminals) > 0:
        states = without_new_terminals
        partial_string += c
    else:
        incoming_terminals, states = set(), []
        for (incoming_state, new_state) in with_new_terminals:
            incoming_terminal = state2terminal[incoming_state]
            incoming_terminals.add(incoming_terminal)
            states.append(new_state)
        assert len(incoming_terminals) == 1
        if incoming_terminal != "" and incoming_terminal not in terminals_to_ignore:
            for x in indenter.consume([lark.Token(incoming_terminal, partial_string)]):
                path.append(x)
                interactive_parser.feed_token(x)
        partial_string = c

if len(states) > 0 and state2terminal[states[0]] == "_NEWLINE":
    for x in indenter.consume([lark.Token("_NEWLINE", partial_string)], final=True):
        path.append(x)
        interactive_parser.feed_token(x)

path

# Token-level NFA

In [0]:
class TokenNFA():
    """
    Class to define an NFA at the token level.
    """
    def __init__(self, states, initial_state, final_states):
        self.states = states
        self.map = {s: {} for s in states}
        self.initial_state = initial_state
        self.final_states = final_states

    def add_transition(self, state, token_id, new_state, path, strings):
        # We label the transitions with the corresponding sequence of terminals and associated strings.
        if token_id not in self.map[state]:
            self.map[state][token_id] = [(new_state, path, strings)]
        else:
            self.map[state][token_id].append((new_state, path, strings))

In [0]:
%%time

# We create the token-level NFA corresponding to the terminals of the grammar.
# For this, we successively apply the transitions of the character-level NFA.
# A challenge is to properly take into account the priority of the terminals.

name_pattern = re.compile(terminal2regex["NAME"])
keywords = [t for t in terminal2regex if name_pattern.match(terminal2regex[t])]

def get_priority(terminals, strings):
    """
    Returns the priority of a path.
    """
    priority, idx = [], 0
    for i in range(len(terminals)):
        idx += len(strings[i])
        priority.append(idx)
        priority.append(terminal2priority[terminals[i]])
    return priority

def compare(p1, p2):
    """
    Compare two priorities (represented as a tuple).
    """
    length = min(len(p1), len(p2))
    p1, p2 = p1[:length], p2[:length]
    if p1 < p2:
        return -1
    return 0 if p1 == p2 else 1

token_nfa = TokenNFA(character_nfa.states, character_nfa.initial_state, character_nfa.final_states)
for state in token_nfa.states:
    start_terminal = state2terminal[state]
    # The `traverse_callback` function is later used to traverse the trie to avoid redundant computations.
    def traverse_callback(path_conv, path, children, token_id=-1):
        # We check if the current node corresponds to a token (with token_id ≥ 0)
        if token_id >= 0:
            token = path_conv(path)
            if len(token) == 1:
                # If the length of the token is 1, we simply copy the transition of the character NFA
                if token in character_nfa.map[state]:
                    for new_state in character_nfa.map[state][token]:
                        if state == 0:
                            terminals = ["", state2terminal[new_state]]
                            strings = ["", token]
                        else:
                            terminals = [start_terminal]
                            strings = [token]
                        token_nfa.add_transition(state, token_id, new_state, terminals, strings)
                elif state in character_nfa.final_states and state != 0 and token in character_nfa.map[0]:
                    if state2terminal[state] not in keywords or not name_pattern.match(token):
                        for new_state in character_nfa.map[0][token]:
                            terminals = [start_terminal, state2terminal[new_state]]
                            strings = ["", token]
                            token_nfa.add_transition(state, token_id, new_state, terminals, strings)
                else:
                    return 0
            else:
                # If the length is > 1, we start from the longest strict prefix of the token...
                previous_token, previous_token_id = trie.longest_prefix(token[:-1])
                if previous_token_id in token_nfa.map[state]:
                    stack = token_nfa.map[state][previous_token_id]
                else:
                    return 0
                #... and for each character between the prefix and the token, we follow the character NFA.
                for i in range(len(previous_token), len(token)):
                    c = token[i]
                    new_configurations = []
                    for current_state, current_terminals, strings in stack:
                        if c in character_nfa.map[current_state]:
                            for new_state in character_nfa.map[current_state][c]:
                                if current_state == 0:
                                    # If we are at the start state, the character leads to a new state.
                                    new_configurations.append((new_state, [state2terminal[new_state]], [c]))
                                else:
                                    # Otherwise, we necessarily stay in the same terminal DFA.
                                    new_configurations.append((new_state, current_terminals, strings[:-1] + [strings[-1] + c]))
                        elif current_state in character_nfa.final_states and current_state != 0 and c in character_nfa.map[0]:
                            if state2terminal[current_state] in keywords:
                                if name_pattern.match(c):
                                    continue
                            for new_state in character_nfa.map[0][c]:
                                new_configurations.append((new_state, current_terminals + [state2terminal[new_state]], strings + [c]))
                    # We favor continuations which don't lead to a new terminal
                    stack = []
                    if len(new_configurations) == 0:
                        return 0
                    else:
                        prioritized = defaultdict(list)
                        for k in new_configurations:
                            new_priority = tuple(get_priority(*k[1:]))
                            if len(prioritized) == 0:
                                prioritized[tuple(new_priority)].append(k)
                            else:
                                to_add, to_remove = False, set()
                                for priority in list(prioritized.keys()):
                                    comparison = compare(new_priority[:-2], priority[:-2])
                                    if comparison == 0:
                                        to_add = True
                                    elif comparison == 1:
                                        to_add = True
                                        to_remove.add(priority)
                                    else:
                                        break
                                else:
                                    if to_add:
                                        prioritized[tuple(new_priority)].append(k)
                                        for priority in to_remove:
                                            del prioritized[priority]
                        for p in prioritized:
                            stack += prioritized[p]

                for (new_state, terminals, strings) in stack:
                    token_nfa.add_transition(state, token_id, new_state, terminals, strings)

        return sum(children)

    trie.traverse(traverse_callback)

num_transitions = 0
for s in token_nfa.map:
    for token_id in token_nfa.map[s]:
        num_transitions += len(token_nfa.map[s][token_id])
print(len(token_nfa.states), "states,", num_transitions, "transitions")

In [0]:
# In `token_nfa.next_tokens` and `token_nfa.mask`, we rearrange the transition function
# to group together the tokens that lead to the same sequence of terminals

# The signature of `token_nfa.next_tokens` is state x (sequence_of_terminals, new_terminal) --> set of tokens
# where `new_terminal` is a boolean equal to True if the first characters remain in the current terminal.

# The signature of `token_nfa.mask` is state x (sequence_of_terminals, new_terminal) --> mask
# where `mask` is a binary vector whose length is the size of the vocabulary.

token_nfa.next_tokens = {}
for s in token_nfa.map:
    terminal = state2terminal[s]
    token_nfa.next_tokens[s] = {}
    for token_id in token_nfa.map[s]:
        for (_, new_terminals, strings) in token_nfa.map[s][token_id]:
            path = tuple([t for t in new_terminals[1:] if t not in parser.ignore_tokens])
            new_characters = len(strings[0]) > 0
            if (path, new_characters) not in token_nfa.next_tokens[s]:
                token_nfa.next_tokens[s][(path, new_characters)] = set([token_id])
            else:
                token_nfa.next_tokens[s][(path, new_characters)].add(token_id)

token_nfa.mask = {}
to_delete = set()
for s in token_nfa.next_tokens:
    token_nfa.mask[s] = {}
    for (path, new_characters) in token_nfa.next_tokens[s]:
        if len(token_nfa.next_tokens[s][(path, new_characters)]) > 0:
            mask = torch.zeros((tokenizer.vocab_size,), dtype=torch.bool)
            mask[list(token_nfa.next_tokens[s][(path, new_characters)])] = 1
            token_nfa.mask[s][(path, new_characters)] = mask
        else:
            to_delete.add((s, k))
for s, k in to_delete:
    del token_nfa.next_tokens[s][k]

In [0]:
print("Number of entries in the mask store:", sum(len(token_nfa.mask[k]) for k in token_nfa.mask))

unique_paths = set()
for s in token_nfa.next_tokens:
    for k in token_nfa.next_tokens[s]:
        unique_paths.add((state2terminal[s],) + k)
print(len(unique_paths), "unique configurations that can be tested at inference time")

In [0]:
def get_length_common_prefix(strings):
    prefix = None
    for s in strings:
        if prefix is None:
            prefix = s[:-1]
        else:
            length = min(len(prefix), len(s)-1)
            for i in range(length):
                if prefix[i] != s[i]:
                    prefix = prefix[:i]
                    break
            else:
                prefix = prefix[:length]

    return len(prefix)

In [0]:
mask_leading_space = torch.zeros((tokenizer.vocab_size,), dtype=torch.bool)
mask_no_leading_space = torch.zeros((tokenizer.vocab_size,), dtype=torch.bool)
mask_leading_space[[vocabulary[w] for w in vocabulary if w.startswith(" ") and not w.startswith("  ")]] = 1
mask_no_leading_space[[vocabulary[w] for w in vocabulary if not w.startswith(" ")]] = 1
linebreak_token_id = tokenizer.encode("\n")[-1]

class PythonMaskGenerator():
    """
    Class to generate masks to enforce a valid Python syntax.
    """
    def __init__(self, parser, token_nfa):
        self.token_nfa = token_nfa
        self.parser = parser
        self.reset()

    def reset(self):
        """
        Reset the states of the mask generator to generate a new string.
        """
        self.indenter = PartialPythonIndenter()
        self.states = [(0, [""], [""], -1)]
        self.partial_path = []
        self.partial_strings = []
        self.interactive_parser = self.parser.parse_interactive("")
        self.interactive_parser2 = self.parser.parse_interactive("")
        self.acceptable_terminals = self.interactive_parser.accepts()
        self.temp_interactive_parser = None

    def consume(self, token_id):
        """
        Update the states of the mask generator by consuming one token.
        """
        new_states = []
        prioritized = defaultdict(list)
        for (state, path, strings, _) in self.states:
            if token_id in self.token_nfa.map[state]:
                for (new_state, new_terminals, new_strings) in self.token_nfa.map[state][token_id]:
                    new_path = path + new_terminals[1:]
                    new_strings = strings[:-1] + [strings[-1] + new_strings[0]] + new_strings[1:]
                    new_priority = tuple(get_priority(new_path, new_strings))
                    if len(prioritized) == 0:
                        prioritized[new_priority].append((new_state, new_path, new_strings, new_priority))
                    else:
                        to_add, to_remove = False, set()
                        for priority in list(prioritized.keys()):
                            comparison = compare(new_priority[:-2], priority[:-2])
                            if comparison == 0:
                                to_add = True
                            elif comparison == 1:
                                to_add = True
                                to_remove.add(priority)
                            else:
                                break
                        else:
                            if to_add:
                                prioritized[new_priority].append((new_state, new_path, new_strings, new_priority))
                                for priority in to_remove:
                                    del prioritized[priority]

        for p in prioritized:
            new_states += prioritized[p]
            
        length_common_prefix = get_length_common_prefix([s[1] for s in new_states])
        (x, p, s, _) = new_states[0]
        for i in range(length_common_prefix):
            if (p[i] not in terminals_to_ignore and p[i] != ""):
                for x in self.indenter.consume([lark.Token(p[i], s[i])]):
                    self.partial_path.append(x)
                    self.interactive_parser.feed_token(x)
                    self.acceptable_terminals = self.interactive_parser.accepts()

        if self.indenter.paren_level > 0:
            new_states = [
                (
                    20 if (s[-1][-1] == "\n" and p[-1] in ["_NEWLINE", "__IGNORE_0"]) else x,
                    ["__IGNORE_0" if t == "_NEWLINE" else t for t in p],
                    s,
                    priority
                )
                for (x, p, s, priority) in new_states
        ]
        self.states = sorted(
            [(x, p[length_common_prefix:], s[-len(p)+length_common_prefix:], priority) for (x, p, s, priority) in new_states],
            key=lambda t: t[-1],
            reverse=True
        )

    def terminate(self):
        """
        Update the states of the mask generator by terminating the current string.
        """
        if len(self.states) > 0:
            if len(self.states[0][1]) > 0 and self.states[0][1][0] == "_NEWLINE":
                for x in self.indenter.consume([lark.Token("_NEWLINE", self.states[0][2][0])], final=True):
                    self.partial_path.append(x)
                    self.interactive_parser.feed_token(x)

    def validate_terminals(self, terminals, parentheses):
        """
        Check whether a sequence of terminals is acceptable given the current states of the mask generator.
        """
        if parentheses:
            terminals = [t for t in terminals if t not in terminals_to_ignore and t != "" and t != "_NEWLINE"]
        else:
            terminals = [t for t in terminals if t not in terminals_to_ignore and t != ""]
        if len(terminals) == 0:
            return True
        if terminals[0] not in self.acceptable_terminals:
            return False
        if len(terminals) == 1:
            return True
        self.interactive_parser2.parser_state = self.interactive_parser.parser_state.copy()
        try:
            for t in terminals:
                self.interactive_parser2.feed_token(lark.Token(t, ""))
            return True
        except:
            return False

    def get_mask(self, indent=False, streamlined=False, parentheses=False):
        """
        Return the mask corresponding to the current states (when the last token was not a new line).
        """
        mask = torch.zeros((tokenizer.vocab_size,), dtype=torch.bool)
        final_state_already_seen = False
        for (state, terminals, _, _) in self.states:
            current_path = tuple([t for t in terminals if t not in terminals_to_ignore])
            terminal = state2terminal[state]
            if len(terminal) > 0 and terminal not in terminals_to_ignore:
                if terminal not in self.acceptable_terminals:
                    if not parentheses or terminal != "_NEWLINE":
                        if state in self.token_nfa.final_states:
                            final_state_already_seen = True
                        continue
            if streamlined:
                for (path, new_characters) in self.token_nfa.streamlined_mask[state]:
                    #if self.states[0][2][-1] == 'class':
                    #    print("ççç", state, terminal, k)
                    if final_state_already_seen:
                        if not new_characters:
                            continue
                    to_check = current_path + (("_INDENT",) if indent else ()) + path
                    if self.validate_terminals(to_check, parentheses):
                        mask = torch.bitwise_or(self.token_nfa.streamlined_mask[state][(path, new_characters)], mask)
                        if parentheses and ((path == () and not new_characters) or terminal in terminals_to_ignore) and state in self.token_nfa.final_states:
                            mask[linebreak_token_id] = True
                if state in self.token_nfa.final_states:
                    final_state_already_seen = True
            else:
                for (path, new_characters) in self.token_nfa.mask[state]:
                    #if self.states[0][2][-1] == 'class':
                    #    print("ççç", state, terminal, k)
                    if final_state_already_seen:
                        if not new_characters:
                            continue
                    to_check = current_path + (("_INDENT",) if indent else ()) + path
                    if self.validate_terminals(to_check, parentheses):
                        mask = torch.bitwise_or(self.token_nfa.mask[state][(path, new_characters)], mask)
                        if parentheses and ((path == () and not new_characters) or terminal in terminals_to_ignore) and state in self.token_nfa.final_states:
                            mask[linebreak_token_id] = True
                if state in self.token_nfa.final_states:
                    final_state_already_seen = True
        return mask

    def build_mask(self, streamlined=False):
        """
        Return the mask corresponding to the current states.
        """
        state = self.states[0]
        if len(state[1]) == 0:
            mask = self.get_mask(streamlined=streamlined)
            mask[linebreak_token_id] = True
            return mask
        if len(state[1]) > 0 and state[1][-1] == "_NEWLINE":
            if "\n" in state[2][-1] and self.indenter.paren_level == 0:
                line_content = state[2][-1].split("\n")[-1]
                if len([c for c in line_content if c != " "]) == 0:
                    num_spaces = len(line_content)
                    parser_state = self.interactive_parser.parser_state.copy()
                    self.interactive_parser2.parser_state = parser_state
                    self.interactive_parser2.feed_token(lark.Token("_NEWLINE", state[2][-1]))
                    if "_INDENT" in self.interactive_parser2.accepts():
                        target_indentation = self.indenter.indent_level[-1] + 4
                        delta_indentation = target_indentation - num_spaces - 1
                        if delta_indentation == 0:
                            mask = torch.bitwise_and(self.get_mask(indent=True), mask_leading_space)
                        else:
                            mask = torch.zeros((tokenizer.vocab_size,), dtype=torch.bool)
                            mask[tokenizer.encode(" "*delta_indentation)[1]] = 1
                    else:
                        if 0 in self.indenter.indent_level and num_spaces == 0:
                            mask = torch.bitwise_and(
                                self.get_mask(),
                                mask_no_leading_space
                            )
                        else:
                            mask = torch.zeros((tokenizer.vocab_size,), dtype=torch.bool)

                        for indent_level in self.indenter.indent_level[1:]:
                            delta_indentation = indent_level - num_spaces - 1
                            if delta_indentation == 0:
                                mask = torch.bitwise_or(
                                    mask,
                                    torch.bitwise_and(
                                        self.get_mask(streamlined=streamlined),
                                        mask_leading_space
                                    )
                                )
                            elif delta_indentation > 0:
                                mask[tokenizer.encode(" "*delta_indentation)[1]] = 1
                    mask[linebreak_token_id] = True
                    return mask
        if self.indenter.paren_level > 0 or (len(self.states[0][1]) > 0 and self.states[0][1][-1] in ["LPAR", "LBRACE", "LSQB"]):
            mask = self.get_mask(streamlined=streamlined, parentheses=True)
            return mask

        return self.get_mask(streamlined=streamlined)

In [0]:
# Test the mask generator with a prompt

token_ids = tokenizer(prompt)["input_ids"][1:]

mask_generator = PythonMaskGenerator(parser, token_nfa)
mask = mask_generator.build_mask(streamlined=False)
for token_id in token_ids:
    # Check that the mask accepts the next token
    assert mask[token_id]
    mask_generator.consume(token_id)
    mask = mask_generator.build_mask(streamlined=False)
mask_generator.terminate()

# `is_never_legal`

In [0]:
variables = set()
terminals = set()
start_symbol = Variable(parser.parser.parser_conf.start[0])
productions = set()

# Build the CFG with only one element of each set of interchangeable terminals
rules = []
for rule in parser.rules:
    name = get_name(rule.origin)
    head = Variable(name)
    body, body2 = [], []
    variables.add(head)
    for x in rule.expansion:
        if type(x) == lark.grammar.NonTerminal:
            variable = Variable(x.name)
            variables.add(variable)
            body.append(variable)
        elif type(x) == lark.grammar.Terminal:
            if x.name in parser.ignore_tokens or x.name in terminals_to_replace:
                break
            terminal = Terminal(x.name)
            terminals.add(terminal)
            body.append(terminal)
        body2.append(x.name)
    else:
        productions.add(Production(head, body))
        rules.append((name, body2))

cfg = CFG(variables=variables, terminals=terminals, productions=productions, start_symbol=start_symbol)

In [0]:
def is_never_legal(terminal, target_sequence):
    """
    Return True if the terminals in `target_sequence` can never follow `terminal`.
    """
    if len(target_sequence) == 0:
        return True
    if terminal == "":
        anything = f"(__ANYTHING_ELSE__|{'|'.join(target_sequence)})*"
        regex_str = f"{'.'.join(target_sequence)}.{anything}"
    elif terminal in parser.lexer_conf.ignore:
        anything = f"(__ANYTHING_ELSE__|{'|'.join(target_sequence)})*"
        regex_str = f"{anything}.{'.'.join(target_sequence)}.{anything}"
    else:
        anything = f"(__ANYTHING_ELSE__|{'|'.join(set([terminal] + target_sequence))})*"
        regex_str = f"{anything}.{terminal}.{'.'.join(target_sequence)}.{anything}"
    regex = Regex(regex_str)

    dfa = regex.to_epsilon_nfa().minimize()
    all_terminals = [t.value for t in terminals if t.value not in terminals_to_ignore and t.value not in terminals_to_replace]
    transitions = dfa._transition_function.to_dict()
    for state in transitions:
        if "__ANYTHING_ELSE__" in transitions[state]:
            target = transitions[state]["__ANYTHING_ELSE__"]
            for terminal in all_terminals:
                if terminal not in transitions[state]:
                    dfa.add_transition(state, terminal, target)
            dfa.remove_transition(state, "__ANYTHING_ELSE__", target)
    return cfg.intersection(dfa).is_empty()

# `is_always_legal`

In [0]:
variables = set()
transitions = defaultdict(set)
# transitions[state] = (symbol_to_read, stack_symbol_to_pop, new_state, stack_symbol_to_add)

# Build the transition relation of the DFT pushdown automaton.
for rule in cfg.productions:
    variables.add(rule.head.value)
    if len(rule.body) == 0:
        transitions[(rule.head.value, "variable_start")].add(("@epsilon", "@epsilon", (rule.head.value, "variable_end"), "@epsilon"))
        continue

    current = (rule.head.value, "variable_start"), "@epsilon"
    for i in range(len(rule.body)):
        symbol = rule.body[i]
        location = rule.head.value + ":" + ".".join([rule.body[i].value for i in range(i+1)])
        if type(symbol) == Variable:
            transitions[current[0]].add(("@epsilon", current[1], (symbol.value, "variable_start"), location))
            current = (symbol.value, "variable_end"), location
        else:
            transitions[current[0]].add((symbol.value, current[1], (symbol.value, "terminal"), location))
            current = (symbol.value, "terminal"), location
    transitions[current[0]].add(("@epsilon", current[1], (rule.head.value, "variable_end"), "@epsilon"))

next_potential_terminals = defaultdict(set)

# Use the DFT pushdown automaton to list terminals that may follow a given symbol.
for symbol in [(t.value, "terminal") for t in terminals] + [(v, "variable_start") for v in variables] + [(v, "variable_end") for v in variables]:
    queue = [symbol]
    already_seen = set()
    while len(queue) > 0:
        current = queue.pop()
        if current in transitions:
            for (to_read, to_pop, destination, to_add) in transitions[current]:
                if destination[1] == "terminal":
                    next_potential_terminals[symbol].add(destination[0])
                elif destination not in already_seen:
                    already_seen.add(destination)
                    queue.append(destination)

In [0]:
class InstantaneousDescription:
    """
    Class to describe each step of the search trajectories within the DFT pushdown automaton.
    """
    def __init__(self, state, sequence):
        self.state = state
        self.stack = ()
        self.sequence = sequence
        self.idx = 0
        self.history = ((state, 0),)

    def apply_transition(self, to_read, to_pop, destination, to_add):
        if to_read != "@epsilon":
            if self.sequence[self.idx] == to_read:
                self.idx = self.idx + 1
            else:
                return None
        if to_pop != "@epsilon":
            if len(self.stack) == 0:
                self.history = self.history + (to_pop, (destination, self.stack, self.idx))
            elif self.stack[-1] == to_pop:
                self.stack = self.stack[:-1]
                self.history = self.history + ("@epsilon", (destination, self.stack, self.idx))
            else:
                return None
        else:
            self.history = self.history + ("@epsilon", (destination, self.stack, self.idx))
        if to_add != "@epsilon":
            self.stack = self.stack + (to_add,)
        self.state = destination
        return self

    def copy(self):
        return copy.deepcopy(self)

def find_trajectories(symbol, target_sequence, maximum_delta=20, terminal=True):
    """
    Find trajectories from `symbol` going through `target_sequence`.
    `maximum_delta` is a termination criterion based on the stack height so that the computation eventually ends.
    """
    successful_paths = []
    uncertain_paths = []
    incomplete_paths = []
    already_seen = defaultdict(set)
    front = [InstantaneousDescription((symbol, "terminal" if terminal else "variable_start"), target_sequence)]
    while len(front) > 0:
        new_front = []
        for current in front:
            for (to_read, to_pop, destination, to_add) in transitions[current.state]:
                new_current = current.copy().apply_transition(to_read, to_pop, destination, to_add)
                if new_current is None:
                    continue

                if (new_current.state, new_current.idx) in already_seen:
                    to_continue = False
                    for recorded_stack in already_seen[(new_current.state, new_current.idx)]:
                        if recorded_stack == new_current.stack:
                            uncertain_paths.append(new_current)
                            to_continue = True
                            break
                        elif len(new_current.stack) > len(recorded_stack) + maximum_delta:
                            incomplete_paths.append(new_current)
                            to_continue = True
                            break
                    if to_continue:
                        continue
                already_seen[(new_current.state, new_current.idx)].add(new_current.stack)

                if new_current.idx == len(new_current.sequence):
                    successful_paths.append(new_current)
                    continue

                if new_current.state[1] != "terminal":
                    if new_current.sequence[new_current.idx] not in next_potential_terminals[new_current.state]:
                        continue

                new_front.append(new_current)

        front = new_front
    return successful_paths, uncertain_paths, incomplete_paths

def build_nfa(successful_paths, uncertain_paths):
    """
    Build the debt NFA based on the trajectories obtained through the search.
    """
    nfa = EpsilonNFA()

    nfa.add_start_state(successful_paths[0].history[0])
    for successful_path in successful_paths:
        nfa.add_final_state(successful_path.history[-1])

    for path in successful_paths:
        for i in range(0, len(path.history) - 2, 2):
            start = path.history[i]
            transition = path.history[i+1]
            end = path.history[i+2]
            transition = path.history[i+1].replace("@epsilon", "epsilon")
            nfa.add_transitions([(start, transition, end)])

    for path in uncertain_paths:
        if path.history[-1] in nfa.states:
            for i in range(0, len(path.history) - 2, 2):
                start = path.history[i]
                transition = path.history[i+1]
                end = path.history[i+2]
                transition = path.history[i+1].replace("@epsilon", "epsilon")
                nfa.add_transitions([(start, transition, end)])

    return nfa

def accepts_empty_word(nfa):
    """
    Check whether the initial state of the NFA is ε-coaccessible.
    """
    final_states = list(nfa.final_states)[0]
    start_state = list(nfa.start_states)[0]
    nfa_transitions = nfa._transition_function.to_dict()
    epsilon_coaccessible_states = set()
    for s in nfa.states:
        eclose = nfa.eclose(s)
        for final_state in nfa.final_states:
            if final_state in eclose:
                epsilon_coaccessible_states.add(s)
                break
    num_new_states = len(epsilon_coaccessible_states)
    if start_state in epsilon_coaccessible_states:
        return True
    while num_new_states > 0:
        num_new_states = 0
        for s in nfa.states:
            if s not in epsilon_coaccessible_states:
                if s in nfa_transitions:
                    for k in nfa_transitions[s]:
                        if k.value == "epsilon":
                            for destination in nfa_transitions[s][k]:
                                if destination in epsilon_coaccessible_states:
                                    if s == start_state:
                                        return True
                                    epsilon_coaccessible_states.add(s)
                                    num_new_states += 1
                                    break
                    else:
                        if len(transitions[s.value[0]]) == len(nfa_transitions[s]):
                            for k in nfa_transitions[s]:
                                for destination in nfa_transitions[s][k]:
                                    if destination in epsilon_coaccessible_states:
                                        break
                                else:
                                    break
                            else:
                                if s == start_state:
                                    return True
                                epsilon_coaccessible_states.add(s)
                                num_new_states += 1
    return False

@functools.lru_cache()
def accepts(symbol, target_sequence, terminal=True, maximum_delta=20):
    """
    Attempt to check whether `symbol` followed by `target_sequence` is always possible or never possible.
    """
    successful_paths, uncertain_paths, incomplete_paths = find_trajectories(symbol, target_sequence, terminal=terminal, maximum_delta=maximum_delta)
    if len(successful_paths) == 0:
        if len(incomplete_paths) == 0:
            return "never legal"
        else:
            if symbol == "file_input":
                potentially_legal = not is_never_legal("", list(target_sequence))
            else:
                potentially_legal = not is_never_legal(symbol, list(target_sequence))
            if potentially_legal:
                nfa = build_nfa(successful_paths + incomplete_paths, uncertain_paths)
                if accepts_empty_word(nfa):
                    return "possibly always legal"
                else:
                    return "sometimes legal"
            else:
                return "never legal"
    nfa = build_nfa(successful_paths, uncertain_paths)
    if accepts_empty_word(nfa):
        return "always legal"
    elif len(incomplete_paths) == 0:
        return "sometimes legal"
    else:
        nfa = build_nfa(successful_paths + incomplete_paths, uncertain_paths)
        if accepts_empty_word(nfa):
            return "possibly always legal"
        else:
            return "sometimes legal"

def draw(nfa):
    G = nfa.to_networkx()
    label_replacements = {
        'ɛ': '&#949;',   # Replace epsilon with e
    }

    for u, v, data in G.edges(data=True):
        if "label" in data:
            if data['label'] in label_replacements:
                data['label'] = label_replacements[data['label']]

    A = to_agraph(G)
    A.graph_attr['rankdir'] = 'LR'  # Left to right layout

    with io.BytesIO() as buffer:
        A.draw("/content/automaton.png", format='png', prog='dot')
        A.draw(buffer, format='png', prog='dot')
        buffer.seek(0)
        img = plt.imread(buffer)

    plt.figure(figsize=(20,20))
    plt.imshow(img)
    plt.axis('off')
    plt.show()

# Mask store streamlining

In [0]:
%%time
if False:#"configuration_maps.pkl" in os.listdir():
    with open("configuration_maps.pkl", "rb") as f:
        (
            configuration_map,
            configuration_map2,
            configuration_map3,
            configuration_map4,
            possibly_always_legal,
        ) = pickle.load(f)
else:
    unique_configurations = set()
    configuration_map = {}
    configuration_map2 = {}
    configuration_map3 = {}
    configuration_map4 = {}
    for state in token_nfa.next_tokens:
        terminal = state2terminal[state]
        for (target_sequence, new_characters) in token_nfa.next_tokens[state]:
            unique_configurations.add((terminal, (target_sequence, new_characters)))

    # Exploit terminal interchangeability to merge equivalent configurations
    for (terminal, (target_sequence, new_characters)) in unique_configurations:
        terminal2 = terminal_replacement[terminal] if terminal in terminal_replacement else terminal
        target_sequence2 = [terminal_replacement[t] if t in terminal_replacement else t for t in target_sequence]
        configuration_map[(terminal, (target_sequence, new_characters))] = (terminal2, (tuple(target_sequence2), new_characters))

    # Use 1-hop successor relationships between terminals to remove impossible configurations
    for (terminal, (target_sequence, new_characters)) in set(configuration_map.values()):
        if len(target_sequence) == 0:
            configuration_map2[(terminal, (target_sequence, new_characters))] = (state, (target_sequence, new_characters))
            continue
        if terminal in parser.lexer_conf.ignore:
            to_test = target_sequence
        elif terminal == "":
            if target_sequence[0] not in next_potential_terminals[("file_input", "variable_start")]:
                configuration_map2[(terminal, (target_sequence, new_characters))] = None
                continue
            else:
                to_test = target_sequence
        else:
            to_test = (terminal,) + target_sequence
        for i in range(len(to_test) - 1):
            if to_test[i+1] not in next_potential_terminals[(to_test[i], "terminal")]:
                configuration_map2[(terminal, (target_sequence, new_characters))] = None
                break
        else:
            configuration_map2[(terminal, (target_sequence, new_characters))] = (terminal, (target_sequence, new_characters))

    # Use the always legal (terminal, [terminal2]) configurations to merge equivalent configurations
    # Use the never legal (terminal, [terminal2]) configurations to remove impossible configurations
    
    pairs_of_terminals = defaultdict(set)
    for x in terminals:
        terminal = terminal_replacement[x.value]
        for terminal2 in next_potential_terminals[(terminal, "terminal")]:
            terminal2 = terminal_replacement[terminal2]
            pairs_of_terminals[accepts(terminal, (terminal2,))].add((terminal, terminal2))
    
    for x in set(configuration_map2.values()):
        if x is None:
            continue
        (terminal, (target_sequence, new_characters)) = x
        if len(target_sequence) == 0:
            configuration_map3[(terminal, (target_sequence, new_characters))] = (terminal, (target_sequence, new_characters))
            continue
        if terminal in parser.lexer_conf.ignore or terminal == "":
            to_test = target_sequence
        else:
            to_test = (terminal,) + target_sequence
        if len(to_test) == 1:
            configuration_map3[(terminal, (target_sequence, new_characters))] = (terminal, (target_sequence, new_characters))
            continue
        for i in range(len(to_test) - 1):
            if (to_test[i], to_test[i+1]) in pairs_of_terminals["never legal"]:
                configuration_map3[(terminal, target_sequence)] = None
                continue
        i = len(to_test) - 2
        while True:
            if (to_test[i], to_test[i+1]) in pairs_of_terminals["always legal"]:
                to_test = to_test[:i+1]
            else:
                break
            i = i - 1
            if i < 0:
                break
        if terminal in parser.lexer_conf.ignore or terminal == "":
            configuration_map3[(terminal, (target_sequence, new_characters))] = (terminal, (to_test, new_characters))
        else:
            configuration_map3[(terminal, (target_sequence, new_characters))] = (terminal, (to_test[1:], new_characters))

    # Remove "always legal" and "never legal" configurations
    possibly_always_legal = set()
    i = 0
    counter = defaultdict(int)
    for x in set(configuration_map3.values()):
        if x is None:
            continue
        (terminal, (target_sequence, new_characters)) = x
        i += 1
        if len(target_sequence) == 0:
            configuration_map4[(terminal, (target_sequence, new_characters))] = (terminal, (target_sequence, new_characters))
            continue
        if terminal == "":
            accepted = accepts("file_input", target_sequence, terminal=False, maximum_delta=2)
        elif terminal in parser.lexer_conf.ignore:
            if len(target_sequence) == 1:
                configuration_map4[(terminal, (target_sequence, new_characters))] = (terminal, (target_sequence, new_characters))
                continue
            accepted = accepts(target_sequence[0], target_sequence[1:], maximum_delta=2)
        else:
            accepted = accepts(terminal, target_sequence, maximum_delta=2)
        counter[accepted] += 1
        if accepted == "never legal":
            configuration_map4[(terminal, (target_sequence, new_characters))] = None
            continue
        elif accepted == "always legal":
            if terminal in parser.lexer_conf.ignore:
                configuration_map4[(terminal, (target_sequence, new_characters))] = (terminal, (target_sequence[:1], new_characters))
            else:
                configuration_map4[(terminal, (target_sequence, new_characters))] = (terminal, ((), new_characters))
        elif accepted == "possibly always legal":
            configuration_map4[(terminal, (target_sequence, new_characters))] = (terminal, (target_sequence, new_characters))
            possibly_always_legal.add((terminal, (target_sequence, new_characters)))
        elif accepted == "sometimes legal":
            configuration_map4[(terminal, (target_sequence, new_characters))] = (terminal, (target_sequence, new_characters))
        elif accepted == "possibly never legal":
            if not is_never_legal(terminal, list(target_sequence)):
                configuration_map4[(terminal, (target_sequence, new_characters))] = (terminal, (target_sequence, new_characters))
                possibly_always_legal.add((terminal, (target_sequence, new_characters)))
            else:
                counter["confirmed never legal"] += 1
                configuration_map4[(terminal, (target_sequence, new_characters))] = None
                continue

    with open("configuration_maps.pkl", "wb") as f:
        pickle.dump([
            configuration_map,
            configuration_map2,
            configuration_map3,
            configuration_map4,
            possibly_always_legal
        ], f)

print(
    len(configuration_map.keys()),
    len(set(configuration_map.values())),
    len(set(configuration_map2.values())),
    len(set(configuration_map3.values())),
    len(set(configuration_map4.values()))
)

In [0]:
def add_streamlined_mask(translation_map):
    token_nfa.streamlined_mask = {}
    for s in token_nfa.mask:
        terminal = state2terminal[s]
        token_nfa.streamlined_mask[s] = {}
        for path in token_nfa.mask[s]:
            if path is None:
                token_nfa.streamlined_mask[s][path] = token_nfa.mask[s][path]
            else:
                new_path = translation_map[(terminal, path)]
                if new_path is None:
                    continue
                new_path = new_path[1]
                if new_path not in token_nfa.streamlined_mask[s]:
                    token_nfa.streamlined_mask[s][new_path] = token_nfa.mask[s][path]
                else:
                    token_nfa.streamlined_mask[s][new_path] = torch.bitwise_or(
                        token_nfa.streamlined_mask[s][new_path],
                        token_nfa.mask[s][path]
                    )

translation_map = {}
for k in configuration_map:
    k2 = configuration_map[k]
    translation_map[k] = k2

add_streamlined_mask(translation_map)    
x1 = [len(token_nfa.streamlined_mask[s]) for s in token_nfa.streamlined_mask]

for k in configuration_map:
    k2 = configuration_map[k]
    if k2 is None:
        translation_map[k] = None
        continue
    k3 = configuration_map2[k2]
    translation_map[k] = k3

add_streamlined_mask(translation_map)
x2 = [len(token_nfa.streamlined_mask[s]) for s in token_nfa.streamlined_mask]
    
for k in configuration_map:
    k2 = configuration_map[k]
    if k2 is None:
        translation_map[k] = None
        continue
    k3 = configuration_map2[k2]
    if k3 is None:
        translation_map[k] = None
        continue
    k4 = configuration_map3[k3]
    translation_map[k] = k4

add_streamlined_mask(translation_map)
x3 = [len(token_nfa.streamlined_mask[s]) for s in token_nfa.streamlined_mask]

for k in configuration_map:
    k2 = configuration_map[k]
    if k2 is None:
        translation_map[k] = None
        continue
    k3 = configuration_map2[k2]
    if k3 is None:
        translation_map[k] = None
        continue
    k4 = configuration_map3[k3]
    if k4 is None:
        translation_map[k] = None       
        
    k4 = configuration_map3[k3]
    if k4 is None:
        translation_map[k] = None
        continue
    k5 = configuration_map4[k4]
    translation_map[k] = k5

add_streamlined_mask(translation_map)
x4 = [len(token_nfa.streamlined_mask[s]) for s in token_nfa.streamlined_mask]
sum([len(token_nfa.mask[s]) for s in token_nfa.mask]), sum(x1), sum(x2), sum(x3), sum(x4)

# Experiments

## Experiment 1
We tokenize a series of Python files from four GitHub repositories (more than one million tokens in total) and for each token, we compute the masks thanks to the streamlined mask store and we check that the next token is indeed allowed by the mask.

In [0]:
if "python_files" not in os.listdir():
    import autopep8
    import sys

    def is_ascii(s):
        try:
            s.encode('ascii')
        except UnicodeEncodeError:
            return False
        return True

    def normalize_file_indentation(root, filename, target):
        # Read the content of the file
        with open(os.path.join(root, filename), 'r') as file:
            original_content = file.read()

        if not is_ascii(original_content):
            print(f"{filename} skipped (because it contains non-ASCII characters)")
        else:
            # Use autopep8 to format the file content
            formatted_content = autopep8.fix_code(original_content, options={
                'indent_size': 4,
                'aggressive': 2,
                'experimental': True
            })

            # Write the formatted content back to the file
            with open(os.path.join(target, root.replace("/", "-")+"-"+filename), 'w') as file:
                file.write(formatted_content)
                print(f"{filename} processed with autopep8")

    !git clone https://github.com/uiuc-focal-lab/syncode.git
    !git clone https://github.com/sgl-project/sglang.git
    !git clone https://github.com/dottxt-ai/outlines.git
    !git clone https://github.com/eth-sri/lmql.git
    os.mkdir("python_files")
    for repo in ["syncode", "sglang", "outlines", "lmql"]:
        for root, subdirs, files in os.walk(repo):
            for f in files:
                if f.endswith(".py"):
                    normalize_file_indentation(root, f, "python_files")

In [0]:
streamlined = True
if False:#"already_checked.json" in os.listdir():
    already_checked = json.load(open("already_checked.json", "r"))
else:
    already_checked = {}

for f in os.listdir("python_files"):
    if f in already_checked and type(already_checked[f]) == int:
        continue
    with open(os.path.join("python_files", f), "r") as file:
        prompt = file.read()
        token_ids = tokenizer(prompt)["input_ids"][1:]
        
        # Skip the file if it contains some characters: `\t`, `\r`, `\f`
        for c in prompt:
            if c in ["\t", "\r", "\f"]:
                already_checked[f] = 0
                json.dump(already_checked, open("already_checked.json", "w"))
                break
        if f in already_checked:
            continue
        
        # Skip the file if it isn't synctatically correct given the Lark Python grammar
        try:
            parser.parse(prompt)
        except:
            already_checked[f] = 0
            json.dump(already_checked, open("already_checked.json", "w"))
            continue
        
        mask_generator = PythonMaskGenerator(parser, token_nfa)
        mask = mask_generator.build_mask()
        i, start = 0, time.time()
        
        for token_id in token_ids:
            # Fail if the mask is excessively restrictive
            if not mask[token_id]:
                already_checked[f] = (f, i, token_id, get_token(token_id, tokenizer), mask_generator.states)
                print("******", *already_checked[f])
                json.dump(already_checked, open("already_checked.json", "w"))
                break
            try:
                # Consume the next token id and update the mask
                mask_generator.consume(token_id)
                mask = mask_generator.build_mask(streamlined=streamlined)
            # Fail if the next terminal is not accepted by the parser
            except lark.UnexpectedToken:
                already_checked[f] = (f, i, token_id, get_token(token_id, tokenizer), mask_generator.states)
                print("******", *already_checked[f])
                json.dump(already_checked, open("already_checked.json", "w"))
            i += 1
        else:
            try:
                mask_generator.terminate()
                already_checked[f] = len(token_ids)
                json.dump(already_checked, open("already_checked.json", "w"))
            except:
                already_checked[f] = (f, i, token_id, get_token(token_id, tokenizer), mask_generator.states)
                print("******", *already_checked[f])
                json.dump(already_checked, open("already_checked.json", "w"))
                break

In [0]:
sum(already_checked[f] for f in already_checked)

## Experiment 2
We check that strings (more than one million characters in total) generated using the streamlined mask store are syntatically correct Python code.

In [0]:
parsing_outcome = {}
i = 0
while True:
    i += 1
    np.random.seed(seed=i)
    token_ids = []
    mask_generator = PythonMaskGenerator(parser, token_nfa)
    try:
        for _ in range(1000):
            mask = mask_generator.build_mask(streamlined=True)
            token_id = np.argmax(np.random.rand(32000)*np.array(mask))
            mask_generator.consume(token_id)
            token_ids.append(token_id)     
        
        prompt = tokenizer.decode(token_ids)
        length_unparsed = len("".join(mask_generator.states[0][2]))
        parser.parse_interactive(prompt[:-length_unparsed]).exhaust_lexer()
        parsing_outcome[i] = len(prompt) - length_unparsed
        json.dump(parsing_outcome, open("infinite_monkey.json", "w"))
        
        if i % 50 == 0:
            num_characters_validated = sum(parsing_outcome[x] for x in parsing_outcome if type(parsing_outcome[x]) == int)
            print(i, num_characters_validated)
            if num_characters_validated > 1_000_000:
                break
    except:
        parsing_outcome[i] = ""
        json.dump(parsing_outcome, open("infinite_monkey.json", "w"))
        print(i, tokenizer.decode(token_ids))
        break

## Experiment 3

For more than 100,000 tokens extracted from the same Python files as before, we compute the masks using both the original mask store and the streamlined mask store and we check that the resulting masks were systematically identical.

In [0]:
num_tokens = 0
for f in os.listdir("python_files"):
    if already_checked[f] > 0:
        with open(os.path.join("python_files", f), "r") as file:
            prompt = file.read()
            token_ids = tokenizer(prompt)["input_ids"][1:]   

        mask_generator = PythonMaskGenerator(parser, token_nfa)
        mask = mask_generator.build_mask(streamlined=True)
        mask2 = mask_generator.build_mask(streamlined=False)
        assert mask.equal(mask2)
        
        for token_id in token_ids:
            assert mask[token_id]
            assert mask2[token_id]
            mask_generator.consume(token_id)
            mask = mask_generator.build_mask(streamlined=True)
            mask2 = mask_generator.build_mask(streamlined=False)
            assert mask.equal(mask2)

    num_tokens += already_checked[f]
    if num_tokens > 100_000:
        print(f"{num_tokens} tokens read. No discrepancy between the mask store and the streamlined mask store")
        break