In [1]:
import sys, os
sys.path.insert(0, os.path.abspath('..'))

import os
import tqdm
import time
import torch
import datetime
from utils.OpenAI_ES import *
from utils.env_utils import *
from utils.train_utils import *
from models.Buffer import Buffer
from marl_aquarium import aquarium_v0
from models.Generator import GeneratorPolicy
from models.Discriminator import Discriminator

In [None]:
"""
ToDo's
- Anscheinend wird in prey_to_tensor der Räuber nicht berücksichtigt! - anscheinend auch bei generativen Daten (get_state_actions)

Beschleunigen:
- Parallelization von ES einbauen
- parallel_step() in env nutzen!
- Samples für generative Buffer direkt während update threaden
- to.device() einbauen (auch get_rollouts)
- get_env_state_actions() mit Numpy Arrays umsetzen statt df

- Early-Stopping

Anschließend (während training):
- Evaluation der models
- File schreiben indem trainiere Policies geladen und Env gezeigt wird
"""

"\nToDo's\n- Anscheinend wird in prey_to_tensor der Räuber nicht berücksichtigt! - anscheinend auch bei generativen Daten (get_state_actions)\n\nBeschleunigen:\n- Parallelization von ES einbauen\n- parallel_step() in env nutzen!\n- Samples für generative Buffer direkt während update threaden\n- to.device() einbauen (auch get_rollouts)\n- get_env_state_actions() mit Numpy Arrays umsetzen statt df\n\nAnschließend (während training):\n- Evaluation der models\n- File schreiben indem trainiere Policies geladen und Env gezeigt wird\n"

In [3]:
# Hyperparameters

#Environment
pred_count = 1
prey_count = 32 
num_frames = 9
num_generative_episodes = 8 # 150 len(expert_data)

# Buffer
buffer_size = 200 # 300 is length expert data
num_trim = 20

# Training
batch_size = 8
epochs = 3

# Policy
lr_pred_policy = 0.1 #like  Wu paper
lr_prey_policy = 0.1
sigma = 0.1
gamma = 0.99
num_generations = 2 #10
num_perturbations = 2 #30 in Wu paper

# Discriminator
lr_pred_dis =  0.0001
lr_prey_dis = 0.0001

lambda_gp_pred = 1
lambda_gp_prey = 1


In [4]:
# Training Folder
path = r"..\data\training"
timestamp = datetime.datetime.now().strftime("%d.%m.%Y_%H.%M")
folder_name = f"Training - {timestamp}"

save_dir = os.path.join(path, folder_name)
os.makedirs(save_dir, exist_ok=True)

data_path = r"..\data\processed\video_8min\tensors\32 Preys"

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

pred_policy = GeneratorPolicy().to(device)
pred_policy.set_parameters(init=True)
optim_policy_pred = torch.optim.Adam(pred_policy.parameters(), lr=lr_pred_policy)

prey_policy = GeneratorPolicy().to(device)
prey_policy.set_parameters(init=True)
optim_policy_prey = torch.optim.Adam(prey_policy.parameters(), lr=lr_prey_policy)

pred_discriminator = Discriminator().to(device)
pred_discriminator.set_parameters(init=True)
optim_dis_pred = torch.optim.Adam(pred_discriminator.parameters(), lr=lr_pred_dis)

prey_discriminator = Discriminator().to(device)
prey_discriminator.set_parameters(init=True)
optim_dis_prey = torch.optim.Adam(prey_discriminator.parameters(), lr=lr_prey_dis)

env = aquarium_v0.env(predator_count=pred_count, prey_count=prey_count, action_count=360)

# Buffer handling
expert_buffer = Buffer(clip_length=num_frames, max_length=buffer_size)
generative_buffer = Buffer(clip_length=num_frames, max_length=buffer_size)

try:
    expert_buffer.load(os.path.join(save_dir, "buffers"), "expert")
    generative_buffer.load(os.path.join(save_dir, "buffers"), "generative")

except:
    if len(expert_buffer) == 0:
        print("Expert Buffer is empty, load data...")
        expert_buffer.add_expert(data_path)
        print("Storage of Expert Buffer: ", len(expert_buffer), "\n")

    if len(generative_buffer) == 0:
        print("Generative Buffer is empty, generating data...")
        for episode in tqdm.tqdm(range(num_generative_episodes)):
            pred_tensors, prey_tensors = get_rollouts(env, pred_policy, prey_policy, num_frames=num_frames, render=False)
            generative_buffer.add_generative(pred_tensors, prey_tensors)
        print("Storage of Generative Buffer: ", len(generative_buffer))

Expert Buffer is empty, load data...
Storage of Expert Buffer:  150 

Generative Buffer is empty, generating data...


100%|██████████| 8/8 [00:25<00:00,  3.24s/it]

Storage of Generative Buffer:  8





In [None]:
losses_pred_discriminator = []
losses_prey_discriminator = []

es_metrics_pred = []
es_metrics_prey = []

for i in range(epochs):
    print(f"[Epoch {i+1}/{epochs}]")
    start_time = time.time()
    # Sample traj from expert and generative buffer
    expert_pred_batch, expert_prey_batch = expert_buffer.sample(batch_size, device=device)
    policy_pred_batch, policy_prey_batch = generative_buffer.sample(batch_size, device=device)

    # Predator discriminator update
    print("Updating Discriminators...")
    loss_pred_discriminator = pred_discriminator.update(expert_pred_batch, policy_pred_batch, optim_dis_pred, lambda_gp_pred)
    losses_pred_discriminator.append(loss_pred_discriminator)
                                     
    # Prey discriminator update
    loss_prey_discriminator = prey_discriminator.update(expert_prey_batch, policy_prey_batch, optim_dis_prey, lambda_gp_prey)
    losses_prey_discriminator.append(loss_prey_discriminator)

    # Policy updates mit OpenAI-ES
    print("Updating Policies...")
    pred_stats = pred_policy.update("predator", "pairwise", env, 
                                    pred_policy, prey_policy,
                                    pred_discriminator, prey_discriminator,
                                    num_generations, num_perturbations,
                                    lr_pred_policy, lr_prey_policy,
                                    sigma, gamma)
    
    pred_stats += pred_policy.update("predator", "attention", env, 
                                    pred_policy, prey_policy,
                                    pred_discriminator, prey_discriminator,
                                    num_generations, num_perturbations, 
                                    lr_pred_policy, lr_prey_policy,
                                    sigma, gamma)
    es_metrics_pred.append(pred_stats)

    prey_stats = prey_policy.update("prey", "pairwise", env, 
                                    pred_policy, prey_policy,
                                    pred_discriminator, prey_discriminator,
                                    num_generations, num_perturbations, 
                                    lr_pred_policy, lr_prey_policy,
                                    sigma, gamma)
    
    prey_stats += prey_policy.update("prey", "attention", env, 
                                    pred_policy, prey_policy,
                                    pred_discriminator, prey_discriminator,
                                    num_generations, num_perturbations, 
                                    lr_pred_policy, lr_prey_policy,
                                    sigma, gamma)
    es_metrics_prey.append(prey_stats)


    # Generate new trajectories with updated policies
    print("Generating new trajectories... \n")
    pred_tensors = []
    prey_tensors = []

    for j in range(num_generative_episodes):
        pred_tensor, prey_tensor = get_rollouts(env, pred_policy, prey_policy, num_frames=num_frames, render=False)
        pred_tensors.append(pred_tensor)
        prey_tensors.append(prey_tensor)
        generative_buffer.add_generative(pred_tensor, prey_tensor)

    pred_batch = torch.stack(pred_tensors, dim=0).to(device)
    prey_batch = torch.stack(prey_tensors, dim=0).to(device)
    r_pred, r_prey = discriminator_reward(pred_batch, prey_batch, pred_discriminator, prey_discriminator)

    finish_time = remaining_time(start_time, epochs, i)

    if i % 5 == 0:
        avg_es_pred = np.mean([m['avg_reward_diff'] for m in pred_stats])
        avg_es_prey = np.mean([m['avg_reward_diff'] for m in prey_stats])
        
        print(f"Iteration {i}, Estimated Finish: {finish_time}" )
        print(f"Predator | Avg. ES-Reward: {avg_es_pred:.2f} | DisReward (Fitness): {r_pred:.2f} | DisLoss: {loss_pred_discriminator:.2f} | LR: {lr_pred_policy} | Sigma: {sigma}")
        print(f"Prey     | Avg. ES-Reward: {avg_es_prey:.2f} | DisReward (Fitness): {r_prey:.2f} | DisLoss: {loss_prey_discriminator:.2f} | LR: {lr_prey_policy} | Sigma: {sigma} \n")
    

    if i % 50 == 0:
        save_checkpoint(path, i,
                        pred_policy, prey_policy,
                        pred_discriminator, prey_discriminator,
                        optim_dis_pred, optim_dis_prey,
                        expert_buffer, generative_buffer,
                        losses_pred_discriminator, losses_prey_discriminator,
                        es_metrics_pred, es_metrics_prey)


# Save models
save_models(path,
            pred_policy, prey_policy,
            pred_discriminator, prey_discriminator,
            optim_dis_pred, optim_dis_prey,
            expert_buffer, generative_buffer,
            losses_pred_discriminator, losses_prey_discriminator,
            es_metrics_pred, es_metrics_prey)

Starting Epoch 1/3
Updating Discriminators...
Updating Policies...
Generating new trajectories... 


 Checkpoint successfully saved!
Iteration 0, Estimated Finish: 2025-06-30 16:06:10
Predator | Avg. ES-Reward: -0.00 | DisReward (Fitness): -0.19 | DisLoss: 25.60 | LR: 0.1 | Sigma: 0.1
Prey     | Avg. ES-Reward: 0.01 | DisReward (Fitness): -0.03 | DisLoss: 1.14 | LR: 0.1 | Sigma: 0.1 

Starting Epoch 2/3
Updating Discriminators...
Updating Policies...
Generating new trajectories... 

Starting Epoch 3/3
Updating Discriminators...
Updating Policies...
Generating new trajectories... 


 Checkpoint successfully saved!
Iteration 2, Estimated Finish: 2025-06-30 16:05:56
Predator | Avg. ES-Reward: 0.00 | DisReward (Fitness): -0.08 | DisLoss: 28.49 | LR: 0.1 | Sigma: 0.1
Prey     | Avg. ES-Reward: -0.01 | DisReward (Fitness): 0.09 | DisLoss: 0.94 | LR: 0.1 | Sigma: 0.1 

