In [1]:
# tone_trial_level_decoding: imports

import numpy as np
from cebra import CEBRA
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score

from core import (
    DEFAULT_DATA_ROOT,
    load_tone_success_punish_trials_train_test,
)


  import pkg_resources


In [2]:
# tone_trial_level_decoding: load trials + build train tensors

root = DEFAULT_DATA_ROOT
print(f"Using data root: {root}")

(
    train_trials,
    train_labels,
    test_trials,
    test_labels,
    train_paths,
    test_paths,
) = load_tone_success_punish_trials_train_test(
    root,
    test_size=0.2,
    random_seed=0,
)

print("n train trials:", len(train_trials))
print("n test  trials:", len(test_trials))

# Build CEBRA-Time train data (X + session_ids)
X_train = np.vstack(train_trials)
session_ids_train = np.concatenate(
    [np.full(t.shape[0], i, dtype=int) for i, t in enumerate(train_trials)]
)
labels_time_train = np.concatenate(
    [np.full(t.shape[0], lbl, dtype=int) for t, lbl in zip(train_trials, train_labels)]
)

print("X_train:", X_train.shape)
print("session_ids_train:", session_ids_train.shape)
print("labels_time_train:", labels_time_train.shape)

assert X_train.shape[0] == session_ids_train.shape[0] == labels_time_train.shape[0]


Using data root: /Users/Columbia/Downloads/transcriptomics_ym/CEBRA_source_test/preprocessed time series
n train trials: 1940
n test  trials: 528
X_train: (202648, 77)
session_ids_train: (202648,)
labels_time_train: (202648,)


In [4]:
# tone_trial_level_decoding: train CEBRA-Time

cebra_time = CEBRA(
    model_architecture="offset10-model",
    output_dimension=3,
    num_hidden_units=32,
    temperature=2.0,
    time_offsets=10,
    conditional="session",   # critical: tells CEBRA where time is continuous
    batch_size=256,
    learning_rate=3e-4,
    max_iterations=4000,
    distance="cosine",
    verbose=True,
    device="cpu",
)

cebra_time.fit(X_train, session_ids_train)
emb_time_train = cebra_time.transform(X_train)

print("emb_time_train:", emb_time_train.shape)


pos: -0.4684 neg:  5.5919 total:  5.1235 temperature:  2.0000: 100%|██████████| 4000/4000 [05:42<00:00, 11.70it/s]


emb_time_train: (202648, 3)


In [5]:
# tone_trial_level_decoding: build train trial-level features

trial_features_train = []
trial_outcomes_train = []

for i in range(len(train_trials)):
    mask = (session_ids_train == i)
    traj = emb_time_train[mask]  # (T, 3)

    # 6-D feature: [mean_x, mean_y, mean_z, std_x, std_y, std_z]
    feat = np.concatenate([traj.mean(axis=0), traj.std(axis=0)])
    trial_features_train.append(feat)
    trial_outcomes_train.append(train_labels[i])

trial_features_train = np.array(trial_features_train)
trial_outcomes_train = np.array(trial_outcomes_train)

print("trial_features_train:", trial_features_train.shape)
print("trial_outcomes_train:", trial_outcomes_train.shape)


trial_features_train: (1940, 6)
trial_outcomes_train: (1940,)


In [6]:
# tone_trial_level_decoding: build test trial-level features

X_test = np.vstack(test_trials)
session_ids_test = np.concatenate(
    [np.full(t.shape[0], i, dtype=int) for i, t in enumerate(test_trials)]
)
labels_time_test = np.concatenate(
    [np.full(t.shape[0], lbl, dtype=int) for t, lbl in zip(test_trials, test_labels)]
)

print("X_test:", X_test.shape)
print("session_ids_test:", session_ids_test.shape)
print("labels_time_test:", labels_time_test.shape)

emb_time_test = cebra_time.transform(X_test)
print("emb_time_test:", emb_time_test.shape)

trial_features_test = []
trial_outcomes_test = []

for i in range(len(test_trials)):
    mask = (session_ids_test == i)
    traj = emb_time_test[mask]

    feat = np.concatenate([traj.mean(axis=0), traj.std(axis=0)])
    trial_features_test.append(feat)
    trial_outcomes_test.append(test_labels[i])

trial_features_test = np.array(trial_features_test)
trial_outcomes_test = np.array(trial_outcomes_test)

print("trial_features_test:", trial_features_test.shape)
print("trial_outcomes_test:", trial_outcomes_test.shape)


X_test: (55923, 77)
session_ids_test: (55923,)
labels_time_test: (55923,)
emb_time_test: (55923, 3)
trial_features_test: (528, 6)
trial_outcomes_test: (528,)


In [8]:
# tone_trial_level_decoding: classify success vs punish at trial level

clf = SVC(kernel="linear")
clf.fit(trial_features_train, trial_outcomes_train)

test_preds = clf.predict(trial_features_test)
test_acc = accuracy_score(trial_outcomes_test, test_preds)

print("✅ TRIAL-LEVEL TEST DECODING ACCURACY:", test_acc)
print("Chance level: 0.5")


✅ TRIAL-LEVEL TEST DECODING ACCURACY: 0.5056818181818182
Chance level: 0.5
