In [1]:
# -*- coding: utf-8 -*-
# --- Cell 1: Imports and Helper Functions ---

import os
import sys
import numpy as np
import json
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from datetime import datetime, timezone, timedelta
import pytz
import traceback
import random
from scipy import stats
from scipy.signal import find_peaks # For peak finding in flattened spectra


# --- Import FOOOF ---
try:
    from specparam.objs import SpectralModel as FOOOF   # drop-in alias
    from specparam.sim.gen import gen_aperiodic
    print(f"Successfully imported 'FOOOF' class, 'Bands', and 'gen_aperiodic'.")
except ImportError as e:
    print(f"ERROR: Could not import 'FOOOF': {e}"); sys.exit()
except Exception as e:
    print(f"An unexpected error occurred during FOOOF import: {e}"); sys.exit()

# --- Helper Function to add metrics text to plots ---
def add_metrics_to_plot(ax, fm_obj, fit_type="", model_choice=None, best_model_flag=False):
    """Adds aperiodic parameters and fit metrics to the top-right of a FOOOF plot.
       bbox facecolor changes based on whether it's the 'best_model_flag'.
    """
    if not fm_obj or not fm_obj.has_model:
        line1 = f"Model: {fit_type or 'N/A'}"
        if model_choice:
            line1 = f"Best: NoFit ({model_choice})"
        text_str = f"{line1}\nAP: NaN\nFit: R2=NaN, Err=NaN"
        ax.text(0.98, 0.98, text_str, transform=ax.transAxes, fontsize=8,
                verticalalignment='top', horizontalalignment='right',
                bbox=dict(boxstyle='round,pad=0.3', fc='lightcoral', alpha=0.75))
        return

    ap_params = fm_obj.aperiodic_params_
    r_sq = fm_obj.r_squared_
    err = fm_obj.error_
    mode = fm_obj.aperiodic_mode

    if model_choice == "NoFit":
        bg_color = 'lightcoral'
    elif best_model_flag:
        bg_color = 'lightgreen'
    else:
        bg_color = 'wheat'

    line1 = f"Model: {model_choice}" if model_choice else f"Type: {fit_type}"
    line1 += f" (Mode: {mode})"

    line2 = "AP: "
    if mode == 'fixed' and len(ap_params) == 2:
        line2 += f"Off={ap_params[0]:.2f}, Exp={ap_params[1]:.2f}"
    elif mode == 'knee' and len(ap_params) == 3:
        offset_label = "Off"; knee_label = "Knee"; exp_label = "Exp"
        knee_status = ""
        if hasattr(fm_obj, 'freqs') and fm_obj.freqs is not None and len(fm_obj.freqs) > 0:
            f_min_analysis = fm_obj.freq_range[0]
            if ap_params[1] < (f_min_analysis + 2) or (len(ap_params) > 2 and abs(ap_params[2]) < 0.1):
                 knee_status = " (Knee?)"
        line2 += f"{offset_label}={ap_params[0]:.2f}, {knee_label}={ap_params[1]:.1f}, {exp_label}={ap_params[2]:.2f}{knee_status}"
    else:
        line2 += f"Params={np.round(ap_params, 2)}"

    line3 = f"Fit: R2={r_sq:.3f}, Err={err:.2e}"
    text_str = f"{line1}\n{line2}\n{line3}"
    ax.text(0.98, 0.98, text_str, transform=ax.transAxes, fontsize=8,
            verticalalignment='top', horizontalalignment='right',
            bbox=dict(boxstyle='round,pad=0.3', fc=bg_color, alpha=0.75))

# Matplotlib settings
plt.rcParams['figure.dpi'] = 100
plt.rcParams['savefig.dpi'] = 600

print("Cell 1: Imports and helper functions defined.")

# --- Prompt user for subject ID (Patient ID) ---
# For development, hardcoded; can replace with input() for interactive use
# subject_id_input = input("Enter subject ID (e.g., RCS08R or RCS20 for RCS20R/L): ").strip().upper()
# For example, if file is RCS20R_data.csv, subject_id_input might be "RCS20R"
# Or if we want to process RCS20L and RCS20R, the script might loop or take specific patient-hemi
subject_id_input = "RCS05L" # Example: This should be the patient-hemisphere ID
print(f"User specified subject-hemisphere ID: {subject_id_input}")
session_id = subject_id_input # session_id in this script context refers to the patient-hemisphere identifier

project_base_path = "/home/jackson/step2_final"

# --- End of Cell 1 ---

Successfully imported 'FOOOF' class, 'Bands', and 'gen_aperiodic'.
Cell 1: Imports and helper functions defined.
User specified subject-hemisphere ID: RCS05L


In [2]:
# -*- coding: utf-8 -*-
# --- Cell 2: Configuration ---

import os
import pandas as pd
import numpy as np  # Required for np.inf

# Prefer specparam; fall back to fooof if needed (for portability)

from specparam.objs import SpectralModel as FOOOF
USING_SPECPARAM = True

# --- Regularization ---
# Use a big default so you can SEE an effect; tune down later.
REG_LAMBDA = float(os.getenv("FOOOF_REG_LAMBDA", "5000"))

def set_lambda(fm, lam=REG_LAMBDA):
    """Set the regularization weight where the patched algorithm reads it."""
    lam = float(lam)
    # Put λ directly on the SpectralModel instance. The patched algorithm reads self.regularization_weight
    setattr(fm, "regularization_weight", lam)
    # Optional: tiny marker so we can print later
    fm._reg_marker = f"λ={lam}"
    return lam


# -----------------------------------------------------------------------------
# Path Configuration
# -----------------------------------------------------------------------------
try:
    current_script_path = os.path.dirname(os.path.abspath(__file__))  # .py
except NameError:
    current_script_path = os.getcwd()  # .ipynb

print(f"Project base path determined as: {project_base_path}")

# --- Input Data Paths (derived from project_base_path and session_id from Cell 1) ---
patient_code = session_id[:-1]
hemisphere_code_short = session_id[-1]
hemisphere_long = "Right" if hemisphere_code_short == "R" else "Left" if hemisphere_code_short == "L" else "UnknownHemi"
# step2_output_folder_name = f"step2_preprocessed_data_120s_neural_aligned_{patient_code}{hemisphere_code_short}_{hemisphere_long}_AllSessions_newnaming_tester_right_neww"
# step2_output_path = os.path.join(project_base_path, step2_output_folder_name)
step2_output_folder_name = f"step2_preprocessed_data_120s_neural_aligned_{patient_code}{hemisphere_code_short}_{hemisphere_long}_AllSessions_newnaming_tester_right_neww"
# step2_output_path = os.path.join(project_base_path, step2_output_folder_name)
step2_output_path = (project_base_path)

print(f"Expecting Step 2 output CSV/JSON files in: {step2_output_path}")

# --- Output Configuration ---
output_version_tag = "neural_pkg_aligned_finalstep3_bushlab5000"
fooof_output_folder_py_base = os.path.join(project_base_path, f'step3_fooof_results_{output_version_tag}')
os.makedirs(fooof_output_folder_py_base, exist_ok=True)
print(f"Base output folder for new FOOOF results: {fooof_output_folder_py_base}")

# --- Analysis Setup (Defaults) ---
freq_ranges_default = {
    'LowFreq': [10, 40],
    'MidFreq': [30, 90],
    'WideFreq': [10, 90],
}
channels_to_process_default = []
electrode_labels_default = {}
neural_hemisphere_default = hemisphere_long
target_neural_segment_duration_default = 30.0
pkg_interpolation_interval_default = 30.0

# --- FOOOF/SpecParam Settings (no regularization passed to __init__) ---
common_new_fooof_params = {
    'peak_width_limits': [2.0, 8.0],
    'max_n_peaks': np.inf,
    'min_peak_height': 0.0,
    'peak_threshold': 2.0,
    'verbose': False,
}
basic_fooof_settings = {**common_new_fooof_params, 'aperiodic_mode': 'fixed'}
knee_fooof_settings  = {**common_new_fooof_params, 'aperiodic_mode': 'knee'}

print("\n--- Basic FOOOF Settings (Fixed Mode) ---")
tmp = FOOOF(**basic_fooof_settings)
if hasattr(tmp, "print_settings"):
    tmp.print_settings()
else:
    print(basic_fooof_settings, f"(λ will be set post-init to {REG_LAMBDA})")

print("\n--- Advanced FOOOF Settings (Knee Mode) ---")
tmp = FOOOF(**knee_fooof_settings)
if hasattr(tmp, "print_settings"):
    tmp.print_settings()
else:
    print(knee_fooof_settings, f"(λ will be set post-init to {REG_LAMBDA})")

# --- NEW: Beta/Gamma Peak Feature Extraction Parameters ---
BETA_BAND = [13, 30]   # Hz
GAMMA_BAND = [60, 90]  # Hz
BETA_BASELINE_FREQ_REGION  = [10, 12]
GAMMA_BASELINE_FREQ_REGION = [55, 59]
PEAK_SIG_SD_THRESHOLD = 1.0

# --- Define Master Table Columns ---
master_table_columns = [
    'SessionID', 'Hemisphere', 'Channel', 'ElectrodeLabel',
    'Neural_Segment_Start_Unixtime', 'Neural_Segment_End_Unixtime',
    'Neural_Segment_Duration_Sec',
    'FS',
    'PSD_Data_Str', 'Frequency_Vector_Str',
    'Aligned_PKG_UnixTimestamp', 'Aligned_PKG_DateTime_Str', 'Clinical_State_2min_Window',
    'Aligned_BK', 'Aligned_DK', 'Aligned_Tremor_Score', 'Aligned_Tremor',
    'Total_Daily_LEDD_mg',
    'Beta_Peak_Power_at_DominantFreq',
    'Gamma_Peak_Power_at_DominantFreq',
    'FreqRangeLabel', 'FreqLow', 'FreqHigh',
    'BestModel_AperiodicMode',
    'Offset_BestModel', 'Knee_BestModel', 'Exponent_BestModel',
    'R2_BestModel', 'Error_BestModel', 'Num_Peaks_BestModel',
    'Offset_Fixed', 'Exponent_Fixed',
    'R2_Fixed', 'Error_Fixed', 'Num_Peaks_Fixed',
    'Offset_Knee', 'Knee_Knee', 'Exponent_Knee',
    'R2_Knee', 'Error_Knee', 'Num_Peaks_Knee',
    'ErrorMsg_FOOOF'
]

# --- Plotting Configuration for Visualization Cells ---
NUM_REPRESENTATIVE_SEGMENTS_PER_CHANNEL = 3
APERIODIC_PLOT_TIME_BIN_MINUTES = 30
HOURLY_AVG_FREQ = 'H'
THIRTY_MIN_AVG_FREQ = '30T'
BOXPLOT_PALETTE = "Set2"
HISTOGRAM_BINS = 50

REFINED_MOBILE_STATES = ['Dyskinetic Mobile', 'Non-Dyskinetic Mobile', 'Transitional Mobile']
AGGREGATED_MOBILE_NAME = 'Mobile (All Types)'

print("\nCell 2: Configuration complete with updated SpecParam settings and new metric parameters.")
# --- End of Cell 2 ---
import inspect
import specparam.objs.algorithm as _alg
from specparam.objs import SpectralModel as _SM

print("Specparam SpectralModel file:", inspect.getfile(_SM))
print("Algorithm file:", inspect.getfile(_alg))
_src = inspect.getsource(_alg.SpectralFitAlgorithm._fit_peak_guess)
print("Algorithm uses least_squares?  ", "least_squares(" in _src)
print("Algorithm mentions reg_weight? ", "regularization_weight" in _src)


Project base path determined as: /home/jackson/step2_final
Expecting Step 2 output CSV/JSON files in: /home/jackson/step2_final
Base output folder for new FOOOF results: /home/jackson/step2_final/step3_fooof_results_neural_pkg_aligned_finalstep3_bushlab5000

--- Basic FOOOF Settings (Fixed Mode) ---
                                                                                                  
                                       specparam - SETTINGS                                       
                                                                                                  
                                  Peak Width Limits : [2.0, 8.0]                                  
                                    Max Number of Peaks : inf                                     
                                    Minimum Peak Height : 0.0                                     
                                       Peak Threshold: 2.0                                        
      

In [3]:
# -*- coding: utf-8 -*-
# --- Cell 3: Load Parameters and Data from CSV & JSON, Calculate Daily LEDD ---

import json
import numpy as np
import pandas as pd
import os
import sys
import pytz
import re # Import regular expressions for parsing
from datetime import datetime, timezone, time

# --- Initialize with defaults (from Cell 2) ---
params_from_matlab = {}
loaded_freq_ranges = freq_ranges_default.copy()
neural_hemisphere = neural_hemisphere_default # Derived from session_id in Cell 2
target_neural_segment_duration = target_neural_segment_duration_default
pkg_interpolation_interval = pkg_interpolation_interval_default

# --- Construct expected filenames and paths (session_id is patient-hemisphere from Cell 1) ---
target_csv_filename = f"Step2_Aligned120sPSDs_{session_id}_Left_AllSessions.csv"
params_json_filename = f"Step2_params_120sPSDs_{session_id}_Left_AllSessions.json"

csv_file_path = os.path.join(step2_output_path, target_csv_filename)
params_json_path = os.path.join(step2_output_path, params_json_filename)

# --- Check if files exist ---
if not os.path.exists(csv_file_path):
    print(f"ERROR: Main data CSV file not found at {csv_file_path}. Cannot proceed.")
    sys.exit()
print(f"\nTarget CSV file: {target_csv_filename}")

if not os.path.exists(params_json_path):
    print(f"Warning: Parameters JSON file not found at {params_json_path}. Using defaults for freq_ranges, segment durations.")
else:
    print(f"Target JSON_params file: {params_json_filename}")

# --- Load Main Data from Step 2 CSV file ---
fooof_input_df = pd.DataFrame()
try:
    fooof_input_df = pd.read_csv(csv_file_path)
    print(f"Successfully loaded CSV from Step 2. Shape: {fooof_input_df.shape}")
    if fooof_input_df.empty:
        print("ERROR: Loaded CSV is empty. Cannot proceed.")
        sys.exit()

    # Use session_id (patient-hemisphere) from Cell 1 for consistency
    fooof_input_df['SessionID_original_from_csv'] = fooof_input_df.get('PatientID', session_id) # Keep original if exists
    fooof_input_df['SessionID'] = session_id # This is patient-hemisphere like RCS20R

    if 'Hemisphere' in fooof_input_df.columns and not fooof_input_df['Hemisphere'].isnull().all():
        csv_hemisphere = fooof_input_df['Hemisphere'].dropna().iloc[0] if not fooof_input_df['Hemisphere'].dropna().empty else neural_hemisphere_default
        if csv_hemisphere.lower() != neural_hemisphere.lower(): # neural_hemisphere from Cell 2 (via session_id)
            print(f"Warning: Hemisphere in CSV ({csv_hemisphere}) differs from expected based on session_id ({neural_hemisphere}). Using value from session_id.")
        fooof_input_df['Hemisphere'] = neural_hemisphere
    else:
        print("Warning: 'Hemisphere' column not found or empty in CSV. Assigning based on session_id.")
        fooof_input_df['Hemisphere'] = neural_hemisphere

except Exception as e:
    print(f"ERROR: Could not load or parse CSV file {csv_file_path}: {e}")
    print(traceback.format_exc())
    sys.exit()

# --- Load Parameters from JSON file (if found) ---
electrode_info_matlab = {}
if os.path.exists(params_json_path):
    print(f"\nLoading parameters from JSON: {params_json_path}")
    try:
        with open(params_json_path, 'r') as f:
            params_from_matlab = json.load(f)
        print("Parameters JSON loaded successfully.")
        loaded_freq_ranges = params_from_matlab.get('freq_ranges_defined_for_fooof', loaded_freq_ranges)
        electrode_info_matlab = params_from_matlab.get('electrode_info_used', {})
        target_neural_segment_duration = params_from_matlab.get('target_neural_segment_duration_sec', target_neural_segment_duration)
        # ... (other params as in original script) ...
    except Exception as e:
        print(f"ERROR: Could not load or parse parameters JSON file {params_json_path}: {e}")
else:
    print(f"\nParameters JSON file not found at {params_json_path}. Using defaults.")

ITERATIVE_FREQ_BANDS = loaded_freq_ranges
print(f"\nUsing Session ID (Patient-Hemisphere): {session_id}")
print(f"Using Hemisphere: {neural_hemisphere}")
print(f"Using freq_ranges for FOOOF: {ITERATIVE_FREQ_BANDS}")

# --- Create dynamic Output Folders ---
session_specific_output_folder = os.path.join(fooof_output_folder_py_base, str(session_id), str(neural_hemisphere))
visualization_folder_session_specific = os.path.join(session_specific_output_folder, 'visualizations_step3')
aperiodic_plot_output_folder_base = visualization_folder_session_specific
FIGURES_OUTPUT_PATH_NEW_CELLS = os.path.join(session_specific_output_folder, 'new_analysis_figures') # For Cell 5a,b,c plots
DATA_OUTPUT_PATH_NEW_CELLS = os.path.join(session_specific_output_folder, 'new_analysis_data') # For Cell 4b outputs
for folder_path in [session_specific_output_folder, visualization_folder_session_specific, FIGURES_OUTPUT_PATH_NEW_CELLS, DATA_OUTPUT_PATH_NEW_CELLS]:
    if not os.path.exists(folder_path): os.makedirs(folder_path)
print(f"Step 3 session-specific outputs will be in: {session_specific_output_folder}")

master_csv_filename_patient = f"MASTER_FOOOF_PKG_results_{session_id}_{output_version_tag}.csv"
master_csv_path_patient_specific = os.path.join(fooof_output_folder_py_base, master_csv_filename_patient)
print(f"Master results CSV for {session_id} will be saved/updated at: {master_csv_path_patient_specific}")

# --- Parse PSD_Data and Frequency_Vector ---
def parse_float_string(s):
    try: return np.array(s.split(';'), dtype=float)
    except: return np.array([]) # return empty if error

if 'PSD_Data_Str' in fooof_input_df.columns:
    fooof_input_df['PSD_Data'] = fooof_input_df['PSD_Data_Str'].apply(parse_float_string) # Linear power
    print("Parsed 'PSD_Data_Str' into numpy arrays.")
else:
    print("ERROR: 'PSD_Data_Str' column not found. Cannot run FOOOF."); sys.exit()

if 'Frequency_Vector_Str' in fooof_input_df.columns:
    fooof_input_df['Frequency_Vector_Raw'] = fooof_input_df['Frequency_Vector_Str'].apply(parse_float_string)
    print("Parsed 'Frequency_Vector_Str' into numpy arrays (as Frequency_Vector_Raw).")
else:
    print("ERROR: 'Frequency_Vector_Str' column not found. Cannot run FOOOF."); sys.exit()

# --- Add/Verify Neural_Segment_Duration_Sec & DateTime columns ---
# ... (rest of original Cell 3 for duration, datetime, electrode labels, channels to process) ...
if 'Neural_Segment_Duration_Sec' not in fooof_input_df.columns:
    if 'Neural_Segment_Start_Unixtime' in fooof_input_df.columns and 'Neural_Segment_End_Unixtime' in fooof_input_df.columns:
        fooof_input_df['Neural_Segment_Duration_Sec'] = fooof_input_df['Neural_Segment_End_Unixtime'] - fooof_input_df['Neural_Segment_Start_Unixtime']
        print("Derived 'Neural_Segment_Duration_Sec' from timestamps.")
    # ... (verification against target_neural_segment_duration) ...
else:
    fooof_input_df['Neural_Segment_Duration_Sec'] = pd.to_numeric(fooof_input_df['Neural_Segment_Duration_Sec'], errors='coerce')

# Add readable DateTime columns
if 'Neural_Segment_Start_Unixtime' in fooof_input_df.columns:
    fooof_input_df['Neural_Segment_Start_DateTime_UTC'] = pd.to_datetime(fooof_input_df['Neural_Segment_Start_Unixtime'], unit='s', utc=True, errors='coerce')
if 'Aligned_PKG_UnixTimestamp' in fooof_input_df.columns: # For clinical state alignment
    fooof_input_df['Aligned_PKG_DateTime_UTC'] = pd.to_datetime(fooof_input_df['Aligned_PKG_UnixTimestamp'], unit='s', utc=True, errors='coerce')


# Extract unique channels and their electrode labels
if 'Channel' in fooof_input_df.columns and 'ElectrodeLabel' in fooof_input_df.columns:
    unique_channels_df = fooof_input_df[['Channel', 'ElectrodeLabel']].drop_duplicates().set_index('Channel')
    electrode_labels = unique_channels_df['ElectrodeLabel'].to_dict()
    print(f"Derived electrode_labels from CSV: {electrode_labels}")
    channels_to_process = list(electrode_labels.keys())
    print(f"Channels to process based on CSV: {channels_to_process}")
else:
    print("Warning: 'Channel' or 'ElectrodeLabel' columns not found. Plotting labels might be generic.")
    channels_to_process = list(fooof_input_df['Channel'].unique()) if 'Channel' in fooof_input_df.columns else []
    if not electrode_labels and channels_to_process:
        electrode_labels = {ch: ch for ch in channels_to_process}


print(f"\nSample of fooof_input_df head after initial processing (first 2 rows):")
print(fooof_input_df.head(2))
print(f"Columns in fooof_input_df: {fooof_input_df.columns.tolist()}")

print("\nCell 3: Data loading, initial parsing, and Daily LEDD calculation complete.")
# --- End of Cell 3 ---


Target CSV file: Step2_Aligned120sPSDs_RCS05L_Left_AllSessions.csv
Target JSON_params file: Step2_params_120sPSDs_RCS05L_Left_AllSessions.json
Successfully loaded CSV from Step 2. Shape: (11368, 15)

Loading parameters from JSON: /home/jackson/step2_final/Step2_params_120sPSDs_RCS05L_Left_AllSessions.json
Parameters JSON loaded successfully.

Using Session ID (Patient-Hemisphere): RCS05L
Using Hemisphere: Left
Using freq_ranges for FOOOF: {'LowFreq': [10, 40], 'MidFreq': [30, 90], 'WideFreq': [10, 90]}
Step 3 session-specific outputs will be in: /home/jackson/step2_final/step3_fooof_results_neural_pkg_aligned_finalstep3_bushlab5000/RCS05L/Left
Master results CSV for RCS05L will be saved/updated at: /home/jackson/step2_final/step3_fooof_results_neural_pkg_aligned_finalstep3_bushlab5000/MASTER_FOOOF_PKG_results_RCS05L_neural_pkg_aligned_finalstep3_bushlab5000.csv
Parsed 'PSD_Data_Str' into numpy arrays.
Parsed 'Frequency_Vector_Str' into numpy arrays (as Frequency_Vector_Raw).
Derived '

In [4]:
# # -*- coding: utf-8 -*-
# # Assume pandas as pd and numpy as np are imported
# # Assume session_id is defined
# # Define REFINED_MOBILE_STATES and AGGREGATED_MOBILE_NAME
# REFINED_MOBILE_STATES = ["Dyskinetic Mobile", "Non-Dyskinetic Mobile", "Transitional Mobile"]
# AGGREGATED_MOBILE_NAME = "Mobile (All Types)"

# print("\n--- Cell 3a: Clinical State Derivation (Point-by-Point Method) ---")

# if 'fooof_input_df' not in locals() or fooof_input_df.empty:
#     print("fooof_input_df is empty or not defined. Skipping state derivation.")
#     if 'fooof_input_df' in locals() and not fooof_input_df.empty:
#         fooof_input_df['Clinical_State_2min_Window'] = "StateProcessingError"
#         fooof_input_df['Clinical_State_Aggregated'] = "StateProcessingError"
# else:
#     required_pkg_cols = ['Aligned_BK', 'Aligned_DK', 'Aligned_PKG_UnixTimestamp']
#     if not all(col in fooof_input_df.columns for col in required_pkg_cols):
#         print(f"ERROR: Missing one or more required PKG columns: {required_pkg_cols}")
#         fooof_input_df['Clinical_State_2min_Window'] = "DataMissingForPKGState"
#         fooof_input_df['Clinical_State_Aggregated'] = "DataMissingForPKGState"
#     else:
#         # Ensure PKG scores are numeric
#         fooof_input_df['Aligned_BK'] = pd.to_numeric(fooof_input_df['Aligned_BK'], errors='coerce')
#         fooof_input_df['Aligned_DK'] = pd.to_numeric(fooof_input_df['Aligned_DK'], errors='coerce')

#         # --- Step 1: Calculate DK Percentiles for Mobile Candidates ---
#         # Define general mobile candidates based on BK <= 26 OR DK >= 7
#         is_general_mobile_candidate = (
#             (fooof_input_df['Aligned_BK'] <= 26) | 
#             (fooof_input_df['Aligned_DK'] >= 7)
#         )
        
#         # Get DK scores for mobile candidates only
#         mobile_candidate_dk_scores = fooof_input_df.loc[
#             is_general_mobile_candidate & fooof_input_df['Aligned_DK'].notna(), 
#             'Aligned_DK'
#         ]
        
#         if len(mobile_candidate_dk_scores) > 1:
#             p30_dk = np.percentile(mobile_candidate_dk_scores, 30)
#             p70_dk = np.percentile(mobile_candidate_dk_scores, 70)
#             print(f"  DK percentiles for mobile candidates: 30th={p30_dk:.2f}, 70th={p70_dk:.2f}")
#         else:
#             print("  Warning: Not enough mobile candidate data to calculate DK percentiles.")
#             p30_dk = -np.inf
#             p70_dk = np.inf

#         # --- Step 2: Apply Point-by-Point State Assignment ---
#         # This method assigns states based on individual PKG values at each timepoint
#         def assign_clinical_state_point(row):
#             """
#             Assigns clinical state based on PKG scores at a single timepoint.
#             This is a point-by-point method without temporal windowing.
#             """
#             bk = row['Aligned_BK']
#             dk = row['Aligned_DK']
            
#             if pd.isna(bk) or pd.isna(dk):
#                 return "Other"
            
#             # Sleep state: BK >= 80
#             if bk >= 80:
#                 return "Sleep"
            
#             # Immobile state: BK > 26 AND BK < 80 AND DK < 7
#             elif (bk > 26) and (bk < 80) and (dk < 7):
#                 return "Immobile"
            
#             # Mobile states: BK <= 26 OR DK >= 7
#             elif (bk <= 26) or (dk >= 7):
#                 # Subdivide mobile states based on DK percentiles
#                 if pd.notna(p30_dk) and pd.notna(p70_dk):
#                     if dk <= p30_dk:
#                         return "Non-Dyskinetic Mobile"
#                     elif dk > p70_dk:
#                         return "Dyskinetic Mobile"
#                     else:  # p30_dk < dk <= p70_dk
#                         return "Transitional Mobile"
#                 else:
#                     # If percentiles couldn't be calculated, use generic mobile
#                     return "Mobile (Generic)"
            
#             else:
#                 return "Other"


#         # Apply the state assignment function to each row
#         fooof_input_df['Clinical_State_2min_Window'] = fooof_input_df.apply(
#             assign_clinical_state_point, axis=1
#         )
#         # After assigning Clinical_State_2min_Window
#         mask_other = fooof_input_df['Clinical_State_2min_Window'] == "Other"
#         if mask_other.any():
#             # Reason tags
#             reasons = np.where(
#                 fooof_input_df.loc[mask_other, 'Aligned_BK'].isna() & fooof_input_df.loc[mask_other, 'Aligned_DK'].isna(),
#                 'BK & DK missing',
#                 np.where(
#                     fooof_input_df.loc[mask_other, 'Aligned_BK'].isna(),
#                     'BK missing',
#                     np.where(fooof_input_df.loc[mask_other, 'Aligned_DK'].isna(), 'DK missing', 'Rule fallback')
#                 )
#             )
#             fooof_input_df.loc[mask_other, 'Other_Reason'] = reasons
        
#             # Print a compact report
#             print("\n  'Other' classification diagnostics:")
#             print(fooof_input_df.loc[mask_other, ['Aligned_PKG_UnixTimestamp','Channel','ElectrodeLabel','Aligned_BK','Aligned_DK','Other_Reason']])
        
#             print("\n  'Other' breakdown:")
#             print(fooof_input_df.loc[mask_other, 'Other_Reason'].value_counts())
#         else:
#             print("\n  No rows classified as 'Other'.")

#         # 🔍 NEW: Print all rows classified as "Other"
#         other_rows = fooof_input_df[fooof_input_df['Clinical_State_2min_Window'] == "Other"]
#         if not other_rows.empty:
#             print(f"\n  Found {len(other_rows)} rows classified as 'Other':")
#             print(other_rows)
#         else:
#             print("\n  No rows classified as 'Other'.")

        
#         print(f"  Assigned clinical states using point-by-point method.")
#         print(f"  Note: Column is named 'Clinical_State_2min_Window' for compatibility,")
#         print(f"        but uses instantaneous PKG values, not temporal windowing.")

#         # --- Step 3: Create Aggregated Clinical States ---
#         # Group all refined mobile states into one category
#         def create_aggregated_state(state):
#             """
#             Aggregates refined mobile states into a single 'Mobile (All Types)' category.
#             Non-mobile states remain unchanged.
#             """
#             if state in REFINED_MOBILE_STATES:
#                 return AGGREGATED_MOBILE_NAME
#             elif state == "Mobile (Generic)":
#                 # Also aggregate the generic mobile fallback
#                 return AGGREGATED_MOBILE_NAME
#             else:
#                 return state

#         fooof_input_df['Clinical_State_Aggregated'] = fooof_input_df['Clinical_State_2min_Window'].apply(
#             create_aggregated_state
#         )

#         # --- Step 4: Display Results ---
#         print("\n  Clinical state distribution (Clinical_State_2min_Window):")
#         print(fooof_input_df['Clinical_State_2min_Window'].value_counts(dropna=False))
        
#         print("\n  Aggregated clinical state distribution:")
#         print(fooof_input_df['Clinical_State_Aggregated'].value_counts(dropna=False))

#         # --- Step 5: Verify Data Integrity ---
#         n_missing_states = fooof_input_df['Clinical_State_2min_Window'].isna().sum()
#         if n_missing_states > 0:
#             print(f"\n  Warning: {n_missing_states} rows have missing clinical states.")

# # Define clinical state colors for plotting
# CLINICAL_STATE_COLORS = {
#     'Sleep': '#4169E1',                 # RoyalBlue
#     'Immobile': '#40E0D0',              # Turquoise
#     'Non-Dyskinetic Mobile': '#32CD32', # LimeGreen
#     'Transitional Mobile': '#FFD700',   # Gold
#     'Dyskinetic Mobile': '#FF6347',     # Tomato
#     'Mobile (All Types)': 'darkgreen',  # For aggregated view
#     'Mobile (Generic)': 'darkgreen',    # Fallback mobile state
#     'Other': '#C0C0C0',                 # Silver
#     'DataMissingForPKGState': '#F5F5F5',
#     'StateProcessingError': '#E0E0E0'
# }

# print(f"\nClinical state colors defined for plotting: {list(CLINICAL_STATE_COLORS.keys())}")
# print("--- Cell 3a: Clinical State Derivation Complete ---")

In [5]:
# -*- coding: utf-8 -*-
# Assume pandas as pd and numpy as np are imported
# Assume session_id is defined
# Define REFINED_MOBILE_STATES and AGGREGATED_MOBILE_NAME
REFINED_MOBILE_STATES = ["Dyskinetic Mobile", "Non-Dyskinetic Mobile", "Transitional Mobile"]
AGGREGATED_MOBILE_NAME = "Mobile (All Types)"

print("\n--- Cell 3a: Clinical State Derivation (Point-by-Point Method; ABSOLUTE RULES) ---")

# ---- Absolute thresholds (configure here) ----
BK_SLEEP = 80         # Sleep boundary
BK_ON_MAX = 26        # Objective ON gate
DK_NONDYSK_MAX = 7    # <7  = Non-dyskinetic
DK_TRANSITION_MAX = 9 # 7–9 = Transitional; >9 = Dyskinetic

if 'fooof_input_df' not in locals() or fooof_input_df.empty:
    print("fooof_input_df is empty or not defined. Skipping state derivation.")
    if 'fooof_input_df' in locals() and not fooof_input_df.empty:
        fooof_input_df['Clinical_State_2min_Window'] = "StateProcessingError"
        fooof_input_df['Clinical_State_Aggregated'] = "StateProcessingError"
else:
    required_pkg_cols = ['Aligned_BK', 'Aligned_DK', 'Aligned_PKG_UnixTimestamp']
    if not all(col in fooof_input_df.columns for col in required_pkg_cols):
        print(f"ERROR: Missing one or more required PKG columns: {required_pkg_cols}")
        fooof_input_df['Clinical_State_2min_Window'] = "DataMissingForPKGState"
        fooof_input_df['Clinical_State_Aggregated'] = "DataMissingForPKGState"
    else:
        # Ensure PKG scores are numeric
        fooof_input_df['Aligned_BK'] = pd.to_numeric(fooof_input_df['Aligned_BK'], errors='coerce')
        fooof_input_df['Aligned_DK'] = pd.to_numeric(fooof_input_df['Aligned_DK'], errors='coerce')

        print(f"  Using ABSOLUTE thresholds -> Sleep BK≥{BK_SLEEP}, ON BK<{BK_ON_MAX}, DK bands: <{DK_NONDYSK_MAX}, {DK_NONDYSK_MAX}–{DK_TRANSITION_MAX}, >{DK_TRANSITION_MAX}")

        # --- Step 1: Apply Point-by-Point State Assignment (ABSOLUTE, ON-gated) ---
        def assign_clinical_state_point(row):
            """
            Absolute rules (no percentiles):
              Sleep: BK ≥ BK_SLEEP
              Immobile/quiet: 26 < BK < 80 AND DK < 7
              Mobile (ON): BK < 26, then:
                 DK < 7      -> Non-Dyskinetic Mobile
                 7 ≤ DK ≤ 9  -> Transitional Mobile
                 DK > 9      -> Dyskinetic Mobile
              All else -> Other (diagnose below)
            """
            bk = row['Aligned_BK']
            dk = row['Aligned_DK']

            if pd.isna(bk) or pd.isna(dk):
                return "Other"

            # Sleep
            if bk >= BK_SLEEP:
                return "Sleep"

            # Immobile / quiet (objective OFF/immobile)
            if (BK_ON_MAX < bk < BK_SLEEP) and (dk < DK_NONDYSK_MAX):
                return "Immobile"

            # ON-gated mobile states
            if bk < BK_ON_MAX:
                if dk < DK_NONDYSK_MAX:
                    return "Non-Dyskinetic Mobile"
                elif dk <= DK_TRANSITION_MAX:
                    return "Transitional Mobile"
                else:  # dk > DK_TRANSITION_MAX
                    return "Dyskinetic Mobile"

            # Anything else that slips through -> Other (e.g., high DK while BK≥26)
            return "Other"

        fooof_input_df['Clinical_State_2min_Window'] = fooof_input_df.apply(assign_clinical_state_point, axis=1)

        # After assigning Clinical_State_2min_Window
        mask_other = fooof_input_df['Clinical_State_2min_Window'] == "Other"
        if mask_other.any():
            # Reason tags for 'Other'
            reasons = np.where(
                fooof_input_df.loc[mask_other, 'Aligned_BK'].isna() & fooof_input_df.loc[mask_other, 'Aligned_DK'].isna(),
                'BK & DK missing',
                np.where(
                    fooof_input_df.loc[mask_other, 'Aligned_BK'].isna(),
                    'BK missing',
                    np.where(
                        fooof_input_df.loc[mask_other, 'Aligned_DK'].isna(),
                        'DK missing',
                        np.where(
                            (fooof_input_df.loc[mask_other, 'Aligned_BK'] >= BK_ON_MAX) & (fooof_input_df.loc[mask_other, 'Aligned_DK'] >= DK_NONDYSK_MAX),
                            'High DK while OFF/immobile (BK≥26)',
                            'Rule fallback'
                        )
                    )
                )
            )
            fooof_input_df.loc[mask_other, 'Other_Reason'] = reasons

            # Compact report
            print("\n  'Other' classification diagnostics:")
            print(fooof_input_df.loc[mask_other, ['Aligned_PKG_UnixTimestamp','Channel','ElectrodeLabel','Aligned_BK','Aligned_DK','Other_Reason']])

            print("\n  'Other' breakdown:")
            print(fooof_input_df.loc[mask_other, 'Other_Reason'].value_counts())
        else:
            print("\n  No rows classified as 'Other'.")

        # 🔍 NEW: Print all rows classified as "Other"
        other_rows = fooof_input_df[fooof_input_df['Clinical_State_2min_Window'] == "Other"]
        if not other_rows.empty:
            print(f"\n  Found {len(other_rows)} rows classified as 'Other'.")
        else:
            print("\n  No rows classified as 'Other'.")

        print(f"\n  Assigned clinical states using ABSOLUTE, ON-gated point-by-point method.")
        print(f"  Note: Column is named 'Clinical_State_2min_Window' for compatibility; no temporal smoothing applied here.")

        # --- Step 2: Create Aggregated Clinical States ---
        def create_aggregated_state(state):
            """
            Aggregates refined mobile states into a single 'Mobile (All Types)' category.
            Non-mobile states remain unchanged.
            """
            if state in REFINED_MOBILE_STATES:
                return AGGREGATED_MOBILE_NAME
            elif state == "Mobile (Generic)":
                return AGGREGATED_MOBILE_NAME
            else:
                return state

        fooof_input_df['Clinical_State_Aggregated'] = fooof_input_df['Clinical_State_2min_Window'].apply(create_aggregated_state)

        # --- Step 3: Display Results ---
        print("\n  Clinical state distribution (Clinical_State_2min_Window):")
        print(fooof_input_df['Clinical_State_2min_Window'].value_counts(dropna=False))

        print("\n  Aggregated clinical state distribution:")
        print(fooof_input_df['Clinical_State_Aggregated'].value_counts(dropna=False))

        # --- Step 4: Verify Data Integrity ---
        n_missing_states = fooof_input_df['Clinical_State_2min_Window'].isna().sum()
        if n_missing_states > 0:
            print(f"\n  Warning: {n_missing_states} rows have missing clinical states.")

# Define clinical state colors for plotting
CLINICAL_STATE_COLORS = {
    'Sleep': '#4169E1',                 # RoyalBlue
    'Immobile': '#40E0D0',              # Turquoise
    'Non-Dyskinetic Mobile': '#32CD32', # LimeGreen
    'Transitional Mobile': '#FFD700',   # Gold
    'Dyskinetic Mobile': '#FF6347',     # Tomato
    'Mobile (All Types)': 'darkgreen',  # For aggregated view
    'Mobile (Generic)': 'darkgreen',    # Fallback mobile state
    'Other': '#C0C0C0',                 # Silver
    'DataMissingForPKGState': '#F5F5F5',
    'StateProcessingError': '#E0E0E0'
}

print(f"\nClinical state colors defined for plotting: {list(CLINICAL_STATE_COLORS.keys())}")
print("--- Cell 3a: Clinical State Derivation Complete ---")



--- Cell 3a: Clinical State Derivation (Point-by-Point Method; ABSOLUTE RULES) ---
  Using ABSOLUTE thresholds -> Sleep BK≥80, ON BK<26, DK bands: <7, 7–9, >9

  'Other' classification diagnostics:
       Aligned_PKG_UnixTimestamp           Channel    ElectrodeLabel  \
0                     1572121800  key0_contact_2_0  key0_contact_2_0   
1                     1572121830  key0_contact_2_0  key0_contact_2_0   
2                     1572121860  key0_contact_2_0  key0_contact_2_0   
3                     1572121890  key0_contact_2_0  key0_contact_2_0   
172                   1572126960  key0_contact_2_0  key0_contact_2_0   
...                          ...               ...               ...   
11338                 1572393210  key1_contact_3_0  key1_contact_3_0   
11339                 1572393240  key1_contact_3_0  key1_contact_3_0   
11340                 1572393270  key1_contact_3_0  key1_contact_3_0   
11351                 1572393600  key1_contact_3_0  key1_contact_3_0   
11360    

In [6]:
import numpy as np, warnings, inspect


def _fit_one(lam):
    freqs = np.linspace(1, 90, 900)
    aper = 1.0 - 1.2*np.log10(freqs)
    bump = np.exp(-0.5*((freqs-20)/4)**2) * 0.35
    rng = np.random.default_rng(0)
    psd_log = aper + bump + rng.normal(0, 0.03, freqs.size)
    psd_lin = 10**psd_log

    fm = FOOOF(aperiodic_mode='knee', peak_width_limits=(2, 12), max_n_peaks=np.inf,
               min_peak_height=0.0, peak_threshold=2.0, verbose=False)
    set_lambda(fm, lam)
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        fm.fit(freqs, psd_lin, freq_range=[10, 40])

    n_peaks = 0 if fm.peak_params_ is None else len(fm.peak_params_)
    widths = [] if fm.gaussian_params_ is None or fm.gaussian_params_.size==0 else (fm.gaussian_params_[:,2]*2.0)
    mean_bw = np.nan if len(widths)==0 else float(np.mean(widths))
    return {
        "lam": lam,
        "R2": fm.r_squared_,
        "Err": fm.error_,
        "NumPeaks": n_peaks,
        "MeanBWHz": mean_bw,
        "reg_seen": getattr(fm, "regularization_weight", None),
        "src": inspect.getfile(fm.__class__),
    }

rows = [_fit_one(0.0), _fit_one(REG_LAMBDA)]
import pandas as pd
print(pd.DataFrame(rows))


      lam        R2       Err  NumPeaks  MeanBWHz  reg_seen  \
0     0.0  0.986766  0.024891         5  3.089145       0.0   
1  5000.0  0.977706  0.032333         5  2.324085    5000.0   

                                                 src  
0  /home/jackson/fooof_specparam_regularisation/s...  
1  /home/jackson/fooof_specparam_regularisation/s...  


In [7]:
# -*- coding: utf-8 -*-
# --- (PARALLEL) Cell 4b: Beta/Gamma Peak Feature Extraction ---
# Calculates Beta_Peak_Power_at_DominantFreq & Gamma_Peak_Power_at_DominantFreq
# Parallelized across segments and channels with joblib.

import os
import warnings
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from joblib import Parallel, delayed
from scipy.signal import find_peaks

# Prefer specparam; fall back to fooof

from specparam.objs import SpectralModel as FOOOF
from specparam.sim.gen import gen_aperiodic
USING_SPECPARAM = True


# Regularization
REG_LAMBDA = float(os.getenv("FOOOF_REG_LAMBDA", "5000"))
def set_lambda(fm, lam=REG_LAMBDA):
    if hasattr(fm, "_model") and hasattr(fm._model, "regularization_weight"):
        fm._model.regularization_weight = lam
    elif hasattr(fm, "regularization_weight"):
        fm.regularization_weight = lam

print("\n--- Cell 4b (Parallel): Starting Beta/Gamma Peak Feature Extraction ---")

# -----------------------------------------------------------
# Prerequisites & defaults (will fall back if not already set)
# -----------------------------------------------------------
if 'fooof_input_df' not in locals() or fooof_input_df.empty:
    print("fooof_input_df is empty or not defined. Skipping Beta/Gamma feature extraction.")
    if 'fooof_input_df' in locals() and not fooof_input_df.empty:
        fooof_input_df['Beta_Peak_Power_at_DominantFreq'] = np.nan
        fooof_input_df['Gamma_Peak_Power_at_DominantFreq'] = np.nan
else:
    # Required columns
    required_cols_for_cell4b = ['PSD_Data', 'Frequency_Vector_Raw', 'SessionID', 'Channel', 'ElectrodeLabel']
    if not all(col in fooof_input_df.columns for col in required_cols_for_cell4b):
        missing_cols = [col for col in required_cols_for_cell4b if col not in fooof_input_df.columns]
        print(f"ERROR: Missing required columns for Beta/Gamma extraction: {missing_cols}. Skipping.")
        fooof_input_df['Beta_Peak_Power_at_DominantFreq'] = np.nan
        fooof_input_df['Gamma_Peak_Power_at_DominantFreq'] = np.nan
    else:
        # Threading knobs (avoid BLAS oversubscription; ideally set in shell before Python)
        os.environ.setdefault("OMP_NUM_THREADS", "1")
        os.environ.setdefault("MKL_NUM_THREADS", "1")
        os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
        os.environ.setdefault("NUMEXPR_NUM_THREADS", "1")

        # Parallel workers
        N_JOBS = int(os.getenv("N_JOBS", max(os.cpu_count() - 1, 1)))
        print(f"[Cell 4b] Using {N_JOBS} parallel workers")

        # Baseline and bands (fallbacks if not in session)
        BETA_BAND = globals().get('BETA_BAND', [13, 30])
        GAMMA_BAND = globals().get('GAMMA_BAND', [60, 90])
        BETA_BASELINE_FREQ_REGION  = globals().get('BETA_BASELINE_FREQ_REGION',  [10, 12])
        GAMMA_BASELINE_FREQ_REGION = globals().get('GAMMA_BASELINE_FREQ_REGION', [55, 59])
        PEAK_SIG_SD_THRESHOLD = globals().get('PEAK_SIG_SD_THRESHOLD', 1.0)

        # FOOOF settings (must exist from Cell 2; else fallback to sane defaults)
        if 'common_new_fooof_params' not in globals():
            common_new_fooof_params = {
                'peak_width_limits': [2.0, 8.0],
                'max_n_peaks': np.inf,
                'min_peak_height': 0.0,
                'peak_threshold': 2.0,
                'verbose': False,
                'regularization_weight': 1e-2,   # <- your new knob; 0.0 keeps stock behavior
            }

        # Output folder for debug plots
        debug_plot_folder = os.path.join(visualization_folder_session_specific, 'cell4b_debug_plots')
        os.makedirs(debug_plot_folder, exist_ok=True)
        print(f"Debug plots for Cell 4b will be saved to: {debug_plot_folder}")

        # Map display labels if needed
        if 'Channel_Display' not in fooof_input_df.columns:
            if 'ElectrodeLabel' in fooof_input_df.columns:
                fooof_input_df['Channel_Display'] = fooof_input_df['ElectrodeLabel'].where(
                    fooof_input_df['ElectrodeLabel'].notna(), fooof_input_df['Channel']
                )
            else:
                fooof_input_df['Channel_Display'] = fooof_input_df['Channel']

        # -----------------------------------------------------------
        # Helpers
        # -----------------------------------------------------------
        def _fit_and_flatten_segment(idx, psd_linear, freqs, fooof_params):
            """Fit FOOOF(knee) on [10,90] Hz; return flattened log spectrum + baseline means."""
            try:
                if psd_linear is None or freqs is None:
                    return None
                if len(psd_linear) == 0 or len(freqs) == 0:
                    return None

                psd_linear = np.asarray(psd_linear)
                freqs = np.asarray(freqs)
                psd_log = np.log10(np.clip(psd_linear, 1e-16, None))

                fm = FOOOF(**{**fooof_params, 'aperiodic_mode': 'knee'})
                set_lambda(fm)  # <-- enable regularization before fit
                lam_used = getattr(fm, "regularization_weight", np.nan)
                # (optional) print for the first few segments
                if idx in (0, 1, 2):
                    print(f"[4b] reg marker: {getattr(fm, '_reg_marker', '')}  (used {lam_used})")

                with warnings.catch_warnings():
                    warnings.filterwarnings('ignore', "divide by zero encountered in power", RuntimeWarning)
                    warnings.filterwarnings('ignore', "invalid value encountered in log10", RuntimeWarning)
                    fm.fit(freqs, psd_linear, freq_range=[10, 90])

                if not fm.has_model:
                    return None

                # IMPORTANT: aperiodic in same mode we fit
                ap_fit_log = gen_aperiodic(freqs, fm.aperiodic_params_, aperiodic_mode=fm.aperiodic_mode)
                flat = psd_log - ap_fit_log

                # Baseline means
                gamma_mask = (freqs >= GAMMA_BASELINE_FREQ_REGION[0]) & (freqs <= GAMMA_BASELINE_FREQ_REGION[1])
                beta_mask  = (freqs >= BETA_BASELINE_FREQ_REGION[0])  & (freqs <= BETA_BASELINE_FREQ_REGION[1])
                gamma_base = np.nanmean(flat[gamma_mask]) if np.any(gamma_mask) else np.nan
                beta_base  = np.nanmean(flat[beta_mask])  if np.any(beta_mask)  else np.nan

                return {'idx': int(idx), 'freqs': freqs, 'flat': flat,
                        'gamma_base': gamma_base, 'beta_base': beta_base}
            except Exception:
                return None

        def _peaks_from_flat(flat_rec, beta_thr, gamma_thr):
            """Find beta/gamma peak frequencies in a flattened spectrum given thresholds."""
            freqs = flat_rec['freqs']; flat = flat_rec['flat']
            beta_freq, gamma_freq = np.nan, np.nan

            # Beta
            if not np.isnan(beta_thr):
                mask = (freqs >= BETA_BAND[0]) & (freqs <= BETA_BAND[1])
                if np.any(mask):
                    vals = flat[mask]
                    good = ~np.isnan(vals)
                    if np.any(good):
                        peaks, props = find_peaks(vals[good], height=beta_thr, prominence=0.02)
                        if peaks.size > 0:
                            fvec = freqs[mask][good]
                            beta_freq = fvec[peaks[np.argmax(props['peak_heights'])]]

            # Gamma
            if not np.isnan(gamma_thr):
                mask = (freqs >= GAMMA_BAND[0]) & (freqs <= GAMMA_BAND[1])
                if np.any(mask):
                    vals = flat[mask]
                    good = ~np.isnan(vals)
                    if np.any(good):
                        peaks, props = find_peaks(vals[good], height=gamma_thr, prominence=0.02)
                        if peaks.size > 0:
                            fvec = freqs[mask][good]
                            gamma_freq = fvec[peaks[np.argmax(props['peak_heights'])]]

            return beta_freq, gamma_freq

        def _dominant_mode(freqs_array, q=0.5):
            """Quantize to q-Hz bins and return the modal frequency (NaN if empty)."""
            if freqs_array.size == 0:
                return np.nan
            qfreqs = np.round(freqs_array / q) * q
            vals, counts = np.unique(qfreqs, return_counts=True)
            return vals[np.argmax(counts)]

        def _extract_power(idx, psd_linear, freqs, dom_beta, dom_gamma):
            """Return (idx, beta_log_power, gamma_log_power) for one segment."""
            if psd_linear is None or freqs is None:
                return idx, np.nan, np.nan
            psd_linear = np.asarray(psd_linear)
            freqs = np.asarray(freqs)
            if psd_linear.size == 0 or freqs.size == 0:
                return idx, np.nan, np.nan
            psd_log = np.log10(np.clip(psd_linear, 1e-16, None))
            b = psd_log[np.argmin(np.abs(freqs - dom_beta))]  if not pd.isna(dom_beta)  else np.nan
            g = psd_log[np.argmin(np.abs(freqs - dom_gamma))] if not pd.isna(dom_gamma) else np.nan
            return idx, b, g

        # -----------------------------------------------------------
        # Per-channel processing (parallel inside each channel)
        # -----------------------------------------------------------
        dominant_beta_freqs_channel = {}
        dominant_gamma_freqs_channel = {}

        print("  Calculating dominant peak frequencies per channel (from flattened spectra)...")
        unique_channels_for_dom_freq = fooof_input_df['Channel'].unique()

        for channel_key in tqdm(unique_channels_for_dom_freq, desc="Processing Channels for Dom. Freq"):
            df_channel = fooof_input_df[fooof_input_df['Channel'] == channel_key].copy().reset_index()
            el_label_for_channel = df_channel['ElectrodeLabel'].iloc[0]

            # Representative indices for debug plots
            num_segments_in_channel = len(df_channel)
            idxs = []
            if num_segments_in_channel > 0: idxs.append(0)
            if num_segments_in_channel > 2: idxs.append(num_segments_in_channel // 2)
            if num_segments_in_channel > 1: idxs.append(num_segments_in_channel - 1)
            indices_to_plot = sorted(list(set(idxs)))

            # PASS 1: Parallel FOOOF fits -> flattened spectra + baseline means
            results = Parallel(n_jobs=N_JOBS, backend='loky')(
                delayed(_fit_and_flatten_segment)(
                    int(i), row['PSD_Data'], row['Frequency_Vector_Raw'], common_new_fooof_params
                ) for i, row in df_channel.iterrows()
            )
            results = [r for r in results if r is not None]
            if not results:
                dominant_beta_freqs_channel[channel_key] = np.nan
                dominant_gamma_freqs_channel[channel_key] = np.nan
                print(f"    Channel {el_label_for_channel}: No dominant beta/gamma peak found (no flattened spectra).")
                continue

            # Channel-level baselines and thresholds
            gamma_bases = np.array([r['gamma_base'] for r in results], dtype=float)
            beta_bases  = np.array([r['beta_base']  for r in results], dtype=float)
            mean_gamma = np.nanmean(gamma_bases) if gamma_bases.size else np.nan
            std_gamma  = np.nanstd(gamma_bases)  if gamma_bases.size else np.nan
            mean_beta  = np.nanmean(beta_bases)  if beta_bases.size  else np.nan
            std_beta   = np.nanstd(beta_bases)   if beta_bases.size  else np.nan

            if np.isnan(mean_beta) or np.isnan(std_beta):
                print(f"    Channel {el_label_for_channel}: Beta baseline NaN — skipping beta peak search.")
            if np.isnan(mean_gamma) or np.isnan(std_gamma):
                print(f"    Channel {el_label_for_channel}: Gamma baseline NaN — skipping gamma peak search.")

            beta_thr  = mean_beta  + PEAK_SIG_SD_THRESHOLD * std_beta   if not (np.isnan(mean_beta)  or np.isnan(std_beta))  else np.nan
            gamma_thr = mean_gamma + PEAK_SIG_SD_THRESHOLD * std_gamma  if not (np.isnan(mean_gamma) or np.isnan(std_gamma)) else np.nan

            # PASS 2: Parallel peak detection per flattened spectrum
            peak_pairs = Parallel(n_jobs=N_JOBS, backend='loky')(
                delayed(_peaks_from_flat)(r, beta_thr, gamma_thr) for r in results
            )

            beta_list  = np.array([p[0] for p in peak_pairs if not np.isnan(p[0])], dtype=float)
            gamma_list = np.array([p[1] for p in peak_pairs if not np.isnan(p[1])], dtype=float)

            dominant_beta_freqs_channel[channel_key]  = _dominant_mode(beta_list, q=0.5)
            dominant_gamma_freqs_channel[channel_key] = _dominant_mode(gamma_list, q=0.5)

            # Debug plots (serial)
            for plot_idx in indices_to_plot:
                rec = next((r for r in results if r['idx'] == plot_idx), None)
                if rec is None: continue
                freqs, flat = rec['freqs'], rec['flat']
                plt.figure(figsize=(12, 5))
                plt.plot(freqs, flat, color='k', label='Flattened')
                if not np.isnan(beta_thr):
                    plt.axvspan(BETA_BAND[0], BETA_BAND[1], color='red', alpha=0.1, label='Beta band')
                    plt.axhline(beta_thr, linestyle='--', color='r', label='Beta thr')
                if not np.isnan(gamma_thr):
                    plt.axvspan(GAMMA_BAND[0], GAMMA_BAND[1], color='green', alpha=0.1, label='Gamma band')
                    plt.axhline(gamma_thr, linestyle='--', color='g', label='Gamma thr')
                plt.title(f'Debug: Flattened Spectrum - Ch: {el_label_for_channel} (Seg {plot_idx})')
                plt.xlabel('Frequency (Hz)'); plt.ylabel('Log Power (aperiodic removed)')
                plt.xlim(0, 100); plt.grid(True, linestyle=':', alpha=0.6); plt.legend()
                plt.savefig(os.path.join(debug_plot_folder, f'Debug_{el_label_for_channel}_Seg{plot_idx}_Parallel.png'))
                plt.close()

            # Console summary
            if pd.isna(dominant_beta_freqs_channel[channel_key]):
                print(f"    Channel {el_label_for_channel}: No dominant beta peak found.")
            else:
                print(f"    Channel {el_label_for_channel}: Dominant beta ≈ {dominant_beta_freqs_channel[channel_key]:.2f} Hz.")
            if pd.isna(dominant_gamma_freqs_channel[channel_key]):
                print(f"    Channel {el_label_for_channel}: No dominant gamma peak found.")
            else:
                print(f"    Channel {el_label_for_channel}: Dominant gamma ≈ {dominant_gamma_freqs_channel[channel_key]:.2f} Hz.")

        # -----------------------------------------------------------
        # Extract power at dominant frequencies (optionally parallel)
        # -----------------------------------------------------------
        fooof_input_df['Beta_Peak_Power_at_DominantFreq']  = np.nan
        fooof_input_df['Gamma_Peak_Power_at_DominantFreq'] = np.nan

        print("  Extracting power at dominant frequencies from original log-scaled PSDs...")
        # Build tasks
        tasks = []
        for i, row_seg in fooof_input_df.iterrows():
            ch = row_seg['Channel']
            dom_beta  = dominant_beta_freqs_channel.get(ch, np.nan)
            dom_gamma = dominant_gamma_freqs_channel.get(ch, np.nan)
            tasks.append((i, row_seg['PSD_Data'], row_seg['Frequency_Vector_Raw'], dom_beta, dom_gamma))

        results_pow = Parallel(n_jobs=N_JOBS, backend='loky')(
            delayed(_extract_power)(i, psd, freqs, db, dg) for (i, psd, freqs, db, dg) in tasks
        )

        for idx, b, g in results_pow:
            fooof_input_df.at[idx, 'Beta_Peak_Power_at_DominantFreq']  = b
            fooof_input_df.at[idx, 'Gamma_Peak_Power_at_DominantFreq'] = g

        # Done
        print("  Finished extracting Beta/Gamma peak powers.")
        print(f"  Example Beta Peak Powers:  {fooof_input_df['Beta_Peak_Power_at_DominantFreq'].dropna().unique()[:5]}")
        print(f"  Example Gamma Peak Powers: {fooof_input_df['Gamma_Peak_Power_at_DominantFreq'].dropna().unique()[:5]}")

print("\n--- Cell 4b (Parallel): Beta/Gamma Peak Feature Extraction Complete ---")



--- Cell 4b (Parallel): Starting Beta/Gamma Peak Feature Extraction ---
[Cell 4b] Using 79 parallel workers
Debug plots for Cell 4b will be saved to: /home/jackson/step2_final/step3_fooof_results_neural_pkg_aligned_finalstep3_bushlab5000/RCS05L/Left/visualizations_step3/cell4b_debug_plots
  Calculating dominant peak frequencies per channel (from flattened spectra)...


Processing Channels for Dom. Freq:   0%|          | 0/4 [00:00<?, ?it/s]

  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - 

    Channel key0_contact_2_0: Dominant beta ≈ 19.50 Hz.
    Channel key0_contact_2_0: Dominant gamma ≈ 89.00 Hz.


  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - 

    Channel key2_contact_10_8: Dominant beta ≈ 27.50 Hz.
    Channel key2_contact_10_8: Dominant gamma ≈ 62.00 Hz.


  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - 

    Channel key3_contact_11_9: Dominant beta ≈ 27.50 Hz.
    Channel key3_contact_11_9: Dominant gamma ≈ 62.00 Hz.


  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - np.log10(knee + xs**exp)
  ys = offset - 

    Channel key1_contact_3_0: Dominant beta ≈ 19.50 Hz.
    Channel key1_contact_3_0: Dominant gamma ≈ 88.50 Hz.
  Extracting power at dominant frequencies from original log-scaled PSDs...
  Finished extracting Beta/Gamma peak powers.
  Example Beta Peak Powers:  [-6.00656092 -6.04660339 -6.09718119 -6.14001679 -6.30258945]
  Example Gamma Peak Powers: [-7.89965292 -7.93988907 -7.96394622 -7.88589485 -7.87125151]

--- Cell 4b (Parallel): Beta/Gamma Peak Feature Extraction Complete ---


In [8]:
# -*- coding: utf-8 -*-
# --- (PARALLEL) Cell 4: Primary FOOOF Aperiodic Fitting & Hump Analysis w/ Progress ---

import time, json, os, sys, warnings
import numpy as np
import pandas as pd
from joblib import Parallel, delayed
# Prefer specparam; fall back to fooof

from specparam.objs import SpectralModel as FOOOF
from specparam.sim.gen import gen_aperiodic
USING_SPECPARAM = True

from tqdm.auto import tqdm


# tqdm-joblib progress hook (graceful no-op fallback if not installed)
try:
    from tqdm_joblib import tqdm_joblib
except Exception:
    from contextlib import contextmanager
    @contextmanager
    def tqdm_joblib(*args, **kwargs):
        yield

print("Starting Primary FOOOF Aperiodic Fitting & Hump Analysis (Parallel+Progress) ---")
start_time_cell4_modified = time.time()

# Avoid BLAS over-threading (each worker uses 1 thread)
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")
os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
os.environ.setdefault("NUMEXPR_NUM_THREADS", "1")

# Use all but one core by default (override with env N_JOBS)
N_JOBS = int(os.getenv("N_JOBS", max(os.cpu_count() - 1, 1)))
print(f"[Cell 4] Using {N_JOBS} parallel workers")

# --- Helper: contiguous oscillatory humps from a fitted FOOOF model ---
def find_oscillatory_humps(fm):
    if not fm.has_model:
        return []
    ap_fit = gen_aperiodic(fm.freqs, fm.aperiodic_params_, aperiodic_mode=fm.aperiodic_mode)
    periodic = fm.modeled_spectrum_ - ap_fit  # both in log10 power
    positive_idx = np.where(periodic > 1e-2)[0]  # ~2.3% above background
    if positive_idx.size == 0:
        return []
    humps, gaps = [], np.where(np.diff(positive_idx) > 1)[0]
    start_ix = 0
    for g in gaps:
        idxs = positive_idx[start_ix:g+1]
        max_ix = idxs[np.argmax(periodic[idxs])]
        humps.append({
            'hump_start_freq': fm.freqs[idxs[0]],
            'hump_end_freq': fm.freqs[idxs[-1]],
            'hump_width': fm.freqs[idxs[-1]] - fm.freqs[idxs[0]],
            'hump_max_power_freq': fm.freqs[max_ix],
            'hump_max_power_val': periodic[max_ix],
        })
        start_ix = g + 1
    last = positive_idx[start_ix:]
    if last.size > 0:
        max_ix = last[np.argmax(periodic[last])]
        humps.append({
            'hump_start_freq': fm.freqs[last[0]],
            'hump_end_freq': fm.freqs[last[-1]],
            'hump_width': fm.freqs[last[-1]] - fm.freqs[last[0]],
            'hump_max_power_freq': fm.freqs[max_ix],
            'hump_max_power_val': periodic[max_ix],
        })
    return humps

# --- Preconditions ---
if 'fooof_input_df' not in locals() or fooof_input_df.empty:
    sys.exit("ERROR: fooof_input_df not defined or empty in Cell 4. Run previous cells.")
if 'ITERATIVE_FREQ_BANDS' not in locals() or not isinstance(ITERATIVE_FREQ_BANDS, dict):
    sys.exit("ERROR: ITERATIVE_FREQ_BANDS not defined. Run previous cells.")

# --- Settings fallbacks (if not already defined upstream) ---
if 'common_new_fooof_params' not in globals():
    common_new_fooof_params = {
        'peak_width_limits': [2.0, 8.0],
        'max_n_peaks': np.inf,
        'min_peak_height': 0.0,
        'peak_threshold': 2.0,
        'verbose': False,
    }
if 'basic_fooof_settings' not in globals():
    basic_fooof_settings = {**common_new_fooof_params, 'aperiodic_mode': 'fixed'}
if 'knee_fooof_settings' not in globals():
    knee_fooof_settings = {**common_new_fooof_params, 'aperiodic_mode': 'knee'}

mode_to_settings = {'fixed': basic_fooof_settings, 'knee': knee_fooof_settings}

# --- Containers ---
collected_hump_results_cell4 = []
fine_grain_aperiodic_results_cell4 = []
all_raw_psds_for_averaging_cell4 = []

print(f"Processing {len(fooof_input_df)} PSD segments from fooof_input_df (for aperiodic params & humps).")

# --- Pre-collect raw PSD metadata (serial; cheap) ---
def _raw_entry_from_row(row):
    psd = row['PSD_Data']; freqs = row['Frequency_Vector_Raw']
    if not isinstance(psd, np.ndarray) or psd.size == 0: return None
    if not isinstance(freqs, np.ndarray) or freqs.size == 0: return None
    if psd.size != freqs.size: return None
    ch = row['Channel']
    el = electrode_labels.get(ch, ch) if 'electrode_labels' in globals() else ch
    return {
        'timestamp_unix': row.get('Neural_Segment_Start_Unixtime'),
        'datetime_utc': row.get('Neural_Segment_Start_DateTime_UTC', pd.NaT),
        'channel': ch, 'electrode_label': el,
        'freqs': freqs.copy(), 'psd': psd.copy()
    }

for _, _r in fooof_input_df.iterrows():
    e = _raw_entry_from_row(_r)
    if e is not None:
        all_raw_psds_for_averaging_cell4.append(e)

# --- Parallel task: process one segment across all bands × modes ---
def _process_segment(row):
    try:
        psd = row['PSD_Data']; freqs = row['Frequency_Vector_Raw']
        if not (isinstance(psd, np.ndarray) and isinstance(freqs, np.ndarray)): return [], []
        if psd.size == 0 or freqs.size == 0 or psd.size != freqs.size: return [], []

        ch = row['Channel']
        el = electrode_labels.get(ch, ch) if 'electrode_labels' in globals() else ch
        ts_pkg = row.get('Aligned_PKG_UnixTimestamp')

        ap_list, hump_list = [], []
        for band_label, band_range in ITERATIVE_FREQ_BANDS.items():
            for ap_mode_key, settings in mode_to_settings.items():
                fm = FOOOF(**settings)
                set_lambda(fm)  # <-- enable regularization
                lam_used = getattr(fm, "regularization_weight", np.nan)

                try:
                    with warnings.catch_warnings():
                        warnings.filterwarnings('ignore', "divide by zero encountered in power", RuntimeWarning)
                        warnings.filterwarnings('ignore', "invalid value encountered in log10", RuntimeWarning)
                        fm.fit(freqs, psd, freq_range=band_range)
                except Exception:
                    # record a failed fit row (optional but keeps downstream merges stable)
                    ap_list.append({
                        'timestamp_unix': ts_pkg, 'channel': ch, 'electrode_label': el,
                        'freq_band_label': band_label, 'aperiodic_mode': ap_mode_key,
                        'r_squared': np.nan, 'fit_error': np.nan,
                        'aperiodic_offset': np.nan, 'aperiodic_exponent': np.nan,
                        'aperiodic_knee': np.nan, 'ap_reg_lambda': lam_used,
                        'num_model_peaks': 0
                    })
                    continue

                if not fm.has_model:
                    ap_list.append({
                        'timestamp_unix': ts_pkg, 'channel': ch, 'electrode_label': el,
                        'freq_band_label': band_label, 'aperiodic_mode': ap_mode_key,
                        'r_squared': np.nan, 'fit_error': np.nan,
                        'aperiodic_offset': np.nan, 'aperiodic_exponent': np.nan,
                        'aperiodic_knee': np.nan, 'ap_reg_lambda': lam_used,
                        'num_model_peaks': 0
                    })
                    continue

                ap_params = fm.aperiodic_params_
                ap_list.append({
                    'timestamp_unix': ts_pkg, 'channel': ch, 'electrode_label': el,
                    'freq_band_label': band_label, 'aperiodic_mode': ap_mode_key,
                    'r_squared': fm.r_squared_, 'fit_error': fm.error_,
                    'aperiodic_offset': ap_params[0],
                    'aperiodic_exponent': ap_params[1] if ap_mode_key == 'fixed' else ap_params[2],
                    'aperiodic_knee': ap_params[1] if ap_mode_key == 'knee' else np.nan,'ap_reg_lambda': lam_used,
                    'num_model_peaks': len(fm.peak_params_) if (fm.peak_params_ is not None and getattr(fm.peak_params_, "ndim", 0) == 2) else 0
                })

                base = {
                    'timestamp_unix': ts_pkg, 'channel': ch, 'electrode_label': el,
                    'session_id': session_id, 'hemisphere': neural_hemisphere,
                    'freq_band_label': band_label, 'aperiodic_mode': ap_mode_key,
                    'r_squared': fm.r_squared_, 'fit_error': fm.error_,
                }
                humps = find_oscillatory_humps(fm)
                if humps:
                    for h in humps:
                        row_h = base.copy(); row_h.update(h); hump_list.append(row_h)
                else:
                    row_h = base.copy()
                    row_h.update({'hump_start_freq': np.nan, 'hump_end_freq': np.nan, 'hump_width': np.nan,
                                  'hump_max_power_freq': np.nan, 'hump_max_power_val': np.nan})
                    hump_list.append(row_h)

        return ap_list, hump_list
    except Exception:
        return [], []

# --- Run in parallel with a progress bar (1 tick per segment) ---
with tqdm_joblib(tqdm(desc="Main FOOOF Fitting (segments)", total=len(fooof_input_df))):
    results = Parallel(n_jobs=N_JOBS, backend='loky')(
        delayed(_process_segment)(row) for _, row in fooof_input_df.iterrows()
    )

# --- Collect results ---
for ap_list, hump_list in results:
    if ap_list:  fine_grain_aperiodic_results_cell4.extend(ap_list)
    if hump_list: collected_hump_results_cell4.extend(hump_list)

# --- Save hump results ---
df_hump_results_cell4 = pd.DataFrame(collected_hump_results_cell4)
if not df_hump_results_cell4.empty:
    n_humps = df_hump_results_cell4['hump_width'].notna().sum()
    print(f"\nSuccessfully processed and found {n_humps} oscillatory humps (out of {len(df_hump_results_cell4)} total entries).")
    hump_results_filename_cell4 = os.path.join(
        DATA_OUTPUT_PATH_NEW_CELLS, f"{session_id}_{neural_hemisphere}_fooof_hump_results_from_cell4.csv"
    )
    df_hump_results_cell4.to_csv(hump_results_filename_cell4, index=False)
    print(f"Saved oscillatory hump results to: {hump_results_filename_cell4}")
else:
    print("\nNo hump results generated in Cell 4.")

# --- Save aperiodic fit results ---
df_fine_grain_results_cell4 = pd.DataFrame(fine_grain_aperiodic_results_cell4)
if not df_fine_grain_results_cell4.empty:
    print(f"\nSuccessfully collected {len(df_fine_grain_results_cell4)} aperiodic model fits (fixed & knee per band).")
    fine_grain_results_filepath_cell4 = os.path.join(
        DATA_OUTPUT_PATH_NEW_CELLS, f"{session_id}_{neural_hemisphere}_fooof_aperiodic_fits_from_cell4.csv"
    )
    df_fine_grain_results_cell4.to_csv(fine_grain_results_filepath_cell4, index=False)
    print(f"Saved fine-grained aperiodic results to: {fine_grain_results_filepath_cell4}")
else:
    print("\nNo fine-grain aperiodic results generated in Cell 4.")

# --- Save raw PSDs for later averaging (Cells 5b/5c) ---
df_all_raw_psds_cell4 = pd.DataFrame(all_raw_psds_for_averaging_cell4)
if not df_all_raw_psds_cell4.empty:
    print(f"\nCollected {len(df_all_raw_psds_cell4)} raw PSD segments for potential averaging.")
    # Convert arrays to JSON strings for parquet
    df_all_raw_psds_cell4['freqs'] = df_all_raw_psds_cell4['freqs'].apply(
        lambda x: json.dumps(x.tolist()) if isinstance(x, np.ndarray) else json.dumps(x)
    )
    df_all_raw_psds_cell4['psd'] = df_all_raw_psds_cell4['psd'].apply(
        lambda x: json.dumps(x.tolist()) if isinstance(x, np.ndarray) else json.dumps(x)
    )
    raw_psds_filename_cell4 = os.path.join(
        DATA_OUTPUT_PATH_NEW_CELLS, f"{session_id}_{neural_hemisphere}_raw_psds_for_averaging_from_cell4.parquet"
    )
    df_all_raw_psds_cell4.to_parquet(raw_psds_filename_cell4, index=False, engine='fastparquet')
    print(f"Saved raw PSDs for averaging to: {raw_psds_filename_cell4}")
else:
    print("\nNo raw PSDs collected in Cell 4.")

end_time_cell4_modified = time.time()
print(f"\nCell 4 execution time: {end_time_cell4_modified - start_time_cell4_modified:.2f} seconds.")
print("--- Cell 4 (Parallel+Progress): Primary FOOOF Aperiodic Fitting & Hump Analysis Complete ---")
# --- End of Cell 4 (Parallel+Progress) ---


Starting Primary FOOOF Aperiodic Fitting & Hump Analysis (Parallel+Progress) ---
[Cell 4] Using 79 parallel workers
Processing 11368 PSD segments from fooof_input_df (for aperiodic params & humps).


Main FOOOF Fitting (segments):   0%|          | 0/11368 [00:00<?, ?it/s]

  0%|          | 0/11368 [00:00<?, ?it/s]


Successfully processed and found 145840 oscillatory humps (out of 145876 total entries).
Saved oscillatory hump results to: /home/jackson/step2_final/step3_fooof_results_neural_pkg_aligned_finalstep3_bushlab5000/RCS05L/Left/new_analysis_data/RCS05L_Left_fooof_hump_results_from_cell4.csv

Successfully collected 68208 aperiodic model fits (fixed & knee per band).
Saved fine-grained aperiodic results to: /home/jackson/step2_final/step3_fooof_results_neural_pkg_aligned_finalstep3_bushlab5000/RCS05L/Left/new_analysis_data/RCS05L_Left_fooof_aperiodic_fits_from_cell4.csv

Collected 11368 raw PSD segments for potential averaging.
Saved raw PSDs for averaging to: /home/jackson/step2_final/step3_fooof_results_neural_pkg_aligned_finalstep3_bushlab5000/RCS05L/Left/new_analysis_data/RCS05L_Left_raw_psds_for_averaging_from_cell4.parquet

Cell 4 execution time: 312.87 seconds.
--- Cell 4 (Parallel+Progress): Primary FOOOF Aperiodic Fitting & Hump Analysis Complete ---


In [9]:
# -*- coding: utf-8 -*-
# --- Cell 5a1: Processing and Summarizing Hump Data ---

import pandas as pd
import os
import sys  # For sys.exit
import fastparquet
print("\n--- Cell 5a1: Starting Hump Data Processing ---")

# Load Data Generated by MODIFIED Cell 4
hump_results_input_filename = os.path.join(
    DATA_OUTPUT_PATH_NEW_CELLS,
    f"{session_id}_{neural_hemisphere}_fooof_hump_results_from_cell4.csv"
)
summary_stats_output_filename = os.path.join(
    DATA_OUTPUT_PATH_NEW_CELLS,
    f"{session_id}_{neural_hemisphere}_hump_width_summary_stats_from_cell5a1.csv"
)

if not os.path.exists(hump_results_input_filename):
    print(f"Warning: Hump results file from MODIFIED Cell 4 not found: {hump_results_input_filename}. Skipping Cell 5a1.")
    df_hump_results_for_5a1 = pd.DataFrame()
else:
    try:
        df_hump_results_for_5a1 = pd.read_csv(hump_results_input_filename)
        print(f"Successfully loaded hump results from: {hump_results_input_filename}")
    except FileNotFoundError:
        print(f"ERROR: Hump results file not found at {hump_results_input_filename}. Cannot create plots. Please run MODIFIED Cell 4.")
        sys.exit()

    if not df_hump_results_for_5a1.empty and 'hump_width' in df_hump_results_for_5a1.columns:
        print("\nSummary Statistics for Oscillatory Hump Width (Hz):")
        group_by_cols_5a1 = ['electrode_label', 'freq_band_label', 'aperiodic_mode']
        if 'electrode_label' not in df_hump_results_for_5a1.columns:
            if 'channel' in df_hump_results_for_5a1.columns:
                print("Warning: 'electrode_label' not found, grouping by 'channel' instead for hump stats.")
                group_by_cols_5a1[0] = 'channel'
            else:
                print("Warning: Neither 'electrode_label' nor 'channel' found for grouping hump stats. Skipping.")
                df_hump_results_for_5a1 = pd.DataFrame()  # Make it empty

        if not df_hump_results_for_5a1.empty:
            summary_df_widths_5a1 = df_hump_results_for_5a1.dropna(subset=['hump_width'])
            if not summary_df_widths_5a1.empty:
                summary_stats_df_5a1 = summary_df_widths_5a1.groupby(
                    group_by_cols_5a1
                )['hump_width'].describe()
                print("Note: '50%' in the table below represents the Median.")
                print(summary_stats_df_5a1)
                summary_stats_df_5a1.to_csv(summary_stats_output_filename)
                print(f"\nSaved hump width summary statistics to: {summary_stats_output_filename}")
            else:
                print("No valid hump widths (non-NaN) found to generate summary statistics.")
    elif df_hump_results_for_5a1.empty:
        print("Loaded hump results DataFrame (df_hump_results_for_5a1) is empty.")
    else:  # df_hump_results_for_5a1 not empty, but 'hump_width' missing
        print("Warning: 'hump_width' column missing in loaded hump results. Cannot generate summary.")

import seaborn as sns
import matplotlib.pyplot as plt

# Only plot if the summary_df_widths_5a1 DataFrame exists and is not empty
if 'summary_df_widths_5a1' in locals() and not summary_df_widths_5a1.empty:
    plot_df = summary_df_widths_5a1.copy()

    # Pick faceter column for rows: prefer 'electrode_label', else 'channel'
    row_faceter = 'electrode_label' if 'electrode_label' in plot_df.columns else (
        'channel' if 'channel' in plot_df.columns else None
    )

    if row_faceter is None:
        print("No faceter ('electrode_label' or 'channel') available for row split. Skipping plot.")
    else:
        plot_df[row_faceter] = plot_df[row_faceter].astype(str)

        # Order columns if the standard labels exist; else use detected unique order
        if "freq_band_label" in plot_df.columns:
            unique_bands = plot_df["freq_band_label"].unique().tolist()
            default_order = ["LowFreq", "MidFreq", "WideFreq"]
            col_order = [b for b in default_order if b in unique_bands] or sorted(unique_bands)
        else:
            print("Missing 'freq_band_label' in plot_df; cannot facet by column.")
            col_order = None

        # Facet by electrode/channel (rows) and frequency band (cols); hue = aperiodic mode
        g = sns.FacetGrid(
            plot_df,
            row=row_faceter,
            col="freq_band_label",
            hue="aperiodic_mode",
            margin_titles=True,
            sharex=True, sharey=False,
            col_order=col_order,
            height=2.6, aspect=1.25
        )
        g.map_dataframe(sns.histplot, x="hump_width", bins=60, alpha=0.7, element="step", stat="count")
        g.add_legend(title="Aperiodic Mode")
        g.set_axis_labels("Oscillatory Hump Width (Hz)", "Count")
        # --- Set custom y-axis limits based on frequency band ---
        for ax in g.axes_dict.values():
            title = ax.get_title()
            if "LowFreq" in title:
                ax.set_ylim(0, 800)
            elif "MidFreq" in title:
                ax.set_ylim(0, 3500)
            elif "WideFreq" in title:
                ax.set_ylim(0, 1600)

        g.set_titles(row_template=f'{row_faceter.replace("_"," ").title()}: {{row_name}}', col_template='{col_name}')
        g.fig.subplots_adjust(top=0.93)
        g.fig.suptitle(
            f"{session_id} - Oscillatory Hump Widths by {row_faceter.replace('_',' ').title()} "
            f"(Histogram per {row_faceter.replace('_',' ').title()}, Band, Model)"
        )

        # ✅ Save the actual FacetGrid figure (electrodes split into subplots)
        hist_plot_path = os.path.join(
            DATA_OUTPUT_PATH_NEW_CELLS,
            f"{session_id}_{neural_hemisphere}_hump_width_hist_split_by_{row_faceter}.png"
        )
        g.fig.savefig(hist_plot_path, dpi=300, bbox_inches="tight")
        plt.close(g.fig)
        print(f"Saved histogram plot to: {hist_plot_path}")

print("\n--- Cell 5a1: Hump Data Processing Complete ---")
# --- End of Cell 5a1 ---



--- Cell 5a1: Starting Hump Data Processing ---
Successfully loaded hump results from: /home/jackson/step2_final/step3_fooof_results_neural_pkg_aligned_finalstep3_bushlab5000/RCS05L/Left/new_analysis_data/RCS05L_Left_fooof_hump_results_from_cell4.csv

Summary Statistics for Oscillatory Hump Width (Hz):
Note: '50%' in the table below represents the Median.
                                                    count       mean  \
electrode_label   freq_band_label aperiodic_mode                       
key0_contact_2_0  LowFreq         fixed            3191.0  17.101222   
                                  knee             4569.0   8.377107   
                  MidFreq         fixed            9442.0   9.060210   
                                  knee            10078.0   8.934759   
                  WideFreq        fixed            5361.0  21.321115   
                                  knee             7404.0  16.488385   
key1_contact_3_0  LowFreq         fixed            2997.0  18.324

In [10]:
# -*- coding: utf-8 -*-
# --- Cell 5a2: Plots of Hump Width ---

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
import sys # For sys.exit
import time # For timing
def _histplot_with_kde(data, color=None, **kwargs):
    """
    Plot histogram + KDE for hump_width column.
    Compatible with seaborn.FacetGrid.map_dataframe.
    """
    import seaborn as sns
    import matplotlib.pyplot as plt

    x = data["hump_width"].dropna().to_numpy()  # Avoid pandas multi-dimensional indexing error

    if len(x) == 0:
        return  # nothing to plot

    # Plot histogram
    sns.histplot(x=x, bins=HISTOGRAM_BINS_5A2, stat="count", color=color, alpha=0.6, **kwargs)

    # Plot KDE
    sns.kdeplot(x=x, color=color, fill=True, alpha=0.3, linewidth=2, **kwargs)
print("\n--- Cell 5a2: Generating Hump Width Plots ---")
start_time_cell5a2 = time.time()

WIDTH_LIMIT_HZ_5A2 = 60.0
HISTOGRAM_BINS_5A2 = 60

hump_results_input_filename_5a2 = os.path.join(DATA_OUTPUT_PATH_NEW_CELLS, f"{session_id}_{neural_hemisphere}_fooof_hump_results_from_cell4.csv")
histogram_save_path_5a2 = os.path.join(FIGURES_OUTPUT_PATH_NEW_CELLS, 'histograms_hump_width_cell5a2') # Specific subfolder
if not os.path.exists(histogram_save_path_5a2): os.makedirs(histogram_save_path_5a2)

if not os.path.exists(hump_results_input_filename_5a2):
    print(f"Warning: Hump results file from MODIFIED Cell 4 not found: {hump_results_input_filename_5a2}. Skipping Cell 5a2.")
    df_hump_results_for_5a2 = pd.DataFrame()
else:
    try:
        df_hump_results_for_5a2 = pd.read_csv(hump_results_input_filename_5a2)
        print(f"Successfully loaded hump results from: {hump_results_input_filename_5a2}")
    except FileNotFoundError:
        print(f"ERROR: Hump results file not found at {hump_results_input_filename_5a2}. Cannot create plots.")
        sys.exit()

if not df_hump_results_for_5a2.empty and 'hump_width' in df_hump_results_for_5a2.columns:
    print(f"\nGenerating histograms with universal smoothing...")
    df_plot_5a2 = df_hump_results_for_5a2.dropna(subset=['hump_width']).copy()
    df_plot_5a2 = df_plot_5a2[df_plot_5a2['hump_width'] <= WIDTH_LIMIT_HZ_5A2]

    hue_col_5a2 = None
    if 'electrode_label' in df_plot_5a2.columns:
        hue_col_5a2 = 'electrode_label'
    elif 'channel' in df_plot_5a2.columns:
        hue_col_5a2 = 'channel'
    else:
        print("Warning: Neither 'electrode_label' nor 'channel' found for hue in hump width plots. Plots may be incorrect.")

    if not df_plot_5a2.empty and hue_col_5a2:
        


        g = sns.FacetGrid(
            df_plot_5a2,
            col="freq_band_label", row="aperiodic_mode", 
            hue=hue_col_5a2,
            margin_titles=True, height=4, aspect=1.2, 
            sharey=True, legend_out=True
        )
        g.map_dataframe(_histplot_with_kde)


        g.set_axis_labels("Oscillatory Hump Width (Hz)", "Count")
        g.set_titles(col_template="{col_name}", row_template="{row_name}")
        g.set(xlim=(0, WIDTH_LIMIT_HZ_5A2))
        g.add_legend(title=hue_col_5a2.replace('_', ' ').title()) # Make legend title nicer
        g.fig.suptitle(f'{session_id} - Distribution of Oscillatory Hump Widths (Humps from Cell 4)', y=1.03)
        # g.tight_layout(rect=[0, 0, 1, 0.97])
        # --- Set custom y-axis limits based on frequency band ---
        # --- Set custom y-axis limits by FacetGrid column (band) ---
        ylim_by_band = {
            "LowFreq":  (0, 800),
            "MidFreq":  (0, 3500),
            "WideFreq": (0, 1600),
        }
        
        n_rows = len(getattr(g, "row_names", [])) or 1
        n_cols = len(g.col_names)
        
        # Loop over every (row, col) pair
        for i in range(n_rows):
            for j, col_name in enumerate(g.col_names):
                if col_name in ylim_by_band:
                    y0, y1 = ylim_by_band[col_name]
                    ax = g.facet_axis(i, j)  # safe way to grab each subplot
                    ax.set_ylim(y0, y1)

        plot_filename_final_5a2 = os.path.join(histogram_save_path_5a2, f"{session_id}_smoothed_hump_width_histogram_cell5a2.png")
        g.savefig(plot_filename_final_5a2, dpi=300) # Save with 300 dpi for better quality
        plt.close(g.fig)
        print(f"Presentation-ready histogram saved to: {plot_filename_final_5a2}")
    elif not hue_col_5a2:
        print("Skipping hump width plot generation as no suitable hue column was found.")
    else: # df_plot_5a2 is empty
        print("No data left to plot for hump widths after filtering.")
elif df_hump_results_for_5a2.empty:
    print("Loaded hump results DataFrame (df_hump_results_for_5a2) is empty.")
else:
    print("Warning: 'hump_width' column missing in loaded hump results. Cannot create histograms.")

end_time_cell5a2 = time.time()
print(f"\nCell 5a2 execution time: {end_time_cell5a2 - start_time_cell5a2:.2f} seconds.")
print("--- Cell 5a2: Hump Width Plotting Complete ---")
# --- End of Cell 5a2 ---


--- Cell 5a2: Generating Hump Width Plots ---
Successfully loaded hump results from: /home/jackson/step2_final/step3_fooof_results_neural_pkg_aligned_finalstep3_bushlab5000/RCS05L/Left/new_analysis_data/RCS05L_Left_fooof_hump_results_from_cell4.csv

Generating histograms with universal smoothing...
Presentation-ready histogram saved to: /home/jackson/step2_final/step3_fooof_results_neural_pkg_aligned_finalstep3_bushlab5000/RCS05L/Left/new_analysis_figures/histograms_hump_width_cell5a2/RCS05L_smoothed_hump_width_histogram_cell5a2.png

Cell 5a2 execution time: 5.19 seconds.
--- Cell 5a2: Hump Width Plotting Complete ---


In [11]:
# -*- coding: utf-8 -*-
# --- Cell 5a3 (MODIFIED): Statistical Summary of Peak Metrics by Contact ---

import pandas as pd
import numpy as np
import os

print("\n--- Cell 5a3: Generating Peak Metrics Summary (long/tidy) ---")

# === Inputs/Outputs ===
hump_results_input_filename = os.path.join(
    DATA_OUTPUT_PATH_NEW_CELLS,
    f"{session_id}_{neural_hemisphere}_fooof_hump_results_from_cell4.csv"
)

# Wide summary (per metric with columns of stats)
summary_wide_out = os.path.join(
    DATA_OUTPUT_PATH_NEW_CELLS,
    f"{session_id}_{neural_hemisphere}_hump_metrics_summary_contactwise_cell5a3.csv"
)

# Long/tidy summary (preferred for comparisons in 5a4)
summary_long_out = os.path.join(
    DATA_OUTPUT_PATH_NEW_CELLS,
    f"{session_id}_{neural_hemisphere}_hump_metrics_summary_contactwise_cell5a3_LONG.csv"
)

if not os.path.exists(hump_results_input_filename):
    print(f"ERROR: No hump results file found at {hump_results_input_filename}")
    df_summary_wide = pd.DataFrame()
    df_summary_long = pd.DataFrame()
else:
    df = pd.read_csv(hump_results_input_filename)

    # --- Required core column ---
    if "hump_width" not in df.columns:
        print("ERROR: hump_width column missing in results file.")
        df_summary_wide = pd.DataFrame()
        df_summary_long = pd.DataFrame()
    else:
        # --- Clean & standardize ---
        df = df.copy()
        df = df.dropna(subset=["hump_width"])
        df = df[df["hump_width"] <= 60.0]  # same cutoff you used elsewhere

        # Prefer electrode_label; fallback to channel
        group_key = "electrode_label" if "electrode_label" in df.columns else "channel"
        if group_key not in df.columns:
            print("ERROR: Neither 'electrode_label' nor 'channel' present.")
            df_summary_wide = pd.DataFrame()
            df_summary_long = pd.DataFrame()
        else:
            group_cols = [group_key, "freq_band_label", "aperiodic_mode"]

            # --- Dynamically detect optional metric columns ---
            # Peak center frequency
            cf_candidates = ["hump_center_freq", "center_frequency", "peak_cf", "hump_cf"]
            cf_col = next((c for c in cf_candidates if c in df.columns), None)

            # Peak amplitude / power
            amp_candidates = ["hump_amplitude", "peak_power", "hump_power", "peak_amplitude"]
            amp_col = next((c for c in amp_candidates if c in df.columns), None)

            # Goodness-of-fit / error
            r2_candidates = ["r_squared", "r2", "R2"]
            r2_col = next((c for c in r2_candidates if c in df.columns), None)

            err_candidates = ["error", "mse", "fit_error", "err"]
            err_col = next((c for c in err_candidates if c in df.columns), None)

            # Any other metrics you want to keep an eye on can be added above

            # --- Helper: standard summary stats for a numeric column ---
            def summarize_numeric(col_name, friendly_name):
                dfx = df.dropna(subset=[col_name]).copy()
                if dfx.empty:
                    return None, None
                def iqr(x): return np.percentile(x, 75) - np.percentile(x, 25)
                agg = dfx.groupby(group_cols)[col_name].agg(
                    N="count",
                    Mean="mean",
                    Median="median",
                    Std="std",
                    Min="min",
                    Max="max",
                    Q25=lambda x: np.percentile(x, 25),
                    Q75=lambda x: np.percentile(x, 75),
                    IQR=iqr
                ).reset_index()
                # Wide with prefixed columns for clarity
                wide = agg.copy()
                wide = wide.rename(columns={
                    "N": f"{friendly_name}_N",
                    "Mean": f"{friendly_name}_Mean",
                    "Median": f"{friendly_name}_Median",
                    "Std": f"{friendly_name}_Std",
                    "Min": f"{friendly_name}_Min",
                    "Max": f"{friendly_name}_Max",
                    "Q25": f"{friendly_name}_Q25",
                    "Q75": f"{friendly_name}_Q75",
                    "IQR": f"{friendly_name}_IQR"
                })
                # Long/tidy for easy plotting/comparison later
                long = agg.melt(
                    id_vars=group_cols,
                    var_name="stat",
                    value_name="value"
                )
                long["metric"] = friendly_name
                # Reorder cols
                long = long[group_cols + ["metric", "stat", "value"]]
                return wide, long

            # --- Always summarize width ---
            width_wide, width_long = summarize_numeric("hump_width", "width")

            # --- Optional metrics if present ---
            widelist_wide = [width_wide]
            widelist_long = [width_long]

            if cf_col:
                cf_wide, cf_long = summarize_numeric(cf_col, "center_freq")
                if cf_wide is not None:
                    widelist_wide.append(cf_wide)
                    widelist_long.append(cf_long)

            if amp_col:
                amp_wide, amp_long = summarize_numeric(amp_col, "amplitude")
                if amp_wide is not None:
                    widelist_wide.append(amp_wide)
                    widelist_long.append(amp_long)

            if r2_col:
                r2_wide, r2_long = summarize_numeric(r2_col, "r2")
                if r2_wide is not None:
                    widelist_wide.append(r2_wide)
                    widelist_long.append(r2_long)

            if err_col:
                err_wide, err_long = summarize_numeric(err_col, "error")
                if err_wide is not None:
                    widelist_wide.append(err_wide)
                    widelist_long.append(err_long)

            # --- Derive proportion of "narrow" peaks (width thresholds) ---
            thresholds = [5.0, 10.0, 20.0]
            prop_frames = []
            for thr in thresholds:
                dthr = df.assign(narrow=(df["hump_width"] <= thr).astype(int))
                prop = dthr.groupby(group_cols)["narrow"].mean().reset_index()
                prop["metric"] = f"prop_width_le_{int(thr)}Hz"
                prop["stat"] = "Mean"
                prop = prop.rename(columns={"narrow": "value"})
                prop_frames.append(prop)

            if prop_frames:
                prop_long = pd.concat(prop_frames, ignore_index=True)
                widelist_long.append(prop_long)

            # --- Combine & Save ---
            # Wide: merge on keys
            from functools import reduce
            df_summary_wide = reduce(
                lambda left, right: pd.merge(left, right, on=group_cols, how="outer"),
                [w for w in widelist_wide if w is not None]
            )
            df_summary_long = pd.concat([l for l in widelist_long if l is not None], ignore_index=True)

            df_summary_wide.to_csv(summary_wide_out, index=False)
            df_summary_long.to_csv(summary_long_out, index=False)

            print(f"Saved WIDE summary to: {summary_wide_out}")
            print(f"Saved LONG summary to: {summary_long_out}")

# Quick peek
try:
    from IPython.display import display
    if not df_summary_long.empty:
        display(df_summary_long.head(20))
except Exception:
    pass



--- Cell 5a3: Generating Peak Metrics Summary (long/tidy) ---
Saved WIDE summary to: /home/jackson/step2_final/step3_fooof_results_neural_pkg_aligned_finalstep3_bushlab5000/RCS05L/Left/new_analysis_data/RCS05L_Left_hump_metrics_summary_contactwise_cell5a3.csv
Saved LONG summary to: /home/jackson/step2_final/step3_fooof_results_neural_pkg_aligned_finalstep3_bushlab5000/RCS05L/Left/new_analysis_data/RCS05L_Left_hump_metrics_summary_contactwise_cell5a3_LONG.csv


Unnamed: 0,electrode_label,freq_band_label,aperiodic_mode,metric,stat,value
0,key0_contact_2_0,LowFreq,fixed,width,N,3191.0
1,key0_contact_2_0,LowFreq,knee,width,N,4569.0
2,key0_contact_2_0,MidFreq,fixed,width,N,9442.0
3,key0_contact_2_0,MidFreq,knee,width,N,10078.0
4,key0_contact_2_0,WideFreq,fixed,width,N,5299.0
5,key0_contact_2_0,WideFreq,knee,width,N,7400.0
6,key1_contact_3_0,LowFreq,fixed,width,N,2997.0
7,key1_contact_3_0,LowFreq,knee,width,N,3888.0
8,key1_contact_3_0,MidFreq,fixed,width,N,9463.0
9,key1_contact_3_0,MidFreq,knee,width,N,10295.0


In [21]:
# -*- coding: utf-8 -*-
# --- Cell 5a4 (MINIMAL + REGIONS): Multi-subject collation + M1 vs STN comparison ---

import os, re
import pandas as pd
import numpy as np
from scipy.stats import wilcoxon, friedmanchisquare, mannwhitneyu
from statsmodels.stats.multitest import multipletests
import matplotlib.pyplot as plt
import seaborn as sns

print("\n--- Cell 5a4 (MINIMAL + REGIONS) ---")

# ====== 1) INPUTS (your lists unchanged) ======
file_5a3_FOOOF_LIST = [
    "/home/jackson/step2_final/step3_fooof_results_neural_pkg_aligned_finalstep3_fooof/RCS02R/Right/new_analysis_data/RCS02R_Right_hump_metrics_summary_contactwise_cell5a3_LONG.csv",
    "/home/jackson/step2_final/step3_fooof_results_neural_pkg_aligned_finalstep3_fooof/RCS02L/Left/new_analysis_data/RCS02L_Left_hump_metrics_summary_contactwise_cell5a3_LONG.csv",
    "/home/jackson/step2_final/step3_fooof_results_neural_pkg_aligned_finalstep3_fooof/RCS05R/Right/new_analysis_data/RCS05R_Right_hump_metrics_summary_contactwise_cell5a3_LONG.csv",
    "/home/jackson/step2_final/step3_fooof_results_neural_pkg_aligned_finalstep3_fooof/RCS05L/Left/new_analysis_data/RCS05L_Left_hump_metrics_summary_contactwise_cell5a3_LONG.csv",
    "/home/jackson/step2_final/step3_fooof_results_neural_pkg_aligned_finalstep3_fooof/RCS06R/Right/new_analysis_data/RCS06R_Right_hump_metrics_summary_contactwise_cell5a3_LONG.csv",
    "/home/jackson/step2_final/step3_fooof_results_neural_pkg_aligned_finalstep3_fooof/RCS06L/Left/new_analysis_data/RCS06L_Left_hump_metrics_summary_contactwise_cell5a3_LONG.csv",
]
file_5a3_SPECPARAM_LIST = [
    "/home/jackson/step2_final/step3_fooof_results_neural_pkg_aligned_finalstep3_bushlab5000/RCS02R/Right/new_analysis_data/RCS02R_Right_hump_metrics_summary_contactwise_cell5a3_LONG.csv",
    "/home/jackson/step2_final/step3_fooof_results_neural_pkg_aligned_finalstep3_bushlab5000/RCS02L/Left/new_analysis_data/RCS02L_Left_hump_metrics_summary_contactwise_cell5a3_LONG.csv",
    "/home/jackson/step2_final/step3_fooof_results_neural_pkg_aligned_finalstep3_bushlab5000/RCS05R/Right/new_analysis_data/RCS05R_Right_hump_metrics_summary_contactwise_cell5a3_LONG.csv",
    "/home/jackson/step2_final/step3_fooof_results_neural_pkg_aligned_finalstep3_bushlab5000/RCS05L/Left/new_analysis_data/RCS05L_Left_hump_metrics_summary_contactwise_cell5a3_LONG.csv",
    "/home/jackson/step2_final/step3_fooof_results_neural_pkg_aligned_finalstep3_bushlab5000/RCS06R/Right/new_analysis_data/RCS06R_Right_hump_metrics_summary_contactwise_cell5a3_LONG.csv",
    "/home/jackson/step2_final/step3_fooof_results_neural_pkg_aligned_finalstep3_bushlab5000/RCS06L/Left/new_analysis_data/RCS06L_Left_hump_metrics_summary_contactwise_cell5a3_LONG.csv",
]

PROJECT_LABEL = "ALL_SUBJECTS"
BASE_OUT_DIR = os.path.dirname(file_5a3_FOOOF_LIST[0]) if file_5a3_FOOOF_LIST else os.getcwd()
FIGURES_DIR  = os.path.normpath(os.path.join(BASE_OUT_DIR, "..", "figures_5a4_min"))
os.makedirs(FIGURES_DIR, exist_ok=True)

# ====== 2) HELPERS ======
def _parse_subject_hemi_from_path(path: str):
    base = os.path.basename(path)
    m = re.search(r'(RCS\d{2}[LR])_(Left|Right)', base, re.IGNORECASE)
    if m: return m.group(1), m.group(2).title()
    parts = [p for p in path.split(os.sep) if p]
    subj = next((p for p in parts if re.match(r'RCS\d{2}[LR]', p, re.IGNORECASE)), "UNKNOWN")
    hemi = next((p for p in parts if p.lower() in {"left", "right"}), "UNKNOWN").title()
    return subj, hemi

def load_many(paths, pipeline_label):
    frames = []
    for p in paths:
        if not os.path.exists(p):
            print(f"WARNING: missing 5a3 LONG (skip): {p}")
            continue
        df = pd.read_csv(p)
        subj, hemi = _parse_subject_hemi_from_path(p)
        df["pipeline"] = pipeline_label
        df["subject_hemi"] = f"{subj}_{hemi}"
        df["subject"] = subj
        df["hemi"] = hemi
        frames.append(df)
    if not frames:
        raise FileNotFoundError(f"No valid 5a3 LONG for '{pipeline_label}'.")
    out = pd.concat(frames, ignore_index=True)
    out = out[pd.to_numeric(out["value"], errors="coerce").notnull()].copy()
    out["value"] = out["value"].astype(float)
    if "aperiodic_mode" in out.columns:
        out["aperiodic_mode"] = out["aperiodic_mode"].astype(str).str.lower()
    return out

def safe_wilcoxon(x):
    x = np.asarray(x)
    if x.size == 0 or np.allclose(x, 0): return (0.0, 1.0)
    try:
        stat, p = wilcoxon(x)
        return (stat, p)
    except Exception:
        return (np.nan, np.nan)

def cliffs_delta(x, y):
    # x, y are 1D arrays; returns Cliff's delta in [-1,1]
    x = np.asarray(x); y = np.asarray(y)
    diffs = x[:, None] - y[None, :]
    return np.sign(diffs).sum() / (x.size * y.size)

# Region labelling from electrode/channel string
def classify_region_from_label(lbl: str):
    if pd.isna(lbl): return "UNKNOWN"
    s = str(lbl)
    # grab last token (usually like 'E3_4' or '10_11')
    tokens = re.split(r'[^\w]+', s)
    last = next((t for t in reversed(tokens) if t != ""), "")
    nums = re.findall(r'\d+', last)
    if not nums:
        nums = re.findall(r'\d+', s)  # fallback: any digits in label
    if not nums: return "UNKNOWN"
    # take the last two numbers if available; if only one, duplicate it
    pair = [int(nums[-1])] if len(nums) == 1 else [int(nums[-2]), int(nums[-1])]
    if len(pair) == 1: pair = [pair[0], pair[0]]
    lo, hi = min(pair), max(pair)
    if hi <= 3: return "STN"
    if lo >= 4 and hi <= 11: return "M1"
    return "BRIDGE"  # spans 3↔4 boundary or outside [0,11]

# ====== 3) LOAD & UNIFY ======
df_fooof = load_many(file_5a3_FOOOF_LIST, "fooof")            # fixed + knee
df_specp = load_many(file_5a3_SPECPARAM_LIST, "specparam5000") # knee only

group_key = "electrode_label" if ("electrode_label" in df_fooof.columns or "electrode_label" in df_specp.columns) else (
            "channel" if ("channel" in df_fooof.columns or "channel" in df_specp.columns) else None)
if group_key is None:
    raise ValueError("Expected 'electrode_label' or 'channel' in 5a3 outputs.")

keep_cols = ["subject_hemi", "subject", "hemi", group_key, "freq_band_label", "aperiodic_mode", "metric", "stat", "value"]
f_fixed = df_fooof[df_fooof["aperiodic_mode"] == "fixed"][keep_cols].copy(); f_fixed["method"] = "fooof_fixed"
f_knee  = df_fooof[df_fooof["aperiodic_mode"] == "knee"][keep_cols].copy();  f_knee["method"]  = "fooof_knee"
s_knee  = df_specp[df_specp["aperiodic_mode"] == "knee"][keep_cols].copy();  s_knee["method"]  = "specparam_knee"
df_all  = pd.concat([f_fixed, f_knee, s_knee], ignore_index=True)

combined_out_csv = os.path.join(BASE_OUT_DIR, f"{PROJECT_LABEL}_5a4_COMBINED_LONG.csv")
df_all.to_csv(combined_out_csv, index=False)
print(f"Saved combined LONG: {combined_out_csv}")

# ====== 4) WIDTH (Median) and REGION TAG ======
width_med = df_all[(df_all["metric"]=="width") & (df_all["stat"]=="Median")].copy()
width_med["region"] = width_med[group_key].apply(classify_region_from_label)

# Keep only STN/M1; drop bridges/unknowns
width_med = width_med[width_med["region"].isin(["STN", "M1"])].copy()

# ====== 5) Region-wise paired tests: fixed vs knee ======
pair_keys = ["subject_hemi", group_key, "freq_band_label", "region"]

def region_wilcoxon(df_region, region_name):
    df_sub = df_region[(df_region["region"]==region_name) & (df_region["method"].isin(["fooof_fixed","fooof_knee"]))].copy()
    pv = df_sub.pivot_table(index=pair_keys[:-1], columns="method", values="value", aggfunc="first").reset_index()
    pv = pv.dropna(subset=["fooof_fixed","fooof_knee"]).copy()
    pv["diff_knee_minus_fixed"] = pv["fooof_knee"] - pv["fooof_fixed"]
    stat, p = safe_wilcoxon(pv["diff_knee_minus_fixed"].values)
    md = pv["diff_knee_minus_fixed"].mean()
    sd = pv["diff_knee_minus_fixed"].std(ddof=1)
    dz = np.nan if sd==0 else md/sd
    # percent reduction
    pct = 100.0*(pv["fooof_fixed"] - pv["fooof_knee"])/pv["fooof_fixed"]
    summary = {
        "region": region_name, "N_pairs": len(pv),
        "mean_diff_Hz": md, "cohens_dz": dz, "p": p,
        "median_pct_reduction": np.nanmedian(pct),
        "IQR_pct_reduction": (np.nanpercentile(pct,25), np.nanpercentile(pct,75)),
        "prop_narrower_%": 100.0*np.mean(pct>0)
    }
    return pv, summary

pv_stn, sum_stn = region_wilcoxon(width_med, "STN")
pv_m1,  sum_m1  = region_wilcoxon(width_med, "M1")

# FDR across the two region tests
pvals = [sum_stn["p"], sum_m1["p"]]
rej, qvals, _, _ = multipletests(pvals, method="fdr_bh")
sum_stn["q"] = qvals[0]; sum_m1["q"] = qvals[1]
sum_stn["rejected_0.05"] = rej[0]; sum_m1["rejected_0.05"] = rej[1]

regional_stats = pd.DataFrame([sum_stn, sum_m1])
regional_stats_out = os.path.join(BASE_OUT_DIR, f"{PROJECT_LABEL}_REGIONAL_fixed_vs_knee_stats.csv")
regional_stats.to_csv(regional_stats_out, index=False)
print(f"Saved regional stats: {regional_stats_out}")
print(regional_stats)

# ====== 6) Does the knee effect differ between regions? (Δwidth STN vs M1) ======
d_stn = pv_stn["diff_knee_minus_fixed"].values
d_m1  = pv_m1["diff_knee_minus_fixed"].values
if (len(d_stn) > 0) and (len(d_m1) > 0):
    u_stat, p_diff = mannwhitneyu(d_stn, d_m1, alternative="two-sided")
    cd = cliffs_delta(d_stn, d_m1)
    diff_of_diff_out = os.path.join(BASE_OUT_DIR, f"{PROJECT_LABEL}_REGIONAL_diff_of_diffs_mannwhitney.csv")
    pd.DataFrame([{"U":u_stat, "p":p_diff, "cliffs_delta":cd, "N_STN":len(d_stn), "N_M1":len(d_m1)}]).to_csv(diff_of_diff_out, index=False)
    print(f"Saved region Δ effect test: {diff_of_diff_out}")
    print(f"Region difference (Δ knee−fixed): Mann–Whitney U={u_stat:.1f}, p={p_diff:.3e}, Cliff's δ={cd:.3f}")
else:
    print("Insufficient pairs to compare region effects.")

# ====== 7) ONE PLOT: STN/M1 fixed vs knee with significance ======
sns.set_context("talk")

plot_df = width_med[width_med["method"].isin(["fooof_fixed","fooof_knee"])].copy()
plot_df = plot_df.rename(columns={"value":"median_width"})
plot_df["method_clean"] = plot_df["method"].map({"fooof_fixed":"fixed", "fooof_knee":"knee"})
plot_df["group"] = plot_df["region"].str.upper() + " " + plot_df["method_clean"]

order = ["STN fixed", "STN knee", "M1 fixed", "M1 knee"]
plt.figure(figsize=(9,6))
ax = sns.boxplot(data=plot_df, x="group", y="median_width", order=order)
sns.swarmplot(data=plot_df, x="group", y="median_width", order=order, color="k", alpha=0.25, size=2)
plt.title("Median peak widths: STN vs M1 (FOOOF fixed vs knee)")
plt.ylabel("Median Peak Width (Hz)")
plt.xlabel("")
plt.grid(True, axis="y", alpha=0.3)

# def annotate_sig(ax, x1, x2, y, text, h=0.8):
#     ax.plot([x1, x1, x2, x2], [y, y+h, y+h, y], lw=1.5, c="k")
#     ax.text((x1+x2)/2., y+h*1.1, text, ha='center', va='bottom')

ymax = plot_df["median_width"].max()
bump = (ymax*0.05) if ymax>0 else 1.0

# # STN bracket (positions 0 vs 1)
# stn_text = f"q={sum_stn['q']:.2e}, dz={sum_stn['cohens_dz']:.2f}"
# annotate_sig(ax, 0, 1, ymax + bump, stn_text)

# # M1 bracket (positions 2 vs 3)
# m1_text  = f"q={sum_m1['q']:.2e}, dz={sum_m1['cohens_dz']:.2f}"
# annotate_sig(ax, 2, 3, ymax + bump*3, m1_text)

fig_path = os.path.join(FIGURES_DIR, f"{PROJECT_LABEL}_STN_M1_fixed_vs_knee_box.png")
plt.tight_layout(); plt.savefig(fig_path, dpi=300); plt.close()
print(f"Saved figure: {fig_path}")

print("\n--- DONE (regions) ---")



--- Cell 5a4 (MINIMAL + REGIONS) ---
Saved combined LONG: /home/jackson/step2_final/step3_fooof_results_neural_pkg_aligned_finalstep3_fooof/RCS02R/Right/new_analysis_data/ALL_SUBJECTS_5a4_COMBINED_LONG.csv
Saved regional stats: /home/jackson/step2_final/step3_fooof_results_neural_pkg_aligned_finalstep3_fooof/RCS02R/Right/new_analysis_data/ALL_SUBJECTS_REGIONAL_fixed_vs_knee_stats.csv
  region  N_pairs  mean_diff_Hz  cohens_dz         p  median_pct_reduction  \
0    STN       36     -5.388889  -0.852956  0.000003             21.803922   
1     M1       36    -12.458333  -0.902894  0.000002             44.472617   

                        IQR_pct_reduction  prop_narrower_%         q  \
0  (9.625668449197862, 36.59476117103236)        83.333333  0.000003   
1  (23.511904761904763, 70.0257731958763)        86.111111  0.000003   

   rejected_0.05  
0           True  
1           True  
Saved region Δ effect test: /home/jackson/step2_final/step3_fooof_results_neural_pkg_aligned_finalstep3

In [13]:
# -*- coding: utf-8 -*-
# --- Cell 5b: Visualizing Average PSDs by Clinical State ---
import json
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import sys
import warnings
from tqdm.notebook import tqdm

NEW_CLINICAL_STATE_COLORS = {
    'Dyskinetic Mobile': '#d62728',        # red
    'Immobile': '#1f77b4',                 # blue
    'Non-Dyskinetic Mobile': '#2ca02c',    # green
    'Transitional Mobile': '#ff7f0e',      # orange
    'Unknown': '#808080'                   # gray fallback
}

# >>> Minimum number of segments to include a line in the plot
MIN_SEGMENTS_FOR_PLOT = 100

print("\n--- Cell 5b: Starting Average PSD Visualization by Clinical State ---")

# --- Load raw PSDs collected in Cell 4 ---
raw_psds_input_filename_c5b = os.path.join(
    DATA_OUTPUT_PATH_NEW_CELLS,
    f"{session_id}_{neural_hemisphere}_raw_psds_for_averaging_from_cell4.parquet"
)
if not os.path.exists(raw_psds_input_filename_c5b):
    print(f"ERROR: Raw PSDs file from Cell 4 not found: {raw_psds_input_filename_c5b}. Skipping Cell 5b.")
    df_all_raw_psds_for_c5b = pd.DataFrame() # Create empty to allow script to run
else:
    try:
        df_all_raw_psds_for_c5b = pd.read_parquet(raw_psds_input_filename_c5b, engine='fastparquet')
        print(f"Successfully loaded raw PSDs from: {raw_psds_input_filename_c5b}. Shape: {df_all_raw_psds_for_c5b.shape}")
        if 'psd' not in df_all_raw_psds_for_c5b.columns or 'freqs' not in df_all_raw_psds_for_c5b.columns:
             print("ERROR: 'psd' or 'freqs' column missing in loaded raw PSDs. Cannot proceed with Cell 5b.")
             df_all_raw_psds_for_c5b = pd.DataFrame()
    except Exception as e:
        print(f"ERROR loading raw PSDs parquet for Cell 5b: {e}")
        df_all_raw_psds_for_c5b = pd.DataFrame()

if df_all_raw_psds_for_c5b.empty:
    print("No raw PSD data available for Cell 5b. Skipping averaging and plotting.")
elif 'fooof_input_df' not in locals() or fooof_input_df.empty:
    print("ERROR: fooof_input_df (for clinical states) not available in Cell 5b. Skipping.")
else:
    # --- Merge raw PSDs with Clinical State Information ---
    print("Merging raw PSDs with clinical state information...")
    clinical_cols_to_merge = ['Aligned_PKG_UnixTimestamp', 'Channel',
                              'Clinical_State_2min_Window', 'Clinical_State_Aggregated']
    if not all(col in fooof_input_df.columns for col in clinical_cols_to_merge):
        print(f"ERROR: One or more clinical columns missing from fooof_input_df: {clinical_cols_to_merge}. Skipping merge.")
        df_merged_psds_for_avg = pd.DataFrame()
    else:
        df_clinical_states_c5b = fooof_input_df[clinical_cols_to_merge].drop_duplicates()

        # --- Align column names and types ---
        df_all_raw_psds_for_c5b = df_all_raw_psds_for_c5b.rename(columns={
            'channel': 'Channel',
            'timestamp_unix': 'Aligned_PKG_UnixTimestamp'
        })
        df_all_raw_psds_for_c5b['Aligned_PKG_UnixTimestamp'] = df_all_raw_psds_for_c5b['Aligned_PKG_UnixTimestamp'].astype('int64')
        df_clinical_states_c5b['Aligned_PKG_UnixTimestamp'] = df_clinical_states_c5b['Aligned_PKG_UnixTimestamp'].astype('int64')

        # --- Merge using nearest timestamp match within 30s tolerance ---
        df_all_raw_psds_for_c5b_sorted = df_all_raw_psds_for_c5b.sort_values('Aligned_PKG_UnixTimestamp')
        df_clinical_states_c5b_sorted = df_clinical_states_c5b.sort_values('Aligned_PKG_UnixTimestamp')
        df_merged_psds_for_avg = pd.merge_asof(
            df_all_raw_psds_for_c5b_sorted,
            df_clinical_states_c5b_sorted,
            on='Aligned_PKG_UnixTimestamp',
            by='Channel',
            direction='nearest',
            tolerance=30000  # in milliseconds
        )

        # --- Fill any unmatched labels with 'Unknown' ---
        df_merged_psds_for_avg['Clinical_State_2min_Window'] = df_merged_psds_for_avg['Clinical_State_2min_Window'].fillna('Unknown')

        if 'electrode_label' not in df_merged_psds_for_avg.columns and electrode_labels:
            df_merged_psds_for_avg['electrode_label'] = df_merged_psds_for_avg['Channel'].map(electrode_labels)

        if df_merged_psds_for_avg.empty:
            print("No data after merging PSDs with clinical states.")
        else:
            print(f"Successfully merged. Shape of merged data for averaging: {df_merged_psds_for_avg.shape}")
            df_merged_psds_for_avg['Clinical_State_2min_Window'] = df_merged_psds_for_avg['Clinical_State_2min_Window'].fillna('Unknown')

    if not df_merged_psds_for_avg.empty:
        # --- Calculate Average PSDs ---
        print("Calculating average PSDs per channel and clinical state...")
        example_freqs_c5b = df_merged_psds_for_avg['freqs'].iloc[0]
        if isinstance(example_freqs_c5b, str):
            try:
                example_freqs_c5b = np.array(json.loads(example_freqs_c5b))
            except Exception as e:
                print(f"ERROR parsing frequency vector: {e}")
                example_freqs_c5b = np.array([])

        def average_psds(psd_list):
            psd_arrays = []
            for item in psd_list:
                if isinstance(item, np.ndarray):
                    psd_arrays.append(item)
                elif isinstance(item, str):
                    try:
                        psd_arrays.append(np.array(json.loads(item)))
                    except Exception as e:
                        print(f"Error parsing PSD string: {e}")
            if not psd_arrays:
                return None, 0
            return np.mean(np.array(psd_arrays), axis=0), len(psd_arrays)

        grouping_cols_c5b = ['Channel', 'electrode_label', 'Clinical_State_2min_Window']
        if 'electrode_label' not in df_merged_psds_for_avg.columns:
            print("Warning: 'electrode_label' not found for grouping in Cell 5b. Using 'Channel'.")
            grouping_cols_c5b = ['Channel', 'Clinical_State_2min_Window']
            df_merged_psds_for_avg['electrode_label'] = df_merged_psds_for_avg['Channel']  # temp for loop

        averaged_psd_results_c5b = []
        for name, group in tqdm(df_merged_psds_for_avg.groupby(grouping_cols_c5b, observed=True), desc="Averaging PSDs"):
            channel_key, el_label, clinical_state = name[0], name[1], name[2]
            avg_psd_linear, n_segments = average_psds(group['psd'].tolist())
            if avg_psd_linear is not None:
                averaged_psd_results_c5b.append({
                    'Channel': channel_key,
                    'ElectrodeLabel': el_label if el_label else channel_key,
                    'Clinical_State': clinical_state,
                    'Average_PSD_Linear': avg_psd_linear,
                    'Frequencies': example_freqs_c5b,
                    'Num_Segments_Averaged': n_segments
                })

        df_averaged_psds_c5b = pd.DataFrame(averaged_psd_results_c5b)

        if df_averaged_psds_c5b.empty:
            print("No averaged PSDs were generated.")
        else:
            print(f"Generated {len(df_averaged_psds_c5b)} averaged PSDs.")

            # --- Plot Average PSDs ---
            output_folder_c5b_avg_psd_plots = os.path.join(FIGURES_OUTPUT_PATH_NEW_CELLS, 'average_psds_by_state_cell5b')
            os.makedirs(output_folder_c5b_avg_psd_plots, exist_ok=True)
            print(f"Average PSD plots will be saved in: {output_folder_c5b_avg_psd_plots}")

            unique_channels_to_plot_c5b = df_averaged_psds_c5b['Channel'].unique()
            for ch_key in unique_channels_to_plot_c5b:
                df_channel_all = df_averaged_psds_c5b[df_averaged_psds_c5b['Channel'] == ch_key]

                # >>> Only plot states with enough segments
                df_channel = df_channel_all[df_channel_all['Num_Segments_Averaged'] >= MIN_SEGMENTS_FOR_PLOT].copy()

                if df_channel.empty:
                    print(f"Skipping Channel {ch_key}: no states with N >= {MIN_SEGMENTS_FOR_PLOT}.")
                    continue

                el_label_for_plot = df_channel['ElectrodeLabel'].iloc[0]
                plt.figure(figsize=(12, 7))

                # Track if we actually plotted anything
                any_plotted = False
                for _, row_avg in df_channel.iterrows():
                    state = row_avg['Clinical_State']
                    freqs = row_avg['Frequencies']
                    avg_psd_log = np.log10(row_avg['Average_PSD_Linear'])
                    n_seg = row_avg['Num_Segments_Averaged']

                    color = NEW_CLINICAL_STATE_COLORS.get(state, '#808080')
                    plt.plot(freqs, avg_psd_log, label=f"{state} (N={n_seg})", color=color, linewidth=2)
                    any_plotted = True

                if not any_plotted:
                    print(f"Channel {ch_key}: nothing plotted after thresholding (N >= {MIN_SEGMENTS_FOR_PLOT}).")
                    plt.close()
                    continue

                plt.title(f"Average PSD by Clinical State - {session_id} - Channel {el_label_for_plot}")
                plt.xlabel("Frequency (Hz)")
                plt.ylabel("Log10 Power Spectral Density")
                plt.xlim(0, 100)
                plt.legend(title=f"Clinical State (N ≥ {MIN_SEGMENTS_FOR_PLOT})")
                plt.grid(True, which="both", ls="--", alpha=0.5)
                plt.tight_layout()

                plot_filename = os.path.join(
                    output_folder_c5b_avg_psd_plots,
                    f"{session_id}_{el_label_for_plot}_avg_psd_by_state.png"
                )
                plt.savefig(plot_filename)
                plt.close()

            print("Finished plotting average PSDs by clinical state.")

# Hand-off to Cell 5c
if 'df_averaged_psds_c5b' in locals() and not df_averaged_psds_c5b.empty:
    print("Cell 5b: Average PSDs DataFrame is available for Cell 5c.")
else:
    print("Cell 5b: Average PSDs DataFrame is NOT available for Cell 5c.")
    df_averaged_psds_c5b = pd.DataFrame()  # Ensure it's defined for 5c check

print("\n--- Cell 5b: Average PSD Visualization Complete ---")



--- Cell 5b: Starting Average PSD Visualization by Clinical State ---
Successfully loaded raw PSDs from: /home/jackson/step2_final/step3_fooof_results_neural_pkg_aligned_finalstep3_bushlab5000/RCS05L/Left/new_analysis_data/RCS05L_Left_raw_psds_for_averaging_from_cell4.parquet. Shape: (11368, 6)
Merging raw PSDs with clinical state information...
Successfully merged. Shape of merged data for averaging: (11368, 8)
Calculating average PSDs per channel and clinical state...


Averaging PSDs:   0%|          | 0/24 [00:00<?, ?it/s]

Generated 24 averaged PSDs.
Average PSD plots will be saved in: /home/jackson/step2_final/step3_fooof_results_neural_pkg_aligned_finalstep3_bushlab5000/RCS05L/Left/new_analysis_figures/average_psds_by_state_cell5b
Finished plotting average PSDs by clinical state.
Cell 5b: Average PSDs DataFrame is available for Cell 5c.

--- Cell 5b: Average PSD Visualization Complete ---


In [14]:
print("Raw PSDs columns:", df_all_raw_psds_for_c5b.columns.tolist())
print("Clinical states columns:", df_clinical_states_c5b.columns.tolist())

Raw PSDs columns: ['Aligned_PKG_UnixTimestamp', 'datetime_utc', 'Channel', 'electrode_label', 'freqs', 'psd']
Clinical states columns: ['Aligned_PKG_UnixTimestamp', 'Channel', 'Clinical_State_2min_Window', 'Clinical_State_Aggregated']


In [15]:
# -*- coding: utf-8 -*-
# --- Cell 5c: Run FOOOF on Averaged PSDs and Plot (with N>=100 gating & dual legends) ---

import os
import sys
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
# Prefer specparam; fall back to fooof

from specparam.objs import SpectralModel as FOOOF
USING_SPECPARAM = True

from tqdm.notebook import tqdm  # make sure tqdm is available

MIN_SEGMENTS_FOR_PLOT = 100  # keep in sync with 5b
REG_LAMBDA = float(os.getenv("FOOOF_REG_LAMBDA", "5000"))

def set_lambda(fm, lam=REG_LAMBDA):
    if hasattr(fm, "_model") and hasattr(fm._model, "regularization_weight"):
        fm._model.regularization_weight = lam
    elif hasattr(fm, "regularization_weight"):
        fm.regularization_weight = lam

print("\n--- Cell 5c: Starting FOOOF on Averaged PSDs ---")

if 'df_averaged_psds_c5b' not in locals() or df_averaged_psds_c5b.empty:
    print("ERROR: Averaged PSDs DataFrame (df_averaged_psds_c5b) not available from Cell 5b. Skipping Cell 5c.")
else:
    print(f"Processing {len(df_averaged_psds_c5b)} averaged PSDs for FOOOF fitting "
          f"(will skip rows with N < {MIN_SEGMENTS_FOR_PLOT}).")

    # Output folder
    output_folder_c5c_fooof_avg_plots = os.path.join(
        FIGURES_OUTPUT_PATH_NEW_CELLS, 'fooof_on_average_psds_cell5c'
    )
    os.makedirs(output_folder_c5c_fooof_avg_plots, exist_ok=True)
    print(f"FOOOF plots on averaged PSDs will be saved in: {output_folder_c5c_fooof_avg_plots}")

    # Fit range
    default_fit_range_label_c5c = 'WideFreq'
    if default_fit_range_label_c5c in ITERATIVE_FREQ_BANDS:
        current_freq_range_c5c = ITERATIVE_FREQ_BANDS[default_fit_range_label_c5c]
    else:
        current_freq_range_c5c = list(ITERATIVE_FREQ_BANDS.values())[0] if ITERATIVE_FREQ_BANDS else [10, 90]
        default_fit_range_label_c5c = list(ITERATIVE_FREQ_BANDS.keys())[0] if ITERATIVE_FREQ_BANDS else f"{current_freq_range_c5c[0]}-{current_freq_range_c5c[1]}Hz"
        print(f"Warning: '{default_fit_range_label_c5c}' not found in ITERATIVE_FREQ_BANDS. Using: {current_freq_range_c5c}")

    fooofed_average_results_c5c = []

    for _, row_avg_psd in tqdm(df_averaged_psds_c5b.iterrows(),
                               total=len(df_averaged_psds_c5b),
                               desc="FOOOFing Avg PSDs"):

        # ---- EARLY GATE: skip low-N before making any figure ----
        if row_avg_psd['Num_Segments_Averaged'] < MIN_SEGMENTS_FOR_PLOT:
            continue

        channel_key = row_avg_psd['Channel']
        el_label = row_avg_psd['ElectrodeLabel']
        clinical_state = row_avg_psd['Clinical_State']
        avg_psd_linear = row_avg_psd['Average_PSD_Linear']
        freqs = row_avg_psd['Frequencies']
        n_segments = row_avg_psd['Num_Segments_Averaged']

        # Basic sanity
        if avg_psd_linear is None or freqs is None:
            continue
        if isinstance(avg_psd_linear, str):
            try: avg_psd_linear = np.array(json.loads(avg_psd_linear))
            except Exception: continue
        if isinstance(freqs, str):
            try: freqs = np.array(json.loads(freqs))
            except Exception: continue
        if len(avg_psd_linear) == 0 or len(freqs) == 0 or len(avg_psd_linear) != len(freqs):
            continue

        # ---- Create figure only for rows we will actually process ----
        fig_avg, axes_avg = plt.subplots(1, 2, figsize=(15, 6), gridspec_kw={'width_ratios': [1, 1]})
        fig_avg.suptitle(
            f"FOOOF on Averaged PSD: {session_id} - Ch {el_label} - State: {clinical_state} (N={n_segments})\n"
            f"Fit Range: {default_fit_range_label_c5c} ({current_freq_range_c5c[0]}-{current_freq_range_c5c[1]} Hz)",
            fontsize=12
        )

        models_fitted_this_spectrum = 0

        # ---------- FIXED ----------
        fm_fixed_avg = FOOOF(**basic_fooof_settings)
        # Enable regularization on the underlying SpectralModel (works with your specparam patch)
        if hasattr(fm_fixed_avg, "_model"):
            fm_fixed_avg = FOOOF(**basic_fooof_settings)
            set_lambda(fm_fixed_avg)
        else:
            # fallback in case wrapper changes
            try: 
                fm_fixed_avg = FOOOF(**basic_fooof_settings)
                set_lambda(fm_fixed_avg)
            except Exception: pass

        try:
            fm_fixed_avg.fit(freqs, avg_psd_linear, freq_range=current_freq_range_c5c)
            if fm_fixed_avg.has_model:
                fm_fixed_avg.plot(ax=axes_avg[0], plt_log=True, add_legend=False)

                # Spectrum legend (bottom-left)
                handles_spec = [
                    plt.Line2D([0], [0], color='black', label='Original Spectrum'),
                    plt.Line2D([0], [0], color='red', label='Full Model Fit'),
                    plt.Line2D([0], [0], color='blue', linestyle='--', label='Aperiodic Fit'),
                ]
                leg1 = axes_avg[0].legend(handles=handles_spec, loc='lower left', fontsize=8, frameon=True)
                axes_avg[0].add_artist(leg1)

                # Metrics legend (top-right)
                metrics_text_fixed = f"Model: Fixed\nR² = {fm_fixed_avg.r_squared_:.3f}\nErr = {fm_fixed_avg.error_:.2e}"
                leg2 = axes_avg[0].legend([metrics_text_fixed], loc='upper right', fontsize=8, frameon=True, handlelength=0)
                for h in leg2.legendHandles: h.set_visible(False)  # text-only legend
                axes_avg[0].add_artist(leg2)

                axes_avg[0].set_title("Aperiodic Mode: Fixed")
                axes_avg[0].grid(True, which="both", ls="--", alpha=0.5)

                fooofed_average_results_c5c.append({
                    'Channel': channel_key, 'ElectrodeLabel': el_label, 'Clinical_State': clinical_state,
                    'AperiodicMode': 'fixed', 'R2': fm_fixed_avg.r_squared_, 'Error': fm_fixed_avg.error_,
                    'Offset': fm_fixed_avg.aperiodic_params_[0], 'Exponent': fm_fixed_avg.aperiodic_params_[1],
                    'NumPeaks': len(fm_fixed_avg.peak_params_),
                    'RegLambda': REG_LAMBDA,        # <- record λ in results
                })
                models_fitted_this_spectrum += 1
        except Exception as e_fit_fixed:
            print(f"  Warning: FOOOF 'fixed' fit failed for avg PSD Ch {el_label}, State {clinical_state}. Error: {e_fit_fixed}")
            axes_avg[0].set_title("Aperiodic Mode: Fixed (No Fit)")

        # ---------- KNEE ----------
        fm_knee_avg = FOOOF(**knee_fooof_settings)
        # Enable regularization
        if hasattr(fm_knee_avg, "_model"):
            fm_knee_avg = FOOOF(**knee_fooof_settings)
            set_lambda(fm_knee_avg)
        else:
            try: 
                fm_knee_avg = FOOOF(**knee_fooof_settings)
                set_lambda(fm_knee_avg)
            except Exception: pass

        try:
            fm_knee_avg.fit(freqs, avg_psd_linear, freq_range=current_freq_range_c5c)
            if fm_knee_avg.has_model:
                fm_knee_avg.plot(ax=axes_avg[1], plt_log=True, add_legend=False)

                # Spectrum legend (bottom-left)
                handles_spec = [
                    plt.Line2D([0], [0], color='black', label='Original Spectrum'),
                    plt.Line2D([0], [0], color='red', label='Full Model Fit'),
                    plt.Line2D([0], [0], color='blue', linestyle='--', label='Aperiodic Fit'),
                ]
                leg1 = axes_avg[1].legend(handles=handles_spec, loc='lower left', fontsize=8, frameon=True)
                axes_avg[1].add_artist(leg1)

                # Metrics legend (top-right)
                metrics_text_knee = f"Model: Knee\nR² = {fm_knee_avg.r_squared_:.3f}\nErr = {fm_knee_avg.error_:.2e}"
                leg2 = axes_avg[1].legend([metrics_text_knee], loc='upper right', fontsize=8, frameon=True, handlelength=0)
                for h in leg2.legendHandles: h.set_visible(False)
                axes_avg[1].add_artist(leg2)

                axes_avg[1].set_title("Aperiodic Mode: Knee")
                axes_avg[1].grid(True, which="both", ls="--", alpha=0.5)

                fooofed_average_results_c5c.append({
                    'Channel': channel_key, 'ElectrodeLabel': el_label, 'Clinical_State': clinical_state,
                    'AperiodicMode': 'knee', 'R2': fm_knee_avg.r_squared_, 'Error': fm_knee_avg.error_,
                    'Offset': fm_knee_avg.aperiodic_params_[0], 'Knee': fm_knee_avg.aperiodic_params_[1],
                    'Exponent': fm_knee_avg.aperiodic_params_[2],
                    'NumPeaks': len(fm_knee_avg.peak_params_),
                    'RegLambda': getattr(fm_fixed_avg, "regularization_weight", REG_LAMBDA),
                })
                models_fitted_this_spectrum += 1
        except Exception as e_fit_knee:
            print(f"  Warning: FOOOF 'knee' fit failed for avg PSD Ch {el_label}, State {clinical_state}. Error: {e_fit_knee}")
            axes_avg[1].set_title("Aperiodic Mode: Knee (No Fit)")

        # ---- Save only if at least one model fit ----
        if models_fitted_this_spectrum > 0:
            plt.tight_layout(rect=[0, 0, 1, 0.96])
            plot_filename_avg_fooof = os.path.join(
                output_folder_c5c_fooof_avg_plots,
                f"{session_id}_{el_label}_{clinical_state.replace(' ', '_')}_avg_psd_fooof.png"
            )
            fig_avg.savefig(plot_filename_avg_fooof)
        plt.close(fig_avg)

    # ---- Save results table ----
    df_fooofed_average_results_c5c = pd.DataFrame(fooofed_average_results_c5c)
    if not df_fooofed_average_results_c5c.empty:
        print(f"\nGenerated {len(df_fooofed_average_results_c5c)} FOOOF model fits for averaged PSDs.")
        avg_fooof_results_filename = os.path.join(
            DATA_OUTPUT_PATH_NEW_CELLS,
            f"{session_id}_{neural_hemisphere}_fooof_fits_on_average_psds_from_cell5c.csv"
        )
        df_fooofed_average_results_c5c.to_csv(avg_fooof_results_filename, index=False)
        print(f"Saved FOOOF results on averaged PSDs to: {avg_fooof_results_filename}")
    else:
        print("\nNo FOOOF models were successfully fitted to the averaged PSDs.")

print("\n--- Cell 5c: FOOOF on Averaged PSDs Complete ---")



--- Cell 5c: Starting FOOOF on Averaged PSDs ---
Processing 24 averaged PSDs for FOOOF fitting (will skip rows with N < 100).
FOOOF plots on averaged PSDs will be saved in: /home/jackson/step2_final/step3_fooof_results_neural_pkg_aligned_finalstep3_bushlab5000/RCS05L/Left/new_analysis_figures/fooof_on_average_psds_cell5c


FOOOFing Avg PSDs:   0%|          | 0/24 [00:00<?, ?it/s]


Generated 32 FOOOF model fits for averaged PSDs.
Saved FOOOF results on averaged PSDs to: /home/jackson/step2_final/step3_fooof_results_neural_pkg_aligned_finalstep3_bushlab5000/RCS05L/Left/new_analysis_data/RCS05L_Left_fooof_fits_on_average_psds_from_cell5c.csv

--- Cell 5c: FOOOF on Averaged PSDs Complete ---


In [16]:
# -*- coding: utf-8 -*-
# --- Cell 8: Generate and Save Master DataFrame ---
# This cell now combines:
# 1. Aperiodic parameters (fixed, knee, best) from df_fine_grain_results_cell4
# 2. Original segment info, Total_Daily_LEDD_mg, Beta_Peak_Power, Gamma_Peak_Power from fooof_input_df

import pandas as pd # Ensure pandas is imported for pd.isna
import numpy as np  # Ensure numpy is imported for np.nan
import sys          # For sys.exit

print("\n--- Cell 8: Starting Master DataFrame Generation (with new metrics) ---")

# Prerequisite DataFrames:
# - fooof_input_df (from Cell 3, now with LEDD, Beta_Peak_Power, Gamma_Peak_Power)
# - df_fine_grain_results_cell4 (from MODIFIED Cell 4, with detailed aperiodic fits)

if 'fooof_input_df' not in locals() or fooof_input_df.empty:
    print("CRITICAL WARNING in Cell 8: fooof_input_df is not available or empty. Creating a dummy DataFrame for LEDD prompt testing.")
    fooof_input_df = pd.DataFrame({'Total_Daily_LEDD_mg': [np.nan, np.nan, np.nan]}) # Dummy for testing
    # sys.exit("ERROR in Cell 8: fooof_input_df is not available or empty. Cannot create master table.")

if 'df_fine_grain_results_cell4' not in locals() or ('df_fine_grain_results_cell4' in locals() and df_fine_grain_results_cell4.empty):
    print("Warning in Cell 8: df_fine_grain_results_cell4 not available or empty. Master table may lack detailed FOOOF results.")
    # Create an empty df with expected columns to allow merge to proceed but result in NaNs for FOOOF columns
    df_fine_grain_results_cell4 = pd.DataFrame(columns=['timestamp_unix', 'channel', 'electrode_label', 'freq_band_label', 'aperiodic_mode',
                                                        'r_squared', 'fit_error', 'aperiodic_offset', 'aperiodic_knee', 
                                                        'aperiodic_exponent', 'num_model_peaks'])


trigger_ledd_input_prompt = False
if 'fooof_input_df' in locals() and ('Total_Daily_LEDD_mg' not in fooof_input_df.columns):
    print("Warning: 'Total_Daily_LEDD_mg' column missing from fooof_input_df. Cannot prompt for LEDD override.")
    fooof_input_df['Total_Daily_LEDD_mg'] = np.nan # Add the column as NaNs

if trigger_ledd_input_prompt:
    while True:
        try:
            user_ledd_value_str = input("Please enter a default LEDD value to propagate for the entire 'Total_Daily_LEDD_mg' column (e.g., 0 or 500), or type 'skip' to leave as is: ")
            if user_ledd_value_str.lower() == 'skip':
                print("Skipping LEDD override. 'Total_Daily_LEDD_mg' will remain as is (likely NaNs).")
                break
            user_ledd_value = float(user_ledd_value_str)
            fooof_input_df['Total_Daily_LEDD_mg'] = user_ledd_value
            print(f"'Total_Daily_LEDD_mg' column has been filled with {user_ledd_value}.")
            break
        except ValueError:
            print("Invalid input. Please enter a numeric value for LEDD or 'skip'.")

expected_pivot_values = ['r_squared', 'fit_error', 'aperiodic_offset', 'aperiodic_knee', 'aperiodic_exponent', 'num_model_peaks']
actual_pivot_values = [val for val in expected_pivot_values if val in df_fine_grain_results_cell4.columns]

if not actual_pivot_values:
    print("Warning: No value columns found in df_fine_grain_results_cell4 for pivoting. FOOOF params will be empty.")
    df_pivot_foof_params = pd.DataFrame(columns=['Aligned_PKG_UnixTimestamp', 'Channel', 'electrode_label', 'FreqRangeLabel']) # Create empty with index cols
    # Add placeholder columns that would have been created by pivot, filled with NaN
    pivoted_cols_to_add = []
    for mode in ['fixed', 'knee']:
        for val_col in expected_pivot_values:
            pivoted_cols_to_add.append(f'{val_col}_{mode}')
    for col in pivoted_cols_to_add:
        df_pivot_foof_params[col] = np.nan

else:
    df_pivot_foof_params = df_fine_grain_results_cell4.pivot_table(
        index=['timestamp_unix', 'channel', 'electrode_label', 'freq_band_label'], 
        columns='aperiodic_mode',
        values=actual_pivot_values,
        aggfunc='first' 
    ).reset_index()
    df_pivot_foof_params.columns = [f'{col[0]}_{col[1]}'.rstrip('_') if col[1] else col[0] for col in df_pivot_foof_params.columns]


rename_dict_cell8 = { 
    'r_squared_fixed': 'R2_Fixed', 'fit_error_fixed': 'Error_Fixed',
    'aperiodic_offset_fixed': 'Offset_Fixed', 'aperiodic_exponent_fixed': 'Exponent_Fixed',
    'num_model_peaks_fixed': 'Num_Peaks_Fixed',
    'r_squared_knee': 'R2_Knee', 'fit_error_knee': 'Error_Knee',
    'aperiodic_offset_knee': 'Offset_Knee', 'aperiodic_knee_knee': 'Knee_Knee',
    'aperiodic_exponent_knee': 'Exponent_Knee',
    'num_model_peaks_knee': 'Num_Peaks_Knee',
    'channel': 'Channel', 'freq_band_label': 'FreqRangeLabel',
    'timestamp_unix': 'Aligned_PKG_UnixTimestamp' 
}
df_pivot_foof_params.rename(columns=rename_dict_cell8, inplace=True)

# --- 2. Determine the "Best Model" from pivoted FOOOF params ---
print("Determining best aperiodic model based on R-squared...")
best_model_param_cols = ['BestModel_AperiodicMode', 'Offset_BestModel', 'Exponent_BestModel', 
                         'Knee_BestModel', 'R2_BestModel', 'Error_BestModel', 'Num_Peaks_BestModel']

if 'R2_Fixed' in df_pivot_foof_params.columns and 'R2_Knee' in df_pivot_foof_params.columns:
    r2_fixed_fillna = df_pivot_foof_params['R2_Fixed'].fillna(-np.inf)
    r2_knee_fillna = df_pivot_foof_params['R2_Knee'].fillna(-np.inf)
    
    conditions_best_model = [ r2_knee_fillna > r2_fixed_fillna, r2_fixed_fillna >= r2_knee_fillna ]
    choices_mode_best = ['knee', 'fixed']
    df_pivot_foof_params['BestModel_AperiodicMode'] = np.select(conditions_best_model, choices_mode_best, default='n/a')

    for mode in ['Fixed', 'Knee']:
        for param in ['Offset', 'Exponent', 'Knee', 'R2', 'Error', 'Num_Peaks']:
            col_name = f"{param}_{mode}"
            if col_name not in df_pivot_foof_params:
                df_pivot_foof_params[col_name] = np.nan
    
    df_pivot_foof_params['Offset_BestModel'] = np.where(df_pivot_foof_params['BestModel_AperiodicMode'] == 'knee', df_pivot_foof_params['Offset_Knee'], df_pivot_foof_params['Offset_Fixed'])
    df_pivot_foof_params['Exponent_BestModel'] = np.where(df_pivot_foof_params['BestModel_AperiodicMode'] == 'knee', df_pivot_foof_params['Exponent_Knee'], df_pivot_foof_params['Exponent_Fixed'])
    df_pivot_foof_params['Knee_BestModel'] = np.where(df_pivot_foof_params['BestModel_AperiodicMode'] == 'knee', df_pivot_foof_params['Knee_Knee'], np.nan)
    df_pivot_foof_params['R2_BestModel'] = np.where(df_pivot_foof_params['BestModel_AperiodicMode'] == 'knee', df_pivot_foof_params['R2_Knee'], df_pivot_foof_params['R2_Fixed'])
    df_pivot_foof_params['Error_BestModel'] = np.where(df_pivot_foof_params['BestModel_AperiodicMode'] == 'knee', df_pivot_foof_params['Error_Knee'], df_pivot_foof_params['Error_Fixed'])
    df_pivot_foof_params['Num_Peaks_BestModel'] = np.where(df_pivot_foof_params['BestModel_AperiodicMode'] == 'knee', df_pivot_foof_params['Num_Peaks_Knee'], df_pivot_foof_params['Num_Peaks_Fixed'])
else: 
    print("Warning: R2_Fixed or R2_Knee not found in pivoted FOOOF data. Best model columns will be NaN.")
    for col in best_model_param_cols:
        df_pivot_foof_params[col] = np.nan
df_pivot_foof_params['ErrorMsg_FOOOF'] = '' 

# --- 3. Merge FOOOF results with original segment info from fooof_input_df ---
print("Merging pivoted FOOOF results with main segment data (including LEDD, Beta, Gamma)...")
segment_base_cols = [
    'Aligned_PKG_UnixTimestamp', 'Channel', 'ElectrodeLabel',
    'SessionID', 'Hemisphere', 
    'Neural_Segment_Start_Unixtime', 'Neural_Segment_End_Unixtime',
    'Neural_Segment_Duration_Sec', 'FS', 'PSD_Data_Str', 'Frequency_Vector_Str',
    'Aligned_PKG_DateTime_Str', 'Clinical_State_2min_Window', 'Clinical_State_Aggregated',
    'Aligned_BK', 'Aligned_DK', 'Aligned_Tremor_Score', 'Aligned_Tremor',
    'Total_Daily_LEDD_mg',
    'Beta_Peak_Power_at_DominantFreq',
    'Gamma_Peak_Power_at_DominantFreq'
]
segment_base_cols_present = [col for col in segment_base_cols if col in fooof_input_df.columns]
df_segment_info_to_merge = fooof_input_df[segment_base_cols_present].drop_duplicates(subset=['Aligned_PKG_UnixTimestamp', 'Channel'])

# Ensure merge keys exist and are of correct type in both dataframes
if 'Aligned_PKG_UnixTimestamp' not in df_pivot_foof_params.columns:
    df_pivot_foof_params['Aligned_PKG_UnixTimestamp'] = np.nan # Add if missing
if 'Channel' not in df_pivot_foof_params.columns:
    df_pivot_foof_params['Channel'] = np.nan

df_pivot_foof_params['Aligned_PKG_UnixTimestamp'] = pd.to_numeric(df_pivot_foof_params['Aligned_PKG_UnixTimestamp'], errors='coerce').astype('Int64')
df_segment_info_to_merge['Aligned_PKG_UnixTimestamp'] = pd.to_numeric(df_segment_info_to_merge['Aligned_PKG_UnixTimestamp'], errors='coerce').astype('Int64')
df_pivot_foof_params['Channel'] = df_pivot_foof_params['Channel'].astype(str)
df_segment_info_to_merge['Channel'] = df_segment_info_to_merge['Channel'].astype(str)


master_df_final = pd.merge(
    df_pivot_foof_params,
    df_segment_info_to_merge,
    on=['Aligned_PKG_UnixTimestamp', 'Channel'],
    how='left' 
)

if 'ITERATIVE_FREQ_BANDS' in locals() and isinstance(ITERATIVE_FREQ_BANDS, dict):
    band_map_low = {k: v[0] for k, v in ITERATIVE_FREQ_BANDS.items()} 
    band_map_high = {k: v[1] for k, v in ITERATIVE_FREQ_BANDS.items()}
    if 'FreqRangeLabel' in master_df_final.columns:
        master_df_final['FreqLow'] = master_df_final['FreqRangeLabel'].map(band_map_low)
        master_df_final['FreqHigh'] = master_df_final['FreqRangeLabel'].map(band_map_high)
    else:
        master_df_final['FreqLow'], master_df_final['FreqHigh'] = np.nan, np.nan
else:
    print("Warning: ITERATIVE_FREQ_BANDS not defined or not a dict, FreqLow/High cannot be mapped.")
    master_df_final['FreqLow'], master_df_final['FreqHigh'] = np.nan, np.nan


# --- 4. Finalize and Save Master Table ---
final_columns_ordered = [col for col in master_table_columns if col in master_df_final.columns]
# Add any columns that might be in master_df_final but not in master_table_columns (e.g. if new ones were missed)
# or if some columns from master_table_columns were not generated
existing_cols_in_master_df = set(master_df_final.columns)
final_ordered_and_existing_cols = []
for col in master_table_columns:
    if col in existing_cols_in_master_df:
        final_ordered_and_existing_cols.append(col)
    else:
        print(f"Info: Column '{col}' from master_table_columns definition is not in the generated master_df_final. It will be missing.")

# Ensure all columns from master_df_final are included, even if not in master_table_columns initially
for col in master_df_final.columns:
    if col not in final_ordered_and_existing_cols:
        final_ordered_and_existing_cols.append(col) # Add to end

master_df_to_save = master_df_final[final_ordered_and_existing_cols].copy()

# UserSessionName from Cell 1 (session_id)
if 'session_id' in locals() and 'UserSessionName' not in master_df_to_save.columns :
    master_df_to_save.insert(0, 'UserSessionName', session_id)

if not master_df_to_save.empty:
    print(f"Master DataFrame for {session_id} prepared. Shape: {master_df_to_save.shape}")
    try:
        # master_csv_path_patient_specific is defined in Cell 3
        master_df_to_save.to_csv(master_csv_path_patient_specific, index=False) 
        print(f"Successfully saved master data for {session_id} to: {master_csv_path_patient_specific}")
        print("\nSample of the final Master DataFrame (first 5 rows):")
        print(master_df_to_save.head())
    except Exception as e_save_master:
        print(f"ERROR saving the master CSV for {session_id}: {e_save_master}")
else:
    print(f"Warning: Master DataFrame for {session_id} is empty. Nothing to save.")

print("\n--- Cell 8: Master DataFrame Generation Complete (with new metrics and LEDD override prompt) ---")
# --- End of Cell 8 ---


--- Cell 8: Starting Master DataFrame Generation (with new metrics) ---
Determining best aperiodic model based on R-squared...
Merging pivoted FOOOF results with main segment data (including LEDD, Beta, Gamma)...
Master DataFrame for RCS05L prepared. Shape: (34104, 46)
Successfully saved master data for RCS05L to: /home/jackson/step2_final/step3_fooof_results_neural_pkg_aligned_finalstep3_bushlab5000/MASTER_FOOOF_PKG_results_RCS05L_neural_pkg_aligned_finalstep3_bushlab5000.csv

Sample of the final Master DataFrame (first 5 rows):
  UserSessionName SessionID Hemisphere           Channel    ElectrodeLabel  \
0          RCS05L    RCS05L       Left  key0_contact_2_0  key0_contact_2_0   
1          RCS05L    RCS05L       Left  key0_contact_2_0  key0_contact_2_0   
2          RCS05L    RCS05L       Left  key0_contact_2_0  key0_contact_2_0   
3          RCS05L    RCS05L       Left  key1_contact_3_0  key1_contact_3_0   
4          RCS05L    RCS05L       Left  key1_contact_3_0  key1_contact_3_

In [1]:
import pandas as pd

# File path
file_path = "/home/jackson/step2_final/step3_fooof_results_neural_pkg_aligned_finalstep3_bushlab5000/MASTER_FOOOF_PKG_results_COHORT_RCS02_05_06_neural_pkg_aligned_finalstep3_bushlab5000.csv"

# Read only the top 10 rows
df_head = pd.read_csv(file_path, nrows=10)

# Display the dataframe
df_head


Unnamed: 0,UserSessionName,SessionID,Hemisphere,Channel,ElectrodeLabel,Neural_Segment_Start_Unixtime,Neural_Segment_End_Unixtime,Neural_Segment_Duration_Sec,FS,PSD_Data_Str,...,Offset_Knee,Knee_Knee,Exponent_Knee,R2_Knee,Error_Knee,Num_Peaks_Knee,ErrorMsg_FOOOF,electrode_label,Knee_Fixed,Clinical_State_Aggregated
0,RCS05R,RCS05R,Right,Contact_10_8,Contact_10_8,1572122000.0,1572121860,119.999,250,2.66052542e-02;1.33773726e-02;6.45884420e-06;5...,...,-1.447832,12420.203688,3.375362,0.984525,0.038797,4,,Contact_10_8,,Other
1,RCS05R,RCS05R,Right,Contact_10_8,Contact_10_8,1572122000.0,1572121860,119.999,250,2.66052542e-02;1.33773726e-02;6.45884420e-06;5...,...,-2.52669,2128.144321,2.64661,0.98498,0.031981,5,,Contact_10_8,,Other
2,RCS05R,RCS05R,Right,Contact_10_8,Contact_10_8,1572122000.0,1572121860,119.999,250,2.66052542e-02;1.33773726e-02;6.45884420e-06;5...,...,-2.717127,366.239869,2.549288,0.995171,0.033937,9,,Contact_10_8,,Other
3,RCS05R,RCS05R,Right,Contact_11_9,Contact_11_9,1572122000.0,1572121860,119.999,250,8.53280283e-03;4.28750245e-03;7.49003494e-06;6...,...,-2.214448,1173.974911,2.821459,0.984792,0.036355,5,,Contact_11_9,,Other
4,RCS05R,RCS05R,Right,Contact_11_9,Contact_11_9,1572122000.0,1572121860,119.999,250,8.53280283e-03;4.28750245e-03;7.49003494e-06;6...,...,-2.096759,4826.07049,2.877989,0.98798,0.031753,7,,Contact_11_9,,Other
5,RCS05R,RCS05R,Right,Contact_11_9,Contact_11_9,1572122000.0,1572121860,119.999,250,8.53280283e-03;4.28750245e-03;7.49003494e-06;6...,...,-2.319679,873.754785,2.762552,0.995746,0.032656,13,,Contact_11_9,,Other
6,RCS05R,RCS05R,Right,Contact_2_0,Contact_2_0,1572122000.0,1572121860,119.999,250,2.00126204e-02;1.00843672e-02;1.58514413e-05;1...,...,-4.852147,-19.191762,1.633683,0.986332,0.033085,5,,Contact_2_0,,Other
7,RCS05R,RCS05R,Right,Contact_2_0,Contact_2_0,1572122000.0,1572121860,119.999,250,2.00126204e-02;1.00843672e-02;1.58514413e-05;1...,...,-5.202423,-35.405554,1.458241,0.981473,0.024593,7,,Contact_2_0,,Other
8,RCS05R,RCS05R,Right,Contact_2_0,Contact_2_0,1572122000.0,1572121860,119.999,250,2.00126204e-02;1.00843672e-02;1.58514413e-05;1...,...,-4.899903,-19.201748,1.609013,0.993985,0.027846,9,,Contact_2_0,,Other
9,RCS05R,RCS05R,Right,Contact_3_0,Contact_3_0,1572122000.0,1572121860,119.999,250,3.78174015e-02;1.90292956e-02;1.66311725e-05;1...,...,-5.580832,-5.193934,0.940326,0.965594,0.038339,2,,Contact_3_0,,Other
