In [None]:
# Import packages and connect to the IBL
import matplotlib.pyplot as plt
import numpy as np
import os
import torch
import yaml
import random
import pandas as pd
import seaborn as sns
import cv2
import ibllib.io.video as vidio
import matplotlib.gridspec as gridspec
import matplotlib.colors as mcolors
import umap.umap_ as umap
from daart.data import DataGenerator
from scipy.interpolate import interp1d
from daart.transforms import ZScore
from daart.models import Segmenter
from one.api import ONE
from brainbox.io.one import SessionLoader
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from matplotlib.figure import Figure
from sklearn.metrics import f1_score
from matplotlib.patches import Patch
from matplotlib.colors import ListedColormap
from matplotlib.lines import Line2D
from collections import Counter

# Connect to the IBL database
ONE.setup(base_url='https://openalyx.internationalbrainlab.org', silent=True)
one = ONE(password='international')

# Load the Model
model_dir = '/Users/zacharyzusin/Documents/NeuroscienceResearch/version_0'
model_file = os.path.join(model_dir, 'best_val_model.pt')

hparams_file = os.path.join(model_dir, 'hparams.yaml')
hparams = yaml.safe_load(open(hparams_file, 'rb'))

if torch.cuda.is_available():
    device = 'cuda:0'
else:
    device = 'cpu'

model = Segmenter(hparams)
model.load_state_dict(torch.load(model_file, map_location=lambda storage, loc: storage))
model.to(device)
model.eval()

In [68]:
# Session Information (dropbox dataset)
dropbox_eids = [
    '46794e05-3f6a-4d35-afb3-9165091a5a74',  # churchlandlab/CSHL045/2020-02-27-001
    #'c7b0e1a3-4d4d-4a76-9339-e73d0ed5425b',  # cortexlab/KS020/2020-02-06/001 (no trials table)
    'db4df448-e449-4a6f-a0e7-288711e7a75a',  # danlab/DY_009/2020-02-27/001
    '54238fd6-d2d0-4408-b1a9-d19d24fd29ce',  # danlab_DY_018_2020-10-15-001
    'f3ce3197-d534-4618-bf81-b687555d1883',  # hoferlab_SWC_043_2020-09-15-001
    '493170a6-fd94-4ee4-884f-cc018c17eeb9',  # hoferlab_SWC_061_2020-11-23-001
    '7cb81727-2097-4b52-b480-c89867b5b34c',  # mrsicflogellab_SWC_052_2020-10-22-001
    #'1735d2be-b388-411a-896a-60b01eaa1cfe',  # mrsicflogellab_SWC_058_2020-12-11-001 (no trials table)
    'ff96bfe1-d925-4553-94b5-bf8297adf259',  # wittenlab/ibl_witten_26/2021-01-27-002
    '73918ae1-e4fd-4c18-b132-00cb555b1ad2',  # wittenlab_ibl_witten_27_2021-01-21-001
]

dropbox_marker_paths = {
    '46794e05-3f6a-4d35-afb3-9165091a5a74': 'churchlandlab_CSHL045_2020-02-27-001',
    'c7b0e1a3-4d4d-4a76-9339-e73d0ed5425b': 'cortexlab_KS020_2020-02-06-001',
    'db4df448-e449-4a6f-a0e7-288711e7a75a': 'danlab_DY_009_2020-02-27-001',
    '54238fd6-d2d0-4408-b1a9-d19d24fd29ce': 'danlab_DY_018_2020-10-15-001',
    'f3ce3197-d534-4618-bf81-b687555d1883': 'hoferlab_SWC_043_2020-09-15-001',
    '493170a6-fd94-4ee4-884f-cc018c17eeb9': 'hoferlab_SWC_061_2020-11-23-001',
    '7cb81727-2097-4b52-b480-c89867b5b34c': 'mrsicflogellab_SWC_052_2020-10-22-001',
    '1735d2be-b388-411a-896a-60b01eaa1cfe': 'mrsicflogellab_SWC_058_2020-12-11-001',
    'ff96bfe1-d925-4553-94b5-bf8297adf259': 'wittenlab_ibl_witten_26_2021-01-27-002',
    '73918ae1-e4fd-4c18-b132-00cb555b1ad2': 'wittenlab_ibl_witten_27_2021-01-21-001'
}

In [69]:
# Function to create data generator
def create_data_generator(sess_id, input_file):    
    signals = ['markers']
    transforms = [ZScore()]
    paths = [input_file]

    return DataGenerator(
        [sess_id], [signals], [transforms], [paths],
        batch_size=hparams['batch_size'], sequence_length=hparams['sequence_length'], sequence_pad=hparams['sequence_pad'],
        trial_splits=hparams['trial_splits'], input_type=hparams['input_type'],
        device=device
    )

In [70]:
# Function to run inference
def get_states(model, data_gen):
    tmp = model.predict_labels(data_gen, return_scores=True)
    probs = np.vstack(tmp['labels'][0])
    return np.argmax(probs, axis=1)

In [71]:
# Function to calculate average durations
def calculate_average_durations(states):
    unique_states = [0, 1, 2, 3, 4]
    average_durations = []

    for state in unique_states:
        durations = []
        current_duration = 0
        in_segment = False

        for s in states:
            if s == state:
                if in_segment:
                    current_duration += 1
                else:
                    in_segment = True
                    current_duration = 1
            else:
                if in_segment:
                    durations.append(current_duration)
                    in_segment = False

        # Append the last segment duration if it ended with the last element
        if in_segment:
            durations.append(current_duration)

        if durations:
            average_durations.append(np.mean(durations))
        else:
            average_durations.append(0)

    return average_durations

In [72]:
# Functions to Perform Smoothing
def non_uniform_savgol(x, y, window, polynom):
    if len(x) != len(y):
        raise ValueError('"x" and "y" must be of the same size')
    if len(x) < window:
        raise ValueError('The data size must be larger than the window size')
    if type(window) is not int:
        raise TypeError('"window" must be an integer')
    if window % 2 == 0:
        raise ValueError('The "window" must be an odd integer')
    if type(polynom) is not int:
        raise TypeError('"polynom" must be an integer')
    if polynom >= window:
        raise ValueError('"polynom" must be less than "window"')

    half_window = window // 2
    polynom += 1

    A = np.empty((window, polynom))
    tA = np.empty((polynom, window))
    t = np.empty(window)
    y_smoothed = np.full(len(y), np.nan)

    for i in range(half_window, len(x) - half_window, 1):
        x0 = x[i]
        A[:, 0] = 1.0
        for j in range(1, polynom):
            A[:, j] = (x[i - half_window:i + half_window + 1] - x0)**j
        tA = A.T
        t = np.linalg.inv(tA @ A) @ tA @ y[i - half_window:i + half_window + 1]
        y_smoothed[i] = t[0]

    y_smoothed[:half_window + 1] = y[:half_window + 1]
    y_smoothed[-half_window - 1:] = y[-half_window - 1:]

    return y_smoothed

# Function to smooth and interpolate the marker signals
def smooth_interpolate_signal_sg(signal, window=31, order=3, interp_kind='linear'):

    signal_noisy_w_nans = np.copy(signal)
    timestamps = np.arange(signal_noisy_w_nans.shape[0])
    good_idxs = np.where(~np.isnan(signal_noisy_w_nans))[0]
    if len(good_idxs) < window:
        print('Not enough non-nan indices to filter; returning original signal')
        return signal_noisy_w_nans
    signal_smooth_nonans = non_uniform_savgol(
        timestamps[good_idxs], signal_noisy_w_nans[good_idxs], window=window, polynom=order)
    signal_smooth_w_nans = np.copy(signal_noisy_w_nans)
    signal_smooth_w_nans[good_idxs] = signal_smooth_nonans
    interpolater = interp1d(
        timestamps[good_idxs], signal_smooth_nonans, kind=interp_kind, fill_value='extrapolate')
    signal = interpolater(timestamps)
    return signal

In [73]:
# Function to extract and/or smooth ibl data
def extract_marker_data(eid, l_thresh, view, paw, smooth):
    #sess_id = dropbox_marker_paths[eid]

    # Load the pose data
    sl = SessionLoader(one=one, eid=eid)
    sl.load_pose(likelihood_thr=l_thresh, views=[view])
    times = sl.pose[f'{view}Camera'].times.to_numpy()
    markers = sl.pose[f'{view}Camera'].loc[:, (f'{paw}_x', f'{paw}_y')].to_numpy()  

    # Load wheel data
    sl.load_wheel()
    wh_times = sl.wheel.times.to_numpy()
    wh_vel_oversampled = sl.wheel.velocity.to_numpy()
    
    # Resample wheel data at marker times
    interpolator = interp1d(wh_times, wh_vel_oversampled, fill_value='extrapolate')
    wh_vel = interpolator(times)

    if smooth == True:
        # Smooth the marker data
        markers[:, 0] = smooth_interpolate_signal_sg(markers[:, 0], window=7)
        markers[:, 1] = smooth_interpolate_signal_sg(markers[:, 1], window=7)
        
    # Process the data
    markers_comb = np.hstack([markers, wh_vel[:, None]])
    velocity = np.vstack([np.array([0, 0, 0]), np.diff(markers_comb, axis=0)])
    markers_comb = np.hstack([markers_comb, velocity])
    markers_z = (markers_comb - np.mean(markers_comb, axis=0)) / np.std(markers_comb, axis=0)
    feature_names = ['paw_x_pos', 'paw_y_pos', 'wheel_vel', 'paw_x_vel', 'paw_y_vel', 'wheel_acc']
    df = pd.DataFrame(markers_z, columns=feature_names)

    if smooth == False:
        df.to_csv(f'/Users/zacharyzusin/Documents/NeuroscienceResearch/daart/features/{eid}_features.csv')
        markers_file = f'/Users/zacharyzusin/Documents/NeuroscienceResearch/daart/features/{eid}_features.csv'
    else:
        df.to_csv(f'/Users/zacharyzusin/Documents/NeuroscienceResearch/daart/features/{eid}_features_smooth.csv')
        markers_file = f'/Users/zacharyzusin/Documents/NeuroscienceResearch/daart/features/{eid}_features_smooth.csv'
    return markers_file

In [74]:
# Function to create a heatmap of the predicted states
def create_predicted_states_heatmap(eid, data_dict):
    # Convert eid to sess_id internally
    sess_id = dropbox_marker_paths[eid]
    
    # State labels
    state_labels = ['Still', 'Move', 'Wheel Turn', 'Groom']
    
    # Create a custom color palette
    cmap = sns.color_palette("tab10", n_colors=4)
    
    # Combine the datasets if multiple are provided
    if len(data_dict) > 1:
        data = np.vstack([data_dict[key] for key in data_dict.keys()])
        yticklabels = list(data_dict.keys())
    else:
        key = list(data_dict.keys())[0]
        data = data_dict[key]
        yticklabels = [key]
        
        # Ensure data is reshaped to be a 2D array
        if len(data.shape) == 1:
            data = data.reshape(1, -1)
    
    # Calculate F1 scores if multiple datasets are provided
    f1_scores = None
    if len(data_dict) > 1:
        f1_scores = {}
        keys = list(data_dict.keys())
        for i in range(len(keys)):
            for j in range(i + 1, len(keys)):
                key1, key2 = keys[i], keys[j]
                f1 = f1_score(data_dict[key1], data_dict[key2], average='macro')
                f1_scores[f'{key1} vs {key2}'] = f1

    # Plot the heatmap
    fig = plt.figure(figsize=(12, 8))
    ax = sns.heatmap(data, cmap=cmap, cbar=True, yticklabels=yticklabels, annot=False)
    
    # Customize the colorbar
    colorbar = ax.collections[0].colorbar
    colorbar.set_ticks([1.375, 2.125, 2.875, 3.625])
    colorbar.set_ticklabels(state_labels)
    
    plt.title(f'Heatmap of Discrete States for {sess_id}')
    plt.xlabel('Time')
    plt.ylabel('Source')
    
    if f1_scores:
        f1_text = '\n\n\n' + 'F1 Scores: ' + ', '.join([f'{key}: {f1:.2f}' for key, f1 in f1_scores.items()])
        plt.text(0.5, -0.1, f1_text, horizontalalignment='center', verticalalignment='center', transform=ax.transAxes)
    
    plt.show()
    fig.savefig(f'/Users/zacharyzusin/Documents/NeuroscienceResearch/repro_ephys_analysis/{sess_id}/{sess_id}_heatmap_states.png')
    


In [75]:
# Function to create a plot of the frequency of each state
def create_state_frequency_plot(eid, data_dict):
    sess_id = dropbox_marker_paths[eid]
    
    # State labels
    state_labels = ['Still', 'Move', 'Wheel Turn', 'Groom']
    
    # Initialize lists to store state counts for each dataset
    unique_states = [1, 2, 3, 4]
    state_counts = {key: [np.sum(data_dict[key] == state) for state in unique_states] for key in data_dict}
    
    # Create a DataFrame for easier plotting
    df = pd.DataFrame({'State': state_labels})
    for key in data_dict:
        df[key] = state_counts[key]

    # Plot the histogram
    fig, ax = plt.subplots(figsize=(12, 6))
    width = 0.25  # the width of the bars

    # Position of the bars on the x-axis
    r_positions = np.arange(len(state_labels)) * (len(data_dict) + 1) * width

    # Create bars for each dataset
    for i, key in enumerate(data_dict):
        r = [x + i * width for x in r_positions]
        ax.bar(r, df[key], width=width, edgecolor='grey', label=key)

    # Add xticks on the middle of the group bars
    ax.set_xlabel('States', fontweight='bold')
    ax.set_ylabel('Frames', fontweight='bold')
    ax.set_title(f'Time Spent in Each State for {sess_id}', fontweight='bold')
    ax.set_xticks([r + width * (len(data_dict) - 1) / 2 for r in r_positions])
    ax.set_xticklabels(state_labels)

    plt.show()
    fig.savefig(f'/Users/zacharyzusin/Documents/NeuroscienceResearch/repro_ephys_analysis/{sess_id}/{sess_id}_state_frequency.png')

In [76]:
# Function to create a plot of the average duration spent in each state
def create_state_duration_plot(eid, data_dict):
    sess_id = dropbox_marker_paths[eid]
    
    # State labels
    state_labels = ['Still', 'Move', 'Wheel Turn', 'Groom']
    
    # Initialize lists to store average durations for each dataset
    avg_durations = {key: calculate_average_durations(data_dict[key])[1:] for key in data_dict}
    
    # Create a DataFrame for easier plotting
    df = pd.DataFrame({'State': state_labels})
    for key in data_dict:
        df[key] = avg_durations[key]

    # Plot the average durations
    fig, ax = plt.subplots(figsize=(12, 6))
    width = 0.25  # the width of the bars

    # Position of the bars on the x-axis
    r_positions = np.arange(len(state_labels)) * (len(data_dict) + 1) * width

    # Create bars for each dataset
    for i, key in enumerate(data_dict):
        r = [x + i * width for x in r_positions]
        ax.bar(r, df[key], width=width, edgecolor='grey', label=key)

    # Add xticks on the middle of the group bars
    ax.set_xlabel('States', fontweight='bold')
    ax.set_ylabel('Frames', fontweight='bold')
    ax.set_title(f'Average Time Spent in Each State for {sess_id}', fontweight='bold')
    ax.set_xticks([r + width * (len(data_dict) - 1) / 2 for r in r_positions])
    ax.set_xticklabels(state_labels)

    plt.show()
    fig.savefig(f'/Users/zacharyzusin/Documents/NeuroscienceResearch/repro_ephys_analysis/{sess_id}/{sess_id}_state_duration.png')

In [77]:
# Function to create video with predictions overlayed
def create_video_comparing_predictions(eid, states_dropbox_truncated, states_ibl_truncated, states_ibl_smooth_truncated):
    # State labels
    state_labels = ['Still', 'Move', 'Wheel Turn', 'Groom']
    
    # Load the video
    label = 'left'
    url = vidio.url_from_eid(eid, one=one)[label]

    video_path = url #'/Users/zacharyzusin/Documents/NeuroscienceResearch/dropbox/videos/' + sess_id + '.mp4'
    cap = cv2.VideoCapture(video_path)

    # Define the output video path and create directory
    output_dir = '/Users/zacharyzusin/Documents/NeuroscienceResearch/video_analysis/dataset_comparison_videos'
    os.makedirs(output_dir, exist_ok=True)
    output_path = os.path.join(output_dir, dropbox_marker_paths[eid] + '_three_datasets.mp4')

    def overlay_predictions(frame, frame_num, dropbox_pred, ibl_pred, ibl_smooth_pred):
        # Create a figure and axis
        fig = Figure(figsize=(10, 6))
        canvas = FigureCanvas(fig)
        ax = fig.gca()

        # Display the video frame
        ax.imshow(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))

        # Add text for predictions
        ax.text(10, 30, f"Dropbox: {state_labels[dropbox_pred-1]}", color='white', fontsize=12, bbox=dict(facecolor='blue', alpha=0.5))
        ax.text(10, 60, f"IBL: {state_labels[ibl_pred-1]}", color='white', fontsize=12, bbox=dict(facecolor='red', alpha=0.5))
        ax.text(10, 90, f"IBL Smooth: {state_labels[ibl_smooth_pred-1]}", color='white', fontsize=12, bbox=dict(facecolor='green', alpha=0.5))

        # Add frame number
        ax.text(10, frame.shape[0] - 10, f"Frame: {frame_num}", color='white', fontsize=12)

        # Remove axes
        ax.axis('off')

        # Convert to image
        canvas.draw()
        image = np.frombuffer(canvas.tostring_rgb(), dtype='uint8')
        image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))

        return image

    # Set up video writer
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(output_path, fourcc, 30.0, (1000, 600))

    for i in range(1000):
        ret, frame = cap.read()
        if not ret:
            break

        # Overlay predictions
        frame_with_pred = overlay_predictions(frame, i, states_dropbox_truncated[i], states_ibl_truncated[i], states_ibl_smooth_truncated[i])

        # Write frame
        out.write(cv2.cvtColor(frame_with_pred, cv2.COLOR_RGB2BGR))

    # Release everything
    cap.release()
    out.release()
    cv2.destroyAllWindows()

    print(f"Video creation completed. Check '{output_path}'")

In [78]:
# Function to create video with trial markers
def create_video_with_trial_markers(eid, states_ibl_smooth,  l_thresh, view, paw):
    sess_id = dropbox_marker_paths[eid]
    
    # State labels
    state_labels = ['Still', 'Move', 'Wheel Turn', 'Groom']
    
    # Load the video
    label = 'left'
    video_path = vidio.url_from_eid(eid, one=one)[label]
    cap = cv2.VideoCapture(video_path)

    # Get the frame rate of the video
    input_fps = cap.get(cv2.CAP_PROP_FPS)

    # Define the output video path and create directory
    output_dir = f'/Users/zacharyzusin/Documents/NeuroscienceResearch/repro_ephys_analysis/{sess_id}'
    os.makedirs(output_dir, exist_ok=True)
    output_path = os.path.join(output_dir, dropbox_marker_paths[eid] + '_predictions.mp4')

    # Set up video writer
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(output_path, fourcc, input_fps, (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))))

    # Function to overlay predictions, paw markers, and trial number on each frame of the video
    def overlay_predictions(frame, frame_num, ibl_smooth_pred=None, trial_number=None, text=None, paw_position=None):
        # Add text for prediction if ibl_smooth_pred is not None
        if ibl_smooth_pred is not None:
            cv2.putText(frame, f"IBL Smooth: {state_labels[int(ibl_smooth_pred)-1]}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2, cv2.LINE_AA)

        # Add frame number
        cv2.putText(frame, f"Frame: {frame_num}", (10, frame.shape[0] - 10), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)

        # Add "Trial n" text at the center of the frame if trial_number is not None
        if trial_number is not None:
            trial_text = f"Trial {trial_number}"
            textsize = cv2.getTextSize(trial_text, cv2.FONT_HERSHEY_SIMPLEX, 1, 2)[0]
            textX = (frame.shape[1] - textsize[0]) // 2
            textY = (frame.shape[0] + textsize[1]) // 2
            cv2.putText(frame, trial_text, (textX, textY), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)

        # Add additional text if provided
        if text is not None:
            textsize = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 1, 2)[0]
            textX = (frame.shape[1] - textsize[0]) // 2
            textY = 50  # Top of the frame
            cv2.putText(frame, text, (textX, textY), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)

        # Add paw marker if provided and not NaN
        if paw_position is not None:
            if not np.isnan(paw_position[0]) and not np.isnan(paw_position[1]):
                cv2.circle(frame, (int(paw_position[0]), int(paw_position[1])), 15, (0, 0, 255), -1)

        return frame

    # Load trial event times
    trials = one.load_object(eid, 'trials')
    trial_start_times = trials['intervals'][:, 0]
    trial_end_times = trials['intervals'][:, 1]
    stim_on_times = trials['stimOn_times']
    first_movement_times = trials['firstMovement_times']
    response_times = trials['response_times']
    feedback_types = trials['feedbackType']

    # Convert times to frame indices
    trial_start_frame_indices = (trial_start_times * input_fps).astype(int)
    trial_end_frame_indices = (trial_end_times * input_fps).astype(int)
    stim_on_frame_indices = (stim_on_times * input_fps).astype(int)
    first_movement_frame_indices = (first_movement_times * input_fps).astype(int)
    response_frame_indices = (response_times * input_fps).astype(int)

    # Get user input for specific trials or ranges of trials
    # Example input: "1,3-5,10" for trials 1, 3 to 5, and 10 or "All" for all trials
    user_input = input("Enter trial numbers or ranges (e.g., '1,3-5,10' or 'All'): ")

    if user_input.strip().lower() == 'all':
        trial_numbers = list(range(1, len(trial_start_times) + 1))
    else:
        trial_numbers = []
        for part in user_input.split(','):
            if '-' in part:
                start, end = map(int, part.split('-'))
                trial_numbers.extend(range(start, end + 1))
            else:
                trial_numbers.append(int(part))
        
    # Get the corresponding frame ranges for the specified trials
    trial_frame_ranges = []
    end_frame = 0
    for trial_num in trial_numbers:
        trial_index = trial_num - 1
        if trial_index < len(trial_start_frame_indices):
            start_frame = trial_start_frame_indices[trial_index]
            end_frame = trial_end_frame_indices[trial_index]
            trial_frame_ranges.append((start_frame, end_frame))

    # Load paw markers
    sl = SessionLoader(one=one, eid=eid)
    sl.load_pose(likelihood_thr=l_thresh, views=[view])
    paw_positions = sl.pose[f'{view}Camera'].loc[:, [f'{paw}_x', f'{paw}_y']].to_numpy()

    # Iterate through each specified trial frame range
    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    
    for start_frame, end_frame in trial_frame_ranges:
        frame_num = start_frame
        cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num)
        while frame_num <= end_frame and frame_num < frame_count:
            ret, frame = cap.read()
            if not ret or frame is None:
                print(f"Skipping frame {frame_num} due to read error.")
                frame_num += 1
                continue

            # Determine the trial number from trial_start_frame_indices
            trial_number_index = np.where(trial_start_frame_indices == start_frame)[0]
            if trial_number_index.size > 0:
                trial_number = trial_number_index[0] + 1

                if frame_num == start_frame:
                    # Insert a one-second screen with trial number
                    for _ in range(int(input_fps)):
                        frame_with_pred = overlay_predictions(np.zeros_like(frame), frame_num, trial_number=trial_number)
                        out.write(frame_with_pred)

                text = None
                if frame_num in stim_on_frame_indices:
                    text = "Stimulus Onset"
                elif frame_num in first_movement_frame_indices:
                    text = "First Movement Detected"
                elif frame_num in response_frame_indices:
                    response_index = np.where(response_frame_indices == frame_num)[0][0]
                    feedback = feedback_types[response_index]
                    text = f"Response Recorded: {'Positive' if feedback == 1 else 'Negative'}"

                paw_pos = paw_positions[frame_num] if frame_num < len(paw_positions) else None

                if frame_num < len(states_ibl_smooth):
                    frame_with_pred = overlay_predictions(frame, frame_num, states_ibl_smooth[frame_num], text=text, paw_position=paw_pos)
                else:
                    frame_with_pred = overlay_predictions(frame, frame_num, text=text, paw_position=paw_pos)

                out.write(frame_with_pred)

            frame_num += 1

            if frame_num % 1000 == 0:
                print(f"Processed frame {frame_num}/{frame_count}")

    # Release everything
    cap.release()
    out.release()
    cv2.destroyAllWindows()

    print(f"Video creation completed. Check '{output_path}'")

In [79]:
# Function to create a heatmap of the wheel speed aligned to movement onset
def create_wheel_speed_heatmap(eid, l_thresh, view, paw):    
    sess_id = dropbox_marker_paths[eid]

    # Load times
    sl = SessionLoader(one=one, eid=eid)
    sl.load_pose(likelihood_thr=l_thresh, views=[view])
    times = sl.pose[f'{view}Camera'].times.to_numpy()
    
    # Load marker data
    file = extract_marker_data(eid, l_thresh, view, paw, smooth=True)
    df = pd.read_csv(file)
    wh_vel = df['wheel_vel'].to_numpy()
    wh_speed = np.abs(wh_vel)

    # Load first movement times
    trials = one.load_object(eid, 'trials')
    first_movement_times = trials['firstMovement_times']

    # Define the window around movement onset
    window = (-0.2, 1.0)  # 200 ms before to 1000 ms after movement onset
    num_frames = int((window[1] - window[0]) * 60)  # 60 Hz sampling rate

    # Initialize array to hold aligned wheel velocities
    aligned_wheel_vel = np.zeros((150, num_frames))

    for i, onset in enumerate(first_movement_times[:150]):
        mask = (times >= (onset + window[0])) & (times <= (onset + window[1]))
        if np.sum(mask) == 0:
            continue
        
        onset_times = times[mask] - onset
        onset_wheel_vel = wh_speed[mask]

        # Interpolate to have a consistent number of frames
        interp_times = np.linspace(window[0], window[1], num_frames)
        interp_wheel_vel = np.interp(interp_times, onset_times, onset_wheel_vel)
        aligned_wheel_vel[i, :] = interp_wheel_vel

    fig = plt.figure(figsize=(15, 10))
    sns.heatmap(aligned_wheel_vel, cmap='coolwarm', cbar_kws={'label': 'Wheel Speed (radians per second)'})
    plt.axvline(x=num_frames * -window[0] / (window[1] - window[0]), color='k', linestyle='--')
    plt.xlabel('Time (frames)')
    plt.ylabel('Trial')
    plt.title(f'Wheel Speed Aligned to Movement Onset for {sess_id}')
    plt.show()
    fig.savefig(f'/Users/zacharyzusin/Documents/NeuroscienceResearch/figures/{sess_id}_wheel_speed_heatmap.png')

In [80]:
# Function to create a heatmap of the predicted states aligned to a specified trial event
def create_states_aligned_heatmap(eid, l_thresh, view, paw, align_event='firstMovement_times', split_by_feedback=False):
    sess_id = dropbox_marker_paths[eid]

    # State labels
    state_labels = ['Still', 'Move', 'Wheel Turn', 'Groom']
    
    # Load times
    sl = SessionLoader(one=one, eid=eid)
    sl.load_pose(likelihood_thr=l_thresh, views=[view])
    times = sl.pose[f'{view}Camera'].times.to_numpy()
    
    # Load marker data
    file = extract_marker_data(eid, l_thresh, view, paw, smooth=True)
    
    # Load trial data
    trials = one.load_object(eid, 'trials')
    trial_start_times = trials['intervals'][:, 0]
    trial_end_times = trials['intervals'][:, 1]
    align_times = trials[align_event]
    feedback_types = trials['feedbackType']

    data_gen = create_data_generator(sess_id, file)
    states_ibl_smooth = get_states(model, data_gen)
    states_ibl_smooth = np.clip(states_ibl_smooth, 1, 4)

    # Create a custom color palette
    cmap = sns.color_palette("tab10", n_colors=4)
    
    # Align states to the specified event
    aligned_states_correct = []
    aligned_states_incorrect = []
    skipped_trials = []
    pre_time = 0.5  # 500 ms before the alignment event
    post_time = 2.0  # 2000 ms after the alignment event
    sampling_rate = 1 / np.median(np.diff(times))  # Approximate sampling rate (Hz)
    pre_frames = int(pre_time * sampling_rate)
    post_frames = int(post_time * sampling_rate)
    max_length = int((pre_time + post_time) * 60)  # 60 Hz sampling rate

    for i, align_time in enumerate(align_times):
        start_time = align_time - pre_time
        end_time = align_time + post_time
        start_idx = np.searchsorted(times, start_time)
        end_idx = np.searchsorted(times, end_time)
        align_idx = np.searchsorted(times, align_time)
        
        trial_states = states_ibl_smooth[start_idx:end_idx]
        align_offset = align_idx - start_idx

        # Convert to float array
        trial_states = trial_states.astype(float)

        # Pad or truncate to ensure same length
        if len(trial_states) < max_length:
            pad_len = max_length - len(trial_states)
            trial_states = np.pad(trial_states, (0, pad_len), mode='constant', constant_values=np.nan)
        elif len(trial_states) > max_length:
            trial_states = trial_states[:max_length]

        # Apply mask for time points outside of the trial
        trial_mask = np.full(max_length, False)
        trial_duration_mask = (times[start_idx:end_idx] >= trial_start_times[i]) & (times[start_idx:end_idx] <= trial_end_times[i])
        min_len = min(len(trial_duration_mask), len(trial_mask))
        trial_mask[:min_len] = trial_duration_mask[:min_len]
        trial_states[~trial_mask] = np.nan

        if np.all(np.isnan(trial_states)):
            skipped_trials.append((i, "All NaNs after alignment"))
        else:
            if split_by_feedback:
                if feedback_types[i] == 1:
                    aligned_states_correct.append((trial_states, align_offset))
                else:
                    aligned_states_incorrect.append((trial_states, align_offset))
            else:
                aligned_states_correct.append((trial_states, align_offset))

    # Convert list of tuples to 2D NumPy arrays
    aligned_states_correct_array = np.full((len(aligned_states_correct), max_length), np.nan)
    for i, (trial_states, align_offset) in enumerate(aligned_states_correct):
        aligned_states_correct_array[i, :len(trial_states)] = trial_states
    
    if split_by_feedback:
        aligned_states_incorrect_array = np.full((len(aligned_states_incorrect), max_length), np.nan)
        for i, (trial_states, align_offset) in enumerate(aligned_states_incorrect):
            aligned_states_incorrect_array[i, :len(trial_states)] = trial_states

    # Print skipped trials and reasons
    for trial_idx, reason in skipped_trials:
        print(f"Skipped trial {trial_idx}: {reason}")

    # Plot the heatmap for correct trials
    fig_correct = plt.figure(figsize=(15, 10))
    ax_correct = sns.heatmap(aligned_states_correct_array, cmap=cmap, cbar=True, yticklabels=np.arange(1, len(aligned_states_correct_array) + 1, 20), annot=False)

    # Customize the colorbar
    colorbar_correct = ax_correct.collections[0].colorbar
    colorbar_correct.set_ticks([1.375, 2.125, 2.875, 3.625])
    colorbar_correct.set_ticklabels(state_labels)

    # Set x-axis ticks and labels
    ticks = np.linspace(0, max_length, num=6)
    ticklabels = np.round(np.linspace(-pre_time, post_time, num=6), 2)
    ax_correct.set_xticks(ticks)
    ax_correct.set_xticklabels(ticklabels)
    
    # Set y-axis ticks and labels (trial numbers)
    ax_correct.set_yticks(np.arange(0, len(aligned_states_correct_array), 50) + 0.5)
    ax_correct.set_yticklabels(np.arange(0, len(aligned_states_correct_array), 50))

    plt.title(f'Aligned Discrete States for {sess_id} (Correct Trials)')
    plt.xlabel(f'Time (seconds) relative to {align_event}')
    plt.ylabel('Trial')
    plt.axvline(x=pre_frames, color='k', linestyle='--', linewidth=1)
    plt.show()
    fig_correct.savefig(f'/Users/zacharyzusin/Documents/NeuroscienceResearch/repro_ephys_analysis/{sess_id}/{sess_id}_{align_event}_aligned_heatmap_correct.png')

    # Plot the heatmap for incorrect trials (if splitting by feedback)
    if split_by_feedback:
        fig_incorrect = plt.figure(figsize=(15, 10))
        ax_incorrect = sns.heatmap(aligned_states_incorrect_array, cmap=cmap, cbar=True, yticklabels=np.arange(1, len(aligned_states_incorrect_array) + 1, 20), annot=False)

        # Customize the colorbar
        colorbar_incorrect = ax_incorrect.collections[0].colorbar
        colorbar_incorrect.set_ticks([1.375, 2.125, 2.875, 3.625])
        colorbar_incorrect.set_ticklabels(state_labels)

        # Set x-axis ticks and labels
        ax_incorrect.set_xticks(ticks)
        ax_incorrect.set_xticklabels(ticklabels)
        
        # Set y-axis ticks and labels (trial numbers)
        ax_incorrect.set_yticks(np.arange(0, len(aligned_states_incorrect_array), 50) + 0.5)
        ax_incorrect.set_yticklabels(np.arange(0, len(aligned_states_incorrect_array), 50))

        plt.title(f'Aligned Discrete States for {sess_id} (Incorrect Trials)')
        plt.xlabel(f'Time (seconds) relative to {align_event}')
        plt.ylabel('Trial')
        plt.axvline(x=pre_frames, color='k', linestyle='--', linewidth=1)
        plt.show()
        fig_incorrect.savefig(f'/Users/zacharyzusin/Documents/NeuroscienceResearch/repro_ephys_analysis/{sess_id}/{sess_id}_{align_event}_aligned_heatmap_incorrect.png')


In [81]:
# Function to create a plot of the length of each trial interval (stim onset to first movement, first movement to response; response to trial end)
def create_trial_interval_plot(eid):
    sess_id = dropbox_marker_paths[eid]

    # Load trial data
    trials = one.load_object(eid, 'trials')
    stim_on_times = trials['stimOn_times']
    first_movement_times = trials['firstMovement_times']
    response_times = trials['response_times']
    trial_start_times = trials['intervals'][:, 0]
    trial_end_times = trials['intervals'][:, 1]
    
    # Calculate the lengths of each trial interval
    stim_to_first_movement = first_movement_times - stim_on_times
    first_movement_to_response = response_times - first_movement_times
    response_to_end = trial_end_times - response_times

    # Remove NaN values before calculating the average durations
    stim_to_first_movement = stim_to_first_movement[np.isfinite(stim_to_first_movement)]
    first_movement_to_response = first_movement_to_response[np.isfinite(first_movement_to_response)]
    response_to_end = response_to_end[np.isfinite(response_to_end)]

    # Calculate the average durations
    avg_stim_to_first_movement = np.mean(stim_to_first_movement)
    avg_first_movement_to_response = np.mean(first_movement_to_response)
    avg_response_to_end = np.mean(response_to_end)

    # Create a DataFrame for easier plotting
    data = {
        'Interval': ['Stimulus to First Movement', 'First Movement to Response', 'Response to Trial End'],
        'Average Duration (s)': [avg_stim_to_first_movement, avg_first_movement_to_response, avg_response_to_end]
    }
    df = pd.DataFrame(data)

    # Plot using DataFrame
    fig, ax = plt.subplots(figsize=(10, 6))
    df.plot(kind='bar', x='Interval', y='Average Duration (s)', ax=ax, edgecolor='grey', legend=False)

    # Customize plot
    ax.set_title(f'Average Trial Interval Durations for session {sess_id}', fontweight='bold')
    ax.set_xlabel('Intervals', fontweight='bold')
    ax.set_ylabel('Seconds', fontweight='bold')
    ax.set_ylim(0, df['Average Duration (s)'].max() * 1.1)  # Add some space above the highest bar

    # Adjust x-axis labels
    ax.tick_params(axis='x', labelrotation=0)  # Set rotation to 0 for horizontal labels

    # Show plot
    plt.tight_layout()
    plt.show()
    fig.savefig(f'/Users/zacharyzusin/Documents/NeuroscienceResearch/repro_ephys_analysis/{sess_id}/{sess_id}_trial_intervals.png')

In [None]:
# Main Loop
save_dir = '/Users/zacharyzusin/Documents/NeuroscienceResearch'

l_thresh = 0.0
l_thresh_smooth = 0.9
view = 'left'
paw = 'paw_r'
if view == 'right':
    raise NotImplementedError

for eid in eids:
    sess_id = dropbox_marker_paths[eid]
    print(f'Eid: {eid} -- Session: {sess_id}')

    # Extract the marker data from the ibl database
    #ibl_markers_file = extract_marker_data(eid, l_thresh, view, paw, smooth=False)
    ibl_smooth_markers_file = extract_marker_data(eid, l_thresh, view, paw, smooth=False)

    # Dropbox Data
    #dropbox_dir = '/Users/zacharyzusin/Documents/NeuroscienceResearch/dropbox'
    #dropbox_markers_file = os.path.join(dropbox_dir, 'features-posvel', sess_id + '_labeled.csv')

    # Build the Data Generators
    #data_gen_dropbox = create_data_generator(sess_id, dropbox_markers_file)
    #data_gen_ibl = create_data_generator(sess_id, ibl_markers_file)
    #data_gen_ibl_smooth = create_data_generator(sess_id, ibl_smooth_markers_file)

    # Run Inference
    #states_dropbox = get_states(model, data_gen_dropbox)
    #states_ibl = get_states(model, data_gen_ibl)
    #states_ibl_smooth = get_states(model, data_gen_ibl_smooth)

    # Ensure state values are within the correct range
    #states_dropbox = np.clip(states_dropbox, 1, 4)
    #states_ibl = np.clip(states_ibl, 1, 4)
    #states_ibl_smooth = np.clip(states_ibl_smooth, 1, 4)

    # Truncate the arrays to match the length of the shortest one
    #min_length = min(len(states_dropbox), len(states_ibl), len(states_ibl_smooth))
    #states_dropbox_truncated = states_dropbox[:min_length]
    #states_ibl_truncated = states_ibl[:min_length]
    #states_ibl_smooth_truncated = states_ibl_smooth[:min_length]
    
    #data_dict = {'Dropbox States': states_dropbox_truncated, 'IBL States': states_ibl_truncated, 'IBL Smooth States': states_ibl_smooth_truncated}
    #data_dict = {'IBL Smooth States': states_ibl_smooth}

    # Create a predicted state heatmap
    #create_predicted_states_heatmap(eid, data_dict)

    # Create a state frequency plot
    #create_state_frequency_plot(eid, data_dict)

    # Create a state duration plot
    #create_state_duration_plot(eid, data_dict)

    # Create a video with predictions overlayed
    #create_video_comparing_predictions(eid, states_dropbox_truncated, states_ibl_truncated, states_ibl_smooth_truncated)

    # Create a video with trial markers
    #create_video_with_trial_markers(eid, states_ibl_smooth, l_thresh_smooth, view, paw)

    # Create a wheel speed heatmap
    #create_wheel_speed_heatmap(eid, l_thresh_smooth, view, paw)

    # Create a heatmap of the predicted states aligned to the first movement times
    #create_first_movement_aligned_heatmap(eid, l_thresh_smooth, view, paw)

    # Create a heatmap of the predicted states aligned to trial event times (which include 'goCueTrigger_times', 'stimOff_times', 'response_times', 'goCue_times', 'firstMovement_times', 'stimOn_times' and 'feedback_times')
    #create_states_aligned_heatmap(eid, l_thresh_smooth, view, paw, align_event='response_times', split_by_feedback=True)

In [33]:
# Session Information (rephro-ephys)
eids = [
    'db4df448-e449-4a6f-a0e7-288711e7a75a',  # Berkeley
    'd23a44ef-1402-4ed7-97f5-47e9a7a504d9',  # Berkeley
    '4a45c8ba-db6f-4f11-9403-56e06a33dfa4',  # Berkeley
    'e535fb62-e245-4a48-b119-88ce62a6fe67',  # Berkeley
    '54238fd6-d2d0-4408-b1a9-d19d24fd29ce',  # Berkeley
    'b03fbc44-3d8e-4a6c-8a50-5ea3498568e0',  # Berkeley
    '30c4e2ab-dffc-499d-aae4-e51d6b3218c2',  # CCU
    'd0ea3148-948d-4817-94f8-dcaf2342bbbe',  # CCU
    'a4a74102-2af5-45dc-9e41-ef7f5aed88be',  # CCU
    '746d1902-fa59-4cab-b0aa-013be36060d5',  # CCU
    '88224abb-5746-431f-9c17-17d7ef806e6a',  # CCU
    '0802ced5-33a3-405e-8336-b65ebc5cb07c',  # CCU
    'ee40aece-cffd-4edb-a4b6-155f158c666a',  # CCU
    'c7248e09-8c0d-40f2-9eb4-700a8973d8c8',  # CCU
    '72cb5550-43b4-4ef0-add5-e4adfdfb5e02',  # CCU
    'dda5fc59-f09a-4256-9fb5-66c67667a466',  # CSHL(C)
    '4b7fbad4-f6de-43b4-9b15-c7c7ef44db4b',  # CSHL(C)
    'f312aaec-3b6f-44b3-86b4-3a0c119c0438',  # CSHL(C)
    '4b00df29-3769-43be-bb40-128b1cba6d35',  # CSHL(C)
    'ecb5520d-1358-434c-95ec-93687ecd1396',  # CSHL(C)
    '51e53aff-1d5d-4182-a684-aba783d50ae5',  # NYU
    'f140a2ec-fd49-4814-994a-fe3476f14e66',  # NYU
    'a8a8af78-16de-4841-ab07-fde4b5281a03',  # NYU
    '61e11a11-ab65-48fb-ae08-3cb80662e5d6',  # NYU
    '73918ae1-e4fd-4c18-b132-00cb555b1ad2',  # Princeton
    'd9f0c293-df4c-410a-846d-842e47c6b502',  # Princeton
    'dac3a4c1-b666-4de0-87e8-8c514483cacf',  # SWC(H)
    '6f09ba7e-e3ce-44b0-932b-c003fb44fb89',  # SWC(H)
    '56b57c38-2699-4091-90a8-aba35103155e',  # SWC(M)
    '3638d102-e8b6-4230-8742-e548cd87a949',  # SWC(M)
    '7cb81727-2097-4b52-b480-c89867b5b34c',  # SWC(M)
    '781b35fd-e1f0-4d14-b2bb-95b7263082bb',  # UCL
    '3f859b5c-e73a-4044-b49e-34bb81e96715',  # UCL
    'b22f694e-4a34-4142-ab9d-2556c3487086',  # UCL
    '0a018f12-ee06-4b11-97aa-bbbff5448e9f',  # UCL
    'aad23144-0e52-4eac-80c5-c4ee2decb198',  # UCL
    'b196a2ad-511b-4e90-ac99-b5a29ad25c22',  # UCL
    'e45481fa-be22-4365-972c-e7404ed8ab5a',  # UCL
    'd04feec7-d0b7-4f35-af89-0232dd975bf0',  # UCL
    '1b715600-0cbc-442c-bd00-5b0ac2865de1',  # UCL
    'c7bf2d49-4937-4597-b307-9f39cb1c7b16',  # UCL
    '8928f98a-b411-497e-aa4b-aa752434686d',  # UCL
    'ebce500b-c530-47de-8cb1-963c552703ea',  # UCLA
    'dc962048-89bb-4e6a-96a9-b062a2be1426',  # UCLA
    '6899a67d-2e53-4215-a52a-c7021b5da5d4',  # UCLA
    '15b69921-d471-4ded-8814-2adad954bcd8',  # UCLA
    '5ae68c54-2897-4d3a-8120-426150704385',  # UCLA
    'ca4ecb4c-4b60-4723-9b9e-2c54a6290a53',  # UCLA
    '824cf03d-4012-4ab1-b499-c83a92c5589e',  # UCLA
    '3bcb81b4-d9ca-4fc9-a1cd-353a966239ca',  # UW
    'f115196e-8dfe-4d2a-8af3-8206d93c1729',  # UW
    '9b528ad0-4599-4a55-9148-96cc1d93fb24',  # UW
    '3e6a97d3-3991-49e2-b346-6948cb4580fb',  # UW
]

dropbox_marker_paths = {
    'db4df448-e449-4a6f-a0e7-288711e7a75a': 'danlab_DY_009_2020-02-27-001',
    'd23a44ef-1402-4ed7-97f5-47e9a7a504d9': 'danlab_DY_016_2020-09-12-001',
    '4a45c8ba-db6f-4f11-9403-56e06a33dfa4': 'danlab_DY_020_2020-09-29-001',
    'e535fb62-e245-4a48-b119-88ce62a6fe67': 'danlab_DY_013_2020-03-12-001',
    '54238fd6-d2d0-4408-b1a9-d19d24fd29ce': 'danlab_DY_018_2020-10-15-001',
    'b03fbc44-3d8e-4a6c-8a50-5ea3498568e0': 'danlab_DY_010_2020-01-27-001',
    '30c4e2ab-dffc-499d-aae4-e51d6b3218c2': 'mainenlab_ZFM-02370_2021-04-29-001',
    'd0ea3148-948d-4817-94f8-dcaf2342bbbe': 'mainenlab_ZFM-01936_2021-01-19-001',
    'a4a74102-2af5-45dc-9e41-ef7f5aed88be': 'mainenlab_ZFM-02368_2021-06-01-001',
    '746d1902-fa59-4cab-b0aa-013be36060d5': 'mainenlab_ZFM-01592_2020-10-20-001',
    '88224abb-5746-431f-9c17-17d7ef806e6a': 'mainenlab_ZFM-02372_2021-06-01-002',
    '0802ced5-33a3-405e-8336-b65ebc5cb07c': 'mainenlab_ZFM-02373_2021-06-23-001',
    'ee40aece-cffd-4edb-a4b6-155f158c666a': 'mainenlab_ZM_2241_2020-01-30-001',
    'c7248e09-8c0d-40f2-9eb4-700a8973d8c8': 'mainenlab_ZM_3001_2020-08-05-001',
    '72cb5550-43b4-4ef0-add5-e4adfdfb5e02': 'mainenlab_ZFM-02369_2021-05-19-001',
    'dda5fc59-f09a-4256-9fb5-66c67667a466': 'churchlandlab_CSHL059_2020-03-06-001',
    '4b7fbad4-f6de-43b4-9b15-c7c7ef44db4b': 'churchlandlab_CSHL049_2020-01-08-001',
    'f312aaec-3b6f-44b3-86b4-3a0c119c0438': 'churchlandlab_CSHL058_2020-07-07-001',
    '4b00df29-3769-43be-bb40-128b1cba6d35': 'churchlandlab_CSHL052_2020-02-21-001',
    'ecb5520d-1358-434c-95ec-93687ecd1396': 'churchlandlab_CSHL051_2020-02-05-001',
    '51e53aff-1d5d-4182-a684-aba783d50ae5': 'angelakilab_NYU-45_2021-07-19-001',
    'f140a2ec-fd49-4814-994a-fe3476f14e66': 'angelakilab_NYU-47_2021-06-21-003',
    'a8a8af78-16de-4841-ab07-fde4b5281a03': 'angelakilab_NYU-12_2020-01-22-001',
    '61e11a11-ab65-48fb-ae08-3cb80662e5d6': 'angelakilab_NYU-21_2020-08-10-002',
    '73918ae1-e4fd-4c18-b132-00cb555b1ad2': 'wittenlab_ibl_witten_27_2021-01-21-001',
    'd9f0c293-df4c-410a-846d-842e47c6b502': 'wittenlab_ibl_witten_25_2020-12-07-002',
    'dac3a4c1-b666-4de0-87e8-8c514483cacf': 'hoferlab_SWC_060_2020-11-24-001',
    '6f09ba7e-e3ce-44b0-932b-c003fb44fb89': 'hoferlab_SWC_043_2020-09-16-002',
    '56b57c38-2699-4091-90a8-aba35103155e': 'mrsicflogellab_SWC_054_2020-10-05-001',
    '3638d102-e8b6-4230-8742-e548cd87a949': 'mrsicflogellab_SWC_058_2020-12-07-001',
    '7cb81727-2097-4b52-b480-c89867b5b34c': 'mrsicflogellab_SWC_052_2020-10-22-001',
    '781b35fd-e1f0-4d14-b2bb-95b7263082bb': 'cortexlab_KS044_2020-12-09-001',
    '3f859b5c-e73a-4044-b49e-34bb81e96715': 'cortexlab_KS094_2022-06-17-001',
    'b22f694e-4a34-4142-ab9d-2556c3487086': 'cortexlab_KS055_2021-05-02-001',
    '0a018f12-ee06-4b11-97aa-bbbff5448e9f': 'cortexlab_KS051_2021-05-11-001',
    'aad23144-0e52-4eac-80c5-c4ee2decb198': 'cortexlab_KS023_2019-12-10-001',
    'b196a2ad-511b-4e90-ac99-b5a29ad25c22': 'cortexlab_KS084_2022-02-01-001',
    'e45481fa-be22-4365-972c-e7404ed8ab5a': 'cortexlab_KS086_2022-03-15-001',
    'd04feec7-d0b7-4f35-af89-0232dd975bf0': 'cortexlab_KS089_2022-03-19-001',
    '1b715600-0cbc-442c-bd00-5b0ac2865de1': 'cortexlab_KS084_2022-01-31-001',
    'c7bf2d49-4937-4597-b307-9f39cb1c7b16': 'cortexlab_KS074_2021-11-22-001',
    '8928f98a-b411-497e-aa4b-aa752434686d': 'cortexlab_KS096_2022-06-17-001',
    'ebce500b-c530-47de-8cb1-963c552703ea': 'churchlandlab_ucla_MFD_09_2023-10-19-001',
    'dc962048-89bb-4e6a-96a9-b062a2be1426': 'churchlandlab_ucla_UCLA048_2022-08-16-001',
    '6899a67d-2e53-4215-a52a-c7021b5da5d4': 'churchlandlab_ucla_MFD_06_2023-08-29-001',
    '15b69921-d471-4ded-8814-2adad954bcd8': 'churchlandlab_ucla_MFD_07_2023-08-31-001',
    '5ae68c54-2897-4d3a-8120-426150704385': 'churchlandlab_ucla_MFD_08_2023-09-07-001',
    'ca4ecb4c-4b60-4723-9b9e-2c54a6290a53': 'churchlandlab_ucla_MFD_05_2023-08-16-001',
    '824cf03d-4012-4ab1-b499-c83a92c5589e': 'churchlandlab_ucla_UCLA011_2021-07-20-001',
    '3bcb81b4-d9ca-4fc9-a1cd-353a966239ca': 'steinmetzlab_NR_0024_2023-01-19-001',
    'f115196e-8dfe-4d2a-8af3-8206d93c1729': 'steinmetzlab_NR_0021_2022-06-23-003',
    '9b528ad0-4599-4a55-9148-96cc1d93fb24': 'steinmetzlab_NR_0019_2022-04-29-001',
    '3e6a97d3-3991-49e2-b346-6948cb4580fb': 'steinmetzlab_NR_0020_2022-05-08-001',
    #dropbox dataset sessions follow
    '46794e05-3f6a-4d35-afb3-9165091a5a74': 'churchlandlab_CSHL045_2020-02-27-001',
    'db4df448-e449-4a6f-a0e7-288711e7a75a': 'danlab_DY_009_2020-02-27-001',
    '54238fd6-d2d0-4408-b1a9-d19d24fd29ce': 'danlab_DY_018_2020-10-15-001',
    'f3ce3197-d534-4618-bf81-b687555d1883': 'hoferlab_SWC_043_2020-09-15-001',
    '493170a6-fd94-4ee4-884f-cc018c17eeb9': 'hoferlab_SWC_061_2020-11-23-001',
    '7cb81727-2097-4b52-b480-c89867b5b34c': 'mrsicflogellab_SWC_052_2020-10-22-001',
    'ff96bfe1-d925-4553-94b5-bf8297adf259': 'wittenlab_ibl_witten_26_2021-01-27-002',
    '73918ae1-e4fd-4c18-b132-00cb555b1ad2': 'wittenlab_ibl_witten_27_2021-01-21-001'
}

In [17]:
# Ethogram graphing functions
palette = sns.color_palette("tab10", n_colors=4)
cmap = ListedColormap(palette.as_hex())

def graph_etho_img(fig, states, markers, marker_names, m_inds, trial_start_times, go_cue_times, feedback_times, start=0, length=600):
    n_rows = 2
    outer_grid = gridspec.GridSpec(n_rows, 1, figure=fig, height_ratios=[.2, .8])
    axes = [fig.add_subplot(outer_grid[i]) for i in range(n_rows)]

    n_classes = np.max(states)  # Determine the number of classes from states

    # Graph DAART model predictions
    data = states[start:(start+length)]
    graph_states(axes[0], data, n_classes)
    axes[0].set(ylabel="DAART")
    axes[0].yaxis.label.set(rotation='horizontal', ha='right')

    # Graph markers
    markers = markers[start:(start+length)]  # Use markers as is
    graph_markers(axes[1], markers, marker_names, m_inds, trial_start_times, go_cue_times, feedback_times, start=start)

    # Add a legend for the states
    state_labels = ['Still', 'Move', 'Wheel Turn', 'Groom']
    state_colors = sns.color_palette("tab10", n_colors=4)
    handles = [Patch(facecolor=color, edgecolor='k', label=label) for color, label in zip(state_colors, state_labels)]
    axes[0].legend(handles=handles, bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)

    # Set x-axis label
    axes[1].set_xlabel('Time (frames)', fontsize=12)

def graph_states(ax, states, n_classes):
    n_frames = states.shape[0]
    im = ax.imshow(
        states[None, :], aspect='auto', 
        cmap=cmap, interpolation='none',
        vmin=1, vmax=n_classes)
    ax.set_xticks([])
    ax.set_yticks([])

def graph_markers(ax, markers, marker_names, m_inds, trial_start_times, go_cue_times, feedback_times, start):
    # Calculate min and max values for each marker
    y_min = np.min(markers, axis=0)
    y_max = np.max(markers, axis=0)
    
    # Define regions for each marker
    num_markers = len(m_inds)
    height = 1.0 
    y_offsets = np.arange(num_markers) * height
    
    for idx, m_idx in enumerate(m_inds):
        marker_data = markers[:, m_idx]
        norm_marker_data = (marker_data - y_min[m_idx]) / (y_max[m_idx] - y_min[m_idx]) * height
        ax.plot(np.arange(len(marker_data)) + start, norm_marker_data + y_offsets[idx], color='k', linewidth=2)
    
    # Add vertical lines for all instances of trial start, go cue, and feedback times within the range
    ax.vlines(trial_start_times[(trial_start_times >= start) & (trial_start_times < start + len(markers))], ymin=-height, ymax=num_markers*height, color='r', linestyle='--')
    ax.vlines(go_cue_times[(go_cue_times >= start) & (go_cue_times < start + len(markers))], ymin=-height, ymax=num_markers*height, color='b', linestyle='--')
    ax.vlines(feedback_times[(feedback_times >= start) & (feedback_times < start + len(markers))], ymin=-height, ymax=num_markers*height, color='g', linestyle='--')
    
    # Create legend handles for each event type
    handles = [
        Patch(color='r', linestyle='--', label='Trial Start'),
        Patch(color='b', linestyle='--', label='Go Cue'),
        Patch(color='g', linestyle='--', label='Feedback')
    ]
    
    # Set x-axis limits and ticks
    end = start + len(markers)
    ax.set_xlim([start, end])
    
    num_ticks = 6
    tick_positions = np.linspace(start, end, num_ticks).astype(int)
    tick_labels = np.linspace(start, end, num_ticks).astype(int)
    ax.set_xticks(tick_positions)
    ax.set_xticklabels(tick_labels)
    
    ax.spines[['right', 'left']].set_visible(False)
    ax.set_yticks(y_offsets + height / 2)
    ax.set_yticklabels(marker_names)
    ax.yaxis.tick_left()
    ax.legend(handles=handles, bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)

In [None]:
# Second Camera Loop
save_dir = '/Users/zacharyzusin/Documents/NeuroscienceResearch'
l_thresh_smooth = 0.9
views = ['left', 'right']
paws = ['paw_r', 'paw_l']

for eid in eids[:1]:
    sess_id = dropbox_marker_paths[eid]
    print(f'Eid: {eid} -- Session: {sess_id}')

    # Load the pose data
    sl = SessionLoader(one=one, eid=eid)
    sl.load_pose(likelihood_thr=l_thresh_smooth, views=views)

    # Load the left camera view
    left_signal = sl.pose[f'{views[0]}Camera'].loc[:, (f'{paws[0]}_x', f'{paws[0]}_y')].to_numpy()
    left_timestamps = sl.pose[f'{views[0]}Camera'].times.to_numpy()

    # Load the right camera view
    right_signal = sl.pose[f'{views[1]}Camera'].loc[:, (f'{paws[1]}_x', f'{paws[1]}_y')].to_numpy()
    right_timestamps = sl.pose[f'{views[1]}Camera'].times.to_numpy()

    # Interpolate the right signal to match the left timestamps
    interpolator = interp1d(right_timestamps, right_signal, axis=0, kind='linear', fill_value='extrapolate')
    right_signal_interpolated = interpolator(left_timestamps)

    # Flip x coordinates for the right paw and z-score
    right_signal_interpolated_flipped = 640 - right_signal_interpolated  # Assuming width of 640
    right_signal_interpolated_zscored = (right_signal_interpolated_flipped - np.mean(right_signal_interpolated_flipped, axis=0)) / np.std(right_signal_interpolated_flipped, axis=0)
    
    # Load wheel data and resample at marker times
    sl.load_wheel()
    wheel_velocity_signal = sl.wheel.velocity.to_numpy()
    wheel_velocity_timestamps = sl.wheel.times.to_numpy()
    wheel_velocity_interpolator = interp1d(wheel_velocity_timestamps, wheel_velocity_signal, kind='linear', fill_value='extrapolate')
    wheel_velocity_interpolated = wheel_velocity_interpolator(left_timestamps)

    # Process the data
    ibl_smooth_markers_file = extract_marker_data(eid, l_thresh_smooth, view, paw, smooth=True)
    data_gen_ibl_smooth = create_data_generator(sess_id, ibl_smooth_markers_file)
    states_ibl_smooth = np.clip(get_states(model, data_gen_ibl_smooth), 1, 4)


    view = 'left'
    paw = 'paw_r'
    create_video_with_trial_markers(eid, states_ibl_smooth, l_thresh_smooth, view, paw)

    

In [18]:
# Function to create histograms of time spent in states during different periods of the trial
def create_histograms(eid, states_ibl_smooth):
    #sess_id = dropbox_marker_paths[eid]
    input_fps = 60
    
    # Load trial event times
    trials = one.load_object(eid, 'trials')
    trial_start_times = trials['intervals'][:, 0]
    trial_end_times = trials['intervals'][:, 1]
    first_movement_times = trials['firstMovement_times']
    feedback_times = trials['feedback_times']

    # Convert times to frame indices
    trial_start_frame_indices = (trial_start_times * input_fps).astype(int)
    trial_end_frame_indices = (trial_end_times * input_fps).astype(int)
    first_movement_frame_indices = (first_movement_times * input_fps).astype(int)
    feedback_frame_indices = (feedback_times * input_fps).astype(int)

    # Define bin edges (in seconds) for histograms
    bin_edges = [0, 0.025, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 1, 2, 5, np.inf]
    bin_labels = ['0-25 ms', '25-50 ms', '50-100 ms', '100-200 ms', '200-300 ms', '300-400 ms', '400-500 ms', '0.5-1 s', '1-2 s', '2-5 s', '5+ s']
    
    def calculate_time_in_state(frame_indices, states):
        time_in_state = np.zeros((len(trial_start_times), 4))
        for i, (start, end) in enumerate(frame_indices):
            trial_states = states[start:end]
            for state in range(4):
                time_in_state[i, state] = np.sum(trial_states == state + 1) / input_fps
        return time_in_state
    
    # Calculate time in each state for both periods
    time_in_state_first_to_feedback = calculate_time_in_state(
        zip(first_movement_frame_indices, feedback_frame_indices),
        states_ibl_smooth
    )
    time_in_state_feedback_to_end = calculate_time_in_state(
        zip(feedback_frame_indices, trial_end_frame_indices),
        states_ibl_smooth
    )

    def plot_histogram(data, period, state_labels):
        fig, axes = plt.subplots(2, 2, figsize=(14, 10))
        axes = axes.flatten()
        for state in range(4):
            state_data = data[:, state]
            hist, _ = np.histogram(state_data, bins=bin_edges)
            hist = hist / len(trial_start_times)
            axes[state].bar(bin_labels, hist, align='center')
            axes[state].set_title(f'{state_labels[state]}')
            axes[state].set_xlabel('Time spent in state')
            axes[state].set_ylabel('Proportion of trials')
            axes[state].set_xticklabels(bin_labels, rotation=45, ha='right')
        
        plt.tight_layout()
        plt.suptitle(f'State Distribution during {period} for Session {sess_id}', y=1.05, fontsize=16)
        plt.show()
        fig.savefig(f'/Users/zacharyzusin/Documents/NeuroscienceResearch/repro_ephys_analysis/{sess_id}/{sess_id}_state_distribution_{period}.png')

    state_labels = ['Still', 'Move', 'Wheel Turn', 'Groom']

    plot_histogram(time_in_state_first_to_feedback, 'first movement to feedback', state_labels)
    plot_histogram(time_in_state_feedback_to_end, 'feedback to trial end', state_labels)

In [None]:
# UMAP Clustering

def calculate_time_in_state(frame_indices, states, input_fps=60):
    time_in_state = np.zeros((len(frame_indices), 4))
    for i, (start, end) in enumerate(frame_indices):
        trial_states = states[start:end]
        for state in range(4):
            time_in_state[i, state] = np.sum(trial_states == state + 1) / input_fps # Convert frames to seconds
    return time_in_state

def get_proportions(eid):
    input_fps = 60

    ibl_smooth_markers_file = extract_marker_data(eid, l_thresh_smooth, view, paw, smooth=True)
    data_gen_ibl_smooth = create_data_generator(eid, ibl_smooth_markers_file)
    states_ibl_smooth = np.clip(get_states(model, data_gen_ibl_smooth), 1, 4)

    # Load trial event times
    trials = one.load_object(eid, 'trials')
    trial_start_times = trials['intervals'][:, 0]
    trial_end_times = trials['intervals'][:, 1]
    first_movement_times = trials['firstMovement_times']
    feedback_times = trials['feedback_times']

    # Convert times to frame indices
    trial_start_frame_indices = (trial_start_times * input_fps).astype(int)
    trial_end_frame_indices = (trial_end_times * input_fps).astype(int)
    first_movement_frame_indices = (first_movement_times * input_fps).astype(int)
    feedback_frame_indices = (feedback_times * input_fps).astype(int)

    # Calculate time in each state for both periods
    time_in_state_first_to_feedback = calculate_time_in_state(
        list(zip(first_movement_frame_indices, feedback_frame_indices)),
        states_ibl_smooth
    )
    time_in_state_feedback_to_end = calculate_time_in_state(
        list(zip(feedback_frame_indices, trial_end_frame_indices)),
        states_ibl_smooth
    )

    # Define bin edges (in seconds) for histograms
    bin_edges = [0, 0.025, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 1, 2, 5, np.inf]

    # Compute histograms
    histograms = np.zeros((2, 4, len(bin_edges) - 1))
    for state in range(4):
        hist_first_to_feedback, _ = np.histogram(time_in_state_first_to_feedback[:, state], bins=bin_edges)
        hist_feedback_to_end, _ = np.histogram(time_in_state_feedback_to_end[:, state], bins=bin_edges)
        histograms[0, state] = hist_first_to_feedback / len(trial_start_times)
        histograms[1, state] = hist_feedback_to_end / len(trial_start_times)

    # Flatten the histograms to create the 88-dim vector
    proportions = histograms.flatten()

    return proportions

def process_sessions_and_umap(eids, dropbox_eids):
    input_fps = 60

    # Combine the eids and dropbox_eids
    all_eids = eids + dropbox_eids
    
    # Get the data
    data = np.array([get_proportions(eid) for eid in all_eids])
    
    # Create labels for the groups
    labels = ['Standard'] * len(eids) + ['Dropbox'] * len(dropbox_eids)
    
    # Apply UMAP to reduce to 2D
    reducer = umap.UMAP(random_state=42)
    embedding = reducer.fit_transform(data)

    # Create a DataFrame for easier handling
    df = pd.DataFrame(embedding, columns=['UMAP1', 'UMAP2'])
    df['Group'] = labels
    df['SessionID'] = [eid for eid in all_eids]
    
    # Plotting the results
    plt.figure(figsize=(12, 10))
    scatter = plt.scatter(df['UMAP1'], df['UMAP2'], c=df['Group'].map({'Standard': 'blue', 'Dropbox': 'red'}), label=df['Group'])

    # Annotate each point with the session ID
    for i, row in df.iterrows():
        plt.annotate(row['SessionID'], (row['UMAP1'], row['UMAP2']), fontsize=8, alpha=0.75)
    
    # Create a legend with handles
    blue_patch = plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='blue', markersize=10, label='Repro Ephys')
    red_patch = plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='red', markersize=10, label='Dropbox')
    
    plt.legend(handles=[blue_patch, red_patch])
    plt.title('UMAP Projection of Session Proportions')
    plt.xlabel('UMAP1')
    plt.ylabel('UMAP2')
    plt.show()

# Call the function with all the necessary parameters
process_sessions_and_umap(eids, dropbox_eids)

In [None]:
# UMAP Clustering (odd vs even trials)

def calculate_time_in_state(frame_indices, states, input_fps=60):
    time_in_state = np.zeros((len(frame_indices), 4))
    for i, (start, end) in enumerate(frame_indices):
        trial_states = states[start:end]
        for state in range(4):
            time_in_state[i, state] = np.sum(trial_states == state + 1) / input_fps # Convert frames to seconds
    return time_in_state

def get_proportions(eid, trial_indices):
    input_fps = 60

    ibl_smooth_markers_file = extract_marker_data(eid, l_thresh_smooth, view, paw, smooth=True)
    data_gen_ibl_smooth = create_data_generator(eid, ibl_smooth_markers_file)
    states_ibl_smooth = np.clip(get_states(model, data_gen_ibl_smooth), 1, 4)

    # Load trial event times
    trials = one.load_object(eid, 'trials')
    trial_start_times = trials['intervals'][:, 0][trial_indices]
    trial_end_times = trials['intervals'][:, 1][trial_indices]
    first_movement_times = trials['firstMovement_times'][trial_indices]
    feedback_times = trials['feedback_times'][trial_indices]

    # Convert times to frame indices
    trial_start_frame_indices = (trial_start_times * input_fps).astype(int)
    trial_end_frame_indices = (trial_end_times * input_fps).astype(int)
    first_movement_frame_indices = (first_movement_times * input_fps).astype(int)
    feedback_frame_indices = (feedback_times * input_fps).astype(int)

    # Calculate time in each state for both periods
    time_in_state_first_to_feedback = calculate_time_in_state(
        list(zip(first_movement_frame_indices, feedback_frame_indices)),
        states_ibl_smooth
    )
    time_in_state_feedback_to_end = calculate_time_in_state(
        list(zip(feedback_frame_indices, trial_end_frame_indices)),
        states_ibl_smooth
    )

    # Define bin edges (in seconds) for histograms
    bin_edges = [0, 0.025, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 1, 2, 5, np.inf]

    # Compute histograms
    histograms = np.zeros((2, 4, len(bin_edges) - 1))
    for state in range(4):
        hist_first_to_feedback, _ = np.histogram(time_in_state_first_to_feedback[:, state], bins=bin_edges)
        hist_feedback_to_end, _ = np.histogram(time_in_state_feedback_to_end[:, state], bins=bin_edges)
        histograms[0, state] = hist_first_to_feedback / len(trial_start_times)
        histograms[1, state] = hist_feedback_to_end / len(trial_start_times)

    # Flatten the histograms to create the 88-dim vector
    proportions = histograms.flatten()
    return proportions

def process_sessions_and_umap(eids, dropbox_eids):
    input_fps = 60

    # Combine the eids and dropbox_eids
    all_eids = eids + dropbox_eids
    
    # Get the data for even and odd trials separately
    even_data = []
    odd_data = []
    
    for eid in all_eids:
        trials = one.load_object(eid, 'trials')
        num_trials = len(trials['intervals'])
        even_trials = np.arange(0, num_trials, 2)
        odd_trials = np.arange(1, num_trials, 2)
        
        even_data.append(get_proportions(eid, even_trials))
        odd_data.append(get_proportions(eid, odd_trials))
    
    even_data = np.array(even_data)
    odd_data = np.array(odd_data)
    
    # Concatenate the even and odd data to fit a single UMAP
    combined_data = np.concatenate([even_data, odd_data])
    
    # Apply UMAP to reduce to 2D
    reducer = umap.UMAP(random_state=42)
    embedding = reducer.fit_transform(combined_data)
    
    # Split the embedding into even and odd embeddings
    embedding_even = embedding[:len(even_data)]
    embedding_odd = embedding[len(even_data):]
    
    # Create DataFrames for easier handling
    df_even = pd.DataFrame(embedding_even, columns=['UMAP1', 'UMAP2'])
    df_even['Group'] = ['Standard'] * len(eids) + ['Dropbox'] * len(dropbox_eids)
    df_even['SessionID'] = [eid for eid in all_eids]

    df_odd = pd.DataFrame(embedding_odd, columns=['UMAP1', 'UMAP2'])
    df_odd['Group'] = ['Standard'] * len(eids) + ['Dropbox'] * len(dropbox_eids)
    df_odd['SessionID'] = [eid for eid in all_eids]
    
    # Plotting the results
    plt.figure(figsize=(12, 10))
    scatter_even = plt.scatter(df_even['UMAP1'], df_even['UMAP2'], c='blue', label='Even Trials')
    scatter_odd = plt.scatter(df_odd['UMAP1'], df_odd['UMAP2'], c='red', label='Odd Trials')
    
    # Connect even and odd trials with lines
    for i in range(len(df_even)):
        plt.plot([df_even['UMAP1'][i], df_odd['UMAP1'][i]], [df_even['UMAP2'][i], df_odd['UMAP2'][i]], 'k--', alpha=0.5)
    
    # Create a legend
    plt.legend()
    plt.title('UMAP Projection of Even and Odd Trial Proportions')
    plt.xlabel('UMAP1')
    plt.ylabel('UMAP2')
    plt.show()

process_sessions_and_umap(eids, dropbox_eids)


In [None]:
# Diagnostic 1: Trace Plots
palette = sns.color_palette("tab10", n_colors=4)
cmap = ListedColormap(palette.as_hex())

def graph_etho_img(fig, states, markers, marker_names, m_inds, trial_start_times, go_cue_times, feedback_times, start=0, length=600):
    n_rows = 2
    outer_grid = gridspec.GridSpec(n_rows, 1, figure=fig, height_ratios=[.2, .8])
    axes = [fig.add_subplot(outer_grid[i]) for i in range(n_rows)]

    n_classes = np.max(states)

    # Graph DAART model predictions
    data = states[start:(start+length)]
    graph_states(axes[0], data, n_classes)
    axes[0].set(ylabel="DAART")
    axes[0].yaxis.label.set(rotation='horizontal', ha='right')

    # Graph markers
    markers = markers[start:(start+length)]
    graph_markers(axes[1], markers, marker_names, m_inds, trial_start_times, go_cue_times, feedback_times, start=start)

    # Add a legend for the states
    state_labels = ['Still', 'Move', 'Wheel Turn', 'Groom']
    state_colors = sns.color_palette("tab10", n_colors=4)
    handles = [Patch(facecolor=color, edgecolor='k', label=label) for color, label in zip(state_colors, state_labels)]
    axes[0].legend(handles=handles, bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)
    axes[1].set_xlabel('Time (frames)', fontsize=12)

def graph_states(ax, states, n_classes):
    n_frames = states.shape[0]
    im = ax.imshow(
        states[None, :], aspect='auto', 
        cmap=cmap, interpolation='none',
        vmin=1, vmax=n_classes)
    ax.set_xticks([])
    ax.set_yticks([])

def graph_markers(ax, markers, marker_names, m_inds, trial_start_times, go_cue_times, feedback_times, start):
    # Calculate min and max values for each marker
    y_min = np.min(markers, axis=0)
    y_max = np.max(markers, axis=0)
    
    # Define regions for each marker
    num_markers = len(m_inds)
    height = 1.0
    y_offsets = np.arange(num_markers) * height
    
    for idx, m_idx in enumerate(m_inds):
        marker_data = markers[:, m_idx]
        norm_marker_data = (marker_data - y_min[m_idx]) / (y_max[m_idx] - y_min[m_idx]) * height
        ax.plot(np.arange(len(marker_data)) + start, norm_marker_data + y_offsets[idx], color='k', linewidth=2)
    
    # Add vertical lines for all instances of trial start, go cue, and feedback times within the range
    ax.vlines(trial_start_times[(trial_start_times >= start) & (trial_start_times < start + len(markers))], ymin=-height, ymax=num_markers*height, color='r', linestyle='--')
    ax.vlines(go_cue_times[(go_cue_times >= start) & (go_cue_times < start + len(markers))], ymin=-height, ymax=num_markers*height, color='b', linestyle='--')
    ax.vlines(feedback_times[(feedback_times >= start) & (feedback_times < start + len(markers))], ymin=-height, ymax=num_markers*height, color='g', linestyle='--')
    
    # Create legend handles for each event type
    handles = [
        Patch(color='r', linestyle='--', label='Trial Start'),
        Patch(color='b', linestyle='--', label='Go Cue'),
        Patch(color='g', linestyle='--', label='Feedback')
    ]
    
    # Set x-axis limits and ticks
    end = start + len(markers)
    ax.set_xlim([start, end])
    
    num_ticks = 6
    tick_positions = np.linspace(start, end, num_ticks).astype(int)
    tick_labels = np.linspace(start, end, num_ticks).astype(int)
    ax.set_xticks(tick_positions)
    ax.set_xticklabels(tick_labels)
    
    ax.spines[['right', 'left']].set_visible(False)
    ax.set_yticks(y_offsets + height / 2)
    ax.set_yticklabels(marker_names)
    ax.yaxis.tick_left()
    ax.legend(handles=handles, bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)


save_dir = '/Users/zacharyzusin/Documents/NeuroscienceResearch/repro_ephys_analysis'

l_thresh_smooth = 0.9
view = 'left'
paw = 'paw_r'

if view == 'right':
    raise NotImplementedError

for eid in eids[:1]:
    print(f'Eid: {eid}')
    sess_id = dropbox_marker_paths[eid]
    print(f'Session: {sess_id}')

    # Extract the marker data from the ibl database
    ibl_smooth_markers_file = extract_marker_data(eid, l_thresh_smooth, view, paw, smooth=True)

    # Build the Data Generators
    data_gen_ibl_smooth = create_data_generator(sess_id, ibl_smooth_markers_file)

    # Run Inference
    states_ibl_smooth = get_states(model, data_gen_ibl_smooth)

    # Ensure state values are within the correct range
    states_ibl_smooth = np.clip(states_ibl_smooth, 1, 4)
    
    data_dict = {'IBL Smooth States': states_ibl_smooth}

    # Load the pose data
    sl = SessionLoader(one=one, eid=eid)
    sl.load_pose(likelihood_thr=0.9, views=[view])
    times = sl.pose[f'{view}Camera'].times.to_numpy()
    markers = sl.pose[f'{view}Camera'].loc[:, (f'{paw}_x', f'{paw}_y')].to_numpy()  

    # Load wheel data
    sl.load_wheel()
    wh_times = sl.wheel.times.to_numpy()
    wh_vel_oversampled = sl.wheel.velocity.to_numpy()
    
    # Resample wheel data at marker times
    interpolator = interp1d(wh_times, wh_vel_oversampled, fill_value='extrapolate')
    wh_vel = interpolator(times)

    # Smooth the marker data
    markers[:, 0] = smooth_interpolate_signal_sg(markers[:, 0], window=7)
    markers[:, 1] = smooth_interpolate_signal_sg(markers[:, 1], window=7)
        
    # Process the data
    markers_comb = np.hstack([markers, wh_vel[:, None]])
    velocity = np.vstack([np.array([0, 0, 0]), np.diff(markers_comb, axis=0)])
    markers_comb = np.hstack([markers_comb, velocity])
    markers_z = (markers_comb - np.mean(markers_comb, axis=0)) / np.std(markers_comb, axis=0)
    
    # Update marker_names and m_inds to remove paw_x_vel and wheel_acc
    marker_names = ['paw_x_pos', 'paw_y_pos', 'wheel_vel']
    m_inds = [0, 1, 2]

    # Load trial data
    trials = one.load_object(eid, 'trials')
    trial_start_times = trials['intervals'][:, 0]
    go_cue_times = trials['goCue_times']
    feedback_times = trials['feedback_times']

    # Convert trial times to frame indices
    frame_rate = 30  # Assuming 30 Hz frame rate, adjust this if different
    trial_start_frames = (trial_start_times * frame_rate).astype(int)
    go_cue_frames = (go_cue_times * frame_rate).astype(int)
    feedback_frames = (feedback_times * frame_rate).astype(int)

    # Iterate through each trial and plot
    for trial_idx in range(len(trial_start_frames)):
        start_frame = trial_start_frames[trial_idx]
        if trial_idx < len(trial_start_frames) - 1:
            end_frame = trial_start_frames[trial_idx + 1]
        else:
            end_frame = start_frame + 600  # Assuming a default length of 600 frames if it's the last trial

        fig = plt.figure(figsize=(10, 6))
        graph_etho_img(fig, states_ibl_smooth, markers_z, marker_names, m_inds, trial_start_frames, go_cue_frames, feedback_frames, start=start_frame, length=int(end_frame-start_frame))
        plt.show()


In [None]:
# Trace Plots Loop

save_dir = '/Users/zacharyzusin/Documents/NeuroscienceResearch/repro_ephys_analysis'

l_thresh_smooth = 0.9
view = 'left'
paw = 'paw_r'

if view == 'right':
    raise NotImplementedError

for eid in eids[:1]:
    print(f'Eid: {eid}')
    sess_id = dropbox_marker_paths[eid]
    print(f'Session: {sess_id}')

    # Extract the marker data from the ibl database
    ibl_smooth_markers_file = extract_marker_data(eid, l_thresh_smooth, view, paw, smooth=True)

    # Build the Data Generators
    data_gen_ibl_smooth = create_data_generator(sess_id, ibl_smooth_markers_file)

    # Run Inference
    states_ibl_smooth = get_states(model, data_gen_ibl_smooth)

    # Ensure state values are within the correct range
    states_ibl_smooth = np.clip(states_ibl_smooth, 1, 4)
    
    data_dict = {'IBL Smooth States': states_ibl_smooth}

    # Load the pose data
    sl = SessionLoader(one=one, eid=eid)
    sl.load_pose(likelihood_thr=0.9, views=[view])
    times = sl.pose[f'{view}Camera'].times.to_numpy()
    markers = sl.pose[f'{view}Camera'].loc[:, (f'{paw}_x', f'{paw}_y')].to_numpy()  

    # Load wheel data
    sl.load_wheel()
    wh_times = sl.wheel.times.to_numpy()
    wh_vel_oversampled = sl.wheel.velocity.to_numpy()
    
    # Resample wheel data at marker times
    interpolator = interp1d(wh_times, wh_vel_oversampled, fill_value='extrapolate')
    wh_vel = interpolator(times)

    # Smooth the marker data
    markers[:, 0] = smooth_interpolate_signal_sg(markers[:, 0], window=7)
    markers[:, 1] = smooth_interpolate_signal_sg(markers[:, 1], window=7)
        
    # Process the data
    markers_comb = np.hstack([markers, wh_vel[:, None]])
    velocity = np.vstack([np.array([0, 0, 0]), np.diff(markers_comb, axis=0)])
    markers_comb = np.hstack([markers_comb, velocity])
    markers_z = (markers_comb - np.mean(markers_comb, axis=0)) / np.std(markers_comb, axis=0)
    
    # Update marker_names and m_inds to remove paw_x_vel and wheel_acc
    marker_names = ['paw_x_pos', 'paw_y_pos', 'wheel_vel']
    m_inds = [0, 1, 2]

    # Load trial data
    trials = one.load_object(eid, 'trials')
    trial_start_times = trials['intervals'][:, 0]
    go_cue_times = trials['goCue_times']
    feedback_times = trials['feedback_times']

    # Convert trial times to frame indices
    frame_rate = 30  # Assuming 30 Hz frame rate, adjust this if different
    trial_start_frames = (trial_start_times * frame_rate).astype(int)
    go_cue_frames = (go_cue_times * frame_rate).astype(int)
    feedback_frames = (feedback_times * frame_rate).astype(int)

    # Iterate through each trial and plot
    for trial_idx in range(len(trial_start_frames)):
        start_frame = trial_start_frames[trial_idx]
        if trial_idx < len(trial_start_frames) - 1:
            end_frame = trial_start_frames[trial_idx + 1]
        else:
            end_frame = start_frame + 600  # Assuming a default length of 600 frames if it's the last trial

        fig = plt.figure(figsize=(10, 6))
        graph_etho_img(fig, states_ibl_smooth, markers_z, marker_names, m_inds, trial_start_frames, go_cue_frames, feedback_frames, start=start_frame, length=int(end_frame-start_frame))
        plt.show()

In [85]:
# Diagnostic 2: Per State Scatterplots of Paw Locations

# Function to overlay paw locations on a video frame
def overlay_paw_locations_on_frame(frame, paw_coordinates, color):
    for (x, y) in paw_coordinates:
        if not np.isnan(x) and not np.isnan(y):
            cv2.circle(frame, (int(x), int(y)), 3, color, -1)

# Function to scatter paw locations on video frames for each state
def scatter_paw_locations_on_video_frame(states, markers, video_path):
    colors = {
        'Still': (255, 0, 0),       # Blue
        'Move': (0, 165, 255),      # Orange
        'Wheel Turn': (0, 255, 0),  # Green
        'Groom': (0, 0, 255)        # Red
    }
    state_to_label = {1: 'Still', 2: 'Move', 3: 'Wheel Turn', 4: 'Groom'}

    cap = cv2.VideoCapture(video_path)
    
    if not cap.isOpened():
        print(f"Error: Cannot open video file {video_path}")
        return

    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    random_frame_number = random.randint(0, total_frames - 1)

    cap.set(cv2.CAP_PROP_POS_FRAMES, random_frame_number)
    ret, frame = cap.read()
    
    if not ret:
        print("Error: Cannot read the frame from video")
        return

    unique_states = np.unique(states)
    fig, axes = plt.subplots(1, len(unique_states), figsize=(20, 5))
    for i, state in enumerate(unique_states):
        state_indices = np.where(states == state)
        paw_coordinates = markers[state_indices]

        overlay_frame = frame.copy()
        label = state_to_label[state]
        overlay_paw_locations_on_frame(overlay_frame, paw_coordinates, colors[label])

        axes[i].imshow(cv2.cvtColor(overlay_frame, cv2.COLOR_BGR2RGB))
        axes[i].set_title(f'{label} Paw Locations')
        axes[i].axis('off')
        
    plt.tight_layout()
    plt.show()
    cap.release()

In [84]:
# Diagnostic 3: Wheel Position Conditioned on State Transition

def plot_relative_wheel_position_conditioned_on_frequent_transitions(eid, wh_times, wh_pos, states, times, trials, pre_window=(-0.25, 0), post_window=(0, 1.5), transition_threshold=5, sampling_rate=60):
    sess_id = dropbox_marker_paths[eid]
    state_labels = ['Still', 'Move', 'Wheel Turn', 'Groom']
    cmap = sns.color_palette("tab10", n_colors=4)

    # Resample wheel data at trial times
    interpolator = interp1d(wh_times, wh_pos, fill_value='extrapolate')
    wh_pos_resampled = interpolator(times)

    # Load stimulus onset times
    stimOn_times = trials['stimOn_times']
    choices = trials['choice']

    # Ensure states and times arrays have the same length
    min_length = min(len(times), len(states))
    times = times[:min_length]
    states = states[:min_length]

    # Define the windows
    num_pre_frames = int((pre_window[1] - pre_window[0]) * sampling_rate)
    num_post_frames = int((post_window[1] - post_window[0]) * sampling_rate)

    # Initialize plot
    plt.figure(figsize=(15, 10))

    # Identify transitions
    transitions = []

    for i, onset in enumerate(stimOn_times):
        if choices[i] == 0:  # Skip trials with no turn
            continue

        pre_mask = (times >= (onset + pre_window[0])) & (times <= (onset + pre_window[1]))
        post_mask = (times >= (onset + post_window[0])) & (times <= (onset + post_window[1]))

        if np.sum(pre_mask) == 0 or np.sum(post_mask) == 0:
            continue

        pre_states = states[pre_mask]
        post_states = states[post_mask]

        if len(pre_states) == 0 or len(post_states) == 0:
            continue

        most_common_pre_state = Counter(pre_states).most_common(1)[0][0]
        most_common_post_state = Counter(post_states).most_common(1)[0][0]

        # Check if the pre-state and post-state are different
        if most_common_pre_state != most_common_post_state:
            transitions.append((most_common_pre_state, most_common_post_state, choices[i]))

    # Count transitions
    transition_counts = Counter((pre, post, choice) for pre, post, choice in transitions)

    # Filter transitions that occur frequently enough
    frequent_transitions = [(pre, post) for (pre, post, choice), count in transition_counts.items() if count >= transition_threshold]

    max_abs_val = 0

    for transition in frequent_transitions:
        pre_state, post_state = transition
        transition_times = []

        for i, onset in enumerate(stimOn_times):
            if choices[i] == 0:  # Skip trials with no turn
                continue

            pre_mask = (times >= (onset + pre_window[0])) & (times <= (onset + pre_window[1]))
            post_mask = (times >= (onset + post_window[0])) & (times <= (onset + post_window[1]))

            if np.sum(pre_mask) == 0 or np.sum(post_mask) == 0:
                continue

            pre_states = states[pre_mask]
            post_states = states[post_mask]

            if len(pre_states) == 0 or len(post_states) == 0:
                continue

            most_common_pre_state = Counter(pre_states).most_common(1)[0][0]
            most_common_post_state = Counter(post_states).most_common(1)[0][0]

            # Check if the pre-state and post-state are different
            if most_common_pre_state != most_common_post_state:
                if most_common_pre_state == pre_state and most_common_post_state == post_state:
                    transition_times.append((onset, choices[i]))  # Store onset and associated choice

        plt.figure(figsize=(15, 10))

        for onset, choice in transition_times:
            mask = (times >= (onset + pre_window[0])) & (times <= (onset + post_window[1]))
            if np.sum(mask) == 0:
                continue

            onset_times = times[mask] - onset
            onset_wheel_pos = wh_pos_resampled[mask]
            onset_states = states[mask]

            # Interpolate to have a consistent number of frames
            interp_times = np.linspace(pre_window[0], post_window[1], num_pre_frames + num_post_frames)
            interp_wheel_pos = np.interp(interp_times, onset_times, onset_wheel_pos)
            interp_states = np.interp(interp_times, onset_times, onset_states, left=-1, right=-1)

            # Adjust to relative wheel position
            relative_wheel_pos = interp_wheel_pos - interp_wheel_pos[0]

            # Update the maximum absolute value for y-axis limits
            max_abs_val = max(max_abs_val, np.max(np.abs(relative_wheel_pos)))

            # Determine the color based on the choice
            color = 'red' if choice == 1 else 'blue'  # Right turn is red, left turn is blue

            # Plot each trial's relative wheel position
            plt.plot(interp_times, relative_wheel_pos, alpha=0.5, color=color)

        # Plot the state bar at the bottom
        state_colors = np.array([cmap[int(s)-1] if s != -1 else (1, 1, 1) for s in interp_states])  # Use white for undefined states
        plt.imshow([state_colors], aspect='auto', extent=[pre_window[0], post_window[1], -max_abs_val, -max_abs_val - 0.1], interpolation='nearest')

        # Set symmetrical y-axis limits
        plt.ylim(-max_abs_val - 0.1, max_abs_val)

        # Add a vertical line at stimulus onset time
        plt.axvline(x=0, color='k', linestyle='--')

        # Add labels and title
        plt.xlabel('Time (seconds)')
        plt.ylabel('Relative Wheel Position (radians)')
        plt.title(f'Relative Wheel Position Over Time Conditioned on Transition {state_labels[pre_state-1]} to {state_labels[post_state-1]} for Session: {sess_id}')

        # Add legend
        legend_elements = [
            Line2D([0], [0], color='red', lw=2, label='Right Turn'),
            Line2D([0], [0], color='blue', lw=2, label='Left Turn'),
            Line2D([0], [0], color='k', lw=2, linestyle='--', label='Stimulus Onset'),
            Line2D([0], [0], color=cmap[0], lw=2, label='Still'),
            Line2D([0], [0], color=cmap[1], lw=2, label='Move'),
            Line2D([0], [0], color=cmap[2], lw=2, label='Wheel Turn'),
            Line2D([0], [0], color=cmap[3], lw=2, label='Groom')
        ]
        plt.legend(handles=legend_elements)

        # Show plot
        plt.show()


In [83]:
# Diagnostic 4: Paw Speed Conditioned on State Transition

def plot_relative_paw_speed_conditioned_on_state_transitions(eid, times, paw_x_vel, paw_y_vel, states, trials, pre_window=(-0.5, 0), post_window=(0, 1.5), transition_threshold=5, sampling_rate=60):
    sess_id = dropbox_marker_paths[eid]
    state_labels = ['Still', 'Move', 'Wheel Turn', 'Groom']

    paw_speed = np.sqrt(paw_x_vel ** 2 + paw_y_vel ** 2)

    stimOn_times = trials['stimOn_times']
    choices = trials['choice']
    feedbackTypes = trials['feedbackType']

    min_length = min(len(times), len(states))
    times = times[:min_length]
    states = states[:min_length]

    num_pre_frames = int((pre_window[1] - pre_window[0]) * sampling_rate)
    num_post_frames = int((post_window[1] - post_window[0]) * sampling_rate)

    transitions = []

    for i, onset in enumerate(stimOn_times):
        if choices[i] == 0:
            continue

        pre_mask = (times >= (onset + pre_window[0])) & (times <= (onset + pre_window[1]))
        post_mask = (times >= (onset + post_window[0])) & (times <= (onset + post_window[1]))

        if np.sum(pre_mask) == 0 or np.sum(post_mask) == 0:
            continue

        pre_states = states[pre_mask]
        post_states = states[post_mask]

        if len(pre_states) == 0 or len(post_states) == 0:
            continue

        most_common_pre_state = Counter(pre_states).most_common(1)[0][0]
        most_common_post_state = Counter(post_states).most_common(1)[0][0]

        if most_common_pre_state != most_common_post_state:
            transitions.append((most_common_pre_state, most_common_post_state, choices[i], feedbackTypes[i]))

    transition_counts = Counter((pre, post, choice) for pre, post, choice, feedback in transitions)

    frequent_transitions = [(pre, post) for (pre, post, choice), count in transition_counts.items() if count >= transition_threshold]

    max_abs_val = 0

    for transition in frequent_transitions:
        pre_state, post_state = transition
        transition_times = []

        for i, onset in enumerate(stimOn_times):
            if choices[i] == 0:
                continue

            pre_mask = (times >= (onset + pre_window[0])) & (times <= (onset + pre_window[1]))
            post_mask = (times >= (onset + post_window[0])) & (times <= (onset + post_window[1]))

            if np.sum(pre_mask) == 0 or np.sum(post_mask) == 0:
                continue

            pre_states = states[pre_mask]
            post_states = states[post_mask]

            if len(pre_states) == 0 or len(post_states) == 0:
                continue

            most_common_pre_state = Counter(pre_states).most_common(1)[0][0]
            most_common_post_state = Counter(post_states).most_common(1)[0][0]

            if most_common_pre_state != most_common_post_state:
                if most_common_pre_state == pre_state and most_common_post_state == post_state:
                    transition_times.append((onset, choices[i], feedbackTypes[i]))

        plt.figure(figsize=(15, 10))

        for onset, choice, feedback in transition_times:
            mask = (times >= (onset + pre_window[0])) & (times <= (onset + post_window[1]))
            if np.sum(mask) == 0:
                continue

            onset_times = times[mask] - onset
            onset_paw_speed = paw_speed[mask]

            interp_times = np.linspace(pre_window[0], post_window[1], num_pre_frames + num_post_frames)
            interp_paw_speed = np.interp(interp_times, onset_times, onset_paw_speed)

            max_abs_val = max(max_abs_val, np.max(interp_paw_speed))

            color_paw = 'black' if feedback == 1 else 'grey'

            plt.plot(interp_times, interp_paw_speed, alpha=0.5, color=color_paw)

        plt.ylim(0, max_abs_val)

        plt.axvline(x=0, color='red', linestyle='--')

        plt.xlabel('Time (seconds)')
        plt.ylabel('Paw Speed')
        plt.title(f'Paw Speed Over Time Conditioned on Transition {state_labels[pre_state-1]} to {state_labels[post_state-1]} for Session: {sess_id}')

        legend_elements = [
            Line2D([0], [0], color='black', lw=2, label='Correct Trial'),
            Line2D([0], [0], color='#D3D3D3', lw=2, label='Incorrect Trial'),
            Line2D([0], [0], color='red', lw=2, linestyle='--', label='Stimulus Onset')
        ]
        plt.legend(handles=legend_elements)
        plt.show()

In [None]:
# Diagnostics Loop

for eid in eids[:1]:
    print(f'Eid: {eid}')
    
    #Session Loader Parameters
    l_thresh_smooth = 0.9
    view = 'left'
    paw = 'paw_r'
    
    # Load trial times
    sl = SessionLoader(one=one, eid=eid)
    sl.load_pose(likelihood_thr=l_thresh_smooth, views=[view])
    times = sl.pose[f'{view}Camera'].times.to_numpy()
    markers = sl.pose[f'{view}Camera'].loc[:, (f'{paw}_x', f'{paw}_y')].to_numpy()

    trials = one.load_object(eid, 'trials')

    # Load wheel data
    sl.load_wheel()
    wh_times = sl.wheel.times.to_numpy()
    wh_pos = sl.wheel.position.to_numpy()

    # Generate state predictions
    ibl_smooth_markers_file = extract_marker_data(eid, l_thresh_smooth, view, paw, smooth=True)
    data_gen_ibl_smooth = create_data_generator(sess_id, ibl_smooth_markers_file)
    states_ibl_smooth = np.clip(get_states(model, data_gen_ibl_smooth), 1, 4)

    # Provide the path to the session's video
    video_path = vidio.url_from_eid(eid, one=one)[view]
    
    scatter_paw_locations_on_video_frame(states_ibl_smooth, markers, video_path)
    plot_relative_wheel_position_conditioned_on_frequent_transitions(eid, wh_times, wh_pos, states_ibl_smooth, times[:len(states_ibl_smooth)], trials)

    # Load marker data for paw velocities
    file = extract_marker_data(eid, l_thresh_smooth, view, paw, smooth=True)
    df = pd.read_csv(file)
    paw_x_vel = df['paw_x_vel'].to_numpy()
    paw_y_vel = df['paw_y_vel'].to_numpy()
    plot_relative_paw_speed_conditioned_on_state_transitions(eid, times[:len(states_ibl_smooth)], paw_x_vel[:len(states_ibl_smooth)], paw_y_vel[:len(states_ibl_smooth)], states_ibl_smooth, trials)
    create_states_aligned_heatmap(eid, l_thresh_smooth, view, paw, align_event='stimOn_times', split_by_feedback=True)

In [None]:
# (INCOMPLETE)

def extract_state_runs(states, min_length=20, transition_threshold=20):
        
        K = int(np.max(states) + 1)
        state_snippets = [[] for _ in range(K)]
        transitions = []

        buffer = 100

        i_beg = buffer
        curr_state = states[i_beg]
        curr_len = 1

        for i in range(i_beg + 1, len(states) - buffer):
            next_state = states[i]
            if next_state != curr_state:
                # Record indices if state duration long enough
                if curr_len >= min_length:
                    state_snippets[curr_state].append(np.arange(i_beg, i))

                # Check for nearby different state chunks
                if len(state_snippets[next_state]) > 0:
                    for prev_chunk in state_snippets[curr_state]:
                        if np.abs(prev_chunk[-1] - i_beg) <= transition_threshold:
                            transition_frame = (prev_chunk[-1] + i_beg) // 2
                            transitions.append((transition_frame, curr_state, next_state))
                            break
                
                i_beg = i
                curr_state = next_state
                curr_len = 1
            else:
                curr_len += 1
        
        # End of trial cleanup
        if curr_len >= min_length:
            state_snippets[curr_state].append(np.arange(i_beg, len(states) - buffer))

        return {
            "state_snippets": state_snippets,
            "transitions": transitions
        }

import matplotlib.pyplot as plt
import numpy as np
from scipy import interpolate

def plot_wheel_around_state_transitions(one, eid, model, sess_id, min_length=20, transition_threshold=20, window_size=100):

    # Session Loader Parameters
    l_thresh_smooth = 0.9
    view = 'left'
    paw = 'paw_r'

    # Load trial times
    sl = SessionLoader(one=one, eid=eid)
    sl.load_pose(likelihood_thr=l_thresh_smooth, views=[view])
    times = sl.pose[f'{view}Camera'].times.to_numpy()
    markers = sl.pose[f'{view}Camera'].loc[:, (f'{paw}_x', f'{paw}_y')].to_numpy()

    trials = one.load_object(eid, 'trials')

    # Load wheel data
    sl.load_wheel()
    wh_times = sl.wheel.times.to_numpy()
    wh_pos = sl.wheel.position.to_numpy()

    # Generate state predictions
    ibl_smooth_markers_file = extract_marker_data(eid, l_thresh_smooth, view, paw, smooth=True)
    data_gen_ibl_smooth = create_data_generator(sess_id, ibl_smooth_markers_file)
    states_ibl_smooth = np.clip(get_states(model, data_gen_ibl_smooth), 1, 4)

    # Identify state transitions
    state_info = extract_state_runs(states_ibl_smooth, min_length, transition_threshold)
    transitions = state_info["transitions"]

    # Convert wheel times to frames
    wh_times_frames = np.interp(times, wh_times, np.arange(len(wh_times)))[:-4]

    # Plot wheel position around state transitions
    fig, ax = plt.subplots(figsize=(10, 6))

    for transition_frame, curr_state, next_state in transitions:
        start_frame = max(0, transition_frame - window_size // 2)
        end_frame = min(len(states_ibl_smooth), transition_frame + window_size // 2)
        window_frames = np.arange(start_frame, end_frame)

        wheel_window = np.interp(window_frames, wh_times_frames, wh_pos)
        
        ax.plot(window_frames - transition_frame, wheel_window, label=f'Transition {curr_state}->{next_state}')

    ax.axvline(x=0, color='r', linestyle='--')
    ax.set_xlabel('Frames around transition')
    
    ax.set_ylabel('Wheel position')
    ax.legend()
    ax.set_title('Wheel Position Around State Transitions')
    plt.show()

plot_wheel_around_state_transitions(one, eid, model, sess_id)

In [None]:
# (INCOMPLETE)

def extract_state_runs(states, min_length=20):
    K = int(np.max(states) + 1)
    state_snippets = [[] for _ in range(K)]
    end_indices = [[] for _ in range(K)]

    buffer = 100
    
    i_beg = buffer
    curr_state = states[i_beg]
    curr_len = 1
    for i in range(i_beg + 1, len(states) - buffer):
        next_state = states[i]
        if next_state != curr_state:
            # record indices if state duration long enough
            if curr_len >= min_length:
                state_snippets[curr_state].append(np.arange(i_beg, i))
                end_indices[curr_state].append(i - 1)  # Store end index of the chunk
            i_beg = i
            curr_state = next_state
            curr_len = 1
        else:
            curr_len += 1
    # end of trial cleanup
    if curr_len >= min_length:
        state_snippets[curr_state].append(np.arange(i_beg, i))
        end_indices[curr_state].append(i - 1)
    return state_snippets, end_indices

def plot_absolute_wheel_position_conditioned_on_frequent_transitions(eid, wh_times, wh_pos, states, times, pre_window=(-0.25, 0), post_window=(0, 1.5), transition_threshold=5, sampling_rate=60, min_chunk_length=20, max_frame_distance=30):
    sess_id = dropbox_marker_paths[eid]
    state_labels = ['Still', 'Move', 'Wheel Turn', 'Groom']
    cmap = sns.color_palette("tab10", n_colors=4)

    # Resample wheel data at trial times
    interpolator = interp1d(wh_times, wh_pos, fill_value='extrapolate')
    wh_pos_resampled = interpolator(times)

    # Define the windows
    num_pre_frames = int((pre_window[1] - pre_window[0]) * sampling_rate)
    num_post_frames = int((post_window[1] - post_window[0]) * sampling_rate)

    # Extract state runs
    state_snippets, end_indices = extract_state_runs(states, min_length=min_chunk_length)

    # Identify transitions between contiguous chunks of different states within the frame distance
    transitions = []

    for i in range(len(state_snippets)):
        for chunk, end_idx in zip(state_snippets[i], end_indices[i]):
            # Look for chunks in other states that start within the max_frame_distance
            for j in range(len(state_snippets)):
                if i != j:  # Ensure different states
                    for next_chunk in state_snippets[j]:
                        start_idx = next_chunk[0]
                        if 0 < start_idx - end_idx <= max_frame_distance:
                            transitions.append((i, j, chunk, next_chunk, times[start_idx]))

    # Count transitions
    transition_counts = Counter((pre, post) for pre, post, _, _, _ in transitions)
    
    # Filter transitions that occur frequently enough
    frequent_transitions = [(pre, post) for (pre, post), count in transition_counts.items() if count >= transition_threshold]

    max_abs_val = 0

    for transition in frequent_transitions:
        pre_state, post_state = transition
        transition_times = [time for pre, post, _, _, time in transitions if pre == pre_state and post == post_state]

        plt.figure(figsize=(15, 10))

        for onset in transition_times:
            mask = (times >= (onset + pre_window[0])) & (times <= (onset + post_window[1]))
            if np.sum(mask) == 0:
                continue

            onset_times = times[mask] - onset
            onset_wheel_pos = wh_pos_resampled[mask]
            onset_states = states[mask]

            # Interpolate to have a consistent number of frames
            interp_times = np.linspace(pre_window[0], post_window[1], num_pre_frames + num_post_frames)
            interp_wheel_pos = np.interp(interp_times, onset_times, onset_wheel_pos)
            interp_states = np.interp(interp_times, onset_times, onset_states, left=-1, right=-1)

            # Update the maximum absolute value for y-axis limits
            max_abs_val = max(max_abs_val, np.max(np.abs(interp_wheel_pos)))

            # Determine the color based on the state transition
            color = cmap[pre_state]

            # Plot each trial's absolute wheel position
            plt.plot(interp_times, interp_wheel_pos, alpha=0.5, color=color)

        # Set symmetrical y-axis limits
        plt.ylim(-max_abs_val - 0.1, max_abs_val)

        # Add a vertical line at transition time
        plt.axvline(x=0, color='k', linestyle='--')

        # Add labels and title
        plt.xlabel('Time (seconds)')
        plt.ylabel('Absolute Wheel Position (radians)')
        plt.title(f'Absolute Wheel Position Over Time Conditioned on Transition {state_labels[pre_state]} to {state_labels[post_state]} for Session: {sess_id}')

        # Add legend
        legend_elements = [
            Line2D([0], [0], color='k', lw=2, linestyle='--', label='State Transition'),
            Line2D([0], [0], color=cmap[0], lw=2, label='Still'),
            Line2D([0], [0], color=cmap[1], lw=2, label='Move'),
            Line2D([0], [0], color=cmap[2], lw=2, label='Wheel Turn'),
            Line2D([0], [0], color=cmap[3], lw=2, label='Groom')
        ]
        plt.legend(handles=legend_elements)

        # Show plot
        plt.show()

# Session Loader Parameters
l_thresh_smooth = 0.9
view = 'left'
paw = 'paw_r'

# Load trial times
sl = SessionLoader(one=one, eid=eid)
sl.load_pose(likelihood_thr=l_thresh_smooth, views=[view])
times = sl.pose[f'{view}Camera'].times.to_numpy()
markers = sl.pose[f'{view}Camera'].loc[:, (f'{paw}_x', f'{paw}_y')].to_numpy()

trials = one.load_object(eid, 'trials')

# Load wheel data
sl.load_wheel()
wh_times = sl.wheel.times.to_numpy()
wh_pos = sl.wheel.position.to_numpy()

# Generate state predictions
ibl_smooth_markers_file = extract_marker_data(eid, l_thresh_smooth, view, paw, smooth=True)
data_gen_ibl_smooth = create_data_generator(sess_id, ibl_smooth_markers_file)
states_ibl_smooth = np.clip(get_states(model, data_gen_ibl_smooth), 1, 4)

# Convert wheel times to frames
wh_times_frames = np.interp(times, wh_times, np.arange(len(wh_times)))[:-4]
plot_absolute_wheel_position_conditioned_on_frequent_transitions(eid, wh_times, wh_pos, states_ibl_smooth, times[:len(states_ibl_smooth)])


In [None]:
# (INCOMPLETE)

def extract_state_runs(states, min_length=20):
    K = int(np.max(states) + 1)
    state_snippets = [[] for _ in range(K)]

    buffer = 100
    
    i_beg = buffer
    curr_state = states[i_beg]
    curr_len = 1
    for i in range(i_beg + 1, len(states) - buffer):
        next_state = states[i]
        if next_state != curr_state:
            if curr_len >= min_length:
                state_snippets[curr_state].append(np.arange(i_beg, i))
            i_beg = i
            curr_state = next_state
            curr_len = 1
        else:
            curr_len += 1
    if curr_len >= min_length:
        state_snippets[curr_state].append(np.arange(i_beg, i))
    return state_snippets

def plot_relative_wheel_positions_around_transitions(eid, wh_times, wh_pos, states, times, pre_window=(-0.25, 0), post_window=(0, 1.5), max_frame_distance=30):
    
    state_labels = ['Still', 'Move', 'Wheel Turn', 'Groom']
    cmap = plt.get_cmap("tab10")

    interpolator = interp1d(wh_times, wh_pos, fill_value='extrapolate')
    wh_pos_resampled = interpolator(times)

    state_snippets = extract_state_runs(states, min_length=20)

    transitions = []

    for i in range(len(state_snippets)):
        for chunk in state_snippets[i]:
            end_idx = chunk[-1]
            for j in range(len(state_snippets)):
                if i != j:  
                    for next_chunk in state_snippets[j]:
                        start_idx = next_chunk[0]
                        if 0 < start_idx - end_idx <= max_frame_distance:
                            transitions.append((i, j, chunk, next_chunk, times[start_idx]))

    transition_groups = defaultdict(list)
    for pre_state, post_state, _, _, _ in transitions:
        transition_groups[(pre_state, post_state)].append(_)

    for (pre_state, post_state), transition_times in transition_groups.items():
        plt.figure(figsize=(15, 8))

        legend_added = {'pre': False, 'post': False}
        
        for transition_time in transition_times:
            pre_mask = (times >= (transition_time + pre_window[0])) & (times <= (transition_time + pre_window[1]))
            if np.sum(pre_mask) > 0:
                pre_times = times[pre_mask] - transition_time
                pre_wheel_pos = wh_pos_resampled[pre_mask]
                relative_pre_wheel_pos = pre_wheel_pos - pre_wheel_pos[0]
                plt.plot(pre_times, relative_pre_wheel_pos, alpha=0.5, color=cmap(pre_state), linestyle='--' if not legend_added['pre'] else '', label=f'{state_labels[pre_state]} (Pre-Transition)')
                legend_added['pre'] = True
            
            post_mask = (times >= (transition_time + post_window[0])) & (times <= (transition_time + post_window[1]))
            if np.sum(post_mask) > 0:
                post_times = times[post_mask] - transition_time
                post_wheel_pos = wh_pos_resampled[post_mask]
                relative_post_wheel_pos = post_wheel_pos - post_wheel_pos[0]
                plt.plot(post_times, relative_post_wheel_pos, alpha=0.5, color=cmap(post_state), linestyle='-' if not legend_added['post'] else '', label=f'{state_labels[post_state]} (Post-Transition)')
                legend_added['post'] = True

        plt.xlabel('Time (seconds)')
        plt.ylabel('Relative Wheel Position (radians)')
        plt.title(f'Relative Wheel Position Around Transition {state_labels[pre_state]} to {state_labels[post_state]}')
        
        plt.legend(loc='best')
        
        plt.show()


l_thresh_smooth = 0.9
view = 'left'
paw = 'paw_r'

sl = SessionLoader(one=one, eid=eid)
sl.load_pose(likelihood_thr=l_thresh_smooth, views=[view])
times = sl.pose[f'{view}Camera'].times.to_numpy()
markers = sl.pose[f'{view}Camera'].loc[:, (f'{paw}_x', f'{paw}_y')].to_numpy()

trials = one.load_object(eid, 'trials')

sl.load_wheel()
wh_times = sl.wheel.times.to_numpy()
wh_pos = sl.wheel.position.to_numpy()

ibl_smooth_markers_file = extract_marker_data(eid, l_thresh_smooth, view, paw, smooth=True)
data_gen_ibl_smooth = create_data_generator(sess_id, ibl_smooth_markers_file)
states_ibl_smooth = np.clip(get_states(model, data_gen_ibl_smooth), 1, 4)

wh_times_frames = np.interp(times, wh_times, np.arange(len(wh_times)))[:-4]

plot_relative_wheel_positions_around_transitions(eids[0], wh_times, wh_pos, states_ibl_smooth, times)

In [None]:
# (INCOMPLETE)

import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d
from collections import defaultdict

def extract_state_runs(states, min_length=20):
    K = int(np.max(states) + 1)
    state_snippets = [[] for _ in range(K)]

    buffer = 100
    
    i_beg = buffer
    curr_state = states[i_beg]
    curr_len = 1
    for i in range(i_beg + 1, len(states) - buffer):
        next_state = states[i]
        if next_state != curr_state:
            if curr_len >= min_length:
                state_snippets[curr_state].append(np.arange(i_beg, i))
            i_beg = i
            curr_state = next_state
            curr_len = 1
        else:
            curr_len += 1

    if curr_len >= min_length:
        state_snippets[curr_state].append(np.arange(i_beg, i))
    return state_snippets

def plot_relative_wheel_positions_around_transitions(eid, wh_times, wh_pos, states, times, pre_window=(-0.25, 0), post_window=(0, 1.5), max_frame_distance=30):
    state_labels = ['Still', 'Move', 'Wheel Turn', 'Groom']
    cmap = plt.get_cmap("tab10")

    interpolator = interp1d(wh_times, wh_pos, fill_value='extrapolate')
    wh_pos_resampled = interpolator(times)

    state_snippets = extract_state_runs(states, min_length=20)

    transitions = []

    for i in range(len(state_snippets)):
        for chunk in state_snippets[i]:
            end_idx = chunk[-1]
            for j in range(len(state_snippets)):
                if i != j:
                    for next_chunk in state_snippets[j]:
                        start_idx = next_chunk[0]
                        if 0 < start_idx - end_idx <= max_frame_distance:
                            transitions.append((i, j, chunk, next_chunk, times[start_idx]))

    # Filter transitions for Still to Move/Wheel Turn and Move/Wheel Turn to Still
    filtered_transitions = [
        (pre_state, post_state, chunk, next_chunk, transition_time)
        for pre_state, post_state, chunk, next_chunk, transition_time in transitions
        if (pre_state == 1 and post_state in [2, 3]) or (pre_state in [2, 3] and post_state == 1)
    ]

    # Group transitions by (pre_state, post_state)
    transition_groups = defaultdict(list)
    for pre_state, post_state, _, _, transition_time in filtered_transitions:
        transition_groups[(pre_state, post_state)].append(transition_time)

    # Plot wheel positions for each transition group
    for (pre_state, post_state), transition_times in transition_groups.items():
        plt.figure(figsize=(15, 8))

        # Plot all transitions for this state pair
        for transition_time in transition_times:
            pre_mask = (times >= (transition_time + pre_window[0])) & (times <= (transition_time + pre_window[1]))
            if np.sum(pre_mask) > 0:
                pre_times = times[pre_mask] - transition_time
                pre_wheel_pos = wh_pos_resampled[pre_mask]
                relative_pre_wheel_pos = pre_wheel_pos - pre_wheel_pos[0]
                plt.plot(pre_times, relative_pre_wheel_pos, alpha=0.5, color=cmap(pre_state-1))

            # Plot post-transition chunk
            post_mask = (times >= (transition_time + post_window[0])) & (times <= (transition_time + post_window[1]))
            if np.sum(post_mask) > 0:
                post_times = times[post_mask] - transition_time
                post_wheel_pos = wh_pos_resampled[post_mask]
                relative_post_wheel_pos = post_wheel_pos - pre_wheel_pos[0]
                plt.plot(post_times, relative_post_wheel_pos, alpha=0.5, color=cmap(post_state-1))

        plt.xlabel('Time (seconds)')
        plt.ylabel('Relative Wheel Position (radians)')
        plt.title(f'Relative Wheel Position Around Transition {state_labels[pre_state-1]} to {state_labels[post_state-1]}')

        plt.legend([
            f'{state_labels[pre_state-1]} (Pre-Transition)',
            f'{state_labels[post_state-1]} (Post-Transition)'
        ], loc='best')
        plt.show()

l_thresh_smooth = 0.9
view = 'left'
paw = 'paw_r'

sl = SessionLoader(one=one, eid=eid)
sl.load_pose(likelihood_thr=l_thresh_smooth, views=[view])
times = sl.pose[f'{view}Camera'].times.to_numpy()
markers = sl.pose[f'{view}Camera'].loc[:, (f'{paw}_x', f'{paw}_y')].to_numpy()

trials = one.load_object(eid, 'trials')

sl.load_wheel()
wh_times = sl.wheel.times.to_numpy()
wh_pos = sl.wheel.position.to_numpy()

ibl_smooth_markers_file = extract_marker_data(eid, l_thresh_smooth, view, paw, smooth=True)
data_gen_ibl_smooth = create_data_generator(sess_id, ibl_smooth_markers_file)
states_ibl_smooth = np.clip(get_states(model, data_gen_ibl_smooth), 1, 4)

wh_times_frames = np.interp(times, wh_times, np.arange(len(wh_times)))[:-4]

sl.load_wheel()
wh_times = sl.wheel.times.to_numpy()
wh_vel_oversampled = sl.wheel.velocity.to_numpy()

interpolator = interp1d(wh_times, wh_vel_oversampled, fill_value='extrapolate')
wh_vel = interpolator(times)

plot_relative_wheel_positions_around_transitions(eids[0], wh_times, wh_pos, states_ibl_smooth, times)


In [None]:
# (INCOMPLETE)

import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d
from collections import defaultdict

def extract_state_runs(states, min_length=20):
    """
    Find contiguous chunks of data with the same state.
    """
    K = int(np.max(states) + 1)
    state_snippets = [[] for _ in range(K)]
    
    buffer = 100
    
    i_beg = buffer
    curr_state = states[i_beg]
    curr_len = 1
    for i in range(i_beg + 1, len(states) - buffer):
        next_state = states[i]
        if next_state != curr_state:
            if curr_len >= min_length:
                state_snippets[curr_state].append(np.arange(i_beg, i))
            i_beg = i
            curr_state = next_state
            curr_len = 1
        else:
            curr_len += 1
    if curr_len >= min_length:
        state_snippets[curr_state].append(np.arange(i_beg, i))
    return state_snippets

def plot_paw_speed_around_transitions(eid, paw_x_vel, paw_y_vel, states, times, pre_window=(-0.25, 0), post_window=(0, 0.25), max_frame_distance=30):
    """
    Plot relative paw speed around Still to Move and Move to Still transitions.
    """
    state_labels = ['Still', 'Move', 'Wheel Turn', 'Groom']
    cmap = plt.get_cmap("tab10")

    paw_speed = np.sqrt(paw_x_vel ** 2 + paw_y_vel ** 2)

    state_snippets = extract_state_runs(states, min_length=20)

    transitions = []

    for i in range(len(state_snippets)):
        for chunk in state_snippets[i]:
            end_idx = chunk[-1]
            for j in range(len(state_snippets)):
                if i != j:
                    for next_chunk in state_snippets[j]:
                        start_idx = next_chunk[0]
                        if 0 < start_idx - end_idx <= max_frame_distance:
                            transitions.append((i, j, chunk, next_chunk, times[start_idx]))

    # Filter transitions for Still to Move and Move to Still
    filtered_transitions = [
        (pre_state, post_state, chunk, next_chunk, transition_time)
        for pre_state, post_state, chunk, next_chunk, transition_time in transitions
        if (pre_state == 1 and post_state == 2) or (pre_state == 2 and post_state == 1)
    ]

    # Group transitions by (pre_state, post_state)
    transition_groups = defaultdict(list)
    for pre_state, post_state, _, _, transition_time in filtered_transitions:
        transition_groups[(pre_state, post_state)].append(transition_time)

    # Plot paw speeds for each transition group
    for (pre_state, post_state), transition_times in transition_groups.items():
        plt.figure(figsize=(15, 8))

        # Plot all transitions for this state pair
        for transition_time in transition_times:
            pre_mask = (times >= (transition_time + pre_window[0])) & (times <= (transition_time + pre_window[1]))
            post_mask = (times >= (transition_time + post_window[0])) & (times <= (transition_time + post_window[1]))

            if np.sum(pre_mask) > 0:
                pre_times = times[pre_mask] - transition_time
                pre_paw_speed = paw_speed[pre_mask]
                relative_pre_paw_speed = pre_paw_speed
                plt.plot(pre_times, relative_pre_paw_speed, alpha=0.5, color=cmap(pre_state - 1))

            if np.sum(post_mask) > 0:
                post_times = times[post_mask] - transition_time
                post_paw_speed = paw_speed[post_mask]
                relative_post_paw_speed = post_paw_speed
                plt.plot(post_times, relative_post_paw_speed, alpha=0.5, color=cmap(post_state - 1))

        plt.xlabel('Time (seconds)')
        plt.ylabel('Relative Paw Speed')
        plt.title(f'Relative Paw Speed Around Transition {state_labels[pre_state-1]} to {state_labels[post_state-1]}')
        plt.legend([
            f'{state_labels[pre_state-1]} (Pre-Transition)',
            f'{state_labels[post_state-1]} (Post-Transition)'
        ], loc='best')
        plt.show()

def plot_wheel_speed_around_transitions(eid, wh_times, wh_vel, states, times, pre_window=(-0.25, 0), post_window=(0, 0.25), max_frame_distance=30):
    """
    Plot relative wheel speed around Still to Wheel Turn and Wheel Turn to Still transitions.
    """
    state_labels = ['Still', 'Move', 'Wheel Turn', 'Groom']
    cmap = plt.get_cmap("tab10")

    if len(wh_times) > len(wh_vel):
        wh_times = wh_times[:len(wh_vel)]
    elif len(wh_vel) > len(wh_times):
        wh_vel = wh_vel[:len(wh_times)]

    interpolator = interp1d(wh_times, wh_vel, fill_value='extrapolate')
    wh_speed_resampled = interpolator(times)

    state_snippets = extract_state_runs(states, min_length=20)

    transitions = []

    for i in range(len(state_snippets)):
        for chunk in state_snippets[i]:
            end_idx = chunk[-1]
            for j in range(len(state_snippets)):
                if i != j:
                    for next_chunk in state_snippets[j]:
                        start_idx = next_chunk[0]
                        if 0 < start_idx - end_idx <= max_frame_distance:
                            transitions.append((i, j, chunk, next_chunk, times[start_idx]))

    # Filter transitions for Still to Wheel Turn and Wheel Turn to Still
    filtered_transitions = [
        (pre_state, post_state, chunk, next_chunk, transition_time)
        for pre_state, post_state, chunk, next_chunk, transition_time in transitions
        if (pre_state == 1 and post_state == 3) or (pre_state == 3 and post_state == 1)
    ]

    # Group transitions by (pre_state, post_state)
    transition_groups = defaultdict(list)
    for pre_state, post_state, _, _, transition_time in filtered_transitions:
        transition_groups[(pre_state, post_state)].append(transition_time)

    # Plot wheel speeds for each transition group
    for (pre_state, post_state), transition_times in transition_groups.items():
        plt.figure(figsize=(15, 8))

        for transition_time in transition_times:
            pre_mask = (times >= (transition_time + pre_window[0])) & (times <= (transition_time + pre_window[1]))
            post_mask = (times >= (transition_time + post_window[0])) & (times <= (transition_time + post_window[1]))

            if np.sum(pre_mask) > 0:
                pre_times = times[pre_mask] - transition_time
                pre_wheel_speed = wh_speed_resampled[pre_mask]
                relative_pre_wheel_speed = pre_wheel_speed
                plt.plot(pre_times, relative_pre_wheel_speed, alpha=0.5, color=cmap(pre_state - 1))

            if np.sum(post_mask) > 0:
                post_times = times[post_mask] - transition_time
                post_wheel_speed = wh_speed_resampled[post_mask]
                relative_post_wheel_speed = post_wheel_speed
                plt.plot(post_times, relative_post_wheel_speed, alpha=0.5, color=cmap(post_state - 1))

        plt.xlabel('Time (seconds)')
        plt.ylabel('Relative Wheel Speed')
        plt.title(f'Relative Wheel Speed Around Transition {state_labels[pre_state-1]} to {state_labels[post_state-1]}')
        plt.legend([
            f'{state_labels[pre_state-1]} (Pre-Transition)',
            f'{state_labels[post_state-1]} (Post-Transition)'
        ], loc='best')
        plt.show()



# Session Loader Parameters
l_thresh_smooth = 0.9
view = 'left'
paw = 'paw_r'

# Load times
sl = SessionLoader(one=one, eid=eid)
sl.load_pose(likelihood_thr=l_thresh_smooth, views=[view])
times = sl.pose[f'{view}Camera'].times.to_numpy()


# Load marker data
file = extract_marker_data(eid, l_thresh_smooth, view, paw, smooth=True)
df = pd.read_csv(file)
wh_vel = df['wheel_vel'].to_numpy()
wh_speed = np.abs(wh_vel)


"""

# Load marker data
sl.load_wheel()
wh_times = sl.wheel.times.to_numpy()
wh_vel_oversampled = sl.wheel.velocity.to_numpy()

# Resample wheel data at marker times
interpolator = interp1d(wh_times, wh_vel_oversampled, fill_value='extrapolate')
wh_vel = interpolator(times)"""

plot_paw_speed_around_transitions(eid, paw_x_vel, paw_y_vel, states_ibl_smooth, times)
plot_wheel_speed_around_transitions(eid, wh_times, wh_vel, states_ibl_smooth, times)