In [None]:
import dataclasses


class JoinHypergraph:
    def __init__(self):
        self.nodes = []
        self.edges = []

    def __repr__(self):
        s = f'nodes: {set(self.nodes)}\n'

        s += 'edges:\n'
        for edge in self.edges:
            s += f'{edge}\n'

        return s


class JoinExpression:
    def __init__(self):
        self.type = str()  # table or join

        self.left = None
        self.right = None

        self.subtree_nodes = set()
        self.condition_used_tables = set()  # used tables in join condition

    def __repr__(self):
        if self.type != 'TABLE':
            left = self.left

            if left.type != 'TABLE':
                left = f'({repr(left)})'
            else:
                left = str(list(left.subtree_nodes)[0])

            right = self.right

            if right.type != 'TABLE':
                right = f'({repr(right)})'
            else:
                right = str(list(right.subtree_nodes)[0])

            condition_used = ', '.join(list(self.condition_used_tables))
            return f'{left} {self.type} JOIN {right} on ({condition_used})'
        else:
            return f'{self.subtree_nodes}'


class ConflictRule:
    def __init__(self, activation_nodes, required_nodes):
        self.activation_nodes = activation_nodes
        self.required_nodes = required_nodes

    activation_nodes = set()
    required_nodes = set()


ASSOC_TABLE = {
    'INNER': {'INNER', 'LEFT OUTER'},
    'LEFT OUTER': {'LEFT OUTER'},
    'FULL OUTER': {'LEFT OUTER', 'FULL OUTER'},
}

LASSCOM_TABLE = {
    'INNER': {'INNER', 'LEFT OUTER'},
    'LEFT OUTER': {'INNER', 'LEFT OUTER', 'FULL OUTER'},
    'FULL OUTER': {'LEFT OUTER', 'FULL OUTER'},
}

RASSCOM_TABLE = {
    'INNER': {'INNER'},
    'LEFT OUTER': {},
    'FULL OUTER': {'FULL OUTER'},
}


def operators_are_assoc(lhs, rhs):
    return rhs.type in ASSOC_TABLE[lhs.type]


def operators_are_left_asscom(lhs, rhs):
    return rhs.type in LASSCOM_TABLE[lhs.type]


def operators_are_right_asscom(lhs, rhs):
    return rhs.type in RASSCOM_TABLE[lhs.type]


class ConflictRulesCollector:
    def __init__(self, root):
        self.root = root
        self.conflict_rules = []

    def visit_expr_tree(self, child, visitor):
        if child.type == 'TABLE':
            return

        self.visit_expr_tree(child.left, visitor)
        self.visit_expr_tree(child.right, visitor)

        visitor(child)

    def collect_left_conflict(self, child):
        if not operators_are_assoc(child, self.root):
            self.conflict_rules.append(
                ConflictRule(
                    child.right.subtree_nodes,
                    child.left.subtree_nodes
                )
            )

        if not operators_are_left_asscom(child, self.root):
            self.conflict_rules.append(
                ConflictRule(
                    child.left.subtree_nodes,
                    child.right.subtree_nodes
                )
            )

    def collect_right_conflict(self, child):
        if not operators_are_assoc(self.root, child):
            self.conflict_rules.append(
                ConflictRule(
                    child.left.subtree_nodes,
                    child.right.subtree_nodes
                )
            )

        if not operators_are_right_asscom(self.root, child):
            self.conflict_rules.append(
                ConflictRule(
                    child.right.subtree_nodes,
                    child.left.subtree_nodes
                )
            )

    def collect_conflicts(self):
        self.visit_expr_tree(self.root.left, self.collect_left_conflict)
        self.visit_expr_tree(self.root.right, self.collect_right_conflict)
        return self.conflict_rules


def convert_cr_into_tes(ses, conflict_rules):
    tes = ses

    while True:
        prev_tes = tes

        for conflict_rule in conflict_rules:
            if len(conflict_rule.activation_nodes & tes) != 0:
                tes |= conflict_rule.required_nodes

        for conflict_rule in conflict_rules:
            if conflict_rule.required_nodes.issubset(tes):
                del conflict_rule

        if tes == prev_tes or len(conflict_rules) == 0:
            return tes


def find_hyperedge(expr):
    cr_collector = ConflictRulesCollector(expr)
    conflict_rules = cr_collector.collect_conflicts()

    tes = convert_cr_into_tes(ses=expr.condition_used_tables, conflict_rules=conflict_rules)

    left = tes & expr.left.subtree_nodes
    right = tes & expr.right.subtree_nodes
    return left, right


def make_join_hypergraph(graph, expr):
    if expr.type == 'TABLE':
        graph.nodes.append(list(expr.subtree_nodes)[0])
        return

    make_join_hypergraph(graph, expr.left)
    make_join_hypergraph(graph, expr.right)

    expr.subtree_nodes = expr.left.subtree_nodes | expr.right.subtree_nodes

    left, right = find_hyperedge(expr)

    graph.edges.append([left, right])

def INNER_JOIN(lhs, rhs, on=None):
    return JOIN(lhs, rhs, used_tables=on, type='INNER')

IJ = INNER_JOIN

def LEFT_OUTER_JOIN(lhs, rhs, on=None):
    return JOIN(lhs, rhs, used_tables=on, type='LEFT OUTER')

LJ = LEFT_OUTER_JOIN

def FULL_OUTER_JOIN(lhs, rhs, on=None):
    return JOIN(lhs, rhs, used_tables=on, type='FULL OUTER')

OJ = FULL_OUTER_JOIN

def JOIN(lhs, rhs, type, used_tables):
    if used_tables is None:
        used_tables = {lhs, rhs}
    else:
        used_tables = set(used_tables.split(', '))

    expr = JoinExpression()
    expr.type = type

    left_expr = lhs

    if isinstance(left_expr, str):
        left_expr = JoinExpression()
        left_expr.type = 'TABLE'
        left_expr.subtree_nodes = set(lhs)

    right_expr = rhs

    if isinstance(right_expr, str):
        right_expr = JoinExpression()
        right_expr.type = 'TABLE'
        right_expr.subtree_nodes = set(rhs)

    expr.left = left_expr
    expr.right = right_expr
    expr.condition_used_tables = used_tables

    return expr

def print_join_hypergraph(join_expr):
    print(join_expr)
    graph = JoinHypergraph()
    make_join_hypergraph(graph, join_expr)
    print(graph)

# A INNER JOIN ((B LEFT OUTER JOIN C on (C, B)) LEFT OUTER JOIN D on (B, D)) on (B, A)

print_join_hypergraph(
    IJ('A', LJ(LJ('B', 'C'), 'D', on='B, D'), on='B, A')
)
