In [1]:
import os
import numpy as np
import mne
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display
import warnings
warnings.filterwarnings('ignore')

# Constructing an Interactive Heatmap

In [3]:
#!/usr/bin/env python
# coding: utf-8

# Reading only the channel geometry from sub-03/run_02.fif
def get_info_sub03_run02(path="/srv/openfmri/val/sub-03/run_02.fif"):
    """
    Attempt to read only the 'info' from /srv/openfmri/val/sub-03/run_02.fif
    (the file with valid EEG+MAG digitization). preload=False so we skip data load.
    Returns the mne.Info or None if it fails.
    """
    if not os.path.exists(path):
        print(f"[Warning] Info file not found: {path}")
        return None

    try:
        raw = mne.io.read_raw_fif(path, preload=False, verbose=False)
        info = raw.info
        raw.close()
        print(f"[Info] Successfully read channel geometry from: {path}")
        return info
    except Exception as e:
        print(f"[Warning] Could not read info from {path}, reason: {e}")
        return None



# Gather data from all .fif files under /val/ EXCEPT sub-03/run_02.fif.
def gather_val_data_excluding_sub03run02(dataset_path="/srv/openfmri"):
    """
    Gathers EEG & MAG data from all runs in /srv/openfmri/val, except sub-03/run_02.fif.
    Returns two big arrays: (EEG_data_combined, MAG_data_combined).
      shape => [num_eeg_channels, total_time_points], [num_mag_channels, total_time_points]
    or (None, None) if nothing is found.
    """
    EEG_data_combined = None
    MAG_data_combined = None

    val_path = os.path.join(dataset_path, "val")
    if not os.path.isdir(val_path):
        print(f"[Warning] val path not found: {val_path}")
        return None, None

    for subject in sorted(os.listdir(val_path)):
        subject_path = os.path.join(val_path, subject)
        if not os.path.isdir(subject_path):
            continue

        fif_files = [f for f in os.listdir(subject_path) if f.endswith(".fif")]
        for fif_file in fif_files:
            # Skip sub-03/run_02.fif because it is corrupted 
            # (we only want its 'info', not its data).
            if subject == "sub-03" and fif_file == "run_02.fif":
                continue

            fif_path = os.path.join(subject_path, fif_file)
            try:
                raw = mne.io.read_raw_fif(fif_path, preload=True, verbose=False)
            except Exception as e:
                print(f"[Warning] Could not read {fif_path}, skipping. Error: {e}")
                continue

            data = raw.get_data()  # shape [n_channels, n_times]

            # Separate EEG vs MAG
            eeg_picks = mne.pick_types(raw.info, meg=False, eeg=True)
            mag_picks = mne.pick_types(raw.info, meg='mag', eeg=False)

            # Add to big combined arrays
            if len(eeg_picks) > 0:
                eeg_data = data[eeg_picks, :]
                if EEG_data_combined is None:
                    EEG_data_combined = eeg_data
                else:
                    EEG_data_combined = np.hstack([EEG_data_combined, eeg_data])

            if len(mag_picks) > 0:
                mag_data = data[mag_picks, :]
                if MAG_data_combined is None:
                    MAG_data_combined = mag_data
                else:
                    MAG_data_combined = np.hstack([MAG_data_combined, mag_data])

    return EEG_data_combined, MAG_data_combined



# A helper function to z-score each channel individually
def zscore_channels(data):
    """
    Performs channel-wise z-scoring. 
    data shape: [n_channels, n_timepoints]
    
    Returns a new array with the same shape, where each channel has mean=0, std=1.
    """
    if data is None:
        return None
    # compute means and stds over time (axis=1)
    means = data.mean(axis=1, keepdims=True)
    stds = data.std(axis=1, keepdims=True)
    
    # avoid division by zero if some channel is constant
    stds[stds < 1e-12] = 1.0  
    
    return (data - means) / stds



# Build a "position array" matching exactly the channels we have
def get_pos_for_picks(info, pick_type='eeg'):
    """
    Return a pos_array (N,2) and list of channel names for the specified pick_type ('eeg' or 'mag').

    STEPS:
      1. use pick_types(...) to find channels in info that match 'eeg' or 'meg="mag"'
      2. use mne.find_layout(info, ch_type='meg' or 'eeg') to get a layout object 
         which might have many positions (e.g. 306 if Elekta).
      3. keep only those that match your pick_ch_names
    """
    if pick_type == 'eeg':
        picks_idx = mne.pick_types(info, meg=False, eeg=True)
        ch_layout_type = 'eeg'
    else:
        # pick_type == 'mag'
        picks_idx = mne.pick_types(info, meg='mag', eeg=False)
        ch_layout_type = 'meg'

    if len(picks_idx) == 0:
        print(f"[Note] No {pick_type} channels found in info.")
        return None, []

    pick_ch_names = [info.ch_names[i] for i in picks_idx]

    try:
        layout = mne.find_layout(info, ch_type=ch_layout_type)
    except RuntimeError as e:
        print(f"[Warning] Could not create '{pick_type}' layout due to: {e}")
        return None, []

    if layout is None:
        print(f"[Warning] {pick_type} layout came back None.")
        return None, []

    layout_names = layout.names
    layout_coords = list(zip(layout.pos[:, 0], layout.pos[:, 1]))

    pos_list = []
    for ch_name in pick_ch_names:
        if ch_name in layout_names:
            i_layout = layout_names.index(ch_name)
            pos_list.append(layout_coords[i_layout])
        else:
            pos_list.append((0.0, 0.0))  # fallback for missing channels

    pos_array = np.array(pos_list)
    return pos_array, pick_ch_names



# 5) An interactive widget to topomap EEG & MAG side by side 
def interactive_brain_heatmap(EEG_data, MAG_data, info):
    """
    Creates an interactive topomap for EEG and MAG side by side,
    with a single time slider. Minimizes advanced `plot_topomap` args
    for compatibility with older MNE versions.
    """
    import matplotlib.pyplot as plt

    if info is None:
        print("[Error] No valid info. Cannot plot topomaps.")
        return

    if (EEG_data is None or EEG_data.size == 0) and (MAG_data is None or MAG_data.size == 0):
        print("[Warning] No EEG or MAG data loaded. Nothing to plot.")
        return

    pos_eeg, _ = get_pos_for_picks(info, pick_type='eeg')
    pos_mag, _ = get_pos_for_picks(info, pick_type='mag')

    if pos_eeg is None and pos_mag is None:
        print("[Warning] No layout for either EEG or MAG. Nothing to plot.")
        return

    def update_plot(frame_idx):
        plt.close('all')
        fig, axs = plt.subplots(1, 2, figsize=(10, 4))

        # EEG
        if (EEG_data is not None) and (EEG_data.size > 0) and (pos_eeg is not None):
            if EEG_data.shape[0] == len(pos_eeg):
                data_eeg = EEG_data[:, frame_idx]
                mne.viz.plot_topomap(data_eeg, pos_eeg,
                                     axes=axs[0], show=False, cmap='RdBu_r')
                axs[0].set_title(f"EEG (frame={frame_idx})")
            else:
                axs[0].axis('off')
                axs[0].set_title("Mismatch: EEG data vs. layout length.")
        else:
            axs[0].axis('off')
            axs[0].set_title("No EEG data or layout.")

        # MAG
        if (MAG_data is not None) and (MAG_data.size > 0) and (pos_mag is not None):
            if MAG_data.shape[0] == len(pos_mag):
                data_mag = MAG_data[:, frame_idx]
                mne.viz.plot_topomap(data_mag, pos_mag,
                                     axes=axs[1], show=False, cmap='RdBu_r')
                axs[1].set_title(f"MAG (frame={frame_idx})")
            else:
                axs[1].axis('off')
                axs[1].set_title("Mismatch: MAG data vs. layout length.")
        else:
            axs[1].axis('off')
            axs[1].set_title("No MAG data or layout.")

        plt.tight_layout()
        plt.show()

    # Determine how many time frames we have
    max_frames = 0
    if (EEG_data is not None) and (EEG_data.size > 0):
        max_frames = max(max_frames, EEG_data.shape[1])
    if (MAG_data is not None) and (MAG_data.size > 0):
        max_frames = max(max_frames, MAG_data.shape[1])

    if max_frames == 0:
        print("[Warning] No valid data frames to plot.")
        return

    slider = widgets.IntSlider(
        value=0,
        min=0,
        max=max_frames - 1,
        step=1,
        description='Time frame',
        continuous_update=False,
        layout=widgets.Layout(width='70%')
    )

    ui = widgets.VBox([slider])
    out = widgets.interactive_output(update_plot, {'frame_idx': slider})
    display(ui, out)

if __name__ == "__main__":
    # 1) Extract just the 'info' from sub-03/run_02.fif for EEG+MAG digitization
    info_example = get_info_sub03_run02("/srv/openfmri/val/sub-03/run_02.fif")
    if info_example is None:
        print("[Warning] Could not load 'info' from sub-03 run_02. EEG topomap won't be feasible.")

    # 2) Gather actual data from all other runs in 'val', skipping sub-03/run_02
    EEG_data_combined, MAG_data_combined = gather_val_data_excluding_sub03run02("/srv/openfmri")

    # 3) Z-score each channel’s data so that the color range is comparable
    if EEG_data_combined is not None:
        EEG_data_combined = zscore_channels(EEG_data_combined)
    if MAG_data_combined is not None:
        MAG_data_combined = zscore_channels(MAG_data_combined)

    # 4) Show the interactive topomap with the normalized data
    interactive_brain_heatmap(EEG_data_combined, MAG_data_combined, info_example)

[Info] Successfully read channel geometry from: /srv/openfmri/val/sub-03/run_02.fif


VBox(children=(IntSlider(value=0, continuous_update=False, description='Time frame', layout=Layout(width='70%'…

Output()