# SLR Parser

This notebook is an implementation of SLR(1) parser - a bottom-up parser that constructs a parse tree by starting from the input tokens and working towards the start symbol of the grammar. The parser employs first and follow sets to determine valid lookahead tokens.

In [3]:
import copy

In [4]:
class SLRParser:
    """
    An implementation of SLR(1) parser. This parser constructs parsing tables
    and processes context-free grammars using a bottom-up approach with
    first and follow set to determine valid lookahead tokens
    """
    def __init__(self, input_grammar, start):
        """
        Initialize the LRParser with a given grammar.

        Args:
            input_grammar (dict): A dictionary defining the context-free grammar (CFG).
        """
        # Initialize parameters of the CFG
        self.grammar = {}
        self.start = start
        self.terminals = []
        self.non_terminals = []
        self.dot = "·"

        self.formattingGrammar(input_grammar)
        
        self.first_table = {}
        self.follow_table = {}
        self.in_progress = set()     # this variable is used to avoid left recursive when calculating first
        self.calculateFirstTable()
        self.calculateFollowTable()
        
        self.augmented_rules = []    # format of rule: [rhs, [<lhs symbol>, <lhs symbol>, ...]
        self.state_map = {}          # store rules of a state (format: state_count: [[rule1], [rule2], ...])
        self.state_dict = {}         # store which state go to which state
        self.state_count = 0
        self.addDot()
        self.generateStates()

        self.parse_table = []
        self.createParseTable()


    def formattingGrammar(self, input_grammar):
        """
        Processes the input grammar into an internal representation for the parser.

        This method reformats the provided input grammar into a format suitable 
        for parsing and initializes the grammar rules, start symbol, non-terminals, 
        and terminals. The first rule in the input grammar is augmented to create 
        a new start rule.

        Args:
            input_grammar (dict): The input grammar represented as a dictionary 
                where keys are non-terminals and values are lists of production rules.

        Attributes Modified:
            self.grammar (dict)
            self.start (str)
            self.non_terminals (list)
            self.terminals (list)
        """
        # Process the input grammar into a dictionary with each rule have the format of
        # key: rulenumber (int) 
        # value: [lhs (str), rhs (list of symbol)]
        count = 0
        for key in input_grammar.keys():
            # Augment the first rule
            if count == 0:
                if self.start == "":
                    self.start = f"{key}'"
                else:
                    self.start = f"{self.start}'"
                self.grammar[0] = (self.start, [key])
                count += 1

            # Process each rule into the format above
            for rule in input_grammar[key]:
                self.grammar[count] = (key, rule)
                count += 1

        # Detecting terminals and non-terminals symbols if it was not given
        if len(self.terminals) == 0 and len(self.non_terminals) == 0:
            for key in self.grammar.keys():
                lhs, rhs = self.grammar[key]
                self.non_terminals.append(lhs)
            for key in self.grammar.keys():
                lhs, rhs = self.grammar[key]
                for sym in rhs:
                    if sym not in self.non_terminals and sym not in self.terminals:
                        self.terminals.append(sym)


    def addDot(self):
        """
        Adding a dot (·) (tracker of process of parsing) at the start of each production's RHS.
        """
        for key in self.grammar.keys():
            lhs, rhs = self.grammar[key]
            new_rhs = [self.dot]
            for elem in rhs:
                new_rhs.append(elem)
            self.augmented_rules.append([lhs, new_rhs])

    def generateStates(self):
        """
        Generate all states for the parser, starting with the initial state.
        """
        # generate and calculate the closure of the initial state I_0
        first_state = []
        for rule in self.augmented_rules:
            if rule[0] == self.start:
                first_state.append(rule)
        closure_rules = self.findClosure(first_state)
        self.state_dict[0] = closure_rules

        # generate states until no more state is able to be generated
        prev_len = -1
        state_completed_GOTO = []
        while prev_len != len(self.state_dict):
            prev_len = len(self.state_dict)

            keys = list(self.state_dict.keys())
            for state in keys:
                if state not in state_completed_GOTO:
                    self.computeGOTO(state)
                    state_completed_GOTO.append(state)

    
    def computeGOTO(self, state):
        """
        Check and manage states that need to compute GOTO transitions

        Args:
            state (int): The state number.
        """
        generate_new_state_for = []
        for rule in self.state_dict[state]:
            # if the rule ends with dot (can't shift anymore) => skip
            if rule[1][-1] == self.dot:
                continue

            dot_ind = rule[1].index(self.dot)
            next_sym = rule[1][dot_ind+1]

            if next_sym not in generate_new_state_for:
                generate_new_state_for.append(next_sym)

        for sym in generate_new_state_for:
            self.GOTO(state, sym)

    
    def GOTO(self, state, sym):
        """
        Compute the GOTO transitions for a given state.

        Args:
            state (int): The current state.
            sym (str): The grammar symbol.
        """
        new_state = []
        for rule in self.state_dict[state]:
            # if the rule ends with dot (can't shift anymore) => skip
            if rule[1][-1] == self.dot:
                continue

            dot_ind = rule[1].index(self.dot)
            next_sym = rule[1][dot_ind+1]

            # shift operation from the previous state of rule on that
            if next_sym == sym:
                # swap dot with next_sym
                shifted_rule = copy.deepcopy(rule)
                shifted_rule[1][dot_ind] = shifted_rule[1][dot_ind + 1]
                shifted_rule[1][dot_ind + 1] = self.dot
                new_state.append(shifted_rule)

        closure_rules = self.findClosure(new_state)

        # check if state exist
        state_exists = -1
        for state_num in self.state_dict:
            if self.state_dict[state_num] == new_state:
                state_exists = state_num
                break
     
        # stateMap is a mapping of GOTO with
        # its output states
        if state_exists == -1:
            self.state_count += 1
            self.state_dict[self.state_count] = closure_rules
            self.state_map[(state, sym)] = self.state_count
        else:
            self.state_map[(state, sym)] = state_exists
            

    def findClosure(self, closure_rules):
        """
        Generate the closure for a rules.

        Args:
            closure_rules (list): A list of rules for which the closure is to be computed.

        Returns:
            list: The closure of the input rules.
        """

        # generate closure for the rules in new_state
        # generate until can't generate anymore
        # start with the closure of the initial state
        prev_len = -1
        while prev_len != len(closure_rules):
            prev_len = len(closure_rules)
            for rule in closure_rules:
                if rule[1][-1] == self.dot:
                    continue
                    
                dot_ind = rule[1].index(self.dot)
                next_sym = rule[1][dot_ind+1]
    
                # if next_sym is non_terminal then continue adding rule with that nonterminals as lhs
                if next_sym in self.non_terminals:
                    for augmented_rule in self.augmented_rules:
                        if augmented_rule[0] == next_sym and augmented_rule not in closure_rules:
                            closure_rules.append(augmented_rule)
        return closure_rules

        
    def calculateFirstTable(self):
        """
        Compute the FIRST table for the grammar.
        """
        for key in self.grammar.keys():
            rule = self.grammar[key]
            lhs, rhs = rule

            if lhs not in self.first_table:
                self.first_table[lhs] = list(elem for elem in self.first(rule))
            else:
                res = self.first(rule)
                for elem in res:
                    if elem not in self.first_table[lhs]:
                        self.first_table[lhs].append(elem)

    
    def calculateFollowTable(self):
        """
        Compute the FOLLOW table for the grammar.
        """
        for nt in self.non_terminals:
            self.follow_table[nt] = self.follow(nt)

    
    def first(self, rule):
        """
        Compute the FIRST set for a given rule.

        Args:
            rule (tuple): A tuple (LHS, RHS) representing a grammar rule.

        Returns:
            list: The FIRST set for the rule.
        """
        lhs, rhs = rule
        
        if lhs in self.in_progress:
            return []  # prevent infinite recursion
        
        # mark this non-terminal as being processed
        self.in_progress.add(lhs)
        
        # rule for terminals
        if rhs[0] in self.terminals:
            return [rhs[0]]
            
        # rule for epsilon
        elif rhs[0] == "#":
            return ["#"]
            
        # rule for non-terminal
        else:
            res = []
            for key in self.grammar.keys():
                if rhs[0] == self.grammar[key][0]:
                    for elem in self.first(self.grammar[key]):
                        res.append(elem) 

            if "#" in res:
                res.remove("#")
                
            self.in_progress.remove(lhs)  # finished processing this non-terminal
            return res

    
    def follow(self, nt, visited=None):
        """
        Compute the FOLLOW set for a non-terminal.

        Args:
            nt (str): The non-terminal symbol.
            visited (set): A set of visited non-terminals to prevent infinite recursion.

        Returns:
            list: The FOLLOW set for the non-terminal.
        """
        if visited is None:
            visited = set()
    
        if nt in visited:
            return []

        visited.add(nt)
        res = set()

        # for start symbol return $
        if nt == self.start:
            res.add("$")

        for key in self.grammar.keys():
            lhs, rhs = self.grammar[key]
            
            for i, symbol in enumerate(rhs):
                if symbol == nt:
                    rhs = rhs[i + 1:]

                    # rule 2: there is a symbol after nt
                    if len(rhs) != 0:
                        # if the symbol after nt is also a non-terminal:
                        #   - calculate its first (remove epsilon) and add to res
                        #   - if its first contain epsilon, then continue checking the next symbol
                        # else the symbol after nt is a terminal:
                        #   - then add it to res
                        for sym in rhs:
                            if sym in self.terminals:
                                res.add(sym)
                                break
                            elif sym in self.first_table:
                                first_sym = self.first_table[sym]
                                res.update(set(first_sym) - {"#"})
    
                                if "#" in first_sym:
                                    res.remove("#")
                                else:
                                    break

                    # rule 3: there is no symbol after nt -> FOLLOW(lhs) ⊆ FOLLOW(nt)
                    if len(rhs) == 0:  
                        if lhs != nt:
                            res.update(self.follow(lhs, visited))
                            
        visited.remove(nt)
        return list(res)

    def createParseTable(self):
        """
        Create the parsing table for the SLR(1) parser.
        """
        rows = list(self.state_dict.keys())
        cols = self.terminals + ["$"] + self.non_terminals

        # create empty table
        temp_row = []
        for i in range(len(cols)):
            temp_row.append([])
        for i in range(len(rows)):
            self.parse_table.append(copy.deepcopy(temp_row))

        # add shift and goto entries to table
        for entry in self.state_map.keys():
            state = entry[0]
            sym = entry[1]

            row_ind = rows.index(state)
            col_ind = cols.index(sym)

            if sym in self.terminals:
                self.parse_table[row_ind][col_ind].append(f"S{self.state_map[entry]}")
            elif sym in self.non_terminals:
                self.parse_table[row_ind][col_ind].append(f"G{self.state_map[entry]}")

        # add reduce to table
        for state in self.state_dict.keys():
            for rule in self.state_dict[state]:
                # if the rule is a handle -> add reduce correspondingly
                if rule[1][-1] == self.dot:
                    copy_rhs = copy.deepcopy(rule[1])
                    copy_rhs.remove(self.dot)

                    # add entry R_rule_num (Reduce -> rule_num) to entry (state, follow(rhs)) in parse table
                    for rule_num in self.grammar.keys():
                        if self.grammar[rule_num][0] == rule[0] and self.grammar[rule_num][1] == copy_rhs:
                            for follow in self.follow_table[rule[0]]:
                                row_ind = rows.index(state)
                                col_ind = cols.index(follow)
                                if rule_num == 0:
                                    self.parse_table[row_ind][col_ind].append("Accept")
                                else:
                                    self.parse_table[row_ind][col_ind].append(f"R{rule_num}")

    	# printing table
        print("\nParsing table:\n")
        frmt = "{:>8}" * len(cols)
        print(" ", frmt.format(*cols), "\n")
        ptr = 0
        j = 0
        for y in self.parse_table:
            # frmt1 = "{:>8}"
            print(f"{{:>3}}".format('I'+str(j)), end="")
            for e in y:
                print(f"{{:>8}}".format("/".join(e)), end="")
            print()
            j += 1

        # saving the parse table to a csv file
        file = open("rules/parse_tables/parsetable1.csv", "w")
        file.write("state,"+",".join(cols)+"\n")
        j = 0
        for y in self.parse_table:
            line = ""
            line += f"I{j}"
            for e in y:
                line += "," + "/".join(e)
            file.write(line + "\n")
            j += 1
        file.close()

    def parse(self, input_string):
        """
        Parses the given input string using the constructed SLR parse table.

        Args:
            input_string (list): The input string represented as a list of symbols 
                (terminals) to be parsed. The end of the input is marked by "$".

        Returns:
            bool: True if the input string is successfully parsed and reaches 
                the "Accept" state; False otherwise.

        Notes: This method handles conflicts by always selecting the first operation 
        in a conflicting cell in the parse table.
        """
        # self.printResultAndGoto()
        rows = list(self.state_dict.keys())
        cols = self.terminals + ["$"] + self.non_terminals
        
        # appends "$" to indicate the end of input.
        ls_input = input_string + ["$"]
        current_char = ls_input[0]
        ls_output = []
        stack = [0]
        while True:
            if current_char not in cols:
                return False
            
            row_ind = rows.index(stack[-1])
            col_ind = cols.index(current_char)
            
            operation = self.parse_table[row_ind][col_ind]
            
            if operation == []:
                return False
                
            else:
                operation = operation[0] # just get the first operation in conflict cell
                if operation[0] == "R":
                    rule_num = int(operation[1:])
                    current_char = self.grammar[rule_num][0]
                    
                    # pop stack equal to number of char on rhs of reduce rule
                    stack_pop_count = len(self.grammar[rule_num][1])
                    stack = stack[:-stack_pop_count]

                    ls_output.append(rule_num)
                
                # goto operation
                elif operation[0] == "G":
                    stack.append(int(operation[1:]))
                    current_char = ls_input[0]  
                    
                # shift operation
                elif operation[0] == "S":
                    stack.append(int(operation[1:]))
                    ls_input.pop(0) 
                    current_char = ls_input[0]      

                # accept reached
                elif operation == "Accept":
                    return True

In [None]:
# class SLRParser(LRParser):
#     def __init__(self, grammar, terminals, non_terminals, start, dot):
#         super().__init__(grammar, terminals, non_terminals, start, dot)

#     def parse(self, input_string):
#         # self.printResultAndGoto()
#         rows = list(self.state_dict.keys())
#         cols = self.terminals + ["$"] + self.non_terminals
        
#         ls_input = input_string + ["$"]
#         current_char = ls_input[0]
#         ls_output = []
#         stack = [0]
#         while True:
#             # print(ls_input, current_char, stack)
#             # time.sleep(1)
#             if current_char not in cols:
#                 return False
            
#             row_ind = rows.index(stack[-1])
#             col_ind = cols.index(current_char)
            
#             operation = self.parse_table[row_ind][col_ind]
            
#             if operation == []:
#                 return False
                
#             else:
#                 operation = operation[0] # just get the first operation in conflict cell
#                 # print(operation)
#                 # reduce operation
#                 if operation[0] == "R":
#                     rule_num = int(operation[1:])
#                     current_char = self.grammar[rule_num][0]
                    
#                     # pop stack equal to number of char on rhs of reduce rule
#                     stack_pop_count = len(self.grammar[rule_num][1])
#                     stack = stack[:-stack_pop_count]

#                     ls_output.append(rule_num)
                
#                 # goto operation
#                 elif operation[0] == "G":
#                     stack.append(int(operation[1:]))
#                     current_char = ls_input[0]  
                    
#                 # shift operation
#                 elif operation[0] == "S":
#                     stack.append(int(operation[1:]))
#                     ls_input.pop(0) 
#                     current_char = ls_input[0]      

#                 # accept reached
#                 elif operation == "Accept":
#                     return True

    
# # Example 1 Grammar and Tables
# grammar = {
#     0: ("E'", ["E"]),                # Rule 0: E'→ E
#     1: ("E", ["E", "+", "T"]),       # Rule 1: E → E + T
#     2: ("E", ["T"]),                 # Rule 2: E → T
#     3: ("T", ["T", "*", "F"]),       # Rule 3: T → T * F
#     4: ("T", ["F"]),                 # Rule 4: T → F
#     5: ("F", ["(", "E", ")"]),       # Rule 5: F → ( E )
#     6: ("F", ["a"]),                 # Rule 6: F → a
    
# }

# terminals = ["a", "+", "*","(", ")"]
# non_terminals = ["E'", "E", "T", "F"]
# start = "E'"
# dot = '·'

# # Test the Parser
# parser = SLRParser(grammar, terminals, non_terminals, start, dot)
# input_string = list("a*a+a*a+a")
# res = parser.parse(input_string)

# print()
# if res == False:
#     print(f"Input not accepted - {''.join(input_string)}")
# else:
#     print(f"Input accepted - {''.join(input_string)}")



# # # Example 2 Grammar and Tables
# # grammar = {
# #     0: ("S'", ["S"]),
# #     1: ("S", ["L", "=", "R"]),    # Rule 1: S → L = R
# #     2: ("S", ["R"]),              # Rule 2: S → R
# #     3: ("L", ["*", "R"]),         # Rule 3: L → * R
# #     4: ("L", ["a"]),              # Rule 4: L → a
# #     5: ("R", ["L"]),              # Rule 5: R → L
# # }

# # terminals = ["a", "=", "*"]
# # non_terminals = ["S'", "S", "L", "R"]
# # start = "S'"
# # dot = '·'

# # # Test the Parser
# # parser = SLRParser(grammar, terminals, non_terminals, start, dot)
# # # input_string = list("a=a")

# # # parser.parse(input_string)




## Testing - K-path coverage

In [None]:
import simplefuzzer as fuzzer

def parents(g):
    parent = {}
    for k in g:
        for r in g[k]:
            for t in r:
                if t not in g: continue
                if t not in parent: parent[t] = set()
                parent[t].add(k)
    return parent


def _k_paths(g, k, parent):
    if k == 1: return [[k] for k in g]
    _k_1_paths = _k_paths(g, k-1, parent)
    # attach parents to each of the _k_1_paths.
    new_paths = []
    for path in _k_1_paths:
        if path[0] not in parent: continue
        for p in parent[path[0]]:
            new_paths.append([p] + path)
    return new_paths


def k_paths(g, k):
    g_parents = parents(g)
    return _k_paths(g, k, g_parents)


def find_rule_containing_key(g, key, root):
    leaf = root[0]
    for rule in g[key]:
        r = []
        while rule:
            token, *rule = rule
            if leaf != token:
                r.append((token, None))
            else:
                return r + [root] + [(t, None) for t in rule]
    assert False


def path_to_tree(path_, g):
    leaf, *path = reversed(path_)
    root = (leaf, [])
    # take the lowest
    while path:
        leaf, *path = path
        if not path: return root
        rule = find_rule_containing_key(g, leaf, root)
        root = [leaf, rule]

def tree_fill_(g, pt, f):
    key, children = pt
    if not children:
        if key in g:
            return (key, [(f.fuzz(key), [])])
        else:
            return (key, [])
    else:
        return (key, [tree_fill_(g, c, f) for c in children])


def tree_fill(g, pt):
    rgf = fuzzer.LimitFuzzer(g)
    return tree_fill_(g, pt, rgf)


def collapse(t):
    key, children = t
    if not children:
        return key
    return ''.join([collapse(c) for c in children])

def display_tree(node, level=0, c='-'):
    key, children = node
    if children is None:
        print(' ' * 4 * level + c+'> ' + key)
    else:
        print(' ' * 4 * level + c+'> ' + key)
        for c in children:
            if isinstance(c, str):
                print(' ' * 4 * (level+1) + c)
            else:
                display_tree(c, level + 1, c='+')


def test_valid_k_path_1():
    # Example 1 Grammar and Tables
    grammar = {
        "E": [
            ["E", "+", "T"],       # Rule 1: E → E + T
            ["T"]                  # Rule 2: E → T
            ],        
        "T": [
            ["T", "*", "F"],       # Rule 3: T → T * F
            ["F"]                  # Rule 4: T → F
            ],           
        "F": [
            ["(", "E", ")"],       # Rule 5: F → ( E )
            ["a"]                  # Rule 6: F → a
            ]
    }

    start = "E"

    parser = SLRParser(grammar, start)

    # Test the parser
    for path in k_paths(grammar, 5):
        if path[0] in start: 
            # print(path)
            tree = path_to_tree(path, grammar)
            # print(tree)
            for i in range(100):
                t = tree_fill(grammar, tree)
                s = collapse(t)
                res = parser.parse(list(s))
                assert res == True, f"Fail test: {s}"


def test_valid_k_path_2():
    # Example 2 Grammar and Tables
    grammar = {
        "S": [
              ["L", "=", "R"],    # Rule 1: S → L = R
              ["R"]               # Rule 2: S → R
             ],
        "L": [
              ["*", "R"],         # Rule 3: L → * R
              ["a"]               # Rule 4: L → a
             ],
        "R": [
              ["L"]               # Rule 5: R → L
             ]
    }
    start = "S"

    parser = SLRParser(grammar, start)

    # Test the parser
    for path in k_paths(grammar, 5):
        if path[0] in start: 
            # print(path)
            tree = path_to_tree(path, grammar)
            # print(tree)
            for i in range(100):
                t = tree_fill(grammar, tree)
                s = collapse(t)
                res = parser.parse(list(s))
                assert res == True, f"Fail test: {s}"


test_valid_k_path_1()
test_valid_k_path_2()


Parsing table:

         +       *       (       )       a       $      E'       E       E       T       T       F       F 

 I0                      S4              S5                      G1              G2              G3        
 I1      S6                                  Accept                                                        
 I2      R2      S7              R2              R2                                                        
 I3      R4      R4              R4              R4                                                        
 I4                      S4              S5                      G8              G2              G3        
 I5      R6      R6              R6              R6                                                        
 I6                      S4              S5                                      G9              G3        
 I7                      S4              S5                                                     G10        
 I8      S