In [1]:
import os
import cv2
import tqdm
import pickle
import numpy as np
from ultralytics import YOLO
import matplotlib.pyplot as plt
from utils.dataset_utils import *
from collections import defaultdict
from deep_sort_realtime.deepsort_tracker import DeepSort

In [2]:
raw_video_folder = r'..\data\raw\pred_prey_interaction'
yolo_path = r'..\models\costumized_yolo\costumized_yolo\costumized_yolo.pt'
processed_video_folder = r'..\data\processed\pred_prey_interactions'

In [3]:
model = YOLO(yolo_path)
tracker = DeepSort(max_age=30)

num_frames=1
total_detections=33

for video in os.listdir(raw_video_folder):

    print(f"\nProcessing {video}...")

    # Load the video
    video_path = os.path.join(raw_video_folder, video)
    cap = cv2.VideoCapture(video_path)
    total_frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    
    
    # total_frames
    os.makedirs(os.path.join(processed_video_folder, "total_frames"), exist_ok=True)
    tf_path = os.path.join(processed_video_folder, "total_frames", f"total_frames_{video}.pkl")

    if os.path.exists(tf_path):
        with open(tf_path, "rb") as f:
            total_frames = pickle.load(f)
    else:
        total_frames = []
        for frame in tqdm.tqdm(range(total_frame_count), desc="Processing frames"):
            frame_records = process_frame(cap, model, tracker, frame)
            total_frames.extend(frame_records)
        cap.release()
        with open(tf_path, "wb") as f:
            pickle.dump(total_frames, f)


    # filtered_frames
    os.makedirs(os.path.join(processed_video_folder, "filtered_frames"), exist_ok=True)
    ff_path = os.path.join(processed_video_folder, "filtered_frames", f"filtered_frames_{video}.pkl")
    ms_path = os.path.join(processed_video_folder, "filtered_frames", f"max_speed_{video}.pkl")

    if os.path.exists(ff_path) and os.path.exists(ms_path):
        with open(ff_path, "rb") as f:
            filtered_frames = pickle.load(f)

        with open(ms_path, "rb") as f:
            max_speed = pickle.load(f)
    else:
        filtered_frames, max_speed = filter_frames(total_frames)
        with open(ff_path, "wb") as f:
            pickle.dump(filtered_frames, f)

        with open(ms_path, "wb") as f:
            pickle.dump(max_speed, f)


    # full_track_windows
    os.makedirs(os.path.join(processed_video_folder, "full_track_windows"), exist_ok=True)
    os.makedirs(os.path.join(processed_video_folder, "full_track_windows",  f"{total_detections}"), exist_ok=True)
    ftw_path = os.path.join(processed_video_folder, "full_track_windows", f"{total_detections}", f"full_track_windows_{total_detections}_{video}.pkl")
    vw_path = os.path.join(processed_video_folder, "full_track_windows", f"{total_detections}", f"valid_windows_{total_detections}_{video}.pkl")

    if os.path.exists(ftw_path) and os.path.exists(vw_path):
        with open(ftw_path, "rb") as f:
            full_track_windows = pickle.load(f)

        with open(vw_path, "rb") as f:
            valid_windows = pickle.load(f)
    else:
        full_track_windows, valid_windows = find_valid_windows(filtered_frames, num_frames=num_frames, total_detections=total_detections)
        with open(ftw_path, "wb") as f:
            pickle.dump(full_track_windows, f)

        with open(vw_path, "wb") as f:
            pickle.dump(valid_windows, f)

    print(f"Found {len(valid_windows)} windows with {total_detections} continuous detections.")


Processing pred_prey_interaction_0.07.mp4...
Found 22 windows with 33 continuous detections.

Processing pred_prey_interaction_0.14.mp4...
Found 1 windows with 33 continuous detections.

Processing pred_prey_interaction_0.15.mp4...
Found 84 windows with 33 continuous detections.

Processing pred_prey_interaction_0.16.mp4...
Found 6 windows with 33 continuous detections.

Processing pred_prey_interaction_0.17.mp4...
Found 37 windows with 33 continuous detections.

Processing pred_prey_interaction_0.24.mp4...
Found 89 windows with 33 continuous detections.

Processing pred_prey_interaction_0.27.mp4...
Found 92 windows with 33 continuous detections.

Processing pred_prey_interaction_0.36.mp4...
Found 100 windows with 33 continuous detections.

Processing pred_prey_interaction_0.41.mp4...
Found 150 windows with 33 continuous detections.

Processing pred_prey_interaction_1.01.mp4...
Found 121 windows with 33 continuous detections.

Processing pred_prey_interaction_1.07.mp4...
Found 147 win

In [6]:
pred_tensors_list = []
prey_tensors_list = []

os.makedirs(os.path.join(processed_video_folder, "expert_tensors", "with_velocity"), exist_ok=True)
pred_et_path = os.path.join(processed_video_folder, "expert_tensors", "with_velocity", f"pred_tensors_velo_{num_frames}.pkl")
prey_et_path = os.path.join(processed_video_folder, "expert_tensors", "with_velocity", f"prey_tensors_velo_{num_frames}.pkl")

if os.path.exists(pred_et_path) and os.path.exists(prey_et_path):
    with open(pred_et_path, "rb") as f:
        pred_tensors = pickle.load(f)
    with open(prey_et_path, "rb") as f:
        prey_tensors = pickle.load(f)
else:
    for video in os.listdir(raw_video_folder):
        pred_tensor, prey_tensor = get_expert_tensors_velo(full_track_windows, valid_windows, width, height, window_size=num_frames)
        pred_tensors_list.append(pred_tensor)
        prey_tensors_list.append(prey_tensor)

    pred_tensors = torch.cat(pred_tensors_list, dim=0)
    prey_tensors = torch.cat(prey_tensors_list, dim=0)

    with open(pred_et_path, "wb") as f:
        pickle.dump(pred_tensors, f)
    with open(prey_et_path, "wb") as f:
        pickle.dump(prey_tensors, f)

print("Pred Tensors Shape:", tuple(pred_tensors.shape))
print("Prey Tensors Shape:", tuple(prey_tensors.shape))

Pred Tensors Shape: (21420, 1, 32, 6)
Prey Tensors Shape: (21420, 32, 32, 6)
