In [1]:
import sys, os
sys.path.insert(0, os.path.abspath('..'))

import pickle
import datetime
from utils.es_utils import *
from utils.env_utils import *
from utils.train_utils import *
from marl_aquarium import aquarium_v0
from models.Buffer import Buffer, Pool
from models.Generator import GeneratorPolicy
from models.Discriminator import Discriminator

In [None]:
"""
Multi-Agent Imitation lernt nur so schnell wie das schw채chste Glied.

ToDo's
- Wasserstein Loss
- H채ndisch Predator Aktion labeln

Falls Training keine Fortschritte:
    - Discriminator Prey-Input (32, 32, 5)?
    - CUDA-Toolkit installieren, damit Modell auf GPU
    - Struktur Netzwerk anpassen (Layer, Batch Normalization, Dropout)
    - PPO statt ES
    - WGAN Loss
    - Abwechselndes Training von Modulen und Policies
"""

"\nMulti-Agent Imitation lernt nur so schnell wie das schw채chste Glied.\n\nToDo's\n- Tuning bauen (Ray)\n- H채ndisch Predator Aktion labeln\n\nFalls Training keine Fortschritte:\n    - Discriminator Prey-Input (32, 32, 5)?\n    - CUDA-Toolkit installieren, damit Modell auf GPU\n    - Struktur Netzwerk anpassen (Layer, Batch Normalization, Dropout)\n    - PPO statt ES\n    - WGAN Loss\n    - Abwechselndes Training von Modulen und Policies\n"

In [3]:
# Hyperparameters

#Environment
pred_count = 1
prey_count = 32 
action_count = 360
use_walls = True

# generated_trajectories
gt_gen_episodes = 10
gt_clip_length = 30

# Buffer
pred_buffer_size = 800
prey_buffer_size = 24000

# Training
pred_batch_size = 128           #     32, 512
prey_batch_size = 1024          #    256, 1024
num_generations = 200           #     80, 200
patience = 15                   #     20,   
window_size = 10                #     10,
min_slope = 0
gen_dis_ratio = 1               #      7, 5, 1

# ES-Pertrubation
num_perturbations = 40          #     32, 40
pert_clip_length = 15           #     10, 30, 100, 15
sigma = 0.25                    #    0.1, 0.2, 0.25   
gamma = 0.9997                  # 0.9997,
lr_pred_policy = 0.03           #   0.01, 0.02, 0.03
lr_prey_policy = 0.03           #   0.05, 0.03

# RMSprop
lr_pred_dis =  0.0005           #  0.001, 0.0005, 0.0001, 0.0005,
lr_prey_dis = 0.0005            #  0.001, 0.0005, 0.0001, 0.0005,
alpha=0.99                      #   0.99,
eps_dis=1e-08                   #  1e-08,
lambda_gp_pred = 5              #     10, 5
lambda_gp_prey = 5              #     10, 5
label_smoothing = True
smooth = 0.1

In [4]:
# Video Data
video = "video_8min"
num_frames=1
total_detections=33

# create training folder
path = r"..\data\training"
timestamp = datetime.datetime.now().strftime("%d.%m.%Y_%H.%M")
folder_name = f"Training - {timestamp}"
save_dir = os.path.join(path, folder_name)
os.makedirs(save_dir, exist_ok=True)

# Video folder
raw_video_folder = rf'..\data\raw\videos'
video_path = raw_video_folder + "\\" + video + ".mp4"

# Expert Data
data_path = rf"..\data\processed\{video}\expert_tensors"
processed_video_folder = rf'..\data\processed\{video}'
ftw_path = os.path.join(processed_video_folder, "full_track_windows", f"full_track_windows_{total_detections}.pkl")

with open(ftw_path, "rb") as f:
    full_track_windows = pickle.load(f)

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

pred_policy = GeneratorPolicy().to(device)
pred_policy.set_parameters(init=True)

prey_policy = GeneratorPolicy().to(device)
prey_policy.set_parameters(init=True)

pred_discriminator = Discriminator().to(device)
pred_discriminator.set_parameters(init=True)
optim_dis_pred = torch.optim.RMSprop(pred_discriminator.parameters(), lr=lr_pred_dis, alpha=alpha, eps=eps_dis)

prey_discriminator = Discriminator().to(device)
prey_discriminator.set_parameters(init=True)
optim_dis_prey = torch.optim.RMSprop(prey_discriminator.parameters(), lr=lr_prey_dis, alpha=alpha, eps=eps_dis)

expert_buffer = Buffer(pred_max_length=pred_buffer_size, prey_max_length=prey_buffer_size, device=device)
generative_buffer = Buffer(pred_max_length=pred_buffer_size, prey_max_length=prey_buffer_size, device=device)

start_frame_pool = Pool(max_length=1000, device=device)
start_frame_pool.generate_startframes(video_path, full_track_windows)

early_stopper_pred = EarlyStoppingWindow(patience=patience, window_size=window_size, min_slope=min_slope)
early_stopper_prey = EarlyStoppingWindow(patience=patience, window_size=window_size, min_slope=min_slope)

In [6]:
# Load Expert Data from local storage
print("Expert Buffer is empty, load data...")
expert_buffer.add_expert(data_path, window_size=1, detections=33)
len_exp_pred, len_exp_prey = expert_buffer.lengths()

print("Storage of Predator Expert Buffer: ", len_exp_pred)
print("Storage of Prey Expert Buffer: ", len_exp_prey, "\n")

Expert Buffer is empty, load data...
Storage of Predator Expert Buffer:  742
Storage of Prey Expert Buffer:  23744 



In [7]:
# Pretrain Policies with Expert Data
print("Pretraining Policies with Behavioral Cloning on Expert Data...\n")
pred_policy = pretrain_policy(pred_policy, expert_buffer, role='predator', pred_bs=512, epochs=120, lr=1e-3, save_dir=save_dir)
prey_policy = pretrain_policy(prey_policy, expert_buffer, role='prey', prey_bs=1024, epochs=220, lr=1e-3, save_dir=save_dir)

Pretraining Policies with Behavioral Cloning on Expert Data...



In [8]:
# Generate Trajectories for Generative Buffer
print("Generative Buffer is empty, generating data...")
generate_trajectories(buffer=generative_buffer, start_frame_pool=start_frame_pool,
                        pred_count=pred_count, prey_count=prey_count, action_count=action_count, 
                        pred_policy=pred_policy, prey_policy=prey_policy, 
                        clip_length=gt_clip_length, num_generative_episodes=gt_gen_episodes,
                        use_walls=True)

len_gen_pred, len_gen_prey = generative_buffer.lengths()
print("Storage of Predator Generative Buffer: ", len_gen_pred)
print("Storage of Prey Generative Buffer: ", len_gen_prey)

Generative Buffer is empty, generating data...
Storage of Predator Generative Buffer:  300
Storage of Prey Generative Buffer:  9600


In [9]:
dis_metrics_pred = []
dis_metrics_prey = []

es_metrics_pred = []
es_metrics_prey = []

for generation in range(num_generations):
    start_time = time.time()
    
    # Sample traj from expert and generative buffer
    expert_pred_batch, expert_prey_batch = expert_buffer.sample(pred_batch_size, prey_batch_size)
    policy_pred_batch, policy_prey_batch = generative_buffer.sample(pred_batch_size, prey_batch_size)

    # Predator discriminator update
    dis_metric_pred = pred_discriminator.update(expert_pred_batch, policy_pred_batch, optim_dis_pred, lambda_gp_pred, label_smoothing, smooth)
    dis_metrics_pred.append(dis_metric_pred)
                                     
    # Prey discriminator update
    dis_metric_prey = prey_discriminator.update(expert_prey_batch, policy_prey_batch, optim_dis_prey, lambda_gp_prey, label_smoothing, smooth)
    dis_metrics_prey.append(dis_metric_prey)

    for i in range(gen_dis_ratio):
        pred_stats = pred_policy.update("predator", "pairwise",
                                        pred_count, prey_count, action_count,
                                        pred_policy, prey_policy,
                                        pred_discriminator, prey_discriminator,
                                        num_perturbations, generation,
                                        lr_pred_policy, lr_prey_policy,
                                        sigma, gamma, clip_length=pert_clip_length,
                                        use_walls=use_walls, start_frame_pool=start_frame_pool)
        

        pred_stats += pred_policy.update("predator", "attention",
                                        pred_count, prey_count, action_count,
                                        pred_policy, prey_policy,
                                        pred_discriminator, prey_discriminator,
                                        num_perturbations, generation,
                                        lr_pred_policy, lr_prey_policy,
                                        sigma, gamma, clip_length=pert_clip_length,
                                        use_walls=use_walls, start_frame_pool=start_frame_pool)
        es_metrics_pred.append(pred_stats)


        prey_stats = prey_policy.update("prey", "pairwise",
                                        pred_count, prey_count, action_count,
                                        pred_policy, prey_policy,
                                        pred_discriminator, prey_discriminator,
                                        num_perturbations, generation,
                                        lr_pred_policy, lr_prey_policy,
                                        sigma, gamma, clip_length=pert_clip_length,
                                        use_walls=use_walls, start_frame_pool=start_frame_pool)
        
        prey_stats += prey_policy.update("prey", "attention",
                                        pred_count, prey_count, action_count,
                                        pred_policy, prey_policy,
                                        pred_discriminator, prey_discriminator,
                                        num_perturbations, generation,
                                        lr_pred_policy, lr_prey_policy,
                                        sigma, gamma, clip_length=pert_clip_length,
                                        use_walls=use_walls, start_frame_pool=start_frame_pool)
        es_metrics_prey.append(prey_stats)

        # Generate new trajectories with updated policies
        generate_trajectories(buffer=generative_buffer, start_frame_pool=start_frame_pool,
                                pred_count=pred_count, prey_count=prey_count, action_count=action_count, 
                                pred_policy=pred_policy, prey_policy=prey_policy, 
                                clip_length=gt_clip_length, num_generative_episodes=gt_gen_episodes,
                                use_walls=use_walls)

    # also reduce globally
    sigma *= gamma
    lr_pred_policy *= gamma
    lr_prey_policy *= gamma

    last_epoch_duration = time.time() - start_time
    estimated_time, epoch_time = remaining_time(num_generations, last_epoch_duration, current_generation=generation)

    avg_es_pred = np.mean([m['avg_reward_diff'] for m in pred_stats])
    avg_es_prey = np.mean([m['avg_reward_diff'] for m in prey_stats])
        
    print(f"[Generation {generation+1}/{num_generations}] - Time: {epoch_time} - Estimated Finish: {estimated_time}" )
    print(f"Predator | Avg. ES-Reward: {avg_es_pred:.4f} | DisLoss: {dis_metric_pred[0]:.4f} | Expert Scores: {dis_metric_pred[2]:.4f} | Policy Scores: {dis_metric_pred[3]:.4f}")
    print(f"Prey     | Avg. ES-Reward: {avg_es_prey:.4f} | DisLoss: {dis_metric_prey[0]:.4f} | Expert Scores: {dis_metric_prey[2]:.4f} | Policy Scores: {dis_metric_prey[3]:.4f}\n")

#    if early_stopper_pred(avg_es_pred, "predator") or early_stopper_prey(avg_es_prey, "prey"):
#        break


    if generation % 50 == 0:
        save_checkpoint(save_dir, generation,
                        pred_policy, prey_policy,
                        pred_discriminator, prey_discriminator,
                        optim_dis_pred, optim_dis_prey,
                        expert_buffer, generative_buffer,
                        dis_metrics_pred, dis_metrics_prey,
                        es_metrics_pred, es_metrics_prey)


# Save models
save_models(save_dir,
            pred_policy, prey_policy,
            pred_discriminator, prey_discriminator,
            optim_dis_pred, optim_dis_prey,
            expert_buffer, generative_buffer,
            dis_metrics_pred, dis_metrics_prey,
            es_metrics_pred, es_metrics_prey)

[Generation 1/200] - Time: 4:55 - Estimated Finish: 26.07.2025 03:06:53
Predator | Avg. ES-Reward: 0.0192 | DisLoss: 6.2049 | Expert Scores: 0.5183 | Policy Scores: 0.5286
Prey     | Avg. ES-Reward: 0.0830 | DisLoss: 6.1002 | Expert Scores: 0.4781 | Policy Scores: 0.4610

Checkpoint successfully saved! 
 
[Generation 2/200] - Time: 4:40 - Estimated Finish: 26.07.2025 02:16:53
Predator | Avg. ES-Reward: -0.0236 | DisLoss: 6.0806 | Expert Scores: 0.5271 | Policy Scores: 0.4955
Prey     | Avg. ES-Reward: -0.0477 | DisLoss: 5.8654 | Expert Scores: 0.4773 | Policy Scores: 0.4156

[Generation 3/200] - Time: 4:57 - Estimated Finish: 26.07.2025 03:13:39
Predator | Avg. ES-Reward: 0.1401 | DisLoss: 5.8806 | Expert Scores: 0.5382 | Policy Scores: 0.4662
Prey     | Avg. ES-Reward: 0.1406 | DisLoss: 5.6549 | Expert Scores: 0.4788 | Policy Scores: 0.4143

[Generation 4/200] - Time: 4:53 - Estimated Finish: 26.07.2025 02:59:29
Predator | Avg. ES-Reward: 0.1894 | DisLoss: 5.6372 | Expert Scores: 0.55