In [ ]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from scipy.stats import zscore
import pingouin as pg
plt.rc('font', family = 'arial')
from pingouin import mediation_analysis
from scipy import stats
from statsmodels.stats.multitest import fdrcorrection

In [ ]:
# Frequency bands definition
BANDS = {'Delta': (1, 4), 'Theta': (4, 8), 'Alpha': (8, 13)}

# Preload channel names
CHANNELS = ['F7', 'Fp1', 'Fp2', 'F8', 'F3', 'Fz', 'F4', 'C3', 'Cz', 'P8', 'P7', 'Pz', 'P4', 'T3', 'P3', 'O1', 'O2', 'C4', 'T4']

# Find the index location of channels in anatomical groups
All_Channel_indices = np.arange(len(CHANNELS))
Occipital_Channel_indices = [CHANNELS.index('O1'), CHANNELS.index('O2')]
Frontal_Channel_indices = [CHANNELS.index('F7'), CHANNELS.index('F3'), CHANNELS.index('Fz'), CHANNELS.index('F4'), CHANNELS.index('F8')]
Central_Channel_indices = [CHANNELS.index('Cz'), CHANNELS.index('C3'), CHANNELS.index('C4')]

# Name of anatomical channel groups and channel indices
Channel_Groups = {
    'All': All_Channel_indices,
    'Occipital': Occipital_Channel_indices, 
    'Frontal': Frontal_Channel_indices,
    'Central': Central_Channel_indices,
}

# Directory paths
csv_path = '04_Features_EEG/Range_13_10_Sec_Epoch_Fixed_Merged.csv'

# Load dataframe
all_data = pd.read_csv(csv_path)
filt_data = all_data.copy()

# Filter dataframe for specific timepoints and Record IDs with at least 3 timepoints
filt_data = all_data[(all_data['Timepoint'].isin(['Baseline', 'Post_Wash', 'Post_Placebo', 'Post_CBD']))]
filt_data = filt_data.groupby('Record ID').filter(lambda x: len(x['Timepoint'].unique()) >= 3)

# Z-score normalize metabolite levels across subjects
for metabolite in ['CBD', 'OHCBD', 'COOHCBD', 'AEA']:
    filt_data[f'{metabolite}_Z_score'] = zscore(filt_data[metabolite])

# Calculate combined score
filt_data['Combined_score'] = (zscore(filt_data['CBD']) + zscore(filt_data['OHCBD']) + zscore(filt_data['COOHCBD'])) / 3

# Convert categorical variables to numeric codes
for col in ['Timepoint', 'Randomization', 'ADOS_Module']:
    filt_data[f'{col}_numeric'] = filt_data[col].astype('category').cat.codes

# Sort the dataframe
filt_data = filt_data.sort_values(by=['Record ID', 'Timepoint'], ascending=[True, False])

# Print summary information
print(f"Number of unique 'Record ID' in the comparisons: n = {len(filt_data['Record ID'].unique())}")
print(filt_data['Record ID'].unique())

# Function to calculate statistics for each group
def calculate_stats(group):
    return pd.Series({
        'Mean Age': group['Age'].mean(),
        'SEM Age': group['Age'].sem(),
        'Min Age': group['Age'].min(),
        'Max Age': group['Age'].max(),
        'Unique Record ID': group['Record ID'].nunique()
    })

# Group by 'Timepoint' and calculate statistics
result = filt_data.groupby('Timepoint').apply(calculate_stats).reset_index()
print(result)

In [0]:
def plot_significant_mediation(results, x_label, m_label, y_label, title, data, np=np, plt=plt, stats=stats, pg=pg):
    # Check for significant indirect effects
    significant_indirect = results[(results['path'].str.contains('Indirect')) & (results['sig'] == 'Yes')]
    
    if significant_indirect.empty:
        print(f"No significant indirect effect found for {title}")
        return
    
    # Determine which column contains p-values
    p_value_col = 'pval' if 'pval' in results.columns else 'p-val' if 'p-val' in results.columns else None
    if p_value_col is None:
        print("Warning: Could not find p-value column. Skipping p-value annotation.")
    
    # Extract coefficients for paths a, b, and c
    a = results.loc[results['path'].str.contains('~ X'), 'coef'].values
    b = results.loc[results['path'].str.startswith('Y ~'), 'coef'].values
    c = results.loc[results['path'] == 'Direct', 'coef'].values[0]
    
    # Generate data points for plotting
    x = np.linspace(data[x_label].min(), data[x_label].max(), 100)
    m = np.outer(x, a)
    y_direct = c * x
    y_indirect = np.dot(m, b)
    y_total = y_direct + y_indirect

    x_data = data[x_label]
    y_data = data[y_label]
    
    # Calculate correlations and partial correlations
    r_a = stats.pearsonr(data[x_label], data[m_label])[0]
    r2_a = r_a ** 2
    r_b = pg.partial_corr(data=data, x=m_label, y=y_label, covar=x_label)['r'].values[0]
    r2_b = r_b ** 2
    r_c = stats.pearsonr(data[x_label], data[y_label])[0]
    r2_c = r_c ** 2
    r_c_prime = pg.partial_corr(data=data, x=x_label, y=y_label, covar=m_label)['r'].values[0]
    r2_c_prime = r_c_prime ** 2
    
    # Plot 1: Relationships between X, mediators, and Y
    fig1, ax1 = plt.subplots(figsize=(12, 8))
    for i, m_path in enumerate(results['path'][results['path'].str.contains('~ X')]):
        mediator = m_path.split('~')[0].strip()
        color = 'red' if f'Indirect {mediator}' in significant_indirect['path'].values else 'gray'
        ax1.plot(x, m[:, i], label=f'{x_label} -> {mediator}', color=color)
    ax1.plot(x, y_direct, label=f'{x_label} -> {y_label} (Direct)', linestyle='--', color='black')
    ax1.set_xlabel(x_label, fontsize=14)
    ax1.set_ylabel('Estimated Value', fontsize=14)
    ax1.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=12)
    ax1.set_title(f'Relationships: {x_label} to mediators and {y_label}', fontsize=16)
    ax1.tick_params(axis='both', which='major', labelsize=12)
    plt.tight_layout()
    plt.show()

    # Plot 2: Indirect and Total Effects
    fig2, ax2 = plt.subplots(figsize=(12, 8))
    ax2.plot(x, y_indirect, label=f'Indirect Effect on {y_label}', color='blue')
    ax2.plot(x, y_total, label=f'Total Effect on {y_label}', color='green')
    ax2.set_xlabel(x_label, fontsize=14)
    ax2.set_ylabel(y_label, fontsize=14)
    ax2.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=12)
    ax2.set_title(f'Indirect and Total Effects on {y_label}', fontsize=16)
    ax2.tick_params(axis='both', which='major', labelsize=12)
    
    # Add effect size annotation
    for _, row in significant_indirect.iterrows():
        annotation_text = f"{row['path']}\nEffect: {row['coef']:.3f}"
        if p_value_col:
            annotation_text += f"\np-value: {row[p_value_col]:.3f}"
        ax2.annotate(annotation_text, xy=(0.05, 0.95), xycoords='axes fraction',
                     va='top', ha='left', bbox=dict(boxstyle='round', fc='white', ec='gray', alpha=0.8),
                     fontsize=12)
    plt.tight_layout()
    plt.show()

    # Plot 3: Scatter plot with regression lines
    fig3, ax3 = plt.subplots(figsize=(12, 8))
    post_cbd_mask = data['Timepoint'] == 'Post_CBD'
    ax3.scatter(x_data[post_cbd_mask], y_data[post_cbd_mask], color='red', label='Post_CBD', alpha=0.7)
    ax3.scatter(x_data[~post_cbd_mask], y_data[~post_cbd_mask], color='black', label='Other Timepoints', alpha=0.7)
    
    # Post_CBD regression
    post_cbd_fit = np.polyfit(x_data[post_cbd_mask], y_data[post_cbd_mask], 1)
    post_cbd_line = np.poly1d(post_cbd_fit)
    ax3.plot(x_data, post_cbd_line(x_data), color='red', linestyle='--', label='Post_CBD Regression')
    
    # Other Timepoints regression
    other_fit = np.polyfit(x_data[~post_cbd_mask], y_data[~post_cbd_mask], 1)
    other_line = np.poly1d(other_fit)
    ax3.plot(x_data, other_line(x_data), color='black', linestyle='--', label='Other Timepoints Regression')
    
    ax3.set_xlabel(x_label, fontsize=14)
    ax3.set_ylabel(y_label, fontsize=14)
    ax3.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=12)
    ax3.set_title(f'{y_label} with Linear Regression', fontsize=16)
    ax3.tick_params(axis='both', which='major', labelsize=12)

    # Calculate the difference between regressions
    slope_diff = post_cbd_fit[0] - other_fit[0]
    intercept_diff = post_cbd_fit[1] - other_fit[1]
    
    # Perform statistical test for slope difference
    from scipy import stats
    t_stat, p_value = stats.ttest_ind(y_data[post_cbd_mask] - post_cbd_line(x_data[post_cbd_mask]),
                                      y_data[~post_cbd_mask] - other_line(x_data[~post_cbd_mask]))
    
    # Add difference statistics annotation
    ax3.text(0.05, 0.95, f"Slope difference: {slope_diff:.3f}", transform=ax3.transAxes, 
             verticalalignment='top', fontsize=12)
    ax3.text(0.05, 0.90, f"Intercept difference: {intercept_diff:.3f}", transform=ax3.transAxes, 
             verticalalignment='top', fontsize=12)
    ax3.text(0.05, 0.85, f"Difference p-value: {p_value:.3f}", transform=ax3.transAxes, 
             verticalalignment='top', fontsize=12)

    plt.tight_layout()
    plt.show()

    # Plot 4: Mediation Triangle with Full Statistics
    fig4, ax4 = plt.subplots(figsize=(12, 10))
    ax4.axis('off')
    ax4.set_title('Mediation Triangle with Full Statistics', fontsize=16)
    
    # Draw the triangle
    triangle = plt.Polygon([(0,0), (1,0), (0.5,0.866)], fill=False)
    ax4.add_patch(triangle)
    
    # Add labels
    ax4.text(0, -0.1, x_label, ha='center', va='center', fontsize=12)
    ax4.text(1, -0.1, y_label, ha='center', va='center', fontsize=12)
    ax4.text(0.5, 0.4, m_label, ha='center', va='center', fontsize=12)
    
    # Add arrows and statistics
    ax4.annotate('', xy=(0.9, 0.05), xytext=(0.1, 0.05), arrowprops=dict(arrowstyle='<->', color='r'))
    direct_effect = results.loc[results['path'] == 'Direct']
    ax4.text(0.5, -0.05, f"Direct Effect (c'):\nCoef = {direct_effect['coef'].values[0]:.3f}\np-value = {direct_effect[p_value_col].values[0]:.3f}\nr = {r_c_prime:.3f}, r² = {r2_c_prime:.3f}", ha='center', va='center', fontsize=10)
    
    ax4.annotate('', xy=(0.25, 0.433), xytext=(0.05, 0.05), arrowprops=dict(arrowstyle='<->', color='b'))
    a_effect = results.loc[results['path'].str.contains('~ X')]
    ax4.text(0.2, 0.25, f"X -> M (a):\nCoef = {a_effect['coef'].values[0]:.3f}\np-value = {a_effect[p_value_col].values[0]:.3f}\nr = {r_a:.3f}, r² = {r2_a:.3f}", ha='center', va='center', fontsize=10)
    
    ax4.annotate('', xy=(0.95, 0.05), xytext=(0.75, 0.433), arrowprops=dict(arrowstyle='<->', color='g'))
    b_effect = results.loc[results['path'].str.startswith('Y ~')]
    ax4.text(0.8, 0.25, f"M -> Y (b):\nCoef = {b_effect['coef'].values[0]:.3f}\np-value = {b_effect[p_value_col].values[0]:.3f}\nr = {r_b:.3f}, r² = {r2_b:.3f}", ha='center', va='center', fontsize=10)

    # Add full model statistics below the diagram
    stats_text = "Full Model Statistics:\n\n"
    for _, row in results.iterrows():
        if row['path'] == 'Total':
            stats_text += f"{row['path']}:\nCoefficient = {row['coef']:.3f}, p-value = {row[p_value_col]:.3f}\nr = {r_c:.3f}, r² = {r2_c:.3f}\n\n"
        elif row['path'] == 'Direct':
            stats_text += f"{row['path']}:\nCoefficient = {row['coef']:.3f}, p-value = {row[p_value_col]:.3f}\nr = {r_c_prime:.3f}, r² = {r2_c_prime:.3f}\n\n"
        else:
            stats_text += f"{row['path']}:\nCoefficient = {row['coef']:.3f}, p-value = {row[p_value_col]:.3f}\n\n"
    
    ax4.text(0.5, -0.5, stats_text, ha='center', va='center', fontsize=10, bbox=dict(facecolor='white', edgecolor='black', alpha=0.8))

    plt.tight_layout()
    plt.show()

    # Print overall title and significant indirect effects
    print(f"\n{title}\nSignificant Indirect Effect(s):\n{', '.join(significant_indirect['path'])}")

In [ ]:
# Initialize dictionaries and lists to store results
mediation_results = {}
all_p_values = []
all_result_keys = []

# Define lists of metabolites, behavioral tests, and electrode groups to analyze
metabolite_list = ['COOHCBD_Z_score']
beh_test_list = ['rbs_total_score', 'ppvt_raw_score', 'toni4_raw_score', 'eowpvt4_raw_score', 'beery_vmi_raw_score', 'beery_vp_raw_score', 'beery_mc_raw_score']
electrode_group = ['Occipital', 'Frontal', 'Central']

# Iterate through each behavioral test
for beh_test in beh_test_list:
    
    # Drop rows with missing values in the relevant behavioral test column
    df = filt_data.dropna(subset=[beh_test])
    # Filter to include only subjects with at least 3 timepoints
    df_beh_filt = df.groupby('Record ID').filter(lambda x: len(x['Timepoint'].unique()) >= 3)
     
    # Calculate statistics for each timepoint
    result = df_beh_filt.groupby('Timepoint').apply(calculate_stats).reset_index()
    print(beh_test)
    print(result)
    
    # Iterate through each metabolite
    for metabolite in metabolite_list:
        # Iterate through each electrode group
        for electrodes in electrode_group:
            # Define fixed covariates
            fixed_covariates = ['Age', 'Timepoint_numeric', 'Randomization_numeric', 'ADOS_numeric']
            # Define EEG covariates for the current electrode group
            eeg_covariates = [f'{electrodes}_Exponent', f'{electrodes}_Alpha_SNR', f'{electrodes}_Offset', f'{electrodes}_Delta_SNR', f'{electrodes}_Theta_SNR']
            
            # Iterate through each EEG covariate
            for eeg_covariate in eeg_covariates:
                # Create a unique key for the current analysis
                result_key = f'{beh_test}_{metabolite}_{electrodes}_{eeg_covariate}'
                
                # Perform mediation analysis
                results = mediation_analysis(data=df_beh_filt, x=eeg_covariate, m=metabolite, y=beh_test, covar=fixed_covariates, n_boot=1000, seed=42).round(3)
                
                # Print results
                print(result_key)
                print('\n')
                print(results)
                print('\n')
                print('\n')
                
                # Store results in the mediation_results dictionary
                mediation_results[result_key] = results
                
                # Store p-values and result keys for FDR correction
                all_p_values.extend(results['pval'].tolist())
                all_result_keys.extend([result_key] * len(results))
            
                # Plot significant mediation relationships
                plot_significant_mediation(
                    results=results, 
                    x_label=f'{eeg_covariate}',
                    m_label=metabolite,
                    y_label=beh_test,
                    title=result_key,
                    data=df_beh_filt
                )

In [ ]:
def hierarchical_fdr_correction(mediation_results):
    # Dictionary to store FDR-corrected results
    fdr_corrected_results = {}
    
    # Identify unique electrode groups from the result keys
    electrode_groups = set()
    for key in mediation_results.keys():
        parts = key.split('_')
        electrode_group = parts[-2]  # Assuming electrode group is always second to last
        electrode_groups.add(electrode_group)
    
    # First level: FDR correction across electrode groups
    group_min_p_values = []
    group_keys = []
    
    # Find the minimum p-value for each electrode group
    for group in electrode_groups:
        group_p_values = []
        for key, results in mediation_results.items():
            if group in key:
                indirect_effect = results[results['path'].str.contains('Indirect')]
                if not indirect_effect.empty and indirect_effect['sig'].values[0] == 'Yes':
                    p_value = indirect_effect['pval'].values[0]
                    group_p_values.append(p_value)
        
        if group_p_values:
            min_p_value = min(group_p_values)
            group_min_p_values.append(min_p_value)
            group_keys.append(group)
    
    # Perform FDR correction on the minimum p-values across groups
    rejected_groups, corrected_group_p_values = fdrcorrection(group_min_p_values, alpha=0.05, method='indep')
    
    # Second level: FDR correction within significant electrode groups
    for group, is_rejected in zip(group_keys, rejected_groups):
        if is_rejected:
            group_p_values = []
            group_result_keys = []
            
            # Collect p-values for significant indirect effects within the group
            for key, results in mediation_results.items():
                if group in key:
                    indirect_effect = results[results['path'].str.contains('Indirect')]
                    if not indirect_effect.empty and indirect_effect['sig'].values[0] == 'Yes':
                        p_value = indirect_effect['pval'].values[0]
                        group_p_values.append(p_value)
                        group_result_keys.append(key)
            
            # Perform FDR correction within the group
            rejected, corrected_p_values = fdrcorrection(group_p_values, alpha=0.05, method='indep')
            
            # Update results with corrected p-values and significance
            for key, p_value, is_rejected in zip(group_result_keys, corrected_p_values, rejected):
                if key not in fdr_corrected_results:
                    fdr_corrected_results[key] = mediation_results[key].copy()
                
                indirect_index = fdr_corrected_results[key]['path'].str.contains('Indirect')
                fdr_corrected_results[key].loc[indirect_index, 'corrected_pval'] = p_value
                fdr_corrected_results[key].loc[indirect_index, 'significant_after_correction'] = is_rejected
    
    return fdr_corrected_results

# Apply hierarchical FDR correction to the mediation results
fdr_corrected_results = hierarchical_fdr_correction(mediation_results)

# Print FDR-corrected results
for key, results in fdr_corrected_results.items():
    print(f"Results for {key}:")
    print(results)
    print("\n")