In [3]:
import os
os.environ['ROOT_DIR_BRAINTREEBANK'] = '/om2/user/zaho/braintreebank/braintreebank/' # Feel free to change this to your own path, or define the os variable elsewhere

import torch
import neuroprobe.config as neuroprobe_config

# Make sure the config ROOT_DIR is set correctly
print("Expected braintreebank data at:", neuroprobe_config.ROOT_DIR)
print("Sampling rate:", neuroprobe_config.SAMPLING_RATE, "Hz")

Expected braintreebank data at: /om2/user/zaho/braintreebank/braintreebank/
Sampling rate: 2048 Hz


## The BrainTreebankSubject Class

In [4]:
from neuroprobe.braintreebank_subject import BrainTreebankSubject

subject_id = 1

# use cache=True to load this trial's neural data into RAM, if you have enough memory!
# It will make the loading process faster.
subject = BrainTreebankSubject(subject_id, allow_corrupted=False, cache=True, dtype=torch.float32)
print("Electrode labels:", subject.electrode_labels) # list of electrode labels

# Optionally, subset the electrodes to a specific set of electrodes.
# subject.set_electrode_subset(['F3aOFa2', 'F3aOFa3', 'F3aOFa4', 'F3aOFa7']) # if you change this line when using cache=True, you need to clear the cache after: subject.clear_neural_data_cache()
# print("Electrode labels after subsetting:", subject.electrode_labels)

Electrode labels: ['F3aOFa2', 'F3aOFa3', 'F3aOFa4', 'F3aOFa7', 'F3aOFa8', 'F3aOFa9', 'F3aOFa10', 'F3aOFa11', 'F3aOFa12', 'F3aOFa13', 'F3aOFa14', 'F3aOFa15', 'F3aOFa16', 'F3bIaOFb1', 'F3bIaOFb2', 'F3bIaOFb3', 'F3bIaOFb4', 'F3bIaOFb5', 'F3bIaOFb6', 'F3bIaOFb7', 'F3bIaOFb8', 'F3bIaOFb9', 'F3bIaOFb10', 'F3bIaOFb11', 'F3bIaOFb12', 'F3bIaOFb13', 'F3bIaOFb14', 'F3bIaOFb15', 'F3bIaOFb16', 'F3cId1', 'F3cId2', 'F3cId3', 'F3cId4', 'F3cId5', 'F3cId6', 'F3cId7', 'F3cId8', 'F3cId9', 'T1aIb1', 'T1aIb2', 'T1aIb3', 'T1aIb4', 'T1aIb5', 'T1aIb6', 'T1aIb7', 'T1aIb8', 'T2aA1', 'T2aA2', 'T2aA3', 'T2aA4', 'T2aA5', 'T2aA6', 'T2aA7', 'T2aA8', 'T2aA9', 'T2aA10', 'T2aA11', 'T2aA12', 'T2bHa1', 'T2bHa3', 'T2bHa4', 'T2bHa5', 'T2bHa7', 'T2bHa8', 'T2bHa9', 'T2bHa10', 'T2bHa11', 'T2bHa12', 'T2bHa13', 'T2bHa14', 'T1bIc1', 'T1bIc2', 'T1bIc3', 'T1bIc4', 'T1bIc5', 'T1bIc6', 'T1bIc7', 'T1bIc8', 'F3dIe1', 'F3dIe2', 'F3dIe3', 'F3dIe4', 'F3dIe5', 'F3dIe6', 'F3dIe7', 'F3dIe8', 'F3dIe9', 'F3dIe10', 'F3dIe14', 'T3aHb6', 'T3aHb9'

Loading the electrode data and electrode coordinates

In [5]:
trial_id = 1

subject.load_neural_data(trial_id)
window_from = None
window_to = None # if None, the whole trial will be loaded

print("All neural data shape:")
print(subject.get_all_electrode_data(trial_id, window_from=window_from, window_to=window_to).shape) # (n_electrodes, n_samples). To get the data for a specific electrode, use subject.get_electrode_data(trial_id, electrode_label)

print("\nElectrode coordinates (MNI space) of the first 10 electrodes:")
print(subject.get_electrode_coordinates()[:10]) # L, P, I coordinates of the electrodes

All neural data shape:
torch.Size([129, 21401009])

Electrode coordinates (MNI space) of the first 10 electrodes:
tensor([[ 76.0103, -49.9502, -25.1740],
        [ 75.4765, -50.8993, -22.9590],
        [ 81.5899, -49.6018, -13.3198],
        [ 81.3702, -47.3542,  -6.0947],
        [ 83.1155, -43.3788,   0.5507],
        [ 79.8622, -41.9135,   3.7532],
        [ 79.1331, -41.2117,   4.8066],
        [ 67.2942, -28.6666,  14.7228],
        [ 68.9201, -28.2619,  15.2759],
        [ 77.7627, -21.2303,  19.6270]])


## The BrainTreebankSubjectTrialBenchmarkDataset Class

In [10]:
from neuroprobe.datasets import BrainTreebankSubjectTrialBenchmarkDataset

# Options for eval_name (from the Neuroprobe paper):
#   frame_brightness, global_flow, local_flow, global_flow_angle, local_flow_angle, face_num, volume, pitch, delta_volume, 
#   delta_pitch, speech, onset, gpt2_surprisal, word_length, word_gap, word_index, word_head_pos, word_part_speech, speaker
eval_name = "volume"

# if True, the dataset will output the indices of the samples in the neural data in a tuple: (index_from, index_to); 
# if False, the dataset will output the neural data directly
output_indices = False

start_neural_data_before_word_onset = 0 # the number of samples to start the neural data before each word onset
end_neural_data_after_word_onset = neuroprobe_config.SAMPLING_RATE * 1 # the number of samples to end the neural data after each word onset -- here we use 1 second


dataset = BrainTreebankSubjectTrialBenchmarkDataset(subject, trial_id, dtype=torch.float32, eval_name=eval_name, output_indices=output_indices, 
                                                    start_neural_data_before_word_onset=start_neural_data_before_word_onset, end_neural_data_after_word_onset=end_neural_data_after_word_onset,
                                                    lite=True) # the default is Neuroprobe Lite for standardized and quick benchmarking. Feel free to set lite=false if trying to access the Full dataset.
# P.S. Allow partial cache -- whether to allow partial caching of the neural data, if only part of it is needed for this particular dataset. Better set to False when doing multiple evals back to back, but better set to True when doing a single eval.

print("Items in the dataset:", len(dataset), "\n")
print(f"The first item: (shape = {dataset[0][0].shape})", dataset[0][0], f"label = {dataset[0][1]}", sep="\n")
print("")
print(f"Electrode labels in the data above in the following order ({len(dataset.electrode_labels)} electrodes):", dataset.electrode_labels)

print("\nNOTE: The electrode labels here are subset to the NEUROPROBE_LITE_ELECTRODES list. Accordingly, there will be fewer electrodes than in the full subject data.")
print("This is because the Neuroprobe benchmark only uses a subset of the electrodes for standardized and quick benchmarking. (assuming lite=True which is the default)")

Items in the dataset: 3500 

The first item: (shape = torch.Size([120, 2048]))
tensor([[ 32.9645,  27.3818,  20.4699,  ...,   2.1267,  -0.5317,   6.3802],
        [ 67.2582,  62.2072,  56.0928,  ...,  44.9274,  42.2690,  49.1809],
        [120.9584, 115.3757, 106.6029,  ...,  58.2195,  53.9661,  59.8146],
        ...,
        [ 15.1530,   7.4436,   3.7218,  ...,  -2.1267,  -5.8485,   2.9243],
        [ 26.3184,  19.4065,  14.8872,  ...,  16.4822,  14.3555,  19.9382],
        [-13.0263, -17.5456, -23.9258,  ...,   3.4560,   2.6584,  11.1654]])
label = 1

Electrode labels in the data above in the following order (120 electrodes): ['T1bIc1', 'T1bIc2', 'T1bIc3', 'T1bIc4', 'T1bIc5', 'T1bIc6', 'T1bIc7', 'T1bIc8', 'T1cIf10', 'T1cIf11', 'T1cIf12', 'T1cIf13', 'T1cIf14', 'T1cIf15', 'T1cIf16', 'T1aIb1', 'T1aIb2', 'T1aIb3', 'T1aIb4', 'T1aIb5', 'T1aIb6', 'T1aIb7', 'T1aIb8', 'T3aHb9', 'T3aHb10', 'T1cIf1', 'T1cIf2', 'T1cIf3', 'T1cIf4', 'T1cIf5', 'T1cIf6', 'T1cIf7', 'T1cIf8', 'T2bHa7', 'T2bHa8', 'T2bH

In [11]:
# Optionally, you can request the output_dict=True to get the data as a dictionary with a bunch of metadata.
dataset.output_dict = True
dataset[0]

{'data': tensor([[ 32.9645,  27.3818,  20.4699,  ...,   2.1267,  -0.5317,   6.3802],
         [ 67.2582,  62.2072,  56.0928,  ...,  44.9274,  42.2690,  49.1809],
         [120.9584, 115.3757, 106.6029,  ...,  58.2195,  53.9661,  59.8146],
         ...,
         [ 15.1530,   7.4436,   3.7218,  ...,  -2.1267,  -5.8485,   2.9243],
         [ 26.3184,  19.4065,  14.8872,  ...,  16.4822,  14.3555,  19.9382],
         [-13.0263, -17.5456, -23.9258,  ...,   3.4560,   2.6584,  11.1654]]),
 'label': 1,
 'electrode_labels': ['T1bIc1',
  'T1bIc2',
  'T1bIc3',
  'T1bIc4',
  'T1bIc5',
  'T1bIc6',
  'T1bIc7',
  'T1bIc8',
  'T1cIf10',
  'T1cIf11',
  'T1cIf12',
  'T1cIf13',
  'T1cIf14',
  'T1cIf15',
  'T1cIf16',
  'T1aIb1',
  'T1aIb2',
  'T1aIb3',
  'T1aIb4',
  'T1aIb5',
  'T1aIb6',
  'T1aIb7',
  'T1aIb8',
  'T3aHb9',
  'T3aHb10',
  'T1cIf1',
  'T1cIf2',
  'T1cIf3',
  'T1cIf4',
  'T1cIf5',
  'T1cIf6',
  'T1cIf7',
  'T1cIf8',
  'T2bHa7',
  'T2bHa8',
  'T2bHa9',
  'T2bHa10',
  'T2bHa11',
  'T2bHa12',
  

## Train/Test Splits

In this example, we generate train/test splits for the Single Subject Single Movie (SS-SM) evaluation.

All options: generate_splits_SS_SM, generate_splits_SS_DM, generate_splits_DS_DM, generate_splits_DS_SM

In [8]:
import neuroprobe.train_test_splits as neuroprobe_train_test_splits

# train_datasets and test_datasets are arrays of length k_folds, each element is a BrainTreebankSubjectTrialBenchmarkDataset for the train/test split
train_datasets, test_datasets = neuroprobe_train_test_splits.generate_splits_SS_SM(subject, trial_id, eval_name, dtype=torch.float32, 
                                                                                # Put the dataset parameters here
                                                                                output_indices=output_indices, start_neural_data_before_word_onset=start_neural_data_before_word_onset, end_neural_data_after_word_onset=end_neural_data_after_word_onset,
                                                                                lite=True)
print("len(train_datasets) = len(test_datasets) = k_folds =", len(train_datasets))

len(train_datasets) = len(test_datasets) = k_folds = 2


## Example Linear Regression on SS_SM

In [9]:
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
import numpy as np

for fold_idx in range(len(train_datasets)):
    print(f"Fold {fold_idx+1} of {len(train_datasets)}")
    train_dataset = train_datasets[fold_idx]
    test_dataset = test_datasets[fold_idx]

    # Convert PyTorch dataset to numpy arrays for scikit-learn
    X_train = np.array([item[0].flatten() for item in train_dataset])
    y_train = np.array([item[1] for item in train_dataset])
    X_test = np.array([item[0].flatten() for item in test_dataset])
    y_test = np.array([item[1] for item in test_dataset])

    # Standardize the data
    scaler = StandardScaler()
    X_train = scaler.fit_transform(X_train)
    X_test = scaler.transform(X_test)

    # Train logistic regression
    clf = LogisticRegression(random_state=42, max_iter=1000, tol=1e-3)
    clf.fit(X_train, y_train)

    # Evaluate model
    train_score = clf.score(X_train, y_train)
    test_score = clf.score(X_test, y_test)
    print(f"\t Train accuracy: {train_score:.3f} | Test accuracy: {test_score:.3f}")

Fold 1 of 2
	 Train accuracy: 1.000 | Test accuracy: 0.598
Fold 2 of 2
	 Train accuracy: 1.000 | Test accuracy: 0.570
