# Loading the package

This notebook will walk you through a simple usecase of Neuroprobe and evaluation of the logistic regression baseline. It can be easily adapted to evaluate any foundation model of neural activity.

In [1]:
import os
# NOTE: Change this to your own path, or define an environment variable elsewhere
os.environ['ROOT_DIR_BRAINTREEBANK'] = '/om2/user/zaho/braintreebank/braintreebank/' 

import torch
import neuroprobe.config as neuroprobe_config

# Make sure the config ROOT_DIR is set correctly
print("Expected braintreebank data at:", neuroprobe_config.ROOT_DIR)
print("Sampling rate:", neuroprobe_config.SAMPLING_RATE, "Hz")

Expected braintreebank data at: /om2/user/zaho/braintreebank/braintreebank/
Sampling rate: 2048 Hz


## The BrainTreebankSubject Class

In [2]:
from neuroprobe import BrainTreebankSubject

subject_id = 1

coordinates_type = "cortical" # "cortical", "mni", "lpi". NOTE: MNI are not yet available for the Braintreebank dataset.
# cortical = standardized brain atlas cortical projection of the coordinates in Freesurfer space
# mni = MNI coordinates
# lpi = LPI coordinates (left, posterior, inferior) in the subject's coordinate system

# use cache=True to load this trial's neural data into RAM, if you have enough memory!
# It will make the loading process faster.
subject = BrainTreebankSubject(subject_id, allow_corrupted=False, cache=True, dtype=torch.float32, coordinates_type=coordinates_type)
print("Loaded subject", subject_id)
print("Electrode labels (first 10):", subject.electrode_labels[:10]) # list of electrode labels

print("\nElectrode coordinates (MNI space) of the first 10 electrodes:")
print(subject.get_electrode_coordinates()[:10]) # L, P, I coordinates of the electrodes

# Optionally, subset the electrodes to a specific set of electrodes. NOTE: you should not do this if you are using the neuroprobe as a standardized benchmark.
# subject.set_electrode_subset(['F3aOFa2', 'F3aOFa3', 'F3aOFa4', 'F3aOFa7']) # if you change this line when using cache=True, you need to clear the cache after: subject.clear_neural_data_cache()
# print("Electrode labels after subsetting:", subject.electrode_labels)

Loaded subject 1
Electrode labels (first 10): ['F3aOFa2', 'F3aOFa3', 'F3aOFa4', 'F3aOFa7', 'F3aOFa8', 'F3aOFa9', 'F3aOFa10', 'F3aOFa11', 'F3aOFa12', 'F3aOFa13']

Electrode coordinates (MNI space) of the first 10 electrodes:
tensor([[ 76.0103, -49.9502, -25.1740],
        [ 75.4765, -50.8993, -22.9590],
        [ 81.5899, -49.6018, -13.3198],
        [ 81.3702, -47.3542,  -6.0947],
        [ 83.1155, -43.3788,   0.5507],
        [ 79.8622, -41.9135,   3.7532],
        [ 79.1331, -41.2117,   4.8066],
        [ 67.2942, -28.6666,  14.7228],
        [ 68.9201, -28.2619,  15.2759],
        [ 77.7627, -21.2303,  19.6270]])


Loading the electrode data from a specific trial:

In [3]:
trial_id = 1

subject.load_neural_data(trial_id)
window_from = None # This is the index into the neural data array from where to start loading the data.
window_to = None # if None, the whole trial will be loaded

all_neural_data = subject.get_all_electrode_data(trial_id, window_from=window_from, window_to=window_to)

print("All neural data shape:")
print(all_neural_data.shape) # (n_electrodes, n_samples). To get the data for a specific electrode, use subject.get_electrode_data(trial_id, electrode_label)

print("\nFirst 50 samples of the first electrode (data is in microvolts):")
print(all_neural_data[0, :50])

All neural data shape:
torch.Size([129, 21401009])

First 50 samples of the first electrode (data is in microvolts):
tensor([-24.7234, -25.7867, -25.5209, -28.9769, -33.2303, -34.5595, -38.8130,
        -36.4204, -33.7620, -33.2303, -26.8501, -21.7991, -17.8115, -25.7867,
        -27.3818, -17.8115, -13.2921,  -5.3169,   2.9243,   1.8609,   7.4436,
         14.8872,  14.3555,  15.4189,  16.4822,  16.4822,  16.2164,  20.2040,
         19.6724,  15.4189,  18.6090,  17.8115,   4.7852,  -4.7852, -16.2164,
        -24.7234, -26.3184, -34.8254, -36.6863, -39.3447, -45.4591, -45.7249,
        -48.3834, -48.3834, -45.1933, -40.6739, -40.1422, -49.7126, -53.7002,
        -44.6616])


## The BrainTreebankSubjectTrialBenchmarkDataset Class

NOTE: In the dataset below, there will be fewer electrodes than in the full subject data. This is because the Neuroprobe benchmark only uses a subset of the electrodes for standardized and quick benchmarking. The electrode labels below are subset to the `neuroprobe_config.NEUROPROBE_LITE_ELECTRODES` list.

Accordingly, when using the `BrainTreebankSubjectTrialBenchmarkDataset` with `lite=True` (which is the default Neuroprobe benchmark option), make sure that you use the `dataset.electrode_labels` and `dataset.electrode_coordinates` properties, which give the electrode labels and the electrode coordinates in MNI space, respectively, in the exact order that the `dataset` will output the data tensors in.

In [4]:
from neuroprobe import BrainTreebankSubjectTrialBenchmarkDataset

# Options for eval_name (from the Neuroprobe paper): neuroprobe_config.EVAL_NAMES
#   frame_brightness, global_flow, local_flow, face_num, volume, pitch, delta_volume, 
#   speech, onset, gpt2_surprisal, word_length, word_gap, word_index, word_head_pos, word_part_speech, speaker
eval_name = "volume"

# if True, the dataset will output the indices of the samples in the neural data in a tuple: (index_from, index_to); 
# if False, the dataset will output the neural data directly
output_indices = False

start_neural_data_before_word_onset = 0 # the number of samples to start the neural data before each word onset
end_neural_data_after_word_onset = neuroprobe_config.SAMPLING_RATE * 1 # the number of samples to end the neural data after each word onset -- here we use 1 second

dataset = BrainTreebankSubjectTrialBenchmarkDataset(subject, trial_id, dtype=torch.float32, eval_name=eval_name, output_indices=output_indices, 
                                                    start_neural_data_before_word_onset=start_neural_data_before_word_onset, end_neural_data_after_word_onset=end_neural_data_after_word_onset,
                                                    lite=True) # the default is Neuroprobe Lite for standardized and quick benchmarking. Feel free to set lite=false if trying to access the Full dataset.
# P.S. Allow partial cache -- whether to allow partial caching of the neural data, if only part of it is needed for this particular dataset. Better set to False when doing multiple evals back to back, but better set to True when doing a single eval.

data_electrode_labels = dataset.electrode_labels # NOTE: this is different from the subject.electrode_labels! Neuroprobe uses a special subset of electrodes in this exact order.
data_electrode_coordinates = dataset.electrode_coordinates 

print("Items in the dataset:", len(dataset), "\n")
print(f"The first item: (shape = {dataset[0][0].shape})", dataset[0][0], f"label = {dataset[0][1]}", sep="\n")
print("")
print(f"Electrode labels in the data above in the following order ({len(data_electrode_labels)} electrodes):", data_electrode_labels)
print(f"Electrode coordinates in the data above in the following order ({len(data_electrode_coordinates)} electrodes):", data_electrode_coordinates)

Items in the dataset: 3500 

The first item: (shape = torch.Size([120, 2048]))
tensor([[ 32.9645,  27.3818,  20.4699,  ...,   2.1267,  -0.5317,   6.3802],
        [ 67.2582,  62.2072,  56.0928,  ...,  44.9274,  42.2690,  49.1809],
        [120.9584, 115.3757, 106.6029,  ...,  58.2195,  53.9661,  59.8146],
        ...,
        [ 15.1530,   7.4436,   3.7218,  ...,  -2.1267,  -5.8485,   2.9243],
        [ 26.3184,  19.4065,  14.8872,  ...,  16.4822,  14.3555,  19.9382],
        [-13.0263, -17.5456, -23.9258,  ...,   3.4560,   2.6584,  11.1654]])
label = 1

Electrode labels in the data above in the following order (120 electrodes): ['T1bIc1', 'T1bIc2', 'T1bIc3', 'T1bIc4', 'T1bIc5', 'T1bIc6', 'T1bIc7', 'T1bIc8', 'T1cIf10', 'T1cIf11', 'T1cIf12', 'T1cIf13', 'T1cIf14', 'T1cIf15', 'T1cIf16', 'T1aIb1', 'T1aIb2', 'T1aIb3', 'T1aIb4', 'T1aIb5', 'T1aIb6', 'T1aIb7', 'T1aIb8', 'T3aHb9', 'T3aHb10', 'T1cIf1', 'T1cIf2', 'T1cIf3', 'T1cIf4', 'T1cIf5', 'T1cIf6', 'T1cIf7', 'T1cIf8', 'T2bHa7', 'T2bHa8', 'T2bH

In [5]:
# Optionally, you can request the output_dict=True to get the data as a dictionary with a bunch of metadata.
dataset.output_dict = True
print(dataset[0])

dataset.output_dict = False # set it back

{'data': tensor([[ 32.9645,  27.3818,  20.4699,  ...,   2.1267,  -0.5317,   6.3802],
        [ 67.2582,  62.2072,  56.0928,  ...,  44.9274,  42.2690,  49.1809],
        [120.9584, 115.3757, 106.6029,  ...,  58.2195,  53.9661,  59.8146],
        ...,
        [ 15.1530,   7.4436,   3.7218,  ...,  -2.1267,  -5.8485,   2.9243],
        [ 26.3184,  19.4065,  14.8872,  ...,  16.4822,  14.3555,  19.9382],
        [-13.0263, -17.5456, -23.9258,  ...,   3.4560,   2.6584,  11.1654]]), 'label': 1, 'electrode_labels': ['T1bIc1', 'T1bIc2', 'T1bIc3', 'T1bIc4', 'T1bIc5', 'T1bIc6', 'T1bIc7', 'T1bIc8', 'T1cIf10', 'T1cIf11', 'T1cIf12', 'T1cIf13', 'T1cIf14', 'T1cIf15', 'T1cIf16', 'T1aIb1', 'T1aIb2', 'T1aIb3', 'T1aIb4', 'T1aIb5', 'T1aIb6', 'T1aIb7', 'T1aIb8', 'T3aHb9', 'T3aHb10', 'T1cIf1', 'T1cIf2', 'T1cIf3', 'T1cIf4', 'T1cIf5', 'T1cIf6', 'T1cIf7', 'T1cIf8', 'T2bHa7', 'T2bHa8', 'T2bHa9', 'T2bHa10', 'T2bHa11', 'T2bHa12', 'T2bHa13', 'T2bHa14', 'T3bOT8', 'T3bOT9', 'T3bOT10', 'F3cId1', 'F3cId2', 'F3cId3', 'F3

In [6]:
# Also, you can request only the indices into the neural data array, instead of the actual data.
# NOTE: These are the indices into the data as in the raw h5 files in the braintreebank dataset.

dataset.output_indices = True
print(dataset[0]) # Data format: (index_from, index_to), label

dataset.output_indices = False # set it back

((443667, 445715), 1)


## Train/Test Splits

In this example, we generate train/test splits for the WithinSession evaluation.

All options: generate_splits_within_session, generate_splits_cross_session, generate_splits_cross_subject

In [7]:
import neuroprobe.train_test_splits as neuroprobe_train_test_splits

folds = neuroprobe_train_test_splits.generate_splits_within_session(subject, trial_id, eval_name, dtype=torch.float32, 
                                                                                # Put the dataset parameters here
                                                                                output_indices=output_indices, start_neural_data_before_word_onset=start_neural_data_before_word_onset, end_neural_data_after_word_onset=end_neural_data_after_word_onset,
                                                                                lite=True)
print("len(folds) = k_folds =", len(folds))
folds

len(folds) = k_folds = 2


[{'train_dataset': <torch.utils.data.dataset.Subset at 0x14c75cf21fa0>,
  'test_dataset': <torch.utils.data.dataset.Subset at 0x14c75cf21250>},
 {'train_dataset': <torch.utils.data.dataset.Subset at 0x14c75cf21430>,
  'test_dataset': <torch.utils.data.dataset.Subset at 0x14c75cf21400>}]

## Example Linear Regression on SS_SM

In [8]:
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
import numpy as np

for fold_idx, fold in enumerate(folds):
    print(f"Fold {fold_idx+1} of {len(folds)}")
    train_dataset = fold["train_dataset"]
    test_dataset = fold["test_dataset"]

    # Convert PyTorch dataset to numpy arrays for scikit-learn
    X_train = np.array([item[0].flatten() for item in train_dataset])
    y_train = np.array([item[1] for item in train_dataset])
    X_test = np.array([item[0].flatten() for item in test_dataset])
    y_test = np.array([item[1] for item in test_dataset])

    # Standardize the data
    scaler = StandardScaler()
    X_train = scaler.fit_transform(X_train)
    X_test = scaler.transform(X_test)

    # Train logistic regression
    clf = LogisticRegression(random_state=42, max_iter=1000, tol=1e-3)
    clf.fit(X_train, y_train)

    # Evaluate model
    train_score = clf.score(X_train, y_train)
    test_score = clf.score(X_test, y_test)
    print(f"\t Train accuracy: {train_score:.3f} | Test accuracy: {test_score:.3f}")

Fold 1 of 2
	 Train accuracy: 1.000 | Test accuracy: 0.598
Fold 2 of 2
	 Train accuracy: 1.000 | Test accuracy: 0.570
