svi - decoder for continuous y:
    
the key is to represent important quantities as random variables, and then sample from them:
    
$s \sim \mathcal{N}(\cdot, \cdot)$,

$z \sim \text{Categorical}(\pi)$,

$\pi = \lambda / \sum \lambda$,

$\log(\lambda) \sim \mathcal{N}(b + \beta y, 0)$,

$y \sim \mathcal{GP}(\cdot, \cdot)$

(remember to specify the priors for these R.V.s)

when calculate the elbo, we can sample from these R.V.s and then use `pytorch`'s existing functionality to compute the log-likelihood and differential entropy terms.

**caveat**: since a finite subset of GP is just a multivariate normal, we can set its mean and kernel function as parameters to be learned by SGD, and then do everything else w.r.t. the multivariate normal. (see [gaussian-process-tutorial](https://peterroelants.github.io/posts/gaussian-process-tutorial/))

we can still keep the cavi encoder but implement a svi decoder. the principle is that if we can get exact estimates (cavi) then exact is prefered over stochastic (svi).

In [15]:
import numpy as np
import random
import matplotlib.pyplot as plt

from sklearn.linear_model import Ridge
from sklearn.mixture import GaussianMixture
from sklearn.metrics import r2_score

import torch
import torch.distributions as D

from clusterless import preprocess
from clusterless import decoder

In [2]:
seed = 666
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.set_default_dtype(torch.double)

In [3]:
SMALL_SIZE = 15
MEDIUM_SIZE = 20
BIGGER_SIZE = 25

plt.rc('font', size=MEDIUM_SIZE)         
plt.rc('axes', titlesize=MEDIUM_SIZE)     
plt.rc('axes', labelsize=MEDIUM_SIZE) 
plt.rc('axes', linewidth = 1.5)
plt.rc('xtick', labelsize=MEDIUM_SIZE)   
plt.rc('ytick', labelsize=MEDIUM_SIZE)   
plt.rc('legend', fontsize=MEDIUM_SIZE)   
plt.rc('figure', titlesize=MEDIUM_SIZE)

#### data

In [4]:
pid = 'dab512bd-a02d-4c1f-8dbc-9155a163efc0'
rootpath = '/mnt/3TB/yizi/Downloads/ONE/openalyx.internationalbrainlab.org'
trial_data_path = rootpath + '/danlab/Subjects/DY_016/2020-09-12/001/alf'
neural_data_path = '/mnt/3TB/yizi/danlab/Subjects/DY_016'
behavior_data_path = rootpath + '/paper_repro_ephys_data/figure9_10/original_data'
save_path = '../saved_results/danlab/Subjects/DY_016/cavi_results/'

In [5]:
unsorted_trials, stim_on_times, np1_channel_map = preprocess.load_neural_data(
    pid=pid, 
    trial_data_path=trial_data_path,
    neural_data_path=neural_data_path,
    behavior_data_path=behavior_data_path,
    keep_active_trials=True, 
    roi='all',
    kilosort=False,
    triage=False
)

behave_dict = preprocess.load_behaviors_data(behavior_data_path, pid)
_, wheel_velocity, _, _, _, _ = preprocess.preprocess_dynamic_behaviors(behave_dict)

pid: dab512bd-a02d-4c1f-8dbc-9155a163efc0
eid: d23a44ef-1402-4ed7-97f5-47e9a7a504d9
1st trial stim on time: 17.56, last trial stim on time 2310.24


In [6]:
class DataLoader():
    def __init__(self, data, y, stim_on_times, np1_channel_map, n_t_bins=30):
        self.data = data
        self.y = y
        self.stim_on_times = stim_on_times
        self.np1_channel_map = np1_channel_map
        self.n_t_bins = n_t_bins
        self.n_trials = stim_on_times.shape[0]
        self.n_channels = np1_channel_map.shape[0]
        self.t_binning = np.arange(0, 1.5, step = (1.5 - 0) / n_t_bins)
        self.rand_trial_ids = np.arange(self.n_trials)
        
        # allocate unsorted data into trials
        self.trial_ids = []
        self.t_ids = []
        self.trials = []
        self.t_bins = []
        for k in range(self.n_trials):
            mask = np.logical_and(data[:,0] >= stim_on_times[k] - 0.5,
                                  data[:,0] <= stim_on_times[k] + 1)
            trial = data[mask,:]
            trial[:,0] = trial[:,0] - trial[:,0].min()
            t_bins = np.digitize(trial[:,0], self.t_binning, right = False) - 1
            t_bin_lst = []
            for t in range(self.n_t_bins):
                t_bin = trial[t_bins == t,1:]
                self.trial_ids.append(np.ones_like(t_bin[:,0]) * k)
                self.t_ids.append(np.ones_like(t_bin[:,0]) * t)
                t_bin_lst.append(t_bin)
            self.trials.append(t_bin_lst)
    
    
    def split_train_test(self, split=.8):
        
        self.train_ids = self.rand_trial_ids[:int(split * self.n_trials)]
        self.test_ids = self.rand_trial_ids[int(split * self.n_trials):]
        self.y_train = self.y[self.train_ids]
        self.y_test = self.y[self.test_ids]
        
        trial_ids = np.concatenate(self.trial_ids)
        t_ids = np.concatenate(self.t_ids)
        trials = np.concatenate(np.concatenate(self.trials))

        train_mask = np.sum([trial_ids==idx for idx in self.train_ids], axis=0).astype(bool)
        test_mask = np.sum([trial_ids==idx for idx in self.test_ids], axis=0).astype(bool)
        train_k_ids, test_k_ids = trial_ids[train_mask], trial_ids[test_mask]
        train_t_ids, test_t_ids = t_ids[train_mask], t_ids[test_mask]
        train_k_ids = [torch.argwhere(torch.tensor(train_k_ids)==k).reshape(-1) for k in self.train_ids]
        train_t_ids = [torch.argwhere(torch.tensor(train_t_ids)==t).reshape(-1) for t in range(self.n_t_bins)]
        test_k_ids = [torch.argwhere(torch.tensor(test_k_ids)==k).reshape(-1) for k in self.test_ids]
        test_t_ids = [torch.argwhere(torch.tensor(test_t_ids)==t).reshape(-1) for t in range(self.n_t_bins)]
        train_trials, test_trials = trials[train_mask], trials[test_mask]
        
        return train_trials, train_k_ids, train_t_ids, \
               test_trials, test_k_ids, test_t_ids

In [7]:
data_loader = DataLoader(data = np.concatenate(unsorted_trials)[:,[0,2,3,4]], 
                         y = wheel_velocity, 
                         stim_on_times = stim_on_times, 
                         np1_channel_map = np1_channel_map, 
                         n_t_bins=30)

In [8]:
train_trials, train_k_ids, train_t_ids, _, _, _ = data_loader.split_train_test(split=.1)



In [9]:
gmm = GaussianMixture(n_components=10, 
                      covariance_type='full', 
                      init_params='k-means++',
                      verbose=0)
gmm.fit(train_trials)

In [10]:
s = torch.tensor(train_trials)
y = torch.tensor(data_loader.y_train)

Nk = len(data_loader.train_ids)
Nt = data_loader.n_t_bins
Nc = gmm.means_.shape[0]
Nd = gmm.means_.shape[1]

#### decoder

In [11]:
def safe_log(x, minval=1e-10):
    return torch.log(x + minval)

In [None]:
class Decoder(torch.nn.Module):

    def __init__(self, 
                 Nk, Nt, Nc, Nd, 
                 init_means, init_covs, 
                 init_bs, init_betas,
                 ks, ts):
        super(Decoder, self).__init__()
        self.Nk = Nk
        self.Nt = Nt
        self.Nc = Nc
        self.Nd = Nd
        self.ks = ks
        self.ts = ts
        
        # initialize variables for variational distribution
        self.means = torch.nn.Parameter(init_means, requires_grad=True)
        self.covs = torch.nn.Parameter(init_covs, requires_grad=True)
        self.bs = torch.nn.Parameter(init_bs, requires_grad=False)
        self.betas = torch.nn.Parameter(init_betas, requires_grad=False)
        
        # initialize mean and cov functions (kernels) for GP
        
        
        
        
        
    def forward(self, s, M=1000):
        # M = number of Monte Carlo samples to be drawn
        
        # specify the variational dist. for y
        # TO DO: use torch's multivariate_normal with learned mean and cov functions

        
        
        # compute log-lambdas
        log_lambdas = torch.zeros((self.Nk, self.Nc, self.Nt))
        for k in range(self.Nk):
            for t in range(self.Nt):
                log_lambdas[k,:,t] = self.bs + self.betas[:,t] @ y[k]
        
        # compute mixing proportions 
        log_pis = log_lambdas - torch.logsumexp(log_lambdas, 1)[:,None,:]
        
        # specify the variational dist. for z
        z = torch.zeros((Nk, Nt, Nc))
        for k in range(Nk):
            for t in range(Nt):
                z[k,:,t] = D.categorical.Categorical(probs=torch.exp(log_pis)[k,:,t])
        
        # compute log-likelihood for s as mixture density
        ll = torch.zeros((s.shape[0], self.Nc))
        for j in range(self.Nc):
            ll[:,j] = D.multivariate_normal.MultivariateNormal(
                            loc=self.means[j], 
                            covariance_matrix=self.covs[j]
                        ).log_prob(s)
        
        # Unlike CAVI where we have exact formula for q(z), 
        # we'd have to sample z from q in SVI.
        # sampling z
        q_z = torch.zeros((Nk, Nc, Nt))
        for k in range(Nk):
            for t in range(Nt):
                z_sample = z[k,:,t].sample((M,))
                for j in range(Nc):
                    q_z[k,j,t] = torch.sum(z_sample == j)
                
        # TO DO: sampling y 
        
        
            
        # compute ELBO = E_z,y[logp(s,z,y)] - E_z,y[logq(z,y)]
        # E_z,y[logp(s,z,y)] = E_z,y[logp(s|z) + logp(z|y) + logp(y)]
        
        # E_z[logp(s|z)]
        log_p_s_cond_z = 0
        for k in range(Nk):
            for t in range(Nt):
                k_t_idx = np.intersect1d(self.ks[k], self.ts[t])
                log_p_s_cond_z += q_z[k,:,t] * ll[k_t_idx]
                
        
        # E_z[logp(z|y)]
        log_p_z_cond_y = 0
        for k in range(Nk):
            for t in range(Nt):
                log_p_z_cond_y += q_z[k,:,t] * log_pis[k,:,t]
        
        
        # TO DO: E_y[logp(y)]
        log_p_y = 0
        
        
        
        # Entropy E_z[logq(z)]
        entropy_z = (z * safe_log(z)).mean(-1).sum()
        
        # TO DO: Entropy E_y[logq(y)]
        entropy_y = 0
        
        
                
        return log_p_s_cond_z + log_p_z_cond_y + log_p_y - entropy_z - entropy_y