In [2]:
%load_ext autoreload
%autoreload 2

import os
import numpy as np
import pandas as pd
import utils.pseudo_utils as pseudo_utils
import utils.pseudo_classifier_utils as pseudo_classifier_utils
import utils.behavioral_utils as behavioral_utils
from utils.session_data import SessionData
import utils.io_utils as io_utils

import json

from spike_tools import (
    general as spike_general,
    analysis as spike_analysis,
)

import matplotlib.pyplot as plt
import matplotlib

from dPCA.dPCA import dPCA

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
# the output directory to store the data
OUTPUT_DIR = "/data/patrick_res/pseudo"
# path to a dataframe of sessions to analyze
# SESSIONS_PATH = "/data/patrick_scratch/multi_sess/valid_sessions.pickle"
SESSIONS_PATH = "/data/patrick_res/multi_sess/valid_sessions_rpe.pickle"
# path for each session, specifying behavior
SESS_BEHAVIOR_PATH = "/data/rawdata/sub-SA/sess-{sess_name}/behavior/sub-SA_sess-{sess_name}_object_features.csv"
# path for each session, for spikes that have been pre-aligned to event time and binned. 
SESS_SPIKES_PATH = "/data/patrick_res/multi_sess/{sess_name}/{sess_name}_firing_rates_{pre_interval}_{event}_{post_interval}_{interval_size}_bins_1_smooth.pickle"

FEATURE_DIMS = ["Color", "Shape", "Pattern"]

In [4]:
sessions = pd.read_pickle(SESSIONS_PATH)

In [12]:
def calc_psth_per_session(row, conditions):
    sess_name = row.session_name
    beh, frs = io_utils.load_rpe_sess_beh_and_frs(sess_name)
    if frs.SpikeCounts.isnull().values.any():
        print("null valuessssss")
    mode = "FiringRate"
    def zscore_unit(group):
        mean = group[mode].mean()
        std = group[mode].std()
        group[f"Z{mode}"] = 0 if std == 0 else (group[mode] - mean) / std
        return group
    frs = frs.groupby(["UnitID", "TimeBins"]).apply(zscore_unit)
    if frs.ZFiringRate.isnull().values.any():
        raise ValueError("Why are there null values after zscoring")
    merged = pd.merge(beh[conditions], frs, on="TrialNumber")
    group_conds = conditions + ["UnitID", "TimeBins"]
    psth = merged.groupby(group_conds).mean()["ZFiringRate"].reset_index()
    psth["PseudoUnitID"] = int(sess_name) * 100 + psth["UnitID"]
    return psth

In [13]:
dim_psths = []
for feature_dim in FEATURE_DIMS:
    conditions = ["RPEGroup", feature_dim]
    dim_psth = pd.concat(sessions.apply(lambda x: calc_psth_per_session(x, conditions), axis=1).values)
    dim_psth = dim_psth.rename(columns={feature_dim: "Feature"})
    dim_psths.append(dim_psth)
full_psth = pd.concat(dim_psths)

To preserve the previous behavior, use

	>>> .groupby(..., group_keys=False)


	>>> .groupby(..., group_keys=True)
  frs = frs.groupby(["UnitID", "TimeBins"]).apply(zscore_unit)
To preserve the previous behavior, use

	>>> .groupby(..., group_keys=False)


	>>> .groupby(..., group_keys=True)
  frs = frs.groupby(["UnitID", "TimeBins"]).apply(zscore_unit)
To preserve the previous behavior, use

	>>> .groupby(..., group_keys=False)


	>>> .groupby(..., group_keys=True)
  frs = frs.groupby(["UnitID", "TimeBins"]).apply(zscore_unit)
To preserve the previous behavior, use

	>>> .groupby(..., group_keys=False)


	>>> .groupby(..., group_keys=True)
  frs = frs.groupby(["UnitID", "TimeBins"]).apply(zscore_unit)
To preserve the previous behavior, use

	>>> .groupby(..., group_keys=False)


	>>> .groupby(..., group_keys=True)
  frs = frs.groupby(["UnitID", "TimeBins"]).apply(zscore_unit)
To preserve the previous behavior, use

	>>> .groupby(..., group_keys=False)


	>>> .groupby(..., group_keys=T

In [8]:
psth_shapes = full_psth[full_psth.Feature.isin(["CIRCLE", "SQUARE", "STAR", "TRIANGLE"])]

In [9]:
psth_shapes.sort_values(by=["PseudoUnitID", "TimeBins", ""])

Unnamed: 0,RPEGroup,Feature,UnitID,TimeBins,ZSpikeCounts,PseudoUnitID
0,less neg,CIRCLE,0,0.0,0.207895,2018070900
1,less neg,CIRCLE,0,0.1,-0.046600,2018070900
2,less neg,CIRCLE,0,0.2,0.178076,2018070900
3,less neg,CIRCLE,0,0.3,0.024526,2018070900
4,less neg,CIRCLE,0,0.4,0.071063,2018070900
...,...,...,...,...,...,...
13435,more pos,TRIANGLE,29,2.3,-0.091495,2018091029
13436,more pos,TRIANGLE,29,2.4,-0.082526,2018091029
13437,more pos,TRIANGLE,29,2.5,-0.089189,2018091029
13438,more pos,TRIANGLE,29,2.6,0.166174,2018091029
