In [6]:
import os
import mne
import numpy as np
import warnings
from mne.preprocessing import create_eog_epochs
from mne import set_bipolar_reference
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

from utils.config import DATASETS, set_plot_style, PLOTS_PATH
from utils.helpers import iterate_dataset_items
from utils.file_io import load_ica, save_ica_excluded_components

%matplotlib qt

warnings.filterwarnings("ignore")  # Optional: suppress verbose MNE warnings
set_plot_style()

VERBOSE = True
SHOW_PLOTS = False  # Set False to avoid popups

z_threshold_blink = 5.0
z_threshold_saccade = 5.0

# Datasets
#DATASETS.pop('braboszcz2017')
#DATASETS.pop('jin2019', None)

def epochs_to_raw(epochs):
    """Convert MNE Epochs to a Raw object by stitching epochs together."""
    info = epochs.info.copy()

    # Get the epochs data
    data = epochs.get_data()  # shape (n_epochs, n_channels, n_times)

    # Concatenate the epochs along time axis
    data = np.concatenate(data, axis=1)  # now (n_channels, total_timepoints)

    # Create Raw object
    raw = mne.io.RawArray(data, info)

    return raw

def _lighter_red(iteration, max_iterations=3):
    """Return lighter shades of red based on iteration."""
    base_color = np.array([1.0, 0.0, 0.0])  # Pure red (RGB)
    white = np.array([1.0, 1.0, 1.0])  # White
    factor = min(iteration / max_iterations, 1.0)  # Gradually move toward white
    color = base_color + (white - base_color) * factor
    return tuple(color)

def plot_ica_component_scores(
    scores,
    z_threshold,
    title,
    save_dir,
    filename,
    figsize=(12, 4),
    ):
    """
    Plot ICA component scores with iterative adaptive threshold visualization.

    Parameters
    ----------
    scores : np.ndarray
        Correlation scores for ICA components.
    z_threshold : float
        Z-score threshold for exclusion.
    title : str
        Title of the plot.
    save_dir : str
        Directory to save the figure.
    filename : str
        Filename for saving the figure.
    color_cycle : list of str, optional
        List of colors to use for different exclusion rounds.
    figsize : tuple of float, optional
        Size of the figure.
    """

    if scores is None or len(scores) == 0:
        print("No scores to plot.")
        return

    remaining_scores = scores.copy()
    n_components = len(scores)

    fig, ax = plt.subplots(figsize=figsize)

    excluded_components = []
    excluded_colors = []

    iteration = 0
    while True:
        if np.all(np.isnan(remaining_scores)):
            break  # No more scores left

        # Recompute z-scores excluding NaNs
        mean_now = np.nanmean(remaining_scores)
        std_now = np.nanstd(remaining_scores)
        z_scores_iter = (remaining_scores - mean_now) / std_now

        max_z = np.nanmax(np.abs(z_scores_iter))
        if max_z > z_threshold:
            idx = np.nanargmax(np.abs(z_scores_iter))
            excluded_components.append(idx)
            excluded_colors.append(_lighter_red(iteration))

            # Mark as excluded
            remaining_scores[idx] = np.nan

            # Plot threshold lines for this iteration
            thresh_value = z_threshold * std_now
            ax.axhline(y=mean_now + thresh_value, linestyle='--', color=_lighter_red(iteration), alpha=0.6, linewidth=0.8)
            ax.axhline(y=mean_now - thresh_value, linestyle='--', color=_lighter_red(iteration), alpha=0.6, linewidth=0.8)

            iteration += 1
        else:
            break

    # Final bar plot
    bars = ax.bar(np.arange(n_components), scores, color='gray', edgecolor='k')

    for idx, color in zip(excluded_components, excluded_colors):
        bars[idx].set_color(color)
        bars[idx].set_edgecolor('black')

    ax.set_xlabel('ICA components')
    ax.set_ylabel('Score (correlation)')
    ax.set_title(title)
    ax.axhline(y=0, color='black', linewidth=0.8)

    # --- Custom grey proxy artist for legend ---
    proxy_line = Line2D(
        [0], [0],
        color='gray',
        linestyle='--',
        linewidth=1,
        alpha=0.7
    )
    ax.legend([proxy_line], [f'Adaptive threshold: ±{z_threshold} z-score'], loc='upper right')

    fig.tight_layout()

    os.makedirs(save_dir, exist_ok=True)
    save_path = os.path.join(save_dir, filename)
    fig.savefig(save_path)
    plt.close(fig)

    print(f"Saved plot to: {save_path}")

def add_title_above_properties(fig, title, height=0.08, fontsize=14):
    """Move all axes down to make space and add a clean title above."""
    for ax in fig.axes:
        pos = ax.get_position()
        new_pos = [pos.x0, pos.y0 - height, pos.width, pos.height]
        ax.set_position(new_pos)
    fig.text(0.5, 0.98, title, ha='center', va='top', fontsize=fontsize)

for dataset, subject, label, item, kwargs in iterate_dataset_items(DATASETS):
    print(f"\n\nProcessing Subject: {subject}, Item: {item}")

    # --- Load data
    file_path = os.path.join(dataset.path_epochs, 'ica_epochs', f"sub-{subject}_{label}-{item}_ica-epo.fif")
    if not os.path.exists(file_path):
        print(f"    No data found for subject {subject} with {label} {item}.")
        continue
    epochs = mne.read_epochs(os.path.join(dataset.path_epochs, 'ica_epochs', f"sub-{subject}_{label}-{item}_ica-epo.fif"), preload=True)
    ica = load_ica(dataset, subject, **kwargs, verbose=VERBOSE)

    # --- Create bipolar EOG channels
    epochs = set_bipolar_reference(epochs, anode='UVEOG', cathode='LVEOG', ch_name='VEOG', drop_refs=False, copy=True)
    epochs = set_bipolar_reference(epochs, anode='LHEOG', cathode='RHEOG', ch_name='HEOG', drop_refs=False, copy=True)

    # plot the EOG channels
    eog_picks = mne.pick_types(epochs.info, meg=False, eeg=False, eog=True)
    epochs.plot(picks=eog_picks, block=True)

Datasets:   0%|          | 0/2 [00:00<?, ?it/s]

[DATASET PROGRESSION] Processing dataset: Jin et al. (2019)


                                               
Datasets:   0%|          | 0/2 [00:00<?, ?it/s]                   

[SUBJECT PROGRESSION] Processing subject: 1



                                               
[A                                                               

Datasets:   0%|          | 0/2 [00:00<?, ?it/s] 
[A

[ ITEM  PROGRESSION ] Processing session: 1


Processing Subject: 1, Item: 1
Reading /home/sivert/Documents/Master_AttentionalDirectionResearch/data/jin2019/epochs/ica_epochs/sub-1_session-1_ica-epo.fif ...
    Found the data of interest:
        t =       0.00 ...    1992.19 ms
        0 CTF compensation matrices available
Not setting metadata
3133 matching events found
No baseline correction applied
0 projection items activated
Reading /home/sivert/Documents/Master_AttentionalDirectionResearch/data/jin2019/derivatives/ica/sub1_1_ica.fif ...
Now restoring ICA solution ...
Ready.
[ICA] Loaded ICA from: /home/sivert/Documents/Master_AttentionalDirectionResearch/data/jin2019/derivatives/ica/sub1_1_ica.fif
EEG channel type selected for re-referencing
Not setting metadata
3133 matching events found
No baseline correction applied
0 projection items activated
Added the following bipolar channels:
VEOG
EEG channel type selected for re-referencing
Not setting metadata
3133 matching events foun


Jin et al. (2019) Subjects:   0%|          | 0/30 [3:59:37<?, ?it/s]
Datasets:   0%|          | 0/2 [3:59:37<?, ?it/s]


KeyboardInterrupt: 

QSocketNotifier: Invalid socket 90 and type 'Read', disabling...
