In [1]:
import sys, os
sys.path.insert(0, os.path.abspath('..'))

import torch
from pathlib import Path
from ema_pytorch import EMA
from datetime import datetime
from utils.sim_utils import *
from utils.eval_utils import *
from utils.train_utils import *
from utils.couzin_utils import *
from utils.vec_sim_utils import *
from utils.encoder_utils import *
from utils.dataset_utils import *
from geomloss import SamplesLoss
from utils.mmd_loss import MMDLoss
from models.Generator import ModularPolicy
from models.Discriminator import Discriminator

  from .autonotebook import tqdm as notebook_tqdm




In [2]:
# Expert
max_steps = 300

# Training
num_generations = 4000
gamma = 0.999
deterministic=False # BC pretrain
performance_eval = 5
num_perturbations = 64

### Prey ###
lr_prey_policy = 2e-4
sigma_prey = 0.1

prey_dis_balance_factor = 2
prey_noise = 0.005
lr_prey_disc = 5e-4
lambda_gp_prey = 5
prey_update_mode = "avoid"


### Predator ###
lr_pred_policy = 1e-4
sigma_pred = 0.08

pred_dis_balance_factor = 2
pred_noise = 0.005
lr_pred_disc = 2e-4
lambda_gp_pred = 10
pred_update_mode = "mean"


# Env Settings
height = 50
width = 50
prey_speed = 5
pred_speed = 5
step_size = 0.5
theta_dot_max = 0.5
max_turn = float(theta_dot_max * step_size) + 1e-12

pert_steps = 100
init_steps = 500

env_settings = (height, width, prey_speed, pred_speed, step_size, max_turn, pert_steps)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
%matplotlib tk
exp_pred_sequence, exp_prey_sequence, couzin_metrics, actions, init_pool = run_couzin_simulation(
                                        visualization="off", 
                                        max_steps=init_steps, 
                                        constant_speed=prey_speed, shark_speed=pred_speed, 
                                        area_width=width, area_height=height,
                                        dt = step_size,
                                        alpha=0.01,
                                        theta_dot_max=theta_dot_max, theta_dot_max_shark=theta_dot_max,
                                        number_of_sharks=1, n=32)

exp_pred_sequence = exp_pred_sequence.to(device)
exp_prey_sequence = exp_prey_sequence.to(device)
init_pool = init_pool.to(device)

print("\nPred Shape:", exp_pred_sequence.shape)
print("Prey Shape:", exp_prey_sequence.shape)

exp_pred_tensor = sliding_window(exp_pred_sequence, window_size=10)
exp_prey_tensor = sliding_window(exp_prey_sequence, window_size=10)

print("\nPred Tensor Shape:", exp_pred_tensor.shape)
print("Prey Tensor Shape:", exp_prey_tensor.shape)



Pred Shape: torch.Size([500, 1, 32, 5])
Prey Shape: torch.Size([500, 32, 32, 6])

Pred Tensor Shape: torch.Size([491, 10, 1, 32, 5])
Prey Tensor Shape: torch.Size([491, 10, 32, 32, 6])


In [4]:
aug = TrajectoryAugmentation(noise_std=0.01, neigh_drop=0.10, feat_drop=0.05).to(device)
prey_encoder = TransitionEncoder(features=5, embd_dim=32, z=32).to(device)
prey_projector = VicRegProjector(input_dim=64).to(device)
prey_optimizer = torch.optim.Adam(list(prey_encoder.parameters()) + list(prey_projector.parameters()), lr=1e-3, weight_decay=1e-6)
train_encoder(prey_encoder, prey_projector, aug=aug, exp_tensor=exp_prey_tensor, epochs=1200, optimizer=prey_optimizer, role="prey")

for p in prey_encoder.parameters():
    p.requires_grad = False 

print("Prey Encoder trained & frozen.\n")

pred_encoder = TransitionEncoder(features=4, embd_dim=32, z=32).to(device)
pred_projector = VicRegProjector(input_dim=64).to(device)
pred_optimizer = torch.optim.Adam(list(pred_encoder.parameters()) + list(pred_projector.parameters()), lr=1e-3, weight_decay=1e-6)
train_encoder(pred_encoder, pred_projector, aug=aug, exp_tensor=exp_pred_tensor, epochs=800, optimizer=pred_optimizer, role="predator")

for p in pred_encoder.parameters():
    p.requires_grad = False

print("Predator Encoder trained & frozen.")

epoch 010: loss=23.101410 sim=0.0550 std=1.2995 cov=0.4469 std_mean=0.350
epoch 020: loss=22.217934 sim=0.0427 std=1.2902 cov=0.3597 std_mean=0.355
epoch 030: loss=21.511539 sim=0.0409 std=1.2332 cov=0.3982 std_mean=0.383
epoch 040: loss=21.125330 sim=0.0470 std=1.1800 cov=0.4501 std_mean=0.410
epoch 050: loss=20.637005 sim=0.0476 std=1.1417 cov=0.4644 std_mean=0.429
epoch 060: loss=20.217930 sim=0.0431 std=1.1076 cov=0.5051 std_mean=0.446
epoch 070: loss=20.332235 sim=0.0508 std=1.0937 cov=0.5313 std_mean=0.453
epoch 080: loss=20.114101 sim=0.0484 std=1.0799 cov=0.5414 std_mean=0.460
epoch 090: loss=20.111403 sim=0.0522 std=1.0651 cov=0.5658 std_mean=0.467
epoch 100: loss=19.892855 sim=0.0527 std=1.0612 cov=0.5317 std_mean=0.469
epoch 110: loss=19.920315 sim=0.0508 std=1.0565 cov=0.5606 std_mean=0.472
epoch 120: loss=19.913464 sim=0.0552 std=1.0376 cov=0.5937 std_mean=0.481
epoch 130: loss=20.070309 sim=0.0615 std=1.0589 cov=0.5296 std_mean=0.471
epoch 140: loss=19.719280 sim=0.0577 s

In [5]:
prey_discriminator = Discriminator(encoder=prey_encoder, role="prey", z_dim=32).to(device)
prey_discriminator.set_parameters(init=True)
optim_disc_prey = torch.optim.RMSprop(prey_discriminator.parameters(), lr=lr_prey_disc, alpha=0.99, eps=1e-08)

pred_discriminator = Discriminator(encoder=pred_encoder, role="predator", z_dim=32).to(device)
pred_discriminator.set_parameters(init=True)
optim_disc_pred = torch.optim.RMSprop(pred_discriminator.parameters(), lr=lr_pred_disc, alpha=0.99, eps=1e-08)

In [None]:
def discriminator_reward(discriminator, gen_tensor, mode="mean", lambda_avoid=None, lambda_attack=None):
    # get discriminator output matrix
    matrix = discriminator(gen_tensor)

    # compute mean discriminator reward
    dis_reward = matrix.mean(dim=(1, 2))

    if mode == "mean": # mean discriminator reward
        return dis_reward

    if mode == "avoid" and lambda_avoid is not None: # compute avoidance reward (prey)
        # compute euclidean distances
        dx = gen_tensor[:, :-1, :, :, 1]
        dy = gen_tensor[:, :-1, :, :, 2]
        dist = torch.sqrt(dx**2 + dy**2) + 1e-8

        # distance to predator
        pred_dist = dist[:, :, :, 0]

        # compute avoidance reward, higher reward for larger distances
        avoid_reward = pred_dist.mean(dim=(1, 2))

        # combine rewards
        reward = dis_reward + lambda_avoid * avoid_reward
        return reward
    

    if mode == "attack" and lambda_attack is not None: # compute attack reward (predator)
        # compute euclidean distances
        dx = gen_tensor[:, :-1, :, :, 1]
        dy = gen_tensor[:, :-1, :, :, 2]
        dist = torch.sqrt(dx**2 + dy**2) + 1e-8

        # distance to preys
        prey_dist = dist[:, :, :, 1:]

        # get nearest prey 
        nearest_prey_dist = prey_dist.min(dim=-1).values

        # compute attack reward, gets higher reward for closer distances
        attack_reward = (-nearest_prey_dist).mean(dim=(1, 2))

        # combine rewards
        reward = dis_reward + lambda_attack * attack_reward
        return reward

In [23]:
exp_pred_tensor.shape

torch.Size([491, 10, 1, 32, 5])

In [24]:
prey_r = discriminator_reward(prey_discriminator, exp_prey_tensor, mode="avoid", lambda_avoid=0.2)
pred_r = discriminator_reward(pred_discriminator, exp_pred_tensor, mode="attack", lambda_attack=0.1)

Avoidance Reward Mean: 0.6321227550506592
Discriminator Reward Mean: 0.12390846759080887
Combined Reward Mean: 0.2503330111503601

Attack Reward Mean: -0.2360522747039795
Discriminator Reward Mean: 0.06180749833583832
Combined Reward Mean: 0.03820226714015007
