# Look at all the sessions for Blanche
- How many trials per session?
- How many blocks?

In [1]:
%load_ext autoreload
%autoreload 2

import os
import numpy as np
import pandas as pd
import utils.behavioral_utils as behavioral_utils
import utils.information_utils as information_utils
import utils.visualization_utils as visualization_utils
import utils.glm_utils as glm_utils
from matplotlib import pyplot as plt
import utils.spike_utils as spike_utils
from constants.glm_constants import *
from constants.behavioral_constants import *

import seaborn as sns

In [8]:
def valid_trials_blanche(beh):
    last_block = beh.BlockNumber.max()
    valid_beh = beh[
        (beh.Response.isin(["Correct", "Incorrect"])) & 
        (beh.BlockNumber >= 1) &
        (beh.BlockNumber != last_block)
    ]
    return valid_beh

def load_subject_data(row, beh_path):
    session = row.session_name
    behavior_path = beh_path.format(sess_name=session)
    beh = pd.read_csv(behavior_path)
    beh = valid_trials_blanche(beh)
    feature_selections = behavioral_utils.get_selection_features(beh)
    beh = pd.merge(beh, feature_selections, on="TrialNumber", how="inner")
    beh["session"] = session
    return beh

In [9]:
# Monkey B
sessions = pd.DataFrame({"session_name": [20190123, 20190124, 20190125, 20190128, 20190312, 20190313, 20190329]})
beh_path = "/data/rawdata/sub-BL/sess-{sess_name}/behavior/sub-BL_sess-{sess_name}_object_features.csv"
blanche_res = pd.concat(sessions.apply(lambda x: load_subject_data(x, beh_path), axis=1).values)

In [10]:
sessions.to_pickle("/data/patrick_res/sessions/all_sessions_blanche.pickle")

### Num trials per session

In [11]:
blanche_res.groupby("session").TrialNumber.nunique()

session
20190123    143
20190124    394
20190125    486
20190128    558
20190312    240
20190313    142
20190329    277
Name: TrialNumber, dtype: int64

### Num blocks per session

In [12]:
blanche_res.groupby("session").BlockNumber.nunique()

session
20190123     3
20190124     8
20190125    12
20190128    13
20190312     8
20190313     3
20190329     4
Name: BlockNumber, dtype: int64

In [7]:
blanche_res.groupby(["session", "CurrentRule"]).BlockNumber.nunique()

session   CurrentRule
20190123  CYAN           1
          MAGENTA        1
          RIPPLE         1
          SQUARE         1
          YELLOW         1
20190124  CYAN           1
          ESCHER         1
          MAGENTA        1
          POLKADOT       2
          RIPPLE         2
          SQUARE         1
          STAR           1
          YELLOW         1
20190125  CIRCLE         1
          CYAN           3
          ESCHER         1
          GREEN          2
          MAGENTA        1
          POLKADOT       1
          RIPPLE         2
          SQUARE         1
          STAR           1
          SWIRL          1
20190128  CIRCLE         1
          CYAN           1
          ESCHER         2
          GREEN          1
          MAGENTA        1
          RIPPLE         1
          SQUARE         2
          STAR           4
          TRIANGLE       1
          YELLOW         1
20190312  CIRCLE         2
          ESCHER         1
          MAGENTA        1
      

### Is the first block always the same?

In [14]:
def first_block_rule(beh):
    row = {}
    for i in range(4): 
        block = beh[beh.BlockNumber == i]
        block_rule = block.CurrentRule.iloc[0]
        block_length = len(block)
        row[f"block {i} rule"] = block_rule
        row[f"block {i} length"] = block_length
    return pd.Series(row)

blanche_res.groupby("session", group_keys=True).apply(first_block_rule).to_csv("/data/patrick_res/behavior/blanche_first_few_blocks.csv")

### Look at pairs of rules, how many do we have? 

In [26]:
pairs = behavioral_utils.get_good_pairs_across_sessions(blanche_res, 1).sort_values(by="num_sessions", ascending=False)[:10]
pairs.style.hide()


pair,sessions,num_sessions,dim_type
"['SQUARE', 'MAGENTA']",[20190124 20190125 20190128 20190312],4,across dim
"['CYAN', 'MAGENTA']",[20190123 20190124 20190125 20190128],4,within dim
"['STAR', 'MAGENTA']",[20190124 20190125 20190128 20190312],4,across dim
"['CYAN', 'YELLOW']",[20190123 20190124 20190128 20190329],4,within dim
"['SQUARE', 'STAR']",[20190124 20190125 20190128 20190312],4,within dim
"['MAGENTA', 'YELLOW']",[20190123 20190124 20190128 20190312],4,within dim
"['SQUARE', 'YELLOW']",[20190124 20190128 20190312],3,across dim
"['CYAN', 'GREEN']",[20190125 20190128 20190329],3,within dim
"['STAR', 'ESCHER']",[20190124 20190125 20190128],3,across dim
"['STAR', 'YELLOW']",[20190124 20190128 20190312],3,across dim


### Some blanche spikes debugging code

In [27]:
import numpy as np
import pandas as pd
from spike_tools import (
    general as spike_general,
    analysis as spike_analysis,
)
import utils.behavioral_utils as behavioral_utils
import utils.spike_utils as spike_utils
import os

In [30]:
PRE_INTERVAL = 500
POST_INTERVAL = 500
INTERVAL_SIZE = 50
NUM_BINS_SMOOTH = 1
EVENT = "FixationOnCross"
SUBJECT = "BL"

In [31]:
sess_name = "20190123"

behavior_path = f"/data/rawdata/sub-SA/sess-{sess_name}/behavior/sub-SA_sess-{sess_name}_object_features.csv"
beh = pd.read_csv(behavior_path)
valid_beh = beh[beh.Response.isin(["Correct", "Incorrect"])]
spike_times = spike_general.get_spike_times(None, SUBJECT, sess_name, species_dir="/data")

print("Calculating spikes by trial interval")
interval_size_secs = INTERVAL_SIZE / 1000
intervals = behavioral_utils.get_trial_intervals(valid_beh, EVENT, PRE_INTERVAL, POST_INTERVAL)

spike_by_trial_interval = spike_utils.get_spikes_by_trial_interval(spike_times, intervals)
end_bin = (PRE_INTERVAL + POST_INTERVAL) / 1000 + interval_size_secs

Calculating spikes by trial interval


In [43]:
len(spike_by_trial_interval.TrialNumber.unique())

579

In [34]:
all_units = spike_general.list_session_units(None, SUBJECT, sess_name, species_dir="/data")


In [35]:
all_units

Unnamed: 0,Channel,Unit,SpikeTimesFile,UnitID
5,102,1,/data/rawdata/sub-BL/sess-20190123/spikes/sub-...,0
0,104,1,/data/rawdata/sub-BL/sess-20190123/spikes/sub-...,1
7,104,2,/data/rawdata/sub-BL/sess-20190123/spikes/sub-...,2
4,109,1,/data/rawdata/sub-BL/sess-20190123/spikes/sub-...,3
2,109,2,/data/rawdata/sub-BL/sess-20190123/spikes/sub-...,4
3,18,1,/data/rawdata/sub-BL/sess-20190123/spikes/sub-...,5
11,31,1,/data/rawdata/sub-BL/sess-20190123/spikes/sub-...,6
10,48,1,/data/rawdata/sub-BL/sess-20190123/spikes/sub-...,7
1,62,1,/data/rawdata/sub-BL/sess-20190123/spikes/sub-...,8
6,79,1,/data/rawdata/sub-BL/sess-20190123/spikes/sub-...,9


In [40]:
len(valid_beh.TrialNumber.unique())

904

In [37]:
firing_rates = spike_analysis.firing_rate(
    spike_by_trial_interval, 
    all_units, 
    bins=np.arange(0, end_bin, interval_size_secs), 
    smoothing=NUM_BINS_SMOOTH,
    trials=valid_beh.TrialNumber.unique()
)

AssertionError: 