In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
import re
import matplotlib.pyplot as plt
import networkx as nx
import pydot
from networkx.drawing.nx_pydot import graphviz_layout
from collections import deque

class Node:
    def __init__(self, name):
        self.name = name

class Expression(Node):
    def __init__(self, value):
        super().__init__(value)
        self.value = value

class NameLettingStatement(Node):
    def __init__(self, letting, name, expression):
        super().__init__(letting)
        self.name_node = Node(name)
        self.expression = expression

class NamedConstant(Node):
    def __init__(self, name, constant):
        super().__init__(name)
        self.constant = constant

class DomainNameLettingStatement(Node):
    def __init__(self, letting, name_of_domain, domain):
        super().__init__(letting)
        self.name_of_domain = Node(name_of_domain)
        self.domain = domain

class FindStatement(Node):
    def __init__(self, find, name, domain):
        super().__init__(find)
        self.name_node = Node(name)
        self.domain = domain

class SuchThatStatement(Node):
    def __init__(self, such_that, expression):
        super().__init__("such that")
        self.expression = expression

class Domain(Node):
    def __init__(self, domain):
        super().__init__("Domain")
        self.domain = domain

class IntDomain(Node):
    def __init__(self, name, lower, upper):
        super().__init__(name)
        self.lower = Node(lower)
        self.upper = Node(upper)

class TupleDomain(Node):
    def __init__(self, name, domains):
        super().__init__(name)
        self.domains = domains

class RelationDomain(Node):
    def __init__(self, name, domains):
        super().__init__(name)
        self.domains = domains

class NamedDomain(Domain):
    def __init__(self, name, domain):
        self.name = name
        self.domain = domain

class Operator(Node):
    def __init__(self, name):
        super().__init__(name)
        
class BinaryExpression(Node):
    def __init__(self, left, operator, right):
        super().__init__(operator.name)
        self.left = left
        self.operator = operator
        self.right = right

class TupleConstant(Expression):
    def __init__(self, values):
        super().__init__(values)
        self.name = "tuple"
        self.values = values

class TupleVariable(Node):
    def __init__(self, elements):
        super().__init__('tuple')
        self.elements = elements

class RelationConstant(Expression):
    def __init__(self, values):
        super().__init__(values)
        self.name = "relation"
        self.values = values

class QuantificationExpression(Node):
    def __init__(self, quantifier, variables, preposition, domain):
        super().__init__(quantifier)
        self.variables = variables
        self.preposition = Node(preposition)
        self.domain = domain
        self.quantifier = quantifier

class ConcatenationExpression(Node):
    def __init__(self, left, right):
        super().__init__(" . ")
        self.left = left
        self.right = right     

class EssenceParser:
    def __init__(self, input_str):
        self.tokens = re.findall(r'\.\.|\->|\\\/|\/\\|>=|<=|!=|[^\s\w]|[\w]+', input_str.replace('\n', ' '))
        self.index = 0
        self.named_domains = {} 
        self.named_constants = {}
        self.binary_operators = ["<",">", "<=", ">=", "+", "-", "*", "/", "%", "="]

    def parse(self):
      statements = []
      while self.index < len(self.tokens):
          statement = self.parse_statement()
          if isinstance(statement, NameLettingStatement):
              self.named_constants[statement.name] = statement.expression
          if isinstance(statement, DomainNameLettingStatement):
              self.named_domains[statement.name_of_domain.name] = statement.domain
          if isinstance(statement, FindStatement):
              self.named_domains[statement.name_node.name] = statement.domain
          statements.append(statement)
      return statements

    def consume(self):
        token = self.tokens[self.index]
        self.index += 1
        return token

    def match(self, expected):
        return self.index < len(self.tokens) and self.tokens[self.index] == expected

    def parse_statement(self):
        if self.match("letting"):          
          if self.tokens[self.index + 2] == "be" and self.tokens[self.index + 3] == "domain":
            return self.parse_domain_name_letting_statement()
          if self.tokens[self.index + 2] == "be":
            return self.parse_name_letting_statement()
          else:
              raise SyntaxError("Invalid letting statement: " + str(self.tokens[self.index]))
        elif self.match("find"):
            return self.parse_find_statement()
        elif self.match("such"):
            return self.parse_such_that_statement()
        else:
            raise SyntaxError("Invalid statement:" +str( self.tokens[self.index]) + " Token Num: " + str(self.index))

    def parse_name_letting_statement(self):
        letting = self.consume()  # "letting"
        name = self.consume()  # Name
        self.consume()  # "be"
        expression = NamedConstant(name,self.parse_expression())
        return NameLettingStatement(letting, name, expression)

    def parse_domain_name_letting_statement(self):
        letting = self.consume()  # "letting"
        name_of_domain = self.consume()  # NameDomain
        self.consume()  # "be "
        self.consume()  # "domain"
        domain = NamedDomain(name_of_domain,self.parse_domain())
        return DomainNameLettingStatement(letting, name_of_domain, domain)

    def parse_find_statement(self):
        find = self.consume()  # "find"
        name = self.consume()  # Name
        colon = self.consume()  # ":"
        domain = self.parse_domain()        
        return FindStatement(find, name, domain)

    def parse_such_that_statement(self):
        such_that = self.consume()  # "such"
        self.consume()  # "that"
        expression = self.parse_expression()
        while self.match("."):
          print("concatenaion")
          self.consume()  # "."
          next_expression = self.parse_expression()
          expression = ConcatenationExpression(expression, next_expression)
        return SuchThatStatement(such_that, expression)

    def parse_domain(self):
        if self.match("int"):
            domain_name = self.consume()  # "int"
            self.consume()  # "("
            lower = self.consume()  # Lower bound
            self.consume()  # ".."
            upper = self.consume()  # Upper bound
            self.consume()  # ")"
            return IntDomain(domain_name, lower, upper)
        elif self.match("tuple"):
            domain_name = self.consume()  # "tuple"
            self.consume()  # "("
            domains = []
            while not self.match(")"):
                domains.append(self.parse_domain())
                if self.match(","):
                    self.consume()  # ","
            self.consume()  # ")"
            return TupleDomain(domain_name, domains)

        elif self.match("relation"):
            domain_name = self.consume()  # "relation"
            self.consume()  # "of"
            self.consume()  # "("
            domains = []
            while not self.match(")"):
                domains.append(self.parse_domain())
                if self.match("*"):
                    self.consume()  # "*"
            self.consume()  # ")"
            return RelationDomain(domain_name, domains)
        elif self.tokens[self.index] in self.named_domains:
            name_of_domain = self.consume()
            return NamedDomain(name_of_domain,self.named_domains[name_of_domain])
        else:
            raise SyntaxError("Domain Parsing Error. Token: " + str(self.tokens[self.index]))
    
    def parse_term(self):
        return Expression(self.consume())  # Literal (integerConstant)

    def parse_operator(self):
        return Operator(self.consume())

    def match_any(self, tokens):
        return any(self.match(token) for token in tokens)

    def parse_constant(self):
        if self.match("(") and self.tokens[self.index + 2] == ",":
            return self.parse_tuple_constant()
        elif self.match("relation"):
            return self.parse_relation_constant()
        elif self.match_any(["forAll", "exists"]):
            return self.parse_quantification()
        elif self.tokens[self.index].isdigit():
            return Expression(self.consume())
        else:
            raise SyntaxError("Invalid constant: " + str(self.tokens[self.index]))


    def is_expression_terminator(self):
        return (
            self.match(".")
            or self.match("such")
            or self.match("letting")
            or self.match("find")
            or self.index >= len(self.tokens)
        )

    def parse_expression(self):
        def precedence(op):
            if op == "->":
                return -4
            if op == "in":
                return 1
            if op in ["*", "/"]:
                return 2
            if op in ["+", "-"]:
                return 3
            if op in ["<", ">", "<=", ">="]:
                return 4
            if op in ["==", "!="]:
                return 5
            if op in ["and", "or"]:
                return 6
            if op in ["forAll", "exists"]:
                return 7
            return 0

        def greater_precedence(op1, op2):
            return precedence(op1) > precedence(op2)

        output_queue = []
        operator_stack = []

        while not self.is_expression_terminator():
            if self.match("("):
                operator_stack.append(self.consume())  # "("
            elif self.match(")"):
                while operator_stack and operator_stack[-1] != "(":
                    output_queue.append(operator_stack.pop())
                operator_stack.pop()  # remove the "("
                self.consume()  # ")"
            elif self.match_any(self.binary_operators):
                current_operator = self.parse_operator()
                while (operator_stack and operator_stack[-1] in self.binary_operators
                        and greater_precedence(operator_stack[-1], current_operator)):
                    output_queue.append(operator_stack.pop())
                operator_stack.append(current_operator)
            elif self.match_any(["forAll", "exists"]):
                output_queue.append(self.parse_quantification())
            elif self.match("in"):
                operator_stack.append(self.consume())  # "in"
            else:
                output_queue.append(Expression(self.consume()))  # Literal or Name

        while operator_stack:
            output_queue.append(operator_stack.pop())

        return self.build_expression_tree(output_queue)


    def build_expression_tree(self, postfix_expression):
        stack = []

        for token in postfix_expression:
            if isinstance(token, Operator):
                right = stack.pop()
                left = stack.pop()
                stack.append(BinaryExpression(left, token, right))
            elif isinstance(token, Node):
                stack.append(token)
            else:
                stack.append(Expression(token))
        print(stack[0])
        return stack[0]


    def peek(self, stack):
        if stack:
            return stack[-1]
        return None
       
    
    def parse_tuple_constant(self):
        self.consume()  # "("
        values = []
        while not self.match(")"):
            values.append(Expression(self.consume()))  # Literal (integerConstant)
            if self.match(","):
                self.consume()  # ","
        self.consume()  # ")"
        return TupleConstant(values)

    def parse_relation_constant(self):
        self.consume()  # "relation"
        self.consume()  # "("
        values = []
        while not self.match(")"):
            if self.match("("):
                values.append(self.parse_tuple_constant())
            if self.match(","):
                self.consume()  # ","
        self.consume()  # ")"
        return RelationConstant(values)

    def parse_quantification(self):
        quantifier = self.consume()  # "forAll" or "exists"
        variables = []

        while not self.match_any([":", "in"]):
            if self.match("("):
                self.consume()  # "("
                tuple_elements = []
                while not self.match(")"):
                    tuple_elements.append(Node(self.consume()))
                    if self.match(","):
                        self.consume()  # ","
                self.consume()  # ")"
                variables = TupleVariable(tuple_elements)
            else:
                variables.append(Node(self.consume()))
                if self.match(","):
                    self.consume()  # ","

        preposition = self.consume()  # ":"
        domain = self.parse_domain()

        return QuantificationExpression(quantifier, variables,preposition, domain) 

def add_nodes_edges(graph, node, parent=None, index=None):
    unique_id = id(node)

    if index is not None:
        label = f"{node.name} ({index})"
    else:
        label = node.name

    graph.add_node(unique_id, label=label)

    if parent:
        graph.add_edge(id(parent), unique_id)

    if isinstance(node, (NameLettingStatement, FindStatement)):
        add_nodes_edges(graph, node.name_node, parent=node)

    if isinstance(node, NameLettingStatement):
        add_nodes_edges(graph, node.expression, parent=node.name_node)

    if isinstance(node, DomainNameLettingStatement):
        add_nodes_edges(graph, node.name_of_domain, parent=node)
        add_nodes_edges(graph, node.domain.domain, parent=node.name_of_domain)

    if isinstance(node, FindStatement):
        add_nodes_edges(graph, node.domain, parent=node.name_node)

    if isinstance(node, SuchThatStatement):
        add_nodes_edges(graph, node.expression, parent=node)

    if isinstance(node, IntDomain):
        add_nodes_edges(graph, node.lower, parent=node)
        add_nodes_edges(graph, node.upper, parent=node)

    if isinstance(node, BinaryExpression):
        add_nodes_edges(graph, node.left, parent=node)
        add_nodes_edges(graph, node.right, parent=node)

    if isinstance(node, TupleDomain):
        for index, domain in enumerate(node.domains, start=1):
            add_nodes_edges(graph, domain, parent=node, index=index)

    if isinstance(node, RelationDomain):
        for i, domain in enumerate(node.domains, start=1):
            add_nodes_edges(graph, domain, parent=node, index=i)

    if isinstance(node, TupleConstant):
        for i, value in enumerate(node.values, start=1):
            add_nodes_edges(graph, value, parent=node, index=i)

    if isinstance(node, RelationConstant):
        for i, value in enumerate(node.values, start=1):
            add_nodes_edges(graph, value, parent=node, index=i)


    if isinstance(node, QuantificationExpression):
        # Add quantifier node
        #graph.add_node(node.name)

        # Add variables and edges to the graph
        if  isinstance(node.variables,  TupleVariable):
            add_nodes_edges(graph, node.variables, parent=node)
            add_nodes_edges(graph, node.preposition, parent= node.variables)
            add_nodes_edges(graph, node.domain, parent= node.preposition)
        else:
            for variable in node.variables:
                #print(variable)
                add_nodes_edges(graph, Node(variable.name), parent=node)
                add_nodes_edges(graph, node.domain.domain, parent=Node(variable.name))

    if isinstance(node, TupleVariable):
        for element in node.elements:
            add_nodes_edges(graph, element, parent=node)

    if isinstance(node, ConcatenationExpression):
        add_nodes_edges(graph, node.left, parent=node)
        add_nodes_edges(graph, node.right, parent=node)
        
def createASG(ast):
    G = nx.DiGraph()

    for item in ast:
        add_nodes_edges(G, item)
    
    return G

In [None]:
input_str = """

letting vertices be domain int(1..3)
letting colours be domain int(1..3)
letting G be relation((1,2),(1,3),(2,3))
letting map be domain relation of (vertices * colours)
letting T be domain tuple (vertices,colours)
find C : map
find t : T
such that
  forAll (u,c) in C .
     forAll (v,d) in C .
        ((u = v) -> (c = d))"""


parser = EssenceParser(input_str)
asts = parser.parse()

G = nx.DiGraph()

for ast in asts:
    add_nodes_edges(G, ast)
    print(ast.name)
    if hasattr(ast, 'value'):
        print(ast.value)

# plot
labels = nx.get_node_attributes(G, 'label')
plt.figure(figsize=(20,20),dpi=40) 
pos = graphviz_layout(G, prog="dot")
#pos = nx.spring_layout(G,k= 3,iterations=450)
nx.draw(G,pos,with_labels=True,labels=labels, node_size=300, node_color="lightblue", font_size=14, font_weight="bold")
plt.show()