In [4]:
!pip install torch

Collecting torch
  Downloading torch-2.8.0-cp313-none-macosx_11_0_arm64.whl.metadata (30 kB)
Collecting filelock (from torch)
  Downloading filelock-3.19.1-py3-none-any.whl.metadata (2.1 kB)
Collecting typing-extensions>=4.10.0 (from torch)
  Downloading typing_extensions-4.15.0-py3-none-any.whl.metadata (3.3 kB)
Collecting sympy>=1.13.3 (from torch)
  Using cached sympy-1.14.0-py3-none-any.whl.metadata (12 kB)
Collecting networkx (from torch)
  Downloading networkx-3.5-py3-none-any.whl.metadata (6.3 kB)
Collecting jinja2 (from torch)
  Using cached jinja2-3.1.6-py3-none-any.whl.metadata (2.9 kB)
Collecting fsspec (from torch)
  Downloading fsspec-2025.9.0-py3-none-any.whl.metadata (10 kB)
Collecting mpmath<1.4,>=1.1.0 (from sympy>=1.13.3->torch)
  Using cached mpmath-1.3.0-py3-none-any.whl.metadata (8.6 kB)
Collecting MarkupSafe>=2.0 (from jinja2->torch)
  Downloading markupsafe-3.0.3-cp313-cp313-macosx_11_0_arm64.whl.metadata (2.7 kB)
Downloading torch-2.8.0-cp313-none-macosx_11_0_ar

In [6]:
import os
import h5py
import torch
from torch.utils.data import Dataset, ConcatDataset

# ---- Your existing dataset class ----
class BrainToTextDataset(Dataset):
    def __init__(self, filepath, require_labels=True):
        self.filepath = filepath
        self.trials = []
        self.require_labels = require_labels

        with h5py.File(filepath, "r") as f:
            for tkey in sorted(f.keys()):
                feats = f[tkey]["input_features"][()]

                if "seq_class_ids" in f[tkey]:
                    labels = f[tkey]["seq_class_ids"][()]
                else:
                    if require_labels:
                        # Skip trials with no labels if we expect them
                        continue
                    else:
                        labels = np.array([])  # placeholder

                self.trials.append((feats, labels))

    def __len__(self):
        return len(self.trials)

    def __getitem__(self, idx):
        x, y = self.trials[idx]
        x = torch.tensor(x, dtype=torch.float32)
        y = torch.tensor(y, dtype=torch.long) if len(y) > 0 else torch.tensor([])
        return x, y


# ---- Function to gather datasets across all days ----
def load_all_days(base_dir):
    train_sets, val_sets, test_sets = [], [], []

    for day in sorted(os.listdir(base_dir)):
        day_path = os.path.join(base_dir, day)
        if not os.path.isdir(day_path):
            continue

        train_file = os.path.join(day_path, "data_train.hdf5")
        val_file   = os.path.join(day_path, "data_val.hdf5")
        test_file  = os.path.join(day_path, "data_test.hdf5")

        if os.path.exists(train_file):
            train_sets.append(BrainToTextDataset(train_file, require_labels=True))
        if os.path.exists(val_file):
            val_sets.append(BrainToTextDataset(val_file, require_labels=True))
        if os.path.exists(test_file):
            test_sets.append(BrainToTextDataset(test_file, require_labels=False))  # test has no labels

    train_ds = ConcatDataset(train_sets) if train_sets else None
    val_ds   = ConcatDataset(val_sets) if val_sets else None
    test_ds  = ConcatDataset(test_sets) if test_sets else None

    return train_ds, val_ds, test_ds



# ---- Usage ----
base_dir = "data/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final"
train_ds, val_ds, test_ds = load_all_days(base_dir)

print("Total trials across all days:")
print("Train:", len(train_ds) if train_ds else 0)
print("Val:", len(val_ds) if val_ds else 0)
print("Test:", len(test_ds) if test_ds else 0)


Total trials across all days:
Train: 8072
Val: 1426
Test: 1450
