# Imports

In [11]:
import libspn as spn
import tensorflow as tf
import numpy as np

from spn_topo.spn_model import SpnModel
from spn_topo.tbm.spn_template import TemplateSpn, NodeTemplateSpn, EdgeTemplateSpn, InstanceSpn
from spn_topo.tbm.template import EdgeTemplate, NodeTemplate, SingleEdgeTemplate, PairEdgeTemplate, SingletonTemplate, PairTemplate, ThreeNodeTemplate

# Copy

### Write a graph traverse algorithm that works on inputs, instead of nodes.

In [12]:
##############################
# WARNING: MY OWN METHOD!
##############################
from collections import deque, defaultdict
def mod_compute_graph_up(root, val_fun, const_fun=None, all_values=None):
    """Computes a certain value for the ``root`` node in the graph, assuming
    that for op nodes, the value depends on values produced by inputs of the op
    node. For this, it traverses the graph depth-first from the ``root`` node
    to the leaf nodes.

    Args:
        root (Node): The root of the SPN graph.
        val_fun (function): A function ``val_fun(input, *args)`` producing a
            certain value for the ``input``. If ``input`` has an op node, it will have
            additional arguments with values produced for the input nodes of
            this node.  The arguments will NOT be added if ``const_fun``
            returns ``True`` for the node. The arguments can be ``None`` if
            the input was empty.
        const_fun (function): A function ``const_fun(input)`` that should return
            ``True`` if the value generated by ``val_fun`` does not depend on
            the values generated for the input nodes, i.e. it is a constant
            function. If set to ``None``, it is assumed to always return
            ``False``, i.e. no ``val_fun`` is a constant function.
        all_values (dict): A dictionary indexed by ``node`` in which values
            computed for each node will be stored. Can be set to ``None``.

    Returns:
        The value for the ``root`` node.
    """
    if all_values is None:  # Dictionary of computed values indexed by node
        all_values = {}
    stack = deque()  # Stack of inputs to process
    stack.append((root, None))  # node and index

    last_val = None
    while stack:
        next_input = stack[-1]
        # Was this node already processed?
        # This might happen if the node is referenced by several parents
        if next_input[0] not in all_values:
            if next_input[0].is_op:
                # OpNode
                input_vals = []  # inputs to the node of 'next_input'
                all_input_vals = True
                if const_fun is None or const_fun(next_input) is False:
                    # Gather input values for non-const val fun
                    for inpt in next_input[0].inputs:
                        if inpt:  # Input is not empty
                            try:
                                # Check if input_node in all_vals
                                input_vals.append(all_values[inpt.node])
                            except KeyError:
                                all_input_vals = False
                                stack.append((inpt.node, inpt.indices))
                        else:
                            # This input was empty, use None as value
                            input_vals.append(None)
                # Got all inputs?
                if all_input_vals:
                    last_val = val_fun(next_input, *input_vals)
                    all_values[next_input[0]] = last_val
                    stack.pop()
            else:
                # VarNode, ParamNode
                last_val = val_fun(next_input)
                all_values[next_input[0]] = last_val
                stack.pop()
        else:
            stack.pop()

    return last_val
##############################
# END WARNING: MY OWN METHOD!
##############################


### Create a NodeTemplateSpn.

In [17]:
three_node_spn = NodeTemplateSpn(ThreeNodeTemplate)
spn.display_spn_graph(three_node_spn.root, skip_params=False)
three_node_spn._conc_inputs.set_inputs()
three_node_spn._conc_inputs

Generating SPN structure...
Generating weight initialization Ops...
Initializing learning Ops...


Conc_NodeTemplateSpn_3_3

In [14]:
conc = None
def fun_up(inpt, *args):
    global conc
    node, indices = inpt
    if node.is_op:
        if isinstance(node, spn.Sum):
            # [2:] is to skip the weights node and the explicit IVs node for this sum.
            return spn.Sum(*args[2:], weights=args[0])
        elif isinstance(node, spn.Product):
            return spn.Product(*args)
        elif isinstance(node, spn.Concat):
            conc = spn.Concat()  # assume there is only one concat node.
            return spn.Input(conc, indices)
    elif isinstance(node, spn.Weights):
        return node
    else:
        raise ValueError("We don't intend to deal with IVs here. Please remove them from the concat.")

In [15]:
new_root=mod_compute_graph_up(three_node_spn.root, val_fun=fun_up)
print(conc)

Concat_1


In [16]:
conc.set_inputs(spn.IVs(num_vars=three_node_spn._inputs.num_vars, num_vals=three_node_spn._inputs.num_vals))

spn.display_spn_graph(new_root, skip_params=False)

### Bug

The co

In [46]:
three_node_spn.root.inputs[2].node.inputs

(Input(Sums1.1_6/Sum1, [0]), Input(Conc_NodeTemplateSpn_3_6, [10]))

In [40]:
new_root.inputs[2].node.inputs

(Input(IVs_5, None), Input(Sum_13, None))