In [1]:
import os
import cv2
import tqdm
import torch
import pickle
import numpy as np
from ultralytics import YOLO
import matplotlib.pyplot as plt
from utils.dataset_utils import *
from collections import defaultdict
from deep_sort_realtime.deepsort_tracker import DeepSort

In [2]:
yolo_path = rf'..\models\costumized_yolo\costumized_yolo\costumized_yolo.pt' 
raw_video_folder = rf'..\data\1. Data Processing\raw\pred_prey_interaction' 
processed_video_folder = rf'..\data\1. Data Processing\processed\video'

In [7]:
num_frames=1            # number of consecutive frames
total_detections=33     # number of total detections in frame
window_len = 10         # length of extracted windows

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = YOLO(yolo_path)
model.to(device)
tracker = DeepSort(max_age=30)

pred_tensors_all = []
prey_tensors_all = []
total_coordinates = []

for video in os.listdir(raw_video_folder):

    print(f"\nProcessing {video}...")

    # Load the video
    video_path = os.path.join(raw_video_folder, video)
    cap = cv2.VideoCapture(video_path)
    total_frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    
    
    # total_frames
    os.makedirs(os.path.join(processed_video_folder, "1. total_frames"), exist_ok=True)
    tf_path = os.path.join(processed_video_folder, "1. total_frames", f"total_frames_{video}.pkl")

    if os.path.exists(tf_path):
        with open(tf_path, "rb") as f:
            total_frames = pickle.load(f)
    else:
        total_frames = []
        for frame in tqdm.tqdm(range(total_frame_count), desc="Processing frames"):
            frame_records = process_frame(cap, model, tracker, frame)
            total_frames.extend(frame_records)
        cap.release()
        with open(tf_path, "wb") as f:
            pickle.dump(total_frames, f)


    # filtered_frames
    os.makedirs(os.path.join(processed_video_folder, "2. filtered_frames"), exist_ok=True)
    ff_path = os.path.join(processed_video_folder, "2. filtered_frames", f"filtered_frames_{video}.pkl")
    ms_path = os.path.join(processed_video_folder, "2. filtered_frames", f"max_speed_{video}.pkl")

    if os.path.exists(ff_path) and os.path.exists(ms_path):
        with open(ff_path, "rb") as f:
            filtered_frames = pickle.load(f)

        with open(ms_path, "rb") as f:
            max_speed = pickle.load(f)
    else:
        filtered_frames, max_speed = filter_frames(total_frames)
        with open(ff_path, "wb") as f:
            pickle.dump(filtered_frames, f)

        with open(ms_path, "wb") as f:
            pickle.dump(max_speed, f)

    valid_episodes = find_valid_windows(filtered_frames, window_len=window_len, total_detections=33)

    if not valid_episodes:
        print("No valid episodes found.")
        continue

    extracted_windows = extract_windows(valid_episodes, window_len=window_len)
    print(f"Extracted {len(extracted_windows)} windows with length {window_len}.")

    pred, prey, coordinates = get_expert_tensors(filtered_frames, extracted_windows, width, height, max_speed=10, window_size=window_len)
    pred_tensors_all.append(pred)
    prey_tensors_all.append(prey)
    total_coordinates.append(coordinates)

pred_tensor = torch.cat(pred_tensors_all, dim=0)
prey_tensor = torch.cat(prey_tensors_all, dim=0)

n, window, agents, neighs, feature = prey_tensor.shape
flag = torch.zeros((n, window, agents, neighs, 1), dtype=prey_tensor.dtype, device=prey_tensor.device)
flag[:, :, :, 0, 0] = 1   # predator always first neighbor
prey_tensor = torch.cat([flag, prey_tensor], dim=-1)

coordinates_all = torch.cat(total_coordinates, dim=0)
n, window, agents, coordinates = coordinates_all.shape
coordinates = coordinates_all.reshape(n * window, agents, coordinates)

print(f"\nPredator Tensor: {pred_tensor.shape}")
print(f"Prey Tensor: {prey_tensor.shape}")


Processing pred_prey_interaction_0.07.mp4...
No valid episodes found.

Processing pred_prey_interaction_0.14.mp4...
No valid episodes found.

Processing pred_prey_interaction_0.15.mp4...
No valid episodes found.

Processing pred_prey_interaction_0.16.mp4...
No valid episodes found.

Processing pred_prey_interaction_0.17.mp4...
No valid episodes found.

Processing pred_prey_interaction_0.24.mp4...
No valid episodes found.

Processing pred_prey_interaction_0.27.mp4...
Extracted 6 windows with length 5.

Processing pred_prey_interaction_0.36.mp4...
Extracted 1 windows with length 5.

Processing pred_prey_interaction_0.41.mp4...
Extracted 1 windows with length 5.

Processing pred_prey_interaction_1.01.mp4...
Extracted 1 windows with length 5.

Processing pred_prey_interaction_1.07.mp4...
Extracted 2 windows with length 5.

Processing pred_prey_interaction_1.09.mp4...
No valid episodes found.

Processing pred_prey_interaction_1.11.mp4...
Extracted 7 windows with length 5.

Processing pred_

In [8]:
window_path = rf"..\data\1. Data Processing\processed\video\expert_tensors\windows"
window_folder = os.path.join(window_path, f"{window_len} windows")
os.makedirs(window_folder, exist_ok=True)

pred_file = os.path.join(window_folder, f"pred_tensor_w{window_len}_n{len(pred_tensor)}.pt")
prey_file = os.path.join(window_folder, f"prey_tensor_w{window_len}_n{len(prey_tensor)}.pt")

torch.save(pred_tensor, pred_file)
torch.save(prey_tensor, prey_file)

init_pool_path = rf"..\data\1. Data Processing\processed\init_pool\init_pool.pt"
torch.save(coordinates, init_pool_path)