In [None]:
import libspn as spn
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import scipy as scp

# Params

### Structure

In [None]:
# Type of input distributions for dense generator
input_dist = spn.DenseSPNGeneratorMultiNodes.InputDist.RAW

num_decomps=1
num_subsets=5
num_mixtures=2

# Number of input mixtures for dense generator
num_input_mixtures = 2

### Learning

In [None]:
# Additive smoothing during learning
additive_smoothing=100
min_additive_smoothing=1
smoothing_decay=0.2

# Weight initialization
# weight_init_value = 1
weight_init_value = spn.ValueType.RANDOM_UNIFORM(10, 11)

# Type of inference during upward pass of learning
value_inference_type = spn.InferenceType.MARGINAL

# Training Set

In [None]:
datasets = tf.contrib.learn.datasets.mnist.read_data_sets("mnist")

In [None]:
def show_mnist(img):
    img=np.reshape(img, (14,14))
    plt.imshow(img, interpolation='none')
    plt.show()

### Process data

In [None]:
def process_set(data):
    threshold = 20
    images = np.reshape(data, (-1, 28,28))
    resized=[]
    for i in range(images.shape[0]):
        resized.append((scp.misc.imresize(images[i, :, :], 0.5, interp='nearest').ravel() > threshold).astype(dtype=int))
    images=np.vstack(resized)
    return images

In [None]:
train_images = process_set(datasets.train.images)
test_images = process_set(datasets.test.images)
train_labels = datasets.train.labels
test_labels = datasets.test.labels
validation_labels = datasets.validation.labels

In [None]:
show_mnist(train_images[19])

In [None]:
train_images_0 = train_images[train_labels==0]
train_images_1 = train_images[train_labels==1]
test_images_0 = test_images[test_labels==0]
test_images_1 = test_images[test_labels==1]

In [None]:
print(train_images_0.shape)
show_mnist(train_images_0[334])
print(train_images_1.shape)
show_mnist(train_images_1[22])
print(test_images_0.shape)
show_mnist(test_images_0[334])
print(test_images_1.shape)
show_mnist(test_images_1[22])

# Generate SPN structure

In [None]:
ivs = spn.IVs(num_vars=196, num_vals=2)
dense_gen = spn.DenseSPNGeneratorMultiNodes(num_decomps=num_decomps, num_subsets=num_subsets,
                                            num_mixtures=num_mixtures, input_dist=input_dist,
                                            num_input_mixtures=num_input_mixtures, multi_nodes=True)
root0 = dense_gen.generate(ivs)
root1 = dense_gen.generate(ivs)
root = spn.Sum(root0, root1)
spn.generate_weights(root, init_value=weight_init_value)
latent = root.generate_ivs()
print(root.get_num_nodes())

# Add Learning Ops

In [None]:
additive_smoothing_var = tf.Variable(additive_smoothing, dtype=spn.conf.dtype)
learning = spn.EMLearning(root, log=True, value_inference_type = value_inference_type,
                          additive_smoothing=additive_smoothing_var)
init_weights = spn.initialize_weights(root)
reset_accumulators = learning.reset_accumulators()
accumulate_updates = learning.accumulate_updates()
update_spn = learning.update_spn()
train_likelihood=learning.value.values[root]
avg_train_likelihood = tf.reduce_mean(train_likelihood)

# Run Training

In [None]:
sess=None

In [None]:
train_set = np.concatenate([train_images_0, train_images_1], 0)
train_labels = np.concatenate([np.ones((train_images_0.shape[0]))*0, 
                               np.ones((train_images_1.shape[0]))*1])
train_labels=np.reshape(train_labels, (-1, 1))

if sess is not None:
    sess.close()
sess = tf.Session()

sess.run(init_weights)
sess.run(reset_accumulators)

num_batches=1
batch_size = train_set.shape[0] // num_batches
prev_likelihood = 100
likelihood = 0
epoch = 0
while abs(prev_likelihood - likelihood)>0.1:
    prev_likelihood=likelihood
    likelihoods = []
    for batch in range(num_batches):
        start = (batch)*batch_size
        stop = (batch+1)*batch_size
        print("EPOCH", epoch, "BATCH", batch, "SAMPLES", start, stop)
        # Adjust smoothing
        ads=max(np.exp(-epoch*smoothing_decay)*additive_smoothing, min_additive_smoothing)
        sess.run(additive_smoothing_var.assign(ads))
        print("Smoothing: ", sess.run(additive_smoothing_var)) 
        # Run accumulate_updates
        train_likelihoods_arr, avg_train_likelihood_val, _, = \
                sess.run([train_likelihood, avg_train_likelihood, accumulate_updates],
                        feed_dict={ivs: train_set[start:stop],
                                   latent: train_labels[start:stop]})
        # Print avg likelihood of this batch data on previous batch weights
        print("Avg likelihood (this batch data on previous weights): %s" % (avg_train_likelihood_val))
        likelihoods.append(avg_train_likelihood_val)
        # Update weights
        sess.run(update_spn)
    likelihood = sum(likelihoods) / len(likelihoods)
    print("Avg likelihood: %s" % (likelihood))
    epoch+=1
    sess.run(reset_accumulators)

# Generate Testing Ops

In [None]:
mpe_state_gen = spn.MPEState(log=True, value_inference_type=spn.InferenceType.MPE)

In [None]:
mpe_ivs, mpe_latent = mpe_state_gen.get_state(root, ivs, latent)

# Run Testing

### MPE of the whole network

In [None]:
mpe_ivs_val, mpe_latent_val = sess.run([mpe_ivs, mpe_latent], feed_dict={ivs: np.ones((1, 14*14), dtype=int) * -1,
                                                                       latent: [[-1]]})
show_mnist(mpe_ivs_val)
print(mpe_latent_val)

### MPE for 0

In [None]:
mpe_ivs_val, mpe_latent_val = sess.run([mpe_ivs, mpe_latent], feed_dict={ivs: np.ones((1, 14*14), dtype=int) * -1,
                                                                       latent: [[0]]})
show_mnist(mpe_ivs_val)
print(mpe_latent_val)

### MPE for 1

In [None]:
mpe_ivs_val, mpe_latent_val = sess.run([mpe_ivs, mpe_latent], feed_dict={ivs: np.ones((1, 14*14), dtype=int) * -1,
                                                                       latent: [[1]]})
show_mnist(mpe_ivs_val)
print(mpe_latent_val)

### Classification of Training Data

In [None]:
mpe_latent_val = sess.run([mpe_latent], feed_dict={ivs: train_set, 
                                                   latent: np.ones((train_set.shape[0], 1))*-1})
result=(mpe_latent_val==train_labels)
np.sum(result) / train_labels.size

### Classification of Test Data

In [None]:
test_set = np.concatenate([test_images_0, test_images_1], 0)
test_labels = np.concatenate([np.ones((test_images_0.shape[0]))*0, 
                               np.ones((test_images_1.shape[0]))*1])
test_labels=np.reshape(test_labels, (-1, 1))

In [None]:
mpe_latent_val = sess.run([mpe_latent], feed_dict={ivs: test_set, 
                                                   latent: np.ones((test_set.shape[0], 1))*-1})
result=(mpe_latent_val==test_labels)
np.sum(result) / test_labels.size