In [None]:
import numpy as np
import pandas as pd
import tracktor as tr
import cv2
import sys
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from scipy.optimize import linear_sum_assignment
from scipy.spatial.distance import cdist

In [None]:
# This function will take in certain parameters and output three things
# 1) A plot showing the distance between two shrimp over time
# 2) A print statement that defines when important things are occuring
# 3) A dataframe with every timestamp and distance for further manipulation


# Parameters defined below
# dataframe - will be the raw outputted dataframe from tracktor (usually called df)
# fps - frames per second of the video (usually 60)
# id1 - the id of the first shrimp you want to track (usually 1, 2, 3, etc...)
# id2 - the id of the second shrimp you want to track how close it gets to the first
# threshold - the maximum distance between two animals where their interaction is deemed 'important'

def dist_between(dataframe, fps, id1, id2, threshold):   
    # This will subset our dataframe into one with just the two animals we care about
    dataframe = dataframe[(dataframe["id"] == id1) | (dataframe["id"] == id2)]

    # For tracktor purposes (ignore this)
    for idx, ID in enumerate(np.unique(dataframe['id'])):
        dataframe['id'][dataframe['id'] == ID] = idx
        
    # Calculating the distance between shrimp 1 and shrimp 2     
    distances = []
    for fr in np.unique(dataframe['frame']):
        tmp = dataframe[dataframe['frame'] == fr]
        x = tmp[tmp['id'] == 0]['pos_x'].values[0] - tmp[tmp['id'] == 1]['pos_x'].values[0]
        y = tmp[tmp['id'] == 0]['pos_y'].values[0] - tmp[tmp['id'] == 1]['pos_y'].values[0]
        distances.append(np.sqrt(x**2 + y**2))
        
    # Adding timestamps to our dataframe to make it easier to locate on the video    
    timestamp = np.unique(dataframe['frame'])/fps
    
    # Creating a dataframe with the distance between two shrimp at what time
    dist_df = pd.DataFrame([timestamp, distances]).transpose()
    dist_df.columns = ["Timestamp", "Distance"]
    
    # Plotting the distance between the two shrimp over time
    plt.scatter(np.unique(dataframe['frame'])/fps, distances, c='#32CD32', s=5, alpha=0.5)
    plt.xlabel('Time (sec)', fontsize=16)
    plt.ylabel('Distance Between Zebra ' + str(id1) + ' and ' + str(id2), fontsize=16)
    plt.tight_layout()
    plt.savefig('imgs/ex3_fig2.eps', format='eps', dpi=300)
    plt.show()
    
    # Pritning our when our shrimp are at an "important" distance from each other 
    dist_df_imp = dist_df.loc[dist_df["Distance"] < threshold, "Timestamp"]
    print('Zebras ' + str(id1) + ' and ' + str(id2) + " get close from " \
          + str(dist_df_imp.min()) + \
          " seconds to " + str(dist_df_imp.max()))
    
    # Returns a dataframe of the distance and time for further manipulation
    return dist_df