In [7]:
import sys, os
sys.path.insert(0, os.path.abspath('..'))

from utils.OpenAI_ES import *
from utils.env_utils import *
from utils.train_utils import *
from models.Buffer import Buffer
from marl_aquarium.aquarium_v0 import parallel_env
from models.Generator import GeneratorPolicy
from models.Discriminator import Discriminator

In [None]:
"""
ToDo's
- Schwarmmetriken berechnen
- CUDA-Toolkit installieren, damit Modell auf GPU

Falls Training keine Fortschritte:
    - Drehwinkel auf 30% beschr채nken 15% links 15% rechts
    - Nur nearest Neighbors verwenden
    - Struktur Netzwerk anpassen (Layer, Batch Normalization, Dropout)
    - PPO statt ES
    - Policy-Reward-Normalisierung 체ber die Zeit (manchmal in Literatur empfohlen)
    - Clipping des Rewards durch Huber-Klut
    - Zu Beginn des Trainings Policies mit Expertendaten trainieren! - richtig gute Idee
"""

"\nToDo's\n- Schwarmmetriken berechnen\n- CUDA-Toolkit installieren, damit Modell auf GPU\n\nFalls Training keine Fortschritte:\n    - Drehwinkel auf 30% beschr채nken 15% links 15% rechts\n    - Nur nearest Neighbors verwenden\n    - Struktur Netzwerk anpassen (Layer, Batch Normalization, Dropout)\n    - PPO statt ES\n"

In [None]:
# Hyperparameters

#Environment
pred_count = 1
prey_count = 32 
action_count = 360
num_frames = 9
inital_gen_episodes = 38 # 38 len(expert_data)
num_generative_episodes = 5

# Buffer
buffer_size = 150

# Training
batch_size = 32
num_generations = 80 #convergence in Wu after 80
patience = 20
gen_dis_ratio = 5 # 1:5

# ES-Pertrubation
num_perturbations = 32 #32 f체r 8 CPU-Cores #30 in Wu paper
sigma = 0.1
gamma = 0.9997
lr_pred_policy = 0.01
lr_prey_policy = 0.05 #high to balance training between predator and prey policy

# RMSprop
lr_pred_dis =  0.0005
lr_prey_dis = 0.0005
alpha=0.99
eps=1e-08
lambda_gp_pred = 10
lambda_gp_prey = 10

In [10]:
# Training Folder
path = r"..\data\training"
timestamp = datetime.datetime.now().strftime("%d.%m.%Y_%H.%M")
folder_name = f"Training - {timestamp}"

save_dir = os.path.join(path, folder_name)
os.makedirs(save_dir, exist_ok=True)

data_path = r"..\data\processed\video_8min\expert_tensors"
buffer_path = r"..\data\buffer"

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

pred_policy = GeneratorPolicy().to(device)
pred_policy.set_parameters(init=True)
optim_policy_pred = torch.optim.Adam(pred_policy.parameters(), lr=lr_pred_policy)

prey_policy = GeneratorPolicy().to(device)
prey_policy.set_parameters(init=True)
optim_policy_prey = torch.optim.Adam(prey_policy.parameters(), lr=lr_prey_policy)

pred_discriminator = Discriminator().to(device)
pred_discriminator.set_parameters(init=True)
optim_dis_pred = torch.optim.RMSprop(pred_discriminator.parameters(), lr=lr_pred_dis, alpha=alpha, eps=eps)

prey_discriminator = Discriminator().to(device)
prey_discriminator.set_parameters(init=True)
optim_dis_prey = torch.optim.RMSprop(prey_discriminator.parameters(), lr=lr_prey_dis, alpha=alpha, eps=eps)

env = parallel_env(predator_count=pred_count, prey_count=prey_count, action_count=360)

expert_buffer = Buffer(clip_length=num_frames, max_length=buffer_size, device=device)
generative_buffer = Buffer(clip_length=num_frames, max_length=buffer_size, device=device)

early_stopper_pred = EarlyStopping(patience=patience)
early_stopper_prey = EarlyStopping(patience=patience)

#expert_buffer.load(buffer_path, type="expert", device=device)
#generative_buffer.load(buffer_path, type="generative", device=device)

if len(expert_buffer) == 0:
    print("Expert Buffer is empty, load data...")
    expert_buffer.add_expert(data_path, detections=33)

    print("Storage of Expert Buffer: ", len(expert_buffer), "\n")

if len(generative_buffer) == 0:
    print("Generative Buffer is empty, generating data...")
    generate_trajectories(buffer=generative_buffer, 
                          pred_count=pred_count, prey_count=prey_count, action_count=action_count, 
                          pred_policy=pred_policy, prey_policy=prey_policy, 
                          num_frames=num_frames, num_generative_episodes=inital_gen_episodes)
    
    print("Storage of Generative Buffer: ", len(generative_buffer))

Expert Buffer is empty, load data...
Storage of Expert Buffer:  38 

Generative Buffer is empty, generating data...
Storage of Generative Buffer:  38


In [12]:
dis_metrics_pred = []
dis_metrics_prey = []

es_metrics_pred = []
es_metrics_prey = []

for generation in range(num_generations):
    start_time = time.time()
    # Sample traj from expert and generative buffer
    expert_pred_batch, expert_prey_batch = expert_buffer.sample(batch_size)
    policy_pred_batch, policy_prey_batch = generative_buffer.sample(batch_size)

    # Predator discriminator update
    dis_metric_pred = pred_discriminator.update(expert_pred_batch, policy_pred_batch, optim_dis_pred, lambda_gp_pred)
    dis_metrics_pred.append(dis_metric_pred)
                                     
    # Prey discriminator update
    dis_metric_prey = prey_discriminator.update(expert_prey_batch, policy_prey_batch, optim_dis_prey, lambda_gp_prey)
    dis_metrics_prey.append(dis_metric_prey)

    for _ in range(gen_dis_ratio):
        pred_stats = pred_policy.update("predator", "pairwise",
                                        pred_count, prey_count, action_count,
                                        pred_policy, prey_policy,
                                        pred_discriminator, prey_discriminator,
                                        num_perturbations, generation,
                                        lr_pred_policy, lr_prey_policy,
                                        sigma, gamma)
        
        pred_stats += pred_policy.update("predator", "attention",
                                        pred_count, prey_count, action_count,
                                        pred_policy, prey_policy,
                                        pred_discriminator, prey_discriminator,
                                        num_perturbations, generation,
                                        lr_pred_policy, lr_prey_policy,
                                        sigma, gamma)
        es_metrics_pred.append(pred_stats)

        prey_stats = prey_policy.update("prey", "pairwise",
                                        pred_count, prey_count, action_count,
                                        pred_policy, prey_policy,
                                        pred_discriminator, prey_discriminator,
                                        num_perturbations, generation,
                                        lr_pred_policy, lr_prey_policy,
                                        sigma, gamma)
        
        prey_stats += prey_policy.update("prey", "attention",
                                        pred_count, prey_count, action_count,
                                        pred_policy, prey_policy,
                                        pred_discriminator, prey_discriminator,
                                        num_perturbations, generation,
                                        lr_pred_policy, lr_prey_policy,
                                        sigma, gamma)
        es_metrics_prey.append(prey_stats)

    # also reduce globally
    sigma *= gamma
    lr_pred_policy *= gamma
    lr_prey_policy *= gamma

    # Generate new trajectories with updated policies
    generate_trajectories(buffer=generative_buffer, 
                            pred_count=pred_count, prey_count=prey_count, action_count=action_count, 
                            pred_policy=pred_policy, prey_policy=prey_policy, 
                            num_frames=num_frames, num_generative_episodes=num_generative_episodes)

    pred_batch, prey_batch = generative_buffer.get_latest(num_generative_episodes)
    r_pred, r_prey = discriminator_reward(pred_batch, prey_batch, pred_discriminator, prey_discriminator)

    finish_time, time_str = remaining_time(start_time, num_generations, generation)

    avg_es_pred = np.mean([m['avg_reward_diff'] for m in pred_stats])
    avg_es_prey = np.mean([m['avg_reward_diff'] for m in prey_stats])
        
    print(f"[Generation {generation+1}/{num_generations}] - Time: {time_str} - Estimated Finish: {finish_time}" )
    print(f"Predator | Avg. ES-Reward: {avg_es_pred:.4f} | DisReward (Fitness): {r_pred:.4f} | DisLoss: {dis_metric_pred[0]:.4f} | LR: {lr_pred_policy:.4f} | Sigma: {sigma:.4f}")
    print(f"Prey     | Avg. ES-Reward: {avg_es_prey:.4f} | DisReward (Fitness): {r_prey:.4f} | DisLoss: {dis_metric_prey[0]:.4f} | LR: {lr_prey_policy:.4f} | Sigma: {sigma:.4f} \n")

    if early_stopper_pred(avg_es_pred, "predator") or early_stopper_prey(avg_es_prey, "prey"):
        break


    if generation % 50 == 0:
        save_checkpoint(save_dir, generation,
                        pred_policy, prey_policy,
                        pred_discriminator, prey_discriminator,
                        optim_dis_pred, optim_dis_prey,
                        expert_buffer, generative_buffer,
                        dis_metrics_pred, dis_metrics_prey,
                        es_metrics_pred, es_metrics_prey)


# Save models
save_models(save_dir,
            pred_policy, prey_policy,
            pred_discriminator, prey_discriminator,
            optim_dis_pred, optim_dis_prey,
            expert_buffer, generative_buffer,
            dis_metrics_pred, dis_metrics_prey,
            es_metrics_pred, es_metrics_prey)

[Generation 1/150] - Time: 13:02 - Estimated Finish: 2025-07-13 03:57:47
Predator | Avg. ES-Reward: -0.0019 | DisReward (Fitness): 0.0910 | DisLoss: 5.1951 | LR: 0.0010 | Sigma: 0.9900
Prey     | Avg. ES-Reward: -0.0021 | DisReward (Fitness): 0.1422 | DisLoss: 38.9255 | LR: 0.0005 | Sigma: 0.9900 

Checkpoint successfully saved! 
 
[Generation 2/150] - Time: 06:35 - Estimated Finish: 2025-07-12 12:02:44
Predator | Avg. ES-Reward: -0.0033 | DisReward (Fitness): 0.0894 | DisLoss: 4.8232 | LR: 0.0010 | Sigma: 0.9801
Prey     | Avg. ES-Reward: 0.0007 | DisReward (Fitness): 0.1452 | DisLoss: 32.1730 | LR: 0.0005 | Sigma: 0.9801 

[Generation 3/150] - Time: 04:13 - Estimated Finish: 2025-07-12 06:21:35
Predator | Avg. ES-Reward: 0.0068 | DisReward (Fitness): 0.0873 | DisLoss: 4.5734 | LR: 0.0010 | Sigma: 0.9703
Prey     | Avg. ES-Reward: 0.0030 | DisReward (Fitness): 0.1483 | DisLoss: 28.4427 | LR: 0.0005 | Sigma: 0.9703 

[Generation 4/150] - Time: 03:12 - Estimated Finish: 2025-07-12 04:01