In [21]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pds
import os
import pickle
import re
from collections import defaultdict

In [22]:
output_dir = '/Users/peirui/code/striatum-microcircuit/striatum-microcircuit/Fig2/output/'
start_time = 1
end_time = 3
msn_size = 2000

In [23]:
def load_data(parms_set,output_dir):
    '''
    load pickle data and organize into dict
    parms_set: different types of parameters sets 
    Control;PD;PD_GluInh';Ctl_Glu
    '''
    data_dict_spk = {'ts_d1': [], 'ts_d2': []}
    data_dict_lfp = {
        'd1_times': [], 'd1_g_ex': [], 'd1_g_in': [], 'd1_V_m': [],
        'd2_times': [], 'd2_g_ex': [], 'd2_g_in': [], 'd2_V_m': []
    }
    # select and sort the q value (1 simulation length 21)
    sorted_filenames = []

    for filename in os.listdir(output_dir):
        char = filename.split('_')
        if char[1] == 'Control':
            sorted_filenames.append((char[4], filename))
    sorted_filenames.sort()
    sorted_filenames = [filename for _, filename in sorted_filenames]

    for sort_file in sorted_filenames:
        char = sort_file.split('_')
        category, subtype = char[0], char[3]
        filepath = os.path.join(output_dir, sort_file)

        # Load the pickle file
        with open(filepath, "rb") as file:
            evs = pickle.load(file)

        # Process 'spk' data
        if category == 'spk' and subtype in ['d1', 'd2']:
            ts = evs['events']['times']
            data_dict_spk[f'ts_{subtype}'].append(ts)

        # Process 'lfp' data
        if category == 'lfp' and subtype in ['d1', 'd2']:
            for key in ['times', 'g_ex', 'g_in', 'V_m']:
                data_dict_lfp[f'{subtype}_{key}'].append(evs[key])
    return data_dict_spk,data_dict_lfp

In [24]:
def proecess_data_spk_firing(data_dict_spk,start_time,end_time,msn_size):
 '''
 Calculate the mean firing rate of D1 and D2 MSNs.
    
  Parameters:
  data_dict_spk (dict): Dictionary containing spike times for D1 and D2 neurons.
  start_time (float): Start time in seconds for calculating the firing rate.
  end_time (float): End time in seconds for calculating the firing rate.
  msn_size (int): Number of neurons in each population (D1 or D2).
  
  Returns:
  tuple: Two lists containing the mean firing rates for D1 and D2 neurons.
 '''
 
 # Calculate the mean firing rate of D1,D2 MSN
 d1_firing=[]
 d2_firing=[]

 duration_secs = float(end_time - start_time)
 start_ms = start_time * 1000
 end_ms = end_time * 1000
 for key, spike_times in data_dict_spk.items():
    
    spikes_in_window = [np.sum((spk_times > start_ms) & (spk_times < end_ms)) for spk_times in spike_times]
    firing_rate = [spk_in_window/ duration_secs / float(msn_size) for spk_in_window in spikes_in_window]

    if key.split('_')[1] == 'd1':
        d1_firing.append(firing_rate)
    elif key.split('_')[1] == 'd2':
        d2_firing.append(firing_rate)

 return d1_firing[0], d2_firing[0]



In [25]:
def process_data_spk_synchrony_index(data_dict_spk,start_time,end_time):
    
    syn_inx_d1 = []
    syn_inx_d2 = []
    start_ms = int(start_time * 1000)
    end_ms = int(end_time * 1000)

    for key, spike_times_list in data_dict_spk.items():
        # Loop through each neuron's spike times
        for spk_times in spike_times_list:
            count_list = []
            
            # Calculate spike counts in 5 ms bins
            for i in range(start_ms, end_ms, 5):
                count = np.sum((spk_times > i) & (spk_times <= i + 5))
                count_list.append(count)
            
            # Calculate the synchrony index as variance/mean of spike counts
            # refer to Yim et al. 2011 Fano factor
            mean_count = np.mean(count_list)
            var_count = np.var(count_list)
            synchrony_index = var_count / mean_count if mean_count > 0 else 0
        
            if key.split('_')[1] == 'd1':
                syn_inx_d1.append(synchrony_index)
            elif key.split('_')[1] == 'd2':
                syn_inx_d2.append(synchrony_index)


    return syn_inx_d1,syn_inx_d2

In [26]:
def z_score_normalize(data):    
    mean = np.mean(data, axis=0)    
    std_dev = np.std(data, axis=0)    
    normalized_data = (data - mean) / std_dev    
    return normalized_data

def compute_synaptic_current(g_ex, v_m_ex, g_in, v_m_in, E_ex1, E_in1):
    '''
    Calculate synaptic currents (I_syn = g_syn * (V_m - E_syn))
    '''
    I_syn_ex = np.sum((v_m_ex - E_ex1) * g_ex)
    I_syn_in = np.sum((v_m_in - E_in1) * g_in)
    return I_syn_ex, I_syn_in

In [27]:
def process_data_lfp(data_dict_lfp, start_time, end_time, E_ex1, E_in1):
    start_ms = int(start_time  * 1000)
    end_ms = int(end_time  * 1000)

    start_idx = int((start_ms-1)*2000)#every 2000 slices is 1 ms 
    end_idx = int((end_ms-1)*2000)

    LFP_list = []

    for gex_d1, gin_d1, vm_d1, gex_d2, gin_d2, vm_d2 in zip(data_dict_lfp['d1_g_ex'], data_dict_lfp['d1_g_in'], 
                                                             data_dict_lfp['d1_V_m'], data_dict_lfp['d2_g_ex'], 
                                                             data_dict_lfp['d2_g_in'], data_dict_lfp['d2_V_m']):
        Syn_d1_ex, Syn_d1_in, Syn_d2_ex, Syn_d2_in = [], [], [], []

        for i in range(start_idx, end_idx, 2000):
            I_syn_d1_ex, I_syn_d1_in = compute_synaptic_current(gex_d1[i-12000:i-10000], vm_d1[i-12000:i-10000], 
                                                                gin_d1[i:i+2000], vm_d1[i:i+2000], E_ex1, E_in1)
            I_syn_d2_ex, I_syn_d2_in = compute_synaptic_current(gex_d2[i-12000:i-10000], vm_d2[i-12000:i-10000], 
                                                                gin_d2[i:i+2000], vm_d2[i:i+2000], E_ex1, E_in1)
            Syn_d1_ex.append(I_syn_d1_ex)
            Syn_d1_in.append(I_syn_d1_in)
            Syn_d2_ex.append(I_syn_d2_ex)
            Syn_d2_in.append(I_syn_d2_in)

        # Calculate local field potential out of point neuron network
        # refer to Mazzoni et al. 2015  reference weighted sum LFP proxy (RWS)
        LFP_d1 = z_score_normalize(abs(np.array(Syn_d1_ex) - 1.65 * np.array(Syn_d1_in)))
        LFP_d2 = z_score_normalize(abs(np.array(Syn_d2_ex) - 1.65 * np.array(Syn_d2_in)))

        LFP_msn  = (LFP_d1+LFP_d2)/2

        LFP_list.append(LFP_msn)

    return LFP_list


In [None]:
def avg_data(data):
    num_slices = len(data) // 21
    slice_list=[]
    for i in range(num_slices):
        slice_list.append(data[i*21:(i+1)*21])
    slice_list = np.array(slice_list)
    slice_avg = np.mean(slice_list,axis=0)

    return slice_avg

In [28]:
data_dict_spk_ctl,data_dict_lfp_ctl = load_data('Control',output_dir)

KeyboardInterrupt: 

In [None]:
d1_firing_ctl,d2_firing_ctl = proecess_data_spk_firing(data_dict_spk_ctl,start_time, end_time,2000)

In [None]:
syn_ctl_d1, syn_ctl_d2 = process_data_spk_synchrony_index(data_dict_spk_ctl,start_time, end_time)

In [None]:
LFP_msn_ctl = process_data_lfp(data_dict_lfp_ctl, start_time, end_time, E_ex1=0, E_in1=-64)

In [19]:
syn_ctl_d1_avg = avg_data(syn_ctl_d1)