## Import Modules

In [1]:
import os, re, glob, cv2, csv
import numpy as np
import cupy as cp
from cupyx.scipy import ndimage as cpx_ndimage  # Import CuPy's GPU ndimage module
import tifffile as tiff
import scipy as sp
from scipy import ndimage, io as sio
from scipy.ndimage import maximum_filter, label, find_objects
from scipy.stats import chi2, lognorm, poisson, norm
from scipy.optimize import curve_fit
from skimage.feature import peak_local_max
from natsort import natsorted
import trackpy as tp
import pandas as pd
import multiprocessing as mp
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from matplotlib.widgets import Slider
from matplotlib import animation, rc
import matplotlib.colors as mcolors
from tqdm import tqdm
from IPython.display import display, Image

# Set up matplotlib for animations
rc('animation', html='jshtml')

# Suppress specific warnings
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)  # ignore warnings for specific matplotlib commands which will be outdated soon


## Define functions

In [10]:
def convert_to_code_path(windows_path):
    """
    Converts a Windows file path with single backslashes to a format with double backslashes for use in Python code.

    Parameters:
    - windows_path: str
        The original Windows file path (e.g., 'C:\\Users\\YourName\\Folder').

    Returns:
    - str
        The modified path with double backslashes, suitable for Python (e.g., 'C:\\\\Users\\\\YourName\\\\Folder').
    """
    # Replace each single backslash with a double backslash
    return windows_path.replace("\\", "\\\\")



def cluster_optical_flow(us, vs, kernel_size=11, kernel_type='gaussian', normalize=False, verbose=False):
    """
    Applies a convolution to horizontal (u) and vertical (v) components of optical flow data and clusters the flow.

    Parameters:
    - us: list of 2D numpy arrays
        List of horizontal components of optical flow (u) for each frame.
    - vs: list of 2D numpy arrays
        List of vertical components of optical flow (v) for each frame.
    - kernel_type: str, optional, default='gaussian'
        Type of kernel to use for clustering. Options are 'gaussian' or 'uniform'.
    - kernel_size: int, optional, default=11
        Size of the kernel (must be odd).
    - normalize: bool, optional, default=False
        Whether to normalize the u and v vectors to avoid magnitude differences.
    - verbose: bool, optional, default=False
        If True, shows progress during computation.

    Returns:
    - cluster: list of 2D numpy arrays
        List of clustered optical flow arrays for each frame.
    """
    
    # Ensure the kernel size is odd
    if kernel_size % 2 == 0:
        print(kernel_size)
        raise ValueError("kernel_size must be an odd integer.")
    
    # Create the kernel based on the specified type
    N = kernel_size
    if kernel_type == 'gaussian':
        gaussian_window = sp.signal.windows.gaussian(N, std=5, sym=True).reshape(-1, 1)
        gaussian_kernel = gaussian_window @ gaussian_window.T
        gaussian_kernel[N // 2, N // 2] = 0  # Set the center value to 0 to avoid self-weighting
        kernel_gpu = cp.asarray(gaussian_kernel)  # Convert to CuPy array for GPU usage

    elif kernel_type == 'uniform':
        kernel_gpu = cp.ones((N, N), dtype=cp.float32)
        kernel_gpu[N // 2, N // 2] = 0  # Center weight is set to zero
        kernel_gpu = kernel_gpu / cp.sum(kernel_gpu)  # Normalize the kernel

    else:
        print(kernel_type)
        raise ValueError("Unsupported kernel_type. Choose 'gaussian' or 'uniform'.")

    # List to store the clustered results
    cluster = []

    # Iterate over each u, v optical flow component
    for u, v in tqdm(zip(us, vs), disable=(not verbose)):
        u_gpu = cp.asarray(u)  # Move u to GPU
        v_gpu = cp.asarray(v)  # Move v to GPU

        # Normalize the u and v flow vectors if required
        if normalize:
            norm = cp.sqrt(u_gpu**2 + v_gpu**2)  # Calculate the magnitude of flow
            norm[norm == 0] = 1.0  # Prevent division by zero for zero flow vectors
            u_gpu = u_gpu / norm
            v_gpu = v_gpu / norm
            u_gpu[cp.isnan(u_gpu)] = 0  # Handle NaNs after normalization
            v_gpu[cp.isnan(v_gpu)] = 0

        # Convolve the flow fields using GPU-accelerated convolution
        u_conv = cpx_ndimage.convolve(u_gpu, kernel_gpu, mode='constant', cval=0.0)
        v_conv = cpx_ndimage.convolve(v_gpu, kernel_gpu, mode='constant', cval=0.0)

        # Multiply the original flow fields by the convolved result
        u_gpu = u_gpu * u_conv
        v_gpu = v_gpu * v_conv

        # Append the result (u + v) to the cluster list and move data back to CPU
        cluster.append(cp.asnumpy(u_gpu + v_gpu))

    return cluster


 
def process_optical_flow(data_folder, file_name, kernel_type='gaussian', kernel_size=11, 
                         normalize=False, verbose=False):
    """
    Function to process optical flow and cluster the results.
    
    Args:
    - data_folder: Path to the folder containing the data.
    - file_name: Name of the .tiff image file.
    - mask_type: The type of mask to load ('neural_mask' or 'cell_mask'). Default is 'neural_mask'.
    - kernel_type: Type of kernel to use in clustering. Default is 'gaussian'.
    - kernel_size: Size of the kernel to use in clustering. Default is 11.
    - normalize: Boolean flag to indicate whether to normalize the flow. Default is False.
    - verbose: Boolean flag for verbosity. Default is False.
    
    Returns:
    - cluster: The clustered optical flow data.
    """
    
    # Define paths
    image_path = os.path.join(data_folder, file_name)
    of_path = os.path.join(data_folder, 'Op_flow')

    # Load the image
    img = tiff.imread(image_path)

    # Load the mask
    mask_path = os.path.join(data_folder, 'neural_mask.mat')
    mask = sio.loadmat(mask_path)['neural_mask']
    mask = mask.astype(float)
    mask[mask == 0] = np.NaN

    try:
        # Get the optical flow file list
        of_list = natsorted(os.listdir(of_path))
    
        # Get number of frames and span dimensions from the first optical flow file
        n_frames = len(of_list)
    
        # Check the condition for the type of optical flow file (.mat or .npz)
        if '.mat' in of_list[0]:
            y_span = np.shape(sio.loadmat(of_path+'\\0.mat')['vy'])[0]
            x_span = np.shape(sio.loadmat(of_path+'\\0.mat')['vy'])[1]
            file_flag = 0
        elif '.npz' in of_list[0]:
            y_span = np.shape(np.load(of_path+'\\0.npz')['vy'])[0]
            x_span = np.shape(np.load(of_path+'\\0.npz')['vy'])[1]
            file_flag = 1
        else:
            print ('Unknown file type found!')
    
        # Initialize arrays to store velocity fields
        vy_all = np.zeros((n_frames, y_span, x_span))
        vx_all = np.zeros((n_frames, y_span, x_span))
    
        # Load all velocity fields
        if file_flag == 0:
            for i in tqdm(range(n_frames), desc="Loading optical flow data"):
                flow_data = sio.loadmat(os.path.join(of_path, of_list[i]))
                vy_all[i, :, :] = flow_data['vy']
                vx_all[i, :, :] = flow_data['vx']
            
        elif file_flag == 1:
            for i in tqdm(range(n_frames), desc="Loading optical flow data"):
                flow_data = np.load(os.path.join(of_path, of_list[i]))
                vy_all[i, :, :] = flow_data['vy']
                vx_all[i, :, :] = flow_data['vx']
    
        # Cluster the optical flow (assuming `cluster_optical_flow` is predefined)
        cluster = cluster_optical_flow(vx_all, vy_all, kernel_size=kernel_size, kernel_type=kernel_type, normalize=normalize, 
                                       verbose=verbose)
    
        return cluster, mask, n_frames
    except FileNotFoundError as e:
        print(f"Error: {e}")
        



def detect_and_track_particles(cluster, mask, n_frames, diameter=5, minmass=1.0, separation=15, search_range=5):
    """
    Detects particles in each frame of the clustered optical flow data and tracks them across frames.
    
    Args:
    - cluster: 3D array of clustered optical flow data.
    - mask: 2D array mask to apply to each frame (NaNs are applied as zero).
    - n_frames: Number of frames in the 3D stack.
    - diameter: Approximate size of the particle in pixels. Default is 5.
    - minmass: Minimum integrated brightness (mass) of a particle to be considered. Default is 1.0.
    - separation: Minimum separation between particles in pixels. Default is 15.
    - search_range: Maximum displacement between frames in pixels for tracking. Default is 5.
    
    Returns:
    - tp_trajectories: DataFrame with tracked particle positions over time.
    """
    
    # A list to store particle locations for each frame
    particle_positions = []

    # Loop through each 2D frame to detect particles
    for t in tqdm(range(n_frames), desc="Detecting particles"):
        frame = np.nan_to_num(cluster[t] * mask, nan=0)  # Get the 2D frame from the 3D stack and apply mask
        
        # Detect particles with sub-pixel accuracy in this frame
        particles = tp.locate(frame, diameter=diameter, minmass=minmass, separation=separation)
        
        # Add the frame number to each detected particle's data
        particles['frame'] = t
        
        # Store the detected particles
        particle_positions.append(particles)

    # Combine all particle positions across frames into a single DataFrame
    peaks_df = pd.concat(particle_positions, ignore_index=True)

    # Use trackpy to link the particles over time
    tp_trajectories = tp.link(peaks_df, search_range=search_range, memory=1)

    return tp_trajectories



def visualize_and_save_trajectories(tp_trajectories_filtered, img, data_folder, figsize=(12, 12), dpi=600):
    """
    Visualize filtered particle trajectories superimposed on an image and save the plot as a high-resolution image.

    Args:
    - tp_trajectories_filtered: DataFrame containing filtered trajectories with 'x', 'y', and 'particle' columns.
    - img: The image array (e.g., the first frame of the time-lapse) to use as a background.
    - data_folder: Path to the folder where the output image will be saved.
    - filename: Name of the file to save the image as. Default is 'filtered_tracks.png' or 'filtered_tracks.svg'.
    - figsize: Tuple specifying the figure size. Default is (8, 12).
    - dpi: Resolution of the saved image. Default is 600.
    """

    # Create figure and axis
    fig, ax = plt.subplots(figsize=figsize)

    # Plot the trajectories on top of the provided image
    tp.plot_traj(tp_trajectories_filtered, superimpose=img[0], ax=ax)

    # Set equal axis scaling
    ax.set_aspect('equal')

    # Display the plot
    # plt.show()

    # Save the figure as an image file in the specified folder
    fig.savefig(f'{data_folder}\\filtered_tracks.svg', format='svg', bbox_inches='tight', dpi=dpi)
    fig.savefig(f'{data_folder}\\filtered_tracks.png', format='png', bbox_inches='tight', dpi=dpi)

    print ("Tracks plotted")
    plt.close()



def plot_particle_trajectories(tp_trajectories_filtered, data_folder, colormap_name='jet', scatter_size=80, alpha=0.5, 
                               line_color='gray', dpi=600):
    """
    Plots particle trajectories as scatter points with progressive colors and dashed lines connecting them.
    
    Args:
    - tp_trajectories_filtered: DataFrame containing filtered trajectories with 'x', 'y', 'particle', and 'frame' columns.
    - data_folder: Path to save the plot as an SVG file.
    - colormap_name: Name of the colormap to use for progressive coloring. Default is 'jet'.
    - scatter_size: Size of the scatter points. Default is 80.
    - alpha: Transparency level of scatter points. Default is 0.5.
    - line_color: Color of the dashed lines connecting the points. Default is 'grey'.
    """
    
    # Create figure and axis
    fig, ax = plt.subplots(figsize=(12, 12))
    colormap = cm.get_cmap(colormap_name)

    # Set figure background color to grey
    ax.set_facecolor('grey')

    # Group by particle ID to plot each track separately with progressive colors
    for particle_id, track in tp_trajectories_filtered.groupby('particle'):
        # Normalize frame values to use for progressive color mapping
        norm = mcolors.Normalize(vmin=track['frame'].min(), vmax=track['frame'].max())

        # Plot each point with a color based on its frame
        for i in range(1, len(track)):
            x_values = [track.iloc[i-1]['x'], track.iloc[i]['x']]
            y_values = [track.iloc[i-1]['y'], track.iloc[i]['y']]
            color = colormap(norm(track.iloc[i]['frame']))  # Map frame to color
            
            # Plot segment
            ax.plot(x_values, y_values, linestyle='--', linewidth=1, color=color)

            # Scatter plot for each point with progressive color
            ax.scatter(track.iloc[i]['x'], track.iloc[i]['y'], s=scatter_size, facecolor='none', alpha=alpha, color=color)

    # Set axis labels and title
    ax.set_xlabel('X Position')
    ax.set_ylabel('Y Position')
    ax.set_title('Particle Trajectories with Progressive Colors and Dashed Lines')

    # Ensure equal axis scaling
    ax.set_aspect('equal')
    ax.invert_yaxis()

    # Save the figure as SVG with high DPI
    fig.savefig(f'{data_folder}\\scatter_tracks.svg', format='svg', bbox_inches='tight', dpi=dpi)
    print ("Tracks plotted with progressive colors")
    plt.close()



def calculate_track_metrics(tp_trajectories_filtered, resolution, frame_interval):
    """
    Calculate and store instantaneous shifts, displacements, velocities, total track properties, 
    and other metrics for each particle's trajectory.

    Parameters:
    - tp_trajectories_filtered (pandas DataFrame): Particle tracking data with 'x', 'y', 'frame', and 'particle' columns.
    - resolution (float): Spatial resolution to convert units as required (e.g., pixels to microns).
    - frame_interval (float): Time interval between frames to calculate velocities and track duration.

    Returns:
    - tp_trajectories_filtered (pandas DataFrame): Input DataFrame with added columns:
        - 'x_shift': x-direction shift per frame for each particle.
        - 'y_shift': y-direction shift per frame for each particle.
        - 'frame_shift': Number of frames between positions in the trajectory.
        - 'displacement': Frame-to-frame displacement (Euclidean distance) in pixels.
        - 'inst_velocity': Instantaneous velocity per frame (#px./frame) for each particle.
    - track_properties (pandas DataFrame): Aggregated track-level metrics per particle, including:
        - 'track_length': Total track length, i.e., cumulative distance traveled by each particle (microns).
        - 'track_duration': Total duration of each particle's trajectory (seconds).
        - 'distance': Net displacement from initial to final position (microns).
        - 'avg_velocity': Average net velocity (net displacement divided by total time, in microns/min).
        - 'inst_velocity': Average instantaneous velocity (total track length divided by total time, in microns/min).
        - 'sinuosity': Measure of track straightness (track length divided by net displacement).
    """

    # Calculate Instantaneous Shifts (x, y, and frame)
    tp_trajectories_filtered['x_shift'] = (tp_trajectories_filtered.groupby('particle')['x'].diff().fillna(0)) 
    tp_trajectories_filtered['y_shift'] = (tp_trajectories_filtered.groupby('particle')['y'].diff().fillna(0)) 
    tp_trajectories_filtered['frame_shift']  = tp_trajectories_filtered.groupby('particle')['frame'].diff().fillna(1)
    
    # Calculate instantaneous displacement and velocity
    tp_trajectories_filtered['displacement'] = (np.sqrt(tp_trajectories_filtered['x_shift']**2 + tp_trajectories_filtered['y_shift']**2))
    tp_trajectories_filtered['displacement_scaled'] = (np.sqrt(tp_trajectories_filtered['x_shift']**2 + tp_trajectories_filtered['y_shift']**2)) / resolution
    tp_trajectories_filtered['inst_velocity'] = (tp_trajectories_filtered['displacement'] /
                                                 tp_trajectories_filtered['frame_shift'])
    
    # Drop NaN values from displacement and velocity
    tp_trajectories_filtered.dropna(subset=['displacement', 'inst_velocity'], inplace=True)


    # Calculate the direction (angle) of movement at each time step
    tp_trajectories_filtered['angle'] = np.arctan2(-tp_trajectories_filtered['y_shift'], tp_trajectories_filtered['x_shift'])
    
    # Handle zero shifts (no movement)
    tp_trajectories_filtered['angle'] = np.where(
        (tp_trajectories_filtered['x_shift'] == 0) & (tp_trajectories_filtered['y_shift'] == 0), np.nan,
        tp_trajectories_filtered['angle']
    )
    
    # Calculate track-level metrics: total distance traveled (track length), track duration, distance, average velocity, and sinuosity   
    
    # Calculate track length (total distance traversed) by summing displacements for each particle
    track_properties = tp_trajectories_filtered.groupby('particle')['displacement'].sum().reset_index()   
    track_properties['track_length'] = track_properties['displacement'] / resolution
    
    # Dropping the 'displacement' column from the new dataframe
    track_properties = track_properties.drop(columns='displacement')
    
    # Calculate track duration (total frames converted to seconds)
    track_duration = tp_trajectories_filtered.groupby('particle')['frame'].nunique().reset_index(name='track_duration')
    track_duration['track_duration'] = (track_duration['track_duration']-1) * frame_interval
    
    # Merge the track_duration back into track_properties DataFrame
    track_properties = track_properties.merge(track_duration[['particle', 'track_duration']], on='particle', how='left')

    # Calculate total displacement (net distance between start and end positions)
    track_metrics = tp_trajectories_filtered.groupby('particle').agg(
        x_start=('x', 'first'),
        y_start=('y', 'first'),
        x_end=('x', 'last'),
        y_end=('y', 'last'),
        frame_start=('frame', 'first'),
        frame_end=('frame', 'last')
    ).reset_index()
    
    # Calculate total displacement (net distance between start and end positions)
    track_properties['distance'] = np.sqrt((track_metrics['x_end'] - track_metrics['x_start'])**2 +
                                            (track_metrics['y_end'] - track_metrics['y_start'])**2) / resolution
    
    # Calculate average net speed (net displacement / total time)
    track_properties['avg_velocity'] = track_properties['distance'] * 60 / track_properties['track_duration']
    
    # Calculate the average instantaneous velocity (track length / total time)
    track_properties['inst_velocity'] =  track_properties['track_length'] * 60 / track_properties['track_duration']

    # Calculate sinuosity (track length / net displacement)
    track_properties['sinuosity'] = track_properties['track_length'] / track_properties['distance']
    
    # Now track_metrics contains 'distance', 'avg_velocity', 'avg_inst_velocity', and 'sinuosity'

    # Filter for tracks with average instantaneous velocity > 0.5 pixels/frame
    track_properties_thresholded = track_properties[track_properties['inst_velocity'] > 0.5 * 60 / (resolution * frame_interval)]
    
    # Log a success message
    print("Trajectories and track properties calculated and stored.")
    
    # Return filtered results
    return tp_trajectories_filtered, track_properties_thresholded




def calculate_angular_differences_process(tp_trajectories_final, data_folder, save_path=None):
    """
    Calculate angular differences between trajectory start angles and process orientation
    from the 'LoG_orientation.mat' file.

    Parameters:
    tp_trajectories_final (DataFrame): DataFrame containing trajectory start coordinates and angles.
    data_folder (str): Path to the folder containing the 'LoG_orientation.mat' file.
    save_path (str, optional): If provided, save the results as a CSV file to the specified path.

    Returns:
    ang_diff_process (array): Array of angular differences for start angles (wrt process orientation).
    """
    # Load process orientation from the 'LoG_orientation.mat' file (process orientation is a matrix)
    mat_file_process = os.path.join(data_folder, 'LoG_orientation.mat')
    data_process = sio.loadmat(mat_file_process)
    process_angle_array = data_process['prefAng']  # Process orientation in degrees (matrix)

    # Initialize array to store process orientations for the starting coordinates
    ang_start_arr_process = np.zeros(len(tp_trajectories_final['x']))

    # Extract process orientations for the starting coordinates of trajectories
    for i in range(len(tp_trajectories_final['x'])):
        x_start = round(tp_trajectories_final['x'].iloc[i])
        y_start = round(tp_trajectories_final['y'].iloc[i])

        # Retrieve the process orientation at the trajectory's starting position
        ang_start_arr_process[i] = process_angle_array[y_start, x_start]

    # Convert trajectory angles to degrees
    angle_values_deg = np.rad2deg(tp_trajectories_final['angle'].values)  # Convert from radians to degrees

    # Calculate angular differences with respect to process orientation
    ang_diff_process = np.abs(ang_start_arr_process - angle_values_deg)

    # Normalize angular differences to [0, 180] and then collapse to [0, 90]
    ang_diff_process = np.mod(ang_diff_process, 180)  # Normalize to [0, 180]
    ang_diff_process = np.where(ang_diff_process > 90, 180 - ang_diff_process, ang_diff_process)

    # Create a DataFrame to save the results
    df = pd.DataFrame({
        'ang_diff_process': ang_diff_process
    })

    # If a save path is provided, save the DataFrame to a CSV file
    if save_path:
        df.to_csv(save_path, index=False)

    # Return the angular differences array and the DataFrame
    return ang_diff_process



def plot_angular_differences_process(ang_diff_process, tp_trajectories_final, transparency=0.9, save_path=None):
    """
    Plot the angular differences for process orientations on a polar plot with transparency.

    Parameters:
    ang_diff_process (array): Array of angular differences for process orientation.
    tp_trajectories_final (DataFrame): DataFrame containing trajectory information, including displacement.
    transparency (float): Transparency level for the histograms (default: 0.9).
    save_path (str, optional): Path to save the plot (default is None, no saving).
    """
    # Remove NaN values from the angular difference array and corresponding displacement values
    valid_indices_process = ~np.isnan(ang_diff_process)
    ang_diff_process = ang_diff_process[valid_indices_process]
    displacement_process = tp_trajectories_final['displacement'].iloc[valid_indices_process]

    # Create a polar plot
    fig, ax = plt.subplots(subplot_kw={'projection': 'polar'}, figsize=(8, 8))

    # Convert angular differences to radians for the polar plot
    ang_diff_process_rad = np.deg2rad(ang_diff_process)

    # Define number of bins
    bins = 30

    # Calculate the histogram for process orientation angular differences
    n_process, bins_process = np.histogram(ang_diff_process_rad, bins=bins)

    # Normalize to fraction of counts within each bin (i.e., divide by the total count)
    n_process = n_process / np.sum(n_process)

    # Plot angular differences for process orientation as fraction of counts
    ax.hist(ang_diff_process_rad, bins=bins_process, weights=displacement_process, alpha=transparency, color='red',
            label='wrt Neuronal Process', histtype='step', density=False)

    # Set angular limits for the plot between 0 and 90 degrees
    ax.set_thetamin(0)  # Minimum angle (0 degrees)
    ax.set_thetamax(90)  # Maximum angle (90 degrees)

    # Customize the plot title and labels
    ax.set_title("Relative Orientation of Tracks wrt Process", fontsize=16)
    ax.set_ylabel("Counts", fontsize=14)

    # Customize the legend
    ax.legend(loc='upper right', bbox_to_anchor=(1.5, 1), fontsize=14)

    # Increase the font size for the ticks
    ax.tick_params(axis='both', which='major', labelsize=14)

    # Show the plot
    plt.tight_layout()  # Adjust layout to prevent label overlap
    
    # If a save path is provided, save the plot as a .svg file
    if save_path:
        fig.savefig(save_path, format='svg')
    
    plt.close()




# Define model functions
def power_law(x, a, k):
    return a * np.power(x, k)

def exponential_decay(x, a, b):
    return a * np.exp(-b * x)


def gaussian(x, amplitude, mean, stddev):
    """
    Gaussian function.

    Args:
    - x: Input data points.
    - amplitude: Height of the peak of the Gaussian.
    - mean: Position of the center of the peak.
    - stddev: Standard deviation (width of the peak).

    Returns:
    - y: The Gaussian function evaluated at x.
    """
    return amplitude * np.exp(-((x - mean) ** 2) / (2 * stddev ** 2))


def lognormal(x, shape, loc, scale):
    """Returns the value of the log-normal probability density function."""
    return lognorm.pdf(x, shape, loc=loc, scale=scale)


def analyze_track_duration(data_folder, tp_trajectories_final):
    """
    Analyze track duration distribution by fitting power-law and exponential decay models.
    
    Parameters:
    - tp_trajectories_final (pd.DataFrame): DataFrame containing track data with 'track_duration' column in seconds.
    - data_folder (str): Folder path to save the plot and results.
    
    Returns:
    - dict: Fit results containing parameters and R² values for power-law and exponential fits.
    """
    # Calculate histogram for track duration
    counts, bin_edges = np.histogram(
        tp_trajectories_final['track_duration'], 
        bins = np.arange(15, 40, 2), 
        density=False
    )

    # Calculate bin centers and normalize counts
    bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
    total_counts = np.sum(counts)
    normalized_counts = counts / total_counts if total_counts > 0 else counts

    # Initialize fit result dictionary
    fit_results = {
        "power_law": {"params": None, "errors": None, "r2": None},
        "exponential": {"params": None, "errors": None, "r2": None}
    }

    # Fit the power-law and exponential decay to the histogram data
    try:
        # Filter out zero counts for fitting
        valid_bins = bin_centers[counts > 0]
        valid_norm_counts = normalized_counts[counts > 0]

        # Fit power-law model
        popt_power, pcov_power = curve_fit(
            power_law, valid_bins, valid_norm_counts, maxfev=10000
        )
        perr_power = np.sqrt(np.diag(pcov_power))

        # Calculate R² for power-law fit
        power_pred = power_law(bin_centers, *popt_power)
        ss_res_power = np.sum((normalized_counts - power_pred) ** 2)
        ss_tot_power = np.sum((normalized_counts - np.mean(normalized_counts)) ** 2)
        r2_power = 1 - (ss_res_power / ss_tot_power)

        # Store power-law fit results
        fit_results["power_law"]["params"] = popt_power
        fit_results["power_law"]["errors"] = perr_power
        fit_results["power_law"]["r2"] = r2_power

        # Fit exponential decay model with bounds to avoid overflow
        popt_exp, pcov_exp = curve_fit(
            exponential_decay, valid_bins, valid_norm_counts, bounds=(0, [np.inf, 1.0]), maxfev=10000
        )
        perr_exp = np.sqrt(np.diag(pcov_exp))

        # Calculate R² for exponential fit
        exp_pred = exponential_decay(bin_centers, *popt_exp)
        ss_res_exp = np.sum((normalized_counts - exp_pred) ** 2)
        ss_tot_exp = np.sum((normalized_counts - np.mean(normalized_counts)) ** 2)
        r2_exp = 1 - (ss_res_exp / ss_tot_exp)

        # Store exponential fit results
        fit_results["exponential"]["params"] = popt_exp
        fit_results["exponential"]["errors"] = perr_exp
        fit_results["exponential"]["r2"] = r2_exp

    except RuntimeError as e:
        print(f"Error fitting curves: {e}")

    # Plot Track Duration Distribution
    fig, ax = plt.subplots(figsize=(12, 8))

    # Plot normalized histogram
    ax.bar(
        bin_centers, normalized_counts, width=np.diff(bin_edges), 
        color='wheat', edgecolor='k', alpha=0.7, 
        label='Normalized Duration Histogram'
    )

    # Plot fitted power-law curve
    if fit_results["power_law"]["params"] is not None:
        ax.plot(
            bin_centers, power_law(bin_centers, *fit_results["power_law"]["params"]), marker='o', 
            color='olive', label=f'Power-law Fit: y = {fit_results["power_law"]["params"][0]:.2f} * \
            x^{fit_results["power_law"]["params"][1]:.2f}'
        )

    # Plot fitted exponential curve
    if fit_results["exponential"]["params"] is not None:
        ax.plot(
            bin_centers, exponential_decay(bin_centers, *fit_results["exponential"]["params"]), marker='o', 
            color='slateblue', label=f'Exponential Fit: y = {fit_results["exponential"]["params"][0]:.2f} *exp(-{fit_results["exponential"]["params"][1]:.2f} * x)'
        )

    # Customize plot
    ax.set_xlabel('Track Duration (s)', fontsize=20)
    ax.set_ylabel('Normalized Counts', fontsize=20)
    ax.tick_params(axis='both', which='major', labelsize=20)
    ax.tick_params(axis='both', which='minor', labelsize=8)
    ax.grid(False)
    ax.legend(fontsize=16)

    # Save the plot
    save_path = os.path.join(data_folder, 'track_durations.svg')
    if os.path.exists(save_path):
        os.remove(save_path)
    fig.savefig(save_path, format='svg', bbox_inches='tight')
    plt.close()

    # Save fit results to CSV
    results_df = pd.DataFrame({
        "Distribution": ["Power-law", "Exponential"],
        "Parameters": [
            fit_results["power_law"]["params"],
            fit_results["exponential"]["params"]
        ],
        "Uncertainties": [
            fit_results["power_law"]["errors"],
            fit_results["exponential"]["errors"]
        ],
        "R^2": [
            fit_results["power_law"]["r2"],
            fit_results["exponential"]["r2"]
        ]
    })

    csv_path = os.path.join(data_folder, 'fit_track_duration_parameters.csv')
    if os.path.exists(csv_path):
        os.remove(csv_path)

    with open(csv_path, 'w', newline='') as csvfile:
        writer = csv.writer(csvfile)
    
    results_df.to_csv(csv_path, index=False)

    # return fit_results




def plot_track_total_distance_distribution(data_folder, tp_trajectories_final):
    """
    Plots the track distance distribution as a histogram of normalized counts 
    and fits log-normal and Poisson distributions to the data.

    Parameters:
    - tp_trajectories_final (pd.DataFrame): DataFrame containing track data with 'track_length' column.
    - data_folder (str): The folder path where the plot and results will be saved.

    Returns:
    - dict: Fit results containing parameters and R² values for log-normal and Poisson fits.
    """
    # Use track lengths directly, assuming they are already in microns
    track_lengths = tp_trajectories_final['track_length']

    # Define histogram bins from 0 to 25 µm
    bins = np.linspace(0, 25, 20 + 1)  # Adjust bin range as necessary
    counts, bin_edges = np.histogram(track_lengths, bins=bins, density=False)
    normalized_counts = counts / np.sum(counts) if np.sum(counts) > 0 else counts

    # Calculate bin centers for plotting
    bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2

    # Create figure and axis
    fig, ax = plt.subplots(figsize=(12, 8))
    ax.bar(bin_centers, normalized_counts, width=np.diff(bin_edges), color='lavender', edgecolor='k', 
           align='center', label='Normalized Counts')

    # Fit results dictionary to store parameters and R² values
    fit_results = {
        "log_normal": {"params": None, "errors": None, "r2": None},
        "poisson": {"params": None, "r2": None}
    }

    try:
        # Log-normal fit
        shape, loc, scale = lognorm.fit(track_lengths, floc=0)  # Fit log-normal, fix location to 0
        lognorm_pdf = lognorm.pdf(bin_centers, shape, loc, scale)
        lognorm_pdf /= np.sum(lognorm_pdf)  # Normalize

        # Calculate R² for log-normal fit
        ss_res_lognorm = np.sum((normalized_counts - lognorm_pdf) ** 2)
        ss_tot_lognorm = np.sum((normalized_counts - np.mean(normalized_counts)) ** 2)
        r2_lognorm = 1 - (ss_res_lognorm / ss_tot_lognorm)

        # Store log-normal fit results and uncertainties (estimated using standard errors)
        fit_results["log_normal"]["params"] = (shape, loc, scale)
        fit_results["log_normal"]["errors"] = (np.sqrt(np.diag(lognorm.fit(
            track_lengths, floc=0, scale=scale, loc=loc))),)  # Example uncertainty calculation
        fit_results["log_normal"]["r2"] = r2_lognorm

        # Plot log-normal fit
        ax.plot(bin_centers, lognorm_pdf, color='teal', lw=2, marker='o', label=f'Log-normal Fit (R² = {r2_lognorm:.4f})')

        # Poisson fit (parameter is the mean of the track lengths in bins)
        poisson_lambda = np.mean(track_lengths)
        poisson_pmf = poisson.pmf(bin_centers.astype(int), poisson_lambda)
        poisson_pmf /= np.sum(poisson_pmf)  # Normalize

        # Calculate R² for Poisson fit
        ss_res_poisson = np.sum((normalized_counts - poisson_pmf) ** 2)
        ss_tot_poisson = np.sum((normalized_counts - np.mean(normalized_counts)) ** 2)
        r2_poisson = 1 - (ss_res_poisson / ss_tot_poisson)

        # Store Poisson fit results
        fit_results["poisson"]["params"] = (poisson_lambda,)
        fit_results["poisson"]["r2"] = r2_poisson

        # Plot Poisson fit
        ax.plot(bin_centers, poisson_pmf, color='crimson', lw=2, marker='o', label=f'Poisson Fit (R² = {r2_poisson:.4f})')

    except Exception as e:
        print(f"Error fitting distributions: {e}")

    # Set plot title and labels
    ax.set_xlabel('Track Distance (µm)', fontsize=20)
    ax.set_ylabel('Normalized Distance Counts', fontsize=20)
    ax.tick_params(axis='both', which='major', labelsize=20)
    ax.tick_params(axis='both', which='minor', labelsize=8)
    ax.legend(fontsize=16)

    # Save as SVG
    save_path = os.path.join(data_folder, 'track_total_distance.svg')
    if os.path.exists(save_path):
        os.remove(save_path)
    fig.savefig(save_path, format='svg', bbox_inches='tight')
    # print(f"Plot saved as SVG at: {save_path}")
    plt.close()

    # Save fit results to CSV
    results_df = pd.DataFrame({
        "Distribution": ["Log-normal", "Poisson"],
        "Parameters": [fit_results["log_normal"]["params"], fit_results["poisson"]["params"]],
        "Uncertainties": [fit_results["log_normal"]["errors"], None],
        "R^2": [fit_results["log_normal"]["r2"], fit_results["poisson"]["r2"]]
    })

    csv_path = os.path.join(data_folder, 'fit_total_distance_parameters.csv')

    if os.path.exists(csv_path):
        os.remove(csv_path)

    with open(csv_path, 'w', newline='') as csvfile:
        writer = csv.writer(csvfile)
    
    results_df.to_csv(csv_path, index=False)
    # print(f"Fit results saved as CSV at: {csv_path}")

    # return fit_results



def fit_and_plot_track_displacement_distribution(data_folder, tp_trajectories_final):
    """
    Plots the track distance distribution and fits power-law and exponential decay curves.
    
    Parameters:
    - tp_trajectories_final (pd.DataFrame): DataFrame containing track data with 'distance' column.
    
    Returns:
    - dict: Fit results containing parameters, errors, and R² values for power-law and exponential fits.
    """
    #  Use track distances directly, assuming they are already in microns
    distances = tp_trajectories_final['distance']

    # Define histogram bins from 0 to 7 µm
    nbins = 20
    bins = np.arange(0, 7 + 7 / nbins, 7 / nbins)
    counts, bin_edges = np.histogram(distances, bins=bins, density=False)
    bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
    normalized_counts = counts / np.sum(counts) if np.sum(counts) > 0 else counts

    # Fit results dictionary to store parameters, errors, and R² values
    fit_results = {
        "power_law": {"params": None, "errors": None, "r2": None},
        "exponential": {"params": None, "errors": None, "r2": None}
    }

    try:
        # Fit power-law to non-zero counts
        popt_power, pcov_power = curve_fit(power_law, bin_centers[counts > 0], normalized_counts[counts > 0], maxfev=10000)
        perr_power = np.sqrt(np.diag(pcov_power))  # Standard deviation errors

        # Calculate R² for power-law fit
        power_pred = power_law(bin_centers, *popt_power)
        ss_res_power = np.sum((normalized_counts - power_pred) ** 2)
        ss_tot_power = np.sum((normalized_counts - np.mean(normalized_counts)) ** 2)
        r2_power = 1 - (ss_res_power / ss_tot_power)

        # Store power-law fit results
        fit_results["power_law"]["params"] = popt_power
        fit_results["power_law"]["errors"] = perr_power
        fit_results["power_law"]["r2"] = r2_power

        # Fit exponential decay to non-zero counts
        popt_exp, pcov_exp = curve_fit(exponential_decay, bin_centers[counts > 0], normalized_counts[counts > 0], 
                                       maxfev=10000)
        perr_exp = np.sqrt(np.diag(pcov_exp))  # Standard deviation errors

        # Calculate R² for exponential fit
        exp_pred = exponential_decay(bin_centers, *popt_exp)
        ss_res_exp = np.sum((normalized_counts - exp_pred) ** 2)
        ss_tot_exp = np.sum((normalized_counts - np.mean(normalized_counts)) ** 2)
        r2_exp = 1 - (ss_res_exp / ss_tot_exp)

        # Store exponential fit results
        fit_results["exponential"]["params"] = popt_exp
        fit_results["exponential"]["errors"] = perr_exp
        fit_results["exponential"]["r2"] = r2_exp

    except RuntimeError as e:
        print(f"Error fitting curves: {e}")

    # Plot Track Final Distance Distribution
    fig, ax = plt.subplots(figsize=(12, 8))

    # Plot normalized histogram
    ax.bar(bin_centers, normalized_counts, width=np.diff(bin_edges), color='azure', edgecolor='k', alpha=0.7, 
           label='Normalized Displacement Histogram')

    # Plot power-law fit
    if fit_results["power_law"]["params"] is not None:
        ax.plot(bin_centers, power_law(bin_centers, *fit_results["power_law"]["params"]), color='olive', marker='o',
                label=f'Power-law Fit: y = {fit_results["power_law"]["params"][0]:.2f} * x^{fit_results["power_law"]
                ["params"][1]:.2f}')

    # Plot exponential fit
    if fit_results["exponential"]["params"] is not None:
        ax.plot(bin_centers, exponential_decay(bin_centers, *fit_results["exponential"]["params"]), 
                color='slateblue', marker='o',
                label=f'Exponential Fit: y = {fit_results["exponential"]["params"][0]:.2f} *exp(-{fit_results["exponential"]["params"][1]:.2f} * x)')

    # Customize plot
    ax.set_xlabel('Track Final Displacement (µm)', fontsize=20)
    ax.set_ylabel('Counts per Bin / Total Bins', fontsize=20)  # Update y-label
    ax.tick_params(axis='both', which='major', labelsize=20)
    ax.tick_params(axis='both', which='minor', labelsize=8)
    ax.grid(False)
    ax.legend(fontsize=16)

    # Show plot
    # plt.show()

    # Save as SVG
    save_path = data_folder + '\\track_final_distance.svg'
    if os.path.exists(save_path):
        os.remove(save_path)
    fig.savefig(save_path, format='svg', bbox_inches='tight')
    # print(f"Plot saved as SVG at: {save_path}")
    plt.close()

    # Save fit parameters and R² values to CSV
    fit_data = {
        'Fit Type': ['Power-law', 'Exponential'],
        'Parameter a': [fit_results["power_law"]["params"][0] if fit_results["power_law"]["params"] is not None else None,
                        fit_results["exponential"]["params"][0] if fit_results["exponential"]["params"] is not None else None],
        'Error a': [fit_results["power_law"]["errors"][0] if fit_results["power_law"]["errors"] is not None else None,
                     fit_results["exponential"]["errors"][0] if fit_results["exponential"]["errors"] is not None else None],
        'Parameter k/b': [fit_results["power_law"]["params"][1] if fit_results["power_law"]["params"] is not None else None,
                          fit_results["exponential"]["params"][1] if fit_results["exponential"]["params"] is not None else None],
        'Error k/b': [fit_results["power_law"]["errors"][1] if fit_results["power_law"]["errors"] is not None else None,
                       fit_results["exponential"]["errors"][1] if fit_results["exponential"]["errors"] is not None else None],
        'R²': [fit_results["power_law"]["r2"], fit_results["exponential"]["r2"]]
    }

    fit_df = pd.DataFrame(fit_data)
    fit_csv_path = data_folder + '\\fit_displacement_parameters.csv'

    if os.path.exists(fit_csv_path):
            os.remove(fit_csv_path)

    with open(fit_csv_path, 'w', newline='') as csvfile:
        writer = csv.writer(csvfile)
    
    
    fit_df.to_csv(fit_csv_path, index=False)
    # print(f"Fit parameters saved to {fit_csv_path}")
    # return fit_results



def save_fit_velocity_parameters_as_csv(gaussian_params, gaussian_r2, lognormal_params, lognormal_r2, filename):
    """
    Saves the fitting parameters for velocity distributions to a CSV file.

    Parameters:
    - gaussian_params (list): List of parameters for the Gaussian fit.
    - gaussian_r2 (float): R² value for the Gaussian fit.
    - lognormal_params (list): List of parameters for the log-normal fit.
    - lognormal_r2 (float): R² value for the log-normal fit.
    - filename (str): Path to the CSV file to save the parameters.
    """
    parameters_data = {
        'Fit Type': [],
        'Parameter': [],
        'Value': [],
        'R²': []
    }

    # Adding Gaussian fit parameters
    for i, param in enumerate(gaussian_params):
        parameters_data['Fit Type'].append('Gaussian')
        parameters_data['Parameter'].append(f'Parameter {i+1}')
        parameters_data['Value'].append(param)
        parameters_data['R²'].append(gaussian_r2 if gaussian_r2 is not None else '')

    # Adding log-normal fit parameters
    for i, param in enumerate(lognormal_params):
        parameters_data['Fit Type'].append('Log-normal')
        parameters_data['Parameter'].append(f'Parameter {i+1}')
        parameters_data['Value'].append(param)
        parameters_data['R²'].append(lognormal_r2 if lognormal_r2 is not None else '')

    # Create DataFrame and save to CSV
    parameters_df = pd.DataFrame(parameters_data)
    parameters_df.to_csv(filename, index=False)

    

def plot_velocity_distributions(data_folder, tp_trajectories_final):
    # Convert velocities to desired units (μm/min)
    avg_inst_velocity = tp_trajectories_final['avg_inst_velocity'] 
    avg_velocity = tp_trajectories_final['avg_velocity'] 
    
    # Define fixed bins for both histograms
    nbins = 20
    bins_smaller = np.arange(0, 25 + (25 - 5) / nbins, (25 - 5) / nbins)
    bins_bigger = np.arange(0, 40 + (40 - 0) / nbins, (40 - 5) / nbins)
    
    # Set up the figure
    fig, ax = plt.subplots(figsize=(12, 8))
    
    # Calculate normalized counts for average instantaneous velocity distribution
    counts_inst, bins_inst = np.histogram(avg_inst_velocity, bins=bins_bigger)
    counts_inst_normalized = counts_inst / counts_inst.sum()
    ax.hist(bins_inst[:-1], bins=bins_bigger, weights=counts_inst_normalized, color='palevioletred', 
            alpha=0.5, label='Average Speed')
    
    # Fit Gaussian to the average instantaneous velocity
    bin_centers_inst = 0.5 * (bins_inst[1:] + bins_inst[:-1])
    params_inst, _ = curve_fit(gaussian, bin_centers_inst, counts_inst_normalized, 
                               p0=[1, np.mean(avg_inst_velocity), np.std(avg_inst_velocity)])
    
    # Generate x values for the fitted Gaussian
    x_fit_inst = np.linspace(bins_inst[0], bins_inst[-1], 100)
    y_fit_inst = gaussian(x_fit_inst, *params_inst)
    
    # Plot the fitted Gaussian
    ax.plot(x_fit_inst, y_fit_inst, color='slategrey', linestyle='--', linewidth=2, label='Gaussian Fit (Average Speed)')
    
    # Calculate R^2 value for the Gaussian fit
    residuals_inst = counts_inst_normalized - gaussian(bin_centers_inst, *params_inst)
    ss_res_inst = np.sum(residuals_inst**2)
    ss_tot_inst = np.sum((counts_inst_normalized - np.mean(counts_inst_normalized))**2)
    r_squared_inst = 1 - (ss_res_inst / ss_tot_inst)

    # Fit log-normal to the average instantaneous velocity
    shape, loc, scale = lognorm.fit(avg_inst_velocity, floc=0)  # shape is the shape parameter

    # Generate y values for the fitted log-normal
    y_fit_lognorm = lognormal(x_fit_inst, shape, loc, scale)  # Correct number of arguments

    # Plot the fitted log-normal
    ax.plot(x_fit_inst, y_fit_lognorm, color='teal', linestyle='--', linewidth=2, label='Log-normal Fit (Average Speed)')
    
    # Calculate R^2 value for the log-normal fit
    residuals_lognorm = counts_inst_normalized - lognormal(bin_centers_inst, shape, loc, scale)  # Correct number of arguments
    ss_res_lognorm = np.sum(residuals_lognorm**2)
    ss_tot_lognorm = np.sum((counts_inst_normalized - np.mean(counts_inst_normalized))**2)
    r_squared_lognorm = 1 - (ss_res_lognorm / ss_tot_lognorm)
    
    # Calculate normalized counts for overall average velocities
    counts_avg, bins_avg = np.histogram(avg_velocity, bins=bins_smaller)
    counts_avg_normalized = counts_avg / counts_avg.sum()
    ax.hist(bins_avg[:-1], bins=bins_smaller, weights=counts_avg_normalized, color='tan', alpha=0.5, label='Average Velocity')
    
    # Set labels and title
    plt.xlabel('Velocity (μm/min)', fontsize=20)
    plt.ylabel('Fraction of Counts', fontsize=20)
    plt.title('Normalized Distribution of Average Speed and Average Velocity', fontsize=20)
    
    # Customize tick parameters
    ax.tick_params(axis='both', which='major', labelsize=20)
    ax.tick_params(axis='both', which='minor', labelsize=8)
    
    # Add legend
    plt.legend(fontsize=15)
    
    # Adjust layout and show the plot
    plt.tight_layout()
    
    # Save fitting parameters to CSV file
    filepath = data_folder + '\\fit_velocity_parameters.csv'
    if os.path.exists(filepath):
            os.remove(filepath)

    with open(filepath, 'w', newline='') as csvfile:
        writer = csv.writer(csvfile)
        
    save_fit_velocity_parameters_as_csv(params_inst, r_squared_inst, [shape, loc, scale], r_squared_lognorm, 
                                        filepath)

    # Save as SVG
    save_path = data_folder + '\\track_velocities.svg'
    if os.path.exists(save_path):
        os.remove(save_path)
    fig.savefig(save_path, format='svg', bbox_inches='tight')
    # print(f"Plot saved as SVG at: {save_path}")
    plt.close()
    


def fit_and_plot_sinuosity_distribution(data_folder, tp_trajectories_final):
    """
    Plots the inverse sinuosity distribution and fits power-law and exponential decay curves.
    
    Parameters:
    - tp_trajectories_final (pd.DataFrame): DataFrame containing track data with 'sinuosity' column.
    
    Returns:
    - dict: Fit results containing parameters, errors, and R² values for power-law and exponential fits.
    """
    # Convert sinuosity values to the desired units
    inverse_sinuosity = 1 / tp_trajectories_final['sinuosity']
    
    # Define histogram bins from 0 to a maximum value
    nbins = 20
    bins = np.linspace(0, 1, nbins)  # Adjust the upper limit based on your data range
    counts, bin_edges = np.histogram(inverse_sinuosity, bins=bins, density=False)
    bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
    normalized_counts = counts / np.sum(counts) if np.sum(counts) > 0 else counts

    # Fit results dictionary to store parameters, errors, and R² values
    fit_results = {
        "power_law": {"params": None, "errors": None, "r2": None},
        "exponential": {"params": None, "errors": None, "r2": None}
    }

    try:
        # Fit power-law to non-zero counts
        popt_power, pcov_power = curve_fit(power_law, bin_centers[counts > 0], normalized_counts[counts > 0], maxfev=10000)
        perr_power = np.sqrt(np.diag(pcov_power))  # Standard deviation errors

        # Calculate R² for power-law fit
        power_pred = power_law(bin_centers, *popt_power)
        ss_res_power = np.sum((normalized_counts - power_pred) ** 2)
        ss_tot_power = np.sum((normalized_counts - np.mean(normalized_counts)) ** 2)
        r2_power = 1 - (ss_res_power / ss_tot_power)

        # Store power-law fit results
        fit_results["power_law"]["params"] = popt_power
        fit_results["power_law"]["errors"] = perr_power
        fit_results["power_law"]["r2"] = r2_power

        # Fit exponential decay to non-zero counts
        popt_exp, pcov_exp = curve_fit(exponential_decay, bin_centers[counts > 0], normalized_counts[counts > 0], maxfev=10000)
        perr_exp = np.sqrt(np.diag(pcov_exp))  # Standard deviation errors

        # Calculate R² for exponential fit
        exp_pred = exponential_decay(bin_centers, *popt_exp)
        ss_res_exp = np.sum((normalized_counts - exp_pred) ** 2)
        ss_tot_exp = np.sum((normalized_counts - np.mean(normalized_counts)) ** 2)
        r2_exp = 1 - (ss_res_exp / ss_tot_exp)

        # Store exponential fit results
        fit_results["exponential"]["params"] = popt_exp
        fit_results["exponential"]["errors"] = perr_exp
        fit_results["exponential"]["r2"] = r2_exp

    except RuntimeError as e:
        print(f"Error fitting curves: {e}")

    # Plot Inverse Sinuosity Distribution
    fig, ax = plt.subplots(figsize=(12, 8))

    # Plot normalized histogram
    ax.bar(bin_centers, normalized_counts, width=np.diff(bin_edges), color='lightcoral', edgecolor='k', 
           alpha=0.7, label='Normalized Inverse Sinuosity Histogram')

    # Plot power-law fit
    if fit_results["power_law"]["params"] is not None:
        ax.plot(bin_centers, power_law(bin_centers, *fit_results["power_law"]["params"]), color='olive',  marker='o',
                label=f'Power-law Fit: y = {fit_results["power_law"]["params"][0]:.2f} *x^{{{fit_results["power_law"]["params"][1]:.2f}}}')

    # Plot exponential fit
    if fit_results["exponential"]["params"] is not None:
        ax.plot(bin_centers, exponential_decay(bin_centers, *fit_results["exponential"]["params"]), 
                color='slateblue', marker='o', 
                label=f'Exponential Fit: y = {fit_results["exponential"]["params"][0]:.2f} *exp(-{fit_results["exponential"]["params"][1]:.2f} * x)')

    # Customize plot
    ax.set_title('Inverse Sinuosity Distribution with Fitted Curves' , fontsize=20)
    ax.set_xlabel('Inverse Sinuosity', fontsize=20)
    ax.set_ylabel('Counts per Bin / Total Bins', fontsize=20)
    ax.tick_params(axis='both', which='major', labelsize=20)
    ax.tick_params(axis='both', which='minor', labelsize=8)
    ax.grid(False)
    ax.legend(fontsize=16)

    # Save fitting parameters to CSV file
    def save_fit_parameters_as_csv(fit_results, filepath):
        """
        Saves the fit results (parameters, errors, R²) to a CSV file.
        
        Args:
          fit_results (dict): Dictionary containing fit results.
          filepath (str): Path to the CSV file for saving.
        """
        with open(filepath, 'w', newline='') as csvfile:
          writer = csv.writer(csvfile)
          writer.writerow(['Fit Type', 'Parameter 1', 'Error 1', 'Parameter 2', 'Error 2', 'R²'])
          writer.writerow(['Power Law', *fit_results['power_law']['params'], *fit_results['power_law']
                           ['errors'], fit_results['power_law']['r2']])
          writer.writerow(['Exponential Decay', *fit_results['exponential']['params'], *fit_results
                           ['exponential']['errors'], fit_results['exponential']['r2']])

        if os.path.exists(filepath):
            os.remove(filepath)

        with open(filepath, 'w', newline='') as csvfile:
            writer = csv.writer(fit_results)
    
    save_fit_parameters_as_csv(fit_results, data_folder + '\\fit_sinuosity_parameters.csv')


    # Save plot as SVG
    save_path = data_folder + '\\inverse_sinuosity_distribution.svg'
    
    # Delete the old figure if it exists (combined logic)
    if os.path.exists(save_path):
        os.remove(save_path)
    
    # Save the new figure
    fig.savefig(save_path, format='svg', bbox_inches='tight')
    plt.close()
    # return fit_results


def analyze_trajectory_data(data_folder, file_name, resolution, frame_interval, kernel_size=11, 
                            min_frames=10, diameter=5, minmass=1.0, separation=15, search_range=5):
    """
    Analyze trajectory data from raw images by processing optical flow, detecting and tracking particles, 
    and plotting various distributions.

    Parameters:
    - data_folder (str): Path to the folder containing raw image files.
    - file_name (str): Name of the raw image file to analyze.
    - resolution (float): Resolution in px/um.
    - frame_interval (float): Time interval per frame in seconds.
    - kernel_size (int): Kernel size for optical flow processing (default is 11).
    - min_frames (int): Minimum number of frames for tracks to be retained (default is 10).
    - diameter (float): Diameter of the particles for detection (default is 5).
    - minmass (float): Minimum mass for particle detection (default is 1.0).
    - separation (float): Minimum separation between detected particles (default is 15).
    - search_range (float): Search range for particle tracking (default is 5).

    Returns:
    - None: Saves plots and fits to specified paths.
    """
    
    try:
        # Construct the full image path
        print("Step 1: Constructing the full image path...")
        image_path = os.path.join(data_folder, file_name)  # Ensure both parts are strings
        print(f"Image path: {image_path}")
    except Exception as e:
        raise RuntimeError(f"Error in Step 1: Constructing the full image path. Details: {e}")
    
    try:
        # Load the image
        img = tiff.imread(image_path)
        print("Step 2: Image loaded successfully.")
    except FileNotFoundError:
        raise FileNotFoundError(f"Error in Step 2: Image file not found at {image_path}.")
    except Exception as e:
        raise RuntimeError(f"Error in Step 2: Loading the image. Details: {e}")
    
    try:
        # Process optical flow and track particles
        print("Step 3: Processing optical flow and tracking particles...")
        cluster, mask, n_frames = process_optical_flow(data_folder, file_name)
    except Exception as e:
        raise RuntimeError(f"Error in Step 3: Processing optical flow. Details: {e}")
    
    try:
        # Detect and track particles
        print("Step 4: Detecting and tracking particles...")
        tp_trajectories = detect_and_track_particles(cluster, mask, n_frames, diameter=diameter, 
                                                     minmass=minmass, separation=separation, search_range=search_range)
        print(f"Particles detected and tracked. Total trajectories: {len(tp_trajectories)}.")
    except Exception as e:
        raise RuntimeError(f"Error in Step 4: Detecting and tracking particles. Details: {e}")
    
    try:
        # Filter out tracks with less than the specified number of frames
        print("Step 5: Filtering tracks...")
        tp_trajectories_pruned = tp.filter_stubs(tp_trajectories, min_frames)
        print(f"Tracks filtered. Remaining trajectories: {len(tp_trajectories_pruned)}.")
    except Exception as e:
        raise RuntimeError(f"Error in Step 5: Filtering tracks. Details: {e}")
    
    try:
        # Extract metrics from the tracks
        print("Step 6: Calculating track metrics...")
        tp_trajectories_final, track_metrics = calculate_track_metrics(tp_trajectories_pruned, resolution, frame_interval)
    except Exception as e:
        raise RuntimeError(f"Error in Step 6: Calculating track metrics. Details: {e}")


    try:
        # Calculate angular differences
        print("Step 7a : Calculating angular differences...")
        ang_diff_process = calculate_angular_differences_process(tp_trajectories_final, data_folder, save_path = data_folder+'\\angles.csv')
    
        # Plot the angular differences
        print("Step 7b : Plotting angular differences...")
        plot_angular_differences_process(ang_diff_process, tp_trajectories_final, save_path = data_folder+'\\relative_angles.svg')
        
    except Exception as e:
        print(f"An error occurred: {e}")

    # Visualize and save trajectories
    # visualize_and_save_trajectories(tp_trajectories_final, img, data_folder)
    # plot_particle_trajectories(tp_trajectories_final, data_folder, colormap_name='magma')

    # Plotting all the track property distributions
    # analyze_track_duration(data_folder, tp_trajectories_final)
    # plot_track_total_distance_distribution(data_folder, tp_trajectories_final)
    # fit_and_plot_track_displacement_distribution(data_folder, tp_trajectories_final)
    # plot_velocity_distributions(data_folder, tp_trajectories_final)
    # fit_and_plot_sinuosity_distribution(data_folder, tp_trajectories_final)


    # Save DataFrames to CSV
    save_path1 = data_folder + '\\tp_trajectories_final.csv'
    if os.path.exists(save_path1):
        os.remove(save_path1)
    tp_trajectories_final.to_csv(save_path1, index=True)
    
    save_path2 = data_folder + '\\track_metrics.csv'
    if os.path.exists(save_path2):
        os.remove(save_path2)
    track_metrics.to_csv(save_path2, index=True)





def extract_numeric(value):
    """
    Extracts numeric value from a given input. If the input is not numeric,
    it will return None.

    Args:
    - value: The input value to be processed.

    Returns:
    - float or None: The extracted numeric value, or None if the input is not valid.
    """
    # Convert to string and strip any leading/trailing whitespace
    value_str = str(value).strip()

    # Use regex to find numeric values (including decimals)
    match = re.findall(r"[-+]?\d*\.\d+|\d+", value_str)

    if match:
        return float(match[0])  # Return the first found numeric value
    else:
        return None

def get_file_names_and_params_by_div(sheet_path, div_value, sheet_name='glass'):
    """
    Get all 'file name', 'resolution', and 'frame interval' values from the specified Excel sheet where div 
    equals the given value.

    Args:
    - sheet_path: Path to the Excel sheet containing imaging details.
    - div_value: The 'div' value to filter the Excel sheet.
    - sheet_name: The sheet name in the Excel file (default: 'glass').

    Returns:
    - params: A list of tuples containing ('file name', 'resolution', 'frame interval') where div equals div_value.
    """
    # Load the Excel sheet
    df_sheet = pd.read_excel(sheet_path, sheet_name=sheet_name)

    # Filter based on the 'div' value
    df_sheet_filtered = df_sheet[df_sheet['div'] == div_value]

    # Get the parameters
    params = []
    for index in df_sheet_filtered.index:
        # Extract file name
        file_name = df_sheet_filtered['file name'][index]

        # Extract frame interval
        interval = df_sheet_filtered['frame interval'][index]
        integer_interval_value = extract_numeric(interval)

        # Extract resolution
        resolution = df_sheet_filtered['resolution'][index]
        integer_resolution_value = extract_numeric(resolution)

        # Append to params
        params.append((file_name, integer_resolution_value, integer_interval_value))

    return params



def check_level_2_subfolders_by_div_and_params(main_folder, sheet_path, div_value, resolution_range, frame_interval_range):
    """
    Check all level-2 subfolder names under level-1 subfolders that match the specified div value,
    and verify if their corresponding resolution and frame interval match the criteria.

    Args:
    - main_folder: Path to the main folder containing level-1 subfolders.
    - sheet_path: Path to the Excel sheet containing imaging details.
    - div_value: The 'div' value to filter the Excel sheet.
    - resolution_range: The resolution range to match (as a tuple or list).
    - frame_interval_range: The frame interval range to match (as a tuple or list).
    """
    # Get the file names and corresponding parameters that match the div value
    params = get_file_names_and_params_by_div(sheet_path, div_value)

    if not params:
        print(f"No parameters found for div = {div_value}.")
        return

    # Create the expected level-1 subfolder name
    level_1_subfolder_name = f"div{div_value}"

    # Loop through each level-1 subfolder
    for subfolder in os.listdir(main_folder):
        subfolder_path = os.path.join(main_folder, subfolder)

        # Check if it's a directory (level-1) and matches the expected name
        if os.path.isdir(subfolder_path) and subfolder == level_1_subfolder_name:
            # Loop through each level-2 subfolder
            for level_2_subfolder in os.listdir(subfolder_path):
                level_2_subfolder_path = os.path.join(subfolder_path, level_2_subfolder)

                if os.path.isdir(level_2_subfolder_path):
                    # Check if this level-2 folder corresponds to any file name
                    for file_name, resolution, frame_interval in params:
                        if file_name in level_2_subfolder:
                            print(f"Analyzing Level-2 Subfolder: {level_2_subfolder}")

                            # Verify if resolution and frame interval meet the criteria
                            if ((resolution_range is None or (resolution is not None and 
                                                              resolution_range[0] <= resolution <= resolution_range[1])) and
                                (frame_interval_range is None or (frame_interval is not None and frame_interval_range[0] <= 
                                                                  frame_interval <= frame_interval_range[1]))):
                                
                                print("Criteria Met")
                                
                                # Search for files ending with '_jttr_blch_corr.tiff'
                                found_file = None
                                for file in os.listdir(level_2_subfolder_path):
                                    if file.endswith('_jttr_blch_corr.tiff'):
                                        found_file = file
                                        break
                                
                                # If a file is found, proceed with analysis
                                if found_file:
                                    folder_path = convert_to_code_path(subfolder_path + '\\' + level_2_subfolder)
                                    try:
                                        analyze_trajectory_data(folder_path, found_file, resolution, frame_interval, min_frames=8)
                                    except Exception as e:
                                        print(f"Error analyzing data in {level_2_subfolder}: {e}")
                                else:
                                    print("No '_jttr_blch_corr.tiff' file found in this folder.")
                            else:
                                print("Criteria Not Met")




def analyze_all_level_2_folders_by_div(main_folder, sheet_path, div_value, resolution_range, frame_interval_range):
    """
    Analyze all level-2 folders under level-1 subfolders based on the specified div value,
    resolution range, and frame interval range.

    Args:
    - main_folder: Path to the main folder containing level-1 subfolders.
    - sheet_path: Path to the Excel sheet containing imaging details.
    - div_value: The 'div' value to filter the Excel sheet.
    - resolution_range: The resolution range to match (as a tuple or list).
    - frame_interval_range: The frame interval range to match (as a tuple or list).
    """
    # Check level-2 subfolders by div and parameters
    check_level_2_subfolders_by_div_and_params(main_folder, sheet_path, div_value, resolution_range, frame_interval_range)

## Practice example for a DIV stage

In [11]:
# Example call
main_folder = r'E:\Spandan\2D_Neurons_Paper\Glass'
sheet_path = r'E:\Spandan\Kate\NEURON MOVIES\tifNotes.xlsx'
div_value = 11
resolution_range = (2.6, 2.9)
frame_interval_range = (1.95, 2.05)

analyze_all_level_2_folders_by_div(main_folder, sheet_path, div_value, resolution_range, frame_interval_range)


Frame 379: 83 trajectories present.
Particles detected and tracked. Total trajectories: 30558.
Step 5: Filtering tracks...
Tracks filtered. Remaining trajectories: 2090.
Step 6: Calculating track metrics...
Trajectories and track properties calculated and stored.
Step 7a : Calculating angular differences...
Step 7b : Plotting angular differences...


## Main function call

In [12]:
def analyze_all_divs(main_folder, sheet_path, resolution_range, frame_interval_range):
    """
    Analyze all level-1 folders corresponding to different DIV values.

    Parameters:
    - main_folder: Path to the main folder containing level-1 subfolders.
    - sheet_path: Path to the Excel file with additional data.
    - resolution_range: Tuple containing the range of resolutions to filter.
    - frame_interval_range: Tuple containing the range of frame intervals to filter.
    """
    
    # Loop through each level-1 subfolder
    for subfolder in os.listdir(main_folder):
        subfolder_path = os.path.join(main_folder, subfolder)

        # Check if it's a directory (level-1)
        if os.path.isdir(subfolder_path):
            # Use regex to extract the DIV value from the folder name
            match = re.search(r'div(\d+)', subfolder, re.IGNORECASE)
            if match:
                div_value = int(match.group(1))  # Convert extracted string to integer
                print(f"Analyzing {subfolder} with DIV value: {div_value}")

                # Call your analysis function
                analyze_all_level_2_folders_by_div(main_folder, sheet_path, div_value, resolution_range, frame_interval_range)

# Example usage
main_folder = r'E:\Spandan\2D_Neurons_Paper\Glass'
sheet_path = r'E:\Spandan\Kate\NEURON MOVIES\tifNotes.xlsx'
resolution_range = (2.6, 2.9)
frame_interval_range = (1.8, 2.2)

analyze_all_divs(main_folder, sheet_path, resolution_range, frame_interval_range)


Frame 124: 73 trajectories present.
Particles detected and tracked. Total trajectories: 9405.
Step 5: Filtering tracks...
Tracks filtered. Remaining trajectories: 768.
Step 6: Calculating track metrics...
Trajectories and track properties calculated and stored.
Step 7a : Calculating angular differences...
Step 7b : Plotting angular differences...
Analyzing Level-2 Subfolder: control_2021_12_06_div7neurons_dish1_15mm_40x_neuron3_media
Criteria Met
Step 1: Constructing the full image path...
Image path: E:\\Spandan\\2D_Neurons_Paper\\Glass\\div7\\control_2021_12_06_div7neurons_dish1_15mm_40x_neuron3_media\div7neurons_dish1_15mm_40x_neuron3_media_jttr_blch_corr.tiff
Step 2: Image loaded successfully.
Step 3: Processing optical flow and tracking particles...
Error: [WinError 3] The system cannot find the path specified: 'E:\\\\Spandan\\\\2D_Neurons_Paper\\\\Glass\\\\div7\\\\control_2021_12_06_div7neurons_dish1_15mm_40x_neuron3_media\\Op_flow'
Error analyzing data in control_2021_12_06_div7