In [None]:
import sys, os
sys.path.insert(0, os.path.abspath('..'))

import torch
from pathlib import Path
from ema_pytorch import EMA
from datetime import datetime
from utils.sim_utils import *
from utils.eval_utils import *
from utils.train_utils import *
from utils.couzin_utils import *
from utils.vec_sim_utils import *
from utils.encoder_utils import *
from utils.dataset_utils import *
from geomloss import SamplesLoss
from utils.mmd_loss import MMDLoss
from models.Generator import ModularPolicy
from models.Discriminator import Discriminator

  from .autonotebook import tqdm as notebook_tqdm




In [2]:
# Expert
max_steps = 300

# Training
num_generations = 4000
gamma = 0.999
deterministic=False # BC pretrain
performance_eval = 5
num_perturbations = 64

### Prey ###
lr_prey_policy = 2e-4
sigma_prey = 0.1

prey_dis_balance_factor = 2
prey_noise = 0.005
lr_prey_disc = 5e-4
lambda_gp_prey = 5
prey_update_mode = "avoid"


### Predator ###
lr_pred_policy = 1e-4
sigma_pred = 0.08

pred_dis_balance_factor = 2
pred_noise = 0.005
lr_pred_disc = 2e-4
lambda_gp_pred = 10
pred_update_mode = "mean"


# Env Settings
height = 50
width = 50
prey_speed = 5
pred_speed = 5
step_size = 0.5
theta_dot_max = 0.5
max_turn = float(theta_dot_max * step_size) + 1e-12

pert_steps = 100
init_steps = 500

env_settings = (height, width, prey_speed, pred_speed, step_size, max_turn, pert_steps)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
%matplotlib tk
exp_pred_sequence, exp_prey_sequence, couzin_metrics, actions, init_pool = run_couzin_simulation(
                                        visualization="off", 
                                        max_steps=init_steps, 
                                        constant_speed=prey_speed, shark_speed=pred_speed, 
                                        area_width=width, area_height=height,
                                        dt = step_size,
                                        alpha=0.01,
                                        theta_dot_max=theta_dot_max, theta_dot_max_shark=theta_dot_max,
                                        number_of_sharks=1, n=32)

exp_pred_sequence = exp_pred_sequence.to(device)
exp_prey_sequence = exp_prey_sequence.to(device)
init_pool = init_pool.to(device)

print("\nPred Shape:", exp_pred_sequence.shape)
print("Prey Shape:", exp_prey_sequence.shape)

exp_pred_tensor = sliding_window(exp_pred_sequence, window_size=10)
exp_prey_tensor = sliding_window(exp_prey_sequence, window_size=10)

print("\nPred Tensor Shape:", exp_pred_tensor.shape)
print("Prey Tensor Shape:", exp_prey_tensor.shape)



Pred Shape: torch.Size([500, 1, 32, 5])
Prey Shape: torch.Size([500, 32, 32, 6])

Pred Tensor Shape: torch.Size([491, 10, 1, 32, 5])
Prey Tensor Shape: torch.Size([491, 10, 32, 32, 6])


In [5]:
aug = TrajectoryAugmentation(noise_std=0.01, neigh_drop=0.10, feat_drop=0.05).to(device)
prey_encoder = TransitionEncoder(features=5, embd_dim=32, z=32).to(device)
prey_projector = VicRegProjector(input_dim=64).to(device)
prey_optimizer = torch.optim.Adam(list(prey_encoder.parameters()) + list(prey_projector.parameters()), lr=1e-3, weight_decay=1e-6)
train_encoder(prey_encoder, prey_projector, aug=aug, exp_tensor=exp_prey_tensor, epochs=1200, optimizer=prey_optimizer, role="prey")

for p in prey_encoder.parameters():
    p.requires_grad = False 

print("Prey Encoder trained & frozen.\n")

pred_encoder = TransitionEncoder(features=4, embd_dim=32, z=32).to(device)
pred_projector = VicRegProjector(input_dim=64).to(device)
pred_optimizer = torch.optim.Adam(list(pred_encoder.parameters()) + list(pred_projector.parameters()), lr=1e-3, weight_decay=1e-6)
train_encoder(pred_encoder, pred_projector, aug=aug, exp_tensor=exp_pred_tensor, epochs=800, optimizer=pred_optimizer, role="predator")

for p in pred_encoder.parameters():
    p.requires_grad = False

print("Predator Encoder trained & frozen.")

epoch 010: loss=22.635822 sim=0.0469 std=1.2981 cov=0.3984 std_mean=0.351
epoch 020: loss=21.943047 sim=0.0436 std=1.2534 cov=0.4103 std_mean=0.373
epoch 030: loss=21.222828 sim=0.0426 std=1.2041 cov=0.4190 std_mean=0.398
epoch 040: loss=20.749269 sim=0.0434 std=1.1622 cov=0.4463 std_mean=0.419
epoch 050: loss=20.569157 sim=0.0488 std=1.1379 cov=0.4563 std_mean=0.431
epoch 060: loss=20.033260 sim=0.0433 std=1.0872 cov=0.5288 std_mean=0.456
epoch 070: loss=19.889326 sim=0.0513 std=1.0633 cov=0.5313 std_mean=0.468
epoch 080: loss=19.827898 sim=0.0542 std=1.0590 cov=0.5174 std_mean=0.471
epoch 090: loss=19.882730 sim=0.0570 std=1.0394 cov=0.5731 std_mean=0.480
epoch 100: loss=19.722971 sim=0.0530 std=1.0380 cov=0.5655 std_mean=0.481
epoch 110: loss=19.772097 sim=0.0574 std=1.0335 cov=0.5668 std_mean=0.483
epoch 120: loss=19.629839 sim=0.0577 std=1.0246 cov=0.5635 std_mean=0.488
epoch 130: loss=19.511559 sim=0.0598 std=1.0127 cov=0.5650 std_mean=0.494
epoch 140: loss=19.472654 sim=0.0605 s

In [6]:
prey_discriminator = Discriminator(encoder=prey_encoder, role="prey", z_dim=32).to(device)
prey_discriminator.set_parameters(init=True)
optim_disc_prey = torch.optim.RMSprop(prey_discriminator.parameters(), lr=lr_prey_disc, alpha=0.99, eps=1e-08)

pred_discriminator = Discriminator(encoder=pred_encoder, role="predator", z_dim=32).to(device)
pred_discriminator.set_parameters(init=True)
optim_disc_pred = torch.optim.RMSprop(pred_discriminator.parameters(), lr=lr_pred_disc, alpha=0.99, eps=1e-08)

In [11]:
def discriminator_reward(discriminator, gen_tensor, mode="mean"):
    matrix = discriminator(gen_tensor)

    if mode == "mean":
        return matrix.mean(dim=(1, 2))

    if mode == "avoid":
        dis_reward = matrix.mean(dim=(1, 2))
        print("Discriminator Reward Mean:", dis_reward.mean().item())

        dx = gen_tensor[:, :-1, :, :, 1]
        dy = gen_tensor[:, :-1, :, :, 2]
        print("DX Mean:", dx.mean().item())
        print("DY Mean:", dy.mean().item())

        dist = torch.sqrt(dx**2 + dy**2)
        print("Distance Mean:", dist.mean().item())
        pred_dist = dist[:, :, :, 0]


        alpha_coeff=1.0 
        r_avoid=0.12 
        eps=1e-8

        avoid_reward = (-torch.relu(r_avoid - pred_dist)).mean(dim=(1, 2))
        print("Avoidance Reward Mean:", avoid_reward.mean().item())

        avoid_centered = avoid_reward - avoid_reward.mean().detach()


        dis_scale = dis_reward.detach().std().clamp(min=eps)  
        avoid_scale = avoid_centered.detach().std().clamp(min=eps)
        print("Discriminator Reward Scale:", dis_scale.item())
        print("Avoidance Reward Scale:", avoid_scale.item())

        alpha = alpha_coeff * dis_scale / avoid_scale
        print("Alpha:", alpha.item())

        reward = dis_reward + alpha * avoid_reward
        print("Total Reward Mean:", reward.mean().item())
        return reward

In [12]:
prey_r = discriminator_reward(prey_discriminator, exp_prey_tensor, mode="avoid")
pred_r = discriminator_reward(pred_discriminator, exp_pred_tensor, mode="mean")

Discriminator Reward Mean: -0.010941068641841412
DX Mean: 0.0007546551642008126
DY Mean: 0.0019254297949373722
Distance Mean: 0.46906596422195435
Avoidance Reward Mean: -0.0012265769764780998
Discriminator Reward Scale: 4.19268362747971e-05
Avoidance Reward Scale: 0.0014619167195633054
Alpha: 0.028679359704256058
Total Reward Mean: -0.010976246558129787


In [None]:
import torch

def discriminator_reward(
    discriminator,
    gen_tensor,
    mode="mean",
    alpha_coeff=1.0,
    r_avoid=0.5,
    r_avoid_quantile=None,
    tau_attack=0.05,
    scale_mode="meanabs",
    center=True,
    eps=1e-8,
):
    matrix = discriminator(gen_tensor)
    dis_reward = matrix.mean(dim=(1, 2))

    def _scale(x):
        if scale_mode == "std":
            return x.detach().std().clamp(min=eps)
        return x.detach().abs().mean().clamp(min=eps)

    if mode == "mean":
        return dis_reward

    gt = gen_tensor[:, :-1]
    feat_dim = gen_tensor.shape[-1]

    if mode == "avoid":
        dx = gt[..., 1]
        dy = gt[..., 2]
        dist = torch.sqrt(dx * dx + dy * dy + eps)
        pred_dist = dist[..., 0]

        if r_avoid_quantile is not None:
            r_avoid = torch.quantile(pred_dist.reshape(-1), float(r_avoid_quantile)).item()

        avoid_reward = (-torch.relu(r_avoid - pred_dist)).mean(dim=(1, 2))
        term = avoid_reward - avoid_reward.mean().detach() if center else avoid_reward

        alpha = alpha_coeff * (_scale(dis_reward) / _scale(term))
        return dis_reward + alpha * term

    if mode == "attack":
        dx = gt[..., 0]
        dy = gt[..., 1]
        dist = torch.sqrt(dx * dx + dy * dy + eps)

        softmin = -tau_attack * torch.logsumexp(-dist / tau_attack, dim=-1)
        attack_reward = (-softmin).mean(dim=(1, 2))

        term = attack_reward - attack_reward.mean().detach() if center else attack_reward

        alpha = alpha_coeff * (_scale(dis_reward) / _scale(term))
        return dis_reward + alpha * term


In [29]:
prey_r = discriminator_reward(prey_discriminator, exp_prey_tensor, mode="avoid")
pred_r = discriminator_reward(pred_discriminator, exp_pred_tensor, mode="attack")