### Do an exploration of other sessions, see if utils and stuff work for them

In [1]:
%load_ext autoreload
%autoreload 2

import glob
import os
from datetime import datetime
import pandas as pd
from spike_tools import (
    general as spike_general,
    analysis as spike_analysis,
)
from lfp_tools import (
    startup as lfp_startup,
)
import json

### Grab sessions and put info in a dataframe

In [2]:
# NOTE: this is hacky code to replicate information that should already be stored in a Datajoint table, but don't have access atm
sess_folder_path = "/data/rawdata/sub-SA/"
sess_paths = glob.glob(f"{sess_folder_path}/sess-*")
session_names = [os.path.split(sess_path)[1].split("-")[1] for sess_path in sess_paths]

rows = []
for sess_name in session_names:
    if not sess_name.isdigit():
        continue
    # hacky way to grab a datetime
    date = datetime.strptime(sess_name[:8], "%Y%m%d").date()
    rest = sess_name[8:]
    count = int(rest) if rest else 0
    rows.append({
        "session_datetime": date,
        "session_count": count,
        "session_name": sess_name,
    })
sess_df = pd.DataFrame(rows)

### Filter sessions by num neurons, num trials

In [3]:
NUM_NEURONS = 20
NUM_TRIALS = 500
VALID_SESS_BEFORE = datetime.strptime("20181015", "%Y%m%d").date()

In [4]:
def check_num_trials(sess):
    sess_name = sess.session_name
    behavior_path = f"/data/rawdata/sub-SA/sess-{sess_name}/behavior/sub-SA_sess-{sess_name}_object_features.csv"
    if not os.path.isfile(behavior_path):
        return False
    beh = pd.read_csv(behavior_path)
    valid_beh = beh[beh.Response.isin(["Correct", "Incorrect"])]   
    return len(valid_beh) > 500

def check_date(sess):
    return sess.session_datetime < VALID_SESS_BEFORE

def check_num_neurons(sess):
    spike_dir_path = f"/data/rawdata/sub-SA/sess-{sess.session_name}/spikes"
    if not os.path.isdir(spike_dir_path):
        return False
    spike_times = spike_general.get_spike_times(None, "SA", sess.session_name, species_dir="/data")
    return len(spike_times.UnitID.unique()) > NUM_NEURONS

def filter_sessions(sess):
    return check_date(sess) and check_num_trials(sess) and check_num_neurons(sess)


sess_df["valid"] = sess_df.apply(filter_sessions, axis=1)

### Find 36 Session before task change date, with at least 500 trials and 20 neurons

In [29]:
len(sess_df[sess_df.valid])

36

### Look at electrode positions again, weights

In [7]:
def get_electrode_locations(row):
    session = row.session_name
    # For the cases like 201807250001
    sess_day = session[:8]
    info_path = f"/data/rawdata/sub-SA/sess-{sess_day}/session_info/sub-SA_sess-{sess_day}_sessioninfo.json"
    with open(info_path, 'r') as f:
        data = json.load(f)
    locs = data['electrode_info']
    locs_df = pd.DataFrame.from_dict(locs)
    electrode_pos_not_nan = locs_df[~locs_df['x'].isna() & ~locs_df['y'].isna() & ~locs_df['z'].isna()]
    units = spike_general.list_session_units(None, "SA", session, species_dir="/data")
    unit_pos = pd.merge(units, electrode_pos_not_nan, left_on="Channel", right_on="electrode_id", how="left")
    unit_pos = unit_pos.astype({"UnitID": int})
    locs_df["session"] = session
    return locs_df

In [8]:
valid_sess = pd.read_pickle("/data/patrick_scratch/multi_sess/valid_sessions.pickle")
all_sess_locations = pd.concat(valid_sess.apply(get_electrode_locations, axis=1).values, ignore_index=True)

In [27]:
len(all_sess_locations)

7920

In [20]:
all_sess_locations.structure_level1.unique()

array(['Occipital_Lobe (Occipital)', None, 'Parietal_Lobe (Parietal)',
       'diencephalon (di)', 'telencephalon (tel)',
       'Temporal_Lobe (Temporal)', 'Frontal_Lobe (Frontal)',
       'metencephalon (met)'], dtype=object)

In [21]:
all_sess_locations.structure_level2.unique()

array(['extrastriate_visual_areas_2-4 (V2-V4)', None,
       'inferior_parietal_lobule (IPL)', 'superior_parietal_lobule (SPL)',
       'middle_temporal_area (MT)', 'posterior_medial_cortex (PMC)',
       'thalamus (Thal)', 'basal_ganglia (BG)', 'medial_pallium (MPal)',
       'preoptic_complex (POC)',
       'floor_of_the_lateral_sulcus (floor_of_ls)',
       'inferior_temporal_cortex (ITC)', 'amygdala (Amy)',
       'lateral_and_ventral_pallium (LVPal)',
       'lateral_prefrontal_cortex (lat_PFC)', 'motor_cortex (motor)',
       'anterior_cingulate_gyrus (ACgG)', 'orbital_frontal_cortex (OFC)',
       'pons (Pons)', 'primary_visual_cortex (V1)',
       'medial_temporal_lobe (MTL)', 'diagonal_subpallium (DSP)',
       'somatosensory_cortex (SI/SII)'], dtype=object)

In [36]:
all_sess_locations.structure_level3.unique()

array(['preoccipital_visual_areas_2-3 (V2-V3)', None,
       'area_7_in_the_inferior_parietal_lobule (area_7_in_IPL)',
       'area_5 (area_5)', 'lateral_intraparietal_sulcus (lat_IPS)',
       'middle_temporal_area (MT)', 'posterior_cingulate_gyrus (PCgG)',
       'posterior_thalamus (PThal)', 'geniculate_thalamus (GThal)',
       'striatum (Str)', 'hippocampal_formation (HF)',
       'preoptic_complex (POC)',
       'floor_of_the_lateral_sulcus (floor_of_ls)',
       'fundus_of_the_superior_temporal_sulcus (STSf)',
       'subpallial_amygdala (spAmy)', 'pallidum (Pd)',
       'pallial_amygdala (pAmy)', 'lateral_pallium (LPal)',
       'ventral_pallium (VPal)', 'dorsolateral_prefrontal_cortex (dlPFC)',
       'medial_supplementary_motor_areas (SMA/preSMA)',
       'anterior_cingulate_cortex (ACC)',
       'ventrolateral_prefrontal_cortex (vlPFC)',
       'lateral_orbital_frontal_cortex (lat_OFC)', 'ventral_pons (VPons)',
       'primary_visual_cortex (V1)', 'visual_area_4 (V4)', 'area

In [15]:
all_sess_locations.groupby(all_sess_locations.structure_level1, dropna=False).count()[["electrode_id"]].rename(columns={"electrode_id": "num_neurons"}).to_csv("/data/patrick_scratch/num_electrodes/structure_level1.csv")
all_sess_locations.groupby(all_sess_locations.structure_level2, dropna=False).count()[["electrode_id"]].rename(columns={"electrode_id": "num_neurons"}).to_csv("/data/patrick_scratch/num_electrodes/structure_level2.csv")
all_sess_locations.groupby(all_sess_locations.structure_level3, dropna=False).count()[["electrode_id"]].rename(columns={"electrode_id": "num_neurons"}).to_csv("/data/patrick_scratch/num_electrodes/structure_level3.csv")
