In [1]:
# -*- coding: utf-8 -*-
# --- Cell 1: Centralized Imports and Global Configuration ---

# --- Part 1: All Library Imports ---
# Python Core Libraries
import os
import glob
import re
import warnings
from datetime import datetime, time

# Data Handling & Scientific Computing
import numpy as np
import pandas as pd
import pytz
from scipy.stats import (pearsonr, spearmanr, mannwhitneyu, t, kruskal)
import scikit_posthocs as sp

# Statistical Modeling
import pingouin as pg
import statsmodels.api as sm
import statsmodels.formula.api as smf
from statsmodels.graphics.regressionplots import plot_partregress_grid
from statsmodels.stats.multitest import fdrcorrection
from statsmodels.stats.anova import anova_lm

# Plotting & Visualization
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib.patches as mpatches
import matplotlib.lines as mlines
from matplotlib.ticker import FormatStrFormatter
import seaborn as sns

warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=UserWarning)
warnings.filterwarnings("ignore", message="Degrees of freedom <= 0 for slice")
warnings.filterwarnings("ignore", message="p-value may not be accurate for N > 5000")
warnings.filterwarnings("ignore", message="invalid value encountered in scalar divide")
warnings.filterwarnings("ignore", message="Confidence interval might not be reliable for bootstrap samples with fewer than 50 elements.")

# --- Part 3: User Input and Path Configuration ---
patient_hemisphere_id = "COHORT_RCS02_05_06"  # merged cohort label for outputs

project_base_path = "/home/jackson/step2_final"
step3_output_version_tag = "neural_pkg_aligned_finalstep3_bushlab5000" # <<< USER: Ensure this matches Step 3's tag

# Derived Paths (no user input needed below this line)
step3_master_csv_base_folder = os.path.join(project_base_path, f'step3_fooof_results_{step3_output_version_tag}')
master_csv_filename = f"MASTER_FOOOF_PKG_results_{patient_hemisphere_id}_{step3_output_version_tag}.csv"
master_csv_path_to_load = os.path.join(step3_master_csv_base_folder, master_csv_filename)
step4_analysis_root_folder = os.path.join(step3_master_csv_base_folder, 'step4_within_subject')
os.makedirs(step4_analysis_root_folder, exist_ok=True)
current_datetime_str_step4 = datetime.now().strftime('%Y%m%d_%H%M%S')
session_plot_folder_name_step4 = f"{patient_hemisphere_id}_plots_{current_datetime_str_step4}"
analysis_session_plot_folder_step4 = os.path.join(step4_analysis_root_folder, session_plot_folder_name_step4)
os.makedirs(analysis_session_plot_folder_step4, exist_ok=True)

# --- Part 4: Column Name and Metric Definitions ---
# Metric Column Dictionaries
APERIODIC_METRICS_COLS = {
    'Exponent_BestModel': 'Aperiodic Exponent',
    'Offset_BestModel': 'Aperiodic Offset',
}
PKG_METRICS_COLS = {
    'Aligned_BK': 'PKG BK Score',
    'Aligned_DK': 'PKG DK Score',
    'Aligned_Tremor_Score': 'PKG Tremor Score'
}
OSCILLATORY_METRICS_COLS = {
    'Beta_Peak_Power_at_DominantFreq': 'Beta Peak Power',
    'Gamma_Peak_Power_at_DominantFreq': 'Gamma Peak Power'
}
APERIODIC_METRICS_TO_PLOT = ['Exponent_BestModel'] # Used in daily exponent plots

# Key Column Names
CHANNEL_COL = 'Channel'
CHANNEL_DISPLAY_COL = 'Channel_Display'
FOOOF_FREQ_BAND_COL = 'FreqRangeLabel'
CLINICAL_STATE_COL = 'Clinical_State_2min_Window'
CLINICAL_STATE_AGGREGATED_COL = 'Clinical_State_Aggregated'

# --- Part 5: Analysis and Plotting Parameters ---
# Ordering for Iterations and Plots
ORDERED_FREQ_LABELS = ["LowFreq", "MidFreq", "WideFreq"]
# Normalized contact labels (works for Contact_*_* and keyX_contact_*_*)
CHANNEL_ORDER_LIST = ['Contact_2_0', 'Contact_3_0', 'Contact_3_1', 'Contact_10_8', 'Contact_11_9']  # keep any that exist in data
CHANNEL_ORDER_MAP = {lab: i for i, lab in enumerate(CHANNEL_ORDER_LIST)}
CHANNEL_GROUP_MAP = {
    'STN': ['Contact_2_0', 'Contact_3_0', 'Contact_3_1'],
    'M1':  ['Contact_10_8', 'Contact_11_9']
}

# Statistical Thresholds
P_VALUE_THRESHOLD = 0.05
MIN_SAMPLES_FOR_CORR = 5
MIN_SAMPLES_FOR_GROUP_COMPARISON = 5
R2_FILTER_THRESHOLD = 0.5

# Clinical State Definitions and Ordering
TARGET_CLINICAL_STATES_ORDERED = ["Immobile", "Non-Dyskinetic Mobile", "Transitional Mobile", "Dyskinetic Mobile"]
ALL_CLINICAL_STATES_ORDERED = ["Sleep", "Immobile", "Non-Dyskinetic Mobile", "Transitional Mobile", "Dyskinetic Mobile"]
SYMPTOM_ORDER = ['PKG BK Score', 'PKG DK Score', 'PKG Tremor Score']
SYMPTOM_LEGEND_MAP = {'PKG BK Score': 'Bradykinesia', 'PKG DK Score': 'Dyskinesia'} #, 'PKG Tremor Score': 'Tremor'}
SYMPTOM_DISPLAY_ORDER = ['Bradykinesia', 'Dyskinesia']#, 'Tremor']

# Daily Plot Parameters
SF_TZ = pytz.timezone('America/Los_Angeles')
PLOTTING_INTERVAL_MINUTES = 10
MIN_POINTS_FOR_CI = 2
CONFIDENCE_LEVEL_CI = 0.95
GAP_THRESHOLD_BINS = 2

# MLR Analysis Toggles
ANALYZE_ALL_FREQ_BANDS_GLOBAL = False
TARGET_FREQ_BAND_GLOBAL = "WideFreq"
ANALYZE_ALL_FREQ_BANDS_STATE_SPECIFIC = False
TARGET_FREQ_BAND_STATE_SPECIFIC = "WideFreq"

# --- Part 6: Global Plotting Style Configuration ---
sns.set_theme(style="whitegrid")
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = ['Arial', 'DejaVu Sans']
plt.rcParams['axes.titlesize'] = 14
plt.rcParams['axes.labelsize'] = 12
plt.rcParams['xtick.labelsize'] = 10
plt.rcParams['ytick.labelsize'] = 10
plt.rcParams['legend.fontsize'] = 10
plt.rcParams['figure.titlesize'] = 18
plt.rcParams['figure.dpi'] = 100
plt.rcParams['savefig.dpi'] = 600

# Color Palettes
BASE_COLOR_PALETTE = {
    'Exponent_BestModel': 'darkslateblue',
    'Offset_BestModel': 'mediumseagreen',
    'Beta_Peak_Power_at_DominantFreq': 'goldenrod',
    'Gamma_Peak_Power_at_DominantFreq': 'firebrick',
    'Aligned_BK': 'steelblue',
    'Aligned_DK': 'orangered',
    'Aligned_Tremor_Score': 'mediumpurple'
}
CLINICAL_STATE_COLORS = {
    'Immobile': '#40E0D0',              # Turquoise
    'Non-Dyskinetic Mobile': '#32CD32', # LimeGreen
    'Transitional Mobile': '#FFD700',   # Gold
    'Dyskinetic Mobile': '#FF6347',     # Tomato
    'Sleep': '#4169E1',                 # RoyalBlue
    'Other': '#C0C0C0',                 # Silver
    'Mobile (All Types)': 'darkgreen'
}
PKG_SYMPTOM_COLORS = {
    'Aligned_BK': BASE_COLOR_PALETTE.get('Aligned_BK', 'steelblue'),
    'Aligned_DK': BASE_COLOR_PALETTE.get('Aligned_DK', 'orangered'),
    'Aligned_Tremor_Score': BASE_COLOR_PALETTE.get('Aligned_Tremor_Score', 'mediumpurple')
}

# Other Plotting Style Constants
DOT_ALPHA = 0.5
REG_CI_ALPHA = 0.15
BOX_FILL_ALPHA = 0.6
BOXPLOT_LINE_THICKNESS = 2.25
REG_LINE_THICKNESS = 2.0
SIGNIFICANT_P_VAL_BG_COLOR = 'khaki'
DEFAULT_P_VAL_BG_COLOR = 'ivory'

# ==== Global analysis toggles (for region-only outputs, minimal plotting) ====
ANALYSIS_LEVEL = 'region'      # 'region' or 'contact'
EXCLUDE_MIXED = True           # drop cross-pair bins labeled 'Mixed' when ANALYSIS_LEVEL == 'region'
ENABLE_PLOTS = True            # master switch for plotting; set False to skip all plots
PLOT_GROUPS = ['STN', 'M1']    # which regions to include in any plots
DISABLE_CONTACT_PLOTS = True   # enforce: never generate per-contact plots
MIN_SAMPLES_PER_BIN = 2        # 10-min bin must have at least this many rows to be kept
# ============================================================================

print("Cell 1: All imports and global parameters have been defined.")


Cell 1: All imports and global parameters have been defined.


In [2]:
import os, pathlib
print("CWD:", os.getcwd())
print("project_base_path:", project_base_path)
print("resolved base:", pathlib.Path(project_base_path).resolve())

print("tag:", step3_output_version_tag)
print("step3 folder:", os.path.abspath(step3_master_csv_base_folder))
print("exists?", os.path.exists(step3_master_csv_base_folder), 
      "is_dir?", os.path.isdir(step3_master_csv_base_folder))

parent = pathlib.Path(step3_master_csv_base_folder).resolve().parent
print("parent writable?", os.access(str(parent), os.W_OK))
print("step3 folder writable?", os.access(str(pathlib.Path(step3_master_csv_base_folder)), os.W_OK))


CWD: /home/jackson
project_base_path: /home/jackson/step2_final
resolved base: /home/jackson/step2_final
tag: neural_pkg_aligned_finalstep3_bushlab5000
step3 folder: /home/jackson/step2_final/step3_fooof_results_neural_pkg_aligned_finalstep3_bushlab5000
exists? True is_dir? True
parent writable? True
step3 folder writable? True


In [3]:
# -*- coding: utf-8 -*-
# --- Cell 2: User Input, File Paths, and Analysis Parameter Definitions ---

# --- User Input ---
# This script processes ONE patient-hemisphere at a time.
patient_hemisphere_id = "COHORT_RCS02_05_06"  # merged cohort label for outputs

print(f"Processing data for Patient-Hemisphere ID: {patient_hemisphere_id}")

if not patient_hemisphere_id:
    raise ValueError("Patient-Hemisphere ID cannot be empty.")

# --- Path Configuration ---
print(f"Project base path determined as: {project_base_path}")
print(f"Attempting to load master data from: {master_csv_path_to_load}")
print(f"Step 4 plots will be saved in: {analysis_session_plot_folder_step4}")

# Plot styling used downstream (kept here to avoid surprises)
DOT_ALPHA_STEP4 = 0.5
REG_CI_ALPHA_STEP4 = 0.15
BOX_FILL_ALPHA_STEP4 = 0.6
REG_LINE_THICKNESS_STEP4 = 2.0

# --- Echo core analysis toggles so logs are explicit ---
try:
    _lvl = ANALYSIS_LEVEL
except NameError:
    _lvl = 'contact'
    print("WARNING: ANALYSIS_LEVEL not defined in Cell 1; defaulting to 'contact'.")

try:
    _exclude_mixed = EXCLUDE_MIXED
except NameError:
    _exclude_mixed = False

try:
    _min_per_bin = MIN_SAMPLES_PER_BIN
except NameError:
    _min_per_bin = 1

try:
    _enable_plots = ENABLE_PLOTS
except NameError:
    _enable_plots = True

print("\n=== Step 4 Analysis Config ===")
print(f"Analysis level         : {_lvl}  (region -> STN vs M1; contact -> per-contact)")
print(f"Exclude 'Mixed' bins   : {_exclude_mixed}")
print(f"Min samples per 10-min : {_min_per_bin}")
print(f"Enable plotting        : {_enable_plots}")
print("==============================\n")

print(f"Step 4 analysis parameters and paths configured for {patient_hemisphere_id}.")


Processing data for Patient-Hemisphere ID: COHORT_RCS02_05_06
Project base path determined as: /home/jackson/step2_final
Attempting to load master data from: /home/jackson/step2_final/step3_fooof_results_neural_pkg_aligned_finalstep3_bushlab5000/MASTER_FOOOF_PKG_results_COHORT_RCS02_05_06_neural_pkg_aligned_finalstep3_bushlab5000.csv
Step 4 plots will be saved in: /home/jackson/step2_final/step3_fooof_results_neural_pkg_aligned_finalstep3_bushlab5000/step4_within_subject/COHORT_RCS02_05_06_plots_20250822_155044

=== Step 4 Analysis Config ===
Analysis level         : region  (region -> STN vs M1; contact -> per-contact)
Exclude 'Mixed' bins   : True
Min samples per 10-min : 2
Enable plotting        : True

Step 4 analysis parameters and paths configured for COHORT_RCS02_05_06.


In [4]:
# -*- coding: utf-8 -*-
# --- Cell 3: Data Loading and Initial Preprocessing (Stable, fast, SF-aligned bins) ---
import traceback

# Diagnostics: 'off' | 'fast' | 'full'
DIAG_MODE = 'off'

# Detect Polars once (module-scope). Do NOT reassign this inside functions.
try:
    import polars as pl
    HAS_POLARS = True
    print("Polars detected: using multi-threaded aggregation.")
except Exception:
    HAS_POLARS = False
    print("Polars not available: using optimized pandas path.")

def _first_contact_label(s: str):
    s = '' if s is None else str(s)
    m = re.search(r'(?:^|_)contact_(\d+)_(\d+)', s, flags=re.IGNORECASE)
    return f"Contact_{m.group(1)}_{m.group(2)}" if m else None

def _normalize_contact_label(s: str) -> str:
    return _first_contact_label(s) or str(s)

def _ensure_channel_display(df_in: pd.DataFrame) -> pd.DataFrame:
    df_out = df_in.copy()
    if CHANNEL_DISPLAY_COL in df_out.columns:
        df_out[CHANNEL_DISPLAY_COL] = df_out[CHANNEL_DISPLAY_COL].map(_normalize_contact_label)
        return df_out

    candidate_cols = [
        CHANNEL_COL, 'electrode_label', 'electrode', 'ChannelLabel',
        'Channel_Name', 'ChannelDisplay'
    ]
    candidate_cols += [c for c in df_out.columns if 'contact' in c.lower() or 'electrode' in c.lower()]

    for c in candidate_cols:
        if c in df_out.columns:
            tmp = df_out[c].map(_first_contact_label)
            if tmp.notna().any():
                df_out[CHANNEL_DISPLAY_COL] = tmp.fillna(df_out[c].astype(str))
                print(f"Built '{CHANNEL_DISPLAY_COL}' from column '{c}'.")
                return df_out

    for c in [c for c in df_out.columns if df_out[c].dtype == object]:
        tmp = df_out[c].map(_first_contact_label)
        if tmp.notna().any():
            df_out[CHANNEL_DISPLAY_COL] = tmp.fillna(df_out[c].astype(str))
            print(f"Built '{CHANNEL_DISPLAY_COL}' by scanning column '{c}'.")
            return df_out

    raise KeyError(
        f"Could not construct '{CHANNEL_DISPLAY_COL}'. "
        f"None of the columns contained contact_#_# patterns. "
        f"Please set CHANNEL_COL to the correct source."
    )

def load_and_preprocess_step4_data(file_path, patient_hemisphere_id_val):
    """Loads Step 3 CSV and returns: (master_df_step4 [10-min per-contact/state bins], ORDERED_CHANNEL_LABELS)."""
    if not os.path.exists(file_path):
        print(f"ERROR: Master CSV file from Step 3 not found at {file_path}")
        return None, []

    try:
        # Read ALL columns to keep schema identical to the original (~46 cols)
        df = pd.read_csv(file_path)
        print(f"Successfully loaded {file_path}. Initial shape: {df.shape}")

        # Ensure Patient ID presence (non-fatal)
        if 'SessionID' in df.columns and df['SessionID'].nunique() == 1:
            csv_session_id = df['SessionID'].iloc[0]
            if csv_session_id != patient_hemisphere_id_val:
                print(f"Warning: SessionID in CSV ({csv_session_id}) != expected ({patient_hemisphere_id_val}). Proceeding.")
        elif 'SessionID' not in df.columns:
            df['SessionID'] = patient_hemisphere_id_val

        # Build Channel_Display robustly
        df = _ensure_channel_display(df)

        # Cast numeric metrics
        cols_to_numeric = (
            list(APERIODIC_METRICS_COLS.keys()) +
            list(PKG_METRICS_COLS.keys()) +
            list(OSCILLATORY_METRICS_COLS.keys()) +
            ['Total_Daily_LEDD_mg', 'R2_BestModel', 'Error_BestModel', 'Num_Peaks_BestModel']
        )
        for col in cols_to_numeric:
            if col in df.columns:
                df[col] = pd.to_numeric(df[col], errors='coerce')

        # Ensure key categoricals are strings
        for col in [CHANNEL_COL, CHANNEL_DISPLAY_COL, FOOOF_FREQ_BAND_COL,
                    CLINICAL_STATE_COL, CLINICAL_STATE_AGGREGATED_COL, 'Hemisphere', 'BestModel_AperiodicMode']:
            if col in df.columns:
                df[col] = df[col].astype(str)

        # Establish ordered channel list (prefer preset; append leftovers)
        if CHANNEL_DISPLAY_COL in df.columns:
            preferred = CHANNEL_ORDER_LIST
            present = list(pd.Index(df[CHANNEL_DISPLAY_COL].dropna().unique()))
            ordered_ch = [c for c in preferred if c in present]
            leftovers = [c for c in present if c not in ordered_ch]
            def _natkey(s): return [int(t) if t.isdigit() else t.lower() for t in re.split('([0-9]+)', str(s))]
            ordered_ch.extend(sorted(leftovers, key=_natkey))
            ORDERED_CHANNEL_LABELS = ordered_ch
            print(f"Derived ORDERED_CHANNEL_LABELS for plots: {ORDERED_CHANNEL_LABELS}")
        else:
            ORDERED_CHANNEL_LABELS = []
            print(f"ERROR: '{CHANNEL_DISPLAY_COL}' not found. ORDERED_CHANNEL_LABELS empty.")

        # Drop rows missing any key analysis metrics
        key_metrics_for_analysis = (
            list(APERIODIC_METRICS_COLS.keys()) +
            list(PKG_METRICS_COLS.keys()) +
            list(OSCILLATORY_METRICS_COLS.keys())
        )
        present_keys = [c for c in key_metrics_for_analysis if c in df.columns]
        before = len(df)
        if present_keys:
            df.dropna(subset=present_keys, how='any', inplace=True)
            print(f"Dropped rows with NaNs in any of {present_keys}. Rows {before} -> {len(df)}.")
        else:
            print("Warning: No key metrics found to NaN-filter; proceeding.")

        print(f"Final master_df (pre-binning) shape: {df.shape}")

        # =========================
        # 10-min Aggregation (SF wall-clock aligned; Polars or pandas)
        # =========================
        if 'Aligned_PKG_UnixTimestamp' not in df.columns:
            raise KeyError("'Aligned_PKG_UnixTimestamp' not found in dataframe")

        # Group keys for independent timelines: per-contact × freq × clinical state
        group_keys = [CHANNEL_DISPLAY_COL, FOOOF_FREQ_BAND_COL, CLINICAL_STATE_COL]

        # Prefer Polars, but use a local flag so we don't shadow the module constant
        use_polars = HAS_POLARS

        # Ensure no duplicate column names before any engine conversion
        if df.columns.duplicated().any():
            df = df.loc[:, ~df.columns.duplicated()].copy()
            print("Deduplicated pandas columns before aggregation.")

        # Choose categorical passthrough columns you want to keep after aggregation
        passthru_cat_cols = [
            'SessionID', 'Hemisphere', 'BestModel_AperiodicMode',
            CLINICAL_STATE_AGGREGATED_COL
        ]
        passthru_cat_cols = [c for c in passthru_cat_cols if c in df.columns]

        if use_polars:
            try:
                # Build a minimal frame for Polars with unique columns
                num_cols = df.select_dtypes(include=[np.number]).columns.tolist()
                # Remove raw timestamp from numeric aggregation list
                for _drop in ['Aligned_PKG_UnixTimestamp']:
                    if _drop in num_cols:
                        num_cols.remove(_drop)

                cols_keep = []
                for c in (group_keys + ['Aligned_PKG_UnixTimestamp'] + num_cols):
                    if c in df.columns and c not in cols_keep:
                        cols_keep.append(c)

                # To Polars (no index)
                pl_df = pl.from_pandas(df[cols_keep], include_index=False)

                # Create SF-local 10-min bin via timezone-aware datetime & truncate
                pl_df = pl_df.with_columns(
                    pl.from_epoch(pl.col("Aligned_PKG_UnixTimestamp"), time_unit="s")
                      .dt.replace_time_zone("UTC")
                      .dt.convert_time_zone("America/Los_Angeles")
                      .dt.truncate("10m")
                      .alias("_bin_dt")
                )

                # Aggregate: mean of numerics + count
                agg_exprs = [pl.col(c).mean().alias(c) for c in num_cols] + [pl.len().alias("_count")]
                grouped = (
                    pl_df.group_by(group_keys + ["_bin_dt"], maintain_order=False)
                         .agg(agg_exprs)
                         .filter(pl.col("_count") >= MIN_SAMPLES_PER_BIN)
                ).to_pandas()

                grouped = grouped.rename(columns={"_bin_dt": "Datetime_10min"})

                # ---- carry-through categorical/meta columns via first() per bin ----
                if passthru_cat_cols:
                    keep_cols_for_cats = []
                    for c in (group_keys + ['Aligned_PKG_UnixTimestamp'] + passthru_cat_cols):
                        if c in df.columns and c not in keep_cols_for_cats:
                            keep_cols_for_cats.append(c)

                    pl_cat = pl.from_pandas(df[keep_cols_for_cats], include_index=False).with_columns(
                        pl.from_epoch(pl.col("Aligned_PKG_UnixTimestamp"), time_unit="s")
                          .dt.replace_time_zone("UTC")
                          .dt.convert_time_zone("America/Los_Angeles")
                          .dt.truncate("10m")
                          .alias("_bin_dt")
                    )
                    cat_grouped = (
                        pl_cat.group_by(group_keys + ["_bin_dt"], maintain_order=False)
                             .agg([pl.col(c).first().alias(c) for c in passthru_cat_cols])
                    ).to_pandas().rename(columns={"_bin_dt": "Datetime_10min"})

                    grouped = grouped.merge(
                        cat_grouped,
                        on=group_keys + ["Datetime_10min"],
                        how="left",
                        validate="one_to_one"
                    )

                # Region mapping (vectorized)
                _m = grouped[CHANNEL_DISPLAY_COL].str.extract(r'Contact_(\d+)_(\d+)', expand=True)
                _a = pd.to_numeric(_m[0], errors='coerce')
                _b = pd.to_numeric(_m[1], errors='coerce')
                grouped['Region'] = np.where((_a <= 3) & (_b <= 3), 'STN',
                                      np.where((_a >= 4) & (_b >= 4), 'M1', 'Mixed'))

                master_df_step4 = grouped

            except Exception as e_polars:
                print(f"Polars path failed ({type(e_polars).__name__}: {e_polars}). Falling back to pandas path...")
                use_polars = False  # local flag only

        if not use_polars:
            # ---------- PANDAS FAST PATH (same SF-aligned semantics) ----------
            # SF-local datetime and floor to 10-minute bins
            _dt_sf = (
                pd.to_datetime(df['Aligned_PKG_UnixTimestamp'], unit='s', utc=True, errors='coerce')
                  .dt.tz_convert(SF_TZ)
            )
            df['_bin_dt'] = _dt_sf.dt.floor('10T')

            num_cols = df.select_dtypes(include=[np.number]).columns.tolist()
            for _drop in ['Aligned_PKG_UnixTimestamp']:
                if _drop in num_cols:
                    num_cols.remove(_drop)

            gb = df.groupby(group_keys + ['_bin_dt'], observed=True, sort=False)
            num_mean = gb[num_cols].mean()
            counts = gb[num_cols[0]].count().rename('_count') if num_cols else gb.size().rename('_count')
            agg_df = num_mean.join(counts)
            agg_df = agg_df[agg_df['_count'] >= MIN_SAMPLES_PER_BIN].reset_index()

            agg_df = agg_df.rename(columns={'_bin_dt': 'Datetime_10min'})

            # ---- carry-through categorical/meta columns via mode/first per bin ----
            if passthru_cat_cols:
                def _mode_or_first(s):
                    m = s.mode(dropna=True)
                    if not m.empty:
                        return m.iat[0]
                    s = s.dropna()
                    return s.iloc[0] if not s.empty else np.nan

                cats = (
                    df.groupby(group_keys + ['_bin_dt'], observed=True, sort=False)[passthru_cat_cols]
                      .agg(_mode_or_first)
                      .reset_index()
                      .rename(columns={'_bin_dt': 'Datetime_10min'})
                )
                agg_df = agg_df.merge(
                    cats,
                    on=group_keys + ['Datetime_10min'],
                    how='left',
                    validate='one_to_one'
                )

            # Region mapping
            _m = agg_df[CHANNEL_DISPLAY_COL].str.extract(r'Contact_(\d+)_(\d+)', expand=True)
            _a = pd.to_numeric(_m[0], errors='coerce')
            _b = pd.to_numeric(_m[1], errors='coerce')
            agg_df['Region'] = np.where((_a <= 3) & (_b <= 3), 'STN',
                                 np.where((_a >= 4) & (_b >= 4), 'M1', 'Mixed'))

            master_df_step4 = agg_df

        print(f"Step 4.5 complete. 10-min binned shape: {master_df_step4.shape}")

        # ------------------------------
        # Diagnostics (light & optional)
        # ------------------------------
        if DIAG_MODE != 'off':
            print("\n=== 10-min Aggregation Diagnostics ===")
            try:
                print(f"Raw rows: {len(df):,} | Aggregated rows: {len(master_df_step4):,}")
                raw_t = pd.to_datetime(df['Aligned_PKG_UnixTimestamp'], unit='s', utc=True, errors='coerce').dt.tz_convert(SF_TZ)
                agg_t = pd.to_datetime(master_df_step4['Datetime_10min'], errors='coerce')
                print(f"RAW span: {raw_t.min()} -> {raw_t.max()}")
                print(f"AGG span: {agg_t.min()} -> {agg_t.max()}")

                print("\nPer-contact 10-min bin counts (top 10):")
                print(
                    master_df_step4.groupby(CHANNEL_DISPLAY_COL)['Datetime_10min']
                        .nunique().sort_values(ascending=False).head(10)
                )

                if DIAG_MODE == 'full':
                    raw_counts = (
                        df.assign(_dt=raw_t)
                          .groupby([CHANNEL_DISPLAY_COL, FOOOF_FREQ_BAND_COL, CLINICAL_STATE_COL])['_dt'].count()
                          .rename('N_raw').reset_index()
                    )
                    agg_counts = (
                        master_df_step4
                          .groupby([CHANNEL_DISPLAY_COL, FOOOF_FREQ_BAND_COL, CLINICAL_STATE_COL])['Datetime_10min'].nunique()
                          .rename('N_bins_10min').reset_index()
                    )
                    diag = raw_counts.merge(agg_counts, how='outer')
                    diag['N_raw'] = diag['N_raw'].fillna(0).astype(int)
                    diag['N_bins_10min'] = diag['N_bins_10min'].fillna(0).astype(int)
                    print("\nLowest 10 groups by N_bins_10min:")
                    print(diag.sort_values('N_bins_10min').head(10))
            except Exception as _e:
                print("Diagnostics skipped:", _e)

        # --- Preview a few rows for sanity (cheap) ---
        if not master_df_step4.empty:
            cols_to_show = [CHANNEL_DISPLAY_COL, 'Region', FOOOF_FREQ_BAND_COL, CLINICAL_STATE_COL, 'Datetime_10min'] \
                           + list(APERIODIC_METRICS_COLS.keys()) \
                           + list(PKG_METRICS_COLS.keys()) \
                           + list(OSCILLATORY_METRICS_COLS.keys()) \
                           + ['Total_Daily_LEDD_mg']
            cols_to_show = [c for c in cols_to_show if c in master_df_step4.columns]
            print("\nPreview (first 5 rows):")
            print(master_df_step4[cols_to_show].head())
        else:
            print("Warning: Aggregated master_df_step4 is empty.")

        return master_df_step4, ORDERED_CHANNEL_LABELS

    except Exception as e:
        print(f"Error loading or processing the CSV file '{file_path}' for Step 4: {e}")
        print(traceback.format_exc())
        return None, []

# --- Load and Preprocess Data ---
master_df_step4, ORDERED_CHANNEL_LABELS = load_and_preprocess_step4_data(master_csv_path_to_load, patient_hemisphere_id)

if master_df_step4 is not None and not master_df_step4.empty:
    print("\nCell 3: SF-aligned per-contact/state 10-min aggregation complete.")
else:
    print("Halting Step 4 script as master data could not be loaded or is empty after preprocessing.")


Polars detected: using multi-threaded aggregation.
Successfully loaded /home/jackson/step2_final/step3_fooof_results_neural_pkg_aligned_finalstep3_bushlab5000/MASTER_FOOOF_PKG_results_COHORT_RCS02_05_06_neural_pkg_aligned_finalstep3_bushlab5000.csv. Initial shape: (429660, 46)
Built 'Channel_Display' from column 'Channel'.
Derived ORDERED_CHANNEL_LABELS for plots: ['Contact_2_0', 'Contact_3_0', 'Contact_3_1', 'Contact_10_8', 'Contact_11_9']
Dropped rows with NaNs in any of ['Exponent_BestModel', 'Offset_BestModel', 'Aligned_BK', 'Aligned_DK', 'Aligned_Tremor_Score', 'Beta_Peak_Power_at_DominantFreq', 'Gamma_Peak_Power_at_DominantFreq']. Rows 429660 -> 429564.
Final master_df (pre-binning) shape: (429564, 47)
Step 4.5 complete. 10-min binned shape: (32436, 42)

Preview (first 5 rows):
  Channel_Display Region FreqRangeLabel Clinical_State_2min_Window  \
0     Contact_2_0    STN        MidFreq                      Other   
1    Contact_10_8     M1        LowFreq      Non-Dyskinetic Mobil

In [5]:
# -*- coding: utf-8 -*-
# --- Cell 4: Helper Functions for Step 4 (region-aware, safe for STN vs M1) ---
# ---- GLOBAL: exclude tremor everywhere ----
EXCLUDE_TREMOR = True
TREMOR_LABELS = {'PKG Tremor Score', 'Tremor'}
TREMOR_COL_KEYS = {'Aligned_Tremor_Score', 'Aligned_Tremor'}   # any PKG tremor columns in your DF

def print_10min_agg_diagnostics(df_raw, df_10, group_keys=(CHANNEL_DISPLAY_COL, FOOOF_FREQ_BAND_COL, CLINICAL_STATE_COL)):
    """Diagnostics for 10-min binning. Adapts group_keys if ANALYSIS_LEVEL='region'."""
    if ANALYSIS_LEVEL == 'region':
        group_keys = ('Region', FOOOF_FREQ_BAND_COL, CLINICAL_STATE_COL)

    print("\n=== 10-min Aggregation Diagnostics ===")
    # Raw counts per group
    raw_counts = (
        df_raw
        .assign(dt=pd.to_datetime(df_raw['Aligned_PKG_UnixTimestamp'], unit='s', utc=True, errors='coerce').dt.tz_convert(SF_TZ))
        .dropna(subset=['Aligned_PKG_UnixTimestamp'])
        .groupby(list(group_keys))['Aligned_PKG_UnixTimestamp'].count()
        .rename('N_raw')
        .reset_index()
    )

    # Aggregated counts per group
    if 'Datetime_10min' not in df_10.columns:
        raise KeyError("Expected 'Datetime_10min' in aggregated dataframe")
    agg_counts = (
        df_10
        .groupby(list(group_keys))['Datetime_10min'].nunique()
        .rename('N_bins_10min')
        .reset_index()
    )

    diag = raw_counts.merge(agg_counts, on=list(group_keys), how='outer')
    diag['N_raw'] = diag['N_raw'].fillna(0).astype(int)
    diag['N_bins_10min'] = diag['N_bins_10min'].fillna(0).astype(int)
    diag['approx_cov_pct'] = (diag['N_bins_10min'] / np.maximum(diag['N_raw'] / 20.0, 1)) * 100.0

    print(diag.sort_values('approx_cov_pct').head(10))
    print("\nLowest coverage groups shown above (good to spot where data go missing).")

    vanished = diag[(diag['N_raw'] > 0) & (diag['N_bins_10min'] == 0)]
    if not vanished.empty:
        print("\nWARNING: groups with raw data but 0 aggregated bins (likely due to MIN_SAMPLES_PER_BIN):")
        print(vanished.head(20))

    print(f"\nRaw rows: {len(df_raw):,} | Aggregated rows (10-min bins): {len(df_10):,}")
    for key_name, d in [('RAW', df_raw), ('AGG', df_10)]:
        try:
            col = 'Datetime_10min' if key_name == 'AGG' else 'Aligned_PKG_UnixTimestamp'
            if key_name == 'RAW':
                t = pd.to_datetime(d[col], unit='s', utc=True, errors='coerce').dt.tz_convert(SF_TZ)
            else:
                t = pd.to_datetime(d[col], errors='coerce')
            print(f"{key_name} time span: {t.min()}  ->  {t.max()}")
        except Exception as e:
            print(f"{key_name} time span: unavailable ({e})")


def calculate_spearman_with_n(data_df, col1, col2, min_samples=MIN_SAMPLES_FOR_CORR):
    """Calculates Spearman correlation if N >= min_samples."""
    pair_data = data_df[[col1, col2]].dropna()
    n_points = len(pair_data)
    if n_points < min_samples:
        return np.nan, np.nan, n_points
    try:
        rho, p_value = spearmanr(pair_data[col1], pair_data[col2])
        if np.isnan(rho):
            return np.nan, np.nan, n_points
        return rho, p_value, n_points
    except ValueError:
        return np.nan, np.nan, n_points


def calculate_partial_spearman(data_df, x_col, y_col, covar_cols, min_samples=MIN_SAMPLES_FOR_CORR):
    """Calculates partial Spearman correlation if N >= min_samples."""
    all_cols_for_partial = [x_col, y_col] + covar_cols
    partial_data = data_df[all_cols_for_partial].dropna()
    n_points = len(partial_data)
    if n_points < min_samples:
        return np.nan, np.nan, n_points
    try:
        if not all(partial_data[col].nunique() > 1 for col in all_cols_for_partial if col in partial_data):
            return np.nan, np.nan, n_points
        pcorr_result = pg.partial_corr(data=partial_data, x=x_col, y=y_col, covar=covar_cols, method='spearman')
        rho = pcorr_result['r'].iloc[0]
        p_value = pcorr_result['p-val'].iloc[0]
        return rho, p_value, n_points
    except Exception:
        return np.nan, np.nan, n_points


def annotate_correlation_on_plot(ax, rho, p_value, N_val, test_type="Spearman ρ",
                                 x_pos=0.97, y_pos=0.97, fontsize=9,
                                 sig_threshold=P_VALUE_THRESHOLD):
    """Annotates correlation statistics on a plot axis. Will auto-skip if ENABLE_PLOTS=False."""
    if not ENABLE_PLOTS:
        return
    if pd.isna(rho) or pd.isna(p_value):
        stat_text = f"{test_type}: N/A (N={N_val})"
        bg_color = DEFAULT_P_VAL_BG_COLOR
    else:
        stars = ""
        if p_value < 0.001: stars = "***"
        elif p_value < 0.01: stars = "**"
        elif p_value < sig_threshold: stars = "*"
        stat_text = f"{test_type}={rho:.2f}{stars}\np={p_value:.3g}\n(N={N_val})"
        bg_color = SIGNIFICANT_P_VAL_BG_COLOR if p_value < sig_threshold else DEFAULT_P_VAL_BG_COLOR

    ax.text(x_pos, y_pos, stat_text, transform=ax.transAxes, fontsize=fontsize,
            verticalalignment='top', horizontalalignment='right',
            bbox=dict(boxstyle='round,pad=0.3', fc=bg_color, alpha=0.85, edgecolor='darkgrey'))


def get_safe_filename_step4(base_name):
    """Creates a filesystem-safe filename."""
    return re.sub(r'[^\w\s-]', '', str(base_name)).strip().replace(' ', '_').replace('-', '_')


def trim_data_for_boxplot_visualization(df_group, value_col):
    """Trims outliers based on IQR for cleaner boxplot visualization (doesn't affect stats)."""
    if df_group.empty or df_group[value_col].isnull().all() or len(df_group) < 2:
        return df_group
    Q1 = df_group[value_col].quantile(0.25)
    Q3 = df_group[value_col].quantile(0.75)
    IQR = Q3 - Q1
    if IQR == 0:
        return df_group
    lower_bound = Q1 - 1.5 * IQR
    upper_bound = Q3 + 1.5 * IQR
    return df_group[(df_group[value_col] >= lower_bound) & (df_group[value_col] <= upper_bound)]


print("Cell 4: Helper functions for Step 4 defined (region-aware, plotting guarded).")


Cell 4: Helper functions for Step 4 defined (region-aware, plotting guarded).


In [6]:
# -*- coding: utf-8 -*-
# --- Cell 5_PREAMBLE: Definitions for State-Specific Analyses (Region-level, Revised Order) ---

# --- Define Target Clinical States and their Order (exclude Sleep, keep 4 motor states) ---
TARGET_CLINICAL_STATES_ORDERED = [
    "Immobile",
    "Non-Dyskinetic Mobile",
    "Transitional Mobile",
    "Dyskinetic Mobile"
]

ORIGINAL_STATES_FOR_ANALYSIS = TARGET_CLINICAL_STATES_ORDERED[:]

# No combining needed for this setup
STATES_TO_COMBINE_MAPPING = {}
NEW_COMBINED_STATE_NAME = None

# --- Define Clinical State Colors (using distinct colors, excluding Sleep) ---
NEW_CLINICAL_STATE_COLORS_FOR_PLOTTING = {
    'Immobile': '#40E0D0',              # Turquoise
    'Non-Dyskinetic Mobile': '#32CD32', # LimeGreen
    'Transitional Mobile': '#FFD700',   # Gold
    'Dyskinetic Mobile': '#FF6347',     # Tomato
    # Fallbacks for unexpected states
    'Sleep': '#4169E1',                 # RoyalBlue (excluded)
    'Other': '#C0C0C0',                 # Silver
    'Mobile (All Types)': 'darkgreen'   # For aggregated view if ever used
}

# --- PKG Symptom Colors (unchanged) ---
PKG_SYMPTOM_COLORS = {
    'Aligned_BK': BASE_COLOR_PALETTE.get('Aligned_BK', 'steelblue'),
    'Aligned_DK': BASE_COLOR_PALETTE.get('Aligned_DK', 'orangered'),
    'Aligned_Tremor_Score': BASE_COLOR_PALETTE.get('Aligned_Tremor_Score', 'mediumpurple')
}

# --- Output Directory for State-Specific Analyses ---
STATE_SPECIFIC_ANALYSIS_DIR = os.path.join(analysis_session_plot_folder_step4)
os.makedirs(STATE_SPECIFIC_ANALYSIS_DIR, exist_ok=True)

# --- Region-level enforcement ---
GROUP_KEY_COL = 'Region' if ANALYSIS_LEVEL == 'region' else CHANNEL_DISPLAY_COL
ORDERED_GROUP_LABELS = ['STN', 'M1'] if ANALYSIS_LEVEL == 'region' else ORDERED_CHANNEL_LABELS
if ANALYSIS_LEVEL == 'region' and EXCLUDE_MIXED:
    ORDERED_GROUP_LABELS = [g for g in ORDERED_GROUP_LABELS if g in ['STN', 'M1']]

print("Cell 5_PREAMBLE: Definitions for state-specific analyses (region-level, 4 motor states) are set.")
print(f"Target clinical states: {TARGET_CLINICAL_STATES_ORDERED}")
print(f"Colors for clinical states: {NEW_CLINICAL_STATE_COLORS_FOR_PLOTTING}")
print(f"Grouping key: {GROUP_KEY_COL}, Labels: {ORDERED_GROUP_LABELS}")
print(f"State-specific outputs will be saved in: {STATE_SPECIFIC_ANALYSIS_DIR}")


Cell 5_PREAMBLE: Definitions for state-specific analyses (region-level, 4 motor states) are set.
Target clinical states: ['Immobile', 'Non-Dyskinetic Mobile', 'Transitional Mobile', 'Dyskinetic Mobile']
Colors for clinical states: {'Immobile': '#40E0D0', 'Non-Dyskinetic Mobile': '#32CD32', 'Transitional Mobile': '#FFD700', 'Dyskinetic Mobile': '#FF6347', 'Sleep': '#4169E1', 'Other': '#C0C0C0', 'Mobile (All Types)': 'darkgreen'}
Grouping key: Region, Labels: ['STN', 'M1']
State-specific outputs will be saved in: /home/jackson/step2_final/step3_fooof_results_neural_pkg_aligned_finalstep3_bushlab5000/step4_within_subject/COHORT_RCS02_05_06_plots_20250822_155044


In [7]:
# -*- coding: utf-8 -*-
# --- Cell 5A (Region-level, Revised): State-Specific Correlations with FDR Correction ---
# Only STN vs M1 (no per-contact loops). Applies BH FDR across all tests.

import pandas as pd
import numpy as np
import os
from statsmodels.stats.multitest import fdrcorrection

print("\n--- Cell 5A (Region-level): Starting State-Specific Correlation Calculations with FDR Correction ---")

# --- Safety checks ---
if 'master_df_step4' not in locals() or master_df_step4.empty:
    print("CRITICAL ERROR: master_df_step4 not available or empty. Cannot proceed with Cell 5A.")
else:
    # Ensure grouping is region-level
    df_analysis = master_df_step4.copy()
    if EXCLUDE_MIXED and 'Region' in df_analysis.columns:
        df_analysis = df_analysis[df_analysis['Region'].isin(['STN', 'M1'])]

    # Filter by states of interest
    df_analysis = df_analysis[df_analysis[CLINICAL_STATE_COL].isin(TARGET_CLINICAL_STATES_ORDERED)].copy()
    if df_analysis.empty:
        print(f"No data found for states {TARGET_CLINICAL_STATES_ORDERED}. Skipping Cell 5A.")
    else:
        # Categorical state ordering
        df_analysis[CLINICAL_STATE_COL] = pd.Categorical(
            df_analysis[CLINICAL_STATE_COL],
            categories=TARGET_CLINICAL_STATES_ORDERED,
            ordered=True
        )

        print(f"Filtered region-level data for target states. Shape: {df_analysis.shape}")

        # Output directory
        state_corr_csv_dir = os.path.join(STATE_SPECIFIC_ANALYSIS_DIR, "Correlation_CSVs_by_State_FDR_Corrected")
        os.makedirs(state_corr_csv_dir, exist_ok=True)

        # --- Step 1: Collect all correlation results ---
        all_results = []
        for state_current in TARGET_CLINICAL_STATES_ORDERED:
            df_state = df_analysis[df_analysis[CLINICAL_STATE_COL] == state_current]
            if df_state.empty:
                continue

            for region in ORDERED_GROUP_LABELS:
                if region not in df_state['Region'].unique():
                    continue
                df_region_state = df_state[df_state['Region'] == region]

                for freq_label in ORDERED_FREQ_LABELS:
                    df_region_freq = df_region_state[df_region_state[FOOOF_FREQ_BAND_COL] == freq_label].copy()
                    if df_region_freq.empty:
                        continue

                    # --- Bivariate Aperiodic vs PKG ---
                    for ap_col, ap_name in APERIODIC_METRICS_COLS.items():
                        for pkg_col, pkg_name in PKG_METRICS_COLS.items():
                            if ap_col in df_region_freq.columns and pkg_col in df_region_freq.columns:
                                rho, p_val, N = calculate_spearman_with_n(df_region_freq, ap_col, pkg_col)
                                all_results.append({
                                    'TestType': 'Bivar_AP_PKG',
                                    'ClinicalState': state_current,
                                    'Channel': region,
                                    'FreqBand': freq_label,
                                    'Metric1': ap_name,
                                    'Metric2': pkg_name,
                                    'Rho': rho,
                                    'P_Value_Original': p_val,
                                    'N': N
                                })

                    # --- Partial Aperiodic vs PKG (controlling for oscillatory) ---
                    covars = [c for c in OSCILLATORY_METRICS_COLS.keys() if c in df_region_freq.columns]
                    if len(covars) == 2:
                        for ap_col, ap_name in APERIODIC_METRICS_COLS.items():
                            for pkg_col, pkg_name in PKG_METRICS_COLS.items():
                                if ap_col in df_region_freq.columns and pkg_col in df_region_freq.columns:
                                    prho, ppval, Np = calculate_partial_spearman(df_region_freq, ap_col, pkg_col, covars)
                                    all_results.append({
                                        'TestType': 'Partial_AP_PKG',
                                        'ClinicalState': state_current,
                                        'Channel': region,
                                        'FreqBand': freq_label,
                                        'Metric1': ap_name,
                                        'Metric2': pkg_name,
                                        'Rho': prho,
                                        'P_Value_Original': ppval,
                                        'N': Np
                                    })

                    # --- Bivariate Aperiodic vs Oscillatory ---
                    for ap_col, ap_name in APERIODIC_METRICS_COLS.items():
                        for osc_col, osc_name in OSCILLATORY_METRICS_COLS.items():
                            if ap_col in df_region_freq.columns and osc_col in df_region_freq.columns:
                                rho2, p2, N2 = calculate_spearman_with_n(df_region_freq, ap_col, osc_col)
                                all_results.append({
                                    'TestType': 'Bivar_AP_Osc',
                                    'ClinicalState': state_current,
                                    'Channel': region,
                                    'FreqBand': freq_label,
                                    'Metric1': ap_name,
                                    'Metric2': osc_name,
                                    'Rho': rho2,
                                    'P_Value_Original': p2,
                                    'N': N2
                                })

        print(f"\nCollected {len(all_results)} tests for FDR correction.")

        # --- Step 2: Apply FDR correction ---
        if not all_results:
            print("No correlation results generated. Skipping FDR.")
        else:
            df_all = pd.DataFrame(all_results)
            valid_mask = df_all['P_Value_Original'].notna()
            if valid_mask.sum() == 0:
                df_all['P_Value_FDR_Adjusted'] = np.nan
                df_all['Significant_FDR_0.05'] = False
            else:
                rej, pvals_corr = fdrcorrection(df_all.loc[valid_mask, 'P_Value_Original'], alpha=0.05)
                df_all['P_Value_FDR_Adjusted'] = np.nan
                df_all.loc[valid_mask, 'P_Value_FDR_Adjusted'] = pvals_corr
                df_all['Significant_FDR_0.05'] = df_all['P_Value_FDR_Adjusted'] < 0.05

            print("FDR correction applied.")

            # --- Step 3: Split and save ---
            out_map = {
                'Bivar_AP_PKG': ("Bivariate_AP_vs_PKG", 'AperiodicMetric', 'PKGMetric', 'SpearmanRho'),
                'Partial_AP_PKG': ("Partial_AP_vs_PKG", 'AperiodicMetric', 'PKGMetric', 'PartialSpearmanRho_vs_BetaGamma'),
                'Bivar_AP_Osc': ("Bivariate_AP_vs_Oscillatory", 'AperiodicMetric', 'OscillatoryMetric', 'SpearmanRho')
            }

            for ttype, (fname, m1, m2, rho_name) in out_map.items():
                sub = df_all[df_all['TestType'] == ttype].copy()
                if sub.empty:
                    continue
                sub.rename(columns={'Metric1': m1, 'Metric2': m2, 'Rho': rho_name}, inplace=True)
                sub = sub.drop(columns='TestType')
                outpath = os.path.join(state_corr_csv_dir, f"{patient_hemisphere_id}_{fname}_ByState_FDR_Region.csv")
                sub.to_csv(outpath, index=False)
                print(f"Saved {ttype} results to {outpath}")

            print("\nSample of FDR-corrected Bivariate AP vs PKG results:")
            print(df_all[df_all['TestType'] == 'Bivar_AP_PKG'].head())

print("\n--- Cell 5A (Region-level, FDR Correction) Complete ---")



--- Cell 5A (Region-level): Starting State-Specific Correlation Calculations with FDR Correction ---
Filtered region-level data for target states. Shape: (24996, 42)

Collected 384 tests for FDR correction.
FDR correction applied.
Saved Bivar_AP_PKG results to /home/jackson/step2_final/step3_fooof_results_neural_pkg_aligned_finalstep3_bushlab5000/step4_within_subject/COHORT_RCS02_05_06_plots_20250822_155044/Correlation_CSVs_by_State_FDR_Corrected/COHORT_RCS02_05_06_Bivariate_AP_vs_PKG_ByState_FDR_Region.csv
Saved Partial_AP_PKG results to /home/jackson/step2_final/step3_fooof_results_neural_pkg_aligned_finalstep3_bushlab5000/step4_within_subject/COHORT_RCS02_05_06_plots_20250822_155044/Correlation_CSVs_by_State_FDR_Corrected/COHORT_RCS02_05_06_Partial_AP_vs_PKG_ByState_FDR_Region.csv
Saved Bivar_AP_Osc results to /home/jackson/step2_final/step3_fooof_results_neural_pkg_aligned_finalstep3_bushlab5000/step4_within_subject/COHORT_RCS02_05_06_plots_20250822_155044/Correlation_CSVs_by_Stat

In [8]:
# -*- coding: utf-8 -*-
# --- Cell 5B-prep: Build datetime_for_avg robustly (Region-level, no plotting here) ---

# Start from the aggregated master_df_step4 produced in Cell 3
if 'master_df_step4' not in locals() or master_df_step4 is None or master_df_step4.empty:
    raise RuntimeError("master_df_step4 is not available. Run Cell 3 first.")

# Region-level subset (respect EXCLUDE_MIXED) and keep only target states
df_for_5b = master_df_step4.copy()
if EXCLUDE_MIXED and 'Region' in df_for_5b.columns:
    df_for_5b = df_for_5b[df_for_5b['Region'].isin(['STN', 'M1'])]

if CLINICAL_STATE_COL not in df_for_5b.columns:
    raise KeyError(f"'{CLINICAL_STATE_COL}' not found in aggregated dataframe.")

df_for_5b = df_for_5b[df_for_5b[CLINICAL_STATE_COL].isin(TARGET_CLINICAL_STATES_ORDERED)].copy()
df_for_5b[CLINICAL_STATE_COL] = pd.Categorical(
    df_for_5b[CLINICAL_STATE_COL],
    categories=TARGET_CLINICAL_STATES_ORDERED,
    ordered=True
)

# Prefer raw seconds if available (rare here), else reuse the 10-min bins from Cell 3
if 'Aligned_PKG_UnixTimestamp' in df_for_5b.columns:
    # Build SF-local tz-aware timestamps, then floor to 10-min to match bin edges
    dt_sf = (
        pd.to_datetime(df_for_5b['Aligned_PKG_UnixTimestamp'], unit='s', utc=True, errors='coerce')
          .dt.tz_convert(SF_TZ)
          .dt.floor('10T')
    )
    df_for_5b['datetime_for_avg'] = dt_sf
elif 'Datetime_10min' in df_for_5b.columns:
    # Reuse the existing bins; ensure tz-awareness and SF timezone
    dt_any = pd.to_datetime(df_for_5b['Datetime_10min'], errors='coerce')
    # If timezone-naive, localize to SF; if UTC or other tz-aware, convert to SF
    if dt_any.dt.tz is None:
        dt_any = dt_any.dt.tz_localize(SF_TZ)
    else:
        dt_any = dt_any.dt.tz_convert(SF_TZ)
    df_for_5b['datetime_for_avg'] = dt_any
else:
    raise KeyError("Neither 'Aligned_PKG_UnixTimestamp' nor 'Datetime_10min' is present. "
                   "Cannot construct 'datetime_for_avg'.")

print("datetime_for_avg ready. NaT rows:", df_for_5b['datetime_for_avg'].isna().sum())

# Make this the canonical filtered frame for any downstream use (e.g., optional plots or summaries)
master_df_step4_filtered_states = df_for_5b

# (Optional, fast) sanity: show the earliest 3 rows to confirm tz and columns
_preview_cols = [
    'Region', CHANNEL_DISPLAY_COL, FOOOF_FREQ_BAND_COL, CLINICAL_STATE_COL,
    'datetime_for_avg'
] + [c for c in ['Exponent_BestModel','Aligned_BK','Aligned_DK','Aligned_Tremor_Score'] if c in master_df_step4_filtered_states.columns]
print("\n5B-prep preview:")
print(master_df_step4_filtered_states[_preview_cols].head(3))


datetime_for_avg ready. NaT rows: 0

5B-prep preview:
  Region Channel_Display FreqRangeLabel Clinical_State_2min_Window  \
1     M1    Contact_10_8        LowFreq      Non-Dyskinetic Mobile   
2     M1    Contact_10_8        LowFreq                   Immobile   
4    STN     Contact_3_1        LowFreq                   Immobile   

           datetime_for_avg  Exponent_BestModel  Aligned_BK  Aligned_DK  \
1 2019-10-13 19:00:00-07:00            2.401817   19.270455    1.890909   
2 2019-05-20 23:40:00-07:00            4.797258   55.161667    0.000000   
4 2019-05-20 16:30:00-07:00            7.482851   45.010000    0.048333   

   Aligned_Tremor_Score  
1                   0.0  
2                   0.0  
4                   0.0  


In [9]:
# -*- coding: utf-8 -*-
# --- Cell 5B (Region-level, Extended): Overview Scatter Plots ---
# Adds Beta vs BK and Gamma vs DK plots (in addition to Aperiodic vs PKG).
# Plots STN vs M1 only, one figure per (metric pair, region, freq band).
# Scatter points = 10-min averages, regression line = all available points.

if not ENABLE_PLOTS:
    print("ENABLE_PLOTS=False -> skipping Cell 5B entirely.")
elif 'master_df_step4_filtered_states' not in locals() or master_df_step4_filtered_states.empty:
    print("master_df_step4_filtered_states not available/empty. Skipping Cell 5B.")
elif 'datetime_for_avg' not in master_df_step4_filtered_states.columns:
    print("ERROR: 'datetime_for_avg' missing. Run 5B-prep first.")
else:
    df_plot_src = master_df_step4_filtered_states.copy()
    if EXCLUDE_MIXED and 'Region' in df_plot_src.columns:
        df_plot_src = df_plot_src[df_plot_src['Region'].isin(['STN', 'M1'])]

    plot_subdir = os.path.join(STATE_SPECIFIC_ANALYSIS_DIR, "Overview_Scatter_AllPairs_Region")
    os.makedirs(plot_subdir, exist_ok=True)

    # --- Build pairs to plot ---
    # 1. Aperiodic vs PKG (original loop)
    pairs_to_plot = []
    for ap_col, ap_name in APERIODIC_METRICS_COLS.items():
        for pkg_col, pkg_name in PKG_METRICS_COLS.items():
            pairs_to_plot.append((ap_col, ap_name, pkg_col, pkg_name))

    # 2. Explicit Beta vs BK
    if 'Beta_Peak_Power_at_DominantFreq' in df_plot_src.columns and 'Aligned_BK' in df_plot_src.columns:
        pairs_to_plot.append(('Beta_Peak_Power_at_DominantFreq', 'Beta Peak Power',
                              'Aligned_BK', 'PKG BK Score'))

    # 3. Explicit Gamma vs DK
    if 'Gamma_Peak_Power_at_DominantFreq' in df_plot_src.columns and 'Aligned_DK' in df_plot_src.columns:
        pairs_to_plot.append(('Gamma_Peak_Power_at_DominantFreq', 'Gamma Peak Power',
                              'Aligned_DK', 'PKG DK Score'))

    # --- Loop through all pairs ---
    for x_col, x_name, y_col, y_name in pairs_to_plot:
        if x_col not in df_plot_src.columns or y_col not in df_plot_src.columns:
            continue

        print(f"\nGenerating overview plots for: {x_name} vs. {y_name} (Region-level)")
        for region_label in ['STN', 'M1']:
            df_region = df_plot_src[df_plot_src['Region'] == region_label]
            if df_region.empty:
                continue

            for freq_label in ORDERED_FREQ_LABELS:
                df_rf = df_region[df_region[FOOOF_FREQ_BAND_COL] == freq_label].copy()
                if df_rf.empty:
                    continue

                # Standardize y-axis across states
                all_y_vals = []
                for state in TARGET_CLINICAL_STATES_ORDERED:
                    sub = df_rf[df_rf[CLINICAL_STATE_COL] == state].dropna(subset=[x_col, y_col, 'datetime_for_avg'])
                    if not sub.empty and len(sub) >= MIN_SAMPLES_FOR_CORR:
                        pts = (sub.set_index('datetime_for_avg')
                                 .groupby(pd.Grouper(freq='10T'))[[x_col, y_col]]
                                 .mean()
                                 .dropna())
                        if not pts.empty:
                            all_y_vals.extend(pts[x_col].tolist())
                y_min, y_max = (np.nan, np.nan)
                if len(all_y_vals) > 0:
                    _mn, _mx = float(np.nanmin(all_y_vals)), float(np.nanmax(all_y_vals))
                    pad = (max(_mx - _mn, 1e-6)) * 0.10
                    y_min, y_max = _mn - pad, _mx + pad

                # Figure with 4 subplots (states)
                fig, axes = plt.subplots(1, len(TARGET_CLINICAL_STATES_ORDERED),
                                         figsize=(max(15, 4*len(TARGET_CLINICAL_STATES_ORDERED)), 5.5),
                                         sharey=False)
                if len(TARGET_CLINICAL_STATES_ORDERED) == 1:
                    axes = [axes]
                fig.suptitle(f"{x_name} vs. {y_name}\nRegion: {region_label} - Freq: {freq_label} - Patient: {patient_hemisphere_id}",
                             fontsize=plt.rcParams['figure.titlesize']*0.85, y=1.05)

                any_valid = False
                for i, state in enumerate(TARGET_CLINICAL_STATES_ORDERED):
                    ax = axes[i]
                    sub = df_rf[df_rf[CLINICAL_STATE_COL] == state].dropna(subset=[x_col, y_col, 'datetime_for_avg'])

                    if not sub.empty and len(sub) >= MIN_SAMPLES_FOR_CORR:
                        any_valid = True
                        # 10-min averaged points
                        pts = (sub.set_index('datetime_for_avg')
                                 .groupby(pd.Grouper(freq='10T'))[[x_col, y_col]]
                                 .mean()
                                 .dropna())
                        if not pts.empty:
                            sns.scatterplot(data=pts, x=y_col, y=x_col,
                                            color=NEW_CLINICAL_STATE_COLORS_FOR_PLOTTING.get(state, 'grey'),
                                            alpha=DOT_ALPHA+0.2, s=40, edgecolor='black', linewidths=0.5, ax=ax, legend=False)
                        # Regression
                        sns.regplot(data=sub, x=y_col, y=x_col, scatter=False, ax=ax,
                                    line_kws={'color': 'black','linewidth':1.5,'alpha':0.6})
                        # Annotate rho/p
                        rho, p_val, N_val = calculate_spearman_with_n(sub, x_col, y_col)
                        annotate_correlation_on_plot(ax, rho, p_val, N_val, fontsize=8)
                        if not pd.isna(y_min) and not pd.isna(y_max):
                            ax.set_ylim(y_min, y_max)
                    else:
                        ax.text(0.5,0.5,"N < min_samples" if len(sub)<MIN_SAMPLES_FOR_CORR else "No Data",
                                ha='center',va='center',transform=ax.transAxes,fontsize=9)

                    ax.set_title(state, fontsize=plt.rcParams['axes.titlesize']*0.8)
                    if i==0: ax.set_ylabel(x_name)
                    else:
                        ax.set_ylabel(""); ax.set_yticklabels([])
                    ax.set_xlabel(y_name if i==len(TARGET_CLINICAL_STATES_ORDERED)//2 else "")
                    ax.tick_params(axis='x', labelsize=plt.rcParams['xtick.labelsize']*0.9)
                    ax.tick_params(axis='y', labelsize=plt.rcParams['ytick.labelsize']*0.9)

                if any_valid:
                    plt.tight_layout(rect=[0,0.03,1,0.93])
                    fname = f"Overview_{get_safe_filename_step4(x_name)}_vs_{get_safe_filename_step4(y_name)}_{region_label}_{freq_label}.png"
                    plt.savefig(os.path.join(plot_subdir, fname))
                plt.close(fig)

print("\n--- Cell 5B (Region-level, Extended with Beta vs BK & Gamma vs DK) Complete ---")



Generating overview plots for: Aperiodic Exponent vs. PKG BK Score (Region-level)

Generating overview plots for: Aperiodic Exponent vs. PKG DK Score (Region-level)

Generating overview plots for: Aperiodic Exponent vs. PKG Tremor Score (Region-level)

Generating overview plots for: Aperiodic Offset vs. PKG BK Score (Region-level)

Generating overview plots for: Aperiodic Offset vs. PKG DK Score (Region-level)

Generating overview plots for: Aperiodic Offset vs. PKG Tremor Score (Region-level)

Generating overview plots for: Beta Peak Power vs. PKG BK Score (Region-level)

Generating overview plots for: Gamma Peak Power vs. PKG DK Score (Region-level)

--- Cell 5B (Region-level, Extended with Beta vs BK & Gamma vs DK) Complete ---


In [13]:
# -*- coding: utf-8 -*-
# --- Cell 5 (Region-level, FDR, minimal plots): Bivariate & Partial Correlations ---
# Runs tests per Region (STN/M1) and FreqRangeLabel. No contact-level logic.

print("\n--- Cell 5 (Region-level, FDR): Starting Correlation Analyses ---")

from statsmodels.stats.multitest import fdrcorrection

# ============ Config ============
try:
    ENABLE_PLOTS
except NameError:
    ENABLE_PLOTS = False  # default: off unless you turned it on earlier

# Exclude Tremor everywhere in Cell 5
EXCLUDE_TREMOR = True
TREMOR_COL_KEYS = ['Aligned_Tremor', 'Aligned_Tremor_Score', 'PKG Tremor Score']

font_scale_factor = 0.6
MODIFIED_COLOR_PALETTE = BASE_COLOR_PALETTE.copy()
if 'Aligned_Tremor_Score' in MODIFIED_COLOR_PALETTE:
    MODIFIED_COLOR_PALETTE['Aligned_Tremor_Score'] = 'green'
SIGNIFICANT_P_VAL_BG_COLOR_STEP4 = 'khaki'
DEFAULT_P_VAL_BG_COLOR_STEP4 = 'ivory'

# ---------- Choose Source DF ----------
if 'df_analysis' in globals():
    src = df_analysis.copy()   # Region-level frame built in Cell 3
else:
    src = master_df_step4.copy()

if src is None or src.empty:
    print("No data available for Cell 5. Skipping.")
else:
    # Keep only STN/M1 rows if Region exists
    if 'Region' in src.columns:
        src = src[src['Region'].isin(['STN', 'M1'])].copy()

    # Drop tremor columns from the working DF (so we never accidentally use them)
    if EXCLUDE_TREMOR:
        src = src.drop(columns=[c for c in TREMOR_COL_KEYS if c in src.columns], errors='ignore')

    # Ensure we have frequency band column
    if FOOOF_FREQ_BAND_COL not in src.columns:
        print(f"Missing '{FOOOF_FREQ_BAND_COL}'. Skipping Cell 5.")
    else:
        # --- Build per-row datetime key for optional 10-min averaging in plots ---
        if 'Aligned_PKG_UnixTimestamp' in src.columns:
            src['datetime_for_avg_c5'] = (
                pd.to_datetime(src['Aligned_PKG_UnixTimestamp'], unit='s', utc=True, errors='coerce')
                  .dt.tz_convert(SF_TZ)
            ).dt.floor('10T')
        elif 'Datetime_10min' in src.columns:
            dt_tmp = pd.to_datetime(src['Datetime_10min'], errors='coerce', utc=True)
            src['datetime_for_avg_c5'] = dt_tmp.dt.tz_convert(SF_TZ)
        else:
            src['datetime_for_avg_c5'] = pd.NaT  # plotting will fall back to granular

        # ---- Containers ----
        all_bivar_ap_pkg = []
        all_partial_ap_pkg = []
        all_bivar_ap_osc = []
        all_bivar_osc_pkg = []  # NEW: Beta/Gamma vs BK/DK

        all_pvals = []
        pmap = []  # (family, idx)

        # Helper: append p-val tracking
        def _push_p(val, family, idx):
            if not pd.isna(val):
                all_pvals.append(val)
                pmap.append((family, idx))

        # Build PKG metric mapping, optionally excluding tremor
        pkg_map = {k: v for k, v in PKG_METRICS_COLS.items()}
        if EXCLUDE_TREMOR:
            pkg_map = {k: v for k, v in pkg_map.items() if 'Tremor' not in v}

        # Identify osc columns for convenience
        beta_key = None
        gamma_key = None
        for k, v in OSCILLATORY_METRICS_COLS.items():
            if 'Beta' in v and k in src.columns:  beta_key = k
            if 'Gamma' in v and k in src.columns: gamma_key = k

        # ===== Iterate Region × Freq =====
        regions = ['STN', 'M1'] if 'Region' in src.columns else [None]
        for region in regions:
            df_region = src if region is None else src[src['Region'] == region]
            if df_region.empty:
                continue

            for freq in ORDERED_FREQ_LABELS:
                df_rf = df_region[df_region[FOOOF_FREQ_BAND_COL] == freq].copy()
                if df_rf.empty:
                    continue

                # --- Family 1: Bivariate (Aperiodic vs PKG) ---
                for ap_col, ap_name in APERIODIC_METRICS_COLS.items():
                    if ap_col not in df_rf.columns: 
                        continue
                    for pkg_col, pkg_name in pkg_map.items():
                        if pkg_col not in df_rf.columns: 
                            continue
                        df_pair = df_rf[[ap_col, pkg_col]].dropna()
                        rho, pval, N = calculate_spearman_with_n(df_pair, ap_col, pkg_col)
                        all_bivar_ap_pkg.append({
                            'Region': region if region is not None else 'All',
                            'FreqBand': freq,
                            'AperiodicMetric': ap_name,
                            'PKGMetric': pkg_name,
                            'SpearmanRho': rho,
                            'PValue': pval,
                            'N': N,
                            'TestType': 'Bivariate_AP_PKG'
                        })
                        _push_p(pval, 'bivar_ap_pkg', len(all_bivar_ap_pkg) - 1)

                # --- Family 2: Partial (Aperiodic vs PKG | Beta, Gamma) ---
                covars = [c for c in OSCILLATORY_METRICS_COLS.keys() if c in df_rf.columns]
                if len(covars) == 2:
                    for ap_col, ap_name in APERIODIC_METRICS_COLS.items():
                        if ap_col not in df_rf.columns: 
                            continue
                        for pkg_col, pkg_name in pkg_map.items():
                            if pkg_col not in df_rf.columns: 
                                continue
                            cols = [ap_col, pkg_col] + covars
                            df_sub = df_rf[cols].dropna()
                            if len(df_sub) < MIN_SAMPLES_FOR_CORR or not all(df_sub[c].nunique() > 1 for c in [ap_col, pkg_col]):
                                prho, ppval, Np = (np.nan, np.nan, len(df_sub))
                            else:
                                prho, ppval, Np = calculate_partial_spearman(df_sub, ap_col, pkg_col, covars)
                            all_partial_ap_pkg.append({
                                'Region': region if region is not None else 'All',
                                'FreqBand': freq,
                                'AperiodicMetric': ap_name,
                                'PKGMetric': pkg_name,
                                'PartialSpearmanRho_vs_BetaGamma': prho,
                                'PartialPValue': ppval,
                                'N_Partial': Np,
                                'TestType': 'Partial_AP_PKG'
                            })
                            _push_p(ppval, 'partial_ap_pkg', len(all_partial_ap_pkg) - 1)

                # --- Family 3: Bivariate (Aperiodic vs Oscillatory) ---
                for ap_col, ap_name in APERIODIC_METRICS_COLS.items():
                    if ap_col not in df_rf.columns:
                        continue
                    for osc_col, osc_name in OSCILLATORY_METRICS_COLS.items():
                        if osc_col not in df_rf.columns:
                            continue
                        df_pair = df_rf[[ap_col, osc_col]].dropna()
                        rho, pval, N = calculate_spearman_with_n(df_pair, ap_col, osc_col)
                        all_bivar_ap_osc.append({
                            'Region': region if region is not None else 'All',
                            'FreqBand': freq,
                            'AperiodicMetric': ap_name,
                            'OscillatoryMetric': osc_name,
                            'SpearmanRho': rho,
                            'PValue': pval,
                            'N': N,
                            'TestType': 'Bivariate_AP_Osc'
                        })
                        _push_p(pval, 'bivar_ap_osc', len(all_bivar_ap_osc) - 1)

                # --- Family 4 (NEW): Bivariate (Oscillatory vs PKG) → Beta/Gamma vs BK/DK ---
                for pkg_col, pkg_name in pkg_map.items():
                    if pkg_col not in df_rf.columns:
                        continue
                    for osc_col, osc_name in [(beta_key, 'Beta_Peak_Power_at_DominantFreq'),
                                              (gamma_key, 'Gamma_Peak_Power_at_DominantFreq')]:
                        if osc_col is None or osc_col not in df_rf.columns:
                            continue
                        df_pair = df_rf[[osc_col, pkg_col]].dropna()
                        rho, pval, N = calculate_spearman_with_n(df_pair, osc_col, pkg_col)
                        all_bivar_osc_pkg.append({
                            'Region': region if region is not None else 'All',
                            'FreqBand': freq,
                            'OscillatoryMetric': OSCILLATORY_METRICS_COLS.get(osc_col, osc_name),
                            'PKGMetric': pkg_name,
                            'SpearmanRho': rho,
                            'PValue': pval,
                            'N': N,
                            'TestType': 'Bivariate_Osc_PKG'
                        })
                        _push_p(pval, 'bivar_osc_pkg', len(all_bivar_osc_pkg) - 1)

        # ===== FDR across all tests =====
        print(f"\nCollected {len(all_pvals)} p-values for FDR correction.")
        if all_pvals:
            rejected, pvals_corr = fdrcorrection(all_pvals, alpha=0.05, method='indep', is_sorted=False)
            for i, (fam, idx) in enumerate(pmap):
                if fam == 'bivar_ap_pkg':
                    all_bivar_ap_pkg[idx]['PValue_FDR'] = pvals_corr[i]
                    all_bivar_ap_pkg[idx]['Significant_FDR'] = bool(rejected[i])
                elif fam == 'partial_ap_pkg':
                    all_partial_ap_pkg[idx]['PartialPValue_FDR'] = pvals_corr[i]
                    all_partial_ap_pkg[idx]['Significant_FDR'] = bool(rejected[i])
                elif fam == 'bivar_ap_osc':
                    all_bivar_ap_osc[idx]['PValue_FDR'] = pvals_corr[i]
                    all_bivar_ap_osc[idx]['Significant_FDR'] = bool(rejected[i])
                elif fam == 'bivar_osc_pkg':
                    all_bivar_osc_pkg[idx]['PValue_FDR'] = pvals_corr[i]
                    all_bivar_osc_pkg[idx]['Significant_FDR'] = bool(rejected[i])
            print("FDR correction applied.")
        else:
            print("No valid p-values to correct.")

        # ===== Save CSVs =====
        out_dir = analysis_session_plot_folder_step4
        if all_bivar_ap_pkg:
            df_bivar_ap_pkg = pd.DataFrame(all_bivar_ap_pkg)
            df_bivar_ap_pkg.to_csv(os.path.join(out_dir, f"{patient_hemisphere_id}_Bivariate_AP_vs_PKG_byRegion_FDR.csv"), index=False)
            print("Saved:", os.path.join(out_dir, f"{patient_hemisphere_id}_Bivariate_AP_vs_PKG_byRegion_FDR.csv"))
        else:
            df_bivar_ap_pkg = pd.DataFrame()

        if all_partial_ap_pkg:
            df_partial_ap_pkg = pd.DataFrame(all_partial_ap_pkg)
            df_partial_ap_pkg.to_csv(os.path.join(out_dir, f"{patient_hemisphere_id}_Partial_AP_vs_PKG_byRegion_FDR.csv"), index=False)
            print("Saved:", os.path.join(out_dir, f"{patient_hemisphere_id}_Partial_AP_vs_PKG_byRegion_FDR.csv"))
        else:
            df_partial_ap_pkg = pd.DataFrame()

        if all_bivar_ap_osc:
            df_bivar_ap_osc = pd.DataFrame(all_bivar_ap_osc)
            df_bivar_ap_osc.to_csv(os.path.join(out_dir, f"{patient_hemisphere_id}_Bivariate_AP_vs_Osc_byRegion_FDR.csv"), index=False)
            print("Saved:", os.path.join(out_dir, f"{patient_hemisphere_id}_Bivariate_AP_vs_Osc_byRegion_FDR.csv"))
        else:
            df_bivar_ap_osc = pd.DataFrame()

        if all_bivar_osc_pkg:
            df_bivar_osc_pkg = pd.DataFrame(all_bivar_osc_pkg)
            df_bivar_osc_pkg.to_csv(os.path.join(out_dir, f"{patient_hemisphere_id}_Bivariate_Osc_vs_PKG_byRegion_FDR.csv"), index=False)
            print("Saved:", os.path.join(out_dir, f"{patient_hemisphere_id}_Bivariate_Osc_vs_PKG_byRegion_FDR.csv"))
        else:
            df_bivar_osc_pkg = pd.DataFrame()

        # ===== Minimal plotting (only if enabled) =====
        if ENABLE_PLOTS:
            # AP vs PKG (FDR-significant only)
            plot_dir_ap_pkg = os.path.join(out_dir, "Region_Bivar_AP_vs_PKG_FDRsig")
            os.makedirs(plot_dir_ap_pkg, exist_ok=True)

            for r in all_bivar_ap_pkg:
                if r.get('Significant_FDR', False) and r['N'] >= MIN_SAMPLES_FOR_CORR:
                    region = r['Region']; freq = r['FreqBand']
                    ap_name = r['AperiodicMetric']; pkg_name = r['PKGMetric']
                    ap_col = [k for k, v in APERIODIC_METRICS_COLS.items() if v == ap_name][0]
                    pkg_col = [k for k, v in pkg_map.items() if v == pkg_name][0]

                    dfp = src.copy()
                    if 'Region' in dfp.columns:
                        dfp = dfp[dfp['Region'] == region]
                    dfp = dfp[dfp[FOOOF_FREQ_BAND_COL] == freq].dropna(subset=[ap_col, pkg_col])
                    if dfp.empty: 
                        continue

                    plt.figure(figsize=(6,6))
                    ax = plt.gca(); ax.grid(False)

                    if 'datetime_for_avg_c5' in dfp.columns and not dfp['datetime_for_avg_c5'].isnull().all():
                        try:
                            pts = (dfp.set_index('datetime_for_avg_c5')
                                      .groupby(pd.Grouper(freq='10T'))[[ap_col, pkg_col]]
                                      .mean().dropna())
                        except Exception:
                            pts = dfp.copy()
                    else:
                        pts = dfp.copy()

                    if not pts.empty:
                        sns.scatterplot(data=pts, x=pkg_col, y=ap_col,
                                        alpha=DOT_ALPHA+0.1, s=40, edgecolor='k', linewidths=0.5, ax=ax)
                    sns.regplot(data=dfp, x=pkg_col, y=ap_col, scatter=False, ax=ax,
                                line_kws={'color': 'black', 'linewidth': REG_LINE_THICKNESS_STEP4, 'alpha': 0.7})

                    annotate_correlation_on_plot(ax, r['SpearmanRho'], r.get('PValue_FDR', np.nan), r['N'],
                                                 test_type="Spearman ρ (FDR)", fontsize=9*font_scale_factor)
                    ax.set_xlabel(pkg_name.replace('PKG ', ''))
                    ax.set_ylabel(ap_name.replace('Aperiodic ', ''))
                    ax.tick_params(axis='both', which='major', labelsize=plt.rcParams['xtick.labelsize']*font_scale_factor)

                    fname = f"Bivar_{get_safe_filename_step4(ap_name)}_vs_{get_safe_filename_step4(pkg_name)}_{region}_{freq}_FDRsig.png"
                    plt.tight_layout()
                    plt.savefig(os.path.join(plot_dir_ap_pkg, fname))
                    plt.close()

            # NEW: Osc (Beta/Gamma) vs PKG (FDR-significant only)
            plot_dir_osc_pkg = os.path.join(out_dir, "Region_Bivar_Osc_vs_PKG_FDRsig")
            os.makedirs(plot_dir_osc_pkg, exist_ok=True)

            for r in all_bivar_osc_pkg:
                if r.get('Significant_FDR', False) and r['N'] >= MIN_SAMPLES_FOR_CORR:
                    region = r['Region']; freq = r['FreqBand']
                    osc_name = r['OscillatoryMetric']; pkg_name = r['PKGMetric']
                    osc_col = [k for k, v in OSCILLATORY_METRICS_COLS.items() if v == osc_name][0]
                    pkg_col = [k for k, v in pkg_map.items() if v == pkg_name][0]

                    dfp = src.copy()
                    if 'Region' in dfp.columns:
                        dfp = dfp[dfp['Region'] == region]
                    dfp = dfp[dfp[FOOOF_FREQ_BAND_COL] == freq].dropna(subset=[osc_col, pkg_col])
                    if dfp.empty:
                        continue

                    plt.figure(figsize=(6,6))
                    ax = plt.gca(); ax.grid(False)

                    if 'datetime_for_avg_c5' in dfp.columns and not dfp['datetime_for_avg_c5'].isnull().all():
                        try:
                            pts = (dfp.set_index('datetime_for_avg_c5')
                                      .groupby(pd.Grouper(freq='10T'))[[osc_col, pkg_col]]
                                      .mean().dropna())
                        except Exception:
                            pts = dfp.copy()
                    else:
                        pts = dfp.copy()

                    if not pts.empty:
                        sns.scatterplot(data=pts, x=pkg_col, y=osc_col,
                                        alpha=DOT_ALPHA+0.1, s=40, edgecolor='k', linewidths=0.5, ax=ax)
                    sns.regplot(data=dfp, x=pkg_col, y=osc_col, scatter=False, ax=ax,
                                line_kws={'color': 'black', 'linewidth': REG_LINE_THICKNESS_STEP4, 'alpha': 0.7})

                    annotate_correlation_on_plot(ax, r['SpearmanRho'], r.get('PValue_FDR', np.nan), r['N'],
                                                 test_type="Spearman ρ (FDR)", fontsize=9*font_scale_factor)
                    ax.set_xlabel(pkg_name.replace('PKG ', ''))
                    ax.set_ylabel(osc_name.replace(' at DominantFreq', ''))
                    ax.tick_params(axis='both', which='major', labelsize=plt.rcParams['xtick.labelsize']*font_scale_factor)

                    fname = f"Bivar_{get_safe_filename_step4(osc_name)}_vs_{get_safe_filename_step4(pkg_name)}_{region}_{freq}_FDRsig.png"
                    plt.tight_layout()
                    plt.savefig(os.path.join(plot_dir_osc_pkg, fname))
                    plt.close()

            # AP vs Osc (FDR-significant only) — unchanged
            plot_dir_ap_osc = os.path.join(out_dir, "Region_Bivar_AP_vs_Osc_FDRsig")
            os.makedirs(plot_dir_ap_osc, exist_ok=True)

            for r in all_bivar_ap_osc:
                if r.get('Significant_FDR', False) and r['N'] >= MIN_SAMPLES_FOR_CORR:
                    region = r['Region']; freq = r['FreqBand']
                    ap_name = r['AperiodicMetric']; osc_name = r['OscillatoryMetric']
                    ap_col = [k for k, v in APERIODIC_METRICS_COLS.items() if v == ap_name][0]
                    osc_col = [k for k, v in OSCILLATORY_METRICS_COLS.items() if v == osc_name][0]

                    dfp = src.copy()
                    if 'Region' in dfp.columns:
                        dfp = dfp[dfp['Region'] == region]
                    dfp = dfp[dfp[FOOOF_FREQ_BAND_COL] == freq].dropna(subset=[ap_col, osc_col])
                    if dfp.empty:
                        continue

                    plt.figure(figsize=(6,6))
                    ax = plt.gca(); ax.grid(False)

                    if 'datetime_for_avg_c5' in dfp.columns and not dfp['datetime_for_avg_c5'].isnull().all():
                        try:
                            pts = (dfp.set_index('datetime_for_avg_c5')
                                      .groupby(pd.Grouper(freq='10T'))[[ap_col, osc_col]]
                                      .mean().dropna())
                        except Exception:
                            pts = dfp.copy()
                    else:
                        pts = dfp.copy()

                    if not pts.empty:
                        sns.scatterplot(data=pts, x=osc_col, y=ap_col,
                                        alpha=DOT_ALPHA+0.1, s=40, edgecolor='k', linewidths=0.5, ax=ax)
                    sns.regplot(data=dfp, x=osc_col, y=ap_col, scatter=False, ax=ax,
                                line_kws={'color': 'black', 'linewidth': REG_LINE_THICKNESS_STEP4, 'alpha': 0.7})

                    annotate_correlation_on_plot(ax, r['SpearmanRho'], r.get('PValue_FDR', np.nan), r['N'],
                                                 test_type="Spearman ρ (FDR)", fontsize=9*font_scale_factor)
                    ax.set_xlabel(osc_name.replace(' at DominantFreq', ''))
                    ax.set_ylabel(ap_name.replace('Aperiodic ', ''))
                    ax.tick_params(axis='both', which='major', labelsize=plt.rcParams['xtick.labelsize']*font_scale_factor)

                    fname = f"Bivar_{get_safe_filename_step4(ap_name)}_vs_{get_safe_filename_step4(osc_name)}_{region}_{freq}_FDRsig.png"
                    plt.tight_layout()
                    plt.savefig(os.path.join(plot_dir_ap_osc, fname))
                    plt.close()

print("\n--- Cell 5 (Region-level, FDR): Correlation Analyses Complete ---")



--- Cell 5 (Region-level, FDR): Starting Correlation Analyses ---

Collected 96 p-values for FDR correction.
FDR correction applied.
Saved: /home/jackson/step2_final/step3_fooof_results_neural_pkg_aligned_finalstep3_bushlab5000/step4_within_subject/COHORT_RCS02_05_06_plots_20250822_155044/COHORT_RCS02_05_06_Bivariate_AP_vs_PKG_byRegion_FDR.csv
Saved: /home/jackson/step2_final/step3_fooof_results_neural_pkg_aligned_finalstep3_bushlab5000/step4_within_subject/COHORT_RCS02_05_06_plots_20250822_155044/COHORT_RCS02_05_06_Partial_AP_vs_PKG_byRegion_FDR.csv
Saved: /home/jackson/step2_final/step3_fooof_results_neural_pkg_aligned_finalstep3_bushlab5000/step4_within_subject/COHORT_RCS02_05_06_plots_20250822_155044/COHORT_RCS02_05_06_Bivariate_AP_vs_Osc_byRegion_FDR.csv
Saved: /home/jackson/step2_final/step3_fooof_results_neural_pkg_aligned_finalstep3_bushlab5000/step4_within_subject/COHORT_RCS02_05_06_plots_20250822_155044/COHORT_RCS02_05_06_Bivariate_Osc_vs_PKG_byRegion_FDR.csv

--- Cell 5 (Re

In [14]:
# -*- coding: utf-8 -*-
# --- Cell 6 (Region-level, Streamlined V4 with LRT): Multiple Linear Regression ---
# Compares nested models with Likelihood Ratio Tests for each Region (STN/M1) × FreqRangeLabel.

import pandas as pd
import numpy as np
import os
import statsmodels.formula.api as smf
from statsmodels.stats.anova import anova_lm
import matplotlib.pyplot as plt
import seaborn as sns

print("\n--- Cell 6 (Region-level, Streamlined V4 with LRT): Starting Analyses ---\n")

# p-value threshold for interpretation notes (does not affect model fitting)
P_VALUE_THRESHOLD = 0.05
print(f"p-value threshold is {P_VALUE_THRESHOLD}.")

# <<< --- USER TOGGLES --- >>>
ANALYZE_ALL_FREQ_BANDS = False
TARGET_FREQ_BAND_IF_NOT_ALL = "WideFreq"
MLR_ENABLE_PLOTS = False  # keep plots off by default
# <<< -------------------- >>>

# ---------- Source data (Region-level if available) ----------
if 'df_analysis' in globals() and df_analysis is not None and not df_analysis.empty:
    src6 = df_analysis.copy()
else:
    src6 = master_df_step4.copy() if ('master_df_step4' in globals() and master_df_step4 is not None) else None

if src6 is None or src6.empty:
    print("No data available for Cell 6. Skipping.")
else:
    # Ensure Region exists; if not, derive from Channel_Display (same rule as Cell 3)
    if 'Region' not in src6.columns and CHANNEL_DISPLAY_COL in src6.columns:
        _m = src6[CHANNEL_DISPLAY_COL].astype(str).str.extract(r'Contact_(\d+)_(\d+)', expand=True)
        _a = pd.to_numeric(_m[0], errors='coerce')
        _b = pd.to_numeric(_m[1], errors='coerce')
        src6['Region'] = np.where((_a <= 3) & (_b <= 3), 'STN',
                           np.where((_a >= 4) & (_b >= 4), 'M1', 'Mixed'))
    if 'Region' in src6.columns:
        src6 = src6[src6['Region'].isin(['STN', 'M1'])].copy()

    # Frequency bands to run
    freq_bands_to_process = ORDERED_FREQ_LABELS if ANALYZE_ALL_FREQ_BANDS else [TARGET_FREQ_BAND_IF_NOT_ALL]
    if not ANALYZE_ALL_FREQ_BANDS:
        print(f"--- Analyzing ONLY Freq Band: {TARGET_FREQ_BAND_IF_NOT_ALL} ---")

    # Column names (ensure presence)
    exponent_col_name = 'Exponent_BestModel'
    offset_col_name   = 'Offset_BestModel'
    beta_col_name     = 'Beta_Peak_Power_at_DominantFreq'
    gamma_col_name    = 'Gamma_Peak_Power_at_DominantFreq'

    available_aperiodic_cols   = [c for c in [exponent_col_name, offset_col_name] if c in src6.columns]
    available_oscillatory_cols = [c for c in [beta_col_name, gamma_col_name] if c in src6.columns]

    if not available_oscillatory_cols:
        print("No oscillatory predictors present (Beta/Gamma). Skipping Cell 6.")
    else:
        # Optional: small diagnostic plots folder (kept off by default)
        plot_subdir_mlr = os.path.join(analysis_session_plot_folder_step4, "MLR_byRegion_V4")
        os.makedirs(plot_subdir_mlr, exist_ok=True)
        if MLR_ENABLE_PLOTS:
            plot_subdir_ap_corr = os.path.join(plot_subdir_mlr, "Aperiodic_Intercorrelations")
            os.makedirs(plot_subdir_ap_corr, exist_ok=True)
        else:
            plot_subdir_ap_corr = None

        mlr_rows = []
        lrt_rows = []

        # Helper: fit OLS safely
        def _fit_ols(df_in, dv_col, predictors):
            """Return statsmodels fit or None if not enough data / issues."""
            if any(p not in df_in.columns for p in predictors):
                return None
            use_cols = [dv_col] + predictors
            dfm = df_in[use_cols].dropna(how='any').copy()
            # Require at least (#predictors + 10) samples (rule of thumb)
            if len(dfm) < (len(predictors) + 10):
                return None
            # No constant predictors
            for p in predictors:
                if dfm[p].nunique() < 2:
                    return None
            try:
                formula = f"{dv_col} ~ {' + '.join(predictors)}"
                fit = smf.ols(formula=formula, data=dfm).fit()
                return fit
            except Exception:
                return None

        # Iterate Region × Freq
        for region in (['STN', 'M1'] if 'Region' in src6.columns else ['All']):
            df_region = src6 if region == 'All' else src6[src6['Region'] == region]
            if df_region.empty:
                continue

            for freq_label in freq_bands_to_process:
                if FOOOF_FREQ_BAND_COL not in df_region.columns:
                    print(f"Missing '{FOOOF_FREQ_BAND_COL}'. Skipping.")
                    continue
                df_rf = df_region[df_region[FOOOF_FREQ_BAND_COL] == freq_label].copy()
                if df_rf.empty:
                    continue

                # Optional: quick inter-corr plots (Aperiodic and Oscillatory); gated
                if MLR_ENABLE_PLOTS and plot_subdir_ap_corr:
                    # Aperiodic inter-corr
                    if exponent_col_name in df_rf.columns and offset_col_name in df_rf.columns:
                        df_ap = df_rf[[exponent_col_name, offset_col_name]].dropna()
                        if len(df_ap) >= MIN_SAMPLES_FOR_CORR:
                            plt.figure(figsize=(6, 6)); ax = plt.gca(); ax.grid(False)
                            sns.scatterplot(data=df_ap, x=exponent_col_name, y=offset_col_name,
                                            alpha=DOT_ALPHA_STEP4, s=40, edgecolor='k', linewidths=0.5, ax=ax)
                            sns.regplot(data=df_ap, x=exponent_col_name, y=offset_col_name, scatter=False, ax=ax,
                                        line_kws={'color':'black','linewidth':REG_LINE_THICKNESS_STEP4,'alpha':0.7})
                            ax.set_title(f"Aperiodic Inter-corr — {region} [{freq_label}]")
                            out_ap = os.path.join(plot_subdir_ap_corr,
                                                  f"AperiodicCorr_Exponent_vs_Offset_{region}_{freq_label}.png")
                            plt.tight_layout(); plt.savefig(out_ap); plt.close()

                    # Oscillatory inter-corr
                    if beta_col_name in df_rf.columns and gamma_col_name in df_rf.columns:
                        df_osc = df_rf[[beta_col_name, gamma_col_name]].dropna()
                        if len(df_osc) >= MIN_SAMPLES_FOR_CORR:
                            plt.figure(figsize=(6, 6)); ax = plt.gca(); ax.grid(False)
                            sns.scatterplot(data=df_osc, x=beta_col_name, y=gamma_col_name,
                                            alpha=DOT_ALPHA_STEP4, s=40, edgecolor='k', linewidths=0.5, ax=ax)
                            sns.regplot(data=df_osc, x=beta_col_name, y=gamma_col_name, scatter=False, ax=ax,
                                        line_kws={'color':'black','linewidth':REG_LINE_THICKNESS_STEP4,'alpha':0.7})
                            ax.set_title(f"Oscillatory Inter-corr — {region} [{freq_label}]")
                            out_osc = os.path.join(plot_subdir_ap_corr,
                                                   f"OscillatoryCorr_Beta_vs_Gamma_{region}_{freq_label}.png")
                            plt.tight_layout(); plt.savefig(out_osc); plt.close()

                # ---------- Fit models per DV ----------
                for pkg_col, pkg_name in PKG_METRICS_COLS.items():
                    if pkg_col not in df_rf.columns:
                        continue

                    # Reduced model (Oscillatory only)
                    osc_predictors = [c for c in [beta_col_name, gamma_col_name] if c in df_rf.columns]
                    reduced_fit = _fit_ols(df_rf, pkg_col, osc_predictors)

                    # Full model 1: Exponent + Osc
                    if exponent_col_name in df_rf.columns and reduced_fit is not None:
                        predictors_full_exp = [exponent_col_name] + osc_predictors
                        full_exp_fit = _fit_ols(df_rf, pkg_col, predictors_full_exp)
                    else:
                        full_exp_fit = None

                    # Full model 2: Offset + Osc
                    if offset_col_name in df_rf.columns and reduced_fit is not None:
                        predictors_full_off = [offset_col_name] + osc_predictors
                        full_off_fit = _fit_ols(df_rf, pkg_col, predictors_full_off)
                    else:
                        full_off_fit = None

                    # Collect base (reduced) model info if it exists
                    if reduced_fit is not None:
                        mlr_rows.append({
                            'Region': region, 'FreqBand': freq_label, 'PKG_Symptom_DV': pkg_name,
                            'Model_Tier': 'Tier 1: Osc Only',
                            'Formula': f"{pkg_col} ~ {' + '.join(osc_predictors)}",
                            'N_model': int(reduced_fit.nobs),
                            'Adj_R2': float(reduced_fit.rsquared_adj),
                            'AIC': float(reduced_fit.aic),
                            'BIC': float(reduced_fit.bic)
                        })

                    # LRT: Exponent + Osc vs Osc
                    if reduced_fit is not None and full_exp_fit is not None:
                        try:
                            lrt_tbl = anova_lm(reduced_fit, full_exp_fit)  # nested comparison
                            f_stat = float(lrt_tbl.iloc[1]['F'])
                            p_val  = float(lrt_tbl.iloc[1]['Pr(>F)'])
                        except Exception:
                            f_stat, p_val = (np.nan, np.nan)

                        delta_adjR2 = float(full_exp_fit.rsquared_adj - reduced_fit.rsquared_adj)

                        lrt_rows.append({
                            'Region': region, 'FreqBand': freq_label, 'PKG_Symptom_DV': pkg_name,
                            'Comparison': 'Exponent + Osc vs. Osc Only',
                            'F_statistic': f_stat, 'P_value': p_val,
                            'Delta_AdjR2': delta_adjR2,
                            'AdjR2_Reduced': float(reduced_fit.rsquared_adj),
                            'AdjR2_Full': float(full_exp_fit.rsquared_adj),
                            'AIC_Reduced': float(reduced_fit.aic), 'AIC_Full': float(full_exp_fit.aic),
                            'BIC_Reduced': float(reduced_fit.bic), 'BIC_Full': float(full_exp_fit.bic),
                            'N_reduced': int(reduced_fit.nobs), 'N_full': int(full_exp_fit.nobs)
                        })

                        # Store full model coefficients too
                        for term in full_exp_fit.params.index:
                            if term == 'Intercept': 
                                continue
                            mlr_rows.append({
                                'Region': region, 'FreqBand': freq_label, 'PKG_Symptom_DV': pkg_name,
                                'Model_Tier': 'Tier 2: Exponent + Osc',
                                'Formula': f"{pkg_col} ~ {' + '.join(predictors_full_exp)}",
                                'Predictor_Term': term,
                                'Predictor_Name_Display': APERIODIC_METRICS_COLS.get(term, OSCILLATORY_METRICS_COLS.get(term, term)),
                                'Coefficient': float(full_exp_fit.params.get(term, np.nan)),
                                'StdErr': float(full_exp_fit.bse.get(term, np.nan)),
                                'PValue': float(full_exp_fit.pvalues.get(term, np.nan)),
                                'Conf_Int_Lower': float(full_exp_fit.conf_int().loc[term, 0]) if term in full_exp_fit.conf_int().index else np.nan,
                                'Conf_Int_Upper': float(full_exp_fit.conf_int().loc[term, 1]) if term in full_exp_fit.conf_int().index else np.nan,
                                'N_model': int(full_exp_fit.nobs),
                                'Adj_R2': float(full_exp_fit.rsquared_adj),
                                'AIC': float(full_exp_fit.aic),
                                'BIC': float(full_exp_fit.bic)
                            })

                    # LRT: Offset + Osc vs Osc
                    if reduced_fit is not None and full_off_fit is not None:
                        try:
                            lrt_tbl = anova_lm(reduced_fit, full_off_fit)  # nested comparison
                            f_stat = float(lrt_tbl.iloc[1]['F'])
                            p_val  = float(lrt_tbl.iloc[1]['Pr(>F)'])
                        except Exception:
                            f_stat, p_val = (np.nan, np.nan)

                        delta_adjR2 = float(full_off_fit.rsquared_adj - reduced_fit.rsquared_adj)

                        lrt_rows.append({
                            'Region': region, 'FreqBand': freq_label, 'PKG_Symptom_DV': pkg_name,
                            'Comparison': 'Offset + Osc vs. Osc Only',
                            'F_statistic': f_stat, 'P_value': p_val,
                            'Delta_AdjR2': delta_adjR2,
                            'AdjR2_Reduced': float(reduced_fit.rsquared_adj),
                            'AdjR2_Full': float(full_off_fit.rsquared_adj),
                            'AIC_Reduced': float(reduced_fit.aic), 'AIC_Full': float(full_off_fit.aic),
                            'BIC_Reduced': float(reduced_fit.bic), 'BIC_Full': float(full_off_fit.bic),
                            'N_reduced': int(reduced_fit.nobs), 'N_full': int(full_off_fit.nobs)
                        })

                        # Store full model coefficients too
                        for term in full_off_fit.params.index:
                            if term == 'Intercept':
                                continue
                            mlr_rows.append({
                                'Region': region, 'FreqBand': freq_label, 'PKG_Symptom_DV': pkg_name,
                                'Model_Tier': 'Tier 3: Offset + Osc',
                                'Formula': f"{pkg_col} ~ {' + '.join(predictors_full_off)}",
                                'Predictor_Term': term,
                                'Predictor_Name_Display': APERIODIC_METRICS_COLS.get(term, OSCILLATORY_METRICS_COLS.get(term, term)),
                                'Coefficient': float(full_off_fit.params.get(term, np.nan)),
                                'StdErr': float(full_off_fit.bse.get(term, np.nan)),
                                'PValue': float(full_off_fit.pvalues.get(term, np.nan)),
                                'Conf_Int_Lower': float(full_off_fit.conf_int().loc[term, 0]) if term in full_off_fit.conf_int().index else np.nan,
                                'Conf_Int_Upper': float(full_off_fit.conf_int().loc[term, 1]) if term in full_off_fit.conf_int().index else np.nan,
                                'N_model': int(full_off_fit.nobs),
                                'Adj_R2': float(full_off_fit.rsquared_adj),
                                'AIC': float(full_off_fit.aic),
                                'BIC': float(full_off_fit.bic)
                            })

        # ------ Save outputs ------
        if lrt_rows:
            df_lrt = pd.DataFrame(lrt_rows)
            csv_lrt = f"{patient_hemisphere_id}_MLR_LRT_Results_byRegion_Step6.csv"
            df_lrt.to_csv(os.path.join(plot_subdir_mlr, csv_lrt), index=False)
            print(f"\nSaved LRT results: {csv_lrt}")
            print(df_lrt.head())
        else:
            print("\nNo Likelihood Ratio Tests were performed (insufficient data or predictors missing).")

        if mlr_rows:
            df_mlr = pd.DataFrame(mlr_rows)
            csv_mlr = f"{patient_hemisphere_id}_MLR_Coefficients_byRegion_Step6.csv"
            df_mlr.to_csv(os.path.join(plot_subdir_mlr, csv_mlr), index=False)
            print(f"Saved model summaries/coefficients: {csv_mlr}")
            cols_show = ['Region','FreqBand','PKG_Symptom_DV','Model_Tier','Predictor_Term','Coefficient','PValue','Adj_R2','N_model']
            print(df_mlr[[c for c in cols_show if c in df_mlr.columns]].head())
        else:
            print("No MLR models were successfully fitted or nothing to save for coefficients.")

print("\n--- Cell 6 (Region-level, Streamlined V4 with LRT): Analyses Complete ---")



--- Cell 6 (Region-level, Streamlined V4 with LRT): Starting Analyses ---

p-value threshold is 0.05.
--- Analyzing ONLY Freq Band: WideFreq ---

Saved LRT results: COHORT_RCS02_05_06_MLR_LRT_Results_byRegion_Step6.csv
  Region  FreqBand    PKG_Symptom_DV                   Comparison  \
0    STN  WideFreq      PKG BK Score  Exponent + Osc vs. Osc Only   
1    STN  WideFreq      PKG BK Score    Offset + Osc vs. Osc Only   
2    STN  WideFreq      PKG DK Score  Exponent + Osc vs. Osc Only   
3    STN  WideFreq      PKG DK Score    Offset + Osc vs. Osc Only   
4    STN  WideFreq  PKG Tremor Score  Exponent + Osc vs. Osc Only   

   F_statistic       P_value  Delta_AdjR2  AdjR2_Reduced  AdjR2_Full  \
0   235.373151  1.019709e-51     0.052536       0.014308    0.066844   
1   234.864531  1.298544e-51     0.052428       0.014308    0.066736   
2     0.944866  3.310861e-01    -0.000013       0.027536    0.027523   
3     0.799906  3.711733e-01    -0.000047       0.027536    0.027489   
4    

In [15]:
# -*- coding: utf-8 -*-
# --- Cell 6A (Revised, Region-level): State-Specific MLR with LRT + FDR ---
# Runs nested OLS models per Clinical State × Region (STN/M1) × FreqRangeLabel.
# Reduced:  PKG ~ Beta + Gamma
# Full #1:  PKG ~ Exponent + Beta + Gamma
# Full #2:  PKG ~ Offset  + Beta + Gamma
# Performs LRTs (anova_lm) and applies BH-FDR across ALL state-specific LRT p-values.

import pandas as pd
import numpy as np
import os
import statsmodels.formula.api as smf
from statsmodels.stats.anova import anova_lm
from statsmodels.stats.multitest import fdrcorrection

print("\n--- Cell 6A (Region-level): State-Specific MLR & LRT with FDR ---\n")

# <<< --- USER TOGGLES --- >>>
ANALYZE_ALL_FREQ_BANDS_STATE_SPECIFIC = False
TARGET_FREQ_BAND_IF_NOT_ALL_STATE_SPECIFIC = "WideFreq"
MIN_STATE_ROWS = 20  # skip states with fewer rows than this
# <<< ---------------------- >>>

# ---- Source dataframe ----
if 'df_analysis' in globals() and df_analysis is not None and not df_analysis.empty:
    src6a = df_analysis.copy()
else:
    src6a = master_df_step4.copy() if ('master_df_step4' in globals() and master_df_step4 is not None) else None

if src6a is None or src6a.empty:
    print("No data available. Skipping Cell 6A.")
else:
    # Ensure Region column exists (STN/M1 only)
    if 'Region' not in src6a.columns and CHANNEL_DISPLAY_COL in src6a.columns:
        _m = src6a[CHANNEL_DISPLAY_COL].astype(str).str.extract(r'Contact_(\d+)_(\d+)', expand=True)
        _a = pd.to_numeric(_m[0], errors='coerce')
        _b = pd.to_numeric(_m[1], errors='coerce')
        src6a['Region'] = np.where((_a <= 3) & (_b <= 3), 'STN',
                            np.where((_a >= 4) & (_b >= 4), 'M1', 'Mixed'))
    if 'Region' in src6a.columns:
        src6a = src6a[src6a['Region'].isin(['STN', 'M1'])].copy()

    # Frequency bands to run
    freq_bands = ORDERED_FREQ_LABELS if ANALYZE_ALL_FREQ_BANDS_STATE_SPECIFIC else [TARGET_FREQ_BAND_IF_NOT_ALL_STATE_SPECIFIC]
    if not ANALYZE_ALL_FREQ_BANDS_STATE_SPECIFIC:
        print(f"--- Analyzing ONLY Freq Band: {TARGET_FREQ_BAND_IF_NOT_ALL_STATE_SPECIFIC} ---")

    # Column names
    exponent_col = 'Exponent_BestModel'
    offset_col   = 'Offset_BestModel'
    beta_col     = 'Beta_Peak_Power_at_DominantFreq'
    gamma_col    = 'Gamma_Peak_Power_at_DominantFreq'

    # Predictors present?
    have_beta  = beta_col  in src6a.columns
    have_gamma = gamma_col in src6a.columns
    if not (have_beta or have_gamma):
        print("No oscillatory predictors (Beta/Gamma) available. Skipping Cell 6A.")
    else:
        # Output folder
        outdir_states = os.path.join(STATE_SPECIFIC_ANALYSIS_DIR, "MLR_State_Specific_byRegion_FDR")
        os.makedirs(outdir_states, exist_ok=True)

        # Helpers
        def _fit_ols(df_in, dv_col, predictors):
            """Fit OLS safely; return fit or None."""
            if any(p not in df_in.columns for p in predictors):
                return None
            use_cols = [dv_col] + predictors
            dfm = df_in[use_cols].dropna(how='any').copy()
            # Rule of thumb: at least (#predictors + 10) samples
            if len(dfm) < (len(predictors) + 10):
                return None
            # drop if any predictor is constant
            for p in predictors:
                if dfm[p].nunique() < 2:
                    return None
            try:
                formula = f"{dv_col} ~ {' + '.join(predictors)}"
                return smf.ols(formula=formula, data=dfm).fit()
            except Exception:
                return None

        # Collectors
        coeff_rows = []   # model summaries / coefficients
        lrt_rows   = []   # LRT results (will get FDR)

        # Iterate states
        for state in TARGET_CLINICAL_STATES_ORDERED:
            if CLINICAL_STATE_COL not in src6a.columns:
                print(f"Missing '{CLINICAL_STATE_COL}'. Skipping all states.")
                break

            df_state = src6a[src6a[CLINICAL_STATE_COL] == state].copy()
            if df_state.empty or len(df_state) < MIN_STATE_ROWS:
                print(f"SKIP state '{state}' (N={len(df_state)})")
                continue

            for region in ['STN', 'M1']:
                df_sr = df_state[df_state['Region'] == region]
                if df_sr.empty:
                    continue

                for band in freq_bands:
                    if FOOOF_FREQ_BAND_COL not in df_sr.columns:
                        print(f"Missing '{FOOOF_FREQ_BAND_COL}'. Skipping.")
                        continue
                    df_srf = df_sr[df_sr[FOOOF_FREQ_BAND_COL] == band].copy()
                    if df_srf.empty:
                        continue

                    # Osc-only predictors available in this slice?
                    osc_predictors = [c for c in [beta_col, gamma_col] if c in df_srf.columns]
                    if not osc_predictors:
                        continue

                    # Run per PKG DV
                    for dv_col, dv_name in PKG_METRICS_COLS.items():
                        if dv_col not in df_srf.columns:
                            continue

                        reduced = _fit_ols(df_srf, dv_col, osc_predictors)

                        # Full (Exponent)
                        full_exp = _fit_ols(df_srf, dv_col, [exponent_col] + osc_predictors) if exponent_col in df_srf.columns and reduced is not None else None
                        # Full (Offset)
                        full_off = _fit_ols(df_srf, dv_col, [offset_col]   + osc_predictors) if offset_col   in df_srf.columns and reduced is not None else None

                        # Store reduced model summary if it exists
                        if reduced is not None:
                            coeff_rows.append({
                                'ClinicalState': state, 'Region': region, 'FreqBand': band, 'PKG_Symptom_DV': dv_name,
                                'Model_Tier': 'Tier 1: Osc Only',
                                'Formula': f"{dv_col} ~ {' + '.join(osc_predictors)}",
                                'N_model': int(reduced.nobs),
                                'Adj_R2': float(reduced.rsquared_adj),
                                'AIC': float(reduced.aic), 'BIC': float(reduced.bic)
                            })

                        # LRT (Exponent + Osc vs Osc)
                        if reduced is not None and full_exp is not None:
                            try:
                                lrt_tbl = anova_lm(reduced, full_exp)  # nested comparison
                                f_stat = float(lrt_tbl.iloc[1]['F'])
                                p_val  = float(lrt_tbl.iloc[1]['Pr(>F)'])
                            except Exception:
                                f_stat, p_val = (np.nan, np.nan)

                            delta_adjR2 = float(full_exp.rsquared_adj - reduced.rsquared_adj)
                            lrt_rows.append({
                                'ClinicalState': state, 'Region': region, 'FreqBand': band, 'PKG_Symptom_DV': dv_name,
                                'Comparison': 'Exponent + Osc vs. Osc Only',
                                'F_statistic': f_stat, 'P_value': p_val,
                                'Delta_AdjR2': delta_adjR2,
                                'AdjR2_Reduced': float(reduced.rsquared_adj), 'AdjR2_Full': float(full_exp.rsquared_adj),
                                'AIC_Reduced': float(reduced.aic), 'AIC_Full': float(full_exp.aic),
                                'BIC_Reduced': float(reduced.bic), 'BIC_Full': float(full_exp.bic),
                                'N_reduced': int(reduced.nobs), 'N_full': int(full_exp.nobs)
                            })
                            # Store coefficients of full model
                            for term in full_exp.params.index:
                                if term == 'Intercept':
                                    continue
                                coeff_rows.append({
                                    'ClinicalState': state, 'Region': region, 'FreqBand': band, 'PKG_Symptom_DV': dv_name,
                                    'Model_Tier': 'Tier 2: Exponent + Osc',
                                    'Formula': f"{dv_col} ~ {exponent_col} + {' + '.join(osc_predictors)}",
                                    'Predictor_Term': term,
                                    'Predictor_Name_Display': APERIODIC_METRICS_COLS.get(term, OSCILLATORY_METRICS_COLS.get(term, term)),
                                    'Coefficient': float(full_exp.params.get(term, np.nan)),
                                    'StdErr': float(full_exp.bse.get(term, np.nan)),
                                    'PValue': float(full_exp.pvalues.get(term, np.nan)),
                                    'Conf_Int_Lower': float(full_exp.conf_int().loc[term, 0]) if term in full_exp.conf_int().index else np.nan,
                                    'Conf_Int_Upper': float(full_exp.conf_int().loc[term, 1]) if term in full_exp.conf_int().index else np.nan,
                                    'N_model': int(full_exp.nobs),
                                    'Adj_R2': float(full_exp.rsquared_adj),
                                    'AIC': float(full_exp.aic), 'BIC': float(full_exp.bic)
                                })

                        # LRT (Offset + Osc vs Osc)
                        if reduced is not None and full_off is not None:
                            try:
                                lrt_tbl = anova_lm(reduced, full_off)  # nested comparison
                                f_stat = float(lrt_tbl.iloc[1]['F'])
                                p_val  = float(lrt_tbl.iloc[1]['Pr(>F)'])
                            except Exception:
                                f_stat, p_val = (np.nan, np.nan)

                            delta_adjR2 = float(full_off.rsquared_adj - reduced.rsquared_adj)
                            lrt_rows.append({
                                'ClinicalState': state, 'Region': region, 'FreqBand': band, 'PKG_Symptom_DV': dv_name,
                                'Comparison': 'Offset + Osc vs. Osc Only',
                                'F_statistic': f_stat, 'P_value': p_val,
                                'Delta_AdjR2': delta_adjR2,
                                'AdjR2_Reduced': float(reduced.rsquared_adj), 'AdjR2_Full': float(full_off.rsquared_adj),
                                'AIC_Reduced': float(reduced.aic), 'AIC_Full': float(full_off.aic),
                                'BIC_Reduced': float(reduced.bic), 'BIC_Full': float(full_off.bic),
                                'N_reduced': int(reduced.nobs), 'N_full': int(full_off.nobs)
                            })
                            # Store coefficients of full model
                            for term in full_off.params.index:
                                if term == 'Intercept':
                                    continue
                                coeff_rows.append({
                                    'ClinicalState': state, 'Region': region, 'FreqBand': band, 'PKG_Symptom_DV': dv_name,
                                    'Model_Tier': 'Tier 3: Offset + Osc',
                                    'Formula': f"{dv_col} ~ {offset_col} + {' + '.join(osc_predictors)}",
                                    'Predictor_Term': term,
                                    'Predictor_Name_Display': APERIODIC_METRICS_COLS.get(term, OSCILLATORY_METRICS_COLS.get(term, term)),
                                    'Coefficient': float(full_off.params.get(term, np.nan)),
                                    'StdErr': float(full_off.bse.get(term, np.nan)),
                                    'PValue': float(full_off.pvalues.get(term, np.nan)),
                                    'Conf_Int_Lower': float(full_off.conf_int().loc[term, 0]) if term in full_off.conf_int().index else np.nan,
                                    'Conf_Int_Upper': float(full_off.conf_int().loc[term, 1]) if term in full_off.conf_int().index else np.nan,
                                    'N_model': int(full_off.nobs),
                                    'Adj_R2': float(full_off.rsquared_adj),
                                    'AIC': float(full_off.aic), 'BIC': float(full_off.bic)
                                })

        # ---------- FDR over ALL state-specific LRT p-values ----------
        print(f"\nCollected {len(lrt_rows)} state-specific LRT rows.")
        if lrt_rows:
            df_lrt = pd.DataFrame(lrt_rows)
            if 'P_value' in df_lrt.columns:
                pvals = df_lrt['P_value'].values
                mask  = np.isfinite(pvals)
                if mask.any():
                    rej, p_corr = fdrcorrection(pvals[mask], alpha=0.05, method='indep', is_sorted=False)
                    df_lrt['P_value_FDR'] = np.nan
                    df_lrt.loc[mask, 'P_value_FDR'] = p_corr
                    df_lrt['Significant_FDR'] = False
                    df_lrt.loc[mask, 'Significant_FDR'] = rej
                    print("Applied BH-FDR to LRT p-values.")
                else:
                    df_lrt['P_value_FDR'] = np.nan
                    df_lrt['Significant_FDR'] = False

            # Save LRT CSV
            csv_lrt = f"{patient_hemisphere_id}_MLR_LRT_StateSpecific_byRegion_Results_FDR_Step6A.csv"
            df_lrt.to_csv(os.path.join(outdir_states, csv_lrt), index=False)
            print(f"Saved LRT (with FDR) to {csv_lrt}")
            print(df_lrt.head())
        else:
            print("No LRT rows to correct/save.")

        # Save model coefficients CSV
        if coeff_rows:
            df_coeff = pd.DataFrame(coeff_rows)
            csv_coeff = f"{patient_hemisphere_id}_MLR_StateSpecific_byRegion_ModelCoeffs_Step6A.csv"
            df_coeff.to_csv(os.path.join(outdir_states, csv_coeff), index=False)
            print(f"Saved model summaries/coefficients to {csv_coeff}")
            show_cols = ['ClinicalState','Region','FreqBand','PKG_Symptom_DV','Model_Tier','Predictor_Term','Coefficient','PValue','Adj_R2','N_model']
            print(df_coeff[[c for c in show_cols if c in df_coeff.columns]].head())
        else:
            print("No model coefficient rows to save.")

print("\n--- Cell 6A (Region-level): Complete ---")



--- Cell 6A (Region-level): State-Specific MLR & LRT with FDR ---

--- Analyzing ONLY Freq Band: WideFreq ---

Collected 48 state-specific LRT rows.
Applied BH-FDR to LRT p-values.
Saved LRT (with FDR) to COHORT_RCS02_05_06_MLR_LRT_StateSpecific_byRegion_Results_FDR_Step6A.csv
  ClinicalState Region  FreqBand    PKG_Symptom_DV  \
0      Immobile    STN  WideFreq      PKG BK Score   
1      Immobile    STN  WideFreq      PKG BK Score   
2      Immobile    STN  WideFreq      PKG DK Score   
3      Immobile    STN  WideFreq      PKG DK Score   
4      Immobile    STN  WideFreq  PKG Tremor Score   

                    Comparison  F_statistic       P_value  Delta_AdjR2  \
0  Exponent + Osc vs. Osc Only    49.701417  2.692950e-12     0.030119   
1    Offset + Osc vs. Osc Only    60.364672  1.429803e-14     0.036469   
2  Exponent + Osc vs. Osc Only     2.297676  1.297729e-01     0.000812   
3    Offset + Osc vs. Osc Only     0.086513  7.686974e-01    -0.000573   
4  Exponent + Osc vs. Osc 

In [17]:
# -*- coding: utf-8 -*-
# --- Cell 7 (Region-first): Visualization of MLR and LRT Results ---
# Prefers Region (STN/M1) instead of per-contact Channel. If Region is absent,
# falls back to Channel. No individual contact plots are produced when Region exists.

from matplotlib.patches import Patch
import pandas as pd
import numpy as np
import os
import matplotlib.pyplot as plt
import seaborn as sns

print("\n--- Cell 7 (Region-first): Starting Visualization of Regression Results ---\n")

# --- Global Font Size Adjustment (legible but not huge) ---
font_scale_factor = 1.6
plt.rcParams.update({
    'font.size': 10 * font_scale_factor,
    'axes.labelsize': 10 * font_scale_factor,
    'axes.titlesize': 12 * font_scale_factor,
    'xtick.labelsize': 9 * font_scale_factor,
    'ytick.labelsize': 9 * font_scale_factor,
    'legend.fontsize': 9 * font_scale_factor,
    'legend.title_fontsize': 10 * font_scale_factor,
})
P_VALUE_THRESHOLD = 0.05

SYMPTOM_LEGEND_MAP = {
    'PKG BK Score': 'Bradykinesia',
    'PKG DK Score': 'Dyskinesia',
    # 'PKG Tremor Score': 'Tremor'
}
SYMPTOM_DISPLAY_ORDER = ['Bradykinesia', 'Dyskinesia']#, 'Tremor']
REGION_ORDER = ['STN', 'M1']  # enforced when Region is present

# ---------------------------- Helpers ----------------------------
def _pick_group_dim(df):
    """Prefer Region; otherwise fall back to Channel."""
    if 'Region' in df.columns and df['Region'].notna().any():
        return 'Region'
    return 'Channel' if 'Channel' in df.columns else None

def _apply_symptom_display(df):
    df['Symptom_Display'] = df['PKG_Symptom_DV'].map(SYMPTOM_LEGEND_MAP).fillna(df['PKG_Symptom_DV'])
    # Ordered symptoms present in data
    present = [s for s in SYMPTOM_DISPLAY_ORDER if s in df['Symptom_Display'].unique()]
    if not present:
        present = sorted(df['Symptom_Display'].dropna().unique().tolist())
    df['Symptom_Display'] = pd.Categorical(df['Symptom_Display'], categories=present, ordered=True)
    return df

def _order_group_dim(df, group_dim):
    """Set order for Region/Channel categories."""
    if group_dim == 'Region':
        cats = [g for g in REGION_ORDER if g in df['Region'].unique()]
        if not cats:
            cats = sorted(df['Region'].dropna().unique().tolist())
        df['Region'] = pd.Categorical(df['Region'], categories=cats, ordered=True)
    elif group_dim == 'Channel':
        # Keep existing order if already categorical; else alphabetical
        if not (hasattr(df['Channel'], 'cat') and len(getattr(df['Channel'], 'cat').categories) > 0):
            cats = sorted(df['Channel'].dropna().unique().tolist())
            df['Channel'] = pd.Categorical(df['Channel'], categories=cats, ordered=True)
    return df

def _prep_for_plot(df):
    df = df.copy()
    df = _apply_symptom_display(df)
    group_dim = _pick_group_dim(df)
    if group_dim is None:
        raise KeyError("Neither 'Region' nor 'Channel' found in the results dataframe.")
    df = _order_group_dim(df, group_dim)
    if 'ClinicalState' not in df.columns:
        df['ClinicalState'] = 'Overall Results'
    return df, group_dim

# ----------------- Plot 1: Dot–Whisker for Exponent/Offset -----------------
def plot_coefficient_dot_whisker(df_mlr, predictor_to_plot, output_path):
    predictor_name_map = {'Exponent_BestModel': 'Exponent', 'Offset_BestModel': 'Offset'}
    display_name = predictor_name_map.get(predictor_to_plot, predictor_to_plot)

    need = ['Coefficient', 'Conf_Int_Lower', 'Conf_Int_Upper', 'Predictor_Term', 'PKG_Symptom_DV']
    if any(c not in df_mlr.columns for c in need):
        missing = [c for c in need if c not in df_mlr.columns]
        print(f"[DotWhisker] Missing columns {missing}. Skipping {display_name}.")
        return

    df = df_mlr[df_mlr['Predictor_Term'] == predictor_to_plot].copy()
    if df.empty:
        print(f"[DotWhisker] No rows for {display_name}.")
        return

    df, group_dim = _prep_for_plot(df)
    rows_before = len(df)
    df = df.dropna(subset=[group_dim, 'Coefficient', 'Conf_Int_Lower', 'Conf_Int_Upper', 'Symptom_Display'])
    if df.empty:
        print(f"[DotWhisker] No valid rows after dropna for {display_name}.")
        return
    if rows_before != len(df):
        print(f"[DotWhisker] Dropped {rows_before - len(df)} rows with missing values.")

    states = list(pd.unique(df['ClinicalState']))
    groups = list(df[group_dim].cat.categories)
    symptoms = list(df['Symptom_Display'].cat.categories)
    colors = dict(zip(symptoms, sns.color_palette('bright', len(symptoms))))

    fig, axes = plt.subplots(len(states), 1, figsize=(14, 6 + 3*len(states)), sharex=True, squeeze=False)
    axes = axes.flatten()

    for i, state in enumerate(states):
        ax = axes[i]
        ax.set_title(state, pad=14)
        sub = df[df['ClinicalState'] == state].copy()

        x_idx = np.arange(len(groups))
        dodge = 0.5
        pos = np.linspace(-dodge/2, dodge/2, len(symptoms))

        for s_idx, sym in enumerate(symptoms):
            dd = sub[sub['Symptom_Display'] == sym]
            if dd.empty:
                continue
            # Align x positions by category code of group_dim
            x = dd[group_dim].cat.codes.values + pos[s_idx]
            y = dd['Coefficient'].values
            lo = (dd['Coefficient'] - dd['Conf_Int_Lower']).values
            hi = (dd['Conf_Int_Upper'] - dd['Coefficient']).values

            ax.errorbar(x=x, y=y, yerr=[lo, hi],
                        fmt='o', color=colors[sym], capsize=6, markersize=8, linestyle='none',
                        label=sym if i == 0 else None)

        ax.axhline(0, ls='--', color='black', lw=1.6, zorder=0)
        ax.set_xticks(np.arange(len(groups)))
        ax.set_xticklabels(groups, rotation=0)
        ax.set_ylabel("Coefficient (95% CI)")
        ax.grid(axis='y', linestyle=':', alpha=0.6)

    # Legend only once
    if symptoms:
        handles = [plt.Line2D([0], [0], marker='o', linestyle='None', markersize=8, color=colors[s]) for s in symptoms]
        fig.legend(handles, symptoms, title="Symptom", bbox_to_anchor=(1.02, 0.92), loc='upper left')

    xlbl = "Region" if group_dim == 'Region' else "Channel"
    fig.supxlabel(xlbl)
    fig.suptitle(f"{display_name} Coefficients vs PKG", y=1.02)
    plt.tight_layout(rect=[0, 0, 0.88, 0.98])
    plt.savefig(output_path, bbox_inches='tight')
    plt.close()
    print(f"[DotWhisker] Saved: {output_path}")

# ------------- Plot 2: Orthogonal Value (Added R² by Exponent) -------------
def plot_orthogonal_value_stacked_bar(df_mlr, df_lrt, output_path):
    print("[Orthogonal] Building stacked bar for Added R² by Exponent...")

    # Normalize key text columns
    def _norm(df):
        df = df.copy()
        for c in ['ClinicalState', 'Channel', 'Region', 'PKG_Symptom_DV', 'Model_Tier', 'Comparison']:
            if c in df.columns:
                df[c] = (df[c].astype('string')
                                .str.normalize('NFKC')
                                .str.strip()
                                .str.replace(r'\s+', ' ', regex=True))
        # unify model tier labels

                # unify model tier labels to match your CSV
        if 'Model_Tier' in df.columns:
            df['Model_Tier'] = df['Model_Tier'].replace({
                'Tier1: Oscillatory Only': 'Tier 1: Osc Only',
                'Tier 1: Oscillatory Only': 'Tier 1: Osc Only',
                'Oscillatory Only': 'Tier 1: Osc Only',
                'Osc Only': 'Tier 1: Osc Only',
                'Tier 1: Osc Only': 'Tier 1: Osc Only',

                'Tier 2: Exponent + Oscillatory': 'Tier 2: Exponent + Osc',
                'Exponent + Osc': 'Tier 2: Exponent + Osc',
                'Tier 2: Exponent + Osc': 'Tier 2: Exponent + Osc',

                'Tier 3: Offset + Oscillatory': 'Tier 3: Offset + Osc',
                'Offset + Osc': 'Tier 3: Offset + Osc',
                'Tier 3: Offset + Osc': 'Tier 3: Offset + Osc',
            })
        return df
    def _find_adj_r2(df):
        for k in ['Adj_R2', 'R_squared_adj_model', 'R_squared_adj', 'AdjR2', 'R2_adj']:
            if k in df.columns: return k
        return None

    df_mlr = _norm(df_mlr)
    df_lrt = _norm(df_lrt)

    # prefer Region grouping
    group_dim = _pick_group_dim(df_mlr)
    if group_dim is None:
        print("[Orthogonal] No Region/Channel column in MLR results; aborting.")
        return

    # symptom names
    df_mlr = _apply_symptom_display(df_mlr)
    df_mlr = _order_group_dim(df_mlr, group_dim)
    if 'ClinicalState' not in df_mlr.columns: df_mlr['ClinicalState'] = 'Overall Results'
    if 'ClinicalState' not in df_lrt.columns: df_lrt['ClinicalState'] = 'Overall Results'

    r2col = _find_adj_r2(df_mlr)
    if r2col is None:
        print(f"[Orthogonal] No adjusted R² col found in MLR results. Columns: {list(df_mlr.columns)}")
        return


    key = ['ClinicalState', group_dim, 'PKG_Symptom_DV']
    reduced = (df_mlr[df_mlr['Model_Tier'] == 'Tier 1: Osc Only']
               .drop_duplicates(subset=key)[key + [r2col]]
               .rename(columns={r2col: 'R2_Reduced'}))

    full_exp = (df_mlr[df_mlr['Model_Tier'] == 'Tier 2: Exponent + Osc']
                .drop_duplicates(subset=key)[key + [r2col]]
                .rename(columns={r2col: 'R2_Full'}))

    if reduced.empty or full_exp.empty:
        print("[Orthogonal] Reduced or full model rows missing.")
        return

    df_r2 = pd.merge(reduced, full_exp, on=key, how='inner')
    if df_r2.empty:
        print("[Orthogonal] No overlap between reduced and full rows.")
        return

    df_r2['R2_Added_by_Exponent'] = (df_r2['R2_Full'] - df_r2['R2_Reduced']).clip(lower=0)
    df_r2['Symptom_Display'] = df_r2['PKG_Symptom_DV'].map(SYMPTOM_LEGEND_MAP).fillna(df_r2['PKG_Symptom_DV'])

    # pick p-value column from LRT (Exponent + Osc vs Osc Only)
    if 'Comparison' in df_lrt.columns:
        mask = df_lrt['Comparison'].str.lower().str.contains('exponent') & df_lrt['Comparison'].str.contains('osc', case=False)
        df_lrt_exp = df_lrt[mask].copy()
    else:
        df_lrt_exp = df_lrt.copy()

    pcol = None
    for c in ['P_value_FDR', 'P_value', 'pval', 'p_adj', 'P_FDR']:
        if c in df_lrt_exp.columns:
            pcol = c; break

    lrt_key = ['ClinicalState', group_dim, 'PKG_Symptom_DV']
    if pcol and not df_lrt_exp.empty:
        df_lrt_small = df_lrt_exp[lrt_key + [pcol]].rename(columns={pcol: 'P_value_any'})
        df_plot = pd.merge(df_r2, df_lrt_small, on=lrt_key, how='left')
        df_plot['is_significant'] = df_plot['P_value_any'] < P_VALUE_THRESHOLD
    else:
        df_plot = df_r2.copy()
        df_plot['is_significant'] = False

    # categories
    df_plot = _order_group_dim(df_plot, group_dim)
    df_plot = _apply_symptom_display(df_plot)

    # coercions
    for c in ['R2_Full', 'R2_Reduced', 'R2_Added_by_Exponent']:
        df_plot[c] = pd.to_numeric(df_plot[c], errors='coerce')
    df_plot = df_plot.dropna(subset=[group_dim, 'Symptom_Display', 'R2_Full', 'R2_Reduced'])
    if df_plot.empty:
        print("[Orthogonal] Nothing to plot after cleaning.")
        return

    # facet per clinical state
    states = list(pd.unique(df_plot['ClinicalState']))
    groups = list(df_plot[group_dim].cat.categories)
    symptoms = list(pd.Categorical(df_plot['Symptom_Display']).categories)

    g = sns.FacetGrid(df_plot, col='ClinicalState', col_wrap=2, height=6.2, aspect=1.6, sharey=True)
    custom_palette = {
    'Bradykinesia': '#b0b0b0',   # light grey
    'Dyskinesia':  '#707070',    # dark grey
    # 'Tremor': '#404040'         # optional
    }
    # base bars: reduced
    g.map_dataframe(
        sns.barplot, x=group_dim, y='R2_Reduced', hue='Symptom_Display',
        palette=custom_palette, dodge=0.8, errorbar=None, alpha=0.55, zorder=1,
        order=groups, hue_order=symptoms
    )

    # add colored hatched cap = Added R² by exponent
    palette = sns.color_palette('viridis', len(symptoms))
    for ax, state_name in zip(g.axes.flat, g.col_names):
        sub = df_plot[df_plot['ClinicalState'] == state_name]
        # compose in the same hue/channel order as seaborn created bars
        added, total, sigs = [], [], []
        for grp in groups:
            for sym in symptoms:
                row = sub[(sub[group_dim] == grp) & (sub['Symptom_Display'] == sym)]
                if not row.empty:
                    added.append(float(row['R2_Added_by_Exponent'].iloc[0]))
                    total.append(float(row['R2_Full'].iloc[0]))
                    sigs.append(bool(row['is_significant'].iloc[0]))
                else:
                    added.append(0.0); total.append(0.0); sigs.append(False)

        patches = [p for p in ax.patches]  # one per base bar
        for idx, p in enumerate(patches):
            if idx >= len(added): break
            h_add = added[idx]
            if h_add > 0:
                sym_idx = idx % len(symptoms)
                ax.bar(
                    p.get_x() + p.get_width()/2, h_add, width=p.get_width(),
                    bottom=p.get_height(), align='center',
                    color=palette[sym_idx], hatch='///', linewidth=0, zorder=2
                )
        # mark significance at total height
        for idx, p in enumerate(patches):
            if idx >= len(sigs): break
            if sigs[idx]:
                ax.text(p.get_x() + p.get_width()/2, total[idx] + 0.01, '*',
                        ha='center', va='bottom', fontsize=14 * font_scale_factor)

        ax.set_xticklabels(groups, rotation=0)
        ax.set_ylabel("Adjusted R-squared")
        ax.grid(axis='y', linestyle=':', alpha=0.6)

    # legends
    g.add_legend(title="Symptom Score")
    pattern_legend = [Patch(facecolor='lightgray', alpha=0.55, label='Reduced (Osc-only)'),
                      Patch(facecolor='white', edgecolor='black', hatch='///', label='Added by Exponent')]
    g.fig.legend(handles=pattern_legend,
                 labels=[h.get_label() for h in pattern_legend],
                 loc='upper left', bbox_to_anchor=(0.88, 0.92))

    g.set_titles("{col_name}", pad=14)
    g.fig.suptitle("Orthogonal Value of Exponent (Added R²)", y=1.02)
    g.fig.tight_layout(rect=[0, 0, 0.96, 0.98])

    # Also dump a CSV of exactly what was plotted (handy to audit)
    try:
        out_csv = os.path.splitext(output_path)[0] + "_VALUES.csv"
        dump_cols = ['ClinicalState', group_dim, 'Symptom_Display', 'R2_Reduced', 'R2_Full', 'R2_Added_by_Exponent']
        if 'P_value_any' in df_plot.columns: dump_cols.append('P_value_any')
        if 'is_significant' in df_plot.columns: dump_cols.append('is_significant')
        df_plot.sort_values(['ClinicalState', group_dim, 'Symptom_Display']).to_csv(out_csv, index=False)
        print(f"[Orthogonal] Values CSV saved: {out_csv}")
    except Exception as e:
        print(f"[Orthogonal] Could not write values CSV: {e}")

    plt.savefig(output_path, bbox_inches='tight')
    plt.close()
    print(f"[Orthogonal] Saved: {output_path}")

# ---------------------------- Runner ----------------------------
def run_visualizations(analysis_type, results_dir, mlr_filename, lrt_filename, output_dir):
    print(f"\n{'='*18}\nGenerating {analysis_type} plots\n{'='*18}")
    mlr_path = os.path.join(results_dir, mlr_filename)
    lrt_path = os.path.join(results_dir, lrt_filename)

    try:
        df_mlr = pd.read_csv(mlr_path)
        df_lrt = pd.read_csv(lrt_path)

        # 'PKG Tremor Score': 'Tremor',
        if EXCLUDE_TREMOR:
            if 'PKG_Symptom_DV' in df_mlr.columns:
                df_mlr = df_mlr[~df_mlr['PKG_Symptom_DV'].isin(TREMOR_LABELS)]
            if 'PKG_Symptom_DV' in df_lrt.columns:
                df_lrt = df_lrt[~df_lrt['PKG_Symptom_DV'].isin(TREMOR_LABELS)]

    except FileNotFoundError as e:
        print(f"SKIP {analysis_type}: Missing file: {e.filename}")
        return

    # Quick sanity
    print(f"[{analysis_type}] MLR cols: {list(df_mlr.columns)[:8]} ...")
    print(f"[{analysis_type}] LRT cols: {list(df_lrt.columns)[:8]} ...")

    # Dot–whisker for Exponent / Offset (only if present)
    if 'Predictor_Term' in df_mlr.columns and 'Coefficient' in df_mlr.columns:
        exp_out = os.path.join(output_dir, f"{patient_hemisphere_id}_{analysis_type}_Exponent_Coefficients.png")
        off_out = os.path.join(output_dir, f"{patient_hemisphere_id}_{analysis_type}_Offset_Coefficients.png")
        plot_coefficient_dot_whisker(df_mlr, 'Exponent_BestModel', exp_out)
        plot_coefficient_dot_whisker(df_mlr, 'Offset_BestModel',   off_out)
    else:
        print(f"[{analysis_type}] No coefficient rows found; skipping dot–whisker plots.")

    # Orthogonal Added R² plot (Exponent)
    ortho_out = os.path.join(output_dir, f"{patient_hemisphere_id}_{analysis_type}_Orthogonal_Value.png")
    plot_orthogonal_value_stacked_bar(df_mlr, df_lrt, ortho_out)

# ---------------------------- Main --------------------------------
if __name__ == "__main__" and 'patient_hemisphere_id' in locals():
    # --- Global (Cell 6) results ---
    global_dir = os.path.join(analysis_session_plot_folder_step4, "MLR_byRegion_V4")
    global_mlr_file = f"{patient_hemisphere_id}_MLR_Coefficients_byRegion_Step6.csv"
    global_lrt_file = f"{patient_hemisphere_id}_MLR_LRT_Results_byRegion_Step6.csv"
    os.makedirs(global_dir, exist_ok=True)
    run_visualizations("Global", global_dir, global_mlr_file, global_lrt_file, global_dir)


    # --- State-specific (Cell 6A) results ---
    # Prefer by-Region 6A outputs if you used the region-level 6A I provided.
    state_dir_candidates = [
        os.path.join(STATE_SPECIFIC_ANALYSIS_DIR, "MLR_State_Specific_byRegion_FDR"),     # region-level 6A (recommended)
        os.path.join(STATE_SPECIFIC_ANALYSIS_DIR, "MultipleLinearRegression_State_Specific_FDR")  # fallback if you ran the older 6A
    ]
    for state_dir in state_dir_candidates:
        if os.path.isdir(state_dir):
            # Region-level file names (first), else older ones
            mlr_files = [
                f"{patient_hemisphere_id}_MLR_StateSpecific_byRegion_ModelCoeffs_Step6A.csv",
                f"{patient_hemisphere_id}_MLR_StateSpecific_Results_Step6A.csv"
            ]
            lrt_files = [
                f"{patient_hemisphere_id}_MLR_LRT_StateSpecific_byRegion_Results_FDR_Step6A.csv",
                f"{patient_hemisphere_id}_MLR_LRT_StateSpecific_Results_FDR_Step6A.csv"
            ]
            mlr_found = next((f for f in mlr_files if os.path.exists(os.path.join(state_dir, f))), None)
            lrt_found = next((f for f in lrt_files if os.path.exists(os.path.join(state_dir, f))), None)
            if mlr_found and lrt_found:
                run_visualizations("StateSpecific", state_dir, mlr_found, lrt_found, state_dir)
                break
    else:
        print("SKIP StateSpecific: No state-specific results directory with expected files was found.")

else:
    print("Skipping plot generation because __main__ guard did not pass or key variables are missing.")

print("\n--- Cell 7 (Region-first): Visualizations Complete ---")



--- Cell 7 (Region-first): Starting Visualization of Regression Results ---


Generating Global plots
[Global] MLR cols: ['Region', 'FreqBand', 'PKG_Symptom_DV', 'Model_Tier', 'Formula', 'N_model', 'Adj_R2', 'AIC'] ...
[Global] LRT cols: ['Region', 'FreqBand', 'PKG_Symptom_DV', 'Comparison', 'F_statistic', 'P_value', 'Delta_AdjR2', 'AdjR2_Reduced'] ...
[DotWhisker] Saved: /home/jackson/step2_final/step3_fooof_results_neural_pkg_aligned_finalstep3_bushlab5000/step4_within_subject/COHORT_RCS02_05_06_plots_20250822_155044/MLR_byRegion_V4/COHORT_RCS02_05_06_Global_Exponent_Coefficients.png
[DotWhisker] Saved: /home/jackson/step2_final/step3_fooof_results_neural_pkg_aligned_finalstep3_bushlab5000/step4_within_subject/COHORT_RCS02_05_06_plots_20250822_155044/MLR_byRegion_V4/COHORT_RCS02_05_06_Global_Offset_Coefficients.png
[Orthogonal] Building stacked bar for Added R² by Exponent...
[Orthogonal] Values CSV saved: /home/jackson/step2_final/step3_fooof_results_neural_pkg_aligned_finalstep3_b

In [18]:
# -*- coding: utf-8 -*-
# --- Cell 8 (Region-first & Safe Fallback): Curated Visualization (STN vs M1 only) ---

import pandas as pd
import numpy as np
import os
import matplotlib.pyplot as plt
import seaborn as sns

print("\n--- Cell 8 (Region-first): Starting Curated Visualization (STN vs M1) ---\n")

# --- Global Font / Style ---
font_scale_factor = 3
plt.rcParams.update({
    'font.size': 10 * font_scale_factor, 'axes.labelsize': 10 * font_scale_factor,
    'axes.titlesize': 12 * font_scale_factor, 'xtick.labelsize': 10 * font_scale_factor,
    'ytick.labelsize': 8 * font_scale_factor, 'legend.fontsize': 9 * font_scale_factor,
    'legend.title_fontsize': 10 * font_scale_factor,
})
P_VALUE_THRESHOLD = 0.05

REGION_ORDER = ['STN', 'M1']
CHANNEL_GROUP_MAP = {
    'STN': ['Contact_2_0', 'Contact_3_0', 'Contact_3_1'],
    'M1':  ['Contact_10_8', 'Contact_11_9']
}

SYMPTOM_ORDER = ['PKG BK Score', 'PKG DK Score', 'PKG Tremor Score']
SYMPTOM_LEGEND_MAP = {
    'PKG BK Score': 'Bradykinesia',
    'PKG DK Score': 'Dyskinesia',
    # 'PKG Tremor Score': 'Tremor'
}
SYMPTOM_DISPLAY_ORDER = ['Bradykinesia', 'Dyskinesia']#, 'Tremor']

def _pick_p_col(df, prefer_fdr=True):
    candidates_fdr = ['P_value_FDR', 'P_FDR', 'p_adj', 'p_fdr']
    candidates_raw = ['P_value', 'p_value', 'pval', 'P']
    if prefer_fdr:
        for c in candidates_fdr:
            if c in df.columns: return c
    for c in candidates_raw:
        if c in df.columns: return c
    return None

# ---------- Plotters ----------
def plot_curated_coefficients(df_plot, predictor_to_plot, output_path):
    predictor_name_map = {'Exponent_BestModel': 'Exponent', 'Offset_BestModel': 'Offset'}
    predictor_display_name = predictor_name_map.get(predictor_to_plot, predictor_to_plot)
    print(f"\nGenerating Curated Dot-and-Whisker: '{predictor_display_name}' (Region)…")

    req = {'Coefficient','Conf_Int_Lower','Conf_Int_Upper','Predictor_Term','Symptom_Display','Region'}
    if not req.issubset(df_plot.columns):
        print(f"Missing columns for coefficient plot: {sorted(list(req - set(df_plot.columns)))}")
        return

    df_plot = df_plot[df_plot['Predictor_Term'] == predictor_to_plot].copy()
    if df_plot.empty:
        print("No data to plot. Skipping.")
        return

    df_plot['Symptom_Display'] = pd.Categorical(df_plot['Symptom_Display'],
                                                categories=SYMPTOM_DISPLAY_ORDER, ordered=True)
    cats = [r for r in REGION_ORDER if r in df_plot['Region'].unique()]
    if not cats: cats = sorted(df_plot['Region'].dropna().unique().tolist())
    df_plot['Region'] = pd.Categorical(df_plot['Region'], categories=cats, ordered=True)
    df_plot = df_plot.sort_values(['Region','Symptom_Display'])

    regions = list(df_plot['Region'].cat.categories)
    symptoms = list(df_plot['Symptom_Display'].cat.categories)
    symptom_colors = dict(zip(symptoms, sns.color_palette('bright', len(symptoms))))

    fig, ax = plt.subplots(1, 1, figsize=(10, 10))
    idx = np.arange(len(regions))
    dodge = 0.4
    pos = np.linspace(-dodge/2, dodge/2, len(symptoms))

    for s_i, sym in enumerate(symptoms):
        sub = df_plot[df_plot['Symptom_Display']==sym]
        if sub.empty: continue
        x = sub['Region'].cat.codes.values + pos[s_i]
        y = sub['Coefficient'].values
        lo = (sub['Coefficient'] - sub['Conf_Int_Lower']).values
        hi = (sub['Conf_Int_Upper'] - sub['Coefficient']).values
        ax.errorbar(x=x, y=y, yerr=[lo, hi], fmt='o',
                    color=symptom_colors[sym], label=sym,
                    capsize=8, markersize=14, linestyle='none',
                    linewidth=2.5, markeredgewidth=2.5)

    ax.axhline(0, ls='--', color='black', lw=2, zorder=0)
    ax.set_xticks(idx); ax.set_xticklabels(regions)
    ax.set_ylabel("Regression Coefficient (95% CI)")
    ax.grid(axis='y', linestyle=':', alpha=0.7)
    fig.legend(title="Symptom Score", bbox_to_anchor=(1.02, 0.9), loc='upper left')
    plt.tight_layout(rect=[0, 0, 0.85, 1])
    plt.savefig(output_path, bbox_inches='tight'); plt.close()
    print(f"Saved plot to: {output_path}")

def plot_curated_orthogonal_value(df_plot, output_path):
    """
    Curated stacked bar: base = R2_Reduced, hatched cap = R2_Added_by_Exponent,
    significance star at total height (R2_Full) if LRT p < alpha.
    """
    print("Generating Curated Stacked Bar (Added R² by Exponent)…")

    if df_plot is None or df_plot.empty:
        print("No data to plot. Skipping.")
        return

    # Expect columns
    need = {'Region','Symptom_Display','R2_Reduced','R2_Full','R2_Added_by_Exponent','is_significant'}
    missing = sorted(list(need - set(df_plot.columns)))
    if missing:
        print(f"Missing columns for orthogonal plot: {missing}")
        return

    df_plot = df_plot.copy()

    # ---- Coerce numeric (in case they came in as strings) ----
    for c in ['R2_Reduced','R2_Full','R2_Added_by_Exponent']:
        df_plot[c] = pd.to_numeric(df_plot[c], errors='coerce')
    df_plot = df_plot.dropna(subset=['Region','Symptom_Display','R2_Reduced','R2_Full'])
    if df_plot.empty:
        print("Nothing to plot after cleaning.")
        return

    # ---- Ensure categorical dtypes & ordering BEFORE plotting ----
    # Symptom order
    present_sym = [s for s in SYMPTOM_DISPLAY_ORDER if s in df_plot['Symptom_Display'].unique()]
    if not present_sym:  # fallback to alphabetical if none match expected order
        present_sym = sorted(df_plot['Symptom_Display'].dropna().unique().tolist())
    df_plot['Symptom_Display'] = pd.Categorical(df_plot['Symptom_Display'],
                                                categories=present_sym, ordered=True)

    # Region order
    present_reg = [r for r in REGION_ORDER if r in df_plot['Region'].unique()]
    if not present_reg:
        present_reg = sorted(df_plot['Region'].dropna().unique().tolist())
    df_plot['Region'] = pd.Categorical(df_plot['Region'],
                                       categories=present_reg, ordered=True)

    # Sort for tidy seaborn drawing
    df_plot = df_plot.sort_values(['Region','Symptom_Display'])

    # ---- Plot ----
    # bright = sns.color_palette('bright', 3)
    # symptom_colors = {'Bradykinesia': bright[0], 'Dyskinesia': bright[1], 'Tremor': bright[2]}
    # Clearer greys for stacked bars
    symptom_colors = {
        'Bradykinesia': '#b0b0b0',  # light grey
        'Dyskinesia':  '#707070',   # darker grey
        # 'Tremor': '#404040'        # (keep only if tremor is reintroduced)
    }

    fig, ax = plt.subplots(1, 1, figsize=(12, 10))
    custom_palette = {
    'Bradykinesia': '#b0b0b0',   # light grey
    'Dyskinesia':  '#707070',    # dark grey
    # 'Tremor': '#404040'         # optional
    }
    # Base (reduced): grayscale bars
    sns.barplot(
        data=df_plot, x='Region', y='R2_Reduced', hue='Symptom_Display',
        palette=custom_palette, dodge=0.8, errorbar=None, ax=ax, alpha=0.55, legend=False,
        order=list(df_plot['Region'].cat.categories),
        hue_order=list(df_plot['Symptom_Display'].cat.categories)
    )

    regions = list(df_plot['Region'].cat.categories)
    symptoms = list(df_plot['Symptom_Display'].cat.categories)

    # Build arrays aligned to seaborn's (region major, symptom minor)
    added_heights, totals, sigs = [], [], []
    for reg in regions:
        for sym in symptoms:
            row = df_plot[(df_plot['Region']==reg) & (df_plot['Symptom_Display']==sym)]
            if not row.empty:
                added_heights.append(float(row['R2_Added_by_Exponent'].iloc[0]))
                totals.append(float(row['R2_Full'].iloc[0]))
                sigs.append(bool(row['is_significant'].iloc[0]))
            else:
                added_heights.append(0.0); totals.append(0.0); sigs.append(False)

    # Hatch caps on top of base bars
    patches = [p for p in ax.patches]  # one per base bar
    for idx, p in enumerate(patches):
        if idx >= len(added_heights): break
        h_add = added_heights[idx]
        if h_add > 0:
            sym_idx = idx % len(symptoms)
            ax.bar(
                p.get_x() + p.get_width()/2, h_add, width=p.get_width(),
                bottom=p.get_height(), align='center',
                color=symptom_colors.get(symptoms[sym_idx], 'C0'),
                hatch='///', linewidth=0, zorder=2
            )

    # Significance stars at total height (R2_Full)
    for idx, p in enumerate(patches):
        if idx >= len(sigs): break
        if sigs[idx]:
            ax.text(
                p.get_x() + p.get_width()/2, totals[idx] + 0.005, '*',
                ha='center', va='bottom', color='red',
                fontsize=18 * font_scale_factor, zorder=3
            )

    ax.set_ylabel("Adjusted R-squared")
    ax.set_xlabel("")
    ax.grid(axis='y', linestyle=':', alpha=0.7)
    plt.tight_layout()
    plt.savefig(output_path, bbox_inches='tight')
    plt.close()
    print(f"Saved plot to: {output_path}")


# ---------- Main ----------
def run_curated_visualizations(results_dir, mlr_filename, lrt_filename, output_dir):
    try:
        df_mlr = pd.read_csv(os.path.join(results_dir, mlr_filename))
        df_lrt = pd.read_csv(os.path.join(results_dir, lrt_filename))
    except FileNotFoundError as e:
        print(f"SKIPPING: Missing file: {e.filename}. Run Cell 6/6A first.")
        return None
    if EXCLUDE_TREMOR:
        for _df in (df_mlr, df_lrt):
            if 'PKG_Symptom_DV' in _df.columns:
                _df.drop(_df[_df['PKG_Symptom_DV'].isin(TREMOR_LABELS)].index, inplace=True)

    # REGION BRANCH
    if 'Region' in df_mlr.columns and df_mlr['Region'].notna().any():
        print("Region column found — generating STN vs M1 plots (no contact selection).")
        df_coeff = df_mlr.copy()
        df_coeff['Symptom_Display'] = df_coeff['PKG_Symptom_DV'].map(SYMPTOM_LEGEND_MAP)

        r2_col = None
        for c in ['R_squared_adj_model','R_squared_adj','Adj_R2','R2_adj','AdjR2']:
            if c in df_mlr.columns: r2_col = c; break
        if r2_col is None:
            print("No adjusted R² col; skipping orthogonal plot.")
            return None

        key = ['Region','PKG_Symptom_DV']
        reduced = (df_mlr[df_mlr.get('Model_Tier','')=='Tier 1: Osc Only']
                   .drop_duplicates(subset=key)[key+[r2_col]].rename(columns={r2_col:'R2_Reduced'}))
        full = (df_mlr[df_mlr.get('Model_Tier','')=='Tier 2: Exponent + Osc']
                .drop_duplicates(subset=key)[key+[r2_col]].rename(columns={r2_col:'R2_Full'}))
        df_r2 = pd.merge(reduced, full, on=key, how='inner')
        df_r2['Symptom_Display'] = df_r2['PKG_Symptom_DV'].map(SYMPTOM_LEGEND_MAP)
        df_r2['R2_Added_by_Exponent'] = (df_r2['R2_Full'] - df_r2['R2_Reduced']).clip(lower=0)

        pcol = _pick_p_col(df_lrt, prefer_fdr=True)
        if pcol and 'Comparison' in df_lrt.columns:
            mask = df_lrt['Comparison'].astype(str).str.lower().str.contains('exponent') & \
                   df_lrt['Comparison'].astype(str).str.lower().str.contains('osc')
            lrt_small = df_lrt.loc[mask, ['Region','PKG_Symptom_DV', pcol]].rename(columns={pcol:'P_any'})
            df_ortho = pd.merge(df_r2, lrt_small, on=['Region','PKG_Symptom_DV'], how='left')
        else:
            df_ortho = df_r2.copy(); df_ortho['P_any'] = np.nan
        df_ortho['is_significant'] = df_ortho['P_any'] < P_VALUE_THRESHOLD

        os.makedirs(output_dir, exist_ok=True)
        plot_curated_coefficients(df_coeff, 'Exponent_BestModel',
                                  os.path.join(output_dir, f"{patient_hemisphere_id}_Curated_Exponent_Coefficients_REGION.png"))
        plot_curated_coefficients(df_coeff, 'Offset_BestModel',
                                  os.path.join(output_dir, f"{patient_hemisphere_id}_Curated_Offset_Coefficients_REGION.png"))
        plot_curated_orthogonal_value(df_ortho,
                                  os.path.join(output_dir, f"{patient_hemisphere_id}_Curated_Orthogonal_Value_REGION.png"))
        return {'mode': 'region', 'selections': None}

    # CONTACT FALLBACK
    print("Region not found — falling back to contact-level selection (plots still STN vs M1).")
    return None

# ---------- Execute ----------
if __name__ == "__main__" and 'patient_hemisphere_id' in locals():
    global_results_dir = os.path.join(analysis_session_plot_folder_step4, "MLR_byRegion_V4")
    global_mlr_file = f"{patient_hemisphere_id}_MLR_Coefficients_byRegion_Step6.csv"
    global_lrt_file = f"{patient_hemisphere_id}_MLR_LRT_Results_byRegion_Step6.csv"
    curated_output_dir = os.path.join(global_results_dir, "Curated_Result_Plots")
    os.makedirs(curated_output_dir, exist_ok=True)

    user_selections = run_curated_visualizations(global_results_dir, global_mlr_file, global_lrt_file, curated_output_dir)

    if user_selections is not None:
        print("\n--- Cell 8 (Region-first): Complete. 'user_selections' saved. ---")
    else:
        print("\n--- Cell 8 (Region-first): Finished with no selections saved. ---")
else:
    print("Skipping plot generation as script is not being run directly or key variables are missing.")



--- Cell 8 (Region-first): Starting Curated Visualization (STN vs M1) ---

Region column found — generating STN vs M1 plots (no contact selection).

Generating Curated Dot-and-Whisker: 'Exponent' (Region)…
Saved plot to: /home/jackson/step2_final/step3_fooof_results_neural_pkg_aligned_finalstep3_bushlab5000/step4_within_subject/COHORT_RCS02_05_06_plots_20250822_155044/MLR_byRegion_V4/Curated_Result_Plots/COHORT_RCS02_05_06_Curated_Exponent_Coefficients_REGION.png

Generating Curated Dot-and-Whisker: 'Offset' (Region)…
Saved plot to: /home/jackson/step2_final/step3_fooof_results_neural_pkg_aligned_finalstep3_bushlab5000/step4_within_subject/COHORT_RCS02_05_06_plots_20250822_155044/MLR_byRegion_V4/Curated_Result_Plots/COHORT_RCS02_05_06_Curated_Offset_Coefficients_REGION.png
Generating Curated Stacked Bar (Added R² by Exponent)…
Saved plot to: /home/jackson/step2_final/step3_fooof_results_neural_pkg_aligned_finalstep3_bushlab5000/step4_within_subject/COHORT_RCS02_05_06_plots_20250822_15

In [19]:
# -*- coding: utf-8 -*-
# --- Cell 9 (Revised V3, Region-first): Individual Box Plots with Bootstrapped Median CIs ---
# Improvements:
# • Region-first (STN vs M1); falls back to per-contact if Region cannot be derived
# • IQR outlier filtering per (Group × State)
# • Bootstrapped 95% CI for the median overlaid on each box
# • Optional Kruskal–Wallis across states (per group) + CSV exports

print("\n--- Cell 9 (Revised V3, Region-first): Generating Exponent Box Plots with Median CIs ---")

import pandas as pd
import numpy as np
import os
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.lines as mlines
import seaborn as sns
from scipy.stats import kruskal
import scikit_posthocs as sp
import warnings

# ---------------- Config ----------------
warnings.filterwarnings("ignore", message="invalid value encountered in scalar divide")
warnings.filterwarnings("ignore", message="Confidence interval might not be reliable for bootstrap samples with fewer than 50 elements.")

P_VALUE_THRESHOLD = 0.05
MIN_SAMPLES_FOR_GROUP_COMPARISON = 5

# Columns / labels from earlier cells
CLINICAL_STATE_COL   = 'Clinical_State_2min_Window'
CHANNEL_DISPLAY_COL  = 'Channel_Display'
FOOOF_FREQ_BAND_COL  = 'FreqRangeLabel'
AP_COL               = 'Exponent_BestModel'   # aperiodic exponent column
AP_LABEL             = 'Aperiodic Exponent'

ORDERED_FREQ_LABELS = ["LowFreq", "MidFreq", "WideFreq"]
CELL9_TARGET_STATES_ORDERED = ["Sleep", "Immobile", "Non-Dyskinetic Mobile", "Transitional Mobile", "Dyskinetic Mobile"]
CELL9_STATE_COLORS = {
    'Sleep': '#4169E1', 'Immobile': '#40E0D0', 'Non-Dyskinetic Mobile': '#32CD32',
    'Transitional Mobile': '#FFD700', 'Dyskinetic Mobile': '#FF6347'
}

# Region mapping (contact -> region)
CONTACT_TO_REGION = {
    'Contact_2_0': 'STN', 'Contact_3_0': 'STN', 'Contact_3_1': 'STN',
    'STN_DBS_2-0': 'STN', 'STN_DBS_3-1': 'STN',
    'Contact_10_8': 'M1', 'Contact_11_9': 'M1',
    'Cortical_ECoG_10-8': 'M1', 'Cortical_ECoG_11-9': 'M1',
}
REGION_ORDER = ['STN', 'M1']

# Short tick labels for states
STATE_TICK_SHORT = {
    "Sleep": "Sleep", "Immobile": "Imm", "Non-Dyskinetic Mobile": "NDM",
    "Transitional Mobile": "TM", "Dyskinetic Mobile": "DM"
}

# --- Plot style (bigger but readable) ---
plt.rcParams.update({
    'font.size': 18, 'axes.labelsize': 20, 'axes.titlesize': 22,
    'xtick.labelsize': 16, 'ytick.labelsize': 16, 'legend.fontsize': 16,
    'legend.title_fontsize': 16,
})
BOX_FILL_ALPHA = 0.7
BOXPLOT_LINE_THICKNESS = 2.0
DOT_ALPHA = 0.5

# ---------------- Helpers ----------------
def bootstrap_median_ci(data, n_boot=2000, ci=0.95, random_state=None):
    """Bootstrap CI for the median. Returns (median, ci_low, ci_high)."""
    x = np.asarray(data, dtype=float)
    x = x[~np.isnan(x)]
    if x.size < 2:
        return np.nan, np.nan, np.nan
    rng = np.random.default_rng(random_state)
    boot = np.median(rng.choice(x, size=(n_boot, x.size), replace=True), axis=1)
    lo = np.percentile(boot, (1-ci)/2 * 100.0)
    hi = np.percentile(boot, (1+ci)/2 * 100.0)
    return float(np.median(x)), float(lo), float(hi)

def filter_outliers_iqr_per_group(df, group_cols, value_col, factor=1.5):
    """
    IQR outlier filtering applied per unique combination in group_cols.
    Returns filtered df and number of removed points.
    """
    df = df.copy()
    removed = 0
    def _clip_group(g):
        nonlocal removed
        v = g[value_col]
        q1, q3 = v.quantile(0.25), v.quantile(0.75)
        iqr = q3 - q1
        lo, hi = q1 - factor*iqr, q3 + factor*iqr
        mask = (v < lo) | (v > hi)
        removed += int(mask.sum())
        g.loc[mask, value_col] = np.nan
        return g

    df = df.groupby(group_cols, dropna=False, as_index=False, group_keys=False).apply(_clip_group)
    # drop after marking
    df = df.dropna(subset=[value_col])
    return df, removed

def ensure_categorical(df, col, order=None):
    df = df.copy()
    if order:
        present = [c for c in order if c in df[col].dropna().unique()]
        if not present:
            present = sorted(df[col].dropna().unique().tolist())
        df[col] = pd.Categorical(df[col], categories=present, ordered=True)
    else:
        if not pd.api.types.is_categorical_dtype(df[col]):
            df[col] = pd.Categorical(df[col])
    return df

# ---------------- Data prep ----------------
if 'master_df_step4' not in locals() or master_df_step4 is None or master_df_step4.empty:
    print("ERROR: master_df_step4 is not available. Please run previous cells.")
else:
    df0 = master_df_step4.copy()

    # Restrict to target states (keep order)
    df0 = df0[df0[CLINICAL_STATE_COL].isin(CELL9_TARGET_STATES_ORDERED)].copy()

    # Derive Region (if not already present)
    if 'Region' not in df0.columns:
        # Start from Channel_Display when possible
        if CHANNEL_DISPLAY_COL in df0.columns:
            df0['Region'] = df0[CHANNEL_DISPLAY_COL].map(CONTACT_TO_REGION)
        else:
            df0['Region'] = np.nan

    # If Region is still missing for some rows, try 'Contact' column if you have it
    if df0['Region'].isna().any() and 'Contact' in df0.columns:
        df0.loc[df0['Region'].isna(), 'Region'] = df0.loc[df0['Region'].isna(), 'Contact'].map(CONTACT_TO_REGION)

    # Build output folder
    plot_dir = os.path.join(analysis_session_plot_folder_step4, "Exponent_BoxPlots_with_MedianCI_REGION")
    os.makedirs(plot_dir, exist_ok=True)
    print(f"  Plots will be saved to: {plot_dir}")

    # 10-min timestamp for optional point overlay
    if 'datetime_for_avg' not in df0.columns and 'Aligned_PKG_UnixTimestamp' in df0.columns:
        df0['datetime_for_avg'] = pd.to_datetime(df0['Aligned_PKG_UnixTimestamp'], unit='s', utc=True, errors='coerce')

    # Collect stats for CSVs
    summary_rows = []
    kw_rows = []

    # ---------------- Main loop (Region-first; falls back to Channel if Region is all NaN) ---------------
    group_dim = 'Region' if ('Region' in df0.columns and df0['Region'].notna().any()) else CHANNEL_DISPLAY_COL
    print(f"  Grouping by: {group_dim}")

    # Respect ordering
    if group_dim == 'Region':
        df0 = ensure_categorical(df0, 'Region', REGION_ORDER)
    else:
        # for contacts, alphabetical
        df0 = ensure_categorical(df0, CHANNEL_DISPLAY_COL, None)

    # Frequency loop
    for freq_label in ORDERED_FREQ_LABELS:
        df_f = df0[df0[FOOOF_FREQ_BAND_COL] == freq_label].copy()
        if df_f.empty:
            print(f"  No data for {freq_label}; skipping.")
            continue

        # Outlier filter per (Group × State)
        before_n = len(df_f)
        df_f, n_removed = filter_outliers_iqr_per_group(
            df_f, group_cols=[group_dim, CLINICAL_STATE_COL], value_col=AP_COL, factor=1.5
        )
        if before_n > 0:
            print(f"  {freq_label}: removed {n_removed} outliers ({(n_removed/max(before_n,1))*100:.1f}%) via IQR per ({group_dim}×State).")

        if df_f.empty:
            print(f"  No valid data after filtering for {freq_label}; skipping.")
            continue

        # Per-group plotting (one figure per group)
        for grp in df_f[group_dim].dropna().unique():
            df_g = df_f[df_f[group_dim] == grp].copy()
            if df_g.empty: continue

            # order states
            df_g = ensure_categorical(df_g, CLINICAL_STATE_COL, CELL9_TARGET_STATES_ORDERED)

            # Optional 10-min averaged points
            if 'datetime_for_avg' in df_g.columns and not df_g['datetime_for_avg'].isnull().all():
                pts = (df_g.set_index('datetime_for_avg')
                           .groupby([pd.Grouper(freq='10T'), CLINICAL_STATE_COL])[[AP_COL]]
                           .mean().dropna().reset_index())
            else:
                pts = pd.DataFrame()

            # ---- Figure ----
            fig, ax = plt.subplots(figsize=(9, 9))
            sns.boxplot(
                data=df_g, x=CLINICAL_STATE_COL, y=AP_COL,
                order=CELL9_TARGET_STATES_ORDERED, palette=CELL9_STATE_COLORS,
                showfliers=False, width=0.55, ax=ax,
                boxprops={'alpha': BOX_FILL_ALPHA, 'linewidth': BOXPLOT_LINE_THICKNESS},
                medianprops={'linewidth': BOXPLOT_LINE_THICKNESS, 'color':'black'},
                whiskerprops={'linewidth': BOXPLOT_LINE_THICKNESS},
                capprops={'linewidth': BOXPLOT_LINE_THICKNESS}
            )

            if not pts.empty:
                sns.stripplot(
                    data=pts, x=CLINICAL_STATE_COL, y=AP_COL,
                    order=CELL9_TARGET_STATES_ORDERED, palette=CELL9_STATE_COLORS,
                    jitter=0.15, alpha=DOT_ALPHA, size=4.0, ax=ax, legend=False
                )

            # ---- Bootstrap CI per state ----
            xticks = ax.get_xticks()
            for i, state in enumerate(CELL9_TARGET_STATES_ORDERED):
                v = df_g.loc[df_g[CLINICAL_STATE_COL] == state, AP_COL].dropna()
                if len(v) >= 10:
                    med, lo, hi = bootstrap_median_ci(v, n_boot=2000, ci=0.95)
                    if np.isfinite(med):
                        ax.errorbar(
                            x=xticks[i], y=med, yerr=[[med - lo], [hi - med]],
                            fmt='o', color='black', ecolor='black',
                            capsize=6, elinewidth=1.8, markersize=6, zorder=10
                        )
                # Save medians & CIs to CSV summary
                if len(v) > 0:
                    med2 = float(np.median(v))
                    if len(v) >= 10:
                        summary_rows.append({
                            'GroupDim': group_dim, 'Group': grp, 'FreqBand': freq_label,
                            'State': state, 'N': int(len(v)), 'Median': med, 'CI_low': lo, 'CI_high': hi
                        })
                    else:
                        summary_rows.append({
                            'GroupDim': group_dim, 'Group': grp, 'FreqBand': freq_label,
                            'State': state, 'N': int(len(v)), 'Median': med2, 'CI_low': np.nan, 'CI_high': np.nan
                        })

            # ---- Stats: Kruskal–Wallis across states within this group ----
            stat_text = ""
            groups = [df_g.loc[df_g[CLINICAL_STATE_COL]==s, AP_COL].dropna().values for s in CELL9_TARGET_STATES_ORDERED]
            groups_nonempty = [g for g in groups if len(g) >= MIN_SAMPLES_FOR_GROUP_COMPARISON]
            if len(groups_nonempty) >= 2:
                H, p = kruskal(*groups_nonempty)
                stat_text = f"Kruskal–Wallis: H={H:.2f}, p={p:.3g}"
                kw_rows.append({
                    'GroupDim': group_dim, 'Group': grp, 'FreqBand': freq_label,
                    'Test': 'Kruskal-Wallis', 'H': H, 'p_value': p,
                    'N_groups_ge_min': len(groups_nonempty)
                })
                # Optional: Dunn posthoc (only if you want; can be commented out)
                # try:
                #     dunn = sp.posthoc_dunn(df_g[[CLINICAL_STATE_COL, AP_COL]].dropna(), val_col=AP_COL, group_col=CLINICAL_STATE_COL, p_adjust='fdr_bh')
                #     # You could write dunn to CSV per (grp,freq) if desired.
                # except Exception:
                #     pass

            # ---- Finalize ----
            ax.set_title(f"{AP_LABEL} — {grp} [{freq_label}]\n{stat_text}", pad=12)
            ax.set_ylabel("Exponent")
            ax.set_xlabel("Clinical State")
            ax.tick_params(axis='y', labelsize=16)
            ax.set_xticklabels([STATE_TICK_SHORT.get(s, s) for s in CELL9_TARGET_STATES_ORDERED],
                               rotation=35, ha="right")

            # Legend
            legend_elements = [
                mpatches.Patch(facecolor='grey', alpha=BOX_FILL_ALPHA, label='Median & IQR (box)'),
                mlines.Line2D([], [], color='grey', marker='o', linestyle='None', markersize=6, label='10-min average'),
                mlines.Line2D([], [], color='black', marker='_', markersize=12, linestyle='None', label='Median 95% CI (bootstrap)')
            ]
            ax.legend(handles=legend_elements, loc='upper right', fontsize=14, frameon=True)

            plt.tight_layout()
            safe_grp  = str(grp).replace(' ', '_').replace('-', '_')
            safe_freq = str(freq_label).replace(' ', '_')
            out_png = os.path.join(plot_dir, f"{patient_hemisphere_id}_{group_dim}_{safe_grp}_{safe_freq}_Exponent_with_CI.png")
            plt.savefig(out_png, dpi=300)
            plt.close(fig)

    # ---------------- Write CSV outputs ----------------
    if summary_rows:
        df_summary = pd.DataFrame(summary_rows)
        out_csv = os.path.join(plot_dir, f"{patient_hemisphere_id}_Cell9_Medians_CI_by_{group_dim}.csv")
        df_summary.to_csv(out_csv, index=False)
        print(f"  Saved medians/CI table to: {out_csv}")

    if kw_rows:
        df_kw = pd.DataFrame(kw_rows)
        out_kw = os.path.join(plot_dir, f"{patient_hemisphere_id}_Cell9_Kruskal_by_{group_dim}.csv")
        df_kw.to_csv(out_kw, index=False)
        print(f"  Saved Kruskal–Wallis summary to: {out_kw}")

    print("\n--- Cell 9 (Revised V3, Region-first): Completed. ---")



--- Cell 9 (Revised V3, Region-first): Generating Exponent Box Plots with Median CIs ---
  Plots will be saved to: /home/jackson/step2_final/step3_fooof_results_neural_pkg_aligned_finalstep3_bushlab5000/step4_within_subject/COHORT_RCS02_05_06_plots_20250822_155044/Exponent_BoxPlots_with_MedianCI_REGION
  Grouping by: Region
  LowFreq: removed 554 outliers (5.5%) via IQR per (Region×State).
  MidFreq: removed 73 outliers (0.7%) via IQR per (Region×State).
  WideFreq: removed 209 outliers (2.1%) via IQR per (Region×State).
  Saved medians/CI table to: /home/jackson/step2_final/step3_fooof_results_neural_pkg_aligned_finalstep3_bushlab5000/step4_within_subject/COHORT_RCS02_05_06_plots_20250822_155044/Exponent_BoxPlots_with_MedianCI_REGION/COHORT_RCS02_05_06_Cell9_Medians_CI_by_Region.csv
  Saved Kruskal–Wallis summary to: /home/jackson/step2_final/step3_fooof_results_neural_pkg_aligned_finalstep3_bushlab5000/step4_within_subject/COHORT_RCS02_05_06_plots_20250822_155044/Exponent_BoxPlots_w

In [20]:
# -*- coding: utf-8 -*-
# --- Cell 11: Generate Final Data Table for Cross-Subject Analysis (Input for Step 5) ---
# This cell prepares the output from Step 4 to be used as input for Step 5.
# The master_df_step4 already contains all necessary information, including
# aperiodic metrics for EACH FreqRangeLabel, LEDD, Beta, Gamma.

print("\n--- Cell 11: Preparing Data Table for Step 5 (Cross-Subject Analysis) ---")

if master_df_step4 is None or master_df_step4.empty:
    print("master_df_step4 not available or empty. Cannot generate final table for Step 5.")
else:
    # Columns to include in the output for Step 5
    # Should match 'master_table_columns' from Step 3 Cell 2, plus UserSessionName from Step 3 Cell 8
    # Ensure 'UserSessionName' is defined. If this script is run standalone for one patient-hemi,
    # 'UserSessionName' would be the patient_hemisphere_id.
    
    # Columns defined in Step 3's master_table_columns (Cell 2 of Step 3)
    # This list must be kept in sync with the actual columns produced by Step 3.
    # For robustness, we select columns that are ACTUALLY PRESENT in master_df_step4
    # and try to match the intended set.
    
# --- Make script robust to optional columns and add Region helper ---
# If the aggregated-state col name isn't defined upstream, create a dummy to avoid NameError
    try:
        _ = CLINICAL_STATE_AGGREGATED_COL
    except NameError:
        CLINICAL_STATE_AGGREGATED_COL = 'Clinical_State_Aggregated'
        if CLINICAL_STATE_AGGREGATED_COL not in master_df_step4.columns:
            master_df_step4[CLINICAL_STATE_AGGREGATED_COL] = pd.NA
    
    # If a human-readable datetime string is not present, synthesize it from the unix ts (UTC)
    if ('Aligned_PKG_DateTime_Str' not in master_df_step4.columns) and ('Aligned_PKG_UnixTimestamp' in master_df_step4.columns):
        master_df_step4['Aligned_PKG_DateTime_Str'] = pd.to_datetime(
            master_df_step4['Aligned_PKG_UnixTimestamp'], unit='s', utc=True, errors='coerce'
        ).astype('datetime64[ns]').astype(str)
    
    # Add Region if not present (use Channel_Display-style labels)
    if 'Region' not in master_df_step4.columns:
        def _to_region(lbl):
            if pd.isna(lbl): return pd.NA
            s = str(lbl)
            if 'STN' in s: return 'STN'
            if 'Cortical' in s or 'ECoG' in s or 'M1' in s: return 'M1'
            return pd.NA
        master_df_step4['Region'] = master_df_step4[CHANNEL_DISPLAY_COL].map(_to_region)
    
    # ---- Your original intended columns list (unchanged) ----
    intended_step5_cols = [
        'SessionID', 'Hemisphere', 'Channel', CHANNEL_DISPLAY_COL,
        'Neural_Segment_Start_Unixtime', 'Neural_Segment_End_Unixtime',
        'Neural_Segment_Duration_Sec', 'FS',
        'Aligned_PKG_UnixTimestamp', 'Aligned_PKG_DateTime_Str',
        CLINICAL_STATE_COL, CLINICAL_STATE_AGGREGATED_COL,
        'Aligned_BK', 'Aligned_DK', 'Aligned_Tremor_Score', 'Aligned_Tremor',
        'Total_Daily_LEDD_mg',
        'Beta_Peak_Power_at_DominantFreq', 'Gamma_Peak_Power_at_DominantFreq',
        FOOOF_FREQ_BAND_COL, 'FreqLow', 'FreqHigh',
        'BestModel_AperiodicMode',
        'Offset_BestModel', 'Knee_BestModel', 'Exponent_BestModel',
        'R2_BestModel', 'Error_BestModel', 'Num_Peaks_BestModel',
        # Convenience field for Step 5 groupings
        'Region',
    ]

    final_table_cols_step5_existing = [col for col in intended_step5_cols if col in master_df_step4.columns]
    
    if not final_table_cols_step5_existing:
        print("Warning: No columns identified for the Step 5 data table. It will be empty.")
        final_data_table_for_step5 = pd.DataFrame()
    else:
        final_data_table_for_step5 = master_df_step4[final_table_cols_step5_existing].copy()
        
        # Add 'UserSessionName' which was previously added in Step 3 Cell 8.
        # Here, we re-affirm it as the patient_hemisphere_id for this file.
        if 'UserSessionName' not in final_data_table_for_step5.columns:
            final_data_table_for_step5.insert(0, 'UserSessionName', patient_hemisphere_id)
        else: # If it was somehow carried over from a loaded file that already had it
            final_data_table_for_step5['UserSessionName'] = patient_hemisphere_id


        # Optional: Sort the table
        sort_by_cols_step5 = ['UserSessionName', 'Aligned_PKG_UnixTimestamp', CHANNEL_DISPLAY_COL, FOOOF_FREQ_BAND_COL]
        sort_by_cols_step5_existing = [col for col in sort_by_cols_step5 if col in final_data_table_for_step5.columns]
        if sort_by_cols_step5_existing:
            final_data_table_for_step5.sort_values(by=sort_by_cols_step5_existing, inplace=True, ignore_index=True)

        print(f"  Final data table for Step 5 created with {final_data_table_for_step5.shape[0]} rows and {final_data_table_for_step5.shape[1]} columns.")
        print(f"  Columns included: {final_data_table_for_step5.columns.tolist()}")

    # Define filename and save (this output path should ideally be outside the patient-specific plot folder,
    # in a place where Step 5 can glob all such files)
    # The original Step 4 saved this in analysis_plots_root_folder (one level up from session_plot_folder_name_step4)
    
    output_filename_for_step5 = f"{patient_hemisphere_id}_CrossSubjectAnalysis_DataTable_{current_datetime_str_step4}.csv"
    # Save in the root of the Step 4 analysis folder (step4_analysis_root_folder)
    # This aligns with where Step 5 would look for inputs from multiple subjects.
    output_path_for_step5 = os.path.join(step4_analysis_root_folder, output_filename_for_step5)

    try:
        final_data_table_for_step5.to_csv(output_path_for_step5, index=False)
        print(f"  Successfully saved final data table for Step 5 input to: {output_path_for_step5}")
        print("\n  Sample of this final data table (first 5 rows):")
        print(final_data_table_for_step5.head())
    except Exception as e_save_final_step4:
        print(f"  ERROR saving the final data table for Step 5 input: {e_save_final_step4}")

print(f"\n--- Cell 11: Final Data Table generation for {patient_hemisphere_id} complete ---")
print(f"\n--- All Step 4 processing for {patient_hemisphere_id} complete. Outputs are in {analysis_session_plot_folder_step4} and {step4_analysis_root_folder} ---")


--- Cell 11: Preparing Data Table for Step 5 (Cross-Subject Analysis) ---
  Final data table for Step 5 created with 32436 rows and 28 columns.
  Columns included: ['UserSessionName', 'SessionID', 'Hemisphere', 'Channel_Display', 'Neural_Segment_Start_Unixtime', 'Neural_Segment_End_Unixtime', 'Neural_Segment_Duration_Sec', 'FS', 'Clinical_State_2min_Window', 'Clinical_State_Aggregated', 'Aligned_BK', 'Aligned_DK', 'Aligned_Tremor_Score', 'Aligned_Tremor', 'Total_Daily_LEDD_mg', 'Beta_Peak_Power_at_DominantFreq', 'Gamma_Peak_Power_at_DominantFreq', 'FreqRangeLabel', 'FreqLow', 'FreqHigh', 'BestModel_AperiodicMode', 'Offset_BestModel', 'Knee_BestModel', 'Exponent_BestModel', 'R2_BestModel', 'Error_BestModel', 'Num_Peaks_BestModel', 'Region']
  Successfully saved final data table for Step 5 input to: /home/jackson/step2_final/step3_fooof_results_neural_pkg_aligned_finalstep3_bushlab5000/step4_within_subject/COHORT_RCS02_05_06_CrossSubjectAnalysis_DataTable_20250822_155044.csv

  Sample o

In [21]:
# -*- coding: utf-8 -*-
# --- Cell 12 (New, Region-aware): Generate CURATED Final Data Table for Cross-Subject Analysis ---
# Works with Cell 8 "Region-first" (no contact selection) OR with contact-based selections.

import pandas as pd
import numpy as np
import os

print("\n--- Cell 12: Preparing CURATED Data Table for Step 5 (Cross-Subject Analysis) ---")

# ---------- Helpers & robustness ----------
def _ensure_agg_and_datetime(df):
    """Make sure optional columns exist: aggregated state + readable datetime string."""
    # Aggregated state name might not be defined upstream; create if missing
    try:
        _ = CLINICAL_STATE_AGGREGATED_COL
    except NameError:
        # define a safe default name if it wasn't set earlier
        globals()['CLINICAL_STATE_AGGREGATED_COL'] = 'Clinical_State_Aggregated'
    if CLINICAL_STATE_AGGREGATED_COL not in df.columns:
        df[CLINICAL_STATE_AGGREGATED_COL] = pd.NA

    # Human-readable datetime string from Unix ts (UTC)
    if ('Aligned_PKG_DateTime_Str' not in df.columns) and ('Aligned_PKG_UnixTimestamp' in df.columns):
        df['Aligned_PKG_DateTime_Str'] = pd.to_datetime(
            df['Aligned_PKG_UnixTimestamp'], unit='s', utc=True, errors='coerce'
        ).astype('datetime64[ns]').astype(str)

def _ensure_region(df):
    """Create a Region column (STN/M1) if missing, based on CHANNEL_DISPLAY_COL."""
    if 'Region' not in df.columns:
        def _to_region(lbl):
            if pd.isna(lbl): return pd.NA
            s = str(lbl)
            if 'STN' in s: return 'STN'
            if ('Cortical' in s) or ('ECoG' in s) or ('M1' in s): return 'M1'
            return pd.NA
        df['Region'] = df[CHANNEL_DISPLAY_COL].map(_to_region)
    return df

def _pick_pcol(df_lrt):
    for c in ['P_value_FDR', 'P_FDR', 'p_adj', 'p_fdr', 'P_value', 'p_value', 'pval', 'P']:
        if c in df_lrt.columns: return c
    return None

# ---------- Prereqs ----------
if 'master_df_step4' not in locals() or master_df_step4 is None or master_df_step4.empty:
    print("\nERROR: master_df_step4 not available or empty. Cannot generate curated table.")
else:
    # Make sure optional columns exist
    _ensure_agg_and_datetime(master_df_step4)
    master_df_step4 = _ensure_region(master_df_step4)

    # Decide how to curate based on user_selections from Cell 8
    sel_obj = None
    if 'user_selections' in locals():
        sel_obj = user_selections

    mode = None
    sel_map = None
    if isinstance(sel_obj, dict) and 'mode' in sel_obj:
        mode = sel_obj.get('mode')
        sel_map = sel_obj.get('selections') if isinstance(sel_obj.get('selections'), dict) else None
    elif isinstance(sel_obj, dict):
        # Older Cell 8 returned the raw selections dict
        sel_map = sel_obj

    curated_data_rows = []

    if isinstance(sel_map, dict) and any(sel_map.values()):
        # ----- CONTACT FALLBACK MODE (explicit per-symptom channel picks) -----
        print("   Using contact selections from Cell 8 to build curated dataset...")
        for symptom_dv, choices in sel_map.items():
            for group_name, chosen_channel in choices.items():
                sub = master_df_step4[master_df_step4[CHANNEL_DISPLAY_COL] == chosen_channel].copy()
                if sub.empty:
                    continue
                sub['BinaryChannel'] = group_name  # 'STN' or 'M1'
                curated_data_rows.append(sub)
        curation_mode_used = "contact_fallback"
    else:
        # ----- REGION MODE (auto-curate by Region) -----
        print("   No explicit contact selections found (Region mode).")
        print("   Curating by Region: including all rows where Region ∈ {STN, M1} and setting BinaryChannel = Region.")
        sub = master_df_step4[master_df_step4['Region'].isin(['STN', 'M1'])].copy()
        sub['BinaryChannel'] = sub['Region']
        curated_data_rows.append(sub)
        curation_mode_used = "region"

    if not curated_data_rows:
        print("   ERROR: No matching data found for curation.")
    else:
        df_curated = pd.concat(curated_data_rows, ignore_index=True).drop_duplicates().reset_index(drop=True)
        print(f"   Curated dataframe built ({curation_mode_used}) with {df_curated.shape[0]} rows.")

        # ---------- Select & format columns ----------
        intended_cols = [
            'UserSessionName', 'SessionID', 'Hemisphere',
            'BinaryChannel',           # new curated field (STN/M1)
            'Region',                  # keep Region for convenience
            CHANNEL_DISPLAY_COL,       # original channel label
            'Aligned_PKG_UnixTimestamp','Aligned_PKG_DateTime_Str',
            CLINICAL_STATE_COL, CLINICAL_STATE_AGGREGATED_COL,
            'Aligned_BK', 'Aligned_DK', 'Aligned_Tremor_Score',
            'Total_Daily_LEDD_mg',
            'Beta_Peak_Power_at_DominantFreq', 'Gamma_Peak_Power_at_DominantFreq',
            FOOOF_FREQ_BAND_COL, 'BestModel_AperiodicMode',
            'Offset_BestModel', 'Exponent_BestModel',
            'R2_BestModel', 'Error_BestModel'
        ]
        keep_cols = [c for c in intended_cols if c in df_curated.columns]
        df_final_curated = df_curated[keep_cols].copy()

        # Ensure UserSessionName
        if 'UserSessionName' not in df_final_curated.columns:
            df_final_curated.insert(0, 'UserSessionName', patient_hemisphere_id)
        else:
            df_final_curated['UserSessionName'] = patient_hemisphere_id

        # Sort for consistency
        sort_candidates = ['UserSessionName', 'BinaryChannel', 'Aligned_PKG_UnixTimestamp', FOOOF_FREQ_BAND_COL]
        sort_present = [c for c in sort_candidates if c in df_final_curated.columns]
        if sort_present:
            df_final_curated.sort_values(by=sort_present, inplace=True, ignore_index=True)

        print(f"   Final curated data table has {df_final_curated.shape[0]} rows and {df_final_curated.shape[1]} columns.")
        print(f"   Columns: {list(df_final_curated.columns)}")

        # ---------- Save ----------
        output_filename_curated = f"{patient_hemisphere_id}_Curated_CrossSubject_DataTable_{current_datetime_str_step4}.csv"
        output_path_curated = os.path.join(step4_analysis_root_folder, output_filename_curated)

        try:
            df_final_curated.to_csv(output_path_curated, index=False)
            print(f"\n   Successfully saved CURATED data table for Step 5 input to: {output_path_curated}")
            print("\n   Sample (first 5 rows):")
            print(df_final_curated.head())
        except Exception as e:
            print(f"   ERROR saving the curated data table: {e}")

print(f"\n--- Cell 12: CURATED Final Data Table generation for {patient_hemisphere_id} complete ---")



--- Cell 12: Preparing CURATED Data Table for Step 5 (Cross-Subject Analysis) ---
   No explicit contact selections found (Region mode).
   Curating by Region: including all rows where Region ∈ {STN, M1} and setting BinaryChannel = Region.
   Curated dataframe built (region) with 32436 rows.
   Final curated data table has 32436 rows and 20 columns.
   Columns: ['UserSessionName', 'SessionID', 'Hemisphere', 'BinaryChannel', 'Region', 'Channel_Display', 'Clinical_State_2min_Window', 'Clinical_State_Aggregated', 'Aligned_BK', 'Aligned_DK', 'Aligned_Tremor_Score', 'Total_Daily_LEDD_mg', 'Beta_Peak_Power_at_DominantFreq', 'Gamma_Peak_Power_at_DominantFreq', 'FreqRangeLabel', 'BestModel_AperiodicMode', 'Offset_BestModel', 'Exponent_BestModel', 'R2_BestModel', 'Error_BestModel']

   Successfully saved CURATED data table for Step 5 input to: /home/jackson/step2_final/step3_fooof_results_neural_pkg_aligned_finalstep3_bushlab5000/step4_within_subject/COHORT_RCS02_05_06_Curated_CrossSubject_Dat