# Generative learning for continuous MNIST data using randomly structured SPNs
This notebook shows how to build a randomly structured SPN and train it with online hard EM on continuous MNIST data.

### Setting up the imports and preparing the data
We load the data from `tf.keras.datasets`. Preprocessing consists of flattening and scaling of the data.

In [1]:
%matplotlib inline
import libspn as spn
import tensorflow as tf
import numpy as np
from libspn.examples.utils.dataiterator import DataIterator
import matplotlib.pyplot as plt

# Load
(train_x, train_y), (test_x, test_y) = tf.keras.datasets.mnist.load_data()

def scale(x):
    return x / 255.

def flatten(x):
    return x.reshape(-1, np.prod(x.shape[1:]))

def preprocess(x, y):
    return scale(flatten(x)), np.expand_dims(y, axis=1)

# Preprocess
train_x, train_y = preprocess(train_x, train_y)
test_x, test_y = preprocess(test_x, test_y)

### Defining the hyperparameters
Some hyperparameters for the SPN. 
- `num_subsets` is used for the `DenseSPNGenerator`. This corresponds to the number of variable subsets joined by product nodes in the SPN.
- `num_mixtures` is used for the `DenseSPNGenerator`. This corresponds to the number of sum nodes per scope.
- `num_decomps` is used for the `DenseSPNGenerator`. This corresponds to the number of decompositions generated at each level of products from top-down.
- `num_vars` corresponds to the number of input variables (the number of pixels in the case of MNIST).
- `balanced` is used for the `DenseSPNGenerator`. If true, then the generated SPN will have balanced subsets and will consequently be a balanced tree.
- `input_dist` is the input distribution (the first product/sum layer in the SPN). `spn.DenseSPNGenerator.InputDist.RAW` corresponds to raw indicators being joined (so first layer is a product layer). `spn.DenseSPNGenerator.InputDist.MIXTURE` would correspond to a sums on top of each indicator.
- `num_leaf_components` is the number of contineuous components in the leaf distribution
- `inference_type` determines the kind of forward inference where `spn.InferenceType.MARGINAL` corresponds to sum nodes marginalizing their inputs. `spn.InferenceType.MPE` would correspond to having max nodes instead.
- `num_classes`, `batch_size` and `num_epochs` should be obvious:)

In [2]:
# Number of variable subsets that a product joins
num_subsets = 2
# Number of sums per scope
num_mixtures = 4
# Number of variables
num_vars = train_x.shape[1]
# Number of decompositions per product layer
num_decomps = 1
# Generate balanced subsets -> balanced tree
balanced = True
# Input distribution. Raw corresponds to first layer being product that 
# takes raw indicators
input_dist = spn.DenseSPNGenerator.InputDist.RAW
# Number of different values at leaf (binary here, so 2)
num_leaf_components = 2
# Initial value for path count accumulators
initial_accum_value = 0.0
# Inference type (can also be spn.InferenceType.MPE) where 
# sum nodes are turned into max nodes
inference_type = spn.InferenceType.MARGINAL
# Using unweighted log probabilities when determining winning child during hard EM
use_unweighted = False
# Sample the MPE pahts
sample_winner = False
# Sample probabilities
sample_prob = None
# Add one smoothing
additive_smoothing = 1.0
# L0 prior
l0_prior_factor = 16 / 50000

# Number of classes
num_classes = 10
batch_size = 16
num_epochs = 10

### Building the SPN
Our SPN consists of Gaussian leafs, a dense SPN per class and a root node connecting the 10 class-wise sub-SPNs. We also add an indicator node to the root node to model the latent class variable. Finally, we generate `Weight` nodes for the full SPN by using `spn.generate_weights`.

In [3]:
from libspn.examples.convspn.architecture import wicker_convspn_two_non_overlapping

# Reset graph
tf.reset_default_graph()

# Leaf nodes
normal_leafs = spn.NormalLeaf(
    trainable_scale=True,
    trainable_loc=True,
    num_components=num_leaf_components, 
    num_vars=num_vars)

# Get convolutional SPN
root, class_roots = wicker_convspn_two_non_overlapping(
    normal_leafs, num_channels_prod=[32, 32, 64, 64, 64], num_channels_sums=[32, 32, 64, 64, 64])
# conv_0 = spn.ConvProducts(normal_leafs, dilation_rate=1, kernel_size=2, padding=)

# x = randomize = spn.BlockRandomDecompositions(normal_leafs, num_decomps=10)

# num_alterations = int(np.ceil(np.log2(num_vars)))
# for i in range(num_alterations - 1):
#     layer_suffix = "_{}".format(i)
#     x = spn.BlockPermuteProduct(x, num_factors=2, name="Products" + layer_suffix)
#     x = spn.BlockSum(x, num_sums_per_block=4, name="Sums" + layer_suffix)

# layer_suffix = "_{}".format(num_alterations - 1)
# x = spn.BlockPermuteProduct(x, num_factors=2, name="Products" + layer_suffix)
# x = spn.BlockSum(x, num_sums_per_block=1, name="ClassRoots")
# x = spn.BlockMergeDecompositions(x, num_decomps=1)
# root = spn.BlockRootSum(x, name="Root")

# Add a IndicatorLeaf node to the root as a latent class variable
class_indicators = root.generate_latent_indicators()

# Generate the weights for the SPN rooted at `root`
spn.generate_weights(root)

print("SPN depth: {}".format(root.get_depth()))
print("Number of products layers: {}".format(root.get_num_nodes(node_type=spn.LocalSums)))
print("Number of sums layers: {}".format(root.get_num_nodes(node_type=spn.ConvProductsDepthwise)))








SPN depth: 13
Number of products layers: 5
Number of sums layers: 5


### Defining the TensorFlow graph
Now that we have defined the SPN graph we can declare the TensorFlow operations needed for training and evaluation. We use the `HardEMLearning` class to help us out. The `MPEState` class can be used to find the MPE state of any node in the graph. In this case we might be interested in generating images or finding the most likely class based on the evidence elsewhere. These correspond to finding the MPE state for `leaf_indicators` and `class_indicators` respectively.

In [4]:
# Op for getting the log probability of the root
root_log_prob = root.get_log_value(inference_type=inference_type)

# Helper for constructing EM learning ops
em_learning = spn.HardEMLearning(
    initial_accum_value=initial_accum_value, root=root, value_inference_type=inference_type,
    sample_prob=sample_prob, sample_winner=sample_winner, use_unweighted=use_unweighted,
    l0_prior_factor=l0_prior_factor, additive_smoothing=additive_smoothing)

# Accumulate counts and update weights
online_em_update_op = em_learning.accumulate_and_update_weights()

# Op for initializing accumulators
init_accumulators = em_learning.reset_accumulators()

# MPE state generator
mpe_state_generator = spn.MPEState()
# Generate MPE state ops for leaf indicator and class indicator
normal_leaf_mpe, class_indicator_mpe = mpe_state_generator.get_state(root, normal_leafs, class_indicators)

Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
Instructions for updating:
Use `tf.random.categorical` instead.



### Display TF Graph
Only works with Chrome browser.

In [None]:
spn.display_tf_graph()

### Training the SPN
Here we we train while monitoring the likelihood. Note that we train the SPN generatively, which means that it does not optimize for discriminating between digits. This is why we observe lower accuracies than when e.g. training a discriminative model such as an MLP with cross-entropy loss.

In [5]:
# Set up some convenient iterators
train_iterator = DataIterator([train_x, train_y], batch_size=batch_size)
test_iterator = DataIterator([test_x, test_y], batch_size=batch_size)

def fd(x, y):
    return {normal_leafs: x, class_indicators: y}

with tf.Session() as sess:
    # Initialize things
    sess.run(tf.global_variables_initializer())
    
    # Do one run for test likelihoods
    log_likelihoods = []
    for batch_x, batch_y in test_iterator.iter_epoch("Testing"):
        batch_llh = sess.run(root_log_prob, fd(batch_x, batch_y))
        log_likelihoods.extend(batch_llh)
        test_iterator.display_progress(LLH="{:.2f}".format(np.mean(batch_llh)))
    mean_test_llh = np.mean(log_likelihoods)
    
    print("Before training test LLH = {:.2f}".format(mean_test_llh))                              
    for epoch in range(num_epochs):
        
        # Train
        log_likelihoods = []
        for batch_x, batch_y in train_iterator.iter_epoch("Training"):
            batch_llh, _ = sess.run(
                [root_log_prob, online_em_update_op], fd(batch_x, batch_y))
            log_likelihoods.extend(batch_llh)
            train_iterator.display_progress(LLH="{:.2f}".format(np.mean(batch_llh)))
        mean_train_llh = np.mean(log_likelihoods)
        
        # Test
        log_likelihoods, matches = [], []
        for batch_x, batch_y in test_iterator.iter_epoch("Testing"):
            batch_llh, batch_class_mpe = sess.run([root_log_prob, class_indicator_mpe], fd(batch_x, -np.ones_like(batch_y, dtype=int)))
            log_likelihoods.extend(batch_llh)
            matches.extend(np.equal(batch_class_mpe, batch_y))
            test_iterator.display_progress(LLH="{:.2f}".format(np.mean(batch_llh)))
        mean_test_llh = np.mean(log_likelihoods)
        mean_test_acc = np.mean(matches)
        
        # Report
        print("Epoch {}, train LLH = {:.2f}, test LLH = {:.2f}, test accuracy = {:.2f}".format(
            epoch, mean_train_llh, mean_test_llh, mean_test_acc))
    
    # Compute MPE state of all digits
    per_class_mpe = sess.run(
        normal_leaf_mpe, 
        fd(
            -np.ones([num_classes, num_vars], dtype=int), 
            np.expand_dims(np.arange(num_classes, dtype=int), 1)
        )
    )
    

Testing: 100%|██████████| 625/625 [00:41<00:00, 15.02it/s, LLH=-888.99]
Training:   0%|          | 0/3750 [00:00<?, ?it/s]

Before training test LLH = -889.02


InvalidArgumentError: Conv2DCustomBackpropInputOp only supports NHWC.
	 [[node TrueMPEPath/ConvProductsDepthwise_4/Conv2DBackpropInput (defined at /home/jos/spn/libspn/libspn/graph/op/conv_products_depthwise.py:96) ]]

Errors may have originated from an input operation.
Input Source operations connected to node TrueMPEPath/ConvProductsDepthwise_4/Conv2DBackpropInput:
 TrueMPEPath/ConvProductsDepthwise_4/Shape (defined at /home/jos/spn/libspn/libspn/graph/op/conv_products_depthwise.py:90)	
 TrueMPEPath/ConvProductsDepthwise_4/ones (defined at /home/jos/spn/libspn/libspn/graph/op/conv_products_depthwise.py:91)	
 TrueMPEPath/ConvProductsDepthwise_4/Reshape_2 (defined at /home/jos/spn/libspn/libspn/graph/op/conv_products_depthwise.py:72)

Original stack trace for 'TrueMPEPath/ConvProductsDepthwise_4/Conv2DBackpropInput':
  File "/usr/lib/python3.5/runpy.py", line 184, in _run_module_as_main
    "__main__", mod_spec)
  File "/usr/lib/python3.5/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/home/jos/.local/lib/python3.5/site-packages/ipykernel_launcher.py", line 16, in <module>
    app.launch_new_instance()
  File "/home/jos/.local/lib/python3.5/site-packages/traitlets/config/application.py", line 658, in launch_instance
    app.start()
  File "/home/jos/.local/lib/python3.5/site-packages/ipykernel/kernelapp.py", line 486, in start
    self.io_loop.start()
  File "/home/jos/.local/lib/python3.5/site-packages/tornado/platform/asyncio.py", line 127, in start
    self.asyncio_loop.run_forever()
  File "/usr/lib/python3.5/asyncio/base_events.py", line 345, in run_forever
    self._run_once()
  File "/usr/lib/python3.5/asyncio/base_events.py", line 1312, in _run_once
    handle._run()
  File "/usr/lib/python3.5/asyncio/events.py", line 125, in _run
    self._callback(*self._args)
  File "/home/jos/.local/lib/python3.5/site-packages/tornado/platform/asyncio.py", line 117, in _handle_events
    handler_func(fileobj, events)
  File "/home/jos/.local/lib/python3.5/site-packages/tornado/stack_context.py", line 276, in null_wrapper
    return fn(*args, **kwargs)
  File "/home/jos/.local/lib/python3.5/site-packages/zmq/eventloop/zmqstream.py", line 450, in _handle_events
    self._handle_recv()
  File "/home/jos/.local/lib/python3.5/site-packages/zmq/eventloop/zmqstream.py", line 480, in _handle_recv
    self._run_callback(callback, msg)
  File "/home/jos/.local/lib/python3.5/site-packages/zmq/eventloop/zmqstream.py", line 432, in _run_callback
    callback(*args, **kwargs)
  File "/home/jos/.local/lib/python3.5/site-packages/tornado/stack_context.py", line 276, in null_wrapper
    return fn(*args, **kwargs)
  File "/home/jos/.local/lib/python3.5/site-packages/ipykernel/kernelbase.py", line 283, in dispatcher
    return self.dispatch_shell(stream, msg)
  File "/home/jos/.local/lib/python3.5/site-packages/ipykernel/kernelbase.py", line 233, in dispatch_shell
    handler(stream, idents, msg)
  File "/home/jos/.local/lib/python3.5/site-packages/ipykernel/kernelbase.py", line 399, in execute_request
    user_expressions, allow_stdin)
  File "/home/jos/.local/lib/python3.5/site-packages/ipykernel/ipkernel.py", line 208, in do_execute
    res = shell.run_cell(code, store_history=store_history, silent=silent)
  File "/home/jos/.local/lib/python3.5/site-packages/ipykernel/zmqshell.py", line 537, in run_cell
    return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
  File "/home/jos/.local/lib/python3.5/site-packages/IPython/core/interactiveshell.py", line 2662, in run_cell
    raw_cell, store_history, silent, shell_futures)
  File "/home/jos/.local/lib/python3.5/site-packages/IPython/core/interactiveshell.py", line 2785, in _run_cell
    interactivity=interactivity, compiler=compiler, result=result)
  File "/home/jos/.local/lib/python3.5/site-packages/IPython/core/interactiveshell.py", line 2903, in run_ast_nodes
    if self.run_code(code, result):
  File "/home/jos/.local/lib/python3.5/site-packages/IPython/core/interactiveshell.py", line 2963, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-4-d7ec4c591fe5>", line 11, in <module>
    online_em_update_op = em_learning.accumulate_and_update_weights()
  File "/home/jos/spn/libspn/libspn/learning/hard_em.py", line 76, in accumulate_and_update_weights
    accumulate_updates = self.accumulate_updates()
  File "/home/jos/spn/libspn/libspn/learning/hard_em.py", line 83, in accumulate_updates
    self._mpe_path.get_mpe_path(self._root)
  File "/home/jos/spn/libspn/libspn/inference/mpe_path.py", line 80, in get_mpe_path
    compute_graph_up_down(root, down_fun=down_fun, graph_input=graph_input)
  File "/home/jos/spn/libspn/libspn/graph/algorithms.py", line 144, in compute_graph_up_down
    down_values[child] = down_fun(child, parent_vals)
  File "/home/jos/spn/libspn/libspn/inference/mpe_path.py", line 68, in down_fun
    for i in node.inputs], **kwargs)
  File "/home/jos/spn/libspn/libspn/graph/op/conv_products.py", line 359, in _compute_log_mpe_path
    return self._compute_mpe_path_common(counts, *input_values)
  File "/home/jos/spn/libspn/libspn/graph/op/conv_products_depthwise.py", line 96, in _compute_mpe_path_common
    data_format="NCHW")
  File "/home/jos/.local/lib/python3.5/site-packages/tensorflow/python/ops/nn_ops.py", line 2077, in conv2d_backprop_input
    explicit_paddings, data_format, dilations, name)
  File "/home/jos/.local/lib/python3.5/site-packages/tensorflow/python/ops/gen_nn_ops.py", line 1407, in conv2d_backprop_input
    name=name)
  File "/home/jos/.local/lib/python3.5/site-packages/tensorflow/python/framework/op_def_library.py", line 788, in _apply_op_helper
    op_def=op_def)
  File "/home/jos/.local/lib/python3.5/site-packages/tensorflow/python/util/deprecation.py", line 507, in new_func
    return func(*args, **kwargs)
  File "/home/jos/.local/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 3616, in create_op
    op_def=op_def)
  File "/home/jos/.local/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 2005, in __init__
    self._traceback = tf_stack.extract_stack()


### Visualize MPE state per class
We can visualize the MPE state computed at the end of the script above.

In [None]:
for sample in per_class_mpe:
    _, ax = plt.subplots()
    ax.imshow(sample.reshape(28, 28).astype(float), cmap='gray')
    plt.show()