# Tutorial 2: Inference

In [None]:
import libspn as spn
import tensorflow as tf

def init_const(v):
    return tf.initializers.constant(v)

### Building a Test Graph with Initialized Weights

In [None]:
iv_x = spn.IVs(num_vars=2, num_vals=2, name="iv_x")
sum_11 = spn.Sum((iv_x, [0,1]), name="sum_11")
sum_11.generate_weights(init_const([0.4, 0.6]))
sum_12 = spn.Sum((iv_x, [0,1]), name="sum_12")
sum_12.generate_weights(init_const([0.1, 0.9]))
sum_21 = spn.Sum((iv_x, [2,3]), name="sum_21")
sum_21.generate_weights(init_const([0.7, 0.3]))
sum_22 = spn.Sum((iv_x, [2,3]), name="sum_22")
sum_22.generate_weights(init_const([0.8, 0.2]))
prod_1 = spn.Product(sum_11, sum_21, name="prod_1")
prod_2 = spn.Product(sum_11, sum_22, name="prod_2")
prod_3 = spn.Product(sum_12, sum_22, name="prod_3")
root = spn.Sum(prod_1, prod_2, prod_3, name="root")
root.generate_weights(init_const([0.5, 0.2, 0.3]))
iv_y = root.generate_ivs(name="iv_y")

### Visualizing the SPN Graph

In [None]:
spn.display_spn_graph(root)

### Add Value Ops

In [None]:
init_weights = spn.initialize_weights(root)
marginal_val = root.get_value(inference_type=spn.InferenceType.MARGINAL)
mpe_val = root.get_value(inference_type=spn.InferenceType.MPE)

### Calculate Values

In [None]:
iv_x_arr = [[0, 1],
           [1, 0],
           [1,-1],
           [-1,-1]]

iv_y_arr = [[-1]] * 4

with spn.session() as (sess, _):
    init_weights.run()
    marginal_val_arr = sess.run(marginal_val, feed_dict={iv_x: iv_x_arr, iv_y: iv_y_arr})
    mpe_val_arr = sess.run(mpe_val, feed_dict={iv_x: iv_x_arr, iv_y: iv_y_arr})

print(marginal_val_arr)
print(mpe_val_arr)

### Add MPE State Ops

In [None]:
mpe_state_gen = spn.MPEState(value_inference_type=spn.InferenceType.MPE)
iv_x_state, iv_y_state = mpe_state_gen.get_state(root, iv_x, iv_y)

In [None]:
with spn.session() as (sess, _):
    init_weights.run()
    iv_x_state_arr, iv_y_state_arr = sess.run([iv_x_state, iv_y_state], 
                                              feed_dict={iv_x: [[-1,-1]], 
                                                         iv_y: [[-1]]})
    
print(iv_x_state_arr)
print(iv_y_state_arr)