Block 1 — imports & paths

In [1]:
import os, h5py, numpy as np, torch
from glob import glob
from tqdm import tqdm

# raw data root (unzipped from 00)
RAW_ROOT = "../data/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final"

# where to save processed combined datasets
PROC_ROOT = "../data/processed/brain-to-text-25"
os.makedirs(PROC_ROOT, exist_ok=True)

SPLITS = ["data_train.hdf5", "data_val.hdf5", "data_test.hdf5"]


Block 2 — discover all split files across days

In [2]:
def discover_split_files(raw_root, split_name):
    files = []
    for day in sorted(os.listdir(raw_root)):
        day_path = os.path.join(raw_root, day)
        if not os.path.isdir(day_path):
            continue
        fp = os.path.join(day_path, split_name)
        if os.path.exists(fp):
            files.append(fp)
    return files

files_by_split = {s: discover_split_files(RAW_ROOT, s) for s in SPLITS}

for s, flist in files_by_split.items():
    print(s, len(flist), "files")
    for f in flist[:5]:
        print("  ", f)


data_train.hdf5 45 files
   ../data/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final/t15.2023.08.11/data_train.hdf5
   ../data/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final/t15.2023.08.13/data_train.hdf5
   ../data/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final/t15.2023.08.18/data_train.hdf5
   ../data/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final/t15.2023.08.20/data_train.hdf5
   ../data/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final/t15.2023.08.25/data_train.hdf5
data_val.hdf5 41 files
   ../data/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final/t15.2023.08.13/data_val.hdf5
   ../data/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final/t15.2023.08.18/data_val.hdf5
   ../data/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final/t15.2023.08.20/data_val.hdf5
   ../data/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final/t15.2023.08.25/data_val.hdf5
   ../data/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final/t

Block 3 — small helpers (trim labels, per-trial z-score)

In [3]:
def trim_label_padding(label_ids: np.ndarray, pad_id: int = 0) -> np.ndarray:
    if label_ids.ndim == 0:
        return np.array([], dtype=np.int64)
    # drop trailing pad_id
    idx = np.where(label_ids != pad_id)[0]
    if idx.size == 0:
        return np.array([], dtype=np.int64)
    last = idx[-1] + 1
    return label_ids[:last].astype(np.int64)

def zscore_per_trial(x: np.ndarray, eps: float = 1e-6) -> np.ndarray:
    # x: (time, features) -> normalize each feature within this trial
    mu = x.mean(axis=0, keepdims=True)
    sd = x.std(axis=0, keepdims=True)
    return (x - mu) / (sd + eps)


Block 4 — load one split across all days; build unified lists


In [4]:
def load_split_across_days(file_list, require_labels=True, normalize=True, max_time=None):
    X_list, y_list, lengths = [], [], []
    bad = 0

    for fp in tqdm(file_list, desc="Loading"):
        try:
            with h5py.File(fp, "r") as f:
                for tkey in sorted(f.keys()):
                    g = f[tkey]
                    if "input_features" not in g:
                        continue
                    x = g["input_features"][()]  # (T, F)

                    if max_time is not None and x.shape[0] > max_time:
                        x = x[:max_time]

                    if normalize:
                        x = zscore_per_trial(x)

                    if require_labels and "seq_class_ids" in g:
                        y = trim_label_padding(g["seq_class_ids"][()])
                    elif require_labels:
                        # skip trials without labels in a labeled split
                        continue
                    else:
                        # test split -> labels unknown
                        y = np.array([], dtype=np.int64)

                    X_list.append(x.astype(np.float32))
                    y_list.append(y)
                    lengths.append(x.shape[0])
        except Exception as e:
            bad += 1
            print(f"⚠️  Skipping {fp}: {e}")

    print(f"Collected {len(X_list)} trials | skipped files: {bad}")
    return X_list, y_list, np.array(lengths, dtype=np.int32)


Block 5 — build combined datasets (train/val/test)


In [5]:
# tweak max_time if you want to hard-cap sequence length (e.g., max_time=1000). None = keep full length.
MAX_TIME = None

train_X, train_y, train_L = load_split_across_days(files_by_split["data_train.hdf5"], require_labels=True,  max_time=MAX_TIME)
val_X,   val_y,   val_L   = load_split_across_days(files_by_split["data_val.hdf5"],   require_labels=True,  max_time=MAX_TIME)
test_X,  test_y,  test_L  = load_split_across_days(files_by_split["data_test.hdf5"],  require_labels=False, max_time=MAX_TIME)

print("Train:", len(train_X), "| Val:", len(val_X), "| Test:", len(test_X))
print("Feature dim (train[0]):", train_X[0].shape[1] if train_X else None)
print("Trial length range (train):", (train_L.min() if train_L.size else None, train_L.max() if train_L.size else None))


Loading: 100%|██████████| 45/45 [00:32<00:00,  1.39it/s]


Collected 8072 trials | skipped files: 0


Loading: 100%|██████████| 41/41 [00:38<00:00,  1.05it/s]


Collected 1426 trials | skipped files: 0


Loading: 100%|██████████| 41/41 [00:25<00:00,  1.58it/s]

Collected 1450 trials | skipped files: 0
Train: 8072 | Val: 1426 | Test: 1450
Feature dim (train[0]): 512
Trial length range (train): (np.int32(138), np.int32(2475))





Block 6 — save processed sets (torch .pt for convenience)

In [6]:
def save_pt(path, X_list, y_list, lengths):
    payload = {"X": X_list, "y": y_list, "lengths": lengths}
    torch.save(payload, path)
    size_mb = os.path.getsize(path) / (1024*1024)
    print(f"Saved {path} ({size_mb:.1f} MB)")

save_pt(os.path.join(PROC_ROOT, "train.pt"), train_X, train_y, train_L)
save_pt(os.path.join(PROC_ROOT, "val.pt"),   val_X,   val_y,   val_L)
save_pt(os.path.join(PROC_ROOT, "test.pt"),  test_X,  test_y,  test_L)


: 

Block 7 — PyTorch Dataset + collate (padding on the fly)

In [None]:
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

class BrainTextCombined(Dataset):
    def __init__(self, pt_path, require_labels=True):
        blob = torch.load(pt_path, map_location="cpu")
        self.X = blob["X"]    # list of np arrays (T,F)
        self.y = blob["y"]    # list of np arrays (L,)
        self.require_labels = require_labels

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        x = torch.from_numpy(self.X[idx])             # (T,F) float32
        if self.require_labels:
            y = torch.from_numpy(self.y[idx]).long()  # (L,)
        else:
            y = torch.empty(0, dtype=torch.long)
        return x, y

def collate_batch(batch):
    xs, ys = zip(*batch)
    x_lens = torch.tensor([x.size(0) for x in xs], dtype=torch.int32)
    y_lens = torch.tensor([y.size(0) for y in ys], dtype=torch.int32)

    xs_pad = pad_sequence(xs, batch_first=True)  # (B, T_max, F)
    ys_pad = pad_sequence(ys, batch_first=True)  # (B, L_max)

    return xs_pad, x_lens, ys_pad, y_lens

# example loaders
train_ds = BrainTextCombined(os.path.join(PROC_ROOT, "train.pt"), require_labels=True)
val_ds   = BrainTextCombined(os.path.join(PROC_ROOT, "val.pt"),   require_labels=True)
test_ds  = BrainTextCombined(os.path.join(PROC_ROOT, "test.pt"),  require_labels=False)

train_loader = DataLoader(train_ds, batch_size=16, shuffle=True,  collate_fn=collate_batch)
val_loader   = DataLoader(val_ds,   batch_size=16, shuffle=False, collate_fn=collate_batch)


Block 8 — quick sanity check (shapes)

In [None]:
xb, xl, yb, yl = next(iter(train_loader))
print("X batch:", xb.shape, "| x_lens:", xl[:5].tolist())
print("Y batch:", yb.shape, "| y_lens:", yl[:5].tolist())
