In [None]:
#importing libraries
import numpy as np
import pickle as pkl
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter
from caiman.utils.utils import load_dict_from_hdf5

In [None]:

%run info_functions.py

In [None]:
def compute_info_versus_sample_size(spike_train, stimulus_trace, sample_sizes, dt, repetitions, info_measures,shuffle_peaks):

    """
    Computes information content using multiple sample sizes
    
    Arguments
    ----------
    spike_train (np.array)
    stimulus_trace (np.array)
    sample_sizes (np.array): array of sample sizes
    dt (float): Temporal bin size (in seconds)
    repetitions (int): number of repititions for each sample size
    info_measures (np.array): binary array to indicate measures to compute (size 1*3)
    
    
    Returns
    ----------
    results (np.ndarray): information content

    """

    N,T = spike_train.shape
    sample_sizes = sample_fraction*T
    nbr_samples = len(sample_sizes)
   
    #initializing arrays to store information content
    if info_measures[0] or info_measures[1]:
        info_bit_spike_vs_sample = np.full((N, nbr_samples), np.nan, order = 'F')
        shuffle_info_bit_spike_vs_sample = np.full((N, nbr_samples), np.nan, order = 'F')
        info_bit_sec_vs_sample = np.full((N, nbr_samples), np.nan, order = 'F')
        shuffle_info_bit_sec_vs_sample = np.full((N, nbr_samples), np.nan, order = 'F')

    if info_measures[2]:
        info_mi_vs_sample = np.full((N,nbr_samples), np.nan, order = 'F')
        shuffle_info_mi_vs_sample = np.full((N, nbr_samples), np.nan, order = 'F')

    #calculating info for different sample sizes
    for n in range(nbr_samples):

        col_dim = int(np.ceil(repetitions * T / sample_sizes[n]))

        num_time_bins = int(np.floor(sample_sizes[n]))

        if info_measures[0] or info_measures[1]:
            #initializing arrays to store information content
            info_bit_spike = np.full((N, col_dim), np.nan, order = 'F')
            shuffle_info_bit_spike = np.full((N, col_dim), np.nan, order = 'F')
            info_bit_sec = np.full((N, col_dim), np.nan, order = 'F')
            shuffle_info_bit_sec = np.full((N, col_dim), np.nan, order = 'F')
        
        if info_measures[2]:
            #initializing arrays to store information content
            info_mi = np.full((N, col_dim), np.nan, order = 'F')
            shuffle_info_mi = np.full((N, col_dim), np.nan, order = 'F')
     
        for k in range(col_dim):
            #shuffling spike trains
            sample_indexes = np.argsort(np.random.rand(T))[:num_time_bins]
            shuffled_spikes =np.squeeze( shuffling('shift',shuffle_peaks,spike_train=spike_train[:, sample_indexes]))
            
            if info_measures[0] or info_measures[1]:

                #computing tunung curves and calculating information content
                temp_tc, temp_states_distribution = compute_tuning_curves(spike_train[:, sample_indexes], stimulus_trace[sample_indexes], dt)
                temp_fr = np.mean(spike_train[:, sample_indexes], axis=1) / dt

                
                temp_info_bit_spike, temp_info_bit_sec = compute_SI(temp_fr, temp_tc, temp_states_distribution)

                info_bit_spike[:, k] = temp_info_bit_spike
                info_bit_sec[:, k] = temp_info_bit_sec

                temp_shuffled_tc, _ = compute_tuning_curves(shuffled_spikes, stimulus_trace[sample_indexes], dt)
                temp_shuffle_fr = np.mean(shuffled_spikes, axis=1) / dt
                temp_shuffle_info_bit_spike, temp_shuffle_info_bit_sec = compute_SI(temp_shuffle_fr, temp_shuffled_tc, temp_states_distribution)
                shuffle_info_bit_spike[:, k] = temp_shuffle_info_bit_spike
                shuffle_info_bit_sec[:, k] = temp_shuffle_info_bit_sec
             
            if info_measures[2]:
                temp_mi = compute_MI(spike_train[:, sample_indexes], stimulus_trace[sample_indexes])
                info_mi[:, k] = temp_mi
                    
                temp_mi_shuffle = compute_MI(shuffled_spikes, stimulus_trace[sample_indexes])
                shuffle_info_mi[:, k] = temp_mi_shuffle

        if info_measures[0] or info_measures[1]:
            #averaging info content across sample sizes
            info_bit_spike_vs_sample[:, n] = np.nanmean(info_bit_spike, axis=1)
            shuffle_info_bit_spike_vs_sample[:, n] = np.nanmean(shuffle_info_bit_spike, axis=1)
            info_bit_sec_vs_sample[:, n] = np.nanmean(info_bit_sec, axis=1)
            shuffle_info_bit_sec_vs_sample[:, n] = np.nanmean(shuffle_info_bit_sec, axis=1)
 
        if info_measures[2]:
            info_mi_vs_sample[:, n] = np.nanmean(info_mi, axis=1)
            shuffle_info_mi_vs_sample[:, n] = np.nanmean(shuffle_info_mi, axis=1)
                
    results = []
    if info_measures[0] or info_measures[1]:
        results.extend([info_bit_spike_vs_sample, shuffle_info_bit_spike_vs_sample, info_bit_sec_vs_sample, shuffle_info_bit_sec_vs_sample])
    if info_measures[2]:
        results.extend([info_mi_vs_sample, shuffle_info_mi_vs_sample])

    return results       

In [None]:
#loading calcium imaging data
ca_data = load_dict_from_hdf5('/Users/namraaamir/Desktop/AD_hipp_analysis/OnACID_results/OnACID_results.hdf5')

ca_data.keys()

dict_keys(['A', 'C', 'Cn', 'S', 'SNR_comp', 'b', 'cnn_preds', 'dims', 'f', 'r_values'])

In [None]:
#loading behavioral data
with open('/Users/namraaamir/Desktop/AD_hipp_analysis/OnACID_results/aligned_behavior.pkl','rb') as f:
    beh_data = pkl.load(f)

beh_data.keys()

dict_keys(['frame', 'time', 'position', 'velocity_ref', 'velocity', 'reward', 'trials', 'bin_position', 'active', 'reward_location', 'reward_prob'])

In [None]:
#getting position data of mice
pos_data = np.array(beh_data['bin_position'])

#getting ditance of track
distance = np.max(pos_data) - np.min(pos_data)

In [None]:
#getting trials' start indices
trial_idx = np.where(np.diff(pos_data)< -(distance/2))[0] + 1

#checking if last trial is complete
if pos_data[trial_idx[-1]:][-1] < (distance/2):
    trial_idx = np.delete(trial_idx, -1)

In [None]:
#getting active instances
active_spike = ca_data['S'][:, beh_data['active']]
stimulus_trace = pos_data[beh_data['active']]

In [None]:
shuffle_peaks = False
shuffled_activity = shuffling('shift',shuffle_peaks,spike_train=active_spike)

In [None]:

sample_fraction = np.arange(0.1, 1.1, 0.1)

repetitions = 500
info_measures = [1,1,1]

dt = 0.05
average_firing_rates = np.mean(active_spike,axis=1) / dt

In [None]:
tuning_curves, stimulus_distribution  = compute_tuning_curves(active_spike,stimulus_trace,dt)

In [None]:
SI_bit_spike, SI_bit_sec = compute_SI(average_firing_rates,tuning_curves, stimulus_distribution)

In [None]:
MI = compute_MI(active_spike,stimulus_trace)

In [None]:
[SI_naive_bit_spike_versus_sample_size,SI_shuffle_bit_spike_versus_sample_size,SI_naive_bit_sec_versus_sample_size,SI_shuffle_bit_sec_versus_sample_size,MI_naive_versus_sample_size,MI_shuffle_versus_sample_size]= compute_info_versus_sample_size(active_spike, stimulus_trace, sample_fraction, dt, 5000, [1,1,1],shuffle_peaks)