[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pronobis/libspn-keras/blob/master/examples/notebooks/Randomly%20Structured%20SPNs%20Image%20Classification.ipynb)

In [None]:
!pip install libspn-keras

# RAT-SPNs
Randomized tensor SPNs are SPNs with randomized region graphs. In `libspn-keras` we've made it straightforward to implement such SPNs. For more background on these types of SPNs see the paper _Random Sum-Product Networks:
A Simple and Effective Approach to Probabilistic Deep Learning_ by [Peharz et al. (2019)](http://auai.org/uai2019/proceedings/papers/124.pdf)

## Defining a `tf.keras.Sequential`
We'll define the SPN below:

In [None]:
import libspn_keras as spnk
from tensorflow import keras

spnk.set_default_sum_op(spnk.SumOpGradBackprop())
spnk.set_default_accumulator_initializer(
    keras.initializers.TruncatedNormal(stddev=0.5, mean=1.0)
)

sum_product_stack = [
    # 8 vars
    spnk.layers.ReduceProduct(num_factors=8),
    spnk.layers.DenseSum(num_sums=8),
    # 64 vars
    spnk.layers.ReduceProduct(num_factors=8),
    spnk.layers.DenseSum(num_sums=8),
    # 128 vars
    spnk.layers.DenseProduct(num_factors=2),
    spnk.layers.DenseSum(num_sums=8),
    # 256 vars
    spnk.layers.DenseProduct(num_factors=2),
    spnk.layers.DenseSum(num_sums=8),
    # 512 vars
    spnk.layers.DenseProduct(num_factors=2),
    spnk.layers.DenseSum(num_sums=8),
    # 1024 vars
    spnk.layers.DenseProduct(num_factors=2),
    spnk.layers.DenseSum(num_sums=1),
    spnk.layers.Undecompose(),
    spnk.layers.RootSum(return_weighted_child_logits=True)
]

factors = [
  layer.num_factors for layer in sum_product_stack 
  if isinstance(layer, (spnk.layers.DenseProduct, spnk.layers.ReduceProduct))
]

location_initializer = keras.initializers.TruncatedNormal(stddev=0.5, mean=0.0)

pre_stack = [
    spnk.layers.NormalizeStandardScore(input_shape=[784]),
    spnk.layers.FlatToRegions(num_decomps=10),
    spnk.layers.NormalLeaf(
        num_components=32, 
        location_initializer=location_initializer,
    ),
    spnk.layers.PermuteAndPadScopesRandom(factors=factors)
]

rat_spn = keras.Sequential(pre_stack + sum_product_stack)

rat_spn.summary()

## Preparing the data
We'll use `tensorflow_datasets` to set up our data.

In [None]:
import tensorflow_datasets as tfds
import tensorflow as tf

batch_size = 128

def flatten_img(img, labels):
    return tf.reshape(img, [tf.shape(img)[0], -1]), labels

mnist_train = tfds.load(name="mnist", split="train", as_supervised=True)
mnist_train = (
    mnist_train
    .shuffle(10_000)
    .batch(batch_size)
    .map(flatten_img)
)

mnist_test = tfds.load(name="mnist", split="test", as_supervised=True)
mnist_test = mnist_test.batch(100).map(flatten_img)

## Optimizer, loss and metrics
A few other key components are also simply coming from `tensorflow.keras`:

In [None]:
optimizer = keras.optimizers.Adam(learning_rate=1e-2)
metrics = [keras.metrics.SparseCategoricalAccuracy()]
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

rat_spn.compile(loss=loss, metrics=metrics, optimizer=optimizer)

## Training and evaluation
We can train and evaluate again using the `tensorflow.keras` interface:

In [None]:
rat_spn.fit(mnist_train, epochs=10)

rat_spn.evaluate(mnist_test)