In [1]:
import numpy as np
import random
import matplotlib.pyplot as plt

from scipy.special import logsumexp
from scipy.stats import pearsonr
from sklearn.linear_model import Ridge
from sklearn.mixture import GaussianMixture
from sklearn.model_selection import KFold
from sklearn.metrics import r2_score, mean_squared_error

import torch
import torch.distributions as D

import isosplit

from clusterless import preprocess
from clusterless import decoder

In [2]:
seed = 666
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.set_default_dtype(torch.double)

In [3]:
SMALL_SIZE = 15
MEDIUM_SIZE = 20
BIGGER_SIZE = 25

plt.rc('font', size=MEDIUM_SIZE)         
plt.rc('axes', titlesize=MEDIUM_SIZE)     
plt.rc('axes', labelsize=MEDIUM_SIZE) 
plt.rc('axes', linewidth = 1.5)
plt.rc('xtick', labelsize=MEDIUM_SIZE)   
plt.rc('ytick', labelsize=MEDIUM_SIZE)   
plt.rc('legend', fontsize=MEDIUM_SIZE)   
plt.rc('figure', titlesize=MEDIUM_SIZE)

#### load data

In [4]:
pid = 'febb430e-2d50-4f83-87a0-b5ffbb9a4943'
rootpath = '/mnt/3TB/yizi/Downloads/ONE/openalyx.internationalbrainlab.org'
trial_data_path = rootpath + '/danlab/Subjects/DY_009/2020-02-27/001/alf'
neural_data_path = '/mnt/3TB/yizi/danlab/Subjects/DY_009'
behavior_data_path = rootpath + '/paper_repro_ephys_data/figure9_10/original_data'
save_path = '../saved_results/danlab/Subjects/DY_009'

In [5]:
roi = 'lp'

In [6]:
sorted_trials, good_sorted_trials, unsorted_trials, stim_on_times, np1_channel_map= preprocess.load_neural_data(
    pid=pid, 
    trial_data_path=trial_data_path,
    neural_data_path=neural_data_path,
    behavior_data_path=behavior_data_path,
    keep_active_trials=True, 
    # roi='all',
    roi = roi,
    kilosort=True,
    triage=False,
    good_units=True,
    thresholding=True
)

behave_dict = preprocess.load_behaviors_data(behavior_data_path, pid)
motion_energy, wheel_velocity, wheel_speed, paw_speed, nose_speed, pupil_diameter = preprocess.preprocess_dynamic_behaviors(behave_dict)

pid: febb430e-2d50-4f83-87a0-b5ffbb9a4943
eid: db4df448-e449-4a6f-a0e7-288711e7a75a
found 85 good ibl units ..
1st trial stim on time: 40.81, last trial stim on time 2252.10
found 74 neurons in region lp ...
found 38 channels in region lp ...


In [7]:
class DataLoader():
    def __init__(self, data, y, stim_on_times, np1_channel_map, n_t_bins=30):
        self.data = data
        self.y = y
        self.stim_on_times = stim_on_times
        self.np1_channel_map = np1_channel_map
        self.n_t_bins = n_t_bins
        self.n_trials = stim_on_times.shape[0]
        self.n_channels = np1_channel_map.shape[0]
        self.t_binning = np.arange(0, 1.5, step = (1.5 - 0) / n_t_bins)
        self.rand_trial_ids = np.arange(self.n_trials)
        
        # allocate unsorted data into trials
        self.trial_ids = []
        self.t_ids = []
        self.trials = []
        self.t_bins = []
        for k in range(self.n_trials):
            mask = np.logical_and(data[:,0] >= stim_on_times[k] - 0.5,
                                  data[:,0] <= stim_on_times[k] + 1)
            trial = data[mask,:]
            trial[:,0] = trial[:,0] - trial[:,0].min()
            t_bins = np.digitize(trial[:,0], self.t_binning, right = False) - 1
            t_bin_lst = []
            for t in range(self.n_t_bins):
                t_bin = trial[t_bins == t,1:]
                self.trial_ids.append(np.ones_like(t_bin[:,0]) * k)
                self.t_ids.append(np.ones_like(t_bin[:,0]) * t)
                t_bin_lst.append(t_bin)
            self.trials.append(t_bin_lst)
    
    
    def split_train_test(self, train_ids, test_ids):
        
        self.train_ids = self.rand_trial_ids[train_ids]
        self.test_ids = self.rand_trial_ids[test_ids]
        self.y_train = self.y[self.train_ids]
        self.y_test = self.y[self.test_ids]
        
        trial_ids = np.concatenate(self.trial_ids)
        t_ids = np.concatenate(self.t_ids)
        trials = np.concatenate(np.concatenate(self.trials))

        train_mask = np.sum([trial_ids == idx for idx in self.train_ids], axis=0).astype(bool)
        test_mask = np.sum([trial_ids == idx for idx in self.test_ids], axis=0).astype(bool)
        train_trial_ids, test_trial_ids = trial_ids[train_mask], trial_ids[test_mask]
        train_t_ids, test_t_ids = t_ids[train_mask], t_ids[test_mask]
        train_trials, test_trials = trials[train_mask], trials[test_mask]
        
        return train_trials, train_trial_ids, train_t_ids, \
               test_trials, test_trial_ids, test_t_ids

In [8]:
data_loader = DataLoader(data = np.concatenate(unsorted_trials)[:,[0,1,2,3,4]], 
                         y = wheel_velocity, 
                         stim_on_times = stim_on_times, 
                         np1_channel_map = np1_channel_map, 
                         n_t_bins=30)

In [9]:
kf = KFold(n_splits=5, shuffle=True, random_state=seed)
# kf = KFold(n_splits=5, shuffle=False)
kf_train_ids = []; kf_test_ids = []
for i, (train_ids, test_ids) in enumerate(kf.split(data_loader.y)):
    kf_train_ids.append(train_ids)
    kf_test_ids.append(test_ids)

In [10]:
i = 1
train_trials, train_trial_ids, train_t_ids, \
test_trials, test_trial_ids, test_t_ids = data_loader.split_train_test(
    train_ids = kf_train_ids[i], test_ids = kf_test_ids[i]
)



In [15]:
sub_weights_lst = []
sub_means_lst = []
sub_covs_lst = []

all_trials = np.vstack([train_trials, test_trials])
for channel in np.unique(all_trials[:,0]):
    sub_s = all_trials[all_trials[:,0] == channel, 1:]
    
    if sub_s.shape[0] > 10:
        isosplit_labels = isosplit.isosplit(sub_s.T, K_init=20, min_cluster_size=10,
                                            whiten_cluster_pairs=1, refine_clusters=1)
    # if sub_s.shape[0] > 5:
    #     isosplit_labels = isosplit.isosplit(sub_s.T, K_init=20, min_cluster_size=5,
    #                                         whiten_cluster_pairs=1, refine_clusters=1)
    elif sub_s.shape[0] < 2:
        continue
    else:
        sub_gmm = GaussianMixture(n_components=1, 
                              covariance_type='full',
                              init_params='k-means++', 
                              verbose=0)
        sub_gmm.fit(sub_s)
        sub_labels = sub_gmm.predict(sub_s)
        sub_weights = len(sub_labels)/len(all_trials)
        sub_weights_lst.append(sub_weights)
        sub_means_lst.append(sub_gmm.means_)
        sub_covs_lst.append(sub_gmm.covariances_)
        continue
    
    n_splits = np.unique(isosplit_labels).shape[0]
    print(f'channel {channel} has {n_splits} modes ...')
    
    if n_splits == 1: 
        sub_gmm = GaussianMixture(n_components=1, 
                              covariance_type='full',
                              init_params='k-means++', 
                              verbose=0)
        sub_gmm.fit(sub_s)
        sub_labels = sub_gmm.predict(sub_s)
        sub_weights = len(sub_labels)/len(all_trials)
        sub_weights_lst.append(sub_weights)
        sub_means_lst.append(sub_gmm.means_)
        sub_covs_lst.append(sub_gmm.covariances_)
    else:
        for label in np.arange(n_splits):
            mask = isosplit_labels == label
            sub_gmm = GaussianMixture(n_components=1, 
                              covariance_type='full',
                              init_params='k-means++', 
                              verbose=0)
            sub_gmm.fit(sub_s[mask])
            sub_labels = sub_gmm.predict(sub_s[mask])
            sub_weights = len(sub_labels)/len(all_trials)
            sub_weights_lst.append(sub_weights)
            sub_means_lst.append(sub_gmm.means_)
            sub_covs_lst.append(sub_gmm.covariances_)
            
sub_weights = np.hstack(sub_weights_lst)
sub_means = np.vstack(sub_means_lst)
sub_covs = np.vstack(sub_covs_lst)

channel 170.0 has 2 modes ...
channel 171.0 has 1 modes ...
channel 172.0 has 2 modes ...
channel 173.0 has 2 modes ...
channel 174.0 has 1 modes ...
channel 177.0 has 1 modes ...
channel 178.0 has 2 modes ...
channel 179.0 has 2 modes ...
channel 180.0 has 1 modes ...
channel 181.0 has 1 modes ...
channel 182.0 has 1 modes ...
channel 183.0 has 1 modes ...
channel 184.0 has 1 modes ...
channel 185.0 has 2 modes ...
channel 186.0 has 2 modes ...
channel 187.0 has 1 modes ...
channel 189.0 has 1 modes ...
channel 190.0 has 2 modes ...
channel 191.0 has 1 modes ...
channel 192.0 has 2 modes ...
channel 194.0 has 3 modes ...
channel 195.0 has 1 modes ...
channel 196.0 has 1 modes ...
channel 197.0 has 1 modes ...
channel 198.0 has 1 modes ...
channel 199.0 has 1 modes ...
channel 200.0 has 1 modes ...
channel 201.0 has 1 modes ...
channel 202.0 has 2 modes ...
channel 204.0 has 2 modes ...
channel 205.0 has 1 modes ...
channel 206.0 has 1 modes ...


In [None]:
# sub_s_lst = []
# sub_weights_lst = []
# sub_means_lst = []
# sub_covs_lst = []
# for channel in np.unique(train_trials[:,0]):
#     sub_s = train_trials[train_trials[:,0] == channel, 1:]
#     sub_s_lst.append(sub_s)
#     if len(sub_s) > 1:
#         sub_gmm = GaussianMixture(n_components=1, 
#                           covariance_type='full',
#                           init_params='k-means++', verbose=0)
#         sub_gmm.fit(sub_s)
#         sub_labels = sub_gmm.predict(sub_s)
#         sub_weights = len(sub_s)/len(train_trials)
#         sub_weights_lst.append(sub_weights)
#         sub_means_lst.append(sub_gmm.means_)
#         sub_covs_lst.append(sub_gmm.covariances_)
        
# sub_weights = np.hstack(sub_weights_lst)
# sub_means = np.vstack(sub_means_lst)
# sub_covs = np.vstack(sub_covs_lst)

In [16]:
gmm = GaussianMixture(n_components=len(sub_weights), covariance_type='full', init_params='k-means++')
gmm.weights_ = sub_weights
gmm.means_ = sub_means
gmm.covariances_ = sub_covs
gmm.precisions_cholesky_ = np.linalg.cholesky(np.linalg.inv(sub_covs))

In [None]:
# gmm = GaussianMixture(n_components=100, 
#                       covariance_type='full', 
#                       init_params='k-means++',
#                       verbose=0)
# gmm.fit(train_trials)

In [17]:
s = torch.tensor(train_trials[:,1:])
y = torch.tensor(data_loader.y)
ks = torch.tensor(train_trial_ids)
ts = torch.tensor(train_t_ids)

Nk = len(data_loader.train_ids)
Nt = data_loader.n_t_bins
Nc = gmm.means_.shape[0]
Nd = gmm.means_.shape[1]

In [18]:
Nc

45

#### CAVI-SGD

In [19]:
def safe_log(x, minval=1e-10):
    return torch.log(x + minval)

In [20]:
class CAVI(torch.nn.Module):

    def __init__(self, Nk, Nt, Nc, Nd, init_means, init_covs):
        super(CAVI, self).__init__()
        self.Nk = Nk
        self.Nt = Nt
        self.Nc = Nc
        self.Nd = Nd
        self.ks = ks
        self.ts = ts
        
        # initialize variables for variational distribution
        self.means = torch.nn.Parameter(torch.tensor(init_means), requires_grad=False)
        self.covs = torch.nn.Parameter(torch.tensor(init_covs), requires_grad=False)
        self.bs = torch.nn.Parameter(torch.randn((Nc)))
        self.betas = torch.nn.Parameter(torch.randn((Nc, Nt)))

        
    def forward(self, s, y, ks, ts):
        
        # compute log-lambdas
        log_lambdas = torch.zeros((self.Nk, self.Nc, self.Nt))
        for k in range(self.Nk):
            for t in range(self.Nt):
                log_lambdas[k,:,t] = self.bs + self.betas[:,t] * y[k][t]
                
        
        # compute mixing proportions 
        log_pis = log_lambdas - torch.logsumexp(log_lambdas, 1)[:,None,:]
        
        # compute log-likelihood
        ll = torch.zeros((s.shape[0], self.Nc))
        for j in range(self.Nc):
            ll[:,j] = D.multivariate_normal.MultivariateNormal(
                            loc=self.means[j], 
                            covariance_matrix=self.covs[j]
                        ).log_prob(s)
            
        
        # order of update is: E step -> compute ELBO -> M step
        # E step
        r = torch.zeros((s.shape[0], self.Nc))
        for k in range(self.Nk):
            for t in range(self.Nt):
                k_t_idx = torch.logical_and(ks == torch.unique(ks).int()[k], ts == t)
                r[k_t_idx] = torch.exp( ll[k_t_idx] + log_pis[k,:,t] )
                r[k_t_idx] = r[k_t_idx] / r[k_t_idx].sum(1)[:,None]
                
                    
        # compute ELBO
        elbo_1 = 0; elbo_2 = 0; elbo_3 = 0
        elbo = 0
        for k in range(self.Nk):
            for t in range(self.Nt):
                k_t_idx = torch.logical_and(ks == torch.unique(ks).int()[k], ts == t)
                elbo_1 += torch.sum( r[k_t_idx] * ll[k_t_idx] )
                elbo_2 += torch.sum( r[k_t_idx] * log_pis[k,:,t] )
                elbo_3 -= torch.sum( r[k_t_idx] * safe_log(r[k_t_idx]) )
                
                
        # M step is done via back propagation
                
        return elbo_1 + elbo_2 + elbo_3

In [21]:
# batch_size = 6
batch_size = 1
batch_ids = list(zip(*(iter(data_loader.train_ids),) * batch_size))

In [22]:
cavi = CAVI(batch_size, Nt, Nc, Nd, gmm.means_, gmm.covariances_)
optim = torch.optim.Adam(cavi.parameters(), lr=1e-2)

In [None]:
%%time
max_iter = 100
elbos = []
N = s.shape[0]
for i in range(max_iter):
    tot_elbo = 0
    for n, batch_idx in enumerate(batch_ids): 
        mask = torch.logical_and(ks >= batch_idx[0], ks <= batch_idx[-1])
        batch_s = s[mask]
        batch_y = y[list(batch_idx)]
        batch_ks = ks[mask]
        batch_ts = ts[mask]
        batch_elbo = cavi(batch_s, batch_y, batch_ks, batch_ts)
        tot_elbo += batch_elbo
        loss = - batch_elbo
        loss.backward()
        if (n+1) % 100 == 0:
            print(f'iter: {i+1} batch {n+1}')
        optim.step()
        optim.zero_grad()
    print(f'iter: {i+1} total elbo: {tot_elbo:.2f}')
    elbos.append(tot_elbo.detach().numpy())

iter: 1 batch 100
iter: 1 total elbo: -537054.06
iter: 2 batch 100
iter: 2 total elbo: -511927.68
iter: 3 batch 100
iter: 3 total elbo: -505198.13
iter: 4 batch 100
iter: 4 total elbo: -503368.77
iter: 5 batch 100
iter: 5 total elbo: -502816.78
iter: 6 batch 100
iter: 6 total elbo: -502643.00
iter: 7 batch 100
iter: 7 total elbo: -502570.67
iter: 8 batch 100
iter: 8 total elbo: -502532.88
iter: 9 batch 100
iter: 9 total elbo: -502510.28
iter: 10 batch 100
iter: 10 total elbo: -502495.37
iter: 11 batch 100
iter: 11 total elbo: -502484.63
iter: 12 batch 100
iter: 12 total elbo: -502476.26
iter: 13 batch 100
iter: 13 total elbo: -502469.29


In [None]:
elbos = [elbo for elbo in elbos]

In [None]:
plt.figure(figsize=(4,2))
plt.plot(elbos)
plt.xlabel("Iteration")
plt.ylabel("ELBO");

In [None]:
log_lambdas = torch.zeros((Nk, Nc, Nt))
for k in range(Nk):
    for t in range(Nt):
        log_lambdas[k,:,t] = cavi.bs + cavi.betas[:,t] * y[k][t]

log_pis = log_lambdas - torch.logsumexp(log_lambdas, 1)[:,None,:]

In [None]:
plt.figure(figsize=(4,3))
plt.imshow(torch.exp(log_pis.mean(0)).detach().numpy(), 
           aspect='auto', cmap='cubehelix')
plt.colorbar();

In [None]:
plt.figure(figsize=(4,3))
plt.plot(cavi.betas[-1,:].detach().numpy());

In [None]:
# gmm.means_

In [None]:
# cavi.means

In [28]:
cont_y_enc_res = {
    'bs': cavi.bs,
    'betas': cavi.betas,
    'means': cavi.means,
    'covs': cavi.covs
}
np.save(save_path + f'dy009_cont_y_enc_res_c{len(cavi.means)}.npy', cont_y_enc_res)

#### MoG only

In [None]:
all_trials = np.concatenate(np.concatenate(data_loader.trials))[:,1:]
spike_times = data_loader.data[:,0]

spike_labels = []
spike_probs = []
spike_labels.extend(gmm.predict(all_trials))
spike_probs.extend(gmm.predict_proba(all_trials))
spike_labels = np.array(spike_labels)
spike_probs = np.array(spike_probs)

In [None]:
enc_gmm = preprocess.compute_neural_activity(
    (spike_times, spike_labels, spike_probs),
    data_loader.stim_on_times,
    'clusterless', 
    n_time_bins=data_loader.n_t_bins
)
print(enc_gmm.shape)

In [None]:
train = data_loader.train_ids
test = data_loader.test_ids

In [None]:
x_train = enc_gmm.reshape(-1, enc_gmm.shape[1] * enc_gmm.shape[2])[train]
x_test = enc_gmm.reshape(-1, enc_gmm.shape[1] * enc_gmm.shape[2])[test]
y_train = data_loader.y[train]

ridge = Ridge(alpha=2000)
ridge.fit(x_train, y_train)
y_hat = ridge.predict(x_test)

In [None]:
print(f'R2 = {r2_score(data_loader.y[test], y_hat):.3f}')
print(f'MSE = {mean_squared_error(data_loader.y[test], y_hat):.3f}')
print(f'corr = {pearsonr(data_loader.y[test].flatten(), y_hat.flatten()).statistic:.3f}')

In [None]:
import time
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import WhiteKernel, ExpSineSquared, RationalQuadratic, RBF

long_term_trend_kernel = 50.0**2 * RBF(length_scale=50.0)
seasonal_kernel = (
    2.0**2
    * RBF(length_scale=100.0)
    * ExpSineSquared(length_scale=1.0, periodicity=1.0, periodicity_bounds="fixed")
)
# irregularities_kernel = 0.5**2 * RationalQuadratic(length_scale=1.0, alpha=1.0)
# noise_kernel = 0.1**2 * RBF(length_scale=0.1) + WhiteKernel(
#     noise_level=0.1**2, noise_level_bounds=(1e-5, 1e5)
# )
kernel = (
    long_term_trend_kernel #+ seasonal_kernel + noise_kernel #+ irregularities_kernel 
)
gaussian_process = GaussianProcessRegressor(kernel=kernel, alpha=1e-10)
start_time = time.time()
gaussian_process.fit(x_train, y_train)
print(
    f"Time for GaussianProcessRegressor fitting: {time.time() - start_time:.3f} seconds"
)
gp_pred = gaussian_process.predict(x_test)
print(f'R2 = {r2_score(data_loader.y[test], gp_pred):.3f}')
print(f'MSE = {mean_squared_error(data_loader.y[test], gp_pred):.3f}')
print(f'corr = {pearsonr(data_loader.y[test].flatten(), gp_pred.flatten()).statistic:.3f}')

In [None]:
plt.figure(figsize=(12, 2))
plt.plot(data_loader.y[test].flatten()[:200], c='gray', linestyle='dashed', label='observed');
plt.plot(gp_pred.flatten()[:200], c='blue', alpha=.6, label='predicted');

In [None]:
windowed_enc_gmm, half_window_size, n_windows = decoder.sliding_window(
    enc_gmm, 
    data_loader.n_trials,
    window_size = 7
)
windowed_y = data_loader.y[:,half_window_size:n_windows].reshape(-1,1)

In [None]:
x_by_trial = windowed_enc_gmm.reshape((data_loader.n_trials, -1))
y_by_trial = windowed_y.reshape((data_loader.n_trials, -1))
x_train, x_test = x_by_trial[train], x_by_trial[test]
y_train, y_test = y_by_trial[train], y_by_trial[test]

x_train = x_train.reshape((-1, windowed_enc_gmm.shape[1]))
x_test = x_test.reshape((-1, windowed_enc_gmm.shape[1]))
y_train = y_train.flatten()
y_test = y_test.flatten()

ridge = Ridge(alpha=10000)
ridge.fit(x_train, y_train)
y_pred = ridge.predict(x_test)

In [None]:
print(f'R2 = {r2_score(y_test, y_pred):.3f}')
print(f'MSE = {mean_squared_error(y_test, y_pred):.3f}')
print(f'corr = {pearsonr(y_test, y_pred).statistic:.3f}')

In [None]:
plt.figure(figsize=(12, 2))
plt.plot(y_test[:200], c='gray', linestyle='dashed', label='observed');
plt.plot(y_pred[:200], c='blue', alpha=.6, label='predicted');

#### encoding MoG

In [None]:
n_trials = stim_on_times.shape[0]
unsorted = np.vstack([unsorted_trials[i] for i in np.arange(n_trials)]) 
spike_times = unsorted[:,0]
spike_channels = unsorted[:,1]
spike_features = unsorted[:,2:]

thresholded_neural_data = preprocess.compute_neural_activity(
    (spike_times, spike_channels),
    stim_on_times,
    'thresholded', 
    n_time_bins=30,
    regional=True
)
print(f'thresholded neural data shape: {thresholded_neural_data.shape}')

In [None]:
x_train = thresholded_neural_data.reshape(-1, thresholded_neural_data.shape[1] * thresholded_neural_data.shape[2])[train]
x_test = thresholded_neural_data.reshape(-1, thresholded_neural_data.shape[1] * thresholded_neural_data.shape[2])[test]
y_train = data_loader.y[train]

In [None]:
import time
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import WhiteKernel, ExpSineSquared, RationalQuadratic, RBF

long_term_trend_kernel = 50.0**2 * RBF(length_scale=50.0)
seasonal_kernel = (
    2.0**2
    * RBF(length_scale=100.0)
    * ExpSineSquared(length_scale=1.0, periodicity=1.0, periodicity_bounds="fixed")
)
# irregularities_kernel = 0.5**2 * RationalQuadratic(length_scale=1.0, alpha=1.0)
# noise_kernel = 0.1**2 * RBF(length_scale=0.1) + WhiteKernel(
#     noise_level=0.1**2, noise_level_bounds=(1e-5, 1e5)
# )
kernel = (
    long_term_trend_kernel #+ seasonal_kernel #+ irregularities_kernel + noise_kernel
)
gaussian_process = GaussianProcessRegressor(kernel=kernel, alpha=1e-2)
start_time = time.time()
gaussian_process.fit(x_train, y_train)
print(
    f"Time for GaussianProcessRegressor fitting: {time.time() - start_time:.3f} seconds"
)
gp_pred = gaussian_process.predict(x_test)
print(f'R2 = {r2_score(data_loader.y[test], gp_pred):.3f}')
print(f'MSE = {mean_squared_error(data_loader.y[test], gp_pred):.3f}')
print(f'corr = {pearsonr(data_loader.y[test].flatten(), gp_pred.flatten()).statistic:.3f}')

In [None]:
plt.figure(figsize=(12, 2))
plt.plot(data_loader.y[test].flatten()[:200], c='gray', linestyle='dashed', label='observed');
plt.plot(gp_pred.flatten()[:200], c='blue', alpha=.6, label='predicted');

In [None]:
ridge = Ridge(alpha=2000)
ridge.fit(x_train, y_train)
y_hat = ridge.predict(x_test)

In [None]:
print(f'R2 = {r2_score(data_loader.y[test], y_hat):.3f}')
print(f'MSE = {mean_squared_error(data_loader.y[test], y_hat):.3f}')
print(f'corr = {pearsonr(data_loader.y[test].flatten(), y_hat.flatten()).statistic:.3f}')

In [None]:
plt.figure(figsize=(12, 2))
plt.plot(data_loader.y[test].flatten()[:200], c='gray', linestyle='dashed', label='observed');
plt.plot(y_hat.flatten()[:200], c='blue', alpha=.6, label='predicted');

In [None]:
log_lambdas_hat = np.zeros((data_loader.n_trials, Nc, Nt))
for k in range(len(train)):
    for t in range(Nt):
        log_lambdas_hat[k,:,t] = cavi.bs.detach().numpy() + cavi.betas[:,t].detach().numpy() * y_train[k][t]

for k in range(len(test)):
    for t in range(Nt):
        log_lambdas_hat[k,:,t] = cavi.bs.detach().numpy() + cavi.betas[:,t].detach().numpy() * y_hat[k][t]

log_pis_hat = log_lambdas_hat - logsumexp(log_lambdas_hat, 1)[:,None,:]

In [None]:
enc_pis = np.exp(log_pis_hat)
enc_means = cavi.means.detach().numpy()
enc_covs = cavi.covs.detach().numpy()

In [None]:
enc_all = np.zeros((data_loader.n_trials, Nc, Nt))

for k in range(enc_all.shape[0]):
    for t in range(Nt):
        enc_gmm =  GaussianMixture(n_components=Nc, covariance_type='full')
        enc_gmm.precisions_cholesky_ = np.linalg.cholesky(np.linalg.inv(enc_covs))
        enc_gmm.weights_ = enc_pis[k,:,t]
        # enc_gmm.weights_ = enc_pis[:,:,t].mean(0)
        enc_gmm.means_ = enc_means
        enc_gmm.covariances_ = enc_covs
        if len(data_loader.trials[k][t]) > 0:
            enc_all[k,:,t] = enc_gmm.predict_proba(data_loader.trials[k][t][:,1:]).sum(0)

In [None]:
plt.figure(figsize=(4,3))
plt.imshow(enc_all.mean(0), aspect='auto', cmap='cubehelix')
plt.colorbar();

In [None]:
x_train = enc_all.reshape(-1, enc_all.shape[1] * enc_all.shape[2])[train]
x_test = enc_all.reshape(-1, enc_all.shape[1] * enc_all.shape[2])[test]
y_train = data_loader.y[train]

ridge = Ridge(alpha=2000)
ridge.fit(x_train, y_train)
y_hat = ridge.predict(x_test)

In [None]:
print(f'R2 = {r2_score(data_loader.y[test], y_hat):.3f}')
print(f'MSE = {mean_squared_error(data_loader.y[test], y_hat):.3f}')
print(f'corr = {pearsonr(data_loader.y[test].flatten(), y_hat.flatten()).statistic:.3f}')

In [None]:
import time
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import WhiteKernel, ExpSineSquared, RationalQuadratic, RBF

long_term_trend_kernel = 50.0**2 * RBF(length_scale=50.0)
seasonal_kernel = (
    2.0**2
    * RBF(length_scale=100.0)
    * ExpSineSquared(length_scale=1.0, periodicity=1.0, periodicity_bounds="fixed")
)
# irregularities_kernel = 0.5**2 * RationalQuadratic(length_scale=1.0, alpha=1.0)
# noise_kernel = 0.1**2 * RBF(length_scale=0.1) + WhiteKernel(
#     noise_level=0.1**2, noise_level_bounds=(1e-5, 1e5)
# )
kernel = (
    long_term_trend_kernel #+ seasonal_kernel + noise_kernel #+ irregularities_kernel 
)
gaussian_process = GaussianProcessRegressor(kernel=kernel, alpha=1e-10)
start_time = time.time()
gaussian_process.fit(x_train, y_train)
print(
    f"Time for GaussianProcessRegressor fitting: {time.time() - start_time:.3f} seconds"
)
gp_pred = gaussian_process.predict(x_test)
print(f'R2 = {r2_score(data_loader.y[test], gp_pred):.3f}')
print(f'MSE = {mean_squared_error(data_loader.y[test], gp_pred):.3f}')
print(f'corr = {pearsonr(data_loader.y[test].flatten(), gp_pred.flatten()).statistic:.3f}')

In [None]:
plt.figure(figsize=(12, 2))
plt.plot(data_loader.y[test].flatten()[:200], c='gray', linestyle='dashed', label='observed');
plt.plot(gp_pred.flatten()[:200], c='blue', alpha=.6, label='predicted');

In [None]:
windowed_enc_all, half_window_size, n_windows = decoder.sliding_window(
    enc_all, 
    data_loader.n_trials,
    window_size = 7
)
windowed_y = data_loader.y[:,half_window_size:n_windows].reshape(-1,1)

In [None]:
x_by_trial = windowed_enc_all.reshape((data_loader.n_trials, -1))
y_by_trial = windowed_y.reshape((data_loader.n_trials, -1))
x_train, x_test = x_by_trial[train], x_by_trial[test]
y_train, y_test = y_by_trial[train], y_by_trial[test]

x_train = x_train.reshape((-1, windowed_enc_all.shape[1]))
x_test = x_test.reshape((-1, windowed_enc_all.shape[1]))
y_train = y_train.flatten()
y_test = y_test.flatten()

In [None]:
ridge = Ridge(alpha=10000)
ridge.fit(x_train, y_train)
y_pred = ridge.predict(x_test)

In [None]:
print(f'R2 = {r2_score(y_test, y_pred):.3f}')
print(f'MSE = {mean_squared_error(y_test, y_pred):.3f}')
print(f'corr = {pearsonr(y_test, y_pred).statistic:.3f}')

In [None]:
plt.figure(figsize=(12, 2))
plt.plot(y_test[:200], c='gray', linestyle='dashed', label='observed');
plt.plot(y_pred[:200], c='blue', alpha=.6, label='predicted');

#### thresholded

In [None]:
windowed_thresh, half_window_size, n_windows = decoder.sliding_window(
    thresholded_neural_data, 
    data_loader.n_trials,
    window_size = 7
)
windowed_y = data_loader.y[:,half_window_size:n_windows].reshape(-1,1)

In [None]:
x_by_trial = windowed_thresh.reshape((data_loader.n_trials, -1))
y_by_trial = windowed_y.reshape((data_loader.n_trials, -1))
x_train, x_test = x_by_trial[train], x_by_trial[test]
y_train, y_test = y_by_trial[train], y_by_trial[test]

x_train = x_train.reshape((-1, windowed_thresh.shape[1]))
x_test = x_test.reshape((-1, windowed_thresh.shape[1]))
y_train = y_train.flatten()
y_test = y_test.flatten()

In [None]:
ridge = Ridge(alpha=10000)
ridge.fit(x_train, y_train)
y_pred = ridge.predict(x_test)

In [None]:
print(f'R2 = {r2_score(y_test, y_pred):.3f}')
print(f'MSE = {mean_squared_error(y_test, y_pred):.3f}')
print(f'corr = {pearsonr(y_test, y_pred).statistic:.3f}')

In [None]:
plt.figure(figsize=(12, 2))
plt.plot(y_test[:200], c='gray', linestyle='dashed', label='observed');
plt.plot(y_pred[:200], c='blue', alpha=.6, label='predicted');

#### KS & good IBL units

In [None]:
n_trials = stim_on_times.shape[0]
sorted = np.vstack([sorted_trials[i] for i in np.arange(n_trials)]) 
spike_times = sorted[:,0]
spike_clusters = sorted[:,1]

sorted_neural_data = preprocess.compute_neural_activity(
    (spike_times, spike_clusters),
    stim_on_times,
    'sorted', 
    n_time_bins=30,
    regional=True
)
print(f'sorted neural data shape: {sorted_neural_data.shape}')

good_sorted = np.vstack([good_sorted_trials[i] for i in np.arange(n_trials)]) 
spike_times = good_sorted[:,0]
spike_clusters = good_sorted[:,1]

good_sorted_neural_data = preprocess.compute_neural_activity(
    (spike_times, spike_clusters),
    stim_on_times,
    'sorted', 
    n_time_bins=30,
    regional=True
)
print(f'good sorted neural data shape: {good_sorted_neural_data.shape}')

In [None]:
x_train = sorted_neural_data.reshape(-1, sorted_neural_data.shape[1] * sorted_neural_data.shape[2])[train]
x_test = sorted_neural_data.reshape(-1, sorted_neural_data.shape[1] * sorted_neural_data.shape[2])[test]
y_train = data_loader.y[train]

ridge = Ridge(alpha=2000)
ridge.fit(x_train, y_train)
y_hat = ridge.predict(x_test)

In [None]:
print(f'R2 = {r2_score(data_loader.y[test], y_hat):.3f}')
print(f'MSE = {mean_squared_error(data_loader.y[test], y_hat):.3f}')
print(f'corr = {pearsonr(data_loader.y[test].flatten(), y_hat.flatten()).statistic:.3f}')

In [None]:
plt.figure(figsize=(12, 2))
plt.plot(data_loader.y[test].flatten()[:200], c='gray', linestyle='dashed', label='observed');
plt.plot(y_hat.flatten()[:200], c='blue', alpha=.6, label='predicted');

In [None]:
import time
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import WhiteKernel, ExpSineSquared, RationalQuadratic, RBF

long_term_trend_kernel = 50.0**2 * RBF(length_scale=50.0)
seasonal_kernel = (
    2.0**2
    * RBF(length_scale=100.0)
    * ExpSineSquared(length_scale=1.0, periodicity=1.0, periodicity_bounds="fixed")
)
# irregularities_kernel = 0.5**2 * RationalQuadratic(length_scale=1.0, alpha=1.0)
noise_kernel = 0.1**2 * RBF(length_scale=0.1) + WhiteKernel(
    noise_level=0.1**2, noise_level_bounds=(1e-5, 1e5)
)
kernel = (
    long_term_trend_kernel #+ seasonal_kernel + noise_kernel #+ irregularities_kernel
)
gaussian_process = GaussianProcessRegressor(kernel=kernel, alpha=1e-2)
start_time = time.time()
gaussian_process.fit(x_train, y_train)
print(
    f"Time for GaussianProcessRegressor fitting: {time.time() - start_time:.3f} seconds"
)
gp_pred = gaussian_process.predict(x_test)
print(f'R2 = {r2_score(data_loader.y[test], gp_pred):.3f}')
print(f'MSE = {mean_squared_error(data_loader.y[test], gp_pred):.3f}')
print(f'corr = {pearsonr(data_loader.y[test].flatten(), gp_pred.flatten()).statistic:.3f}')

In [None]:
windowed_sorted, half_window_size, n_windows = decoder.sliding_window(
    sorted_neural_data, 
    data_loader.n_trials,
    window_size = 7
)

windowed_good_units, half_window_size, n_windows = decoder.sliding_window(
    good_sorted_neural_data, 
    data_loader.n_trials,
    window_size = 7
)
windowed_y = data_loader.y[:,half_window_size:n_windows].reshape(-1,1)

In [None]:
x_by_trial = windowed_sorted.reshape((data_loader.n_trials, -1))
y_by_trial = windowed_y.reshape((data_loader.n_trials, -1))
x_train, x_test = x_by_trial[train], x_by_trial[test]
y_train, y_test = y_by_trial[train], y_by_trial[test]

x_train = x_train.reshape((-1, windowed_sorted.shape[1]))
x_test = x_test.reshape((-1, windowed_sorted.shape[1]))
y_train = y_train.flatten()
y_test = y_test.flatten()

ridge = Ridge(alpha=10000)
ridge.fit(x_train, y_train)
y_pred = ridge.predict(x_test)

In [None]:
print(f'R2 = {r2_score(y_test, y_pred):.3f}')
print(f'MSE = {mean_squared_error(y_test, y_pred):.3f}')
print(f'corr = {pearsonr(y_test, y_pred).statistic:.3f}')

In [None]:
plt.figure(figsize=(12, 2))
plt.plot(y_test[:200], c='gray', linestyle='dashed', label='observed');
plt.plot(y_pred[:200], c='blue', alpha=.6, label='predicted');

In [None]:
x_train = good_sorted_neural_data.reshape(-1, good_sorted_neural_data.shape[1] * good_sorted_neural_data.shape[2])[train]
x_test = good_sorted_neural_data.reshape(-1, good_sorted_neural_data.shape[1] * good_sorted_neural_data.shape[2])[test]
y_train = data_loader.y[train]

ridge = Ridge(alpha=2000)
ridge.fit(x_train, y_train)
y_hat = ridge.predict(x_test)

In [None]:
print(f'R2 = {r2_score(data_loader.y[test], y_hat):.3f}')
print(f'MSE = {mean_squared_error(data_loader.y[test], y_hat):.3f}')
print(f'corr = {pearsonr(data_loader.y[test].flatten(), y_hat.flatten()).statistic:.3f}')

In [None]:
import time
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import WhiteKernel, ExpSineSquared, RationalQuadratic, RBF

long_term_trend_kernel = 50.0**2 * RBF(length_scale=50.0)
seasonal_kernel = (
    2.0**2
    * RBF(length_scale=100.0)
    * ExpSineSquared(length_scale=1.0, periodicity=1.0, periodicity_bounds="fixed")
)
# irregularities_kernel = 0.5**2 * RationalQuadratic(length_scale=1.0, alpha=1.0)
noise_kernel = 0.1**2 * RBF(length_scale=0.1) + WhiteKernel(
    noise_level=0.1**2, noise_level_bounds=(1e-5, 1e5)
)
kernel = (
    long_term_trend_kernel #+ seasonal_kernel + noise_kernel #+ irregularities_kernel
)
gaussian_process = GaussianProcessRegressor(kernel=kernel, alpha=1e-2)
start_time = time.time()
gaussian_process.fit(x_train, y_train)
print(
    f"Time for GaussianProcessRegressor fitting: {time.time() - start_time:.3f} seconds"
)
gp_pred = gaussian_process.predict(x_test)
print(f'R2 = {r2_score(data_loader.y[test], gp_pred):.3f}')
print(f'MSE = {mean_squared_error(data_loader.y[test], gp_pred):.3f}')
print(f'corr = {pearsonr(data_loader.y[test].flatten(), gp_pred.flatten()).statistic:.3f}')

In [None]:
plt.figure(figsize=(12, 2))
plt.plot(data_loader.y[test].flatten()[:200], c='gray', linestyle='dashed', label='observed');
plt.plot(gp_pred.flatten()[:200], c='blue', alpha=.6, label='predicted');

In [None]:
x_by_trial = windowed_good_units.reshape((data_loader.n_trials, -1))
y_by_trial = windowed_y.reshape((data_loader.n_trials, -1))
x_train, x_test = x_by_trial[train], x_by_trial[test]
y_train, y_test = y_by_trial[train], y_by_trial[test]

x_train = x_train.reshape((-1, windowed_good_units.shape[1]))
x_test = x_test.reshape((-1, windowed_good_units.shape[1]))
y_train = y_train.flatten()
y_test = y_test.flatten()

ridge = Ridge(alpha=2000)
ridge.fit(x_train, y_train)
y_pred = ridge.predict(x_test)

In [None]:
print(f'R2 = {r2_score(y_test, y_pred):.3f}')
print(f'MSE = {mean_squared_error(y_test, y_pred):.3f}')
print(f'corr = {pearsonr(y_test, y_pred).statistic:.3f}')

In [None]:
plt.figure(figsize=(12, 2))
plt.plot(y_test[:200], c='gray', linestyle='dashed', label='observed');
plt.plot(y_pred[:200], c='blue', alpha=.6, label='predicted');

#### plotting

In [79]:
po_good_units_r2 = [0.008, 0.058, 0.066, 0.046, 0.068]
po_all_units_r2 = [0.159, 0.164, 0.119, 0.135, 0.140]
po_thresh_r2 = [0.138, 0.123, 0.099, 0.130, 0.145]
po_mog_r2 = [0.102, 0.048, 0.028, 0.064, 0.049]
po_enc_mog_r2 = [0.154, 0.170, 0.104, 0.135, 0.147]

po_good_units_mse = [1.685, 1.219, 1.447, 1.423, 1.535]
po_all_units_mse = [1.175, 0.889, 1.284, 1.193, 1.253]
po_thresh_mse = [1.248, 1.011, 1.323, 1.209, 1.216]
po_mog_mse = [1.329, 1.281, 1.547, 1.378, 1.597]
po_enc_mog_mse = [1.180, 0.892, 1.308, 1.191, 1.239]

po_good_units_corr = []
po_all_units_corr = []
po_thresh_corr = []
po_mog_corr = []
po_enc_mog_corr = []

In [81]:
print(f'PO good units R2: {np.mean(po_good_units_r2):.3f} MSE: {np.mean(po_good_units_mse):.3f}')
print(f'PO all units R2: {np.mean(po_all_units_r2):.3f} MSE: {np.mean(po_all_units_mse):.3f}')
print(f'PO thresholded R2: {np.mean(po_thresh_r2):.3f} MSE: {np.mean(po_thresh_mse):.3f}')
print(f'PO MoG R2: {np.mean(po_mog_r2):.3f} MSE: {np.mean(po_mog_mse):.3f}')
print(f'PO encoded-MoG R2: {np.mean(po_enc_mog_r2):.3f} MSE: {np.mean(po_enc_mog_mse):.3f}')

PO good units R2: 0.049 MSE: 1.462
PO all units R2: 0.143 MSE: 1.159
PO thresholded R2: 0.127 MSE: 1.201
PO MoG R2: 0.058 MSE: 1.426
PO encoded-MoG R2: 0.142 MSE: 1.162


In [None]:
lp_good_units_r2 = [0.004, ]
lp_all_units_r2 = [0.157, ]
lp_thresh_r2 = [0.126, ]
lp_mog_r2 = [0.053, ]
lp_enc_mog_r2 = [0.151, ]

lp_good_units_mse = [1.646, ]
lp_all_units_mse = [1.122, ]
lp_thresh_mse = [1.235, ]
lp_mog_mse = [1.513, ]
lp_enc_mog_mse = [1.136, ]

lp_good_units_corr = [0.407, ]
lp_all_units_corr = [0.667, ]
lp_thresh_corr = [0.614, ]
lp_mog_corr = [0.484, ]
lp_enc_mog_corr = [0.662, ]

In [None]:
dg_good_units_mse = []
dg_all_units_mse = []
dg_thresh_mse = []
dg_mog_mse = []
dg_enc_mog_mse = []

ca1_good_units_mse = []
ca1_all_units_mse = []
ca1_thresh_mse = []
ca1_mog_mse = []
ca1_enc_mog_mse = []

vis_good_units_mse = []
vis_all_units_mse = []
vis_thresh_mse = []
vis_mog_mse = []
vis_enc_mog_mse = []