## Code for Figure 3

In [None]:
import mne
import numpy as np
import matplotlib.pyplot as plt
import h5io
import pandas as pd
import seaborn as sns
import statsmodels.api as sm

from meeglet import define_frequencies
from mne.viz.topomap import _prepare_topomap_plot, _make_head_outlines
from mpl_toolkits.axes_grid1 import ImageGrid

from scipy import io
from scipy.stats import bootstrap, permutation_test

In [None]:
foi = define_frequencies(foi_start=1, foi_end=64, bw_oct=0.35, delta_oct=0.05)[0]

In [None]:
sample_data_folder = mne.datasets.sample.data_path()
sample_data_raw_file = (
    sample_data_folder / "MEG" / "sample" / "sample_audvis_filt-0-40_raw.fif"
)
raw_for_adjacency = mne.io.read_raw_fif(sample_data_raw_file)
meg_indices = raw_for_adjacency.pick_types(meg='mag')
adj_matrix = mne.channels.find_ch_adjacency(raw_for_adjacency.info, 'mag')[0]

### GROUP LEVEL STATISTICS

### all features

In [None]:
features_CBU = h5io.read_hdf5('./BIOFIND/cosmeeg_CBU_2023-06-22_10-06.h5')
features_CTB1 = h5io.read_hdf5('./BIOFIND/cosmeeg_CTB_2023-06-22_10-46.h5')
features_CTB2 = h5io.read_hdf5('./BIOFIND/cosmeeg_CTB_2023-06-23_11-04.h5')

In [None]:
features_all = features_CTB1 | features_CTB2 | features_CBU #merge 3 hdf5 files (key: patient, value : tuple cov, pow)

# Build groups

In [None]:
participants_fname = './participants.tsv'
subject_df = pd.read_csv(participants_fname, delimiter='\t')
subject_df['participant_id'] = subject_df['participant_id'].str.replace('sub-', '')
subject_df = subject_df.set_index('participant_id')

In [None]:
subject_df = subject_df.loc[features_all.keys()]

In [None]:
feature_converter = subject_df['Converters'].fillna(-1)

## import power envelope

In [None]:
r_plain = np.array([features_all[subject][6] for subject in features_all])
r_plain.shape

In [None]:
del features_all, features_CBU, features_CTB1, features_CTB2

## ANALYSIS

## power envelope for converters vs non converters

In [None]:
plt.rcParams['font.size'] = 11
fig, axes = plt.subplots(1,1, figsize=[3.6,1.4],sharey=True)
r_plain_converter = np.mean(np.mean(np.mean(r_plain[np.array(feature_converter) == 1], 0),0),0)
r_plain_non_converter = np.mean(np.mean(np.mean(r_plain[np.array(feature_converter) == 0], 0),0),0)
plt.plot(foi, r_plain_converter, label='AD progression', linewidth = 3)
plt.plot(foi, r_plain_non_converter, label='Stable MCI', linewidth=3)

# Remove the legend frame
legend = plt.legend(loc='lower left', bbox_to_anchor=(-0.07, -0.16))
frame = legend.get_frame()
frame.set_linewidth(0) 
plt.xlabel('Frequencies (Hz)')
plt.ylabel('Correlation')
plt.xscale('log', base =2)
plt.xticks([1,2,4,8,16,32,64],labels = [1,2,4,8,16,32,64])
plt.ylim(0.00,0.20)

#plt.xticks(plt.xticks()[0][2:-2], labels = foi[::20])
sns.despine(offset=10, trim=True);
plt.title('Average power envelope correlation', y=1.05, fontsize=11)
plt.savefig('./figures/figure3a.pdf', dpi=300, bbox_inches='tight')

## power envelope for converters vs non converters

In [None]:
subject_df_adjust3 = subject_df.copy()
subject_df_adjust3['Converters'] = subject_df_adjust3['Converters'].fillna(-1)
converters_idx = subject_df_adjust3['Converters'] >= 0

In [None]:
r_plain_clean_converter = np.mean(np.mean(r_plain[converters_idx], 1), 1)
r_plain_clean_mean_converter = r_plain_clean_converter
r_plain_clean_converter.shape

## Permutation F-test on sensor data averaged over all sensors

In [None]:
data_condition_1 = r_plain_clean_mean_converter[np.array(feature_converter[feature_converter >= 0]) == 0]
data_condition_2 = r_plain_clean_mean_converter[np.array(feature_converter[feature_converter >= 0]) == 1]

X = [data_condition_1, data_condition_2]

## Spatiotemporal TFCE permutation F-test on full sensor data

In [None]:
T_obs, clusters, cluster_p_values, H0 = mne.stats.permutation_cluster_test(
    [data_condition_2, data_condition_1],
    threshold=dict(start=0, step=0.2),
    n_jobs = None, n_permutations = 10000, 
    tail=0, out_type="mask"
)

In [None]:
print(cluster_p_values)

### Vizualization

In [None]:
rng = np.random.RandomState(23)

def my_statistic(sample1, sample2, axis=0):
    statistic = np.mean(sample1) - np.mean(sample2)
    return statistic

condition1 = data_condition_1
condition2 = data_condition_2

results = list()

for ii in range(condition2.shape[1]):
    data = (condition2[:, ii], condition1[:, ii])
    res = bootstrap(data, my_statistic, method='basic', random_state=rng, vectorized=False)
    results.append(res)    

conf_ints = np.array([tuple(res.confidence_interval) for res in results])
conf_ints.shape

In [None]:
plt.rcParams['font.size'] = 11

# mean r_plain difference
mean_difference = condition2.mean(axis=0) - condition1.mean(axis=0)

# IC95 for the difference
upper_bound = conf_ints[:,1]
lower_bound = conf_ints[:,0]

fig, ax = plt.subplots(1, 1, figsize=[3.6,1.4], sharey=True)
ax.set_title('Power envelope difference', fontsize = 11, y=1.05)
ax.set_xscale('log', base=2)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
xticks = [1, 2, 4, 8, 16, 32, 64]
ax.set_xticks(xticks)
ax.set_xticklabels(xticks)
#ax.set_ylim(-0.05, 0.04)
plt.ylim(-0.05,0.05)
ax.set_xlim(1, 64)
ax.axhline(0, color='black', linestyle= "--")
sns.despine(offset=10, trim=False);

# Plot the mean difference
line = ax.plot(
    foi,
    mean_difference,
    label="Mean Difference"
)

# Fill the area between upper and lower CI bounds
ax.fill_between(foi,upper_bound, lower_bound, alpha=0.5, label="95% CI")

ax.set_ylabel("Difference")
ax.set_xlabel("Frequencies (Hz)")

plt.savefig('./figures/figure3b.pdf', dpi=300, bbox_inches='tight')

In [None]:
# Assuming `foi`, `T_obs`, and `cluster_p_values` are defined
mask = cluster_p_values <= 0.25

plt.rcParams['font.size'] = 11
fig, ax2 = plt.subplots(1, 1, figsize=[3.6,1.4], sharey=True)
ax2.set_title('Permutation F-test on power envelope', fontsize = 11, y=1.1)

# Plot the main black line
hf, = ax2.plot(foi, T_obs, "black")

# Find contiguous regions of significant frequencies
inside_cluster = False
fill_plotted = False  # Flag to track if the green fill has already been labeled

for i in range(len(foi)):
    if mask[i] and not inside_cluster:
        start_idx = i
        inside_cluster = True
    elif not mask[i] and inside_cluster:
        end_idx = i
        # Fill between for this cluster region
        if not fill_plotted:
            ax2.fill_between(foi[start_idx:end_idx], y1=T_obs[start_idx:end_idx], y2=0, 
                             color="green", label='cluster P < 0.25', alpha=0.3)
            fill_plotted = True  # Set flag to True so that the label is only used once
        else:
            ax2.fill_between(foi[start_idx:end_idx], y1=T_obs[start_idx:end_idx], y2=0, 
                             color="green", alpha=0.3)
        inside_cluster = False
# If a cluster goes till the end of the data
if inside_cluster:
    if not fill_plotted:
        ax2.fill_between(foi[start_idx:], y1=T_obs[start_idx:], y2=0, color="green", alpha=0.3, label='P < 0.10')
    else:
        ax2.fill_between(foi[start_idx:], y1=T_obs[start_idx:], y2=0, color="green", alpha=0.3)

# Set axes labels and scales
ax2.set_xlabel("Frequencies (Hz)")
ax2.set_ylabel("F-values")
ax2.set_xscale('log', base=2)
xticks = [1, 2, 4, 8, 16, 32, 64]
ax2.set_xticks(xticks)
ax2.set_xlim(1, 64)
ax2.set_ylim(0,1.5)
ax2.set_xticklabels(xticks)

# Customize appearance and remove unnecessary spines
ax2.spines['top'].set_visible(False)
ax2.spines['right'].set_visible(False)
sns.despine(offset=5, trim=False)

# Add the legend
ax2.legend(loc='upper right', bbox_to_anchor=(1.1, 1.15), ncol=2, fontsize=10, frameon=False)

# Save the figure
plt.savefig('./figures/figure3c.pdf', bbox_inches='tight')

In [None]:
# print significant frequencies for clusters
print(foi[mask])

### Non parametric two-tailed permutation test (uncorrected for multiple comparisons) to find frequencies with p-val < 0.05 

In [None]:
r_plain_converter = np.mean(r_plain[converters_idx],1)

In [None]:
r_plain_cluster_converters = r_plain_converter[feature_converter[converters_idx] == 1]
r_plain_cluster_nonconverters = r_plain_converter[feature_converter[converters_idx] == 0]
r_plainT = r_plain_converter.transpose(0,2,1) # converter/non
r_plainT.shape

In [None]:
rng = np.random.RandomState(23)

# Define custom statistic function: difference of means
def my_statistic(sample1, sample2, axis=0):
    return np.mean(sample1, axis=axis) - np.mean(sample2, axis=axis)

# Generate or define condition1 and condition2 data arrays
condition1 = np.mean(r_plainT[feature_converter[converters_idx] == 0], axis=2)
condition2 = np.mean(r_plainT[feature_converter[converters_idx] == 1], axis=2)

# Initialize list to store indices of features with significant p-values
significant_indices = []

# Loop through each feature to perform permutation test
for i in range(condition1.shape[1]):
    # Extract data for the current feature from both conditions
    data = (condition1[:, i], condition2[:, i])
    
    # Perform permutation test
    res = permutation_test(data, my_statistic, vectorized=True, random_state=rng)
    
    # If p-value < 0.05, add the feature index to the list of significant features
    if res.pvalue < 0.05:
        significant_indices.append(i)

# Print the indices of features with p-value below 0.05
print("Indices with p-value below 0.05:", significant_indices)


In [None]:
foi[significant_indices]

## topographical maps, permutation cluster test

In [None]:
# cluster permutation test with TFCE keeping channel info and frequencies info
F, c, alpha, h = mne.stats.spatio_temporal_cluster_test([
    r_plainT[feature_converter[converters_idx] == 0],
    r_plainT[feature_converter[converters_idx] == 1]
    ],
    threshold = dict (start = 0, step = 0.2),
    n_jobs = 4,
    tail=0,
    n_permutations = 10000,
    adjacency = adj_matrix
)

In [None]:
p_clust = np.zeros((r_plainT.shape[1],r_plainT.shape[2])) # init p_clust, shape freq, channels
for cl, p in zip(c, alpha): # fills in p_clust array 
    p_clust[cl]=-np.log10(p)
p_clust.shape
mask = np.ones_like(p_clust, dtype = bool)
mask[p_clust<1.3] = False

In [None]:
idx = np.where(mask.sum(axis=1)>0)[0] 
print(idx[::6]) # utilise un pas de 6

In [None]:
r_plain_converters = r_plain_converter[feature_converter[converters_idx] == 1]
r_plain_nonconverters = r_plain_converter[feature_converter[converters_idx] == 0]
r_plain_converters_mean = np.mean(r_plain_converters, 0)
r_plain_nonconverters_mean = np.mean(r_plain_nonconverters, 0)

## topoplots

In [None]:
def topoplot(    
    data,  # 1D array or list with values corresponding to the channels in channel_names
    channel_names, # list of channel names, e.g. ["Fp1", "Fp2", ...]
    info,  
    montage = "standard_1020",
    cmap='RdBu_r',
    scale_limits=(None, None),  # useful to plot several topographies on the same scale
    size=1,  # size of the topoplot
    axes=None,
    mask=None,
):
    assert len(data) == len(channel_names)

    _, pos, _, _, ch_type, sphere_, clip_origin = _prepare_topomap_plot(
        info,
        "mag",
        sphere=(0, 0.023, 0.021, 0.1)
    )
    
    outlines_ = _make_head_outlines(sphere_, pos, 'head', clip_origin)
    
    im, cn = mne.viz.plot_topomap(
        data=data, 
        pos=pos, 
        axes=axes,
        vlim = scale_limits,
        cmap=cmap,
        outlines=outlines_,
        image_interp='cubic',
        size=size,
        mask=mask,
        show=False,
    )

    return im, cn

def plot_topoplot_grid(
    data, # data[i][j] contains a 1D list or array with data for the topo plot in row i and col j
    info, 
    row_labels,  # list of row labels that describe the rows in data, e.g. ["Condition 1", "Condition 2"]
    col_labels,  # list of col labels that describe the cols in data, e.g. ["Alpha" ,"Beta", "Gamma"]
    channel_names,  # channel labels in corresponding order to the 1D lists or arrays in data
    cbar_mode="single",  # "single" for shared scale, "each" for individual scales
    cbar_label=r"$10\times\log_{10}fT^2/Hz$[dB]",  # label for the color bar,
    #cbar_label=r"dB",
    cmap='RdBu_r',
    scale_limits=(None, None),
    mask=None,
):
    # Figure Grid Params
    x_label_size=12
    y_label_size=12
    x_label_pad=10
    y_label_pad=10
    
    cbar_fmt='%3.2f'
    cbar_size="5%"
    cbar_pad=0.1 if cbar_mode == "each" else 0.8
    axes_pad=(0.9 if cbar_mode == "each" else 0.4 , 0.4)
    clabel_size=10
    fig_size=(12, 4)
    rect=(0.05, 0.05, 0.90, 0.95)

    nrows = len(data)
    ncols = len(data[0])

    fig = plt.figure(figsize=fig_size)
    grid = ImageGrid(
        fig, rect,
        nrows_ncols=(nrows, ncols),
        axes_pad=axes_pad, share_all=True, cbar_location="right",
        cbar_mode=cbar_mode, cbar_size=cbar_size, cbar_pad=cbar_pad,
    )

    
    if cbar_mode == "single" and scale_limits == (None, None):
        scale_limits = np.array(data).min(), np.array(data).max()
    elif cbar_mode == "each" and scale_limits == "columns":
        mins = np.min(np.hstack(data), axis=1)
        maxs = np.max(np.hstack(data), axis=1)
 
    scale_limits_ = scale_limits
    axes = grid.axes_row
    for row, row_label in enumerate(row_labels):
        for col, col_label in enumerate(col_labels):
            ax = axes[row][col]
            if cbar_mode == "each" and scale_limits == "columns":
                scale_limits_ = [mins[col], maxs[col]]
            im, _ = topoplot(data[row][col], channel_names, info, scale_limits=scale_limits_,
                             axes=ax,
                             cmap=cmap,
                             mask= None if mask is None else mask[row][col])

            ax.set_xlabel(col_label, fontsize=x_label_size, labelpad=x_label_pad)
            ax.set_ylabel(row_label, fontsize=y_label_size, labelpad=y_label_pad)

            cbar = ax.cax.colorbar(im)
            cbar.ax.set_ylabel(cbar_label, fontsize=clabel_size)
            ax.cax.toggle_label(True)
     
    return fig

In [None]:
def min_max(c):
    this_min, this_max = np.min(c), np.max(c)
    my_max = max(abs(this_min), abs(this_max))
    return my_max

In [None]:
difference = np.squeeze(np.array([r_plain_converters_mean - r_plain_nonconverters_mean]))

data_diff = [
    [difference[:, ix] for ix in [5, 13, 39, 62, 116]],
]
mask_grid = [
    [mask.T[:, ix] for ix in [5, 13, 39, 62, 116]],
]
row_labels = ["Difference"]
col_labels = ["1.2 Hz", "1.6 Hz", "3.9 Hz", "8.6 Hz", "55.7 Hz"]
channel_names = raw_for_adjacency.info.ch_names

fig = plot_topoplot_grid(data_diff, raw_for_adjacency.info, row_labels, col_labels, channel_names, scale_limits = [lambda x: -min_max(x), min_max],
                         cbar_mode="single", mask=mask_grid,
                         );
fig.set_size_inches(10, 5)
plt.savefig('./figures/figure3_topoplot.pdf', dpi=300, bbox_inches='tight')