In [None]:
import numpy as np
from time import time
from theano import tensor as T
from theano.sandbox.rng_mrg import MRG_RandomStreams
from theano.tensor.shared_randomstreams import RandomStreams
from theano import function, shared
from theano.tensor.extra_ops import repeat, to_one_hot
from sklearn.preprocessing import OneHotEncoder
from matplotlib import pyplot as plt
%matplotlib inline

### Study the processes and variables needed for theano implementation of variational inference on the probe model
Here I built and tested most of the machinery needed to implement variational inference.

#### Dimensions

In [None]:
floatX = 'float32'
sqrt_D = 8
D = int(sqrt_D*sqrt_D)#number of pixels
K = 2#number of objects
N = 200#number of windows (data points)
M = 10#number of object map samples to generate for calculating object responsibility

#### Generate samples of object maps from a posterior over object maps

In [None]:
##the variational posterior over object maps Z
_Q_Z = T.matrix('Q_Z') ##(K,D)
_M = T.scalar('M',dtype='int32')

##a theano random number generator
rng = MRG_RandomStreams(use_cuda = True)

##sample one Z map from posterior Q_Z
_Z_samples = rng.multinomial(pvals = repeat(_Q_Z.T,_M,axis=0)).reshape((_Q_Z.shape[1],_M,_Q_Z.shape[0])).dimshuffle((1,2,0))

##functionalize
Z_sample_func = function([_Q_Z,_M],outputs=_Z_samples)

In [None]:
##to test first generate posterior from a dirichlet distribution
alpha_0 = 1.1
Q_Z = np.zeros((K,D), dtype=floatX)
probs = np.random.dirichlet([alpha_0]*K,)
for d in range(D): #np.random.permutation(D):
    if not np.mod(d,64):
        probs = np.random.dirichlet([alpha_0]*K,)
    Q_Z[:,d] = probs
print np.sum(Q_Z,axis=0)

In [None]:
##this shows the how the crazy tensor manipulation is working
r_shuff = function([_Q_Z, _M], outputs = repeat(_Q_Z.T,_M,axis=0).reshape((_Q_Z.shape[1],_M,_Q_Z.shape[0])).dimshuffle((1,2,0)))

In [None]:
##this will help visualize
def from_one_hot(Z,axis=0):
    '''
    Z ~ K x D
    convert to 1 x D, D[i] = j, j = argmax(Z[:,i])
    '''
    return np.argmax(Z,axis=axis)

In [None]:
##generate M sample maps (M x K x D)
Z_samples = Z_sample_func(Q_Z, M)

In [None]:
print Z_samples.shape
fig = plt.figure(figsize=(5,5))
plt.pcolor(from_one_hot(Z_samples[0]).reshape(np.sqrt(D),np.sqrt(D)))

#### The object responsibility matrix (i.e., N x K matrix of object count probabilities, N = #windows, K = #possible object counts)
TODO: THIS COMPUTATION IS WAY TOO GODDAMN SLOW.

In [None]:
##block of sampled object maps
_Z = T.tensor3('Z') ##(M x K x D)

##window index indicator (N X D)
_W = T.matrix('windows')


An expression for an (M,N) matrix of objects counts

In [None]:
##(M x K x 1 x D)
##         N x D
##(M x K x N x D)  sum(D)
##(M x K x N)      clip(0,1)
##(M x K x N)      sum(K)
##(M x N)
_O_W = T.sum(_Z.dimshuffle((0,1,'x',2))*_W,axis=-1).clip(0,1).sum(axis=1)

In [None]:
window_object_count_func = function([_Z,_W], outputs=_O_W)

In [None]:
##construct some contiguous windows of varying size
W = np.zeros((N,D),dtype=floatX)
win_stride = np.round(D/N)
size_factor = 2
for n in range(N):
    W[n,(n*win_stride):(n*win_stride+size_factor*n+1)] = 1

In [None]:
plt.pcolor(W)
plt.title('Windows')
plt.xlabel('pixels')
plt.ylabel('window number')

In [None]:
foo = window_object_count_func(Z_samples.astype(floatX),W)

In [None]:
plt.pcolor(foo)
plt.title('object counts')
plt.xlabel('windows')
plt.ylabel('random samples from Q_Z')

In [None]:
##calculate rolling mean 
rolling_mean = np.zeros((M,N))
for m in range(M-1):
    rolling_mean[m,:] = np.mean(foo[:m+1, :], axis=0)

_=plt.plot(rolling_mean[0:-1,:])
plt.title('rolling mean of object counts for %d windows' %(N))
plt.xlabel('number of random samples from Q_Z')
plt.ylabel('mean object count')

Test out one-hot encoding of the object counts (M, N, K)

In [None]:
X = T.matrix(dtype='int32')
object_count_one_hot_func = function([X],to_one_hot(X.flatten()-1,K).reshape((X.shape[0],X.shape[1],K)))

In [None]:
baz = object_count_one_hot_func(foo.astype('int32'))

In [None]:
foo[:,12]

In [None]:
baz[:,12,:]

Having tested the one-hot encoding, we sum across samples and divide to obtain object count probabilities (i.e., the "object responsibility matrix").

In [None]:
_K = T.scalar('objects',dtype='int32')
_R_nn = to_one_hot(_O_W.astype('int32').flatten()-1,_K).reshape((_O_W.shape[0],_O_W.shape[1],_K)).sum(axis=0)

In [None]:
non_norm_resp_func = function([_Z, _W, _K], outputs = _R_nn)

In [None]:
non_norm_resp = non_norm_resp_func(Z_samples.astype(floatX),W,K)

In [None]:
plt.pcolor(non_norm_resp)

In [None]:
_=plt.plot(non_norm_resp.T)

In [None]:
##normalize
_R = _R_nn / _R_nn.sum(axis=1).reshape((_R_nn.shape[0], 1))

In [None]:
##object count probabilities function
object_count_prob_func = function([_Z, _W, _K], outputs = _R)

In [None]:
oc = object_count_prob_func(Z_samples.astype(floatX),W,K)

In [None]:
oc.shape

In [None]:
_=plt.plot(oc.T)

In [None]:
##some timing info--how long for one full sweep of calls for each pixel/object pair?
Z_samples = Z_samples.astype(floatX)
start = time()
for d in range(D):
    if not np.mod(d, np.round(D/12.)):
        print '%d pixels remaining' %(D-d)
    for k in range(K):
        _=object_count_prob_func(Z_samples,W,K)
end = time()-start
print end

In [None]:
2*26*1000/60./60.

#### Likelihood function and parameter updates

In [None]:
from imagery_psychophysics.src.model_z import noise_grid
from scipy.misc import comb as nCk

In [None]:
def counts(r,d,n):
    return np.array([nCk(d,m)*nCk(n-d, r-m) for m in range(min(r,d)+1)])

def lkhd(r,d,n,p_on,p_off):
    probs = np.array([(1-p_on)**(d-m) * (p_on)**m * (p_off)**(r-m) * (1-p_off)**(n-d-r+m) for m in range(min(r,d)+1)])
    #print probs
    return counts(r,d,n).dot(probs)

Critical tensor is the likelihoods iterated over a fine grid of noise parameters


In [None]:
theta_dns = 25
p_on, p_off = noise_grid(theta_dns,theta_dns)
G = len(p_on)
P_theta = np.zeros((G, N, K),dtype=floatX)
r = np.random.randint(1,high=K+1, size=(N,))

In [None]:
np.min(p_on), np.max(p_on), np.min(p_off), np.max(p_off)

In [None]:
##fortunately we only need to generate this once.
for g,p in enumerate(zip(p_on,p_off)):
    for n in range(N):
        for k in range(K):
            P_theta[g,n,k]  = lkhd(r[n],k+1,K, p[0],p[1])

In [None]:
_=plt.plot(np.log(P_theta[-20]).T, '-o')

In [None]:
print np.sum(np.isfinite(np.log(P_theta)))
print G*N*K

This is the simple update rule for the variational log posterior over theta (i.e., the noise parameters). It's understood that this is technically $ln[q(\theta)] - const$

Note also that we probably won't be needing the $ln[q(\theta)]$ output, but we emit it anyway

In [None]:


_P_theta = T.tensor3('P_theta') ##(G x N x K)
_X = T.matrix('dummy') ##N x K ~ this is a stand-in for the "object responsibility matrix" R

##(G x N x K)
##(    N x K)  (dot product, broadcast across G)
##(G x 1)  --> because we don't do vectors we reshape to make output 2Dimensional (G x 1)
_lnQ_theta = T.tensordot(T.log(_P_theta),_X, axes=[[1,2], [0,1]],).reshape((_P_theta.shape[0], 1))

In [None]:
lnQ_theta_update_func = function([_P_theta, _X], outputs = _lnQ_theta)

In [None]:
lnQ_theta_update_func(np.random.random((G,N,K)).astype(floatX), np.random.random((N,K)).astype(floatX)).shape

In [None]:
#select the best noise params
_P_star = _P_theta[T.argmax(_lnQ_theta),:,:]

In [None]:
P_star_update_func = function([_P_theta, _X], outputs = _P_star)

In [None]:
foo = P_star_update_func(P_theta, np.random.random((N,K)).astype(floatX))
print foo.shape

In [None]:
##put the lnQ_theta and lnP_star updates into same handy function
theta_update_func = function([_P_theta,_X], outputs = [_lnQ_theta, _P_star])

In [None]:
lnQ_theta, P_star = theta_update_func(P_theta, np.random.random((N,K)).astype(floatX))

In [None]:
print lnQ_theta.shape ##(G,)
print P_star.shape  ##(N , K)

#### Prior over object maps
Main quantity of interest here is $\mathbb{E}[ln[\pi]$. 

In [None]:
_alpha_0 = T.scalar('alpha_0')
_q_Z = T.matrix('q_Z')  ##K x 1, this is result of summing over pixels in Q_Z matrix

_alpha = _q_Z + _alpha_0 ##broadcasts the scalar _alpha_0 across K

_Eln_pi = T.psi(_alpha) - T.psi(_alpha.sum())

In [None]:
Eln_pi_update_func = function([_q_Z, _alpha_0], outputs = _Eln_pi)

In [None]:
alpha_0 = 1.1
q_Z = np.random.dirichlet([alpha_0]*K,).astype(floatX)[:,np.newaxis] ##a fake q_Z

Eln_pi = Eln_pi_update_func(q_Z, alpha_0)
plt.plot(Eln_pi)

In [None]:
Eln_pi.shape

#### Update variational posterior for object maps

This will be the only update that returns a normalized variational posterior.
Uses the responsibility matrices above.


In [None]:
_R_full = T.tensor4('responsibility_tensor') ##K x D x N x K
_lnP_star = T.matrix('lnP_star') ##N x K
_V = T.matrix('prior_penalties') ## K x D

##K x D x N x K
##        N x K (dot)
##K x D         (add V)
##K x D         exp
##K x D         normalize

_lnQ_Z_nn = T.tensordot(_R_full, _lnP_star, [[2,3], [0,1]])+_V
_Q_Z_nn = T.exp(_lnQ_Z_nn-T.max(_lnQ_Z_nn,axis=0)) 
_Q_Z = _Q_Z_nn / _Q_Z_nn.sum(axis=0)

In [None]:
np.exp(-88, dtype=floatX)

In [None]:
np.exp(-500)

In [None]:
Q_Z_update_func = function([_R_full, _lnP_star, _V], outputs=[_Q_Z, _Q_Z_nn])

In [None]:
##some timing info--how long for one full sweep of calls for each pixel/object pair?
R_full = np.zeros((K,D,N,K),dtype=floatX)
lnP_star = np.log(P_star).astype(floatX)
V = np.random.random(size=(K,D)).astype(floatX)+alpha_0
start = time()
for k in range(K):
    print '%d objects remaining' %(K-k)
    for d in range(D):
        R_full[k,d,:,:] =object_count_prob_func(Z_samples,W,K)
Q_Z_new, foo = Q_Z_update_func(R_full, lnP_star, V)
end = time()-start
print end

In [None]:
foo = _lnQ_Z_nn.eval({_R_full: R_full, _lnP_star: lnP_star, _V: V})

In [None]:
foo

In [None]:
baz = np.exp((foo-np.max(foo,axis=0)))
print baz
# print baz / baz.sum(axis=0)

In [None]:
print Q_Z_new.sum(axis=0)

In [None]:
_=plt.pcolor(Q_Z_new[1,:].reshape(sqrt_D,sqrt_D))

#### Expressions for the ELBO
I suppose this is an approximate ELBO since we are using a $max \approx expectation$ approximation for $\theta$.

In [None]:
_Eln_pi = T.matrix('Eln_pi')  ##K x 1
_Q_Z = T.matrix('Q_Z')        ##K x D
_lnQ_theta = T.matrix('lnQ_theta') ##G x 1

_q_Z = _Q_Z.sum(axis=1,keepdims=True) #K x 1

In [None]:
np.zeros((K,D)).sum(axis=1,keepdims=True).shape

In [None]:
_prior_entropy = -(T.tensordot(_q_Z-1, _Eln_pi)-(T.gammaln(_q_Z.sum()) - T.sum(T.gammaln(_q_Z))))
_posterior_entropy = -T.tensordot(_Q_Z, T.log(_Q_Z)) 

In [None]:
print _prior_entropy.eval({_Q_Z : Q_Z_new, _Eln_pi : Eln_pi})
print _posterior_entropy.eval({_Q_Z : Q_Z_new})

In [None]:
_ELBO = _lnQ_theta.max()  - _posterior_entropy - _prior_entropy

In [None]:
ELBO_update_func = function([_Eln_pi, _Q_Z, _lnQ_theta], outputs=_ELBO)

In [None]:
Eln_pi.shape

In [None]:
lnQ_theta.shape

In [None]:
Q_Z_new.sum(axis=1,keepdims=True).shape

In [None]:
ELBO = ELBO_update_func(Eln_pi, Q_Z_new, lnQ_theta)

In [None]:
ELBO

Well. The sign is right. We'll see if it makes any sense...

#### Simulate data

In [None]:
from imagery_psychophysics.src.stirling_maps import sparse_point_maps as spm

In [None]:
##first some more sensical windows
scales = np.array([2, 4, 6, 8])
stride = 2
sizes = scales/2
Windows = []
for sz in sizes:
    scale_count = 0
    for rows in np.arange(sz,sqrt_D,stride,dtype=int, ):
        for cols in np.arange(sz,sqrt_D,stride,dtype=int):
            one_win = np.zeros((sqrt_D,sqrt_D),dtype=floatX)
            one_win[(rows-sz):(rows+sz), (cols-sz):(cols+sz)]=1
            Windows.append(one_win)
            scale_count +=1
    print scale_count


N = len(Windows)
npairs = 1000
W = np.zeros((N+npairs,D),dtype=floatX)
for n in range(N):
    W[n,:] = Windows.pop().ravel()

for n in range(N,N+npairs):
    rand_pairs = np.random.permutation(N)[:2]
    W[n,:] = np.clip(W[rand_pairs[0],:]+W[rand_pairs[1],:], 0, 1)
    
N = W.shape[0]
print N

In [None]:
##construct a test object map
test_object_map = spm(3,3,sqrt_D,cluster_pref = 'random',number_of_clusters = K)
test_object_map.scatter()
test_object_map = np.squeeze(test_object_map.nn_interpolation())

In [None]:
plt.imshow(test_object_map, cmap='Dark2')
plt.imshow(W[-1].reshape((sqrt_D,sqrt_D)).astype('uint8')*255, interpolation='nearest', alpha = .5, cmap=plt.cm.gray, clim=[0,255])

In [None]:
##convert to one_hot encoding
test_Z = np.eye(K)[test_object_map.ravel()-1].T  ##K x D
d = 5
print test_object_map[d,d]
print test_Z.reshape((K,sqrt_D,sqrt_D))[:,d,d]

In [None]:
##get true object counts for each window
object_counts = np.sum(test_Z[:,np.newaxis,:]*W,axis=-1).clip(0,1).sum(axis=0).astype('int')
object_counts[100]

In [None]:
object_counts

In [None]:
##generate some fake responses using fixed theta parameters
p_on = 0.99
p_off = 0.01
r = np.zeros(object_counts.shape[0], dtype = 'int')
for ii,o in enumerate(object_counts):
    resp_dist = np.zeros(K)
    for k in range(K):
        resp_dist[k] = lkhd(k+1,o,K,p_on,p_off)
        r[ii]=np.argmax(np.random.multinomial(1,resp_dist))+1


In [None]:
r

In [None]:
np.abs(r-object_counts)

In [None]:
np.mean(np.abs(r-object_counts))

#### Complete loop for variational dynamics

In [None]:
##to test first generate posterior from a dirichlet distribution

alpha_0 = 1.1

theta_plus, theta_minus = .99, .01

def init_Q():
    Q_Z = np.zeros((K,D), dtype=floatX)
    probs = np.random.dirichlet([alpha_0]*K,)
    for d in range(D): #np.random.permutation(D):
        if not np.mod(d,64):
            probs = np.random.dirichlet([alpha_0]*K,)
        Q_Z[:,d] = probs
    return Q_Z

def init_Eln_pi():
    q_Z = np.random.dirichlet([2.1]*K,).astype(floatX).reshape((K,1)) ##a fake q_Z
    Eln_pi = Eln_pi_update_func(q_Z, alpha_0)
    return Eln_pi

def init_P_star():
    P_theta = np.zeros((N,K),dtype=floatX)
    for n in range(N):
        for k in range(K):
            P_theta[n,k]  = lkhd(r[n],k+1,K, theta_plus,theta_minus)
    return P_theta


R_full = np.zeros((K,D,N,K),dtype=floatX)
V = np.zeros((K,D),dtype=floatX)

So it seems this posterior update is wrong. It's not coordinate ascent. It ascends multiple coordinates at once. Apparently there is not guarantee that this update rule will converge.

In [None]:
def p_Z(lnP_star, Eln_pi): 
    for d in range(D):
        if not np.mod(d, np.round(D/12.)):
            print '%d pixels remaining' %(D-d)
        for k in range(K):
            Z = Z_samples.copy()
            Z[:,:,d] = 0.
            Z[:,k, d] = 1.
            R_full[k,d,:,:] =object_count_prob_func(Z,W,K).astype(floatX)
            V[k,d] = np.dot(Z[0,:,d], Eln_pi)
    Q_Z_new, Q_Z_new_nn = Q_Z_update_func(R_full, lnP_star, V)
    return Q_Z_new, Q_Z_new_nn, V, R_full



Q_Z = init_Q()
Eln_pi = init_Eln_pi()
P_star = init_P_star()

ELBO = -np.inf
delta_ELBO = np.inf

P_theta = np.zeros((1,N,K), dtype=floatX)
for n in range(N):
    for k in range(K):
        P_theta[0,n,k]  = lkhd(r[n],k+1,K, theta_plus,theta_minus)

In [None]:
t = 0
max_T = 300
ELBO = np.zeros(max_T)
M = 200
while (delta_ELBO > 0) and (t < max_T):
    print 'iteration: %d' %(t,)
    
    Z_samples = Z_sample_func(Q_Z, M).astype(floatX)
    lnP_star = np.log(P_star).astype(floatX)
    Q_Z, Q_Z_nn,V,R_full = p_Z(lnP_star, Eln_pi.astype(floatX))
    print _lnQ_Z_nn.eval({_R_full: R_full, _lnP_star: lnP_star, _V: V})
#     print Eln_pi
#     if not np.all(np.isfinite(Q_Z)):
#         print _lnQ_Z_nn.eval({_R_full: R_full, _lnP_star: lnP_star, _V: V})
#         print Q_Z_nn
#         assert False
    Z_new = Z_sample_func(Q_Z, M).astype(floatX)
    R = object_count_prob_func(Z_new, W, K)
    lnQ_theta, P_star = theta_update_func(P_theta, R)
    Eln_pi = Eln_pi_update_func(Q_Z.sum(axis=1,keepdims=True), alpha_0)
    print Eln_pi
    ELBO[t] = ELBO_update_func(Eln_pi, Q_Z, lnQ_theta)
    print '============ELBO: %f' %(ELBO[t])
    print 'prior entropy: %f' %(_prior_entropy.eval({_Q_Z : Q_Z, _Eln_pi : Eln_pi}))
    print 'posterior entropy: %f' %(_posterior_entropy.eval({_Q_Z : Q_Z}))
    print 'goodness of fit: %f' %(np.max(lnQ_theta))
    t += 1

In [None]:
_=plt.plot(lnP_star.T)

In [None]:
plt.plot(ELBO)

In [None]:
from mpl_toolkits.axes_grid1 import ImageGrid
images = [Q_Z[0,:].reshape((sqrt_D,sqrt_D)), Q_Z[1,:].reshape((sqrt_D,sqrt_D))]

##view: construct an image grid
fig = plt.figure(1, (15,5))
grid = ImageGrid(fig, 111, # similar to subplot(111)
                nrows_ncols = (1, 3), # creates 2x2 grid of axes
                axes_pad=0.5, # pad between axes in inch.
                cbar_mode = 'each',
                cbar_pad = .05
                )
im = grid[0].imshow(test_object_map,cmap='Dark2')
grid[0].cax.colorbar(im)
for kk in range(1,K+1):
    im = grid[kk].imshow(images.pop(0), cmap='hot', clim=[0,1])
    grid[kk].cax.colorbar(im)

In [None]:
Q_Z