# Adult Fish Trajectory Analysis Script

This script is designed to analyze the trajectory of fish in video recordings. It loads fish coordinate data from a CSV file, performs several calculations, and generates various plots to visualize the results.

## Instructions

1.  **Prepare your data:**
    *   Your data must be in a CSV file. The CSV file should contain two columns for each fish, representing x and y coordinates. So if you have 8 fish there should be 16 columns in total, without headers.
    *   Make sure the `file_path` variable in the script points to the correct CSV file.
    *   The script assumes the coordinates are in pixels.
2.  **Configure the settings:**
    *   Adjust the following variables at the start of the script to match your experimental setup:
        *   `file_path`: Path to your CSV file.
        *   `num_fish`: Number of fish tracked in the video.
        *   `tank_width_cm`: Width of the tank in centimeters.
        *   `video_width_pixels`: Width of the video frame in pixels.
        *   `tank_height_cm`: Height of the tank in centimeters.
        *   `video_height_pixels`: Height of the video frame in pixels.
        *   `fps`: Frames per second of the video recording.
        *   `prefix`: Prefix for column names, if any (default is 'pixel_').
        *   `save_path`: Folder to save plots in (folder will be created if it does not exist).
        *   `PLOT_DPI`: DPI for saving the plots.
3.  **Run the script:** Execute the script from your IDE using run all button.
4.  **Check the outputs:**
    *   The script will output three CSV files into the folder where the python script is stored: `time_in_halvess.csv`, `distance_speed.csv`, and `fish_analysis_results.csv`.
    *   The script will output plot png files into a folder called `plots` (or whatever you set the `save_path` variable to) which is created in the same location as the python script.
    *   The console will print progress messages for each stage of processing.


In [37]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.spatial import distance
import os
from scipy.signal import savgol_filter

### 1. Configuration Settings

These lines at the beginning of the script define parameters for your analysis, such as file paths, tank dimensions, and video properties.  They also configure the save folder and DPI of the plots

In [38]:
# Load File
file_path = '1.csv'  # Replace with your actual file

# load the file in dataframe
df = pd.read_csv(file_path)

# look at the discription of the data
print(df.head())
print(df.describe())

        x1       y1       x2       y2       x3       y3       x4       y4
0  342.829  384.614  923.479  291.459  805.567  225.575  681.106  280.663
1  342.425  384.860  924.132  291.129  805.869  221.234  681.472  281.387
2  343.260  384.392  924.050  290.802  806.627  217.080  682.230  282.216
3  343.167  384.928  924.650  290.514  806.436  215.903  682.617  283.060
4  343.484  385.030  924.692  290.096  806.574  214.784  683.254  283.923
                 x1            y1            x2            y2            x3  \
count  20302.000000  20302.000000  25201.000000  25201.000000  27411.000000   
mean     367.690934    353.531981    475.999057    351.623925    534.474848   
std      150.091824     94.922978    208.321607    110.216519    221.498981   
min       83.056000     88.164000     88.733000     86.420000     86.107000   
25%      249.128000    279.834000    309.327000    285.629000    352.910000   
50%      333.509500    359.289000    439.678000    369.596000    522.046000   
75%

*   **Make Changes in hight/width:** Look at the discription of data generated by previous cell of code and make necessary adjustment to remove the outliers or trajectory points.


In [39]:
# Configuration settings
num_fish = 4 # Replace with your fish count
tank_width_cm = 22 # Replace with your tank width in cm
video_width_pixels = 1000 # Replace with your video width in pixels
tank_height_cm = 15 # Replace with your tank height in cm
video_height_pixels = 600 # Replace with your video height in pixels
fps = 30 # Replace with your video fps
start_time = 0 # Start of the time window in seconds
end_time = 300  # End of the time window in seconds
prefix = 'fish_' # Replace with your prefix
save_path = 'plots'  # folder where the plots will be saved
os.makedirs(save_path, exist_ok=True)  # ensures folder exists, creates it otherwise
PLOT_DPI = 600 # Set a default plot dpi


# Clean data

def clean_trajectories(df, video_width, video_height):

    def interpolate_invalid_blocks(series, video_dimension=None):

        # Create a mask for invalid points (negative, NaN, or exceeding video dimensions)
        mask = (series < 0) | series.isna()
        if video_dimension is not None:
            mask = mask | (series > video_dimension)

        # Identify blocks of consecutive invalid points
        invalid_blocks = (mask != mask.shift()).cumsum()

        # Iterate over each block
        for block_id, group in series.groupby(invalid_blocks):
            if mask[group.index[0]]:  # Only process invalid blocks
                start_idx = group.index[0]
                end_idx = group.index[-1]

                # Find the last valid point before the block
                prev_valid_idx = series.loc[:start_idx - 1][~mask].index.max()
                # Find the next valid point after the block
                next_valid_idx = series.loc[end_idx + 1:][~mask].index.min()

                # Handle initial NaNs (no previous valid point)
                if pd.isna(prev_valid_idx) and pd.notna(next_valid_idx):
                    # Replace initial NaNs with the first valid value
                    series.loc[:end_idx] = series.at[next_valid_idx]
                # Handle middle blocks (both previous and next valid points exist)
                elif pd.notna(prev_valid_idx) and pd.notna(next_valid_idx):
                    prev_valid_val = series.at[prev_valid_idx]
                    next_valid_val = series.at[next_valid_idx]

                    # Perform linear interpolation
                    num_points = end_idx - start_idx + 1
                    interpolated_values = np.linspace(prev_valid_val, next_valid_val, num_points + 2)[1:-1]

                    # Fill the invalid block with interpolated values
                    series.loc[start_idx:end_idx] = interpolated_values

        return series

    # Process each fish's trajectory
    for fish_id in range(1, num_fish + 1):
        x_col = f'x{fish_id}'
        y_col = f'y{fish_id}'

        # Interpolate x and y coordinates, checking for outliers
        df[x_col] = interpolate_invalid_blocks(df[x_col], video_dimension=video_width)
        df[y_col] = interpolate_invalid_blocks(df[y_col], video_dimension=video_height)

    return df
df = clean_trajectories(df, video_width_pixels, video_height_pixels)

df.to_csv('trajectories_cleaned.csv', index=False)

# update file path
file_path = 'trajectories_cleaned.csv'

### 2. Generates coordinate column names with an optional prefix and convert pixel to cm


In [40]:
# Generates coordinate column names with an optional prefix.
def generate_coordinate_columns(num_fish, prefix=''):
    return [f'{prefix}{axis}{i + 1}' for i in range(num_fish) for axis in ['x', 'y']]

# Loads trajectory data from a CSV, renames columns, adds frame number,
# converts pixel coordinates to cm, calculates tank center, and adds derived
def load_and_prepare_data(file_path, num_fish, tank_width_cm, video_width_pixels, tank_height_cm, video_height_pixels, prefix=''):
    try:
        df = pd.read_csv(file_path, sep=',')
        print("Data loaded successfully.")
    except FileNotFoundError:
        raise FileNotFoundError(f"Error: File not found at {file_path}")
    except pd.errors.EmptyDataError:
        raise ValueError(f"Error: the csv file at {file_path} appears empty")
    except Exception as e:
        raise Exception(f"An unexpected error happened when loading the csv file: {e}")

    expected_columns = num_fish * 2
    if len(df.columns) != expected_columns:
        raise ValueError(f"Error: Expected {expected_columns} columns but got {len(df.columns)} in {file_path}.")

    df.columns = generate_coordinate_columns(num_fish, prefix)
    df['frame_number'] = range(len(df))
    
    # - Add time difference column for calculations
    time_per_frame = 1 / fps
    df['time_diff'] = time_per_frame
    
    df = convert_pixels_to_cm(df, num_fish, tank_width_cm, video_width_pixels, tank_height_cm, video_height_pixels, prefix)
    print("Pixel to cm conversion completed.")

    for i in range(1, num_fish + 1):
        df[f'{prefix}x{i}'] = df[f'{prefix}x{i}'].interpolate(method='linear')
        df[f'{prefix}y{i}'] = df[f'{prefix}y{i}'].interpolate(method='linear')
    print("Interpolation of NaNs complete")

    df = calculate_tank_center(df, tank_width_cm, tank_height_cm)
    df = calculate_acceleration(df, num_fish, prefix, time_interval=time_per_frame)
    df = calculate_group_metrics(df, num_fish, prefix)
    
    
    print("Calculations completed.")
    return df

# Converts pixel coordinates to cm coordinates.
def convert_pixels_to_cm(df, num_fish, tank_width_cm, video_width_pixels, tank_height_cm, video_height_pixels, prefix=''):
    pixel_to_cm_x = tank_width_cm / video_width_pixels
    pixel_to_cm_y = tank_height_cm / video_height_pixels

    for i in range(1, num_fish + 1):
        df[f'{prefix}x{i}'] = df[f'{prefix}x{i}'] * pixel_to_cm_x
        df[f'{prefix}y{i}'] = df[f'{prefix}y{i}'] * pixel_to_cm_y
    return df


# Calculates and adds the tank center coordinates to the DataFrame.
def calculate_tank_center(df, tank_width_cm, tank_height_cm):
    df['tank_center_x'] = tank_width_cm / 2
    df['tank_center_y'] = tank_height_cm / 2
    return df

### 3. Calculates the acceleration and group matrices


In [41]:
# Calculates the smoothed acceleration for each fish and adds it to the DataFrame.
def calculate_acceleration(df, num_fish, prefix='', time_interval=1):
    window_size = 5 
    poly_order = 2 
    for i in range(1, num_fish + 1):
        df[f'{prefix}vx{i}'] = df[f'{prefix}x{i}'].diff() / time_interval
        df[f'{prefix}vy{i}'] = df[f'{prefix}y{i}'].diff() / time_interval
        df[f'{prefix}svx{i}'] = savgol_filter(df[f'{prefix}vx{i}'].fillna(0), window_size, poly_order) 
        df[f'{prefix}svy{i}'] = savgol_filter(df[f'{prefix}vy{i}'].fillna(0), window_size, poly_order)
        df[f'{prefix}ax{i}'] = df[f'{prefix}svx{i}'].diff() / time_interval
        df[f'{prefix}ay{i}'] = df[f'{prefix}svy{i}'].diff() / time_interval
    return df

# Calculates group metrics like centroid, average distance, spread, and total pairwise distance.
def calculate_group_metrics(df, num_fish, prefix):
    x_cols = [f'{prefix}x{i}' for i in range(1, num_fish + 1)]
    y_cols = [f'{prefix}y{i}' for i in range(1, num_fish + 1)]

    df['group_centroid_x'] = df[x_cols].mean(axis=1)
    df['group_centroid_y'] = df[y_cols].mean(axis=1)
    
    pairwise_distances = []
    for index, row in df.iterrows():
      coords = np.array([[row[f'{prefix}x{i}'], row[f'{prefix}y{i}']] for i in range(1, num_fish + 1)])
      dist_matrix = distance.cdist(coords, coords, 'euclidean') 
      pairwise_distances.append(dist_matrix.flatten()) 
    df['pairwise_distances'] = pairwise_distances
    df['mean_pairwise_distance'] = df['pairwise_distances'].apply(lambda x: np.mean(x))
    df['std_pairwise_distance'] = df['pairwise_distances'].apply(lambda x: np.std(x))
    df['total_pairwise_distance'] = df['pairwise_distances'].apply(lambda x: np.sum(x))
    return df


# Calculates speed and acceleration from positional data.
def calculate_speed_acceleration(all_x, all_y, time_interval=1):
    vx = np.diff(all_x) / time_interval
    vy = np.diff(all_y) / time_interval
    speeds = np.sqrt(vx**2 + vy**2)
    ax = np.diff(vx) / time_interval
    ay = np.diff(vy) / time_interval
    acceleration = np.sqrt(ax**2 + ay**2)
    return speeds, acceleration

### 4. Calculates the total time spent in each halves for each fish.
*   **How it works:** Counts how many frames each fish spends in each halves, then converts the number of frames to time in seconds.

In [42]:
# Calculates the time spent in the upper and lower half for a single fish.
def calculate_time_in_halves_for_fish(df, fish_id, tank_height_cm, prefix):
    """Calculates time spent in upper and lower halves for a single fish."""
    y_col = f'{prefix}y{fish_id}'
    center_y = tank_height_cm / 2

    # Create boolean conditions for filtering the DataFrame
    upper_half_condition = df[y_col] > center_y
    # Note: Using '~' ensures that all other points (<= center_y) are in the lower half
    lower_half_condition = ~upper_half_condition 

    # Sum the 'time_diff' for frames where the condition is true
    upper_time = df.loc[upper_half_condition, 'time_diff'].sum()
    lower_time = df.loc[lower_half_condition, 'time_diff'].sum()
    
    return upper_time, lower_time

### 5. calculate total distance and average speed

In [43]:
# Calculates the total distance traveled by each fish.
def calculate_total_distance(df, num_fish, prefix=''):
    total_distances = {}
    for i in range(1, num_fish + 1):
        distances = np.sqrt(
            np.diff(df[f'{prefix}x{i}'])**2 + np.diff(df[f'{prefix}y{i}'])**2
        )
        total_distances[f'fish{i}'] = np.sum(distances)
    return total_distances

# Calculates the average speed of each fish.
def calculate_average_speed(df, num_fish, prefix='', fps=30):
    average_speeds = {}
    time_per_frame = 1/fps
    for i in range(1, num_fish + 1):
        distances = np.sqrt(
            np.diff(df[f'{prefix}x{i}'])**2 + np.diff(df[f'{prefix}y{i}'])**2
        )
        total_distance = np.sum(distances)
        total_time = len(df)*time_per_frame
        average_speeds[f'fish{i}'] = total_distance / total_time if total_time > 0 else 0
    return average_speeds

### 6. Generate plots

In [44]:
# Generates a KDE plot of the positions of a subject.
def generate_kde_plot(subject, all_x, all_y, save_path, plot_dpi = 300):
    plt.figure(figsize=(8, 6))
    sns.kdeplot(x=all_x, y=all_y, cmap='YlOrRd', fill=True, cbar=True)
    plt.title(f'KDE Plot for {subject}', fontweight='bold', fontsize=14)
    plt.xlabel("X Position (cm)")
    plt.ylabel("Y Position (cm)")
    save_path_kde = os.path.join(save_path, f'{subject}_kde_plot.png')
    plt.savefig(save_path_kde, dpi=plot_dpi, bbox_inches='tight')
    plt.close()
    
# Generates a KDE plot of the combined positions of all fish.
def generate_group_kde_plot(all_x, all_y, save_path, plot_dpi=300):
    plt.figure(figsize=(8, 6))
    sns.kdeplot(x=all_x, y=all_y, cmap='viridis', fill=True, cbar=True)
    plt.title('Group KDE Plot', fontweight='bold', fontsize=14)
    plt.xlabel("X Position (cm)")
    plt.ylabel("Y Position (cm)")
    save_path_kde = os.path.join(save_path, 'group_kde_plot.png')
    plt.savefig(save_path_kde, dpi=plot_dpi, bbox_inches='tight')
    plt.close()

# Generates additional plots
def generate_additional_plots(subject, all_x, all_y, save_path, tank_center_x, tank_center_y, fps, plot_dpi=300):
    # Radial Distribution Plot
    radial_distances = np.sqrt((all_x - tank_center_x)**2 + (all_y - tank_center_y)**2)
    plt.figure(figsize=(8, 6))
    sns.histplot(radial_distances, bins=30, kde=True, color='green')
    plt.title(f'Radial Distribution for {subject}', fontweight='bold', fontsize=14)
    plt.xlabel("Radial Distance (cm)")
    save_path_radial = os.path.join(save_path, f'{subject}_radial_distribution.png')
    plt.savefig(save_path_radial, dpi=plot_dpi, bbox_inches='tight')
    plt.close()

    # Angle of Movement Plot
    angles = np.arctan2(np.diff(all_y), np.diff(all_x))
    plt.figure(figsize=(8, 6))
    sns.histplot(angles, bins=30, kde=True, color='purple')
    plt.title(f'Angle of Movement for {subject}', fontweight='bold', fontsize=14)
    plt.xlabel("Angle (radians)")
    save_path_angle = os.path.join(save_path, f'{subject}_angle_of_movement.png')
    plt.savefig(save_path_angle, dpi=plot_dpi, bbox_inches='tight')
    plt.close()

    # Speed Time Series Plot
    speeds, acceleration = calculate_speed_acceleration(all_x, all_y, time_interval=1/fps) #changed to time interval
    time = np.arange(len(speeds)) / fps
    plt.figure(figsize=(8, 6))
    plt.plot(time, speeds, color='blue')
    plt.title(f'Speed Time Series for {subject}', fontweight='bold', fontsize=14)
    plt.ylabel("Speed (cm/s)")
    plt.xlabel("Time (s)")
    save_path_speed = os.path.join(save_path, f'{subject}_speed_timeseries.png')
    plt.savefig(save_path_speed, dpi=plot_dpi, bbox_inches='tight')
    plt.close()

    # Acceleration Time Series Plot
    time = np.arange(len(acceleration)) / fps
    plt.figure(figsize=(8, 6))
    plt.plot(time, acceleration, color='red')
    plt.title(f'Acceleration Time Series for {subject}', fontweight='bold', fontsize=14)
    plt.ylabel("Acceleration (cm/s^2)")
    plt.xlabel("Time (s)")
    save_path_accel = os.path.join(save_path, f'{subject}_acceleration_timeseries.png')
    plt.savefig(save_path_accel, dpi=plot_dpi, bbox_inches='tight')
    plt.close()

    # Polar Plot (Direction of Movement)
    fig, ax = plt.subplots(figsize=(8, 6), subplot_kw={'projection': 'polar'})
    ax.hist(angles, bins=30, color='orange')
    plt.title(f'Polar Plot for {subject}', fontweight='bold', fontsize=14)
    save_path_polar = os.path.join(save_path, f'{subject}_polar_plot.png')
    plt.savefig(save_path_polar, dpi=plot_dpi, bbox_inches='tight')
    plt.close()

    # Occupancy Plot (Time spent in regions)
    plt.figure(figsize=(8, 6))
    heatmap, xedges, yedges = np.histogram2d(all_x, all_y, bins=(10, 10))
    plt.imshow(heatmap.T, origin='lower', cmap='Blues', interpolation='nearest')
    plt.title(f'Occupancy Plot for {subject}', fontweight='bold', fontsize=14)
    plt.xlabel("X Position (cm)")
    plt.ylabel("Y Position (cm)")
    save_path_occupancy = os.path.join(save_path, f'{subject}_occupancy_plot.png')
    plt.savefig(save_path_occupancy, dpi=plot_dpi, bbox_inches='tight')
    plt.close()
    
    # Trajectory Smoothness Plot
    smoothness = np.abs(np.diff(angles))
    time = np.arange(len(smoothness)) / fps
    plt.plot(time, smoothness, color='pink')
    plt.title(f'Trajectory Smoothness for {subject}', fontweight='bold', fontsize=14)
    plt.ylabel("Change in angle (radians)")
    plt.xlabel("Time (s)")
    save_path_smoothness = os.path.join(save_path, f'{subject}_smoothness_plot.png')
    plt.savefig(save_path_smoothness, dpi=plot_dpi, bbox_inches='tight')
    plt.close()


# Generates a trajectory plot overlaying all fish.
def generate_trajectory_plot(df, num_fish, prefix, save_path, plot_dpi=600):
    plt.figure(figsize=(10, 8))
    for i in range(1, num_fish + 1):
        plt.plot(df[f'{prefix}x{i}'], df[f'{prefix}y{i}'], label=f'Fish {i}')
    plt.xlabel("X Position (cm)")
    plt.ylabel("Y Position (cm)")
    plt.title("Overlaid Fish Trajectories", fontweight='bold', fontsize = 14)
    plt.legend()
    plt.gca().set_aspect('equal', adjustable='box')
    save_path_trajectory = os.path.join(save_path, 'overlaid_trajectory.png')
    plt.savefig(save_path_trajectory, dpi=plot_dpi, bbox_inches='tight')
    plt.close()


# Generates a plot of mean pairwise distance with std deviation.
def generate_spread_plot(df, num_fish, prefix, save_path, plot_dpi=600):
    time = df['frame_number'] / fps
    plt.figure(figsize=(10, 6))
    plt.plot(time, df['mean_pairwise_distance'], label='Mean distance')
    plt.fill_between(time, df['mean_pairwise_distance'] - df['std_pairwise_distance'],
                    df['mean_pairwise_distance'] + df['std_pairwise_distance'], alpha=0.4, label='Std')
    plt.legend()
    plt.title("Group pairwise distance", fontweight = 'bold', fontsize = 14)
    plt.xlabel("Time (s)")
    plt.ylabel("Distance (cm)")
    save_path_spread = os.path.join(save_path, 'pairwise_distance.png')
    plt.savefig(save_path_spread, dpi=plot_dpi, bbox_inches='tight')
    plt.close()


In [45]:
def main():
    try:
        df = load_and_prepare_data(file_path, num_fish, tank_width_cm, video_width_pixels, tank_height_cm, video_height_pixels, prefix)
        
        start_frame = int(start_time * fps)
        end_frame = int(end_time * fps)
        df_filtered = df[(df['frame_number'] >= start_frame) & (df['frame_number'] <= end_frame)].copy()
        print(f"Data filtered to {start_time} - {end_time} seconds")

        # Calculate time in halves using the new function and your requested list comprehension
        # This creates two tuples: (fish1_upper, fish2_upper, ...) and (fish1_lower, fish2_lower, ...)
        all_upper_times, all_lower_times = zip(*[
            calculate_time_in_halves_for_fish(df_filtered, i, tank_height_cm, prefix) for i in range(1, num_fish + 1)
        ])

        # Re-structure the results into a dictionary for easy printing and CSV export
        time_in_halves = {}
        for i in range(num_fish):
            fish_name = f'fish{i+1}'
            time_in_halves[fish_name] = {
                'Upper': all_upper_times[i],
                'Lower': all_lower_times[i]
            }
        
        print("Time Spent in Upper/Lower Half (seconds):")
        for fish, times in time_in_halves.items():
            print(f"{fish}: Upper={times['Upper']:.2f}s, Lower={times['Lower']:.2f}s")

        # Calculate and print total distance
        total_distances = calculate_total_distance(df_filtered, num_fish, prefix)
        print("Total distances (cm):")
        for fish, distance in total_distances.items():
            print(f"{fish}: {distance:.2f}")

        # Calculate and print average speed
        average_speeds = calculate_average_speed(df_filtered, num_fish, prefix, fps)
        print("Average speeds (cm/s):")
        for fish, speed in average_speeds.items():
            print(f"{fish}: {speed:.2f}")

        # Output CSVs
        # 1. Time in halves CSV
        time_in_halves_df = pd.DataFrame.from_dict(time_in_halves, orient='index')
        time_in_halves_df.to_csv('time_in_halves.csv')
        print(f"Time in halves data exported to time_in_halves.csv")

        # 2. Total distance and average speed CSV
        distance_speed_df = pd.DataFrame({
            'total_distance': total_distances,
            'average_speed': average_speeds
        })
        distance_speed_df.to_csv('distance_speed.csv')
        print(f"Total distance and average speed data exported to distance_speed.csv")
        
        # 3. Full data with all the calculated variables
        output_file = 'fish_analysis_results.csv'
        # Note: 'fish_half{i}' columns are no longer created, so they won't be in the output
        df_filtered.to_csv(output_file, index=False)
        print(f"Full data exported to {output_file}")

        # Plotting
        generate_trajectory_plot(df_filtered, num_fish, prefix, save_path, plot_dpi=PLOT_DPI)
        print("Overlaid Trajectory plot generated")
        generate_spread_plot(df_filtered, num_fish, prefix, save_path, plot_dpi=PLOT_DPI)
        print("Group pairwise distance plot generated")

        all_group_x = []
        all_group_y = []
        for i in range(1, num_fish + 1):
            all_group_x.extend(df_filtered[f'{prefix}x{i}'].values)
            all_group_y.extend(df_filtered[f'{prefix}y{i}'].values)
            
        generate_group_kde_plot(all_group_x, all_group_y, save_path, plot_dpi=PLOT_DPI)
        print(f"Group kde plot generated")    
         
        for i in range(1,num_fish+1):
            all_x = df_filtered[f'{prefix}x{i}'].values
            all_y = df_filtered[f'{prefix}y{i}'].values
            subject = f'Fish {i}'
            generate_kde_plot(subject, all_x, all_y, save_path, plot_dpi=PLOT_DPI)
            tank_center_x = df_filtered['tank_center_x'].iloc[0]
            tank_center_y = df_filtered['tank_center_y'].iloc[0]
            generate_additional_plots(subject, all_x, all_y, save_path, tank_center_x, tank_center_y, fps, plot_dpi=PLOT_DPI)
            print(f"Plots generated for {subject}")
        print("All plots generated and saved")
        print("Analysis complete.")
            
    except Exception as e:
         print(f"An error occurred: {e}")

In [46]:
if __name__ == '__main__':
    main()

Data loaded successfully.
Pixel to cm conversion completed.
Interpolation of NaNs complete
Calculations completed.
Data filtered to 0 - 300 seconds
Time Spent in Upper/Lower Half (seconds):
fish1: Upper=181.43s, Lower=118.60s
fish2: Upper=194.13s, Lower=105.90s
fish3: Upper=181.67s, Lower=118.37s
fish4: Upper=156.03s, Lower=144.00s
Total distances (cm):
fish1: 162.66
fish2: 496.82
fish3: 642.23
fish4: 551.55
Average speeds (cm/s):
fish1: 0.54
fish2: 1.66
fish3: 2.14
fish4: 1.84
Time in halves data exported to time_in_halves.csv
Total distance and average speed data exported to distance_speed.csv
Full data exported to fish_analysis_results.csv
Overlaid Trajectory plot generated
Group pairwise distance plot generated
Group kde plot generated
Plots generated for Fish 1
Plots generated for Fish 2
Plots generated for Fish 3
Plots generated for Fish 4
All plots generated and saved
Analysis complete.


## Output Files

The script will generate these files in the same location as the python script:

1.  **`time_in_halves.csv`**: Contains the time each fish spent in each halves in seconds.
2.  **`distance_speed.csv`**: Contains the total distance and average speed for each fish in cm and cm/s respectively.
3.  **`fish_analysis_results.csv`**: Contains all the original data together with all the derived metrics
4.  **`plots`**: A folder which contains all the plots generated by the script.
     * **`overlaid_trajectory.png`**: shows the trajectories of all the fish.
     * **`pairwise_distance.png`**: shows the mean pairwise distance between fish over time.
     * **`group_kde_plot.png`**: KDE plot for spatial distribution of all fish combined
     * **`Fish i_kde_plot.png`**: (where i is the number of the fish) shows a KDE plot for the position of fish i
     * **`Fish i_radial_distribution.png`**: (where i is the number of the fish) shows the radial distribution of fish i
     * **`Fish i_angle_of_movement.png`**: (where i is the number of the fish) shows the angle of movement distribution of fish i
     * **`Fish i_speed_timeseries.png`**: (where i is the number of the fish) shows the speed of fish i over time
     * **`Fish i_acceleration_timeseries.png`**: (where i is the number of the fish) shows the acceleration of fish i over time
     * **`Fish i_polar_plot.png`**: (where i is the number of the fish) shows a polar plot of the direction of movement of fish i
     * **`Fish i_occupancy_plot.png`**: (where i is the number of the fish) shows an occupancy plot of where fish i spent most of it's time
     * **`Fish i_smoothness_plot.png`**: (where i is the number of the fish) shows the smoothness of trajectory for fish i