# This notebook is designed to give a basic overview of the functionality of our model.

In this package we implement binary decoding of timestamped events. This can be trivially extended to decoding continuous variables. To decode binary variables as described here, the design matrix must contain rows of the timestamps of the two events. In the descriptor of the design matrix, these events should be labelled identically (see the 'dec' varaible in desc [descriptor of the design matrix]; and the corresponding rows of the design matrix in the example below).

In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import sys
import os
import seaborn
clrs = seaborn.color_palette(n_colors=9)
seaborn.set(style='ticks',font_scale=2)

In [None]:
from mouse_vae import BAE
from mouse_vae import dat_utils

In [None]:
os.listdir(ROOT_PATH)

In [None]:
##load video, design matrix and descriptor of design matrix


##fill in this path
ROOT_PATH = '/path/to/bae'

#Load video data
video_path = os.path.join(ROOT_PATH,'video_data')
video = np.concatenate([np.load(os.path.join(video_path,i)) for i in os.listdir(video_path)])

#Design Matrix containing timestamps of task events. One frame is approximately
#70ms duration
DM =  np.load(os.path.join(ROOT_PATH,'DM.npy'))

#Descriptor of Design Matrix specifying what each row of
#the design matrix contains. Numbers appended to event types
#describe lags
desc =  np.load(os.path.join(ROOT_PATH,'desc.npy'))

#Index of levels at which stimuli were presented. 1 corresponds to
#the loudest stimulus and 99 corresponds to a catch trial
allVols =  np.load(os.path.join(ROOT_PATH,'allVols.npy'))#np.load('/home/yves/Documents/Code/mouse_vae/allVols.npy')


In [None]:
#Normalize the video.

mu1  = np.mean(video,axis=0)
std1 = np.std(video,axis=0)
std1[std1<1e-2] = 1e6
video = (video-mu1)/std1

In [None]:
#Add cognitive variables to the design matrix
allDM, allDesc = dat_utils.get_full_DM(DM,desc,5,allVols)

# Visualise elements of the Design Matrix

In the visualisations below, events are timestamped according to the time of their entries in the Design Matrix. So for example, bout initiation is not locked to the time of the first lick, but preceding it by several frames.

In [None]:
plt.figure(figsize=(14,3))
lick_ix = int(np.where(desc=='lickL0')[0])
stim_ix = int(np.where(desc=='clicks0')[0])
rew_ix = int(np.where(desc=='rews0')[0])
bout_ix = int(np.where(desc=='bout_init0')[0])

var_names = ['click','rew','lick','bout']
for kk,var in enumerate([stim_ix,rew_ix,lick_ix,bout_ix]):
    plt.vlines(np.where(DM[var,5000:9000])[0],1.5*kk,1.5*kk+1,label=var_names[kk],color=clrs[kk])


plt.legend()
plt.yticks([])
plt.xlim(1000,2500)
seaborn.despine()


### Show decision basis regressor relative to other events.

In [None]:
#Here plot

lick_ix = int(np.where(allDesc=='lickL0')[0])
stim_ix = int(np.where(allDesc=='clicks0')[0])
dec0_ix = int(np.where(allDesc=='dec0')[0][1]) #stimulus driven bout
dec1_ix = int(np.where(allDesc=='dec0')[0][0]) #spontaneous


var_names = ['click','lick','spont-bout','stim-bout']
_,(a1,a2) = plt.subplots(1,2, gridspec_kw = {'width_ratios':[3, 2]},figsize=(16,3))


for kk,var in enumerate([stim_ix,lick_ix,dec0_ix,dec1_ix]):
    a1.vlines(np.where(allDM[var,5000:6000])[0],1.5*kk,1.5*kk+1,label=var_names[kk],color=clrs[kk])
a1.set_xlim(0,2000)
a1.legend()

    
for kk,var in enumerate([stim_ix,lick_ix,dec0_ix,dec1_ix]):
    a2.vlines(np.where(allDM[var,5000:5100])[0],1.5*kk,1.5*kk+1,label=var_names[kk],color=clrs[kk])

seaborn.despine()


In [None]:
plt.figure(figsize=(14,3))
lick_ix = int(np.where(allDesc=='lickL0')[0])
stim_ix = int(np.where(allDesc=='clicks0')[0])
rew_ix = int(np.where(allDesc=='rew0')[0])
att0_ix = int(np.where(allDesc=='att0')[0][1]) #signifies the animal paying attention
att1_ix = int(np.where(allDesc=='att0')[0][0]) #signifies the animal is not paying attention



var_names = ['click','rew','lick','att0','att1']
for kk,var in enumerate([stim_ix,rew_ix,lick_ix,att0_ix,att1_ix]):
    plt.vlines(np.where(allDM[var,5000:5500])[0],1.5*kk,1.5*kk+1,label=var_names[kk],color=clrs[kk])

plt.legend()
plt.yticks([])
plt.xlim(0,700)
seaborn.despine()


# Fit Model

In [None]:
#Create an instance of the BAE-model
lvm = BAE()

In [None]:
#add data to model instance
lvm.add_data(video=video,DM=allDM,descriptor=allDesc)

In [None]:
#change any network. You would like here. To just run a VAE set encode_weight to 0
lvm.network_params

In [None]:
lvm.network_params['n_epochs'] = 5

In [None]:
#Initialise tensorflow variables and functions. Running this function uses our default en-
#and decoder networks. Other networks may be used by simply passing a function implementing
#some form of network as an argument to this function (see make_encoder and make_decoder in
#the model_utils.py for the required structure)
lvm.run_tf_setup()

In [None]:
#view network parameters
lvm.network_params

In [None]:
lvm.estimate_decoding_perf_full_cv(window=[0,4])

In [None]:
#Fit En-and decoder networks
lvm.fit()

In [None]:
lvm.get_latent_states()

In [None]:
#Decoding decision basis. This implementation assumes that evets are discrete
lvm.estimate_decoding_perf(decode_ev='dec',window=[0,4],verbose=1,kfp=[5,1])

In [None]:
lvm.fit_encoding_model()

In [None]:
#Linear prediction of latent-states
lin_pred_lats = lvm.enc_params.dot(lvm.DM)

In [None]:
plt.figure(figsize=(16,5))

up = np.percentile(np.concatenate([lvm.lats[:,0],lin_pred_lats[0]]),100)
lw = np.percentile(np.concatenate([lvm.lats[:,0],lin_pred_lats[0]]),0)

plt.subplot(1,2,1)
plt.plot(lvm.lats[:,0],label='latent state')
plt.plot(lin_pred_lats[0],label='linear prediction')
plt.ylim(lw,up)
plt.xlabel("Time (frames)")
plt.ylabel("Latent State \nValue (a.u.)")
plt.locator_params('x',nbins=3)

plt.legend()

plt.subplot(1,2,2)
plt.plot(lvm.lats[:1000,0])
plt.plot(lin_pred_lats[0,:1000])
plt.ylim(lw,up)
plt.xlabel("Time (frames)")
plt.locator_params('x',nbins=3)

seaborn.despine()

In [None]:
ccs = []
for i,j in zip(lin_pred_lats,lvm.lats.T):
    ccs.append(np.corrcoef(i,j)[0,1])
print('Correlations between predicted and measured latent states is:')
for i in ccs:
    print(i)

In [None]:
#to reconstruct full images from latent states pass an estimate of the latent states 
#(or the full latent states) to lvm.predictor which returns the decoder network's 
#prediction of the image
reconstructed_images = lvm.sess.run(lvm.predictor, feed_dict={lvm._latents: lin_pred_lats[:,:5].T})

In [None]:
plt.figure(figsize=(16,8))
for kk,im in enumerate(reconstructed_images):
    plt.subplot(2,5,kk+1)
    plt.imshow(im,cmap='binary_r',vmin=-2,vmax=2)
    plt.xticks([])
    plt.yticks([])
    if kk==0:
        plt.ylabel("Reconstruction")

kk += 1
for im in video[:5]:
    plt.subplot(2,5,kk+1)
    plt.imshow(im,cmap='binary_r',vmin=-5,vmax=5)
    plt.xticks([])
    plt.yticks([])
    if kk==5:
        plt.ylabel("Data")
    kk += 1

In [None]:
#Perform decoding. Returns indices of stimulus driven (click_sel)
#and spontaneous bouts (spont_sel), as well as their projection
#onto the decoding axis (proj_stim & proj_spont). proj<0
#signifies that this bout is decoded as stimulus driven, proj>0
#indicates it is classified as spontaneous. Greater distances
#from 0 indicate classifier is 'more confident'
(click_sel,spont_sel), (proj_stim,proj_spont) = lvm.decode()

In [None]:
bins = np.linspace(-3,3,num=21)
seaborn.distplot(proj_stim,kde=0,bins=bins)
seaborn.distplot(proj_spont,kde=0,bins=bins)
plt.xlabel("Projection onto decoding axis")
plt.ylabel("Number of bouts")
seaborn.despine()