In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import glob
from scipy.spatial import distance
import pickle

folder = r'D:\datos_GPetersii\datos_GPetersii\Fish4\Object\Pose Estimation'
os.chdir(folder)
files = sorted(glob.glob('*clean.h5'))
print('hay ' + str(len(files)) + ' archivos')

In [None]:
#cargamos el archivo de FB-DOE
with open('fish4_FB-DOE.pkl', 'rb') as file:   #cambiar al nombre apropiado de archivo
        FB_doe = pickle.load(file)


In [None]:
files[0]

In [None]:
def get_centroids(track, n_keypoints):
    centroids = []

    bodyparts = np.array([x for x in track.columns.get_level_values(1)])
    xpositions = pd.DataFrame(np.zeros((len(track),n_keypoints)), columns=np.unique(bodyparts))
    for i in range(n_keypoints):
        xpositions[np.unique(bodyparts)[i]] = (track[track.columns.get_level_values(0)[0], np.unique(bodyparts)[i], 'x'])
        
    median_xposition = np.median(xpositions, axis=1) #xpos

    ypositions = pd.DataFrame(np.zeros((len(track),n_keypoints)), columns=np.unique(bodyparts))
    for i in range(n_keypoints):
        ypositions[np.unique(bodyparts)[i]] = (track[track.columns.get_level_values(0)[0], np.unique(bodyparts)[i], 'y'])
        
    median_yposition = np.median(ypositions, axis=1) #ypos

    for j in range(len(median_xposition)):
        centroids.append([median_xposition[j], median_yposition[j]])

    return centroids

In [None]:
centroids = get_centroids(track, n_keypoints=6)

In [None]:
def calculate_velocity(centroids, sf, pix_to_cm):
    desplazamiento = [distance.euclidean(x,y)*pix_to_cm for x, y in zip(centroids[1:], centroids[:-1])]
    dt = len(centroids) / sf
    v = [i/dt for i in desplazamiento]

    return v

In [None]:
velocity = calculate_velocity(centroids, sf=150, pix_to_cm=12)

In [None]:
%matplotlib widget
plt.figure()
plt.plot(velocity)
plt.scatter(range(len(velocity)), velocity)
plt.plot(range(len(velocity)), [1 for x in velocity], c='r')

In [None]:
vel_per_frame = pd.DataFrame(zip([round(x[0]/12) for x in centroids], [round(y[1]/12) for y in centroids], velocity), columns=['x', 'y', 'v'])
vel_per_frame

In [None]:
def plot_map(grid, objCoordinates,cmap, label,filename, vmax=None,vmin=None):
    fig, ax = plt.subplots()
    plt.imshow(grid, cmap=cmap, vmax=vmax, vmin=vmin, origin='lower')
    cbar = plt.colorbar()
    cbar.set_label(label)
    #plt.scatter(objCoordinates[1]/10, objCoordinates[0]/10, s=100, c='k')
    #fig.savefig(filename, format='svg', dpi=1200)

In [None]:
idx_movement = [i for i,x in enumerate(velocity) if x > 0.1]
idx_movement

In [None]:
duration_moments = [i-j for i,j in zip(idx_movement[1:], idx_movement[:-1])]
single_moments = [1 if i < 50 else 0 for i in duration_moments]
single_moments

In [None]:
start = [0]
end = []
for i, x in enumerate(single_moments[:-1]):
    if (x == 1 and single_moments[i-1] == 0):
        start.append(i)
    if (x == 1 and single_moments[i+1] == 0):
        end.append(i+1)

frames_moments = [(idx_movement[x],idx_movement[y]) for x,y in zip(start, end) if not x==y]

In [None]:
frames_moments

In [None]:
import seaborn as sns
def plot_trayectories(vel_per_frame, start_frame, end_frame):
    data = vel_per_frame.iloc[start_frame:end_frame+1, :]
    #graficamos
    fig, ax = plt.subplots()
    points = sns.scatterplot(
        data=data,
        x='x',
        y='y',
        size='v',
        hue='v'
    )
    ax.plot(data['x'], data['y'], c='pink', linewidth=.5, linestyle='--', alpha=.5)
    ax.scatter(x =  439.2/12, y = 207.53/12, c='crimson', s=150, marker='*')
    ax.set_ylim([0,63])
    ax.set_xlim([0,63])
    plt.show()

In [None]:
import subprocess

def extract_frames(input_video, output_video, start_frame, end_frame, fps):
    # Construct ffmpeg command

    ffmpeg_cmd = [
        '/opt/local/bin/ffmpeg', 
        '-i', input_video,
        '-vf', f'select=\'between(n\,{start_frame}\,{end_frame})\',setpts=PTS-STARTPTS',
        '-r', str(fps),
        output_video
    ]
    
    # Execute ffmpeg command
    subprocess.run(ffmpeg_cmd)

In [None]:
vid_folder = '/Volumes/Expansion/datos_GPetersii/datos_GPetersii/Fish1/Object/raw/50fps/'
videos = sorted(glob.glob(vid_folder+'*.avi'))

In [None]:
from random import sample
rand_frames_moments = sample(frames_moments, 5)
rand_frames_moments

In [None]:
for moment in rand_frames_moments:
    plot_trayectories(vel_per_frame, start_frame=moment[0] - 500, end_frame=moment[1] + 500)

In [None]:
#vid_folder = '/Volumes/Expansion/datos_GPetersii/datos_GPetersii/Fish1/Object/raw/'
#videos = glob.glob(vid_folder+'*.avi')
#for vid in videos :
input_video = videos[0]
for i, moment in enumerate(rand_frames_moments):
    output_video = input_video[74:-9] + '_moment_' + str(i) + '.avi'
    start_frame = moment[0] - 100 # Start frame index
    end_frame = moment[1] + 100   # End frame index
    fps = 50           # Desired frames per second for output video

    extract_frames(input_video, output_video, start_frame, end_frame, fps)
