In [5]:
import sys, os
sys.path.insert(0, os.path.abspath('..'))

import pickle
import datetime
from utils.es_utils import *
from utils.env_utils import *
from utils.train_utils import *
from models.Buffer import Buffer
from models.PredatorPolicy import PredatorPolicy
from models.PreyPolicy import PreyPolicy
from torch.utils.data import TensorDataset, DataLoader, random_split
import matplotlib.pyplot as plt
from models.Buffer import Pool
from utils.env_utils import *
from utils.eval_utils import *
from utils.train_utils import pretrain_policy_with_validation

import math
import torch
import scipy.stats
import torch.nn as nn
from utils.es_utils import *
from utils.env_utils import *
import torch.nn.functional as F
from torch.distributions import Normal
from multiprocessing import Pool, set_start_method
from models.ModularNetworks import PairwiseInteraction, Attention, PredatorInteraction

In [6]:
# Expert Data
traj_path = rf'..\data\1. Data Processing\processed\video\expert_tensors\yolo_detected'
couzin_path = rf'..\data\1. Data Processing\processed\couzin'
hl_path = rf'..\data\1. Data Processing\processed\video\expert_tensors\hand_labeled'
ftw_path = rf'..\data\1. Data Processing\processed\video\3. full_track_windows'

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [8]:
expert_buffer = Buffer(pred_max_length=23000, prey_max_length=200000, device=device)

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

gail_folder = "GAIL Training - 19.11.2025_18.52 - Couzin Data"
bc_folder = "BC Training - 19.11.2025_18.52 - Couzin Data"

model_folder = rf"..\data\2. Training\training"
gail_path = os.path.join(model_folder, "GAIL", gail_folder)
bc_path = os.path.join(model_folder, "BC", bc_folder)
pred_policy = torch.load(os.path.join(gail_path, "gail_pred_policy.pt"), weights_only=False)

prey_policy = PreyPolicy().to(device)
prey_policy.set_parameters(init=True)

In [10]:
# Load Expert Data from local storage
print("Expert Buffer is empty, load data...")
expert_buffer.add_expert(couzin_path)

len_exp_pred, len_exp_prey = expert_buffer.lengths()
print("Storage of Predator Expert Buffer: ", len_exp_pred)
print("Storage of Prey Expert Buffer: ", len_exp_prey, "\n")

_, expert_prey_batch = expert_buffer.sample(1, 1)
states = expert_prey_batch[..., :4]

Expert Buffer is empty, load data...
Storage of Predator Expert Buffer:  1000
Storage of Prey Expert Buffer:  32000 



In [None]:
def get_pred_gain(states):
    weights = torch.softmax(pred_policy.attention(states).view(1, 32), dim=1)
    w_min = weights.min()
    w_max = weights.max()
    scaled_weights = (weights - w_min) / (w_max - w_min) * 0.9 + 0.05  # scale to [0.05, 0.95]
    return scaled_weights

action_prey, mu_prey, sigma_prey, weights_prey, pred_gain = prey_policy.forward(states, weights)
print(weights)
print(weights.mean())
print(action_prey)

# clip auf 0.05 - 0.95

Pred Gain from Attention Weights:  tensor([0.2459], grad_fn=<MeanBackward1>)
tensor([[0.9500, 0.2109, 0.2646, 0.1213, 0.2121, 0.2457, 0.3799, 0.1408, 0.1417,
         0.1917, 0.3720, 0.2876, 0.4599, 0.0622, 0.2665, 0.0890, 0.1499, 0.2622,
         0.3348, 0.1979, 0.1511, 0.2943, 0.1040, 0.3105, 0.1617, 0.3126, 0.2420,
         0.2925, 0.0826, 0.1810, 0.3447, 0.0500]], grad_fn=<AddBackward0>)
tensor(0.2459, grad_fn=<MeanBackward0>)
tensor([-0.7072], grad_fn=<AddBackward0>)


In [18]:
agents, neigh, feat = states.shape                  # Shape: (32,32,4)

device = states.device
dtype  = states.dtype

##### Predator #####
pred_states = states[:, 0, :]                               # Shape: (32,1,4)
mu_pred, sigma_pred = prey_policy.pred_pairwise(pred_states)       # mu=32, simga=32
sampled_pred_action = Normal(mu_pred, sigma_pred).sample()  # actions=32
pred_actions = torch.tanh(sampled_pred_action) * math.pi    # Value Range [-pi:pi]
pred_action_flat = pred_actions.squeeze(-1)
print("Pred Actions: ", pred_action_flat)

if weights is not None:
    pred_gain = weights
else:
    pred_gain = torch.full((agents,), 1/33, device=states.device, dtype=states.dtype) # treat every action equal
print("Pred Gain: ", pred_gain)

##### Prey #####
prey_states = states[:, 1:, :]                                      # Shape: (32,31,4)
prey_states_flat   = prey_states.reshape(agents * (neigh-1), feat)  # Shape: (32*31,4)

mu_prey, sigma_prey = prey_policy.prey_pairwise(prey_states_flat)          # mu=32*31, simga=32*31
sampled_prey_action = Normal(mu_prey, sigma_prey).sample()          # actions=32*31
prey_actions = (torch.tanh(sampled_prey_action) * math.pi).view(agents, neigh - 1, 1)
print("Prey Actions: ", prey_actions.squeeze(-1))

prey_weight_logits = prey_policy.prey_attention(prey_states_flat)
prey_weight_logits = prey_weight_logits.view(agents, neigh-1)       # [A, N-1]
prey_weights = torch.softmax(prey_weight_logits, dim=1).view(agents, neigh-1, 1)
print("Prey Weights: ", prey_weights.squeeze(-1))

##### Action Aggregation #####

# Aggregation of Prey Actions per Prey
prey_actions_nei = prey_actions.squeeze(-1)
prey_weights_nei = prey_weights.squeeze(-1)
prey_action_per_prey = (prey_actions_nei * prey_weights_nei).sum(dim=1)
print("Prey Action per Prey: ", prey_action_per_prey)

final_action = pred_gain * pred_action_flat + (1.0 - pred_gain) * prey_action_per_prey
print("Final Action: ", final_action)

Pred Actions:  tensor([-2.4492])
Pred Gain:  tensor([[0.9500, 0.2109, 0.2646, 0.1213, 0.2121, 0.2457, 0.3799, 0.1408, 0.1417,
         0.1917, 0.3720, 0.2876, 0.4599, 0.0622, 0.2665, 0.0890, 0.1499, 0.2622,
         0.3348, 0.1979, 0.1511, 0.2943, 0.1040, 0.3105, 0.1617, 0.3126, 0.2420,
         0.2925, 0.0826, 0.1810, 0.3447, 0.0500]], grad_fn=<AddBackward0>)
Prey Actions:  tensor([[ 1.8440e+00, -2.5470e+00, -1.2248e+00, -1.4999e+00, -3.8322e-01,
         -9.0433e-01, -1.0150e+00,  2.7126e+00, -1.9035e+00,  2.6387e+00,
         -8.2270e-01,  1.1807e+00,  1.8885e+00,  8.8653e-01, -3.7311e-02,
          2.9208e+00, -1.6126e+00, -3.9730e-01,  2.1645e+00, -2.3790e+00,
          2.1130e+00, -1.4247e+00, -1.1093e+00, -3.4951e-02,  1.5150e+00,
          1.8907e+00, -9.9560e-04,  9.6808e-01, -1.2022e+00, -8.2934e-01,
         -1.6889e+00]])
Prey Weights:  tensor([[0.0323, 0.0323, 0.0323, 0.0323, 0.0323, 0.0323, 0.0323, 0.0323, 0.0323,
         0.0323, 0.0323, 0.0323, 0.0323, 0.0323, 0.0323, 0