In [None]:
import networkx as nx
from uccgGenerator import tree_insertion
from plotNetwork import plotGraph
import random
import matplotlib.pyplot as plt
from itertools import combinations
from networkx.drawing.nx_agraph import graphviz_layout, to_agraph
from matplotlib.patches import ArrowStyle
from networkx.algorithms.approximation.treewidth import *
from networkx.algorithms.dag import *
from functools import partial

In [None]:
res = 'res'
sep = 'sep'
label = 'I'
status = 'status'
I_edge = 'I_edge'
F_edge = 'F_edge'
common_edge = 'common_edge'
in_current = 'in_current'

connectionstyle = 'arc3, rad=0.1'
node_size = 300
color = 'white'
font_size = 13
alpha=1
width = 2

In [None]:
def random_flip(g):
    u, v = random.sample(g.edges, 1)[0]
    g.remove_edge(u, v)
    g.add_edge(v, u)
    return g

In [None]:
def get_IF(I, F):
    common_edges = list(set(I.edges).intersection(F.edges))
    diff_edges = list(set(I.edges).symmetric_difference(F.edges))
    IF = nx.DiGraph()

    IF.add_edges_from(common_edges+diff_edges)
    for e in set(I.edges) - set(F.edges):
        IF.edges[e][label] = I_edge
        IF.edges[e][in_current] = True 
    for e in set(F.edges) - set(I.edges):
        IF.edges[e][label] = F_edge 
        IF.edges[e][in_current] = False
    for e in common_edges:
        IF.edges[e][label] = common_edge 
        IF.edges[e][in_current] = True 
    return IF

In [None]:
def current_graph(IF):
    current_edges = [e for e in IF.edges if IF.edges[e][in_current]]
    dg = nx.DiGraph()
    dg.add_edges_from(current_edges)
    return dg

In [None]:
def plot_edges(IF, pos, edges):
    if len(edges) == 0:
        return
    e = edges[0]
    is_in_current = IF.edges[e][in_current]
    color = {I_edge: 'b', F_edge: 'r'}[IF.edges[e][label]]
    width = {True: 1.5, False: 1}[is_in_current]
    style = {True: 'solid', False: 'dotted'}[is_in_current]
    alpha = {True:1, False:0.5}[is_in_current]
    nx.draw_networkx_edges(IF, pos, edges, edge_color=color, alpha=alpha, width=width, connectionstyle=connectionstyle)

In [None]:
def plot_IF(IF, pos):
    nx.draw_networkx_nodes(IF, pos, linewidths=width, node_size=node_size, node_color=color, edgecolors='k')
    nx.draw_networkx_edges(IF, pos, [e for e in IF.edges if IF.edges[e][label] == common_edge], width=1, alpha=0.3)
    I_edges = [e for e in IF.edges if IF.edges[e][label] == I_edge]
    F_edges = [e for e in IF.edges if IF.edges[e][label] == F_edge]
     
    plot_edges(IF, pos, [e for e in I_edges if IF.edges[e][in_current]])
    plot_edges(IF, pos, [e for e in I_edges if not IF.edges[e][in_current]])
    plot_edges(IF, pos, [e for e in F_edges if IF.edges[e][in_current]])
    plot_edges(IF, pos, [e for e in F_edges if not IF.edges[e][in_current]])
    nx.draw_networkx_labels(IF, pos)

In [None]:
def plot_complement(IF, pos):
    common_edges = [e for e in IF.edges if IF.edges[e][label] == common_edge]
    nx.draw_networkx_edges(IF, pos, common_edges, width=1, alpha=0.3)
    complement_edges = [e for e in IF.edges if e not in common_edges and not IF.edges[e][in_current]]
    current_edges = [e for e in IF.edges if IF.edges[e][in_current] and IF.edges[e][label] != common_edge]
    nx.draw_networkx_edges(IF, pos, complement_edges, edge_color='m', width=1.5, connectionstyle=connectionstyle)
    nx.draw_networkx_edges(IF, pos, current_edges, edge_color='g', alpha=0.5, width=1, connectionstyle=connectionstyle)
    nx.draw_networkx_nodes(IF, pos, linewidths=width, node_size=node_size, node_color=color, edgecolors='k')
    nx.draw_networkx_labels(IF, pos)

In [None]:
def get_sep_and_res(rt, u):
    predecessors = set(rt.predecessors(u))
    parent = list(predecessors)[0] if len(predecessors) > 0 else None
    rt.nodes[u][sep] = u.intersection(parent) if parent else set()
    rt.nodes[u][res] = u - rt.nodes[u][sep]
    for v in rt.successors(u):
        get_sep_and_res(rt, v)

In [None]:
def process_clique(IF, c):
    update = True
    while update:
        update = False
        sub_g = IF.subgraph(c)
        top_order = list(topological_sort(current_graph(sub_g)))
        reversible_edges = [(u, v) for (u, v) in sub_g.edges if sub_g.edges[(u,v)][in_current] and sub_g.edges[(u,v)][label] == I_edge]
        for u, v in reversible_edges:
            # u -> v is an directed edge
            if top_order.index(u)-top_order.index(v) != -1:
                continue
            IF.edges[(u, v)][in_current] = False
            IF.edges[(v, u)][in_current] = True
            if not nx.is_directed_acyclic_graph(current_graph(IF)):
                IF.edges[(u, v)][in_current] = True
                IF.edges[(v, u)][in_current] = False
            else:
                plt.figure(figsize=(12, 6))
                plt.subplot(1,2,1)
                plot_IF(IF, pos)
                plt.title(f'clique: {set(c)}, reverse {(u, v)}')
                plt.subplot(1,2,2)
                plot_complement(IF, pos)
                update = True
                break
    return all([sub_g.edges[e][in_current] for e in sub_g.edges if sub_g.edges[e][label] == F_edge])

def process(rt, IF, root):
    finished = process_clique(IF, root)
    for child in rt.successors(root):
        process(rt, IF, child)
    print(root, finished)
    if not finished:
        process_clique(IF, root)

In [None]:
g = tree_insertion(20, 30)
dg = nx.DiGraph()
for u, v in g.edges:
    dg.add_edge(u, v)
while not nx.is_directed_acyclic_graph(dg):
    dg = random_flip(dg)
nx.is_directed_acyclic_graph(dg)

In [None]:
I = dg
F = nx.DiGraph(I)
for _ in range(100):
    F = random_flip(F)
    while not nx.is_directed_acyclic_graph(F):
        F = random_flip(F)

pos = nx.kamada_kawai_layout(I)
pos = graphviz_layout(I, prog='dot')

# plt.figure(figsize=(10, 5))
# plt.subplot(1,2,1)

# plotGraph(I, pos)
# plt.subplot(1,2,2)
# plotGraph(F, pos)

In [None]:
IF = get_IF(I, F)
_, t = treewidth_min_fill_in(I.to_undirected())
root = list(topological_sort(I))[0]
root_clique = random.choice([c for c in t if root in c])
rt = nx.dfs_tree(t, root_clique)
get_sep_and_res(rt, root_clique)

plt.figure(figsize=(12, 5))
plt.subplot(1,2,1)
plot_IF(IF, pos)
plt.subplot(1,2,2)
plotGraph(rt, graphviz_layout(rt, prog='dot'))

In [None]:
root_clique

In [None]:
process(rt, IF, root_clique)

In [None]:
for e in IF.edges:
#     if IF.edges[e][label] == I_edge:
#         print(IF.edges[e][in_current])
    if IF.edges[e][label] == F_edge and not IF.edges[e][in_current] :
        print(e, IF.edges[e][in_current])