In [2]:
import os
import json
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer


class EEGTextMetaDataset(Dataset):
    def __init__(self, eeg_dir, metadata_dir, tokenizer, max_length=64, use_emotional_tone=True):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.use_emotional_tone = use_emotional_tone

        # -------------------------
        # 1. Load EEG files (all subjects)
        # -------------------------
        eeg_files = []
        for root, dirs, files in os.walk(eeg_dir):
            for f in files:
                if f.endswith(".npy") and "_preprocessed" in f:
                    eeg_files.append(os.path.join(root, f))
        eeg_files = sorted(eeg_files)

        if not eeg_files:
            raise FileNotFoundError(f"No EEG .npy files found in {eeg_dir}")

        self.eeg_file_paths = eeg_files
        self.eeg_data_list = []
        self.index_map = []

        for subj_idx, path in enumerate(self.eeg_file_paths):
            eeg = np.load(path, mmap_mode='r')

            # --- Assertion 1: EEG shape ---
            assert eeg.ndim == 3 and eeg.shape[1:] == (62, 400), \
                f"EEG file {path} has shape {eeg.shape}, expected (*, 62, 400)"

            self.eeg_data_list.append(eeg)

            # build index map
            n_samples = eeg.shape[0]
            self.index_map.extend([(subj_idx, i) for i in range(n_samples)])

        total_samples = len(self.index_map)
        print(f"Found {len(self.eeg_file_paths)} EEG files → Total samples: {total_samples}")

        # -------------------------
        # 2. Load Metadata JSONs
        # -------------------------
        metadata_files = sorted(
            [os.path.join(dp, f)
             for dp, dn, filenames in os.walk(metadata_dir)
             for f in filenames if f.endswith(".json")]
        )
        if not metadata_files:
            raise FileNotFoundError(f"No metadata JSON files found in {metadata_dir}")

        self.metadata_list = []
        for fpath in metadata_files:
            with open(fpath, 'r', encoding='utf-8') as f:
                meta = json.load(f)

                # --- Assertion 2: Metadata content validity ---
                assert "semantic_features" in meta and "scene_category" in meta["semantic_features"], \
                    f"Missing scene_category in {fpath}"
                assert "visual_attributes" in meta and "major_colors" in meta["visual_attributes"], \
                    f"Missing major_colors in {fpath}"
                assert "optical_flow_score" in meta["visual_attributes"], \
                    f"Missing optical_flow_score in {fpath}"

                self.metadata_list.append(meta)

        base_count = len(self.metadata_list)
        print(f"Loaded {base_count} metadata JSON files")

        # -------------------------
        # 3. Build base captions
        # -------------------------
        base_captions = []
        for meta in self.metadata_list:
            caption_text = meta["caption"]["text"]
            if self.use_emotional_tone and "emotional_tone" in meta["caption"]:
                caption_text += f". Tone: {meta['caption']['emotional_tone']}"
            base_captions.append(caption_text)

        # --- Assertion 3: captions vs metadata ---
        assert len(base_captions) == len(self.metadata_list), \
            "Mismatch between metadata and base captions count"

        # Repeat for each subject
        num_subjects = len(self.eeg_file_paths)
        self.captions = base_captions * num_subjects
        self.metadata_repeated = self.metadata_list * num_subjects

        # --- Assertion 4: Repeated counts ---
        assert len(self.captions) == len(self.metadata_repeated) == len(self.index_map), \
            "Mismatch after repeating captions and metadata for subjects"

        # -------------------------
        # 4. Encode metadata categories
        # -------------------------
        scene_categories = sorted(list({m["semantic_features"]["scene_category"] for m in self.metadata_list}))
        colors = sorted(list({m["visual_attributes"]["major_colors"][0]["color"].split()[0]
                              for m in self.metadata_list}))

        self.scene_to_id = {scene: i for i, scene in enumerate(scene_categories)}
        self.color_to_id = {c: i for i, c in enumerate(colors)}

        print(f"Scene categories: {len(self.scene_to_id)} | Colors: {len(self.color_to_id)}")

    def __len__(self):
        return len(self.index_map)

    def __getitem__(self, idx):
        subj_idx, local_idx = self.index_map[idx]

        # 1. EEG tensor
        eeg_tensor = torch.tensor(self.eeg_data_list[subj_idx][local_idx], dtype=torch.float32)  # (62, 400)

        # 2. Tokenized caption
        caption = self.captions[idx]
        tokenized = self.tokenizer(
            caption,
            padding='max_length',
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )
        tokenized = {k: v.squeeze(0) for k, v in tokenized.items()}

        # 3. Metadata tensor
        meta = self.metadata_repeated[idx]
        scene_id = self.scene_to_id[meta["semantic_features"]["scene_category"]]
        color_id = self.color_to_id[meta["visual_attributes"]["major_colors"][0]["color"].split()[0]]
        motion_score = float(meta["visual_attributes"]["optical_flow_score"]["value"])
        metadata_tensor = torch.tensor([scene_id, color_id, motion_score], dtype=torch.float32)

        return eeg_tensor, tokenized, metadata_tensor


# -------------------------------
# Test the dataset
# -------------------------------
if __name__ == "__main__":
    tokenizer = BertTokenizer.from_pretrained("/home/poorna/models/bert-base-uncased")

    EEG_DIR = "/home/poorna/data/preprocessed_eeg"
    METADATA_DIR = "/home/poorna/data/metadata_dir"

    dataset = EEGTextMetaDataset(
        eeg_dir=EEG_DIR,
        metadata_dir=METADATA_DIR,
        tokenizer=tokenizer,
        max_length=64
    )

    loader = DataLoader(dataset, batch_size=32, shuffle=True)

    for eeg_batch, tokenized_batch, meta_batch in loader:
        print("\nEEG batch:", eeg_batch.shape)
        print("Input IDs:", tokenized_batch["input_ids"].shape)
        print("Metadata batch:", meta_batch.shape)

        # Sample inspection
        sample_idx = 0
        print("\nSample Inspection:")
        print("EEG shape:", eeg_batch[sample_idx].shape)
        print("Token IDs:", tokenized_batch["input_ids"][sample_idx][:20])
        print("Decoded caption:", tokenizer.decode(
            tokenized_batch["input_ids"][sample_idx], skip_special_tokens=True
        ))
        print("Metadata tensor:", meta_batch[sample_idx])

        scene_id = int(meta_batch[sample_idx][0].item())
        color_id = int(meta_batch[sample_idx][1].item())
        motion_score = float(meta_batch[sample_idx][2].item())
        scene_name = [k for k, v in dataset.scene_to_id.items() if v == scene_id][0]
        color_name = [k for k, v in dataset.color_to_id.items() if v == color_id][0]

        print(f"Scene: {scene_name}, Color: {color_name}, Motion: {motion_score:.3f}")
        break

    print(f"\nTotal samples loaded: {len(dataset)}")



Found 20 EEG files → Total samples: 28000
Loaded 1400 metadata JSON files
Scene categories: 77 | Colors: 53

EEG batch: torch.Size([32, 62, 400])
Input IDs: torch.Size([32, 64])
Metadata batch: torch.Size([32, 3])

Sample Inspection:
EEG shape: torch.Size([62, 400])
Token IDs: tensor([  101, 14231,  2980,  2250, 22163, 14257,  2114,  1037,  3154,  2630,
         3712,  1012,  1012,  4309,  1024, 25388,   102,     0,     0,     0])
Decoded caption: colorful hot air balloons float against a clear blue sky.. tone : serene
Metadata tensor: tensor([1.0000, 3.0000, 0.1000])
Scene: aerial, Color: blue, Motion: 0.100

Total samples loaded: 28000


In [5]:
import h5py
import torch
from torch.utils.data import DataLoader
from transformers import BertTokenizer

# --- Paths ---
HDF5_FILE = "eeg_dataset.h5"
EEG_DIR = "/home/poorna/data/preprocessed_eeg"
METADATA_DIR = "/home/poorna/data/metadata_dir"

# --- Dataset ---
tokenizer = BertTokenizer.from_pretrained("/home/poorna/models/bert-base-uncased")
dataset = EEGTextMetaDataset(EEG_DIR, METADATA_DIR, tokenizer)
loader = DataLoader(dataset, batch_size=1, shuffle=False)  # batch=1 for sequential save

# --- Create HDF5 file ---
with h5py.File(HDF5_FILE, "w") as f:
    n_samples = len(dataset)

    # Create datasets (preallocate space)
    eeg_shape = (n_samples, 62, 400)               # EEG shape
    token_shape = (n_samples, dataset.max_length)  # tokenized input_ids length
    meta_shape = (n_samples, 3)                    # scene_id, color_id, motion_score

    eeg_ds = f.create_dataset("eeg", shape=eeg_shape, dtype="float32")
    tokens_ds = f.create_dataset("input_ids", shape=token_shape, dtype="int64")
    meta_ds = f.create_dataset("metadata", shape=meta_shape, dtype="float32")

    # Iterate and save directly to file (low memory usage)
    for idx, (eeg_tensor, tokenized, metadata_tensor) in enumerate(loader):
        eeg_ds[idx] = eeg_tensor.squeeze(0).numpy()
        tokens_ds[idx] = tokenized["input_ids"].squeeze(0).numpy()
        meta_ds[idx] = metadata_tensor.squeeze(0).numpy()

print(f"Saved dataset to {HDF5_FILE}")

Found 20 EEG files → Total samples: 28000
Loaded 1400 metadata JSON files
Scene categories: 77 | Colors: 53
Saved dataset to eeg_dataset.h5


In [6]:
import h5py
import torch

with h5py.File("eeg_dataset.h5", "r") as f:
    print(list(f.keys()))  # ['eeg', 'input_ids', 'metadata']
    eeg_sample = torch.tensor(f["eeg"][0])         # first sample EEG tensor
    tokens_sample = torch.tensor(f["input_ids"][0])
    meta_sample = torch.tensor(f["metadata"][0])

print(eeg_sample.shape, tokens_sample.shape, meta_sample.shape)

['eeg', 'input_ids', 'metadata']
torch.Size([62, 400]) torch.Size([64]) torch.Size([3])


In [7]:
import h5py

HDF5_FILE = "eeg_dataset.h5"

with h5py.File(HDF5_FILE, "r") as f:
    print("\nDatasets in file:", list(f.keys()))

    for name in f.keys():
        dset = f[name]
        print(f"{name}: shape={dset.shape}, dtype={dset.dtype}")

    # Optional: verify a few entries
    sample_idx = 0
    print("\nSample check:")
    print("EEG sample shape:", f["eeg"][sample_idx].shape)
    print("Tokens sample shape:", f["input_ids"][sample_idx].shape)
    print("Metadata sample shape:", f["metadata"][sample_idx].shape)


Datasets in file: ['eeg', 'input_ids', 'metadata']
eeg: shape=(28000, 62, 400), dtype=float32
input_ids: shape=(28000, 64), dtype=int64
metadata: shape=(28000, 3), dtype=float32

Sample check:
EEG sample shape: (62, 400)
Tokens sample shape: (64,)
Metadata sample shape: (3,)
