In [None]:
#imports
# this notebook SAVES halt aligned data and baselined data as CSV together with PLOTS, different compared to the previous SANDBOX_2_noSLEAP#
#---------------------------------------------------------------------------------------------------#
from pathlib import Path
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import plotly.io as pio
import plotly.express as px
import plotly.subplots as sp
import math
from pprint import pprint

from matplotlib.collections import LineCollection
import seaborn as sns
import traceback
import gc
from typing import Optional, Tuple, List, Dict, Any

from plotly.subplots import make_subplots
from scipy.stats import mode
from scipy.integrate import cumulative_trapezoid
from scipy.signal import correlate
import json
%config Completer.use_jedi = False  # Fixes autocomplete issues
%config InlineBackend.figure_format = 'retina'  # Improves plot resolution

import gc # garbage collector for removing large variables from memory instantly 
import importlib #for force updating changed packages 

#import harp
import harp_resources.process
import harp_resources.utils
from harp_resources import process, utils # Reassign to maintain direct references for force updating 
#from sleap import load_and_process as lp

In [None]:
#-------------------------------
# data paths setup
#-------------------------------
data_dirs = [  # Add your data directories here
    # Path('~/RANCZLAB-NAS/data/ONIX/20250409_Cohort3_rotation/Vestibular_mismatch_day1').expanduser(),
    Path('/Volumes/RanczLab2/20241125_Cohort1_rotation/Visual_mismatch_day4').expanduser()
    # Path('/Volumes/RanczLab2/20250409_Cohort3_rotation/Visual_mismatch_day4').expanduser()
]
# Collect raw data paths (excluding '_processedData' dirs)
rawdata_paths = []
for data_dir in data_dirs:
    subdirs = [p for p in data_dir.iterdir() if p.is_dir() and not p.name.endswith('_processedData')]
    rawdata_paths.extend(subdirs)  # Collect all subdirectories

# Build processed data paths
data_paths = [raw.parent / f"{raw.name}_processedData/downsampled_data" for raw in rawdata_paths]
# Print data paths in a more readable format
print("Processed Data Paths:")
pprint(data_paths)

#-------------------------------
# initial variables setup
#-------------------------------
time_window_start = -5  # s, FOR PLOTTING PURPOSES
time_window_end = 10  # s, FOR PLOTTING PURPOSES
baseline_window = (-1, 0)  # s, FOR baselining averages
plot_width = 14

event_name = "No halt"  # Apply halt: 2s, No halt, DrumWithReverseflow block started, DrumBase block started
vestibular_mismatch = False
common_resampled_rate = 1000  # in Hz
plot_fig1 = False

# for saccades
framerate = 59.77  # Hz (in the future, should come from saved data)
threshold = 65  # px/s FIXME make this adaptive
refractory_period = pd.Timedelta(milliseconds=100)  # msec, using pd.Timedelta for datetime index
plot_saccade_detection_QC = False


In [None]:
# load downsampled data for each data path
#-------------------------------
loaded_data = {}  # Dictionary to store loaded data for each path

for idx, data_path in enumerate(data_paths, start=1):
    print(f"\nProcessing data path {idx}/{len(data_paths)}: {data_path}")
    try:
        # Load all parquet files for this data path
        photometry_tracking_encoder_data = pd.read_parquet(data_path / "photometry_tracking_encoder_data.parquet", engine="pyarrow")
        camera_photodiode_data = pd.read_parquet(data_path / "camera_photodiode_data.parquet", engine="pyarrow")
        experiment_events = pd.read_parquet(data_path / "experiment_events.parquet", engine="pyarrow")
        photometry_info = pd.read_parquet(data_path / "photometry_info.parquet", engine="pyarrow")
        session_settings = pd.read_parquet(data_path / "session_settings.parquet", engine="pyarrow")
        session_settings["metadata"] = session_settings["metadata"].apply(process.safe_from_json)
        
        print(f"✅ Successfully loaded all parquet files for {data_path.name}")
        
        # Calculate time differences between event_name events
        event_times = experiment_events[experiment_events["Event"] == event_name].index
        if len(event_times) > 1:
            time_diffs = event_times.to_series().diff().dropna().dt.total_seconds()
            # Print the 5 shortest time differences
            # print("5 shortest time differences between events:")
            # print(time_diffs.nsmallest(5))
            if (time_diffs < 10).any():
                print(f"⚠️ Warning: Some '{event_name}' events are less than 10 seconds apart. Consider applying a filter to events.")
        else:
            print(f"ℹ️ INFO: Found {len(event_times)} events with name '{event_name}' - not enough to calculate differences")
        
        # Check experiment events and get mouse name
        mouse_name = process.check_exp_events(experiment_events, photometry_info, verbose=True)
        
        # Store all loaded data in the dictionary
        loaded_data[data_path] = {
            "photometry_tracking_encoder_data": photometry_tracking_encoder_data,
            "camera_photodiode_data": camera_photodiode_data,
            "experiment_events": experiment_events,
            "photometry_info": photometry_info,
            "session_settings": session_settings,
            "mouse_name": mouse_name
        }
        
    except Exception as e:
        print(f"⚠️ ERROR processing data path {data_path}: {str(e)}")
        continue

print(f"\n✅ Finished loading data for all {len(loaded_data)} successfully processed data paths")

In [None]:
# create DFs and plot figure for each data path
#---------------------------------------------------
# Dictionary to store analysis results for each data path
data_path_variables = {}

for idx, data_path in enumerate(data_paths, start=1):
    print(f"\n--------- Processing analysis for data path {idx}/{len(data_paths)}: {data_path} ---------")
    
    # Skip if data wasn't successfully loaded for this path
    if data_path not in loaded_data:
        print(f"⚠️ Skipping analysis for {data_path} - data not loaded successfully")
        continue
    
    try:
        # Extract data from loaded_data dictionary
        photometry_tracking_encoder_data = loaded_data[data_path]["photometry_tracking_encoder_data"]
        camera_photodiode_data = loaded_data[data_path]["camera_photodiode_data"]
        experiment_events = loaded_data[data_path]["experiment_events"]
        mouse_name = loaded_data[data_path]["mouse_name"]
        session_name = f"{mouse_name}_{data_path.name}"  # Assuming session_name is constructed this way
        
        # Create dataframe to analyze
        df_to_analyze = photometry_tracking_encoder_data["Photodiode_int"]  # Using downsampled values in common time grid
        # df_to_analyze = camera_photodiode_data["Photodiode"]  # Use async raw values if needed for troubleshooting
        
        # Determine halt times based on different conditions
        if vestibular_mismatch or event_name == "No halt":  # Determine halt times based on experiment events
            events_matching_name = experiment_events[experiment_events["Event"] == event_name]
            if events_matching_name.empty:
                print(f"⚠️ WARNING: No events found with name '{event_name}', skipping this data path")
                continue
                
            photodiode_halts = events_matching_name.index.tolist()
            nearest_indices = photometry_tracking_encoder_data.index.get_indexer(photodiode_halts, method='nearest')
            photodiode_halts = photometry_tracking_encoder_data.index[nearest_indices]  # Align to downsampled data time grid
            print(f"ℹ️ INFO: vestibular MM or 'No halt', no signal in the photodiode, using experiment events for MM times")
            photodiode_delay_min = photodiode_delay_avg = photodiode_delay_max = None
        else:  # Determine exact halt times based on photodiode signal
            try:
                photodiode_halts, photodiode_delay_min, photodiode_delay_avg, photodiode_delay_max = process.analyze_photodiode(
                    df_to_analyze, experiment_events, event_name, plot=True
                )
                print(f"✅ Successfully analyzed photodiode signal for {data_path.name}")
            except Exception as e:
                print(f"⚠️ ERROR analyzing photodiode signal: {str(e)}")
                continue
        
        # Store analysis results
        data_path_variables[data_path] = {
            "photodiode_halts": photodiode_halts,
            "photodiode_delay_min": photodiode_delay_min,
            "photodiode_delay_avg": photodiode_delay_avg,
            "photodiode_delay_max": photodiode_delay_max,
            "session_name": session_name
        }
        
        # Plot figure if requested
        if plot_fig1:
            try:
                process.plot_figure_1(
                    photometry_tracking_encoder_data, 
                    session_name, 
                    save_path, 
                    common_resampled_rate, 
                    photodiode_halts, 
                    save_figure=True, 
                    show_figure=True, 
                    downsample_factor=50
                )
                print(f"✅ Successfully created figure 1 for {data_path.name}")
            except Exception as e:
                print(f"⚠️ ERROR creating figure 1: {str(e)}")
        else:
            print(f"ℹ️ INFO: skipping figure 1 for {data_path.name}")
        
        # Clean up to free memory
        del df_to_analyze
        gc.collect()
        
        print(f"✅ Completed analysis for data path: {data_path}")
        
    except Exception as e:
        print(f"⚠️ ERROR during analysis of {data_path}: {str(e)}")

print(f"\n✅ Finished analyzing all {len(data_path_variables)} successfully processed data paths")

In [None]:
# #----oldcode-----------------------------------------------
# # Create aligned data and plot fluorescence traces for each data path (with improved error handling)
# #---------------------------------------------------

# # Check if requiorange variables exist
# required_vars = ['time_window_start', 'time_window_end']
# for var in required_vars:
#     if var not in globals():
#         print(f"⚠️ ERROR: Required variable '{var}' is not defined")
#         # Set default values
#         if var == 'time_window_start':
#             time_window_start = -5  # Default: 5 seconds before halt
#             print(f"  Setting default value: {var} = {time_window_start}")
#         elif var == 'time_window_end':
#             time_window_end = 10  # Default: 10 seconds after halt
#             print(f"  Setting default value: {var} = {time_window_end}")

# # Define save_plots if not already defined
# if 'save_plots' not in globals():
#     save_plots = False
#     print(f"Setting default value: save_plots = {save_plots}")

# for idx, data_path in enumerate(data_paths, start=1):
#     print(f"\n--------- Creating fluorescence plots for data path {idx}/{len(data_paths)}: {data_path} ---------")
    
#     # Skip if data wasn't successfully analyzed for this path
#     if data_path not in data_path_variables:
#         print(f"⚠️ Skipping fluorescence plots for {data_path} - analysis not completed successfully")
#         continue
    
#     try:
#         # Verify data exists before attempting to extract it
#         if data_path not in loaded_data:
#             print(f"⚠️ ERROR: data_path {data_path} not found in loaded_data")
#             continue
            
#         if "photometry_tracking_encoder_data" not in loaded_data[data_path]:
#             print(f"⚠️ ERROR: 'photometry_tracking_encoder_data' not found in loaded_data[{data_path}]")
#             continue
            
#         if "photodiode_halts" not in data_path_variables[data_path]:
#             print(f"⚠️ ERROR: 'photodiode_halts' not found in data_path_variables[{data_path}]")
#             continue
            
#         # Extract data from loaded_data and data_path_variables dictionaries
#         photometry_tracking_encoder_data = loaded_data[data_path]["photometry_tracking_encoder_data"]
#         photodiode_halts = data_path_variables[data_path]["photodiode_halts"]
        
#         # Check if session_name exists in data_path_variables
#         if "session_name" in data_path_variables[data_path]:
#             session_name = data_path_variables[data_path]["session_name"]
#         else:
#             # Try to generate session_name
#             if "mouse_name" in loaded_data[data_path]:
#                 mouse_name = loaded_data[data_path]["mouse_name"]
#                 session_name = f"{mouse_name}_{data_path.name}"
#             else:
#                 session_name = f"session_{data_path.name}"
#                 print(f"⚠️ WARNING: Using generic session name: {session_name}")
        
#         print(f"Creating aligned data for {len(photodiode_halts)} events in {data_path.name}")
        
#         # Check if required columns exist in photometry data
#         required_columns = ["z_470", "z_560"]
#         missing_columns = [col for col in required_columns if col not in photometry_tracking_encoder_data.columns]
#         if missing_columns:
#             print(f"⚠️ ERROR: Missing columns in photometry_tracking_encoder_data: {missing_columns}")
#             print(f"Available columns: {photometry_tracking_encoder_data.columns.tolist()}")
#             continue
        
#         # --- Data Alignment ---
#         aligned_data = []
#         for halt_time in photodiode_halts:
#             try:
#                 window_data = photometry_tracking_encoder_data.loc[
#                     (photometry_tracking_encoder_data.index >= halt_time + pd.Timedelta(seconds=time_window_start)) &
#                     (photometry_tracking_encoder_data.index <= halt_time + pd.Timedelta(seconds=time_window_end))
#                 ].copy()
                
#                 if window_data.empty:
#                     print(f"⚠️ WARNING: No data found for window around halt time {halt_time}")
#                     continue
                    
#                 window_data["Time (s)"] = (window_data.index - halt_time).total_seconds()
#                 window_data["Halt Time"] = halt_time
#                 aligned_data.append(window_data)
#             except Exception as e:
#                 print(f"⚠️ ERROR processing halt time {halt_time}: {str(e)}")
#                 continue
        
#         if not aligned_data:
#             print(f"⚠️ WARNING: No aligned data created for {data_path.name}, skipping plotting")
#             continue
            
#         aligned_df = pd.concat(aligned_data, ignore_index=True)
        
#         # --- Subplot Grid Setup ---
#         n_events = len(photodiode_halts)
#         n_cols = 4
#         n_rows = math.ceil(n_events / n_cols)
        
#         print(f"Creating subplot grid with {n_rows} rows and {n_cols} columns for {n_events} events")
        
#         # Create subplots with a single (default) y-axis in each cell
#         specs = [[{} for _ in range(n_cols)] for _ in range(n_rows)]
#         subplot_titles = [f'Event: {halt_time}' for halt_time in photodiode_halts]
#         fig = sp.make_subplots(rows=n_rows, cols=n_cols, subplot_titles=subplot_titles, specs=specs)
        
#         # Base extra is used to create unique axis IDs starting after the auto-assigned primary axes
#         base_extra = n_events + 1
        
#         for i, halt_time in enumerate(photodiode_halts):
#             try:
#                 row = (i // n_cols) + 1
#                 col = (i % n_cols) + 1
                
#                 subset = aligned_df[aligned_df["Halt Time"] == halt_time]
#                 if subset.empty:
#                     print(f"⚠️ WARNING: No data found for halt time {halt_time}")
#                     continue
                
#                 # Verify data exists in subset
#                 for col_name in ["Time (s)", "z_470", "z_560"]:
#                     if col_name not in subset.columns:
#                         print(f"⚠️ ERROR: Column '{col_name}' not found in aligned data subset")
#                         print(f"Available columns: {subset.columns.tolist()}")
#                         raise KeyError(f"Missing column: {col_name}")
                    
#                 # -- Fluorescence Traces on Primary y-axis --
#                 fig.add_trace(
#                     go.Scatter(
#                         x=subset["Time (s)"],
#                         y=subset["z_470"],
#                         mode='lines',
#                         name='z_470',
#                         line=dict(color='green')
#                     ),
#                     row=row, col=col
#                 )
                
#                 fig.add_trace(
#                     go.Scatter(
#                         x=subset["Time (s)"],
#                         y=subset["z_560"],
#                         mode='lines',
#                         name='z_560',
#                         line=dict(color='red')
#                     ),
#                     row=row, col=col
#                 )
                
#                 # Determine the subplot's x-axis anchor
#                 xaxis_number = (row - 1) * n_cols + col
#                 x_anchor = "x" if xaxis_number == 1 else f"x{xaxis_number}"
                
#                 # The primary y-axis for this subplot
#                 primary_y = "y" if xaxis_number == 1 else f"y{xaxis_number}"
                
#             except Exception as e:
#                 print(f"⚠️ ERROR processing subplot for halt time {halt_time}: {str(e)}")
#                 continue
        
#         # --- Update Common Axis Labels ---
#         fig.update_xaxes(title_text="Time (s)")
#         fig.update_yaxes(title_text="Fluorescence (z-score)")
        
#         # Create a descriptive title that includes session name
#         title = f"Fluorescence for each event - {session_name}"
        
#         fig.update_layout(
#             height=400 * n_rows,
#             width=350 * n_cols,
#             title_text=title,
#             template='plotly_white'
#         )
        
#         # Save the figure if needed
#         if save_plots:
#             try:
#                 # Ensure save_path exists
#                 if 'save_path' not in globals() or save_path is None:
#                     from pathlib import Path
#                     save_path = Path('./output')
#                     save_path.mkdir(exist_ok=True)
#                     print(f"Creating default save_path: {save_path}")
                    
#                 output_file = save_path / f"{session_name}_fluorescence_events.html"
#                 fig.write_html(str(output_file))
#                 print(f"✅ Saved fluorescence plot to {output_file}")
#             except Exception as e:
#                 print(f"⚠️ ERROR saving fluorescence plot: {str(e)}")
        
#         # Display the figure
#         fig.show()
        
#         # Clean up to free memory
#         del aligned_data
#         gc.collect()
        
#         print(f"✅ Completed fluorescence plots for data path: {data_path}")
        
#     except Exception as e:
#         import traceback
#         print(f"⚠️ ERROR creating fluorescence plots for {data_path}: {str(e)}")
#         print("Detailed error traceback:")
#         traceback.print_exc()  # This will print the full stack trace

# print(f"\n✅ Finished creating fluorescence plots for all successfully processed data paths")

In [None]:
# #----oldcode-----------------------------------------------
# # Create aligned data and plot comprehensive figures for each data path
# #---------------------------------------------------

# print(f"Using time window: {time_window_start}s to {time_window_end}s relative to halt")

# # Iterate through each data path
# for idx, data_path in enumerate(data_paths, start=1):
#     print(f"\n--------- Processing {idx}/{len(data_paths)}: {data_path} ---------")
    
#     if data_path not in data_path_variables:
#         print(f"⚠️ Skipping {data_path} - no analysis data found")
#         continue
    
#     try:
#         # Extract data
#         data = loaded_data[data_path]
#         vars_ = data_path_variables[data_path]
        
#         df = data["photometry_tracking_encoder_data"]
#         halts = vars_["photodiode_halts"]
        
#         session_name = vars_.get("session_name")
#         if not session_name:
#             mouse_name = data.get("mouse_name", "unknown_mouse")
#             session_name = f"{mouse_name}_{data_path.stem}"
#             print(f"⚠️ No session_name found, using: {session_name}")

#         event_name = event_name
#         print(f"Aligning {len(halts)} events for session '{session_name}'")

#         # Align data to each halt event
#         aligned_data = []
#         for i, halt_time in enumerate(halts):
#             window = df.loc[
#                 (df.index >= halt_time + pd.Timedelta(seconds=time_window_start)) &
#                 (df.index <= halt_time + pd.Timedelta(seconds=time_window_end))
#             ].copy()

#             if window.empty:
#                 print(f"⚠️ No data in window around halt {halt_time}")
#                 continue

#             window["Time (s)"] = (window.index - halt_time).total_seconds()
#             window["Halt Time"] = halt_time
#             aligned_data.append(window)

#         if not aligned_data:
#             print(f"⚠️ No aligned data generated for {session_name}, skipping")
#             continue

#         aligned_df = pd.concat(aligned_data, ignore_index=True)

#         # Save CSV in same folder as data
#         aligned_dir = data_path.parent / "aligned_data"
#         aligned_dir.mkdir(exist_ok=True)

#         aligned_file = aligned_dir / f"{session_name}_{event_name}_aligned.csv"
#         aligned_df.to_csv(aligned_file, index=False)
#         print(f"✅ Saved aligned data to {aligned_file}")

#         # Fill in missing columns with dummy data (except required)
#         required_columns = ["Time (s)", "Photodiode_int", "z_470", "z_560", "Motor_Velocity", "Velocity_0X", "Velocity_0Y"]
#         for col in required_columns:
#             if col not in aligned_df.columns:
#                 print(f"⚠️ Missing column: {col}, adding zeros")
#                 aligned_df[col] = 0

#         # Compute group mean and SEM
#         mean_df = aligned_df.groupby("Time (s)").mean()
#         sem_df = aligned_df.groupby("Time (s)").sem()

#         # Create figure
#         print(f"📈 Creating plot for {session_name}")
#         fig, axes = plt.subplots(1, 2, figsize=(14, 6), sharex=True)

#         ## Plot 1: Individual Traces
#         ax1 = axes[0]
#         for halt in aligned_df["Halt Time"].unique():
#             subset = aligned_df[aligned_df["Halt Time"] == halt]
#             ax1.plot(subset["Time (s)"], subset["Photodiode_int"], color='grey', alpha=0.5)

#         ax1.set_title('Photodiode, z_470, and z_560')
#         ax1.set_xlabel("Time (s)")
#         ax1.set_ylabel("Photodiode")

#         ax1_2 = ax1.twinx()
#         for halt in aligned_df["Halt Time"].unique():
#             subset = aligned_df[aligned_df["Halt Time"] == halt]
#             ax1_2.plot(subset["Time (s)"], subset["z_470"], color='green', alpha=0.5)
#             ax1_2.plot(subset["Time (s)"], subset["z_560"], color='red', alpha=0.5)

#         ax1_2.set_ylabel("Fluorescence (z-score)", color='green')

#         ## Plot 2: Mean + SEM
#         ax2 = axes[1]
#         ax2.plot(mean_df.index, mean_df["Photodiode_int"], color='grey')
#         ax2.fill_between(mean_df.index, mean_df["Photodiode_int"] - sem_df["Photodiode_int"],
#                          mean_df["Photodiode_int"] + sem_df["Photodiode_int"], color='grey', alpha=0.2)
#         ax2.set_xlabel("Time (s)")
#         ax2.set_ylabel("Photodiode")
#         ax2.set_title("Mean & SEM")

#         ax2_2 = ax2.twinx()
#         ax2_2.plot(mean_df.index, mean_df["z_470"], color='green')
#         ax2_2.fill_between(mean_df.index, mean_df["z_470"] - sem_df["z_470"], 
#                            mean_df["z_470"] + sem_df["z_470"], color='green', alpha=0.2)

#         ax2_2.plot(mean_df.index, mean_df["z_560"], color='red')
#         ax2_2.fill_between(mean_df.index, mean_df["z_560"] - sem_df["z_560"], 
#                            mean_df["z_560"] + sem_df["z_560"], color='red', alpha=0.2)

#         ax2_2.set_ylabel("Fluorescence (z-score)", color='green')

#         # Save figure in same folder as data
#         fig.suptitle(f"{session_name} - {event_name}")
#         fig.tight_layout()
#         figure_file = data_path.parent / f"{session_name}_{event_name}.pdf"
#         fig.savefig(figure_file, dpi=300)
#         plt.close(fig)
#         print(f"✅ Saved figure to {figure_file}")

#     except Exception as e:
#         import traceback
#         print(f"❌ ERROR processing {data_path}: {str(e)}")
#         traceback.print_exc()

# print("\n✅ Finished all data paths.")


In [None]:
# -----oldcode--------------------------------------------------
# #separates RIGHT vs LEFT TURNS, creates heatmaps, and comprehensive (Mean +- sem) aligned data for each data path
# #---------------------------------------------------
# from matplotlib.collections import LineCollection
# import seaborn as sns

# import traceback
# def process_aligned_data(df, halt_time, time_window_start, time_window_end):
#     """Process a single halt event efficiently."""
#     window_start = halt_time + pd.Timedelta(seconds=time_window_start)
#     window_end = halt_time + pd.Timedelta(seconds=time_window_end)
#     mask = (df.index >= window_start) & (df.index <= window_end)
    
#     if not mask.any():
#         return None
    
#     window = df.loc[mask].copy()
#     window["Time (s)"] = (window.index - halt_time).total_seconds()
#     window["Halt Time"] = halt_time
#     return window

# def create_heatmap(pivot_data, session_name, event_name, channel, save_path, figsize=(10, 6)):
#     """Create and save heatmap efficiently."""
#     # Base normalize
#     baseline_cols = (pivot_data.columns >= -1) & (pivot_data.columns < 0)
#     if baseline_cols.any():
#         baseline_means = pivot_data.loc[:, baseline_cols].mean(axis=1)
#         normalized_data = pivot_data.subtract(baseline_means, axis=0)
#     else:
#         normalized_data = pivot_data
    
#     # Create figure with optimized settings
#     fig, ax = plt.subplots(figsize=figsize)
    
#     # Use fewer colors for better performance
#     sns.heatmap(normalized_data, cmap="RdBu_r", center=0, ax=ax, 
#                 cbar_kws={'label': f'Normalized {channel}'},
#                 rasterized=True)  # Rasterize for smaller file size
    
#     ax.set_title(f"Heatmap ({channel}) - {session_name}")
#     ax.set_xlabel("Time (s)")
#     ax.set_ylabel("Event")
    
#     # Set y-axis ticks to simple numbers instead of datetime
#     n_events = len(normalized_data.index)
#     y_tick_positions = range(0, n_events, max(1, n_events // 10))  # Show ~10 ticks max
#     y_tick_labels = [str(i+1) for i in y_tick_positions]  # Event numbers 1, 2, 3, etc.
#     ax.set_yticks(y_tick_positions)
#     ax.set_yticklabels(y_tick_labels)
    
#     # Add event line if 0 exists in columns
#     if 0 in normalized_data.columns:
#         zero_idx = list(normalized_data.columns).index(0)
#         ax.axvline(zero_idx, linestyle='--', color='black', alpha=0.7)
    
#     # Optimize tick labels - show every 2nd second
#     time_cols = normalized_data.columns
#     tick_indices = [i for i, val in enumerate(time_cols) 
#                    if isinstance(val, (int, float)) and val % 2 == 0]
#     tick_labels = [f"{int(time_cols[i])}" for i in tick_indices]
    
#     ax.set_xticks(tick_indices)
#     ax.set_xticklabels(tick_labels, rotation=45)
    
#     plt.tight_layout()
#     plt.savefig(save_path, dpi=200, bbox_inches='tight')  # Lower DPI for efficiency
#     plt.close(fig)
    
#     return normalized_data

# def create_summary_plot(aligned_df, session_name, event_name, save_path):
#     """Create summary plots efficiently."""
#     # Pre-compute grouped statistics once
#     grouped = aligned_df.groupby("Time (s)")
#     mean_df = grouped.mean()
#     sem_df = grouped.sem()
    
#     fig, axes = plt.subplots(1, 4, figsize=(28, 6), sharex=True)
#     # Separate left vs right turn halts based on Motor_Velocity before time0
#     left_turn_halts = []
#     right_turn_halts = []
#     for halt in aligned_df["Halt Time"].unique():
#         subset = aligned_df[(aligned_df["Halt Time"] == halt) & 
#                             (aligned_df["Time (s)"] >= -1) & 
#                             (aligned_df["Time (s)"] < 0)]
#         mean_velocity = subset["Motor_Velocity"].mean()
        
#         if mean_velocity < 0:  # Negative mean velocity indicates left turn
#             left_turn_halts.append(halt)
#         elif mean_velocity > 0:  # Positive mean velocity indicates right turn
#             right_turn_halts.append(halt)

#     # Create separate DataFrames for left and right turn halts
#     left_turn_df = aligned_df[aligned_df["Halt Time"].isin(left_turn_halts)]
#     right_turn_df = aligned_df[aligned_df["Halt Time"].isin(right_turn_halts)]

#     # Save left and right turn DataFrames to CSV
#     left_turn_file = aligned_dir / f"{session_name}_{event_name}_left_turns.csv"
#     right_turn_file = aligned_dir / f"{session_name}_{event_name}_right_turns.csv"

#     left_turn_df.to_csv(left_turn_file, index=False, float_format='%.4f')
#     right_turn_df.to_csv(right_turn_file, index=False, float_format='%.4f')

#     print(f"Saved left turn data to {left_turn_file}")
#     print(f"Saved right turn data to {right_turn_file}")
    
#     # Create subplots for z-score data (row 1) and dfF data (row 2)
#     fig, axes = plt.subplots(2, 3, figsize=(21, 12), sharex=True)
    
#     # Row 1: Z-score data
#     # Left plot - Individual traces for left turn halts (z_470, z_560, and Motor_Velocity)
#     ax1 = axes[0, 0]
#     left_lines_470 = []
#     left_lines_560 = []
#     left_lines_motor = []
#     for halt in left_turn_df["Halt Time"].unique():
#         subset = left_turn_df[left_turn_df["Halt Time"] == halt]
#         time_vals = subset["Time (s)"].values
#         left_lines_470.append(list(zip(time_vals, subset["z_470"].values)))
#         left_lines_560.append(list(zip(time_vals, subset["z_560"].values)))
#         left_lines_motor.append(list(zip(time_vals, subset["Motor_Velocity"].values)))

#     left_lc_470 = LineCollection(left_lines_470, colors='cornflowerblue', alpha=0.3, linewidths=1)
#     left_lc_560 = LineCollection(left_lines_560, colors='red', alpha=0.3, linewidths=1)
#     ax1.add_collection(left_lc_470)
#     ax1.add_collection(left_lc_560)

#     # Create a secondary y-axis for motor velocity
#     ax1_motor = ax1.twinx()
#     left_lc_motor = LineCollection(left_lines_motor, colors='slategray', alpha=0.3, linewidths=1)
#     ax1_motor.add_collection(left_lc_motor)
#     ax1_motor.set_ylabel("Motor Velocity", color='slategray')
#     ax1_motor.tick_params(axis='y', labelcolor='slategray')
#     ax1_motor.autoscale()  # Automatically fit the axis to the plot

#     ax1.set_title('Left Turn Traces (z_470, z_560, Motor_Velocity)')
#     ax1.set_xlabel("Time (s)")
#     ax1.set_ylabel("Fluorescence (z-score)")
#     ax1.axvline(0, linestyle='--', color='black', alpha=0.7)
#     ax1.autoscale()
    
#     # Right plot - Individual traces for right turn halts (z_470 & z_560)
#     ax2 = axes[0, 1]
#     right_lines_470 = []
#     right_lines_560 = []
#     right_lines_motor = []
#     for halt in right_turn_df["Halt Time"].unique():
#         subset = right_turn_df[right_turn_df["Halt Time"] == halt]
#         time_vals = subset["Time (s)"].values
#         right_lines_470.append(list(zip(time_vals, subset["z_470"].values)))
#         right_lines_560.append(list(zip(time_vals, subset["z_560"].values)))
#         right_lines_motor.append(list(zip(time_vals, subset["Motor_Velocity"].values)))
#     right_lc_470 = LineCollection(right_lines_470, colors='cornflowerblue', alpha=0.3, linewidths=1)
#     right_lc_560 = LineCollection(right_lines_560, colors='red', alpha=0.3, linewidths=1)
#     right_lc_motor = LineCollection(right_lines_motor, colors='slategray', alpha=0.3, linewidths=1)
#     ax2.add_collection(right_lc_470)
#     ax2.add_collection(right_lc_560)
    
#     # Add secondary y-axis for motor velocity
#     ax2_motor = ax2.twinx()
#     ax2_motor.add_collection(right_lc_motor)
#     ax2_motor.set_ylabel("Motor Velocity", color='slategray')
#     ax2_motor.tick_params(axis='y', labelcolor='slategray')
#     ax2_motor.autoscale()  # Automatically fit the axis to the plot
    
#     ax2.set_title('Right Turn Traces (z_470, z_560, Motor_Velocity)')
#     ax2.set_xlabel("Time (s)")
#     ax2.set_ylabel("Fluorescence (z-score)")
#     ax2.axvline(0, linestyle='--', color='black', alpha=0.7)
#     ax2.autoscale()
    
#     # Middle plot - Mean ± SEM for left and right turns (z_470 & z_560) with motor data
#     ax3 = axes[0, 2]
#     time_index = mean_df.index.values
#     for color, channel in [('cornflowerblue', 'z_470'), ('red', 'z_560')]:
#         ax3.plot(time_index, left_turn_df.groupby("Time (s)").mean()[channel], color=color, linestyle='--', linewidth=2, label=f"Left {channel}")
#         ax3.fill_between(time_index,
#                          left_turn_df.groupby("Time (s)").mean()[channel] - left_turn_df.groupby("Time (s)").sem()[channel],
#                          left_turn_df.groupby("Time (s)").mean()[channel] + left_turn_df.groupby("Time (s)").sem()[channel], 
#                          color=color, alpha=0.2)
#         ax3.plot(time_index, right_turn_df.groupby("Time (s)").mean()[channel], color=color, linewidth=2, label=f"Right {channel}")
#         ax3.fill_between(time_index,
#                          right_turn_df.groupby("Time (s)").mean()[channel] - right_turn_df.groupby("Time (s)").sem()[channel],
#                          right_turn_df.groupby("Time (s)").mean()[channel] + right_turn_df.groupby("Time (s)").sem()[channel], 
#                          color=color, alpha=0.2)
#     ax3.axvline(0, linestyle='--', color='black', alpha=0.7)
#     ax3.set_xlabel("Time (s)")
#     ax3.set_ylabel("Fluorescence (z-score)")
#     ax3.set_title("Mean ± SEM (z_470 & z_560)")
#     ax3.legend(loc='upper left')  # Add legend for z_470 and z_560

#     # Add secondary y-axis for motor data with SEM
#     ax3_motor = ax3.twinx()
#     motor_color = 'slategray'
#     ax3_motor.plot(time_index, left_turn_df.groupby("Time (s)").mean()["Motor_Velocity"], color=motor_color, linestyle='--', linewidth=1.5, label="Left Motor")
#     ax3_motor.fill_between(time_index,
#                            left_turn_df.groupby("Time (s)").mean()["Motor_Velocity"] - left_turn_df.groupby("Time (s)").sem()["Motor_Velocity"],
#                            left_turn_df.groupby("Time (s)").mean()["Motor_Velocity"] + left_turn_df.groupby("Time (s)").sem()["Motor_Velocity"],
#                            color=motor_color, alpha=0.2)
#     ax3_motor.plot(time_index, right_turn_df.groupby("Time (s)").mean()["Motor_Velocity"], color=motor_color, linewidth=1.5, label="Right Motor")
#     ax3_motor.fill_between(time_index,
#                            right_turn_df.groupby("Time (s)").mean()["Motor_Velocity"] - right_turn_df.groupby("Time (s)").sem()["Motor_Velocity"],
#                            right_turn_df.groupby("Time (s)").mean()["Motor_Velocity"] + right_turn_df.groupby("Time (s)").sem()["Motor_Velocity"],
#                            color=motor_color, alpha=0.2)
#     ax3_motor.set_ylabel("Motor Velocity", color=motor_color)
#     ax3_motor.tick_params(axis='y', labelcolor=motor_color)
#     ax3_motor.axhline(0, linestyle='--', color='gray', alpha=0.5)
#     ax3_motor.legend(loc='upper right')  # Add legend for motor data
    
#     # # Row 2: dfF data
#     # # Left plot - Individual traces for left turn halts (dfF_470 & dfF_560)
#     # ax4 = axes[1, 0]
#     # left_lines_dfF_470 = []
#     # left_lines_dfF_560 = []
#     # for halt in left_turn_df["Halt Time"].unique():
#     #     subset = left_turn_df[left_turn_df["Halt Time"] == halt]
#     #     time_vals = subset["Time (s)"].values
#     #     left_lines_dfF_470.append(list(zip(time_vals, subset["dfF_470"].values)))
#     #     left_lines_dfF_560.append(list(zip(time_vals, subset["dfF_560"].values)))
#     # left_lc_dfF_470 = LineCollection(left_lines_dfF_470, colors='blue', alpha=0.3, linewidths=1)
#     # left_lc_dfF_560 = LineCollection(left_lines_dfF_560, colors='orange', alpha=0.3, linewidths=1)
#     # ax4.add_collection(left_lc_dfF_470)
#     # ax4.add_collection(left_lc_dfF_560)
#     # ax4.set_title('Left Turn Traces (dfF_470 & dfF_560)')
#     # ax4.set_xlabel("Time (s)")
#     # ax4.set_ylabel("Fluorescence (dfF)")
#     # ax4.axvline(0, linestyle='--', color='black', alpha=0.7)
#     # ax4.autoscale()
    
#     # # Right plot - Individual traces for right turn halts (dfF_470 & dfF_560)
#     # ax5 = axes[1, 1]
#     # right_lines_dfF_470 = []
#     # right_lines_dfF_560 = []
#     # for halt in right_turn_df["Halt Time"].unique():
#     #     subset = right_turn_df[right_turn_df["Halt Time"] == halt]
#     #     time_vals = subset["Time (s)"].values
#     #     right_lines_dfF_470.append(list(zip(time_vals, subset["dfF_470"].values)))
#     #     right_lines_dfF_560.append(list(zip(time_vals, subset["dfF_560"].values)))
#     # right_lc_dfF_470 = LineCollection(right_lines_dfF_470, colors='blue', alpha=0.3, linewidths=1)
#     # right_lc_dfF_560 = LineCollection(right_lines_dfF_560, colors='orange', alpha=0.3, linewidths=1)
#     # ax5.add_collection(right_lc_dfF_470)
#     # ax5.add_collection(right_lc_dfF_560)
#     # ax5.set_title('Right Turn Traces (dfF_470 & dfF_560)')
#     # ax5.set_xlabel("Time (s)")
#     # ax5.set_ylabel("Fluorescence (dfF)")
#     # ax5.axvline(0, linestyle='--', color='black', alpha=0.7)
#     # ax5.autoscale()
    
#     # # Middle plot - Mean ± SEM for left and right turns (dfF_470 & dfF_560)
#     # ax6 = axes[1, 2]
#     # for color, channel in [('blue', 'dfF_470'), ('orange', 'dfF_560')]:
#     #     ax6.plot(time_index, left_turn_df.groupby("Time (s)").mean()[channel], color=color, linestyle='--', linewidth=2, label=f"Left {channel}")
#     #     ax6.fill_between(time_index,
#     #                      left_turn_df.groupby("Time (s)").mean()[channel] - left_turn_df.groupby("Time (s)").sem()[channel],
#     #                      left_turn_df.groupby("Time (s)").mean()[channel] + left_turn_df.groupby("Time (s)").sem()[channel], 
#     #                      color=color, alpha=0.2)
#     #     ax6.plot(time_index, right_turn_df.groupby("Time (s)").mean()[channel], color=color, linewidth=2, label=f"Right {channel}")
#     #     ax6.fill_between(time_index,
#     #                      right_turn_df.groupby("Time (s)").mean()[channel] - right_turn_df.groupby("Time (s)").sem()[channel],
#     #                      right_turn_df.groupby("Time (s)").mean()[channel] + right_turn_df.groupby("Time (s)").sem()[channel], 
#     #                      color=color, alpha=0.2)
#     # ax6.axvline(0, linestyle='--', color='black', alpha=0.7)
#     # ax6.set_xlabel("Time (s)")
#     # ax6.set_ylabel("Fluorescence (dfF)")
#     # ax6.set_title("Mean ± SEM (dfF_470 & dfF_560)")
#     # ax6.legend()
    
#     fig.suptitle(f"{session_name} - {event_name}")
#     plt.tight_layout()
#     plt.savefig(save_path, dpi=200, bbox_inches='tight')
#     plt.close(fig)
#     plt.ioff()  # Turn off interactive mode to suppress figure display

# # Required columns defined once
# REQUIRED_COLUMNS = ["Time (s)", "Photodiode_int", "z_470", "z_560","dfF_470", "dfF_560",
#                    "Motor_Velocity", "Velocity_0X", "Velocity_0Y"]

# for idx, data_path in enumerate(data_paths, start=1):
#     print(f"\n--------- Processing {idx}/{len(data_paths)}: {data_path} ---------")

#     if data_path not in data_path_variables:
#         print(f"Skipping {data_path} - no analysis data found")
#         continue

#     try:
#         # Load data references
#         data = loaded_data[data_path]
#         vars_ = data_path_variables[data_path]
#         df = data["photometry_tracking_encoder_data"]
#         halts = vars_["photodiode_halts"]

#         # Get session name
#         session_name = vars_.get("session_name")
#         if not session_name:
#             mouse_name = data.get("mouse_name", "unknown_mouse")
#             session_name = f"{mouse_name}_{data_path.stem}"
#             print(f"No session_name found, using: {session_name}")

#         print(f"Aligning {len(halts)} events for session '{session_name}'")

#         # Process all halt events efficiently
#         aligned_data = []
#         for halt_time in halts:
#             window_data = process_aligned_data(df, halt_time, time_window_start, time_window_end)
#             if window_data is not None:
#                 aligned_data.append(window_data)

#         if not aligned_data:
#             print(f"No aligned data generated for {session_name}, skipping")
#             continue

#         # Concatenate once
#         aligned_df = pd.concat(aligned_data, ignore_index=True)

#         # Save aligned data
#         aligned_dir = data_path.parent / "aligned_data"
#         aligned_dir.mkdir(exist_ok=True)
#         aligned_file = aligned_dir / f"{session_name}_{event_name}_aligned.csv"
        
#         # Use efficient CSV writing
#         aligned_df.to_csv(aligned_file, index=False, float_format='%.4f')
#         print(f"Saved aligned data to {aligned_file}")

#         # Create plots
#         print(f"📈 Creating plots for {session_name}")
        
#         # Summary plot
#         summary_path = data_path.parent / f"{session_name}_{event_name}.pdf"
#         create_summary_plot(aligned_df, session_name, event_name, summary_path)
#         print(f"Saved summary plot")

#         # Heatmaps - only create if we have enough data
#         unique_halts = aligned_df["Halt Time"].nunique()
#         if unique_halts > 1:
#             # z_470 heatmap
#             pivot_470 = aligned_df.pivot_table(index="Halt Time", columns="Time (s)", 
#                                              values="z_470", aggfunc='first')
#             heatmap_470_path = data_path.parent / f"{session_name}_{event_name}_heatmap_z_470.pdf"
#             create_heatmap(pivot_470, session_name, event_name, "z_470", heatmap_470_path)
#             print(f"Saved z_470 heatmap")

#             # z_560 heatmap  
#             pivot_560 = aligned_df.pivot_table(index="Halt Time", columns="Time (s)", 
#                                              values="z_560", aggfunc='first')
#             heatmap_560_path = data_path.parent / f"{session_name}_{event_name}_heatmap_z_560.pdf"
#             create_heatmap(pivot_560, session_name, event_name, "z_560", heatmap_560_path)
#             print(f"Saved z_560 heatmap")

#             # df470 heatmap
#             pivot_470df = aligned_df.pivot_table(index="Halt Time", columns="Time (s)", 
#                                              values="dfF_470", aggfunc='first')
#             heatmap_470df_path = data_path.parent / f"{session_name}_{event_name}_heatmap_dfF_470.pdf"
#             create_heatmap(pivot_470df, session_name, event_name, "dfF_470", heatmap_470df_path)
#             print(f"Saved df_470 heatmap")

#             #df560 heatmap
#             pivot_560df = aligned_df.pivot_table(index="Halt Time", columns="Time (s)", 
#                                             values="dfF_560", aggfunc='first')
#             heatmap_560df_path = data_path.parent / f"{session_name}_{event_name}_heatmap_df_560.pdf"
#             create_heatmap(pivot_560df, session_name, event_name, "dfF_560", heatmap_560df_path)
#             print(f"Saved df_560 heatmap")
#         else:
#             print("Insufficient data for heatmaps (need >1 event)")

#         # Explicit memory cleanup
#         del aligned_data, aligned_df
#         if 'pivot_470' in locals():
#             del pivot_470
#         if 'pivot_560' in locals():
#             del pivot_560
#         if 'pivot_470df' in locals():
#             del pivot_470df
#         if 'pivot_560df' in locals():
#             del pivot_560df
#         gc.collect()

#     except Exception as e:
#         print(f"ERROR processing {data_path}: {str(e)}")
#         traceback.print_exc()
#         # Clean up on error
#         gc.collect()

# print("✅ Finished all data paths.")

In [None]:
#separates RIGHT vs LEFT TURNS, creates heatmaps, and comprehensive (Mean +- sem) aligned data for each data path
#----------------------------------------------------
"""
Refactored photometry analysis code for processing aligned behavioral data.
Separates left vs right turns, creates heatmaps, and generates comprehensive plots.
"""
class PhotometryAnalyzer:
    """Class for analyzing photometry data with behavioral events."""
    
    # Class constants
    REQUIRED_COLUMNS = [
        "Time (s)", "Photodiode_int", "z_470", "z_560", 
        "dfF_470", "dfF_560", "Motor_Velocity", "Velocity_0X", "Velocity_0Y"
    ]
    
    FLUORESCENCE_CHANNELS = {
        'z_470': {'color': 'cornflowerblue', 'label': 'z_470'},
        'z_560': {'color': 'red', 'label': 'z_560'},
        'dfF_470': {'color': 'blue', 'label': 'dfF_470'},
        'dfF_560': {'color': 'orange', 'label': 'dfF_560'}
    }
    
    def __init__(self, time_window: Tuple[float, float] = (time_window_start, time_window_end)):
        """
        Initialize analyzer with time window parameters.
        
        Args:
            time_window: Tuple of (start, end) times relative to event (seconds)
        """
        self.time_window_start, self.time_window_end = time_window
        
    def process_aligned_data(self, df: pd.DataFrame, halt_time: pd.Timestamp) -> Optional[pd.DataFrame]:
        """
        Process a single halt event efficiently.
        
        Args:
            df: Main dataframe with photometry and behavioral data
            halt_time: Timestamp of the halt event
            
        Returns:
            Windowed dataframe or None if no data in window
        """
        window_start = halt_time + pd.Timedelta(seconds=self.time_window_start)
        window_end = halt_time + pd.Timedelta(seconds=self.time_window_end)
        mask = (df.index >= window_start) & (df.index <= window_end)
        
        if not mask.any():
            return None
        
        window = df.loc[mask].copy()
        window["Time (s)"] = (window.index - halt_time).total_seconds()
        window["Halt Time"] = halt_time
        return window
    
    def separate_turns(self, aligned_df: pd.DataFrame) -> Tuple[List[pd.Timestamp], List[pd.Timestamp]]:
        """
        Separate halt events into left and right turns based on motor velocity.
        
        Args:
            aligned_df: Aligned dataframe with all halt events
            
        Returns:
            Tuple of (left_turn_halts, right_turn_halts)
        """
        left_turn_halts = []
        right_turn_halts = []
        
        for halt in aligned_df["Halt Time"].unique():
            # Look at motor velocity in pre-event window
            subset = aligned_df[
                (aligned_df["Halt Time"] == halt) & 
                (aligned_df["Time (s)"] >= -1) & 
                (aligned_df["Time (s)"] < 0)
            ]
            
            if subset.empty:
                continue
                
            mean_velocity = subset["Motor_Velocity"].mean()
            
            if mean_velocity < 0:  # Negative = left turn
                left_turn_halts.append(halt)
            elif mean_velocity > 0:  # Positive = right turn
                right_turn_halts.append(halt)
        
        return left_turn_halts, right_turn_halts
    
    def save_turn_data(self, aligned_df: pd.DataFrame, left_turns: List, right_turns: List, 
                      session_name: str, event_name: str, output_dir: Path) -> None:
        """Save separated turn data to CSV files only if turns are detected."""
        
        # Only save left turns if there are any
        if left_turns:
            left_df = aligned_df[aligned_df["Halt Time"].isin(left_turns)]
            left_file = output_dir / f"{session_name}_{event_name}_left_turns.csv"
            left_df.to_csv(left_file, index=False, float_format='%.4f')
            print(f"Saved {len(left_turns)} left turns to {left_file}")
        else:
            print(f"No left turns detected - no CSV file saved")
        
        # Only save right turns if there are any
        if right_turns:
            right_df = aligned_df[aligned_df["Halt Time"].isin(right_turns)]
            right_file = output_dir / f"{session_name}_{event_name}_right_turns.csv"
            right_df.to_csv(right_file, index=False, float_format='%.4f')
            print(f"Saved {len(right_turns)} right turns to {right_file}")
        else:
            print(f"No right turns detected - no CSV file saved")
    
    def create_heatmap(self, pivot_data: pd.DataFrame, session_name: str, event_name: str, 
                      channel: str, save_path: Path, figsize: Tuple[int, int] = (10, 6)) -> pd.DataFrame:
        """
        Create and save normalized heatmap.
        
        Args:
            pivot_data: Pivoted data (events x time)
            session_name: Name of session
            event_name: Name of event type
            channel: Channel name (e.g., 'z_470')
            save_path: Path to save figure
            figsize: Figure size tuple
            
        Returns:
            Normalized data used for heatmap
        """
        # Baseline normalization
        baseline_cols = (pivot_data.columns >= -1) & (pivot_data.columns < 0)
        if baseline_cols.any():
            baseline_means = pivot_data.loc[:, baseline_cols].mean(axis=1)
            normalized_data = pivot_data.subtract(baseline_means, axis=0)
        else:
            normalized_data = pivot_data
        
        # Create figure
        fig, ax = plt.subplots(figsize=figsize)
        
        sns.heatmap(
            normalized_data, 
            cmap="RdBu_r", 
            center=0, 
            ax=ax,
            cbar_kws={'label': f'Normalized {channel}'},
            rasterized=True
        )
        
        ax.set_title(f"Heatmap ({channel}) - {session_name}")
        ax.set_xlabel("Time (s)")
        ax.set_ylabel("Event")
        
        # Optimize y-axis ticks
        n_events = len(normalized_data.index)
        y_positions = range(0, n_events, max(1, n_events // 10))
        y_labels = [str(i+1) for i in y_positions]
        ax.set_yticks(y_positions)
        ax.set_yticklabels(y_labels)
        
        # Add event line at time 0
        if 0 in normalized_data.columns:
            zero_idx = list(normalized_data.columns).index(0)
            ax.axvline(zero_idx, linestyle='--', color='black', alpha=0.7)
        
        # Optimize x-axis ticks
        time_cols = normalized_data.columns
        tick_indices = [i for i, val in enumerate(time_cols) 
                       if isinstance(val, (int, float)) and val % 2 == 0]
        tick_labels = [f"{int(time_cols[i])}" for i in tick_indices]
        
        ax.set_xticks(tick_indices)
        ax.set_xticklabels(tick_labels, rotation=45)
        
        plt.tight_layout()
        plt.savefig(save_path, dpi=200, bbox_inches='tight')
        plt.close(fig)
        
        return normalized_data
    
    def create_line_collections(self, df: pd.DataFrame, channels: List[str]) -> Dict[str, LineCollection]:
        """Create line collections for individual traces."""
        line_collections = {}
        
        for channel in channels:
            lines = []
            for halt in df["Halt Time"].unique():
                subset = df[df["Halt Time"] == halt]
                time_vals = subset["Time (s)"].values
                channel_vals = subset[channel].values
                lines.append(list(zip(time_vals, channel_vals)))
            
            color = self.FLUORESCENCE_CHANNELS.get(channel, {}).get('color', 'gray')
            line_collections[channel] = LineCollection(lines, colors=color, alpha=0.3, linewidths=1)
        
        return line_collections
    
    def add_mean_sem_plot(self, ax: plt.Axes, df: pd.DataFrame, channels: List[str], 
                         turn_type: str, line_style: str = '-') -> None:
        """Add mean ± SEM traces to axis."""
        grouped = df.groupby("Time (s)")
        time_index = grouped.mean().index.values
        
        for channel in channels:
            color = self.FLUORESCENCE_CHANNELS[channel]['color']
            label = f"{turn_type} {channel}"
            
            mean_vals = grouped.mean()[channel]
            sem_vals = grouped.sem()[channel]
            
            ax.plot(time_index, mean_vals, color=color, linestyle=line_style, 
                   linewidth=2, label=label)
            ax.fill_between(time_index, mean_vals - sem_vals, mean_vals + sem_vals,
                           color=color, alpha=0.2)
    
    def create_summary_plot(self, aligned_df: pd.DataFrame, session_name: str, 
                          event_name: str, save_path: Path) -> None:
        """
        Create comprehensive summary plots comparing left vs right turns.
        
        Args:
            aligned_df: Aligned dataframe with all events
            session_name: Session identifier
            event_name: Event type name
            save_path: Path to save the plot
        """
        # Separate turns
        left_turns, right_turns = self.separate_turns(aligned_df)
        
        if not left_turns and not right_turns:
            print("No valid turns found for summary plot")
            return
        
        # Create DataFrames for each turn type
        left_df = aligned_df[aligned_df["Halt Time"].isin(left_turns)]
        right_df = aligned_df[aligned_df["Halt Time"].isin(right_turns)]
        
        # Create figure
        fig, axes = plt.subplots(1, 3, figsize=(21, 6), sharex=True)
        z_channels = ['z_470', 'z_560']
        
        # Left plot - Left turn traces
        if not left_df.empty:
            ax1 = axes[0]
            left_collections = self.create_line_collections(left_df, z_channels + ['Motor_Velocity'])
            
            # Add fluorescence traces
            for channel in z_channels:
                ax1.add_collection(left_collections[channel])
            
            # Add motor velocity on secondary axis
            ax1_motor = ax1.twinx()
            ax1_motor.add_collection(left_collections['Motor_Velocity'])
            ax1_motor.set_ylabel("Motor Velocity", color='slategray')
            ax1_motor.tick_params(axis='y', labelcolor='slategray')
            ax1_motor.autoscale()
            
            ax1.set_title(f'Left Turn Traces (n={len(left_turns)})')
            ax1.set_ylabel("Fluorescence (z-score)")
            ax1.axvline(0, linestyle='--', color='black', alpha=0.7)
            ax1.autoscale()
        
        # Right plot - Right turn traces  
        if not right_df.empty:
            ax2 = axes[1]
            right_collections = self.create_line_collections(right_df, z_channels + ['Motor_Velocity'])
            
            # Add fluorescence traces
            for channel in z_channels:
                ax2.add_collection(right_collections[channel])
            
            # Add motor velocity on secondary axis
            ax2_motor = ax2.twinx()
            ax2_motor.add_collection(right_collections['Motor_Velocity'])
            ax2_motor.set_ylabel("Motor Velocity", color='slategray')
            ax2_motor.tick_params(axis='y', labelcolor='slategray')
            ax2_motor.autoscale()
            
            ax2.set_title(f'Right Turn Traces (n={len(right_turns)})')
            ax2.set_ylabel("Fluorescence (z-score)")
            ax2.axvline(0, linestyle='--', color='black', alpha=0.7)
            ax2.autoscale()
        
        # Comparison plot - Mean ± SEM
        ax3 = axes[2]
        
        if not left_df.empty:
            self.add_mean_sem_plot(ax3, left_df, z_channels, "Left", '--')
        if not right_df.empty:
            self.add_mean_sem_plot(ax3, right_df, z_channels, "Right", '-')
        
        # Add motor velocity comparison
        ax3_motor = ax3.twinx()
        motor_color = 'slategray'
        
        if not left_df.empty:
            left_motor_grouped = left_df.groupby("Time (s)")
            time_idx = left_motor_grouped.mean().index.values
            mean_motor = left_motor_grouped.mean()["Motor_Velocity"]
            sem_motor = left_motor_grouped.sem()["Motor_Velocity"]
            
            ax3_motor.plot(time_idx, mean_motor, color=motor_color, linestyle='--', 
                          linewidth=1.5, label="Left Motor")
            ax3_motor.fill_between(time_idx, mean_motor - sem_motor, mean_motor + sem_motor,
                                  color=motor_color, alpha=0.2)
        
        if not right_df.empty:
            right_motor_grouped = right_df.groupby("Time (s)")
            time_idx = right_motor_grouped.mean().index.values
            mean_motor = right_motor_grouped.mean()["Motor_Velocity"]
            sem_motor = right_motor_grouped.sem()["Motor_Velocity"]
            
            ax3_motor.plot(time_idx, mean_motor, color=motor_color, linestyle='-',
                          linewidth=1.5, label="Right Motor")
            ax3_motor.fill_between(time_idx, mean_motor - sem_motor, mean_motor + sem_motor,
                                  color=motor_color, alpha=0.2)
        
        ax3_motor.set_ylabel("Motor Velocity", color=motor_color)
        ax3_motor.tick_params(axis='y', labelcolor=motor_color)
        ax3_motor.axhline(0, linestyle='--', color='gray', alpha=0.5)
        ax3_motor.legend(loc='upper right')
        
        ax3.axvline(0, linestyle='--', color='black', alpha=0.7)
        ax3.set_xlabel("Time (s)")
        ax3.set_ylabel("Fluorescence (z-score)")
        ax3.set_title("Mean ± SEM Comparison")
        ax3.legend(loc='upper left')
        
        # Format all x-axes
        for ax in axes:
            ax.set_xlabel("Time (s)")
        
        fig.suptitle(f"{session_name} - {event_name}")
        plt.tight_layout()
        plt.savefig(save_path, dpi=200, bbox_inches='tight')
        plt.close(fig)
        plt.ioff()
    
    def process_session(self, data_path: Path, data: Dict[str, Any], 
                       variables: Dict[str, Any], event_name: str = "halt") -> None:
        """
        Process a complete session of data.
        
        Args:
            data_path: Path to the data directory
            data: Dictionary containing loaded data
            variables: Dictionary containing analysis variables
            event_name: Name of the event type being analyzed
        """
        print(f"\n--------- Processing: {data_path} ---------")
        
        try:
            # Extract data components
            df = data["photometry_tracking_encoder_data"]
            halts = variables["photodiode_halts"]
            
            # Get session name
            session_name = variables.get("session_name")
            if not session_name:
                mouse_name = data.get("mouse_name", "unknown_mouse")
                session_name = f"{mouse_name}_{data_path.stem}"
                print(f"No session_name found, using: {session_name}")
            
            print(f"Aligning {len(halts)} events for session '{session_name}'")
            
            # Process all halt events
            aligned_data = []
            for halt_time in halts:
                window_data = self.process_aligned_data(df, halt_time)
                if window_data is not None:
                    aligned_data.append(window_data)
            
            if not aligned_data:
                print(f"No aligned data generated for {session_name}, skipping")
                return
            
            # Combine all aligned data
            aligned_df = pd.concat(aligned_data, ignore_index=True)
            
            # Create output directory
            aligned_dir = data_path.parent / "aligned_data"
            aligned_dir.mkdir(exist_ok=True)
            
            # Save main aligned data
            aligned_file = aligned_dir / f"{session_name}_{event_name}_aligned.csv"
            aligned_df.to_csv(aligned_file, index=False, float_format='%.4f')
            print(f"Saved aligned data to {aligned_file}")
            
            # Separate and save turn data
            left_turns, right_turns = self.separate_turns(aligned_df)
            self.save_turn_data(aligned_df, left_turns, right_turns, 
                              session_name, event_name, aligned_dir)
            
            # Create summary plot
            print(f"📈 Creating plots for {session_name}")
            summary_path = data_path.parent / f"{session_name}_{event_name}.pdf"
            self.create_summary_plot(aligned_df, session_name, event_name, summary_path)
            print(f"Saved summary plot to {summary_path}")
            
            # Create heatmaps if sufficient data
            unique_halts = aligned_df["Halt Time"].nunique()
            if unique_halts > 1:
                self._create_all_heatmaps(aligned_df, session_name, event_name, data_path)
            else:
                print("Insufficient data for heatmaps (need >1 event)")
            
            # Cleanup
            del aligned_data, aligned_df
            gc.collect()
            
        except Exception as e:
            print(f"ERROR processing {data_path}: {str(e)}")
            traceback.print_exc()
            gc.collect()
    
    def _create_all_heatmaps(self, aligned_df: pd.DataFrame, session_name: str, 
                           event_name: str, data_path: Path) -> None:
        """Create all heatmaps for different channels."""
        heatmap_channels = ['z_470', 'z_560', 'dfF_470', 'dfF_560']
        
        for channel in heatmap_channels:
            try:
                # Create pivot table
                pivot_data = aligned_df.pivot_table(
                    index="Halt Time", 
                    columns="Time (s)", 
                    values=channel, 
                    aggfunc='first'
                )
                
                # Create heatmap
                heatmap_path = data_path.parent / f"{session_name}_{event_name}_heatmap_{channel}.pdf"
                self.create_heatmap(pivot_data, session_name, event_name, channel, heatmap_path)
                print(f"Saved {channel} heatmap")
                
                # Cleanup
                del pivot_data
                
            except Exception as e:
                print(f"Error creating {channel} heatmap: {e}")
        
        gc.collect()


def main(data_paths: List[Path], loaded_data: Dict, data_path_variables: Dict, 
         event_name: str = "halt", time_window: Tuple[float, float] = (-5, 10)):
    """
    Main processing function.
    
    Args:
        data_paths: List of data directory paths
        loaded_data: Dictionary of loaded data for each path
        data_path_variables: Dictionary of analysis variables for each path
        event_name: Name of event type (default: "halt")
        time_window: Tuple of (start, end) times relative to event in seconds
    """
    # Initialize analyzer
    analyzer = PhotometryAnalyzer(time_window)
    
    # Process each data path
    for idx, data_path in enumerate(data_paths, start=1):
        print(f"\n{'='*60}")
        print(f"Processing {idx}/{len(data_paths)}: {data_path.name}")
        print(f"{'='*60}")
        
        if data_path not in data_path_variables:
            print(f"Skipping {data_path} - no analysis data found")
            continue
        
        if data_path not in loaded_data:
            print(f"Skipping {data_path} - no loaded data found")
            continue
        
        # Process this session
        analyzer.process_session(
            data_path, 
            loaded_data[data_path], 
            data_path_variables[data_path], 
            event_name
        )
    
    print("\n✅ Finished processing all data paths.")

In [None]:
time_window = (time_window_start, time_window_end) 
main(data_paths, loaded_data, data_path_variables, 
     event_name=event_name, time_window=time_window)

In [None]:
#----oldcode----------------------------
#  def baseline_aligned_data(aligned_df, baseline_window, mouse_name):
#     # ---------------- Baseline Correction ----------------
#     baseline_df = aligned_df[
#         (aligned_df["Time (s)"] >= baseline_window[0]) & 
#         (aligned_df["Time (s)"] <= baseline_window[1])
#     ].groupby("Halt Time").mean()

#     for signal_name in ["z_470", "z_560", "Motor_Velocity", "Velocity_0X", "Velocity_0Y"]:
#         aligned_df[f"{signal_name}_Baseline"] = aligned_df[signal_name] - aligned_df["Halt Time"].map(baseline_df[signal_name])

#     # ---------------- Mean and SEM ----------------
#     mean_baseline_df = aligned_df.groupby("Time (s)").mean()
#     sem_baseline_df = aligned_df.groupby("Time (s)").sem()

#     def get_symmetric_ylim(mean_data, sem_data):
#         max_abs_value = max(
#             abs(mean_data).max() + sem_data.max(),
#             abs(mean_data).min() - sem_data.min()
#         )
#         return (-max_abs_value, max_abs_value)

#     # ---------------- Plotting ----------------
#     fig, ax = plt.subplots(figsize=(plot_width, 6))

#     ax.plot(mean_baseline_df.index, mean_baseline_df["Photodiode_int"], color='grey', alpha=0.8)
#     ax.fill_between(mean_baseline_df.index,
#                     mean_baseline_df["Photodiode_int"] - sem_baseline_df["Photodiode_int"],
#                     mean_baseline_df["Photodiode_int"] + sem_baseline_df["Photodiode_int"],
#                     color='grey', alpha=0.2)

#     ax.set_xlabel('Time (s) relative to halt')
#     ax.set_ylabel('Photodiode', color='grey')
#     ax.set_title(f'Baselined Mean & SEM of All Signals - {mouse_name}')

#     # z_470 and z_560
#     ax2 = ax.twinx()
#     ax2.plot(mean_baseline_df.index, mean_baseline_df["z_470_Baseline"], color='green', alpha=0.8)
#     ax2.fill_between(mean_baseline_df.index,
#                      mean_baseline_df["z_470_Baseline"] - sem_baseline_df["z_470_Baseline"],
#                      mean_baseline_df["z_470_Baseline"] + sem_baseline_df["z_470_Baseline"],
#                      color='green', alpha=0.2)
#     ax2.plot(mean_baseline_df.index, mean_baseline_df["z_560_Baseline"], color='red', alpha=0.8)
#     ax2.fill_between(mean_baseline_df.index,
#                      mean_baseline_df["z_560_Baseline"] - sem_baseline_df["z_560_Baseline"],
#                      mean_baseline_df["z_560_Baseline"] + sem_baseline_df["z_560_Baseline"],
#                      color='red', alpha=0.2)
#     ax2.set_ylabel('Fluorescence (z-score, red 560nm)', color='green')
#     ax2.set_ylim(get_symmetric_ylim(
#         pd.concat([mean_baseline_df["z_470_Baseline"], mean_baseline_df["z_560_Baseline"]]),
#         pd.concat([sem_baseline_df["z_470_Baseline"], sem_baseline_df["z_560_Baseline"]])
#     ))
#     ax2.yaxis.label.set_color('green')

#     # Motor velocity
#     ax3 = ax.twinx()
#     ax3.spines['right'].set_position(('outward', 50))
#     ax3.plot(mean_baseline_df.index, mean_baseline_df["Motor_Velocity_Baseline"], color='#00008B', alpha=0.8)
#     ax3.fill_between(mean_baseline_df.index,
#                      mean_baseline_df["Motor_Velocity_Baseline"] - sem_baseline_df["Motor_Velocity_Baseline"],
#                      mean_baseline_df["Motor_Velocity_Baseline"] + sem_baseline_df["Motor_Velocity_Baseline"],
#                      color='#00008B', alpha=0.2)
#     ax3.set_ylabel('Motor Velocity (deg/s²)', color='#00008B')
#     ax3.set_ylim(get_symmetric_ylim(mean_baseline_df["Motor_Velocity_Baseline"], sem_baseline_df["Motor_Velocity_Baseline"]))
#     ax3.yaxis.label.set_color('#00008B')

#     # Running velocity (Velocity_0X)
#     ax4 = ax.twinx()
#     ax4.spines['right'].set_position(('outward', 100))
#     ax4.plot(mean_baseline_df.index, mean_baseline_df["Velocity_0X_Baseline"] * 1000, color='orange', alpha=0.8)
#     ax4.fill_between(mean_baseline_df.index,
#                      (mean_baseline_df["Velocity_0X_Baseline"] - sem_baseline_df["Velocity_0X_Baseline"]) * 1000,
#                      (mean_baseline_df["Velocity_0X_Baseline"] + sem_baseline_df["Velocity_0X_Baseline"]) * 1000,
#                      color='orange', alpha=0.2)
#     ax4.set_ylabel('Running velocity (mm/s²) WRONG SCALE?', color='orange')
#     ax4.set_ylim(get_symmetric_ylim(mean_baseline_df["Velocity_0X_Baseline"] * 1000, sem_baseline_df["Velocity_0X_Baseline"] * 1000))
#     ax4.yaxis.label.set_color('orange')

#     # Turning velocity (Velocity_0Y)
#     ax5 = ax.twinx()
#     ax5.spines['right'].set_position(('outward', 150))
#     ax5.plot(mean_baseline_df.index, mean_baseline_df["Velocity_0Y_Baseline"], color='#4682B4', alpha=0.8)
#     ax5.fill_between(mean_baseline_df.index,
#                      mean_baseline_df["Velocity_0Y_Baseline"] - sem_baseline_df["Velocity_0Y_Baseline"],
#                      mean_baseline_df["Velocity_0Y_Baseline"] + sem_baseline_df["Velocity_0Y_Baseline"],
#                      color='#4682B4', alpha=0.2)
#     ax5.set_ylabel('Turning velocity (deg/s²) WRONG SCALE?', color='#4682B4')
#     ax5.set_ylim(get_symmetric_ylim(mean_baseline_df["Velocity_0Y_Baseline"], sem_baseline_df["Velocity_0Y_Baseline"]))
#     ax5.yaxis.label.set_color('#4682B4')

#     fig.tight_layout()

#     # Save the figure
#     try:          
#         figure_file = data_path.parent / f"{session_name}_{event_name}_baselined.pdf"
#         fig.savefig(figure_file, dpi=1200, bbox_inches='tight')
#         print(f"✅ Saved figure to {figure_file}")
#     except Exception as e:
#         print(f"⚠️ ERROR saving figure: {str(e)}")    

#     plt.close(fig)
#     return fig


In [None]:
#----oldcode----------------------------
# #--------baselining definitions--------------------------------------------
# def process_aligned_data_folders(data_dirs, baseline_window, event_name=event_name, plot_width=12, create_plots=True):
#     """
#     Process all aligned_data folders and generate baseline plots.
    
#     Parameters:
#     -----------
#     data_dirs : list
#         List of Path objects pointing to your main data directories
#     baseline_window : tuple
#         Tuple of (start_time, end_time) for baseline window
#     event_name : str
#         Event name for file naming (default: "halt")
#     plot_width : int
#         Width of the plot in inches
#     create_plots : bool
#         Whether to create and save plots (default: True)
#     """
    
#     results = {
#         'processed': [],
#         'errors': [],
#         'total_folders': 0
#     }
    
#     # Find all aligned_data folders
#     aligned_folders = []
#     for data_dir in data_dirs:
#         print(f"Searching in: {data_dir}")
#         # Find all aligned_data folders recursively
#         found_folders = list(data_dir.rglob("aligned_data"))
#         aligned_folders.extend(found_folders)
#         print(f"  Found {len(found_folders)} aligned_data folders")
    
#     results['total_folders'] = len(aligned_folders)
#     print(f"\nTotal aligned_data folders found: {len(aligned_folders)}")
    
#     for aligned_folder in aligned_folders:
#         try:
#             print(f"\n📁 Processing folder: {aligned_folder}")
            
#             # Find only the original aligned CSV files (exclude already processed baselined files)
#             all_csv_files = list(aligned_folder.glob("*.csv"))
#             csv_files = [f for f in all_csv_files if not f.name.endswith('_baselined_data.csv')]
            
#             if not csv_files:
#                 print(f"  ⚠️  No original aligned CSV files found in {aligned_folder}")
#                 print(f"  Available files: {[f.name for f in all_csv_files]}")
#                 results['errors'].append({
#                     'folder': str(aligned_folder),
#                     'error': 'No original aligned CSV files found',
#                     'status': 'skipped'
#                 })
#                 continue
#             print(f"  Found {len(csv_files)} aligned CSV files to process")
            
#             for csv_file in csv_files:
#                 try:
#                     # Check if the CSV file name matches the event name
#                     if event_name not in csv_file.name:
#                         print(f"    ⚠️ Skipping {csv_file.name} as it does not match the event name '{event_name}'")
#                         continue
                    
#                     print(f"    📊 Processing: {csv_file.name}")
                    
#                     # Load the data
#                     aligned_df = pd.read_csv(csv_file)
                    
#                     # Create aligned DataFrames for left and right turns
#                     left_turns_csv = csv_file.with_name(csv_file.stem.replace("aligned", "left_turns") + ".csv")
#                     right_turns_csv = csv_file.with_name(csv_file.stem.replace("aligned", "right_turns") + ".csv")
                    
#                     if left_turns_csv.exists():
#                         print(f"    📂 Found left turns CSV: {left_turns_csv.name}")
#                         left_turns_df = pd.read_csv(left_turns_csv)
#                     else:
#                         print(f"    ⚠️ Left turns CSV not found: {left_turns_csv.name}")
#                         left_turns_df = None
                    
#                     if right_turns_csv.exists():
#                         print(f"    📂 Found right turns CSV: {right_turns_csv.name}")
#                         right_turns_df = pd.read_csv(right_turns_csv)
#                     else:
#                         print(f"    ⚠️ Right turns CSV not found: {right_turns_csv.name}")
#                         right_turns_df = None
                    
#                     # Clean up the mouse name (remove extra suffixes)
#                     mouse_name = csv_file.stem.replace('_aligned', '').replace('_downsampled_data_Apply halt: 2s', '').split('_')[0]
#                     # Get session name from the folder structure
#                     session_name = aligned_folder.parent.name
#                     # Check if required columns exist
#                     required_columns = ["Time (s)", "Halt Time", "z_470", "z_560", "Motor_Velocity", 
#                                       "Velocity_0X", "Velocity_0Y", "Photodiode_int"]
#                     missing_columns = [col for col in required_columns if col not in aligned_df.columns]
                    
#                     if missing_columns:
#                         print(f"    ⚠️  Missing columns: {missing_columns}")
#                         print(f"    Available columns: {list(aligned_df.columns)}")
#                         raise ValueError(f"Missing required columns: {missing_columns}")
                    
#                     # Process the data and create plot
#                     fig = baseline_aligned_data_simple(
#                         aligned_df=aligned_df,
#                         baseline_window=baseline_window,
#                         mouse_name=mouse_name,
#                         session_name=session_name,
#                         event_name=event_name,
#                         output_folder=aligned_folder,
#                         csv_file=csv_file,
#                         plot_width=plot_width,
#                         create_plots=create_plots
#                     )
                    
#                     results['processed'].append({
#                         'file': str(csv_file),
#                         'mouse_name': mouse_name,
#                         'session_name': session_name,
#                         'folder': str(aligned_folder),
#                         'status': 'success'
#                     })
                    
#                 except Exception as e:
#                     error_info = {
#                         'file': str(csv_file),
#                         'error': str(e),
#                         'status': 'failed'
#                     }
#                     results['errors'].append(error_info)
#                     print(f"    ❌ Error processing {csv_file.name}: {str(e)}")
                    
#         except Exception as e:
#             error_info = {
#                 'folder': str(aligned_folder),
#                 'error': str(e),
#                 'status': 'failed'
#             }
#             results['errors'].append(error_info)
#             print(f"❌ Error accessing {aligned_folder}: {str(e)}")
    
#     # Print summary
#     print(f"\n{'='*60}")
#     print(f"PROCESSING SUMMARY")
#     print(f"{'='*60}")
#     print(f"Total aligned_data folders: {results['total_folders']}")
#     print(f"Successfully processed files: {len(results['processed'])}")
#     print(f"Errors encountered: {len(results['errors'])}")
    
#     if results['errors']:
#         print(f"\nErrors:")
#         for error in results['errors']:
#             if 'file' in error:
#                 print(f"  - File {Path(error['file']).name}: {error['error']}")
#             else:
#                 print(f"  - Folder {Path(error['folder']).name}: {error['error']}")
    
#     if results['processed']:
#         print(f"\nSuccessfully processed:")
#         for proc in results['processed']:
#             print(f"  - {proc['mouse_name']} in {Path(proc['folder']).parent.name}")
    
#     return results

# def baseline_aligned_data_simple(aligned_df, baseline_window, mouse_name, session_name, event_name, output_folder, csv_file, plot_width=12, create_plots=True):
# # def baseline_aligned_data_simple(aligned_df, baseline_window, mouse_name, session_name, event_name, output_folder, plot_width=12, create_plots=True):
#     """
#     Simple baseline correction and plotting function.
    
#     Parameters:
#     -----------
#     create_plots : bool
#         Whether to create and save plots (default: True)
#     """
    
#     # # Check if baseline file already exists and skip if it does
#     # baseline_data_file = output_folder / f"{mouse_name}_{event_name}_baselined_data.csv"
#     # if baseline_data_file.exists():
#     #     print(f"      ⚠️  Baseline file already exists, skipping: {baseline_data_file.name}")
#     #     return None

#     # # Check if the file corresponds to left or right turn apply halt events
#     # if "left_turns" in csv_file.name or "right_turns" in csv_file.name:
#     #     print(f"      📂 Processing left/right turn apply halt file: {csv_file.name}")
#     # else:
#     #     print(f"      ⚠️  Skipping eft/right turn apply halt file: {csv_file.name}")
#     #     return None
#     print(f"      🔄 Performing baseline correction...")

#     # ---------------- Baseline Correction ----------------
#     # Make a copy to avoid modifying the original data
#     aligned_df_copy = aligned_df.copy()
#     left_turns_file = align
    
#     baseline_df = aligned_df_copy[
#         (aligned_df_copy["Time (s)"] >= baseline_window[0]) & 
#         (aligned_df_copy["Time (s)"] <= baseline_window[1])
#     ].groupby("Halt Time").mean(numeric_only=True)

#     baseline_df_left_turns = aligned_df_copy[
#         (aligned_df_copy["Time (s)"] >= baseline_window[0]) & 
#         (aligned_df_copy["Time (s)"] <= baseline_window[1])
#     ].groupby("Halt Time").mean(numeric_only=True)

#     # Create baseline-corrected columns
#     for signal_name in ["z_470", "z_560", "Motor_Velocity", "Velocity_0X", "Velocity_0Y"]:
#         if signal_name in aligned_df_copy.columns:
#             aligned_df_copy[f"{signal_name}_Baseline"] = aligned_df_copy[signal_name] - aligned_df_copy["Halt Time"].map(baseline_df[signal_name])
#         else:
#             print(f"      ⚠️  Column {signal_name} not found, skipping...")

#     # Define the baseline data file path
#     # Define the baseline data file path
#     baseline_data_file = output_folder / f"{mouse_name}_{event_name}_baselined_data.csv"
#     left_turns_file = output_folder / f"{mouse_name}_{event_name}_left_turns_baselined_data.csv"
#     right_turns_file = output_folder / f"{mouse_name}_{event_name}_right_turns_baselined_data.csv"

#     # Save the baseline-corrected data
#     aligned_df_copy.to_csv(baseline_data_file, index=False)
#     print(f"      💾 Saved baseline data to: {baseline_data_file.name}")

#     # ---------------- Mean and SEM ----------------
#     # Select only numeric columns for aggregation
#     numeric_columns = aligned_df_copy.select_dtypes(include=['number']).columns
#     mean_baseline_df = aligned_df_copy.groupby("Time (s)")[numeric_columns].mean()
#     sem_baseline_df = aligned_df_copy.groupby("Time (s)")[numeric_columns].sem()

#     def get_symmetric_ylim(mean_data, sem_data):
#         max_abs_value = max(
#             abs(mean_data).max() + sem_data.max(),
#             abs(mean_data).min() - sem_data.min()
#         )
#         return (-max_abs_value, max_abs_value)

#     print(f"      📊 Creating plot...")

#     # ---------------- Plotting ----------------
#     fig, ax = plt.subplots(figsize=(plot_width, 6))

#     # Photodiode
#     ax.plot(mean_baseline_df.index, mean_baseline_df["Photodiode_int"], color='grey', alpha=0.8, linewidth=2)
#     ax.fill_between(mean_baseline_df.index,
#                     mean_baseline_df["Photodiode_int"] - sem_baseline_df["Photodiode_int"],
#                     mean_baseline_df["Photodiode_int"] + sem_baseline_df["Photodiode_int"],
#                     color='grey', alpha=0.2)

#     ax.set_xlabel('Time (s) relative to halt')
#     ax.set_ylabel('Photodiode', color='grey')
#     ax.set_title(f'Baselined Signals - {mouse_name} ({session_name})')

#     # z_470 and z_560 (Fluorescence)
#     ax2 = ax.twinx()
#     ax2.plot(mean_baseline_df.index, mean_baseline_df["z_470_Baseline"], color='green', alpha=0.8, linewidth=2, label='470nm')
#     ax2.fill_between(mean_baseline_df.index,
#                      mean_baseline_df["z_470_Baseline"] - sem_baseline_df["z_470_Baseline"],
#                      mean_baseline_df["z_470_Baseline"] + sem_baseline_df["z_470_Baseline"],
#                      color='green', alpha=0.2)
#     ax2.plot(mean_baseline_df.index, mean_baseline_df["z_560_Baseline"], color='red', alpha=0.8, linewidth=2, label='560nm')
#     ax2.fill_between(mean_baseline_df.index,
#                      mean_baseline_df["z_560_Baseline"] - sem_baseline_df["z_560_Baseline"],
#                      mean_baseline_df["z_560_Baseline"] + sem_baseline_df["z_560_Baseline"],
#                      color='red', alpha=0.2)
#     ax2.set_ylabel('Fluorescence (z-score)', color='green')
#     ax2.set_ylim(get_symmetric_ylim(
#         pd.concat([mean_baseline_df["z_470_Baseline"], mean_baseline_df["z_560_Baseline"]]),
#         pd.concat([sem_baseline_df["z_470_Baseline"], sem_baseline_df["z_560_Baseline"]])
#     ))
#     ax2.yaxis.label.set_color('green')

#     # Motor velocity
#     ax3 = ax.twinx()
#     ax3.spines['right'].set_position(('outward', 50))
#     ax3.plot(mean_baseline_df.index, mean_baseline_df["Motor_Velocity_Baseline"], color='#00008B', alpha=0.8, linewidth=2)
#     ax3.fill_between(mean_baseline_df.index,
#                      mean_baseline_df["Motor_Velocity_Baseline"] - sem_baseline_df["Motor_Velocity_Baseline"],
#                      mean_baseline_df["Motor_Velocity_Baseline"] + sem_baseline_df["Motor_Velocity_Baseline"],
#                      color='#00008B', alpha=0.2)
#     ax3.set_ylabel('Motor Velocity (deg/s²)', color='#00008B')
#     ax3.set_ylim(get_symmetric_ylim(mean_baseline_df["Motor_Velocity_Baseline"], sem_baseline_df["Motor_Velocity_Baseline"]))
#     ax3.yaxis.label.set_color('#00008B')

#     # Running velocity (Velocity_0X)
#     ax4 = ax.twinx()
#     ax4.spines['right'].set_position(('outward', 100))
#     ax4.plot(mean_baseline_df.index, mean_baseline_df["Velocity_0X_Baseline"] * 1000, color='orange', alpha=0.8, linewidth=2)
#     ax4.fill_between(mean_baseline_df.index,
#                      (mean_baseline_df["Velocity_0X_Baseline"] - sem_baseline_df["Velocity_0X_Baseline"]) * 1000,
#                      (mean_baseline_df["Velocity_0X_Baseline"] + sem_baseline_df["Velocity_0X_Baseline"]) * 1000,
#                      color='orange', alpha=0.2)
#     ax4.set_ylabel('Running velocity (mm/s²)', color='orange')
#     ax4.set_ylim(get_symmetric_ylim(mean_baseline_df["Velocity_0X_Baseline"] * 1000, sem_baseline_df["Velocity_0X_Baseline"] * 1000))
#     ax4.yaxis.label.set_color('orange')

#     # Turning velocity (Velocity_0Y)
#     ax5 = ax.twinx()
#     ax5.spines['right'].set_position(('outward', 150))
#     ax5.plot(mean_baseline_df.index, mean_baseline_df["Velocity_0Y_Baseline"], color='#4682B4', alpha=0.8, linewidth=2)
#     ax5.fill_between(mean_baseline_df.index,
#                      mean_baseline_df["Velocity_0Y_Baseline"] - sem_baseline_df["Velocity_0Y_Baseline"],
#                      mean_baseline_df["Velocity_0Y_Baseline"] + sem_baseline_df["Velocity_0Y_Baseline"],
#                      color='#4682B4', alpha=0.2)
#     ax5.set_ylabel('Turning velocity (deg/s²)', color='#4682B4')
#     ax5.set_ylim(get_symmetric_ylim(mean_baseline_df["Velocity_0Y_Baseline"], sem_baseline_df["Velocity_0Y_Baseline"]))
#     ax5.yaxis.label.set_color('#4682B4')

#     # Add vertical line at event time (t=0)
#     ax.axvline(x=0, color='black', linestyle='--', alpha=0.5, linewidth=1)

#     fig.tight_layout()

#     figure_file = output_folder / f"{session_name}_{event_name}_baselined.pdf"

#     # Save the figure
#     fig.savefig(figure_file, format='pdf', bbox_inches='tight')
#     print(f"      💾 Saved plot to: {figure_file.name}")
#     plt.close(fig)
#     return fig


# # Example usage:
# if __name__ == "__main__":
#     # Process all aligned_data folders
#     results = process_aligned_data_folders(
#         data_dirs=data_dirs,
#         baseline_window=baseline_window,
#         event_name=event_name,
#         plot_width=plot_width,
#         create_plots=True  # Ensure this is set to True to save plots
#     )
    
#     print(f"\n🎉 Processing complete!")

In [None]:
# BASELINING
#----------------------------------------------------
def process_aligned_data_folders(data_dirs, baseline_window, event_name=event_name, plot_width=12, create_plots=True):
    """
    Process all aligned_data folders and generate baseline plots.
    
    Parameters:
    -----------
    data_dirs : list
        List of Path objects pointing to your main data directories
    baseline_window : tuple
        Tuple of (start_time, end_time) for baseline window
    event_name : str
        Event name for file naming (default: "halt")
    plot_width : int
        Width of the plot in inches
    create_plots : bool
        Whether to create and save plots (default: True)
    """
    
    results = {
        'processed': [],
        'errors': [],
        'total_folders': 0
    }
    
    # Find all aligned_data folders
    aligned_folders = []
    for data_dir in data_dirs:
        print(f"Searching in: {data_dir}")
        # Find all aligned_data folders recursively
        found_folders = list(data_dir.rglob("aligned_data"))
        aligned_folders.extend(found_folders)
        print(f"  Found {len(found_folders)} aligned_data folders")
    
    results['total_folders'] = len(aligned_folders)
    print(f"\nTotal aligned_data folders found: {len(aligned_folders)}")
    
    for aligned_folder in aligned_folders:
        try:
            print(f"\n📁 Processing folder: {aligned_folder}")
            
            # Find only the original aligned CSV files (exclude already processed baselined files and turn files)
            all_csv_files = list(aligned_folder.glob("*.csv"))
            csv_files = [f for f in all_csv_files if not f.name.endswith('_baselined_data.csv') 
                        and not f.name.endswith('_left_turns.csv') 
                        and not f.name.endswith('_right_turns.csv')]
            
            if not csv_files:
                print(f"  ⚠️  No original aligned CSV files found in {aligned_folder}")
                print(f"  Available files: {[f.name for f in all_csv_files]}")
                results['errors'].append({
                    'folder': str(aligned_folder),
                    'error': 'No original aligned CSV files found',
                    'status': 'skipped'
                })
                continue
            print(f"  Found {len(csv_files)} aligned CSV files to process")
            
            for csv_file in csv_files:
                try:
                    # Check if the CSV file name matches the event name
                    if event_name not in csv_file.name:
                        print(f"    ⚠️ Skipping {csv_file.name} as it does not match the event name '{event_name}'")
                        continue
                    
                    print(f"    📊 Processing: {csv_file.name}")
                    
                    # Load the data
                    aligned_df = pd.read_csv(csv_file)
                    
                    # Create aligned DataFrames for left and right turns
                    # Replace '_aligned' with '_left_turns' and '_right_turns'
                    left_turns_csv = csv_file.with_name(csv_file.stem.replace('_aligned', '_left_turns') + ".csv")
                    right_turns_csv = csv_file.with_name(csv_file.stem.replace('_aligned', '_right_turns') + ".csv")
                    
                    left_turns_df = None
                    right_turns_df = None
                    
                    if left_turns_csv.exists():
                        print(f"    📂 Found left turns CSV: {left_turns_csv.name}")
                        left_turns_df = pd.read_csv(left_turns_csv)
                    else:
                        print(f"    ⚠️ Left turns CSV not found: {left_turns_csv.name}")
                    
                    if right_turns_csv.exists():
                        print(f"    📂 Found right turns CSV: {right_turns_csv.name}")
                        right_turns_df = pd.read_csv(right_turns_csv)
                    else:
                        print(f"    ⚠️ Right turns CSV not found: {right_turns_csv.name}")
                    
                    # Clean up the mouse name (remove extra suffixes)
                    mouse_name = csv_file.stem.replace('_aligned', '').replace('_downsampled_data_Apply halt: 2s', '').split('_')[0]
                    # Get session name from the folder structure
                    session_name = aligned_folder.parent.name
                    
                    # Check if required columns exist
                    required_columns = ["Time (s)", "Halt Time", "z_470", "z_560", "Motor_Velocity", 
                                      "Velocity_0X", "Velocity_0Y", "Photodiode_int"]
                    missing_columns = [col for col in required_columns if col not in aligned_df.columns]
                    
                    if missing_columns:
                        print(f"    ⚠️  Missing columns: {missing_columns}")
                        print(f"    Available columns: {list(aligned_df.columns)}")
                        raise ValueError(f"Missing required columns: {missing_columns}")
                    
                    # Process the data and create plot
                    fig = baseline_aligned_data_simple(
                        aligned_df=aligned_df,
                        left_turns_df=left_turns_df,
                        right_turns_df=right_turns_df,
                        baseline_window=baseline_window,
                        mouse_name=mouse_name,
                        session_name=session_name,
                        event_name=event_name,
                        output_folder=aligned_folder,
                        csv_file=csv_file,
                        plot_width=plot_width,
                        create_plots=create_plots
                    )
                    
                    results['processed'].append({
                        'file': str(csv_file),
                        'mouse_name': mouse_name,
                        'session_name': session_name,
                        'folder': str(aligned_folder),
                        'status': 'success'
                    })
                    
                except Exception as e:
                    error_info = {
                        'file': str(csv_file),
                        'error': str(e),
                        'status': 'failed'
                    }
                    results['errors'].append(error_info)
                    print(f"    ❌ Error processing {csv_file.name}: {str(e)}")
                    
        except Exception as e:
            error_info = {
                'folder': str(aligned_folder),
                'error': str(e),
                'status': 'failed'
            }
            results['errors'].append(error_info)
            print(f"❌ Error accessing {aligned_folder}: {str(e)}")
    
    # Print summary
    print(f"\n{'='*60}")
    print(f"PROCESSING SUMMARY")
    print(f"{'='*60}")
    print(f"Total aligned_data folders: {results['total_folders']}")
    print(f"Successfully processed files: {len(results['processed'])}")
    print(f"Errors encountered: {len(results['errors'])}")
    
    if results['errors']:
        print(f"\nErrors:")
        for error in results['errors']:
            if 'file' in error:
                print(f"  - File {Path(error['file']).name}: {error['error']}")
            else:
                print(f"  - Folder {Path(error['folder']).name}: {error['error']}")
    
    if results['processed']:
        print(f"\nSuccessfully processed:")
        for proc in results['processed']:
            print(f"  - {proc['mouse_name']} in {Path(proc['folder']).parent.name}")
    
    return results

def baseline_aligned_data_simple(aligned_df, left_turns_df, right_turns_df, baseline_window, mouse_name, session_name, event_name, output_folder, csv_file, plot_width=12, create_plots=True):
    """
    Simple baseline correction and plotting function.
    
    Parameters:
    -----------
    aligned_df : pd.DataFrame
        Main aligned data
    left_turns_df : pd.DataFrame or None
        Left turns data
    right_turns_df : pd.DataFrame or None
        Right turns data
    baseline_window : tuple
        Tuple of (start_time, end_time) for baseline window
    mouse_name : str
        Mouse name for file naming
    session_name : str
        Session name for file naming
    event_name : str
        Event name for file naming
    output_folder : Path
        Output folder path
    csv_file : Path
        Original CSV file path
    plot_width : int
        Width of the plot in inches
    create_plots : bool
        Whether to create and save plots (default: True)
    """
    
    print(f"      🔄 Performing baseline correction...")

    def baseline_dataframe(df, baseline_window, mouse_name, event_name, output_folder, suffix=""):
        """Helper function to baseline a single dataframe"""
        # Make a copy to avoid modifying the original data
        df_copy = df.copy()
        
        # Calculate baseline values
        baseline_df = df_copy[
            (df_copy["Time (s)"] >= baseline_window[0]) & 
            (df_copy["Time (s)"] <= baseline_window[1])
        ].groupby("Halt Time").mean(numeric_only=True)
        
        # Create baseline-corrected columns
        for signal_name in ["z_470", "z_560", "Motor_Velocity", "Velocity_0X", "Velocity_0Y"]:
            if signal_name in df_copy.columns:
                df_copy[f"{signal_name}_Baseline"] = df_copy[signal_name] - df_copy["Halt Time"].map(baseline_df[signal_name])
            else:
                print(f"      ⚠️  Column {signal_name} not found in {suffix} data, skipping...")
        
        # Define the baseline data file path
        if suffix:
            baseline_data_file = output_folder / f"{mouse_name}_{event_name}_{suffix}_baselined_data.csv"
        else:
            baseline_data_file = output_folder / f"{mouse_name}_{event_name}_baselined_data.csv"
        
        # Save the baseline-corrected data
        df_copy.to_csv(baseline_data_file, index=False)
        print(f"      💾 Saved {suffix} baseline data to: {baseline_data_file.name}")
        
        return df_copy

    # ---------------- Baseline Correction ----------------
    # Process main aligned data
    aligned_df_baselined = baseline_dataframe(aligned_df, baseline_window, mouse_name, event_name, output_folder)
    
    # Process left turns data if available
    left_turns_df_baselined = None
    if left_turns_df is not None:
        print(f"      🔄 Processing left turns data...")
        left_turns_df_baselined = baseline_dataframe(left_turns_df, baseline_window, mouse_name, event_name, output_folder, "left_turns")
    
    # Process right turns data if available
    right_turns_df_baselined = None
    if right_turns_df is not None:
        print(f"      🔄 Processing right turns data...")
        right_turns_df_baselined = baseline_dataframe(right_turns_df, baseline_window, mouse_name, event_name, output_folder, "right_turns")

    # ---------------- Mean and SEM for plotting (using main aligned data) ----------------
    # Select only numeric columns for aggregation
    numeric_columns = aligned_df_baselined.select_dtypes(include=['number']).columns
    mean_baseline_df = aligned_df_baselined.groupby("Time (s)")[numeric_columns].mean()
    sem_baseline_df = aligned_df_baselined.groupby("Time (s)")[numeric_columns].sem()

    def get_symmetric_ylim(mean_data, sem_data):
        max_abs_value = max(
            abs(mean_data).max() + sem_data.max(),
            abs(mean_data).min() - sem_data.min()
        )
        return (-max_abs_value, max_abs_value)

    if create_plots:
        print(f"      📊 Creating plot...")

        # ---------------- Plotting ----------------
        fig, ax = plt.subplots(figsize=(plot_width, 6))

        # Photodiode
        ax.plot(mean_baseline_df.index, mean_baseline_df["Photodiode_int"], color='grey', alpha=0.8, linewidth=2)
        ax.fill_between(mean_baseline_df.index,
                        mean_baseline_df["Photodiode_int"] - sem_baseline_df["Photodiode_int"],
                        mean_baseline_df["Photodiode_int"] + sem_baseline_df["Photodiode_int"],
                        color='grey', alpha=0.2)

        ax.set_xlabel('Time (s) relative to halt')
        ax.set_ylabel('Photodiode', color='grey')
        ax.set_title(f'Baselined Signals - {mouse_name} ({session_name})')

        # z_470 and z_560 (Fluorescence)
        ax2 = ax.twinx()
        ax2.plot(mean_baseline_df.index, mean_baseline_df["z_470_Baseline"], color='green', alpha=0.8, linewidth=2, label='470nm')
        ax2.fill_between(mean_baseline_df.index,
                         mean_baseline_df["z_470_Baseline"] - sem_baseline_df["z_470_Baseline"],
                         mean_baseline_df["z_470_Baseline"] + sem_baseline_df["z_470_Baseline"],
                         color='green', alpha=0.2)
        ax2.plot(mean_baseline_df.index, mean_baseline_df["z_560_Baseline"], color='red', alpha=0.8, linewidth=2, label='560nm')
        ax2.fill_between(mean_baseline_df.index,
                         mean_baseline_df["z_560_Baseline"] - sem_baseline_df["z_560_Baseline"],
                         mean_baseline_df["z_560_Baseline"] + sem_baseline_df["z_560_Baseline"],
                         color='red', alpha=0.2)
        ax2.set_ylabel('Fluorescence (z-score)', color='green')
        ax2.set_ylim(get_symmetric_ylim(
            pd.concat([mean_baseline_df["z_470_Baseline"], mean_baseline_df["z_560_Baseline"]]),
            pd.concat([sem_baseline_df["z_470_Baseline"], sem_baseline_df["z_560_Baseline"]])
        ))
        ax2.yaxis.label.set_color('green')

        # Motor velocity
        ax3 = ax.twinx()
        ax3.spines['right'].set_position(('outward', 50))
        ax3.plot(mean_baseline_df.index, mean_baseline_df["Motor_Velocity_Baseline"], color='#00008B', alpha=0.8, linewidth=2)
        ax3.fill_between(mean_baseline_df.index,
                         mean_baseline_df["Motor_Velocity_Baseline"] - sem_baseline_df["Motor_Velocity_Baseline"],
                         mean_baseline_df["Motor_Velocity_Baseline"] + sem_baseline_df["Motor_Velocity_Baseline"],
                         color='#00008B', alpha=0.2)
        ax3.set_ylabel('Motor Velocity (deg/s²)', color='#00008B')
        ax3.set_ylim(get_symmetric_ylim(mean_baseline_df["Motor_Velocity_Baseline"], sem_baseline_df["Motor_Velocity_Baseline"]))
        ax3.yaxis.label.set_color('#00008B')

        # Running velocity (Velocity_0X)
        ax4 = ax.twinx()
        ax4.spines['right'].set_position(('outward', 100))
        ax4.plot(mean_baseline_df.index, mean_baseline_df["Velocity_0X_Baseline"] * 1000, color='orange', alpha=0.8, linewidth=2)
        ax4.fill_between(mean_baseline_df.index,
                         (mean_baseline_df["Velocity_0X_Baseline"] - sem_baseline_df["Velocity_0X_Baseline"]) * 1000,
                         (mean_baseline_df["Velocity_0X_Baseline"] + sem_baseline_df["Velocity_0X_Baseline"]) * 1000,
                         color='orange', alpha=0.2)
        ax4.set_ylabel('Running velocity (mm/s²)', color='orange')
        ax4.set_ylim(get_symmetric_ylim(mean_baseline_df["Velocity_0X_Baseline"] * 1000, sem_baseline_df["Velocity_0X_Baseline"] * 1000))
        ax4.yaxis.label.set_color('orange')

        # Turning velocity (Velocity_0Y)
        ax5 = ax.twinx()
        ax5.spines['right'].set_position(('outward', 150))
        ax5.plot(mean_baseline_df.index, mean_baseline_df["Velocity_0Y_Baseline"], color='#4682B4', alpha=0.8, linewidth=2)
        ax5.fill_between(mean_baseline_df.index,
                         mean_baseline_df["Velocity_0Y_Baseline"] - sem_baseline_df["Velocity_0Y_Baseline"],
                         mean_baseline_df["Velocity_0Y_Baseline"] + sem_baseline_df["Velocity_0Y_Baseline"],
                         color='#4682B4', alpha=0.2)
        ax5.set_ylabel('Turning velocity (deg/s²)', color='#4682B4')
        ax5.set_ylim(get_symmetric_ylim(mean_baseline_df["Velocity_0Y_Baseline"], sem_baseline_df["Velocity_0Y_Baseline"]))
        ax5.yaxis.label.set_color('#4682B4')

        # Add vertical line at event time (t=0)
        ax.axvline(x=0, color='black', linestyle='--', alpha=0.5, linewidth=1)

        fig.tight_layout()

        figure_file = output_folder / f"{session_name}_{event_name}_baselined.pdf"

        # Save the figure
        fig.savefig(figure_file, format='pdf', bbox_inches='tight')
        print(f"      💾 Saved plot to: {figure_file.name}")
        plt.close(fig)
        return fig
    else:
        return None



In [None]:
# baseline usage:
if __name__ == "__main__":
    # Process all aligned_data folders
    results = process_aligned_data_folders(
        data_dirs=data_dirs,
        baseline_window=baseline_window,
        event_name=event_name,
        plot_width=plot_width,
        create_plots=True  # Ensure this is set to True to save plots
    )
    
    print(f"\n🎉 Processing complete!")