# Fish Tracking Analysis

This Jupyter Notebook provides a comprehensive analysis of fish tracking data obtained as cordinate file from tracking softwares, including individual fish metrics, group dynamics, and visualization of movement patterns.


## Instructions

1.  **Environment Setup:**
    *   Make sure you have Python 3.6 or later installed.
    *   Install the required libraries using pip:

        ```bash
        pip install pandas numpy matplotlib seaborn scipy
        ```
    *   Place your `trajectory.csv` file in the same directory as this notebook.
2.  **Configuration:**
    *   Review the `Configuration and Input Data` section and replace placeholders with your specific data.
        *   `FRAME_RATE`: Frames per second of your video
        *   `VIDEO_WIDTH_PIXELS`: Width of the video in pixels
        *   `VIDEO_HEIGHT_PIXELS`: Height of the video in pixels
        *   `TANK_WIDTH_CM`: Width of the tank in cm
        *   `TANK_HEIGHT_CM`: Height of the tank in cm
        *   `FILE_PATH`: Path to your CSV file containing fish tracking data
        *   `OUTPUT_CSV_FILE_SUMMARY`: Name for output CSV file containing summary data
        *   `OUTPUT_CSV_FILE_OVER_TIME`: Name for output CSV file containing fish movement data over time
        *   `DENSITY_PLOT_FILE`: Output file path for the density heatmap
        *   `TANK_PLOT_FILE`: Output file path for the tank structure plot
        *  `PLOT_SAVE_PATH`: output path for individual fish plots
        *   `START_FRAME` & `END_FRAME` (Optional): Define a subset of frames to analyze or set as None to analyse the whole data.
        *  `PLOT_DPI`: DPI for saving plots
3.  **Running the Notebook:**
    *   Open this notebook in Jupyter Lab or Jupyter Notebook.
    *   Execute all cells by clicking "Cell" > "Run All"

## Input Data Format

The expected input CSV file (`trajectory.csv`) should contain:

*   Each row represents a single frame of the video.
*   Columns representing x and y coordinates for each fish, i.e., `x1`, `y1`, `x2`, `y2`, ...
* The order of the columns must match the number of fish.

In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.spatial import distance
import os

##  Step 1. Configuration and Input Data
Enter the values relevent to your studies

In [2]:
FRAME_RATE = 30  # Frames per second (replace with your value)
VIDEO_WIDTH_PIXELS = 950  # Width of the video in pixels (replace with your value)
VIDEO_HEIGHT_PIXELS = 575  # Height of the video in pixels (replace with your value)
TANK_WIDTH_CM = 14  # Width of the tank in cm (replace with your value)
TANK_HEIGHT_CM = 8  # Height of the tank in cm (replace with your value)
FILE_PATH = '1.csv'  # Path to your CSV file
OUTPUT_CSV_FILE_SUMMARY = 'individual_fish_summary.csv'  # Output file for stats
OUTPUT_CSV_FILE_OVER_TIME = 'fish_movement_over_time.csv'
DENSITY_PLOT_FILE = 'density_heatmap.png'  # Density plot output path
TANK_PLOT_FILE = 'tank_structure_plot.png'  # Tank plot output path
START_FRAME = 0  # Start frame for time range analysis
END_FRAME = None  # End frame for time range analysis
PLOT_DPI = 600  # DPI for saving plot
PLOT_SAVE_PATH = 'individual_fish_plots' # output path for individual fish plots

##  Step 2. Cleaning and Interpolation of missing data
We will interpolate the missing values i.e. NaN values present in csv file using linear interpolation method

In [7]:
# Assuming your data is stored in a CSV file
df = pd.read_csv(FILE_PATH, sep=',')

# Check for NaN values in the dataset
print("Before interpolation:")
print(df.isna().sum())

# Interpolate the NaN values (linear interpolation by default)
df_interpolated = df.interpolate()

# Check if the NaN values have been filled
print("After interpolation:")
print(df_interpolated.isna().sum())

# Save the cleaned data to a new CSV file
df_interpolated.to_csv('trial_interpolated_cleaned.csv', index=False)
FILE_PATH = 'trial_interpolated_cleaned.csv'
print(df_interpolated.head())

Before interpolation:
x1    0
y1    0
x2    0
y2    0
x3    0
y3    0
x4    0
y4    0
dtype: int64
After interpolation:
x1    0
y1    0
x2    0
y2    0
x3    0
y3    0
x4    0
y4    0
dtype: int64
        x1       y1       x2       y2       x3       y3       x4       y4
0  342.829  384.614  923.479  291.459  805.567  225.575  681.106  280.663
1  342.425  384.860  924.132  291.129  805.869  221.234  681.472  281.387
2  343.260  384.392  924.050  290.802  806.627  217.080  682.230  282.216
3  343.167  384.928  924.650  290.514  806.436  215.903  682.617  283.060
4  343.484  385.030  924.692  290.096  806.574  214.784  683.254  283.923


# Step 3. Data Loading and Preparation


In [8]:
def load_and_prepare_data(file_path, num_fish):
    """Loads the coordinate data, renames columns, adds frame number,
       and converts pixel coordinates to cm.
    """
    df = pd.read_csv(file_path, sep=',')  # Reads a comma separated file
    # Rename columns to 'x1','y1','x2','y2' ...
    df.columns = [f'{axis}{i + 1}' for i in range(num_fish) for axis in ['x', 'y']]
    df['frame_number'] = range(len(df))  # create the frame column
    df = convert_pixels_to_cm(df, num_fish)  # convert the coordinates from pixel to cm
    return df

def convert_pixels_to_cm(df, num_fish):
    """Converts pixel coordinates to cm coordinates."""
    pixel_to_cm_x = TANK_WIDTH_CM / VIDEO_WIDTH_PIXELS
    pixel_to_cm_y = TANK_HEIGHT_CM / VIDEO_HEIGHT_PIXELS

    for i in range(1, num_fish + 1):
        df[f'x{i}'] = df[f'x{i}'] * pixel_to_cm_x
        df[f'y{i}'] = df[f'y{i}'] * pixel_to_cm_y
    return df

# Step 4. Individual Fish Analysis


In [9]:
def calculate_movement_metrics(df, fish_number):
    """Calculates speed, acceleration, and movement angles."""
    x_col = f'x{fish_number}'
    y_col = f'y{fish_number}'
    dx = df[x_col].diff()
    dy = df[y_col].diff()
    df[f'distance{fish_number}'] = np.sqrt(dx ** 2 + dy ** 2)
    df[f'speed{fish_number}'] = df[f'distance{fish_number}'] * FRAME_RATE
    df[f'acceleration{fish_number}'] = df[f'speed{fish_number}'].diff() * FRAME_RATE
    angles_rad = np.arctan2(dy, dx)
    angles_deg = np.degrees(angles_rad)
    df[f'angle{fish_number}'] = angles_deg
    return df

def calculate_polarity(df, fish_number):
    """Calculates the polarity of movement (forward, backward, other)."""
    # Assuming you have a column with orientation 'orientation_fish1'
    orientation_col = f'orientation{fish_number}'
    angle_col = f'angle{fish_number}'
    if orientation_col in df.columns:
        # If you have real orientations
        relative_angle = (df[angle_col] - df[orientation_col]) % 360
        df[f'polarity{fish_number}'] = np.select(
            [
                (relative_angle > 315) | (relative_angle < 45),  # Forward motion
                (relative_angle > 135) & (relative_angle < 225)  # Backward motion
            ],
            ["forward", "backward"],
            default="other"
        )
    else:
        # if you don't have orientation, use heading direction as an approximation
        df[f'polarity{fish_number}'] = np.select(
            [
                (df[angle_col] > 315) | (df[angle_col] < 45),  # Forward motion
                (df[angle_col] > 135) & (df[angle_col] < 225)  # Backward motion
            ],
            ["forward", "backward"],
            default="other"
        )
    return df


def calculate_cumulative_distance(df, fish_number):
    """Calculates the cumulative distance travelled by fish."""
    df[f'cumulative_distance{fish_number}'] = df[f'distance{fish_number}'].cumsum()
    return df

# Step 5. Group Analysis

In [10]:
def calculate_group_metrics(df, num_fish):
    """Calculates group metrics like centroid, average distance, spread."""
    x_cols = [f'x{i}' for i in range(1, num_fish + 1)]
    y_cols = [f'y{i}' for i in range(1, num_fish + 1)]

    df['group_centroid_x'] = df[x_cols].mean(axis=1)
    df['group_centroid_y'] = df[y_cols].mean(axis=1)
    # Calculate pairwise distances
    pairwise_distances = []
    for index, row in df.iterrows():
        coordinates = []
        for i in range(1, num_fish + 1):
            coordinates.append((row[f'x{i}'], row[f'y{i}']))
        distances = [distance.euclidean(p1, p2) for p1 in coordinates for p2 in coordinates]
        pairwise_distances.append(distances)
    df['pairwise_distances'] = pairwise_distances
    df['mean_pairwise_distance'] = df['pairwise_distances'].apply(lambda x: np.mean(x))
    df['std_pairwise_distance'] = df['pairwise_distances'].apply(lambda x: np.std(x))
    return df


# Step 6. Time Spent Analysis and Stats for time range

In [11]:
def calculate_time_spent(df, num_fish, start_frame=None, end_frame=None):
    """Calculates the time spent by each fish in the upper and lower halves of the tank."""
    if start_frame is not None and end_frame is not None:
        df_subset = df[(df['frame_number'] >= start_frame) & (df['frame_number'] <= end_frame)]
    else:
        df_subset = df
    stats = []
    for i in range(1, num_fish + 1):
        mid_y = TANK_HEIGHT_CM / 2
        upper_time = len(df_subset[df_subset[f'y{i}'] < mid_y]) / FRAME_RATE
        lower_time = len(df_subset[df_subset[f'y{i}'] >= mid_y]) / FRAME_RATE
        stats.append({'fish_number': i, 'upper_time': upper_time, 'lower_time': lower_time})
    return pd.DataFrame(stats)

def calculate_stats_for_time_range(df, num_fish, start_frame=None, end_frame=None):
    """Calculates average speed, dispersion, cumulative distance for the group"""
    if start_frame is not None and end_frame is not None:
        df_subset = df[(df['frame_number'] >= start_frame) & (df['frame_number'] <= end_frame)]
    else:
        df_subset = df
    # Group average speed calculation
    speed_cols = [f'speed{i}' for i in range(1, num_fish + 1)]
    df_subset['group_speed'] = df_subset[speed_cols].mean(axis=1)
    avg_group_speed = df_subset['group_speed'].mean()
    # Group average spread
    avg_group_dispersion = df_subset['std_pairwise_distance'].mean()
    # Group average cumulative distance
    cum_dist_cols = [f'cumulative_distance{i}' for i in range(1, num_fish + 1)]
    df_subset['sum_cumulative_distance'] = df_subset[cum_dist_cols].sum(axis=1)
    avg_cumulative_distance = df_subset['sum_cumulative_distance'].mean()
    # create stats for the time range
    num_frames = len(df_subset)
    time_in_seconds = num_frames / FRAME_RATE
    stats = {'average_group_speed': avg_group_speed,
             'group_dispersion': avg_group_dispersion,
             'group_cumulative_distance': avg_cumulative_distance,
             'total_frames': num_frames,
             'total_time_seconds': time_in_seconds
             }
    return pd.DataFrame([stats])

def calculate_individual_stats_for_time_range(df, num_fish, start_frame=None, end_frame=None):
    """Calculates average speed, average acceleration, total distance for individual fish"""
    if start_frame is not None and end_frame is not None:
        df_subset = df[(df['frame_number'] >= start_frame) & (df['frame_number'] <= end_frame)]
    else:
        df_subset = df

    stats = []
    for i in range(1, num_fish + 1):
        avg_speed = df_subset[f'speed{i}'].mean()
        avg_acceleration = df_subset[f'acceleration{i}'].mean()
        total_distance = df_subset[f'distance{i}'].sum()
        stats.append({'fish_number': i, 'average_speed': avg_speed,
                      'average_acceleration': avg_acceleration, 'total_distance': total_distance})
    return pd.DataFrame(stats)


# Step 7. Additional Plotting Functions

In [12]:
def calculate_speed_acceleration(all_x, all_y):
    dx = np.diff(all_x)
    dy = np.diff(all_y)
    distances = np.sqrt(dx**2 + dy**2)
    speeds = distances
    acceleration = np.diff(speeds)
    return speeds, acceleration

def generate_kde_plot(fish_num, all_x, all_y, save_path):
    plt.figure(figsize=(8, 6), dpi=PLOT_DPI)
    sns.kdeplot(x=all_x, y=all_y, cmap='YlOrRd', fill=True, cbar=True)
    plt.title(f'KDE Plot for Fish {fish_num}', fontweight='bold', fontsize=14)
    plt.axis('off')
    save_path_kde = os.path.join(save_path, f'fish{fish_num}_kde_plot.png')
    plt.savefig(save_path_kde, dpi=PLOT_DPI, bbox_inches='tight')
    plt.close()

def generate_additional_plots(fish_num, all_x, all_y, save_path):
    # Radial Distribution Plot
    radial_distances = np.sqrt((all_x - TANK_WIDTH_CM/2)**2 + (all_y - TANK_HEIGHT_CM/2)**2)
    plt.figure(figsize=(8, 6), dpi=PLOT_DPI)
    sns.histplot(radial_distances, bins=30, kde=True, color='green')
    plt.title(f'Radial Distribution for Fish {fish_num}', fontweight='bold', fontsize=14)
    plt.axis('off')
    save_path_radial = os.path.join(save_path, f'fish{fish_num}_radial_distribution.png')
    plt.savefig(save_path_radial, dpi=PLOT_DPI, bbox_inches='tight')
    plt.close()

    # Angle of Movement Plot
    angles = np.arctan2(np.diff(all_y), np.diff(all_x))
    plt.figure(figsize=(8, 6), dpi=PLOT_DPI)
    sns.histplot(angles, bins=30, kde=True, color='purple')
    plt.title(f'Angle of Movement for Fish {fish_num}', fontweight='bold', fontsize=14)
    plt.axis('off')
    save_path_angle = os.path.join(save_path, f'fish{fish_num}_angle_of_movement.png')
    plt.savefig(save_path_angle, dpi=PLOT_DPI, bbox_inches='tight')
    plt.close()

    # Speed Time Series Plot
    speeds, acceleration = calculate_speed_acceleration(all_x, all_y)
    plt.figure(figsize=(8, 6), dpi=PLOT_DPI)
    plt.plot(speeds, color='blue')
    plt.title(f'Speed Time Series for Fish {fish_num}', fontweight='bold', fontsize=14)
    plt.axis('off')
    save_path_speed = os.path.join(save_path, f'fish{fish_num}_speed_timeseries.png')
    plt.savefig(save_path_speed, dpi=PLOT_DPI, bbox_inches='tight')
    plt.close()

    # Acceleration Time Series Plot
    plt.figure(figsize=(8, 6), dpi=PLOT_DPI)
    plt.plot(acceleration, color='red')
    plt.title(f'Acceleration Time Series for Fish {fish_num}', fontweight='bold', fontsize=14)
    plt.axis('off')
    save_path_accel = os.path.join(save_path, f'fish{fish_num}_acceleration_timeseries.png')
    plt.savefig(save_path_accel, dpi=PLOT_DPI, bbox_inches='tight')
    plt.close()

    # Polar Plot (Direction of Movement)
    fig, ax = plt.subplots(figsize=(8, 6), dpi=PLOT_DPI, subplot_kw={'projection': 'polar'})
    ax.hist(angles, bins=30, color='orange')
    plt.title(f'Polar Plot for Fish {fish_num}', fontweight='bold', fontsize=14)
    save_path_polar = os.path.join(save_path, f'fish{fish_num}_polar_plot.png')
    plt.savefig(save_path_polar, dpi=PLOT_DPI, bbox_inches='tight')
    plt.close()

    # Occupancy Plot (Time spent in regions)
    plt.figure(figsize=(8, 6), dpi=PLOT_DPI)
    heatmap, xedges, yedges = np.histogram2d(all_x, all_y, bins=(8, 8))
    plt.imshow(heatmap.T, origin='lower', cmap='Blues', interpolation='nearest')
    plt.title(f'Occupancy Plot for Fish {fish_num}', fontweight='bold', fontsize=14)
    plt.axis('off')
    save_path_occupancy = os.path.join(save_path, f'fish{fish_num}_occupancy_plot.png')
    plt.savefig(save_path_occupancy, dpi=PLOT_DPI, bbox_inches='tight')
    plt.close()

    # Trajectory Smoothness Plot
    smoothness = np.abs(np.diff(angles))
    plt.figure(figsize=(8, 6), dpi=PLOT_DPI)
    plt.plot(smoothness, color='pink')
    plt.title(f'Trajectory Smoothness for Fish {fish_num}', fontweight='bold', fontsize=14)
    plt.axis('off')
    save_path_smoothness = os.path.join(save_path, f'fish{fish_num}_smoothness_plot.png')
    plt.savefig(save_path_smoothness, dpi=PLOT_DPI, bbox_inches='tight')
    plt.close()


# Step 8. Main Analysis

In [13]:
def main():
    num_fish = 4
    df = load_and_prepare_data(FILE_PATH, num_fish)

    for i in range(1, num_fish + 1):
        df = calculate_movement_metrics(df, i)
        df = calculate_polarity(df, i)
        df = calculate_cumulative_distance(df, i)

    df = calculate_group_metrics(df, num_fish)

    # Calculate Time Spent in halves
    time_spent_df = calculate_time_spent(df, num_fish, START_FRAME, END_FRAME)

    # Calculate individual stats for time range
    individual_stats_df = calculate_individual_stats_for_time_range(df, num_fish, START_FRAME, END_FRAME)
    # Saving Data to Summary CSV
    combined_df = pd.concat([time_spent_df, individual_stats_df], axis=1)
    save_data_to_csv(combined_df, OUTPUT_CSV_FILE_SUMMARY)
    # Saving Data to Over Time CSV 
    save_data_to_csv(df, OUTPUT_CSV_FILE_OVER_TIME)
    # Calculate group stats for time range
    group_stats_df = calculate_stats_for_time_range(df, num_fish, START_FRAME, END_FRAME)
    # Plotting Density and Tank Structure 
    plot_density_heatmap(df, num_fish, DENSITY_PLOT_FILE)
    plot_tank_structure(df, num_fish, TANK_PLOT_FILE)
    # Generate additional plots
    os.makedirs(PLOT_SAVE_PATH, exist_ok=True)
    for i in range(1, num_fish + 1):
        x_col = f'x{i}'
        y_col = f'y{i}'
        all_x = df[x_col].dropna()
        all_y = df[y_col].dropna()
        generate_kde_plot(i, all_x, all_y, PLOT_SAVE_PATH)
        generate_additional_plots(i, all_x, all_y, PLOT_SAVE_PATH)

    print("Analysis Complete!")
    print(f"Individual fish stats are saved in {OUTPUT_CSV_FILE_SUMMARY}")
    print(f"Fish movement data over time is saved in {OUTPUT_CSV_FILE_OVER_TIME}")
    print(f"Density plot is saved in {DENSITY_PLOT_FILE}")
    print(f"Tank plot is saved in {TANK_PLOT_FILE}")
    print(f"Individual fish plots are saved in {PLOT_SAVE_PATH}")


# Saving Data to CSV Functions
def save_data_to_csv(df, file_path):
    """Saves the DataFrame with all calculated stats to a CSV file."""
    df.to_csv(file_path, index=False)

# Plotting Functions 

def plot_density_heatmap(df, num_fish, output_path):
    """Generates and saves a 2D density heatmap of fish positions."""
    x_coords = []
    y_coords = []
    for i in range(1, num_fish + 1):
        x_coords.extend(df[f'x{i}'])
        y_coords.extend(df[f'y{i}'])
    plt.figure(figsize=(8, 6), dpi=PLOT_DPI)
    plt.hist2d(x_coords, y_coords, bins=10, cmap='viridis')
    plt.colorbar(label='Density')
    plt.title('Fish Position Density Heatmap')
    plt.xlabel('X Position (cm)')
    plt.ylabel('Y Position (cm)')
    plt.gca().invert_yaxis()
    plt.savefig(output_path, dpi=PLOT_DPI)
    plt.close()

def plot_tank_structure(df, num_fish, output_path):
    """Generates and saves a plot with tank structure, quadrants, and center rectangle."""
    plt.figure(figsize=(8, 6), dpi=PLOT_DPI)
    # Draw the tank rectangle
    plt.plot([0, TANK_WIDTH_CM, TANK_WIDTH_CM, 0, 0], [0, 0, TANK_HEIGHT_CM, TANK_HEIGHT_CM, 0],
             color='black', linestyle='--', label='Tank Boundary')
    # Define quadrant size and offsets based on tank dimensions
    quadrant_width = TANK_WIDTH_CM / 4
    quadrant_height = TANK_HEIGHT_CM / 4
    offset = 0  # Set to zero so it starts from corners
    # Draw Quadrants at each corner
    quadrant_coords = [
        (offset, offset),  # Bottom-left
        (offset, TANK_HEIGHT_CM - quadrant_height),  # Top-left
        (TANK_WIDTH_CM - quadrant_width, TANK_HEIGHT_CM - quadrant_height),  # Top-right
        (TANK_WIDTH_CM - quadrant_width, offset)  # Bottom-right
    ]
    for x, y in quadrant_coords:
        plt.gca().add_patch(plt.Rectangle((x, y), quadrant_width, quadrant_height,
                                         facecolor='lightcoral', alpha=0.4, edgecolor='black',
                                         label='Corner Quadrants' if x == 0 and y == 0 else None))
    # Define center rectangle dimensions
    center_rect_width = TANK_WIDTH_CM / 2
    center_rect_height = TANK_HEIGHT_CM / 2
    center_rect_x = (TANK_WIDTH_CM - center_rect_width) / 2
    center_rect_y = (TANK_HEIGHT_CM - center_rect_height) / 2
    # Draw Center Rectangle
    plt.gca().add_patch(plt.Rectangle((center_rect_x, center_rect_y),
                                        center_rect_width, center_rect_height,
                                        facecolor='lightblue', alpha=0.5,
                                        edgecolor='black', label='Center Rectangle'))
    # plot trajectories
    for i in range(1, num_fish + 1):
        plt.plot(df[f'x{i}'], df[f'y{i}'], alpha=0.4, label=f'Trajectory Fish {i}')
    plt.title("Tank Structure, Quadrants, and Center Rectangle")
    plt.xlabel("X Position (cm)")
    plt.ylabel("Y Position (cm)")
    plt.gca().invert_yaxis()  # Invert y-axis
    plt.legend(loc='upper right', fontsize='small')
    plt.savefig(output_path, dpi=PLOT_DPI)
    plt.close()


# Step 9. Run

In [14]:
if __name__ == '__main__':
    main()

Analysis Complete!
Individual fish stats are saved in individual_fish_summary.csv
Fish movement data over time is saved in fish_movement_over_time.csv
Density plot is saved in density_heatmap.png
Tank plot is saved in tank_structure_plot.png
Individual fish plots are saved in individual_fish_plots
