In [47]:
import numpy as np 
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import glob
import pickle
from scipy.linalg import hankel
from scipy.stats import norm, zscore
import statsmodels.api as sm
import matplotx
from matplotlib.backends.backend_pdf import PdfPages
plt.style.use(matplotx.styles.aura['dark'])

In [95]:
with open('test_data_dms_ind_492_0603.pickle', 'rb') as handle:
    data = pickle.load(handle)
print(data.keys())
n_neurons = len(data['spikes'])
print(f'n_neurons: {n_neurons}')


dict_keys(['nCues_RminusL', 'currMaze', 'laserON', 'trialStart', 'trialEnd', 'keyFrames', 'time', 'cueOnset_L', 'cueOnset_R', 'choice', 'trialType', 'spikes', 'timeSqueezedFR'])
n_neurons: 173


In [96]:
neuron = 12

psth_figs = PdfPages("arm_enter_psth_dms.pdf")

for neuron in range(n_neurons):
    trial_indices = np.nonzero((data["currMaze"] > 7)*(data['laserON'] == 0))[0]
    print(f'number of trials: {trial_indices.size}')
    X = []
    y = []
    filt_len = 30
    sp_filt_len = 30
    bin_size = 0.35
    trial_id = []
    bin_positions = []
    bin_centers_all = []
    time_before = 1
    time_after = 1
    spikes_all = []

    for trial_idx in trial_indices:
        keyframes = data["keyFrames"][trial_idx]
        keyframe_times = data["time"][trial_idx][keyframes.astype(int)].tolist()
        trial_start = data["trialStart"][trial_idx]
        # trial_end = data["trialEnd"][trial_idx]
        # lcue_times = data["cueOnset_L"][trial_idx] 
        # rcue_times = data["cueOnset_R"][trial_idx]
        # trial_start = keyframe_times[0] + data["trialStart"][trial_idx]
        trial_end = keyframe_times[2] + trial_start
        # arm_enter = trial_start + keyframe_times[2]
        psth_center = trial_start + keyframe_times[1]
        lcue_times = data["cueOnset_L"][trial_idx] - keyframe_times[0] # remove keyframe subtraction if using all phases of trial
        rcue_times = data["cueOnset_R"][trial_idx] - keyframe_times[0]

        trial_length = trial_end - trial_start
        spikes = data["spikes"][neuron]
        spikes = spikes[(spikes > (psth_center-time_before)) * (spikes < (psth_center + time_after))]
        spikes -= psth_center
        spikes_all.append(spikes)


    fig, axs = plt.subplots(2, 2, figsize=(8, 8))
    axs = axs.ravel()
    bin_edges = np.arange(-14, 15, 7)
    n_ev_level, _ = np.histogram(data['nCues_RminusL'][trial_indices], bins=bin_edges)
    bin_edges[-1] +=0.001
    ev_level = np.digitize(data['nCues_RminusL'][trial_indices], bins=bin_edges) - 1
    max_count = []
    bin_size = 0.02
    for i, ax in enumerate(axs):
        ev_spikes = np.hstack([spikes_all[i] for i in np.nonzero(ev_level==i)[0]])
        time_bins = np.arange(-time_before, time_after, bin_size)
        n, _, = np.histogram(ev_spikes, time_bins)
        ax.bar(np.convolve(time_bins, [0.5, 0.5], mode='valid'), n/(n_ev_level[i] * bin_size), bin_size)
        max_count.append(n.max())
        ax.set_title(f'Evidence: {bin_edges[i]:.0f} to {bin_edges[i+1]:.0f}')
        ax.set_ylabel('Spikes/s')

    max_count = np.max(max_count)
    max_count *= 1.05
    for ax in axs:
        ax.set_ylim([0, max_count])
    plt.suptitle(f'PSTH centered at arm enter | neuron {neuron}')
    plt.tight_layout()

    psth_figs.savefig(fig)
    plt.close(fig)

    print(f'completed neuron {neuron}')
    
psth_figs.close()

number of trials: 165
completed neuron 0
number of trials: 165
completed neuron 1
number of trials: 165
completed neuron 2
number of trials: 165
completed neuron 3
number of trials: 165
completed neuron 4
number of trials: 165
completed neuron 5
number of trials: 165
completed neuron 6
number of trials: 165
completed neuron 7
number of trials: 165
completed neuron 8
number of trials: 165
completed neuron 9
number of trials: 165
completed neuron 10
number of trials: 165
completed neuron 11
number of trials: 165
completed neuron 12
number of trials: 165
completed neuron 13
number of trials: 165
completed neuron 14
number of trials: 165
completed neuron 15
number of trials: 165
completed neuron 16
number of trials: 165
completed neuron 17
number of trials: 165
completed neuron 18
number of trials: 165
completed neuron 19
number of trials: 165


  ax.set_ylim([0, max_count])


completed neuron 20
number of trials: 165
completed neuron 21
number of trials: 165
completed neuron 22
number of trials: 165
completed neuron 23
number of trials: 165
completed neuron 24
number of trials: 165
completed neuron 25
number of trials: 165
completed neuron 26
number of trials: 165
completed neuron 27
number of trials: 165
completed neuron 28
number of trials: 165
completed neuron 29
number of trials: 165
completed neuron 30
number of trials: 165
completed neuron 31
number of trials: 165
completed neuron 32
number of trials: 165
completed neuron 33
number of trials: 165
completed neuron 34
number of trials: 165
completed neuron 35
number of trials: 165
completed neuron 36
number of trials: 165
completed neuron 37
number of trials: 165
completed neuron 38
number of trials: 165
completed neuron 39
number of trials: 165
completed neuron 40
number of trials: 165
completed neuron 41
number of trials: 165
completed neuron 42
number of trials: 165
completed neuron 43
number of tria

In [97]:
time_before = 0
time_after = 0.5
psth_figs = PdfPages("tower_psth_dms.pdf")

for neuron in range(n_neurons):
    spikes_left = []
    spikes_right = []
    for trial_idx in trial_indices:
        keyframes = data["keyFrames"][trial_idx]
        keyframe_times = data["time"][trial_idx][keyframes.astype(int)].tolist()
        trial_start = data["trialStart"][trial_idx]
        # trial_end = data["trialEnd"][trial_idx]
        # lcue_times = data["cueOnset_L"][trial_idx] 
        # rcue_times = data["cueOnset_R"][trial_idx]
        # trial_start = keyframe_times[0] + data["trialStart"][trial_idx]
        trial_end = keyframe_times[2] + trial_start
        # arm_enter = trial_start + keyframe_times[2]
        # psth_center = trial_start + keyframe_times[1]
        lcue_times = data["cueOnset_L"][trial_idx]# - keyframe_times[0] # remove keyframe subtraction if using all phases of trial
        rcue_times = data["cueOnset_R"][trial_idx] #- keyframe_times[0]
        trial_length = trial_end - trial_start
        for l in np.array([lcue_times]).reshape(-1, 1):
            spikes = data["spikes"][neuron]
            psth_center = l + trial_start
            spikes = spikes[(spikes > (psth_center-time_before)) * (spikes < (psth_center + time_after))]
            spikes -= psth_center
            spikes_left.append(spikes)

        for r in np.array([rcue_times]).reshape(-1, 1):
            spikes = data["spikes"][neuron]
            psth_center = r + trial_start
            spikes = spikes[(spikes > (psth_center-time_before)) * (spikes < (psth_center + time_after))]
            spikes -= psth_center
            spikes_right.append(spikes)

    fig, axs = plt.subplots(1, 2, figsize=(8, 4))
    bin_size = 0.01
    time_bins = np.arange(-time_before, time_after+0.01, bin_size)
    max_count = []
    for i, spikes in enumerate([spikes_left, spikes_right]):
        n, _ = np.histogram(np.hstack(spikes), time_bins)
        print(len(spikes))
        rate = n/(len(spikes) * bin_size)
        axs[i].bar(np.convolve(time_bins, [0.5, 0.5], mode='valid'), rate, bin_size)
        max_count.append(rate.max())
        axs[i].set_title('{} tower psth'.format(['left', 'right'][i]))
        axs[i].set_ylabel('Spikes/s')

    max_count = np.max(max_count)
    max_count *= 1.05
    for ax in axs:
        ax.set_ylim([0, max_count])
    plt.suptitle(f'PSTH centered at towers | neuron {neuron}')
    plt.tight_layout()
    psth_figs.savefig(fig)
    plt.close(fig)
    print(f'completed neuron {neuron}')
psth_figs.close()

    

621
910
completed neuron 0
621
910
completed neuron 1
621
910
completed neuron 2
621
910
completed neuron 3
621
910
completed neuron 4
621
910
completed neuron 5
621
910
completed neuron 6
621
910
completed neuron 7
621
910
completed neuron 8
621
910
completed neuron 9
621
910
completed neuron 10
621
910
completed neuron 11
621
910
completed neuron 12
621
910
completed neuron 13
621
910
completed neuron 14
621
910
completed neuron 15
621
910
completed neuron 16
621
910
completed neuron 17
621
910
completed neuron 18
621
910
completed neuron 19
621
910
completed neuron 20
621
910
completed neuron 21
621
910
completed neuron 22
621
910
completed neuron 23
621
910
completed neuron 24
621
910
completed neuron 25
621
910
completed neuron 26
621
910
completed neuron 27
621
910
completed neuron 28
621
910
completed neuron 29
621
910
completed neuron 30
621
910
completed neuron 31
621
910
completed neuron 32
621
910
completed neuron 33
621
910
completed neuron 34
621
910
completed neuron 35
62

In [62]:
l

array([1.3718512, 1.6384308, 2.3556488], dtype=float32)