In [1]:
import sys
import numpy as np
import pandas as pd
from one.api import ONE
path_root = '/home/yzhang39/IBL_foundation_model/'
sys.path.append(str(path_root))
from src.utils.ibl_data_utils import (
    prepare_data,
    select_brain_regions, 
    list_brain_regions, 
    bin_spiking_data,
    load_anytime_behaviors
)
from kirby.data import Data, IrregularTimeSeries, Interval, DatasetBuilder, ArrayDict

from datetime import datetime
from kirby.taxonomy import (
    Task,
    Sex,
    Species,
    SubjectDescription,
)
np.random.seed(42)

In [2]:
one = ONE(
    base_url='https://openalyx.internationalbrainlab.org', 
    password='international', mode='remote'
)

freeze_file = f'{path_root}/data/2023_12_bwm_release.csv'
bwm_df = pd.read_csv(freeze_file, index_col=0)

eid = 'db4df448-e449-4a6f-a0e7-288711e7a75a'

params = {
    'interval_len': 2, 'binsize': 0.02, 'single_region': False, 
    'align_time': 'stimOn_times', 'time_window': (-.5, 1.5), 'fr_thresh': 0.5
}

In [3]:
# load data
# CAUTION: Match trial selection criteria
neural_dict, _, meta_data, trials_data = prepare_data(one, eid, bwm_df, params, n_workers=1)
regions, beryl_reg = list_brain_regions(neural_dict, **params)
region_cluster_ids = select_brain_regions(neural_dict, beryl_reg, regions, **params)
binned_spikes, clusters_used_in_bins = bin_spiking_data(
    region_cluster_ids, neural_dict, trials_df=trials_data['trials_df'], n_workers=1, **params
)
avg_fr = binned_spikes.sum(1).mean(0) / params['interval_len']
active_neuron_ids = np.argwhere(avg_fr > 1/params['fr_thresh']).flatten()

Merge 1 probes for session eid: db4df448-e449-4a6f-a0e7-288711e7a75a


100%|██████████| 2/2 [00:00<00:00,  3.03it/s]


Use spikes from brain regions:  ['CA1' 'DG' 'LP' 'PoT' 'SGN' 'SNc' 'SPF' 'VISp' 'ZI' 'root']


100%|██████████| 402/402 [00:03<00:00, 119.37it/s]


In [4]:
# extract spiking activity
spike_times = neural_dict['spike_times']
spike_clusters = neural_dict['spike_clusters']
unit_mask = np.isin(spike_clusters, active_neuron_ids)
spike_times = spike_times[unit_mask]
spike_clusters = spike_clusters[unit_mask]

In [5]:
unit_ids = np.array(meta_data['uuids'])[active_neuron_ids]

In [6]:
unit_meta = []
timestamps = []
unit_index = []

for i in range(len(unit_ids)):

    unit_id = unit_ids[i]
    
    # extract spikes
    times = spike_times[spike_clusters == i]
    timestamps.append(times)

    if len(times) > 0:
        unit_index.append([i] * len(times))

    # extract unit metadata
    unit_meta.append(
        {
            "id": unit_id,
            "unit_number": i,
            "count": len(times),
            "type": 0,
        }
    )
    
unit_meta_df = pd.DataFrame(unit_meta)  # list of dicts to dataframe
units = ArrayDict.from_dataframe(
    unit_meta_df,
    unsigned_to_long=True,
)

# concatenate spikes
timestamps = np.concatenate(timestamps)
unit_index = np.concatenate(unit_index)

# create spikes object
spikes = IrregularTimeSeries(
    timestamps=timestamps,
    unit_index=unit_index,
    domain="auto",
)

# make sure to sort ethe spikes
spikes.sort()

In [7]:
# extract_trials
# CAUTION: Need to exclude trials in which behavior is NONE

trial_mask = trials_data['trials_mask']
start_time = (trials_data['trials_df'][params['align_time']] - params['time_window'][0])[trial_mask]
end_time = (trials_data['trials_df'][params['align_time']] + params['time_window'][1])[trial_mask]

max_num_trials = sum(trial_mask)
trial_idxs = np.random.choice(np.arange(max_num_trials), max_num_trials, replace=False)
train_idxs = trial_idxs[:int(0.7*max_num_trials)]
val_idxs = trial_idxs[int(0.7*max_num_trials):int(0.8*max_num_trials)]
test_idxs = trial_idxs[int(0.8*max_num_trials):]
trial_split = np.array(['train'] * max_num_trials)
trial_split[val_idxs] = 'val'
trial_split[test_idxs] = 'test'

trial_table = pd.DataFrame({
    "start": start_time,
    "end": end_time,
    "split_indicator": trial_split,
})
trials = Interval.from_dataframe(trial_table)

train_mask_nwb = trial_table.split_indicator.to_numpy() == "train"
val_mask_nwb = trial_table.split_indicator.to_numpy() == "val"
test_mask_nwb = trial_table.split_indicator.to_numpy() == "test"

trials.train_mask_nwb = (
    train_mask_nwb  # Naming with "_" since train_mask is reserved
)
trials.val_mask_nwb = val_mask_nwb
trials.test_mask_nwb = test_mask_nwb

In [8]:
# extract_behavior
behave_dict = load_anytime_behaviors(one, eid, n_workers=1)

100%|██████████| 2/2 [00:00<00:00,  4.62it/s]


In [9]:
timestamps = behave_dict['right-whisker-motion-energy']['times']
whisker = behave_dict['right-whisker-motion-energy']['values'] / 100.

In [10]:
behavior_type = np.ones_like(timestamps, dtype=np.int64) * 0

# report accuracy only on the evaluation intervals
eval_mask = np.zeros_like(timestamps, dtype=bool)

for i in range(len(trials)):
    eval_mask[
        (timestamps >= trials.start[i]) & (timestamps < trials.end[i])
    ] = True

In [11]:
behavior = IrregularTimeSeries(
    timestamps=timestamps,
    whisker=whisker.reshape(-1,1),
    subtask_index=behavior_type,
    eval_mask=eval_mask,
    domain="auto",
)

In [12]:
data = Data(
    # neural activity
    spikes=spikes,
    units=units,
    # stimuli and behavior
    trials=trials,
    behavior=behavior,
    # domain
    domain=Interval(trials.start[0], trials.end[-1]),
)

In [15]:
spike_times.shape

(3761212,)

In [16]:
spike_clusters.shape

(3761212,)

In [18]:
start_time.shape

(288,)

In [13]:
data_dict = {
    'data': {
        'spikes': ,
        'behavior': whisker,
    },
    'intervals': ,
    'train_intervals':,
    'test_intervals':
    'finetune_intervals': , 
}

37879

In [14]:
db = DatasetBuilder(
    raw_folder_path='/home/yzhang39/project-kirby/data/raw/',
    processed_folder_path=f'/home/yzhang39/project-kirby/data/processed/ibl',
    # metadata for the dataset
    experiment_name=eid,
    origin_version='',
    derived_version='',
    source='',
    description='',
)

In [15]:
with db.new_session() as session:

    # extract subject metadata
    # this dataset is from dandi, which has structured subject metadata, so we
    # can use the helper function extract_subject_from_nwb
    subject = SubjectDescription(
        id=meta_data['subject'],
        species=Species.from_string('MUS_MUSCULUS'),
        sex=Sex.from_string('MALE'),
    )
    session.register_subject(subject)

    # extract experiment metadata
    # recording_date = nwbfile.session_start_time.strftime("%Y%m%d")
    session_id = eid

    # register session
    session.register_session(
        id=session_id,
        recording_date=datetime.today().strftime('%Y%m%d'),
        task=Task.FREE_BEHAVIOR,
    )

    # register sortset
    session.register_sortset(
        id=session_id,
        units=units,
    )

    # register session
    session_start, session_end = (
        behavior.timestamps[0].item(),
        behavior.timestamps[-1].item(),
    )

    data = Data(
        # neural activity
        spikes=spikes,
        units=units,
        # stimuli and behavior
        trials=trials,
        behavior=behavior,
        # domain
        domain=Interval(session_start, session_end),
    )

    session.register_data(data)

    # split and register trials into train, validation and test
    train_trials = trials.select_by_mask(trials.train_mask_nwb)
    valid_trials = trials.select_by_mask(trials.val_mask_nwb)
    test_trials = trials.select_by_mask(trials.test_mask_nwb)

    session.register_split("train", train_trials)
    session.register_split("valid", valid_trials)
    session.register_split("test", test_trials)

    # save data to disk
    session.save_to_disk()

# all sessions added, finish by generating a description file for the entire dataset
db.finish()

In [None]:
# EVAL
# Need to bin spikes and standardize behavior before evaluation




