# ðŸ“Š EEG Emotion Recognition â€” Exploratory Data Analysis

This notebook provides a comprehensive visual exploration of the DEAP and SEED datasets:

1. **Raw EEG signals** â€” multi-channel time-series
2. **Power Spectral Density** â€” Welch PSD with frequency band overlays
3. **Electrode topography** â€” 2D scalp map
4. **Class distribution** â€” arousal/valence bar charts
5. **Signal statistics** â€” per-channel summary

> **Note:** If you don't have the real datasets yet, this notebook generates **synthetic EEG data** for demonstration.

In [None]:
import sys
sys.path.insert(0, '..')

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from scipy.signal import welch
import seaborn as sns

from src.utils.helpers import DEAP_ELECTRODE_NAMES, DEAP_ELECTRODE_POS_2D, FREQ_BANDS
from src.preprocessing.filters import bandpass_filter
from src.preprocessing.feature_extraction import compute_fft_features, compute_differential_entropy

sns.set_theme(style='whitegrid', palette='deep', font_scale=1.1)
%matplotlib inline

print('All imports OK âœ“')

---
## 1. Load Data

Tries to load real DEAP data.  Falls back to **synthetic EEG** if the dataset is unavailable.

In [None]:
import os, pickle

DEAP_DIR = '../data/raw/deap'
FS = 128.0  # Hz
N_CHANNELS = 32
USE_SYNTHETIC = True

# Try loading real data
deap_path = os.path.join(DEAP_DIR, 's01.dat')
if os.path.exists(deap_path):
    with open(deap_path, 'rb') as f:
        subject = pickle.load(f, encoding='latin1')
    data_all = subject['data'][:, :N_CHANNELS, :]  # (40, 32, 8064)
    labels_all = subject['labels']                   # (40, 4)
    USE_SYNTHETIC = False
    print(f'Loaded DEAP s01.dat â€” data: {data_all.shape}, labels: {labels_all.shape}')
else:
    print('DEAP data not found â€” generating synthetic EEG for demo.')
    np.random.seed(42)
    n_trials, n_samples = 40, 8064
    t = np.arange(n_samples) / FS
    
    data_all = np.zeros((n_trials, N_CHANNELS, n_samples))
    for trial in range(n_trials):
        for ch in range(N_CHANNELS):
            # Synthetic: mix of alpha (10Hz) + theta (6Hz) + noise
            alpha_amp = np.random.uniform(5, 20)
            theta_amp = np.random.uniform(2, 10)
            noise_amp = np.random.uniform(1, 5)
            data_all[trial, ch, :] = (
                alpha_amp * np.sin(2 * np.pi * 10 * t + np.random.uniform(0, 2*np.pi)) +
                theta_amp * np.sin(2 * np.pi * 6 * t + np.random.uniform(0, 2*np.pi)) +
                noise_amp * np.random.randn(n_samples)
            )
    
    labels_all = np.column_stack([
        np.random.uniform(1, 9, n_trials),  # valence
        np.random.uniform(1, 9, n_trials),  # arousal
        np.random.uniform(1, 9, n_trials),  # dominance
        np.random.uniform(1, 9, n_trials),  # liking
    ])
    print(f'Synthetic data: {data_all.shape}, labels: {labels_all.shape}')

# Select a single trial for detailed plots
trial_idx = 0
trial_data = data_all[trial_idx].T  # â†’ (n_samples, n_channels)
print(f'\nUsing trial {trial_idx}: shape = {trial_data.shape}')

---
## 2. Raw EEG Signal Visualisation

Multi-channel time-series with vertical offset for readability.

In [None]:
fig, ax = plt.subplots(figsize=(16, 12))

# Show first 5 seconds (640 samples)
show_samples = min(640, trial_data.shape[0])
t_axis = np.arange(show_samples) / FS

# Plot with vertical offset
offset = 0
ch_labels = DEAP_ELECTRODE_NAMES[:N_CHANNELS]
spacing = 80  # ÂµV spacing

for ch in range(min(N_CHANNELS, 16)):  # Show 16 channels
    signal = trial_data[:show_samples, ch]
    ax.plot(t_axis, signal + offset, linewidth=0.6, label=ch_labels[ch])
    ax.text(-0.15, offset, ch_labels[ch], fontsize=9, ha='right', va='center',
            fontweight='bold', color='#333')
    offset -= spacing

ax.set_xlabel('Time (s)', fontsize=12)
ax.set_title('Raw EEG â€” 16 Channels (first 5 seconds)', fontsize=14, fontweight='bold')
ax.set_yticks([])
ax.set_xlim([0, show_samples / FS])
ax.spines['left'].set_visible(False)
plt.tight_layout()
plt.savefig('../results/raw_eeg_signals.png', dpi=150, bbox_inches='tight')
plt.show()
print('Saved: results/raw_eeg_signals.png')

---
## 3. Power Spectral Density

Welch PSD for selected channels with frequency band overlays.

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
axes = axes.flatten()

# Select 4 representative channels
selected_channels = [0, 6, 14, 31]  # Fp1, C3, Oz, Cz
band_colors = {'delta': '#FF6B6B', 'theta': '#FFA07A', 'alpha': '#98D8C8',
               'beta': '#7EC8E3', 'gamma': '#B19CD9'}

for i, ch_idx in enumerate(selected_channels):
    ax = axes[i]
    freqs, psd = welch(trial_data[:, ch_idx], fs=FS, nperseg=256)
    
    ax.semilogy(freqs, psd, color='#333', linewidth=1.5, zorder=5)
    
    # Shade frequency bands
    for band_name, (f_lo, f_hi) in FREQ_BANDS.items():
        mask = (freqs >= f_lo) & (freqs <= f_hi)
        if mask.any():
            ax.fill_between(freqs[mask], psd[mask], alpha=0.3,
                          color=band_colors[band_name], label=band_name)
    
    ax.set_xlabel('Frequency (Hz)')
    ax.set_ylabel('PSD (ÂµVÂ²/Hz)')
    ax.set_title(f'Channel: {ch_labels[ch_idx]}', fontweight='bold')
    ax.set_xlim([0, 50])
    ax.legend(fontsize=8, loc='upper right')
    ax.grid(True, alpha=0.3)

plt.suptitle('Power Spectral Density (Welch)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig('../results/psd_analysis.png', dpi=150, bbox_inches='tight')
plt.show()
print('Saved: results/psd_analysis.png')

---
## 4. Electrode Topography

2D scalp map showing alpha-band power distribution across electrodes.

In [None]:
# Compute alpha-band power per channel
fft_features = compute_fft_features(trial_data, sampling_rate=FS)
alpha_power = fft_features[:, 2]  # Alpha is index 2

# Get electrode positions
positions = np.array([DEAP_ELECTRODE_POS_2D[name] for name in ch_labels])

fig, ax = plt.subplots(figsize=(8, 8))

# Draw head outline
theta = np.linspace(0, 2 * np.pi, 100)
ax.plot(np.cos(theta), np.sin(theta), 'k-', linewidth=2)

# Draw nose
ax.plot([0, -0.08, 0, 0.08, 0], [1.0, 1.08, 1.15, 1.08, 1.0], 'k-', linewidth=2)

# Draw ears
for sign in [-1, 1]:
    ear_x = sign * np.array([1.0, 1.05, 1.08, 1.05, 1.0])
    ear_y = np.array([0.15, 0.1, 0.0, -0.1, -0.15])
    ax.plot(ear_x, ear_y, 'k-', linewidth=1.5)

# Scatter electrodes coloured by alpha power
sc = ax.scatter(positions[:, 0], positions[:, 1], 
                c=alpha_power, s=400, cmap='YlOrRd',
                edgecolors='black', linewidths=1.5, zorder=5)

# Label electrodes
for i, name in enumerate(ch_labels):
    ax.annotate(name, (positions[i, 0], positions[i, 1]),
                fontsize=7, ha='center', va='center', fontweight='bold')

plt.colorbar(sc, ax=ax, label='Alpha Power', shrink=0.6)
ax.set_title('Electrode Topography â€” Alpha Band Power', fontsize=14, fontweight='bold')
ax.set_aspect('equal')
ax.set_xlim([-1.3, 1.3])
ax.set_ylim([-1.2, 1.3])
ax.axis('off')
plt.tight_layout()
plt.savefig('../results/electrode_topography.png', dpi=150, bbox_inches='tight')
plt.show()
print('Saved: results/electrode_topography.png')

---
## 5. Class Distribution

Distribution of arousal and valence labels after binarisation (threshold = 5.0).

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(16, 5))

# --- Continuous label distributions ---
ax = axes[0]
ax.hist(labels_all[:, 0], bins=15, alpha=0.7, color='#FF6B6B', label='Valence', edgecolor='white')
ax.hist(labels_all[:, 1], bins=15, alpha=0.7, color='#7EC8E3', label='Arousal', edgecolor='white')
ax.axvline(5.0, color='black', linestyle='--', linewidth=1.5, label='Threshold (5.0)')
ax.set_xlabel('Rating')
ax.set_ylabel('Count')
ax.set_title('Continuous Label Distribution', fontweight='bold')
ax.legend()

# --- Binary Arousal ---
ax = axes[1]
arousal_binary = (labels_all[:, 1] >= 5).astype(int)
counts_a = [np.sum(arousal_binary == 0), np.sum(arousal_binary == 1)]
bars = ax.bar(['Low', 'High'], counts_a, color=['#FFA07A', '#FF6B6B'],
              edgecolor='white', linewidth=2)
for bar, count in zip(bars, counts_a):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.3,
            str(count), ha='center', fontweight='bold')
ax.set_ylabel('Count')
ax.set_title('Arousal (Binary)', fontweight='bold')

# --- Binary Valence ---
ax = axes[2]
valence_binary = (labels_all[:, 0] >= 5).astype(int)
counts_v = [np.sum(valence_binary == 0), np.sum(valence_binary == 1)]
bars = ax.bar(['Negative', 'Positive'], counts_v, color=['#B19CD9', '#98D8C8'],
              edgecolor='white', linewidth=2)
for bar, count in zip(bars, counts_v):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.3,
            str(count), ha='center', fontweight='bold')
ax.set_ylabel('Count')
ax.set_title('Valence (Binary)', fontweight='bold')

plt.suptitle('Class Distribution', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig('../results/class_distribution.png', dpi=150, bbox_inches='tight')
plt.show()
print('Saved: results/class_distribution.png')

---
## 6. Signal Statistics

Per-channel summary: mean, std, skewness, kurtosis.

In [None]:
from scipy.stats import skew, kurtosis
import pandas as pd

stats = []
for ch in range(N_CHANNELS):
    sig = trial_data[:, ch]
    stats.append({
        'Channel': ch_labels[ch],
        'Mean': f'{np.mean(sig):.3f}',
        'Std': f'{np.std(sig):.3f}',
        'Min': f'{np.min(sig):.3f}',
        'Max': f'{np.max(sig):.3f}',
        'Skewness': f'{skew(sig):.3f}',
        'Kurtosis': f'{kurtosis(sig):.3f}',
    })

df_stats = pd.DataFrame(stats)
print(df_stats.to_string(index=False))

---
## 7. Differential Entropy Heatmap

DE features Ã— frequency bands â€” the primary features used by the FAT model.

In [None]:
de_features = compute_differential_entropy(trial_data, sampling_rate=FS)

fig, ax = plt.subplots(figsize=(10, 8))

sns.heatmap(de_features, 
            xticklabels=list(FREQ_BANDS.keys()),
            yticklabels=ch_labels,
            cmap='RdYlBu_r', center=0, annot=False,
            linewidths=0.3, linecolor='white', ax=ax)

ax.set_xlabel('Frequency Band', fontsize=12)
ax.set_ylabel('Electrode', fontsize=12)
ax.set_title('Differential Entropy Features', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig('../results/de_heatmap.png', dpi=150, bbox_inches='tight')
plt.show()
print('Saved: results/de_heatmap.png')
print(f'\nDE features shape: {de_features.shape} (channels Ã— bands)')

---
## 8. Band-Pass Filtering Effect

Visual comparison of raw vs. filtered signal.

In [None]:
# Filter the trial data
trial_filtered = bandpass_filter(trial_data, low=0.5, high=50.0, fs=FS)

fig, axes = plt.subplots(2, 1, figsize=(14, 6), sharex=True)

ch_show = 0  # Fp1
show_sec = 3
n_show = int(show_sec * FS)
t_show = np.arange(n_show) / FS

axes[0].plot(t_show, trial_data[:n_show, ch_show], color='#FF6B6B', linewidth=0.8)
axes[0].set_title(f'Raw Signal â€” {ch_labels[ch_show]}', fontweight='bold')
axes[0].set_ylabel('Amplitude (ÂµV)')
axes[0].grid(True, alpha=0.3)

axes[1].plot(t_show, trial_filtered[:n_show, ch_show], color='#7EC8E3', linewidth=0.8)
axes[1].set_title(f'Band-Pass Filtered (0.5â€“50 Hz) â€” {ch_labels[ch_show]}', fontweight='bold')
axes[1].set_xlabel('Time (s)')
axes[1].set_ylabel('Amplitude (ÂµV)')
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('../results/filter_comparison.png', dpi=150, bbox_inches='tight')
plt.show()
print('Saved: results/filter_comparison.png')

---

## Summary

| Metric | Value |
|--------|-------|
| Data source | DEAP (or synthetic) |
| Channels | 32 |
| Sampling rate | 128 Hz |
| Frequency bands | Î´, Î¸, Î±, Î², Î³ |
| Feature types | FFT PSD, DCT, DWT, DE |

**Next:** Phase 2 â€” Build classical ML baselines (SVM, KNN, RF, XGBoost).