In [None]:
import os
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms

class EEGMultimodalDataset(Dataset):
    """
    PyTorch Dataset class for loading the multimodal EEG, Image, and Text data.
    
    UPDATE 3: Now correctly handles CSVs WITH a header row.
    """
    def __init__(self, 
                 bids_root,          # Path to the .../ds005589/ directory
                 images_dir,         # Path to the .../All_images/ directory
                 captions_path,      # Path to the captions.txt file
                 subject_list,       # List of subjects to load, e.g., ['sub-02', 'sub-03']
                 session_list,       # List of sessions to load, e.g., ['ses-01', 'ses-02']
                 image_transform=None, # PyTorch transforms for the images
                 clamp_thres=500     # Clamping threshold for EEG in microvolts
                ):
        
        self.bids_root = bids_root
        self.images_dir = images_dir
        self.image_transform = image_transform
        self.clamp_thres = clamp_thres

        self.all_eeg_trials = []
        self.all_image_paths = []
        self.all_captions = []
        self.all_categories = []
        
        print("Initializing dataset... This may take a moment.")
        
        print(f"Loading captions from {captions_path}...")
        self.captions_dict = self._load_captions(captions_path)
        print(f"Loaded {len(self.captions_dict)} captions.")

        for sub in subject_list:
            for ses in session_list:
                for run in ['01', '02', '03', '04']:
                    
                    session_path = os.path.join(self.bids_root, sub, ses)
                    csv_path = os.path.join(session_path, f"{sub}_{ses}_task-lowSpeed_run-{run}_image.csv")
                    npy_path = os.path.join(session_path, f"{sub}_{ses}_task-lowSpeed_run-{run}_1000Hz.npy")
                    
                    if not (os.path.exists(csv_path) and os.path.exists(npy_path)):
                        print(f"Warning: Missing files for {sub} {ses} {run}. Skipping.")
                        continue
                        
                    # 1. Parse metadata (the .csv)
                    try:
                        # --- FIX 1: Read the CSV WITH its header ---
                        csv_data = pd.read_csv(csv_path) 
                    except Exception as e:
                        print(f"Error reading CSV {csv_path}: {e}. Skipping run.")
                        continue
                    
                    # 2. Load EEG trials (the .npy)
                    eeg_data = np.load(npy_path) 
                    
                    # 5. Verify correspondence
                    # Now csv_data (100) and eeg_data.shape[0] (100) should match
                    if eeg_data.shape[0] != len(csv_data):
                        print(f"Warning: Trial mismatch in {sub} {ses} {run}. "
                              f"EEG has {eeg_data.shape[0]}, CSV has {len(csv_data)}. Skipping.")
                        continue
                        
                    for i, row in csv_data.iterrows():
                        
                        # Step 1: Get image identifier
                        # --- FIX 2: Access the column by its name 'FilePath' ---
                        img_base_name = self._get_base_name(row['FilePath']) 
                        if not img_base_name:
                            continue
                        
                        # Step 3: Merge with captions
                        category, caption = self.captions_dict.get(img_base_name, ("Unknown", "No Caption"))
                        
                        # Step 4: Resolve image path
                        img_path = self._find_image_path(img_base_name)
                        if not img_path:
                            # print(f"Warning: Could not find image file for {img_base_name}. Skipping trial.")
                            continue 
                            
                        self.all_eeg_trials.append(eeg_data[i])   
                        self.all_image_paths.append(img_path)     
                        self.all_captions.append(caption)         

        print(f"Found {len(self.all_eeg_trials)} total aligned trials.")
        
        if len(self.all_eeg_trials) == 0:
            print("ERROR: No trials were loaded. Check your BIDS_ROOT, IMAGE_DIR, and CAPTIONS_FILE paths.")
            self.eeg_dataset = np.array([])
            self.all_categories.append(category)
            self.image_paths = []
            self.captions = []
            return

        eeg_dataset = np.array(self.all_eeg_trials, dtype=np.float32)
        
        # 1. Clamp
        eeg_dataset[eeg_dataset >  self.clamp_thres] =  self.clamp_thres
        eeg_dataset[eeg_dataset < -self.clamp_thres] = -self.clamp_thres
        
        # 2. Normalize
        sample_num, channel_num, time_num = eeg_dataset.shape
        eeg_dataset_flat = eeg_dataset.reshape(sample_num, -1)
        
        mean = np.mean(eeg_dataset_flat, axis=0)
        std = np.std(eeg_dataset_flat, axis=0)
        
        eeg_dataset_flat = (eeg_dataset_flat - mean) / (std + 1e-6)
        
        self.eeg_dataset = eeg_dataset_flat.reshape(sample_num, channel_num, time_num)
        self.image_paths = self.all_image_paths
        self.captions = self.all_captions
        self.categories = self.all_categories
        
        print("Dataset initialization complete.")

    def _load_captions(self, captions_path):
            """
            Helper to load captions.txt into a dictionary.
            UPDATED: Now handles TAB-separated columns and skips the header row.
            """
            captions_dict = {}
            with open(captions_path, 'r') as f:
                # Skip the header line
                next(f) 
                
                for line in f:
                    # CORRECT: Split by tab character
                    parts = line.strip().split('\t') 

                    # We expect exactly 4 parts: 
                    # [Source] [Category] [Image_ID] [Caption]
                    if len(parts) == 4:
                        source = parts[0]   # We don't strictly need this, but good to keep
                        category = parts[1]
                        img_name = parts[2]
                        caption = parts[3]  # The caption is the 4th part
                        
                        captions_dict[img_name] = (category, caption)
                    # else:
                    #     print(f"Warning: Skipping malformed line in captions.txt: {line.strip()}")
            return captions_dict

    def _get_base_name(self, file_path):
            """
            Helper to extract the base image name.
            Also removes the '_resized' suffix. ADDED DEBUGGING.
            """
            try:
                # --- ADDED: Print the raw input ---

                normalized_path = str(file_path).replace('\\', '/') 
    
                # Use the normalized path
                base_name_with_ext = os.path.basename(normalized_path) 
                # --- ADDED: Print after basename ---
                # -----------------------------------
    
                base_name_resized = os.path.splitext(base_name_with_ext)[0]
                
                # -----------------------------------
                
                if base_name_resized.endswith('_resized'):
                    base_name = base_name_resized[:-len('_resized')]
                else:
                    base_name = base_name_resized
                
               
                return base_name 
            except Exception as e:
                 # --- ADDED: Print any error ---
                print(f"ERROR in _get_base_name: {e}")
                # --------------------------------
                return None

    def _find_image_path(self, img_base_name):
        """Helper to find the full image path, checking for extensions."""
        for ext in ['.jpg', '.jpeg', '.png', '.JPEG']: 
            img_path = os.path.join(self.images_dir, img_base_name + ext)
            if os.path.exists(img_path):
                return img_path
        return None 

    def __len__(self):
        """Returns the total number of aligned trials."""
        return len(self.eeg_dataset)

    def __getitem__(self, idx):
        """Returns one aligned (EEG, Image, Text) triplet."""
        eeg_tensor = torch.tensor(self.eeg_dataset[idx]).float()
        caption = self.captions[idx] 
        img_path = self.image_paths[idx]
        
        try:
            image = Image.open(img_path).convert('RGB')
            if self.image_transform:
                image_tensor = self.image_transform(image)
            else:
                image_tensor = transforms.ToTensor()(image)
        except Exception as e:
            print(f"Error loading image {img_path}: {e}. Returning a dummy image.")
            image_tensor = torch.zeros(3, 224, 224) 

        category = self.categories[idx]
        return eeg_tensor, image_tensor, caption, category

In [2]:
import os
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

# --- 1. Define Your Paths ---
# (Update these paths to match your system)
BIDS_ROOT = '/ocean/projects/cis250019p/gandotra/11785-gp-eeg/ds005589'
IMAGE_DIR = '/ocean/projects/cis250019p/gandotra/11785-gp-eeg/images'
CAPTIONS_FILE = '/ocean/projects/cis250019p/gandotra/11785-gp-eeg/captions.txt'

# --- 2. Define Your Subject List ---
ALL_SUBJECTS = ['sub-02', 'sub-03', 'sub-05', 'sub-09', 'sub-14', 'sub-15', 
                'sub-17', 'sub-19', 'sub-20', 'sub-23', 'sub-24', 'sub-28', 'sub-29']

# --- 3. Define Image Transforms (e.g., for CLIP) ---
# (You would get the specific transforms from your model)
image_transforms = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# --- 4. Create the 3 Datasets (Train/Val/Test) ---
# This perfectly follows the paper's "split by session" rule.

print("Creating Training Dataset...")
train_dataset = EEGMultimodalDataset(
    bids_root=BIDS_ROOT,
    images_dir=IMAGE_DIR,
    captions_path=CAPTIONS_FILE,
    subject_list=ALL_SUBJECTS,
    session_list=['ses-01', 'ses-02', 'ses-03'], # 3 sessions for training
    image_transform=image_transforms
)

print("\nCreating Validation Dataset...")
val_dataset = EEGMultimodalDataset(
    bids_root=BIDS_ROOT,
    images_dir=IMAGE_DIR,
    captions_path=CAPTIONS_FILE,
    subject_list=ALL_SUBJECTS,
    session_list=['ses-04'], # 1 session for validation
    image_transform=image_transforms
)

print("\nCreating Test Dataset...")
test_dataset = EEGMultimodalDataset(
    bids_root=BIDS_ROOT,
    images_dir=IMAGE_DIR,
    captions_path=CAPTIONS_FILE,
    subject_list=ALL_SUBJECTS,
    session_list=['ses-05'], # 1 session for testing
    image_transform=image_transforms
)

# --- 5. Create PyTorch DataLoaders ---
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)

# --- 6. Test the loader ---
print("\nTesting the training loader...")
eeg_batch, image_batch, caption_batch = next(iter(train_loader))

print(f"EEG batch shape:   {eeg_batch.shape}")
print(f"Image batch shape: {image_batch.shape}")
print(f"Caption batch (first item): '{caption_batch[0]}'")

Creating Training Dataset...
Initializing dataset... This may take a moment.
Loading captions from /ocean/projects/cis250019p/gandotra/11785-gp-eeg/captions.txt...
Loaded 9825 captions.
Found 15600 total aligned trials.
Dataset initialization complete.

Creating Validation Dataset...
Initializing dataset... This may take a moment.
Loading captions from /ocean/projects/cis250019p/gandotra/11785-gp-eeg/captions.txt...
Loaded 9825 captions.
Found 5200 total aligned trials.
Dataset initialization complete.

Creating Test Dataset...
Initializing dataset... This may take a moment.
Loading captions from /ocean/projects/cis250019p/gandotra/11785-gp-eeg/captions.txt...
Loaded 9825 captions.
Found 5200 total aligned trials.
Dataset initialization complete.

Testing the training loader...
EEG batch shape:   torch.Size([32, 500, 122])
Image batch shape: torch.Size([32, 3, 224, 224])
Caption batch (first item): 'Bottle with message lying on sandy beach'
