# Raw Data Processing Stage Analysis Notebook

Purpose: Load data at different intermediate processing points from `raw_data_processor.py`
for a specific subject and visualize selected sensors/modalities to verify the steps.


In [None]:
# %% Imports
import os
import sys
import yaml
import pandas as pd
import numpy as np
from datetime import timedelta, datetime, time
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
from scipy.signal import butter, sosfiltfilt
from ipywidgets import interact, Dropdown, IntText, FloatText, Button, VBox, HBox, Output, DatePicker, Label
import logging
import re

# --- Add project root to path to import project modules ---
# Adjust this path if your script is located elsewhere relative to 'src'
try:
    script_dir = os.path.dirname(os.path.abspath(__file__)) # Get script directory
except NameError:
    script_dir = os.getcwd() # Fallback for interactive environments like notebooks
project_root = os.path.abspath(os.path.join(script_dir, "..")) # Assumes script is in src

# Temporarily add project root to sys.path if not already there
if project_root not in sys.path:
    print(f"Adding project root to path: {project_root}")
    sys.path.insert(0, project_root)

# --- Import project modules ---
try:
    # Assuming the script is run from within the 'src' directory or 'src' is in PYTHONPATH
    import utils
    import config_loader
    # Import specific functions needed from raw_data_processor
    from raw_data_processor import (
        correct_timestamp_drift,
        process_file_numeric_time, # Needed by loaders
        data_loader_no_dir,
        data_loader_with_dir,
        select_data_loader,
        process_modality_duplicates,
        handle_missing_data_interpolation, # Kept for potential other uses, but removed from main flow
        modify_modality_names, # Added import
        butter_lowpass_sos,
        apply_filter_combined
    )
except ImportError as e:
    print(f"Error importing project modules: {e}")
    print("Please ensure:")
    print(f"1. You are running this script/notebook from the '{os.path.basename(project_root)}/src' directory, OR")
    print(f"2. The project root directory '{project_root}' is in your PYTHONPATH.")
    # You might need to exit or handle this error appropriately
    sys.exit(1)
except Exception as e:
    print(f"An unexpected error occurred during imports: {e}")
    sys.exit(1)


# --- Basic Setup ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

SYNC_Y = 2024
SYNC_Month = 2
SYNC_D = 7

SYNC_H = 11
SYNC_M = 37
SYNC_S = 0
# --- Configuration Loading ---
try:
    # Construct config paths relative to the project root determined earlier
    config_file_path = os.path.join(project_root, 'config.yaml')
    sync_params_path = os.path.join(project_root, 'Sync_Parameters.yaml')

    if not os.path.exists(config_file_path):
        raise FileNotFoundError(f"Config file not found at {config_file_path}")
    if not os.path.exists(sync_params_path):
        raise FileNotFoundError(f"Sync parameters file not found at {sync_params_path}")

    config = config_loader.load_config(config_file_path)
    with open(sync_params_path, 'r') as f:
        sync_params = yaml.safe_load(f)

except FileNotFoundError as e:
    logging.error(e)
    # Handle error appropriately, e.g., exit or use defaults
    sys.exit(1)
except Exception as e:
    logging.error(f"Error loading configuration: {e}", exc_info=True)
    sys.exit(1)

# --- Load Global Labels ---
global_labels_df = pd.DataFrame() # Initialize empty DataFrame
try:
    # Use the path from config.yaml
    global_labels_path = os.path.join(project_root, config.get('global_labels_file'))
    if not os.path.exists(global_labels_path):
        logging.warning(f"Global labels file not found at {global_labels_path}. Cannot display labels.")
    else:
        logging.info(f"Loading global labels from: {global_labels_path}")
        global_labels_df = pd.read_csv(
            global_labels_path,
            # Assuming format might be DD.MM.YYYY HH:MM:SS based on previous example
            # Use infer_datetime_format for flexibility, add dayfirst=True if needed
            parse_dates=['Real_Start_Time', 'Real_End_Time'],
            infer_datetime_format=True,
            dayfirst=False, # Set to True if dates are DD/MM, False if MM/DD
        )
        # Drop rows where time parsing failed
        rows_before = len(global_labels_df)
        global_labels_df.dropna(subset=['Real_Start_Time', 'Real_End_Time'], inplace=True)
        rows_after = len(global_labels_df)
        if rows_before > rows_after:
            logging.warning(f"Dropped {rows_before - rows_after} rows from labels file due to parsing errors in time columns.")

        # Ensure TZ is naive for comparison with plot axes
        global_labels_df['Real_Start_Time'] = global_labels_df['Real_Start_Time'].dt.tz_localize(None)
        global_labels_df['Real_End_Time'] = global_labels_df['Real_End_Time'].dt.tz_localize(None)
        logging.info(f"Loaded {len(global_labels_df)} valid global labels.")

except FileNotFoundError as e:
    logging.error(e)
    logging.warning("Proceeding without global labels for plotting.")
except Exception as e:
    logging.error(f"Error loading or parsing global labels: {e}", exc_info=True)
    global_labels_df = pd.DataFrame() # Ensure it's an empty df on error
    logging.warning("Proceeding without global labels for plotting.")

# --- Helper Function to Get Data at Stages ---

def get_sensor_data_at_stages(subject_id, sensor_name, start_unix_ts, duration_secs, config, sync_params):
    """
    Loads raw data and processes it step-by-step for a given time window,
    mirroring the preprocessing order of raw_data_processor.py.

    Args:
        subject_id (str): The subject ID (e.g., 'OutSense-608').
        sensor_name (str): The name of the sensor modality (e.g., 'mbient_imu_wc_accelerometer').
        start_unix_ts (float): The starting timestamp (UNIX seconds) for the analysis window.
        duration_secs (float): The duration of the analysis window in seconds.
        config (dict): Loaded configuration from config.yaml.
        sync_params (dict): Loaded synchronization parameters from Sync_Parameters.yaml.

    Returns:
        dict: A dictionary containing DataFrames for each stage:
              'raw_initial', 'indexed_sorted', 'duplicates_handled', 
              'columns_renamed', 'resampled', 'imputed', 'filtered'.
              Returns None if any step fails critically.
    """
    logging.info(f"Processing {subject_id} - {sensor_name} from {start_unix_ts} for {duration_secs}s")
    data_stages_buffered = {}

    # --- Configuration Values ---
    # (Configuration loading remains the same as original)
    raw_data_base_dir = os.path.join(project_root, config.get('raw_data_input_dir')) # Ensure project_root is defined
    subject_dir = os.path.join(raw_data_base_dir, subject_id)
    raw_data_parsing_config = config.get('raw_data_parsing_config', {})
    sensor_settings = raw_data_parsing_config.get(sensor_name)
    subject_correction_params = sync_params.get(subject_id, {})
    sensor_corr_params = subject_correction_params.get(sensor_name, {})
    downsample_freq = config.get('downsample_freq', 25) # Default from notebook, main script uses 20
    filter_params = config.get('filter_parameters', {})
    highcut = filter_params.get('highcut_kinematic', 9.9)
    filter_order = filter_params.get('filter_order', 4)
    target_freq_interval = pd.Timedelta(seconds=1.0 / downsample_freq) if downsample_freq > 0 else None

    if not sensor_settings:
        logging.error(f"No parsing config found for sensor '{sensor_name}' in config.yaml.")
        return None
    if target_freq_interval is None:
         logging.error(f"Invalid downsample_freq: {downsample_freq}")
         return None

    buffer_secs = max(duration_secs * 0.5, 10.0)

    # --- 1. Load Raw Data & Calculate Corrected Time ---
    logging.debug(f"Loading raw data for {sensor_name}...")
    loader = select_data_loader(sensor_name)
    df_raw_full_sensor = loader(subject_dir, sensor_name, sensor_settings) # Renamed to avoid clash

    if df_raw_full_sensor.empty or 'time' not in df_raw_full_sensor.columns:
        logging.warning(f"No raw data loaded or 'time' column missing for {sensor_name}.")
        return None

    logging.debug("Calculating corrected timestamps for all raw data...")
    try:
        time_unit = sensor_corr_params.get('unit', 's')
        time_col_num = df_raw_full_sensor['time'].astype(float)
        time_col_num_shifted = time_col_num / 1000.0 if time_unit == 'ms' else time_col_num
        shift_val = sensor_corr_params.get('shift', 0)
        time_col_num_shifted += shift_val
        drift_params = sensor_corr_params.get('drift')
        time_col_final_num = time_col_num_shifted.copy()

        if drift_params and all(k in drift_params for k in ['t0', 't1', 'drift_secs']):
            t0_dt = pd.to_datetime(drift_params['t0'], errors='coerce')
            t1_dt = pd.to_datetime(drift_params['t1'], errors='coerce')
            drift_secs = drift_params.get('drift_secs')
            if pd.notna(t0_dt) and pd.notna(t1_dt):
                t0_num = t0_dt.timestamp()
                t1_num = t1_dt.timestamp()
                time_col_final_num = time_col_num_shifted.apply(correct_timestamp_drift, args=(t0_num, t1_num, drift_secs))
            else:
                logging.warning(f"Invalid drift timestamps. Skipping drift correction.")
        else:
            logging.debug("No valid drift parameters. Skipping drift correction.")

        corrected_timestamps_dt = pd.to_datetime(time_col_final_num, unit='s', errors='coerce')
        df_raw_full_sensor['corrected_time'] = corrected_timestamps_dt
        df_raw_full_sensor.dropna(subset=['corrected_time'], inplace=True)
    except Exception as e:
        logging.error(f"Error during initial time correction: {e}", exc_info=True)
        return None

    # Filter by user window + buffer (on corrected_time)
    corrected_start_dt_buffered = pd.to_datetime(start_unix_ts - buffer_secs, unit='s')
    corrected_end_dt_buffered = pd.to_datetime(start_unix_ts + duration_secs + buffer_secs, unit='s')
    df_raw_buffered_stage_data = df_raw_full_sensor[
        (df_raw_full_sensor['corrected_time'] >= corrected_start_dt_buffered) &
        (df_raw_full_sensor['corrected_time'] <= corrected_end_dt_buffered)
    ].copy()
    
    del df_raw_full_sensor # Free memory
    import gc
    gc.collect()

    if df_raw_buffered_stage_data.empty:
        logging.warning(f"No raw data in buffered corrected time range for {sensor_name}.")
        return None
    data_stages_buffered['raw_initial'] = df_raw_buffered_stage_data

    # --- 2. Create Indexed & Sorted Data ---
    logging.debug("Creating indexed & sorted data stage...")
    try:
        # Use the already windowed raw data which has 'corrected_time'
        df_indexed_sorted_buffered = data_stages_buffered['raw_initial'].set_index('corrected_time').sort_index()
        # Original 'time' column might still be here if it wasn't dropped by set_index, ensure it's not a data col
        if 'time' in df_indexed_sorted_buffered.columns:
             df_indexed_sorted_buffered = df_indexed_sorted_buffered.drop(columns=['time'])
        data_stages_buffered['indexed_sorted'] = df_indexed_sorted_buffered
    except Exception as e:
        logging.error(f"Error creating indexed_sorted data: {e}", exc_info=True)
        return None

    # --- 3. Handle Duplicates ---
    logging.debug("Handling duplicates...")
    df_input_for_duplicates = data_stages_buffered.get('indexed_sorted')
    if df_input_for_duplicates is not None and not df_input_for_duplicates.empty:
        try:
            sensor_orig_sr = sensor_settings.get('sample_rate', downsample_freq)
            df_duplicates_handled = process_modality_duplicates(df_input_for_duplicates.copy(), sensor_orig_sr)
            data_stages_buffered['duplicates_handled'] = df_duplicates_handled
        except Exception as e:
            logging.error(f"Error handling duplicates: {e}", exc_info=True)
            data_stages_buffered['duplicates_handled'] = df_input_for_duplicates.copy() # Pass previous on error
    else:
        logging.warning("Skipping duplicate handling, input is empty.")
        data_stages_buffered['duplicates_handled'] = pd.DataFrame(index=df_input_for_duplicates.index if df_input_for_duplicates is not None else None)

    # --- 4. Modify Modality Names ---
    logging.debug("Modifying modality names...")
    df_input_for_rename = data_stages_buffered.get('duplicates_handled')
    if df_input_for_rename is not None and not df_input_for_rename.empty:
        try:
            # modify_modality_names expects the original sensor_name from config, not a new_prefix
            df_renamed_or_tuple = modify_modality_names(df_input_for_rename.copy(), sensor_name)
            # Check if the function returned a tuple (sensor_name, df) as observed in logs
            if isinstance(df_renamed_or_tuple, tuple) and len(df_renamed_or_tuple) == 2 and isinstance(df_renamed_or_tuple[1], pd.DataFrame):
                logging.info(f"modify_modality_names returned a tuple for {sensor_name}. Extracting DataFrame.")
                df_renamed = df_renamed_or_tuple[1]
            elif isinstance(df_renamed_or_tuple, pd.DataFrame):
                df_renamed = df_renamed_or_tuple
            else:
                logging.error(f"modify_modality_names returned an unexpected type: {type(df_renamed_or_tuple)}. Expected DataFrame or (str, DataFrame). Using input for rename as fallback.")
                df_renamed = df_input_for_rename.copy() # Fallback
            data_stages_buffered['columns_renamed'] = df_renamed
        except Exception as e:
            logging.error(f"Error renaming columns: {e}", exc_info=True)
            data_stages_buffered['columns_renamed'] = df_input_for_rename.copy()
    else:
        logging.warning("Skipping column renaming, input is empty.")
        data_stages_buffered['columns_renamed'] = pd.DataFrame(index=df_input_for_rename.index if df_input_for_rename is not None else None)

    # --- 5. Resample Data ---
    logging.debug("Resampling data...")
    df_input_for_resample = data_stages_buffered.get('columns_renamed')

    # Initialize 'resampled' stage with an empty DataFrame.
    # Try to determine columns from df_input_for_resample if it's a DataFrame,
    # otherwise fallback to 'duplicates_handled' or None.
    expected_resampled_columns = None
    if isinstance(df_input_for_resample, pd.DataFrame):
        expected_resampled_columns = df_input_for_resample.columns
    else:
        # If df_input_for_resample is not a DataFrame (e.g. due to an issue in 'columns_renamed' stage),
        # try to get columns from a previous valid stage.
        df_prev_stage_for_cols = data_stages_buffered.get('duplicates_handled') # or 'indexed_sorted'
        if isinstance(df_prev_stage_for_cols, pd.DataFrame):
            expected_resampled_columns = df_prev_stage_for_cols.columns
        else:
            # As a last resort, if no column info can be derived, log it.
            # The 'resampled' df will be initialized with no columns.
            logging.warning("Could not determine columns for initializing 'resampled' stage from previous stages.")

    # Ensure data_stages_buffered['resampled'] is initialized as a DataFrame
    # even if expected_resampled_columns is None (it will be an empty DF with no columns then)
    data_stages_buffered['resampled'] = pd.DataFrame(index=pd.to_datetime([]), columns=expected_resampled_columns)

    if isinstance(df_input_for_resample, pd.DataFrame): # Proceed only if it's a DataFrame
        if not df_input_for_resample.empty:
            try:
                df_resampled_mean = df_input_for_resample.resample(target_freq_interval).mean()
                
                resample_start_time = df_resampled_mean.index.min() if not df_resampled_mean.empty else df_input_for_resample.index.min()
                resample_end_time = df_resampled_mean.index.max() if not df_resampled_mean.empty else df_input_for_resample.index.max()

                if pd.notna(resample_start_time) and pd.notna(resample_end_time) and resample_start_time <= resample_end_time:
                    target_index_for_reindex = pd.date_range(
                        start=resample_start_time.floor(target_freq_interval),
                        end=resample_end_time.ceil(target_freq_interval),
                        freq=target_freq_interval
                    )
                    df_resampled_reindexed = df_resampled_mean.reindex(target_index_for_reindex)
                elif not df_resampled_mean.empty: # If mean is not empty but range is bad, use its index
                    df_resampled_reindexed = df_resampled_mean
                else: # Fallback to empty DF with correct index type and columns from input
                    df_resampled_reindexed = pd.DataFrame(index=pd.to_datetime([]), columns=df_input_for_resample.columns)
                
                data_stages_buffered['resampled'] = df_resampled_reindexed # Assign successful result
            except Exception as e:
                logging.error(f"Error during resampling: {e}", exc_info=True)
                # 'resampled' remains the initialized empty DataFrame with expected_resampled_columns
        else: # df_input_for_resample is an empty DataFrame
            logging.warning("Skipping resampling, input DataFrame (columns_renamed) is empty.")
            # Ensure 'resampled' is an empty DF with the same columns as the input
            data_stages_buffered['resampled'] = pd.DataFrame(index=pd.to_datetime([]), columns=df_input_for_resample.columns)
    elif df_input_for_resample is None:
        logging.warning("Skipping resampling, input (columns_renamed) is None.")
        # 'resampled' remains the initialized empty DataFrame
    else: # df_input_for_resample is not a DataFrame and not None (e.g., it's a tuple)
        logging.error(f"Skipping resampling. Expected 'columns_renamed' to be a pandas DataFrame, but found {type(df_input_for_resample)}. Value: {df_input_for_resample}")
        # 'resampled' remains the initialized empty DataFrame

    # --- 6. Impute Data ---
    logging.debug("Imputing data...")
    df_input_for_imputation = data_stages_buffered.get('resampled')
    if df_input_for_imputation is not None and not df_input_for_imputation.empty:
        try:
            df_imputed = df_input_for_imputation.copy()
            # Imputation logic from raw_data_processor.py
            df_imputed.ffill(limit=int(downsample_freq * 2), inplace=True)
            df_imputed.bfill(limit=int(downsample_freq * 2), inplace=True)
            df_imputed.fillna(0, inplace=True)
            data_stages_buffered['imputed'] = df_imputed
        except Exception as e:
            logging.error(f"Error during imputation: {e}", exc_info=True)
            data_stages_buffered['imputed'] = df_input_for_imputation.copy()
    else:
        logging.warning("Skipping imputation, input is empty.")
        data_stages_buffered['imputed'] = pd.DataFrame(index=df_input_for_imputation.index if df_input_for_imputation is not None else None)
        if data_stages_buffered['imputed'] is not None and df_input_for_imputation is not None: # Ensure columns if index exists
             data_stages_buffered['imputed'] = data_stages_buffered['imputed'].reindex(columns=df_input_for_imputation.columns)


    # --- 7. Apply Filter ---
    logging.debug("Applying low-pass filter...")
    df_input_for_filter = data_stages_buffered.get('imputed')
    if df_input_for_filter is not None and not df_input_for_filter.empty:
        try:
            sos = butter_lowpass_sos(highcut, downsample_freq, filter_order)
            columns_to_filter = df_input_for_filter.select_dtypes(include=np.number).columns.tolist()
            if columns_to_filter and sos is not None:
                df_filtered = apply_filter_combined(df_input_for_filter.copy(), sos, columns_to_filter)
                data_stages_buffered['filtered'] = df_filtered
            elif sos is None:
                logging.warning("SOS filter coefficients are None. Storing unfiltered.")
                data_stages_buffered['filtered'] = df_input_for_filter.copy()
            else:
                logging.warning("No numeric columns to filter or SOS is None. Storing unfiltered.")
                data_stages_buffered['filtered'] = df_input_for_filter.copy()
        except Exception as e:
            logging.error(f"Error during filtering: {e}", exc_info=True)
            data_stages_buffered['filtered'] = df_input_for_filter.copy()
    else:
        logging.warning("Skipping filtering, input is empty.")
        data_stages_buffered['filtered'] = pd.DataFrame(index=df_input_for_filter.index if df_input_for_filter is not None else None)
        if data_stages_buffered['filtered'] is not None and df_input_for_filter is not None: # Ensure columns
            data_stages_buffered['filtered'] = data_stages_buffered['filtered'].reindex(columns=df_input_for_filter.columns)


    # --- 8. Slice all stages to Final Requested Window (without buffer) ---
    logging.debug("Slicing buffered data to final exact window using UNIX timestamp...")
    final_data_stages = {}
    final_start_dt = pd.to_datetime(start_unix_ts, unit='s')
    final_end_dt = pd.to_datetime(start_unix_ts + duration_secs, unit='s')
    
    # Define the order of stages for final slicing and output
    # These are the keys used in data_stages_buffered
    processing_stage_keys = ['raw_initial', 'indexed_sorted', 'duplicates_handled', 
                             'columns_renamed', 'resampled', 'imputed', 'filtered']

    for stage_key in processing_stage_keys:
        df_buffered_stage = data_stages_buffered.get(stage_key)
        if df_buffered_stage is None:
            final_data_stages[stage_key] = pd.DataFrame()
            logging.debug(f"Stage '{stage_key}' is empty or None, skipping final slice.")
            continue

        logging.debug(f"Slicing stage '{stage_key}'...")
        try:
            if stage_key == 'raw_initial':
                # Slice raw_initial stage based on its 'corrected_time' column
                df_final_slice = df_buffered_stage[
                    (df_buffered_stage['corrected_time'] >= final_start_dt) &
                    (df_buffered_stage['corrected_time'] < final_end_dt) # Use < for end for consistency with date_range
                ].copy()
                # Do not drop 'corrected_time' here, it might be useful for inspection
                # if 'corrected_time' in df_final_slice.columns:
                # df_final_slice = df_final_slice.drop(columns=['corrected_time'])
            elif isinstance(df_buffered_stage.index, pd.DatetimeIndex):
                # Slice other stages based on their DatetimeIndex
                df_final_slice = df_buffered_stage[
                    (df_buffered_stage.index >= final_start_dt) &
                    (df_buffered_stage.index < final_end_dt)
                ].copy()
            else:
                logging.warning(f"DataFrame for stage '{stage_key}' has unexpected index: {type(df_buffered_stage.index)}. Skipping.")
                df_final_slice = pd.DataFrame()
            
            final_data_stages[stage_key] = df_final_slice
            logging.debug(f" Stage '{stage_key}' final shape: {df_final_slice.shape}")
        except Exception as slice_err:
             logging.error(f"Error slicing stage '{stage_key}': {slice_err}", exc_info=True)
             final_data_stages[stage_key] = pd.DataFrame()

    logging.info(f"Successfully processed stages for {subject_id} - {sensor_name}")
    return final_data_stages


# --- Plotting Function ---
def plot_data_stages(data_stages_dict, subject_id, sensor_name, start_unix_ts, start_dt, duration_secs, config, subject_labels):
    """Plots the data from different processing stages with labels."""

    if not data_stages_dict:
        logging.warning("No data provided to plot_data_stages.")
        return

    # Updated stages to plot, in order
    stages_to_plot = ['raw_initial', 'indexed_sorted', 'duplicates_handled', 
                      'columns_renamed', 'resampled', 'imputed', 'filtered']
    num_stages = len(stages_to_plot)
    # sensor_settings = config.get('raw_data_parsing_config', {}).get(sensor_name, {}) # Already available if needed

    # --- Label preparation (remains largely the same) ---
    unique_labels_in_plot_range = []
    relevant_labels = pd.DataFrame()
    if not subject_labels.empty:
         plot_window_start_utc_naive = pd.to_datetime(start_unix_ts, unit='s')
         plot_window_end_utc_naive = pd.to_datetime(start_unix_ts + duration_secs, unit='s')
         subject_labels['Real_Start_Time'] = pd.to_datetime(subject_labels['Real_Start_Time'])
         subject_labels['Real_End_Time'] = pd.to_datetime(subject_labels['Real_End_Time'])
         relevant_labels = subject_labels[
             (subject_labels['Real_Start_Time'] < plot_window_end_utc_naive) &
             (subject_labels['Real_End_Time'] > plot_window_start_utc_naive)
         ].copy()
         if not relevant_labels.empty and 'Label' in relevant_labels.columns:
              unique_labels_in_plot_range = sorted(relevant_labels['Label'].unique())
    # --- End Label preparation ---
    
    plot_window_start_dt = pd.to_datetime(start_unix_ts, unit='s')
    plot_window_end_dt = pd.to_datetime(start_unix_ts + duration_secs, unit='s')
    xlim_for_corrected_axes = (plot_window_start_dt, plot_window_end_dt)

    prop_cycle = plt.get_cmap('tab10')
    num_colors_to_map = min(len(unique_labels_in_plot_range), 10)
    label_colors_list = [prop_cycle(i / num_colors_to_map) for i in range(num_colors_to_map)] # Renamed
    label_color_map = {label: label_colors_list[i % num_colors_to_map] for i, label in enumerate(unique_labels_in_plot_range)}


    fig, axes = plt.subplots(num_stages, 1, figsize=(15, 5 * num_stages), sharex=False)
    fig.suptitle(f"Data Processing Stages for {subject_id} - {sensor_name}\\nWindow: {datetime.fromtimestamp(start_unix_ts)} (Duration: {duration_secs}s)", fontsize=16)

    for i, stage_key in enumerate(stages_to_plot): # Use stage_key consistently
        ax = axes[i]
        ax.set_title(f"Stage: {stage_key.replace('_', ' ').capitalize()}") # Prettify title
        ax.grid(True, linestyle='--', alpha=0.6)
        ax.set_ylabel("Value")

        df = data_stages_dict.get(stage_key)
        
        if df is None or df.empty:
            ax.text(0.5, 0.5, 'No data available for this stage/window', horizontalalignment='center', verticalalignment='center', transform=ax.transAxes)
            # Set xlim even for empty plots for consistency, except for raw_initial if it has no time data
            if stage_key != 'raw_initial':
                 ax.set_xlim(xlim_for_corrected_axes)
            continue

        time_data_for_plot = None
        df_to_plot_values = df # df whose columns will be plotted
        
        if stage_key == 'raw_initial' and 'time' in df.columns:
            # For 'raw_initial', x-axis is original 'time', converted to datetime
            time_data_for_plot = pd.to_datetime(df['time'], unit='s', errors='coerce')
            valid_time_mask = time_data_for_plot.notna()
            time_data_for_plot = time_data_for_plot[valid_time_mask]
            df_to_plot_values = df[valid_time_mask]
            ax.set_xlabel("Original Timestamp (from 'time' column, as Datetime)")
            if not time_data_for_plot.empty:
                 ax.set_xlim(time_data_for_plot.min(), time_data_for_plot.max())
        elif isinstance(df.index, pd.DatetimeIndex):
            # For all other stages, x-axis is the DatetimeIndex (corrected time)
            time_data_for_plot = df.index
            df_to_plot_values = df
            ax.set_xlabel("Corrected Timestamp (DatetimeIndex)")
            ax.set_xlim(xlim_for_corrected_axes)
        else:
            ax.text(0.5, 0.5, "Cannot determine time axis", transform=ax.transAxes)
            logging.warning(f"Cannot determine time axis for stage '{stage_key}'. Index: {type(df.index)}")
            continue

        columns_to_plot_now = df_to_plot_values.select_dtypes(include=np.number).columns.tolist()
        # For raw_initial, ensure 'time' and 'corrected_time' (if numeric) are not plotted as y-values
        if stage_key == 'raw_initial':
            if 'time' in columns_to_plot_now: columns_to_plot_now.remove('time')
            if 'corrected_time' in columns_to_plot_now: columns_to_plot_now.remove('corrected_time')
        
        line_handles, line_labels_legend = [], [] # Renamed for clarity
        plotted_something = False
        if not df_to_plot_values.empty and time_data_for_plot is not None and not time_data_for_plot.empty:
            for col in columns_to_plot_now:
                if col in df_to_plot_values.columns:
                    line, = ax.plot(time_data_for_plot, df_to_plot_values[col], label=col, alpha=0.8)
                    line_handles.append(line)
                    line_labels_legend.append(col)
                    plotted_something = True
        
        # Label Spans (axvspan) - only for stages with corrected time index
        unique_labels_plotted_on_axis = set()
        if stage_key != 'raw_initial' and not relevant_labels.empty:
            for _, label_row in relevant_labels.iterrows():
                label_start = label_row['Real_Start_Time']
                label_end = label_row['Real_End_Time']
                label_name = label_row['Label']
                color = label_color_map.get(label_name, 'silver')
                ax.axvspan(label_start, label_end, color=color, alpha=0.3, zorder=0)
                unique_labels_plotted_on_axis.add(label_name)
        
        # Create Legend
        combined_handles = list(line_handles)
        combined_labels_legend = list(line_labels_legend) # Renamed
        label_patches = []
        if stage_key != 'raw_initial': # Add label patches for relevant stages
             for label_name_patch in sorted(list(unique_labels_plotted_on_axis)): # Renamed
                  if label_name_patch in label_color_map:
                       patch = Patch(color=label_color_map[label_name_patch], alpha=0.3, label=f"{label_name_patch}")
                       label_patches.append(patch)
        combined_handles.extend(label_patches)
        combined_labels_legend.extend([p.get_label() for p in label_patches])

        if combined_handles:
            ax.legend(combined_handles, combined_labels_legend, loc='upper left', bbox_to_anchor=(1.01, 1), fontsize='small', title="Legend")
        elif not plotted_something:
             ax.text(0.5, 0.5, 'No numeric data or labels to plot', horizontalalignment='center', verticalalignment='center', transform=ax.transAxes)

    plt.tight_layout(rect=[0, 0.03, 0.9, 0.95]) # Adjust for suptitle and legend
    plt.show() # Ensure plot is displayed

# --- Widgets Setup and Interaction ---

# Get available subjects and sensors (as before)
available_subjects = sorted(list(sync_params.keys()))
available_sensors = sorted(list(config.get('raw_data_parsing_config', {}).keys()))

# Create Widgets
subject_dropdown = Dropdown(options=available_subjects, description='Subject:')
sensor_dropdown = Dropdown(options=available_sensors, description='Sensor:')

# --- Date/Time Widgets ---
# Default to a sensible date/time (e.g., now, or a fixed date relevant to data)
# Let's default to now, rounded down to the nearest second for simplicity
now = datetime.now().replace(microsecond=0)
start_date_picker = DatePicker(description='Start Date:', value=datetime(SYNC_Y, SYNC_Month, SYNC_D), tooltip="Select the starting date")
start_hour_text = IntText(value=SYNC_H, description='Hour (0-23):', min=0, max=23, step=1, style={'description_width': 'initial'}, layout={'width': '150px'})
start_minute_text = IntText(value=SYNC_M, description='Min (0-59):', min=0, max=59, step=1, style={'description_width': 'initial'}, layout={'width': '150px'})
start_second_text = IntText(value=SYNC_S, description='Sec (0-59):', min=0, max=59, step=1, style={'description_width': 'initial'}, layout={'width': '150px'})
# --- End Date/Time Widgets ---

# --- Label Shift Widget ---
label_hour_shift_text = FloatText(value=0.0, description='Label Shift (h):', step=0.5, style={'description_width': 'initial'}, tooltip="Enter hours to shift labels (e.g., -2 or 1.5)")
label_minute_shift_text = IntText(value=0, description='(min):', step=1, style={'description_width': 'initial'}, layout={'width': '120px'}, tooltip="Enter minutes to shift labels")
label_second_shift_text = IntText(value=0, description='(sec):', step=1, style={'description_width': 'initial'}, layout={'width': '120px'}, tooltip="Enter seconds to shift labels")

duration_text = FloatText(value=120.0, description='Duration (s):', step=1.0, style={'description_width': 'initial'})
run_button = Button(description='Analyze Window')
plot_output = Output() # To display plots within the widget area

# Layout
# Group date/time widgets together for better organization
time_input_box = HBox([start_hour_text, start_minute_text, start_second_text], layout={'justify_content': 'space-around'})
datetime_controls = VBox([start_date_picker, time_input_box], layout={'border': '1px solid lightgray', 'padding': '5px', 'margin_bottom': '5px'})

# Combine all controls
label_shift_controls = HBox([label_hour_shift_text, label_minute_shift_text, label_second_shift_text], layout={'justify_content': 'flex-start'})
controls = VBox([
    HBox([subject_dropdown, sensor_dropdown]),
    datetime_controls, # Add the new datetime controls VBox
    duration_text,
    label_shift_controls,
    run_button
])

# Button Click Handler
def on_run_button_clicked(b):
    with plot_output:
        plot_output.clear_output(wait=True) # Clear previous plot/output
        subject = subject_dropdown.value
        sensor = sensor_dropdown.value
        duration = duration_text.value
        label_shift_hours = label_hour_shift_text.value
        label_shift_minutes = label_minute_shift_text.value
        label_shift_seconds = label_second_shift_text.value

        # --- Read Date/Time Widgets and Construct Timestamp ---
        start_date = start_date_picker.value
        start_hour = start_hour_text.value
        start_minute = start_minute_text.value
        start_second = start_second_text.value

        # --- Input Validation ---
        error_messages = []
        if start_date is None:
            error_messages.append("Please select a valid start date.")
        # Combine checks for time components
        if not (0 <= start_hour <= 23 and 0 <= start_minute <= 59 and 0 <= start_second <= 59):
            error_messages.append("Invalid hour (0-23), minute (0-59), or second (0-59) value.")
        if duration <= 0:
            error_messages.append("Duration must be positive.")
        if not subject:
            error_messages.append("Please select a subject.")
        if not sensor:
            error_messages.append("Please select a sensor.")

        if error_messages:
            print("Input Errors Found:")
            for msg in error_messages:
                print(f"- {msg}")
            return # Stop processing if errors exist

        # --- Construct datetime and timestamp ---
        try:
            # Combine date and time components into a datetime object
            start_dt = datetime.combine(start_date, time(start_hour, start_minute, start_second))
            # Convert the datetime object to a UNIX timestamp (float)
            start_ts = start_dt.timestamp()
            logging.info(f"Selected Start DateTime: {start_dt}, Corresponding UNIX Timestamp: {start_ts}")
        except Exception as e:
            print(f"\nError: Could not construct valid datetime from inputs: {e}")
            logging.error(f"Error constructing datetime from DatePicker/IntText: {e}", exc_info=True)
            return
        # --- End Timestamp Construction ---

        # --- >>> Apply Label Shift and Filter Labels <<< ---
        subject_labels_shifted_filtered = pd.DataFrame() # Default to empty
        if not global_labels_df.empty and 'Video_File' in global_labels_df.columns and 'Label' in global_labels_df.columns:
             try:
                  # 1. Create a copy to avoid modifying the original global_labels_df
                  labels_to_shift = global_labels_df.copy()

                  # 2. Apply the shift if it's non-zero
                  total_shift_seconds_manual = (label_shift_hours * 3600) + (label_shift_minutes * 60) + label_shift_seconds
                  if total_shift_seconds_manual != 0.0:
                      logging.info(f"Applying manual label time shift of {label_shift_hours} hours, {label_shift_minutes} minutes, {label_shift_seconds} seconds.")
                      time_delta_shift_manual = pd.Timedelta(seconds=total_shift_seconds_manual)
                      # Ensure columns are datetime before shifting
                      labels_to_shift['Real_Start_Time'] = pd.to_datetime(labels_to_shift['Real_Start_Time'])
                      labels_to_shift['Real_End_Time'] = pd.to_datetime(labels_to_shift['Real_End_Time'])
                      labels_to_shift['Real_Start_Time'] = labels_to_shift['Real_Start_Time'] + time_delta_shift_manual
                      labels_to_shift['Real_End_Time'] = labels_to_shift['Real_End_Time'] + time_delta_shift_manual
                  else:
                       logging.info("Manual label time shift is 0. Using original label times for this step.")

                  # 3. Filter the (potentially shifted) labels for the subject
                  subject_for_filter = subject
                  if subject == 'OutSense-425_48h':
                        subject_for_filter = 'OutSense-425' # Use base name for filtering Video_File

                  subject_labels_intermediate = labels_to_shift[
                      labels_to_shift['Video_File'].astype(str).str.contains(subject_for_filter, na=False, regex=False)
                  ].copy()

                  # 4. Apply Label_Time_Shift from Sync_Parameters.yaml to this subject-specific subset
                  label_time_shift_str = sync_params.get(subject, {}).get('Label_Time_Shift', '0h 0min 0s')
                  logging.info(f"Applying Label_Time_Shift from Sync_Parameters for subject {subject}: {label_time_shift_str}")
                  shift_match = re.match(r'(?:(-?\d+)h)?\s*(?:(-?\d+)min)?\s*(?:(-?\d+)s)?', label_time_shift_str)
                  if shift_match:
                      sync_shift_hours = int(shift_match.group(1) or 0)
                      sync_shift_minutes = int(shift_match.group(2) or 0)
                      sync_shift_seconds = int(shift_match.group(3) or 0)
                      total_sync_shift_seconds = (sync_shift_hours * 3600) + (sync_shift_minutes * 60) + sync_shift_seconds

                      if total_sync_shift_seconds != 0:
                          logging.info(f"Shifting labels by {total_sync_shift_seconds} seconds based on Sync_Parameters.")
                          time_delta_sync_shift = pd.Timedelta(seconds=total_sync_shift_seconds)
                          subject_labels_intermediate['Real_Start_Time'] = pd.to_datetime(subject_labels_intermediate['Real_Start_Time']) + time_delta_sync_shift
                          subject_labels_intermediate['Real_End_Time'] = pd.to_datetime(subject_labels_intermediate['Real_End_Time']) + time_delta_sync_shift
                      else:
                          logging.info("No additional label time shift from Sync_Parameters (shift is 0 seconds).")
                  else:
                      logging.warning(f"Could not parse Label_Time_Shift for subject {subject}: {label_time_shift_str}")

                  # 5. Keep only necessary columns and drop NaNs from time columns *after* all shifts
                  subject_labels_shifted_filtered = subject_labels_intermediate[['Real_Start_Time', 'Real_End_Time', 'Label']].dropna(subset=['Real_Start_Time', 'Real_End_Time'])
                  logging.info(f"Filtered {len(subject_labels_shifted_filtered)} labels for subject {subject} (after all shifts).")

             except Exception as e:
                  logging.error(f"Error processing labels for subject {subject}: {e}", exc_info=True)
                  subject_labels_shifted_filtered = pd.DataFrame() # Ensure it's empty on error
        else:
             logging.warning("Global labels DataFrame is empty or essential columns ('Video_File', 'Label') missing. Cannot filter labels.")
        # --- >>> End Label Shift and Filtering <<< ---

        print("-" * 30)
        print(f"DEBUG: Passing {len(subject_labels_shifted_filtered)} labels to plot function.")
        if not subject_labels_shifted_filtered.empty:
            print("DEBUG: First 5 filtered labels passed:")
            print(subject_labels_shifted_filtered.head().to_string())
        print("-" * 30)

        print(f"\nStarting analysis for Subject: {subject}, Sensor: {sensor}...")

        data_dict = get_sensor_data_at_stages(subject, sensor, start_ts, duration, config, sync_params)

        if data_dict:
            print("--- DataFrame Shapes After Processing ---")
            for stage, df_stage_data in data_dict.items(): # Renamed df to df_stage_data
                print(f"DEBUG: Stage '{stage}' shape: {df_stage_data.shape if df_stage_data is not None else 'None'}")
                if df_stage_data is not None and not df_stage_data.empty and isinstance(df_stage_data.index, pd.DatetimeIndex):
                    print(f"DEBUG: Stage '{stage}' index range: Min={df_stage_data.index.min()}, Max={df_stage_data.index.max()}")
                elif df_stage_data is not None and not df_stage_data.empty and stage == 'raw_initial' and 'time' in df_stage_data.columns: # Changed from 'raw'
                    try:
                        raw_time_dt = pd.to_datetime(df_stage_data['time'], unit='s', errors='coerce')
                        print(f"DEBUG: Stage '{stage}' original 'time' range: Min={raw_time_dt.min()}, Max={raw_time_dt.max()}")
                    except Exception:
                        print(f"DEBUG: Stage '{stage}' original 'time' could not be fully converted.")

            print("--- End DataFrame Shapes ---")
            print("Plotting data stages...")
            plot_data_stages(data_dict, subject, sensor, start_ts, start_dt, duration, config, subject_labels_shifted_filtered)
            print("\nAnalysis complete.")
        else:
            print("\nFailed to process data for the selected window. Check logs for details.")

run_button.on_click(on_run_button_clicked)

# --- Display Widgets ---
# This should be the last part of your script if running interactively
display(VBox([controls, plot_output]))

Adding project root to path: /scai_data3/scratch/stirnimann_r


2025-06-24 16:30:55,058 - INFO - Configuration loaded successfully from /scai_data3/scratch/stirnimann_r/config.yaml
2025-06-24 16:30:55,130 - INFO - Loading global labels from: /scai_data3/scratch/stirnimann_r/All_Videos_with_Labels_Real_Time_Corrected_Labels.csv
  global_labels_df = pd.read_csv(
2025-06-24 16:30:55,163 - INFO - Loaded 7205 valid global labels.


VBox(children=(VBox(children=(HBox(children=(Dropdown(description='Subject:', options=('OutSense-036', 'OutSen…