# Measuring Mutual Information (MI) beteen EEG & MEG
- Measures shared information between two signals (in the information theory sense)
- Lets us know if the signal between eeg & meg shares resemblance.
- Formula: MI(X,Y) = Σ Σ P(x,y) * log₂(P(x,y)/(P(x)P(y)))



- Importing and loading the data.

# Import mean pooled data

# Sharding the data most probably fucks up my electrode indicies. Am I only saving the mag channels in shards? Investigate how they are saved!

In [1]:
import os
import torch
from pathlib import Path
import glob
from tqdm.notebook import tqdm
from dotenv import load_dotenv
import warnings

warnings.filterwarnings('ignore')

load_dotenv()
DATASET_PATH = os.getenv('DATASET_PATH')

# ONLY IMPLEMENTS TRAINING DATA ATM
def load_and_concat_shards(dataset_path, mode='train'):
    """
    Loads and concatenates all EEG and MAG shards from the specified dataset path and mode.
    
    Args:
        dataset_path (str): Path to the dataset root directory
        mode (str): One of 'train', 'val', or 'test'
    
    Returns:
        tuple: (concatenated_eeg, concatenated_mag) tensors
    """
    # Initialize lists to store tensors
    # Shape (num_eeg_channels, 275, total_windows)
    # Shape (num_mag_channels, 275, total_windows)
    #                   ^       ^     ^
    #                   |       |     └── Number of windows
    #                   |       └──────── Timepoints (a list of 275 signal values)
    #                   └──────────── Number of channels
    
    eeg_tensors = []
    mag_tensors = []
    
    # Get all subject folders in the specified mode
    mode_path = Path(dataset_path) / mode
    subject_folders = sorted([f for f in mode_path.iterdir() if f.is_dir()])
    
    print(f"Loading {mode} data from {len(subject_folders)} subjects...")
    
    # Iterate through each subject folder
    for subject_folder in tqdm(subject_folders, desc="Loading subjects"):
        # Load EEG shards
        eeg_shard_folder = subject_folder / "EEG_shards"
        if eeg_shard_folder.exists():
            eeg_files = sorted(eeg_shard_folder.glob("*.pt"))
            for eeg_file in eeg_files:
                eeg_tensor = torch.load(eeg_file)
                eeg_tensors.append(eeg_tensor)
        
        # Load MAG shards
        mag_shard_folder = subject_folder / "MAG_shards"
        if mag_shard_folder.exists():
            mag_files = sorted(mag_shard_folder.glob("*.pt"))
            for mag_file in mag_files:
                mag_tensor = torch.load(mag_file)
                mag_tensors.append(mag_tensor)
    
    # Concatenate all tensors
    if eeg_tensors:
        concatenated_eeg = torch.cat(eeg_tensors, dim=2)  # Concatenate along windows dimension
    else:
        concatenated_eeg = None
        
    if mag_tensors:
        concatenated_mag = torch.cat(mag_tensors, dim=2)  # Concatenate along windows dimension
    else:
        concatenated_mag = None
    
    return concatenated_eeg, concatenated_mag

# Load and concatenate the data
dataset_path = DATASET_PATH  # Adjust this path as needed
eeg_data, mag_data = load_and_concat_shards(dataset_path, mode='train')

# Print information about the tensors
print("\nData shapes:")
if eeg_data is not None:
    print(f"EEG data: {eeg_data.shape}")
    print(f"  - {eeg_data.shape[0]} channels")
    print(f"  - {eeg_data.shape[1]} timepoints per window")
    print(f"  - {eeg_data.shape[2]} total windows")
    print(f"  - dtype: {eeg_data.dtype}")
else:
    print("No EEG data found")

if mag_data is not None:
    print(f"\nMAG data: {mag_data.shape}")
    print(f"  - {mag_data.shape[0]} channels")
    print(f"  - {mag_data.shape[1]} timepoints per window")
    print(f"  - {mag_data.shape[2]} total windows")
    print(f"  - dtype: {mag_data.dtype}")
else:
    print("No MAG data found")

# Optional: Print some basic statistics
if eeg_data is not None:
    print("\nEEG Statistics:")
    print(f"Mean: {eeg_data.float().mean():.3f}")
    print(f"Std: {eeg_data.float().std():.3f}")
    print(f"Min: {eeg_data.float().min():.3f}")
    print(f"Max: {eeg_data.float().max():.3f}")

if mag_data is not None:
    print("\nMAG Statistics:")
    print(f"Mean: {mag_data.float().mean():.3f}")
    print(f"Std: {mag_data.float().std():.3f}")
    print(f"Min: {mag_data.float().min():.3f}")
    print(f"Max: {mag_data.float().max():.3f}")

Loading train data from 11 subjects...


Loading subjects:   0%|          | 0/11 [00:00<?, ?it/s]


Data shapes:
EEG data: torch.Size([74, 275, 25706])
  - 74 channels
  - 275 timepoints per window
  - 25706 total windows
  - dtype: torch.float16

MAG data: torch.Size([102, 275, 25706])
  - 102 channels
  - 275 timepoints per window
  - 25706 total windows
  - dtype: torch.float16

EEG Statistics:
Mean: 0.001
Std: 1.049
Min: -122.562
Max: 77.312

MAG Statistics:
Mean: -0.000
Std: 0.060
Min: -5.055
Max: 4.867


In [2]:
# In a new cell, add something like this:
import ipywidgets as widgets
from ipywidgets import interact
import matplotlib.pyplot as plt

# Make sure matplotlib displays inline in Jupyter
%matplotlib inline

@interact(
    time_frames=widgets.IntSlider(min=1, max=275, step=1, value=275, description='Time Frames'),
    window_index=widgets.IntSlider(min=0, max=0 if eeg_data is None else eeg_data.shape[2]-1, step=1, value=0, description='Window')
)
def dynamic_plot(time_frames, window_index):
    """
    Dynamically plot the first `time_frames` samples from EEG channel 13 and 
    MAG channel 21 for the requested window_index, using ipywidgets for 
    interactive sliders in a Jupyter Notebook.
    """
    fig, axs = plt.subplots(2, 1, figsize=(10, 6), sharex=True)

    
    if eeg_data is not None:
        # EEG Channel 13
        eeg_slice = eeg_data[13, :time_frames, window_index].cpu().numpy() if eeg_data.is_cuda else eeg_data[13, :time_frames, window_index].numpy()
        axs[0].plot(eeg_slice, label='EEG Ch 13', color='red')  # Changed to red
        axs[0].set_title(f'EEG channel 13 (Window {window_index})')
        axs[0].legend()

    if mag_data is not None:
        # MAG Channel 21
        mag_slice = mag_data[21, :time_frames, window_index].cpu().numpy() if mag_data.is_cuda else mag_data[21, :time_frames, window_index].numpy()
        axs[1].plot(mag_slice, label='MAG Ch 21', color='blue')  # Changed to blue
        axs[1].set_title(f'MAG channel 21 (Window {window_index})')
        axs[1].legend()

    axs[1].set_xlabel('Sample Index')
    plt.tight_layout()
    plt.show()

interactive(children=(IntSlider(value=275, description='Time Frames', max=275, min=1), IntSlider(value=0, desc…

In [3]:
import pywt
import numpy as np
import matplotlib.pyplot as plt
from scipy import signal

import ipywidgets as widgets
from ipywidgets import interact
from IPython.display import clear_output

def compute_wavelet_transform(data, sampling_rate=250, frequencies=None, wavelet='cmor1.5-1.0'):
    """
    Compute the continuous wavelet transform for a signal.
    
    Args:
        data: 1D array of signal values
        sampling_rate: sampling rate of the signal (default 250 Hz)
        frequencies: frequency range to analyze (default: None, will create automatically)
        wavelet: wavelet to use (default: complex Morlet wavelet)
    
    Returns:
        tuple: (wavelet_coeffs, frequencies)
    """
    if frequencies is None:
        frequencies = np.logspace(np.log10(1), np.log10(100), num=50)
    
    scales = pywt.frequency2scale(wavelet, frequencies / sampling_rate)
    coefficients, _ = pywt.cwt(data, scales, wavelet)
    
    return coefficients, frequencies

def plot_wavelet_visualizations(eeg_data, mag_data, window_idx=0, eeg_channel=13, mag_channel=21, cmap='inferno'):
    """
    Multiple visualizations of wavelet analysis for EEG and MEG signals
    """
    # Clear previous plots
    plt.close('all')
    
    # Extract signals
    eeg_signal = eeg_data[eeg_channel, :, window_idx].cpu().numpy()
    mag_signal = mag_data[mag_channel, :, window_idx].cpu().numpy()
    
    # Compute wavelets
    eeg_coeffs, freqs = compute_wavelet_transform(eeg_signal)
    mag_coeffs, _ = compute_wavelet_transform(mag_signal)
    
    time = np.arange(len(eeg_signal)) / 250
    
    # Create figure with 3 rows, 2 columns
    fig = plt.figure(figsize=(15, 15))
    
    # 1. Standard Scalogram (heat map)
    ax1 = plt.subplot(321)
    im1 = ax1.pcolormesh(time, freqs, np.abs(eeg_coeffs), shading='gouraud', cmap=cmap)
    ax1.set_ylabel('Frequency (Hz)')
    ax1.set_yscale('log')
    ax1.set_title('EEG Scalogram')
    plt.colorbar(im1, ax=ax1)
    
    ax2 = plt.subplot(322)
    im2 = ax2.pcolormesh(time, freqs, np.abs(mag_coeffs), shading='gouraud', cmap=cmap)
    ax2.set_ylabel('Frequency (Hz)')
    ax2.set_yscale('log')
    ax2.set_title('MEG Scalogram')
    plt.colorbar(im2, ax=ax2)
    
    # 2. 3D Surface Plot
    ax3 = plt.subplot(323, projection='3d')
    time_mesh, freq_mesh = np.meshgrid(time, freqs)
    ax3.plot_surface(time_mesh, freq_mesh, np.abs(eeg_coeffs), cmap=cmap, alpha=0.8)
    ax3.set_ylabel('Frequency (Hz)')
    ax3.set_title('EEG 3D Wavelet')
    ax3.set_yscale('log')
    
    ax4 = plt.subplot(324, projection='3d')
    ax4.plot_surface(time_mesh, freq_mesh, np.abs(mag_coeffs), cmap=cmap, alpha=0.8)
    ax4.set_ylabel('Frequency (Hz)')
    ax4.set_title('MEG 3D Wavelet')
    ax4.set_yscale('log')
    
    # 3. Contour Plot
    ax5 = plt.subplot(325)
    cont1 = ax5.contourf(time, freqs, np.abs(eeg_coeffs), levels=20, cmap=cmap)
    ax5.set_ylabel('Frequency (Hz)')
    ax5.set_yscale('log')
    ax5.set_title('EEG Contour')
    plt.colorbar(cont1, ax=ax5)
    
    ax6 = plt.subplot(326)
    cont2 = ax6.contourf(time, freqs, np.abs(mag_coeffs), levels=20, cmap=cmap)
    ax6.set_ylabel('Frequency (Hz)')
    ax6.set_yscale('log')
    ax6.set_title('MEG Contour')
    plt.colorbar(cont2, ax=ax6)
    
    plt.tight_layout()
    return fig
@interact(
    window_index=widgets.IntSlider(
        min=0, 
        max=eeg_data.shape[2]-1, 
        step=1, 
        value=0, 
        description='Window'
    ),
    visualization_type=widgets.Dropdown(
        options=['3d', 'standard', 'contour', 'all'],
        value='3d',  # Default to '3d' now
        description='View Type:'
    )
)
def plot_interactive_wavelet_viz(window_index, visualization_type='3d'):
    """
    Interactive function for visualizing different wavelet plots for EEG and MEG.
    Using RdYlBu_r colormap: Blue (low) -> White (medium) -> Red (high)
    """
    clear_output(wait=True)
    
    # Extract signals
    eeg_signal = eeg_data[13, :, window_index].cpu().numpy()
    mag_signal = mag_data[21, :, window_index].cpu().numpy()
    eeg_coeffs, freqs = compute_wavelet_transform(eeg_signal)
    mag_coeffs, _ = compute_wavelet_transform(mag_signal)
    time = np.arange(len(eeg_signal)) / 250
    
    if visualization_type == '3d':
        # Create a larger figure for better 3D visualization
        plt.close('all')
        fig = plt.figure(figsize=(20, 8))
        
        # EEG 3D Plot
        ax1 = fig.add_subplot(121, projection='3d')
        time_mesh, freq_mesh = np.meshgrid(time, freqs)
        
        surf1 = ax1.plot_surface(time_mesh, freq_mesh, np.abs(eeg_coeffs), 
                               cmap='RdYlBu_r', alpha=0.8,
                               rstride=1, cstride=1,  # Reduce stride for smoother surface
                               linewidth=0, antialiased=True)
        
        # Set labels and title
        ax1.set_xlabel('Time (s)')
        ax1.set_ylabel('Frequency (Hz)')
        ax1.set_zlabel('Magnitude')
        ax1.set_title('EEG Wavelet - 3D')
        
        # Set the viewing angle
        ax1.view_init(elev=30, azim=45)
        
        # Add colorbar
        fig.colorbar(surf1, ax=ax1, shrink=0.5, aspect=5)
        
        # MEG 3D Plot
        ax2 = fig.add_subplot(122, projection='3d')
        
        surf2 = ax2.plot_surface(time_mesh, freq_mesh, np.abs(mag_coeffs), 
                               cmap='RdYlBu_r', alpha=0.8,
                               rstride=1, cstride=1,  # Reduce stride for smoother surface
                               linewidth=0, antialiased=True)
        
        # Set labels and title
        ax2.set_xlabel('Time (s)')
        ax2.set_ylabel('Frequency (Hz)')
        ax2.set_zlabel('Magnitude')
        ax2.set_title('MEG Wavelet - 3D')
        
        # Set the viewing angle
        ax2.view_init(elev=30, azim=45)
        
        # Add colorbar
        fig.colorbar(surf2, ax=ax2, shrink=0.5, aspect=5)
        
        plt.tight_layout(w_pad=5)  # Increase spacing between subplots
        
    elif visualization_type == 'standard':
        plt.close('all')
        fig = plt.figure(figsize=(14, 6))
        
        ax1 = plt.subplot(121)
        im1 = ax1.pcolormesh(time, freqs, np.abs(eeg_coeffs), 
                            shading='gouraud', cmap='RdYlBu_r')
        ax1.set_title('EEG Wavelet - Standard')
        ax1.set_yscale('log')
        plt.colorbar(im1, ax=ax1)
        
        ax2 = plt.subplot(122)
        im2 = ax2.pcolormesh(time, freqs, np.abs(mag_coeffs), 
                            shading='gouraud', cmap='RdYlBu_r')
        ax2.set_title('MEG Wavelet - Standard')
        ax2.set_yscale('log')
        plt.colorbar(im2, ax=ax2)
        plt.tight_layout()
        
    elif visualization_type == 'contour':
        plt.close('all')
        fig = plt.figure(figsize=(14, 6))
        
        ax1 = plt.subplot(121)
        cont1 = ax1.contourf(time, freqs, np.abs(eeg_coeffs), 
                            levels=20, cmap='RdYlBu_r')
        ax1.set_yscale('log')
        ax1.set_title('EEG Wavelet - Contour')
        plt.colorbar(cont1, ax=ax1)
        
        ax2 = plt.subplot(122)
        cont2 = ax2.contourf(time, freqs, np.abs(mag_coeffs), 
                            levels=20, cmap='RdYlBu_r')
        ax2.set_yscale('log')
        ax2.set_title('MEG Wavelet - Contour')
        plt.colorbar(cont2, ax=ax2)
        plt.tight_layout()
    
    elif visualization_type == 'all':
        fig = plot_wavelet_visualizations(eeg_data, mag_data, 
                                        window_idx=window_index, cmap='RdYlBu_r')
    
    plt.show()

interactive(children=(IntSlider(value=0, description='Window', max=25705), Dropdown(description='View Type:', …

# Applying Fast Fourier Transform

In [4]:
import ipywidgets as widgets
from ipywidgets import interact
import matplotlib.pyplot as plt
import numpy as np

%matplotlib inline

def bandpass_filter(signal, sf, low_freq, high_freq):
    """
    Perform a simple bandpass filter in the frequency domain by FFT and 
    removing all frequency components outside [low_freq, high_freq].
    """
    n = len(signal)
    # Frequency array (for real FFT up to Nyquist)
    freqs = np.fft.rfftfreq(n, 1.0 / sf)
    # FFT of the signal
    fft_signal = np.fft.rfft(signal)
    # Zero out frequencies that are not in the desired band
    mask = (freqs >= low_freq) & (freqs <= high_freq)
    fft_signal[~mask] = 0
    # Reconstruct filtered signal via IFFT
    filtered = np.fft.irfft(fft_signal, n=n)
    return filtered

@interact(
    time_frames=widgets.IntSlider(min=1, max=275, step=1, value=250, description='Time Frames'),
    window_index=widgets.IntSlider(min=0, max=0 if eeg_data is None else eeg_data.shape[2]-1, step=1, value=0, description='Window')
)
def dynamic_six_subplots(time_frames, window_index):
    """
    Dynamically plot the alpha, beta, and gamma band-limited signals for EEG channel 13
    and MAG channel 21, each on its own subplot. That gives us 6 subplots in total:
    1) EEG Alpha
    2) EEG Beta
    3) EEG Gamma
    4) MAG Alpha
    5) MAG Beta
    6) MAG Gamma

    The number of time_frames and window_index can be adjusted with sliders.
    
    Alpha:  8-12 Hz
    Beta:   13-30 Hz
    Gamma:  31-80 Hz
    """
    sf = 250 
    fig, axes = plt.subplots(6, 1, figsize=(10, 14), sharex=True)

    # Extract EEG data and apply filters
    if eeg_data is not None:
        eeg_slice = eeg_data[13, :time_frames, window_index]
        eeg_slice = eeg_slice.cpu().numpy() if eeg_data.is_cuda else eeg_slice.numpy()

        alpha_eeg = bandpass_filter(eeg_slice, sf, 8, 12)
        beta_eeg  = bandpass_filter(eeg_slice, sf, 13, 30)
        gamma_eeg = bandpass_filter(eeg_slice, sf, 31, 80)

        axes[0].plot(alpha_eeg, color='blue')
        axes[0].set_ylabel('EEG Alpha\n(8-12 Hz)')
        axes[0].set_title(f'EEG (Ch 13) - Window {window_index}')

        axes[1].plot(beta_eeg, color='orange')
        axes[1].set_ylabel('EEG Beta\n(13-30 Hz)')

        axes[2].plot(gamma_eeg, color='green')
        axes[2].set_ylabel('EEG Gamma\n(31-80 Hz)')

    # Extract MAG data and apply filters
    if mag_data is not None:
        mag_slice = mag_data[21, :time_frames, window_index]
        mag_slice = mag_slice.cpu().numpy() if mag_data.is_cuda else mag_slice.numpy()

        alpha_mag = bandpass_filter(mag_slice, sf, 8, 12)
        beta_mag  = bandpass_filter(mag_slice, sf, 13, 30)
        gamma_mag = bandpass_filter(mag_slice, sf, 31, 80)

        axes[3].plot(alpha_mag, color='blue')
        axes[3].set_ylabel('MAG Alpha\n(8-12 Hz)')
        axes[3].set_title(f'MAG (Ch 21) - Window {window_index}')

        axes[4].plot(beta_mag, color='orange')
        axes[4].set_ylabel('MAG Beta\n(13-30 Hz)')

        axes[5].plot(gamma_mag, color='green')
        axes[5].set_ylabel('MAG Gamma\n(31-80 Hz)')

    axes[-1].set_xlabel('Sample Index')
    plt.tight_layout()
    plt.show()

interactive(children=(IntSlider(value=250, description='Time Frames', max=275, min=1), IntSlider(value=0, desc…

# Compute MI

In [None]:
# MI between the EEG (channel 13) and MAG (channel 21) signals across ALL windows.

import numpy as np
from sklearn.feature_selection import mutual_info_regression

def bandpass_filter(signal, sf, low_freq, high_freq):
    """
    Perform a simple bandpass filter in the frequency domain by FFT and 
    removing all frequency components outside [low_freq, high_freq].
    """
    n = len(signal)
    # Frequency array (for real FFT up to Nyquist)
    freqs = np.fft.rfftfreq(n, 1.0 / sf)
    # FFT of the signal
    fft_signal = np.fft.rfft(signal)
    # Zero out frequencies that are not in the desired band
    mask = (freqs >= low_freq) & (freqs <= high_freq)
    fft_signal[~mask] = 0
    # Reconstruct filtered signal via IFFT
    filtered = np.fft.irfft(fft_signal, n=n)
    return filtered

def compute_alpha_mutual_information(eeg_data, mag_data, eeg_channel=13, mag_channel=21, alpha_range=(8, 12), sf=250):
    """
    Computes the mutual information for alpha-band EEG and MAG signals across ALL windows.
    
    eeg_data shape: (num_eeg_channels, num_time_frames, num_windows)
    mag_data shape: (num_mag_channels, num_time_frames, num_windows)
    
    1) Apply bandpass filter in [alpha_range[0], alpha_range[1]] for each window.
    2) Concatenate all window data into a single array for EEG alpha, and one for MAG alpha.
    3) Compute mutual information using sklearn's mutual_info_regression.
    4) Return the MI value.
    """
    low_freq, high_freq = alpha_range

    # Prepare lists to accumulate alpha data across all windows
    alpha_eeg_all = []
    alpha_mag_all = []

    # Number of windows
    num_windows = eeg_data.shape[2]
    
    for w in range(num_windows):
        # Extract the full EEG and MAG signals for this window
        eeg_slice = eeg_data[eeg_channel, :, w]
        mag_slice = mag_data[mag_channel, :, w]

        # Convert to numpy if on GPU
        if hasattr(eeg_slice, 'cpu'):
            eeg_slice = eeg_slice.cpu().numpy()
        else:
            eeg_slice = eeg_slice.numpy()

        if hasattr(mag_slice, 'cpu'):
            mag_slice = mag_slice.cpu().numpy()
        else:
            mag_slice = mag_slice.numpy()

        # Filter for alpha band
        alpha_eeg = bandpass_filter(eeg_slice, sf, low_freq, high_freq)
        alpha_mag = bandpass_filter(mag_slice, sf, low_freq, high_freq)

        # Accumulate
        alpha_eeg_all.append(alpha_eeg)
        alpha_mag_all.append(alpha_mag)

    # Concatenate all windows into single 1D arrays
    alpha_eeg_all = np.concatenate(alpha_eeg_all, axis=0)
    alpha_mag_all = np.concatenate(alpha_mag_all, axis=0)

    # sklearn mutual_info_regression requires shapes:
    # X: (n_samples, n_features), y: (n_samples,)
    alpha_eeg_all = alpha_eeg_all.reshape(-1, 1)

    # Compute Mutual Information (for continuous data)
    mi_value = mutual_info_regression(alpha_eeg_all, alpha_mag_all, random_state=42)
    # mutual_info_regression returns an array of MI values for each feature column;
    # we only have one feature column, so take mi_value[0].
    return mi_value[0]

alpha_mi = compute_alpha_mutual_information(eeg_data, mag_data, eeg_channel=13, mag_channel=21)
print("Alpha band Mutual Information (EEG->MAG), channel 13->21: ", alpha_mi)