In [None]:
import matplotlib.pyplot as plt
import numpy as np
import os
import torch
import yaml
import pandas as pd
from daart.data import DataGenerator
from daart.models import Segmenter
from daart.transforms import ZScore

from one.api import ONE
from brainbox.io.one import SessionLoader

In [2]:
from pathlib import Path

In [3]:
# Connect to the IBL database
ONE.setup(base_url='https://openalyx.internationalbrainlab.org', silent=True)
one = ONE(password='international')

# Load the Model
model_dir = './artifacts/daart_binary'
model_file = os.path.join(model_dir, 'best_val_model.pt')

hparams_file = os.path.join(model_dir, 'hparams.yaml')
hparams = yaml.safe_load(open(hparams_file, 'rb'))

if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'
    
device = 'cpu'

model = Segmenter(hparams)
model.load_state_dict(torch.load(model_file, map_location=lambda storage, loc: storage))
model.to(device)
model.eval()

  WeightNorm.apply(module, name, dim)
  model.load_state_dict(torch.load(model_file, map_location=lambda storage, loc: storage))


Segmenter(
  (model): ModuleDict(
    (encoder): DilatedTCN(
      (model): Sequential(
        (tcn_block_00): DilationBlock(
          (conv0): Conv1d(6, 32, kernel_size=(9,), stride=(1,), padding=(4,))
          (conv1): Conv1d(32, 32, kernel_size=(9,), stride=(1,), padding=(4,))
          (activation): LeakyReLU(negative_slope=0.05)
          (final_activation): LeakyReLU(negative_slope=0.05)
          (dropout): Dropout2d(p=0.1, inplace=False)
          (block): Sequential(
            (conv1d_layer_0): Conv1d(6, 32, kernel_size=(9,), stride=(1,), padding=(4,))
            (lrelu_0): LeakyReLU(negative_slope=0.05)
            (dropout_0): Dropout2d(p=0.1, inplace=False)
            (conv1d_layer_1): Conv1d(32, 32, kernel_size=(9,), stride=(1,), padding=(4,))
            (lrelu_1): LeakyReLU(negative_slope=0.05)
            (dropout_1): Dropout2d(p=0.1, inplace=False)
          )
          (downsample): Conv1d(6, 32, kernel_size=(1,), stride=(1,))
        )
        (tcn_block_01): 

In [4]:
# Function to create data generator
def create_data_generator(sess_id, input_file):    
    signals = ['markers']
    transforms = [ZScore()]
    paths = [input_file]

    return DataGenerator(
        [sess_id], [signals], [transforms], [paths],
        batch_size=hparams['batch_size'], sequence_length=hparams['sequence_length'], sequence_pad=hparams['sequence_pad'],
        trial_splits=hparams['trial_splits'], input_type=hparams['input_type'],
        device=device
    )

In [5]:
# Function to run inference
def get_states(model, data_gen):
    tmp = model.predict_labels(data_gen, return_scores=True)
    probs = np.vstack(tmp['labels'][0])
    return np.argmax(probs, axis=1)

In [6]:
# Functions to Perform Smoothing
def non_uniform_savgol(x, y, window, polynom):
    if len(x) != len(y):
        raise ValueError('"x" and "y" must be of the same size')
    if len(x) < window:
        raise ValueError('The data size must be larger than the window size')
    if type(window) is not int:
        raise TypeError('"window" must be an integer')
    if window % 2 == 0:
        raise ValueError('The "window" must be an odd integer')
    if type(polynom) is not int:
        raise TypeError('"polynom" must be an integer')
    if polynom >= window:
        raise ValueError('"polynom" must be less than "window"')

    half_window = window // 2
    polynom += 1

    A = np.empty((window, polynom))
    tA = np.empty((polynom, window))
    t = np.empty(window)
    y_smoothed = np.full(len(y), np.nan)

    for i in range(half_window, len(x) - half_window, 1):
        x0 = x[i]
        A[:, 0] = 1.0
        for j in range(1, polynom):
            A[:, j] = (x[i - half_window:i + half_window + 1] - x0)**j
        tA = A.T
        t = np.linalg.inv(tA @ A) @ tA @ y[i - half_window:i + half_window + 1]
        y_smoothed[i] = t[0]

    y_smoothed[:half_window + 1] = y[:half_window + 1]
    y_smoothed[-half_window - 1:] = y[-half_window - 1:]

    return y_smoothed

# Function to smooth and interpolate the marker signals
def smooth_interpolate_signal_sg(signal, window=31, order=3, interp_kind='linear'):

    signal_noisy_w_nans = np.copy(signal)
    timestamps = np.arange(signal_noisy_w_nans.shape[0])
    good_idxs = np.where(~np.isnan(signal_noisy_w_nans))[0]
    if len(good_idxs) < window:
        print('Not enough non-nan indices to filter; returning original signal')
        return signal_noisy_w_nans
    signal_smooth_nonans = non_uniform_savgol(
        timestamps[good_idxs], signal_noisy_w_nans[good_idxs], window=window, polynom=order)
    signal_smooth_w_nans = np.copy(signal_noisy_w_nans)
    signal_smooth_w_nans[good_idxs] = signal_smooth_nonans
    interpolater = interp1d(
        timestamps[good_idxs], signal_smooth_nonans, kind=interp_kind, fill_value='extrapolate')
    signal = interpolater(timestamps)
    return signal

In [7]:
# Function to extract and/or smooth ibl data
def extract_marker_data(eid, l_thresh, view, paw, smooth, path):
    #sess_id = dropbox_marker_paths[eid]

    # Load the pose data
    sl = SessionLoader(one=one, eid=eid)
    sl.load_pose(likelihood_thr=l_thresh, views=[view])
    times = sl.pose[f'{view}Camera'].times.to_numpy()
    markers = sl.pose[f'{view}Camera'].loc[:, (f'{paw}_x', f'{paw}_y')].to_numpy()  

    # Load wheel data
    sl.load_wheel()
    wh_times = sl.wheel.times.to_numpy()
    wh_vel_oversampled = sl.wheel.velocity.to_numpy()
    
    # Resample wheel data at marker times
    interpolator = interp1d(wh_times, wh_vel_oversampled, fill_value='extrapolate')
    wh_vel = interpolator(times)

    if smooth == True:
        # Smooth the marker data
        markers[:, 0] = smooth_interpolate_signal_sg(markers[:, 0], window=7)
        markers[:, 1] = smooth_interpolate_signal_sg(markers[:, 1], window=7)
        
    # Process the data
    markers_comb = np.hstack([markers, wh_vel[:, None]])
    velocity = np.vstack([np.array([0, 0, 0]), np.diff(markers_comb, axis=0)])
    markers_comb = np.hstack([markers_comb, velocity])
    markers_z = (markers_comb - np.mean(markers_comb, axis=0)) / np.std(markers_comb, axis=0)
    feature_names = ['paw_x_pos', 'paw_y_pos', 'wheel_vel', 'paw_x_vel', 'paw_y_vel', 'wheel_acc']
    df = pd.DataFrame(markers_z, columns=feature_names)

    if smooth == False:
        df.to_csv(path + f'/marker_features/{eid}_features.csv')
        markers_file = path + f'/marker_features/{eid}_features.csv'
    else:
        df.to_csv(path + f'/marker_features/{eid}_features_smooth.csv')
        markers_file = path + f'/marker_features/{eid}_features_smooth.csv'
    return markers_file

In [8]:
# Session Information (rephro-ephys)
eids = [
    'db4df448-e449-4a6f-a0e7-288711e7a75a',  # Berkeley
    'd23a44ef-1402-4ed7-97f5-47e9a7a504d9',  # Berkeley
    '4a45c8ba-db6f-4f11-9403-56e06a33dfa4',  # Berkeley
    'e535fb62-e245-4a48-b119-88ce62a6fe67',  # Berkeley
    '54238fd6-d2d0-4408-b1a9-d19d24fd29ce',  # Berkeley
    'b03fbc44-3d8e-4a6c-8a50-5ea3498568e0',  # Berkeley
    '30c4e2ab-dffc-499d-aae4-e51d6b3218c2',  # CCU
    'd0ea3148-948d-4817-94f8-dcaf2342bbbe',  # CCU
    'a4a74102-2af5-45dc-9e41-ef7f5aed88be',  # CCU
    '746d1902-fa59-4cab-b0aa-013be36060d5',  # CCU
    '88224abb-5746-431f-9c17-17d7ef806e6a',  # CCU
    '0802ced5-33a3-405e-8336-b65ebc5cb07c',  # CCU
    'ee40aece-cffd-4edb-a4b6-155f158c666a',  # CCU
    'c7248e09-8c0d-40f2-9eb4-700a8973d8c8',  # CCU
    '72cb5550-43b4-4ef0-add5-e4adfdfb5e02',  # CCU
    'dda5fc59-f09a-4256-9fb5-66c67667a466',  # CSHL(C)
    '4b7fbad4-f6de-43b4-9b15-c7c7ef44db4b',  # CSHL(C)
    'f312aaec-3b6f-44b3-86b4-3a0c119c0438',  # CSHL(C)
    '4b00df29-3769-43be-bb40-128b1cba6d35',  # CSHL(C)
    'ecb5520d-1358-434c-95ec-93687ecd1396',  # CSHL(C)
    '51e53aff-1d5d-4182-a684-aba783d50ae5',  # NYU
    'f140a2ec-fd49-4814-994a-fe3476f14e66',  # NYU
    'a8a8af78-16de-4841-ab07-fde4b5281a03',  # NYU
    '61e11a11-ab65-48fb-ae08-3cb80662e5d6',  # NYU
    '73918ae1-e4fd-4c18-b132-00cb555b1ad2',  # Princeton
    'd9f0c293-df4c-410a-846d-842e47c6b502',  # Princeton
    'dac3a4c1-b666-4de0-87e8-8c514483cacf',  # SWC(H)
    '6f09ba7e-e3ce-44b0-932b-c003fb44fb89',  # SWC(H)
    '56b57c38-2699-4091-90a8-aba35103155e',  # SWC(M)
    '3638d102-e8b6-4230-8742-e548cd87a949',  # SWC(M)
    '7cb81727-2097-4b52-b480-c89867b5b34c',  # SWC(M)
    '781b35fd-e1f0-4d14-b2bb-95b7263082bb',  # UCL
    '3f859b5c-e73a-4044-b49e-34bb81e96715',  # UCL
    'b22f694e-4a34-4142-ab9d-2556c3487086',  # UCL
    '0a018f12-ee06-4b11-97aa-bbbff5448e9f',  # UCL
    'aad23144-0e52-4eac-80c5-c4ee2decb198',  # UCL
    'b196a2ad-511b-4e90-ac99-b5a29ad25c22',  # UCL
    'e45481fa-be22-4365-972c-e7404ed8ab5a',  # UCL
    'd04feec7-d0b7-4f35-af89-0232dd975bf0',  # UCL
    '1b715600-0cbc-442c-bd00-5b0ac2865de1',  # UCL
    'c7bf2d49-4937-4597-b307-9f39cb1c7b16',  # UCL
    '8928f98a-b411-497e-aa4b-aa752434686d',  # UCL
    'ebce500b-c530-47de-8cb1-963c552703ea',  # UCLA
    'dc962048-89bb-4e6a-96a9-b062a2be1426',  # UCLA
    '6899a67d-2e53-4215-a52a-c7021b5da5d4',  # UCLA
    '15b69921-d471-4ded-8814-2adad954bcd8',  # UCLA
    '5ae68c54-2897-4d3a-8120-426150704385',  # UCLA
    'ca4ecb4c-4b60-4723-9b9e-2c54a6290a53',  # UCLA
    '824cf03d-4012-4ab1-b499-c83a92c5589e',  # UCLA
    '3bcb81b4-d9ca-4fc9-a1cd-353a966239ca',  # UW
    'f115196e-8dfe-4d2a-8af3-8206d93c1729',  # UW
    '9b528ad0-4599-4a55-9148-96cc1d93fb24',  # UW
    '3e6a97d3-3991-49e2-b346-6948cb4580fb',  # UW
]

dropbox_marker_paths = {
    'db4df448-e449-4a6f-a0e7-288711e7a75a': 'danlab_DY_009_2020-02-27-001',
    'd23a44ef-1402-4ed7-97f5-47e9a7a504d9': 'danlab_DY_016_2020-09-12-001',
    '4a45c8ba-db6f-4f11-9403-56e06a33dfa4': 'danlab_DY_020_2020-09-29-001',
    'e535fb62-e245-4a48-b119-88ce62a6fe67': 'danlab_DY_013_2020-03-12-001',
    '54238fd6-d2d0-4408-b1a9-d19d24fd29ce': 'danlab_DY_018_2020-10-15-001',
    'b03fbc44-3d8e-4a6c-8a50-5ea3498568e0': 'danlab_DY_010_2020-01-27-001',
    '30c4e2ab-dffc-499d-aae4-e51d6b3218c2': 'mainenlab_ZFM-02370_2021-04-29-001',
    'd0ea3148-948d-4817-94f8-dcaf2342bbbe': 'mainenlab_ZFM-01936_2021-01-19-001',
    'a4a74102-2af5-45dc-9e41-ef7f5aed88be': 'mainenlab_ZFM-02368_2021-06-01-001',
    '746d1902-fa59-4cab-b0aa-013be36060d5': 'mainenlab_ZFM-01592_2020-10-20-001',
    '88224abb-5746-431f-9c17-17d7ef806e6a': 'mainenlab_ZFM-02372_2021-06-01-002',
    '0802ced5-33a3-405e-8336-b65ebc5cb07c': 'mainenlab_ZFM-02373_2021-06-23-001',
    'ee40aece-cffd-4edb-a4b6-155f158c666a': 'mainenlab_ZM_2241_2020-01-30-001',
    'c7248e09-8c0d-40f2-9eb4-700a8973d8c8': 'mainenlab_ZM_3001_2020-08-05-001',
    '72cb5550-43b4-4ef0-add5-e4adfdfb5e02': 'mainenlab_ZFM-02369_2021-05-19-001',
    'dda5fc59-f09a-4256-9fb5-66c67667a466': 'churchlandlab_CSHL059_2020-03-06-001',
    '4b7fbad4-f6de-43b4-9b15-c7c7ef44db4b': 'churchlandlab_CSHL049_2020-01-08-001',
    'f312aaec-3b6f-44b3-86b4-3a0c119c0438': 'churchlandlab_CSHL058_2020-07-07-001',
    '4b00df29-3769-43be-bb40-128b1cba6d35': 'churchlandlab_CSHL052_2020-02-21-001',
    'ecb5520d-1358-434c-95ec-93687ecd1396': 'churchlandlab_CSHL051_2020-02-05-001',
    '51e53aff-1d5d-4182-a684-aba783d50ae5': 'angelakilab_NYU-45_2021-07-19-001',
    'f140a2ec-fd49-4814-994a-fe3476f14e66': 'angelakilab_NYU-47_2021-06-21-003',
    'a8a8af78-16de-4841-ab07-fde4b5281a03': 'angelakilab_NYU-12_2020-01-22-001',
    '61e11a11-ab65-48fb-ae08-3cb80662e5d6': 'angelakilab_NYU-21_2020-08-10-002',
    '73918ae1-e4fd-4c18-b132-00cb555b1ad2': 'wittenlab_ibl_witten_27_2021-01-21-001',
    'd9f0c293-df4c-410a-846d-842e47c6b502': 'wittenlab_ibl_witten_25_2020-12-07-002',
    'dac3a4c1-b666-4de0-87e8-8c514483cacf': 'hoferlab_SWC_060_2020-11-24-001',
    '6f09ba7e-e3ce-44b0-932b-c003fb44fb89': 'hoferlab_SWC_043_2020-09-16-002',
    '56b57c38-2699-4091-90a8-aba35103155e': 'mrsicflogellab_SWC_054_2020-10-05-001',
    '3638d102-e8b6-4230-8742-e548cd87a949': 'mrsicflogellab_SWC_058_2020-12-07-001',
    '7cb81727-2097-4b52-b480-c89867b5b34c': 'mrsicflogellab_SWC_052_2020-10-22-001',
    '781b35fd-e1f0-4d14-b2bb-95b7263082bb': 'cortexlab_KS044_2020-12-09-001',
    '3f859b5c-e73a-4044-b49e-34bb81e96715': 'cortexlab_KS094_2022-06-17-001',
    'b22f694e-4a34-4142-ab9d-2556c3487086': 'cortexlab_KS055_2021-05-02-001',
    '0a018f12-ee06-4b11-97aa-bbbff5448e9f': 'cortexlab_KS051_2021-05-11-001',
    'aad23144-0e52-4eac-80c5-c4ee2decb198': 'cortexlab_KS023_2019-12-10-001',
    'b196a2ad-511b-4e90-ac99-b5a29ad25c22': 'cortexlab_KS084_2022-02-01-001',
    'e45481fa-be22-4365-972c-e7404ed8ab5a': 'cortexlab_KS086_2022-03-15-001',
    'd04feec7-d0b7-4f35-af89-0232dd975bf0': 'cortexlab_KS089_2022-03-19-001',
    '1b715600-0cbc-442c-bd00-5b0ac2865de1': 'cortexlab_KS084_2022-01-31-001',
    'c7bf2d49-4937-4597-b307-9f39cb1c7b16': 'cortexlab_KS074_2021-11-22-001',
    '8928f98a-b411-497e-aa4b-aa752434686d': 'cortexlab_KS096_2022-06-17-001',
    'ebce500b-c530-47de-8cb1-963c552703ea': 'churchlandlab_ucla_MFD_09_2023-10-19-001',
    'dc962048-89bb-4e6a-96a9-b062a2be1426': 'churchlandlab_ucla_UCLA048_2022-08-16-001',
    '6899a67d-2e53-4215-a52a-c7021b5da5d4': 'churchlandlab_ucla_MFD_06_2023-08-29-001',
    '15b69921-d471-4ded-8814-2adad954bcd8': 'churchlandlab_ucla_MFD_07_2023-08-31-001',
    '5ae68c54-2897-4d3a-8120-426150704385': 'churchlandlab_ucla_MFD_08_2023-09-07-001',
    'ca4ecb4c-4b60-4723-9b9e-2c54a6290a53': 'churchlandlab_ucla_MFD_05_2023-08-16-001',
    '824cf03d-4012-4ab1-b499-c83a92c5589e': 'churchlandlab_ucla_UCLA011_2021-07-20-001',
    '3bcb81b4-d9ca-4fc9-a1cd-353a966239ca': 'steinmetzlab_NR_0024_2023-01-19-001',
    'f115196e-8dfe-4d2a-8af3-8206d93c1729': 'steinmetzlab_NR_0021_2022-06-23-003',
    '9b528ad0-4599-4a55-9148-96cc1d93fb24': 'steinmetzlab_NR_0019_2022-04-29-001',
    '3e6a97d3-3991-49e2-b346-6948cb4580fb': 'steinmetzlab_NR_0020_2022-05-08-001',
    #dropbox follows
    '46794e05-3f6a-4d35-afb3-9165091a5a74': 'churchlandlab_CSHL045_2020-02-27-001',
    'db4df448-e449-4a6f-a0e7-288711e7a75a': 'danlab_DY_009_2020-02-27-001',
    '54238fd6-d2d0-4408-b1a9-d19d24fd29ce': 'danlab_DY_018_2020-10-15-001',
    'f3ce3197-d534-4618-bf81-b687555d1883': 'hoferlab_SWC_043_2020-09-15-001',
    '493170a6-fd94-4ee4-884f-cc018c17eeb9': 'hoferlab_SWC_061_2020-11-23-001',
    '7cb81727-2097-4b52-b480-c89867b5b34c': 'mrsicflogellab_SWC_052_2020-10-22-001',
    'ff96bfe1-d925-4553-94b5-bf8297adf259': 'wittenlab_ibl_witten_26_2021-01-27-002',
    '73918ae1-e4fd-4c18-b132-00cb555b1ad2': 'wittenlab_ibl_witten_27_2021-01-21-001'
}

In [9]:
# Main Loop
save_dir = './artifacts'

l_thresh = 0.0
view = 'left'
paw = 'paw_r'
if view == 'right':
    raise NotImplementedError

for eid in eids:
    sess_id = dropbox_marker_paths[eid]
    print(f'Eid: {eid} -- Session: {sess_id}')

    # Extract the marker data from the ibl database
    ibl_markers_file = extract_marker_data(eid, l_thresh, view, paw, smooth=False, path=save_dir)
    ibl_smooth_markers_file = extract_marker_data(eid, l_thresh, view, paw, smooth=True, path=save_dir)
    
    # Build the Data Generators
    data_gen_ibl = create_data_generator(sess_id, ibl_markers_file)
    data_gen_ibl_smooth = create_data_generator(sess_id, ibl_smooth_markers_file)

    # Run Inference
    states_ibl = get_states(model, data_gen_ibl)
    states_ibl_smooth = get_states(model, data_gen_ibl_smooth)

    # Ensure state values are within the correct range
    states_ibl = np.clip(states_ibl, 1, 4)
    states_ibl_smooth = np.clip(states_ibl_smooth, 1, 4)

    # Truncate the arrays to match the length of the shortest one
    min_length = min(len(states_ibl), len(states_ibl_smooth))
    states_ibl_truncated = states_ibl[:min_length]
    states_ibl_smooth_truncated = states_ibl_smooth[:min_length]

    break

Eid: db4df448-e449-4a6f-a0e7-288711e7a75a -- Session: danlab_DY_009_2020-02-27-001


NameError: name 'ZScore' is not defined

In [None]:
np.unique(states_ibl_truncated)

array([1, 2, 3, 4])