[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pronobis/libspn-keras/blob/master/examples/notebooks/DGC-SPN%20Image%20Completion.ipynb)
# Training a Deep Generalized Convolutional Sum-Product Network (DGC-SPN) for image completion.
Let's go through an example of building complex SPNs with [`libspn-keras`](https://github.com/pronobis/libspn-keras). The layer-based API of the library makes it straightforward to build advanced SPN architectures.

First let's set up the dependencies:

In [None]:
!pip install libspn-keras matplotlib

## The Data
We'll use the Olivetti faces dataset. Note that this dataset is small by today's standards. Nevertheless, we'll be able to produce pretty good completions.

In [None]:
!wget https://raw.githubusercontent.com/pronobis/libspn-keras/master/examples/notebooks/olivetti.raw

In [None]:
import numpy as np
import tensorflow as tf

def load_olivetti(test_size=50):
    x = np.loadtxt("olivetti.raw").transpose().reshape(
        400, 64, 64, 1).transpose((0, 2, 1, 3)).astype(np.float32)
    train_x = x[:-test_size]
    test_x = x[-test_size:]
    return train_x, test_x

train_x, test_x = load_olivetti()

train_x_ds = tf.data.Dataset.from_tensor_slices((train_x,)).batch(32)
test_x_ds = tf.data.Dataset.from_tensor_slices((test_x,)).batch(32)


## A Reasonable Default for Accumulator Initializers
Since we're going to be using hard EM-based learning that will automatically break symmetries for equally weighted sums in the SPN, 
we'll use a default accumulator initializer that initializes all weights with 
$$\boldsymbol w_{\text{sum}} \sim Dir(\alpha)$$

In [None]:
import libspn_keras as spnk

spnk.set_default_accumulator_initializer(
    spnk.initializers.Dirichlet(alpha=0.1)
)

## Leaf Variables
For the leaf variable, we'll use `spnk.layers.NormalLeaf`. We initialize the locations of the normal distribution components through an initializer that was based on the seminal work by [Poon and Domingos (2011)](https://arxiv.org/abs/1202.3732). This splits up train data into $n$ quantiles per pixel, after which
the mean of the $i$-th quantile is used as the location of the $i$-th component.

In [None]:
normalize = spnk.layers.NormalizeStandardScore(axes=spnk.layers.NormalizeAxes.GLOBAL, input_shape=(64, 64, 1))

normalize.adapt(train_x_ds)

train_x_normalized = train_x_ds.map(normalize)

location_initializer = spnk.initializers.PoonDomingosMeanOfQuantileSplit(
    data=train_x_normalized
)

## DGC-SPN
A DGC-SPN consists of convolutional product and sum nodes. 

We begin by using non-overlapping convolution patches for our first two product layers,
since this way we make sure that no scopes overlap between products. 

For all convolutional product layers, we'll use *exponentially increasing dilation rates*. By doing so, we have 'overlapping' patches that still yield a valid SPN. These exponentially increasing dilation rates for convolutional SPNs were first introduced in [this paper](https://arxiv.org/abs/1902.06155).

The layer SPN construction is wrapped in a function so that we can redefine it later for other purposes. Don't worry, we'll get to the details.

In [None]:
def define_spn(sum_op, infer_no_evidence=False):
    spnk.set_default_sum_op(sum_op)
    spnk.set_default_linear_accumulators_constraint(spnk.constraints.GreaterEqualEpsilon())
    stack = [
        spnk.layers.NormalizeStandardScore(input_shape=(64, 64, 1)),
        # Non-overlapping products
        spnk.layers.NormalLeaf(
            num_components=8, 
            location_trainable=True,
            location_initializer=location_initializer,
            use_accumulators=True,
            scale_trainable=False
        ),
        tf.keras.layers.Dropout(rate=0.1),
        spnk.layers.Conv2DProduct(
            depthwise=True, 
            strides=[1, 1], 
            dilations=[1, 1], 
            kernel_size=[2, 2],
            padding='full'
        ),
        spnk.layers.Local2DSum(num_sums=2),
        # Non-overlapping products
        spnk.layers.Conv2DProduct(
            depthwise=False, 
            strides=[1, 1], 
            dilations=[2, 2], 
            kernel_size=[2, 2],
            padding='full'
        ),
        spnk.layers.Local2DSum(num_sums=2),
        # Overlapping products, starting at dilations [1, 1]
        spnk.layers.Conv2DProduct(
            depthwise=False, 
            strides=[1, 1], 
            dilations=[4, 4], 
            kernel_size=[2, 2],
            padding='full'
        ),
        spnk.layers.Local2DSum(num_sums=2),
        # Overlapping products, with dilations [2, 2] and full padding
        spnk.layers.Conv2DProduct(
            depthwise=False, 
            strides=[1, 1], 
            dilations=[8, 8], 
            kernel_size=[2, 2],
            padding='full'
        ),
        spnk.layers.Local2DSum(num_sums=2),
        # Overlapping products, with dilations [2, 2] and full padding
        spnk.layers.Conv2DProduct(
            depthwise=False, 
            strides=[1, 1], 
            dilations=[16, 16], 
            kernel_size=[2, 2],
            padding='full'
        ),
        spnk.layers.Local2DSum(num_sums=2),
        spnk.layers.Conv2DProduct(
            depthwise=False, 
            strides=[1, 1], 
            dilations=[32, 32], 
            kernel_size=[2, 2],
            padding='full'
        ),
        spnk.layers.Local2DSum(num_sums=2),
        # Overlapping products, with dilations [2, 2] and 'final' padding to combine 
        # all scopes
        spnk.layers.Conv2DProduct(
            depthwise=False, 
            strides=[1, 1], 
            dilations=[64, 64], 
            kernel_size=[2, 2],
            padding='final'
        ),
        spnk.layers.LogDropout(rate=0.1),
        spnk.layers.SpatialToRegions(),
        spnk.layers.RootSum(
            return_weighted_child_logits=False
        )
    ]
    sum_product_network = spnk.models.SequentialSumProductNetwork(
      stack, infer_no_evidence=infer_no_evidence)
    return sum_product_network

sum_product_network = define_spn(sum_op=spnk.SumOpEMBackprop())
sum_product_network.summary()

### Training the DGC-SPN with hard EM
We use the `libspn_keras.optimizers.OnlineExpectationMaximization` optimizer. Note that this optimizer works for generative training setups only and should be
combined with one of the EM-based sum ops. In the snippet above, we have set this to ``SumOpUnweightedEMBackprop``, so we're good to go!

We'll configure a train set and a test set using `tf.data.Dataset`. As first suggested in the work by [Poon and Domingos](https://arxiv.org/abs/1202.3732), we normalize each sample by subtracting the mean and dividing by the
standard devation.

In [None]:
import tensorflow as tf
from libspn_keras.optimizers import OnlineExpectationMaximization
from libspn_keras import losses
from libspn_keras import metrics

batch_size = 100

train_data = (
    tf.data.Dataset.from_tensor_slices((train_x,))
    .shuffle(350)
    .batch(batch_size)
)

test_data = (
    tf.data.Dataset.from_tensor_slices((test_x,))
    .batch(batch_size)
)

sum_product_network.compile(
    optimizer=spnk.optimizers.OnlineExpectationMaximization(learning_rate=0.05),
    loss=spnk.losses.NegativeLogLikelihood(),
    metrics=[spnk.metrics.LogLikelihood()]
)
sum_product_network.fit(train_data, epochs=100)
sum_product_network.evaluate(test_data)

Now comes the tricky part. To do image completion, we need to propagate some kind of signal to the leafs that we occlude. Probabilistically, we exclude the variables corresponding to the occluded pixels from the evidence $\mathbf E$.

In the seminal work by [Poon and Domingos (2011)](https://arxiv.org/abs/1202.3732), in which Sum-Product Networks were proposed as a new type of deep probabilistic models, the image completion was accomplished by backpropagating the **gradients** and using the these to obtain the 'posterior marginal' of the missing variables.

The paper itself, however, does not mention this explicitly. Only after we inspect the code from the paper that comes with it we find:
```java
// compute marginal by differentiation; see Darwiche-03 for details 
public void cmpMAPBottomHalfMarginal(Instance inst) {
  setInputOccludeBottomHalf(inst);
  eval();		
  cmpDerivative();
      
  for (int i=0; i<Parameter.inputDim1_/2; i++) {
    for (int j=0; j<Parameter.inputDim2_; j++) 
      MyMPI.buf_int_[MyMPI.buf_idx_++]=Utils.getIntVal(inst, inst.vals_[i][j]);
  }
  for (int i=Parameter.inputDim1_/2; i<Parameter.inputDim1_; i++) {
    for (int j=0; j<Parameter.inputDim2_; j++) {
      int ri=Region.getRegionId(i, i+1, j, j+1);
      Region r=Region.getRegion(ri);
      double p=cmpMarginal(r);
      MyMPI.buf_int_[MyMPI.buf_idx_++]=Utils.getIntVal(inst,p);//(int)(p*255);
    }
  }
}
```
Note the commpent on top! [The paper by Darwiche (2003)](http://reasoning.cs.ucla.edu/fetch.php?id=22&type=pdf) indeed mentions a way of determining the _posterior marginal_ of any variable when the network computes a _network polynomial_. SPNs happen to compute a network polynomial, so we can indeed use gradients to compute posterior marginals!

From the paper, we have:

_For every variable $X$ and evidence $\mathbf e \notin \mathbf E$_:
$$
P(x_i \mid \mathbf e) = \frac{1}{\text{Root}(\mathbf e)} \frac{\partial \text{Root}}{\partial \text{Leaf}_{x_i}}(\mathbf e)
$$

Note that in the last equation we are using lower case variable with an index: $x_i$. The index corresponds to the _component_ on top of the variable. In other words, $P(x_i \mid \mathbf e)$ is the posterior marginal of the $i$-th component attached to variable $X$. The variable $X$ is excluded from the evidence $\mathbf E$. Lower case $\mathbf e$ is the actual assignment of pixel values in the part of the image that is _included_ (not occluded). Note that $P(x_i \mid \mathbf e)$ is **not** the output probability of the leaf. Rather, it is the probability that the pixel value at $X$ was _generated_ by the component at $x_i$.

In the code by Poon and Domingos we mentioned earlier, the posterior marginals are used as follows:
$$
\text{InferredPixelValue} = \sum_i \mu_i P(x_i \mid \mathbf e)
$$
Where $\mu_i$ is the mean of the $i$-th Gaussian component.

So how would this work for the SPN we have now? The layer implementations in LibSPN Keras are propagating log probabilities. So at the root of the network we'll find $\log(\text{Root}(\mathbf e))$ instead of just $\text{Root}(\mathbf e)$. Simply wrapping the value with $\exp(\cdot)$ results in numerical difficulties, so that's a dead end. 

Let's reverse our approach: we'll start propagating gradients from the roots to the leafs and see how far we are from solving the problem!

If we just use TensorFlow's autograd engine we can obtain something like:
```python
leaf_log_prob = leaf(img)
leaf_log_prob = tf.where(evidence_mask, leaf_log_prob, tf.zeros_like(leaf_log_prob))
with tf.GradientTape() as g:
    g.watch(leaf_log_prob)
    root = apply_sum_product_stack(leaf_log_prob)
droot_log_prob_dleaf_log_prob = g.gradient(root_log_prob, leaf_log_prob)
```

This gives us:

$$
\frac{\partial \log(\text{Root})}{\partial \log(\text{Leaf}_{x_i})}
$$

Which we can rewrite a bit:

\begin{align}
\frac{\partial \log(\text{Root})}{\partial \log(\text{Leaf}_{x_i})}
&=\frac{\partial \log(\text{Root})}{\partial \text{Root}}
\frac{\partial \text{Root}}{\partial \log(\text{Leaf}_{x_i})} \\&= \frac{1}{Root} \frac{\partial \text{Root}}{\partial \log(\text{Leaf}_{x_i})} \\ &= \frac{1}{Root} \frac{\partial \text{Root}}{\partial \log(\text{Leaf}_{x_i})} \frac{1}{\text{Leaf}_{x_i}} \text{Leaf}_{x_i} \\&=  \frac{1}{Root} \frac{\partial \text{Root}}{\partial \log(\text{Leaf}_{x_i})} \frac{\partial \log (\text{Leaf}_{x_i})}{\partial \text{Leaf}_{x_i}} \text{Leaf}_{x_i} \\&=  \frac{1}{Root} \frac{\partial \text{Root}}{\partial \text{Leaf}_{x_i}}  \text{Leaf}_{x_i} \\
&=\frac{1}{Root} \frac{\partial \text{Root}}{\partial \text{Leaf}_{x_i}} & (x_i \notin \mathbf E\text{, so } \text{Leaf}_{x_i} = 1)\\
&= P(x_i \mid \mathbf e)
\end{align}

So in fact, we already have what we want in `droot_log_prob_dleaf_log_prob`!

Below, we use those observations to compute the image completions.


In [None]:
import matplotlib.pyplot as plt
from tensorflow import keras

%matplotlib inline

completion_model = define_spn(
    sum_op=spnk.SumOpGradBackprop(logspace_accumulators=False),
    infer_no_evidence=True
)
completion_model.set_weights(sum_product_network.get_weights())
completion_model.compile(metrics=[keras.metrics.MeanSquaredError()])

def eval(model, x, omit_side):
    print("omitting ", omit_side)
    evidence_mask = get_evidence_mask(omit_side).astype(np.bool)
    model.evaluate([x, evidence_mask], x, verbose=2)
    completion_out = model.predict([x, evidence_mask])
    image_grid = make_image_grid(completion_out, num_rows=5)

    plt.figure(figsize=(14, 14))
    plt.imshow(image_grid.squeeze(), cmap='gray')
    plt.show()

def make_image_grid(images, num_rows):
    images_per_row = np.split(images, axis=0, indices_or_sections=num_rows)
    rows = [np.concatenate(imgs, axis=1) for imgs in images_per_row]
    full_grid = np.concatenate(rows, axis=0)
    return full_grid


def get_evidence_mask(omit_side):
  if omit_side == "top":
    return np.concatenate(
        [np.zeros([50, 32, 64, 1]), np.ones([50, 32, 64, 1])], axis=1)
  elif omit_side == 'bottom':
    return np.concatenate(
        [np.ones([50, 32, 64, 1]), np.zeros([50, 32, 64, 1])], axis=1)
  elif omit_side == 'right':
    return np.concatenate(
        [np.ones([50, 64, 32, 1]), np.zeros([50, 64, 32, 1])], axis=2)
  elif omit_side == 'left':
    return np.concatenate(
        [np.zeros([50, 64, 32, 1]), np.ones([50, 64, 32, 1])], axis=2)
  else:
    raise ValueError("We have a problem")

eval(completion_model, test_x, 'top')
eval(completion_model, test_x, 'bottom')
eval(completion_model, test_x, 'left')
eval(completion_model, test_x, 'right')