In [None]:
import libspn as spn
import tensorflow as tf
import numpy as np
sess = tf.InteractiveSession()

# Build a custom SPN utilizing all Op node types

In [None]:
# Inputs
iv_1 = spn.IVs(num_vars=2, num_vals=2, name="IVs_1")
iv_2 = spn.IVs(num_vars=2, num_vals=2, name="IVs_2")

# DECOMPOSITION 1
# Layer 1
parsums1_1_1 = spn.ParSums((iv_1, [0, 1]), num_sums=2, name="Sums1.1/ParSums1")
sums1_2_1 = spn.Sums((iv_1, [2, 3]), (iv_1, [2, 3]), num_sums=2, name="Sums1.2/Sums1")
# Layer 2
permprods1 = spn.PermProducts(parsums1_1_1, sums1_2_1, name="Products1/PermProducts1")
# Layer 3
parsums3_1_1 = spn.ParSums(permprods1, num_sums=2, name="Sums3.1/ParSums1")
sum3_1_2 = spn.Sum(permprods1, name="Sums3.1/Sum2")

# DECOMPOSITION 2
# Layer 1
sum2_1_1 = spn.Sum((iv_2, [0, 1]), name="Sums2.1/Sum1")
parsums2_1_2 = spn.ParSums((iv_2, [0, 1]), name="Sums2.1/ParSums2")
sums2_2_1 = spn.Sums((iv_2, [2, 3]), num_sums=1, name="Sums2.2/Sums1")
sum2_2_2 = spn.Sum((iv_2, [2, 3]), name="Sums2.2/Sum2")
# Layer 2
prods2 = spn.Products(sum2_1_1, sums2_2_1,
                   sum2_1_1, sum2_2_2,
                   parsums2_1_2, sums2_2_1,
                   parsums2_1_2, sum2_2_2,
                   num_prods = 4, name="Products2/Products1")
# Layer 3
sum3_2_1 = spn.Sum(prods2, name="Sums3.2/Sum1")
sums3_2_2 = spn.Sums(prods2, num_sums=2, name="Sums3.2/Sums2")

# Layer 4
prod3_1 = spn.Product((parsums3_1_1, 0), sum3_2_1, name="Products3/Products1")
prod3_2 = spn.Product((parsums3_1_1, 0), (sums3_2_2, 0), name="Products3/Products2")
prods3_3 = spn.Products((parsums3_1_1, 0), (sums3_2_2, 1),
                        (parsums3_1_1, 1), sum3_2_1,
                        (parsums3_1_1, 1), (sums3_2_2, 0),
                        (parsums3_1_1, 1), (sums3_2_2, 1),
                        (sum3_1_2, 0), sum3_2_1,
                        num_prods=5, name="Products3/Products3")
prods3_4 = spn.Products(sum3_1_2, (sums3_2_2, 0), name="Products3/Products4")
prod3_5 = spn.Product(sum3_1_2, (sums3_2_2, 0), name="Products3/Products5")

# Layer 5
root = spn.Sum(prod3_1, prod3_2, prods3_3, prods3_4, prod3_5, name="root")
iv_y = root.generate_ivs()

# Generate and initialize weights
spn.generate_weights(root, init_value=spn.ValueType.RANDOM_UNIFORM())
spn.initialize_weights(root).run()

## Inspect

In [None]:
print("Number of nodes: ", root.get_num_nodes())
print("\nScope of root: ", root.get_scope())
print("\nNetwork valid?: ", ("YES" if root.is_valid() else "NO"))

## Visualize SPN graph

In [None]:
spn.display_spn_graph(root, skip_params=False)

# Save Model With Initialized Weights

In [None]:
spn.JSONSaver('saved_models/test_init.spn').save(root)

# Marginal Value

In [None]:
value = root.get_value(spn.InferenceType.MARGINAL)
log_value = root.get_log_value(spn.InferenceType.MARGINAL)

In [None]:
# Inputs Feed
iv_x_feed = np.random.randint(2, size=(5, 4))
iv_y_feed = np.random.randint(9, size=(5, 1))

value_array = value.eval(feed_dict={iv_1: iv_x_feed[:, 0:2], iv_2: iv_x_feed[:, 2:4], iv_y: iv_y_feed})
value_array_log = log_value.eval(feed_dict={iv_1: iv_x_feed[:, 0:2], iv_2: iv_x_feed[:, 2:4], iv_y: iv_y_feed})

print("Marginal Value:\n", value_array)
print("\nMarginal Value (log):\n", value_array_log)

## Marginal Path

In [None]:
mpe_marginal_path_gen = spn.MPEPath(value_inference_type=spn.InferenceType.MARGINAL, log=False)
mpe_marginal_path_gen.get_mpe_path(root)

print("IV_1 Counts:\n", mpe_marginal_path_gen.counts[iv_1].eval(feed_dict={iv_1: iv_x_feed[:, 0:2],
                                                                           iv_2: iv_x_feed[:, 2:4],
                                                                           iv_y: iv_y_feed}))
print("\nIV_2 Counts:\n", mpe_marginal_path_gen.counts[iv_2].eval(feed_dict={iv_1: iv_x_feed[:, 0:2],
                                                                             iv_2: iv_x_feed[:, 2:4],
                                                                             iv_y: iv_y_feed}))

# MPE Value

In [None]:
value = root.get_value(spn.InferenceType.MPE)
log_value = root.get_log_value(spn.InferenceType.MPE)

In [None]:
value_array = value.eval(feed_dict={iv_1: iv_x_feed[:, 0:2], iv_2: iv_x_feed[:, 2:4], iv_y: iv_y_feed})
value_array_log = log_value.eval(feed_dict={iv_1: iv_x_feed[:, 0:2], iv_2: iv_x_feed[:, 2:4], iv_y: iv_y_feed})

print("MPE Value:\n", value_array)
print("\nMPE Value (log):\n", value_array_log)

## MPE Path

In [None]:
mpe_marginal_path_gen = spn.MPEPath(value_inference_type=spn.InferenceType.MPE, log=False)
mpe_marginal_path_gen.get_mpe_path(root)

print("IV_1 Counts:\n", mpe_marginal_path_gen.counts[iv_1].eval(feed_dict={iv_1: iv_x_feed[:, 0:2],
                                                                           iv_2: iv_x_feed[:, 2:4],
                                                                           iv_y: iv_y_feed}))
print("\nIV_2 Counts:\n", mpe_marginal_path_gen.counts[iv_2].eval(feed_dict={iv_1: iv_x_feed[:, 0:2],
                                                                             iv_2: iv_x_feed[:, 2:4],
                                                                             iv_y: iv_y_feed}))

# Load SPN With Initialzed Weights from File

In [None]:
loader = spn.JSONLoader('saved_models/test_init.spn')
root_1 = loader.load()

## Inspect Loaded SPN Graph

In [None]:
print("Number of nodes: ", root_1.get_num_nodes())
print("\nScope of root: ", root_1.get_scope())
print("\nNetwork valid?: ", ("YES" if root_1.is_valid() else "NO"))

## Visualize Loaded SPN graph

In [None]:
spn.display_spn_graph(root_1, skip_params=False)

In [None]:
# Initialize weights
spn.initialize_weights(root_1).run()

# Find IVs nodes
iv_1 = loader.find_node("IVs_1")
iv_2 = loader.find_node("IVs_2")
iv_y = loader.find_node("root_IVs")

# Marginal Value

In [None]:
value_1 = root_1.get_value(spn.InferenceType.MARGINAL)
log_value_1 = root_1.get_log_value(spn.InferenceType.MARGINAL)

In [None]:
value_array_1 = value_1.eval(feed_dict={iv_1: iv_x_feed[:, 0:2], iv_2: iv_x_feed[:, 2:4], iv_y: iv_y_feed})
value_array_log_1=log_value_1.eval(feed_dict={iv_1: iv_x_feed[:, 0:2], iv_2: iv_x_feed[:, 2:4], iv_y: iv_y_feed})

print("Marginal Value:\n", value_array_1)
print("\nMarginal Value (log):\n", value_array_log_1)

## Marginal Path

In [None]:
mpe_marginal_path_gen = spn.MPEPath(value_inference_type=spn.InferenceType.MARGINAL, log=False)
mpe_marginal_path_gen.get_mpe_path(root_1)

print("IV_1 Counts:\n", mpe_marginal_path_gen.counts[iv_1].eval(feed_dict={iv_1: iv_x_feed[:, 0:2],
                                                                           iv_2: iv_x_feed[:, 2:4],
                                                                           iv_y: iv_y_feed}))
print("\nIV_2 Counts:\n", mpe_marginal_path_gen.counts[iv_2].eval(feed_dict={iv_1: iv_x_feed[:, 0:2],
                                                                             iv_2: iv_x_feed[:, 2:4],
                                                                             iv_y: iv_y_feed}))

# MPE Value

In [None]:
value_1 = root_1.get_value(spn.InferenceType.MPE)
log_value_1 = root_1.get_log_value(spn.InferenceType.MPE)

In [None]:
value_array_1 = value_1.eval(feed_dict={iv_1: iv_x_feed[:, 0:2], iv_2: iv_x_feed[:, 2:4], iv_y: iv_y_feed})
value_array_log_1 = log_value_1.eval(feed_dict={iv_1: iv_x_feed[:, 0:2], iv_2: iv_x_feed[:, 2:4], iv_y: iv_y_feed})

print("MPE Value:\n", value_array_1)
print("\nMPE Value (log):\n", value_array_log_1)

## MPE Path

In [None]:
mpe_marginal_path_gen = spn.MPEPath(value_inference_type=spn.InferenceType.MPE, log=False)
mpe_marginal_path_gen.get_mpe_path(root_1)

print("IV_1 Counts:\n", mpe_marginal_path_gen.counts[iv_1].eval(feed_dict={iv_1: iv_x_feed[:, 0:2],
                                                                           iv_2: iv_x_feed[:, 2:4],
                                                                           iv_y: iv_y_feed}))
print("\nIV_2 Counts:\n", mpe_marginal_path_gen.counts[iv_2].eval(feed_dict={iv_1: iv_x_feed[:, 0:2],
                                                                             iv_2: iv_x_feed[:, 2:4],
                                                                             iv_y: iv_y_feed}))