In [1]:
import os
import time
import pylab
import torch
import pickle
import numpy as np
from math import *
import matplotlib as mpl
from cycler import cycler
from numpy.linalg import *
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from models.ModularNetworks import PairwiseInteraction, Attention
import torch.nn as nn
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F
from pathlib import Path
from torch.distributions import Normal
from datetime import datetime

mpl.use('TkAgg')
from utils.couzin_utils import run_couzin_simulation

In [2]:
class ModularPolicy(nn.Module):
    def __init__(self, features=4):
        super(ModularPolicy, self).__init__()

        self.pairwise = PairwiseInteraction(features)
        self.attention = Attention(features)

    def forward(self, states, deterministic=True):
        mu, sigma = self.pairwise(states)
        sigma = F.softplus(sigma) + 1e-6  # ensure positivity

        weights_logit = self.attention(states)
        weights = torch.softmax(weights_logit, dim=1)

        if deterministic:
            scaled_action = torch.sigmoid(mu)
            action = (scaled_action * weights).sum(dim=1)
            return action
        else:
            eps = torch.randn_like(mu)
            action = mu + sigma * eps
            scaled_action = torch.sigmoid(action)
            action = (scaled_action * weights).sum(dim=1)
            return action # [0,1]
        
    def set_parameters(self, init=True):
        if init is True:
            for layer in self.modules():
                if hasattr(layer, 'reset_parameters'):
                    layer.reset_parameters()

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

folder = "2026.01.06_15.22"
base_dir = Path(r"..\data\2. Training\training\VideoPredPrey - GAIL")

prey_path = base_dir / folder / "prey_policy.pth"
prey_policy = ModularPolicy(features=5).to(device)
prey_policy.load_state_dict(torch.load(prey_path))

if (base_dir / folder / "pred_policy.pth").exists():
    pred_path = base_dir / folder / "pred_policy.pth"
    pred_policy = ModularPolicy(features=4).to(device)
    pred_policy.load_state_dict(torch.load(pred_path))

init_pool_path = rf"..\data\1. Data Processing\processed\init_pool\init_pool.pt"
init_pool = torch.load(init_pool_path)

In [4]:
import os
import time
import pylab
import torch
import pickle
import numpy as np
from math import *
import matplotlib as mpl
from cycler import cycler
from numpy.linalg import *
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from copy import deepcopy


class Agent:
    def __init__(self, agent_id, speed, area_width, area_height):
        self.id = agent_id
        self.pos = np.array([np.random.uniform(0, area_width),
                             np.random.uniform(0, area_height)], dtype=np.float64)

        self.theta = np.random.uniform(-np.pi, np.pi)
        self.vel = np.array([np.cos(self.theta), np.sin(self.theta)], dtype=np.float64) * speed

    def update_position(self, step_size):
        self.pos += self.vel * step_size


def apply_turnrate_on_theta(agent, action, speed, max_turn=np.pi):
    dtheta = (action - 0.5) * 2.0 * max_turn
    agent.theta = (agent.theta + dtheta + np.pi) % (2*np.pi) - np.pi
    agent.vel = np.array([np.cos(agent.theta), np.sin(agent.theta)]) * speed



def enforce_walls(agent, area_width, area_height):
    bounced = False

    if agent.pos[0] < 0:
        agent.pos[0] = 0
        agent.theta = np.pi - agent.theta
        bounced = True
    elif agent.pos[0] > area_width:
        agent.pos[0] = area_width
        agent.theta = np.pi - agent.theta
        bounced = True

    if agent.pos[1] < 0:
        agent.pos[1] = 0
        agent.theta = -agent.theta
        bounced = True
    elif agent.pos[1] > area_height:
        agent.pos[1] = area_height
        agent.theta = -agent.theta
        bounced = True

    if bounced:
        agent.theta = (agent.theta + np.pi) % (2*np.pi) - np.pi
        speed = float(norm(agent.vel)) 
        agent.vel = np.array([np.cos(agent.theta), np.sin(agent.theta)], dtype=np.float64) * speed


def get_state_tensors(prey_log_step, pred_log_step, n_pred=1, 
                      area_width=50, area_height=50,
                      prey_speed=5, pred_speed=5, max_speed_norm=15,
                      mask=None):
    
    combined = np.vstack([pred_log_step, prey_log_step]).astype(np.float32)  # [N,6]
    n_agents = combined.shape[0]

    xs, ys  = combined[:, 0], combined[:, 1]
    vxs, vys = combined[:, 2], combined[:, 3]
    dir_x, dir_y = combined[:, 4], combined[:, 5]

    cos_t = dir_x.astype(np.float32)
    sin_t = dir_y.astype(np.float32)

    xs_scaled = xs / float(area_width)
    ys_scaled = ys / float(area_height)

    dx = xs_scaled[None, :] - xs_scaled[:, None]
    dy = ys_scaled[None, :] - ys_scaled[:, None]

    rel_vx = cos_t[:, None] * vxs[None, :] + sin_t[:, None] * vys[None, :]
    rel_vy = -sin_t[:, None] * vxs[None, :] + cos_t[:, None] * vys[None, :]

    speed = max_speed_norm # same scaling as expert
    rel_vx = np.clip(rel_vx, -speed, speed) / speed
    rel_vy = np.clip(rel_vy, -speed, speed) / speed

    features = np.stack([dx, dy, rel_vx, rel_vy], axis=-1).astype(np.float32)
    neigh = features[mask].reshape(n_agents, n_agents-1, 4)

    tensor = torch.from_numpy(neigh)   # already float32

    pred_tensor = tensor[:n_pred]
    prey_tensor = tensor[n_pred:]

    if n_pred > 0:
        agents, neighs, _ = prey_tensor.shape
        flag = torch.zeros((agents, neighs, 1), dtype=prey_tensor.dtype, device=prey_tensor.device)
        flag[:, :n_pred, 0] = 1
        prey_tensor = torch.cat([flag, prey_tensor], dim=-1)

    return pred_tensor, prey_tensor


def apply_init_pool(init_pool, pred, prey, area_width=50, area_height=50, seed=None):
    steps, agents, coordinates = init_pool.shape
    agents = len(pred) + len(prey)

    positions = init_pool[torch.randint(steps, (1,)).item(), :agents]  # [agents, 2]

    positions[:, 0] *= float(area_width)
    positions[:, 1] *= float(area_height)

    center_env = torch.tensor([area_width * 0.5, area_height * 0.5], dtype=positions.dtype, device=positions.device)
    center_pos = positions.mean(dim=0)
    shift = center_env - center_pos
    positions = positions + shift

    # predators
    for i in range(len(pred)):
        pred[i].pos = positions[i].detach().cpu().numpy().astype(np.float64)

    # prey
    for i in range(len(prey)):
        j = len(pred) + i
        prey[i].pos = positions[j].detach().cpu().numpy().astype(np.float64)


def run_env_simulation(prey_policy=None, pred_policy=None, 
                       n_prey=32, n_pred=1, step_size=1.0,
                       max_steps=100, seed=None, deterministic=False,
                       prey_speed=15, pred_speed=15, 
                       area_width=50, area_height=50, 
                       visualization='off', init_pool=None):

    if seed is not None:
        np.random.seed(seed) # agent init
        torch.manual_seed(seed) # CPU

    prey_policy = deepcopy(prey_policy).to("cpu")
    pred_policy = deepcopy(pred_policy).to("cpu") if pred_policy is not None else None

    prey = [Agent(i, prey_speed, area_width, area_height) for i in range(n_prey)]
    pred = [Agent(i, pred_speed, area_width, area_height) for i in range(n_pred)]

    if init_pool is not None:
        apply_init_pool(init_pool, pred, prey, area_width=area_width, area_height=area_height)

    n_agents = n_prey + n_pred
    neigh = n_agents - 1
    prey_traj = torch.empty((max_steps, n_prey, neigh, 6), dtype=torch.float32) if n_pred > 0 else torch.empty((max_steps, n_prey, neigh, 5), dtype=torch.float32)
    pred_traj = torch.empty((max_steps, n_pred, neigh, 5), dtype=torch.float32) if n_pred > 0 else None

    if visualization == 'on':
        prey_pos_vis = np.zeros((n_prey, 2), dtype=np.float32)
        prey_vel_vis = np.zeros((n_prey, 2), dtype=np.float32)
        pred_pos_vis = np.zeros((n_pred, 2), dtype=np.float32) if n_pred > 0 else None
        pred_vel_vis = np.zeros((n_pred, 2), dtype=np.float32) if n_pred > 0 else None
        fig, ax = plt.subplots()

    mask = ~np.eye(n_agents, dtype=bool)
    t = 0

    while t < max_steps:
        prey_pos_now = np.asarray([a.pos for a in prey], dtype=np.float32)  # [n_prey,2]
        prey_vel_now = np.asarray([a.vel for a in prey], dtype=np.float32)  # [n_prey,2]
        prey_dir = prey_vel_now / (np.linalg.norm(prey_vel_now, axis=1, keepdims=True) + 1e-12)
        prey_log_t = np.concatenate([prey_pos_now, prey_vel_now, prey_dir], axis=1)  # [n_prey,6]

        if n_pred > 0:
            pred_pos_now = np.asarray([a.pos for a in pred], dtype=np.float32)
            pred_vel_now = np.asarray([a.vel for a in pred], dtype=np.float32)
            pred_dir = pred_vel_now / (np.linalg.norm(pred_vel_now, axis=1, keepdims=True) + 1e-12)
            predator_log_t = np.concatenate([pred_pos_now, pred_vel_now, pred_dir], axis=1)
        else:
            predator_log_t = np.empty((0, 6), dtype=np.float32)

        # --- Visualization ---
        if visualization == 'on':
            prey_pos_vis[:, :] = prey_pos_now
            prey_vel_vis[:, :] = prey_vel_now

            if n_pred > 0:
                pred_pos_vis[:, :] = pred_pos_now
                pred_vel_vis[:, :] = pred_vel_now

            ax.clear()
            pylab.quiver(
                prey_pos_vis[:, 0], prey_pos_vis[:, 1],
                prey_vel_vis[:, 0], prey_vel_vis[:, 1],
                scale=120,
                width=0.01,
                headwidth=3,
                headlength=3,
                headaxislength=3,
            )

            if n_pred > 0:
                pylab.quiver(
                    pred_pos_vis[:, 0], pred_pos_vis[:, 1],
                    pred_vel_vis[:, 0], pred_vel_vis[:, 1],
                    color="#FF0000",
                    scale=120,
                    width=0.01,
                    headwidth=3,
                    headlength=3,
                    headaxislength=3,
                )
            ax.set_aspect('equal', 'box')
            ax.set_xlim(0, area_width)
            ax.set_ylim(0, area_height)

            plt.pause(0.00000001)

        pred_states, prey_states = get_state_tensors(prey_log_t, predator_log_t, 
                                                     n_pred=n_pred, 
                                                     area_width=area_width, area_height=area_height, 
                                                     prey_speed=prey_speed, pred_speed=pred_speed,
                                                     mask=mask)

        if n_pred > 0:
            with torch.inference_mode():
                pred_actions = pred_policy.forward(pred_states, deterministic=deterministic)
                #pred_actions = torch.full_like(pred_actions, float(0.49))
                pred_traj[t, :, :, :4] = pred_states
                pred_traj[t, :, :, 4:] = pred_actions.unsqueeze(1).expand(-1, neigh, -1) #[1, 1, 32, 6]

            with torch.inference_mode():
                prey_actions = prey_policy.forward(prey_states, deterministic=deterministic)
                #prey_actions = torch.full_like(prey_actions, float(0.51))
                prey_traj[t, :, :, :5] = prey_states
                prey_traj[t, :, :, 5:] = prey_actions.unsqueeze(1).expand(-1, neigh, -1) #[1, 32, 32, 6]
        else:
            with torch.inference_mode():
                prey_actions = prey_policy.forward(prey_states, deterministic=deterministic)
                prey_traj[t, :, :, :4] = prey_states
                prey_traj[t, :, :, 4:] = prey_actions.unsqueeze(1).expand(-1, neigh, -1)

        prey_actions = prey_actions.squeeze(-1).detach().cpu().numpy()
        for i, agent in enumerate(prey):
            apply_turnrate_on_theta(agent, prey_actions[i], prey_speed)

        if n_pred > 0:
            pred_actions = pred_actions.squeeze(-1).detach().cpu().numpy()
            for i, predator in enumerate(pred):
                apply_turnrate_on_theta(predator, pred_actions[i], pred_speed)

        for agent in prey:
            agent.update_position(step_size=step_size)
            enforce_walls(agent, area_width, area_height)

        if n_pred > 0:
            for predator in pred:
                predator.update_position(step_size=step_size)
                enforce_walls(predator, area_width, area_height)

        t += 1

        prey_tensor = prey_traj[:t]
        pred_tensor = pred_traj[:t] if n_pred > 0 else None

    return pred_tensor, prey_tensor

In [5]:
init_pool_path = rf"..\data\1. Data Processing\processed\init_pool\init_pool.pt"
init_pool = torch.load(init_pool_path)

gen_pred_tensor, gen_prey_tensor = run_env_simulation(visualization='off', init_pool=init_pool,
                                                      prey_policy=prey_policy, pred_policy=pred_policy, 
                                                      n_prey=32, n_pred=1, 
                                                      pred_speed=5, prey_speed=5,
                                                      max_steps=50, step_size=3,
                                                      area_width=2160, area_height=2160)

In [6]:
start = time.time()
for i in range(32):
    gen_pred_tensor, gen_prey_tensor = run_env_simulation(visualization='off', init_pool=init_pool,
                                                        prey_policy=prey_policy, pred_policy=pred_policy, 
                                                        n_prey=32, n_pred=1, 
                                                        pred_speed=5, prey_speed=5,
                                                        max_steps=50, step_size=3,
                                                        area_width=2160, area_height=2160)
end = time.time()

print("Time:", (end - start))

Time: 2.558122396469116


In [None]:
def velocity_from_theta(theta, speed):
    vx = torch.cos(theta) * speed
    vy = torch.sin(theta) * speed
    return torch.stack([vx, vy], dim=-1)


def apply_turnrate(theta, action, max_turn=torch.pi):
    dtheta = (action - 0.5) * 2.0 * max_turn
    theta = theta + dtheta
    return (theta + torch.pi) % (2*torch.pi) - torch.pi


def enforce_walls(pos, theta, area_width, area_height):
    bounced_x = (pos[..., 0] < 0) | (pos[..., 0] > area_width)
    bounced_y = (pos[..., 1] < 0) | (pos[..., 1] > area_height)

    pos[..., 0] = pos[..., 0].clamp(0.0, float(area_width))
    pos[..., 1] = pos[..., 1].clamp(0.0, float(area_height))

    theta = torch.where(bounced_x, torch.pi - theta, theta)
    theta = torch.where(bounced_y, -theta, theta)

    theta = (theta + torch.pi) % (2 * torch.pi) - torch.pi
    return pos, theta


def get_state_tensors(prey_log_step, pred_log_step, n_pred=1, 
                      area_width=50, area_height=50, 
                      max_speed_norm=15, neigh_idx=None):
    
    device = prey_log_step.device
    combined = torch.cat([pred_log_step, prey_log_step], dim=1)
    batch, n_agents, _ = combined.shape
    n_neigh = n_agents - 1

    xs, ys   = combined[..., 0], combined[..., 1]
    vxs, vys = combined[..., 2], combined[..., 3]
    cos_t, sin_t = combined[..., 4], combined[..., 5]

    xs_scaled = xs / float(area_width)
    ys_scaled = ys / float(area_height)

    dx = xs_scaled.unsqueeze(2) - xs_scaled.unsqueeze(1)
    dy = ys_scaled.unsqueeze(2) - ys_scaled.unsqueeze(1)

    rel_vx = cos_t.unsqueeze(2) * vxs.unsqueeze(1) + sin_t.unsqueeze(2) * vys.unsqueeze(1)
    rel_vy = -sin_t.unsqueeze(2) * vxs.unsqueeze(1) + cos_t.unsqueeze(2) * vys.unsqueeze(1)

    speed = float(max_speed_norm)
    rel_vx = torch.clamp(rel_vx, -speed, speed) / speed
    rel_vy = torch.clamp(rel_vy, -speed, speed) / speed

    features = torch.stack([dx, dy, rel_vx, rel_vy], dim=-1)

    gather_idx = neigh_idx.view(1, n_agents, n_neigh, 1).expand(batch, n_agents, n_neigh, 4).to(device)
    neigh = features.gather(dim=2, index=gather_idx)

    pred_tensor = neigh[:, :n_pred]
    prey_tensor = neigh[:, n_pred:]

    if n_pred > 0:
        mask_pred_neigh = (neigh_idx < n_pred).to(device)
        mask_pred_neigh = mask_pred_neigh.view(1, n_agents, n_neigh, 1)
        prey_mask = mask_pred_neigh[:, n_pred:]
        prey_mask = prey_mask.expand(batch, -1, -1, -1).to(prey_tensor.dtype)

        prey_tensor = torch.cat([prey_mask, prey_tensor], dim=-1)

    return pred_tensor, prey_tensor


def init_positions(init_pool, batch=32, area_width=2160, area_height=2160, device="cpu"):
    steps, agents, coordinates = init_pool.shape

    idx = torch.randint(steps, (batch,), device=device)  # (B,)
    positions = init_pool[idx, :agents].clone().to(device)  # (B,N,2)

    positions[..., 0] *= float(area_width)
    positions[..., 1] *= float(area_height)

    center_env = torch.tensor([area_width * 0.5, area_height * 0.5],
                              dtype=positions.dtype, device=device).view(1, 1, 2)

    center_pos = positions.mean(dim=1, keepdim=True)  # (B,1,2)
    shift = center_env - center_pos
    positions = positions + shift
    
    return positions.to(torch.float32)


def run_env_vectorized(prey_policy=None, pred_policy=None, 
                       n_prey=32, n_pred=1, 
                       step_size=1.0, batch=32,
                       max_steps=100, seed=None, deterministic=False,
                       prey_speed=5, pred_speed=5, 
                       area_width=2160, area_height=2160, 
                       init_pool=None):

    if seed is not None:
        np.random.seed(seed)
        torch.manual_seed(seed)

    device = torch.device("cpu")

    n_agents = n_prey + n_pred
    n_neigh = n_agents - 1

    positions = init_positions(init_pool, batch, n_agents, area_width=area_width, area_height=area_height, device=device).to(device)
    theta = (torch.rand((batch, n_agents), dtype=torch.float32, device=device) * 2 * torch.pi) - torch.pi
    speed = torch.full((batch, n_agents), float(prey_speed), dtype=torch.float32, device=device)

    if n_pred > 0:
        speed[:, :n_pred] = float(pred_speed)
        prey_traj = torch.empty((batch, max_steps, n_prey, n_neigh, 6), dtype=torch.float32, device=device)
        pred_traj = torch.empty((batch, max_steps, n_pred, n_neigh, 5), dtype=torch.float32, device=device)
    else:
        prey_traj = torch.empty((max_steps, batch, n_prey, n_neigh, 5), dtype=torch.float32, device=device)
        pred_traj = None

    idx = torch.arange(n_agents, device=device)
    neigh_idx = idx.repeat(n_agents, 1)
    neigh_idx = neigh_idx[~torch.eye(n_agents, dtype=torch.bool, device=device)].view(n_agents, n_agents - 1)

    t = 0
    with torch.inference_mode():
        while t < max_steps:
            vel = velocity_from_theta(theta, speed)

            prey_pos_now = positions[:, n_pred:]
            prey_vel_now = vel[:, n_pred:]
            prey_dir = prey_vel_now / (torch.linalg.norm(prey_vel_now, dim=-1, keepdim=True) + 1e-12)
            prey_log_t = torch.cat([prey_pos_now, prey_vel_now, prey_dir], dim=-1)

            if n_pred > 0:
                pred_pos_now = positions[:, :n_pred]
                pred_vel_now = vel[:, :n_pred]
                pred_dir = pred_vel_now / (torch.linalg.norm(pred_vel_now, dim=-1, keepdim=True) + 1e-12)
                predator_log_t = torch.cat([pred_pos_now, pred_vel_now, pred_dir], dim=-1)
            else:
                predator_log_t = torch.empty((batch, 0, 6), dtype=torch.float32, device=device)

            pred_states, prey_states = get_state_tensors(prey_log_t, predator_log_t,
                                                         n_pred=n_pred,
                                                         area_width=area_width, area_height=area_height,
                                                         max_speed_norm=15,
                                                         neigh_idx=neigh_idx)

            if n_pred > 0:
                pred_in = pred_states.reshape(batch * n_pred, n_neigh, 4)
                prey_in = prey_states.reshape(batch * n_prey, n_neigh, 5)

                pred_actions = pred_policy.forward(pred_in, deterministic=deterministic).view(batch, n_pred, 1)
                prey_actions = prey_policy.forward(prey_in, deterministic=deterministic).view(batch, n_prey, 1)

                pred_traj[:, t, :, :, :4] = pred_states
                pred_traj[:, t, :, :, 4:] = pred_actions.unsqueeze(3).expand(-1, -1, n_neigh, -1)

                prey_traj[:, t, :, :, :5] = prey_states
                prey_traj[:, t, :, :, 5:] = prey_actions.unsqueeze(3).expand(-1, -1, n_neigh, -1)
            else:
                prey_in = prey_states.reshape(batch * n_prey, n_neigh, 4)
                prey_actions = prey_policy.forward(prey_in, deterministic=deterministic).view(batch, n_prey, 1)

                prey_traj[:, t, :, :, :4] = prey_states
                prey_traj[:, t, :, :, 4:] = prey_actions.unsqueeze(3).expand(-1, -1, n_neigh, -1)


            theta[:, n_pred:] = apply_turnrate(theta[:, n_pred:], prey_actions.squeeze(-1))
            if n_pred > 0:
                theta[:, :n_pred] = apply_turnrate(theta[:, :n_pred], pred_actions.squeeze(-1))

            vel = velocity_from_theta(theta, speed)
            positions = positions + vel * float(step_size)
            positions, theta = enforce_walls(positions, theta, area_width, area_height)

            t += 1

    prey_tensor = prey_traj[:, :t]
    pred_tensor = pred_traj[:, :t] if n_pred > 0 else None
    return pred_tensor, prey_tensor

In [8]:
prey_policy = prey_policy.to("cpu")
pred_policy = pred_policy.to("cpu") if pred_policy is not None else None

In [9]:
start = time.time()
gen_pred_tensor, gen_prey_tensor = run_env_vectorized(init_pool=init_pool, batch=32,
                                                    prey_policy=prey_policy, pred_policy=pred_policy, 
                                                    n_prey=32, n_pred=1, 
                                                    pred_speed=5, prey_speed=5,
                                                    max_steps=100, step_size=3,
                                                    area_width=2160, area_height=2160)
end = time.time()

print("Time:", (end - start))

Time: 1.1570827960968018


In [95]:
from collections import OrderedDict
from torch.func import functional_call, vmap

def init_positions(init_pool, batch=32, area_width=2160, area_height=2160,
                   mode="dual", device="cpu"):
    steps, agents, coordinates = init_pool.shape

    idx = torch.randint(steps, (batch,), device=device)  # (B,)

    positions = init_pool[idx, :agents].clone().to(device)  # (B,N,2)

    positions[..., 0] *= float(area_width)
    positions[..., 1] *= float(area_height)

    center_env = torch.tensor([area_width * 0.5, area_height * 0.5],
                              dtype=positions.dtype, device=device).view(1, 1, 2)

    center_pos = positions.mean(dim=1, keepdim=True)  # (B,1,2)
    shift = center_env - center_pos
    positions = positions + shift

    if mode == "dual":
        positions = positions.repeat(2, 1, 1)  # (2B, N, 2)

    return positions.to(torch.float32)

def policy_perturbation(pred_policy, prey_policy, 
                        role="prey", module="pairwise", 
                        sigma=0.1, num_perturbations=32,
                        device="cpu"):
    
    policy = prey_policy if role == "prey" else pred_policy
    base_state_dict = policy.state_dict()

    prefix = "pairwise." if module == "pairwise" else "attention."

    # Extract parameter keys for perturbation
    param_keys = [k for k, v in base_state_dict.items()
                    if k.startswith(prefix) and torch.is_tensor(v) and v.is_floating_point()]

    base = OrderedDict()
    for k, v in base_state_dict.items():
        if torch.is_tensor(v):
            base[k] = v.detach().to(device)

    pos_list = []
    neg_list  = []
    epsilons = []

    for _ in range(num_perturbations):
        pos = base.copy()
        neg = base.copy()

        for k in param_keys:
            eps = torch.randn_like(base[k])
            pos[k] = base[k] + float(sigma) * eps
            neg[k] = base[k] - float(sigma) * eps

        pos_list.append(pos)
        neg_list.append(neg)
        epsilons.append(eps)

    pert_list_all = pos_list + neg_list

    return pert_list_all, epsilons


def batch_policy_forward(policy, states, pert_list, deterministic=False):
    keys = pert_list[0].keys()
    params_batched = OrderedDict((k, torch.stack([p[k] for p in pert_list], dim=0)) for k in keys)

    def vmap_function(params, x):
        return functional_call(policy, params, (x,), kwargs={"deterministic": deterministic})

    return vmap(vmap_function, in_dims=(0, 0), randomness="different")(params_batched, states)


def run_batch_env(prey_policy=None, pred_policy=None, 
                       n_prey=32, n_pred=1, 
                       step_size=1.0, batch=32,
                       max_steps=100, seed=None, deterministic=False,
                       prey_speed=5, pred_speed=5, 
                       area_width=2160, area_height=2160, 
                       init_pos=None, pert_list=None, role="prey"):

    if seed is not None:
        np.random.seed(seed)
        torch.manual_seed(seed)

    device = torch.device("cpu")

    n_agents = n_prey + n_pred
    n_neigh = n_agents - 1

    positions = init_pos.to(device)
    theta = (torch.rand((batch, n_agents), dtype=torch.float32, device=device) * 2 * torch.pi) - torch.pi
    speed = torch.full((batch, n_agents), float(prey_speed), dtype=torch.float32, device=device)

    if n_pred > 0:
        speed[:, :n_pred] = float(pred_speed)
        prey_traj = torch.empty((batch, max_steps, n_prey, n_neigh, 6), dtype=torch.float32, device=device)
        pred_traj = torch.empty((batch, max_steps, n_pred, n_neigh, 5), dtype=torch.float32, device=device)
    else:
        prey_traj = torch.empty((max_steps, batch, n_prey, n_neigh, 5), dtype=torch.float32, device=device)
        pred_traj = None

    idx = torch.arange(n_agents, device=device)
    neigh_idx = idx.repeat(n_agents, 1)
    neigh_idx = neigh_idx[~torch.eye(n_agents, dtype=torch.bool, device=device)].view(n_agents, n_agents - 1)

    t = 0
    with torch.inference_mode():
        while t < max_steps:
            vel = velocity_from_theta(theta, speed)

            prey_pos_now = positions[:, n_pred:]
            prey_vel_now = vel[:, n_pred:]
            prey_dir = prey_vel_now / (torch.linalg.norm(prey_vel_now, dim=-1, keepdim=True) + 1e-12)
            prey_log_t = torch.cat([prey_pos_now, prey_vel_now, prey_dir], dim=-1)

            if n_pred > 0:
                pred_pos_now = positions[:, :n_pred]
                pred_vel_now = vel[:, :n_pred]
                pred_dir = pred_vel_now / (torch.linalg.norm(pred_vel_now, dim=-1, keepdim=True) + 1e-12)
                predator_log_t = torch.cat([pred_pos_now, pred_vel_now, pred_dir], dim=-1)
            else:
                predator_log_t = torch.empty((batch, 0, 6), dtype=torch.float32, device=device)

            pred_states, prey_states = get_state_tensors(prey_log_t, predator_log_t,
                                                         n_pred=n_pred,
                                                         area_width=area_width, area_height=area_height,
                                                         max_speed_norm=15,
                                                         neigh_idx=neigh_idx)

            if n_pred > 0:
                # keep env-wise shape
                pred_in_env = pred_states.view(batch, n_pred, n_neigh, 4)  # (B, n_pred, neigh, 4)
                prey_in_env = prey_states.view(batch, n_prey, n_neigh, 5)  # (B, n_prey, neigh, 5)

                # PRED actions
                if role in ("pred", "predator") and pert_list is not None:
                    pred_actions = batch_policy_forward(pred_policy, pred_in_env, pert_list,
                                                          deterministic=deterministic).view(batch, n_pred, 1)
                else:
                    #print("No perturbation on predator policy.")
                    pred_actions = pred_policy.forward(
                        pred_in_env.view(batch * n_pred, n_neigh, 4),
                        deterministic=deterministic
                    ).view(batch, n_pred, 1)

                # PREY actions
                if role == "prey" and pert_list is not None:
                    prey_actions = batch_policy_forward(prey_policy, prey_in_env, pert_list,
                                                          deterministic=deterministic).view(batch, n_prey, 1)
                else:
                    #print("No perturbation on prey policy.")
                    prey_actions = prey_policy.forward(
                        prey_in_env.view(batch * n_prey, n_neigh, 5),
                        deterministic=deterministic
                    ).view(batch, n_prey, 1)

                pred_traj[:, t, :, :, :4] = pred_states
                pred_traj[:, t, :, :, 4:] = pred_actions.unsqueeze(3).expand(-1, -1, n_neigh, -1)

                prey_traj[:, t, :, :, :5] = prey_states
                prey_traj[:, t, :, :, 5:] = prey_actions.unsqueeze(3).expand(-1, -1, n_neigh, -1)

            else:
                prey_in_env = prey_states.view(batch, n_prey, n_neigh, 4)

                if role == "prey" and pert_list is not None:
                    prey_actions = batch_policy_forward(prey_policy, prey_in_env, pert_list,
                                                          deterministic=deterministic).view(batch, n_prey, 1)
                else:
                    prey_actions = prey_policy.forward(
                        prey_in_env.view(batch * n_prey, n_neigh, 4),
                        deterministic=deterministic
                    ).view(batch, n_prey, 1)

                prey_traj[:, t, :, :, :4] = prey_states
                prey_traj[:, t, :, :, 4:] = prey_actions.unsqueeze(3).expand(-1, -1, n_neigh, -1)

            theta[:, n_pred:] = apply_turnrate(theta[:, n_pred:], prey_actions.squeeze(-1))
            if n_pred > 0:
                theta[:, :n_pred] = apply_turnrate(theta[:, :n_pred], pred_actions.squeeze(-1))

            vel = velocity_from_theta(theta, speed)
            positions = positions + vel * float(step_size)
            positions, theta = enforce_walls(positions, theta, area_width, area_height)

            t += 1

    prey_tensor = prey_traj[:, :t]
    pred_tensor = pred_traj[:, :t] if n_pred > 0 else None
    return pred_tensor, prey_tensor


def apply_perturbations(prey_policy, pred_policy, init_pos, 
                        role, module, device,
                        sigma, num_perturbations):
    
    pert_list, epsilons = policy_perturbation(pred_policy, prey_policy,
                                            role=role, module=module,
                                            sigma=sigma, num_perturbations=num_perturbations,
                                            device=device)
    
    pred_rollouts, prey_rollouts = run_batch_env(prey_policy=prey_policy, pred_policy=pred_policy,
                                        batch=2*num_perturbations, init_pos=init_pos,
                                        pert_list=pert_list, role=role)
    
    return pred_rollouts, prey_rollouts, epsilons

In [96]:
init_pool_path = rf"..\data\1. Data Processing\processed\init_pool\init_pool.pt"
init_pool = torch.load(init_pool_path)

init_pos = init_positions(init_pool, batch=32, mode="dual")

pred_rollouts, prey_rollouts, epsilons = apply_perturbations(prey_policy, pred_policy, init_pos,
                               role="prey", module="pairwise", device="cpu",
                               sigma=0.1, num_perturbations=32)

In [None]:
def optimize_es(pred_policy, prey_policy, 
                role, module, 
                discriminator, lr, 
                sigma, num_perturbations, 
                device="cuda", init_pos=None):
    
    if role == "prey":
        network = prey_policy.pairwise if module == 'pairwise' else prey_policy.attention
    else:
        network = pred_policy.pairwise if module == 'pairwise' else pred_policy.attention

    theta = nn.utils.parameters_to_vector(network.parameters())

    pred_rollouts, prey_rollouts, epsilons = apply_perturbations(prey_policy, pred_policy, init_pos,
                                role=role, module=module, device=device,
                                sigma=sigma, num_perturbations=num_perturbations)
    
    if role == "prey":
        reward = discriminator_reward(discriminator, prey_rollouts, mode="top")
    else:
        reward = discriminator_reward(discriminator, pred_rollouts, mode="top")

    reward_pos = reward[:num_perturbations]
    reward_neg = reward[num_perturbations:]

    diffs = (reward_pos - reward_neg).detach()
    ranks = torch.argsort(torch.argsort(diffs)).float()
    ranks_norm = (ranks - ranks.mean()) / (ranks.std() + 1e-8)

    theta_est, grad_metrics = gradient_estimate(theta, ranks_norm, epsilons, sigma, lr, num_perturbations)

    # if std is too small, do not update (Random Walk)
    if diffs.std(unbiased=False) < 1e-6:
        theta_est = theta

    nn.utils.vector_to_parameters(theta_est, network.parameters())
    
    return {"diff_min": round(diffs.min().item(), 6),
            "diff_max": round(diffs.max().item(), 6),
            "diff_mean": round(diffs.mean().item(), 6),
            "diff_std": round(diffs.std(unbiased=False).item(), 6),
            "delta_norm": round((theta_est - theta).norm().item(), 6),
            "clip_ratio": round(grad_metrics["clip_ratio"], 6),
            "delta_raw_norm": round(grad_metrics["delta_raw_norm"], 6),
            "max_delta_norm": round(grad_metrics["max_delta_norm"], 6)
        }
