In [1]:
import os
import cv2
import tqdm
import pickle
import numpy as np
from ultralytics import YOLO
import matplotlib.pyplot as plt
from collections import defaultdict
from deep_sort_realtime.deepsort_tracker import DeepSort

import torch
import math
from marl_aquarium.env.utils import scale

In [2]:
def process_frame(cap, model, tracker, frame_idx, device="cpu"):
    ret, frame = cap.read()
    if not ret or frame is None:
        return []
    
    height, width = frame.shape[:2]
    result = model(frame, verbose=False, device=device)[0]

    xywh = result.boxes.xywh.cpu().numpy()
    confs = result.boxes.conf.cpu().numpy()
    cls_ids = result.boxes.cls.cpu().numpy().astype(int)

    raw_detections = list(zip(xywh, confs, cls_ids))
    tracks = tracker.update_tracks(raw_detections, frame=frame)

    records = []
    for track in tracks:
        if not track.is_confirmed():
            continue

        x_raw = track.mean[0]
        y_raw = track.mean[1]
        x = float(np.clip(x_raw, 0, width))
        y = float(np.clip(y_raw, 0, height))

        records.append({"frame":    int(frame_idx),
                        "track_id": int(track.track_id),
                        "label":    str(track.det_class),
                        "conf":     track.det_conf,
                        "x":        x,
                        "y":        y,
                        "vx":       float(track.mean[4]),
                        "vy":       float(track.mean[5]),
                        "speed":    abs(float(math.hypot(track.mean[4], track.mean[5]))), #vx, vy
                        "angle":    float(math.atan2(track.mean[5], track.mean[4]))}) #vy, vx = radians
        
    return records


def filter_frames(total_frames):
    # Use only necessary detections of Pred Head and Prey
    pred_prey_frames = [frame for frame in total_frames if frame['label'] in ('2')] #Prey 2

    # Filter detections with None confidence
    filtered_conf = [frame for frame in pred_prey_frames if frame['conf'] is not None]

    # Drop multiple Pred Detections
    best_pred_label = {}
    preys = []

    for data in filtered_conf:
        frame = data['frame']
        if data['label'] == '1': #Pred Head 1, Prey 2
            if frame not in best_pred_label or data['conf'] > best_pred_label[frame]['conf']:
                best_pred_label[frame] = data
        else:
            preys.append(data)

    best_pred_prey_frames = list(best_pred_label.values()) + preys

    all_speeds = [det["speed"] for det in best_pred_prey_frames]
    max_speed = max(all_speeds)

    return best_pred_prey_frames, max_speed


def find_valid_windows(filtered_frames, num_frames=9, total_detections=32):
    det_by_frame = defaultdict(list)
    for d in filtered_frames:
        det_by_frame[d['frame']].append(d['track_id'])
    
    all_frames = sorted(det_by_frame.keys())
    
    valid_windows =[]
    for start in range(all_frames[0], all_frames[-1] - num_frames + 1):
        track_ids = set(det_by_frame.get(start, [])) #get track_ids for the start frame
        
        # update track_ids for the next num_frames
        for f in range(start + 1, start + num_frames):
            frame_ids = set(det_by_frame.get(f, []))
            track_ids &= frame_ids
            if not track_ids:
                break #if no track_ids left, break early
        
        if len(list(track_ids)) == total_detections:
            valid_windows.append({"start_frame": start, "ids": list(track_ids)}) #save start frame of the valid window
    
    full_track_windows = []
    for window in valid_windows:
        start = window['start_frame']
        ids = set(window['ids'])
        frames = set(range(start, start + num_frames))

        window_data = [data for data in filtered_frames if data['frame'] in frames and data['track_id'] in ids]
        full_track_windows.append(window_data)
        
    return full_track_windows, valid_windows


def get_expert_features(frame, width, height, max_speed=15):
    vscale = np.vectorize(scale)

    xs = np.array([det['x'] for det in frame])
    ys = np.array([det['y'] for det in frame])

    clipped_xs = np.clip(xs, 0, width)
    clipped_ys = np.clip(ys, 0, height)

    scaled_xs = vscale(clipped_xs, 0, width, 0, 1)
    scaled_ys = vscale(clipped_ys, 0, height, 0, 1)

    vxs = np.array([det['vx'] for det in frame])
    vys = np.array([det['vy'] for det in frame])

    thetas = np.array([det['angle'] for det in frame])
    scaled_thetas = vscale(thetas, -np.pi, np.pi, -1, 1)

    cos_t = np.cos(thetas)
    sin_t = np.sin(thetas)

    dx = scaled_xs[None, :] - scaled_xs[:, None]
    dy = scaled_ys[None, :] - scaled_ys[:, None]

    rel_vx = cos_t[:, None] * vxs[None, :] + sin_t[:, None] * vys[None, :]
    rel_vy = -sin_t[:, None] * vxs[None, :] + cos_t[:, None] * vys[None, :]

    scaled_rel_vx = np.clip(rel_vx, -max_speed, max_speed) / max_speed
    scaled_rel_vy = np.clip(rel_vy, -max_speed, max_speed) / max_speed

    n = scaled_xs.shape[0]
    thetas_mat = np.tile(scaled_thetas[:, None], (1, n))
    features = np.stack([dx, dy, scaled_rel_vx, scaled_rel_vy, thetas_mat], axis=-1)

    mask = ~np.eye(n, dtype=bool)
    neigh = features[mask].reshape(n, n-1, 5)   # (N, N-1, 5)

    # Prey-only: kein echter Predator -> pred_tensor leer
    pred_tensor = torch.empty(0)                # oder None, je nach weiterem Code
    prey_tensor = torch.from_numpy(neigh)       # (N, N-1, 5) für alle Preys

    return pred_tensor, prey_tensor, scaled_xs, scaled_ys, vxs, vys



def get_expert_tensors(full_track_windows, valid_windows, width, height, max_speed=15, window_size=1):
    if len(valid_windows) == 0:
        empty_metrics = {"xs": [], "ys": [], "vxs": [], "vys": []}
        return torch.empty(0), torch.empty(0), empty_metrics
    
    start_frames = [vw['start_frame'] for vw in valid_windows]
    pred_windows = []
    prey_windows = []
    expert_metrics = {"xs": [], "ys": [], "vxs": [], "vys": []}

    for idx, start in enumerate(start_frames):
        window_detections = []
        for frame in range(start, start + window_size):
            dets = [det for det in full_track_windows[idx] if det['frame'] == frame]
            window_detections.append(dets)

        preds = []
        preys = []

        xs_all = []
        ys_all = []
        vxs_all = []
        vys_all = []

        for dets in window_detections:
            pred_tensor, prey_tensor, xs_frame, ys_frame, vxs_frame, vys_frame = get_expert_features(dets, width, height, max_speed)

            preds.append(pred_tensor)
            preys.append(prey_tensor)

            xs_all.extend(xs_frame)
            ys_all.extend(ys_frame)
            vxs_all.extend(vxs_frame)
            vys_all.extend(vys_frame)

        pred_windows.append(torch.stack(preds, dim=0))
        prey_windows.append(torch.stack(preys, dim=0))

        expert_metrics["xs"].extend(xs_all)
        expert_metrics["ys"].extend(ys_all)
        expert_metrics["vxs"].extend(vxs_all)
        expert_metrics["vys"].extend(vys_all)

    pred_tensor = torch.stack(pred_windows, dim=0)
    prey_tensor = torch.stack(prey_windows, dim=0)
    print("Expert Tensors Shapes - Pred:", tuple(pred_tensor.shape), " Prey:", tuple(prey_tensor.shape))

    #total, n_clips, agent, neigh, feat = pred_tensor.shape
    #pred_tensors = pred_tensor.reshape(total * n_clips, agent, neigh, feat)

    total, n_clips, agent, neigh, feat = prey_tensor.shape
    prey_tensors = prey_tensor.reshape(total * n_clips, agent, neigh, feat)

    return pred_tensor, prey_tensors, expert_metrics

In [3]:
yolo_path = rf'..\models\costumized_yolo\costumized_yolo\costumized_yolo.pt' 
raw_video_folder = rf'..\data\1. Data Processing\raw\prey_only' 
processed_video_folder = rf'..\data\1. Data Processing\processed\prey_only'

In [4]:
num_frames=1            # number of consecutive frames
total_detections=32     # number of total detections in frame

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = YOLO(yolo_path)
model.to(device)
tracker = DeepSort(max_age=30)

for video in os.listdir(raw_video_folder):

    print(f"\nProcessing {video}...")

    # Load the video
    video_path = os.path.join(raw_video_folder, video)
    cap = cv2.VideoCapture(video_path)
    total_frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    
    
    # total_frames
    os.makedirs(os.path.join(processed_video_folder, "1. total_frames"), exist_ok=True)
    tf_path = os.path.join(processed_video_folder, "1. total_frames", f"total_frames_{video}.pkl")

    if os.path.exists(tf_path):
        with open(tf_path, "rb") as f:
            total_frames = pickle.load(f)
    else:
        total_frames = []
        for frame in tqdm.tqdm(range(total_frame_count), desc="Processing frames"):
            frame_records = process_frame(cap, model, tracker, frame, device=device)
            total_frames.extend(frame_records)
        cap.release()
        with open(tf_path, "wb") as f:
            pickle.dump(total_frames, f)


    # filtered_frames
    os.makedirs(os.path.join(processed_video_folder, "2. filtered_frames"), exist_ok=True)
    ff_path = os.path.join(processed_video_folder, "2. filtered_frames", f"filtered_frames_{video}.pkl")
    ms_path = os.path.join(processed_video_folder, "2. filtered_frames", f"max_speed_{video}.pkl")

    if os.path.exists(ff_path) and os.path.exists(ms_path):
        with open(ff_path, "rb") as f:
            filtered_frames = pickle.load(f)

        with open(ms_path, "rb") as f:
            max_speed = pickle.load(f)
    else:
        filtered_frames, max_speed = filter_frames(total_frames)
        with open(ff_path, "wb") as f:
            pickle.dump(filtered_frames, f)

        with open(ms_path, "wb") as f:
            pickle.dump(max_speed, f)


    # full_track_windows
    os.makedirs(os.path.join(processed_video_folder, "3. full_track_windows"), exist_ok=True)
    ftw_path = os.path.join(processed_video_folder, "3. full_track_windows", f"full_track_windows_{video}.pkl")
    vw_path = os.path.join(processed_video_folder, "3. full_track_windows", f"valid_windows_{video}.pkl")

    if os.path.exists(ftw_path) and os.path.exists(vw_path):
        with open(ftw_path, "rb") as f:
            full_track_windows = pickle.load(f)

        with open(vw_path, "rb") as f:
            valid_windows = pickle.load(f)
    else:
        full_track_windows, valid_windows = find_valid_windows(filtered_frames, num_frames=num_frames, total_detections=total_detections)
        with open(ftw_path, "wb") as f:
            pickle.dump(full_track_windows, f)

        with open(vw_path, "wb") as f:
            pickle.dump(valid_windows, f)

    print(f"Found {len(valid_windows)} windows with {total_detections} continuous detections.")


Processing prey_only_1.1_4.30.mp4...
Found 1044 windows with 32 continuous detections.

Processing prey_only_1.2_19.21.mp4...
Found 5832 windows with 32 continuous detections.

Processing prey_only_1.3_19.00.mp4...


Processing frames: 100%|██████████| 34202/34202 [3:13:02<00:00,  2.95it/s]  


Found 6276 windows with 32 continuous detections.

Processing prey_only_2.1_19.45.mp4...


Processing frames: 100%|██████████| 35565/35565 [3:13:51<00:00,  3.06it/s]  


Found 5726 windows with 32 continuous detections.

Processing prey_only_2.2_17.59.mp4...
Found 7011 windows with 32 continuous detections.

Processing prey_only_2.3_17.59.mp4...
Found 1714 windows with 32 continuous detections.

Processing prey_only_2.4_21.00.mp4...
Found 4528 windows with 32 continuous detections.

Processing prey_only_3.1_18.59.mp4...
Found 7610 windows with 32 continuous detections.

Processing prey_only_3.2_17.44.mp4...
Found 6238 windows with 32 continuous detections.

Processing prey_only_3.3_19.45.mp4...
Found 5219 windows with 32 continuous detections.

Processing prey_only_3.4_19.30.mp4...


Processing frames: 100%|██████████| 35123/35123 [3:16:21<00:00,  2.98it/s]  


Found 6356 windows with 32 continuous detections.

Processing prey_only_4.1_18.31.mp4...


Processing frames: 100%|██████████| 33358/33358 [3:01:39<00:00,  3.06it/s]  


Found 4798 windows with 32 continuous detections.

Processing prey_only_4.2_18.12.mp4...


Processing frames: 100%|██████████| 32788/32788 [2:55:32<00:00,  3.11it/s]  


Found 6128 windows with 32 continuous detections.


In [5]:
pred_tensors_list = []
prey_tensors_list = []
expert_metrics = {}

os.makedirs(os.path.join(processed_video_folder, "expert_tensors", "yolo_detected"), exist_ok=True)
pred_et_path = os.path.join(processed_video_folder, "expert_tensors", "yolo_detected", f"pred_tensors_yd.pkl")
prey_et_path = os.path.join(processed_video_folder, "expert_tensors", "yolo_detected", f"prey_tensors_yd.pkl")
metrics_path = os.path.join(processed_video_folder, "expert_tensors", "yolo_detected", f"expert_metrics_yd.pkl")

if os.path.exists(pred_et_path) and os.path.exists(prey_et_path) and os.path.exists(metrics_path):
    with open(pred_et_path, "rb") as f:
        pred_tensors = pickle.load(f)
    with open(prey_et_path, "rb") as f:
        prey_tensors = pickle.load(f)

    with open(metrics_path, "rb") as f:
        expert_metrics = pickle.load(f)
else:
    for video in os.listdir(raw_video_folder):
        pred_tensor, prey_tensor, video_metrics = get_expert_tensors(full_track_windows, valid_windows, width, height, window_size=num_frames)
        pred_tensors_list.append(pred_tensor)
        prey_tensors_list.append(prey_tensor)
        expert_metrics[video] = video_metrics

    pred_tensors = torch.cat(pred_tensors_list, dim=0)
    prey_tensors = torch.cat(prey_tensors_list, dim=0)

    with open(pred_et_path, "wb") as f:
        pickle.dump(pred_tensors, f)
    with open(prey_et_path, "wb") as f:
        pickle.dump(prey_tensors, f)
    with open(metrics_path, "wb") as f:
        pickle.dump(expert_metrics, f)

print("Pred Tensors Shape:", tuple(pred_tensors.shape))
print("Prey Tensors Shape:", tuple(prey_tensors.shape))
prey_tensors

Expert Tensors Shapes - Pred: (6128, 1, 0)  Prey: (6128, 1, 32, 31, 5)
Expert Tensors Shapes - Pred: (6128, 1, 0)  Prey: (6128, 1, 32, 31, 5)
Expert Tensors Shapes - Pred: (6128, 1, 0)  Prey: (6128, 1, 32, 31, 5)
Expert Tensors Shapes - Pred: (6128, 1, 0)  Prey: (6128, 1, 32, 31, 5)
Expert Tensors Shapes - Pred: (6128, 1, 0)  Prey: (6128, 1, 32, 31, 5)
Expert Tensors Shapes - Pred: (6128, 1, 0)  Prey: (6128, 1, 32, 31, 5)
Expert Tensors Shapes - Pred: (6128, 1, 0)  Prey: (6128, 1, 32, 31, 5)
Expert Tensors Shapes - Pred: (6128, 1, 0)  Prey: (6128, 1, 32, 31, 5)
Expert Tensors Shapes - Pred: (6128, 1, 0)  Prey: (6128, 1, 32, 31, 5)
Expert Tensors Shapes - Pred: (6128, 1, 0)  Prey: (6128, 1, 32, 31, 5)
Expert Tensors Shapes - Pred: (6128, 1, 0)  Prey: (6128, 1, 32, 31, 5)
Expert Tensors Shapes - Pred: (6128, 1, 0)  Prey: (6128, 1, 32, 31, 5)
Expert Tensors Shapes - Pred: (6128, 1, 0)  Prey: (6128, 1, 32, 31, 5)
Pred Tensors Shape: (79664, 1, 0)
Prey Tensors Shape: (79664, 32, 31, 5)


tensor([[[[ 1.3788e-02, -7.6601e-02,  1.2871e-01,  2.1498e-01,  8.0007e-02],
          [ 6.3807e-03, -2.5207e-01,  3.0846e-02,  3.0431e-01,  8.0007e-02],
          [ 1.0608e-02, -4.5453e-01, -1.8516e-02,  2.1503e-01,  8.0007e-02],
          ...,
          [ 1.8957e-02, -5.3434e-02,  1.1009e-01,  9.9437e-02,  8.0007e-02],
          [-3.7973e-02, -3.2250e-02,  1.6609e-02,  2.8559e-02,  8.0007e-02],
          [-8.3908e-02, -1.0670e-01, -5.9113e-03, -3.9020e-03,  8.0007e-02]],

         [[-1.3788e-02,  7.6601e-02,  1.7229e-01, -2.8777e-01,  4.0828e-01],
          [-7.4068e-03, -1.7547e-01,  2.7693e-01,  1.2986e-01,  4.0828e-01],
          [-3.1795e-03, -3.7793e-01,  1.7498e-01,  1.2635e-01,  4.0828e-01],
          ...,
          [ 5.1697e-03,  2.3166e-02,  1.4186e-01, -4.3371e-02,  4.0828e-01],
          [-5.1761e-02,  4.4351e-02,  3.3035e-02,  4.1984e-04,  4.0828e-01],
          [-9.7695e-02, -3.0096e-02, -6.3844e-03,  3.0673e-03,  4.0828e-01]],

         [[-6.3807e-03,  2.5207e-01,  3.38