In [1]:
import sys, os
sys.path.insert(0, os.path.abspath('..'))

import os
import ast
import torch
import pickle
import keyboard
import numpy as np
import pandas as pd
from utils.env_utils import *
from utils.eval_utils import *
import matplotlib.pyplot as plt
from utils.train_utils import *
from utils.couzin_utils import *
from marl_aquarium import aquarium_v0
from models.Buffer import Pool, Buffer
from sklearn.preprocessing import MinMaxScaler

In [2]:
expert_tensors_path =  r"..\data\1. Data Processing\processed\video\expert_tensors\yolo_detected"

with open(os.path.join(expert_tensors_path, "pred_tensors_yd.pkl"), "rb") as f:
    expert_pred_tensors = pickle.load(f)

with open(os.path.join(expert_tensors_path, "prey_tensors_yd.pkl"), "rb") as f:
    expert_prey_tensors = pickle.load(f)

In [3]:
import torch

# ---- Predator: [N, 1, 32, 5] -> [N*32, 5] ----
pred_flat = expert_pred_tensors.view(-1, 5)  # alle Nachbarn, alle Frames

feature_names = ["dx", "dy", "rel_vx", "rel_vy", "theta_norm"]

print("=== PREDATOR FEATURES ===")
for i, name in enumerate(feature_names):
    f = pred_flat[:, i]
    print(f"{name:10s}: min = {f.min().item(): .6f}, max = {f.max().item(): .6f}")

# ---- Prey: [N, 32, 32, 5] -> [N*32*32, 5] ----
prey_flat = expert_prey_tensors.view(-1, 5)

print("\n=== PREY FEATURES ===")
for i, name in enumerate(feature_names):
    f = prey_flat[:, i]
    print(f"{name:10s}: min = {f.min().item(): .6f}, max = {f.max().item(): .6f}")


=== PREDATOR FEATURES ===
dx        : min = -0.296633, max =  0.809462
dy        : min = -0.540936, max =  0.373218
rel_vx    : min = -0.437960, max =  0.694731
rel_vy    : min = -0.901806, max =  0.766194
theta_norm: min = -0.986111, max =  0.998831

=== PREY FEATURES ===
dx        : min = -0.870899, max =  0.870899
dy        : min = -0.646195, max =  0.646195
rel_vx    : min = -0.830117, max =  0.904384
rel_vy    : min = -0.904289, max =  0.736103
theta_norm: min = -0.999922, max =  0.999985


In [4]:
def run_policies_in_steps2(env, pred_policy, prey_policy, steps=200, render=True):
    if render:
        print("Press 'q' to end simulation.")

    metrics = []

    for frame in range(steps):
        if render and keyboard.is_pressed('q'):
            break

        global_state = env.state().item()
        pred_tensor, prey_tensor, xs, ys, dx, dy, vxs, vys = get_eval_features(global_state)

        # Predator
        pred_states = pred_tensor[..., :4]
        action_pred, mu_pred, sigma_pred, weights_pred = pred_policy.forward(pred_states)
        dis_pred = continuous_to_discrete(action_pred, 360, role='predator')

        prey_states = prey_tensor[..., :4]
        action_prey, mu_prey, sigma_prey, weights_prey, pred_gain = prey_policy.forward(prey_states)
        dis_prey = continuous_to_discrete(action_prey, 360, role='prey')

        # Action dictionary
        action_dict = {'predator_0': dis_pred}
        for i, agent_name in enumerate(sorted([agent for agent in env.agents if agent.startswith("prey")])):
            action_dict[agent_name] = dis_prey[i]

        env.step(action_dict)

        # Log metrics
        metrics.append({
            "polarization": compute_polarization(vxs, vys),
            "angular_momentum": compute_angular_momentum(xs, ys, vxs, vys),
            "degree_of_sparsity": degree_of_sparsity(xs, ys),
            "distance_to_predator": distance_to_predator(xs, ys),
            "escape_alignment": escape_alignment(xs, ys, vxs, vys),
            "actions": (dis_pred, dis_prey),
            "mu": (mu_pred, mu_prey),
            "sigma": (sigma_pred, sigma_prey),
            "weights": (weights_pred, weights_prey),
            "pred_gain": pred_gain,
            "xs": xs,
            "ys": ys,
            "dx": dx,
            "dy": dy,
            "vxs": vxs,
            "vys": vys
        })

        # Render only if user wants it
        if render:
            env.render()

    # Try closing the environment
    try:
        env.close()
    except:
        pass

    return metrics, pred_tensor, prey_tensor

In [5]:
gail_folder = "GAIL Training - 19.11.2025_18.52 - Couzin Data"
bc_folder = "BC Training - 19.11.2025_18.52 - Couzin Data"

model_folder = rf"..\data\2. Training\training"
gail_path = os.path.join(model_folder, "GAIL", gail_folder)
bc_path = os.path.join(model_folder, "BC", bc_folder)

expert_path = rf'..\data\1. Data Processing\processed\video\expert_tensors\yolo_detected\expert_metrics_yd.pkl'
ftw_path = rf'..\data\1. Data Processing\processed\video\3. full_track_windows'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

start_frame_pool = Pool(max_length=12100, device=device)
start_frame_pool.generate_startframes(ftw_path)

env = parallel_env(predator_count=1, prey_count=32, action_count=360, use_walls=True)
positions = start_frame_pool.sample(n=1)
env.reset(options=positions)
print("Environment initialized")


num_steps = 100

# GAIL Simulation
gail_pred_policy = torch.load(os.path.join(gail_path, "bc_pred_policy.pt"), weights_only=False)
gail_prey_policy = torch.load(os.path.join(gail_path, "bc_prey_policy.pt"), weights_only=False)
metrics, pred_tensor, prey_tensor = run_policies_in_steps2(env, gail_pred_policy, gail_prey_policy, steps=num_steps, render=False)
print("GAIL Simulation done!")

Environment initialized
GAIL Simulation done!


In [6]:
pred_flat = pred_tensor.view(-1, 5)  # alle Nachbarn, alle Frames

feature_names = ["dx", "dy", "rel_vx", "rel_vy", "theta_norm"]

print("=== PREDATOR FEATURES ===")
for i, name in enumerate(feature_names):
    f = pred_flat[:, i]
    print(f"{name:10s}: min = {f.min().item(): .6f}, max = {f.max().item(): .6f}")

# ---- Prey: [N, 32, 32, 5] -> [N*32*32, 5] ----
prey_flat = prey_tensor.view(-1, 5)

print("\n=== PREY FEATURES ===")
for i, name in enumerate(feature_names):
    f = prey_flat[:, i]
    print(f"{name:10s}: min = {f.min().item(): .6f}, max = {f.max().item(): .6f}")

=== PREDATOR FEATURES ===
dx        : min =  0.086037, max =  0.975513
dy        : min = -0.238320, max =  0.236286
rel_vx    : min = -0.597337, max =  0.599670
rel_vy    : min = -0.597033, max =  0.599821
theta_norm: min =  0.614444, max =  0.614444

=== PREY FEATURES ===
dx        : min = -0.975513, max =  0.889476
dy        : min = -0.474606, max =  0.474606
rel_vx    : min = -0.599704, max =  0.599996
rel_vy    : min = -0.599999, max =  0.599999
theta_norm: min = -0.980000, max =  0.976667
