# Prepare spike trains for gpfa and train models

## !! This script generates gpfa_dict.pkl

In [7]:
# # IMPORTS
# %matplotlib inline
# %run -i '/home/jovyan/pablo_tostado/bird_song/manifold_paper_analysis/all_imports.py'

In [8]:
import pickle as pkl
import numpy as np
import quantities as pq

from songbirdcore.statespace_analysis.gpfa_songbirdcore import GPFACore
from songbirdcore.statespace_analysis.pca_songbirdcore import PCACore
from songbirdcore.statespace_analysis.statespace_analysis_utils import convert_to_neo_spike_trains, convert_to_neo_spike_trains_3d

import songbirdcore.spikefinder.spike_analysis_helper as sh
from songbirdcore.utils.params import GlobalParams as gparams
from songbirdcore.utils.data_utils import save_dataset

np.random.seed(24)

# Load data

In [9]:
# # b1, RAW
file_path = '/home/jovyan/pablo_tostado/bird_song/enSongDec/data/RAW_z_w12m7_20_20240325_210721.pkl'
# b2, RAW
# file_path = '/home/jovyan/pablo_tostado/bird_song/enSongDec/data/RAW_z_r12r13_21_20240328_185716.pkl'

with open(file_path, 'rb') as pickle_file:
    state_space_analysis_dict = pkl.load(pickle_file)

print("Loaded Dictionary!")

print(state_space_analysis_dict.keys())

neural_dict = state_space_analysis_dict['neural_dict']
audio_motifs = state_space_analysis_dict['audio_motifs']
audio_labels = state_space_analysis_dict['audio_labels']
fs_neural = state_space_analysis_dict['fs_neural']
fs_audio = state_space_analysis_dict['fs_audio']
t_pre = state_space_analysis_dict['t_pre']
t_post = state_space_analysis_dict['t_post']
sess_params = state_space_analysis_dict['sess_params']

t_pre, t_post

Loaded Dictionary!
dict_keys(['neural_dict', 'audio_motifs', 'audio_labels', 'fs_neural', 'fs_audio', 't_pre', 't_post', 'sess_params'])


(0.1, 0.6)

### Drop silent clusters (at least 1 spike per trial)

In [10]:
for k in neural_dict.keys():
    num_trials = len(neural_dict[k])
    neural_dict[k] = np.delete(neural_dict[k], np.where(np.sum(neural_dict[k], axis=(0,2))<num_trials)[0], axis=1)
    
    print(k, neural_dict[k].shape)

ra_sua (10, 93, 21000)
ra_all (10, 172, 21000)
hvc_sua (10, 46, 21000)
hvc_all (10, 114, 21000)


In [11]:
# ## Generate spiketrains for shuffle control conditions

# neural_groups = list(neural_dict.keys())

# for k in neural_groups:
#     neural_dict[k+'_shuffle_time'] = np.array([permute_array_rows_independently(i) for i in neural_dict[k]])
#     neural_dict[k+'_shuffle_neurons'] = np.array([permute_array_cols_independently(i) for i in neural_dict[k]])

# display(neural_dict.keys())

# Fit GPFA & PCA

In [12]:
# State-Space Analysis Params
bin_size = 5 * pq.ms
latent_dim = [12] # neural_dict[key].shape[1]

## Loop to fit PCA and GPFA to each neural group of interest (including controls)

In [15]:
# neural_groups = list(neural_dict.keys())
neural_groups = ['ra_all']
latent_models = ['pca', 'gpfa']
neural_samp_perc = 0.5

# Initialize dictionaries to store PCA & GPFA state-space analysis results for each neural group
neural_splits = ['original_neural_traces', 'complimentary_neural_traces']
state_space_analysis_dict_raw = {ct: {} for ct in neural_splits}
state_space_analysis_dict_trajectories = {ct: {} for ct in neural_splits}


for ng in neural_groups:
    
    neural_traces = neural_dict[ng]
    print(f'Processing {ng}')

    # Randomly sample neural channels
    if neural_samp_perc < 1:
        print(f'subsampling {neural_samp_perc*100}% of the neural channels.')
        num_channels = neural_traces.shape[1]
        num_to_sample = round(num_channels * neural_samp_perc)  

        all_indices = np.arange(num_channels)
        # Randomly select a subset of indices
        sampled_indices = np.random.choice(all_indices, num_to_sample, replace=False)
        # Find the complementary set of indices
        complementary_indices = np.setdiff1d(all_indices, sampled_indices)

        # Dict with original and complimentary neural sets
        combined_neural_traces = {}
        combined_neural_traces[neural_splits[0]] = neural_traces[:, sampled_indices, :]
        combined_neural_traces[neural_splits[1]] = neural_traces[:, complementary_indices, :]

    for ct in combined_neural_traces.keys():
        
        neural_traces = combined_neural_traces[ct]
        print(ng, ct, neural_traces.shape)
        
        # Dictionary to store PCA & GPFA state-space analysis results for ng
        trajectories_dict = {k:{} for k in latent_models}
    
        for ld in latent_dim:
            
            # If not enough clusters for desired number of latent dimensions
            if ld > neural_traces.shape[1]:
                print(f'ld: {ld}, clusters: {neural_traces.shape[1]}. ld > num_clusters: skipping state-space analysis.')
                trajectories_dict['pca'][k] = None
                trajectories_dict['gpfa'][k] = None
                continue
    
            """ Fit PCA """
            # Instantiate PCA
            myPCA = PCACore(neural_traces, round(fs_neural), audio_motifs, fs_audio, audio_labels, fs_audio)
            myPCA.instantiate_pca(ld)
    
            # Downsample spiketrains
            spike_trains = sh.downsample_list_3d(myPCA.neural_traces, number_bin_samples=int(bin_size/1000*fs_neural), mode='sum')
    
            # Fit PCA
            pca_dict = myPCA.fit_transform_pca(spike_trains)
            pca_dict['bin_w'] = bin_size
    
            k = ng+'_dim'+str(ld)
            trajectories_dict['pca'][k] = pca_dict
    
            """ Fit GPFA """
            # Instantiate GPFA
            myGPFA = GPFACore(neural_traces, round(fs_neural), audio_motifs, fs_audio, audio_labels, fs_audio)
    
            # Run GPFA in spiketrains of targer neural data
            myGPFA.instantiate_gpfa(bin_size, ld, em_max_iters=gparams.gpfa_max_iter);
    
            k = ng+'_dim'+str(ld)
            spike_trains = convert_to_neo_spike_trains_3d(myGPFA.neural_traces, myGPFA.fs_neural)
            trajectories_dict['gpfa'][k] = myGPFA.fit_transform_gpfa(spike_trains);
        
        state_space_analysis_dict_raw[ct][ng] = neural_traces
        state_space_analysis_dict_trajectories[ct][ng] = trajectories_dict


# Save params
dir_path = '/home/jovyan/pablo_tostado/bird_song/enSongDec/data/'
filename_appendix = f"{ng}_latent_stability_dataset"
# Save RAW
file_type = 'RAW'
raw_fs_neural = fs_neural
save_dataset(
    state_space_analysis_dict_raw,
    audio_motifs,
    audio_labels,
    raw_fs_neural,
    fs_audio,
    t_pre,
    t_post,
    sess_params, 
    file_type,
    dir_path, 
    filename_appendix)

# Save trajectories
file_type = 'TRAJECTORIES'
traj_fs_neural = 1/int(bin_size)*1000
save_dataset(
    state_space_analysis_dict_trajectories,
    audio_motifs,
    audio_labels,
    traj_fs_neural,
    fs_audio,
    t_pre,
    t_post,
    sess_params, 
    file_type,
    dir_path, 
    filename_appendix)
    

Processing ra_all
subsampling 50.0% of the neural channels.
ra_all original_neural_traces (10, 86, 21000)
Initializing parameters using factor analysis...

Fitting GPFA model...
ra_all complimentary_neural_traces (10, 86, 21000)




Initializing parameters using factor analysis...

Fitting GPFA model...




Dictionary saved as /home/jovyan/pablo_tostado/bird_song/enSongDec/data/RAW_z_w12m7_20_20240509_051627_ra_all_latent_stability_dataset.pkl
Dictionary saved as /home/jovyan/pablo_tostado/bird_song/enSongDec/data/TRAJECTORIES_z_w12m7_20_20240509_051630_ra_all_latent_stability_dataset.pkl


In [16]:
def print_dict_tree(d, indent=0):
    for key, value in d.items():
        print('  ' * indent + str(key))
        if isinstance(value, dict):
            print_dict_tree(value, indent + 1)


# print_dict_tree(state_space_analysis_dict)