# EEG Data Exploration

This notebook demonstrates how to:
1. Load preprocessed EEG data from .mat files
2. Merge with behavioral data (familiarity/liking ratings)
3. Explore the data structure
4. Extract basic features
5. Train a simple classifier

In [7]:
import sys
sys.path.insert(0, '..')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.io import loadmat

%matplotlib inline

## 1. Load Data

In [8]:
# Path to the preprocessed .mat file (MATLAB v5 format)
MAT_FILE = "../../preprocessed_data/sub-001.mat"

# Load the .mat file
mat_data = loadmat(MAT_FILE, squeeze_me=True, struct_as_record=False)

# Show all variables (excluding MATLAB metadata)
print("Variables in .mat file:")
for key in mat_data.keys():
    if not key.startswith('__'):
        val = mat_data[key]
        print(f"  {key}: type={type(val).__name__}, shape={getattr(val, 'shape', 'N/A')}")

Variables in .mat file:
  cellArray: type=mat_struct, shape=N/A


In [9]:
# Explore the FieldTrip structure
# Based on preprocessing.m, the variable is 'eeg_data' (FieldTrip raw format)

# Get the main data variable
eeg_data = mat_data['eeg_data']

# Show all fields in the FieldTrip structure
print("FieldTrip structure fields:")
for attr in dir(eeg_data):
    if not attr.startswith('_'):
        val = getattr(eeg_data, attr)
        if hasattr(val, 'shape'):
            print(f"  .{attr}: shape={val.shape}, dtype={val.dtype}")
        elif isinstance(val, np.ndarray):
            print(f"  .{attr}: array with {len(val)} elements")
        else:
            print(f"  .{attr}: {type(val).__name__} = {val}")

KeyError: 'eeg_data'

In [None]:
# Extract trial data from FieldTrip structure
# .trial is a cell array where each cell contains (n_channels x n_samples)

trials = eeg_data.trial  # Array of trial matrices
fsample = eeg_data.fsample  # Sampling rate
labels = eeg_data.label  # Channel labels

print(f"Number of trials: {len(trials)}")
print(f"Sampling rate: {fsample} Hz")
print(f"Number of channels: {len(labels)}")
print(f"First trial shape: {trials[0].shape}")

# Check if all trials have same length
trial_lengths = [t.shape[1] for t in trials]
if len(set(trial_lengths)) == 1:
    print(f"All trials have {trial_lengths[0]} samples ({trial_lengths[0]/fsample:.2f} s)")
else:
    print(f"Variable trial lengths: min={min(trial_lengths)}, max={max(trial_lengths)}")

In [None]:
# Convert to 3D numpy array: (n_trials, n_channels, n_samples)
# Note: FieldTrip stores as (channels x samples), so we stack along new axis

epochs = np.stack(trials, axis=0)  # Stack trials
# epochs is now (n_trials, n_channels, n_samples)

print(f"Epochs array shape: {epochs.shape}")
print(f"  - {epochs.shape[0]} trials")
print(f"  - {epochs.shape[1]} channels") 
print(f"  - {epochs.shape[2]} samples ({epochs.shape[2]/fsample:.2f} seconds)")

# Channel names
print(f"\nChannel names: {list(labels[:10])}..." if len(labels) > 10 else f"\nChannel names: {list(labels)}")

In [None]:
# Visualize a single trial
trial_idx = 0
channel_idx = 0

fig, axes = plt.subplots(2, 1, figsize=(14, 6))

# Time vector (based on epoch window: -3 to 32 seconds from preprocessing.m)
time = np.arange(epochs.shape[2]) / fsample - 3  # Offset by -3s (epoch start)

# Plot single channel
ax = axes[0]
ax.plot(time, epochs[trial_idx, channel_idx, :], 'b-', linewidth=0.5)
ax.axvline(x=0, color='r', linestyle='--', alpha=0.5, label='Stimulus onset')
ax.set_xlabel('Time (s)')
ax.set_ylabel('Amplitude (µV)')
ax.set_title(f'Trial {trial_idx + 1}, Channel: {labels[channel_idx]}')
ax.legend()
ax.grid(True, alpha=0.3)

# Butterfly plot (multiple channels)
ax = axes[1]
for ch in range(min(10, epochs.shape[1])):
    ax.plot(time, epochs[trial_idx, ch, :], linewidth=0.3, alpha=0.7)
ax.axvline(x=0, color='r', linestyle='--', alpha=0.5)
ax.set_xlabel('Time (s)')
ax.set_ylabel('Amplitude (µV)')
ax.set_title(f'Trial {trial_idx + 1}, First 10 channels (butterfly plot)')
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 2. Explore Class Distribution

In [None]:
# # Familiarity rating distribution
# fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# behavioral['familiarity_rating'].value_counts().sort_index().plot(
#     kind='bar', ax=axes[0], color='steelblue'
# )
# axes[0].set_title('Familiarity Ratings')
# axes[0].set_xlabel('Rating')
# axes[0].set_ylabel('Count')

# behavioral['liking_rating'].value_counts().sort_index().plot(
#     kind='bar', ax=axes[1], color='coral'
# )
# axes[1].set_title('Liking Ratings')
# axes[1].set_xlabel('Rating')

# plt.tight_layout()
# plt.show()

## 3. Create Classification Labels

In [None]:
from data.loader import create_labels

# # Binary labels: familiar (4-5) vs unfamiliar (1-2)
# labels = create_labels(behavioral, target_type="familiarity_binary")
# unique, counts = np.unique(labels, return_counts=True)
# print(f"Binary labels: {dict(zip(unique, counts))}")

## 4. Extract Features

In [None]:
from features.frequency_domain import extract_frequency_features, compute_band_power
from features.time_domain import extract_time_features

# # Assuming you have epochs loaded as shape (n_trials, n_channels, n_samples)
# # epochs = ...

# # Extract frequency features
# freq_features = extract_frequency_features(epochs)
# print(f"Frequency features shape: {freq_features.shape}")

# # Extract time features
# time_features = extract_time_features(epochs)
# print(f"Time features shape: {time_features.shape}")

## 5. Train a Simple Classifier

In [None]:
from models.svm import SVMClassifier
from evaluation.cross_validation import cross_validate

# # Combine features
# X = np.concatenate([freq_features, time_features], axis=1)
# y = labels

# # Train SVM with cross-validation
# model = SVMClassifier(kernel="rbf", C=1.0)
# results = cross_validate(model, X, y, n_folds=5, verbose=True)

# print(f"\nMean Accuracy: {results['mean_score']:.3f} ± {results['std_score']:.3f}")

## 6. Visualize Results

In [None]:
from utils.visualization import plot_cv_results, plot_confusion_matrix

# # Plot CV results
# plot_cv_results(results, title="SVM Cross-Validation")
# plt.show()

# # Plot confusion matrix
# if 'predictions' in results:
#     plot_confusion_matrix(
#         y, results['predictions'],
#         class_names=['Unfamiliar', 'Familiar']
#     )
#     plt.show()

## Next Steps

1. Try different classifiers (LDA, EEGNet)
2. Experiment with different feature combinations
3. Use hyperparameter tuning
4. Try leave-one-subject-out cross-validation for multi-subject data