# Sampling neural network
Here we develop a neural network that learns by Gibbs sampling the distribution over parameters.

## Math

We have a Multi-Layer Perceptron with parameters.  We will use the notation of [1](http://papers.nips.cc/paper/5269-expectation-backpropagation-parameter-free-training-of-multilayer-neural-networks-with-continuous-or-discrete-weights), and refer to the weight from unit $i$ in layer $l-1$ to unit $j$ in layer l as $W_{ijl}$

Our MLP defines some deterministic function $f(X, W)$, where $X$ is a shape $(N_{samples}, N_{dims})$ array of input data, and $W_l$ is a list of $(N_{units(l-1)}, N_{units(l)})$ weight matrix (this may later be generalized for conv-nets, etc)

For 1-of-K classification, we restrict f(X_n, W) to output a discrete distribution over {1..K}.  We can then say that the network defines a distibution over Y given X, W.

$$p(Y|X, W) = \prod_{n=1}^{N_{samples}}Categorical\big(f(X_n, W)\big)$$

Now, the goal of training is to find W that maximizes $p(W|X,Y)$.  Applying bayes rule,

$$ \\
\begin{align} \\
p(W|X,Y)&=\frac{p(Y|X,W)p(W)}{p(Y)} \\
&\propto p(W)p(Y|X,W) \\
p(W)p(Y|X,W)&=p(W)\prod_{n=1}^{N_{samples}}Categorical\big(f(X_n, W)\big)(Y_n) \\
&=p(W)\prod_{n=1}^{N_{samples}}f(X_n, W)_{Y_n} \\
&\equiv L(W|X,Y)
\end{align}
$$

For numerical reasons, we work in terms of log-likelihood.

$$
\begin{align} \\
logL(W|X,Y)&\equiv log(L(W|W,Y))\\
&=log(p(W))+\sum_{n=1}^{N_{samples}}log(f(X_n,W)_{Y_n}) \\
\end{align}
$$


Lets make a few assumptions to simplify things.
- Each weight $W_{ijl}$ has an independent prior, so $p(W)=\prod_{ijl}p(W_{ijl})$
- Weights are discrete - that is they can take on 1-of-K values.

Now, we want to find $W$ that maximizes $p(W|X,Y)$.  We can use Gibbs sampling.  For a given weight-index, $\alpha \in (i,j,l)$, and a set of possible weights $c_k, k \in 1...K$

$$
\begin{align} \\
p(W_{\alpha}=c_k|W_{~\alpha}, X, Y) &= \frac{[L(W_{\alpha}=c_k|W_{~\alpha}, X, Y), k \in 1..K]}{\sum_k L(W_{\alpha}=c_k|W_{~\alpha}, X, Y)} \\
p(W_{\alpha}=c_k|W_{~\alpha}, X, Y) &= softmax([logL(W|W_{\alpha}=c_k, X, Y), k \in 1...K])\\
&= softmax([log(p(W))+\sum_{n=1}^{N_{samples}}log(f(X_n,W_{w_{\alpha}=c_k})_{Y_n}), k \in 1...K])_k
\end{align}
$$

So, we have an Gibbs sampling update:
$$
W_{\alpha} \sim Categorical \Big(softmax\big(\big[log(p(W_{\alpha}=c_k))+\sum_{n=1}^{N_{samples}}log(f(X_n,W_{w_{\alpha}=c_k})_{Y_n},k \in 1..K\big]\big)\Big)
$$

## Demo
To demo this concept, we will design a simple 3-layer (input, hidden, output) network, with discrete valued weights.  We will train this on a trivial synthetic dataset consiting of clustered binary vectors.

In [0]:
from general.mymath import sigm, softmax

n_training_samples = 500           # Number of samples in training set
n_test_samples = 500               # Number of samples in test set
n_passes = 20                      # Number of Gibbs-sampling passes to make through all the weights
n_input = 10                       # Input dimension of dataset/number of units in input layer
n_hidden = 10                      # Number of units in hidden layer
n_categories = 4                   # Number of categories in dataset
possible_ws = (-1, 0, 1)           # Possible values for W
test_every = 1                     # Test every <X> passes
minibatch_size = None              # None for full-batch
hidden_activation_funcion = sigm   # Nonlinear activation function for hidden layer
seed = 3525423                     # Random seed

In [0]:
# Get Dataset
# The Synthetic Clusters Dataset consists of binary input vectors as inputs, and integer labels.  Each label
# has an associated binary vector.  Samples are generated by taking an (input_vector, label) pair and corrupting
# the input vector by randomly flipping some bits (see flip_noise parameter). 
from utils.datasets.synthetic_clusters import get_synthetic_clusters_dataset
dataset = get_synthetic_clusters_dataset(n_dims = n_input, n_training=n_training_samples, n_test=n_test_samples, n_clusters = n_categories, flip_noise=0.1)
print dataset
x_tr, y_tr, x_ts, y_ts = dataset.xyxy

In [0]:
import numpy as np
from itertools import product
from utils.benchmarks.train_and_test import percent_argmax_correct

# Define network
b_h = np.zeros(n_hidden)
b_o = np.zeros(n_categories)
w_ih = np.zeros((n_input, n_hidden))
w_ho = np.zeros((n_hidden, n_categories))
w = [w_ih, w_ho]
p_y_given_wx = lambda x: softmax(sigm(x.dot(w_ih)+b_h).dot(w_ho)+b_o, axis = 1)
rng = np.random.RandomState(seed)

# Now, train the network on the dataset.
n_samples_in_minibatch = minibatch_size if minibatch_size is not None else n_training_samples
all_indices = [(i, j, 0) for i, j in np.ndindex(n_input, n_hidden)] + [(i, j, 1) for i, j in np.ndindex(n_hidden, n_categories)]

for t in xrange(n_passes):
    if t % test_every == 0:
        score = percent_argmax_correct(p_y_given_wx(x_ts), y_ts)
        print 'Pass %s of %s.  Score: %.2f%%' % (t, n_passes, score)
    for (i, j, l) in all_indices:
        ixs = slice(None) if minibatch_size is None else rng.choice(n_training_samples, size=minibatch_size, replace = False)
        weight_likelihoods = np.empty((len(possible_ws), n_samples_in_minibatch))
        old_w = w[l][i, j]
        for k, c_k in enumerate(possible_ws):
            w[l][i, j] = c_k
            weight_likelihoods[k, :] = p_y_given_wx(x_tr[ixs])[np.arange(n_samples_in_minibatch), y_tr[ixs]]
        w_ijl_dist = softmax(np.sum(np.log(weight_likelihoods), axis = 1))  # Note we ignore the prior over w for now.
        w_ijl = rng.choice(possible_ws, p=w_ijl_dist)
        w[l][i, j] = w_ijl


## Conclusions

We can see that training works.  When `n_clusters` is set to 4, the expected score by guessing is 25%.  Once it's converged, our network guesses around 96% correct on average.  Our network converges within about 3 passes though the data (which says more about the simplicity of the dataset than it does about our network). 

Note that:
- For simplicity we do not learn biases.  It's also not clear if biases should be discretized in the same way as weights are.
- We have assume a uniform prior over possible weight values.  There may be some advantage to designing a prior to encourage sparse weights.
- We could get a better score by computing a running average of p_y_given_wx(x_ts) over t, but we do not, for simplicity.
- The code is extremely inefficient - not just because it's looping in Python, but also because it recomputes the forward pass for the entire network for each possible weight value each time we update a weight.
 