# The L* Algorithm

In [None]:
import sys
import random
import enum
import src.utils as utils
import string

In [None]:
import ipynb.fs.full.x0_3_EarleyParser as parser
import ipynb.fs.full.x0_2_GrammarFuzzer as fuzzer

## Definitions

- Input symbol: A single symbol that is consumed by the machine which can move it from one state to another. The set of such symbols is called an alphabet, and is represented by $A$
- Membership query: A string that is passed to the blackbox. The blackbox answers yes or no.
- Equivalence query: A grammar that is passed to the teacher as a hypothesis of what the target language is. The teacher answers yes or a counter example that behaves differently on the blackbox and the hypothesis grammar.
- Prefix closed: a set is prefix closed if all prefixes of any of its elements are also in the same set.
- Suffix closed: a set is suffix closed if all suffixes of any of its elements are also in the same set.
- Observation table: A table whose rows correspond to the candidate states. The rows are made up of prefix strings that can reach given states — commonly represented as $S$, but here we will denote these by $P$ for prefixes — and the columns are made up of suffix strings that serves to distinguish these states — commonly expressed as E for extensions, but we will use to denote suffixes here. The table contains auxiliary rows $ p \in P$ that extends each item with each alphabet $a \in A$ as we discuss later in closedness. This table defines the language inferred by the algorithm. The contents of the table are the answers from the oracle on a string composed of the row and column labels — prefix + suffix. That is $T[s,e] = O(s.e)$. The table has two properties: closedness and consistency. If these are not met at any time, we take to resolve it.
- The state: A state in the DFA is represented by a prefix in the observation table, and is named by the pattern of 1s and 0s in the cell contents. We represent a state corresponding the prefix $p$ as $[p]$ 
- Closedness of the observation table means that for each $p \in P$ and each $a \in A$ , the state represented by the auxiliary row $[p.a]$ (i.e., its contents) exists in $P$. That is, there is some $p' \in P$ such that $[p.a] = [p']$. The idea is that, the state $[p]$ corresponding to accepts alphabet $a$ and transitions to the state $[p']$, and $p'$ must be in the main set of rows $P$.
- Consistency of the observation table means that if two prefixes represents the same state (i.e. the contents of two rows are equal), that is $[p1] = [p2]$ then $[p1.a] = [p2.a]$ for all alphabets. The idea is that if two prefixes reach the state, then when fed any alphabet, both prefixes should transition to the same next state (represented by the pattern produced by the suffixes).
The candidate states $P$ is prefix closed, while the set of suffixes S is suffix closed.

## Observation Table

We initialize the observation table with the alphabet. We keep the table itself as an internal dict `_T`. We also keep the prefixes in `P` and suffixes in `S`. We initialize the set of prefixes `P` to be $\epsilon$ and the set of suffixes `S` also to be $\epsilon$. We also add a few utility functions.

In [None]:
class ObservationTable:
    def __init__(self, alphabet):
        self._T, self.P, self.S, self.A = {}, [''], [''], alphabet

    def cell(self, v, e): return self._T[v][e]

    def state(self, p):
        return '<%s>' % ''.join([str(self.cell(p,s)) for s in self.S])

Using the observation table

In [None]:
if __name__ == '__main__':
    alphabet = list('abcdefgh')
    o = ObservationTable(alphabet)
    o._T = {p:{'':1, 'a':0, 'ba':1, 'cba':0} for p in alphabet}
    print(o.cell('a', 'ba'))
    print(o.state('a'))


### Convert Table to Grammar

Given the observation table, we can recover the grammar from this table (corresponding to the DFA). The unique cell contents of rows are states. In many cases, multiple rows may correspond to the same state (as the cell contents are the same). 
The *start state* is given by the state that correspond to the $\epsilon$ row.

A state is accepting if it on query of $ \epsilon $ i.e. `''`, it returns 1.
 
The formal notations are as follows. The notation $ [p] $ means the state corresponding to the prefix $ p $. The notation $ [[p,s]] $ means the result of oracle for the prefix $ p $ and the suffix $ s $. The notation $ [p](a) $ means the state obtained by feeding the input symbol $ a $ to the state $ [p] $. We take the first prefix that resulted in a particular state as its *access prefix*, and we denote the access prefix of a state $ s $ by $ \lfloor{}s\rfloor $ (this is not used in this post). The following is the DFA from our table.

* states: $ Q = {[p] : p \in P} $
* start state: $ q0 = [\epsilon] $
* transition function: $ [p](a) \rightarrow [p.a] $
* accepting state: $ F = {[p] : p \in P : [[p,\epsilon]] = 1} $

For constructing the grammar from the table, we first identify all distinguished states. Next, we identify the start state, followed by
accepting states. Finally, we connect states together with transitions between them.


In [None]:
class ObservationTable(ObservationTable):
    def table_to_grammar(self):
        # Step 1: identify all distinguished states.
        prefix_to_state = {}  # Mapping from row string to state ID
        states = {}
        grammar = {}
        for p in self.P:
            stateid = self.state(p)
            if stateid not in states: states[stateid] = []
            states[stateid].append(p)
            prefix_to_state[p] = stateid

        for stateid in states: grammar[stateid] = []

        # Step 2: Identify the start state, which corresponds to epsilon row
        start_nt = prefix_to_state['']

        # Step 3: Identify the accepting states
        accepting = [prefix_to_state[p] for p in self.P if self.cell(p,'') == 1]
        if not accepting: return {'<start>': []}, '<start>'
        for s in accepting: grammar[s] = [['<_>']]
        grammar['<_>'] = [[]]

        # Step 4: Create the transition function
        for sid1 in states:
            first_such_row = states[sid1][0]
            for a in self.A:
                sid2 = self.state(first_such_row + a)
                grammar[sid1].append([a, sid2])

        return grammar, start_nt


In [None]:
if __name__ == '__main__':
    alphabet = list('ab')
    o = ObservationTable(alphabet)
    o._T = {'':    {'': 0, 'a': 1},
            'a':   {'': 1, 'a': 0},
            'b':   {'': 0, 'a': 0},
            'aa':  {'': 0, 'a': 0},
            'ab':  {'': 0, 'a': 0},
            'ba':  {'': 0, 'a': 0},
            'bb':  {'': 0, 'a': 0},
            'baa': {'': 0, 'a': 0},
            'bab': {'': 0, 'a': 0}}
    P = [k for k in o._T]
    S = [k for k in o._T['']]
    o.P, o.S = P, S
    g, s = o.table_to_grammar()
    print('start: ', s)
    for k in g:
        print(k)
        for r in g[k]:
            print(" | ", r)


### Remove infinite loops

In [None]:
import math
def symbol_cost(grammar, symbol, seen, cache):
    if symbol in seen: return float('inf')
    lst = []
    for rule in grammar.get(symbol, []):
        if symbol in cache and str(rule) in cache[symbol]:
            lst.append(cache[symbol][str(rule)])
        else:
            lst.append(expansion_cost(grammar, rule, seen | {symbol}, cache))
    v = min(lst, default=0)
    return v

# A rule costs as much as the cost of expansion of the most costliest symbol
# in that rule + 1.

def expansion_cost(grammar, tokens, seen, cache):
    return max((symbol_cost(grammar, token, seen, cache)
                for token in tokens if token in grammar), default=0) + 1

def compute_cost(grammar):
    rule_cost = {}
    for k in grammar:
        rule_cost[k] = {}
        for rule in grammar[k]:
            rule_cost[k][str(rule)] = expansion_cost(grammar, rule, set(), rule_cost)
    return rule_cost

In [None]:
class ObservationTable(ObservationTable):
    def remove_infinite_loops(self, g, start):
        rule_cost = compute_cost(g)
        remove_keys = []
        for k in rule_cost:
            if k == start: continue
            res = [rule_cost[k][r] for r in rule_cost[k]
                   if rule_cost[k][r] != float('inf')]
            if not res: remove_keys.append(k)

        cont = True
        while cont:
            cont = False
            new_g = {}
            for k in g:
                if k in remove_keys: continue
                new_g[k] = []
                for r in g[k]:
                    if [t for t in r if t in remove_keys]: continue
                    new_g[k].append(r)
                if not new_g[k]:
                    if k == start: continue
                    remove_keys.append(k)
                    cont = True
        return new_g, start

class ObservationTable(ObservationTable):
    def grammar(self):
        g, s = self.table_to_grammar()
        return self.remove_infinite_loops(g, s)

In [None]:
if __name__ == '__main__':
    o = ObservationTable(alphabet)
    o._T = {'':    {'': 0, 'a': 1},
            'a':   {'': 1, 'a': 0},
            'b':   {'': 0, 'a': 0},
            'aa':  {'': 0, 'a': 0},
            'ab':  {'': 0, 'a': 0},
            'ba':  {'': 0, 'a': 0},
            'bb':  {'': 0, 'a': 0},
            'baa': {'': 0, 'a': 0},
            'bab': {'': 0, 'a': 0}}
    o.P, o.S = P, S
    g, s = o.grammar()
    print('start: ', s)
    for k in g:
        print(k)
        for r in g[k]:
            print(" | ", r)

Now that we are convinced that we can produce a DFA or a grammar out of the table let us proceed to examining how to produce this table.

We start with the start state in the table, because we know for sure that it exists, and is represented by the empty string in row and column,
which together (prefix + suffix) is the empty string `''` or $ \epsilon $.  We ask the program if it accepts the empty string, and if it accepts, we mark  the corresponding cell in the table as *accept* (or `1`).

For any given state in the DFA, we should be able to say what happens when an input symbol is fed into the machine in that state. So, we can extend the    table with what happens when each input symbol is fed into the start state.  This means that we extend the table with rows corresponding to each symbol in the input alphabet.

So, we can initialize the table as follows. First, we check whether the empty string is in the language. Then, we extend the table `T` to `(P u P.A).S`    using membership queries. This is given in `update_table()`


In [None]:
class ObservationTable(ObservationTable):
    def init_table(self, oracle):
        self._T[''] = {'': oracle.is_member('') }
        self.update_table(oracle)

The update table has two parts. First, it takes the current set of prefixes (`rows`) and determines the auxiliary rows to compute based on extensions of   the current rows with the symbols in the alphabet (`auxrows`). This gives the complete set of rows for the table. Then, for each suffix in `S`, ensure     that the table has a cell, and it is updated with the oracle result.

In [None]:
class ObservationTable(ObservationTable):
    def update_table(self, oracle):
        def unique(l): return list({s:None for s in l}.keys())
        rows = self.P
        auxrows = [p + a for p in self.P for a in self.A]
        PuPxA = unique(rows + auxrows)
        for p in PuPxA:
            if p not in self._T: self._T[p] = {}
            for s in self.S:
                if p in self._T and s in self._T[p]: continue
                self._T[p][s] = oracle.is_member(p + s)

In [None]:
if __name__ == '__main__':
    o = ObservationTable(alphabet)
    def orcl(): pass
    orcl.is_member = lambda x: 1
    o.init_table(orcl)
    for p in o._T: print(p, o._T[p])

While doing this, there is one requirement we need to ensure. The result of transition from every state for every alphabet needs to be defined.  The property that ensures this for the observation table is called *closedness* or equivalently, the observation table is *closed* if the table has the following property.

### Closed
The idea is that for every prefix we have, in set $ P $, we need to find the state that is reached for every $ a \in A $. Then, we need to make sure that the *state* represented by that   prefix exists in $ P $. (If such a state does not exist in P, then it means that we have found a new state).

Formally: An observation table $ P \times S $ is closed if for each $ t \in P·A $ there exists a $ p \in P $ such that $ [t] = [p] $

In [None]:
class ObservationTable(ObservationTable):
    def closed(self):
        states_in_P = {self.state(p) for p in self.P}
        P_A = [p+a for p in self.P for a in self.A]
        for t in P_A:
            if self.state(t) not in states_in_P: return False, t
        return True, None


In [None]:
if __name__ == '__main__':
    def orcl(): pass
    orcl.is_member = lambda x: 1 if x in ['a'] else 0

    ot = ObservationTable(list('ab'))
    ot.init_table(orcl)
    for p in ot._T: print(p, ot._T[p])

    res, counter = ot.closed()
    assert not res
    print(counter)

### Add prefix

In [None]:
class ObservationTable(ObservationTable):
    def add_prefix(self, p, oracle):
        if p in self.P: return
        self.P.append(p)
        self.update_table(oracle)

In [None]:
if __name__ == '__main__':
    def orcl(): pass
    orcl.is_member = lambda x: 1 if x in ['a'] else 0

    ot = ObservationTable(list('ab'))
    ot.init_table(orcl)
    res, counter = ot.closed()
    assert not res

    ot.add_prefix('a', orcl)
    for p in ot._T: print(p, ot._T[p])
    res, counter = ot.closed()
    assert res

This is essentially the intuition behind most of the grammar inference algorithms, and the cleverness lies in how the suffixes are chosen. In the case of L\*, when we find that one of the transitions from the current states result in a new state, we add the alphabet that caused the transition from the current state and the suffix that distinguished the new state to the     suffixes (i.e, a + suffix is added to the columns).

This particular aspect is governed by the *consistence* property of the observation table.

### Consistent

An observation table $ P \times S $ is consistent if, whenever $ p1 $ and $ p2 $ are elements of P such that $ [p1] = [p2] $, for each $ a \in A $, $ [p1.a] = [p2.a] $.  *If* there are    two rows in the top part of the table repeated, then the corresponding suffix results should be the same.  If not, we have found a counter example. So we report the alphabet and the       suffix that distinguished the rows. We will then add the new string (a + suffix) as a new suffix to the table.


In [None]:
class ObservationTable(ObservationTable):
    def consistent(self):
        matchingpairs = [(p1, p2) for p1 in self.P for p2 in self.P
                         if p1 != p2 and self.state(p1) == self.state(p2)]
        suffixext = [(a, s) for a in self.A for s in self.S]
        for p1,p2 in matchingpairs:
            for a, s in suffixext:
                if self.cell(p1+a,s) != self.cell(p2+a,s):
                        return False, (p1, p2), (a + s)
        return True, None, None

### Add suffix

In [None]:
class ObservationTable(ObservationTable):
    def add_suffix(self, a_s, oracle):
        if a_s in self.S: return
        self.S.append(a_s)
        self.update_table(oracle)

In [None]:
if __name__ == '__main__':
    def orcl(): pass
    orcl.is_member = lambda x: 1 if x in ['a'] else 0

    ot = ObservationTable(list('ab'))
    ot.init_table(orcl)
    is_closed, counter = ot.closed()
    assert not is_closed
    ot.add_prefix('a', orcl)
    ot.add_prefix('b', orcl)
    ot.add_prefix('ba', orcl)
    for p in ot._T: print(p, ot._T[p])

    is_closed, unknown_P = ot.closed() 
    print(is_closed)

    is_consistent,_, unknown_A = ot.consistent() 
    assert not is_consistent

    ot.add_suffix('a', orcl)
    for p in ot._T: print(p, ot._T[p])

    is_consistent,_, unknown_A = ot.consistent() 
    assert is_consistent

(Of course readers will quickly note that the table is not the best data structure here, and just because a suffix distinguished two particular states does not mean that it is a good idea to evaluate the same suffix on all other states. These are ideas that will be explored in later algorithms).


Finally, L\* also relies on a *Teacher* for it to suggest new suffixes that can distinguish unrecognized states from current ones.

## Teacher

We now construct our teacher. We have two requirements for the teacher.  The first is that it should fulfil the requirement for Oracle. That is, it should answer `is_member()` queries.    Secondly, it should also answer `is_equivalent()` queries.

First, we define the oracle interface.


In [None]:
class Oracle:
    def is_member(self, q): pass

We define a simple teacher based on regular grammars. That is, if you give it a regular grammar, will convert it to an acceptor based on a parser and a generator based on a grammar fuzzer, and will then use it for verification of hypothesis grammars.

### PAC Learning

PAC learning was introduced by Valiant in 1984 as a way to think about inferred models in computational linguistics and machine learning.  The basic idea is that given a    blackbox model, we need to be able to produce samples that can then be tested against the model to construct an inferred model (i.e, to train the model). For sampling during training, we  have to assume some sampling procedure, and hence a distribution for training.  Per PAC learning, we can only guarantee the performance of the learned model when tested using samples from the same distribution. Given that we are sampling from a distribution, there is a possibility that due to non-determinism, the data is not as spread out as we may like, and hence the      training data is not optimal by a certain probability. This reflects on the quality of the model learned. This is indicated by the concept of confidence intervals, and indicated by the $  \delta $ parameter. That is, $ 1 - \delta $ quantifies the confidence we have in our model. Next, given any training data, due to the fact that the training data is finite, our grammar    learned is an approximation of the real grammar, and there will always be an error term. This error is quantified by the $ \epsilon $ parameter. Given the desired $ \delta $ and $         \epsilon $ Angluin provides a formula to compute the number of calls to make to the membership oracle at the $ i^{th} $ equivalence query.

$$ n=\lceil\frac{1}{\epsilon}\times log(\frac{1}{\delta}+i\times log(2))\rceil $$

In essence the PAC framework says that there is $ 1 - \delta  $ probability that the model learned will be approximately correct. That is, it will classify samples with an error rate less than $ \epsilon $.


In [None]:
class Teacher(Oracle):
    def is_equivalent(self, grammar, start): assert False

We input the PAC parameters delta for confidence and epsilon for accuracy

In [None]:
class Teacher(Teacher):
    def __init__(self, re_g, re_s, delta=0.5, epsilon=0.5):
        self.g, self.s = re_g, re_s
        self.parser = parser.EarleyParser(self.g)
        self.sampler = fuzzer.LimitFuzzer(self.g)
        self.equivalence_query_counter = 0
        self.delta, self.epsilon = delta, epsilon

We can define the membership query `is_member()` as follows:

In [None]:
class Teacher(Teacher):
    def is_member(self, q):
        try: list(self.parser.recognize_on(q, self.s))
        except: return 0
        return 1

Given a grammar, check whether it is equivalent to the given grammar.
The PAC guarantee is that we only need `num_calls` for the `i`th equivalence
query. For equivalence check here, we check for strings of length 1, then
length 2 etc, whose sum should be `num_calls`. We take the easy way out here,
and just use `num_calls` as the number of calls for each string length.
We have what is called a *cooperative teacher*, that tries to respond with
a shortest possible counter example. We # also take the easy way out and only
check for a maximum length of 10.

In [None]:
class Teacher(Teacher):
    def is_equivalent(self, grammar, start, max_length_limit=10):
        self.equivalence_query_counter += 1
        num_calls = math.ceil(1.0/self.epsilon *
                  (math.log(1.0/self.delta) +
                              self.equivalence_query_counter * math.log(2)))

        for limit in range(1, max_length_limit):
            is_eq, counterex, c = self.is_equivalent_for(self.g, self.s,
                                                    grammar, start, num_calls)
            if counterex is None: # no members of length limit
                continue
            if not is_eq:
                c = [a for a in counterex if a is not None][0]
                return False, c
        return True, None


We need to remove epsilon tokens from places other than the start rule to make the grammar well behaved.

In [None]:
class Teacher(Teacher):
    def digest_grammar(self, g, s):
        if not g[s]: return 0, None, None
        rgf = fuzzer.LimitFuzzer(g)
        ep = parser.EarleyParser(g)
        return ep, rgf

    def gen_random(self, f, s):
        return f.fuzz(s)

## Check Grammar Equivalence
Checking if two grammars are equivalent to a length of string for n count.

In [None]:
class Teacher(Teacher):
    def is_equivalent_for(self, g1, s1, g2, s2, n):
        ep1, lf1 = self.digest_grammar(g1, s1)
        ep2, lf2 = self.digest_grammar(g2, s2)
        count = 0

        str1 = {self.gen_random(lf1, s1) for _ in range(n)}
        str2 = {self.gen_random(lf2, s2) for _ in range(n)}

        for st1 in str1:
            if st1 is None: continue
            count += 1
            try:
                #print('st1', st1, ' recognize on s2:', s2)
                list(ep2.recognize_on(st1, s2))
            except Exception as e:
                return False, (st1, None), count

        for st2 in str2:
            if st2 is None: continue
            count += 1
            try:
                #print('st2', st1, ' recognize on s1:', s1)
                list(ep1.recognize_on(st2, s1))
            except Exception as e:
                return False, (None, st2), count

        return True, None, count

In [None]:
if __name__ == '__main__':
    g1 = { # should end with one.
            '<0>': [
                ['1', '<1>'],
                ['0', '<0>']
                ],
            '<1>':[
                ['1', '<1>'],
                []
                ]
    }
    g2 = { # should end with one.
            '<0>': [
                ['1', '<1>'],
                ['0', '<1>'] # changed here.
                ],
            '<1>':[
                ['1', '<1>'],
                []
                ]
    }
    s1 = '<0>'
    s2 = '<0>'
    

In [None]:
if __name__ == '__main__':
    t = Teacher(g1, '<0>')
    parse1, fuzz1 = t.digest_grammar(g1, s1)
    parse2, fuzz2 = t.digest_grammar(g2, s2)
    #for v in parse2.parse_on('000111', s1):
    #    print(utils.tree_to_str(v))

In [None]:
if __name__ == '__main__':
    v = t.is_equivalent_for(g1, '<0>', g2, '<0>', 10)
    print(v)

In [None]:
my_str = fuzz2.fuzz(s2)
print(my_str)
for v in parse2.parse_on(my_str, s2):
    print(utils.tree_to_str(v))
    utils.display_tree(v)


## L star main loop

Given the observation table and the teacher, the algorithm itself is simple. The L* algorithm loops, doing the following operations in sequence. (1) keep the table closed, (2) keep the table consistent, and if it is closed and consistent (3) ask the teacher if the corresponding hypothesis grammar is correct.

In [None]:
def l_star(T, teacher):
    T.init_table(teacher)

    while True:
        print('+')
        while True:
            print('.', end='', flush=True)
            is_closed, unknown_P = T.closed()
            is_consistent, _, unknown_AS = T.consistent()
            if is_closed and is_consistent: break
            if not is_closed: T.add_prefix(unknown_P, teacher)
            if not is_consistent: T.add_suffix(unknown_AS, teacher)

        grammar, start = T.grammar()
        eq, counterX = teacher.is_equivalent(grammar, start)
        if eq: return grammar, start
        print('Equivalence query: counter', counterX)
        for i,_ in enumerate(counterX): T.add_prefix(counterX[0:i+1], teacher)

In [None]:
if __name__ == '__main__':
    import string
    e = '(ab|cd|ef)*'
    eg = {
          '<X>': [['a', 'b'], ['c', 'd'], ['e', 'f']],
          '<Xs>': [['<X>', '<Xs>'], []],
          '<start>': [['<Xs>']],
    }

    es = '<start>'
    teacher = Teacher(eg,es)
    tbl = ObservationTable(list(string.ascii_letters))
    g, s = l_star(tbl, teacher)

    gf = fuzzer.LimitFuzzer(g)
    for i in range(10):
        res = gf.fuzz(key=s, max_depth=100)
        print(res)

In [None]:
import fuzzingbook.GrammarMiner as miner

In [None]:
miner.syntax_diagram(utils.to_fuzzingbook_grammar(g))

# Done

In [None]:
#%tb