In [9]:
%matplotlib inline
%load_ext autoreload
%autoreload 2


In [10]:
import enum


from IPython.core.display import SVG, display
import networkx as nx


def draw(graph, path=None):
    svg = nx.nx_agraph.to_agraph(graph).draw(path=path, prog='dot', format='svg')
    display(SVG(svg))


class NodeType(enum.Enum):
    PARAM = "param"
    INTERMEDIATE = "intermediate"
    OBJECTIVE = "objective"

In [11]:
G = nx.DiGraph()

for i in range(0, 7):
    G.add_node(f"x{i}", type=NodeType.PARAM)

for i in range(1, 5):
    G.add_node(f"z{i}", type=NodeType.INTERMEDIATE)

G.add_node("y", type=NodeType.OBJECTIVE)

G.add_edges_from([("x1", "z1"),
                  ("x2", "z1"),
                  ("x3", "z2"),
                  ("x4", "z2"),
                  ("x5", "z2"),
                  ("z1", "z3"),
                  ("z2", "y"),
                  ("x1", "z3"),  #for a colider
                  ("z3", "y"),
                  ("x0", "y"),
                  ("x6", "z4"),
                  ("z4", "z3")
                  ])
draw(G)

In [12]:
# nx.algorithms.d_separated(G, {"x1"}, {"x3"}, {"y"})  # F
# nx.algorithms.d_separated(G, {"x1"}, {"x3"}, {"z2"})  # T

draw(G, path="all_g.svg")


In [13]:
from collections.abc import Set, Sequence
from collections import defaultdict


def all_params(g) :
    return {n for n, d in g.nodes(data=True) if d['type'] == NodeType.PARAM}


def all_intermediates(g) -> Set:
    return {n for n, d in g.nodes(data=True) if d['type'] == NodeType.INTERMEDIATE}


def all_objectives(g) -> Set:
    return {n for n, d in g.nodes(data=True) if d['type'] == NodeType.OBJECTIVE}


# Get all parameter pairs in the graph
def all_params_pairs(g, include_self: bool = False):
    all_params_pair = set()

    for p1 in all_params(g):
        for p2 in all_params(g):
            if not include_self and p1 == p2:
                continue
            all_params_pair.add(tuple(sorted((p1, p2))))

    return all_params_pair


def find_d_connected_subgraphsv1(g) -> Set:
    # Find all d-connected parameters
    all_d_connected = defaultdict(set)
    for p1, p2 in all_params_pairs(g):
        for intermediate in all_intermediates(g):
            if not nx.d_separated(g, {p1}, {p2}, {intermediate}):
                # D-connected
                all_d_connected[intermediate].update({p1, p2})

    return all_d_connected

def find_decomposition(g) -> Set:
    all_d_connected = defaultdict(set)
    all_objs = all_objectives(G)
    param_pairs = all_params_pairs(g)
    param_to_skip = set()
    unseen_params = all_params(g)

    for obj in all_objs:
        for n in G.predecessors(obj):
            if G.nodes[n]["type"] == NodeType.PARAM:
                all_d_connected[obj].add(n)
                param_to_skip.add(n)
                unseen_params.remove(n)

    for p1, p2 in param_pairs:
        if p1 in param_to_skip or p2 in param_to_skip:
            continue
        union_of_children = set(G.successors(p1)).union(set(G.successors(p2)))
        if not nx.d_separated(g, {p1}, {p2}, union_of_children):
            all_d_connected[str(union_of_children)].add(p1)
            all_d_connected[str(union_of_children)].add(p2)
            if p1 in unseen_params:
                unseen_params.remove(p1)
            if p2 in unseen_params:
                unseen_params.remove(p2)

    for unseen_param in unseen_params:
        all_d_connected[str(set(g.successors(unseen_param)))].add(unseen_param)
    return all_d_connected


find_decomposition(G)



In [15]:
import utils

utils.find_decomposition(G, all_params(G), all_objectives(G))

In [446]:

def find_connecting_subgraph(g, sources, targets):
    """Finds the subgraph connecting source nodes to target nodes."""

    connecting_nodes = set()
    for source in sources:
        for target in targets:
            if nx.has_path(g, source, target):
                connecting_nodes.update(nx.shortest_path(G, source, target))

    return g.subgraph(connecting_nodes)


def draw(graph, path=None):
    svg = nx.nx_agraph.to_agraph(graph).draw(path=path, prog='dot', format='svg')
    display(SVG(svg))
img = 0
for groups in find_decomposition(G).values():
    img+=1
    draw(find_connecting_subgraph(G, groups, ["y"]), path=f"subgraph_{img}.svg")

In [30]:
for subgraph in utils.create_d_separable_subgraphs(G, all_params(G), set(all_objectives(G))):
    utils.draw(G.subgraph(subgraph))


In [245]:
draw(G.subgraph(['x1', 'x2', 'z3', 'z1', 'y']))

draw(G.subgraph(['x3', 'x4', 'x5', 'z2', 'y']))


* Disregard any intermediate that has no path to Y
* Connect any parameter node to objectives if they are not connected to any intermediate
* Create a sub-graph of all parameters that are direct parents of y, while conditioning on intermediates.
* Then generate sub-graphs of all intermediates parents of y recursively


Discussion:
A reader may wonder why do we need d-separation as a concept and not simply perform recursive path lookup through the graph as was done in other works \cite{boat}.
Given the example in the figure, x1 and x2 may appear independent given z1, however, this is not correct, they remain d-connected due to the influence of a "collider" \cite{pearl}.
While blocking z1 fixes it value, its influence on x1 and x2 remains, since to achieve the fixed z1 value, you need to set x2 and x1 and if you set x1 you influence z3.
without d-separation important correlations like these go unnoticed and create incorrect optimization landscape.




Sampling:
* If parameter is direct parent of Y: simply draw random samples from the parameter - rather than using its predicted distribution


