In [None]:
# %load draw.py
#!/usr/bin/env python

import tensorflow as tf
from tensorflow.examples.tutorials import mnist
import numpy as np
import os
import copy

tf.flags.DEFINE_string("data_dir", "", "")
FLAGS = tf.flags.FLAGS

## MODEL PARAMETERS ## 
n_input=600
img_size=n_input #B*A # the canvas size
z_pc_size=50 #QSampler output size
z_gc_size=10
train_iters=3000 # 10000
learning_rate=1e-3 # 1e-3 # learning rate for optimizer
eps=1e-8 # epsilon for numerical stability
state_size=16
batch_size=20 # training minibatch size
beta=2
head_dir_on=1

def seq_sort(it_num,b_size,data):
    n_data=data.shape[0]
    n_input=data.shape[1]
    seq_len=int(n_data/b_size)
    data_sort=np.zeros((b_size,n_input))
    a=0
    for i in range(n_data):
        if np.mod(i,seq_len)==it_num:            
            data_sort[a,:]=data[i,:]
            a=a+1
    return data_sort

def accuracy(states,reconstuction,inp):
    ls_prediction=np.zeros((batch_size,state_size))
    prediction=np.zeros((batch_size,1))
    for j in range(batch_size):
        pc_out=reconstuction[j,:]
        for k in range(state_size):
            ls_prediction[j,k]=np.sum((pc_out-inp[k])**2)
        prediction[j]=np.where(ls_prediction[j,:]==np.min(ls_prediction[j,:]))
    temp=np.shape(np.where((prediction-states)==0))
    return (temp[1]/batch_size)*100

def linear_track(length):    
    A_both=np.zeros((length,length))
    A_for=np.zeros((length,length))
    A_back=np.zeros((length,length))
    for i in range(length):
        A_both[i,np.mod(i+1,length)]=1           
        A_both[np.mod(i+1,length),i]=1
        A_for[i,np.mod(i+1,length)]=1           
        A_back[np.mod(i+1,length),i]=1
    return A_both,A_for,A_back

def square_box(states):
    width=int(np.sqrt(states))
    A_box=np.zeros((states,states))
    #state number counts down and then across
    for i in range(states):
        if i+width<states: #right - left
            A_box[i,i+width]=1
            A_box[i+width,i]=1
        
        if np.mod(i,width)!=0: #up - down
            A_box[i,i-1]=1
            A_box[i-1,i]=1            
    return A_box

def square_torus(states):
    width=int(np.sqrt(states))
    A_box=np.zeros((states,states))
    #state number counts down and then across
    for i in range(states):
        if i+width<states: #right - left
            A_box[i,i+width]=1
            A_box[i+width,i]=1
        else:
            A_box[i,i-(width-1)*width]=1
            A_box[i-(width-1)*width,i]=1        
            
        if np.mod(i,width)!=0: #up - down
            A_box[i,i-1]=1
            A_box[i-1,i]=1
        else:
            A_box[i,i+width-1]=1
            A_box[i+width-1,i]=1         
    return A_box

def input_cells(state_size,n_input):
    inp=np.zeros((state_size,n_input))
    for i in range(state_size):
        #inp[i,:]=np.random.uniform(0,1,(1,n_input))*np.random.uniform(0,1,(1,n_input))
        inp[i,:]=np.random.uniform(0,1,(1,n_input))
    return inp    

def walking_data(state_size,time_steps,n_input,inp):
    #A_both,A_1,A_2=linear_track(state_size)
    #A_1=square_box(state_size)
    A_1=square_torus(state_size)
    data_t=np.zeros((time_steps,n_input))
    data_plus=np.zeros((time_steps,n_input))
    context=np.zeros((time_steps,1))
    state=np.zeros((time_steps+1,1))
    head_dir=np.zeros((time_steps,4))
    width=int(np.sqrt(state_size))
    #start at random state
    #have two difference contexts
    state[0,0]=np.random.choice(range(state_size),1)
    for i in range(time_steps):
        data_t[i,:]=inp[state[i,0],:]
        q=np.mod(i,batch_size) # get sequences of contexts        
        if q<batch_size//2: #np.random.rand(1) <= 0.5: # 
            Adj=A_1
            context[i]=0;
        else:
            Adj=A_1
            context[i]=1; #0
        available=np.where(Adj[state[i,0],:]==1)[0]
        state[i+1,0]=np.random.choice(available,1)
        data_plus[i,:]=inp[state[i+1,0],:]
        # consider square environment. if move +1=up, -1=down, +length=right, -length=left.
        if head_dir_on==1:
            diff=state[i+1,0]-state[i,0]
            if diff==1 or diff==-(width-1):  #down
                head_dir[i,0]=1 
            elif diff==-1 or diff==(width-1):  #up
                head_dir[i,1]=1
            elif diff==width or diff==-width*(width-1):  #right
                head_dir[i,2]=1
            elif diff==-width or diff==width*(width-1):  #left
                head_dir[i,3]=1
            
    return (context,data_t,data_plus,state,head_dir)

## BUILD MODEL ## 
DO_SHARE=None # workaround for variable_scope(reuse=True)

x = tf.placeholder(tf.float32,shape=(batch_size,img_size)) # input (batch_size * img_size)
gc_t = tf.placeholder(tf.float32,shape=(batch_size,z_gc_size)) # input (batch_size * img_size)
mumu = tf.placeholder(tf.float32,shape=(batch_size,z_pc_size)) # input (batch_size * img_size)
logsigsig = tf.placeholder(tf.float32,shape=(batch_size,z_pc_size)) # input (batch_size * img_size)
context = tf.placeholder(tf.float32,shape=(batch_size,1)) # input (batch_size * img_size)
head_d = tf.placeholder(tf.float32,shape=(batch_size,4)) # input (batch_size * img_size)
x_n = tf.placeholder(tf.float32,shape=(batch_size,img_size)) # input (batch_size * img_size)
e_pc_1=tf.random_normal((batch_size,z_pc_size), mean=0, stddev=1) # Qsampler noise
e_pc_2=tf.random_normal((batch_size,z_pc_size), mean=0, stddev=1) # Qsampler noise
e_gc=tf.random_normal((batch_size,z_gc_size), mean=0, stddev=1) # Qsampler noise
train_iteration=tf.placeholder(tf.int32,shape=())

def linear(x,output_dim):
    w=tf.get_variable("w", [x.get_shape()[1], output_dim]) 
    b=tf.get_variable("b", [output_dim], initializer=tf.constant_initializer(0.0))
    return tf.matmul(x,w)+b
 
## PLACE CELL STUFF
def encode_PC(x):    
    with tf.variable_scope("encoder_pc",reuse=DO_SHARE):
        h=linear(x,z_pc_size)        
    return tf.tanh(h)  

def sampleQ_PC(h_enc_x):
    with tf.variable_scope("mu_x2pc",reuse=DO_SHARE):
        mu=linear(h_enc_x,z_pc_size)
    with tf.variable_scope("sigma_x2pc",reuse=DO_SHARE):
        logsigma=linear(h_enc_x,z_pc_size)
        sigma=tf.exp(logsigma)
    return (mu + sigma*e_pc_1, mu, logsigma, sigma)

def sampleQ_PC_1(h_enc_x,h_dec_gc): 
    _, mu_gc2pc,logsigma_gc2pc,sigma_gc2pc=sampleQ_PC_2(h_dec_gc)    
    _, mu_x2pc,logsigma_x2pc,sigma_x2pc=sampleQ_PC(h_enc_x)    
    
    logsigma=-0.5*tf.log(1/tf.square(sigma_gc2pc) + 1/tf.square(sigma_x2pc))
    sigma=tf.exp(logsigma) 
    mu=sigma*(mu_gc2pc/tf.square(sigma_gc2pc) + mu_x2pc/tf.square(sigma_x2pc))
    return (mu + sigma*e_pc_1, mu, logsigma, sigma)

def sampleQ_PC_2(h_enc):    
    with tf.variable_scope("mu_gc2pc",reuse=DO_SHARE):
        mu=linear(h_enc,z_pc_size)
    with tf.variable_scope("sigma_gc2pc",reuse=DO_SHARE):
        logsigma=linear(h_enc,z_pc_size)
        sigma=tf.exp(logsigma)    
    return (mu + sigma*e_pc_2, mu, logsigma, sigma)

def decode_PC(x):
    with tf.variable_scope("decoder_pc",reuse=DO_SHARE):
        h=linear(x,img_size)
    return tf.tanh(h) 
    
def write_PC(h_dec):
    with tf.variable_scope("write_pc",reuse=DO_SHARE):
        return linear(h_dec,img_size)    

## GRID CELL STUFF
def encode_GC(x):
    mult=1
    with tf.variable_scope("encoder_gc",reuse=DO_SHARE):
        h=linear(x,mult*z_gc_size)
    #xp=tf.tanh(h)    
    #with tf.variable_scope("encoder_gc_2",reuse=DO_SHARE):
    #    h=linear(xp,z_gc_size)
    return tf.tanh(h)           

def sampleQ_GC(h_enc):    
    with tf.variable_scope("mu_gc",reuse=DO_SHARE):
        mu=linear(h_enc,z_gc_size)
    with tf.variable_scope("sigma_gc",reuse=DO_SHARE):
        logsigma=linear(h_enc,z_gc_size)
        sigma=tf.exp(logsigma)
    return (mu + 0*sigma*e_gc, mu, logsigma, sigma)

def decode_GC(x):
    with tf.variable_scope("decoder_gc",reuse=DO_SHARE):
        h=linear(x,z_pc_size)
    return tf.tanh(h) 
    
def write_GC(h_dec):
    with tf.variable_scope("write_gc",reuse=DO_SHARE):
        return linear(h_dec,z_pc_size)    

## STATE VARIABLES ## 
mus_pc,logsigmas_pc,sigmas_pc,zs_pc=[0]*2,[0]*2,[0]*2,[0]*2 
mus_gc,logsigmas_gc,sigmas_gc,zs_gc=[0],[0],[0],[0] 

"""
## COMPUTATIONAL GRAPH
#place cell t=1
h_pc1_enc=encode_PC(x)
h_dec_gc_t=decode_GC(gc_t)
zs_pc[0],mus_pc[0],logsigmas_pc[0],sigmas_pc[0]=sampleQ_PC_1(h_pc1_enc,h_dec_gc_t)
h_dec_pc=decode_PC(zs_pc[0])
PC1_rec=write_PC(h_dec_pc)
PC1_recons=tf.nn.sigmoid(PC1_rec)
#grid cell
h_gc_enc=encode_GC(tf.concat(1,[context,head_d,mus_pc[0]]))
z_gc,mu_gc,logsigma_gc,sigma_gc=sampleQ_GC(h_gc_enc)
DO_SHARE=True
h_dec_gc=decode_GC(z_gc)
#place cell t=2
zs_pc[1],mus_pc[1],logsigmas_pc[1],sigmas_pc[1]=sampleQ_PC_2(h_dec_gc)
#DO_SHARE=True
h_dec_pc=decode_PC(zs_pc[1])
PC2_rec=write_PC(h_dec_pc)
PC2_recons=tf.nn.sigmoid(PC2_rec)
# z_pc_t+1
h_pc1_enc_2=encode_PC(x_n)
zs_pc_2,mus_pc_2,logsigmas_pc_2,_=sampleQ_PC_1(h_pc1_enc_2,h_dec_gc)
"""

## COMPUTATIONAL GRAPH -- this is where the model is specified, calls above encoders, decoders, samplers, etc
#place cell t=1
h_pc1_enc=encode_PC(x)
zs_pc[0],mus_pc[0],logsigmas_pc[0],sigmas_pc[0]=sampleQ_PC(h_pc1_enc)
h_dec_pc=decode_PC(zs_pc[0])
PC1_rec=write_PC(h_dec_pc)
PC1_recons=tf.nn.sigmoid(PC1_rec)
#grid cell
h_gc_enc=encode_GC(tf.concat(1,[context,head_d,h_pc1_enc]))
z_gc,mu_gc,logsigma_gc,sigma_gc=sampleQ_GC(h_gc_enc)
h_dec_gc=decode_GC(z_gc)
DO_SHARE=True
#place cell t=2
zs_pc[1],mus_pc[1],logsigmas_pc[1],sigmas_pc[1]=sampleQ_PC(h_dec_gc)
h_dec_pc=decode_PC(zs_pc[1])
PC2_rec=write_PC(h_dec_pc)
PC2_recons=tf.nn.sigmoid(PC2_rec)

# z_pc_t+1
h_pc1_enc_t_plus=encode_PC(x_n)
zs_pc_t_plus,mus_pc_t_plus,logsigmas_pc_t_plus,sigmas_pc_t_plus=sampleQ_PC(h_pc1_enc_t_plus)





## LOSS FUNCTION ## ~
def binary_crossentropy(t,o):
    return -(t*tf.log(o+eps) + (1.0-t)*tf.log(1.0-o+eps))

def squared_error(t,o):
    return tf.square(t-o)

## RECONSTRUCTIONS ERRORS
Lx1=tf.reduce_sum(squared_error(x,PC1_recons),1) # reconstruction term
Lx1=tf.reduce_mean(Lx1)
Lx2=tf.reduce_sum(squared_error(x_n,PC2_recons),1) # reconstruction term
Lx2=tf.reduce_mean(Lx2)
## KL PLACE CELLS
kl_terms_pc=[0]*2
for t in range(2):
    mu2_pc=tf.square(mus_pc[t])
    sigma2_pc=tf.square(sigmas_pc[t])
    logsigma_pc=logsigmas_pc[t]
    c=mus_pc[t]/(2*sigmas_pc[t])
    c_sq=tf.square(c)
    kl_terms_pc[t]=tf.reduce_sum(-logsigma_pc+beta*(tf.sqrt(2*sigma2_pc/np.pi)*tf.exp(-c_sq)+\
                    mus_pc[t]*tf.erf(c)),1)-.5-.5*tf.log(2*np.pi)-tf.log(2/beta)# each kl term is (1xminibatch)
KL_pc=tf.add_n(kl_terms_pc) # this is 1xminibatch, corresponding to summing kl_terms from 1:T
#KL_pc=kl_terms_pc[0]
Lz_pc=tf.reduce_mean(KL_pc) # average over minibatches

## KL GRID CELLS
mu2_gc=tf.square(mu_gc)
sigma2_gc=tf.square(sigma_gc)
kl_terms_gc=0.5*tf.reduce_sum(mu2_gc+sigma2_gc-2*logsigma_gc,1)- .5 # each kl term is (1xminibatch)
#KL_gc=tf.add_n(kl_terms_gc)
Lz_gc=tf.reduce_mean(kl_terms_gc) # average over minibatches

#KL beetween zs
mu2_pc=tf.square(mus_pc[1])
sigma2_pc=tf.square(sigmas_pc[1])
logsigma_pc=logsigmas_pc[1]
"""
#ATTEMPT 1 - treat mu,sigma at t+1 as constant - maybe doesnt work as get differnt z's due to stochasticity?
sigsig=tf.exp(logsigsig)
sigsig2=tf.square(sigsig)
mu2=tf.square(mumu)
KL_between_zs=0.5*tf.reduce_sum(2*logsigsig-2*logsigma_pc+\
                                (mu2+mu2_pc+sigma2_pc-2*mus_pc[1]*mumu)/sigsig2,1)- .5 
Lz_zs=tf.reduce_mean(KL_between_zs)
"""
"""
#ATTEMPT 2 - treat mu,sigma at t+1 as variables
sigsig=tf.exp(logsigmas_pc_2)
sigsig2=tf.square(sigsig)
mu2=tf.square(mus_pc_2)
KL_between_zs=0.5*tf.reduce_sum(2*logsigmas_pc_2-2*logsigma_pc+\
                                (mu2+mu2_pc+sigma2_pc-2*mus_pc[1]*mus_pc_2)/sigsig2,1)- .5 
Lz_zs=tf.reduce_mean(KL_between_zs)
"""

Lz_zs=tf.reduce_mean(squared_error(zs_pc[1],zs_pc_t_plus))

## COST
#cost_all=Lx1+Lx2+Lz_pc+Lz_gc+Lz_zs
#cost_all=Lx2#Lx1+Lx2#+Lz_pc#+Lz_gc+Lz_zs
cost_all=Lz_zs+Lz_gc
####cost=Lx1+Lz_pc+Lz_zs+Lz_gc#+Lx2
cost_autoenc=Lx1+tf.reduce_mean(kl_terms_pc[0])


with tf.variable_scope("encoder_gc",reuse=True):
    a=tf.get_variable("w")
    b=tf.get_variable("b")
with tf.variable_scope("mu_gc",reuse=True):
    a1=tf.get_variable("w")
    b1=tf.get_variable("b") 
with tf.variable_scope("sigma_gc",reuse=True):
    a2=tf.get_variable("w")
    b2=tf.get_variable("b") 
with tf.variable_scope("decoder_gc",reuse=True):
    a3=tf.get_variable("w")
    b3=tf.get_variable("b")
"""
with tf.variable_scope("mu_gc2pc",reuse=True):
    a4=tf.get_variable("w")
    b4=tf.get_variable("b") 
with tf.variable_scope("sigma_gc2pc",reuse=True):
    a5=tf.get_variable("w")
    b5=tf.get_variable("b") 
#with tf.variable_scope("write_gc",reuse=True):
   # a6=tf.get_variable("w")
    #b6=tf.get_variable("b") 
#tf.Variable(a, trainable=False)
"""
#my_var_list={a,b,a1,b1,a2,b2,a3,b3,a4,b4,a5,b5}#,a6,b6} 
my_var_list={a,b,a1,b1,a2,b2,a3,b3}#,a4,b4,a5,b5}#,a6,b6} 
    
    
#tf.Variable([tf.get_variable("decoder_pc/w")], trainable=False)
optimizer=tf.train.AdamOptimizer(learning_rate, beta1=0.5)
#grads=optimizer.compute_gradients(cost)
if tf.assert_less(train_iteration,500):
    grads=optimizer.compute_gradients(cost_autoenc)
    #grads=optimizer.compute_gradients(cost_all,var_list=my_var_list)
else:
    grads=optimizer.compute_gradients(cost_all,var_list=my_var_list)
    #grads=optimizer.compute_gradients(cost_autoenc)
    
    
for i,(g,v) in enumerate(grads):
    if g is not None:
        grads[i]=(tf.clip_by_norm(g,5),v) # clip gradients
train_op=optimizer.apply_gradients(grads)


## RUN TRAINING ## 
fetches=[]
#fetches.extend([Lx1,Lx2,Lz_pc,Lz_gc,Lz_zs,z_gc,train_op])
fetches.extend([Lx1,Lx2,Lz_pc,Lz_gc,Lz_zs,z_gc,train_op])
Lx1s=[0]*train_iters
Lx2s=[0]*train_iters
Lz_pcs=[0]*train_iters
Lz_gcs=[0]*train_iters
Lzs=[0]*train_iters

#fetches2=[]
#fetches2.extend([mus_pc_2,logsigmas_pc_2])

sess=tf.InteractiveSession()

saver = tf.train.Saver() # saves variables learned during training
tf.initialize_all_variables().run()
#saver.restore(sess, "infer_place_grid_v1.ckpt") # to restore from model, uncomment this line
inp_1=input_cells(state_size,n_input)
inp_2=input_cells(state_size,n_input)
seq_len=10
for i in range(train_iters):
    if np.random.rand(1) <= 0.5:
        inp=inp_1
    else:
        inp=inp_1
        
    gcs=np.zeros((batch_size,z_gc_size))
    
    (contexxy,data_t,data_plus,states,head_dir)=walking_data(state_size,seq_len*batch_size,n_input,inp)
    
    for ii in range(seq_len):
        
        x_t=seq_sort(ii,batch_size,data_t)
        x_p=seq_sort(ii,batch_size,data_plus)
        cx=seq_sort(ii,batch_size,contexxy)
        hd=seq_sort(ii,batch_size,head_dir) 
        
        feed_dict={x:x_t,x_n:x_p,context:cx,head_d:hd,gc_t:gcs,train_iteration:i}
    
        #results=sess.run([mus_pc_2,logsigmas_pc_2],feed_dict)
        #results=sess.run(fetches2,feed_dict)
        #mu,logsig=results    
    
        #feed_dict={x:x_t,x_n:x_p,context:cx,head_d:hd,gc_t:gcs,mumu:mu,logsigsig:logsig}
    
        results=sess.run(fetches,feed_dict)
        Lx1s[i],Lx2s[i],Lz_pcs[i],Lz_gcs[i],Lzs[i],gcs,_=results
    
        #gcs=np.zeros((batch_size,z_gc_size))

    if i%100==0:
        print("iter=%d : Lx1: %f Lx2: %f Lzpc: %f Lzgc: %f Lzs : %f" % (i*seq_len,Lx1s[i],Lx2s[i],Lz_pcs[i],Lz_gcs[i],Lzs[i]))
        #print("iter=%d : Lx1: %f Lx2: %f" % (i*seq_len,Lx1s[i],Lx2s[i]))#,Lz_pcs[i],Lz_gcs[i],Lzs[i]))

## TRAINING FINISHED ## 

state1=seq_sort(ii,batch_size,states[0:-1])
state2=seq_sort(ii,batch_size,states[1:])

PC1_reconsructions=sess.run(PC1_recons,feed_dict)
PC1_reconsructions=np.array(PC1_reconsructions)
PC2_reconsructions=sess.run(PC2_recons,feed_dict)
PC2_reconsructions=np.array(PC2_reconsructions)
## test accuracy
correct1=accuracy(state1,PC1_reconsructions,inp)
correct2=accuracy(state2,PC2_reconsructions,inp)

place_cells=sess.run(zs_pc,feed_dict)
place_cells_mus=sess.run(mus_pc,feed_dict)
grid_cells=sess.run(z_gc,feed_dict)
grid_cells_mus=sess.run(mu_gc,feed_dict)
out_file=os.path.join(FLAGS.data_dir,"infer_place_grid_v1_data.npy")
np.save(out_file,[PC1_reconsructions,PC2_reconsructions,Lx1s,Lx2s,Lz_pcs,Lz_gcs,contexxy,inp_1,inp_2])
print("Outputs saved in file: %s" % out_file)

ckpt_file=os.path.join(FLAGS.data_dir,"infer_place_grid_v1.ckpt")
print("Model saved in file: %s" % saver.save(sess,ckpt_file))

print("PC1 accuracy=%d" % correct1)
print("PC2 accuracy=%d" % correct2)

sess.close()

print('Done drawing! Have a nice day! :)')



iter=0 : Lx1: 54.880035 Lx2: 59.166199 Lzpc: 160.260696 Lzgc: 6.763613 Lzs : 0.321085


In [None]:
import matplotlib
import sys
import numpy as np
import copy
import matplotlib.pyplot as plt
%matplotlib inline

#out_file='simple_model_adj_data.npy'
#[PC1,PC2,Lx1s,Lx2s,Lz_pcs,Lz_gcs,contexxy]=np.load(out_file)
#batch_size,img_size=PC2.shape
plt.plot(Lx1s,label='Reconstruction Loss Lx : PC1')
plt.plot(Lx2s,label='Reconstruction Loss Lx : PC2')    
plt.plot(np.add(Lz_pcs,Lz_gcs),label='Latent Loss Lz')
plt.xlabel('iterations')
plt.legend()

tt=np.array(place_cells_mus)
img=tt[0,:,:]
fig=plt.matshow(img,cmap=plt.cm.gray)
plt.colorbar(fig)
plt.xlabel('Place cells 1')
plt.ylabel('State')

tt=np.array(place_cells_mus)
img=tt[1,:,:]
fig=plt.matshow(img,cmap=plt.cm.gray)
plt.colorbar(fig)
plt.xlabel('Place cells 2')
plt.ylabel('State')

pp=np.array(grid_cells_mus)
plt.matshow(np.concatenate((pp[0:batch_size/2,:],pp[batch_size/2:batch_size,:]),axis=1),cmap=plt.cm.gray)
plt.colorbar()
plt.xlabel('Grid cells context 1 - Context 2')
plt.ylabel('State')