In [None]:
# -*- coding: utf-8 -*-
# --- Step 4: Within-Subject Analysis (Revised) ---
# --- Cell 1: Imports and Global Styling Setup ---

import pandas as pd
import numpy as np
import os
import glob
import re # For natural sorting and safe filenames
import warnings
from datetime import datetime
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import seaborn as sns
from scipy.stats import pearsonr, spearmanr, mannwhitneyu # For bivariate correlations
import pingouin as pg # For partial correlations
import statsmodels.formula.api as smf # For multiple linear regression
from statsmodels.graphics.regressionplots import plot_partregress_grid # For partial regression plots

# --- Suppress Warnings ---
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=UserWarning)
warnings.filterwarnings("ignore", message="Degrees of freedom <= 0 for slice")
warnings.filterwarnings("ignore", message="p-value may not be accurate for N > 5000")

# --- Plotting Style and Global Parameters ---
sns.set_theme(style="whitegrid")

plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = ['Arial', 'DejaVu Sans']
plt.rcParams['axes.titlesize'] = 14     # Titles for individual subplots
plt.rcParams['axes.labelsize'] = 12     # X/Y axis labels
plt.rcParams['xtick.labelsize'] = 10    # X-axis tick labels
plt.rcParams['ytick.labelsize'] = 10    # Y-axis tick labels
plt.rcParams['legend.fontsize'] = 10    # Legend font size
plt.rcParams['figure.titlesize'] = 18   # For suptitle on overview (figure-level) plots
plt.rcParams['figure.dpi'] = 100        # For inline viewing
plt.rcParams['savefig.dpi'] = 600       # For saving figures

print("Cell 1: Imports and global styling setup complete.")

In [None]:
# -*- coding: utf-8 -*-
# --- Cell 2: User Input, File Paths, and Analysis Parameter Definitions ---

import os
from datetime import datetime
import re # For safe filenames

# --- User Input ---
# For development, hardcode; replace with input() in practice if needed for batch processing setup
# This script processes ONE patient-hemisphere at a time.
# The patient_hemisphere_id should match the identifier in the MASTER_FOOOF_PKG_results file name.
# Example: "RCS20R" if your file is MASTER_FOOOF_PKG_results_RCS20R_... .csv
patient_hemisphere_id = "RCS02L" # <<< USER: SET THIS FOR THE CURRENT PATIENT-HEMISPHERE
print(f"Processing data for Patient-Hemisphere ID: {patient_hemisphere_id}")

if not patient_hemisphere_id:
    raise ValueError("Patient-Hemisphere ID cannot be empty.")

# --- Path Configuration ---
user_home = os.path.expanduser("~")
project_base_path = ".."

# Path to the output folder from Step 3 (where MASTER_FOOOF_PKG_results... files are)
step3_output_version_tag = "neural_pkg_aligned" # <<< USER: Ensure this matches Step 3's tag
step3_master_csv_base_folder = os.path.join(project_base_path, 'Working', f'step3_fooof_results_{step3_output_version_tag}')

master_csv_filename = f"MASTER_FOOOF_PKG_results_{patient_hemisphere_id}_{step3_output_version_tag}.csv"
master_csv_path_to_load = os.path.join(step3_master_csv_base_folder, master_csv_filename)
print(f"Attempting to load master data from: {master_csv_path_to_load}")

# Output folder for Step 4 results and plots
step4_analysis_root_folder = os.path.join(step3_master_csv_base_folder, "step4_within_subject_analysis")
os.makedirs(step4_analysis_root_folder, exist_ok=True)

current_datetime_str_step4 = datetime.now().strftime('%Y%m%d_%H%M%S')
session_plot_folder_name_step4 = f"{patient_hemisphere_id}_plots_{current_datetime_str_step4}"
analysis_session_plot_folder_step4 = os.path.join(step4_analysis_root_folder, session_plot_folder_name_step4)
os.makedirs(analysis_session_plot_folder_step4, exist_ok=True)
print(f"Step 4 plots will be saved in: {analysis_session_plot_folder_step4}")

# --- Analysis Parameters ---
# These should align with what's available in your master CSV from Step 3
APERIODIC_METRICS_COLS = {
    'Exponent_BestModel': 'Aperiodic Exponent',
    'Offset_BestModel': 'Aperiodic Offset',
}
PKG_METRICS_COLS = {
    'Aligned_BK': 'PKG BK Score',
    'Aligned_DK': 'PKG DK Score',
    'Aligned_Tremor_Score': 'PKG Tremor Score'
}
OSCILLATORY_METRICS_COLS = {
    'Beta_Peak_Power_at_DominantFreq': 'Beta Peak Power',
    'Gamma_Peak_Power_at_DominantFreq': 'Gamma Peak Power'
}

CHANNEL_COL = 'Channel' # Raw channel key like 'TD_key0'
CHANNEL_DISPLAY_COL = 'Channel_Display' # User-friendly labels like 'STN_DBS_2-0'
FOOOF_FREQ_BAND_COL = 'FreqRangeLabel'
CLINICAL_STATE_COL = 'Clinical_State_2min_Window'
CLINICAL_STATE_AGGREGATED_COL = 'Clinical_State_Aggregated' # From Step 3

# Desired order for iterations and plotting
# ORDERED_CHANNEL_LABELS will be derived from data in Cell 3, or can be hardcoded here if preferred:
# e.g., ORDERED_CHANNEL_LABELS = ['STN_DBS_2-0', 'STN_DBS_3-1', 'Cortical_ECoG_10-8', 'Cortical_ECoG_11-9']
ORDERED_FREQ_LABELS = ["LowFreq", "MidFreq", "WideFreq"] # From Step 3 FOOOF bands

# --- Statistical Thresholds ---
MIN_SAMPLES_FOR_CORR = 5       # Minimum data points for a correlation to be considered reliable
P_VALUE_THRESHOLD = 0.05

# --- Styling Parameters (can be expanded) ---
COLOR_PALETTE_STEP4 = {
    'Exponent_BestModel': 'darkslateblue',
    'Offset_BestModel': 'mediumseagreen',
    'Beta_Peak_Power_at_DominantFreq': 'goldenrod',
    'Gamma_Peak_Power_at_DominantFreq': 'firebrick',
    'Aligned_BK': 'steelblue',
    'Aligned_DK': 'orangered',
    'Aligned_Tremor_Score': 'mediumpurple'
}
DOT_ALPHA_STEP4 = 0.5
REG_CI_ALPHA_STEP4 = 0.15
BOX_FILL_ALPHA_STEP4 = 0.6
REG_LINE_THICKNESS_STEP4 = 2.0

# For heatmaps
SIGNIFICANT_P_VAL_BG_COLOR_STEP4 = 'khaki' # Not directly used for heatmap cell color, but for annotations
DEFAULT_P_VAL_BG_COLOR_STEP4 = 'ivory'

print(f"\nStep 4 analysis parameters and paths configured for {patient_hemisphere_id}.")

In [None]:
# -*- coding: utf-8 -*-
# --- Cell 3: Data Loading and Initial Preprocessing ---

import pandas as pd
import numpy as np
import os
import re # For natural sort key

def load_and_preprocess_step4_data(file_path, patient_hemisphere_id_val):
    """Loads and preprocesses the master CSV file from Step 3 for Step 4 analyses."""
    if not os.path.exists(file_path):
        print(f"ERROR: Master CSV file from Step 3 not found at {file_path}")
        return None, []

    try:
        df = pd.read_csv(file_path)
        print(f"Successfully loaded {file_path}. Initial shape: {df.shape}")

        # --- Verify Patient ID consistency (optional but good check) ---
        if 'SessionID' in df.columns and df['SessionID'].nunique() == 1:
            csv_session_id = df['SessionID'].unique()[0]
            if csv_session_id != patient_hemisphere_id_val:
                print(f"Warning: SessionID in CSV ({csv_session_id}) differs from expected ({patient_hemisphere_id_val}). Proceeding with CSV data.")
        elif 'SessionID' not in df.columns:
             print(f"Warning: 'SessionID' column not found. Adding it based on patient_hemisphere_id_val: {patient_hemisphere_id_val}")
             df['SessionID'] = patient_hemisphere_id_val


        # --- Data Type Conversions & Cleaning ---
        cols_to_numeric = (
            list(APERIODIC_METRICS_COLS.keys()) +
            list(PKG_METRICS_COLS.keys()) +
            list(OSCILLATORY_METRICS_COLS.keys()) +
            ['Total_Daily_LEDD_mg', 'R2_BestModel', 'Error_BestModel', 'Num_Peaks_BestModel']
        )
        for col in cols_to_numeric:
            if col in df.columns:
                df[col] = pd.to_numeric(df[col], errors='coerce')
            else:
                print(f"Warning: Expected numeric column '{col}' not found in master_df.")
        
        # Ensure key categorical columns are strings
        for col in [CHANNEL_COL, CHANNEL_DISPLAY_COL, FOOOF_FREQ_BAND_COL, CLINICAL_STATE_COL, CLINICAL_STATE_AGGREGATED_COL, 'Hemisphere', 'BestModel_AperiodicMode']:
            if col in df.columns:
                df[col] = df[col].astype(str)
            else:
                print(f"Warning: Expected categorical column '{col}' not found.")


        # --- Filtering based on FOOOF fit quality (optional, as done in original Step 4) ---
        initial_rows = len(df)
        if 'R2_BestModel' in df.columns:
            r2_threshold_step4 = 0.5 # Example threshold
            df = df[df['R2_BestModel'] >= r2_threshold_step4].copy()
            print(f"Filtered by R2_BestModel >= {r2_threshold_step4}. Rows changed from {initial_rows} to {len(df)}.")
        else:
            print("Warning: 'R2_BestModel' column not found. Cannot filter by FOOOF fit quality.")

        # --- Create Channel_Display if not present (mapping from Step 3 should have done this) ---
        # This is a fallback
        if CHANNEL_DISPLAY_COL not in df.columns and CHANNEL_COL in df.columns:
            print(f"'{CHANNEL_DISPLAY_COL}' not found. Creating from '{CHANNEL_COL}'. Hardcoded map will be used if available.")
            # Hardcoded map (ensure this is consistent with Step 3 and your data)
            channel_mapping_step4 = {
                'TD_key0': 'STN_DBS_2-0', 'TD_key1': 'STN_DBS_3-1',
                'TD_key2': 'Cortical_ECoG_10-8', 'TD_key3': 'Cortical_ECoG_11-9'
            }
            df[CHANNEL_DISPLAY_COL] = df[CHANNEL_COL].map(channel_mapping_step4).fillna(df[CHANNEL_COL])


        # Determine ordered channel labels from the data if not hardcoded in Cell 2
        if CHANNEL_DISPLAY_COL in df.columns:
            def natural_sort_key_step4(s):
                return [int(text) if text.isdigit() else text.lower() for text in re.split('([0-9]+)', str(s))]
            
            # Use pre-defined channel mapping for ordering if available
            defined_channel_map_order = {
                'STN_DBS_2-0': 0, 'STN_DBS_3-1': 1, 
                'Cortical_ECoG_10-8': 2, 'Cortical_ECoG_11-9': 3
            }
            unique_ch_labels_data = df[CHANNEL_DISPLAY_COL].unique()
            # Filter and sort based on the defined order
            ordered_ch_labels_from_data = sorted(
                [ch for ch in unique_ch_labels_data if ch in defined_channel_map_order],
                key=lambda x: defined_channel_map_order[x]
            )
            # Add any channels from data not in defined_channel_map_order, sorted naturally
            ordered_ch_labels_from_data.extend(
                sorted([ch for ch in unique_ch_labels_data if ch not in defined_channel_map_order], 
                       key=natural_sort_key_step4)
            )
            print(f"Derived ORDERED_CHANNEL_LABELS for plots: {ordered_ch_labels_from_data}")
        else:
            print(f"ERROR: '{CHANNEL_DISPLAY_COL}' not found. Cannot determine channel order.")
            ordered_ch_labels_from_data = []
            
        # Drop rows where essential metrics for correlation/regression are NaN
        # This step is crucial as analysis loops will iterate through these.
        key_metrics_for_analysis = (
            list(APERIODIC_METRICS_COLS.keys()) +
            list(PKG_METRICS_COLS.keys()) +
            list(OSCILLATORY_METRICS_COLS.keys())
        )
        key_metrics_present = [col for col in key_metrics_for_analysis if col in df.columns]
        
        rows_before_na_essential_drop = len(df)
        if key_metrics_present:
            df.dropna(subset=key_metrics_present, how='any', inplace=True) # Drop if ANY of these are NaN for a row
            print(f"Dropped rows with NaNs in any of {key_metrics_present}. Rows changed from {rows_before_na_essential_drop} to {len(df)}.")
        else:
            print("Warning: No key metrics found to check for NaNs. Data might be incomplete.")


        print(f"Final master_df for Step 4 shape: {df.shape}")
        if df.empty:
            print("Warning: DataFrame is empty after preprocessing. Subsequent analyses might fail.")
        
        return df, ordered_ch_labels_from_data

    except Exception as e:
        print(f"Error loading or processing the CSV file '{file_path}' for Step 4: {e}")
        print(traceback.format_exc())
        return None, []

# --- Load and Preprocess Data ---
master_df_step4, ORDERED_CHANNEL_LABELS = load_and_preprocess_step4_data(master_csv_path_to_load, patient_hemisphere_id)

# --- START OF NEW CODE BLOCK for Redefining Clinical States ---
if master_df_step4 is not None and not master_df_step4.empty and \
   'Aligned_BK' in master_df_step4.columns and 'Aligned_DK' in master_df_step4.columns:

    print("\\n--- Redefining Clinical States Point-by-Point ---")

    # Ensure Aligned_BK and Aligned_DK are numeric (should be done by load_and_preprocess_step4_data already)
    master_df_step4['Aligned_BK'] = pd.to_numeric(master_df_step4['Aligned_BK'], errors='coerce')
    master_df_step4['Aligned_DK'] = pd.to_numeric(master_df_step4['Aligned_DK'], errors='coerce')

    # 1. Define criteria for "general mobile candidate"
    # These are based on the original Step 3 Cell 3a logic before windowing
    is_general_mobile_candidate_series = (
        (master_df_step4['Aligned_BK'] <= 26) | (master_df_step4['Aligned_DK'] >= 7)
    )
    
    # 2. Isolate "Mobile Candidate" Data for Percentile Calculation
    df_mobile_for_percentiles = master_df_step4[is_general_mobile_candidate_series & \
                                                master_df_step4['Aligned_DK'].notna()].copy()

    p30_dk_mobile_only = np.nan
    p70_dk_mobile_only = np.nan

    if not df_mobile_for_percentiles.empty and len(df_mobile_for_percentiles['Aligned_DK'].dropna()) > 1: # Need at least 2 points for percentile
        p30_dk_mobile_only = np.percentile(df_mobile_for_percentiles['Aligned_DK'].dropna(), 30)
        p70_dk_mobile_only = np.percentile(df_mobile_for_percentiles['Aligned_DK'].dropna(), 70)
        print(f"  Calculated new mobile-only DK percentiles: p30={p30_dk_mobile_only:.2f}, p70={p70_dk_mobile_only:.2f}")
    else:
        print("  Warning: Not enough 'mobile candidate' data with valid DK scores to calculate new percentiles. States might not be redefined accurately.")

    # 3. Define a function for point-by-point state assignment
    def redefine_clinical_state_pointwise(row, p30_mobile, p70_mobile):
        bk = row['Aligned_BK']
        dk = row['Aligned_DK']

        if pd.isna(bk) or pd.isna(dk):
            return "Other" # Or some other placeholder for missing PKG data

        # Original criteria from Step 3 Cell 3a
        is_sleep = (bk >= 80)
        is_immobile = (bk > 26) & (bk < 80) & (dk < 7)
        is_general_mobile = (bk <= 26) | (dk >= 7)

        if is_sleep:
            return "Sleep"
        elif is_immobile:
            return "Immobile"
        elif is_general_mobile:
            if pd.notna(p30_mobile) and pd.notna(p70_mobile): # Only assign refined if percentiles are valid
                if dk <= p30_mobile:
                    return "Non-Dyskinetic Mobile"
                elif dk > p70_mobile:
                    return "Dyskinetic Mobile"
                else: # dk is > p30_mobile and <= p70_mobile
                    return "Transitional Mobile"
            else: # Fallback if percentiles could not be calculated
                return "Mobile (Generic)" # Fallback for mobile if percentiles are NaN
        else:
            return "Other"

    # 4. Apply the function to Re-calculate Clinical_State_2min_Window
    # CLINICAL_STATE_COL is 'Clinical_State_2min_Window' as defined in Cell 2
    master_df_step4[CLINICAL_STATE_COL] = master_df_step4.apply(
        lambda row: redefine_clinical_state_pointwise(row, p30_dk_mobile_only, p70_dk_mobile_only), axis=1
    )
    print(f"  '{CLINICAL_STATE_COL}' column has been redefined point-by-point.")

    # 5. Update Clinical_State_Aggregated
    # CLINICAL_STATE_AGGREGATED_COL is 'Clinical_State_Aggregated'
    # REFINED_MOBILE_STATES were Sleep, Immobile, Dyskinetic Mobile, Non-Dyskinetic Mobile, Transitional Mobile, Other
    # We need a list of the new mobile states for aggregation
    new_refined_mobile_states = ["Non-Dyskinetic Mobile", "Transitional Mobile", "Dyskinetic Mobile"]
    
    def aggregate_new_mobile_states(state_redefined):
        if state_redefined in new_refined_mobile_states:
            return "Mobile (All Types)" # This name is from previous logic, can be adjusted
        return state_redefined

    master_df_step4[CLINICAL_STATE_AGGREGATED_COL] = master_df_step4[CLINICAL_STATE_COL].apply(aggregate_new_mobile_states)
    print(f"  '{CLINICAL_STATE_AGGREGATED_COL}' column has been updated based on new states.")

    # 6. Print value counts of the new distribution
    print("\\n  New distribution of 'Clinical_State_2min_Window':")
    print(master_df_step4[CLINICAL_STATE_COL].value_counts(dropna=False))
    print("\\n  New distribution of 'Clinical_State_Aggregated':")
    print(master_df_step4[CLINICAL_STATE_AGGREGATED_COL].value_counts(dropna=False))

else:
    if master_df_step4 is None or master_df_step4.empty:
        print("\\nSkipping clinical state redefinition because master_df_step4 is not loaded or is empty.")
    else:
        print("\\nSkipping clinical state redefinition because 'Aligned_BK' or 'Aligned_DK' columns are missing.")

# --- END OF NEW CODE BLOCK ---

if master_df_step4 is not None and not master_df_step4.empty:
    print("\nFirst 5 rows of the processed master DataFrame for Step 4 (showing key cols):")
    cols_to_show_step4 = [CHANNEL_DISPLAY_COL, FOOOF_FREQ_BAND_COL] + \
                         list(APERIODIC_METRICS_COLS.keys()) + \
                         list(PKG_METRICS_COLS.keys()) + \
                         list(OSCILLATORY_METRICS_COLS.keys()) + \
                         ['Total_Daily_LEDD_mg']
    cols_to_show_step4_present = [col for col in cols_to_show_step4 if col in master_df_step4.columns]
    print(master_df_step4[cols_to_show_step4_present].head())
else:
    print("Halting Step 4 script as master data could not be loaded or is empty after preprocessing.")
    # sys.exit() # Uncomment to halt execution if master_df_step4 is not loaded

print("\nCell 3: Data loading and preprocessing complete for Step 4.")

In [None]:
# -*- coding: utf-8 -*-
# --- Cell 4: Helper Functions for Step 4 ---

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches # For legend patches if needed
from scipy.stats import spearmanr
import pingouin as pg # For partial correlation

def calculate_spearman_with_n(data_df, col1, col2, min_samples=MIN_SAMPLES_FOR_CORR):
    """Calculates Spearman correlation if N >= min_samples."""
    pair_data = data_df[[col1, col2]].dropna()
    n_points = len(pair_data)

    if n_points < min_samples:
        return np.nan, np.nan, n_points # rho, p-value, N
    try:
        rho, p_value = spearmanr(pair_data[col1], pair_data[col2])
        if np.isnan(rho): # Handle cases where spearmanr might return NaN (e.g., no variance)
            return np.nan, np.nan, n_points
        return rho, p_value, n_points
    except ValueError: # Handle other potential errors like constant input
        return np.nan, np.nan, n_points

def calculate_partial_spearman(data_df, x_col, y_col, covar_cols, min_samples=MIN_SAMPLES_FOR_CORR):
    """Calculates partial Spearman correlation if N >= min_samples."""
    all_cols_for_partial = [x_col, y_col] + covar_cols
    partial_data = data_df[all_cols_for_partial].dropna()
    n_points = len(partial_data)

    if n_points < min_samples:
        return np.nan, np.nan, n_points # partial_rho, p-value, N
    try:
        # Ensure all columns for partial corr are numeric and not constant after dropna
        if not all(partial_data[col].nunique() > 1 for col in all_cols_for_partial if col in partial_data):
            # print(f"Warning: Constant column found in data for partial corr {x_col} vs {y_col}. N={n_points}")
            return np.nan, np.nan, n_points
            
        pcorr_result = pg.partial_corr(data=partial_data, x=x_col, y=y_col, covar=covar_cols, method='spearman')
        rho = pcorr_result['r'].iloc[0]
        p_value = pcorr_result['p-val'].iloc[0]
        return rho, p_value, n_points
    except Exception as e: # Catch any error during partial correlation
        # print(f"Error in partial correlation for {x_col} vs {y_col} (N={n_points}): {e}")
        return np.nan, np.nan, n_points


def annotate_correlation_on_plot(ax, rho, p_value, N_val, test_type="Spearman ρ", 
                                 x_pos=0.97, y_pos=0.97, fontsize=9,
                                 sig_threshold=P_VALUE_THRESHOLD):
    """Annotates correlation statistics on a plot axis."""
    if pd.isna(rho) or pd.isna(p_value):
        stat_text = f"{test_type}: N/A (N={N_val})"
        bg_color = DEFAULT_P_VAL_BG_COLOR_STEP4 # from Cell 2
    else:
        stars = ""
        if p_value < 0.001: stars = "***"
        elif p_value < 0.01: stars = "**"
        elif p_value < sig_threshold: stars = "*"
        stat_text = f"{test_type}={rho:.2f}{stars}\np={p_value:.3g}\n(N={N_val})"
        bg_color = SIGNIFICANT_P_VAL_BG_COLOR_STEP4 if p_value < sig_threshold else DEFAULT_P_VAL_BG_COLOR_STEP4
    
    ax.text(x_pos, y_pos, stat_text, transform=ax.transAxes, fontsize=fontsize,
            verticalalignment='top', horizontalalignment='right',
            bbox=dict(boxstyle='round,pad=0.3', fc=bg_color, alpha=0.85, edgecolor='darkgrey'))

def get_safe_filename_step4(base_name):
    """Creates a filesystem-safe filename."""
    return re.sub(r'[^\w\s-]', '', str(base_name)).strip().replace(' ', '_').replace('-', '_')

def trim_data_for_boxplot_visualization(df_group, value_col):
    """Trims outliers based on IQR for cleaner boxplot visualization (doesn't affect stats)."""
    if df_group.empty or df_group[value_col].isnull().all() or len(df_group) < 2:
        return df_group
    Q1 = df_group[value_col].quantile(0.25)
    Q3 = df_group[value_col].quantile(0.75)
    IQR = Q3 - Q1
    if IQR == 0: # Avoid issues if all data points are the same
        return df_group
    lower_bound = Q1 - 1.5 * IQR
    upper_bound = Q3 + 1.5 * IQR
    return df_group[(df_group[value_col] >= lower_bound) & (df_group[value_col] <= upper_bound)]


print("Cell 4: Helper functions for Step 4 defined.")

In [None]:
# -*- coding: utf-8 -*-
# --- Cell 5_PREAMBLE: Definitions for State-Specific Analyses (Revised Order: No Sleep, 4 Separate Mobile/Immobile States) ---
# This cell should be run after Cell 4 (Helper Functions) and before the new state-specific analysis cells.

import pandas as pd
import numpy as np
import os
import matplotlib.pyplot as plt
import seaborn as sns
import statsmodels.formula.api as smf
import re

# --- Define NEW Target Clinical States and their Order (4 States, No Sleep, New Order) ---
TARGET_CLINICAL_STATES_ORDERED = [
    "Immobile",
    "Non-Dyskinetic Mobile", # Moved up
    "Transitional Mobile",   # Moved down
    "Dyskinetic Mobile"
]

# This list is now the same as TARGET_CLINICAL_STATES_ORDERED as no combining is done here.
# It's used to initially filter master_df_step4.
ORIGINAL_STATES_FOR_ANALYSIS = TARGET_CLINICAL_STATES_ORDERED[:] # Create a copy

# No combining needed for this setup
STATES_TO_COMBINE_MAPPING = {}
NEW_COMBINED_STATE_NAME = None


# --- Define Clinical State Colors (using original distinct colors, excluding Sleep) ---
# The color definitions themselves don't change, but their application order will follow TARGET_CLINICAL_STATES_ORDERED
NEW_CLINICAL_STATE_COLORS_FOR_PLOTTING = {
    'Immobile': '#40E0D0',              # Turquoise
    'Transitional Mobile': '#FFD700',   # Gold
    'Non-Dyskinetic Mobile': '#32CD32', # LimeGreen
    'Dyskinetic Mobile': '#FF6347',     # Tomato
    # Fallbacks or other states if they were to appear unexpectedly
    'Sleep': '#4169E1',                 # RoyalBlue (original, but excluded from TARGET_CLINICAL_STATES_ORDERED)
    'Other': '#C0C0C0',                 # Silver
    'Mobile (All Types)': 'darkgreen'   # For aggregated view if ever used
}


# --- Define PKG Symptom Colors (remains the same) ---
PKG_SYMPTOM_COLORS = {
    'Aligned_BK': COLOR_PALETTE_STEP4.get('Aligned_BK', 'steelblue'),
    'Aligned_DK': COLOR_PALETTE_STEP4.get('Aligned_DK', 'orangered'),
    'Aligned_Tremor_Score': COLOR_PALETTE_STEP4.get('Aligned_Tremor_Score', 'mediumpurple')
}

# Base output directory name remains based on the content (4 states, no sleep)
STATE_SPECIFIC_ANALYSIS_DIR = os.path.join(analysis_session_plot_folder_step4, "State_Specific_Analyses")
os.makedirs(STATE_SPECIFIC_ANALYSIS_DIR, exist_ok=True)

print("Cell 5_PREAMBLE: Definitions for state-specific analyses (4 States - No Sleep, Separate Mobile, New Order) are set.")
print(f"Target clinical states for analysis (NEW ORDER): {TARGET_CLINICAL_STATES_ORDERED}")
print(f"Colors for clinical states: {NEW_CLINICAL_STATE_COLORS_FOR_PLOTTING}") # This dict remains the same, order of use changes
print(f"State-specific outputs will be saved in subdirectories of: {STATE_SPECIFIC_ANALYSIS_DIR}")

In [None]:
# -*- coding: utf-8 -*-
# --- Cell 5A (New): State-Specific Correlation Calculations (Revised: No Sleep, 4 Separate Mobile/Immobile States) ---
# Calculates Bivariate and Partial correlations for each of the TARGET_CLINICAL_STATES_ORDERED.

print("\\n--- Cell 5A (New): Starting State-Specific Correlation Calculations (4 States - No Sleep, Separate Mobile) ---")

if 'master_df_step4' not in locals() or master_df_step4.empty:
    print("master_df_step4 not available or empty. Skipping Cell 5A.")
else:
    # 1. Filter master_df_step4 directly for the new TARGET_CLINICAL_STATES_ORDERED
    # TARGET_CLINICAL_STATES_ORDERED is now ["Immobile", "Transitional Mobile", "Non-Dyskinetic Mobile", "Dyskinetic Mobile"]
    master_df_step4_filtered_states = master_df_step4[master_df_step4[CLINICAL_STATE_COL].isin(TARGET_CLINICAL_STATES_ORDERED)].copy()

    if not master_df_step4_filtered_states.empty:
        # Ensure the column is categorical with the specified order for consistent processing
        master_df_step4_filtered_states[CLINICAL_STATE_COL] = pd.Categorical(
            master_df_step4_filtered_states[CLINICAL_STATE_COL],
            categories=TARGET_CLINICAL_STATES_ORDERED,
            ordered=True
        )
        # Drop any rows that might not have matched (e.g., if a state name had a typo or was not in the original data)
        master_df_step4_filtered_states.dropna(subset=[CLINICAL_STATE_COL], inplace=True)
        
        # Add datetime_for_avg if not present (needed for Cell 5B)
        if 'datetime_for_avg' not in master_df_step4_filtered_states.columns:
            if 'Aligned_PKG_UnixTimestamp' in master_df_step4_filtered_states.columns:
                master_df_step4_filtered_states['datetime_for_avg'] = pd.to_datetime(
                    master_df_step4_filtered_states['Aligned_PKG_UnixTimestamp'], unit='s', utc=True, errors='coerce'
                )
            else:
                print("Warning in Cell 5A: 'Aligned_PKG_UnixTimestamp' missing, 'datetime_for_avg' cannot be created. Cell 5B averaging might fail.")
                master_df_step4_filtered_states['datetime_for_avg'] = pd.NaT


    if 'master_df_step4_filtered_states' not in locals() or master_df_step4_filtered_states.empty:
        print(f"No data found after filtering for target clinical states: {TARGET_CLINICAL_STATES_ORDERED}. Skipping Cell 5A.")
    else:
        print(f"Filtered data for target clinical states. Shape: {master_df_step4_filtered_states.shape}. Unique states: {master_df_step4_filtered_states[CLINICAL_STATE_COL].unique()}")

        # UPDATED folder name
        state_corr_csv_dir = os.path.join(STATE_SPECIFIC_ANALYSIS_DIR, "Correlation_CSVs_by_State")
        os.makedirs(state_corr_csv_dir, exist_ok=True)

        all_bivariate_ap_pkg_by_state_results = []
        all_partial_ap_pkg_by_state_results = []
        all_bivariate_ap_osc_by_state_results = []

        if CHANNEL_DISPLAY_COL not in master_df_step4_filtered_states.columns:
            if CHANNEL_COL in master_df_step4_filtered_states.columns and 'electrode_labels' in locals():
                 master_df_step4_filtered_states[CHANNEL_DISPLAY_COL] = master_df_step4_filtered_states[CHANNEL_COL].map(electrode_labels).fillna(master_df_step4_filtered_states[CHANNEL_COL])
                 print(f"'{CHANNEL_DISPLAY_COL}' created from '{CHANNEL_COL}' using electrode_labels map for state-specific analysis.")

        for state_current in TARGET_CLINICAL_STATES_ORDERED: # Loop over the 4 new state categories
            df_state = master_df_step4_filtered_states[master_df_step4_filtered_states[CLINICAL_STATE_COL] == state_current]
            if df_state.empty:
                print(f"  No data for Clinical State: {state_current}. Skipping correlations for this state.")
                continue
            print(f"\\nProcessing Clinical State: {state_current} (N={len(df_state)})")

            for channel_label in ORDERED_CHANNEL_LABELS:
                df_channel_state = df_state[df_state[CHANNEL_DISPLAY_COL] == channel_label]
                if df_channel_state.empty:
                    continue

                for freq_label in ORDERED_FREQ_LABELS:
                    df_channel_freq_state = df_channel_state[df_channel_state[FOOOF_FREQ_BAND_COL] == freq_label].copy()
                    if df_channel_freq_state.empty:
                        continue

                    # Part 1: Bivariate Aperiodic vs. PKG
                    for ap_col, ap_name in APERIODIC_METRICS_COLS.items():
                        for pkg_col, pkg_name in PKG_METRICS_COLS.items():
                            if ap_col in df_channel_freq_state.columns and pkg_col in df_channel_freq_state.columns:
                                rho, p_val, N = calculate_spearman_with_n(df_channel_freq_state, ap_col, pkg_col)
                                all_bivariate_ap_pkg_by_state_results.append({
                                    'ClinicalState': state_current, 'Channel': channel_label, 'FreqBand': freq_label,
                                    'AperiodicMetric': ap_name, 'PKGMetric': pkg_name,
                                    'SpearmanRho': rho, 'PValue': p_val, 'N': N
                                })

                    # Part 2: Partial Aperiodic vs. PKG (controlling for Beta, Gamma)
                    covariates_partial_corr_state = [col for col in OSCILLATORY_METRICS_COLS.keys() if col in df_channel_freq_state.columns]
                    if len(covariates_partial_corr_state) == 2:
                        for ap_col, ap_name in APERIODIC_METRICS_COLS.items():
                            for pkg_col, pkg_name in PKG_METRICS_COLS.items():
                                if ap_col in df_channel_freq_state.columns and pkg_col in df_channel_freq_state.columns:
                                    data_for_partial_state = df_channel_freq_state[[ap_col, pkg_col] + covariates_partial_corr_state].dropna()
                                    if len(data_for_partial_state) < MIN_SAMPLES_FOR_CORR:
                                        partial_rho, partial_p_val, N_partial = np.nan, np.nan, len(data_for_partial_state)
                                    else:
                                        partial_rho, partial_p_val, N_partial = calculate_partial_spearman(
                                            data_for_partial_state, ap_col, pkg_col, covariates_partial_corr_state
                                        )
                                    all_partial_ap_pkg_by_state_results.append({
                                        'ClinicalState': state_current, 'Channel': channel_label, 'FreqBand': freq_label,
                                        'AperiodicMetric': ap_name, 'PKGMetric': pkg_name,
                                        'PartialSpearmanRho_vs_BetaGamma': partial_rho, 'PartialPValue': partial_p_val, 'N_Partial': N_partial
                                    })

                    # Part 3: Bivariate Aperiodic vs. Oscillatory
                    for ap_col, ap_name in APERIODIC_METRICS_COLS.items():
                        for osc_col, osc_name in OSCILLATORY_METRICS_COLS.items():
                            if ap_col in df_channel_freq_state.columns and osc_col in df_channel_freq_state.columns:
                                rho_ap_osc, p_val_ap_osc, N_ap_osc = calculate_spearman_with_n(df_channel_freq_state, ap_col, osc_col)
                                all_bivariate_ap_osc_by_state_results.append({
                                    'ClinicalState': state_current, 'Channel': channel_label, 'FreqBand': freq_label,
                                    'AperiodicMetric': ap_name, 'OscillatoryMetric': osc_name,
                                    'SpearmanRho': rho_ap_osc, 'PValue': p_val_ap_osc, 'N': N_ap_osc
                                })
        # UPDATED filenames
        if all_bivariate_ap_pkg_by_state_results:
            df_bivar_ap_pkg_state = pd.DataFrame(all_bivariate_ap_pkg_by_state_results)
            df_bivar_ap_pkg_state.to_csv(os.path.join(state_corr_csv_dir, f"{patient_hemisphere_id}_Bivariate_AP_vs_PKG_ByState.csv"), index=False)
            print(f"\\nSaved Bivariate AP vs PKG by State (4 States - No Sleep) results for {patient_hemisphere_id}.")

        if all_partial_ap_pkg_by_state_results:
            df_partial_ap_pkg_state = pd.DataFrame(all_partial_ap_pkg_by_state_results)
            df_partial_ap_pkg_state.to_csv(os.path.join(state_corr_csv_dir, f"{patient_hemisphere_id}_Partial_AP_vs_PKG_ByState.csv"), index=False)
            print(f"Saved Partial AP vs PKG (controlling Beta, Gamma) by State (4 States - No Sleep) results for {patient_hemisphere_id}.")

        if all_bivariate_ap_osc_by_state_results:
            df_bivar_ap_osc_state = pd.DataFrame(all_bivariate_ap_osc_by_state_results)
            df_bivar_ap_osc_state.to_csv(os.path.join(state_corr_csv_dir, f"{patient_hemisphere_id}_Bivariate_AP_vs_Oscillatory_ByState.csv"), index=False)
            print(f"Saved Bivariate AP vs Oscillatory by State (4 States - No Sleep) results for {patient_hemisphere_id}.")

print("\\n--- Cell 5A (New): State-Specific Correlation Calculations (4 States - No Sleep, Separate Mobile) Complete ---")

In [None]:
# -*- coding: utf-8 -*-
# --- Cell 5B (New): State-Specific Overview Scatter Plots (Aperiodic vs. PKG) (Revised: No Sleep, 4 Separate Mobile/Immobile States) ---
# Generates overview scatter plots: one figure per (PKG_Symptom, Aperiodic_Metric, Channel, FreqBand),
# with subplots for each clinical state (now 4 states). Y-axes are standardized within each figure.
# Scatter plot points are 10-minute averages. Regression line uses all granular data.

print("\\n--- Cell 5B (New): Starting State-Specific Overview Scatter Plot Generation (4 States - No Sleep, 10-min avg pts) ---")

# Ensure master_df_step4_filtered_states has the correct 4 states from the updated Cell 5A
if 'master_df_step4_filtered_states' not in locals() or master_df_step4_filtered_states.empty:
    print("master_df_step4_filtered_states (with 4 states - No Sleep) not available or empty. Skipping Cell 5B.")
elif 'df_bivar_ap_pkg_state' not in locals() or ('df_bivar_ap_pkg_state' in locals() and df_bivar_ap_pkg_state.empty):
    print("Bivariate AP vs PKG by State correlation results (df_bivar_ap_pkg_state) not found. Skipping Cell 5B plot annotations.")
    df_bivar_ap_pkg_state = pd.DataFrame(columns=['ClinicalState', 'Channel', 'FreqBand', 'AperiodicMetric', 'PKGMetric', 'SpearmanRho', 'PValue', 'N'])
else:
    # UPDATED FOLDER NAME
    plot_subdir_overview_scatter_state = os.path.join(STATE_SPECIFIC_ANALYSIS_DIR, "Overview_Scatter_AP_vs_PKG_by_State")
    os.makedirs(plot_subdir_overview_scatter_state, exist_ok=True)

    # Ensure 'datetime_for_avg' exists from Cell 5A modifications
    if 'datetime_for_avg' not in master_df_step4_filtered_states.columns:
        print("ERROR in Cell 5B: 'datetime_for_avg' column missing from master_df_step4_filtered_states. Averaging will fail.")
        # Fallback or exit
    else:
        for pkg_col_overview, pkg_name_overview in PKG_METRICS_COLS.items():
            # ... (rest of the loops for ap_metric, channel, freq_band are the same as the last version of 5B) ...
            # Key is that TARGET_CLINICAL_STATES_ORDERED (now 4 states) and 
            # NEW_CLINICAL_STATE_COLORS_FOR_PLOTTING (with original mobile colors) are used.
            # The number of subplots will be len(TARGET_CLINICAL_STATES_ORDERED) -> 4.

            if pkg_col_overview not in master_df_step4_filtered_states.columns:
                print(f"PKG Metric {pkg_name_overview} ({pkg_col_overview}) not found in data. Skipping its overview plots.")
                continue

            for ap_metric_overview_col, ap_metric_overview_name in APERIODIC_METRICS_COLS.items():
                if ap_metric_overview_col not in master_df_step4_filtered_states.columns:
                    print(f"Aperiodic Metric {ap_metric_overview_name} ({ap_metric_overview_col}) not found. Skipping its overview plots.")
                    continue
                
                print(f"\\nGenerating overview plots for: {ap_metric_overview_name} vs. {pkg_name_overview}")

                for channel_label_overview in ORDERED_CHANNEL_LABELS:
                    for freq_label_overview in ORDERED_FREQ_LABELS:
                        
                        fig_width = max(15, 4 * len(TARGET_CLINICAL_STATES_ORDERED)) # Now 4 states
                        fig, axes = plt.subplots(1, len(TARGET_CLINICAL_STATES_ORDERED), 
                                                 figsize=(fig_width, 5.5), sharey=False)
                        if len(TARGET_CLINICAL_STATES_ORDERED) == 1: # Should not happen with 4 states, but good practice
                            axes = [axes]
                        
                        fig.suptitle(f"{ap_metric_overview_name} vs. {pkg_name_overview}\\nChannel: {channel_label_overview} - Freq: {freq_label_overview} - Patient: {patient_hemisphere_id}\\n(Scatter points are 10-min averages; Regression on all raw data)",
                                     fontsize=plt.rcParams['figure.titlesize'] * 0.85, y=1.05)

                        all_ap_values_for_ylim = []
                        valid_plot_exists_for_figure = False 

                        for i, state_overview_ylim in enumerate(TARGET_CLINICAL_STATES_ORDERED): # Iterates 4 states
                            df_current_combo_ylim_pass = master_df_step4_filtered_states[
                                (master_df_step4_filtered_states[CLINICAL_STATE_COL] == state_overview_ylim) &
                                (master_df_step4_filtered_states[CHANNEL_DISPLAY_COL] == channel_label_overview) &
                                (master_df_step4_filtered_states[FOOOF_FREQ_BAND_COL] == freq_label_overview)
                            ]
                            cols_to_drop_na_for_ylim_pass = [ap_metric_overview_col, pkg_col_overview, 'datetime_for_avg']
                            
                            df_plot_data_ylim_pass = df_current_combo_ylim_pass.dropna(subset=cols_to_drop_na_for_ylim_pass)
                            
                            if not df_plot_data_ylim_pass.empty and len(df_plot_data_ylim_pass) >= MIN_SAMPLES_FOR_CORR:
                                all_ap_values_for_ylim.extend(df_plot_data_ylim_pass[ap_metric_overview_col].tolist())
                        
                        min_y, max_y = (np.nan, np.nan)
                        if all_ap_values_for_ylim: 
                            min_y_val_calc = np.nanmin(all_ap_values_for_ylim)
                            max_y_val_calc = np.nanmax(all_ap_values_for_ylim)
                            if not (np.isnan(min_y_val_calc) or np.isnan(max_y_val_calc)):
                                 padding = (max_y_val_calc - min_y_val_calc) * 0.1 if (max_y_val_calc - min_y_val_calc) > 0 else 0.1
                                 min_y = min_y_val_calc - padding
                                 max_y = max_y_val_calc + padding
                                 
                        for i, state_overview in enumerate(TARGET_CLINICAL_STATES_ORDERED): # Iterates 4 states
                            ax = axes[i]
                            df_current_combo_plot = master_df_step4_filtered_states[
                                (master_df_step4_filtered_states[CLINICAL_STATE_COL] == state_overview) &
                                (master_df_step4_filtered_states[CHANNEL_DISPLAY_COL] == channel_label_overview) &
                                (master_df_step4_filtered_states[FOOOF_FREQ_BAND_COL] == freq_label_overview)
                            ]
                            
                            cols_to_drop_na_plot = [ap_metric_overview_col, pkg_col_overview, 'datetime_for_avg']
                            df_plot_data_granular_plot = df_current_combo_plot.dropna(subset=cols_to_drop_na_plot)

                            if not df_plot_data_granular_plot.empty and len(df_plot_data_granular_plot) >= MIN_SAMPLES_FOR_CORR :
                                valid_plot_exists_for_figure = True 

                                corr_stats_row = df_bivar_ap_pkg_state[
                                    (df_bivar_ap_pkg_state['ClinicalState'] == state_overview) &
                                    (df_bivar_ap_pkg_state['Channel'] == channel_label_overview) &
                                    (df_bivar_ap_pkg_state['FreqBand'] == freq_label_overview) &
                                    (df_bivar_ap_pkg_state['AperiodicMetric'] == ap_metric_overview_name) &
                                    (df_bivar_ap_pkg_state['PKGMetric'] == pkg_name_overview)
                                ]
                                rho = corr_stats_row['SpearmanRho'].iloc[0] if not corr_stats_row.empty else np.nan
                                p_val = corr_stats_row['PValue'].iloc[0] if not corr_stats_row.empty else np.nan
                                N_val = corr_stats_row['N'].iloc[0] if not corr_stats_row.empty else len(df_plot_data_granular_plot)

                                df_averaged_points_plot = pd.DataFrame()
                                if 'datetime_for_avg' in df_plot_data_granular_plot.columns and \
                                   not df_plot_data_granular_plot['datetime_for_avg'].isnull().all():
                                    try:
                                        df_averaged_points_plot = df_plot_data_granular_plot.set_index('datetime_for_avg')\
                                            .groupby(pd.Grouper(freq='10T'))[[ap_metric_overview_col, pkg_col_overview]]\
                                            .mean().dropna()
                                    except Exception as e_avg: 
                                        print(f"Warning: 10-min averaging failed for {channel_label_overview}, {freq_label_overview}, {state_overview}. Plotting granular points. Error: {e_avg}")
                                        df_averaged_points_plot = df_plot_data_granular_plot 
                                else: 
                                    df_averaged_points_plot = df_plot_data_granular_plot

                                if not df_averaged_points_plot.empty:
                                    sns.scatterplot(data=df_averaged_points_plot, x=pkg_col_overview, y=ap_metric_overview_col,
                                                    color=NEW_CLINICAL_STATE_COLORS_FOR_PLOTTING.get(state_overview, 'grey'),
                                                    alpha=DOT_ALPHA_STEP4 + 0.2, s=40, edgecolor='black', linewidths=0.5, ax=ax, legend=False)

                                sns.regplot(data=df_plot_data_granular_plot, x=pkg_col_overview, y=ap_metric_overview_col, scatter=False, ax=ax,
                                            line_kws={'color': 'black', 'linewidth': 1.5, 'alpha': 0.6})
                                
                                annotate_correlation_on_plot(ax, rho, p_val, N_val, fontsize=8)
                                ax.set_title(state_overview, fontsize=plt.rcParams['axes.titlesize']*0.8)
                                ax.set_xlabel(pkg_name_overview if i == len(TARGET_CLINICAL_STATES_ORDERED) // 2 else "", fontsize=plt.rcParams['axes.labelsize']*0.9)
                                
                                if not pd.isna(min_y) and not pd.isna(max_y):
                                    ax.set_ylim(min_y, max_y)
                            else:
                                ax.text(0.5, 0.5, "N < min_samples" if len(df_plot_data_granular_plot) < MIN_SAMPLES_FOR_CORR else "No Data", 
                                        horizontalalignment='center', verticalalignment='center', transform=ax.transAxes, fontsize=9)
                                ax.set_title(state_overview, fontsize=plt.rcParams['axes.titlesize']*0.8)
                                if not pd.isna(min_y) and not pd.isna(max_y): 
                                    ax.set_ylim(min_y, max_y)

                            if i == 0:
                                ax.set_ylabel(ap_metric_overview_name, fontsize=plt.rcParams['axes.labelsize'])
                            else:
                                ax.set_ylabel("")
                                ax.set_yticklabels([])
                            
                            ax.tick_params(axis='x', labelsize=plt.rcParams['xtick.labelsize']*0.9)
                            ax.tick_params(axis='y', labelsize=plt.rcParams['ytick.labelsize']*0.9)

                        if valid_plot_exists_for_figure: 
                            plt.tight_layout(rect=[0, 0.03, 1, 0.93]) 
                            safe_ap_name = get_safe_filename_step4(ap_metric_overview_name)
                            safe_pkg_name = get_safe_filename_step4(pkg_name_overview)
                            safe_ch_name = get_safe_filename_step4(channel_label_overview)
                            safe_freq_name = get_safe_filename_step4(freq_label_overview)
                            # UPDATED FILENAME
                            plot_filename = f"Overview_{safe_ap_name}_vs_{safe_pkg_name}_{safe_ch_name}_{safe_freq_name}.png" 
                            plt.savefig(os.path.join(plot_subdir_overview_scatter_state, plot_filename))
                        plt.close(fig)

print("\\n--- Cell 5B (New): State-Specific Overview Scatter Plot Generation (4 States - No Sleep, 10-min avg pts) Complete ---")

In [None]:
# -*- coding: utf-8 -*-
# --- Cell 5C (Revised Logic v3): State-Specific Heatmaps for Exponent/Offset vs. PKG ---
# Generates heatmaps of Partial Spearman correlations (Aperiodic vs. PKG, controlling for Beta/Gamma).
# Separate plots for Exponent and Offset. Columns are PKG metrics (BK, DK, Tremor).
# Significance stars are bold. Annotations show rho and stars only.

print("\\n--- Cell 5C (Revised Logic v3): Starting State-Specific Heatmap Generation for Aperiodic vs. PKG ---")

if 'df_partial_ap_pkg_state' not in locals() or df_partial_ap_pkg_state.empty:
    print("State-specific partial correlation data (df_partial_ap_pkg_state) not available or empty. Skipping Cell 5C.")
else:
    plot_subdir_heatmaps_state = os.path.join(STATE_SPECIFIC_ANALYSIS_DIR, "Heatmaps_PartialCorrelations_by_State_Revised")
    os.makedirs(plot_subdir_heatmaps_state, exist_ok=True)

    # Define short PKG names for heatmap columns and their desired order
    PKG_DISPLAY_NAME_MAP = {
        'PKG BK Score': 'BK',
        'PKG DK Score': 'DK',
        'PKG Tremor Score': 'Tremor' 
    }
    ORDERED_PKG_METRIC_ORIGINAL_NAMES = [
        PKG_METRICS_COLS.get('Aligned_BK', 'PKG BK Score'),
        PKG_METRICS_COLS.get('Aligned_DK', 'PKG DK Score'),
        PKG_METRICS_COLS.get('Aligned_Tremor_Score', 'PKG Tremor Score')
    ]
    ORDERED_PKG_SHORT_NAMES_FOR_HEATMAP = [PKG_DISPLAY_NAME_MAP[name] for name in ORDERED_PKG_METRIC_ORIGINAL_NAMES if name in PKG_DISPLAY_NAME_MAP]

    for ap_metric_key_loop, ap_metric_name_display_loop in APERIODIC_METRICS_COLS.items():
        
        print(f"\\n=== Generating Heatmaps for Aperiodic Metric: {ap_metric_name_display_loop} ===")

        for state_heatmap in TARGET_CLINICAL_STATES_ORDERED: 
            df_current_state_all_ap_corrs = df_partial_ap_pkg_state[df_partial_ap_pkg_state['ClinicalState'] == state_heatmap]
            
            df_current_state_single_ap_corrs = df_current_state_all_ap_corrs[
                df_current_state_all_ap_corrs['AperiodicMetric'] == ap_metric_name_display_loop
            ].copy() 

            if df_current_state_single_ap_corrs.empty:
                continue
            
            for freq_label_heatmap in ORDERED_FREQ_LABELS:
                df_heatmap_data_filtered = df_current_state_single_ap_corrs[
                    df_current_state_single_ap_corrs['FreqBand'] == freq_label_heatmap
                ].copy()

                if df_heatmap_data_filtered.empty:
                    continue
                
                print(f"  Generating heatmap for Aperiodic: {ap_metric_name_display_loop}, State: {state_heatmap}, Freq: {freq_label_heatmap}")

                df_heatmap_data_filtered['PKGMetricShort'] = df_heatmap_data_filtered['PKGMetric'].map(PKG_DISPLAY_NAME_MAP)
                df_heatmap_data_filtered = df_heatmap_data_filtered[df_heatmap_data_filtered['PKGMetricShort'].notna()]

                if df_heatmap_data_filtered.empty:
                    continue

                pivot_index_col = CHANNEL_DISPLAY_COL if CHANNEL_DISPLAY_COL in df_heatmap_data_filtered.columns else 'Channel'

                try:
                    heatmap_pivot_rho = df_heatmap_data_filtered.pivot_table(
                        index=pivot_index_col, 
                        columns='PKGMetricShort', 
                        values='PartialSpearmanRho_vs_BetaGamma',
                        aggfunc='first'
                    )
                    heatmap_pivot_rho = heatmap_pivot_rho.reindex(
                        index=[ch for ch in ORDERED_CHANNEL_LABELS if ch in heatmap_pivot_rho.index],
                        columns=[name for name in ORDERED_PKG_SHORT_NAMES_FOR_HEATMAP if name in heatmap_pivot_rho.columns]
                    ).dropna(how='all', axis=0).dropna(how='all', axis=1)

                    heatmap_pivot_annot_underlying_data = df_heatmap_data_filtered.pivot_table(
                        index=pivot_index_col,
                        columns='PKGMetricShort', 
                        values=['PartialPValue', 'N_Partial'], # N_Partial still needed for significance logic
                        aggfunc='first'
                    )
                    if not heatmap_pivot_annot_underlying_data.empty:
                        heatmap_pivot_annot_underlying_data = heatmap_pivot_annot_underlying_data.reindex(
                            index=[ch for ch in ORDERED_CHANNEL_LABELS if ch in heatmap_pivot_annot_underlying_data.index]
                        )
                        valid_pkg_cols_for_annot = [name for name in ORDERED_PKG_SHORT_NAMES_FOR_HEATMAP if name in heatmap_pivot_rho.columns]
                        
                        cols_to_reindex_annot = []
                        for val_type in ['PartialPValue', 'N_Partial']:
                            for pkg_short_name in valid_pkg_cols_for_annot:
                                if (val_type, pkg_short_name) in heatmap_pivot_annot_underlying_data.columns:
                                    cols_to_reindex_annot.append((val_type, pkg_short_name))
                        
                        if cols_to_reindex_annot:
                             heatmap_pivot_annot_underlying_data = heatmap_pivot_annot_underlying_data.reindex(
                                 columns=pd.MultiIndex.from_tuples(cols_to_reindex_annot)
                             ).dropna(how='all', axis=0)
                        else:
                             heatmap_pivot_annot_underlying_data = pd.DataFrame(index=heatmap_pivot_annot_underlying_data.index)

                except Exception as e_pivot_state:
                    print(f"    Error pivoting data for heatmap (Aperiodic: {ap_metric_name_display_loop}, State: {state_heatmap}, Freq: {freq_label_heatmap}): {e_pivot_state}. Skipping.")
                    continue

                if heatmap_pivot_rho.empty:
                    continue
                
                annot_text_final = heatmap_pivot_rho.copy().astype(object) 
                for r_idx in heatmap_pivot_rho.index:
                    for c_idx_pkg_short_name in heatmap_pivot_rho.columns: 
                        rho_val = heatmap_pivot_rho.loc[r_idx, c_idx_pkg_short_name]
                        
                        p_val_val = np.nan
                        n_val = 0
                        
                        if r_idx in heatmap_pivot_annot_underlying_data.index:
                            p_val_col_tuple = ('PartialPValue', c_idx_pkg_short_name)
                            n_val_col_tuple = ('N_Partial', c_idx_pkg_short_name)

                            if p_val_col_tuple in heatmap_pivot_annot_underlying_data.columns:
                                p_val_val = heatmap_pivot_annot_underlying_data.loc[r_idx, p_val_col_tuple]
                            if n_val_col_tuple in heatmap_pivot_annot_underlying_data.columns:
                                n_val = heatmap_pivot_annot_underlying_data.loc[r_idx, n_val_col_tuple]
                        
                        if pd.isna(rho_val):
                            annot_text_final.loc[r_idx, c_idx_pkg_short_name] = "N/A"
                        else:
                            stars_str = ""
                            if pd.notna(p_val_val) and pd.notna(n_val) and n_val >= MIN_SAMPLES_FOR_CORR:
                                if p_val_val < 0.001: stars_str = "***"
                                elif p_val_val < 0.01: stars_str = "**"
                                elif p_val_val < P_VALUE_THRESHOLD: stars_str = "*"
                            
                            stars_formatted_str = "" # Renamed from stars_colored_str
                            if stars_str: 
                                # Matplotlib mathtext string for bold stars (color removed)
                                stars_formatted_str = f"$\\mathbf{{{stars_str}}}$"
                                
                            annot_text_final.loc[r_idx, c_idx_pkg_short_name] = f"{rho_val:.2f}{stars_formatted_str}"

                plt.figure(figsize=(max(8, heatmap_pivot_rho.shape[1] * 2.5), max(6, heatmap_pivot_rho.shape[0] * 0.9)))
                sns.heatmap(heatmap_pivot_rho.astype(float), annot=annot_text_final, fmt='s', 
                            cmap="coolwarm_r", center=0, vmin=-1, vmax=1,
                            linewidths=.5, linecolor='grey', cbar_kws={'label': f"Partial Spearman ρ\n({ap_metric_name_display_loop} vs. PKG, ctrl Beta,Gamma)"},
                            annot_kws={"size": 9}) 
                
                plt.title(f"Partial Corr: {ap_metric_name_display_loop} vs. PKG (ctrl Beta,Gamma)\nState: {state_heatmap} - Freq: {freq_label_heatmap} - Patient: {patient_hemisphere_id}",
                          fontsize=plt.rcParams['figure.titlesize']*0.80) 
                plt.ylabel("Channel", fontsize=plt.rcParams['axes.labelsize'])
                plt.xlabel("PKG Symptom", fontsize=plt.rcParams['axes.labelsize']) 
                plt.xticks(rotation=0, ha="center", fontsize=plt.rcParams['xtick.labelsize']*0.9) 
                plt.yticks(rotation=0, fontsize=plt.rcParams['ytick.labelsize']*0.9)
                plt.tight_layout(rect=[0,0,1,0.92]) 

                safe_ap_name_file = get_safe_filename_step4(ap_metric_name_display_loop)
                safe_state_name_file = get_safe_filename_step4(state_heatmap)
                safe_freq_name_file = get_safe_filename_step4(freq_label_heatmap)
                
                plot_filename_heatmap = f"Heatmap_PartialCorr_{safe_ap_name_file}_PKG_{safe_state_name_file}_{safe_freq_name_file}.png"
                plt.savefig(os.path.join(plot_subdir_heatmaps_state, plot_filename_heatmap))
                plt.close()
                # print(f"    Saved partial correlation heatmap for {ap_metric_name_display_loop}, State: {state_heatmap}, Freq Band: {freq_label_heatmap}")

print("\\n--- Cell 5C (Revised Logic v3): State-Specific Heatmap Generation for Aperiodic vs. PKG Complete ---")

In [None]:
# -*- coding: utf-8 -*-
# --- Cell 5D (New): Plot Spearman Coefficients vs. Clinical States (Revised for 4 States) ---
# Generates line plots showing how Spearman Rho (Aperiodic vs PKG) changes across clinical states (now 4 states).
# Separate plots for Exponent-PKG and Offset-PKG correlations.

print("\\n--- Cell 5D (New): Starting Spearman Coefficients vs. Clinical States Plot Generation (4 States) ---")

# Ensure df_bivar_ap_pkg_state is the remapped version from the updated Cell 5A
if 'df_bivar_ap_pkg_state' not in locals() or df_bivar_ap_pkg_state.empty:
    print("State-specific bivariate AP vs PKG correlation data (df_bivar_ap_pkg_state) not available. Skipping Cell 5D.")
else:
    # UPDATED FOLDER NAME
    plot_subdir_coeff_vs_state = os.path.join(STATE_SPECIFIC_ANALYSIS_DIR, "SpearmanCoeff_vs_ClinicalState")
    os.makedirs(plot_subdir_coeff_vs_state, exist_ok=True)

    # Ensure the dataframe has the clinical states in the desired order for plotting
    # TARGET_CLINICAL_STATES_ORDERED is now the 4-state list from Cell 5_PREAMBLE
    df_bivar_ap_pkg_state_ordered_for_coeff_plot = df_bivar_ap_pkg_state.copy()
    df_bivar_ap_pkg_state_ordered_for_coeff_plot['ClinicalState'] = pd.Categorical(
        df_bivar_ap_pkg_state_ordered_for_coeff_plot['ClinicalState'],
        categories=TARGET_CLINICAL_STATES_ORDERED, # Uses the 4-state list
        ordered=True
    )
    # No need to sort here as .reindex later will handle order for plotting

    for ap_metric_key, ap_metric_name_display in APERIODIC_METRICS_COLS.items():
        print(f"\\nProcessing coefficient plots for Aperiodic Metric: {ap_metric_name_display}")
        
        df_ap_metric_subset_coeff = df_bivar_ap_pkg_state_ordered_for_coeff_plot[
            df_bivar_ap_pkg_state_ordered_for_coeff_plot['AperiodicMetric'] == ap_metric_name_display
        ]
        if df_ap_metric_subset_coeff.empty:
            print(f"  No data for aperiodic metric: {ap_metric_name_display}. Skipping its plot.")
            continue

        for channel_label_coeff_plot in ORDERED_CHANNEL_LABELS:
            for freq_label_coeff_plot in ORDERED_FREQ_LABELS:
                
                df_plot_final_coeff = df_ap_metric_subset_coeff[
                    (df_ap_metric_subset_coeff['Channel'] == channel_label_coeff_plot) &
                    (df_ap_metric_subset_coeff['FreqBand'] == freq_label_coeff_plot)
                ]

                if df_plot_final_coeff.empty:
                    continue

                fig, ax = plt.subplots(figsize=(10, 6)) # Adjusted figsize slightly
                has_data_for_plot = False

                for pkg_col_coeff, pkg_name_coeff_display in PKG_METRICS_COLS.items():
                    df_pkg_symptom_subset_coeff = df_plot_final_coeff[
                        df_plot_final_coeff['PKGMetric'] == pkg_name_coeff_display
                    ]
                    
                    if not df_pkg_symptom_subset_coeff.empty:
                        pivoted_for_plot = df_pkg_symptom_subset_coeff.pivot_table(
                            index='ClinicalState', values='SpearmanRho', aggfunc='mean' 
                        ).reindex(TARGET_CLINICAL_STATES_ORDERED) # Ensure correct order & NaNs for missing

                        if not pivoted_for_plot['SpearmanRho'].isnull().all():
                            ax.plot(pivoted_for_plot.index.astype(str), pivoted_for_plot['SpearmanRho'], # x-axis as string for categorical
                                    marker='o', markersize=7, linestyle='-', linewidth=1.5,
                                    label=pkg_name_coeff_display,
                                    color=PKG_SYMPTOM_COLORS.get(pkg_col_coeff, 'grey'))
                            has_data_for_plot = True
                
                if has_data_for_plot:
                    ax.set_title(f"Spearman ρ ({ap_metric_name_display} vs. PKG) by Clinical State\\n{channel_label_coeff_plot} - {freq_label_coeff_plot} - Patient: {patient_hemisphere_id}",
                                 fontsize=plt.rcParams['axes.titlesize']*0.9)
                    ax.set_xlabel("Clinical State", fontsize=plt.rcParams['axes.labelsize'])
                    ax.set_ylabel("Spearman Correlation Coefficient (ρ)", fontsize=plt.rcParams['axes.labelsize'])
                    ax.legend(title="PKG Symptom", loc='center left', bbox_to_anchor=(1, 0.5)) # Adjusted legend position
                    ax.grid(True, linestyle='--', alpha=0.6)
                    plt.xticks(rotation=20, ha="right", fontsize=plt.rcParams['xtick.labelsize']*0.9)
                    ax.axhline(0, color='black', linewidth=0.8, linestyle='--') 
                    ax.set_ylim(-1.05, 1.05) # Ensure full range is visible

                    plt.tight_layout(rect=[0, 0, 0.85, 0.95]) # Adjust for legend
                    
                    safe_ap_name_coeff = get_safe_filename_step4(ap_metric_name_display)
                    safe_ch_name_coeff = get_safe_filename_step4(channel_label_coeff_plot)
                    safe_freq_name_coeff = get_safe_filename_step4(freq_label_coeff_plot)
                    
                    # UPDATED FILENAME
                    plot_filename_coeff = f"CoeffVsState_{safe_ap_name_coeff}_{safe_ch_name_coeff}_{safe_freq_name_coeff}.png"
                    plt.savefig(os.path.join(plot_subdir_coeff_vs_state, plot_filename_coeff))
                plt.close(fig)

print("\\n--- Cell 5D (New): Spearman Coefficients vs. Clinical States Plot Generation (4 States) Complete ---")

In [None]:
# -*- coding: utf-8 -*-
# --- Cell 5 (Revised with 15-min Averaged Scatter Points): Bivariate & Partial Correlations, Aperiodic vs. Beta/Gamma ---
# Iterates through Channel_Display AND FreqRangeLabel
# Scatter points are 15-min averages; regression lines use all granular data.

print("\\n--- Cell 5 (Revised with 15-min Averaged Scatter Points): Starting Correlation Analyses ---")

if 'master_df_step4' not in locals() or master_df_step4 is None or master_df_step4.empty:
    print("master_df_step4 not available or empty. Skipping Cell 5.")
else:
    # Create a working copy to add a datetime column for averaging
    master_df_step4_processed_c5 = master_df_step4.copy()
    if 'Aligned_PKG_UnixTimestamp' in master_df_step4_processed_c5.columns:
        master_df_step4_processed_c5['datetime_for_avg_c5'] = pd.to_datetime(
            master_df_step4_processed_c5['Aligned_PKG_UnixTimestamp'], unit='s', utc=True, errors='coerce'
        )
        if master_df_step4_processed_c5['datetime_for_avg_c5'].isnull().all():
            print("Warning in Cell 5: 'datetime_for_avg_c5' could not be created (all NaT). 15-min averaging might fail or use granular points.")
    else:
        print("Warning in Cell 5: 'Aligned_PKG_UnixTimestamp' not found. Cannot perform 15-minute averaging for scatter plots. Will plot granular points.")
        master_df_step4_processed_c5['datetime_for_avg_c5'] = pd.NaT 

    # Create subdirectories for plots from this cell (removed suffix)
    plot_subdir_bivariate_ap_pkg = os.path.join(analysis_session_plot_folder_step4, "Bivariate_AP_vs_PKG")
    plot_subdir_bivariate_ap_osc = os.path.join(analysis_session_plot_folder_step4, "Bivariate_AP_vs_Oscillatory")
    os.makedirs(plot_subdir_bivariate_ap_pkg, exist_ok=True)
    os.makedirs(plot_subdir_bivariate_ap_osc, exist_ok=True)

    all_bivariate_ap_pkg_results = []
    all_partial_ap_pkg_results = [] 
    all_bivariate_ap_osc_results = []

    for channel_label in ORDERED_CHANNEL_LABELS: 
        df_channel = master_df_step4_processed_c5[master_df_step4_processed_c5[CHANNEL_DISPLAY_COL] == channel_label]
        if df_channel.empty:
            print(f"\\nNo data for channel: {channel_label}. Skipping.")
            continue
        
        print(f"\\nProcessing Channel: {channel_label}")

        for freq_label in ORDERED_FREQ_LABELS: 
            df_channel_freq = df_channel[df_channel[FOOOF_FREQ_BAND_COL] == freq_label].copy()
            if df_channel_freq.empty:
                print(f"  No data for Freq Band: {freq_label} in Channel: {channel_label}. Skipping.")
                continue
            print(f"  Processing Freq Band: {freq_label}")

            # --- Part 5.1: Bivariate Spearman Correlations (Aperiodic vs. PKG) ---
            print(f"    Part 5.1: Bivariate Aperiodic vs. PKG Correlations for {channel_label} ({freq_label})")
            for ap_col, ap_name in APERIODIC_METRICS_COLS.items():
                for pkg_col, pkg_name in PKG_METRICS_COLS.items():
                    df_granular_for_corr_pkg = df_channel_freq.dropna(subset=[ap_col, pkg_col])
                    rho, p_val, N = calculate_spearman_with_n(df_granular_for_corr_pkg, ap_col, pkg_col)
                    
                    all_bivariate_ap_pkg_results.append({
                        'Channel': channel_label, 'FreqBand': freq_label,
                        'AperiodicMetric': ap_name, 'PKGMetric': pkg_name,
                        'SpearmanRho': rho, 'PValue': p_val, 'N': N
                    })

                    if N >= MIN_SAMPLES_FOR_CORR:
                        plt.figure(figsize=(7, 6))
                        ax = plt.gca() 

                        df_averaged_points_pkg = pd.DataFrame()
                        # Check if 'datetime_for_avg_c5' exists and has non-null values before attempting to use it as index
                        if 'datetime_for_avg_c5' in df_granular_for_corr_pkg.columns and \
                           not df_granular_for_corr_pkg['datetime_for_avg_c5'].isnull().all():
                            try:
                                # MODIFIED: Changed to 5T for 5-minute averaging
                                df_averaged_points_pkg = df_granular_for_corr_pkg.set_index('datetime_for_avg_c5')\
                                    .groupby(pd.Grouper(freq='5T'))[[ap_col, pkg_col]]\
                                    .mean().dropna()
                            except Exception as e_avg_pkg:
                                # MODIFIED: Warning message for 5-min
                                print(f"      Warning: 5-min averaging failed for PKG plot ({ap_name} vs {pkg_name}). Plotting granular. Error: {e_avg_pkg}")
                                df_averaged_points_pkg = df_granular_for_corr_pkg 
                        else:
                            df_averaged_points_pkg = df_granular_for_corr_pkg 
                        
                        if df_averaged_points_pkg.empty and not df_granular_for_corr_pkg.empty:
                            df_averaged_points_pkg = df_granular_for_corr_pkg
                        
                        if not df_averaged_points_pkg.empty:
                             sns.scatterplot(data=df_averaged_points_pkg, x=pkg_col, y=ap_col,
                                            color=COLOR_PALETTE_STEP4.get(ap_col, 'grey'), 
                                            alpha=DOT_ALPHA_STEP4 + 0.1, s=40, 
                                            edgecolor='k', linewidths=0.5, ax=ax)
                        
                        sns.regplot(data=df_granular_for_corr_pkg, x=pkg_col, y=ap_col, scatter=False, ax=ax,
                                    line_kws={'color': COLOR_PALETTE_STEP4.get(pkg_col, 'black'), 
                                              'linewidth': REG_LINE_THICKNESS_STEP4, 'alpha': 0.7})
                        
                        annotate_correlation_on_plot(ax, rho, p_val, N) 
                        # MODIFIED: Removed averaging info from title
                        ax.set_title(f"{ap_name} vs. {pkg_name}\\n{channel_label} ({freq_label})", fontsize=plt.rcParams['axes.titlesize'])
                        ax.set_xlabel(pkg_name, fontsize=plt.rcParams['axes.labelsize'])
                        ax.set_ylabel(ap_name, fontsize=plt.rcParams['axes.labelsize'])
                        plt.tight_layout()
                        
                        safe_ch = get_safe_filename_step4(channel_label)
                        safe_ap = get_safe_filename_step4(ap_name)
                        safe_pkg = get_safe_filename_step4(pkg_name)
                        # MODIFIED: Removed suffix from filename
                        plot_filename = f"Bivar_{safe_ap}_vs_{safe_pkg}_{safe_ch}_{freq_label}.png" 
                        plt.savefig(os.path.join(plot_subdir_bivariate_ap_pkg, plot_filename))
                        plt.close()
            
            print(f"    Part 5.2: Partial Aperiodic vs. PKG Correlations for {channel_label} ({freq_label})")
            covariates_partial_corr = [col for col in OSCILLATORY_METRICS_COLS.keys() if col in df_channel_freq.columns]
            if len(covariates_partial_corr) == 2: 
                for ap_col, ap_name in APERIODIC_METRICS_COLS.items():
                    for pkg_col, pkg_name in PKG_METRICS_COLS.items():
                        data_for_partial = df_channel_freq[[ap_col, pkg_col] + covariates_partial_corr].dropna()
                        if len(data_for_partial) < MIN_SAMPLES_FOR_CORR:
                             partial_rho, partial_p_val, N_partial = np.nan, np.nan, len(data_for_partial)
                        else:
                             partial_rho, partial_p_val, N_partial = calculate_partial_spearman(
                                 data_for_partial, ap_col, pkg_col, covariates_partial_corr
                             )
                        all_partial_ap_pkg_results.append({
                            'Channel': channel_label, 'FreqBand': freq_label,
                            'AperiodicMetric': ap_name, 'PKGMetric': pkg_name,
                            'PartialSpearmanRho_vs_BetaGamma': partial_rho, 'PartialPValue': partial_p_val, 'N_Partial': N_partial
                        })

            print(f"    Part 5.3: Bivariate Aperiodic vs. Oscillatory Correlations for {channel_label} ({freq_label})")
            for ap_col, ap_name in APERIODIC_METRICS_COLS.items():
                for osc_col, osc_name in OSCILLATORY_METRICS_COLS.items():
                    if osc_col not in df_channel_freq.columns: continue 
                    
                    df_granular_for_corr_osc = df_channel_freq.dropna(subset=[ap_col, osc_col])
                    rho_ap_osc, p_val_ap_osc, N_ap_osc = calculate_spearman_with_n(df_granular_for_corr_osc, ap_col, osc_col)
                    
                    all_bivariate_ap_osc_results.append({
                        'Channel': channel_label, 'FreqBand': freq_label,
                        'AperiodicMetric': ap_name, 'OscillatoryMetric': osc_name,
                        'SpearmanRho': rho_ap_osc, 'PValue': p_val_ap_osc, 'N': N_ap_osc
                    })

                    if N_ap_osc >= MIN_SAMPLES_FOR_CORR:
                        plt.figure(figsize=(7, 6))
                        ax_ap_osc = plt.gca()

                        df_averaged_points_osc = pd.DataFrame()
                        # Check if 'datetime_for_avg_c5' exists and has non-null values before attempting to use it as index
                        if 'datetime_for_avg_c5' in df_granular_for_corr_osc.columns and \
                           not df_granular_for_corr_osc['datetime_for_avg_c5'].isnull().all():
                            try:
                                # MODIFIED: Changed to 5T for 15-minute averaging
                                df_averaged_points_osc = df_granular_for_corr_osc.set_index('datetime_for_avg_c5')\
                                    .groupby(pd.Grouper(freq='5T'))[[ap_col, osc_col]]\
                                    .mean().dropna()
                            except Exception as e_avg_osc:
                                # MODIFIED: Warning message for 15-min
                                print(f"      Warning: 15-min averaging failed for Oscillatory plot ({ap_name} vs {osc_name}). Plotting granular. Error: {e_avg_osc}")
                                df_averaged_points_osc = df_granular_for_corr_osc 
                        else:
                             df_averaged_points_osc = df_granular_for_corr_osc 
                        
                        if df_averaged_points_osc.empty and not df_granular_for_corr_osc.empty:
                            df_averaged_points_osc = df_granular_for_corr_osc

                        if not df_averaged_points_osc.empty:
                            sns.scatterplot(data=df_averaged_points_osc, x=osc_col, y=ap_col,
                                            color=COLOR_PALETTE_STEP4.get(ap_col, 'grey'), 
                                            alpha=DOT_ALPHA_STEP4 + 0.1, s=40,
                                            edgecolor='k', linewidths=0.5, ax=ax_ap_osc)
                        
                        sns.regplot(data=df_granular_for_corr_osc, x=osc_col, y=ap_col, scatter=False, ax=ax_ap_osc,
                                    line_kws={'color': COLOR_PALETTE_STEP4.get(osc_col, 'black'), 
                                              'linewidth': REG_LINE_THICKNESS_STEP4, 'alpha': 0.7})
                        
                        annotate_correlation_on_plot(ax_ap_osc, rho_ap_osc, p_val_ap_osc, N_ap_osc, test_type="Spearman ρ")
                        # MODIFIED: Removed averaging info from title
                        ax_ap_osc.set_title(f"{ap_name} vs. {osc_name}\\n{channel_label} ({freq_label})", fontsize=plt.rcParams['axes.titlesize'])
                        ax_ap_osc.set_xlabel(osc_name, fontsize=plt.rcParams['axes.labelsize'])
                        ax_ap_osc.set_ylabel(ap_name, fontsize=plt.rcParams['axes.labelsize'])
                        plt.tight_layout()
                        
                        safe_ch = get_safe_filename_step4(channel_label)
                        safe_ap = get_safe_filename_step4(ap_name)
                        safe_osc = get_safe_filename_step4(osc_name)
                        # MODIFIED: Removed suffix from filename
                        plot_filename_ap_osc = f"Bivar_{safe_ap}_vs_{safe_osc}_{safe_ch}_{freq_label}.png" 
                        plt.savefig(os.path.join(plot_subdir_bivariate_ap_osc, plot_filename_ap_osc))
                        plt.close()

    if all_bivariate_ap_pkg_results:
        df_bivar_ap_pkg = pd.DataFrame(all_bivariate_ap_pkg_results)
        df_bivar_ap_pkg.to_csv(os.path.join(analysis_session_plot_folder_step4, f"{patient_hemisphere_id}_Bivariate_AP_vs_PKG_Correlations_Cell5Original.csv"), index=False)
        print(f"\\nSaved Bivariate AP vs PKG correlation results (from original Cell 5 structure) for {patient_hemisphere_id}.")

    if all_partial_ap_pkg_results:
        df_partial_ap_pkg = pd.DataFrame(all_partial_ap_pkg_results)
        df_partial_ap_pkg.to_csv(os.path.join(analysis_session_plot_folder_step4, f"{patient_hemisphere_id}_Partial_AP_vs_PKG_Correlations_Cell5Original.csv"), index=False)
        print(f"Saved Partial AP vs PKG (controlling Beta, Gamma) correlation results (from original Cell 5 structure) for {patient_hemisphere_id}.")

    if all_bivariate_ap_osc_results:
        df_bivar_ap_osc = pd.DataFrame(all_bivariate_ap_osc_results)
        df_bivar_ap_osc.to_csv(os.path.join(analysis_session_plot_folder_step4, f"{patient_hemisphere_id}_Bivariate_AP_vs_Oscillatory_Correlations_Cell5Original.csv"), index=False)
        print(f"Saved Bivariate AP vs Oscillatory correlation results (from original Cell 5 structure) for {patient_hemisphere_id}.")

print("\\n--- Cell 5 (Revised with 15-min Averaged Scatter Points): Correlation Analyses Complete ---")

In [None]:
# -*- coding: utf-8 -*-
# --- Cell 5 (Revised with 15-min Averaged Scatter Points & PKG Zero Exclusion): Bivariate & Partial Correlations, Aperiodic vs. Beta/Gamma ---
# Iterates through Channel_Display AND FreqRangeLabel
# Scatter points are 15-min averages; regression lines use all granular data.
# MODIFICATION: For analyses involving a specific PKG score, data points where that PKG score is zero are excluded.

print("\\n--- Cell 5 (Revised with 15-min Averaged Scatter Points & PKG Zero Exclusion): Starting Correlation Analyses ---")

if 'master_df_step4' not in locals() or master_df_step4 is None or master_df_step4.empty:
    print("master_df_step4 not available or empty. Skipping Cell 5.")
else:
    # Create a working copy to add a datetime column for averaging
    master_df_step4_processed_c5 = master_df_step4.copy()
    if 'Aligned_PKG_UnixTimestamp' in master_df_step4_processed_c5.columns:
        master_df_step4_processed_c5['datetime_for_avg_c5'] = pd.to_datetime(
            master_df_step4_processed_c5['Aligned_PKG_UnixTimestamp'], unit='s', utc=True, errors='coerce'
        )
        if master_df_step4_processed_c5['datetime_for_avg_c5'].isnull().all():
            print("Warning in Cell 5: 'datetime_for_avg_c5' could not be created (all NaT). 15-min averaging might fail or use granular points.")
    else:
        print("Warning in Cell 5: 'Aligned_PKG_UnixTimestamp' not found. Cannot perform 15-minute averaging for scatter plots. Will plot granular points.")
        master_df_step4_processed_c5['datetime_for_avg_c5'] = pd.NaT 

    # Create subdirectories for plots from this cell
    plot_subdir_bivariate_ap_pkg = os.path.join(analysis_session_plot_folder_step4, "Bivariate_AP_vs_PKG")
    plot_subdir_bivariate_ap_osc = os.path.join(analysis_session_plot_folder_step4, "Bivariate_AP_vs_Oscillatory")
    os.makedirs(plot_subdir_bivariate_ap_pkg, exist_ok=True)
    os.makedirs(plot_subdir_bivariate_ap_osc, exist_ok=True)

    all_bivariate_ap_pkg_results = []
    all_partial_ap_pkg_results = [] 
    all_bivariate_ap_osc_results = []

    for channel_label in ORDERED_CHANNEL_LABELS: 
        df_channel = master_df_step4_processed_c5[master_df_step4_processed_c5[CHANNEL_DISPLAY_COL] == channel_label]
        if df_channel.empty:
            print(f"\\nNo data for channel: {channel_label}. Skipping.")
            continue
        
        print(f"\\nProcessing Channel: {channel_label}")

        for freq_label in ORDERED_FREQ_LABELS: 
            df_channel_freq = df_channel[df_channel[FOOOF_FREQ_BAND_COL] == freq_label].copy()
            if df_channel_freq.empty:
                print(f"  No data for Freq Band: {freq_label} in Channel: {channel_label}. Skipping.")
                continue
            print(f"  Processing Freq Band: {freq_label}")

            # --- Part 5.1: Bivariate Spearman Correlations (Aperiodic vs. PKG) ---
            print(f"    Part 5.1: Bivariate Aperiodic vs. PKG Correlations for {channel_label} ({freq_label})")
            for ap_col, ap_name in APERIODIC_METRICS_COLS.items():
                for pkg_col, pkg_name in PKG_METRICS_COLS.items():
                    if pkg_col not in df_channel_freq.columns or ap_col not in df_channel_freq.columns:
                        print(f"      Skipping {ap_name} vs {pkg_name}: Column missing.")
                        continue

                    # MODIFICATION: Filter out rows where the current pkg_col is zero
                    df_temp_no_zero_pkg = df_channel_freq[df_channel_freq[pkg_col] != 0]
                    
                    # Then, drop NaNs from the relevant columns (ap_col and current pkg_col)
                    df_granular_for_corr_pkg = df_temp_no_zero_pkg.dropna(subset=[ap_col, pkg_col])
                    
                    rho, p_val, N = calculate_spearman_with_n(df_granular_for_corr_pkg, ap_col, pkg_col)
                    
                    all_bivariate_ap_pkg_results.append({
                        'Channel': channel_label, 'FreqBand': freq_label,
                        'AperiodicMetric': ap_name, 'PKGMetric': pkg_name,
                        'SpearmanRho': rho, 'PValue': p_val, 'N': N
                    })

                    if N >= MIN_SAMPLES_FOR_CORR:
                        plt.figure(figsize=(7, 6))
                        ax = plt.gca() 

                        df_averaged_points_pkg = pd.DataFrame()
                        if 'datetime_for_avg_c5' in df_granular_for_corr_pkg.columns and \
                           not df_granular_for_corr_pkg['datetime_for_avg_c5'].isnull().all():
                            try:
                                df_averaged_points_pkg = df_granular_for_corr_pkg.set_index('datetime_for_avg_c5')\
                                    .groupby(pd.Grouper(freq='5T'))[[ap_col, pkg_col]]\
                                    .mean().dropna()
                            except Exception as e_avg_pkg:
                                print(f"      Warning: 15-min averaging failed for PKG plot ({ap_name} vs {pkg_name}). Plotting granular. Error: {e_avg_pkg}")
                                df_averaged_points_pkg = df_granular_for_corr_pkg.copy() # Use copy for safety
                        else:
                            df_averaged_points_pkg = df_granular_for_corr_pkg.copy() 
                        
                        if df_averaged_points_pkg.empty and not df_granular_for_corr_pkg.empty:
                            df_averaged_points_pkg = df_granular_for_corr_pkg.copy()
                        
                        if not df_averaged_points_pkg.empty:
                             sns.scatterplot(data=df_averaged_points_pkg, x=pkg_col, y=ap_col,
                                            color=COLOR_PALETTE_STEP4.get(ap_col, 'grey'), 
                                            alpha=DOT_ALPHA_STEP4 + 0.1, s=40, 
                                            edgecolor='k', linewidths=0.5, ax=ax)
                        
                        # Regression line uses the granular data (which is already filtered for pkg_col zeros and NaNs)
                        sns.regplot(data=df_granular_for_corr_pkg, x=pkg_col, y=ap_col, scatter=False, ax=ax,
                                    line_kws={'color': COLOR_PALETTE_STEP4.get(pkg_col, 'black'), 
                                              'linewidth': REG_LINE_THICKNESS_STEP4, 'alpha': 0.7})
                        
                        annotate_correlation_on_plot(ax, rho, p_val, N) 
                        ax.set_title(f"{ap_name} vs. {pkg_name}\n{channel_label} ({freq_label})", fontsize=plt.rcParams['axes.titlesize'])
                        ax.set_xlabel(pkg_name, fontsize=plt.rcParams['axes.labelsize'])
                        ax.set_ylabel(ap_name, fontsize=plt.rcParams['axes.labelsize'])
                        plt.tight_layout()
                        
                        safe_ch = get_safe_filename_step4(channel_label)
                        safe_ap = get_safe_filename_step4(ap_name)
                        safe_pkg = get_safe_filename_step4(pkg_name)
                        plot_filename = f"Bivar_{safe_ap}_vs_{safe_pkg}_{safe_ch}_{freq_label}.png" 
                        plt.savefig(os.path.join(plot_subdir_bivariate_ap_pkg, plot_filename))
                        plt.close()
            
            # --- Part 5.2: Partial Aperiodic vs. PKG Correlations ---
            print(f"    Part 5.2: Partial Aperiodic vs. PKG Correlations for {channel_label} ({freq_label})")
            covariates_partial_corr = [col for col in OSCILLATORY_METRICS_COLS.keys() if col in df_channel_freq.columns]
            if len(covariates_partial_corr) == 2: 
                for ap_col, ap_name in APERIODIC_METRICS_COLS.items():
                    for pkg_col, pkg_name in PKG_METRICS_COLS.items():
                        if pkg_col not in df_channel_freq.columns or ap_col not in df_channel_freq.columns:
                            continue # Skip if essential columns are missing

                        # MODIFICATION: Filter out rows where the current pkg_col is zero
                        df_temp_no_zero_pkg_partial = df_channel_freq[df_channel_freq[pkg_col] != 0]
                        
                        # Then, select columns for partial correlation and drop NaNs
                        cols_for_partial_corr = [ap_col, pkg_col] + covariates_partial_corr
                        # Ensure all selected columns actually exist in df_temp_no_zero_pkg_partial before selection
                        cols_for_partial_corr_present = [c for c in cols_for_partial_corr if c in df_temp_no_zero_pkg_partial.columns]
                        if len(cols_for_partial_corr_present) != len(cols_for_partial_corr):
                            # This case should be rare if checks above are fine, but as a safeguard
                            # print(f"      Skipping partial corr {ap_name} vs {pkg_name}: Not all necessary columns present after filtering.")
                            partial_rho, partial_p_val, N_partial = np.nan, np.nan, 0
                        else:
                            data_for_partial = df_temp_no_zero_pkg_partial[cols_for_partial_corr_present].dropna()
                        
                            if len(data_for_partial) < MIN_SAMPLES_FOR_CORR or not all(data_for_partial[c].nunique() > 1 for c in [ap_col, pkg_col] if c in data_for_partial): # check variance for key vars
                                 partial_rho, partial_p_val, N_partial = np.nan, np.nan, len(data_for_partial)
                            else:
                                 partial_rho, partial_p_val, N_partial = calculate_partial_spearman(
                                     data_for_partial, ap_col, pkg_col, covariates_partial_corr # Pass original covariates list
                                 )
                        all_partial_ap_pkg_results.append({
                            'Channel': channel_label, 'FreqBand': freq_label,
                            'AperiodicMetric': ap_name, 'PKGMetric': pkg_name,
                            'PartialSpearmanRho_vs_BetaGamma': partial_rho, 'PartialPValue': partial_p_val, 'N_Partial': N_partial
                        })

            # --- Part 5.3: Bivariate Aperiodic vs. Oscillatory Correlations ---
            # This part remains unchanged by the PKG zero-filtering request, as it does not directly use PKG scores in the correlation.
            print(f"    Part 5.3: Bivariate Aperiodic vs. Oscillatory Correlations for {channel_label} ({freq_label})")
            for ap_col, ap_name in APERIODIC_METRICS_COLS.items():
                for osc_col, osc_name in OSCILLATORY_METRICS_COLS.items():
                    if osc_col not in df_channel_freq.columns or ap_col not in df_channel_freq.columns:
                        continue 
                    
                    # Data for this correlation is taken directly from df_channel_freq (no PKG zero filter here)
                    df_granular_for_corr_osc = df_channel_freq.dropna(subset=[ap_col, osc_col])
                    rho_ap_osc, p_val_ap_osc, N_ap_osc = calculate_spearman_with_n(df_granular_for_corr_osc, ap_col, osc_col)
                    
                    all_bivariate_ap_osc_results.append({
                        'Channel': channel_label, 'FreqBand': freq_label,
                        'AperiodicMetric': ap_name, 'OscillatoryMetric': osc_name,
                        'SpearmanRho': rho_ap_osc, 'PValue': p_val_ap_osc, 'N': N_ap_osc
                    })

                    if N_ap_osc >= MIN_SAMPLES_FOR_CORR:
                        plt.figure(figsize=(7, 6))
                        ax_ap_osc = plt.gca()

                        df_averaged_points_osc = pd.DataFrame()
                        if 'datetime_for_avg_c5' in df_granular_for_corr_osc.columns and \
                           not df_granular_for_corr_osc['datetime_for_avg_c5'].isnull().all():
                            try:
                                df_averaged_points_osc = df_granular_for_corr_osc.set_index('datetime_for_avg_c5')\
                                    .groupby(pd.Grouper(freq='5T'))[[ap_col, osc_col]]\
                                    .mean().dropna()
                            except Exception as e_avg_osc:
                                print(f"      Warning: 15-min averaging failed for Oscillatory plot ({ap_name} vs {osc_name}). Plotting granular. Error: {e_avg_osc}")
                                df_averaged_points_osc = df_granular_for_corr_osc.copy() 
                        else:
                             df_averaged_points_osc = df_granular_for_corr_osc.copy() 
                        
                        if df_averaged_points_osc.empty and not df_granular_for_corr_osc.empty:
                            df_averaged_points_osc = df_granular_for_corr_osc.copy()

                        if not df_averaged_points_osc.empty:
                            sns.scatterplot(data=df_averaged_points_osc, x=osc_col, y=ap_col,
                                            color=COLOR_PALETTE_STEP4.get(ap_col, 'grey'), 
                                            alpha=DOT_ALPHA_STEP4 + 0.1, s=40,
                                            edgecolor='k', linewidths=0.5, ax=ax_ap_osc)
                        
                        sns.regplot(data=df_granular_for_corr_osc, x=osc_col, y=ap_col, scatter=False, ax=ax_ap_osc,
                                    line_kws={'color': COLOR_PALETTE_STEP4.get(osc_col, 'black'), 
                                              'linewidth': REG_LINE_THICKNESS_STEP4, 'alpha': 0.7})
                        
                        annotate_correlation_on_plot(ax_ap_osc, rho_ap_osc, p_val_ap_osc, N_ap_osc, test_type="Spearman ρ")
                        ax_ap_osc.set_title(f"{ap_name} vs. {osc_name}\n{channel_label} ({freq_label})", fontsize=plt.rcParams['axes.titlesize'])
                        ax_ap_osc.set_xlabel(osc_name, fontsize=plt.rcParams['axes.labelsize'])
                        ax_ap_osc.set_ylabel(ap_name, fontsize=plt.rcParams['axes.labelsize'])
                        plt.tight_layout()
                        
                        safe_ch = get_safe_filename_step4(channel_label)
                        safe_ap = get_safe_filename_step4(ap_name)
                        safe_osc = get_safe_filename_step4(osc_name)
                        plot_filename_ap_osc = f"Bivar_{safe_ap}_vs_{safe_osc}_{safe_ch}_{freq_label}.png" 
                        plt.savefig(os.path.join(plot_subdir_bivariate_ap_osc, plot_filename_ap_osc))
                        plt.close()

    if all_bivariate_ap_pkg_results:
        df_bivar_ap_pkg = pd.DataFrame(all_bivariate_ap_pkg_results)
        df_bivar_ap_pkg.to_csv(os.path.join(analysis_session_plot_folder_step4, f"{patient_hemisphere_id}_Bivariate_AP_vs_PKG_Correlations_NoZerosInPKG_Cell5.csv"), index=False) # Filename updated
        print(f"\\nSaved Bivariate AP vs PKG correlation results (zeros in PKG excluded) for {patient_hemisphere_id}.")

    if all_partial_ap_pkg_results:
        df_partial_ap_pkg = pd.DataFrame(all_partial_ap_pkg_results)
        df_partial_ap_pkg.to_csv(os.path.join(analysis_session_plot_folder_step4, f"{patient_hemisphere_id}_Partial_AP_vs_PKG_Correlations_NoZerosInPKG_Cell5.csv"), index=False) # Filename updated
        print(f"Saved Partial AP vs PKG (zeros in PKG excluded) correlation results for {patient_hemisphere_id}.")

    if all_bivariate_ap_osc_results:
        df_bivar_ap_osc = pd.DataFrame(all_bivariate_ap_osc_results)
        # Filename remains the same as this part is not affected by PKG zero filtering
        df_bivar_ap_osc.to_csv(os.path.join(analysis_session_plot_folder_step4, f"{patient_hemisphere_id}_Bivariate_AP_vs_Oscillatory_Correlations_Cell5.csv"), index=False) 
        print(f"Saved Bivariate AP vs Oscillatory correlation results for {patient_hemisphere_id}.")

print("\\n--- Cell 5 (Revised with 15-min Averaged Scatter Points & PKG Zero Exclusion): Correlation Analyses Complete ---")

In [None]:
# -*- coding: utf-8 -*-
# --- Cell 6 (Streamlined V3 - Tiered MLR with Oscillatory-Only Model): Multiple Linear Regression ---
# Focuses on separate models for Exponent and Offset, and introduces an Oscillatory-Only tier.

import pandas as pd
import numpy as np
import os
import statsmodels.formula.api as smf
import matplotlib.pyplot as plt
import seaborn as sns 
from scipy.stats import spearmanr 

print("\n--- Cell 6 (Streamlined V3 - Tiered MLR with Oscillatory-Only Model): Starting Analyses ---\n")

# <<< --- USER TOGGLE --- >>>
ANALYZE_ALL_FREQ_BANDS = False
TARGET_FREQ_BAND_IF_NOT_ALL = "WideFreq"
# <<< ------------------ >>>

if 'master_df_step4' not in locals() or master_df_step4.empty:
    print("master_df_step4 not available or empty. Skipping Cell 6.")
else:
    plot_subdir_mlr = os.path.join(analysis_session_plot_folder_step4, "MultipleLinearRegression_PKG_on_Neural_STREAMLINED_V3") # New folder name
    os.makedirs(plot_subdir_mlr, exist_ok=True)
    
    plot_subdir_ap_corr = os.path.join(plot_subdir_mlr, "Aperiodic_Intercorrelations")
    os.makedirs(plot_subdir_ap_corr, exist_ok=True)
    
    mlr_results_list_streamlined_v3 = [] 
    exp_offset_corr_results_list = [] 

    available_aperiodic_cols = [col for col in APERIODIC_METRICS_COLS.keys() if col in master_df_step4.columns]
    available_oscillatory_cols = [col for col in OSCILLATORY_METRICS_COLS.keys() if col in master_df_step4.columns]

    exponent_col_name = 'Exponent_BestModel' 
    offset_col_name = 'Offset_BestModel'
    beta_col_name = 'Beta_Peak_Power_at_DominantFreq'
    gamma_col_name = 'Gamma_Peak_Power_at_DominantFreq'

    # Helper function (retained from previous streamlined version)
    def fit_and_interpret_mlr_focused_v3(data_df, dv_col, dv_name, predictors, model_tier_label,
                                          channel="N/A", freq_band="N/A"):
        
        if not predictors or not all(p in data_df.columns for p in predictors):
            missing_p = [p for p in predictors if p not in data_df.columns] if predictors else []
            # Only print warning if predictors were expected but missing. If predictors list is empty, it's handled.
            if predictors and missing_p:
                 print(f"      Skipping {dv_name} ({model_tier_label}): Predictor(s) {missing_p} not found.")
            elif not predictors : # If predictor list is genuinely empty (e.g. no oscillatory features available)
                 print(f"      Skipping {dv_name} ({model_tier_label}): No valid predictors provided for this model tier.")
            return
        
        formula_to_fit = f"{dv_col} ~ {' + '.join(predictors)}"
        
        model_cols_to_check = [dv_col] + predictors
        unique_model_cols_to_check = list(dict.fromkeys(model_cols_to_check))
        current_data_for_model = data_df[unique_model_cols_to_check].dropna(how='any').copy()

        if len(current_data_for_model) < (len(predictors) + 10): 
            print(f"      Skipping {dv_name} ({model_tier_label}): Insufficient data ({len(current_data_for_model)}) for {len(predictors)} predictors.")
            return
        
        for pred_check in predictors: 
            if pred_check in current_data_for_model.columns and current_data_for_model[pred_check].nunique() < 2:
                print(f"      Skipping {dv_name} ({model_tier_label}): Constant predictor {pred_check} found.")
                return
        
        print(f"      --- Model Tier: {model_tier_label} ---")
        print(f"      Fitting for DV '{dv_name}' with formula: {formula_to_fit}")
        
        try:
            model_fit = smf.ols(formula=formula_to_fit, data=current_data_for_model).fit()
            
            print(f"\n      --- MLR Summary: {dv_name} ({model_tier_label}) ---")
            print(f"      Channel: {channel}, Freq Band: {freq_band}")
            print(f"      N = {model_fit.nobs}")
            print(model_fit.summary())
            print("      --- End of Statsmodels Summary ---\n")

            interpretation_string = f"      --- Interpretation for {dv_name} ({model_tier_label} - {channel}, {freq_band}) ---\n"
            interpretation_string += f"      Overall Model Fit: Explains {model_fit.rsquared_adj * 100:.1f}% of variance (Adj. R² = {model_fit.rsquared_adj:.3f}, N = {model_fit.nobs}).\n"
            interpretation_string += "      Predictor Contributions:\n"

            for term_in_model in model_fit.params.index:
                if term_in_model == 'Intercept': continue
                coeff = model_fit.params.get(term_in_model, np.nan)
                pval = model_fit.pvalues.get(term_in_model, np.nan)
                term_name_display = APERIODIC_METRICS_COLS.get(term_in_model, OSCILLATORY_METRICS_COLS.get(term_in_model, term_in_model))
                
                interpretation_string += f"        - {term_name_display} (Term: {term_in_model}):\n"
                interpretation_string += f"          Coefficient: {coeff:.3f}, P-value: {pval:.3g}"
                
                if pd.notna(pval) and pd.notna(coeff):
                    is_significant = pval < P_VALUE_THRESHOLD 
                    direction = 'increase' if coeff > 0 else 'decrease' if coeff < 0 else 'no change'
                    significance_text = ' (significant)' if is_significant else ' (not significant)'
                    if is_significant:
                        interpretation_string += f"\n          Interpretation: A 1-unit increase in {term_name_display} is associated with a {abs(coeff):.3f} unit {direction} in {dv_name}{significance_text}, controlling for other model predictors.\n"
                    else:
                        interpretation_string += f"\n          Interpretation: {term_name_display} was not a significant predictor of {dv_name} (p={pval:.3g}), controlling for other model predictors.\n"
                else:
                    interpretation_string += "\n          Interpretation: Stats not available.\n"

                mlr_results_list_streamlined_v3.append({ 
                    'Channel': channel, 'FreqBand_Aperiodics': freq_band, 'PKG_Symptom_DV': dv_name,
                    'Model_Tier': model_tier_label, 'Formula': formula_to_fit,
                    'Predictor_Term': term_in_model, 
                    'Predictor_Name_Display': term_name_display,
                    'Coefficient': coeff, 'StdErr': model_fit.bse.get(term_in_model, np.nan), 'PValue': pval,
                    'Conf_Int_Lower': model_fit.conf_int().loc[term_in_model, 0] if term_in_model in model_fit.conf_int().index else np.nan,
                    'Conf_Int_Upper': model_fit.conf_int().loc[term_in_model, 1] if term_in_model in model_fit.conf_int().index else np.nan,
                    'N_model': model_fit.nobs, 'R_squared_adj_model': model_fit.rsquared_adj
                })
            print(interpretation_string)
            print("      --- End of Interpretation ---\n")
        except Exception as e_mlr_fit:
            print(f"      ERROR fitting MLR for {dv_name} ({model_tier_label}): {e_mlr_fit}")


    freq_bands_to_process = [TARGET_FREQ_BAND_IF_NOT_ALL] if not ANALYZE_ALL_FREQ_BANDS else ORDERED_FREQ_LABELS
    if not ANALYZE_ALL_FREQ_BANDS: print(f"--- Analyzing ONLY for Freq Band: {TARGET_FREQ_BAND_IF_NOT_ALL} ---")

    for channel_label_iter in ORDERED_CHANNEL_LABELS:
        df_channel_mlr_main = master_df_step4[master_df_step4[CHANNEL_DISPLAY_COL] == channel_label_iter]
        if df_channel_mlr_main.empty: continue
        print(f"\n>>> Processing MLR for Channel: {channel_label_iter} <<<")

        for freq_label_iter in freq_bands_to_process:
            df_channel_freq_mlr_main = df_channel_mlr_main[df_channel_mlr_main[FOOOF_FREQ_BAND_COL] == freq_label_iter].copy()
            if df_channel_freq_mlr_main.empty: continue
            print(f"\n  --- Freq Band: {freq_label_iter} ---")

            if exponent_col_name in df_channel_freq_mlr_main.columns and offset_col_name in df_channel_freq_mlr_main.columns:
                print(f"    Sanity Check: Bivariate Correlation between {exponent_col_name} and {offset_col_name}")
                exp_offset_data = df_channel_freq_mlr_main[[exponent_col_name, offset_col_name]].dropna()
                if len(exp_offset_data) >= MIN_SAMPLES_FOR_CORR: 
                    rho_eo, p_eo, N_eo = calculate_spearman_with_n(exp_offset_data, exponent_col_name, offset_col_name)
                    print(f"      Spearman Rho (Exponent vs Offset): {rho_eo:.3f}, P-value: {p_eo:.3g}, N: {N_eo}")
                    exp_offset_corr_results_list.append({
                        'Channel': channel_label_iter, 'FreqBand': freq_label_iter,
                        'SpearmanRho_ExpOff': rho_eo, 'PValue_ExpOff': p_eo, 'N_ExpOff': N_eo
                    })
                    # Plotting Exponent vs Offset (retained)
                    plt.figure(figsize=(7,6))
                    ax_eo = sns.scatterplot(data=exp_offset_data, x=offset_col_name, y=exponent_col_name, alpha=0.5)
                    sns.regplot(data=exp_offset_data, x=offset_col_name, y=exponent_col_name, scatter=False, ax=ax_eo, color='red')
                    annotate_correlation_on_plot(ax_eo, rho_eo, p_eo, N_eo) 
                    ax_eo.set_title(f"{APERIODIC_METRICS_COLS.get(exponent_col_name, exponent_col_name)} vs. {APERIODIC_METRICS_COLS.get(offset_col_name, offset_col_name)}\n{channel_label_iter} ({freq_label_iter})")
                    ax_eo.set_xlabel(APERIODIC_METRICS_COLS.get(offset_col_name, offset_col_name))
                    ax_eo.set_ylabel(APERIODIC_METRICS_COLS.get(exponent_col_name, exponent_col_name))
                    plt.tight_layout()
                    safe_ch_eo = get_safe_filename_step4(channel_label_iter) 
                    plt.savefig(os.path.join(plot_subdir_ap_corr, f"Bivar_Exponent_vs_Offset_{safe_ch_eo}_{freq_label_iter}.png")); plt.close()
                    print(f"        Saved Exponent vs Offset plot.")
                else: print(f"      Insufficient data for Exponent vs Offset correlation.")

            for pkg_col_iter, pkg_name_iter in PKG_METRICS_COLS.items():
                if pkg_col_iter not in df_channel_freq_mlr_main.columns: continue
                print(f"\n    --- Predicting: {pkg_name_iter} (DV: {pkg_col_iter}) ---")

                # Tier 1: Exponent Only
                if exponent_col_name in available_aperiodic_cols:
                    fit_and_interpret_mlr_focused_v3(df_channel_freq_mlr_main, pkg_col_iter, pkg_name_iter, 
                                          [exponent_col_name], "Tier 1: Exponent Only", 
                                          channel=channel_label_iter, freq_band=freq_label_iter)
                
                # <<< NEW TIER 1b: Oscillatory Only >>>
                tier1b_predictors_osc_only = [p for p in [beta_col_name, gamma_col_name] if p in available_oscillatory_cols and p in df_channel_freq_mlr_main.columns]
                if len(tier1b_predictors_osc_only) > 0: # Run if at least one oscillatory predictor is available
                     fit_and_interpret_mlr_focused_v3(df_channel_freq_mlr_main, pkg_col_iter, pkg_name_iter, 
                                           tier1b_predictors_osc_only, "Tier 1b: Oscillatory Only", 
                                           channel=channel_label_iter, freq_band=freq_label_iter)
                
                # Tier 1c (was Tier 1b): Exponent + Oscillatory
                tier1c_predictors = []
                if exponent_col_name in available_aperiodic_cols: tier1c_predictors.append(exponent_col_name)
                tier1c_predictors.extend(p for p in [beta_col_name, gamma_col_name] if p in available_oscillatory_cols and p in df_channel_freq_mlr_main.columns)
                if exponent_col_name in tier1c_predictors and len(tier1c_predictors) > 1: 
                     fit_and_interpret_mlr_focused_v3(df_channel_freq_mlr_main, pkg_col_iter, pkg_name_iter, 
                                           tier1c_predictors, "Tier 1c: Exponent + Oscillatory", 
                                           channel=channel_label_iter, freq_band=freq_label_iter)
                
                # Tier 2: Offset Only
                if offset_col_name in available_aperiodic_cols:
                    fit_and_interpret_mlr_focused_v3(df_channel_freq_mlr_main, pkg_col_iter, pkg_name_iter,
                                          [offset_col_name], "Tier 2: Offset Only",
                                          channel=channel_label_iter, freq_band=freq_label_iter)

                # Tier 2b: Offset + Oscillatory
                tier2b_predictors = []
                if offset_col_name in available_aperiodic_cols: tier2b_predictors.append(offset_col_name)
                tier2b_predictors.extend(p for p in [beta_col_name, gamma_col_name] if p in available_oscillatory_cols and p in df_channel_freq_mlr_main.columns)
                if offset_col_name in tier2b_predictors and len(tier2b_predictors) > 1: 
                     fit_and_interpret_mlr_focused_v3(df_channel_freq_mlr_main, pkg_col_iter, pkg_name_iter, 
                                           tier2b_predictors, "Tier 2b: Offset + Oscillatory", 
                                           channel=channel_label_iter, freq_band=freq_label_iter)
                
    # Saving Exponent vs Offset correlation results
    if exp_offset_corr_results_list:
        df_exp_offset_corr = pd.DataFrame(exp_offset_corr_results_list)
        csv_filename_exp_offset_corr = f"{patient_hemisphere_id}_Exponent_vs_Offset_Correlations_Step6.csv"
        df_exp_offset_corr.to_csv(os.path.join(analysis_session_plot_folder_step4, csv_filename_exp_offset_corr), index=False)
        print(f"\nSaved Exponent vs Offset bivariate correlation results to {csv_filename_exp_offset_corr}.")
        print("Sample of Exponent vs Offset correlations:")
        print(df_exp_offset_corr.head())

    if mlr_results_list_streamlined_v3: 
        df_mlr_results_step6_streamlined_v3 = pd.DataFrame(mlr_results_list_streamlined_v3)
        csv_filename_mlr_streamlined_v3 = f"{patient_hemisphere_id}_MLR_Streamlined_V3_Results_Step6.csv" 
        df_mlr_results_step6_streamlined_v3.to_csv(os.path.join(analysis_session_plot_folder_step4, csv_filename_mlr_streamlined_v3), index=False)
        print(f"\nSaved Step 6 Streamlined V3 MLR results for {patient_hemisphere_id} to {csv_filename_mlr_streamlined_v3}.")
        cols_to_show_mlr = ['Channel', 'FreqBand_Aperiodics', 'PKG_Symptom_DV', 'Model_Tier', 
                            'Predictor_Term', 'Coefficient', 'PValue', 'R_squared_adj_model', 'N_model']
        print(df_mlr_results_step6_streamlined_v3[[c for c in cols_to_show_mlr if c in df_mlr_results_step6_streamlined_v3.columns]].head())
    else:
        print("\nNo Streamlined V3 MLR models were successfully fitted or no results to save for Step 6.")

print(f"\n--- Cell 6 (Streamlined V3 - Tiered MLR with Oscillatory-Only Model): Analyses Complete ---")

In [None]:
# -*- coding: utf-8 -*-
# --- Cell 7: Box Plots - Aperiodic Metric Distributions by Channel (Exploratory) ---
# Iterates through Channel_Display AND FreqRangeLabel

print("\n--- Cell 7: Generating Box Plots for Aperiodic Metric Distributions by Channel (Exploratory) ---")

if master_df_step4 is None or master_df_step4.empty:
    print("master_df_step4 not available or empty. Skipping Cell 7.")
else:
    plot_subdir_ap_dist = os.path.join(analysis_session_plot_folder_step4, "Distributions_Aperiodic_by_Channel")
    os.makedirs(plot_subdir_ap_dist, exist_ok=True)

    for freq_label in ORDERED_FREQ_LABELS:
        df_freq_band_c7 = master_df_step4[master_df_step4[FOOOF_FREQ_BAND_COL] == freq_label].copy()
        if df_freq_band_c7.empty:
            print(f"No data for Freq Band: {freq_label}. Skipping box plots for this band.")
            continue
        
        print(f"\nProcessing Aperiodic Box Plots for Freq Band: {freq_label}")

        for ap_col, ap_name in APERIODIC_METRICS_COLS.items():
            if ap_col not in df_freq_band_c7.columns or df_freq_band_c7[ap_col].isnull().all():
                print(f"  Skipping {ap_name}: column not found or all NaN for {freq_label}.")
                continue

            plt.figure(figsize=(max(8, len(ORDERED_CHANNEL_LABELS) * 1.2), 6)) # Adjust width based on num channels
            
            # Data for boxplot - trim visual outliers
            boxplot_viz_data_list_c7 = []
            for ch_lab_c7 in ORDERED_CHANNEL_LABELS:
                ch_data_c7 = df_freq_band_c7[df_freq_band_c7[CHANNEL_DISPLAY_COL] == ch_lab_c7]
                if not ch_data_c7.empty:
                    trimmed_ch_data_c7 = trim_data_for_boxplot_visualization(ch_data_c7, ap_col) # Helper from Cell 4
                    boxplot_viz_data_list_c7.append(trimmed_ch_data_c7)
            
            boxplot_df_viz_c7 = pd.concat(boxplot_viz_data_list_c7, ignore_index=True) if boxplot_viz_data_list_c7 else pd.DataFrame()

            if not boxplot_df_viz_c7.empty and not boxplot_df_viz_c7[ap_col].isnull().all():
                current_ap_color_c7 = COLOR_PALETTE_STEP4.get(ap_col, 'grey')
                current_ap_face_color_rgba_c7 = list(sns.color_palette([current_ap_color_c7])[0]) + [BOX_FILL_ALPHA_STEP4]


                sns.boxplot(data=boxplot_df_viz_c7, x=CHANNEL_DISPLAY_COL, y=ap_col,
                            order=ORDERED_CHANNEL_LABELS, showfliers=False,
                            boxprops={'facecolor': tuple(current_ap_face_color_rgba_c7), 'edgecolor': current_ap_color_c7, 'linewidth': 1.5},
                            whiskerprops={'color': current_ap_color_c7, 'linewidth': 1.5},
                            capprops={'color': current_ap_color_c7, 'linewidth': 1.5},
                            medianprops={'color': 'black', 'linewidth': 1.5} # Make median more distinct
                           )
                # Overlay stripplot with raw (but potentially filtered for NaNs) data for the current ap_col
                strip_data_c7 = df_freq_band_c7.dropna(subset=[ap_col])
                if not strip_data_c7.empty:
                    sns.stripplot(data=strip_data_c7, x=CHANNEL_DISPLAY_COL, y=ap_col,
                                  order=ORDERED_CHANNEL_LABELS, color=current_ap_color_c7,
                                  alpha=DOT_ALPHA_STEP4*0.5, jitter=0.2, size=4, marker='o', linewidth=0)
            else:
                 plt.text(0.5, 0.5, "No Data to Plot", ha='center', va='center', transform=plt.gca().transAxes)


            plt.title(f"{ap_name} Distribution by Channel\n{patient_hemisphere_id} - Freq Band: {freq_label}", fontsize=plt.rcParams['axes.titlesize'])
            plt.xlabel("Channel", fontsize=plt.rcParams['axes.labelsize'])
            plt.ylabel(ap_name, fontsize=plt.rcParams['axes.labelsize'])
            plt.xticks(rotation=45, ha="right")
            plt.tight_layout()
            
            safe_ap = get_safe_filename_step4(ap_name)
            plot_filename_ap_dist = f"Box_ApDist_{safe_ap}_{freq_label}.png"
            plt.savefig(os.path.join(plot_subdir_ap_dist, plot_filename_ap_dist))
            plt.close()
            print(f"  Saved box plot for {ap_name} ({freq_label})")

print("\n--- Cell 7: Aperiodic Metric Distribution Box Plots Complete ---")

In [None]:
# -*- coding: utf-8 -*-
# --- Cell 9 (MODIFIED for 5 States - Sleep, Imm, NDM, TM, DM & 10-min avg points) ---

print("\\n--- Cell 9 (Specific 5 States - Sleep, Immobile, NDM, TM, DM): Box Plots ---")

from scipy.stats import kruskal
import scikit_posthocs as sp

# Constants assumed from your environment (define if not globally available)
if 'ORDERED_CHANNEL_LABELS' not in locals(): ORDERED_CHANNEL_LABELS = [] # Should be defined from Cell 3 of Step 4
if 'ORDERED_FREQ_LABELS' not in locals(): ORDERED_FREQ_LABELS = ["LowFreq", "MidFreq", "WideFreq"] # Should be from Cell 2 of Step 4
if 'APERIODIC_METRICS_COLS' not in locals(): APERIODIC_METRICS_COLS = {} # Should be from Cell 2 of Step 4
if 'CLINICAL_STATE_COL' not in locals(): CLINICAL_STATE_COL = 'Clinical_State_2min_Window' # Should be from Cell 2 of Step 4
if 'CHANNEL_DISPLAY_COL' not in locals(): CHANNEL_DISPLAY_COL = 'Channel_Display' # Should be from Cell 2 of Step 4
if 'FOOOF_FREQ_BAND_COL' not in locals(): FOOOF_FREQ_BAND_COL = 'FreqRangeLabel' # Should be from Cell 2 of Step 4
if 'patient_hemisphere_id' not in locals(): patient_hemisphere_id = "UnknownPatient" # Should be from Cell 2 of Step 4
if 'analysis_session_plot_folder_step4' not in locals(): analysis_session_plot_folder_step4 = "./step4_plots" # Fallback

if 'N_CHANNELS_PER_OVERVIEW' not in locals(): N_CHANNELS_PER_OVERVIEW = len(ORDERED_CHANNEL_LABELS) if ORDERED_CHANNEL_LABELS else 2
if 'N_FREQ_BANDS_PER_OVERVIEW' not in locals(): N_FREQ_BANDS_PER_OVERVIEW = len(ORDERED_FREQ_LABELS) if ORDERED_FREQ_LABELS else 3
if 'BOX_FILL_ALPHA' not in locals(): BOX_FILL_ALPHA = 0.7
if 'BOXPLOT_LINE_THICKNESS' not in locals(): BOXPLOT_LINE_THICKNESS = 1.5
if 'DOT_ALPHA' not in locals(): DOT_ALPHA = 0.5
if 'MIN_SAMPLES_FOR_GROUP_COMPARISON' not in locals(): MIN_SAMPLES_FOR_GROUP_COMPARISON = 5
if 'P_VALUE_THRESHOLD' not in locals(): P_VALUE_THRESHOLD = 0.05

if 'annotate_p_value' not in locals(): # Fallback definition
    def annotate_p_value(ax, p_val, sig_threshold=0.05, custom_text=None, fontsize=8, y_pos=0.9, x_pos=0.98, N_val=""):
        if custom_text:
            text_to_display = custom_text
        elif pd.isna(p_val):
            text_to_display = f"P: N/A"
            if N_val: text_to_display += f"\\nN: {N_val}"
        else:
            stars = ""
            if p_val < 0.001: stars = "***"
            elif p_val < 0.01: stars = "**"
            elif p_val < sig_threshold: stars = "*"
            text_to_display = f"P: {p_val:.2e}{stars}"
            if N_val: text_to_display += f"\\nN: {N_val}"
        bg_color_ann = 'khaki' if not pd.isna(p_val) and p_val < sig_threshold else 'ivory'
        ax.text(x_pos, y_pos, text_to_display, transform=ax.transAxes, fontsize=fontsize,
                verticalalignment='top', horizontalalignment='right',
                bbox=dict(boxstyle='round,pad=0.2', fc=bg_color_ann, alpha=0.7))

# --- Cell 9 Specific State Definitions (5 States) ---
CELL9_TARGET_STATES_ORDERED = [ # New 5-state order for this cell
    "Sleep",
    "Immobile",
    "Non-Dyskinetic Mobile",
    "Transitional Mobile",
    "Dyskinetic Mobile"
]

CELL9_STATE_COLORS = { # Colors for Cell 9's specific 5 states
    'Sleep': '#4169E1',                 # RoyalBlue
    'Immobile': '#40E0D0',              # Turquoise
    'Non-Dyskinetic Mobile': '#32CD32', # LimeGreen
    'Transitional Mobile': '#FFD700',   # Gold
    'Dyskinetic Mobile': '#FF6347',     # Tomato
    # Add other potential states with a default color if they might appear
    'Other': '#C0C0C0'
}

if 'master_df_step4' not in locals() or master_df_step4.empty:
    print("master_df_step4 (base data) not available or empty. Skipping Cell 9.")
else:
    # Prepare data specifically for Cell 9's 5-state configuration
    # No remapping needed, just filtering for these 5 states
    df_cell9_input = master_df_step4[master_df_step4[CLINICAL_STATE_COL].isin(CELL9_TARGET_STATES_ORDERED)].copy()

    if not df_cell9_input.empty:
        df_cell9_input.loc[:, CLINICAL_STATE_COL] = pd.Categorical(
            df_cell9_input[CLINICAL_STATE_COL],
            categories=CELL9_TARGET_STATES_ORDERED,
            ordered=True
        )
        df_cell9_input.dropna(subset=[CLINICAL_STATE_COL], inplace=True)

        if 'datetime_for_avg' not in df_cell9_input.columns:
            if 'Aligned_PKG_UnixTimestamp' in df_cell9_input.columns:
                df_cell9_input.loc[:, 'datetime_for_avg'] = pd.to_datetime(
                    df_cell9_input['Aligned_PKG_UnixTimestamp'], unit='s', utc=True, errors='coerce'
                )
            else:
                print("ERROR in Cell 9 Prep: 'Aligned_PKG_UnixTimestamp' missing. 'datetime_for_avg' cannot be created.")
                df_cell9_input.loc[:, 'datetime_for_avg'] = pd.NaT
    
    if 'df_cell9_input' not in locals() or df_cell9_input.empty:
        print(f"No data found after filtering for Cell 9's target clinical states: {CELL9_TARGET_STATES_ORDERED}. Skipping Cell 9.")
    else:
        print(f"Data prepared for Cell 9. Shape: {df_cell9_input.shape}. Unique states: {df_cell9_input[CLINICAL_STATE_COL].unique()}")
        
        # UPDATED output directory name for these specific 5-state plots
        plot_subdir_ap_dist_cell9_5states = os.path.join(analysis_session_plot_folder_step4, "Distributions_Cell9_5States_Sl_Im_NDM_TM_DM")
        os.makedirs(plot_subdir_ap_dist_cell9_5states, exist_ok=True)

        if not ORDERED_CHANNEL_LABELS:
            first_two_channels, last_two_channels = [], []
        else:
            first_two_channels = ORDERED_CHANNEL_LABELS[:2]
            last_two_channels = ORDERED_CHANNEL_LABELS[-2:] if len(ORDERED_CHANNEL_LABELS) > 2 else []
        
        print(f"  Exponent filter (<=3.0) for 'Exponent_BestModel' on channels: {first_two_channels if first_two_channels else 'None'}.")
        if last_two_channels:
            print(f"  Exponent filter (<=5.0) for 'Exponent_BestModel' on channels: {last_two_channels}.")
        print(f"  Strip plot points will be 10-minute averages.")
        print(f"  Y-axis for box plots will be fixed to (0, 5.5).")

        subplot_base_width_inches = 2.5 # Adjusted for potentially 5 boxes
        subplot_height_inches = subplot_base_width_inches * (16/9.0) 

        all_kruskal_wallis_results_cell9 = []

        for ap_metric_col, ap_metric_name in APERIODIC_METRICS_COLS.items():
            y_label_ap_metric = f'{ap_metric_name}'
            print(f"\\n  Processing {ap_metric_name} across Cell 9's 5 Clinical States")

            channels_to_plot_in_fig = ORDERED_CHANNEL_LABELS[:N_CHANNELS_PER_OVERVIEW]
            if not channels_to_plot_in_fig: 
                print(f"    No channels selected for overview plot of {ap_metric_name}. Skipping.")
                continue
            
            num_plot_rows = len(channels_to_plot_in_fig)
            num_plot_cols = N_FREQ_BANDS_PER_OVERVIEW

            fig_overview, axes_overview = plt.subplots(
                num_plot_rows, num_plot_cols, 
                figsize=(num_plot_cols * subplot_base_width_inches, num_plot_rows * subplot_height_inches), 
                sharey=False, 
                squeeze=False
            )
            fig_overview.suptitle(f"Overview (Cell 9): {ap_metric_name} by 5 Clinical States (K-W & Dunn's)\\nPatient: {patient_hemisphere_id} (Y-axis 0-5.5, Pts: 10-min avg)",
                                  fontsize=plt.rcParams['figure.titlesize']*0.85, y=1.04) # Adjusted y for suptitle

            for ch_idx, channel_label_plot in enumerate(channels_to_plot_in_fig):
                df_channel_current = df_cell9_input[df_cell9_input[CHANNEL_DISPLAY_COL] == channel_label_plot]
                
                for fr_idx, freq_label_plot in enumerate(ORDERED_FREQ_LABELS): 
                    ax_current = axes_overview[ch_idx, fr_idx]
                    df_stratum_plot = df_channel_current[df_channel_current[FOOOF_FREQ_BAND_COL] == freq_label_plot].copy()
                    
                    # Apply exponent filtering by setting out-of-range values to NaN, then dropping NaNs for that metric
                    if ap_metric_col == 'Exponent_BestModel':
                        if channel_label_plot in first_two_channels:
                            df_stratum_plot.loc[df_stratum_plot[ap_metric_col] > 3.0, ap_metric_col] = np.nan
                        elif channel_label_plot in last_two_channels and channel_label_plot not in first_two_channels:
                            df_stratum_plot.loc[df_stratum_plot[ap_metric_col] > 5.0, ap_metric_col] = np.nan
                        df_stratum_plot.dropna(subset=[ap_metric_col], inplace=True)
                    
                    ax_current.set_title(f"{channel_label_plot}\\n{freq_label_plot}", fontsize=plt.rcParams['axes.titlesize'] * 0.7)
                    if ch_idx == num_plot_rows - 1: ax_current.set_xlabel("Clinical State", fontsize=plt.rcParams['axes.labelsize'] * 0.75)
                    else: ax_current.set_xlabel("")
                    if fr_idx == 0: ax_current.set_ylabel(y_label_ap_metric, fontsize=plt.rcParams['axes.labelsize'] * 0.8)
                    else: ax_current.set_ylabel("")
                    
                    ax_current.set_ylim(0, 5.5) # Y-AXIS LIMIT

                    if df_stratum_plot.empty or \
                       df_stratum_plot[ap_metric_col].isnull().all() or \
                       ('datetime_for_avg' in df_stratum_plot.columns and df_stratum_plot['datetime_for_avg'].isnull().all()):
                        ax_current.text(0.5, 0.5, "Insufficient Data", ha='center', va='center', transform=ax_current.transAxes, fontsize=8)
                        ax_current.set_xticks([])
                        continue

                    boxplot_viz_data_list = []
                    for state_val_box in CELL9_TARGET_STATES_ORDERED: 
                        state_data_for_box_plot = df_stratum_plot[df_stratum_plot[CLINICAL_STATE_COL] == state_val_box]
                        if not state_data_for_box_plot.empty:
                            trimmed_data = trim_data_for_boxplot_visualization(state_data_for_box_plot, ap_metric_col)
                            trimmed_data.loc[:, CLINICAL_STATE_COL] = state_val_box 
                            boxplot_viz_data_list.append(trimmed_data)
                    final_boxplot_viz_df = pd.concat(boxplot_viz_data_list) if boxplot_viz_data_list else pd.DataFrame()

                    strip_plot_resampled_list = []
                    if 'datetime_for_avg' in df_stratum_plot.columns:
                        for state_val_strip in CELL9_TARGET_STATES_ORDERED: 
                            state_data_for_strip_plot = df_stratum_plot[df_stratum_plot[CLINICAL_STATE_COL] == state_val_strip]
                            if not state_data_for_strip_plot.empty and not state_data_for_strip_plot['datetime_for_avg'].isnull().all():
                                try:
                                    resampled_data = state_data_for_strip_plot.set_index('datetime_for_avg')\
                                                                    .resample('10T')[[ap_metric_col]].mean().dropna()
                                    if not resampled_data.empty:
                                        resampled_data[CLINICAL_STATE_COL] = state_val_strip
                                        strip_plot_resampled_list.append(resampled_data)
                                except Exception as e_resample:
                                    print(f"    Resampling error for {channel_label_plot}, {freq_label_plot}, {state_val_strip}: {e_resample}")
                    df_for_stripplot_resampled = pd.concat(strip_plot_resampled_list) if strip_plot_resampled_list else pd.DataFrame()

                    if not final_boxplot_viz_df.empty and not final_boxplot_viz_df[ap_metric_col].isnull().all():
                        sns.boxplot(data=final_boxplot_viz_df, x=CLINICAL_STATE_COL, y=ap_metric_col, 
                                    order=CELL9_TARGET_STATES_ORDERED, 
                                    palette=CELL9_STATE_COLORS, 
                                    showfliers=False, width=0.7, ax=ax_current, # width can be adjusted for 5 groups
                                    boxprops={'alpha': BOX_FILL_ALPHA, 'linewidth': BOXPLOT_LINE_THICKNESS}, 
                                    medianprops={'linewidth': BOXPLOT_LINE_THICKNESS, 'color':'black'})
                    
                    if not df_for_stripplot_resampled.empty and not df_for_stripplot_resampled[ap_metric_col].isnull().all():
                        sns.stripplot(data=df_for_stripplot_resampled, x=CLINICAL_STATE_COL, y=ap_metric_col, 
                                      order=CELL9_TARGET_STATES_ORDERED,
                                      palette=CELL9_STATE_COLORS, 
                                      jitter=0.1, alpha=DOT_ALPHA - 0.2 if DOT_ALPHA > 0.2 else 0.1, # Adjusted jitter and alpha
                                      size=3.0, ax=ax_current) # Adjusted size
                    
                    ax_current.tick_params(axis='y', labelsize=plt.rcParams['ytick.labelsize'] * 0.75)
                    # Create shorter labels for x-axis if needed, now for 5 states
                    xtick_labels_c9_updated = [s.replace("Non-Dyskinetic Mobile", "NDM")
                                               .replace("Transitional Mobile", "TM")
                                               .replace("Dyskinetic Mobile", "DM")
                                               .replace("Immobile", "Imm.") 
                                               .replace(" Mobile", "\nMob.") # General catch for mobile if any other variations
                                               for s in CELL9_TARGET_STATES_ORDERED]
                    ax_current.set_xticklabels(xtick_labels_c9_updated, rotation=45, ha="right", fontsize=plt.rcParams['xtick.labelsize'] * 0.65)

                    groups_for_stat_test_kw, group_names_for_stat_test_kw, group_ns_kw = [], [], {}
                    for state_val_kw in CELL9_TARGET_STATES_ORDERED: 
                        data_kw = df_stratum_plot[df_stratum_plot[CLINICAL_STATE_COL] == state_val_kw][ap_metric_col].dropna()
                        group_ns_kw[state_val_kw] = len(data_kw)
                        if len(data_kw) >= MIN_SAMPLES_FOR_GROUP_COMPARISON:
                            groups_for_stat_test_kw.append(data_kw)
                            group_names_for_stat_test_kw.append(state_val_kw)
                    
                    kw_h_val, kw_p_val_stat, annotation_text_kw = np.nan, np.nan, ""
                    if len(groups_for_stat_test_kw) >= 2: 
                        try:
                            kw_h_val, kw_p_val_stat = kruskal(*groups_for_stat_test_kw)
                            annotation_text_kw += f"K-W P: {kw_p_val_stat:.1e}" # Slightly more precision for K-W P
                            if kw_p_val_stat < P_VALUE_THRESHOLD: annotation_text_kw += "*"
                            
                            stat_entry_kw = {'PatientHemisphereID': patient_hemisphere_id,
                                'Channel': channel_label_plot, 'FrequencyBand': freq_label_plot, 'AperiodicMetric': ap_metric_name,
                                'Kruskal_H_statistic': kw_h_val, 'Kruskal_p_value': kw_p_val_stat,
                                'Compared_Groups_Kruskal': ", ".join(group_names_for_stat_test_kw),
                                'Group_Ns_Kruskal': "; ".join([f"{name.replace(' ','')[:10]}:{group_ns_kw[name]}" for name in group_names_for_stat_test_kw])}

                            if kw_p_val_stat < P_VALUE_THRESHOLD and len(groups_for_stat_test_kw) > 2:
                                dunn_results_df = sp.posthoc_dunn(groups_for_stat_test_kw, p_adjust='bonferroni')
                                dunn_results_df.columns = group_names_for_stat_test_kw
                                dunn_results_df.index = group_names_for_stat_test_kw
                                
                                pairs_to_annotate_dunn = [] # Define relevant pairs for 5 states
                                if "Immobile" in group_names_for_stat_test_kw:
                                    for other_st in ["Non-Dyskinetic Mobile", "Transitional Mobile", "Dyskinetic Mobile", "Sleep"]:
                                        if other_st in group_names_for_stat_test_kw: pairs_to_annotate_dunn.append(("Immobile", other_st))
                                if "Sleep" in group_names_for_stat_test_kw:
                                     for mobile_st in ["Non-Dyskinetic Mobile", "Transitional Mobile", "Dyskinetic Mobile"]:
                                         if mobile_st in group_names_for_stat_test_kw: pairs_to_annotate_dunn.append(("Sleep", mobile_st))
                                
                                annotation_text_kw += "\nDunn's (Bonf):"
                                annot_count = 0
                                unique_dunn_pairs_annotated = set()

                                for g1_dunn, g2_dunn in pairs_to_annotate_dunn:
                                    # Ensure pair is ordered to avoid duplicate checks (e.g. (A,B) vs (B,A)) for display
                                    sorted_pair = tuple(sorted((g1_dunn, g2_dunn)))
                                    if annot_count < 2 and sorted_pair not in unique_dunn_pairs_annotated and \
                                       g1_dunn in dunn_results_df.index and g2_dunn in dunn_results_df.columns:
                                        pair_p_val_dunn = dunn_results_df.loc[g1_dunn, g2_dunn]
                                        if pd.notna(pair_p_val_dunn):
                                            g1_short = g1_dunn.replace("Non-Dyskinetic Mobile", "NDM").replace("Transitional Mobile","TM").replace("Dyskinetic Mobile","DM")[:3]
                                            g2_short = g2_dunn.replace("Non-Dyskinetic Mobile", "NDM").replace("Transitional Mobile","TM").replace("Dyskinetic Mobile","DM")[:3]
                                            pair_text_dunn = f"\n{g1_short}v{g2_short}:{pair_p_val_dunn:.1e}"
                                            if pair_p_val_dunn < P_VALUE_THRESHOLD: pair_text_dunn += "*"
                                            annotation_text_kw += pair_text_dunn
                                            annot_count +=1
                                            unique_dunn_pairs_annotated.add(sorted_pair)
                                        stat_entry_kw[f"Dunn_{g1_dunn.replace(' ','').replace('-','')}_vs_{g2_dunn.replace(' ','').replace('-','')}_p_adj"] = pair_p_val_dunn
                            all_kruskal_wallis_results_cell9.append(stat_entry_kw)
                        except ValueError as e_stat_kw:
                            annotation_text_kw = "Stat Err"
                            print(f"    Stat error for K-W: {channel_label_plot}, {freq_label_plot}: {e_stat_kw}")
                    else:
                        annotation_text_kw = "N<2 valid grps"
                    
                    annotate_p_value(ax_current, kw_p_val_stat if pd.notna(kw_p_val_stat) else 1.0, 
                                     sig_threshold=P_VALUE_THRESHOLD,
                                     fontsize=4.5, y_pos=0.99, x_pos=0.99, # Even smaller for more text
                                     custom_text=annotation_text_kw)
            
            plt.tight_layout(rect=[0, 0.03, 1, 0.93]) 
            overview_filename_kw_cell9 = f"Overview_Box_Cell9_5States_{ap_metric_name.replace(' ','').replace('/','_')}.png"
            plt.savefig(os.path.join(plot_subdir_ap_dist_cell9_5states, overview_filename_kw_cell9), dpi=300)
            print(f"  Saved Cell 9 (5-State) overview box plot for {ap_metric_name}: {overview_filename_kw_cell9}")
            plt.close(fig_overview)

        if all_kruskal_wallis_results_cell9:
            df_all_kw_stats_cell9 = pd.DataFrame(all_kruskal_wallis_results_cell9)
            kw_stats_filename_cell9 = f"{patient_hemisphere_id}_Cell9_5States_KruskalWallis_Dunn_Stats.csv"
            kw_stats_save_path_cell9 = os.path.join(plot_subdir_ap_dist_cell9_5states, kw_stats_filename_cell9)
            df_all_kw_stats_cell9.to_csv(kw_stats_save_path_cell9, index=False)
            print(f"\\n  Saved all Cell 9 (5-State) Kruskal-Wallis & Dunn's statistical results to: {kw_stats_save_path_cell9}")

print("\\n--- Cell 9 (MODIFIED for 5 States - Sleep, Immobile, NDM, TM, DM & 10-min avg points): Box Plots by Clinical States generation attempt complete. ---")

In [None]:
# -*- coding: utf-8 -*-
# --- Cell 11: Generate Final Data Table for Cross-Subject Analysis (Input for Step 5) ---
# This cell prepares the output from Step 4 to be used as input for Step 5.
# The master_df_step4 already contains all necessary information, including
# aperiodic metrics for EACH FreqRangeLabel, LEDD, Beta, Gamma.

print("\n--- Cell 11: Preparing Data Table for Step 5 (Cross-Subject Analysis) ---")

if master_df_step4 is None or master_df_step4.empty:
    print("master_df_step4 not available or empty. Cannot generate final table for Step 5.")
else:
    # Columns to include in the output for Step 5
    # Should match 'master_table_columns' from Step 3 Cell 2, plus UserSessionName from Step 3 Cell 8
    # Ensure 'UserSessionName' is defined. If this script is run standalone for one patient-hemi,
    # 'UserSessionName' would be the patient_hemisphere_id.
    
    # Columns defined in Step 3's master_table_columns (Cell 2 of Step 3)
    # This list must be kept in sync with the actual columns produced by Step 3.
    # For robustness, we select columns that are ACTUALLY PRESENT in master_df_step4
    # and try to match the intended set.
    
    intended_step5_cols = [
        'SessionID', 'Hemisphere', 'Channel', CHANNEL_DISPLAY_COL, # CHANNEL_DISPLAY_COL is 'ElectrodeLabel' or similar
        'Neural_Segment_Start_Unixtime', 'Neural_Segment_End_Unixtime',
        'Neural_Segment_Duration_Sec', 'FS',
        # PSD_Data_Str and Frequency_Vector_Str are usually too large for group analysis files
        # 'PSD_Data_Str', 'Frequency_Vector_Str', 
        'Aligned_PKG_UnixTimestamp', 'Aligned_PKG_DateTime_Str', 
        CLINICAL_STATE_COL, CLINICAL_STATE_AGGREGATED_COL, # Clinical states
        'Aligned_BK', 'Aligned_DK', 'Aligned_Tremor_Score', 'Aligned_Tremor',
        # New Metrics
        'Total_Daily_LEDD_mg',
        'Beta_Peak_Power_at_DominantFreq',
        'Gamma_Peak_Power_at_DominantFreq',
        # FOOOF Results - these are per FreqRangeLabel, so FreqRangeLabel must be included
        FOOOF_FREQ_BAND_COL, 'FreqLow', 'FreqHigh', # FreqRangeLabel is critical here
        'BestModel_AperiodicMode',
        'Offset_BestModel', 'Knee_BestModel', 'Exponent_BestModel',
        'R2_BestModel', 'Error_BestModel', 'Num_Peaks_BestModel',
        # Optionally, detailed fixed/knee params if needed for specific Step 5 checks:
        # 'Offset_Fixed', 'Exponent_Fixed', 'R2_Fixed', 'Error_Fixed', 'Num_Peaks_Fixed',
        # 'Offset_Knee', 'Knee_Knee', 'Exponent_Knee', 'R2_Knee', 'Error_Knee', 'Num_Peaks_Knee',
        # 'ErrorMsg_FOOOF'
    ]
    
    final_table_cols_step5_existing = [col for col in intended_step5_cols if col in master_df_step4.columns]
    
    if not final_table_cols_step5_existing:
        print("Warning: No columns identified for the Step 5 data table. It will be empty.")
        final_data_table_for_step5 = pd.DataFrame()
    else:
        final_data_table_for_step5 = master_df_step4[final_table_cols_step5_existing].copy()
        
        # Add 'UserSessionName' which was previously added in Step 3 Cell 8.
        # Here, we re-affirm it as the patient_hemisphere_id for this file.
        if 'UserSessionName' not in final_data_table_for_step5.columns:
            final_data_table_for_step5.insert(0, 'UserSessionName', patient_hemisphere_id)
        else: # If it was somehow carried over from a loaded file that already had it
            final_data_table_for_step5['UserSessionName'] = patient_hemisphere_id


        # Optional: Sort the table
        sort_by_cols_step5 = ['UserSessionName', 'Aligned_PKG_UnixTimestamp', CHANNEL_DISPLAY_COL, FOOOF_FREQ_BAND_COL]
        sort_by_cols_step5_existing = [col for col in sort_by_cols_step5 if col in final_data_table_for_step5.columns]
        if sort_by_cols_step5_existing:
            final_data_table_for_step5.sort_values(by=sort_by_cols_step5_existing, inplace=True, ignore_index=True)

        print(f"  Final data table for Step 5 created with {final_data_table_for_step5.shape[0]} rows and {final_data_table_for_step5.shape[1]} columns.")
        print(f"  Columns included: {final_data_table_for_step5.columns.tolist()}")

    # Define filename and save (this output path should ideally be outside the patient-specific plot folder,
    # in a place where Step 5 can glob all such files)
    # The original Step 4 saved this in analysis_plots_root_folder (one level up from session_plot_folder_name_step4)
    
    output_filename_for_step5 = f"{patient_hemisphere_id}_CrossSubjectAnalysis_DataTable_{current_datetime_str_step4}.csv"
    # Save in the root of the Step 4 analysis folder (step4_analysis_root_folder)
    # This aligns with where Step 5 would look for inputs from multiple subjects.
    output_path_for_step5 = os.path.join(step4_analysis_root_folder, output_filename_for_step5)

    try:
        final_data_table_for_step5.to_csv(output_path_for_step5, index=False)
        print(f"  Successfully saved final data table for Step 5 input to: {output_path_for_step5}")
        print("\n  Sample of this final data table (first 5 rows):")
        print(final_data_table_for_step5.head())
    except Exception as e_save_final_step4:
        print(f"  ERROR saving the final data table for Step 5 input: {e_save_final_step4}")

print(f"\n--- Cell 11: Final Data Table generation for {patient_hemisphere_id} complete ---")
print(f"\n--- All Step 4 processing for {patient_hemisphere_id} complete. Outputs are in {analysis_session_plot_folder_step4} and {step4_analysis_root_folder} ---")