# Adult Fish Trajectory Analysis Script

This script is designed to analyze the trajectory of fish in video recordings. It loads fish coordinate data from a CSV file, performs several calculations, and generates various plots to visualize the results.

## Instructions

1.  **Prepare your data:**
    *   Your data must be in a CSV file. The CSV file should contain two columns for each fish, representing x and y coordinates. So if you have 8 fish there should be 16 columns in total, without headers.
    *   Make sure the `file_path` variable in the script points to the correct CSV file.
    *   The script assumes the coordinates are in pixels.
2.  **Configure the settings:**
    *   Adjust the following variables at the start of the script to match your experimental setup:
        *   `file_path`: Path to your CSV file.
        *   `num_fish`: Number of fish tracked in the video.
        *   `tank_width_cm`: Width of the tank in centimeters.
        *   `video_width_pixels`: Width of the video frame in pixels.
        *   `tank_height_cm`: Height of the tank in centimeters.
        *   `video_height_pixels`: Height of the video frame in pixels.
        *   `fps`: Frames per second of the video recording.
        *   `prefix`: Prefix for column names, if any (default is 'pixel_').
        *   `save_path`: Folder to save plots in (folder will be created if it does not exist).
        *   `PLOT_DPI`: DPI for saving the plots.
3.  **Run the script:** Execute the script from your IDE using run all button.
4.  **Check the outputs:**
    *   The script will output three CSV files into the folder where the python script is stored: `time_in_halvess.csv`, `distance_speed.csv`, and `fish_analysis_results.csv`.
    *   The script will output plot png files into a folder called `plots` (or whatever you set the `save_path` variable to) which is created in the same location as the python script.
    *   The console will print progress messages for each stage of processing.


In [42]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.spatial import distance
import os
from scipy.signal import savgol_filter

### 1. Configuration Settings

These lines at the beginning of the script define parameters for your analysis, such as file paths, tank dimensions, and video properties.  They also configure the save folder and DPI of the plots

In [43]:
# Load File
file_path = '1.csv'  # Replace with your actual file

# load the file in dataframe
df = pd.read_csv(file_path)

# look at the discription of the data
print(df.head())
print(df.describe())

        x1       y1       x2       y2       x3       y3       x4       y4
0  342.829  384.614  923.479  291.459  805.567  225.575  681.106  280.663
1  342.425  384.860  924.132  291.129  805.869  221.234  681.472  281.387
2  343.260  384.392  924.050  290.802  806.627  217.080  682.230  282.216
3  343.167  384.928  924.650  290.514  806.436  215.903  682.617  283.060
4  343.484  385.030  924.692  290.096  806.574  214.784  683.254  283.923
                 x1            y1            x2            y2            x3  \
count  20302.000000  20302.000000  25201.000000  25201.000000  27411.000000   
mean     367.690934    353.531981    475.999057    351.623925    534.474848   
std      150.091824     94.922978    208.321607    110.216519    221.498981   
min       83.056000     88.164000     88.733000     86.420000     86.107000   
25%      249.128000    279.834000    309.327000    285.629000    352.910000   
50%      333.509500    359.289000    439.678000    369.596000    522.046000   
75%

*   **Make Changes in hight/width:** Look at the discription of data generated by previous cell of code and make necessary adjustment to remove the outliers or trajectory points.


In [44]:
# Configuration settings
num_fish = 4              # The number of fish you tracked.
tank_width_cm = 22        # The real-world width of your tank in cm.
video_width_pixels = 1000 # The width of your video in pixels.
tank_height_cm = 15       # The real-world height of your tank in cm.
video_height_pixels = 600# The height of your video in pixels.
fps = 30                  # The frames-per-second of your video.
start_time = 0            # Analysis window start time in seconds.
end_time = 300            # Analysis window end time in seconds.
prefix = 'fish_'          # A prefix for DataFrame columns (can be left as is).
save_path = 'plots'       # The name of the folder where plots will be saved.
os.makedirs(save_path, exist_ok=True) # ensures folder exists, creates it otherwise
PLOT_DPI = 600            # The resolution (dots-per-inch) for saved plots.
center_zone_percentage = 0.5 # 0.5 means the central 50% of the area is the 'center'.
# Freezing Behavior:
freezing_threshold_speed = 0.5  # cm/s. Speed below which a fish is considered 'immobile'.
freezing_min_duration_s = 2     # seconds. Minimum duration for a period of immobility to be a 'freeze'.
# Turning Preference:
turn_threshold_degrees = 30 # Angle in degrees. A change in direction greater than this is counted as a 'turn'.


# Clean data
def clean_trajectories(df, video_width, video_height):
    """Interpolates missing data and removes points outside video dimensions."""
    def interpolate_invalid_blocks(series, video_dimension=None):
        mask = (series < 0) | series.isna()
        if video_dimension is not None:
            mask = mask | (series > video_dimension)
        invalid_blocks = (mask != mask.shift()).cumsum()
        for _, group in series.groupby(invalid_blocks):
            if mask[group.index[0]]:
                start_idx, end_idx = group.index[0], group.index[-1]
                prev_valid_idx = series.loc[:start_idx - 1][~mask].index.max()
                next_valid_idx = series.loc[end_idx + 1:][~mask].index.min()
                if pd.isna(prev_valid_idx) and pd.notna(next_valid_idx):
                    series.loc[:end_idx] = series.at[next_valid_idx]
                elif pd.notna(prev_valid_idx) and pd.notna(next_valid_idx):
                    prev_val, next_val = series.at[prev_valid_idx], series.at[next_valid_idx]
                    num_points = end_idx - start_idx + 1
                    interpolated_values = np.linspace(prev_val, next_val, num_points + 2)[1:-1]
                    series.loc[start_idx:end_idx] = interpolated_values
        return series
    for fish_id in range(1, num_fish + 1):
        df[f'x{fish_id}'] = interpolate_invalid_blocks(df[f'x{fish_id}'], video_width)
        df[f'y{fish_id}'] = interpolate_invalid_blocks(df[f'y{fish_id}'], video_height)
    return df

df = clean_trajectories(df, video_width_pixels, video_height_pixels)
df.to_csv('trajectories_cleaned.csv', index=False)
file_path = 'trajectories_cleaned.csv'

### 2. Generates coordinate column names with an optional prefix and convert pixel to cm


In [45]:
# Function to generate coordinate column names based on the number of fish and an optional prefix
def generate_coordinate_columns(num_fish, prefix=''):
    return [f'{prefix}{axis}{i + 1}' for i in range(num_fish) for axis in ['x', 'y']]

def load_and_prepare_data(file_path, num_fish, tank_width_cm, video_width_pixels, tank_height_cm, video_height_pixels, fps, prefix=''):
    try:
        df = pd.read_csv(file_path, sep=',')
        print("Data loaded successfully.")
    except Exception as e:
        raise Exception(f"Error loading CSV: {e}")

    expected_columns = num_fish * 2
    if len(df.columns) != expected_columns:
        raise ValueError(f"Error: Expected {expected_columns} columns, found {len(df.columns)}.")

    df.columns = generate_coordinate_columns(num_fish, prefix)
    df['frame_number'] = range(len(df))
    time_per_frame = 1 / fps
    df['time_diff'] = time_per_frame

    df = convert_pixels_to_cm(df, num_fish, tank_width_cm, video_width_pixels, tank_height_cm, video_height_pixels, prefix)
    print("Pixel to cm conversion completed.")

    for i in range(1, num_fish + 1):
        df[f'{prefix}x{i}'] = df[f'{prefix}x{i}'].interpolate(method='linear')
        df[f'{prefix}y{i}'] = df[f'{prefix}y{i}'].interpolate(method='linear')
    print("Interpolation of NaNs complete.")

    df = calculate_tank_center(df, tank_width_cm, tank_height_cm)
    df = calculate_acceleration(df, num_fish, prefix, time_interval=time_per_frame)
    df = calculate_group_metrics(df, num_fish, prefix)
    df = add_instantaneous_kinematics(df, num_fish, prefix, time_per_frame)
    print("Instantaneous speed and angle calculated.")

    print("Data preparation completed.")
    return df

def convert_pixels_to_cm(df, num_fish, tank_width_cm, video_width_pixels, tank_height_cm, video_height_pixels, prefix=''):
    pixel_to_cm_x = tank_width_cm / video_width_pixels
    pixel_to_cm_y = tank_height_cm / video_height_pixels
    for i in range(1, num_fish + 1):
        df[f'{prefix}x{i}'] *= pixel_to_cm_x
        df[f'{prefix}y{i}'] *= pixel_to_cm_y
    return df

def calculate_tank_center(df, tank_width_cm, tank_height_cm):
    df['tank_center_x'] = tank_width_cm / 2
    df['tank_center_y'] = tank_height_cm / 2
    return df

### 3. Calculates the matrices


In [46]:
# Function to calculate acceleration from position data
# This function calculates the velocity and acceleration for each fish based on their x and y coordinates.
def calculate_acceleration(df, num_fish, prefix='', time_interval=1):
    window_size, poly_order = 5, 2
    for i in range(1, num_fish + 1):
        df[f'{prefix}vx{i}'] = df[f'{prefix}x{i}'].diff() / time_interval
        df[f'{prefix}vy{i}'] = df[f'{prefix}y{i}'].diff() / time_interval
        df[f'{prefix}svx{i}'] = savgol_filter(df[f'{prefix}vx{i}'].fillna(0), window_size, poly_order)
        df[f'{prefix}svy{i}'] = savgol_filter(df[f'{prefix}vy{i}'].fillna(0), window_size, poly_order)
        df[f'{prefix}ax{i}'] = df[f'{prefix}svx{i}'].diff() / time_interval
        df[f'{prefix}ay{i}'] = df[f'{prefix}svy{i}'].diff() / time_interval
    return df

# Function to calculate group metrics
# This function calculates the centroid, pairwise distances, and other metrics for a group of fish
# based on their x and y coordinates.
def calculate_group_metrics(df, num_fish, prefix):
    x_cols = [f'{prefix}x{i}' for i in range(1, num_fish + 1)]
    y_cols = [f'{prefix}y{i}' for i in range(1, num_fish + 1)]
    df['group_centroid_x'] = df[x_cols].mean(axis=1)
    df['group_centroid_y'] = df[y_cols].mean(axis=1)
    pairwise_distances = []
    for _, row in df.iterrows():
      coords = np.array([[row[f'{prefix}x{i}'], row[f'{prefix}y{i}']] for i in range(1, num_fish + 1)])
      dist_matrix = distance.cdist(coords, coords, 'euclidean')
      pairwise_distances.append(dist_matrix.flatten())
    df['pairwise_distances'] = pairwise_distances
    df['mean_pairwise_distance'] = df['pairwise_distances'].apply(np.mean)
    df['std_pairwise_distance'] = df['pairwise_distances'].apply(np.std)
    df['total_pairwise_distance'] = df['pairwise_distances'].apply(np.sum)
    return df

# Function to calculate speed and acceleration
# This function calculates the instantaneous speed and acceleration for each fish based on their x and y coordinates
# and returns the speeds and accelerations as numpy arrays.
def calculate_speed_acceleration(all_x, all_y, time_interval=1):
    vx = np.diff(all_x) / time_interval
    vy = np.diff(all_y) / time_interval
    speeds = np.sqrt(vx**2 + vy**2)
    ax = np.diff(vx) / time_interval
    ay = np.diff(vy) / time_interval
    acceleration = np.sqrt(ax**2 + ay**2)
    return speeds, acceleration

# Function to add instantaneous speed and angle for each fish
# This function calculates and adds instantaneous speed and movement angle for each fish.
def add_instantaneous_kinematics(df, num_fish, prefix, time_interval):
    """Calculates and adds instantaneous speed and movement angle for each fish."""
    for i in range(1, num_fish + 1):
        x_col, y_col = f'{prefix}x{i}', f'{prefix}y{i}'
        dx = df[x_col].diff()
        dy = df[y_col].diff()
        dist = np.sqrt(dx**2 + dy**2)
        df[f'{prefix}speed{i}'] = dist / time_interval
        df[f'{prefix}angle{i}'] = np.arctan2(dy, dx)
    return df

# Function to calculate time spent in center vs. periphery
# This function calculates the time spent in the center and periphery of the tank for each fish
# based on the specified center zone percentage.
def calculate_time_in_zones(df, fish_id, tank_width_cm, tank_height_cm, center_zone_pct, prefix):
    """Calculates time spent in center vs. periphery."""
    x_col, y_col = f'{prefix}x{fish_id}', f'{prefix}y{fish_id}'
    center_x, center_y = tank_width_cm / 2, tank_height_cm / 2
    half_width = (tank_width_cm * center_zone_pct) / 2
    half_height = (tank_height_cm * center_zone_pct) / 2
    center_x_min, center_x_max = center_x - half_width, center_x + half_width
    center_y_min, center_y_max = center_y - half_height, center_y + half_height
    in_center = (
        (df[x_col] >= center_x_min) & (df[x_col] <= center_x_max) &
        (df[y_col] >= center_y_min) & (df[y_col] <= center_y_max)
    )
    time_in_center = df.loc[in_center, 'time_diff'].sum()
    total_time = df['time_diff'].sum()
    time_in_periphery = total_time - time_in_center
    return time_in_center, time_in_periphery

# Function to calculate freezing behavior
# This function identifies and quantifies freezing behavior based on speed thresholds and duration.
def calculate_freezing_bouts(df, fish_id, threshold_speed, min_duration_s, fps, prefix):
    """Identifies and quantifies freezing behavior."""
    speed_col = f'{prefix}speed{fish_id}'
    is_immobile = df[speed_col] < threshold_speed
    immobile_blocks = (is_immobile != is_immobile.shift()).cumsum()
    total_freeze_time, num_bouts = 0, 0
    for _, group in df.groupby(immobile_blocks):
        if is_immobile.iloc[group.index[0]]:
            duration = len(group) / fps
            if duration >= min_duration_s:
                num_bouts += 1
                total_freeze_time += duration
    return total_freeze_time, num_bouts

# Function to calculate turning preference
# This function counts left and right turns and calculates the bias index based on the angle changes.
def calculate_turning_preference(df, fish_id, turn_threshold_rad, prefix):
    """Counts left and right turns and calculates bias."""
    angle_col = f'{prefix}angle{fish_id}'
    delta_angle = df[angle_col].diff()
    delta_angle = (delta_angle + np.pi) % (2 * np.pi) - np.pi
    left_turns = (delta_angle > turn_threshold_rad).sum()
    right_turns = (delta_angle < -turn_threshold_rad).sum()
    bias_index = (left_turns - right_turns) / (left_turns + right_turns) if (left_turns + right_turns) > 0 else 0
    return left_turns, right_turns, bias_index

# Function to calculate time spent in upper and lower halves of the tank for each fish
# This function calculates the time spent in the upper and lower halves of the tank for each fish
# based on the y-coordinate and the tank height.
def calculate_time_in_halves_for_fish(df, fish_id, tank_height_cm, prefix):
    y_col = f'{prefix}y{fish_id}'
    center_y = tank_height_cm / 2
    upper_half_condition = df[y_col] > center_y
    lower_half_condition = ~upper_half_condition
    upper_time = df.loc[upper_half_condition, 'time_diff'].sum()
    lower_time = df.loc[lower_half_condition, 'time_diff'].sum()
    return upper_time, lower_time

# Function to calculate total distance traveled by each fish
# This function calculates the total distance traveled by each fish based on their x and y coordinates.
def calculate_total_distance(df, num_fish, prefix=''):
    total_distances = {}
    for i in range(1, num_fish + 1):
        distances = np.sqrt(np.diff(df[f'{prefix}x{i}'])**2 + np.diff(df[f'{prefix}y{i}'])**2)
        total_distances[f'fish{i}'] = np.sum(distances)
    return total_distances

# Function to calculate average speed for each fish
# This function calculates the average speed of each fish based on their x and y coordinates.
# It computes the total distance traveled and divides it by the total time of the recording.
def calculate_average_speed(df, num_fish, prefix='', fps=30):
    average_speeds = {}
    time_per_frame = 1/fps
    for i in range(1, num_fish + 1):
        total_distance = np.sum(np.sqrt(np.diff(df[f'{prefix}x{i}'])**2 + np.diff(df[f'{prefix}y{i}'])**2))
        total_time = len(df)*time_per_frame
        average_speeds[f'fish{i}'] = total_distance / total_time if total_time > 0 else 0
    return average_speeds

### 4. Generate plots

In [47]:
# Function to generate KDE plot for a individual fish
# This function creates a Kernel Density Estimate (KDE) plot for the x and y positions
def generate_kde_plot(subject, all_x, all_y, save_path, plot_dpi = 300):
    plt.figure(figsize=(8, 6))
    sns.kdeplot(x=all_x, y=all_y, cmap='YlOrRd', fill=True, cbar=True)
    plt.title(f'KDE Plot for {subject}', fontweight='bold', fontsize=14)
    plt.xlabel("X Position (cm)")
    plt.ylabel("Y Position (cm)")
    plt.savefig(os.path.join(save_path, f'{subject}_kde_plot.png'), dpi=plot_dpi, bbox_inches='tight')
    plt.close()

# Function to generate KDE plot for a group of fish
# This function creates a Kernel Density Estimate (KDE) plot for the x and y positions
def generate_group_kde_plot(all_x, all_y, save_path, plot_dpi=300):
    plt.figure(figsize=(8, 6))
    sns.kdeplot(x=all_x, y=all_y, cmap='viridis', fill=True, cbar=True)
    plt.title('Group KDE Plot', fontweight='bold', fontsize=14)
    plt.xlabel("X Position (cm)")
    plt.ylabel("Y Position (cm)")
    plt.savefig(os.path.join(save_path, 'group_kde_plot.png'), dpi=plot_dpi, bbox_inches='tight')
    plt.close()

# Function to generate additional plots for a single subject
# This function generates radial distribution, angle of movement, and speed time series plots for a single subject.
# It also saves these plots in the specified directory.
# The radial distribution plot shows the distribution of distances from the tank center,
# the angle of movement plot shows the distribution of angles between consecutive positions,
# and the speed time series plot shows the speed of the fish over time.
def generate_additional_plots(subject, all_x, all_y, save_path, tank_center_x, tank_center_y, fps, plot_dpi=300):
    # Radial Distribution Plot
    radial_distances = np.sqrt((all_x - tank_center_x)**2 + (all_y - tank_center_y)**2)
    plt.figure(figsize=(8, 6))
    sns.histplot(radial_distances, bins=30, kde=True, color='green')
    plt.title(f'Radial Distribution for {subject}', fontweight='bold', fontsize=14)
    plt.xlabel("Radial Distance (cm)")
    plt.savefig(os.path.join(save_path, f'{subject}_radial_distribution.png'), dpi=plot_dpi, bbox_inches='tight')
    plt.close()

    # Angle of Movement Plot
    angles = np.arctan2(np.diff(all_y), np.diff(all_x))
    plt.figure(figsize=(8, 6))
    sns.histplot(angles, bins=30, kde=True, color='purple')
    plt.title(f'Angle of Movement for {subject}', fontweight='bold', fontsize=14)
    plt.xlabel("Angle (radians)")
    plt.savefig(os.path.join(save_path, f'{subject}_angle_of_movement.png'), dpi=plot_dpi, bbox_inches='tight')
    plt.close()

    # Speed Time Series Plot
    speeds, _ = calculate_speed_acceleration(all_x, all_y, time_interval=1/fps)
    time = np.arange(len(speeds)) / fps
    plt.figure(figsize=(8, 6))
    plt.plot(time, speeds, color='blue')
    plt.title(f'Speed Time Series for {subject}', fontweight='bold', fontsize=14)
    plt.ylabel("Speed (cm/s)")
    plt.xlabel("Time (s)")
    plt.savefig(os.path.join(save_path, f'{subject}_speed_timeseries.png'), dpi=plot_dpi, bbox_inches='tight')
    plt.close()

# Function to generate trajectory plot for a group of fish
# This function generates a plot showing the trajectories of all fish overlaid on the same axes.
# It saves the plot in the specified directory.
# The trajectories are plotted with different colors for each fish, and the axes are set to equal aspect ratio
# to accurately represent the spatial relationships between the fish.
def generate_trajectory_plot(df, num_fish, prefix, save_path, plot_dpi=600):
    plt.figure(figsize=(10, 8))
    for i in range(1, num_fish + 1):
        plt.plot(df[f'{prefix}x{i}'], df[f'{prefix}y{i}'], label=f'Fish {i}')
    plt.xlabel("X Position (cm)")
    plt.ylabel("Y Position (cm)")
    plt.title("Overlaid Fish Trajectories", fontweight='bold', fontsize = 14)
    plt.legend()
    plt.gca().set_aspect('equal', adjustable='box')
    plt.savefig(os.path.join(save_path, 'overlaid_trajectory.png'), dpi=plot_dpi, bbox_inches='tight')
    plt.close()

# Function to generate spread plot
# This function generates a plot showing the mean pairwise distance between fish over time
def generate_spread_plot(df, num_fish, prefix, save_path, plot_dpi=600):
    time = df['frame_number'] / fps
    plt.figure(figsize=(10, 6))
    plt.plot(time, df['mean_pairwise_distance'], label='Mean distance')
    plt.fill_between(time, df['mean_pairwise_distance'] - df['std_pairwise_distance'],
                    df['mean_pairwise_distance'] + df['std_pairwise_distance'], alpha=0.4, label='Std')
    plt.legend()
    plt.title("Group pairwise distance", fontweight = 'bold', fontsize = 14)
    plt.xlabel("Time (s)")
    plt.ylabel("Distance (cm)")
    plt.savefig(os.path.join(save_path, 'pairwise_distance.png'), dpi=plot_dpi, bbox_inches='tight')
    plt.close()


### 5. Main function to Run the analysis

In [48]:
# Main function to run the analysis
# This function orchestrates the loading, processing, and analysis of the fish tracking data.
# It applies various calculations such as time in halves, zones, freezing behavior, and turning preference.
# It also generates plots and exports the results to CSV files.
def main():
    try:
        df = load_and_prepare_data(file_path, num_fish, tank_width_cm, video_width_pixels, tank_height_cm, video_height_pixels, fps, prefix)

        start_frame = int(start_time * fps)
        end_frame = int(end_time * fps)
        df_filtered = df[(df['frame_number'] >= start_frame) & (df['frame_number'] <= end_frame)].copy()
        print(f"Data filtered to {start_time} - {end_time} seconds")

        time_in_halves, time_in_zones_results, freezing_results, turning_results = {}, {}, {}, {}
        turn_threshold_rad = np.deg2rad(turn_threshold_degrees)

        for i in range(1, num_fish + 1):
            fish_name = f'fish{i}'
            upper, lower = calculate_time_in_halves_for_fish(df_filtered, i, tank_height_cm, prefix)
            time_in_halves[fish_name] = {'Upper': upper, 'Lower': lower}
            center, periphery = calculate_time_in_zones(df_filtered, i, tank_width_cm, tank_height_cm, center_zone_percentage, prefix)
            time_in_zones_results[fish_name] = {'Center': center, 'Periphery': periphery}
            freeze_time, bouts = calculate_freezing_bouts(df_filtered, i, freezing_threshold_speed, freezing_min_duration_s, fps, prefix)
            freezing_results[fish_name] = {'total_freeze_time_s': freeze_time, 'num_bouts': bouts}
            left, right, bias = calculate_turning_preference(df_filtered, i, turn_threshold_rad, prefix)
            turning_results[fish_name] = {'left_turns': left, 'right_turns': right, 'turn_bias_index': bias}

        print("\n--- Analysis Results ---")
        total_distances = calculate_total_distance(df_filtered, num_fish, prefix)
        average_speeds = calculate_average_speed(df_filtered, num_fish, prefix, fps)
        
        for i in range(1, num_fish + 1):
            fish_name = f'fish{i}'
            print(f"\n--- Fish {i} ---")
            print(f"  Total Distance: {total_distances.get(fish_name, 0):.2f} cm")
            print(f"  Average Speed: {average_speeds.get(fish_name, 0):.2f} cm/s")
            print(f"  Time in Halves: Upper={time_in_halves[fish_name]['Upper']:.2f}s, Lower={time_in_halves[fish_name]['Lower']:.2f}s")
            print(f"  Time in Zones: Center={time_in_zones_results[fish_name]['Center']:.2f}s, Periphery={time_in_zones_results[fish_name]['Periphery']:.2f}s")
            print(f"  Freezing: Total Time={freezing_results[fish_name]['total_freeze_time_s']:.2f}s, Bouts={freezing_results[fish_name]['num_bouts']}")
            print(f"  Turning: Left={turning_results[fish_name]['left_turns']}, Right={turning_results[fish_name]['right_turns']}, Bias={turning_results[fish_name]['turn_bias_index']:.2f}")

        # --- Export to CSV ---
        pd.DataFrame.from_dict(time_in_halves, orient='index').to_csv('time_in_halves.csv')
        pd.DataFrame.from_dict(time_in_zones_results, orient='index').to_csv('time_in_zones.csv')
        pd.DataFrame.from_dict(freezing_results, orient='index').to_csv('freezing_analysis.csv')
        pd.DataFrame.from_dict(turning_results, orient='index').to_csv('turning_analysis.csv')
        distance_speed_df = pd.DataFrame({'total_distance': total_distances, 'average_speed': average_speeds})
        distance_speed_df.to_csv('distance_speed.csv')
        df_filtered.to_csv('fish_analysis_results.csv', index=False)
        print("\nAll summary data exported to CSV files.")

        # --- Plotting ---
        print("\nGenerating plots...")
        generate_trajectory_plot(df_filtered, num_fish, prefix, save_path, plot_dpi=PLOT_DPI)
        generate_spread_plot(df_filtered, num_fish, prefix, save_path, plot_dpi=PLOT_DPI)
        all_group_x = [df_filtered[f'{prefix}x{i}'].values for i in range(1, num_fish + 1)]
        all_group_y = [df_filtered[f'{prefix}y{i}'].values for i in range(1, num_fish + 1)]
        generate_group_kde_plot(np.concatenate(all_group_x), np.concatenate(all_group_y), save_path, plot_dpi=PLOT_DPI)

        for i in range(1, num_fish + 1):
            subject = f'Fish {i}'
            all_x = df_filtered[f'{prefix}x{i}'].values
            all_y = df_filtered[f'{prefix}y{i}'].values
            generate_kde_plot(subject, all_x, all_y, save_path, plot_dpi=PLOT_DPI)
            generate_additional_plots(subject, all_x, all_y, save_path, df_filtered['tank_center_x'].iloc[0], df_filtered['tank_center_y'].iloc[0], fps, plot_dpi=PLOT_DPI)
            print(f"Plots generated for {subject}")

        print("\nAnalysis complete.")

    except Exception as e:
         print(f"An error occurred: {e}")

In [49]:
if __name__ == '__main__':
    main()

Data loaded successfully.
Pixel to cm conversion completed.
Interpolation of NaNs complete.
Instantaneous speed and angle calculated.
Data preparation completed.
Data filtered to 0 - 300 seconds

--- Analysis Results ---

--- Fish 1 ---
  Total Distance: 162.66 cm
  Average Speed: 0.54 cm/s
  Time in Halves: Upper=181.43s, Lower=118.60s
  Time in Zones: Center=239.90s, Periphery=60.13s
  Freezing: Total Time=91.87s, Bouts=18
  Turning: Left=1211, Right=1200, Bias=0.00

--- Fish 2 ---
  Total Distance: 496.82 cm
  Average Speed: 1.66 cm/s
  Time in Halves: Upper=194.13s, Lower=105.90s
  Time in Zones: Center=122.63s, Periphery=177.40s
  Freezing: Total Time=9.77s, Bouts=4
  Turning: Left=355, Right=349, Bias=0.01

--- Fish 3 ---
  Total Distance: 642.23 cm
  Average Speed: 2.14 cm/s
  Time in Halves: Upper=181.67s, Lower=118.37s
  Time in Zones: Center=126.80s, Periphery=173.23s
  Freezing: Total Time=3.07s, Bouts=1
  Turning: Left=334, Right=360, Bias=-0.04

--- Fish 4 ---
  Total Dist

## Output Files

The script will generate the following files and folders in the same directory as the python script:

### 1. `plots` (Folder)

A folder containing all the plots generated by the script.

*   **`overlaid_trajectory.png`**: Shows the trajectories of all fish on a single plot, useful for a quick visual overview of movement patterns.
*   **`pairwise_distance.png`**: A time-series plot showing the mean distance between all pairs of fish, with the standard deviation shaded. This is a key metric for shoaling behavior.
*   **`group_kde_plot.png`**: A 2D Kernel Density Estimate (KDE) "heatmap" showing the spatial distribution for all fish combined. Hotter colors indicate where the group spent the most time.
*   **`Fish i_kde_plot.png`**: An individual KDE plot for each fish, showing its specific spatial preference.
*   **`Fish i_radial_distribution.png`**: A histogram showing how much time each fish spent at various distances from the tank's center.
*   **`Fish i_speed_timeseries.png`**: A line graph showing the speed of an individual fish over the course of the trial. This is useful for spotting periods of high activity or freezing.

*(Note: `i` in the filenames corresponds to the fish number, e.g., `Fish 1_kde_plot.png`)*

### 2. Summary Data Files (`.csv`)

These files provide high-level summaries of key behavioral metrics, with one row per fish.

*   **`distance_speed.csv`**: Contains the total distance traveled (cm) and the average speed (cm/s) for each fish.
*   **`time_in_halves.csv`**: Shows the total time (in seconds) each fish spent in the **Upper** and **Lower** halves of the tank.
*   **`time_in_zones.csv`**: Shows the total time (in seconds) each fish spent in the defined **Center** and **Periphery** zones of the tank. A key metric for thigmotaxis and anxiety.
*   **`freezing_analysis.csv`**: Contains the total time spent frozen (`total_freeze_time_s`) and the total number of freezing episodes (`num_bouts`) for each fish.
*   **`turning_analysis.csv`**: Provides counts for `left_turns` and `right_turns` and calculates a `turn_bias_index` (from -1 for all right to +1 for all left).

### 3. Detailed & Intermediate Files

*   **`fish_analysis_results.csv`**: This is the main detailed data output. It contains the original coordinate data along with all derived metrics (velocity, acceleration, speed, angle, etc.) for **every frame** of the analysis.
*   **`trajectories_cleaned.csv`**: An intermediate file containing the raw coordinate data after the initial cleaning and interpolation step. This file is used as the input for the main analysis but is typically not needed for final interpretation.