# Prepare spike trains for gpfa and train models

## !! This script generates gpfa_dict.pkl

In [17]:
# # IMPORTS
# %matplotlib inline
# %run -i '/home/jovyan/pablo_tostado/bird_song/manifold_paper_analysis/all_imports.py'

In [18]:
import pickle as pkl
import numpy as np
import quantities as pq

from songbirdcore.statespace_analysis.gpfa_songbirdcore import GPFACore
from songbirdcore.statespace_analysis.pca_songbirdcore import PCACore
from songbirdcore.statespace_analysis.statespace_analysis_utils import convert_to_neo_spike_trains, convert_to_neo_spike_trains_3d

import songbirdcore.spikefinder.spike_analysis_helper as sh
from songbirdcore.params import GlobalParams as gparams
from songbirdcore.utils.data_utils import save_dataset

# Load data

In [23]:
file_path = '/home/jovyan/pablo_tostado/bird_song/enSongDec/data/RAW_z_r12r13_21_20240328_185716.pkl'

with open(file_path, 'rb') as pickle_file:
    state_space_analysis_dict = pkl.load(pickle_file)

print("Loaded Dictionary!")

print(state_space_analysis_dict.keys())


neural_dict = state_space_analysis_dict['neural_dict']
audio_motifs = state_space_analysis_dict['audio_motifs']
audio_labels = state_space_analysis_dict['audio_labels']
fs_neural = state_space_analysis_dict['fs_neural']
fs_audio = state_space_analysis_dict['fs_audio']
t_pre = state_space_analysis_dict['t_pre']
t_post = state_space_analysis_dict['t_post']
sess_params = state_space_analysis_dict['sess_params']

t_pre, t_post

Loaded Dictionary!
dict_keys(['neural_dict', 'audio_motifs', 'audio_labels', 'fs_neural', 'fs_audio', 't_pre', 't_post', 'sess_params'])


(0.1, 0.7999999999999999)

### Drop silent clusters (at least 1 spike per trial)

In [24]:
for k in neural_dict.keys():
    num_trials = len(neural_dict[k])
    neural_dict[k] = np.delete(neural_dict[k], np.where(np.sum(neural_dict[k], axis=(0,2))<num_trials)[0], axis=1)
    
    print(k, neural_dict[k].shape)

ra_sua (47, 61, 26999)
ra_all (47, 84, 26999)
hvc_sua (47, 20, 26999)
hvc_all (47, 50, 26999)


In [25]:
# ## Generate spiketrains for shuffle control conditions

# neural_groups = list(neural_dict.keys())

# for k in neural_groups:
#     neural_dict[k+'_shuffle_time'] = np.array([permute_array_rows_independently(i) for i in neural_dict[k]])
#     neural_dict[k+'_shuffle_neurons'] = np.array([permute_array_cols_independently(i) for i in neural_dict[k]])

# display(neural_dict.keys())

# Fit GPFA & PCA

In [30]:
# State-Space Analysis Params
bin_size = 15 * pq.ms
latent_dim = [2, 4, 8, 12, 24] # neural_dict[key].shape[1]

## Loop to fit PCA and GPFA to each neural group of interest (including controls)

In [31]:
# neural_groups = list(neural_dict.keys())
neural_groups = ['ra_all', 'hvc_all']
latent_models = ['pca', 'gpfa']
neural_samp_perc = 1

# Dictionary to store PCA & GPFA state-space analysis results for each neural group
state_space_analysis_dict = {}
resampled_neural_dict = {}

for ng in neural_groups:
    
    neural_traces = neural_dict[ng]
    print(f'Processing {ng}')

    # Randomly sample neural channels
    if neural_samp_perc < 1:
        print(f'subsampling {neural_samp_perc*100}% of the neural channels.')
        num_channels = neural_traces.shape[1]
        num_to_sample = round(num_channels * neural_samp_perc)  
        sampled_indices = np.random.choice(num_channels, num_to_sample, replace=False)
        neural_traces = neural_traces[:, sampled_indices, :]
        print(ng, neural_traces.shape)
    
    # Dictionary to store PCA & GPFA state-space analysis results for ng
    trajectories_dict = {k:{} for k in latent_models}
    
    for ld in latent_dim:
        
        # If not enough clusters for desired number of latent dimensions
        if ld > neural_traces.shape[1]:
            print(f'ld: {ld}, clusters: {neural_traces.shape[1]}. ld > num_clusters: skipping state-space analysis.')
            trajectories_dict['pca'][k] = None
            trajectories_dict['gpfa'][k] = None
            continue

        """ Fit PCA """
        # Instantiate PCA
        myPCA = PCACore(neural_traces, round(fs_neural), audio_motifs, fs_audio, audio_labels, fs_audio)
        myPCA.instantiate_pca(ld)

        # Downsample spiketrains
        spike_trains = sh.downsample_list_3d(myPCA.neural_traces, number_bin_samples=int(bin_size/1000*fs_neural), mode='sum')

        # Fit PCA
        pca_dict = myPCA.fit_transform_pca(spike_trains)
        pca_dict['bin_w'] = bin_size

        k = ng+'_dim'+str(ld)
        trajectories_dict['pca'][k] = pca_dict

        """ Fit GPFA """
        # Instantiate GPFA
        myGPFA = GPFACore(neural_traces, round(fs_neural), audio_motifs, fs_audio, audio_labels, fs_audio)

        # Run GPFA in spiketrains of targer neural data
        myGPFA.instantiate_gpfa(bin_size, ld, em_max_iters=gparams.gpfa_max_iter);

        k = ng+'_dim'+str(ld)
        spike_trains = convert_to_neo_spike_trains_3d(myGPFA.neural_traces, myGPFA.fs_neural)
        trajectories_dict['gpfa'][k] = myGPFA.fit_transform_gpfa(spike_trains);
    
    state_space_analysis_dict[ng] = trajectories_dict

# Save params
dir_path = '/home/jovyan/pablo_tostado/bird_song/enSongDec/data/'
filename_appendix = f"{neural_samp_perc*100}%_neural_channels"

# # Save RAW
# file_type = 'RAW'
# raw_fs_neural = fs_neural
# save_dataset(
#     resampled_neural_dict,
#     audio_motifs,
#     audio_labels,
#     raw_fs_neural,
#     fs_audio,
#     t_pre,
#     t_post,
#     sess_params, 
#     file_type,
#     dir_path, 
#     filename_appendix)

# Save trajectories
file_type = 'TRAJECTORIES'
traj_fs_neural = 1/int(bin_size)*1000
save_dataset(
    state_space_analysis_dict,
    audio_motifs,
    audio_labels,
    traj_fs_neural,
    fs_audio,
    t_pre,
    t_post,
    sess_params, 
    file_type,
    dir_path, 
    filename_appendix)
    

Processing ra_all




Initializing parameters using factor analysis...

Fitting GPFA model...
Fitting has converged after 345 EM iterations.)




Initializing parameters using factor analysis...

Fitting GPFA model...
Fitting has converged after 1560 EM iterations.)




Initializing parameters using factor analysis...

Fitting GPFA model...




Initializing parameters using factor analysis...

Fitting GPFA model...




Initializing parameters using factor analysis...

Fitting GPFA model...




Processing hvc_all




Initializing parameters using factor analysis...

Fitting GPFA model...
Fitting has converged after 1240 EM iterations.)




Initializing parameters using factor analysis...

Fitting GPFA model...




Initializing parameters using factor analysis...

Fitting GPFA model...




Initializing parameters using factor analysis...

Fitting GPFA model...




Initializing parameters using factor analysis...

Fitting GPFA model...




Dictionary saved as /home/jovyan/pablo_tostado/bird_song/enSongDec/data/TRAJECTORIES_z_r12r13_21_20240426_030454_100%_neural_channels.pkl.pkl


In [33]:
state_space_analysis_dict

{'ra_all': {'pca': {'ra_all_dim2': {'model': PCA(n_components=2),
    'trajectories': array([[[ 9.43426693e-01,  1.94358226e+00, -8.69310860e-01, ...,
             -3.03830052e+00, -3.09324700e+00, -2.28961565e+00],
            [ 1.50721962e+00, -1.26186622e-01,  2.14460492e+00, ...,
             -1.56076583e-01,  9.69385181e-02,  4.85035806e-01]],
    
           [[-1.45526056e+00, -1.00242701e+00, -5.31890172e-01, ...,
             -2.54725205e+00, -3.24718633e+00, -2.95089532e+00],
            [-5.90985099e-01, -1.68681203e+00, -2.21070046e+00, ...,
              4.70906793e-01, -2.27490467e-01,  4.71389878e-02]],
    
           [[-1.88175655e+00, -7.35268109e-01, -8.47983036e-01, ...,
             -2.79086185e+00, -2.96947442e+00, -3.08902978e+00],
            [-4.78506880e-01, -1.98133466e+00, -2.99598781e+00, ...,
             -9.52976045e-01, -1.47078286e+00, -1.31619729e+00]],
    
           ...,
    
           [[-2.10821246e+00, -4.44286504e-01,  1.36405796e+00, ...,
      

In [35]:
def print_dict_tree(d, indent=0):
    for key, value in d.items():
        print('  ' * indent + str(key))
        if isinstance(value, dict):
            print_dict_tree(value, indent + 1)


# print_dict_tree(state_space_analysis_dict)