# Extract and align data from Onix, Harp, Sleap, and photometry
## Cohort 1 and 2 working, Cohort 0: onix_digital Clock column is 0, explore why and/or use timestamps instead 

In [None]:
import numpy as np
from pathlib import Path
import os
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import pandas as pd
#import harp
import plotly.express as px
from scipy.stats import mode

import gc # garbage collector for removing large variables from memory instantly 

import importlib #for force updating changed packages 
import harp_resources.process
import harp_resources.utils
from harp_resources import process, utils # Reassign to maintain direct references for force updating 
from sleap import load_and_process as lp

In [None]:
#initiate variables 
has_heartbeat = False
cohort0 = False
cohort2 = False
onix_analog_clock_downsampled = False
onix_analog_framecount_upsampled = False
common_resampled_rate = 10000 #in Hz
unit_conversions = False #
save_full_asynchronous_data = False #saves alldata before resampling

#Cohort 1 vestibular mismatch, multiple OnixDigital files 
#data_path = Path('/Users/rancze/Documents/Data/vestVR/Cohort1/VestibularMismatch_day1/B6J2718-2024-12-12T13-28-14') #multiple onix_digital file

#Cohort 1 vestibular mismatch, with clock accumulation issue marked on google sheet, seems fine though
#data_path = Path('/Users/rancze/Documents/Data/vestVR/Cohort1/VestibularMismatch_day1/B6J2719-2024-12-12T13-59-38') #multiple onix_digital file

#Cohort 1 vestibular mismatch
#data_path = Path('/Users/rancze/Documents/Data/vestVR/Cohort1/VestibularMismatch_day1/B6J2717-2024-12-12T13-00-21')

#Cohort 1 visual mismatch 
#data_path = Path('/Users/rancze/Documents/Data/vestVR/Cohort1/Visual_mismatch_day3/B6J2718-2024-12-10T12-57-02') 

#Cohort 1 visual mismatch THIS
data_path = Path('/Users/rancze/Documents/Data/vestVR/Cohort1/Visual_mismatch_day3/B6J2717-2024-12-10T12-17-03')

#Cohort 0 (no OnixHarp in this Cohort)
#data_path = Path('/Users/rancze/Documents/Data/vestVR/Cohort0/Cohort0_GCaMP_example/B3M3xx-2024-08-08T10-05-26')
#cohort0 = True

#Cohort 2 (Cohort 1 animal) 
#data_path = Path('/Users/rancze/Documents/Data/vestVR/Cohort2_test/2025-02-13T12-41-57')
#has_heartbeat = True

photometry_path = data_path.parent / f"{data_path.name}_processedData" / "photometry"

#h1_datafolder = data_path / 'HarpDataH1' #only if reading separate registers
#h2_datafolder = data_path / 'HarpDataH2' #only if reading separate registers
#h1 and h2 only needed if timestamps are readed separately and not as all harp_streams
#h1_reader = harp.create_reader('harp_resources/h1-device.yml', epoch=harp.REFERENCE_EPOCH)
#h2_reader = harp.create_reader('harp_resources/h2-device.yml', epoch=harp.REFERENCE_EPOCH)

#create loaders 
session_settings_reader = utils.SessionData("SessionSettings")
experiment_events_reader = utils.TimestampedCsvReader("ExperimentEvents", columns=["Event"])
onix_framecount_reader = utils.TimestampedCsvReader("OnixAnalogFrameCount", columns=["Index"])
#photometry_reader = utils.PhotometryReader("Processed_fluorescence")
video_reader1 = utils.VideoReader("VideoData1")
video_reader2 = utils.VideoReader("VideoData2")
onix_digital_reader = utils.OnixDigitalReader("OnixDigital", columns=["Value.Clock", "Value.HubClock", 
                                                                         "Value.DigitalInputs",
                                                                         "Seconds"])
onix_harp_reader = utils.TimestampedCsvReader("OnixHarp", columns=["Clock", "HubClock", "HarpTime"])

### Load all data 

In [None]:
print ("Loading session settings")
session_settings = utils.load_2(session_settings_reader, data_path) #Andrew's, creates ugly df, but used in further analysis code
print ("Loading experiment events")
experiment_events = utils.load_2(experiment_events_reader, data_path)

print ("Loading processed photometry")
photometry_data=pd.read_csv(str(photometry_path)+'/Processed_fluorescence.csv')
photometry_data.set_index("TimeStamp", inplace=True)
photometry_data.index.name = 'Seconds'
print ("Loading processed photometry info")
photometry_info=pd.read_csv(str(photometry_path)+'/Info.csv')
print ("Loading processed photometry events")
photometry_events=pd.read_csv(str(photometry_path)+'/Events.csv')
photometry_events["TimeStamp"] = photometry_events["TimeStamp"] /1000 # convert to seconds from ms
photometry_events.set_index("TimeStamp", inplace=True)
photometry_events.index.name = 'Seconds'

if not cohort2:
    print ("Loading video data 1")
    video_data1 = utils.load_2(video_reader1, data_path)
    print ("Loading video data 2")
    video_data2 = utils.load_2(video_reader2, data_path)

# read Onix data 
print ("Loading OnixDigital")
onix_digital = utils.load_2(onix_digital_reader, data_path)

if cohort0:
    print ("Loading OnixAnalogFrameClock")
    onix_analog_framecount = utils.load_2(onix_framecount_reader, data_path)
    
print ("Loading OnixAnalogClock")
onix_analog_clock = utils.read_OnixAnalogClock(data_path)
print ("Loading OnixAnalogData and converting to boolean photodiode array")
photodiode = utils.read_OnixAnalogData(data_path, channels = [0], binarise=True, method='adaptive', refractory = 300, flip=True, verbose=False) #method adaptive or threshold (which is hard threshold at 120), refractory to avoid multiple detections

#read HARP data
print ("Loading H1 and H2 streams, AnalogInput removed")
harp_streams = utils.load_registers(data_path, dataframe = True, has_heartbeat = has_heartbeat, verbose = False) #loads as df, or if False, as dict
harp_streams.drop(columns=["AnalogInput(39)"], inplace=True)  # Removes AnalogInput permanently, as not currently used
harp_streams = harp_streams.dropna(how="all") # remove rows with all NaNs
# Convert specific columns in harp_streams to boolean type
columns_to_convert = ["StartCam0(38)", "StartCam1(38)", "StopCam0(38)", "StopCam1(38)"]
for col in columns_to_convert:
    harp_streams[col] = harp_streams[col].astype(bool)

#read syncronising signal between HARP and ONIX
if not cohort0:
    print ("Loading OnixHarp")
    onix_harp = utils.load_2(onix_harp_reader, data_path)
    onix_harp = utils.detect_and_remove_outliers(
    df=onix_harp,
    x_column="HarpTime",
    y_column="Clock",
    verbose=False  # True prints all outliers
    )
    onix_harp["HarpTime"] = onix_harp["HarpTime"] + 1 # known issue with current version of ONIX, harp timestamps lag 1 second
    print ("❗Reminder: HarpTime was increased by 1s to account for know issue with ONIX")

print ("✅ Done Loading")

Convert platform position and flow sensor streams to real world units and forward fill 

In [None]:
# Get encoder values for homing and next event positions as absolute real life 0 position 
homing_position, next_event_position = process.get_encoder_home_position(experiment_events, harp_streams)
print ("Encoder values for homing and next event positions")
print(f"Encoder value at 'Homing platform': {homing_position}")
print(f"Encoder value at the next experiment event: {next_event_position}")
print("❗ Warning: home position not tested. Likely the next event after homing reports the home position. "
      "Alternatively, save a separate experiment event when homing is finished")

# Perform unit conversions if not already done
if not unit_conversions:
    harp_streams["OpticalTrackingRead0X(46)"] = process.running_unit_conversion(
        harp_streams["OpticalTrackingRead0X(46)"].to_numpy())  # m / s
    harp_streams["OpticalTrackingRead0Y(46)"] = process.turning_unit_conversion(
        harp_streams["OpticalTrackingRead0Y(46)"].to_numpy())  # degrees / s
    harp_streams["OpticalTrackingRead1X(46)"] = process.running_unit_conversion(
        harp_streams["OpticalTrackingRead1X(46)"].to_numpy())
    harp_streams["OpticalTrackingRead1Y(46)"] = process.turning_unit_conversion(
        harp_streams["OpticalTrackingRead1Y(46)"].to_numpy())
    harp_streams["Encoder(38)"] = process.encoder_unit_conversion(
        harp_streams["Encoder(38)"], next_event_position)  # FIXME: what is the real home position?

    # Forward fill all values to remove NaNs
    columns_to_fill = [
        "OpticalTrackingRead0X(46)", "OpticalTrackingRead0Y(46)",
        "OpticalTrackingRead1X(46)", "OpticalTrackingRead1Y(46)",
        "Encoder(38)"
    ]
    harp_streams[columns_to_fill] = harp_streams[columns_to_fill].ffill()
    unit_conversions = True
    print("✅ Unit conversions to real-life values done")
else:
    print("❗ Flow sensor and encoder values already converted to real-world units, skipping")


### Downsamples photodiode and analog_clock to common rate (10 kHz) and aligned photodiode_df to harptime - for Cohort 0 also upsamples framecount, not used for Cohort1+

In [None]:
# gets analog data sample rate and downsamples to common_resampled_rate
if not onix_analog_clock_downsampled:
    onix_analog_clock = (onix_analog_clock * 4) * 1e-9  # convert to seconds with 250MHz DeviceClock, set in hardware
    oac_diff = np.diff(onix_analog_clock)
    onix_analog_rate = round(1 / (np.median(oac_diff)))  # to get to sampling rate (in Hz) with 250MHz DeviceClock, set in hardware
    downsample_factor = int(onix_analog_rate / common_resampled_rate)
    print(f"onix_analog_clock rate: {onix_analog_rate}, downsample factor: {downsample_factor}")
    
    onix_analog_clock = process.downsample_numpy(onix_analog_clock, downsample_factor, method="mean")
    photodiode = process.downsample_numpy(photodiode, downsample_factor, method="mean")
    photodiode_df = pd.DataFrame({"Photodiode": photodiode.astype(bool)}, index=pd.Index(onix_analog_clock, name="Seconds"))
    
    # Get the timestamp when the sync signal started
    sync_start_time = experiment_events[experiment_events == "Sync signal started"].index[0]
    seconds_index = photodiode_df.index.values

    # Convert the seconds index to timedelta and add to the sync start time using vectorized operations
    timedelta_index = pd.to_timedelta(seconds_index, unit='s')
    datetime_index = sync_start_time + timedelta_index

    # Create a new DataFrame with the datetime index
    photodiode_df.index = datetime_index
    photodiode_df.index.name = 'Time'
    
    del onix_analog_clock, oac_diff, photodiode
    gc.collect()
    onix_analog_clock_downsampled = True
    print("✅ Done downsampling analog_clock and photodiode, photodiode_df is now indexed to harp time")
else:
    print("❗ onix_analog_clock & photodiode already downsampled, skipping")

# framecount upsampling, only used in Cohort 0 synchronization
if cohort0:
    if not onix_analog_framecount_upsampled:
        upsample_factor = int(100 / downsample_factor)  # framecount counts every 100 analog datapoints
        df = onix_analog_framecount
        new_index = np.linspace(0, len(df) - 1, len(df) * upsample_factor)
        onix_analog_framecount = pd.DataFrame(index=new_index)
        for col in df.columns:
            onix_analog_framecount[col] = np.interp(new_index, np.arange(len(df)), df[col])
        del new_index
        gc.collect()
        # Check onix_analog shapes for consistency
        data_len = photodiode.shape[0]
        clock_len = onix_analog_clock.shape[0]
        framecount_len = len(onix_analog_framecount)

        if data_len != framecount_len or clock_len != framecount_len:
            offset = framecount_len - clock_len
            onix_analog_framecount = onix_analog_framecount.iloc[offset:]
            print(f"Warning: analog_data and _framecount mismatch, framecount truncated by {offset * 10}! Should be OK, but see https://github.com/neurogears/vestibular-vr/issues/81 for more information.")
        else:
            print("onix_analog shapes are consistent!")
        onix_analog_framecount_upsampled = True
        print("✅ Done upsampling analog_frameclock")
    else:
        print("❗ onix_analog_framecount already upsampled, skipping")

del timedelta_index, datetime_index, seconds_index
gc.collect()

In [None]:
( 
    onix_to_harp, 
    harp_to_onix, 
    photometry_to_onix, 
    photometry_to_harp, 
    conversions, #FIXME used?
    photometry_aligned
) = process.photometry_harp_onix_synchronisation(
    onix_digital=onix_digital,
    onix_harp=onix_harp,
    photometry_events=photometry_events,
    photometry_data = photometry_data,
    verbose=True
)

In [None]:
harp_streams.info()

In [None]:
# Finding global first and last timestamp 
streams_dict = {
    'session_settings': {'session_settings': session_settings},
    'experiment_events': {'experiment_events': experiment_events},
    'video_data1': {'video_data1': video_data1},
    'video_data2': {'video_data2': video_data2},
    'harp_streams': {'harp_streams': harp_streams},
    'onix_harp': {'onix_harp': onix_harp},
    'onix_digital': {'onix_digital': onix_digital},
    'harp_streams': {'harp_streams': harp_streams},
    'photodiode_df': {'photodiode_df': photodiode_df},
    'photometry_aligned': {'photometry_aligned': photometry_aligned}
}
if cohort0:
    streams_dict['onix_analog_framecount'] = {'onix_analog_framecount': onix_analog_framecount}

global_first_timestamp, global_last_timestamp, _, _ = process.get_global_minmax_timestamps(streams_dict, print_all=False, verbose=True)

del onix_digital, onix_harp, photometry_events, photometry_data
gc.collect()

# padding these dataframes to global first and last timestamp and bringing under new alldata dataframe 
dataframes = [video_data1, video_data2, harp_streams, photometry_aligned, photodiode_df]
padded_dataframes = [process.pad_dataframe_with_global_timestamps(df, global_first_timestamp, global_last_timestamp) for df in dataframes]
video_data1, video_data2, harp_streams, photometry_aligned, photodiode_df = padded_dataframes

del padded_dataframes, streams_dict, dataframes
gc.collect()

if save_full_asynchronous_data: #FIXME is this correct?
    alldata = pd.concat([video_data1, video_data2, harp_streams, photometry_aligned, photodiode_df], axis=1)
    alldata.to_csv(data_path / "full_asynchronous_data.csv")
    print("✅ Saved full asynchronous data to full_asynchronous_data.csv")


In [None]:
# Combine all dataframes into a single dataframe
video_data1 = video_data1.rename(columns=lambda x: f"{x}_1")
video_data2 = video_data2.rename(columns=lambda x: f"{x}_2")
alldata = pd.concat([video_data1, video_data2, harp_streams, photometry_aligned, photodiode_df], axis=1)
alldata.info()

del video_data1, video_data2, harp_streams, photometry_aligned, photodiode_df
gc.collect()

In [None]:
# convert startCam etc to bool at load, check if it propagates 
# convert photodiode (or propagate) as bool

In [None]:
alldata.info()

In [None]:
# Force reload the modules
importlib.reload(harp_resources.process)
importlib.reload(harp_resources.utils)
# Reassign after reloading to ensure updated references
process = harp_resources.process
utils = harp_resources.utils

In [None]:
import numpy as np
import pandas as pd

def resample_to_1khz_grid(experiment_events, photometry_data, onix_analog_clock, photodiode, harp_streams):
    """
    Resamples all datasets to a uniform 1 kHz (1 ms resolution) time grid while preserving alignment.
    - Numeric signals are resampled using interpolation.
    - Boolean signals retain original timestamps.

    Parameters:
        experiment_events (DataFrame): Original event timestamps.
        photometry_data (DataFrame): Photometry data (with a "TimeStamp" column in seconds).
        onix_analog_clock (ndarray): ONIX timestamps (nanoseconds).
        photodiode (ndarray): ONIX data.
        harp_streams (DataFrame): HARP-streamed data (both numeric and boolean).

    Returns:
        dict: A dictionary of aligned datasets.
    """

    # Convert photometry timestamps from seconds to datetime
    photometry_data["Datetime"] = pd.to_datetime(photometry_data["TimeStamp"], unit="s", origin="1900-01-01")
    photometry_data = photometry_data.set_index("Datetime").drop(columns=["TimeStamp"])

    # Convert ONIX timestamps from nanoseconds to datetime
    onix_time_index = pd.to_datetime(onix_analog_clock, unit="ns", origin="1900-01-01")

    # Define a uniform 1 kHz time grid spanning the full experiment
    min_time = min(
        experiment_events.index.min(),
        photometry_data.index.min(),
        onix_time_index.min(),
        harp_streams.index.min()
    )
    max_time = max(
        experiment_events.index.max(),
        photometry_data.index.max(),
        onix_time_index.max(),
        harp_streams.index.max()
    )

    # Create the 1 kHz common time grid
    common_time_grid = pd.date_range(start=min_time, end=max_time, freq="1ms")

    # Debugging print to confirm grid size
    print(f"⚡ Resampling to 1 kHz grid: {len(common_time_grid)} time points")

    # Split harp_streams into numeric and boolean columns
    harp_numeric = harp_streams.select_dtypes(exclude=['bool'])
    harp_bool = harp_streams.select_dtypes(include=['bool'])

    # Resample numeric data to the 1 kHz grid
    def resample_numeric(df):
        """Interpolates numeric data to match the 1 kHz time grid."""
        return df.reindex(df.index.union(common_time_grid)).interpolate(method='time').reindex(common_time_grid)

    # Resample event timestamps, photometry, and ONIX signals
    experiment_events_resampled = resample_numeric(experiment_events)
    photometry_data_resampled = resample_numeric(photometry_data)

    # Interpolate ONIX analog data
    photodiode_resampled = pd.DataFrame(
        index=common_time_grid,
        data=np.interp(
            common_time_grid.astype('int64') / 1e9,  # Convert ms timestamps to seconds
            onix_analog_clock / 1e9,  # Convert ONIX timestamps to seconds
            photodiode
        ),
        columns=["photodiode"]
    )

    # Resample numeric columns of harp_streams to the 1 kHz grid
    harp_numeric_resampled = resample_numeric(harp_numeric)

    # Keep boolean columns at original timestamps (no downsampling)
    harp_bool_aligned = harp_bool.reindex(harp_bool.index.union(common_time_grid)).fillna(method='ffill')

    return {
        "experiment_events_resampled": experiment_events_resampled,
        "photometry_data_resampled": photometry_data_resampled,
        "photodiode_resampled": photodiode_resampled,
        "harp_numeric_resampled": harp_numeric_resampled,
        "harp_bool_aligned": harp_bool_aligned
    }




In [None]:
# Example Usage:
aligned_data = resample_to_1khz_grid(experiment_events, photometry_data, onix_analog_clock, photodiode, harp_streams)

# Access the resampled datasets:
experiment_events_resampled = aligned_data["experiment_events_resampled"]
photometry_data_resampled = aligned_data["photometry_data_resampled"]
photodiode_resampled = aligned_data["photodiode_resampled"]
harp_numeric_resampled = aligned_data["harp_numeric_resampled"]
harp_bool_aligned = aligned_data["harp_bool_aligned"]


In [None]:
unique_events = experiment_events["Event"].unique()
print(unique_events)

In [None]:
# ---- Plotting Parameters ----
window_start = -1  # seconds, analysis window to plot and average
window_stop = 5
how_many_to_plot = -1  # -1 plots all or X plots first x halt events 


if "Visual_mismatch" in str(data_path):    
    block_start_event = "DrumWithReverseHalt block started"
    halt_event = "Apply halt: 2s"
    block_end_event = "Block timer elapsed" # Set to "no_end" to scan all events

if "VestibularMismatch" in str(data_path):
    block_start_event = "Sync signal started"
    halt_event = "DrumWithReverseflow block started"
    block_end_event = "no_end"
    
if "Cohort2_test" in str(data_path):    
    block_start_event = "DrumWithReverseHalt block started"
    halt_event = "Apply halt: 2s"
    block_end_event = "Block timer elapsed" # Set to "no_end" to scan all events
    

In [None]:
%%time

def pad_arrays(array_list):
    """Pads a list of 1D NumPy arrays to the same length using NaN padding."""
    max_len = max(map(len, array_list))  # Efficient max length calculation
    padded_array = np.empty((len(array_list), max_len), dtype=np.float64)  # Preallocate array
    padded_array.fill(np.nan)  # Fill with NaNs in one operation

    for i, arr in enumerate(array_list):
        padded_array[i, :len(arr)] = arr  # Vectorized assignment

    return padded_array

def set_axis_limits(ax_run, ax_turn, data_run, data_turn):
    """Set axis limits with minimum ranges that expand if data requires"""
    run_min, run_max = np.nanmin(data_run), np.nanmax(data_run)
    turn_min, turn_max = np.nanmin(data_turn), np.nanmax(data_turn)
    
    # Set running axis limits (minimum range: -0.02 to +0.10)
    ax_run.set_ylim([min(-0.02, run_min), max(0.10, run_max)])
    
    # Set turning axis limits (minimum range: -45 to +45)
    ax_turn.set_ylim([min(-45, turn_min), max(45, turn_max)])


# ---- Extract Halt Events Efficiently ----
block_starts = experiment_events.query("Event == @block_start_event").index.to_numpy()

halt_events_list = []
for block_start in block_starts:
    if block_end_event == "no_end":
        block_halts = experiment_events.query("Event == @halt_event and index > @block_start")
    else:
        block_end = experiment_events.query("Event == @block_end_event and index > @block_start").index.min()
        if pd.notna(block_end):
            block_halts = experiment_events.query("Event == @halt_event and index > @block_start and index < @block_end")
        else:
            block_halts = pd.DataFrame()  # No valid end event found
    
    if not block_halts.empty:
        halt_events_list.append(block_halts)

block_halts = pd.concat(halt_events_list) if halt_events_list else pd.DataFrame()

if block_halts.empty:
    raise ValueError(f"⚠️ No [{halt_event}] events found between [{block_start_event}] and [{block_end_event}]. "
                     f"Check if the event names are correct and exist in experiment_events.")

# Convert Halt Times to NumPy for Efficiency
halt_event_times = block_halts.index.to_numpy()

# Adjust how_many_to_plot if it exceeds available events
if how_many_to_plot > len(block_halts):
    print(f"⚠️ Warning: Requested {how_many_to_plot} halts, but only {len(block_halts)} are available. "
          "Adjusting how_many_to_plot accordingly.")
    how_many_to_plot = len(block_halts)

if how_many_to_plot == -1:
    how_many_to_plot = len(block_halts)  # Limit to avoid excessive plots
print(f"Found {len(block_halts)} halt events within valid blocks, plotting {how_many_to_plot}.")

# Define colors
flow_x_color, flow_y_color, photodiode_color = "blue", "orange", "grey"  # Changed photodiode to grey
z_470_color, z_560_color = "green", "red"

# Initialize lists for aligned data
aligned_time = np.linspace(window_start, window_stop, 500)  
flow_x_aligned, flow_y_aligned, photodiode_aligned, z_470_aligned, z_560_aligned = [], [], [], [], []

# ----------------------
# First Plot: Individual Trials
# ----------------------
fig, ax_run = plt.subplots(figsize=(10, 6))

# Create axes for turning, photodiode and fluorescence
ax_turn = ax_run.twinx()
ax_turn.spines.left.set_position(('outward', 60))
ax_turn.yaxis.set_label_position('left')
ax_turn.yaxis.set_ticks_position('left')

ax_photo = ax_run.twinx()
ax_photo.spines.right.set_position(('outward', 60))

ax_fluor = ax_run.twinx()
ax_fluor.spines.right.set_position(('outward', 120))

all_plotted_x_values = []
all_plotted_y_values = []

for idx, halt_time in enumerate(block_halts.index[:how_many_to_plot]):
    halt_time_seconds = halt_time.timestamp()
    min_time, max_time = halt_time + pd.DateOffset(seconds=window_start), halt_time + pd.DateOffset(seconds=window_stop)

    # Extract Optical Tracking Data
    optical_x = harp_streams['OpticalTrackingRead0X(46)'].loc[min_time:max_time].dropna()
    optical_y = harp_streams['OpticalTrackingRead0Y(46)'].loc[min_time:max_time].dropna()

    if not optical_x.empty and not optical_y.empty:
        optical_x_rel = (optical_x.index.astype("int64") / 1e9) - halt_time_seconds
        optical_y_rel = (optical_y.index.astype("int64") / 1e9) - halt_time_seconds
        
        if not optical_x.empty and not optical_y.empty:
            all_plotted_x_values.extend(optical_x.values)
            all_plotted_y_values.extend(optical_y.values)

        label_x = "Running (Flow X)" if idx == 0 else None
        label_y = "Turning (Flow Y)" if idx == 0 else None
        ax_run.plot(optical_x_rel, optical_x, color=flow_x_color, alpha=0.3, label=label_x)
        ax_turn.plot(optical_y_rel, optical_y, color=flow_y_color, alpha=0.3, label=label_y)

        # Restrict aligned_time to the valid range of optical_x_rel
        valid_mask = (aligned_time >= optical_x_rel.min()) & (aligned_time <= optical_x_rel.max())
        aligned_time_valid = aligned_time[valid_mask]

        # Interpolate only within the valid time range
        flow_x_interp = np.interp(aligned_time_valid, optical_x_rel, optical_x, left=np.nan, right=np.nan)
        flow_y_interp = np.interp(aligned_time_valid, optical_y_rel, optical_y, left=np.nan, right=np.nan)

        # Append only the valid interpolated values
        flow_x_aligned.append(flow_x_interp)
        flow_y_aligned.append(flow_y_interp)
        
# Set axis limits using only the plotted data
all_plotted_x_values = np.array(all_plotted_x_values)
all_plotted_y_values = np.array(all_plotted_y_values)
set_axis_limits(ax_run, ax_turn, all_plotted_x_values, all_plotted_y_values)

ax_run.set_xlabel("Relative Time (s)")
ax_run.set_ylabel("Running X (m/s)")
ax_turn.set_ylabel("Turning Y (deg/s)")
ax_photo.set_ylabel("Photodiode Signal")
ax_photo.set_ylim([0, 1.2])
ax_fluor.set_ylabel("Fluorescence Signal")

for idx, halt_time in enumerate(block_halts.index[:how_many_to_plot]):
    halt_time_seconds = halt_time.timestamp()

    onix_sec_start_time = harp_to_onix_clock(block_halts.iloc[idx]["Seconds"] + window_start)
    onix_sec_stop_time = harp_to_onix_clock(block_halts.iloc[idx]["Seconds"] + window_stop)

    onix_sec_start_index = np.searchsorted(onix_analog_clock, onix_sec_start_time)
    onix_sec_stop_index = np.searchsorted(onix_analog_clock, onix_sec_stop_time)

    onix_time_rel = (onix_to_harp_timestamp(onix_analog_clock[onix_sec_start_index:onix_sec_stop_index])
                     .astype("int64") / 1e9) - halt_time_seconds

    photodiode_signal = photodiode[onix_sec_start_index:onix_sec_stop_index]

    label_photodiode = "Photodiode" if idx == 0 else None
    ax_photo.plot(onix_time_rel, photodiode_signal, color=photodiode_color, alpha=0.5, label=label_photodiode)

    # Restrict aligned_time to valid range of onix_time_rel
    valid_mask = (aligned_time >= onix_time_rel.min()) & (aligned_time <= onix_time_rel.max())
    aligned_time_valid = aligned_time[valid_mask]

    # Interpolate photodiode on the valid time range
    photodiode_interp = np.interp(aligned_time_valid, onix_time_rel, photodiode_signal, left=np.nan, right=np.nan)
    photodiode_aligned.append(photodiode_interp)
    

if "TimeStamp" in photometry_data.columns:
    photometry_data = photometry_data.set_index("TimeStamp")

for idx, halt_time in enumerate(block_halts.index[:how_many_to_plot]):
    halt_time_seconds = halt_time.timestamp()

    photometry_sec_start_time = onix_time_to_photometry(harp_to_onix_clock(block_halts.iloc[idx]["Seconds"] + window_start))
    photometry_sec_stop_time = onix_time_to_photometry(harp_to_onix_clock(block_halts.iloc[idx]["Seconds"] + window_stop))

    photometry_sec = photometry_data.loc[photometry_sec_start_time:photometry_sec_stop_time]

    if not photometry_sec.empty:
        photometry_time_rel = (photometry_to_harp_time(photometry_sec.index).astype("int64") / 1e9) - halt_time_seconds

        label_560, label_470 = "z_560" if idx == 0 else None, "z_470" if idx == 0 else None
        ax_fluor.plot(photometry_time_rel, photometry_sec['z_560'], color=z_560_color, alpha=0.3, label=label_560)
        ax_fluor.plot(photometry_time_rel, photometry_sec['z_470'], color=z_470_color, alpha=0.3, label=label_470)
        
        # Restrict aligned_time to the valid range of photometry_time_rel
        valid_mask = (aligned_time >= photometry_time_rel.min()) & (aligned_time <= photometry_time_rel.max())
        aligned_time_valid = aligned_time[valid_mask]

        # Perform interpolation on the adjusted time range
        z_560_interp = np.interp(aligned_time_valid, photometry_time_rel, photometry_sec['z_560'], left=np.nan, right=np.nan)
        z_470_interp = np.interp(aligned_time_valid, photometry_time_rel, photometry_sec['z_470'], left=np.nan, right=np.nan)
        
        # Append only the valid interpolated values
        z_560_aligned.append(z_560_interp)
        z_470_aligned.append(z_470_interp)

ax_run.legend(loc="upper left")
ax_turn.legend(loc="upper left", bbox_to_anchor=(0, 0.9))
ax_photo.legend(loc="center right")
ax_fluor.legend(loc="upper right")

plt.title("Individual Trials")
plt.tight_layout()
plt.show()

# ----------------------
# Second Plot: Averages with Error Shading (Proper Axes Labels & No Overlap)
# ----------------------
fig, ax_run = plt.subplots(figsize=(10, 6))

# Create axes for turning, photodiode and fluorescence
ax_turn = ax_run.twinx()
ax_turn.spines.left.set_position(('outward', 60))
ax_turn.yaxis.set_label_position('left')
ax_turn.yaxis.set_ticks_position('left')

ax_photo = ax_run.twinx()
ax_photo.spines.right.set_position(('outward', 60))

ax_fluor = ax_run.twinx()
ax_fluor.spines.right.set_position(('outward', 120))

# Pad all signal data to ensure uniform shape
flow_x_aligned_padded = pad_arrays(flow_x_aligned)
flow_y_aligned_padded = pad_arrays(flow_y_aligned)
photodiode_aligned_padded = pad_arrays(photodiode_aligned)
z_560_aligned_padded = pad_arrays(z_560_aligned)
z_470_aligned_padded = pad_arrays(z_470_aligned)

# Ensure all arrays have the same length
# Get the number of valid (non-NaN) values per time point across trials
valid_counts_560 = np.sum(~np.isnan(z_560_aligned_padded), axis=0)
valid_counts_470 = np.sum(~np.isnan(z_470_aligned_padded), axis=0)

# Find the last point where at least 80% of trials still have data
threshold = 0.8 * len(z_560_aligned_padded)  # Adjustable threshold (80%)
adaptive_cutoff = np.where(valid_counts_560 >= threshold)[0][-1]  # Last valid index

# Use the smaller of (1) standard min_length, (2) adaptive cutoff
min_length = min(
    aligned_time.shape[0], 
    photodiode_aligned_padded.shape[1], 
    z_560_aligned_padded.shape[1], 
    z_470_aligned_padded.shape[1], 
    flow_x_aligned_padded.shape[1],  
    flow_y_aligned_padded.shape[1],
    adaptive_cutoff  # Ensure we include this index
)

print(f"🔍 Adaptive cutoff applied at index {adaptive_cutoff}, using min_length = {min_length}")

# Compute means after padding
flow_x_mean = np.nanmean(flow_x_aligned_padded, axis=0)
flow_y_mean = np.nanmean(flow_y_aligned_padded, axis=0)
photodiode_mean = np.nanmean(photodiode_aligned_padded, axis=0)
z_560_mean = np.nanmean(z_560_aligned_padded, axis=0)
z_470_mean = np.nanmean(z_470_aligned_padded, axis=0)

# Compute SEM
flow_x_sem = np.nanstd(flow_x_aligned_padded, axis=0) / np.sqrt(np.sum(~np.isnan(flow_x_aligned_padded), axis=0))
flow_y_sem = np.nanstd(flow_y_aligned_padded, axis=0) / np.sqrt(np.sum(~np.isnan(flow_y_aligned_padded), axis=0))
photodiode_sem = np.nanstd(photodiode_aligned_padded, axis=0) / np.sqrt(np.sum(~np.isnan(photodiode_aligned_padded), axis=0))
z_560_sem = np.nanstd(z_560_aligned_padded, axis=0) / np.sqrt(np.sum(~np.isnan(z_560_aligned_padded), axis=0))
z_470_sem = np.nanstd(z_470_aligned_padded, axis=0) / np.sqrt(np.sum(~np.isnan(z_470_aligned_padded), axis=0))

## Truncate all arrays to match min_length
aligned_time = aligned_time[:min_length]
flow_x_mean = flow_x_mean[:min_length]
flow_y_mean = flow_y_mean[:min_length]
photodiode_mean = photodiode_mean[:min_length]
z_560_mean = z_560_mean[:min_length]
z_470_mean = z_470_mean[:min_length]

# Truncate SEM values to match min_length
flow_x_sem = flow_x_sem[:min_length]
flow_y_sem = flow_y_sem[:min_length]
photodiode_sem = photodiode_sem[:min_length]
z_560_sem = z_560_sem[:min_length]
z_470_sem = z_470_sem[:min_length]

# Plot running mean and SEM
ax_run.plot(aligned_time, flow_x_mean, color=flow_x_color, label="Running (Mean)")
ax_run.fill_between(aligned_time, flow_x_mean - flow_x_sem, flow_x_mean + flow_x_sem, color=flow_x_color, alpha=0.2)

# Plot turning mean and SEM
ax_turn.plot(aligned_time, flow_y_mean, color=flow_y_color, label="Turning (Mean)")
ax_turn.fill_between(aligned_time, flow_y_mean - flow_y_sem, flow_y_mean + flow_y_sem, color=flow_y_color, alpha=0.2)

# Plot photodiode mean and SEM
ax_photo.plot(aligned_time, photodiode_mean, color=photodiode_color, label="Photodiode (Mean)")
ax_photo.fill_between(aligned_time, photodiode_mean - photodiode_sem, photodiode_mean + photodiode_sem, 
                      color=photodiode_color, alpha=0.2)

# Plot fluorescence means and SEM
ax_fluor.plot(aligned_time, z_560_mean, color=z_560_color, label="z_560 (Mean)")
ax_fluor.fill_between(aligned_time, z_560_mean - z_560_sem, z_560_mean + z_560_sem, color=z_560_color, alpha=0.2)
ax_fluor.plot(aligned_time, z_470_mean, color=z_470_color, label="z_470 (Mean)")
ax_fluor.fill_between(aligned_time, z_470_mean - z_470_sem, z_470_mean + z_470_sem, color=z_470_color, alpha=0.2)

set_axis_limits(ax_run, ax_turn, 
                np.concatenate([flow_x_mean - flow_x_sem, flow_x_mean + flow_x_sem]),
                np.concatenate([flow_y_mean - flow_y_sem, flow_y_mean + flow_y_sem]))

ax_run.set_xlabel("Relative Time (s)")
ax_run.set_ylabel("Running X (m/s)")
ax_turn.set_ylabel("Turning Y (deg/s)")
ax_photo.set_ylabel("Photodiode Signal")
ax_fluor.set_ylabel("Fluorescence Signal")

ax_run.legend(loc="upper left")
ax_turn.legend(loc="upper left", bbox_to_anchor=(0, 0.9))
ax_photo.legend(loc="center right")
ax_fluor.legend(loc="upper right")

plt.title("Trial Averages")
plt.tight_layout()
plt.show()


In [None]:
pympler_memory_df = utils.get_pympler_memory_usage()


In [None]:
#del onix_analog_clock
gc.collect()