In [1]:
import sys, os
sys.path.insert(0, os.path.abspath('..'))

import ast
import optuna
import pickle
import datetime
import custom_marl_aquarium
from utils.es_utils import *
from utils.env_utils import *
from utils.eval_utils import *
from utils.train_utils import *
from models.Buffer import Buffer, Pool
from plotly.io import write_html
import optuna.visualization as vis
from models.PreyOnlyPolicy import PreyOnlyPolicy
from models.Discriminator import Discriminator

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# generated_trajectories
gt_gen_episodes = 16
gt_clip_length  = 30

# Training
num_generations = 400
prey_batch_size = 512

# Policy Update (ES)
num_perturbations = 32
pert_clip_length  = 30
sigma             = 0.1
gamma             = 0.9995
lr_prey_pin       = 1e-4
lr_prey_an        = 5e-5   # attention etwas kleiner als pairwise

# Discriminator Update
lr_prey_dis     = 1e-4
alpha           = 0.99
eps_dis         = 1e-8
lambda_gp_prey  = 5

In [3]:
# Create training folder
path = rf"..\data\2. Training\tuning\GAIL"
timestamp = datetime.datetime.now().strftime("%d.%m.%Y_%H.%M")
folder_name = f"GAIL Tuning - {timestamp}"
save_dir = os.path.join(path, folder_name)
os.makedirs(save_dir, exist_ok=True)

# Expert Data
mill_path = rf'..\data\1. Data Processing\processed\couzin\reversed'
traj_path = rf'..\data\1. Data Processing\processed\prey_only\expert_tensors\yolo_detected'
ftw_path = rf'..\data\1. Data Processing\processed\prey_only\3. full_track_windows'

In [4]:
device = torch.device("cpu") # Use CPU for GAIL training due to PoolThreading issues with GPU

prey_policy = PreyOnlyPolicy().to(device)
prey_policy.set_parameters(init=True)

prey_discriminator = Discriminator(neigh=32).to(device)
prey_discriminator.set_parameters(init=True)
optim_dis_prey = torch.optim.RMSprop(prey_discriminator.parameters(), lr=lr_prey_dis, alpha=alpha, eps=eps_dis)

expert_buffer = Buffer(prey_max_length=100000, device=device)

len_gb_prey = gt_gen_episodes * gt_clip_length * 3 * 32 # completly update reply buffer after three generations
generative_buffer = Buffer(prey_max_length=len_gb_prey, device=device)

start_frame_pool = Pool(max_length=13000, device=device)
start_frame_pool.generate_startframes(ftw_path)

In [5]:
# Load Expert Data from local storage
print("Expert Buffer is empty, load data...")
expert_buffer.add_expert(mill_path)

print("Buffer Size:", len(expert_buffer.prey_buffer))

Expert Buffer is empty, load data...
Buffer Size: 32000


In [6]:
# Generate Trajectories for Generative Buffer
print("Generative Buffer is empty, generating data...")
generate_trajectories(buffer=generative_buffer, start_frame_pool=start_frame_pool, prey_policy=prey_policy, 
                      clip_length=gt_clip_length, num_generative_episodes=gt_gen_episodes, use_walls=True)

Generative Buffer is empty, generating data...


In [7]:
def train_gail(prey_policy=prey_policy, prey_discriminator=prey_discriminator,
               expert_buffer=expert_buffer, generative_buffer=generative_buffer, start_frame_pool=start_frame_pool,
               gt_gen_episodes=gt_gen_episodes, gt_clip_length=gt_clip_length, 
               gamma=gamma, alpha=alpha, eps_dis=eps_dis,

                num_generations=num_generations,
                prey_batch_size=prey_batch_size,
                num_perturbations=num_perturbations,
                pert_clip_length=pert_clip_length,
                sigma=sigma,
                lr_prey_pin = lr_prey_pin,
                lr_prey_an = lr_prey_an,
                lr_prey_dis=lr_prey_dis,
                lambda_gp_prey=lambda_gp_prey):
    
    optim_dis_prey = torch.optim.RMSprop(prey_discriminator.parameters(), lr=lr_prey_dis, alpha=alpha, eps=eps_dis)

    dis_metrics_prey = []
    es_metrics_prey = []

    for generation in range(num_generations):
        start_time = time.time()
        
        # Sample traj from expert and generative buffer
        expert_prey_batch = expert_buffer.sample(prey_batch_size)
        policy_prey_batch = expert_buffer.sample(prey_batch_size)

        # Prey discriminator update
        dis_metric_prey = prey_discriminator.update(expert_prey_batch, policy_prey_batch, optim_dis_prey, lambda_gp_prey)
        dis_metrics_prey.append(dis_metric_prey)

        prey_stats = prey_policy.update("prey", "prey_pairwise",
                                        prey_policy, prey_discriminator,
                                        num_perturbations, generation, lr_prey_pin,
                                        sigma, clip_length=pert_clip_length,
                                        use_walls=True, start_frame_pool=start_frame_pool)
        
        prey_stats += prey_policy.update("prey", "prey_attention",
                                        prey_policy, prey_discriminator,
                                        num_perturbations, generation, lr_prey_an,
                                        sigma, clip_length=pert_clip_length,
                                        use_walls=True, start_frame_pool=start_frame_pool)
        es_metrics_prey.append(prey_stats)

        # Generate new trajectories with updated policies
        generate_trajectories(buffer=generative_buffer, start_frame_pool=start_frame_pool, prey_policy=prey_policy, 
                            clip_length=gt_clip_length, num_generative_episodes=gt_gen_episodes, use_walls=True)
        
        polarization, angular_momentum = run_prey_policy(make_env(pred_count=0, use_walls=True, start_frame_pool=start_frame_pool), prey_policy, steps=30)
            
        lr_prey_pin *= gamma
        lr_prey_an *= gamma
        sigma *= gamma

        return dis_metrics_prey, es_metrics_prey, polarization, angular_momentum

In [8]:
def objective(trial):
    try:
        # Training
        lr_prey_pin       = trial.suggest_float("lr_prey_pairwise", 5e-5, 3e-4, log=True)
        lr_prey_an        = trial.suggest_float("lr_prey_attention", 3e-5, 2e-4, log=True)
        lr_prey_dis       = trial.suggest_float("lr_prey_dis", 1e-5, 1e-4, log=True)


        dis_metrics_prey, es_metrics_prey, polarization, angular_momentum = train_gail(num_generations=num_generations,
                                                                                        prey_batch_size=prey_batch_size,
                                                                                        num_perturbations=num_perturbations,
                                                                                        pert_clip_length=pert_clip_length,
                                                                                        sigma=sigma,
                                                                                        lr_prey_pin = lr_prey_pin,
                                                                                        lr_prey_an = lr_prey_an,
                                                                                        gamma=gamma,
                                                                                        lr_prey_dis=lr_prey_dis,
                                                                                        lambda_gp_prey=lambda_gp_prey)
        
        expert_pol = 0.69
        expert_am = 0.23

        polarization_diff = abs(polarization - expert_pol)
        angular_momentum_diff = abs(angular_momentum - expert_am)

        return polarization_diff, angular_momentum_diff

    except optuna.TrialPruned:
        raise

    except Exception as e:
        print(f"[WARNING] Trial failed due to error: {e}")
        raise optuna.TrialPruned()

In [9]:
os.makedirs(os.path.join(save_dir, "tuning"), exist_ok=True)
tuning_path = os.path.join(save_dir, "tuning")
db_path = os.path.join(tuning_path, "optuna_results.db")
storage = f"sqlite:///{db_path}"

study = optuna.create_study(directions=["minimize", "minimize"], # Minimize Policy Scores
                            sampler=optuna.samplers.TPESampler(), # Bayesian Optimization
                            pruner=optuna.pruners.MedianPruner(), # Stopping trials that are not promising
                            storage=storage,
                            study_name="tuning results")

study.optimize(objective, n_trials=10, n_jobs=1)

file_path = os.path.join(tuning_path, "study.pkl")
with open(file_path, "wb") as f:
    pickle.dump(study, f)

[I 2025-12-16 10:55:57,958] A new study created in RDB with name: tuning results
[I 2025-12-16 11:02:38,557] Trial 0 finished with values: [0.30998820066452026, 0.2298719584941864] and parameters: {'lr_prey_pairwise': 8.805743710084083e-05, 'lr_prey_attention': 3.0897074546996505e-05, 'lr_prey_dis': 1.4734668766303705e-05}.
[I 2025-12-16 11:09:18,340] Trial 1 finished with values: [0.30914074182510376, 0.22937741875648499] and parameters: {'lr_prey_pairwise': 5.3667954221743755e-05, 'lr_prey_attention': 5.224471186870188e-05, 'lr_prey_dis': 2.3230415970671627e-05}.
[I 2025-12-16 11:15:54,139] Trial 2 finished with values: [0.05883282423019409, 0.14747068285942078] and parameters: {'lr_prey_pairwise': 0.0002886287750485167, 'lr_prey_attention': 0.0001197403310059681, 'lr_prey_dis': 7.56643482964404e-05}.
[I 2025-12-16 11:22:29,955] Trial 3 finished with values: [0.26924291253089905, 0.1399441957473755] and parameters: {'lr_prey_pairwise': 0.0001566635419206595, 'lr_prey_attention': 3.02

In [10]:
#Best hyperparameters: 
# [FrozenTrial(number=3, state=TrialState.COMPLETE, values=[0.08558011054992676, 0.19384074211120605], 
# datetime_start=datetime.datetime(2025, 12, 16, 1, 20, 50, 963181), datetime_complete=datetime.datetime(2025, 12, 16, 1, 27, 25, 339681), 
# params={'prey_batch_size': 512, 'num_perturbations': 32, 'lr_prey_pairwise': 0.00012022084326506861, 'lr_prey_attention': 7.692881981890137e-05, 
# 'lr_prey_dis': 3.295850739600134e-05, 'lambda_gp_prey': 7}, user_attrs={}, system_attrs={}, intermediate_values={}, 
# distributions={'prey_batch_size': CategoricalDistribution(choices=(128, 256, 512, 1024)), 'num_perturbations': CategoricalDistribution(choices=(16, 32, 64)), 
# 'lr_prey_pairwise': FloatDistribution(high=0.01, log=True, low=1e-05, step=None), 'lr_prey_attention': FloatDistribution(high=0.01, log=True, low=1e-05, step=None), 
# 'lr_prey_dis': FloatDistribution(high=0.001, log=True, low=1e-06, step=None), 'lambda_gp_prey': IntDistribution(high=10, log=False, low=1, step=1)}, trial_id=4, value=None), 

# FrozenTrial(number=27, state=TrialState.COMPLETE, values=[0.1538769006729126, 0.19259414076805115], datetime_start=datetime.datetime(2025, 12, 16, 4, 12, 18, 777008), 
# datetime_complete=datetime.datetime(2025, 12, 16, 4, 24, 10, 207367), params={'prey_batch_size': 1024, 'num_perturbations': 64, 'lr_prey_pairwise': 0.00045944936863592475, 
# 'lr_prey_attention': 0.0008269524435643975, 'lr_prey_dis': 9.890626741460674e-05, 'lambda_gp_prey': 8}, user_attrs={}, system_attrs={}, intermediate_values={}, 
# distributions={'prey_batch_size': CategoricalDistribution(choices=(128, 256, 512, 1024)), 'num_perturbations': CategoricalDistribution(choices=(16, 32, 64)), 
# 'lr_prey_pairwise': FloatDistribution(high=0.01, log=True, low=1e-05, step=None), 'lr_prey_attention': FloatDistribution(high=0.01, log=True, low=1e-05, step=None), 
# 'lr_prey_dis': FloatDistribution(high=0.001, log=True, low=1e-06, step=None), 'lambda_gp_prey': IntDistribution(high=10, log=False, low=1, step=1)}, trial_id=28, value=None), 

# FrozenTrial(number=46, state=TrialState.COMPLETE, values=[0.17033612728118896, 0.1818726658821106], datetime_start=datetime.datetime(2025, 12, 16, 6, 39, 17, 23671), 
# datetime_complete=datetime.datetime(2025, 12, 16, 6, 45, 54, 291899), params={'prey_batch_size': 512, 'num_perturbations': 32, 'lr_prey_pairwise': 1.982248657709893e-05, 
# 'lr_prey_attention': 4.002435058081766e-05, 'lr_prey_dis': 3.4642956614979435e-05, 'lambda_gp_prey': 5}, user_attrs={}, system_attrs={}, intermediate_values={}, 
# distributions={'prey_batch_size': CategoricalDistribution(choices=(128, 256, 512, 1024)), 'num_perturbations': CategoricalDistribution(choices=(16, 32, 64)), 
# 'lr_prey_pairwise': FloatDistribution(high=0.01, log=True, low=1e-05, step=None), 'lr_prey_attention': FloatDistribution(high=0.01, log=True, low=1e-05, step=None), 
# 'lr_prey_dis': FloatDistribution(high=0.001, log=True, low=1e-06, step=None), 'lambda_gp_prey': IntDistribution(high=10, log=False, low=1, step=1)}, trial_id=47, value=None)]

In [11]:
with open(file_path, "rb") as f:
    study = pickle.load(f)

print("Best hyperparameters:", study.best_trials)

best_trials_path = os.path.join(tuning_path, "best_trials.txt")
os.makedirs(os.path.dirname(best_trials_path), exist_ok=True)
with open(best_trials_path, "w") as f:
    for trial in study.best_trials:
        f.write(str(trial) + "\n\n")

# PREDATOR
fig1 = vis.plot_optimization_history(study, target=lambda t: t.values[0], target_name="Predator Loss")
write_html(fig1, file=os.path.join(tuning_path, "opt_history_predator.html"))
fig1.show()

fig3 = vis.plot_param_importances(study, target=lambda t: t.values[0], target_name="Predator Loss")
write_html(fig3, file=os.path.join(tuning_path, "param_importance_predator.html"))
fig3.show()

fig5 = vis.plot_slice(study, target=lambda t: t.values[0], target_name="Predator Loss")
write_html(fig5, file=os.path.join(tuning_path, "slice_plot_predator.html"))
fig5.show()

fig7 = vis.plot_parallel_coordinate(study, target=lambda t: t.values[0], target_name="Predator Loss")
write_html(fig7, file=os.path.join(tuning_path, "parallel_coordinate_predator.html"))
fig7.show()

fig9 = vis.plot_contour(study, target=lambda t: t.values[0], target_name="Predator Loss")
write_html(fig9, file=os.path.join(tuning_path, "contour_predator.html"))
fig9.show()


# PREY
fig2 = vis.plot_optimization_history(study, target=lambda t: t.values[1], target_name="Prey Loss")
write_html(fig2, file=os.path.join(tuning_path, "opt_history_prey.html"))
fig2.show()

fig4 = vis.plot_param_importances(study, target=lambda t: t.values[1], target_name="Prey Loss")
write_html(fig4, file=os.path.join(tuning_path, "param_importance_prey.html"))
fig4.show()

fig6 = vis.plot_slice(study, target=lambda t: t.values[1], target_name="Prey Loss")
write_html(fig6, file=os.path.join(tuning_path, "slice_plot_prey.html"))
fig6.show()

fig8 = vis.plot_parallel_coordinate(study, target=lambda t: t.values[1], target_name="Prey Loss")
write_html(fig8, file=os.path.join(tuning_path, "parallel_coordinate_prey.html"))
fig8.show()

fig10 = vis.plot_contour(study, target=lambda t: t.values[1], target_name="Prey Loss")
write_html(fig10, file=os.path.join(tuning_path, "contour_prey.html"))
fig10.show()


Best hyperparameters: [FrozenTrial(number=2, state=TrialState.COMPLETE, values=[0.05883282423019409, 0.14747068285942078], datetime_start=datetime.datetime(2025, 12, 16, 11, 9, 18, 358621), datetime_complete=datetime.datetime(2025, 12, 16, 11, 15, 54, 111467), params={'lr_prey_pairwise': 0.0002886287750485167, 'lr_prey_attention': 0.0001197403310059681, 'lr_prey_dis': 7.56643482964404e-05}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'lr_prey_pairwise': FloatDistribution(high=0.0003, log=True, low=5e-05, step=None), 'lr_prey_attention': FloatDistribution(high=0.0002, log=True, low=3e-05, step=None), 'lr_prey_dis': FloatDistribution(high=0.0001, log=True, low=1e-05, step=None)}, trial_id=3, value=None), FrozenTrial(number=3, state=TrialState.COMPLETE, values=[0.26924291253089905, 0.1399441957473755], datetime_start=datetime.datetime(2025, 12, 16, 11, 15, 54, 139484), datetime_complete=datetime.datetime(2025, 12, 16, 11, 22, 29, 935323), params={'lr_prey_pairwi