In [1]:
import sys, os
sys.path.insert(0, os.path.abspath('..'))

import pickle
import datetime
from utils.es_utils import *
from utils.env_utils import *
from utils.train_utils import *
from marl_aquarium import aquarium_v0
from models.Buffer import Buffer, Pool
from models.Generator import GeneratorPolicy
from models.Discriminator import Discriminator

In [None]:
"""
Multi-Agent Imitation lernt nur so schnell wie das schwächste Glied.

ToDo's
- Training auf HL Data
- GitHub anlegen mit Env

Nach Urlaub:
    - Velocity in Netz einbauen
    - Struktur Netzwerk anpassen (Layer, Batch Normalization, Dropout)
    - Thesis anmelden
"""

'\nMulti-Agent Imitation lernt nur so schnell wie das schwächste Glied.\n\nToDo\'s\n- Training auf HL Data\n- GitHub anlegen mit Env\n- Unterlagen Anmeldung vorbereiten: "Imitating Predator-Prey Swarm Dynamics using Coevolutionary Generative Adversarial Imitation Learning"\n\nNach Urlaub:\n    - Velocity in Netz einbauen\n    - Struktur Netzwerk anpassen (Layer, Batch Normalization, Dropout)\n'

In [3]:
# Hyperparameters

#Environment
pred_count = 1
prey_count = 32 
action_count = 360
total_detections = 33
use_walls = True

# generated_trajectories
gt_gen_episodes = 20
gt_clip_length = 30

# Buffer
pred_buffer_size = 800
prey_buffer_size = 24000

# Training
num_generations = 50
pred_batch_size = 512
prey_batch_size = 1024
gen_dis_ratio = 4

# Early Stopping
start_es_pred = 70
start_es_prey = 50
patience = 20

# ES-Pertrubation
num_perturbations = 16
pert_clip_length = 28
sigma = 0.12
gamma = 0.9998
lr_pred_policy = 0.003
lr_prey_policy = 0.009

# RMSprop
lr_pred_dis =  0.0015
lr_prey_dis = 0.008
alpha=0.99
eps_dis=1e-08
lambda_gp_pred = 4
lambda_gp_prey = 3

In [4]:
# create training folder
path = r"..\data\training"
timestamp = datetime.datetime.now().strftime("%d.%m.%Y_%H.%M")
folder_name = f"Training - {timestamp}"
save_dir = os.path.join(path, folder_name)
os.makedirs(save_dir, exist_ok=True)

# Expert Data
data_path = rf'..\data\processed\pred_prey_interactions\expert_tensors\{total_detections}'
ftw_path = rf'..\data\processed\pred_prey_interactions\full_track_windows\{total_detections}'

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

pred_policy = GeneratorPolicy().to(device)
pred_policy.set_parameters(init=True)

prey_policy = GeneratorPolicy().to(device)
prey_policy.set_parameters(init=True)

pred_discriminator = Discriminator().to(device)
pred_discriminator.set_parameters(init=True)
optim_dis_pred = torch.optim.RMSprop(pred_discriminator.parameters(), lr=lr_pred_dis, alpha=alpha, eps=eps_dis)

prey_discriminator = Discriminator().to(device)
prey_discriminator.set_parameters(init=True)
optim_dis_prey = torch.optim.RMSprop(prey_discriminator.parameters(), lr=lr_prey_dis, alpha=alpha, eps=eps_dis)

expert_buffer = Buffer(pred_max_length=pred_buffer_size, prey_max_length=prey_buffer_size, device=device)
generative_buffer = Buffer(pred_max_length=pred_buffer_size, prey_max_length=prey_buffer_size, device=device)

start_frame_pool = Pool(max_length=1000, device=device)
start_frame_pool.generate_startframes(ftw_path)

early_stopper_pred = EarlyStoppingWasserstein(patience=patience, start_es=start_es_pred)
early_stopper_prey = EarlyStoppingWasserstein(patience=patience, start_es=start_es_prey)

In [6]:
# Load Expert Data from local storage
print("Expert Buffer is empty, load data...")
expert_buffer.add_expert(data_path)
len_exp_pred, len_exp_prey = expert_buffer.lengths()

print("Storage of Predator Expert Buffer: ", len_exp_pred)
print("Storage of Prey Expert Buffer: ", len_exp_prey, "\n")

Expert Buffer is empty, load data...


ValueError: too many values to unpack (expected 4)

In [None]:
# Pretrain Policies with Expert Data
print("Pretraining Policies with Behavioral Cloning on Expert Data...\n")
pred_policy = pretrain_policy(pred_policy, expert_buffer, role='predator', pred_bs=512, epochs=200, lr=1e-3, save_dir=save_dir)
prey_policy = pretrain_policy(prey_policy, expert_buffer, role='prey', prey_bs=1024, epochs=200, lr=1e-3, save_dir=save_dir)

In [None]:
# Generate Trajectories for Generative Buffer
print("Generative Buffer is empty, generating data...")
generate_trajectories(buffer=generative_buffer, start_frame_pool=start_frame_pool,
                        pred_count=pred_count, prey_count=prey_count, action_count=action_count, 
                        pred_policy=pred_policy, prey_policy=prey_policy, 
                        clip_length=gt_clip_length, num_generative_episodes=gt_gen_episodes,
                        use_walls=True)

len_gen_pred, len_gen_prey = generative_buffer.lengths()
print("Storage of Predator Generative Buffer: ", len_gen_pred)
print("Storage of Prey Generative Buffer: ", len_gen_prey)

Generative Buffer is empty, generating data...
Storage of Predator Generative Buffer:  600
Storage of Prey Generative Buffer:  19200


In [None]:
dis_metrics_pred = []
dis_metrics_prey = []

es_metrics_pred = []
es_metrics_prey = []

for generation in range(num_generations):
    start_time = time.time()
    
    # Sample traj from expert and generative buffer
    expert_pred_batch, expert_prey_batch = expert_buffer.sample(pred_batch_size, prey_batch_size)
    policy_pred_batch, policy_prey_batch = generative_buffer.sample(pred_batch_size, prey_batch_size)

    # Predator discriminator update
    dis_metric_pred = pred_discriminator.update(expert_pred_batch, policy_pred_batch, optim_dis_pred, lambda_gp_pred)
    dis_metrics_pred.append(dis_metric_pred)
                                     
    # Prey discriminator update
    dis_metric_prey = prey_discriminator.update(expert_prey_batch, policy_prey_batch, optim_dis_prey, lambda_gp_prey)
    dis_metrics_prey.append(dis_metric_prey)

    for i in range(gen_dis_ratio):
        pred_stats = pred_policy.update("predator", "pairwise",
                                        pred_count, prey_count, action_count,
                                        pred_policy, prey_policy,
                                        pred_discriminator, prey_discriminator,
                                        num_perturbations, generation,
                                        lr_pred_policy, lr_prey_policy,
                                        sigma, gamma, clip_length=pert_clip_length,
                                        use_walls=use_walls, start_frame_pool=start_frame_pool)
        

        pred_stats += pred_policy.update("predator", "attention",
                                        pred_count, prey_count, action_count,
                                        pred_policy, prey_policy,
                                        pred_discriminator, prey_discriminator,
                                        num_perturbations, generation,
                                        lr_pred_policy, lr_prey_policy,
                                        sigma, gamma, clip_length=pert_clip_length,
                                        use_walls=use_walls, start_frame_pool=start_frame_pool)
        es_metrics_pred.append(pred_stats)


        prey_stats = prey_policy.update("prey", "pairwise",
                                        pred_count, prey_count, action_count,
                                        pred_policy, prey_policy,
                                        pred_discriminator, prey_discriminator,
                                        num_perturbations, generation,
                                        lr_pred_policy, lr_prey_policy,
                                        sigma, gamma, clip_length=pert_clip_length,
                                        use_walls=use_walls, start_frame_pool=start_frame_pool)
        
        prey_stats += prey_policy.update("prey", "attention",
                                        pred_count, prey_count, action_count,
                                        pred_policy, prey_policy,
                                        pred_discriminator, prey_discriminator,
                                        num_perturbations, generation,
                                        lr_pred_policy, lr_prey_policy,
                                        sigma, gamma, clip_length=pert_clip_length,
                                        use_walls=use_walls, start_frame_pool=start_frame_pool)
        es_metrics_prey.append(prey_stats)

        # Generate new trajectories with updated policies
        generate_trajectories(buffer=generative_buffer, start_frame_pool=start_frame_pool,
                                pred_count=pred_count, prey_count=prey_count, action_count=action_count, 
                                pred_policy=pred_policy, prey_policy=prey_policy, 
                                clip_length=gt_clip_length, num_generative_episodes=gt_gen_episodes,
                                use_walls=use_walls)

    # also reduce globally
    sigma *= gamma
    lr_pred_policy *= gamma
    lr_prey_policy *= gamma

    last_epoch_duration = time.time() - start_time
    estimated_time, epoch_time = remaining_time(num_generations, last_epoch_duration, current_generation=generation)

    avg_es_pred = np.mean([m['avg_reward_diff'] for m in pred_stats])
    avg_es_prey = np.mean([m['avg_reward_diff'] for m in prey_stats])
        
    print(f"[Generation {generation+1}/{num_generations}] - Time: {epoch_time} - Estimated Finish: {estimated_time}" )
    print(f"Predator | Avg. ES-Reward: {avg_es_pred:.4f} | Wasserstein Loss: {dis_metric_pred[0]:.4f} | Expert Scores: {dis_metric_pred[2]:.4f} | Policy Scores: {dis_metric_pred[3]:.4f}")
    print(f"Prey     | Avg. ES-Reward: {avg_es_prey:.4f} | Wasserstein Loss: {dis_metric_prey[0]:.4f} | Expert Scores: {dis_metric_prey[2]:.4f} | Policy Scores: {dis_metric_prey[3]:.4f}\n")

    if early_stopper_pred(dis_metric_pred[0], generation, "predator") or early_stopper_prey(dis_metric_prey[0], generation, "prey"):
        break


    if generation % 25 == 0:
        save_checkpoint(save_dir, generation,
                        pred_policy, prey_policy,
                        pred_discriminator, prey_discriminator,
                        optim_dis_pred, optim_dis_prey,
                        expert_buffer, generative_buffer,
                        dis_metrics_pred, dis_metrics_prey,
                        es_metrics_pred, es_metrics_prey)


# Save models
save_models(save_dir,
            pred_policy, prey_policy,
            pred_discriminator, prey_discriminator,
            optim_dis_pred, optim_dis_prey,
            expert_buffer, generative_buffer,
            dis_metrics_pred, dis_metrics_prey,
            es_metrics_pred, es_metrics_prey)

[Generation 1/50] - Time: 19:01 - Estimated Finish: 06.08.2025 03:39:59
Predator | Avg. ES-Reward: 0.0279 | Wasserstein Loss: 3.1616 | Expert Scores: -0.0520 | Policy Scores: -0.0149
Prey     | Avg. ES-Reward: 0.7099 | Wasserstein Loss: 2.4791 | Expert Scores: 0.1916 | Policy Scores: 0.2092

Checkpoint successfully saved! 
 
[Generation 2/50] - Time: 19:22 - Estimated Finish: 06.08.2025 03:56:31
Predator | Avg. ES-Reward: 0.3146 | Wasserstein Loss: 1.7984 | Expert Scores: 0.0576 | Policy Scores: -0.4058
Prey     | Avg. ES-Reward: -0.2045 | Wasserstein Loss: 4.4741 | Expert Scores: 1.5678 | Policy Scores: -3.1331

[Generation 3/50] - Time: 19:08 - Estimated Finish: 06.08.2025 03:45:35
Predator | Avg. ES-Reward: 0.5242 | Wasserstein Loss: -1.0042 | Expert Scores: 0.4239 | Policy Scores: -1.1570
Prey     | Avg. ES-Reward: 0.4680 | Wasserstein Loss: 2.1963 | Expert Scores: -0.0532 | Policy Scores: 1.4880

[Generation 4/50] - Time: 19:06 - Estimated Finish: 06.08.2025 03:44:10
Predator | Av