In [2]:
%load_ext autoreload
%autoreload 2

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

import utils.spike_utils as spike_utils
import utils.classifier_utils as classifier_utils
import utils.visualization_utils as visualization_utils
import utils.behavioral_utils as behavioral_utils
import utils.io_utils as io_utils

import os
import pandas as pd
import matplotlib

from sklearn.linear_model import LinearRegression

EVENT = "FeedbackOnset"  # event in behavior to align on
PRE_INTERVAL = 1300   # time in ms before event
POST_INTERVAL = 1500  # time in ms after event
INTERVAL_SIZE = 100  # size of interval in ms

SESS_BEHAVIOR_PATH = "/data/rawdata/sub-SA/sess-{sess_name}/behavior/sub-SA_sess-{sess_name}_object_features.csv"
SESS_SPIKES_PATH = "/data/patrick_res/multi_sess/{sess_name}/{sess_name}_firing_rates_{pre_interval}_{event}_{post_interval}_{interval_size}_bins_1_smooth.pickle"
SESSIONS_PATH = "/data/patrick_res/multi_sess/valid_sessions_rpe.pickle"

feature_dims = ["Color", "Shape", "Pattern"]

In [3]:
sessions = pd.read_pickle(SESSIONS_PATH)

In [None]:
conditions = ["RPEGroup", "Pattern"]

def calc_psth_per_session(row, conditions):
    sess_name = row.session_name
    beh, frs = io_utils.load_rpe_sess_beh_and_frs(sess_name)
    mode = "SpikeCounts"
    def zscore_unit(group):
        mean = group[mode].mean()
        std = group[mode].std()
        group[f"Z{mode}"] = (group[mode] - mean) / std
        return group
    frs = frs.groupby(["UnitID", "TimeBins"]).apply(zscore_unit)
    merged = pd.merge(beh[conditions], frs, on="TrialNumber")
    group_conds = conditions + ["UnitID", "TimeBins"]
    psth = merged.groupby(group_conds).mean()["ZSpikeCounts"].reset_index()
    psth["PseudoUnitID"] = int(sess_name) * 100 + psth["UnitID"]
    return psth
full_psth = pd.concat(sessions.apply(lambda x: calc_psth_per_session(x, conditions), axis=1).values)

In [None]:
full_psth.groupby(conditions + ["TimeBins"]).apply(lambda x: print(len(x.PseudoUnitID.unique())))

In [35]:
pos = spike_utils.get_unit_positions(sessions)

In [36]:
pos

Unnamed: 0,Channel,Unit,SpikeTimesFile,UnitID,electrode_id,x,y,z,distance,in_brain,...,structure_level1,structure_level2,structure_level3,structure_level4,structure_level5,structure_level6,structure_potential,session,PseudoUnitID,manual_structure
0,100,1,/data/rawdata/sub-SA/sess-20180709/spikes/sub-...,0,100,13.535342,-55.049086,19.90947,27.5,True,...,telencephalon (tel),lateral_and_ventral_pallium (LVPal),lateral_pallium (LPal),claustrum (Cl),claustrum (Cl),claustrum (Cl),[],20180709,2018070900,Claustrum
1,108,1,/data/rawdata/sub-SA/sess-20180709/spikes/sub-...,1,108,13.699304,-54.08206,21.82937,28.75,True,...,telencephalon (tel),lateral_and_ventral_pallium (LVPal),lateral_pallium (LPal),claustrum (Cl),claustrum (Cl),claustrum (Cl),[],20180709,2018070901,Claustrum
2,109,1,/data/rawdata/sub-SA/sess-20180709/spikes/sub-...,2,109,15.204069,-54.488859,21.577754,27.5,True,...,telencephalon (tel),lateral_and_ventral_pallium (LVPal),lateral_pallium (LPal),claustrum (Cl),claustrum (Cl),claustrum (Cl),[],20180709,2018070902,Claustrum
3,10a,1,/data/rawdata/sub-SA/sess-20180709/spikes/sub-...,3,10a,0.789428,-75.75489,38.106774,9,True,...,Frontal_Lobe (Frontal),lateral_prefrontal_cortex (lat_PFC),dorsolateral_prefrontal_cortex (dlPFC),area_8B (area_8B),area_8B (area_8B),medial_area_8B (area_8Bm),dlPFC,20180709,2018070903,Prefrontal Cortex
4,10a,2,/data/rawdata/sub-SA/sess-20180709/spikes/sub-...,4,10a,0.789428,-75.75489,38.106774,9,True,...,Frontal_Lobe (Frontal),lateral_prefrontal_cortex (lat_PFC),dorsolateral_prefrontal_cortex (dlPFC),area_8B (area_8B),area_8B (area_8B),medial_area_8B (area_8Bm),dlPFC,20180709,2018070904,Prefrontal Cortex
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
25,90,1,/data/rawdata/sub-SA/sess-20180910/spikes/sub-...,25,90,10.452756,-52.703145,19.318001,31.125,True,...,telencephalon (tel),amygdala (Amy),pallial_amygdala (pAmy),lateropallial_amygdala (lpAmy),lateral_amygdaloid_nucleus (La),lateral_dorsal_amygdaloid_nucleus (LaD),Hippocampal,20180910,2018091025,Amygdala
26,92,1,/data/rawdata/sub-SA/sess-20180910/spikes/sub-...,26,92,13.520363,-51.91999,19.341959,30.625,True,...,telencephalon (tel),basal_ganglia (BG),striatum (Str),dorsal_striatum (DStr),caudate (Cd),caudate_tail (CdT),Hippocampal,20180910,2018091026,Basal Ganglia
27,95,1,/data/rawdata/sub-SA/sess-20180910/spikes/sub-...,27,95,5.433664,-68.902483,15.922112,16,True,...,telencephalon (tel),basal_ganglia (BG),striatum (Str),dorsal_striatum (DStr),caudate (Cd),caudate_head (CdH),[],20180910,2018091027,Basal Ganglia
28,99,1,/data/rawdata/sub-SA/sess-20180910/spikes/sub-...,28,99,12.140106,-51.630938,21.155323,31.75,True,...,telencephalon (tel),amygdala (Amy),pallial_amygdala (pAmy),lateropallial_amygdala (lpAmy),lateral_amygdaloid_nucleus (La),lateral_dorsal_amygdaloid_nucleus (LaD),Hippocampal,20180910,2018091028,Amygdala


In [18]:
sess_name = "20180802"
beh, frs = io_utils.load_rpe_sess_beh_and_frs(sess_name)

In [19]:
frs

Unnamed: 0_level_0,UnitID,TimeBins,SpikeCounts,FiringRate
TrialNumber,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
0.0,0,0.0,1.0,9.368775
0.0,0,0.1,1.0,6.993379
0.0,0,0.2,0.0,3.005283
0.0,0,0.3,0.0,0.585568
0.0,0,0.4,0.0,0.045657
...,...,...,...,...
1749.0,45,2.3,0.0,6.504820
1749.0,45,2.4,0.0,3.635169
1749.0,45,2.5,1.0,4.664978
1749.0,45,2.6,0.0,5.383355


In [20]:
mode = "SpikeCounts"
def zscore_unit(group):
    mean = group[mode].mean()
    std = group[mode].std()
    group[f"Z{mode}"] = (group[mode] - mean) / std
    return group
frs = frs.groupby(["UnitID", "TimeBins"]).apply(zscore_unit)

To preserve the previous behavior, use

	>>> .groupby(..., group_keys=False)


	>>> .groupby(..., group_keys=True)
  frs = frs.groupby(["UnitID", "TimeBins"]).apply(zscore_unit)


In [21]:
frs

Unnamed: 0_level_0,UnitID,TimeBins,SpikeCounts,FiringRate,ZSpikeCounts
TrialNumber,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
0.0,0,0.0,1.0,9.368775,1.637521
0.0,0,0.1,1.0,6.993379,1.948359
0.0,0,0.2,0.0,3.005283,-0.250590
0.0,0,0.3,0.0,0.585568,-0.229230
0.0,0,0.4,0.0,0.045657,-0.197749
...,...,...,...,...,...
1749.0,45,2.3,0.0,6.504820,-0.924937
1749.0,45,2.4,0.0,3.635169,-0.912903
1749.0,45,2.5,1.0,4.664978,0.409594
1749.0,45,2.6,0.0,5.383355,-0.898966


In [22]:
conditions = ["Color", "RPEGroup"]
merged = pd.merge(beh[conditions], frs, on="TrialNumber")

In [23]:
group_conds = conditions + ["UnitID", "TimeBins"]
psth = merged.groupby(group_conds).mean()["ZSpikeCounts"].reset_index()
psth["PseudoUnitID"] = int(sess_name) + psth["UnitID"] * 100

In [24]:
psth

Unnamed: 0,Color,RPEGroup,UnitID,TimeBins,ZSpikeCounts,PseudoUnitID
0,CYAN,less neg,0,0.0,0.043465,20180802
1,CYAN,less neg,0,0.1,0.054362,20180802
2,CYAN,less neg,0,0.2,-0.076807,20180802
3,CYAN,less neg,0,0.3,-0.144036,20180802
4,CYAN,less neg,0,0.4,0.044996,20180802
...,...,...,...,...,...,...
20603,YELLOW,more pos,45,2.3,0.105872,20185302
20604,YELLOW,more pos,45,2.4,-0.029498,20185302
20605,YELLOW,more pos,45,2.5,-0.110191,20185302
20606,YELLOW,more pos,45,2.6,-0.078440,20185302
