### Do an exploration of other sessions, see if utils and stuff work for them

In [10]:
%load_ext autoreload
%autoreload 2

import glob
import os
from datetime import datetime
import pandas as pd
from spike_tools import (
    general as spike_general,
    analysis as spike_analysis,
)
from lfp_tools import (
    startup as lfp_startup,
)
import json

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Grab sessions and put info in a dataframe

In [9]:
# NOTE: this is hacky code to replicate information that should already be stored in a Datajoint table, but don't have access atm
sess_folder_path = "/data/rawdata/sub-SA/"
sess_paths = glob.glob(f"{sess_folder_path}/sess-*")
session_names = [os.path.split(sess_path)[1].split("-")[1] for sess_path in sess_paths]

rows = []
for sess_name in session_names:
    if not sess_name.isdigit():
        continue
    # hacky way to grab a datetime
    date = datetime.strptime(sess_name[:8], "%Y%m%d").date()
    rest = sess_name[8:]
    count = int(rest) if rest else 0
    rows.append({
        "session_datetime": date,
        "session_count": count,
        "session_name": sess_name,
    })
sess_df = pd.DataFrame(rows)

### Filter sessions by num neurons, num trials

In [10]:
NUM_NEURONS = 20
NUM_TRIALS = 500
VALID_SESS_BEFORE = datetime.strptime("20181015", "%Y%m%d").date()

In [None]:
def check_num_trials(sess):
    sess_name = sess.session_name
    behavior_path = f"/data/rawdata/sub-SA/sess-{sess_name}/behavior/sub-SA_sess-{sess_name}_object_features.csv"
    if not os.path.isfile(behavior_path):
        return False
    beh = pd.read_csv(behavior_path)
    valid_beh = beh[beh.Response.isin(["Correct", "Incorrect"])]   
    return len(valid_beh) > 500

def check_date(sess):
    return sess.session_datetime < VALID_SESS_BEFORE

def check_num_neurons(sess):
    spike_dir_path = f"/data/rawdata/sub-SA/sess-{sess.session_name}/spikes"
    print(spike_dir_path)
    if not os.path.isdir(spike_dir_path):
        return False
    spike_times = spike_general.get_spike_times(None, "SA", sess.session_name, species_dir="/data")
    return len(spike_times.UnitID.unique()) > NUM_NEURONS

def filter_sessions(sess):
    print(sess.session_name)
    return check_date(sess) and check_num_trials(sess) and check_num_neurons(sess)


sess_df["valid"] = sess_df.apply(filter_sessions, axis=1)

### Find 36 Session before task change date, with at least 500 trials and 20 neurons

In [19]:
len(sess_df[sess_df.valid])

36

### Look at electrode positions again, weights

In [3]:
def get_electrode_locations(row):
    session = row.session_name
    if len(session)==12: #For the cases like 201807250001
        session = session[:8]
    info_path = f"/data/rawdata/sub-SA/sess-{session}/session_info/sub-SA_sess-{session}_sessioninfo.json"
    with open(info_path, 'r') as f:
        data = json.load(f)
    locs = data['electrode_info']
    locs_df = pd.DataFrame.from_dict(locs)
    electrode_pos_not_nan = locs_df[~locs_df['x'].isna() & ~locs_df['y'].isna() & ~locs_df['z'].isna()]
    units = spike_general.list_session_units(None, "SA", "20180802", species_dir="/data")
    unit_pos = pd.merge(units, electrode_pos_not_nan, left_on="Channel", right_on="electrode_id", how="left")
    unit_pos = unit_pos.astype({"UnitID": int})
    return locs_df

In [4]:
valid_sess = pd.read_pickle("/data/patrick_scratch/multi_sess/valid_sessions.pickle")
all_sess_locations = valid_sess.apply(get_electrode_locations, axis=1)

In [18]:
session = "20180802"
if len(session)==12: #For the cases like 201807250001
    session = session[:8]
info_path = f"/data/rawdata/sub-SA/sess-{session}/session_info/sub-SA_sess-{session}_sessioninfo.json"
with open(info_path, 'r') as f:
    data = json.load(f)
locs = data['electrode_info']
locs_df = pd.DataFrame.from_dict(locs)
electrode_pos_not_nan = locs_df[~locs_df['x'].isna() & ~locs_df['y'].isna() & ~locs_df['z'].isna()]
units = spike_general.list_session_units(None, "SA", session, species_dir="/data")
unit_pos = pd.merge(units, electrode_pos_not_nan, left_on="Channel", right_on="electrode_id", how="left")
unit_pos = unit_pos.astype({"UnitID": int})

In [20]:
units

Unnamed: 0,UnitID,Channel,Unit,SpikeTimesFile
0,0,13a,1,/data/rawdata/sub-SA/sess-20180802/spikes/sub-...
1,1,32a,2,/data/rawdata/sub-SA/sess-20180802/spikes/sub-...
2,2,85a,2,/data/rawdata/sub-SA/sess-20180802/spikes/sub-...
3,3,37a,1,/data/rawdata/sub-SA/sess-20180802/spikes/sub-...
4,4,33a,1,/data/rawdata/sub-SA/sess-20180802/spikes/sub-...
5,5,81a,2,/data/rawdata/sub-SA/sess-20180802/spikes/sub-...
6,6,36a,2,/data/rawdata/sub-SA/sess-20180802/spikes/sub-...
7,7,28a,3,/data/rawdata/sub-SA/sess-20180802/spikes/sub-...
8,8,116,1,/data/rawdata/sub-SA/sess-20180802/spikes/sub-...
9,9,63a,1,/data/rawdata/sub-SA/sess-20180802/spikes/sub-...


In [40]:
temporals[temporals.structure_level2 == 'medial_temporal_lobe (MTL)']

17