# Time-Frequency Analysis of EEG Data to visaulize Alpha-band Lateralization

This script performs a time-frequency analysis of EEG data to examine differences in power across frequencies and time between two conditions ("left" and "right" - attending trials) and left vs right Electrodes .

---

## **1. Data Preprocessing**
- Loads preprocessed EEG data for multiple subjects.
- Filters data in the 5-15 Hz range using a bandpass filter to isolate relevant frequencies.
- Removes edge artifacts and excludes noisy channels based on a rejection list.

---

## **2. Time-Frequency Representation (TFR)**
- Uses Short-Time Fourier Transform (STFT) to compute the time-frequency representation of power for each trial.
- Separates trials into "left" and "right" conditions and computes average power for each condition across trials.
- Calculates the difference in power between conditions (`right - left`).

---

## **3. Analysis**
- Averages power differences across subjects for each channel, frequency, and time.
- Focuses on the alpha frequency band (8-12 Hz) to compute the lateralization effect.

---

## **4. Visualization**
- Generates time-frequency plots for selected electrode pairs to highlight condition differences:
  - **Electrode-wise TFR:** Displays power difference across frequencies and time for each electrode.
  - **Alpha Power Time Series:** Shows time-resolved alpha power differences for lateralized electrode pairs.

---

## **5. Output**
- Saves the visualizations as `.png` and `.svg` files to for further analysis.


In [None]:
import numpy as np
import os
from os.path import join
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from scipy.signal import stft, butter, sosfilt

wd = r'C:\Users\Radovan\OneDrive\Radboud\Studentships\Jordy Thielen\root'
os.chdir(wd)
data_dir = join(wd, "data")
experiment_dir = join(data_dir, "experiment")
files_dir = join(experiment_dir, 'files')
sourcedata_dir = join(experiment_dir, 'sourcedata')
derivatives_dir = join(join(experiment_dir, 'derivatives'))
analysis_dir = join(data_dir, "analysis")
alpha_dir = join(analysis_dir, "alpha")
decoding_results_dir = join(alpha_dir, "decoding_results") 
plots_dir = join(alpha_dir, "plots")
features_dir = join(alpha_dir, "plots", "features")


subjects = [
    "VPpdia", "VPpdib", "VPpdic", "VPpdid", "VPpdie", "VPpdif", "VPpdig", "VPpdih",
    "VPpdii", "VPpdij", "VPpdik", "VPpdil", "VPpdim", "VPpdin", "VPpdio", "VPpdip", 
    "VPpdiq", "VPpdir", "VPpdis", "VPpdit", "VPpdiu", "VPpdiv", "VPpdiw", "VPpdix", 
    "VPpdiy", "VPpdiz", "VPpdiza", "VPpdizb", "VPpdizc"
]

selected_channels = ['P3', 'P4', 'P7', 'P8', 'O1', 'O2', 'Oz', 'Pz'] 
picks_hubner = [
    "F7", "F3", "Fz", "F4", "F8", "FC1", "FC2", "FC5", "FC6", "FCz", "T7", "C3", 
    "Cz", "C4", "T8", "CP1", "CP2", "CP5", "CP6", 'CPz',
    "P7", "P3", "Pz", "P4", "P8", "Oz", "O1", "O2"
]

subjects_channel_reject = {
    "VPpdib": ["FC2"],
    "VPpdih": ["C3"],
    "VPpdizb": ["Fz"],
    "VPpdizc": ["FC2"]
}

task  = 'covert'
sampling_rate = 120
fmin, fmax = 5, 25
n_fft = 512
window_length = 1.0  # in seconds
nperseg = int(window_length * sampling_rate)
noverlap = int(nperseg * 0.5)
ch_names = selected_channels  # channel names for indexing/plotting

def bandpass_filter(data, lowcut, highcut, fs = 120, order=4):
    sos = butter(order, [lowcut, highcut], btype='bandpass', fs=fs, output='sos')
    return sosfilt(sos, data, axis=-1)

def compute_stft_tfr(data, fs, fmin, fmax, n_fft, nperseg, noverlap):
    """
    Compute STFT-based TFR.
    data shape: (n_trials, n_channels, n_samples)
    
    Returns:
    power: (n_channels, n_freqs, n_times)
    freqs: frequencies array
    times: times array
    """

    n_trials, n_channels, n_samples = data.shape
    # Run STFT on first channel/trial to get the time-frequency grid
    f, t, Zxx = stft(data[0,0,:], fs=fs, nperseg=nperseg, noverlap=noverlap, nfft=n_fft)
    # Select frequencies to display
    freq_mask = (f >= fmin) & (f <= fmax)
    f = f[freq_mask]

    # Compute mean power over trials. Power = abs(Zxx)**2
    all_power = np.zeros((n_trials, n_channels, len(f), len(t)))
    for tr in range(n_trials):
        for ch in range(n_channels):
            _, _, Zxx = stft(data[tr,ch,:], fs=fs, nperseg=nperseg, noverlap=noverlap, nfft=n_fft)
            Zxx = Zxx[freq_mask, :]
            power = np.abs(Zxx)**2
            all_power[tr,ch,:,:] = power
    # Average over trials
    mean_power = all_power.mean(axis=0)  # (n_channels, n_freqs, n_times)
    return mean_power, f, t

# Initialize lists to collect features and labels from all subjects
all_subject_tfr = []
all_subject_labels = []
selected_indices = [picks_hubner.index(ch) for ch in selected_channels if ch in picks_hubner]
power_tfr_left = []
power_tfr_right = []
ch_map = {ch:i for i,ch in enumerate(ch_names)}

for subject in subjects:
    # Load the NPZ file
    file_dir = os.path.join(derivatives_dir, 'preprocessed', "alpha", f"sub-{subject}")
    file_path = os.path.join(file_dir, f"sub-{subject}_task-{task}_alpha.npz")

    if not os.path.exists(file_path):
        print(f"File not found: {file_path}")
        continue

    picks_clean = picks_hubner.copy()
    
    #Adapt indexing for rejected channels
    if subject in subjects_channel_reject:
        # Get the channels to reject for this subject
        channels_to_reject = subjects_channel_reject[subject]
        # Remove all channels from picks_clean
        for channel in channels_to_reject:
            if channel in picks_clean:
                picks_clean.remove(channel)

    npz_data = np.load(file_path)
    X = npz_data['X']  # (n_trials, n_channels, n_samples)
    y = npz_data['y']  # (n_trials,)

    selected_indices = [picks_clean.index(ch) for ch in selected_channels if ch in picks_clean]
    data = X[:, selected_indices, :]
    #data = data[:, :, 120:-120]
    data = bandpass_filter(data, 5, 15)
    data = data[:, :, 120:-120]

    # Separate trials into left and right
    left_mask = (y == 0)
    right_mask = (y == 1)
    X_left = data[left_mask]   # (N_left_trials, n_channels, n_samples)
    X_right = data[right_mask] # (N_right_trials, n_channels, n_samples)

    # Compute TFR for left and right
    power_l, freqs, times = compute_stft_tfr(X_left, sampling_rate, fmin, fmax, n_fft, nperseg, noverlap)
    power_r, freqs, times = compute_stft_tfr(X_right, sampling_rate, fmin, fmax, n_fft, nperseg, noverlap)

    # Store for later averaging across subjects
    power_tfr_left.append(power_l)  # (n_channels, n_freqs, n_times)
    power_tfr_right.append(power_r)


# Stack and average over subjects
power_tfr_left_stack = np.stack(power_tfr_left, axis=0)   # (n_subj, n_channels, n_freqs, n_times)
power_tfr_right_stack = np.stack(power_tfr_right, axis=0) # (n_subj, n_channels, n_freqs, n_times)

# Remove Subj b for plotting, massive power across frequencies at t=14, pollutes everything in the average
power_tfr_right_stack = np.delete(power_tfr_right_stack, 1, axis=0)
power_tfr_left_stack = np.delete(power_tfr_left_stack, 1, axis=0)

power_tfr_l = np.mean(power_tfr_left_stack, axis=0)   # (n_channels, n_freqs, n_times)
power_tfr_r = np.mean(power_tfr_right_stack, axis=0)  # (n_channels, n_freqs, n_times)
power_tfr_delta = power_tfr_r - power_tfr_l


# Visualize


# Plotting parameters
cmap = 'coolwarm'
abs_max = max(abs(np.min(power_tfr_delta)), abs(np.max(power_tfr_delta)))
vmin, vmax = -4e-12, 4e-12 #-abs_max, abs_max

# Alpha frequency band for averaging
alpha_min, alpha_max = 8, 12
alpha_mask = (freqs >= alpha_min) & (freqs <= alpha_max)
alpha_power = np.mean(power_tfr_delta[:, alpha_mask, :], axis=1)  # (n_channels, n_times)

alpha_vmin = -4e-12#np.min(alpha_power) - 2e-12
alpha_vmax = 4e-12#np.max(alpha_power) + 2e-12

# Define electrode pairs
electrode_pairs = [
    ('Oz', 'Pz'),
    ('P3', 'P4'),
    ('P7', 'P8'),
    ('O1', 'O2')
]

fig = plt.figure(figsize=(24, 4 * len(electrode_pairs)))
fig.suptitle('Time-Frequency Power Difference (R - L) per Electrode', fontsize=16, fontweight='bold')


gs = gridspec.GridSpec(len(electrode_pairs), 3)

for i, (left_electrode, right_electrode) in enumerate(electrode_pairs):
    left_idx = ch_map[left_electrode]
    right_idx = ch_map[right_electrode]

    # Left electrode TFR
    power_channel_left = power_tfr_delta[left_idx]  # (n_freqs, n_times)
    ax = fig.add_subplot(gs[i, 0])
    pcm = ax.pcolormesh(times, freqs, power_channel_left, shading='auto', cmap=cmap, vmin=vmin, vmax=vmax)
    fig.colorbar(pcm, ax=ax, label='Power Difference (R-L)')
    ax.set_title(f'{left_electrode}', fontsize = 15)
    ax.set_yticks(np.linspace(6, 24, 10))
    ax.set_xlabel('Time (s)')
    ax.set_ylabel('Frequency (Hz)')

    ax.axhline(y=8, color='gray', linestyle='--')
    ax.axhline(y=12, color='gray', linestyle='--')

    # Right electrode TFR
    power_channel_right = power_tfr_delta[right_idx]  # (n_freqs, n_times)
    ax = fig.add_subplot(gs[i, 1])
    pcm = ax.pcolormesh(times, freqs, power_channel_right, shading='auto', cmap=cmap, vmin=vmin, vmax=vmax)
    fig.colorbar(pcm, ax=ax, label='Power Difference (R-L)')
    ax.set_yticks(np.linspace(6, 24, 10))
    ax.set_title(f'{right_electrode}', fontsize = 15)
    ax.set_xlabel('Time (s)')
    ax.set_ylabel('Frequency (Hz)')

    ax.axhline(y=8, color='gray', linestyle='--')
    ax.axhline(y=12, color='gray', linestyle='--')

    # Alpha power time series
    left_alpha = alpha_power[left_idx]   # (n_times,)
    right_alpha = alpha_power[right_idx] # (n_times,)
    ax = fig.add_subplot(gs[i, 2])
    ax.plot(times, left_alpha, label=left_electrode, color='blue')
    ax.plot(times, right_alpha, label=right_electrode, color='red')
    ax.set_ylim(alpha_vmin, alpha_vmax)

    ax.axvline(x=0, color='k', linestyle='--', alpha = 0.5)
    ax.axvline(x=20, color='k', linestyle='--', alpha = 0.5)
    ax.set_title(f'Difference Wave for {left_electrode} vs {right_electrode}', fontsize = 15)
    ax.set_xlabel("Time (s)")
    ax.set_ylabel("Alpha Power")
    ax.legend(loc = 'upper right')

plt.tight_layout(rect=[0, 0, 1, 0.94])
fig.subplots_adjust(top=0.92)

plots_dir = join(alpha_dir, "plots")
os.makedirs(os.path.dirname(plots_dir), exist_ok=True)

plt.savefig(join(plots_dir, f"TFR_alpha_lateralization.png"), dpi=300)
plt.savefig(join(plots_dir, f"TFR_alpha_lateralization.svg"), dpi=300)
plt.close()
