In [1]:
import sys, os
sys.path.insert(0, os.path.abspath('..'))

import ast
import optuna
import pickle
import datetime
from utils.es_utils import *
from utils.env_utils import *
from utils.train_utils import *
from plotly.io import write_html
import optuna.visualization as vis
from marl_aquarium import aquarium_v0
from models.Buffer import Buffer, Pool
from models.Generator import GeneratorPolicy
from models.Discriminator import Discriminator

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Hyperparameters

#Environment
pred_count = 1
prey_count = 32 
action_count = 360
use_walls = True

# generated_trajectories
gt_gen_episodes = 20
gt_clip_length = 30

# Buffer
pred_buffer_size = 800
prey_buffer_size = 24000

# Training
pred_batch_size = 128           #     32, 512
prey_batch_size = 1024          #    256, 1024
num_generations = 200           #     80, 200
patience = 15                   #     20,   
window_size = 10                #     10,
min_slope = 0
gen_dis_ratio = 1               #      7, 5, 1

# ES-Pertrubation
num_perturbations = 40          #     32, 40
pert_clip_length = 15           #     10, 30, 100, 15
sigma = 0.25                    #    0.1, 0.2, 0.25   
gamma = 0.9997                  # 0.9997,
lr_pred_policy = 0.03           #   0.01, 0.02, 0.03
lr_prey_policy = 0.03           #   0.05, 0.03

# RMSprop
lr_pred_dis =  0.0005           #  0.001, 0.0005, 0.0001, 0.0005,
lr_prey_dis = 0.0005            #  0.001, 0.0005, 0.0001, 0.0005,
alpha=0.99                      #   0.99,
eps_dis=1e-08                   #  1e-08,
lambda_gp_pred = 5              #     10, 5
lambda_gp_prey = 5              #     10, 5
label_smoothing = True
smooth = 0.1

In [3]:
# Video Data
video = "video_8min"
num_frames=1
total_detections=33

# create training folder
path = r"..\data\training"
timestamp = datetime.datetime.now().strftime("%d.%m.%Y_%H.%M")
folder_name = f"Tuning - {timestamp}"
save_dir = os.path.join(path, folder_name)
os.makedirs(save_dir, exist_ok=True)

# Video folder
raw_video_folder = rf'..\data\raw\videos'
video_path = raw_video_folder + "\\" + video + ".mp4"

# Expert Data
data_path = rf"..\data\processed\{video}\expert_tensors"
processed_video_folder = rf'..\data\processed\{video}'
ftw_path = os.path.join(processed_video_folder, "full_track_windows", f"full_track_windows_{total_detections}.pkl")

with open(ftw_path, "rb") as f:
    full_track_windows = pickle.load(f)

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

pred_policy = GeneratorPolicy().to(device)
pred_policy.set_parameters(init=True)

prey_policy = GeneratorPolicy().to(device)
prey_policy.set_parameters(init=True)

pred_discriminator = Discriminator().to(device)
pred_discriminator.set_parameters(init=True)

prey_discriminator = Discriminator().to(device)
prey_discriminator.set_parameters(init=True)

expert_buffer = Buffer(pred_max_length=pred_buffer_size, prey_max_length=prey_buffer_size, device=device)
generative_buffer = Buffer(pred_max_length=pred_buffer_size, prey_max_length=prey_buffer_size, device=device)

start_frame_pool = Pool(max_length=1000, device=device)
start_frame_pool.generate_startframes(video_path, full_track_windows)

early_stopper_pred = EarlyStoppingWindow(patience=patience, window_size=window_size, min_slope=min_slope)
early_stopper_prey = EarlyStoppingWindow(patience=patience, window_size=window_size, min_slope=min_slope)

In [5]:
# Load Expert Data from local storage
print("Expert Buffer is empty, load data...")
expert_buffer.add_expert(data_path, window_size=1, detections=33)
len_exp_pred, len_exp_prey = expert_buffer.lengths()

print("Storage of Predator Expert Buffer: ", len_exp_pred)
print("Storage of Prey Expert Buffer: ", len_exp_prey, "\n")

Expert Buffer is empty, load data...
Storage of Predator Expert Buffer:  742
Storage of Prey Expert Buffer:  23744 



In [6]:
# Pretrain Policies with Expert Data
print("Pretraining Policies with Behavioral Cloning on Expert Data...\n")
pred_policy = pretrain_policy(pred_policy, expert_buffer, role='predator', pred_bs=512, epochs=120, lr=1e-3, save_dir=save_dir)
prey_policy = pretrain_policy(prey_policy, expert_buffer, role='prey', prey_bs=1024, epochs=220, lr=1e-3, save_dir=save_dir)

Pretraining Policies with Behavioral Cloning on Expert Data...



In [7]:
# Generate Trajectories for Generative Buffer
print("Generative Buffer is empty, generating data...")
generate_trajectories(buffer=generative_buffer, start_frame_pool=start_frame_pool,
                        pred_count=pred_count, prey_count=prey_count, action_count=action_count, 
                        pred_policy=pred_policy, prey_policy=prey_policy, 
                        clip_length=gt_clip_length, num_generative_episodes=gt_gen_episodes,
                        use_walls=True)

len_gen_pred, len_gen_prey = generative_buffer.lengths()
print("Storage of Predator Generative Buffer: ", len_gen_pred)
print("Storage of Prey Generative Buffer: ", len_gen_prey)

Generative Buffer is empty, generating data...
Storage of Predator Generative Buffer:  600
Storage of Prey Generative Buffer:  19200


In [8]:
def train_gail(pred_policy=pred_policy, prey_policy=prey_policy, 
               pred_discriminator=pred_discriminator, prey_discriminator=prey_discriminator,
               expert_buffer=expert_buffer, generative_buffer=generative_buffer, start_frame_pool=start_frame_pool,
               pred_count=pred_count, prey_count=prey_count, action_count=action_count, 
               use_walls=use_walls, gt_gen_episodes=gt_gen_episodes, gt_clip_length=gt_clip_length, 
               gen_dis_ratio=gen_dis_ratio, gamma=gamma, 
               alpha=alpha, eps_dis=eps_dis, 
               label_smoothing=label_smoothing, smooth=smooth,

                num_generations=num_generations,
                pred_batch_size=pred_batch_size,
                prey_batch_size=prey_batch_size,
                num_perturbations=num_perturbations,
                pert_clip_length=pert_clip_length,
                sigma=sigma,
                lr_pred_policy=lr_pred_policy,
                lr_prey_policy=lr_prey_policy,
                lr_pred_dis=lr_pred_dis,
                lr_prey_dis=lr_prey_dis,
                lambda_gp_pred=lambda_gp_pred,
                lambda_gp_prey=lambda_gp_prey):
    
    optim_dis_pred = torch.optim.RMSprop(pred_discriminator.parameters(), lr=lr_pred_dis, alpha=alpha, eps=eps_dis)
    optim_dis_prey = torch.optim.RMSprop(prey_discriminator.parameters(), lr=lr_prey_dis, alpha=alpha, eps=eps_dis)

    dis_metrics_pred = []
    dis_metrics_prey = []

    es_metrics_pred = []
    es_metrics_prey = []

    for generation in range(num_generations):
        # Sample traj from expert and generative buffer
        expert_pred_batch, expert_prey_batch = expert_buffer.sample(pred_batch_size, prey_batch_size)
        policy_pred_batch, policy_prey_batch = generative_buffer.sample(pred_batch_size, prey_batch_size)

        # Predator discriminator update
        dis_metric_pred = pred_discriminator.update(expert_pred_batch, policy_pred_batch, optim_dis_pred, lambda_gp_pred, label_smoothing, smooth)
        dis_metrics_pred.append(dis_metric_pred)
                                        
        # Prey discriminator update
        dis_metric_prey = prey_discriminator.update(expert_prey_batch, policy_prey_batch, optim_dis_prey, lambda_gp_prey, label_smoothing, smooth)
        dis_metrics_prey.append(dis_metric_prey)

        for i in range(gen_dis_ratio):
            pred_stats = pred_policy.update("predator", "pairwise",
                                            pred_count, prey_count, action_count,
                                            pred_policy, prey_policy,
                                            pred_discriminator, prey_discriminator,
                                            num_perturbations, generation,
                                            lr_pred_policy, lr_prey_policy,
                                            sigma, gamma, clip_length=pert_clip_length,
                                            use_walls=use_walls, start_frame_pool=start_frame_pool)
            

            pred_stats += pred_policy.update("predator", "attention",
                                            pred_count, prey_count, action_count,
                                            pred_policy, prey_policy,
                                            pred_discriminator, prey_discriminator,
                                            num_perturbations, generation,
                                            lr_pred_policy, lr_prey_policy,
                                            sigma, gamma, clip_length=pert_clip_length,
                                            use_walls=use_walls, start_frame_pool=start_frame_pool)
            es_metrics_pred.append(pred_stats)


            prey_stats = prey_policy.update("prey", "pairwise",
                                            pred_count, prey_count, action_count,
                                            pred_policy, prey_policy,
                                            pred_discriminator, prey_discriminator,
                                            num_perturbations, generation,
                                            lr_pred_policy, lr_prey_policy,
                                            sigma, gamma, clip_length=pert_clip_length,
                                            use_walls=use_walls, start_frame_pool=start_frame_pool)
            
            prey_stats += prey_policy.update("prey", "attention",
                                            pred_count, prey_count, action_count,
                                            pred_policy, prey_policy,
                                            pred_discriminator, prey_discriminator,
                                            num_perturbations, generation,
                                            lr_pred_policy, lr_prey_policy,
                                            sigma, gamma, clip_length=pert_clip_length,
                                            use_walls=use_walls, start_frame_pool=start_frame_pool)
            es_metrics_prey.append(prey_stats)

            # Generate new trajectories with updated policies
            generate_trajectories(buffer=generative_buffer, start_frame_pool=start_frame_pool,
                                    pred_count=pred_count, prey_count=prey_count, action_count=action_count, 
                                    pred_policy=pred_policy, prey_policy=prey_policy, 
                                    clip_length=gt_clip_length, num_generative_episodes=gt_gen_episodes,
                                    use_walls=use_walls)

        # also reduce globally
        sigma *= gamma
        lr_pred_policy *= gamma
        lr_prey_policy *= gamma

        return dis_metrics_pred, dis_metrics_prey, es_metrics_pred, es_metrics_prey

In [9]:
def objective(trial):
    #Environment
    pred_count        = 1
    prey_count        = 32 
    action_count      = 360
    use_walls         = True

    # generated_trajectories
    gt_gen_episodes   = 10
    gt_clip_length    = 30

    # Buffer
    pred_buffer_size  = 800
    prey_buffer_size  = 24000

    # Training
    num_generations   = trial.suggest_int("num_generations", 200, 300)
    pred_batch_size   = trial.suggest_categorical("pred_batch_size", [128, 256, 512])
    prey_batch_size   = trial.suggest_categorical("prey_batch_size", [512, 1024])
    patience          = 15
    window_size       = 10
    min_slope         = 0
    gen_dis_ratio     = trial.suggest_categorical("gen_dis_ratio", [1, 2, 3, 4, 5])

    # ES-Pertrubation
    num_perturbations = trial.suggest_categorical("num_perturbations", [8, 16, 32])
    pert_clip_length  = trial.suggest_int("pert_clip_length", 16, 28)
    sigma             = trial.suggest_float("sigma", 0.1, 0.14, log=True)
    lr_pred_policy    = trial.suggest_float("lr_pred_policy", 0.002, 0.01, log=True)
    lr_prey_policy    = trial.suggest_float("lr_prey_policy", 0.001, 0.01, log=True)
    gamma             = trial.suggest_float("gamma", 0.9994, 0.9999)

    # RMSprop
    lr_pred_dis       = trial.suggest_float("lr_pred_dis", 0.001, 0.005, log=True)
    lr_prey_dis       = trial.suggest_float("lr_prey_dis", 0.0007, 0.009, log=True)
    lambda_gp_pred    = trial.suggest_int("lambda_gp_pred", 3, 8)
    lambda_gp_prey    = trial.suggest_int("lambda_gp_prey", 1, 4)
    alpha             = 0.99
    eps_dis           = 1e-08


    dis_metrics_pred, dis_metrics_prey, es_metrics_pred, es_metrics_prey = train_gail(num_generations=num_generations,
                                                                                      pred_batch_size=pred_batch_size,
                                                                                      prey_batch_size=prey_batch_size,
                                                                                      gen_dis_ratio=gen_dis_ratio,
                                                                                      num_perturbations=num_perturbations,
                                                                                      pert_clip_length=pert_clip_length,
                                                                                      sigma=sigma,
                                                                                      lr_pred_policy=lr_pred_policy,
                                                                                      lr_prey_policy=lr_prey_policy,
                                                                                      gamma=gamma,
                                                                                      lr_pred_dis=lr_pred_dis,
                                                                                      lr_prey_dis=lr_prey_dis,
                                                                                      lambda_gp_pred=lambda_gp_pred,
                                                                                      lambda_gp_prey=lambda_gp_prey)

    #pred_wloss = abs(sum(score[0] for score in dis_metrics_pred) / len(dis_metrics_pred))
    #prey_wloss = abs(sum(score[0] for score in dis_metrics_prey) / len(dis_metrics_prey))

    policy_score_pred = abs(sum(score[3] for score in dis_metrics_pred) / len(dis_metrics_pred))
    policy_score_prey = abs(sum(score[3] for score in dis_metrics_prey) / len(dis_metrics_prey))

    expert_score_pred = abs(sum(score[2] for score in dis_metrics_pred) / len(dis_metrics_pred))
    expert_score_prey = abs(sum(score[2] for score in dis_metrics_prey) / len(dis_metrics_prey))

    dist_pred = policy_score_pred + expert_score_pred
    dist_prey = policy_score_prey + expert_score_prey

    return dist_pred, dist_prey

In [10]:
# Core count per trail = num_generative_episodes + num_pertrubations (nach Tuning Space aktuell 10 + max. 64 = 74)
# 74 * 50 = 3700

In [11]:
os.makedirs(os.path.join(save_dir, "tuning"), exist_ok=True)
tuning_path = os.path.join(save_dir, "tuning")
db_path = os.path.join(tuning_path, "optuna_results.db")
storage = f"sqlite:///{db_path}"

study = optuna.create_study(directions=["minimize", "minimize"], # Minimize Policy Scores
                            sampler=optuna.samplers.TPESampler(), # Bayesian Optimization
                            pruner=optuna.pruners.MedianPruner(), # Stopping trials that are not promising
                            storage=storage,
                            study_name="tuning results")

study.optimize(objective, n_trials=100, n_jobs=1)

file_path = os.path.join(tuning_path, "study.pkl")
with open(file_path, "wb") as f:
    pickle.dump(study, f)

[I 2025-07-30 09:22:42,319] A new study created in RDB with name: tuning results
[I 2025-07-30 09:41:52,192] Trial 0 finished with values: [0.07388617098331451, 0.13642525486648083] and parameters: {'num_generations': 235, 'pred_batch_size': 512, 'prey_batch_size': 1024, 'gen_dis_ratio': 4, 'num_perturbations': 16, 'pert_clip_length': 28, 'sigma': 0.11933908285160187, 'lr_pred_policy': 0.0029535568291141363, 'lr_prey_policy': 0.009700426446399622, 'gamma': 0.9998330781326765, 'lr_pred_dis': 0.0014574558280421523, 'lr_prey_dis': 0.008660352595609551, 'lambda_gp_pred': 4, 'lambda_gp_prey': 3}.
[I 2025-07-30 09:49:12,608] Trial 1 finished with values: [0.6850573793053627, 2.408707022666931] and parameters: {'num_generations': 288, 'pred_batch_size': 128, 'prey_batch_size': 1024, 'gen_dis_ratio': 2, 'num_perturbations': 16, 'pert_clip_length': 16, 'sigma': 0.1235894228917568, 'lr_pred_policy': 0.009641865095628905, 'lr_prey_policy': 0.0016767433260814952, 'gamma': 0.999508608182956, 'lr_pr

In [12]:
with open(file_path, "rb") as f:
    study = pickle.load(f)

print("Best hyperparameters:", study.best_trials)

best_trials_path = os.path.join(tuning_path, "best_trials.txt")
os.makedirs(os.path.dirname(best_trials_path), exist_ok=True)
with open(best_trials_path, "w") as f:
    for trial in study.best_trials:
        f.write(str(trial) + "\n\n")

#PREDATOR
fig1 = vis.plot_optimization_history(study, target=lambda t: t.values[0], target_name="Predator Loss")
write_html(fig1, file=os.path.join(tuning_path, "opt_history_predator.html"))
fig1.show()

fig3 = vis.plot_param_importances(study, target=lambda t: t.values[0], target_name="Predator Loss")
write_html(fig3, file=os.path.join(tuning_path, "param_importance_predator.html"))
fig3.show()

fig5 = vis.plot_slice(study, target=lambda t: t.values[0], target_name="Predator Loss")
write_html(fig5, file=os.path.join(tuning_path, "slice_plot_predator.html"))
fig5.show()

# PREY
fig2 = vis.plot_optimization_history(study, target=lambda t: t.values[1], target_name="Prey Loss")
write_html(fig2, file=os.path.join(tuning_path, "opt_history_prey.html"))
fig2.show()

fig4 = vis.plot_param_importances(study, target=lambda t: t.values[1], target_name="Prey Loss")
write_html(fig4, file=os.path.join(tuning_path, "param_importance_prey.html"))
fig4.show()

fig6 = vis.plot_slice(study, target=lambda t: t.values[1], target_name="Prey Loss")
write_html(fig6, file=os.path.join(tuning_path, "slice_plot_prey.html"))
fig6.show()

Best hyperparameters: [FrozenTrial(number=0, state=TrialState.COMPLETE, values=[0.07388617098331451, 0.13642525486648083], datetime_start=datetime.datetime(2025, 7, 30, 9, 22, 42, 334702), datetime_complete=datetime.datetime(2025, 7, 30, 9, 41, 52, 177272), params={'num_generations': 235, 'pred_batch_size': 512, 'prey_batch_size': 1024, 'gen_dis_ratio': 4, 'num_perturbations': 16, 'pert_clip_length': 28, 'sigma': 0.11933908285160187, 'lr_pred_policy': 0.0029535568291141363, 'lr_prey_policy': 0.009700426446399622, 'gamma': 0.9998330781326765, 'lr_pred_dis': 0.0014574558280421523, 'lr_prey_dis': 0.008660352595609551, 'lambda_gp_pred': 4, 'lambda_gp_prey': 3}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'num_generations': IntDistribution(high=300, log=False, low=200, step=1), 'pred_batch_size': CategoricalDistribution(choices=(128, 256, 512)), 'prey_batch_size': CategoricalDistribution(choices=(512, 1024)), 'gen_dis_ratio': CategoricalDistribution(choices=(1, 2,