# EEG State-wise Channel Mean Alpha Power Analysis

This notebook loads EEG data for a specific subject from a specified dataset. It then calculates the mean alpha power (summed power in 8-12Hz band) for each state across all available channels. The mean is computed by:
1. Calculating alpha power for each epoch and channel using `Metrics.alpha_power()`.
2. Averaging this alpha power across epochs for each channel within a specific task-state.
3. Averaging these per-channel means across all tasks that include the state.

In [2]:
import numpy as np
import os
import sys

from eeg_analyzer.eeg_analyzer import EEGAnalyzer
from eeg_analyzer.metrics import Metrics # Added Metrics import
from utils.config import DATASETS

# --- Configuration ---
DATASET_NAME = "braboszcz2017"
SUBJECT_ID = "078"

In [3]:
# --- Load Data ---
if DATASET_NAME not in DATASETS:
    raise ValueError(f"Dataset configuration for '{DATASET_NAME}' not found.")
dataset_config = DATASETS[DATASET_NAME]

print(f"Initializing EEGAnalyzer for dataset: {DATASET_NAME}...")
# EEGAnalyzer expects a list of dataset configurations
analyzer = EEGAnalyzer(dataset_configs=[dataset_config])

print(f"Getting dataset '{DATASET_NAME}' from analyzer...")
eeg_dataset = analyzer.get_dataset(DATASET_NAME)
if eeg_dataset is None:
    raise RuntimeError(f"Could not retrieve dataset '{DATASET_NAME}' from EEGAnalyzer.")

print(f"Loading subjects for dataset '{DATASET_NAME}'...")
# The EEGAnalyzer's __init__ calls dataset.load_data(), which should call load_subjects().
# If not, uncomment the line below. Based on eeg_analyzer.py, dataset.load_data() is called.
# eeg_dataset.load_subjects() # Ensure subjects are loaded if not done by EEGAnalyzer

print(f"Getting subject '{SUBJECT_ID}' from dataset '{DATASET_NAME}'...")
subject = eeg_dataset.get_subject(SUBJECT_ID)
if subject is None:
    raise RuntimeError(f"Subject '{SUBJECT_ID}' not found in dataset '{DATASET_NAME}'.")

print(f"Successfully loaded: {subject}")
print(f"Subject has {len(subject.get_all_recordings())} recording(s).")

Initializing EEGAnalyzer for dataset: braboszcz2017...
Getting dataset 'braboszcz2017' from analyzer...
Loading subjects for dataset 'braboszcz2017'...
Getting subject '078' from dataset 'braboszcz2017'...
Successfully loaded: <Subject 078 (vip) - 1 sessions>
Subject has 1 recording(s).
Getting dataset 'braboszcz2017' from analyzer...
Loading subjects for dataset 'braboszcz2017'...
Getting subject '078' from dataset 'braboszcz2017'...
Successfully loaded: <Subject 078 (vip) - 1 sessions>
Subject has 1 recording(s).


In [4]:
# --- Calculate Mean Alpha Power per State per Channel ---
aggregated_state_channel_alpha_means = {}
# {state_name: {'sum_alpha_means': np.array, 'task_count': 0}}

channel_names = None

for recording in subject.get_all_recordings():
    if channel_names is None:
        channel_names = recording.get_channel_names() # Use getter method
    
    for task in recording.get_available_tasks():
        for state in recording.get_available_states(task):
            try:
                psd_data = recording.get_psd(task, state) # Expected shape: (n_epochs, n_channels, n_frequencies)
                freqs_data = recording.get_freqs(task, state)
            except ValueError as e:
                print(f"Warning: Could not get PSD/freqs for task '{task}', state '{state}': {e}. Skipping.")
                continue
            
            if psd_data.ndim == 3 and psd_data.shape[0] > 0: # n_epochs > 0
                # Calculate alpha power for each epoch and channel
                alpha_power_epochs_channels = Metrics.alpha_power(psd_data, freqs_data) # Shape: (n_epochs, n_channels)
                # Calculate mean alpha power across epochs for each channel
                mean_alpha_power_this_task_state = np.mean(alpha_power_epochs_channels, axis=0) # Shape: (n_channels,)
            elif psd_data.ndim == 2 and psd_data.shape[0] > 0: # Already (n_epochs, n_channels) if precomputed band power
                print(f"Info: PSD data for task '{task}', state '{state}' is 2D, assuming it's already band power (epochs, channels). Using directly.")
                mean_alpha_power_this_task_state = np.mean(psd_data, axis=0)
            else:
                print(f"Warning: PSD data for task '{task}', state '{state}' has unexpected shape {psd_data.shape} or no epochs. Skipping.")
                continue

            if state not in aggregated_state_channel_alpha_means:
                aggregated_state_channel_alpha_means[state] = {
                    'sum_alpha_means': np.zeros_like(mean_alpha_power_this_task_state),
                    'task_count': 0
                }
            
            aggregated_state_channel_alpha_means[state]['sum_alpha_means'] += mean_alpha_power_this_task_state
            aggregated_state_channel_alpha_means[state]['task_count'] += 1

if channel_names is None and subject.get_all_recordings():
   # Fallback if psd_map was empty but recordings exist and channel_names not set
   try:
       channel_names = subject.get_channel_names()
   except ValueError as e:
       print(f"Could not determine channel names: {e}")

print("Aggregation complete.")
if not aggregated_state_channel_alpha_means:
    print("No data found to aggregate.")

Aggregation complete.


In [5]:
# --- Display Results ---
final_alpha_results = {}
print(f"\nMean Alpha Power per Channel for Subject '{SUBJECT_ID}' in Dataset '{DATASET_NAME}':")

if not aggregated_state_channel_alpha_means:
    print("No results to display.")
else:
    if channel_names is None:
        print("Error: Channel names could not be determined.")
    else:
        for state, data in aggregated_state_channel_alpha_means.items():
            if data['task_count'] > 0:
                final_mean_alpha = data['sum_alpha_means'] / data['task_count']
                final_alpha_results[state] = final_mean_alpha
                print(f"\nState: {state} (averaged over {data['task_count']} task(s))")
                if len(channel_names) == len(final_mean_alpha):
                    for i, ch_name in enumerate(channel_names):
                        print(f"  {ch_name}: {final_mean_alpha[i]:.4f}")
                else:
                    print(f"  Error: Mismatch between channel count ({len(channel_names)}) and mean alpha values count ({len(final_mean_alpha)}).")
                    print(f"  Raw mean alpha values: {final_mean_alpha}")
            else:
                print(f"\nState: {state} - No data processed.")


Mean Alpha Power per Channel for Subject '078' in Dataset 'braboszcz2017':

State: OT (averaged over 1 task(s))
  Fp1: 22.9324
  AF7: 20.9996
  AF3: 23.1303
  F1: 23.8335
  F3: 21.7317
  F5: 19.0515
  F7: 18.5411
  FT7: 15.4132
  FC5: 12.8166
  FC3: 17.9056
  FC1: 20.8608
  C1: 14.8171
  C3: 10.4888
  C5: 8.0372
  T7: 16.0743
  TP7: 26.1695
  CP5: 10.0681
  CP3: 9.7521
  CP1: 13.5789
  P1: 23.9010
  P3: 25.0785
  P5: 40.2538
  P7: 60.6737
  P9: 53.6776
  PO7: 97.1089
  PO3: 71.3529
  O1: 85.0252
  Iz: 53.6597
  Oz: 64.0659
  POz: 67.3214
  Pz: 26.8229
  CPz: 15.9674
  Fpz: 23.8211
  Fp2: 23.1602
  AF8: 20.7287
  AF4: 23.2950
  AFz: 24.7946
  Fz: 24.9071
  F2: 24.1355
  F4: 22.4974
  F6: 18.8447
  F8: 16.8216
  FT8: 13.0673
  FC6: 12.0749
  FC4: 17.2205
  FC2: 21.2223
  FCz: 22.6003
  Cz: 17.3155
  C2: 13.9748
  C4: 9.5730
  C6: 7.5755
  T8: 15.6415
  TP8: 34.0953
  CP6: 15.0748
  CP4: 10.5751
  CP2: 13.0879
  P2: 25.6825
  P4: 27.7284
  P6: 43.6645
  P8: 84.7628
  P10: 89.2624
  PO8: 