# Look at all the sessions for Blanche
- How many trials per session?
- How many blocks?

In [54]:
%load_ext autoreload
%autoreload 2

import os
import numpy as np
import pandas as pd
import utils.behavioral_utils as behavioral_utils
import utils.information_utils as information_utils
import utils.visualization_utils as visualization_utils
import utils.glm_utils as glm_utils
from matplotlib import pyplot as plt
import utils.spike_utils as spike_utils
from constants.glm_constants import *
from constants.behavioral_constants import *

import seaborn as sns

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [55]:
def valid_trials_blanche(beh):
    last_block = beh.BlockNumber.max()
    valid_beh = beh[
        (beh.Response.isin(["Correct", "Incorrect"])) & 
        (beh.BlockNumber >= 1) &
        (beh.BlockNumber != last_block)
    ]
    return valid_beh

def load_subject_data(row, beh_path):
    session = row.session_name
    behavior_path = beh_path.format(sess_name=session)
    beh = pd.read_csv(behavior_path)
    beh = valid_trials_blanche(beh)
    feature_selections = behavioral_utils.get_selection_features(beh)
    beh = pd.merge(beh, feature_selections, on="TrialNumber", how="inner")
    beh["session"] = session
    return beh

In [56]:
# Monkey B
# sessions = pd.DataFrame({"session_name": [20190123, 20190124, 20190125, 20190128, 20190312, 20190313, 20190329]})
sessions = pd.read_pickle("/data/patrick_res/sessions/BL/valid_sessions_61.pickle")
beh_path = "/data/rawdata/sub-BL/sess-{sess_name}/behavior/sub-BL_sess-{sess_name}_object_features.csv"
blanche_res = pd.concat(sessions.apply(lambda x: load_subject_data(x, beh_path), axis=1).values)

In [10]:
# sessions.to_pickle("/data/patrick_res/sessions/all_sessions_blanche.pickle")

### Num trials per session

In [5]:
blanche_res.groupby("session").TrialNumber.nunique()

session
20190123     143
20190124     394
20190125     486
20190128     558
20190129     466
20190130     529
20190131      46
20190201     567
20190206     587
20190207     613
20190214     429
20190215     681
20190220     766
20190221     846
20190226     957
20190227     543
20190228     924
20190312     240
20190313     142
20190314     390
20190318     270
20190319     414
20190320     333
20190321     358
20190325     329
20190326     330
20190328     289
20190329     277
20190522     528
20190524    1009
20190529     984
20190530     999
20190531    1174
20190603     819
20190605     976
20190606    1024
20190607     986
20190611     696
Name: TrialNumber, dtype: int64

### Num blocks per session

In [58]:
blanche_res.groupby("session").BlockNumber.nunique()

session
20190123     3
20190124     8
20190125    12
20190128    13
20190129    10
            ..
20190821    24
20190823    28
20190917    15
20191010     5
20191031     1
Name: BlockNumber, Length: 61, dtype: int64

In [7]:
blanche_res.groupby(["session", "CurrentRule"]).BlockNumber.nunique()

session   CurrentRule
20190123  CYAN           1
          MAGENTA        1
          YELLOW         1
20190124  CYAN           1
          ESCHER         1
                        ..
20190611  POLKADOT       2
          RIPPLE         3
          SQUARE         2
          STAR           2
          TRIANGLE       2
Name: BlockNumber, Length: 290, dtype: int64

### Is the first block always the same?

In [14]:
def first_block_rule(beh):
    row = {}
    for i in range(4): 
        block = beh[beh.BlockNumber == i]
        block_rule = block.CurrentRule.iloc[0]
        block_length = len(block)
        row[f"block {i} rule"] = block_rule
        row[f"block {i} length"] = block_length
    return pd.Series(row)

blanche_res.groupby("session", group_keys=True).apply(first_block_rule).to_csv("/data/patrick_res/behavior/blanche_first_few_blocks.csv")

### Look at pairs of rules, how many do we have? 

In [59]:
pairs = behavioral_utils.get_good_pairs_across_sessions(blanche_res, 3).sort_values(by="num_sessions", ascending=False)


In [60]:
session_priorities_tier1 = pairs[pairs.num_sessions >= 5].sessions.explode().value_counts().reset_index()
session_priorities_tier1 = session_priorities_tier1.rename(columns={"index": "session", "sessions": "priority"})
session_priorities_tier1.to_csv("/data/patrick_res/tmp/sess_priorities_tier1.csv")

In [61]:
session_priorities_tier1

Unnamed: 0,session,priority
0,20190529,3
1,20190617,2
2,20190627,2
3,20190823,2
4,20190207,1
5,20190220,1
6,20190531,1
7,20190611,1
8,20190710,1
9,20190816,1


In [62]:
session_priorities_tier2 = pairs[pairs.num_sessions >= 4].sessions.explode().value_counts().reset_index()
session_priorities_tier2 = session_priorities_tier2.rename(columns={"index": "session", "sessions": "priority"})
session_priorities_tier2 = session_priorities_tier2[~session_priorities_tier2.session.isin(session_priorities_tier1.session)]
session_priorities_tier2.to_csv("/data/patrick_res/tmp/sess_priorities_tier2.csv")

In [63]:
session_priorities_tier2

Unnamed: 0,session,priority
3,20190605,3
9,20190226,2
10,20190524,2
11,20190201,1
12,20190227,1
13,20190815,1
14,20190708,1
15,20190228,1
17,20190703,1
21,20190215,1


### Spot check sessions

In [30]:
session = "20190814"
behavior_path = beh_path.format(sess_name=session)
beh = pd.read_csv(behavior_path)
beh = valid_trials_blanche(beh)

In [32]:
beh.groupby("CurrentRule").BlockNumber.nunique()

CurrentRule
CIRCLE      1
ESCHER      3
MAGENTA     1
POLKADOT    1
RIPPLE      1
SQUARE      2
STAR        3
SWIRL       3
TRIANGLE    3
YELLOW      1
Name: BlockNumber, dtype: int64

In [6]:
# good_pairs = pairs[pairs.num_sessions >= 3]
# good_pairs.to_pickle("/data/patrick_res/sessions/BL/pairs_at_least_1blocks_3sess.pickle")
good_pairs = pairs[pairs.num_sessions >= 1]
good_pairs.to_pickle("/data/patrick_res/sessions/BL/pairs_at_least_2blocks_1sess.pickle")

In [7]:
good_pairs

Unnamed: 0,pair,sessions,num_sessions,dim_type
38,"[CYAN, GREEN]",[20190125],1,within dim
11,"[SQUARE, STAR]",[20190128],1,within dim


In [11]:
len(good_pairs)

19

### Some blanche spikes debugging code

In [1]:
import numpy as np
import pandas as pd
from spike_tools import (
    general as spike_general,
    analysis as spike_analysis,
)
import utils.behavioral_utils as behavioral_utils
import utils.spike_utils as spike_utils
import os

In [2]:
PRE_INTERVAL = 500
POST_INTERVAL = 500
INTERVAL_SIZE = 50
NUM_BINS_SMOOTH = 1
EVENT = "FixationOnCross"
SUBJECT = "BL"

In [3]:
sess_name = "20190123"

behavior_path = f"/data/rawdata/sub-SA/sess-{sess_name}/behavior/sub-SA_sess-{sess_name}_object_features.csv"
beh = pd.read_csv(behavior_path)
valid_beh = beh[beh.Response.isin(["Correct", "Incorrect"])]
spike_times = spike_general.get_spike_times(None, SUBJECT, sess_name, species_dir="/data")

print("Calculating spikes by trial interval")
interval_size_secs = INTERVAL_SIZE / 1000
intervals = behavioral_utils.get_trial_intervals(valid_beh, EVENT, PRE_INTERVAL, POST_INTERVAL)

spike_by_trial_interval = spike_utils.get_spikes_by_trial_interval(spike_times, intervals)
end_bin = (PRE_INTERVAL + POST_INTERVAL) / 1000 + interval_size_secs

Calculating spikes by trial interval


In [43]:
len(spike_by_trial_interval.TrialNumber.unique())

579

In [34]:
all_units = spike_general.list_session_units(None, SUBJECT, sess_name, species_dir="/data")


In [35]:
all_units

Unnamed: 0,Channel,Unit,SpikeTimesFile,UnitID
5,102,1,/data/rawdata/sub-BL/sess-20190123/spikes/sub-...,0
0,104,1,/data/rawdata/sub-BL/sess-20190123/spikes/sub-...,1
7,104,2,/data/rawdata/sub-BL/sess-20190123/spikes/sub-...,2
4,109,1,/data/rawdata/sub-BL/sess-20190123/spikes/sub-...,3
2,109,2,/data/rawdata/sub-BL/sess-20190123/spikes/sub-...,4
3,18,1,/data/rawdata/sub-BL/sess-20190123/spikes/sub-...,5
11,31,1,/data/rawdata/sub-BL/sess-20190123/spikes/sub-...,6
10,48,1,/data/rawdata/sub-BL/sess-20190123/spikes/sub-...,7
1,62,1,/data/rawdata/sub-BL/sess-20190123/spikes/sub-...,8
6,79,1,/data/rawdata/sub-BL/sess-20190123/spikes/sub-...,9


In [40]:
len(valid_beh.TrialNumber.unique())

904

In [37]:
firing_rates = spike_analysis.firing_rate(
    spike_by_trial_interval, 
    all_units, 
    bins=np.arange(0, end_bin, interval_size_secs), 
    smoothing=NUM_BINS_SMOOTH,
    trials=valid_beh.TrialNumber.unique()
)

AssertionError: 