# Imports

Various useful builtins:

In [None]:
from typing import *
import functools
import pathlib

External packages:

In [None]:
import pandas as pd
import numpy as np

Our main workhorse, `allensdk`:

In [None]:
import allensdk
from allensdk.brain_observatory.ecephys.ecephys_project_cache import EcephysProjectCache
from allensdk.brain_observatory.ecephys.ecephys_session import EcephysSession

Local modules:

In [None]:
from constraints import *

# Global Variables
We have a bunch of variables which we will realistically set only once
and never again.  Because of that, and to save up on namespace, we
will write these in uppercase and type them as `Final`.

## Introspection & debugging
First are the two debugging variables, `DEBUG` and `EXAMPLE`.  
`DEBUG` should control behavior checks.  
`EXAMPLE` should control examples which are not necessary for the project.

In [None]:
DEBUG: Final[bool] = False
EXAMPLE: Final[bool] = True

## Paths
Then we have our data directory, `DATA_DIR`, which contains the data
we cache from `allensdk`.

In [None]:
DATA_DIR: Final[pathlib.Path] = pathlib.Path('./data/')

and we need to ensure that this directory exists

In [None]:
DATA_DIR.mkdir(parents=True, exists_ok=True)

Furthermore, we have the path to the manifest file, again relative to
the project root, which is by default the `manifest.json` file located
directly in `DATA_DIR`.

In [None]:
MANIFEST_PATH = Final[Path] = DATA_DIR / 'manifest.json'

## Cache & sessions
Next is our `allensdk` cache and the sessions in it.

First define the type aliases `Cache` and `Session` for easier
typing in functions.

In [None]:
Cache: TypeAlias = EcephysProjectCache
Session: TypeAlias = EcephysSession

Set the cache's timeout, `CACHE_TIMEOUT`, which is measured in
seconds:

In [None]:
CACHE_TIMEOUT: Final[int] = 30*60 # 30 minutes

Define the global cache, `CACHE`,

In [None]:
CACHE: Final[Cache] = EcephysProjectCache.from_warehouse(manifest=MANIFEST_PATH, timeout=CACHE_TIMEOUT)

Load the table of sessions for simple access.

In [None]:
SESSIONS_TABLE: Final = CACHE.get_session_table()
if EXAMPLE: SESSIONS_TABLE

Get the complete list of session ids.

In [None]:
SESSION_IDS: Final[Sequence[int]] = SESSIONS_TABLE.index
if EXAMPLE: SESSION_IDS

We typically want to work on just one session, `CURRENT_SESSION`.
That can be specified here for convenience using `CURRENT_SESSION_ID`.

Functions strive to accept a `session` parameter where meaningful and
default to `CURRENT_SESSION`.

In [None]:
CURRENT_SESSION_ID: Final[int] = 715093703
assert CURRENT_SESSION_ID in SESSION_IDS
CURRENT_SESSION: Final[Session] = CACHE.get_session_data(CURRENT_SESSION_ID)

# Accessing filtered data
The `Session` object exposes a lot of useful functions and properties
to us.  In order to free us from having to always type the session
(which will overwhelmingly be `CURRENT_SESSION`) and to simultaneously
allow us to filter our data without having to go through much trouble,
we will define simple wrappers around those functions.

Typically, these functions accept a `session` parameter which defaults
to `CURRENT_SESSION`, and a bunch of keyword arguments which filter
the resulting dataset.

Each argument `foo=bar`, unless stated otherwise, filters the `foo`
column of the resulting dataset to values matching `bar`.  See the
commentary at the start of `constraints.py` for a full description of
what values `bar` can take.

In [None]:
def get_sessions(**kwargs):
    """Return a table of the matching sessions.

    The following filters are meaningful:
    - published_at                (time)
    - specimen_id                 (integer, key)
    - session_type                ('brain_observatory_1.1' or 'functional_connectivity')
    - age_in_days                 (float)
    - sex                         ('M' or 'F')
    - full_genotype               (string)
    - unit_count                  (integer)
    - channel_count               (integer)
    - probe_count                 (integer)
    - ecephys_structure_acronyms  (list of strings)

    If an argument `__total__=False` is passed, additional filters may
    be provided with no effect on the result.

    """
    return filter_df(SESSIONS_TABLE, FIELD(**kwargs))

def get_session_ids(**kwargs):
    """Return the matching session ids.

    See `get_sessions` for the list of meaningful filters.
    
    """
    return get_sessions(**kwargs).index

if EXAMPLE:
    get_sessions(sex='M', unit_count=RANGE(650, None))

In [None]:
def get_units(ecephys_structure_acronym = None,
              session: Session = CURRENT_SESSION,
              **kwargs):
    """Return a `Session.units` dataframe of the matching units in `session`.
    
    The following filters are meaningful:
    - waveform_PT_ratio                      (float)
    - waveform_amplitude                     (float)
    - amplitude_cutoff                       (float)
    - cluster_id                             (integer, key)
    - cumulative_drift                       (float)
    - d_prime                                (float or null)
    - firing_rate                            (float)
    - isi_violations                         (float)
    - isolation_distance                     (float or null)
    - L_ratio                                (float or null)
    - local_index                            (integer)
    - max_drift                              (float)
    - nn_hit_rate                            (float or null)
    - nn_miss_rate                           (float or null)
    - peak_channel_id                        (integer, key)
    - presence_ratio                         (float)
    - waveform_recovery_slope                (float or null)
    - waveform_repolarization_slope          (float)
    - silhouette_score                       (float or null)
    - snr                                    (float)
    - waveform_spread                        (float)
    - waveform_velocity_above                (float or null)
    - waveform_velocity_below                (float or null)
    - waveform_duration                      (float)
    - filtering                              (string)
    - probe_channel_number                   (integer)
    - probe_horizontal_position              (integer)
    - probe_id                               (integer)
    - probe_vertical_position                (integer)
    - structure_acronym                      (string)
    - ecephys_structure_id                   (float, key)
    - ecephys_structure_acronym              (string)
    - anterior_posterior_ccf_coordinate      (float or null)
    - dorsal_ventral_ccf_coordinate          (float or null)
    - left_right_ccf_coordinate              (float or null)
    - probe_description                      (string, probeA..F)
    - location                               (object)
    - probe_sampling_rate                    (float)
    - probe_lfp_sampling_rate                (float)
    - probe_has_lfp_data                     (bool)

    If an argument `__total__=False` is passed, additional filters may
    be provided with no effect on the result.
    """
    if ecephys_structure_acronym is not None:
        kwargs['ecephys_structure_acronym'] = ecephys_structure_acronym
        return filter_df(session.units, FIELD(**kwargs))
    
def get_unit_ids(ecephys_structure_acronym = None,
                 session: Session = CURRENT_SESSION,
                 **kwargs):
    """Return the matching unit ids in `session`.
    
    See `get_units` for the list of meaningful filters.

    """
    return get_units(ecephys_structure_acronym = ecephys_structure_acronym,
                     session = session, **kwargs).index

if EXAMPLE:
    get_units(isi_violations = RANGE(None, 0.7),
          structure_acronym = 'VISam')

In [None]:
def get_stimulus_presentations(stimulus_name = None,
                               stimulus_condition_id = None,
                               session: Session = CURRENT_SESSION,
                               **kwargs):
    """Return the Sessions.stimulus_presentations dataframe of `session`.

    The following filters are meaningful:
    - stimulus_block           (float or null, key) 
    - start_time               (float)
    - stop_time                (float)
    - contrast                 (float or null) 
    - spatial_frequency        (float, string, or null) 
    - frame                    (float or null) 
    - stimulus_name            (string) 
    - x_position               (float or null)
    - y_position               (float or null) 
    - orientation              (float or null) 
    - temporal_frequency       (float or null) 
    - size                     (object) 
    - color                    (-1.0, 1.0, or null) 
    - phase                    (object) 
    - duration                 (float)
    - stimulus_condition_id    (integer, key)

    If an argument `__total__=False` is passed, additional filters may
    be provided with no effect on the result.
    
    """
    if stimulus_name is not None:
        kwargs['stimulus_name'] = stimulus_name
    if stimulus_condition_id is not None:
        kwargs['stimulus_condition_id'] = stimulus_condition_id
        return filter_df(session.stimulus_presentations, FIELD(**kwargs))

def get_stimulus_presentation_ids(stimulus_name = None,
                                  stimulus_condition_id = None,
                                  session: Session = CURRENT_SESSION,
                                  **kwargs):
    """Return the matching stimulus presentation ids in `session`.
        
    See `get_stimulus_presentations` for a list of meaningful filters.
    
    """
    return get_stimulus_presentations(stimulus_name = stimulus_name,
                                      stimulus_condition_id = stimulus_condition_id,
                                      session = session,
                                      **kwargs).index

if EXAMPLE:
    get_stimulus_presentations(stimulus_name = 'static_gratings',
                               orientation = RANGE(30, 60, ub_strict=False))

In [None]:
def get_presentationwise_spike_times(session: Session = CURRENT_SESSION, **kwargs):
    """Return a table of the spike times of the matching units and stimuli.

    All filters which `get_units` and `get_stimulus_presentations`
    accept are meaningful.

    """
    kwargs['__total__'] = False
    return session.presentationwise_spike_times(
        stimulus_presentation_ids = get_stimulus_presentation_ids(session = session, **kwargs),
        unit_ids = get_unit_ids(session = session, **kwargs)
    )