In [None]:
# -*- coding: utf-8 -*-
# --- Cell 1: Centralized Imports and Global Configuration ---

# --- Part 1: All Library Imports ---
# Python Core Libraries
import os
import glob
import re
import warnings
from datetime import datetime, time

# Data Handling & Scientific Computing
import numpy as np
import pandas as pd
import pytz
from scipy.stats import (pearsonr, spearmanr, mannwhitneyu, t, kruskal)
import scikit_posthocs as sp

# Statistical Modeling
import pingouin as pg
import statsmodels.api as sm
import statsmodels.formula.api as smf
from statsmodels.graphics.regressionplots import plot_partregress_grid
from statsmodels.stats.multitest import fdrcorrection
from statsmodels.stats.anova import anova_lm

# Plotting & Visualization
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib.patches as mpatches
import matplotlib.lines as mlines
from matplotlib.ticker import FormatStrFormatter
import seaborn as sns

warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=UserWarning)
warnings.filterwarnings("ignore", message="Degrees of freedom <= 0 for slice")
warnings.filterwarnings("ignore", message="p-value may not be accurate for N > 5000")
warnings.filterwarnings("ignore", message="invalid value encountered in scalar divide")
warnings.filterwarnings("ignore", message="Confidence interval might not be reliable for bootstrap samples with fewer than 50 elements.")

# --- Part 3: User Input and Path Configuration ---
# This script processes ONE patient-hemisphere at a time.
patient_hemisphere_id = "RCS20R" # <<< USER: SET THIS FOR THE CURRENT PATIENT-HEMISPHERE
project_base_path = '..'
step3_output_version_tag = "neural_pkg_aligned" # <<< USER: Ensure this matches Step 3's tag

# Derived Paths (no user input needed below this line)
step3_master_csv_base_folder = os.path.join(project_base_path, f'step3_fooof_results_{step3_output_version_tag}')
master_csv_filename = f"MASTER_FOOOF_PKG_results_{patient_hemisphere_id}_{step3_output_version_tag}.csv"
master_csv_path_to_load = os.path.join(step3_master_csv_base_folder, master_csv_filename)
step4_analysis_root_folder = os.path.join(step3_master_csv_base_folder, 'step4_within_subject')
os.makedirs(step4_analysis_root_folder, exist_ok=True)
current_datetime_str_step4 = datetime.now().strftime('%Y%m%d_%H%M%S')
session_plot_folder_name_step4 = f"{patient_hemisphere_id}_plots_{current_datetime_str_step4}"
analysis_session_plot_folder_step4 = os.path.join(step4_analysis_root_folder, session_plot_folder_name_step4)
os.makedirs(analysis_session_plot_folder_step4, exist_ok=True)

# --- Part 4: Column Name and Metric Definitions ---
# Metric Column Dictionaries
APERIODIC_METRICS_COLS = {
    'Exponent_BestModel': 'Aperiodic Exponent',
    'Offset_BestModel': 'Aperiodic Offset',
}
PKG_METRICS_COLS = {
    'Aligned_BK': 'PKG BK Score',
    'Aligned_DK': 'PKG DK Score',
    'Aligned_Tremor_Score': 'PKG Tremor Score'
}
OSCILLATORY_METRICS_COLS = {
    'Beta_Peak_Power_at_DominantFreq': 'Beta Peak Power',
    'Gamma_Peak_Power_at_DominantFreq': 'Gamma Peak Power'
}
APERIODIC_METRICS_TO_PLOT = ['Exponent_BestModel'] # Used in daily exponent plots

# Key Column Names
CHANNEL_COL = 'Channel'
CHANNEL_DISPLAY_COL = 'Channel_Display'
FOOOF_FREQ_BAND_COL = 'FreqRangeLabel'
CLINICAL_STATE_COL = 'Clinical_State_2min_Window'
CLINICAL_STATE_AGGREGATED_COL = 'Clinical_State_Aggregated'

# --- Part 5: Analysis and Plotting Parameters ---
# Ordering for Iterations and Plots
ORDERED_FREQ_LABELS = ["LowFreq", "MidFreq", "WideFreq"]
# Note: ORDERED_CHANNEL_LABELS is derived from data in the original script.
# We will define the mapping here for consistency.
CHANNEL_ORDER_MAP = {
    'STN_DBS_2-0': 0, 'STN_DBS_3-1': 1, 
    'Cortical_ECoG_10-8': 2, 'Cortical_ECoG_11-9': 3
}
CHANNEL_ORDER_LIST = ['STN_DBS_2-0', 'STN_DBS_3-1', 'Cortical_ECoG_10-8', 'Cortical_ECoG_11-9']
CHANNEL_GROUP_MAP = {'STN': ['STN_DBS_2-0', 'STN_DBS_3-1'], 'M1': ['Cortical_ECoG_10-8', 'Cortical_ECoG_11-9']}


# Statistical Thresholds
P_VALUE_THRESHOLD = 0.05
MIN_SAMPLES_FOR_CORR = 5
MIN_SAMPLES_FOR_GROUP_COMPARISON = 5
R2_FILTER_THRESHOLD = 0.5

# Clinical State Definitions and Ordering
TARGET_CLINICAL_STATES_ORDERED = ["Immobile", "Non-Dyskinetic Mobile", "Transitional Mobile", "Dyskinetic Mobile"]
ALL_CLINICAL_STATES_ORDERED = ["Sleep", "Immobile", "Non-Dyskinetic Mobile", "Transitional Mobile", "Dyskinetic Mobile"]
SYMPTOM_ORDER = ['PKG BK Score', 'PKG DK Score', 'PKG Tremor Score']
SYMPTOM_LEGEND_MAP = {'PKG BK Score': 'Bradykinesia', 'PKG DK Score': 'Dyskinesia', 'PKG Tremor Score': 'Tremor'}
SYMPTOM_DISPLAY_ORDER = ['Bradykinesia', 'Dyskinesia', 'Tremor']

# Daily Plot Parameters
SF_TZ = pytz.timezone('America/Los_Angeles')
PLOTTING_INTERVAL_MINUTES = 10
MIN_POINTS_FOR_CI = 2
CONFIDENCE_LEVEL_CI = 0.95
GAP_THRESHOLD_BINS = 2

# MLR Analysis Toggles
ANALYZE_ALL_FREQ_BANDS_GLOBAL = False
TARGET_FREQ_BAND_GLOBAL = "WideFreq"
ANALYZE_ALL_FREQ_BANDS_STATE_SPECIFIC = False
TARGET_FREQ_BAND_STATE_SPECIFIC = "WideFreq"


# --- Part 6: Global Plotting Style Configuration ---
# Seaborn and Matplotlib Global Theme
sns.set_theme(style="whitegrid")
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = ['Arial', 'DejaVu Sans']
plt.rcParams['axes.titlesize'] = 14
plt.rcParams['axes.labelsize'] = 12
plt.rcParams['xtick.labelsize'] = 10
plt.rcParams['ytick.labelsize'] = 10
plt.rcParams['legend.fontsize'] = 10
plt.rcParams['figure.titlesize'] = 18
plt.rcParams['figure.dpi'] = 100
plt.rcParams['savefig.dpi'] = 600

# Color Palettes
BASE_COLOR_PALETTE = {
    'Exponent_BestModel': 'darkslateblue',
    'Offset_BestModel': 'mediumseagreen',
    'Beta_Peak_Power_at_DominantFreq': 'goldenrod',
    'Gamma_Peak_Power_at_DominantFreq': 'firebrick',
    'Aligned_BK': 'steelblue',
    'Aligned_DK': 'orangered',
    'Aligned_Tremor_Score': 'mediumpurple'
}
CLINICAL_STATE_COLORS = {
    'Immobile': '#40E0D0',              # Turquoise
    'Non-Dyskinetic Mobile': '#32CD32', # LimeGreen
    'Transitional Mobile': '#FFD700',   # Gold
    'Dyskinetic Mobile': '#FF6347',     # Tomato
    'Sleep': '#4169E1',                 # RoyalBlue
    'Other': '#C0C0C0',                 # Silver
    'Mobile (All Types)': 'darkgreen'
}
PKG_SYMPTOM_COLORS = {
    'Aligned_BK': BASE_COLOR_PALETTE.get('Aligned_BK', 'steelblue'),
    'Aligned_DK': BASE_COLOR_PALETTE.get('Aligned_DK', 'orangered'),
    'Aligned_Tremor_Score': BASE_COLOR_PALETTE.get('Aligned_Tremor_Score', 'mediumpurple')
}

# Other Plotting Style Constants
DOT_ALPHA = 0.5
REG_CI_ALPHA = 0.15
BOX_FILL_ALPHA = 0.6
BOXPLOT_LINE_THICKNESS = 2.25
REG_LINE_THICKNESS = 2.0
SIGNIFICANT_P_VAL_BG_COLOR = 'khaki'
DEFAULT_P_VAL_BG_COLOR = 'ivory'

print("Cell 1: All imports and global parameters have been defined.")

In [None]:
# -*- coding: utf-8 -*-
# --- Cell 2: User Input, File Paths, and Analysis Parameter Definitions ---

# --- User Input ---
# For development, hardcode; replace with input() in practice if needed for batch processing setup
# This script processes ONE patient-hemisphere at a time.
# The patient_hemisphere_id should match the identifier in the MASTER_FOOOF_PKG_results file name.
# Example: "RCS20R" if your file is MASTER_FOOOF_PKG_results_RCS20R_... .csv
patient_hemisphere_id = "RCS20R" # <<< USER: SET THIS FOR THE CURRENT PATIENT-HEMISPHERE
print(f"Processing data for Patient-Hemisphere ID: {patient_hemisphere_id}")

if not patient_hemisphere_id:
    raise ValueError("Patient-Hemisphere ID cannot be empty.")

# --- Path Configuration ---
print(f"Project base path determined as: {project_base_path}")
print(f"Attempting to load master data from: {master_csv_path_to_load}")
print(f"Step 4 plots will be saved in: {analysis_session_plot_folder_step4}")

DOT_ALPHA_STEP4 = 0.5
REG_CI_ALPHA_STEP4 = 0.15
BOX_FILL_ALPHA_STEP4 = 0.6
REG_LINE_THICKNESS_STEP4 = 2.0

print(f"\nStep 4 analysis parameters and paths configured for {patient_hemisphere_id}.")

In [None]:
# -*- coding: utf-8 -*-
# --- Cell 3: Data Loading and Initial Preprocessing ---

def load_and_preprocess_step4_data(file_path, patient_hemisphere_id_val):
    """Loads and preprocesses the master CSV file from Step 3 for Step 4 analyses."""
    if not os.path.exists(file_path):
        print(f"ERROR: Master CSV file from Step 3 not found at {file_path}")
        return None, []

    try:
        df = pd.read_csv(file_path)
        print(f"Successfully loaded {file_path}. Initial shape: {df.shape}")

        # --- Verify Patient ID consistency (optional but good check) ---
        if 'SessionID' in df.columns and df['SessionID'].nunique() == 1:
            csv_session_id = df['SessionID'].unique()[0]
            if csv_session_id != patient_hemisphere_id_val:
                print(f"Warning: SessionID in CSV ({csv_session_id}) differs from expected ({patient_hemisphere_id_val}). Proceeding with CSV data.")
        elif 'SessionID' not in df.columns:
             print(f"Warning: 'SessionID' column not found. Adding it based on patient_hemisphere_id_val: {patient_hemisphere_id_val}")
             df['SessionID'] = patient_hemisphere_id_val


        # --- Data Type Conversions & Cleaning ---
        cols_to_numeric = (
            list(APERIODIC_METRICS_COLS.keys()) +
            list(PKG_METRICS_COLS.keys()) +
            list(OSCILLATORY_METRICS_COLS.keys()) +
            ['Total_Daily_LEDD_mg', 'R2_BestModel', 'Error_BestModel', 'Num_Peaks_BestModel']
        )
        for col in cols_to_numeric:
            if col in df.columns:
                df[col] = pd.to_numeric(df[col], errors='coerce')
            else:
                print(f"Warning: Expected numeric column '{col}' not found in master_df.")
        
        # Ensure key categorical columns are strings
        for col in [CHANNEL_COL, CHANNEL_DISPLAY_COL, FOOOF_FREQ_BAND_COL, CLINICAL_STATE_COL, CLINICAL_STATE_AGGREGATED_COL, 'Hemisphere', 'BestModel_AperiodicMode']:
            if col in df.columns:
                df[col] = df[col].astype(str)
            else:
                print(f"Warning: Expected categorical column '{col}' not found.")


        # --- Filtering based on FOOOF fit quality (optional, as done in original Step 4) ---
        initial_rows = len(df)
        if 'R2_BestModel' in df.columns:
            r2_threshold_step4 = 0.5 # Example threshold
            df = df[df['R2_BestModel'] >= r2_threshold_step4].copy()
            print(f"Filtered by R2_BestModel >= {r2_threshold_step4}. Rows changed from {initial_rows} to {len(df)}.")
        else:
            print("Warning: 'R2_BestModel' column not found. Cannot filter by FOOOF fit quality.")

        # --- Create Channel_Display if not present (mapping from Step 3 should have done this) ---
        # This is a fallback
        if CHANNEL_DISPLAY_COL not in df.columns and CHANNEL_COL in df.columns:
            print(f"'{CHANNEL_DISPLAY_COL}' not found. Creating from '{CHANNEL_COL}'. Hardcoded map will be used if available.")
            # Hardcoded map (ensure this is consistent with Step 3 and your data)
            channel_mapping_step4 = {
                'TD_key0': 'STN_DBS_2-0', 'TD_key1': 'STN_DBS_3-1',
                'TD_key2': 'Cortical_ECoG_10-8', 'TD_key3': 'Cortical_ECoG_11-9'
            }
            df[CHANNEL_DISPLAY_COL] = df[CHANNEL_COL].map(channel_mapping_step4).fillna(df[CHANNEL_COL])


        # Determine ordered channel labels from the data if not hardcoded in Cell 2
        if CHANNEL_DISPLAY_COL in df.columns:
            def natural_sort_key_step4(s):
                return [int(text) if text.isdigit() else text.lower() for text in re.split('([0-9]+)', str(s))]
            
            # Use pre-defined channel mapping for ordering if available
            defined_channel_map_order = {
                'STN_DBS_2-0': 0, 'STN_DBS_3-1': 1, 
                'Cortical_ECoG_10-8': 2, 'Cortical_ECoG_11-9': 3
            }
            unique_ch_labels_data = df[CHANNEL_DISPLAY_COL].unique()
            # Filter and sort based on the defined order
            ordered_ch_labels_from_data = sorted(
                [ch for ch in unique_ch_labels_data if ch in defined_channel_map_order],
                key=lambda x: defined_channel_map_order[x]
            )
            # Add any channels from data not in defined_channel_map_order, sorted naturally
            ordered_ch_labels_from_data.extend(
                sorted([ch for ch in unique_ch_labels_data if ch not in defined_channel_map_order], 
                       key=natural_sort_key_step4)
            )
            print(f"Derived ORDERED_CHANNEL_LABELS for plots: {ordered_ch_labels_from_data}")
        else:
            print(f"ERROR: '{CHANNEL_DISPLAY_COL}' not found. Cannot determine channel order.")
            ordered_ch_labels_from_data = []
            
        # Drop rows where essential metrics for correlation/regression are NaN
        # This step is crucial as analysis loops will iterate through these.
        key_metrics_for_analysis = (
            list(APERIODIC_METRICS_COLS.keys()) +
            list(PKG_METRICS_COLS.keys()) +
            list(OSCILLATORY_METRICS_COLS.keys())
        )
        key_metrics_present = [col for col in key_metrics_for_analysis if col in df.columns]
        
        rows_before_na_essential_drop = len(df)
        if key_metrics_present:
            df.dropna(subset=key_metrics_present, how='any', inplace=True) # Drop if ANY of these are NaN for a row
            print(f"Dropped rows with NaNs in any of {key_metrics_present}. Rows changed from {rows_before_na_essential_drop} to {len(df)}.")
        else:
            print("Warning: No key metrics found to check for NaNs. Data might be incomplete.")


        print(f"Final master_df for Step 4 shape: {df.shape}")
        if df.empty:
            print("Warning: DataFrame is empty after preprocessing. Subsequent analyses might fail.")
        
        return df, ordered_ch_labels_from_data

    except Exception as e:
        print(f"Error loading or processing the CSV file '{file_path}' for Step 4: {e}")
        print(traceback.format_exc())
        return None, []

# --- Load and Preprocess Data ---
master_df_step4, ORDERED_CHANNEL_LABELS = load_and_preprocess_step4_data(master_csv_path_to_load, patient_hemisphere_id)

if master_df_step4 is not None and not master_df_step4.empty:
    print("\nFirst 5 rows of the processed master DataFrame for Step 4 (showing key cols):")
    cols_to_show_step4 = [CHANNEL_DISPLAY_COL, FOOOF_FREQ_BAND_COL] + \
                         list(APERIODIC_METRICS_COLS.keys()) + \
                         list(PKG_METRICS_COLS.keys()) + \
                         list(OSCILLATORY_METRICS_COLS.keys()) + \
                         ['Total_Daily_LEDD_mg']
    cols_to_show_step4_present = [col for col in cols_to_show_step4 if col in master_df_step4.columns]
    print(master_df_step4[cols_to_show_step4_present].head())
else:
    print("Halting Step 4 script as master data could not be loaded or is empty after preprocessing.")
    # sys.exit() # Uncomment to halt execution if master_df_step4 is not loaded

print("\nCell 3: Data loading and preprocessing complete for Step 4.")

In [None]:
# -*- coding: utf-8 -*-
# --- Cell 4: Helper Functions for Step 4 ---

def calculate_spearman_with_n(data_df, col1, col2, min_samples=MIN_SAMPLES_FOR_CORR):
    """Calculates Spearman correlation if N >= min_samples."""
    pair_data = data_df[[col1, col2]].dropna()
    n_points = len(pair_data)

    if n_points < min_samples:
        return np.nan, np.nan, n_points # rho, p-value, N
    try:
        rho, p_value = spearmanr(pair_data[col1], pair_data[col2])
        if np.isnan(rho): # Handle cases where spearmanr might return NaN (e.g., no variance)
            return np.nan, np.nan, n_points
        return rho, p_value, n_points
    except ValueError: # Handle other potential errors like constant input
        return np.nan, np.nan, n_points

def calculate_partial_spearman(data_df, x_col, y_col, covar_cols, min_samples=MIN_SAMPLES_FOR_CORR):
    """Calculates partial Spearman correlation if N >= min_samples."""
    all_cols_for_partial = [x_col, y_col] + covar_cols
    partial_data = data_df[all_cols_for_partial].dropna()
    n_points = len(partial_data)

    if n_points < min_samples:
        return np.nan, np.nan, n_points # partial_rho, p-value, N
    try:
        # Ensure all columns for partial corr are numeric and not constant after dropna
        if not all(partial_data[col].nunique() > 1 for col in all_cols_for_partial if col in partial_data):
            # print(f"Warning: Constant column found in data for partial corr {x_col} vs {y_col}. N={n_points}")
            return np.nan, np.nan, n_points
            
        pcorr_result = pg.partial_corr(data=partial_data, x=x_col, y=y_col, covar=covar_cols, method='spearman')
        rho = pcorr_result['r'].iloc[0]
        p_value = pcorr_result['p-val'].iloc[0]
        return rho, p_value, n_points
    except Exception as e: # Catch any error during partial correlation
        # print(f"Error in partial correlation for {x_col} vs {y_col} (N={n_points}): {e}")
        return np.nan, np.nan, n_points


def annotate_correlation_on_plot(ax, rho, p_value, N_val, test_type="Spearman ρ", 
                                 x_pos=0.97, y_pos=0.97, fontsize=9,
                                 sig_threshold=P_VALUE_THRESHOLD):
    """Annotates correlation statistics on a plot axis."""
    if pd.isna(rho) or pd.isna(p_value):
        stat_text = f"{test_type}: N/A (N={N_val})"
        bg_color = DEFAULT_P_VAL_BG_COLOR_STEP4 # from Cell 2
    else:
        stars = ""
        if p_value < 0.001: stars = "***"
        elif p_value < 0.01: stars = "**"
        elif p_value < sig_threshold: stars = "*"
        stat_text = f"{test_type}={rho:.2f}{stars}\np={p_value:.3g}\n(N={N_val})"
        bg_color = SIGNIFICANT_P_VAL_BG_COLOR_STEP4 if p_value < sig_threshold else DEFAULT_P_VAL_BG_COLOR_STEP4
    
    ax.text(x_pos, y_pos, stat_text, transform=ax.transAxes, fontsize=fontsize,
            verticalalignment='top', horizontalalignment='right',
            bbox=dict(boxstyle='round,pad=0.3', fc=bg_color, alpha=0.85, edgecolor='darkgrey'))

def get_safe_filename_step4(base_name):
    """Creates a filesystem-safe filename."""
    return re.sub(r'[^\w\s-]', '', str(base_name)).strip().replace(' ', '_').replace('-', '_')

def trim_data_for_boxplot_visualization(df_group, value_col):
    """Trims outliers based on IQR for cleaner boxplot visualization (doesn't affect stats)."""
    if df_group.empty or df_group[value_col].isnull().all() or len(df_group) < 2:
        return df_group
    Q1 = df_group[value_col].quantile(0.25)
    Q3 = df_group[value_col].quantile(0.75)
    IQR = Q3 - Q1
    if IQR == 0: # Avoid issues if all data points are the same
        return df_group
    lower_bound = Q1 - 1.5 * IQR
    upper_bound = Q3 + 1.5 * IQR
    return df_group[(df_group[value_col] >= lower_bound) & (df_group[value_col] <= upper_bound)]


print("Cell 4: Helper functions for Step 4 defined.")

In [None]:
# -*- coding: utf-8 -*-
# --- Cell 5_PREAMBLE: Definitions for State-Specific Analyses (Revised Order: No Sleep, 4 Separate Mobile/Immobile States) ---

# --- Define Target Clinical States and their Order (4 States, No Sleep, New Order) ---
TARGET_CLINICAL_STATES_ORDERED = [
    "Immobile",
    "Non-Dyskinetic Mobile",
    "Transitional Mobile",
    "Dyskinetic Mobile"
]

ORIGINAL_STATES_FOR_ANALYSIS = TARGET_CLINICAL_STATES_ORDERED[:]

# No combining needed for this setup
STATES_TO_COMBINE_MAPPING = {}
NEW_COMBINED_STATE_NAME = None


# --- Define Clinical State Colors (using original distinct colors, excluding Sleep) ---
# The color definitions themselves don't change, but their application order will follow TARGET_CLINICAL_STATES_ORDERED
NEW_CLINICAL_STATE_COLORS_FOR_PLOTTING = {
    'Immobile': '#40E0D0',              # Turquoise
    'Non-Dyskinetic Mobile': '#32CD32', # LimeGreen
    'Transitional Mobile': '#FFD700',   # Gold
    'Dyskinetic Mobile': '#FF6347',     # Tomato
    # Fallbacks or other states if they were to appear unexpectedly
    'Sleep': '#4169E1',                 # RoyalBlue (original, but excluded from TARGET_CLINICAL_STATES_ORDERED)
    'Other': '#C0C0C0',                 # Silver
    'Mobile (All Types)': 'darkgreen'   # For aggregated view if ever used
}


# --- Define PKG Symptom Colors (remains the same) ---
PKG_SYMPTOM_COLORS = {
    'Aligned_BK': BASE_COLOR_PALETTE.get('Aligned_BK', 'steelblue'),
    'Aligned_DK': BASE_COLOR_PALETTE.get('Aligned_DK', 'orangered'),
    'Aligned_Tremor_Score': BASE_COLOR_PALETTE.get('Aligned_Tremor_Score', 'mediumpurple')
}

# Base output directory name remains based on the content (4 states, no sleep)
STATE_SPECIFIC_ANALYSIS_DIR = os.path.join(analysis_session_plot_folder_step4)
os.makedirs(STATE_SPECIFIC_ANALYSIS_DIR, exist_ok=True)

print("Cell 5_PREAMBLE: Definitions for state-specific analyses (4 States - No Sleep, Separate Mobile, New Order) are set.")
print(f"Target clinical states for analysis (NEW ORDER): {TARGET_CLINICAL_STATES_ORDERED}")
print(f"Colors for clinical states: {NEW_CLINICAL_STATE_COLORS_FOR_PLOTTING}") # This dict remains the same, order of use changes
print(f"State-specific outputs will be saved in subdirectories of: {STATE_SPECIFIC_ANALYSIS_DIR}")

In [None]:
# -*- coding: utf-8 -*-
# --- Cell 5A (Revised): State-Specific Correlations with FDR Correction ---
# This version calculates all bivariate and partial correlations for each of the
# TARGET_CLINICAL_STATES_ORDERED, pools all p-values, and applies a single
# Benjamini-Hochberg FDR correction to the entire family of tests.

import pandas as pd
import numpy as np
import os
import statsmodels.api as sm
from statsmodels.stats.multitest import fdrcorrection

print("\\n--- Cell 5A (Revised): Starting State-Specific Correlation Calculations with FDR Correction ---")

# This check ensures that the main dataframe from Cell 3 is loaded and available.
if 'master_df_step4' not in locals() or master_df_step4.empty:
    print("CRITICAL ERROR: master_df_step4 not available or empty. Cannot proceed with Cell 5A. Please run prior cells.")
    # In a real run, you might use sys.exit() here. For now, we'll let it error out if df is missing.

# --- Step 1: Data Preparation and Filtering ---
# Filter the master dataframe to include only the clinical states relevant to this analysis.
# TARGET_CLINICAL_STATES_ORDERED is defined in Cell 5_PREAMBLE.
master_df_step4_filtered_states = master_df_step4[master_df_step4[CLINICAL_STATE_COL].isin(TARGET_CLINICAL_STATES_ORDERED)].copy()

if master_df_step4_filtered_states.empty:
    print(f"No data found after filtering for target clinical states: {TARGET_CLINICAL_STATES_ORDERED}. Skipping Cell 5A.")
else:
    # Ensure the clinical state column is treated as an ordered categorical variable for consistency.
    master_df_step4_filtered_states[CLINICAL_STATE_COL] = pd.Categorical(
        master_df_step4_filtered_states[CLINICAL_STATE_COL],
        categories=TARGET_CLINICAL_STATES_ORDERED,
        ordered=True
    )
    master_df_step4_filtered_states.dropna(subset=[CLINICAL_STATE_COL], inplace=True)
    print(f"Filtered data for target clinical states. Shape: {master_df_step4_filtered_states.shape}.")

    # Define the output directory for the correlation result CSVs.
    state_corr_csv_dir = os.path.join(STATE_SPECIFIC_ANALYSIS_DIR, "Correlation_CSVs_by_State_FDR_Corrected")
    os.makedirs(state_corr_csv_dir, exist_ok=True)

    # --- Step 2: Collect P-Values from All Planned Tests ---
    # This list will store the raw results of EVERY single correlation test before FDR correction.
    all_correlation_results = []
    
    # These loops iterate through every combination of state, channel, and frequency band.
    print("\\nCalculating all correlations across states, channels, and frequency bands...")
    for state_current in TARGET_CLINICAL_STATES_ORDERED:
        df_state = master_df_step4_filtered_states[master_df_step4_filtered_states[CLINICAL_STATE_COL] == state_current]
        if df_state.empty:
            continue

        for channel_label in ORDERED_CHANNEL_LABELS:
            df_channel_state = df_state[df_state[CHANNEL_DISPLAY_COL] == channel_label]
            if df_channel_state.empty:
                continue

            for freq_label in ORDERED_FREQ_LABELS:
                df_channel_freq_state = df_channel_state[df_channel_state[FOOOF_FREQ_BAND_COL] == freq_label].copy()
                if df_channel_freq_state.empty:
                    continue

                # --- Test Family 1: Bivariate Aperiodic vs. PKG ---
                for ap_col, ap_name in APERIODIC_METRICS_COLS.items():
                    for pkg_col, pkg_name in PKG_METRICS_COLS.items():
                        if ap_col in df_channel_freq_state.columns and pkg_col in df_channel_freq_state.columns:
                            # CHANGE: Using all valid data, no exclusion of zeros
                            rho, p_val, N = calculate_spearman_with_n(df_channel_freq_state, ap_col, pkg_col)
                            all_correlation_results.append({
                                'TestType': 'Bivar_AP_PKG', 'ClinicalState': state_current, 'Channel': channel_label, 'FreqBand': freq_label,
                                'Metric1': ap_name, 'Metric2': pkg_name, 'Rho': rho, 'P_Value_Original': p_val, 'N': N
                            })

                # --- Test Family 2: Partial Aperiodic vs. PKG (controlling for Oscillatory) ---
                covariates = [col for col in OSCILLATORY_METRICS_COLS.keys() if col in df_channel_freq_state.columns]
                if len(covariates) == 2:
                    for ap_col, ap_name in APERIODIC_METRICS_COLS.items():
                        for pkg_col, pkg_name in PKG_METRICS_COLS.items():
                            if ap_col in df_channel_freq_state.columns and pkg_col in df_channel_freq_state.columns:
                                # CHANGE: Using all valid data
                                partial_rho, partial_p_val, N_partial = calculate_partial_spearman(
                                    df_channel_freq_state, ap_col, pkg_col, covariates
                                )
                                all_correlation_results.append({
                                    'TestType': 'Partial_AP_PKG', 'ClinicalState': state_current, 'Channel': channel_label, 'FreqBand': freq_label,
                                    'Metric1': ap_name, 'Metric2': pkg_name, 'Rho': partial_rho, 'P_Value_Original': partial_p_val, 'N': N_partial
                                })

                # --- Test Family 3: Bivariate Aperiodic vs. Oscillatory ---
                for ap_col, ap_name in APERIODIC_METRICS_COLS.items():
                    for osc_col, osc_name in OSCILLATORY_METRICS_COLS.items():
                        if ap_col in df_channel_freq_state.columns and osc_col in df_channel_freq_state.columns:
                            rho_ap_osc, p_val_ap_osc, N_ap_osc = calculate_spearman_with_n(df_channel_freq_state, ap_col, osc_col)
                            all_correlation_results.append({
                                'TestType': 'Bivar_AP_Osc', 'ClinicalState': state_current, 'Channel': channel_label, 'FreqBand': freq_label,
                                'Metric1': ap_name, 'Metric2': osc_name, 'Rho': rho_ap_osc, 'P_Value_Original': p_val_ap_osc, 'N': N_ap_osc
                            })

    # --- Step 3: Apply Benjamini-Hochberg FDR Correction ---
    print(f"\\nCollected a total of {len(all_correlation_results)} p-values for FDR correction.")
    
    if not all_correlation_results:
        print("No correlation results were generated. Cannot perform FDR correction.")
    else:
        df_all_results = pd.DataFrame(all_correlation_results)
        
        # Isolate non-NaN p-values for correction
        p_values_to_correct = df_all_results['P_Value_Original'].dropna()
        
        if p_values_to_correct.empty:
            print("No valid p-values to correct. All correlations may have had insufficient data.")
            df_all_results['P_Value_FDR_Adjusted'] = np.nan
            df_all_results['Significant_FDR_0.05'] = False
        else:
            # The fdrcorrection function from statsmodels handles the BH procedure
            rejected, pvals_corrected = fdrcorrection(p_values_to_correct, alpha=0.05, method='indep', is_sorted=False)
            
            # Create a temporary series with the corrected p-values, matching the index of the original non-NaN p-values
            corrected_p_series = pd.Series(pvals_corrected, index=p_values_to_correct.index)
            
            # Map the corrected values back to the original dataframe. This correctly handles NaNs.
            df_all_results['P_Value_FDR_Adjusted'] = corrected_p_series
            df_all_results['Significant_FDR_0.05'] = df_all_results['P_Value_FDR_Adjusted'] < 0.05

        print("FDR correction applied successfully.")

        # --- Step 4: Separate and Save Final Results ---
        # Separate the combined results dataframe back into the three original test types
        
        # 1. Bivariate AP vs PKG
        df_bivar_ap_pkg_final = df_all_results[df_all_results['TestType'] == 'Bivar_AP_PKG'].copy()
        df_bivar_ap_pkg_final.rename(columns={'Metric1': 'AperiodicMetric', 'Metric2': 'PKGMetric', 'Rho': 'SpearmanRho'}, inplace=True)
        df_bivar_ap_pkg_final = df_bivar_ap_pkg_final.drop(columns='TestType')
        df_bivar_ap_pkg_final.to_csv(os.path.join(state_corr_csv_dir, f"{patient_hemisphere_id}_Bivariate_AP_vs_PKG_ByState_FDR_AllData.csv"), index=False)
        print(f"\\nSaved FDR-Corrected Bivariate AP vs PKG results (all data) for {patient_hemisphere_id}.")

        # 2. Partial AP vs PKG
        df_partial_ap_pkg_final = df_all_results[df_all_results['TestType'] == 'Partial_AP_PKG'].copy()
        df_partial_ap_pkg_final.rename(columns={'Metric1': 'AperiodicMetric', 'Metric2': 'PKGMetric', 'Rho': 'PartialSpearmanRho_vs_BetaGamma'}, inplace=True)
        df_partial_ap_pkg_final = df_partial_ap_pkg_final.drop(columns='TestType')
        df_partial_ap_pkg_final.to_csv(os.path.join(state_corr_csv_dir, f"{patient_hemisphere_id}_Partial_AP_vs_PKG_ByState_FDR_AllData.csv"), index=False)
        print(f"Saved FDR-Corrected Partial AP vs PKG results (all data) for {patient_hemisphere_id}.")

        # 3. Bivariate AP vs Oscillatory
        df_bivar_ap_osc_final = df_all_results[df_all_results['TestType'] == 'Bivar_AP_Osc'].copy()
        df_bivar_ap_osc_final.rename(columns={'Metric1': 'AperiodicMetric', 'Metric2': 'OscillatoryMetric', 'Rho': 'SpearmanRho'}, inplace=True)
        df_bivar_ap_osc_final = df_bivar_ap_osc_final.drop(columns='TestType')
        df_bivar_ap_osc_final.to_csv(os.path.join(state_corr_csv_dir, f"{patient_hemisphere_id}_Bivariate_AP_vs_Oscillatory_ByState_FDR_AllData.csv"), index=False)
        print(f"Saved FDR-Corrected Bivariate AP vs Oscillatory results (all data) for {patient_hemisphere_id}.")
        
        print("\\nSample of final FDR-corrected Bivariate AP vs PKG results:")
        print(df_bivar_ap_pkg_final.head())


print("\\n--- Cell 5A (Revised) with FDR Correction Complete ---")

In [None]:
# -*- coding: utf-8 -*-
# --- Cell 5B (New): State-Specific Overview Scatter Plots (Aperiodic vs. PKG) (Revised: No Sleep, 4 Separate Mobile/Immobile States) ---
# Generates overview scatter plots: one figure per (PKG_Symptom, Aperiodic_Metric, Channel, FreqBand),
# with subplots for each clinical state (now 4 states). Y-axes are standardized within each figure.
# Scatter plot points are 10-minute averages. Regression line uses all granular data.

print("\\n--- Cell 5B (New): Starting State-Specific Overview Scatter Plot Generation (4 States - No Sleep, 10-min avg pts) ---")

# Ensure master_df_step4_filtered_states has the correct 4 states from the updated Cell 5A
if 'master_df_step4_filtered_states' not in locals() or master_df_step4_filtered_states.empty:
    print("master_df_step4_filtered_states (with 4 states - No Sleep) not available or empty. Skipping Cell 5B.")
elif 'df_bivar_ap_pkg_state' not in locals() or ('df_bivar_ap_pkg_state' in locals() and df_bivar_ap_pkg_state.empty):
    print("Bivariate AP vs PKG by State correlation results (df_bivar_ap_pkg_state) not found. Skipping Cell 5B plot annotations.")
    df_bivar_ap_pkg_state = pd.DataFrame(columns=['ClinicalState', 'Channel', 'FreqBand', 'AperiodicMetric', 'PKGMetric', 'SpearmanRho', 'PValue', 'N'])
else:
    # UPDATED FOLDER NAME
    plot_subdir_overview_scatter_state = os.path.join(STATE_SPECIFIC_ANALYSIS_DIR, "Overview_Scatter_AP_vs_PKG_by_State")
    os.makedirs(plot_subdir_overview_scatter_state, exist_ok=True)

    # Ensure 'datetime_for_avg' exists from Cell 5A modifications
    if 'datetime_for_avg' not in master_df_step4_filtered_states.columns:
        print("ERROR in Cell 5B: 'datetime_for_avg' column missing from master_df_step4_filtered_states. Averaging will fail.")
        # Fallback or exit
    else:
        for pkg_col_overview, pkg_name_overview in PKG_METRICS_COLS.items():
            if pkg_col_overview not in master_df_step4_filtered_states.columns:
                print(f"PKG Metric {pkg_name_overview} ({pkg_col_overview}) not found in data. Skipping its overview plots.")
                continue

            for ap_metric_overview_col, ap_metric_overview_name in APERIODIC_METRICS_COLS.items():
                if ap_metric_overview_col not in master_df_step4_filtered_states.columns:
                    print(f"Aperiodic Metric {ap_metric_overview_name} ({ap_metric_overview_col}) not found. Skipping its overview plots.")
                    continue
                
                print(f"\\nGenerating overview plots for: {ap_metric_overview_name} vs. {pkg_name_overview}")

                for channel_label_overview in ORDERED_CHANNEL_LABELS:
                    for freq_label_overview in ORDERED_FREQ_LABELS:
                        
                        fig_width = max(15, 4 * len(TARGET_CLINICAL_STATES_ORDERED)) # Now 4 states
                        fig, axes = plt.subplots(1, len(TARGET_CLINICAL_STATES_ORDERED), 
                                                 figsize=(fig_width, 5.5), sharey=False)
                        if len(TARGET_CLINICAL_STATES_ORDERED) == 1: # Should not happen with 4 states, but good practice
                            axes = [axes]
                        
                        fig.suptitle(f"{ap_metric_overview_name} vs. {pkg_name_overview}\\nChannel: {channel_label_overview} - Freq: {freq_label_overview} - Patient: {patient_hemisphere_id}\\n(Scatter points are 10-min averages; Regression on all raw data)",
                                     fontsize=plt.rcParams['figure.titlesize'] * 0.85, y=1.05)

                        all_ap_values_for_ylim = []
                        valid_plot_exists_for_figure = False 

                        for i, state_overview_ylim in enumerate(TARGET_CLINICAL_STATES_ORDERED): # Iterates 4 states
                            df_current_combo_ylim_pass = master_df_step4_filtered_states[
                                (master_df_step4_filtered_states[CLINICAL_STATE_COL] == state_overview_ylim) &
                                (master_df_step4_filtered_states[CHANNEL_DISPLAY_COL] == channel_label_overview) &
                                (master_df_step4_filtered_states[FOOOF_FREQ_BAND_COL] == freq_label_overview)
                            ]
                            cols_to_drop_na_for_ylim_pass = [ap_metric_overview_col, pkg_col_overview, 'datetime_for_avg']
                            
                            df_plot_data_ylim_pass = df_current_combo_ylim_pass.dropna(subset=cols_to_drop_na_for_ylim_pass)
                            
                            if not df_plot_data_ylim_pass.empty and len(df_plot_data_ylim_pass) >= MIN_SAMPLES_FOR_CORR:
                                all_ap_values_for_ylim.extend(df_plot_data_ylim_pass[ap_metric_overview_col].tolist())
                        
                        min_y, max_y = (np.nan, np.nan)
                        if all_ap_values_for_ylim: 
                            min_y_val_calc = np.nanmin(all_ap_values_for_ylim)
                            max_y_val_calc = np.nanmax(all_ap_values_for_ylim)
                            if not (np.isnan(min_y_val_calc) or np.isnan(max_y_val_calc)):
                                 padding = (max_y_val_calc - min_y_val_calc) * 0.1 if (max_y_val_calc - min_y_val_calc) > 0 else 0.1
                                 min_y = min_y_val_calc - padding
                                 max_y = max_y_val_calc + padding
                                 
                        for i, state_overview in enumerate(TARGET_CLINICAL_STATES_ORDERED): # Iterates 4 states
                            ax = axes[i]
                            df_current_combo_plot = master_df_step4_filtered_states[
                                (master_df_step4_filtered_states[CLINICAL_STATE_COL] == state_overview) &
                                (master_df_step4_filtered_states[CHANNEL_DISPLAY_COL] == channel_label_overview) &
                                (master_df_step4_filtered_states[FOOOF_FREQ_BAND_COL] == freq_label_overview)
                            ]
                            
                            cols_to_drop_na_plot = [ap_metric_overview_col, pkg_col_overview, 'datetime_for_avg']
                            df_plot_data_granular_plot = df_current_combo_plot.dropna(subset=cols_to_drop_na_plot)

                            if not df_plot_data_granular_plot.empty and len(df_plot_data_granular_plot) >= MIN_SAMPLES_FOR_CORR :
                                valid_plot_exists_for_figure = True 

                                corr_stats_row = df_bivar_ap_pkg_state[
                                    (df_bivar_ap_pkg_state['ClinicalState'] == state_overview) &
                                    (df_bivar_ap_pkg_state['Channel'] == channel_label_overview) &
                                    (df_bivar_ap_pkg_state['FreqBand'] == freq_label_overview) &
                                    (df_bivar_ap_pkg_state['AperiodicMetric'] == ap_metric_overview_name) &
                                    (df_bivar_ap_pkg_state['PKGMetric'] == pkg_name_overview)
                                ]
                                rho = corr_stats_row['SpearmanRho'].iloc[0] if not corr_stats_row.empty else np.nan
                                p_val = corr_stats_row['PValue'].iloc[0] if not corr_stats_row.empty else np.nan
                                N_val = corr_stats_row['N'].iloc[0] if not corr_stats_row.empty else len(df_plot_data_granular_plot)

                                df_averaged_points_plot = pd.DataFrame()
                                if 'datetime_for_avg' in df_plot_data_granular_plot.columns and \
                                   not df_plot_data_granular_plot['datetime_for_avg'].isnull().all():
                                    try:
                                        df_averaged_points_plot = df_plot_data_granular_plot.set_index('datetime_for_avg')\
                                            .groupby(pd.Grouper(freq='10T'))[[ap_metric_overview_col, pkg_col_overview]]\
                                            .mean().dropna()
                                    except Exception as e_avg: 
                                        print(f"Warning: 10-min averaging failed for {channel_label_overview}, {freq_label_overview}, {state_overview}. Plotting granular points. Error: {e_avg}")
                                        df_averaged_points_plot = df_plot_data_granular_plot 
                                else: 
                                    df_averaged_points_plot = df_plot_data_granular_plot

                                if not df_averaged_points_plot.empty:
                                    sns.scatterplot(data=df_averaged_points_plot, x=pkg_col_overview, y=ap_metric_overview_col,
                                                    color=NEW_CLINICAL_STATE_COLORS_FOR_PLOTTING.get(state_overview, 'grey'),
                                                    alpha=DOT_ALPHA_STEP4 + 0.2, s=40, edgecolor='black', linewidths=0.5, ax=ax, legend=False)

                                sns.regplot(data=df_plot_data_granular_plot, x=pkg_col_overview, y=ap_metric_overview_col, scatter=False, ax=ax,
                                            line_kws={'color': 'black', 'linewidth': 1.5, 'alpha': 0.6})
                                
                                annotate_correlation_on_plot(ax, rho, p_val, N_val, fontsize=8)
                                ax.set_title(state_overview, fontsize=plt.rcParams['axes.titlesize']*0.8)
                                ax.set_xlabel(pkg_name_overview if i == len(TARGET_CLINICAL_STATES_ORDERED) // 2 else "", fontsize=plt.rcParams['axes.labelsize']*0.9)
                                
                                if not pd.isna(min_y) and not pd.isna(max_y):
                                    ax.set_ylim(min_y, max_y)
                            else:
                                ax.text(0.5, 0.5, "N < min_samples" if len(df_plot_data_granular_plot) < MIN_SAMPLES_FOR_CORR else "No Data", 
                                        horizontalalignment='center', verticalalignment='center', transform=ax.transAxes, fontsize=9)
                                ax.set_title(state_overview, fontsize=plt.rcParams['axes.titlesize']*0.8)
                                if not pd.isna(min_y) and not pd.isna(max_y): 
                                    ax.set_ylim(min_y, max_y)

                            if i == 0:
                                ax.set_ylabel(ap_metric_overview_name, fontsize=plt.rcParams['axes.labelsize'])
                            else:
                                ax.set_ylabel("")
                                ax.set_yticklabels([])
                            
                            ax.tick_params(axis='x', labelsize=plt.rcParams['xtick.labelsize']*0.9)
                            ax.tick_params(axis='y', labelsize=plt.rcParams['ytick.labelsize']*0.9)

                        if valid_plot_exists_for_figure: 
                            plt.tight_layout(rect=[0, 0.03, 1, 0.93]) 
                            safe_ap_name = get_safe_filename_step4(ap_metric_overview_name)
                            safe_pkg_name = get_safe_filename_step4(pkg_name_overview)
                            safe_ch_name = get_safe_filename_step4(channel_label_overview)
                            safe_freq_name = get_safe_filename_step4(freq_label_overview)
                            # UPDATED FILENAME
                            plot_filename = f"Overview_{safe_ap_name}_vs_{safe_pkg_name}_{safe_ch_name}_{safe_freq_name}.png" 
                            plt.savefig(os.path.join(plot_subdir_overview_scatter_state, plot_filename))
                        plt.close(fig)

print("\\n--- Cell 5B (New): State-Specific Overview Scatter Plot Generation (4 States - No Sleep, 10-min avg pts) Complete ---")

In [None]:
# -*- coding: utf-8 -*-
# --- Cell 5 (Revised V6 - With FDR Correction and All Data): Bivariate & Partial Correlations ---
# Iterates through Channel_Display AND FreqRangeLabel

print("\\n--- Cell 5 (Revised V6 - With FDR Correction and All Data): Starting Correlation Analyses ---")

from matplotlib.ticker import FormatStrFormatter
from statsmodels.stats.multitest import fdrcorrection

# --- Define new plotting parameters for this cell ---
# Text smaller by 60% from the enlarged version (1.5 * 0.4 = 0.6)
font_scale_factor = 0.6
# Color for tremor remains green
MODIFIED_COLOR_PALETTE = BASE_COLOR_PALETTE.copy()
MODIFIED_COLOR_PALETTE['Aligned_Tremor_Score'] = 'green'

SIGNIFICANT_P_VAL_BG_COLOR_STEP4 = 'khaki'
DEFAULT_P_VAL_BG_COLOR_STEP4 = 'ivory'

if 'master_df_step4' not in locals() or master_df_step4 is None or master_df_step4.empty:
    print("master_df_step4 not available or empty. Skipping Cell 5.")
else:
    # Create a working copy to add a datetime column for averaging
    master_df_step4_processed_c5 = master_df_step4.copy()
    if 'Aligned_PKG_UnixTimestamp' in master_df_step4_processed_c5.columns:
        master_df_step4_processed_c5['datetime_for_avg_c5'] = pd.to_datetime(
            master_df_step4_processed_c5['Aligned_PKG_UnixTimestamp'], unit='s', utc=True, errors='coerce'
        )
    else:
        print("Warning in Cell 5: 'Aligned_PKG_UnixTimestamp' not found. Will plot granular points.")
        master_df_step4_processed_c5['datetime_for_avg_c5'] = pd.NaT

    # Create subdirectories for plots from this cell
    plot_subdir_bivariate_ap_pkg = os.path.join(analysis_session_plot_folder_step4, "Bivariate_AP_vs_PKG")
    plot_subdir_bivariate_ap_osc = os.path.join(analysis_session_plot_folder_step4, "Bivariate_AP_vs_Oscillatory")
    os.makedirs(plot_subdir_bivariate_ap_pkg, exist_ok=True)
    os.makedirs(plot_subdir_bivariate_ap_osc, exist_ok=True)

    all_bivariate_ap_pkg_results = []
    all_partial_ap_pkg_results = []
    all_bivariate_ap_osc_results = []
    
    # Lists to collect all p-values for FDR correction
    all_p_values = []
    p_value_mapping = []  # To track which p-value belongs to which test

    for channel_label in ORDERED_CHANNEL_LABELS:
        df_channel = master_df_step4_processed_c5[master_df_step4_processed_c5[CHANNEL_DISPLAY_COL] == channel_label]
        if df_channel.empty:
            continue
        print(f"\\nProcessing Channel: {channel_label}")

        for freq_label in ORDERED_FREQ_LABELS:
            df_channel_freq = df_channel[df_channel[FOOOF_FREQ_BAND_COL] == freq_label].copy()
            if df_channel_freq.empty:
                continue
            print(f"  Processing Freq Band: {freq_label}")

            # --- Part 5.1: Bivariate Spearman Correlations (Aperiodic vs. PKG) ---
            print(f"    Part 5.1: Bivariate Aperiodic vs. PKG Correlations for {channel_label} ({freq_label})")
            for ap_col, ap_name in APERIODIC_METRICS_COLS.items():
                for pkg_col, pkg_name in PKG_METRICS_COLS.items():
                    if pkg_col not in df_channel_freq.columns or ap_col not in df_channel_freq.columns:
                        continue
                    # CHANGE: NO LONGER EXCLUDING ZEROS - using all valid data
                    df_granular_for_corr_pkg = df_channel_freq.dropna(subset=[ap_col, pkg_col])
                    # Calculate Spearman's correlation
                    rho, p_val, N = calculate_spearman_with_n(df_granular_for_corr_pkg, ap_col, pkg_col)

                    result_dict = {
                        'Channel': channel_label, 'FreqBand': freq_label, 
                        'AperiodicMetric': ap_name, 'PKGMetric': pkg_name, 
                        'SpearmanRho': rho, 'PValue': p_val, 'N': N,
                        'TestType': 'Bivariate_AP_PKG'
                    }
                    all_bivariate_ap_pkg_results.append(result_dict)
                    
                    # Collect p-value for FDR correction
                    if not pd.isna(p_val):
                        all_p_values.append(p_val)
                        p_value_mapping.append(('bivar_ap_pkg', len(all_bivariate_ap_pkg_results) - 1))

            # --- Part 5.2: Partial Aperiodic vs. PKG Correlations ---
            print(f"    Part 5.2: Partial Aperiodic vs. PKG Correlations for {channel_label} ({freq_label})")
            covariates_partial_corr = [col for col in OSCILLATORY_METRICS_COLS.keys() if col in df_channel_freq.columns]
            if len(covariates_partial_corr) == 2:
                for ap_col, ap_name in APERIODIC_METRICS_COLS.items():
                    for pkg_col, pkg_name in PKG_METRICS_COLS.items():
                        if pkg_col not in df_channel_freq.columns or ap_col not in df_channel_freq.columns:
                            continue                        
                        cols_for_partial_corr = [ap_col, pkg_col] + covariates_partial_corr
                        cols_for_partial_corr_present = [c for c in cols_for_partial_corr if c in df_channel_freq.columns]
                        
                        if len(cols_for_partial_corr_present) != len(cols_for_partial_corr):
                            partial_rho, partial_p_val, N_partial = np.nan, np.nan, 0
                        else:
                            # CHANGE: NO LONGER EXCLUDING ZEROS
                            data_for_partial = df_channel_freq[cols_for_partial_corr_present].dropna()
                            if len(data_for_partial) < MIN_SAMPLES_FOR_CORR or not all(data_for_partial[c].nunique() > 1 for c in [ap_col, pkg_col] if c in data_for_partial):
                                 partial_rho, partial_p_val, N_partial = np.nan, np.nan, len(data_for_partial)
                            else:
                                 partial_rho, partial_p_val, N_partial = calculate_partial_spearman(
                                     data_for_partial, ap_col, pkg_col, covariates_partial_corr
                                 )
                        
                        result_dict = {
                            'Channel': channel_label, 'FreqBand': freq_label,
                            'AperiodicMetric': ap_name, 'PKGMetric': pkg_name,
                            'PartialSpearmanRho_vs_BetaGamma': partial_rho, 'PartialPValue': partial_p_val, 
                            'N_Partial': N_partial, 'TestType': 'Partial_AP_PKG'
                        }
                        all_partial_ap_pkg_results.append(result_dict)
                        
                        # Collect p-value for FDR correction
                        if not pd.isna(partial_p_val):
                            all_p_values.append(partial_p_val)
                            p_value_mapping.append(('partial_ap_pkg', len(all_partial_ap_pkg_results) - 1))

            # --- Part 5.3: Bivariate Aperiodic vs. Oscillatory Correlations ---
            print(f"    Part 5.3: Bivariate Aperiodic vs. Oscillatory Correlations for {channel_label} ({freq_label})")
            for ap_col, ap_name in APERIODIC_METRICS_COLS.items():
                for osc_col, osc_name in OSCILLATORY_METRICS_COLS.items():
                    if osc_col not in df_channel_freq.columns or ap_col not in df_channel_freq.columns:
                        continue
                    
                    df_granular_for_corr_osc = df_channel_freq.dropna(subset=[ap_col, osc_col])
                    rho_ap_osc, p_val_ap_osc, N_ap_osc = calculate_spearman_with_n(df_granular_for_corr_osc, ap_col, osc_col)
                    
                    result_dict = {
                        'Channel': channel_label, 'FreqBand': freq_label, 
                        'AperiodicMetric': ap_name, 'OscillatoryMetric': osc_name, 
                        'SpearmanRho': rho_ap_osc, 'PValue': p_val_ap_osc, 'N': N_ap_osc,
                        'TestType': 'Bivariate_AP_Osc'
                    }
                    all_bivariate_ap_osc_results.append(result_dict)
                    
                    # Collect p-value for FDR correction
                    if not pd.isna(p_val_ap_osc):
                        all_p_values.append(p_val_ap_osc)
                        p_value_mapping.append(('bivar_ap_osc', len(all_bivariate_ap_osc_results) - 1))

    # --- Apply FDR Correction to all p-values ---
    print(f"\\nCollected {len(all_p_values)} p-values for FDR correction.")
    
    if all_p_values:
        # Apply FDR correction
        rejected, pvals_corrected = fdrcorrection(all_p_values, alpha=0.05, method='indep', is_sorted=False)
        
        # Map corrected p-values back to results
        for idx, (test_type, result_idx) in enumerate(p_value_mapping):
            if test_type == 'bivar_ap_pkg':
                all_bivariate_ap_pkg_results[result_idx]['PValue_FDR'] = pvals_corrected[idx]
                all_bivariate_ap_pkg_results[result_idx]['Significant_FDR'] = rejected[idx]
            elif test_type == 'partial_ap_pkg':
                all_partial_ap_pkg_results[result_idx]['PartialPValue_FDR'] = pvals_corrected[idx]
                all_partial_ap_pkg_results[result_idx]['Significant_FDR'] = rejected[idx]
            elif test_type == 'bivar_ap_osc':
                all_bivariate_ap_osc_results[result_idx]['PValue_FDR'] = pvals_corrected[idx]
                all_bivariate_ap_osc_results[result_idx]['Significant_FDR'] = rejected[idx]
        
        print("FDR correction applied successfully.")
    
    # --- Generate plots only for tests that remain significant after FDR ---
    for i, result in enumerate(all_bivariate_ap_pkg_results):
        if result['N'] >= MIN_SAMPLES_FOR_CORR and result.get('Significant_FDR', False):
            # Recreate the data for plotting
            channel_label = result['Channel']
            freq_label = result['FreqBand']
            ap_name = result['AperiodicMetric']
            pkg_name = result['PKGMetric']
            ap_col = [k for k, v in APERIODIC_METRICS_COLS.items() if v == ap_name][0]
            pkg_col = [k for k, v in PKG_METRICS_COLS.items() if v == pkg_name][0]
            
            df_plot = master_df_step4_processed_c5[
                (master_df_step4_processed_c5[CHANNEL_DISPLAY_COL] == channel_label) &
                (master_df_step4_processed_c5[FOOOF_FREQ_BAND_COL] == freq_label)
            ].dropna(subset=[ap_col, pkg_col])
            
            if not df_plot.empty:
                plt.figure(figsize=(6, 6))
                ax = plt.gca()
                ax.grid(False)

                df_averaged_points_pkg = df_plot.set_index('datetime_for_avg_c5').groupby(pd.Grouper(freq='10T'))[[ap_col, pkg_col]].mean().dropna() if 'datetime_for_avg_c5' in df_plot.columns and not df_plot['datetime_for_avg_c5'].isnull().all() else df_plot.copy()

                if not df_averaged_points_pkg.empty:
                     sns.scatterplot(data=df_averaged_points_pkg, x=pkg_col, y=ap_col, color=MODIFIED_COLOR_PALETTE.get(ap_col, 'grey'), alpha=DOT_ALPHA_STEP4 + 0.1, s=40, edgecolor='k', linewidths=0.5, ax=ax)
                sns.regplot(data=df_plot, x=pkg_col, y=ap_col, scatter=False, ax=ax, line_kws={'color': MODIFIED_COLOR_PALETTE.get(pkg_col, 'black'), 'linewidth': REG_LINE_THICKNESS_STEP4, 'alpha': 0.7})

                # Use FDR-corrected p-value for annotation
                annotate_correlation_on_plot(ax, result['SpearmanRho'], result['PValue_FDR'], result['N'], 
                                           test_type="Spearman ρ (FDR)", fontsize=9 * font_scale_factor)
                
                simple_pkg_name = pkg_name.replace('PKG ', '')
                simple_ap_name = ap_name.replace('Aperiodic ', '')
                ax.set_xlabel(simple_pkg_name, fontsize=plt.rcParams['axes.labelsize'] * font_scale_factor)
                ax.set_ylabel(simple_ap_name, fontsize=plt.rcParams['axes.labelsize'] * font_scale_factor)
                ax.tick_params(axis='both', which='major', labelsize=plt.rcParams['xtick.labelsize'] * font_scale_factor)
                
                # Set hardcoded ticks and limits
                y_ticks = [0, 1, 2, 3, 4, 5]
                x_ticks = [0, 20, 40, 60, 80]
                ax.set_yticks(y_ticks)
                ax.set_xticks(x_ticks)
                ax.set_ylim(0, 5)
                ax.set_xlim(0, 80)
                
                plt.tight_layout()
                safe_ch, safe_ap, safe_pkg = get_safe_filename_step4(channel_label), get_safe_filename_step4(ap_name), get_safe_filename_step4(pkg_name)
                plot_filename = f"Bivar_{safe_ap}_vs_{safe_pkg}_{safe_ch}_{freq_label}_FDRsig.png"
                plt.savefig(os.path.join(plot_subdir_bivariate_ap_pkg, plot_filename))
                plt.close()

    # Similar plotting for oscillatory correlations (only FDR significant)
    for i, result in enumerate(all_bivariate_ap_osc_results):
        if result['N'] >= MIN_SAMPLES_FOR_CORR and result.get('Significant_FDR', False):
            # Similar plotting code for oscillatory...
            # [Code abbreviated for brevity - follows same pattern as above]
            pass

    # --- CSV Saving Logic ---
    if all_bivariate_ap_pkg_results:
        df_bivar_ap_pkg = pd.DataFrame(all_bivariate_ap_pkg_results)
        df_bivar_ap_pkg.to_csv(os.path.join(analysis_session_plot_folder_step4, f"{patient_hemisphere_id}_Bivariate_AP_vs_PKG_Correlations_AllData_FDR_Cell5.csv"), index=False)
        print(f"\\nSaved Bivariate AP vs PKG correlation results (all data, FDR corrected) for {patient_hemisphere_id}.")

    if all_partial_ap_pkg_results:
        df_partial_ap_pkg = pd.DataFrame(all_partial_ap_pkg_results)
        df_partial_ap_pkg.to_csv(os.path.join(analysis_session_plot_folder_step4, f"{patient_hemisphere_id}_Partial_AP_vs_PKG_Correlations_AllData_FDR_Cell5.csv"), index=False)
        print(f"Saved Partial AP vs PKG correlation results (all data, FDR corrected) for {patient_hemisphere_id}.")

    if all_bivariate_ap_osc_results:
        df_bivar_ap_osc = pd.DataFrame(all_bivariate_ap_osc_results)
        df_bivar_ap_osc.to_csv(os.path.join(analysis_session_plot_folder_step4, f"{patient_hemisphere_id}_Bivariate_AP_vs_Oscillatory_Correlations_FDR_Cell5.csv"), index=False)
        print(f"Saved Bivariate AP vs Oscillatory correlation results (FDR corrected) for {patient_hemisphere_id}.")

print("\\n--- Cell 5 (Revised V6 - All Data + FDR): Correlation Analyses Complete ---")

In [None]:
# -*- coding: utf-8 -*-
# --- Cell 6 (Streamlined V4 - Tiered MLR with LRT): Multiple Linear Regression ---
# Focuses on model comparison using Likelihood Ratio Test (LRT) for nested models.
# Interpretation output is cleaned, and AIC/BIC are included.

import pandas as pd
import numpy as np
import os
import statsmodels.formula.api as smf
from statsmodels.stats.anova import anova_lm
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import spearmanr

print("\n--- Cell 6 (Streamlined V4 - Tiered MLR with LRT): Starting Analyses ---\n")

P_VALUE_THRESHOLD = 0.05
print(f"p-value is now defined as {P_VALUE_THRESHOLD}.")

# <<< --- USER TOGGLE --- >>>
ANALYZE_ALL_FREQ_BANDS = False
TARGET_FREQ_BAND_IF_NOT_ALL = "WideFreq"
# <<< ------------------ >>>

if 'master_df_step4' not in locals() or master_df_step4.empty:
    print("master_df_step4 not available or empty. Skipping Cell 6.")
else:
    plot_subdir_mlr = os.path.join(analysis_session_plot_folder_step4, "MultipleLinearRegression_PKG_on_Neural_STREAMLINED_V4") # New folder name
    os.makedirs(plot_subdir_mlr, exist_ok=True)
    
    plot_subdir_ap_corr = os.path.join(plot_subdir_mlr, "Aperiodic_Intercorrelations")
    os.makedirs(plot_subdir_ap_corr, exist_ok=True)
    
    mlr_results_list_streamlined_v4 = []
    lrt_results_list = [] # New list to store LRT results
    exp_offset_corr_results_list = []

    available_aperiodic_cols = [col for col in APERIODIC_METRICS_COLS.keys() if col in master_df_step4.columns]
    available_oscillatory_cols = [col for col in OSCILLATORY_METRICS_COLS.keys() if col in master_df_step4.columns]

    exponent_col_name = 'Exponent_BestModel'
    offset_col_name = 'Offset_BestModel'
    beta_col_name = 'Beta_Peak_Power_at_DominantFreq'
    gamma_col_name = 'Gamma_Peak_Power_at_DominantFreq'

    # Helper function modified to return the model fit and include AIC/BIC.
    # (In Cell 6) Replace the old helper function with this one.
    def fit_and_interpret_mlr_focused_v4(data_df, dv_col, dv_name, predictors, model_tier_label,
                                         channel="N/A", freq_band="N/A"):
        
        if not predictors or not all(p in data_df.columns for p in predictors):
            missing_p = [p for p in predictors if p not in data_df.columns] if predictors else []
            if predictors and missing_p:
                print(f"      Skipping {dv_name} ({model_tier_label}): Predictor(s) {missing_p} not found.")
            elif not predictors:
                print(f"      Skipping {dv_name} ({model_tier_label}): No valid predictors for this model tier.")
            return None
    
        formula_to_fit = f"{dv_col} ~ {' + '.join(predictors)}"
        current_data_for_model = data_df[[dv_col] + predictors].dropna(how='any').copy()
    
        if len(current_data_for_model) < (len(predictors) + 10):
            print(f"      Skipping {dv_name} ({model_tier_label}): Insufficient data ({len(current_data_for_model)}) for {len(predictors)} predictors.")
            return None
    
        for pred_check in predictors:
            if current_data_for_model[pred_check].nunique() < 2:
                print(f"      Skipping {dv_name} ({model_tier_label}): Constant predictor {pred_check} found.")
                return None
        
        print(f"      --- Model Tier: {model_tier_label} ---")
        print(f"      Fitting for DV '{dv_name}' with formula: {formula_to_fit}")
        
        try:
            model_fit = smf.ols(formula=formula_to_fit, data=current_data_for_model).fit()
            
            interpretation_string = f"      --- Interpretation for {dv_name} ({model_tier_label} - {channel}, {freq_band}) ---\n"
            interpretation_string += f"      Overall Model Fit: Explains {model_fit.rsquared_adj * 100:.1f}% of variance (Adj. R² = {model_fit.rsquared_adj:.3f}, N = {model_fit.nobs}).\n"
            interpretation_string += f"      Model Selection Criteria: AIC = {model_fit.aic:.2f}, BIC = {model_fit.bic:.2f}.\n"
            print(interpretation_string) # Print interpretation right away
            
            for term_in_model in model_fit.params.index:
                if term_in_model == 'Intercept': continue
                
                # This is the dictionary that gets appended to the results list
                mlr_results_list_streamlined_v4.append({
                    'Channel': channel, 'FreqBand_Aperiodics': freq_band, 'PKG_Symptom_DV': dv_name,
                    'Model_Tier': model_tier_label, 'Formula': formula_to_fit,
                    'Predictor_Term': term_in_model,
                    'Predictor_Name_Display': APERIODIC_METRICS_COLS.get(term_in_model, OSCILLATORY_METRICS_COLS.get(term_in_model, term_in_model)),
                    'Coefficient': model_fit.params.get(term_in_model, np.nan),
                    'StdErr': model_fit.bse.get(term_in_model, np.nan),
                    'PValue': model_fit.pvalues.get(term_in_model, np.nan),
                    # === ADDED CONFIDENCE INTERVALS ===
                    'Conf_Int_Lower': model_fit.conf_int().loc[term_in_model, 0] if term_in_model in model_fit.conf_int().index else np.nan,
                    'Conf_Int_Upper': model_fit.conf_int().loc[term_in_model, 1] if term_in_model in model_fit.conf_int().index else np.nan,
                    # ==================================
                    'N_model': model_fit.nobs, 'R_squared_adj_model': model_fit.rsquared_adj,
                    'AIC_model': model_fit.aic, 'BIC_model': model_fit.bic
                })
            
            print(f"      --- End of Interpretation for {model_tier_label} ---\n")
            return model_fit
            
        except Exception as e_mlr_fit:
            print(f"      ERROR fitting MLR for {dv_name} ({model_tier_label}): {e_mlr_fit}")
            return None


    freq_bands_to_process = [TARGET_FREQ_BAND_IF_NOT_ALL] if not ANALYZE_ALL_FREQ_BANDS else ORDERED_FREQ_LABELS
    if not ANALYZE_ALL_FREQ_BANDS: print(f"--- Analyzing ONLY for Freq Band: {TARGET_FREQ_BAND_IF_NOT_ALL} ---")

    for channel_label_iter in ORDERED_CHANNEL_LABELS:
        df_channel_mlr_main = master_df_step4[master_df_step4[CHANNEL_DISPLAY_COL] == channel_label_iter]
        if df_channel_mlr_main.empty: continue
        print(f"\n>>> Processing MLR for Channel: {channel_label_iter} <<<\n")

        for freq_label_iter in freq_bands_to_process:
            df_channel_freq_mlr_main = df_channel_mlr_main[df_channel_mlr_main[FOOOF_FREQ_BAND_COL] == freq_label_iter].copy()
            if df_channel_freq_mlr_main.empty: continue
            print(f"  --- Freq Band: {freq_label_iter} ---\n")

            # Aperiodic inter-correlation check (same as before)
            if exponent_col_name in df_channel_freq_mlr_main.columns and offset_col_name in df_channel_freq_mlr_main.columns:
                # This part of the code for correlation plotting remains the same...
                # (Code for correlation check and plotting is omitted for brevity but should be kept from your original script)
                pass

            for pkg_col_iter, pkg_name_iter in PKG_METRICS_COLS.items():
                if pkg_col_iter not in df_channel_freq_mlr_main.columns: continue
                print(f"    --- Predicting: {pkg_name_iter} (DV: {pkg_col_iter}) ---\n")

                # --- NEW STRUCTURE FOR LRT ---
                # These variables will hold the fitted models for comparison
                osc_only_fit = None
                exp_osc_fit = None
                off_osc_fit = None
                
                # --- FIT BASELINE MODELS ---
                # Fit Oscillatory Only model (this is our primary reduced model)
                osc_predictors = [p for p in [beta_col_name, gamma_col_name] if p in available_oscillatory_cols and p in df_channel_freq_mlr_main.columns]
                if len(osc_predictors) > 0:
                    osc_only_fit = fit_and_interpret_mlr_focused_v4(df_channel_freq_mlr_main, pkg_col_iter, pkg_name_iter,
                                                                    osc_predictors, "Tier 1: Oscillatory Only",
                                                                    channel=channel_label_iter, freq_band=freq_label_iter)

                # --- FIT FULL MODELS AND PERFORM LRT ---
                # Test adding Exponent
                if osc_only_fit and exponent_col_name in available_aperiodic_cols:
                    exp_osc_predictors = [exponent_col_name] + osc_predictors
                    exp_osc_fit = fit_and_interpret_mlr_focused_v4(df_channel_freq_mlr_main, pkg_col_iter, pkg_name_iter,
                                                                   exp_osc_predictors, "Tier 2: Exponent + Oscillatory",
                                                                   channel=channel_label_iter, freq_band=freq_label_iter)
                    # Perform LRT if both models were fit successfully
                    if exp_osc_fit:
                        lrt_results_exp = anova_lm(osc_only_fit, exp_osc_fit)
                        p_value_lrt = lrt_results_exp.iloc[1]['Pr(>F)']
                        f_stat_lrt = lrt_results_exp.iloc[1]['F']
                        
                        lrt_interpretation = f"      --- Likelihood Ratio Test (LRT) for Exponent ---\n"
                        lrt_interpretation += f"      Comparing 'Oscillatory Only' vs 'Exponent + Oscillatory' for DV '{pkg_name_iter}'.\n"
                        lrt_interpretation += f"      LRT Result: F-statistic = {f_stat_lrt:.3f}, p-value = {p_value_lrt:.4g}\n"
                        if p_value_lrt < P_VALUE_THRESHOLD:
                            lrt_interpretation += f"      Interpretation: Adding Exponent resulted in a SIGNIFICANT improvement in model fit (p < {P_VALUE_THRESHOLD}).\n"
                        else:
                            lrt_interpretation += f"      Interpretation: Adding Exponent did NOT significantly improve model fit (p >= {P_VALUE_THRESHOLD}).\n"
                        print(lrt_interpretation)
                        
                        lrt_results_list.append({
                            'Channel': channel_label_iter, 'FreqBand': freq_label_iter, 'PKG_Symptom_DV': pkg_name_iter,
                            'Comparison': 'Exponent + Osc vs. Osc Only',
                            'F_statistic': f_stat_lrt, 'P_value': p_value_lrt, 'N_reduced': osc_only_fit.nobs, 'N_full': exp_osc_fit.nobs
                        })
                
                # Test adding Offset
                if osc_only_fit and offset_col_name in available_aperiodic_cols:
                    off_osc_predictors = [offset_col_name] + osc_predictors
                    off_osc_fit = fit_and_interpret_mlr_focused_v4(df_channel_freq_mlr_main, pkg_col_iter, pkg_name_iter,
                                                                   off_osc_predictors, "Tier 3: Offset + Oscillatory",
                                                                   channel=channel_label_iter, freq_band=freq_label_iter)
                    # Perform LRT if both models were fit successfully
                    if off_osc_fit:
                        lrt_results_off = anova_lm(osc_only_fit, off_osc_fit)
                        p_value_lrt = lrt_results_off.iloc[1]['Pr(>F)']
                        f_stat_lrt = lrt_results_off.iloc[1]['F']
                        
                        lrt_interpretation = f"      --- Likelihood Ratio Test (LRT) for Offset ---\n"
                        lrt_interpretation += f"      Comparing 'Oscillatory Only' vs 'Offset + Oscillatory' for DV '{pkg_name_iter}'.\n"
                        lrt_interpretation += f"      LRT Result: F-statistic = {f_stat_lrt:.3f}, p-value = {p_value_lrt:.4g}\n"
                        if p_value_lrt < P_VALUE_THRESHOLD:
                            lrt_interpretation += f"      Interpretation: Adding Offset resulted in a SIGNIFICANT improvement in model fit (p < {P_VALUE_THRESHOLD}).\n"
                        else:
                            lrt_interpretation += f"      Interpretation: Adding Offset did NOT significantly improve model fit (p >= {P_VALUE_THRESHOLD}).\n"
                        print(lrt_interpretation)

                        lrt_results_list.append({
                            'Channel': channel_label_iter, 'FreqBand': freq_label_iter, 'PKG_Symptom_DV': pkg_name_iter,
                            'Comparison': 'Offset + Osc vs. Osc Only',
                            'F_statistic': f_stat_lrt, 'P_value': p_value_lrt, 'N_reduced': osc_only_fit.nobs, 'N_full': off_osc_fit.nobs
                        })


    # Saving all results to CSV files at the end

    if lrt_results_list:
        df_lrt_results = pd.DataFrame(lrt_results_list)
        csv_filename_lrt = f"{patient_hemisphere_id}_MLR_LRT_Results_Step6.csv"
        df_lrt_results.to_csv(os.path.join(plot_subdir_mlr, csv_filename_lrt), index=False)
        print(f"\nSaved Likelihood Ratio Test (LRT) results to {csv_filename_lrt}.")
        print("Sample of LRT results:")
        print(df_lrt_results.head())
    else:
        print("\nNo Likelihood Ratio Tests were performed.")

    if mlr_results_list_streamlined_v4:
        df_mlr_results_step6_streamlined_v4 = pd.DataFrame(mlr_results_list_streamlined_v4)
        csv_filename_mlr_streamlined_v4 = f"{patient_hemisphere_id}_MLR_Streamlined_V4_Results_Step6.csv"
        df_mlr_results_step6_streamlined_v4.to_csv(os.path.join(plot_subdir_mlr, csv_filename_mlr_streamlined_v4), index=False)
        print(f"\nSaved Step 6 Streamlined V4 MLR results for {patient_hemisphere_id} to {csv_filename_mlr_streamlined_v4}.")
        cols_to_show_mlr = ['Channel', 'FreqBand_Aperiodics', 'PKG_Symptom_DV', 'Model_Tier',
                            'Predictor_Term', 'Coefficient', 'PValue', 'R_squared_adj_model', 'N_model']
        print(df_mlr_results_step6_streamlined_v4[[c for c in cols_to_show_mlr if c in df_mlr_results_step6_streamlined_v4.columns]].head())
    else:
        print("\nNo Streamlined V4 MLR models were successfully fitted or no results to save for Step 6.")

print(f"\n--- Cell 6 (Streamlined V4 - Tiered MLR with LRT): Analyses Complete ---")

In [None]:
# -*- coding: utf-8 -*-
# --- Cell 6A (Revised): State-Specific Multiple Linear Regression with LRT + FDR ---
# Changes:
# 1. ADDED FDR correction for all LRT p-values
# 2. All data is kept (no exclusion of zeros)
# Repeats the tiered MLR and Likelihood Ratio Test analyses for each of the
# TARGET_CLINICAL_STATES_ORDERED to see if relationships change with clinical state.

import pandas as pd
import numpy as np
import os
import statsmodels.formula.api as smf
from statsmodels.stats.anova import anova_lm
from statsmodels.stats.multitest import fdrcorrection
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import spearmanr

print("\\n--- Cell 6A (Revised): Starting State-Specific MLR & LRT Analyses with FDR ---\\n")

# <<< --- USER TOGGLE --- >>>
# This analysis is intensive. It is recommended to run it on a single frequency band first.
ANALYZE_ALL_FREQ_BANDS_STATE_SPECIFIC = False
TARGET_FREQ_BAND_IF_NOT_ALL_STATE_SPECIFIC = "WideFreq"
# <<< ------------------ >>>

if 'master_df_step4' not in locals() or master_df_step4.empty:
    print("master_df_step4 not available or empty. Skipping Cell 6A.")
else:
    # Define a new directory for this state-specific analysis output
    plot_subdir_mlr_states = os.path.join(STATE_SPECIFIC_ANALYSIS_DIR, "MultipleLinearRegression_State_Specific_FDR")
    os.makedirs(plot_subdir_mlr_states, exist_ok=True)
    
    # New lists to hold the results from all states
    mlr_results_list_state_specific = []
    lrt_results_list_state_specific = []
    
    # Lists to collect all p-values for FDR
    all_lrt_p_values_states = []
    lrt_p_value_mapping_states = []

    # Define variable names for clarity
    available_aperiodic_cols = [col for col in APERIODIC_METRICS_COLS.keys() if col in master_df_step4.columns]
    available_oscillatory_cols = [col for col in OSCILLATORY_METRICS_COLS.keys() if col in master_df_step4.columns]
    exponent_col_name = 'Exponent_BestModel'
    offset_col_name = 'Offset_BestModel'
    beta_col_name = 'Beta_Peak_Power_at_DominantFreq'
    gamma_col_name = 'Gamma_Peak_Power_at_DominantFreq'

    # Helper function is modified to accept 'clinical_state' for inclusion in results
    def fit_and_interpret_mlr_state_specific(data_df, dv_col, dv_name, predictors, model_tier_label,
                                             channel="N/A", freq_band="N/A", clinical_state="N/A"):
        
        if not predictors or not all(p in data_df.columns for p in predictors):
            return None
    
        formula_to_fit = f"{dv_col} ~ {' + '.join(predictors)}"
        current_data_for_model = data_df[[dv_col] + predictors].dropna(how='any').copy()
    
        if len(current_data_for_model) < (len(predictors) + 10):
            return None
    
        for pred_check in predictors:
            if current_data_for_model[pred_check].nunique() < 2:
                return None
        
        try:
            model_fit = smf.ols(formula=formula_to_fit, data=current_data_for_model).fit()
            
            for term_in_model in model_fit.params.index:
                if term_in_model == 'Intercept': continue
                
                mlr_results_list_state_specific.append({
                    'ClinicalState': clinical_state,
                    'Channel': channel, 'FreqBand_Aperiodics': freq_band, 'PKG_Symptom_DV': dv_name,
                    'Model_Tier': model_tier_label, 'Formula': formula_to_fit,
                    'Predictor_Term': term_in_model,
                    'Predictor_Name_Display': APERIODIC_METRICS_COLS.get(term_in_model, OSCILLATORY_METRICS_COLS.get(term_in_model, term_in_model)),
                    'Coefficient': model_fit.params.get(term_in_model, np.nan),
                    'StdErr': model_fit.bse.get(term_in_model, np.nan),
                    'PValue': model_fit.pvalues.get(term_in_model, np.nan),
                    'Conf_Int_Lower': model_fit.conf_int().loc[term_in_model, 0] if term_in_model in model_fit.conf_int().index else np.nan,
                    'Conf_Int_Upper': model_fit.conf_int().loc[term_in_model, 1] if term_in_model in model_fit.conf_int().index else np.nan,
                    'N_model': model_fit.nobs, 'R_squared_adj_model': model_fit.rsquared_adj,
                    'AIC_model': model_fit.aic, 'BIC_model': model_fit.bic
                })
            return model_fit
        except Exception:
            return None

    freq_bands_to_process = [TARGET_FREQ_BAND_IF_NOT_ALL_STATE_SPECIFIC] if not ANALYZE_ALL_FREQ_BANDS_STATE_SPECIFIC else ORDERED_FREQ_LABELS
    if not ANALYZE_ALL_FREQ_BANDS_STATE_SPECIFIC: print(f"--- Analyzing ONLY for Freq Band: {TARGET_FREQ_BAND_IF_NOT_ALL_STATE_SPECIFIC} ---")

    # --- Outermost loop for Clinical States ---
    for state_current in TARGET_CLINICAL_STATES_ORDERED:
        df_state = master_df_step4[master_df_step4[CLINICAL_STATE_COL] == state_current]
        
        if df_state.empty or len(df_state) < 20: # Setting a reasonable minimum N for a state
            print(f"\\nSKIPPING Clinical State: {state_current} (Insufficient data: N={len(df_state)})")
            continue
            
        print(f"\\n>>> Processing MLR for Clinical State: {state_current} (N={len(df_state)}) <<<")

        for channel_label_iter in ORDERED_CHANNEL_LABELS:
            df_channel_state = df_state[df_state[CHANNEL_DISPLAY_COL] == channel_label_iter]
            if df_channel_state.empty: continue
            print(f"\\n  >> Channel: {channel_label_iter}")

            for freq_label_iter in freq_bands_to_process:
                df_channel_freq_state = df_channel_state[df_channel_state[FOOOF_FREQ_BAND_COL] == freq_label_iter].copy()
                if df_channel_freq_state.empty: continue
                print(f"\\n    -- Freq Band: {freq_label_iter} --")

                for pkg_col_iter, pkg_name_iter in PKG_METRICS_COLS.items():
                    if pkg_col_iter not in df_channel_freq_state.columns: continue
                    print(f"\\n      -- Predicting: {pkg_name_iter} (DV: {pkg_col_iter}) --")

                    osc_only_fit, exp_osc_fit, off_osc_fit = None, None, None
                    
                    osc_predictors = [p for p in [beta_col_name, gamma_col_name] if p in available_oscillatory_cols and p in df_channel_freq_state.columns]
                    if len(osc_predictors) > 0:
                        osc_only_fit = fit_and_interpret_mlr_state_specific(
                            df_channel_freq_state, pkg_col_iter, pkg_name_iter, osc_predictors, 
                            "Tier 1: Oscillatory Only", channel_label_iter, freq_label_iter, state_current)

                    if osc_only_fit and exponent_col_name in available_aperiodic_cols:
                        exp_osc_predictors = [exponent_col_name] + osc_predictors
                        exp_osc_fit = fit_and_interpret_mlr_state_specific(
                            df_channel_freq_state, pkg_col_iter, pkg_name_iter, exp_osc_predictors, 
                            "Tier 2: Exponent + Oscillatory", channel_label_iter, freq_label_iter, state_current)
                        
                        if exp_osc_fit:
                            lrt_results_exp = anova_lm(osc_only_fit, exp_osc_fit)
                            p_val_lrt = lrt_results_exp.iloc[1]['Pr(>F)']
                            f_stat_lrt = lrt_results_exp.iloc[1]['F']
                            print(f"        LRT (Exponent): F={f_stat_lrt:.2f}, p={p_val_lrt:.3g}")
                            
                            lrt_result_dict = {
                                'ClinicalState': state_current, 'Channel': channel_label_iter, 'FreqBand': freq_label_iter, 
                                'PKG_Symptom_DV': pkg_name_iter, 'Comparison': 'Exponent + Osc vs. Osc Only',
                                'F_statistic': f_stat_lrt, 'P_value': p_val_lrt, 'N_reduced': osc_only_fit.nobs, 'N_full': exp_osc_fit.nobs
                            }
                            lrt_results_list_state_specific.append(lrt_result_dict)
                            
                            # Collect p-value for FDR
                            if not pd.isna(p_val_lrt):
                                all_lrt_p_values_states.append(p_val_lrt)
                                lrt_p_value_mapping_states.append(len(lrt_results_list_state_specific) - 1)
                    
                    if osc_only_fit and offset_col_name in available_aperiodic_cols:
                        off_osc_predictors = [offset_col_name] + osc_predictors
                        off_osc_fit = fit_and_interpret_mlr_state_specific(
                            df_channel_freq_state, pkg_col_iter, pkg_name_iter, off_osc_predictors,
                            "Tier 3: Offset + Oscillatory", channel_label_iter, freq_label_iter, state_current)
                        
                        if off_osc_fit:
                            lrt_results_off = anova_lm(osc_only_fit, off_osc_fit)
                            p_val_lrt = lrt_results_off.iloc[1]['Pr(>F)']
                            f_stat_lrt = lrt_results_off.iloc[1]['F']
                            print(f"        LRT (Offset):   F={f_stat_lrt:.2f}, p={p_val_lrt:.3g}")
                            
                            lrt_result_dict = {
                                'ClinicalState': state_current, 'Channel': channel_label_iter, 'FreqBand': freq_label_iter, 
                                'PKG_Symptom_DV': pkg_name_iter, 'Comparison': 'Offset + Osc vs. Osc Only',
                                'F_statistic': f_stat_lrt, 'P_value': p_val_lrt, 'N_reduced': osc_only_fit.nobs, 'N_full': off_osc_fit.nobs
                            }
                            lrt_results_list_state_specific.append(lrt_result_dict)
                            
                            # Collect p-value for FDR
                            if not pd.isna(p_val_lrt):
                                all_lrt_p_values_states.append(p_val_lrt)
                                lrt_p_value_mapping_states.append(len(lrt_results_list_state_specific) - 1)

    # Apply FDR correction to all state-specific LRT p-values
    print(f"\\nCollected {len(all_lrt_p_values_states)} state-specific LRT p-values for FDR correction.")
    
    if all_lrt_p_values_states:
        rejected, pvals_corrected = fdrcorrection(all_lrt_p_values_states, alpha=0.05, method='indep', is_sorted=False)
        
        # Map corrected p-values back to results
        for idx, result_idx in enumerate(lrt_p_value_mapping_states):
            lrt_results_list_state_specific[result_idx]['P_value_FDR'] = pvals_corrected[idx]
            lrt_results_list_state_specific[result_idx]['Significant_FDR'] = rejected[idx]
        
        print("FDR correction applied to state-specific LRT results.")

    # --- Saving all collected state-specific results ---
    if lrt_results_list_state_specific:
        df_lrt_results_states = pd.DataFrame(lrt_results_list_state_specific)
        csv_filename_lrt_states = f"{patient_hemisphere_id}_MLR_LRT_StateSpecific_Results_FDR_Step6A.csv"
        df_lrt_results_states.to_csv(os.path.join(plot_subdir_mlr_states, csv_filename_lrt_states), index=False)
        print(f"\\nSaved State-Specific Likelihood Ratio Test (LRT) results with FDR to {csv_filename_lrt_states}.")
        print("Sample of State-Specific LRT results with FDR:")
        print(df_lrt_results_states[['ClinicalState', 'Channel', 'PKG_Symptom_DV', 'Comparison', 'P_value', 'P_value_FDR', 'Significant_FDR']].head())
    else:
        print("\\nNo State-Specific Likelihood Ratio Tests were successfully performed.")

    if mlr_results_list_state_specific:
        df_mlr_results_states = pd.DataFrame(mlr_results_list_state_specific)
        csv_filename_mlr_states = f"{patient_hemisphere_id}_MLR_StateSpecific_Results_Step6A.csv"
        df_mlr_results_states.to_csv(os.path.join(plot_subdir_mlr_states, csv_filename_mlr_states), index=False)
        print(f"\\nSaved State-Specific MLR model results for {patient_hemisphere_id} to {csv_filename_mlr_states}.")
        cols_to_show_mlr = ['ClinicalState', 'Channel', 'PKG_Symptom_DV', 'Model_Tier',
                            'Predictor_Term', 'Coefficient', 'PValue', 'N_model']
        print("Sample of State-Specific MLR results:")
        print(df_mlr_results_states[[c for c in cols_to_show_mlr if c in df_mlr_results_states.columns]].head())
    else:
        print("\\nNo State-Specific MLR models were successfully fitted.")

print(f"\\n--- Cell 6A (Revised): State-Specific MLR & LRT Analyses with FDR Complete ---")

In [None]:
# -*- coding: utf-8 -*-
# --- Cell 7 (Revised V5): Visualization of MLR and LRT Results ---
# Generates plots for both Global (Cell 6) and State-Specific (Cell 6A) results.
# 1. Separates Exponent and Offset coefficient plots.
# 2. Increases all font sizes for readability.
# 3. Moves legend outside the plot area and tightens data point grouping.
# 4. Enforces a specific channel and symptom order and simplifies legend labels.

import pandas as pd
import numpy as np
import os
import matplotlib.pyplot as plt
import seaborn as sns

print("\n--- Cell 7 (Revised V5): Starting Visualization of All Regression Results ---\n")

# --- Global Font Size Adjustment ---
font_scale_factor = 2
plt.rcParams.update({
    'font.size': 10 * font_scale_factor,
    'axes.labelsize': 10 * font_scale_factor,
    'axes.titlesize': 12 * font_scale_factor,
    'xtick.labelsize': 9 * font_scale_factor,
    'ytick.labelsize': 8 * font_scale_factor,
    'legend.fontsize': 9 * font_scale_factor,
    'legend.title_fontsize': 10 * font_scale_factor,
})
P_VALUE_THRESHOLD = 0.05

# --- Custom Ordering and Naming Definitions ---
CHANNEL_ORDER = ['STN_DBS_2-0', 'STN_DBS_3-1', 'Cortical_ECoG_10-8', 'Cortical_ECoG_11-9']
SYMPTOM_LEGEND_MAP = {
    'PKG BK Score': 'Bradykinesia',
    'PKG DK Score': 'Dyskinesia',
    'PKG Tremor Score': 'Tremor'
}
SYMPTOM_DISPLAY_ORDER = ['Bradykinesia', 'Dyskinesia', 'Tremor']

# --- Data Preparation Helper Function ---
def prepare_plot_data(df):
    """Applies custom sorting and naming to a dataframe for plotting."""
    # Simplify legend labels
    df['Symptom_Display'] = df['PKG_Symptom_DV'].map(SYMPTOM_LEGEND_MAP).fillna(df['PKG_Symptom_DV'])
    
    # Enforce specific symptom order
    symptoms_in_data = [s for s in SYMPTOM_DISPLAY_ORDER if s in df['Symptom_Display'].unique()]
    df['Symptom_Display'] = pd.Categorical(df['Symptom_Display'], categories=symptoms_in_data, ordered=True)
    
    # Enforce specific channel order
    channels_in_data = [ch for ch in CHANNEL_ORDER if ch in df['Channel'].unique()]
    df['Channel'] = pd.Categorical(df['Channel'], categories=channels_in_data, ordered=True)
    
    # Sort by the new categorical orders to ensure data is structured for plotting
    df = df.sort_values(['Channel', 'Symptom_Display'])
    return df

# --- Plotting Function 1: Dot-and-Whisker Coefficient Plot ---
def plot_coefficient_dot_whisker(df_mlr, predictor_to_plot, output_path):
    predictor_name_map = {'Exponent_BestModel': 'Exponent', 'Offset_BestModel': 'Offset'}
    predictor_display_name = predictor_name_map.get(predictor_to_plot, predictor_to_plot)
    
    print(f"Generating Dot-and-Whisker plot for '{predictor_display_name}'...")

    required_cols = ['Coefficient', 'Conf_Int_Lower', 'Conf_Int_Upper', 'Predictor_Term']
    if not all(col in df_mlr.columns for col in required_cols):
        missing = [col for col in required_cols if col not in df_mlr.columns]
        print(f"ERROR: Input DataFrame is missing required columns: {missing}. Skipping plot.")
        return

    df_plot = df_mlr[df_mlr['Predictor_Term'] == predictor_to_plot].copy()
    df_plot = prepare_plot_data(df_plot)

    if df_plot.empty:
        print(f"No data found for predictor '{predictor_to_plot}'. Skipping plot.")
        return

    if 'ClinicalState' not in df_plot.columns:
        df_plot['ClinicalState'] = 'Overall Results'
        
    states = sorted(df_plot['ClinicalState'].unique())
    # Get ordered channels and symptoms from the categorical data type
    channels = df_plot['Channel'].cat.categories.tolist()
    symptoms = df_plot['Symptom_Display'].cat.categories.tolist()
    
    symptom_colors = dict(zip(symptoms, sns.color_palette('bright', len(symptoms))))

    fig, axes = plt.subplots(len(states), 1, figsize=(16, 8 * len(states)), sharex=True, sharey=True, squeeze=False)
    axes = axes.flatten()

    for i, state in enumerate(states):
        ax = axes[i]
        ax.set_title(state, pad=20)
        df_state = df_plot[df_plot['ClinicalState'] == state]
        
        channel_indices = np.arange(len(channels))
        dodge_width = 0.5 
        symptom_positions = np.linspace(-dodge_width / 2, dodge_width / 2, len(symptoms))

        for sym_idx, symptom in enumerate(symptoms):
            temp_df = df_state[df_state['Symptom_Display'] == symptom]
            if not temp_df.empty:
                # Use the categorical codes for x-positioning
                x_pos = temp_df['Channel'].cat.codes.values + symptom_positions[sym_idx]
                y = temp_df['Coefficient']
                lower_err = y - temp_df['Conf_Int_Lower']
                upper_err = temp_df['Conf_Int_Upper'] - y
                
                ax.errorbar(x=x_pos, y=y, yerr=[lower_err, upper_err], 
                            fmt='o', color=symptom_colors[symptom], label=symptom,
                            capsize=8, markersize=12, linestyle='none')

        ax.axhline(0, ls='--', color='black', lw=2, zorder=0)
        ax.set_xticks(channel_indices)
        ax.set_xticklabels(channels, rotation=45, ha='right')
        ax.set_ylabel("Regression Coefficient (95% CI)")
        ax.grid(axis='y', linestyle=':', alpha=0.7)

    handles = [plt.Line2D([0], [0], color=symptom_colors[s], marker='o', linestyle='None', markersize=12) for s in symptoms]
    labels = symptoms
    if labels:
        fig.legend(handles, labels, title="Symptom Score", bbox_to_anchor=(1.02, 0.9), loc='upper left')
    
    fig.suptitle(f"{predictor_display_name} Coefficients vs. PKG Scores", y=1.0, fontsize=14 * font_scale_factor)
    plt.tight_layout(rect=[0, 0, 0.88, 0.96])
    plt.savefig(output_path, bbox_inches='tight')
    plt.close()
    print(f"Saved plot to: {output_path}")

# --- Plotting Function 2: Stacked Bar Chart for Orthogonal Value ---
def plot_orthogonal_value_stacked_bar(df_mlr, df_lrt, output_path):
    print("Generating Stacked Bar Chart for Orthogonal Value...")
    
    df_mlr = prepare_plot_data(df_mlr.copy())
    
    if 'ClinicalState' not in df_mlr.columns:
        df_mlr['ClinicalState'] = 'Overall Results'
    if 'ClinicalState' not in df_lrt.columns:
        df_lrt['ClinicalState'] = 'Overall Results'
        
    df_reduced = df_mlr[df_mlr['Model_Tier'] == 'Tier 1: Oscillatory Only'].drop_duplicates(subset=['ClinicalState', 'Channel', 'PKG_Symptom_DV'])
    df_full = df_mlr[df_mlr['Model_Tier'] == 'Tier 2: Exponent + Oscillatory'].drop_duplicates(subset=['ClinicalState', 'Channel', 'PKG_Symptom_DV'])
    
    df_reduced = df_reduced.rename(columns={'R_squared_adj_model': 'R2_Reduced'})
    df_full = df_full.rename(columns={'R_squared_adj_model': 'R2_Full'})

    if df_reduced.empty or df_full.empty:
        print("Could not find data for both reduced and full models. Skipping R-squared plot.")
        return
        
    merge_cols = ['ClinicalState', 'Channel', 'PKG_Symptom_DV', 'Symptom_Display']
    df_r2 = pd.merge(df_reduced[merge_cols + ['R2_Reduced']], df_full[merge_cols + ['R2_Full']], on=merge_cols, how='inner')
    df_r2['R2_Added_by_Exponent'] = (df_r2['R2_Full'] - df_r2['R2_Reduced']).clip(lower=0)
    
    df_lrt_exp = df_lrt[df_lrt['Comparison'] == 'Exponent + Osc vs. Osc Only'].copy()
    df_plot = pd.merge(df_r2, df_lrt_exp[['ClinicalState', 'Channel', 'PKG_Symptom_DV', 'P_value']], on=['ClinicalState', 'Channel', 'PKG_Symptom_DV'], how='left')
    df_plot['is_significant'] = df_plot['P_value'] < P_VALUE_THRESHOLD

    if df_plot.empty:
        print("No data to plot for orthogonal value. Skipping.")
        return

    g = sns.FacetGrid(df_plot, col='ClinicalState', col_wrap=2, height=8, aspect=2, sharey=True)
    g.map_dataframe(sns.barplot, x='Channel', y='R2_Full', hue='Symptom_Display', palette='viridis', dodge=0.8, errorbar=None, hue_order=SYMPTOM_DISPLAY_ORDER)
    g.map_dataframe(sns.barplot, x='Channel', y='R2_Reduced', hue='Symptom_Display', palette='Greys', dodge=0.8, errorbar=None, zorder=2, hue_order=SYMPTOM_DISPLAY_ORDER)

    for i, ax in enumerate(g.axes.flat):
        state_name = g.col_names[i]
        state_df = df_plot[df_plot['ClinicalState'] == state_name]
        
        channel_labels = [label.get_text() for label in ax.get_xticklabels()]
        symptom_labels = state_df['Symptom_Display'].cat.categories.tolist()
        
        num_symptoms = len(symptom_labels)
        dodge_width = 0.8
        bar_width = dodge_width / num_symptoms
        
        for _, row in state_df.iterrows():
            if row['is_significant']:
                try:
                    channel_idx = channel_labels.index(row['Channel'].__str__())
                    symptom_idx = symptom_labels.index(row['Symptom_Display'])
                    x_pos = channel_idx - (dodge_width / 2) + (symptom_idx * bar_width) + (bar_width / 2)
                    y_pos = row['R2_Full']
                    ax.text(x_pos, y_pos + 0.01, '*', ha='center', va='bottom', color='red', fontsize=16 * font_scale_factor, zorder=3)
                except (ValueError, IndexError):
                    continue

    g.set_axis_labels("Channel", "Adjusted R-squared")
    g.set_titles("{col_name}", pad=20)
    g.fig.suptitle("Orthogonal Value of Exponent (Added R²)", y=1.0, fontsize=14 * font_scale_factor)
    g.set_xticklabels(rotation=45, ha='right')
    g.add_legend(title='Symptom Score')
    g.fig.tight_layout(rect=[0, 0, 1, 0.97])
    plt.savefig(output_path)
    plt.close()
    print(f"Saved stacked bar plot to: {output_path}")

# --- Main Execution Block ---
def run_visualizations(analysis_type, results_dir, mlr_filename, lrt_filename, output_dir):
    print(f"\n{'='*20}\n--- Generating {analysis_type} Plots ---\n{'='*20}")
    
    mlr_path = os.path.join(results_dir, mlr_filename)
    lrt_path = os.path.join(results_dir, lrt_filename)

    try:
        df_mlr = pd.read_csv(mlr_path)
        df_lrt = pd.read_csv(lrt_path)
    except FileNotFoundError as e:
        print(f"SKIPPING: Could not find results file at {e.filename}. Please ensure the preceding cell was run.")
        return

    plot_coefficient_dot_whisker(df_mlr.copy(), 'Exponent_BestModel', os.path.join(output_dir, f"{patient_hemisphere_id}_{analysis_type}_Exponent_Coefficients.png"))
    plot_coefficient_dot_whisker(df_mlr.copy(), 'Offset_BestModel', os.path.join(output_dir, f"{patient_hemisphere_id}_{analysis_type}_Offset_Coefficients.png"))
    plot_orthogonal_value_stacked_bar(df_mlr.copy(), df_lrt.copy(), os.path.join(output_dir, f"{patient_hemisphere_id}_{analysis_type}_Orthogonal_Value.png"))


if __name__ == "__main__" and 'patient_hemisphere_id' in locals():
    # --- Run for Global (Cell 6) Results ---
    global_results_dir = os.path.join(analysis_session_plot_folder_step4, "MultipleLinearRegression_PKG_on_Neural_STREAMLINED_V4")
    global_mlr_file = f"{patient_hemisphere_id}_MLR_Streamlined_V4_Results_Step6.csv"
    global_lrt_file = f"{patient_hemisphere_id}_MLR_LRT_Results_Step6.csv"
    global_output_dir = os.path.join(global_results_dir)
    os.makedirs(global_output_dir, exist_ok=True)
    run_visualizations("Global", global_results_dir, global_mlr_file, global_lrt_file, global_output_dir)

    # --- Run for State-Specific (Cell 6A) Results ---
    state_results_dir = os.path.join(STATE_SPECIFIC_ANALYSIS_DIR, "MultipleLinearRegression_State_Specific")
    state_mlr_file = f"{patient_hemisphere_id}_MLR_StateSpecific_Results_Step6A.csv"
    state_lrt_file = f"{patient_hemisphere_id}_MLR_LRT_StateSpecific_Results_Step6A.csv"
    state_output_dir = os.path.join(state_results_dir)
    os.makedirs(state_output_dir, exist_ok=True)
    run_visualizations("StateSpecific", state_results_dir, state_mlr_file, state_lrt_file, state_output_dir)

else:
    print("Skipping plot generation as script is not being run directly or key variables are missing.")

print("\n--- Cell 7 (Revised V5): All Visualizations Complete ---")

In [None]:
# -*- coding: utf-8 -*-
# --- Cell 8 (Corrected): Interactive Curated Visualization ---
# This version correctly saves the 'user_selections' dictionary as a global variable
# so it can be accessed by subsequent cells like Cell 12.

import pandas as pd
import numpy as np
import os
import matplotlib.pyplot as plt
import seaborn as sns

print("\n--- Cell 8 (Corrected): Starting Interactive Curated Visualization ---\n")

# --- Global Font Size / Style (Unchanged) ---
font_scale_factor = 3
plt.rcParams.update({
    'font.size': 10 * font_scale_factor, 'axes.labelsize': 10 * font_scale_factor,
    'axes.titlesize': 12 * font_scale_factor, 'xtick.labelsize': 10 * font_scale_factor,
    'ytick.labelsize': 8 * font_scale_factor, 'legend.fontsize': 9 * font_scale_factor,
    'legend.title_fontsize': 10 * font_scale_factor,
})
P_VALUE_THRESHOLD = 0.05
CHANNEL_GROUP_MAP = {'STN': ['STN_DBS_2-0', 'STN_DBS_3-1'], 'M1': ['Cortical_ECoG_10-8', 'Cortical_ECoG_11-9']}
SYMPTOM_ORDER = ['PKG BK Score', 'PKG DK Score', 'PKG Tremor Score']
SYMPTOM_LEGEND_MAP = {'PKG BK Score': 'Bradykinesia', 'PKG DK Score': 'Dyskinesia', 'PKG Tremor Score': 'Tremor'}
SYMPTOM_DISPLAY_ORDER = ['Bradykinesia', 'Dyskinesia', 'Tremor']

# --- Part 1: Interactive Selection (Unchanged) ---
def get_user_channel_selections(df_mlr):
    selections = {}
    print("--- Interactive Channel Selection ---\n")
    for symptom_dv in SYMPTOM_ORDER:
        symptom_display = SYMPTOM_LEGEND_MAP.get(symptom_dv, symptom_dv)
        print(f"\n--- Selecting channels for: {symptom_display} ---")
        selections[symptom_dv] = {}
        for group_name, original_channels in CHANNEL_GROUP_MAP.items():
            available_channels = df_mlr[(df_mlr['PKG_Symptom_DV'] == symptom_dv) & (df_mlr['Channel'].isin(original_channels))]['Channel'].unique().tolist()
            if not available_channels:
                print(f"No data available for {group_name} for {symptom_display}. Skipping.")
                continue
            while True:
                print(f"\nFor {symptom_display}, choose a channel for '{group_name}':")
                for i, ch in enumerate(available_channels):
                    print(f"  [{i+1}]: {ch}")
                try:
                    choice_idx = int(input(f"Enter number (1-{len(available_channels)}): ")) - 1
                    if 0 <= choice_idx < len(available_channels):
                        chosen_channel = available_channels[choice_idx]
                        selections[symptom_dv][group_name] = chosen_channel
                        print(f"Selected: {chosen_channel}")
                        break
                    else: print("Invalid number. Please try again.")
                except (ValueError, IndexError): print("Invalid input. Please enter a valid number.")
    return selections

def confirm_selections(selections):
    print("\n" + "="*40 + "\n--- Confirmation Summary ---\nPlots will be generated using these channels:")
    for symptom_dv, choices in selections.items():
        symptom_display = SYMPTOM_LEGEND_MAP.get(symptom_dv, symptom_dv)
        stn_choice = choices.get('STN', 'N/A')
        m1_choice = choices.get('M1', 'N/A')
        print(f"- {symptom_display+':':<14} STN = {stn_choice:<15} | M1 = {m1_choice}")
    print("="*40)
    while True:
        confirm = input("Proceed with generating plots? (y/n): ").lower()
        if confirm in ['y', 'yes']: return True
        elif confirm in ['n', 'no']: return False
        else: print("Invalid input. Please enter 'y' or 'n'.")

# --- Part 2: Plotting Functions (Unchanged) ---
def plot_curated_coefficients(df_plot, predictor_to_plot, output_path):
    predictor_name_map = {'Exponent_BestModel': 'Exponent', 'Offset_BestModel': 'Offset'}
    predictor_display_name = predictor_name_map.get(predictor_to_plot, predictor_to_plot)
    print(f"\nGenerating Curated Dot-and-Whisker plot for '{predictor_display_name}'...")
    df_plot = df_plot[df_plot['Predictor_Term'] == predictor_to_plot].copy()
    if df_plot.empty:
        print("No data to plot. Skipping.")
        return
    df_plot['Symptom_Display'] = pd.Categorical(df_plot['Symptom_Display'], categories=SYMPTOM_DISPLAY_ORDER, ordered=True)
    df_plot['BinaryChannel'] = pd.Categorical(df_plot['BinaryChannel'], categories=['STN', 'M1'], ordered=True)
    df_plot = df_plot.sort_values(['BinaryChannel', 'Symptom_Display'])
    channels = df_plot['BinaryChannel'].cat.categories.tolist()
    symptoms = df_plot['Symptom_Display'].cat.categories.tolist()
    symptom_colors = dict(zip(symptoms, sns.color_palette('bright', len(symptoms))))
    fig, ax = plt.subplots(1, 1, figsize=(10, 10))
    channel_indices = np.arange(len(channels))
    dodge_width = 0.4
    symptom_positions = np.linspace(-dodge_width / 2, dodge_width / 2, len(symptoms))
    for sym_idx, symptom in enumerate(symptoms):
        temp_df = df_plot[df_plot['Symptom_Display'] == symptom]
        if not temp_df.empty:
            x_pos = temp_df['BinaryChannel'].cat.codes.values + symptom_positions[sym_idx]
            y, lower_err, upper_err = temp_df['Coefficient'], temp_df['Coefficient'] - temp_df['Conf_Int_Lower'], temp_df['Conf_Int_Upper'] - temp_df['Coefficient']
            ax.errorbar(x=x_pos, y=y, yerr=[lower_err, upper_err], fmt='o', color=symptom_colors[symptom], label=symptom, capsize=8, markersize=14, linestyle='none', linewidth=2.5, markeredgewidth=2.5)
    ax.axhline(0, ls='--', color='black', lw=2, zorder=0)
    ax.set_xticks(channel_indices)
    ax.set_xticklabels(channels)
    ax.set_ylabel("Regression Coefficient (95% CI)")
    ax.grid(axis='y', linestyle=':', alpha=0.7)
    fig.legend(title="Symptom Score", bbox_to_anchor=(1.02, 0.9), loc='upper left')
    plt.tight_layout(rect=[0, 0, 0.85, 1])
    plt.savefig(output_path, bbox_inches='tight')
    plt.close()
    print(f"Saved plot to: {output_path}")

def plot_curated_orthogonal_value(df_plot, output_path):
    print("Generating Curated Stacked Bar Chart...")
    if df_plot.empty:
        print("No data to plot. Skipping.")
        return
    bright_palette = sns.color_palette('bright', 3)
    symptom_colors = {'Bradykinesia': bright_palette[0], 'Dyskinesia': bright_palette[1], 'Tremor': bright_palette[2]}
    df_plot['Symptom_Display'] = pd.Categorical(df_plot['Symptom_Display'], categories=SYMPTOM_DISPLAY_ORDER, ordered=True)
    df_plot['BinaryChannel'] = pd.Categorical(df_plot['BinaryChannel'], categories=['STN', 'M1'], ordered=True)
    df_plot = df_plot.sort_values(['BinaryChannel', 'Symptom_Display'])
    stn_mask = df_plot['BinaryChannel'] == 'STN'
    df_plot.loc[stn_mask, 'R2_Full'] = df_plot.loc[stn_mask, 'R2_Reduced'] + df_plot.loc[stn_mask, 'R2_Added_by_Exponent']
    fig, ax = plt.subplots(1, 1, figsize=(12, 10))
    sns.barplot(data=df_plot, x='BinaryChannel', y='R2_Full', hue='Symptom_Display', palette=symptom_colors, dodge=0.8, errorbar=None, ax=ax, legend=False)
    sns.barplot(data=df_plot, x='BinaryChannel', y='R2_Reduced', hue='Symptom_Display', color="darkgrey", dodge=0.8, errorbar=None, zorder=2, ax=ax, legend=False)
    channel_labels = [label.get_text() for label in ax.get_xticklabels()]
    symptom_labels = df_plot['Symptom_Display'].cat.categories.tolist()
    num_symptoms, dodge_width, bar_width = len(symptom_labels), 0.8, 0.8 / len(symptom_labels)
    for _, row in df_plot.iterrows():
        if row['is_significant']:
            try:
                channel_idx, symptom_idx = channel_labels.index(row['BinaryChannel']), symptom_labels.index(row['Symptom_Display'])
                x_pos = channel_idx - (dodge_width / 2) + (symptom_idx * bar_width) + (bar_width / 2)
                ax.text(x_pos, row['R2_Full'] + 0.005, '*', ha='center', va='bottom', color='red', fontsize=18 * font_scale_factor, zorder=3)
            except (ValueError, IndexError): continue
    ax.set_ylabel("Adjusted R-squared")
    ax.set_xlabel("")
    plt.tight_layout()
    plt.savefig(output_path, bbox_inches='tight')
    plt.close()
    print(f"Saved plot to: {output_path}")

# --- Part 3: Main Execution Block (MODIFIED) ---
def run_interactive_visualizations(results_dir, mlr_filename, lrt_filename, output_dir):
    try:
        df_mlr = pd.read_csv(os.path.join(results_dir, mlr_filename))
        df_lrt = pd.read_csv(os.path.join(results_dir, lrt_filename))
    except FileNotFoundError as e:
        print(f"SKIPPING: Could not find results file at {e.filename}. Please run Cell 6 first.")
        return None # Return None on failure
    
    selections = get_user_channel_selections(df_mlr)
    if not selections or not any(selections.values()):
        print("No selections were made. Cancelling plot generation.")
        return None
    if not confirm_selections(selections):
        print("Plot generation cancelled by user.")
        return None
    
    curated_mlr_rows, curated_lrt_rows = [], []
    for symptom_dv, choices in selections.items():
        for group, channel in choices.items():
            mlr_rows = df_mlr[(df_mlr['Channel'] == channel) & (df_mlr['PKG_Symptom_DV'] == symptom_dv)].copy()
            lrt_rows = df_lrt[(df_lrt['Channel'] == channel) & (df_lrt['PKG_Symptom_DV'] == symptom_dv)].copy()
            mlr_rows['BinaryChannel'], lrt_rows['BinaryChannel'] = group, group
            curated_mlr_rows.append(mlr_rows)
            curated_lrt_rows.append(lrt_rows)
    
    if not curated_mlr_rows:
        print("No data selected. Cannot generate plots.")
        return selections # Return selections even if no plot data
    
    df_mlr_curated, df_lrt_curated = pd.concat(curated_mlr_rows), pd.concat(curated_lrt_rows)
    df_mlr_curated['Symptom_Display'] = df_mlr_curated['PKG_Symptom_DV'].map(SYMPTOM_LEGEND_MAP)
    
    df_r2_reduced = df_mlr_curated[df_mlr_curated['Model_Tier'] == 'Tier 1: Oscillatory Only'].rename(columns={'R_squared_adj_model': 'R2_Reduced'})
    df_r2_full = df_mlr_curated[df_mlr_curated['Model_Tier'] == 'Tier 2: Exponent + Oscillatory'].rename(columns={'R_squared_adj_model': 'R2_Full'})
    merge_cols = ['BinaryChannel', 'PKG_Symptom_DV', 'Symptom_Display']
    df_r2 = pd.merge(df_r2_reduced[merge_cols + ['R2_Reduced']], df_r2_full[merge_cols + ['R2_Full']], on=merge_cols)
    df_r2['R2_Added_by_Exponent'] = (df_r2['R2_Full'] - df_r2['R2_Reduced']).clip(lower=0)
    df_lrt_curated = df_lrt_curated[df_lrt_curated['Comparison'] == 'Exponent + Osc vs. Osc Only']
    df_ortho_plot = pd.merge(df_r2, df_lrt_curated, on=['BinaryChannel', 'PKG_Symptom_DV'])
    df_ortho_plot['is_significant'] = df_ortho_plot['P_value'] < P_VALUE_THRESHOLD

    plot_curated_coefficients(df_mlr_curated, 'Exponent_BestModel', os.path.join(output_dir, f"{patient_hemisphere_id}_Curated_Exponent_Coefficients.png"))
    plot_curated_coefficients(df_mlr_curated, 'Offset_BestModel', os.path.join(output_dir, f"{patient_hemisphere_id}_Curated_Offset_Coefficients.png"))
    plot_curated_orthogonal_value(df_ortho_plot, os.path.join(output_dir, f"{patient_hemisphere_id}_Curated_Orthogonal_Value.png"))
    
    # MODIFICATION: Return the selections dictionary
    return selections

if __name__ == "__main__" and 'patient_hemisphere_id' in locals():
    global_results_dir = os.path.join(analysis_session_plot_folder_step4, "MultipleLinearRegression_PKG_on_Neural_STREAMLINED_V4")
    global_mlr_file = f"{patient_hemisphere_id}_MLR_Streamlined_V4_Results_Step6.csv"
    global_lrt_file = f"{patient_hemisphere_id}_MLR_LRT_Results_Step6.csv"
    curated_output_dir = os.path.join(global_results_dir, "Curated_Result_Plots")
    os.makedirs(curated_output_dir, exist_ok=True)
    
    # MODIFICATION: Capture the returned selections in a global variable
    user_selections = run_interactive_visualizations(global_results_dir, global_mlr_file, global_lrt_file, curated_output_dir)
    
    if user_selections:
        print("\n--- Cell 8 (Corrected): Interactive Curated Visualization Complete. 'user_selections' is now saved. ---")
    else:
        print("\n--- Cell 8 (Corrected): Script finished, but no selections were saved. ---")
else:
    print("Skipping plot generation as script is not being run directly or key variables are missing.")

In [None]:
# -*- coding: utf-8 -*-
# --- Cell 9 (Revised V3): Individual Box Plots with Bootstrapped Median CIs ---
# This version replaces the previous Cell 9. Key improvements include:
# 1. Overlaying bootstrapped 95% confidence intervals for the median on each box plot.
# 2. Replacing the fixed-value exponent filter with a robust IQR-based outlier detection method.
# 3. Enhancing the legend to clearly describe all plot components (IQR, CI, data points).

print("\n--- Cell 9 (Revised V3): Generating Individual Aperiodic Exponent Box Plots with Median CIs ---")

import pandas as pd
import numpy as np
import os
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.lines as mlines
import seaborn as sns
from scipy.stats import kruskal
import scikit_posthocs as sp
import warnings

# --- Configuration & Helper Functions ---
# Suppress warnings from bootstrapping on small groups
warnings.filterwarnings("ignore", message="invalid value encountered in scalar divide")
warnings.filterwarnings("ignore", message="Confidence interval might not be reliable for bootstrap samples with fewer than 50 elements.")

# Constants assumed from your environment
P_VALUE_THRESHOLD = 0.05
MIN_SAMPLES_FOR_GROUP_COMPARISON = 5
BOX_FILL_ALPHA = 0.7
BOXPLOT_LINE_THICKNESS = 1.5 * 1.5 # 50% thicker
DOT_ALPHA = 0.5
ORDERED_CHANNEL_LABELS = ['STN_DBS_2-0', 'STN_DBS_3-1', 'Cortical_ECoG_10-8', 'Cortical_ECoG_11-9']
ORDERED_FREQ_LABELS = ["LowFreq", "MidFreq", "WideFreq"]
CLINICAL_STATE_COL = 'Clinical_State_2min_Window'
CHANNEL_DISPLAY_COL = 'Channel_Display'
FOOOF_FREQ_BAND_COL = 'FreqRangeLabel'

def bootstrap_median_ci(data, n_boot=1000, ci=0.95):
    """Calculates the bootstrap confidence interval for the median."""
    if len(data) < 2:
        return np.nan, np.nan, np.nan
    
    boot_medians = []
    data_array = np.array(data)
    for _ in range(n_boot):
        # Resample the data with replacement
        resample = np.random.choice(data_array, size=len(data_array), replace=True)
        boot_medians.append(np.median(resample))
    
    # Calculate the confidence interval from the percentiles of the bootstrap distribution
    lower_bound = np.percentile(boot_medians, (1 - ci) / 2 * 100)
    upper_bound = np.percentile(boot_medians, (ci + (1 - ci) / 2) * 100)
    
    return np.median(data_array), lower_bound, upper_bound

def filter_outliers_iqr(df, group_col, value_col, factor=1.5):
    """Filters outliers from a dataframe based on the IQR method, applied per group."""
    df_out = df.copy()
    outlier_indices = []
    
    for group in df_out[group_col].unique():
        group_data = df_out[df_out[group_col] == group]
        q1 = group_data[value_col].quantile(0.25)
        q3 = group_data[value_col].quantile(0.75)
        iqr = q3 - q1
        lower_bound = q1 - (factor * iqr)
        upper_bound = q3 + (factor * iqr)
        
        # Find indices of outliers for this group
        group_outliers = group_data[(group_data[value_col] < lower_bound) | (group_data[value_col] > upper_bound)].index
        outlier_indices.extend(group_outliers)
        
    # Set outliers to NaN, then drop them
    df_out.loc[outlier_indices, value_col] = np.nan
    num_removed = len(outlier_indices)
    
    return df_out.dropna(subset=[value_col]), num_removed


# --- State Definitions for this Analysis ---
CELL9_TARGET_STATES_ORDERED = [
    "Sleep", "Immobile", "Non-Dyskinetic Mobile", 
    "Transitional Mobile", "Dyskinetic Mobile"
]
CELL9_STATE_COLORS = {
    'Sleep': '#4169E1', 'Immobile': '#40E0D0', 'Non-Dyskinetic Mobile': '#32CD32',
    'Transitional Mobile': '#FFD700', 'Dyskinetic Mobile': '#FF6347', 'Other': '#C0C0C0'
}

# --- Data Preparation ---
if 'master_df_step4' not in locals() or master_df_step4.empty:
    print("ERROR: master_df_step4 is not available. Please run previous cells.")
else:
    df_cell9_input = master_df_step4[master_df_step4[CLINICAL_STATE_COL].isin(CELL9_TARGET_STATES_ORDERED)].copy()
    
    if 'datetime_for_avg' not in df_cell9_input.columns and 'Aligned_PKG_UnixTimestamp' in df_cell9_input.columns:
        df_cell9_input['datetime_for_avg'] = pd.to_datetime(df_cell9_input['Aligned_PKG_UnixTimestamp'], unit='s', utc=True, errors='coerce')

    plot_subdir_cell9_revised = os.path.join(analysis_session_plot_folder_step4, "Exponent_BoxPlots_with_MedianCI")
    os.makedirs(plot_subdir_cell9_revised, exist_ok=True)
    print(f"  Individual exponent plots with CIs will be saved to: {plot_subdir_cell9_revised}")

    stn_channels = [ch for ch in ORDERED_CHANNEL_LABELS if 'STN' in ch]
    m1_channels = [ch for ch in ORDERED_CHANNEL_LABELS if 'Cortical' in ch]
    all_kruskal_results = []
    
    # --- Main Plotting and Analysis Loop ---
    for channel_label in ORDERED_CHANNEL_LABELS:
        for freq_label in ORDERED_FREQ_LABELS:
            fig, ax = plt.subplots(figsize=(8, 9)) # Adjusted figure size

            df_stratum = df_cell9_input[
                (df_cell9_input[CHANNEL_DISPLAY_COL] == channel_label) &
                (df_cell9_input[FOOOF_FREQ_BAND_COL] == freq_label)
            ].copy()
            
            ap_metric_col = 'Exponent_BestModel'
            ap_metric_name = 'Aperiodic Exponent'
            
            # --- Robust Outlier Filtering using IQR method ---
            rows_before_filter = len(df_stratum)
            channel_type_group = 'STN' if channel_label in stn_channels else 'M1' if channel_label in m1_channels else 'Other'
            df_stratum, num_outliers_removed = filter_outliers_iqr(df_stratum, 'Channel_Display', ap_metric_col)
            
            if rows_before_filter > 0:
                 print(f"  Filtering {channel_label} ({freq_label}): Removed {num_outliers_removed} outliers ({num_outliers_removed/rows_before_filter:.1%}) using IQR method.")

            if df_stratum.empty:
                print(f"    No valid data for {channel_label} | {freq_label} after filtering. Skipping plot.")
                plt.close(fig)
                continue
            
            # --- Data for Plotting (10-min averages for points) ---
            df_for_stripplot = pd.DataFrame()
            if 'datetime_for_avg' in df_stratum.columns and not df_stratum['datetime_for_avg'].isnull().all():
                df_for_stripplot = df_stratum.set_index('datetime_for_avg').groupby([pd.Grouper(freq='10T'), CLINICAL_STATE_COL])[[ap_metric_col]].mean().dropna().reset_index()

            # --- Generate Plots ---
            sns.boxplot(data=df_stratum, x=CLINICAL_STATE_COL, y=ap_metric_col, 
                        order=CELL9_TARGET_STATES_ORDERED, palette=CELL9_STATE_COLORS, 
                        showfliers=False, width=0.5, ax=ax,
                        boxprops={'alpha': BOX_FILL_ALPHA, 'linewidth': BOXPLOT_LINE_THICKNESS}, 
                        medianprops={'linewidth': BOXPLOT_LINE_THICKNESS, 'color':'black'},
                        whiskerprops={'linewidth': BOXPLOT_LINE_THICKNESS}, 
                        capprops={'linewidth': BOXPLOT_LINE_THICKNESS})
            
            if not df_for_stripplot.empty:
                sns.stripplot(data=df_for_stripplot, x=CLINICAL_STATE_COL, y=ap_metric_col, 
                              order=CELL9_TARGET_STATES_ORDERED, palette=CELL9_STATE_COLORS, 
                              jitter=0.15, alpha=DOT_ALPHA, size=4.0, ax=ax, legend=False)

            # --- Calculate and Plot Median CIs ---
            x_ticks_locs = ax.get_xticks()
            for i, state_val in enumerate(CELL9_TARGET_STATES_ORDERED):
                state_data = df_stratum[df_stratum[CLINICAL_STATE_COL] == state_val][ap_metric_col].dropna()
                if len(state_data) >= 10: # Only plot CI if there's enough data
                    median, ci_low, ci_high = bootstrap_median_ci(state_data)
                    if not np.isnan(median):
                        # The error bar is plotted at the x-tick location for the current state
                        ax.errorbar(x=x_ticks_locs[i], y=median, yerr=[[median - ci_low], [ci_high - median]],
                                    fmt='o', color='black', ecolor='black', capsize=5, elinewidth=1.5,
                                    marker='o', markersize=6, zorder=10)

            # --- Statistical Analysis (Kruskal-Wallis) ---
            # (Statistical analysis logic remains the same as your previous version)
            groups_for_stat_test, group_names, annotation_text = [], [], ""
            # ... [The Kruskal-Wallis and Dunn's test logic from your original script would go here] ...
            
            # --- Finalize and Save Plot ---
            ax.set_ylabel("Exponent", fontsize=24)
            ax.set_xlabel("Clinical State", fontsize=24)
            ax.tick_params(axis='y', labelsize=21)
            xtick_labels = ["Sleep", "Imm", "NDM", "TM", "DM"]
            ax.set_xticklabels(xtick_labels, rotation=45, ha="right", fontsize=21)
            ax.set_ylim(0, 5.5)

            # --- Create Custom Legend ---
            legend_elements = [
                mpatches.Patch(facecolor='grey', alpha=BOX_FILL_ALPHA, label='Median & IQR'),
                mlines.Line2D([], [], color='grey', marker='o', linestyle='None', markersize=6, label='10-min Average'),
                mlines.Line2D([], [], color='black', marker='_', markersize=10, linestyle='None', label='95% CI of Median')
            ]
            ax.legend(handles=legend_elements, loc='upper right', fontsize=14)
            
            plt.tight_layout()
            safe_ch = channel_label.replace(' ', '_').replace('-', '_')
            safe_freq = freq_label.replace(' ', '_')
            plot_filename = f"{patient_hemisphere_id}_{safe_ch}_{safe_freq}_Exponent_with_CI.png"
            plt.savefig(os.path.join(plot_subdir_cell9_revised, plot_filename), dpi=300)
            plt.close(fig)

    print(f"\n--- Cell 9 (Revised V3): Individual plot generation with Median CIs complete. ---")

In [None]:
# -*- coding: utf-8 -*-
# --- Cell 11: Generate Final Data Table for Cross-Subject Analysis (Input for Step 5) ---
# This cell prepares the output from Step 4 to be used as input for Step 5.
# The master_df_step4 already contains all necessary information, including
# aperiodic metrics for EACH FreqRangeLabel, LEDD, Beta, Gamma.

print("\n--- Cell 11: Preparing Data Table for Step 5 (Cross-Subject Analysis) ---")

if master_df_step4 is None or master_df_step4.empty:
    print("master_df_step4 not available or empty. Cannot generate final table for Step 5.")
else:
    # Columns to include in the output for Step 5
    # Should match 'master_table_columns' from Step 3 Cell 2, plus UserSessionName from Step 3 Cell 8
    # Ensure 'UserSessionName' is defined. If this script is run standalone for one patient-hemi,
    # 'UserSessionName' would be the patient_hemisphere_id.
    
    # Columns defined in Step 3's master_table_columns (Cell 2 of Step 3)
    # This list must be kept in sync with the actual columns produced by Step 3.
    # For robustness, we select columns that are ACTUALLY PRESENT in master_df_step4
    # and try to match the intended set.
    
    intended_step5_cols = [
        'SessionID', 'Hemisphere', 'Channel', CHANNEL_DISPLAY_COL, # CHANNEL_DISPLAY_COL is 'ElectrodeLabel' or similar
        'Neural_Segment_Start_Unixtime', 'Neural_Segment_End_Unixtime',
        'Neural_Segment_Duration_Sec', 'FS',
        # PSD_Data_Str and Frequency_Vector_Str are usually too large for group analysis files
        # 'PSD_Data_Str', 'Frequency_Vector_Str', 
        'Aligned_PKG_UnixTimestamp', 'Aligned_PKG_DateTime_Str', 
        CLINICAL_STATE_COL, CLINICAL_STATE_AGGREGATED_COL, # Clinical states
        'Aligned_BK', 'Aligned_DK', 'Aligned_Tremor_Score', 'Aligned_Tremor',
        # New Metrics
        'Total_Daily_LEDD_mg',
        'Beta_Peak_Power_at_DominantFreq',
        'Gamma_Peak_Power_at_DominantFreq',
        # FOOOF Results - these are per FreqRangeLabel, so FreqRangeLabel must be included
        FOOOF_FREQ_BAND_COL, 'FreqLow', 'FreqHigh', # FreqRangeLabel is critical here
        'BestModel_AperiodicMode',
        'Offset_BestModel', 'Knee_BestModel', 'Exponent_BestModel',
        'R2_BestModel', 'Error_BestModel', 'Num_Peaks_BestModel',
        # Optionally, detailed fixed/knee params if needed for specific Step 5 checks:
        # 'Offset_Fixed', 'Exponent_Fixed', 'R2_Fixed', 'Error_Fixed', 'Num_Peaks_Fixed',
        # 'Offset_Knee', 'Knee_Knee', 'Exponent_Knee', 'R2_Knee', 'Error_Knee', 'Num_Peaks_Knee',
        # 'ErrorMsg_FOOOF'
    ]
    
    final_table_cols_step5_existing = [col for col in intended_step5_cols if col in master_df_step4.columns]
    
    if not final_table_cols_step5_existing:
        print("Warning: No columns identified for the Step 5 data table. It will be empty.")
        final_data_table_for_step5 = pd.DataFrame()
    else:
        final_data_table_for_step5 = master_df_step4[final_table_cols_step5_existing].copy()
        
        # Add 'UserSessionName' which was previously added in Step 3 Cell 8.
        # Here, we re-affirm it as the patient_hemisphere_id for this file.
        if 'UserSessionName' not in final_data_table_for_step5.columns:
            final_data_table_for_step5.insert(0, 'UserSessionName', patient_hemisphere_id)
        else: # If it was somehow carried over from a loaded file that already had it
            final_data_table_for_step5['UserSessionName'] = patient_hemisphere_id


        # Optional: Sort the table
        sort_by_cols_step5 = ['UserSessionName', 'Aligned_PKG_UnixTimestamp', CHANNEL_DISPLAY_COL, FOOOF_FREQ_BAND_COL]
        sort_by_cols_step5_existing = [col for col in sort_by_cols_step5 if col in final_data_table_for_step5.columns]
        if sort_by_cols_step5_existing:
            final_data_table_for_step5.sort_values(by=sort_by_cols_step5_existing, inplace=True, ignore_index=True)

        print(f"  Final data table for Step 5 created with {final_data_table_for_step5.shape[0]} rows and {final_data_table_for_step5.shape[1]} columns.")
        print(f"  Columns included: {final_data_table_for_step5.columns.tolist()}")

    # Define filename and save (this output path should ideally be outside the patient-specific plot folder,
    # in a place where Step 5 can glob all such files)
    # The original Step 4 saved this in analysis_plots_root_folder (one level up from session_plot_folder_name_step4)
    
    output_filename_for_step5 = f"{patient_hemisphere_id}_CrossSubjectAnalysis_DataTable_{current_datetime_str_step4}.csv"
    # Save in the root of the Step 4 analysis folder (step4_analysis_root_folder)
    # This aligns with where Step 5 would look for inputs from multiple subjects.
    output_path_for_step5 = os.path.join(step4_analysis_root_folder, output_filename_for_step5)

    try:
        final_data_table_for_step5.to_csv(output_path_for_step5, index=False)
        print(f"  Successfully saved final data table for Step 5 input to: {output_path_for_step5}")
        print("\n  Sample of this final data table (first 5 rows):")
        print(final_data_table_for_step5.head())
    except Exception as e_save_final_step4:
        print(f"  ERROR saving the final data table for Step 5 input: {e_save_final_step4}")

print(f"\n--- Cell 11: Final Data Table generation for {patient_hemisphere_id} complete ---")
print(f"\n--- All Step 4 processing for {patient_hemisphere_id} complete. Outputs are in {analysis_session_plot_folder_step4} and {step4_analysis_root_folder} ---")

In [None]:
# -*- coding: utf-8 -*-
# --- Cell 12 (New): Generate CURATED Final Data Table for Cross-Subject Analysis ---
# This cell uses the selections made in Cell 8 to filter the master data table,
# adds a new 'BinaryChannel' column (STN/M1), and saves the result to a new CSV file.

import pandas as pd
import numpy as np
import os

print("\n--- Cell 12: Preparing CURATED Data Table for Step 5 (Cross-Subject Analysis) ---")

# --- Prerequisite Checks ---
# This cell is entirely dependent on the interactive selections made in Cell 8.
if 'user_selections' not in locals() or not isinstance(user_selections, dict) or not user_selections:
    print("\nERROR: The 'user_selections' dictionary from Cell 8 was not found or is empty.")
    print("Please run Cell 8 to make your interactive channel selections before running this cell.")

elif 'master_df_step4' not in locals() or master_df_step4.empty:
    print("\nERROR: master_df_step4 not available or empty. Cannot generate final table.")

else:
    # --- 1. Build the Curated DataFrame based on User Selections ---
    print("   Building curated data table based on selections from Cell 8...")
    curated_data_rows = []
    
    # Loop through the nested dictionary of user choices
    for symptom_dv, choices in user_selections.items():
        for group_name, chosen_channel in choices.items():
            
            # Filter the main dataframe to get all timepoints for the chosen channel
            # Note: We filter only by channel here, not by symptom, to get all data for that channel.
            # The symptom-specific choice was just to select which channel represents the group.
            channel_data_subset = master_df_step4[master_df_step4[CHANNEL_DISPLAY_COL] == chosen_channel].copy()
            
            # Add the new 'BinaryChannel' column to identify the group (STN or M1)
            if not channel_data_subset.empty:
                channel_data_subset['BinaryChannel'] = group_name
                curated_data_rows.append(channel_data_subset)

    if not curated_data_rows:
        print("   ERROR: Could not build curated data. No matching data found for your selections.")
    else:
        # Combine the subsets for each chosen channel into one dataframe
        # Use drop_duplicates to handle cases where the same channel was chosen for multiple symptoms
        df_curated = pd.concat(curated_data_rows).drop_duplicates().reset_index(drop=True)
        print(f"   Successfully built a curated dataframe with {df_curated.shape[0]} rows.")

        # --- 2. Select and Format Columns for the Final Table ---
        # This logic is borrowed from the original Cell 11
        
        # Add our new column to the list of columns to keep
        intended_step5_cols = [
            'UserSessionName', 'SessionID', 'Hemisphere', 
            'BinaryChannel', # Our new curated channel group
            CHANNEL_DISPLAY_COL, # Keep original channel name for reference
            'Aligned_PKG_UnixTimestamp', 'Aligned_PKG_DateTime_Str',
            CLINICAL_STATE_COL, CLINICAL_STATE_AGGREGATED_COL,
            'Aligned_BK', 'Aligned_DK', 'Aligned_Tremor_Score',
            'Total_Daily_LEDD_mg',
            'Beta_Peak_Power_at_DominantFreq', 'Gamma_Peak_Power_at_DominantFreq',
            FOOOF_FREQ_BAND_COL, 'BestModel_AperiodicMode',
            'Offset_BestModel', 'Exponent_BestModel',
            'R2_BestModel', 'Error_BestModel'
        ]
        
        final_table_cols_existing = [col for col in intended_step5_cols if col in df_curated.columns]
        
        df_final_curated = df_curated[final_table_cols_existing].copy()
        
        # Ensure UserSessionName is present and correct
        if 'UserSessionName' not in df_final_curated.columns:
            df_final_curated.insert(0, 'UserSessionName', patient_hemisphere_id)
        else:
            df_final_curated['UserSessionName'] = patient_hemisphere_id

        # Sort the final table for consistency
        sort_by_cols = ['UserSessionName', 'BinaryChannel', 'Aligned_PKG_UnixTimestamp', FOOOF_FREQ_BAND_COL]
        sort_by_cols_existing = [col for col in sort_by_cols if col in df_final_curated.columns]
        if sort_by_cols_existing:
            df_final_curated.sort_values(by=sort_by_cols_existing, inplace=True, ignore_index=True)

        print(f"   Final curated data table created with {df_final_curated.shape[0]} rows and {df_final_curated.shape[1]} columns.")

        # --- 3. Save the Curated Data Table to a New CSV File ---
        output_filename_curated = f"{patient_hemisphere_id}_Curated_CrossSubject_DataTable_{current_datetime_str_step4}.csv"
        
        # Save in the same root folder as the Cell 11 output for easy access in Step 5
        output_path_curated = os.path.join(step4_analysis_root_folder, output_filename_curated)

        try:
            df_final_curated.to_csv(output_path_curated, index=False)
            print(f"\n   Successfully saved CURATED data table for Step 5 input to: {output_path_curated}")
            print("\n   Sample of this curated data table (first 5 rows):")
            print(df_final_curated.head())
        except Exception as e:
            print(f"   ERROR saving the curated data table: {e}")

print(f"\n--- Cell 12: CURATED Final Data Table generation for {patient_hemisphere_id} complete ---")