In [1]:
import sys, os
sys.path.insert(0, os.path.abspath('..'))

import ast
import optuna
import pickle
import datetime
from utils.es_utils import *
from utils.env_utils import *
from utils.train_utils import *
from plotly.io import write_html
import optuna.visualization as vis
from marl_aquarium import aquarium_v0
from models.Buffer import Buffer, Pool
from models.PreyPolicy import PreyPolicy
from models.PredatorPolicy import PredatorPolicy
from models.Discriminator import Discriminator

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Hyperparameters

# generated_trajectories
gt_gen_episodes = 50
gt_clip_length = 30

# Training
num_generations = 50
pred_batch_size = 256
prey_batch_size = 1024
gen_dis_ratio = 1

# Polcy Update
num_perturbations = 32
pert_clip_length = 24
sigma = 0.07                  
gamma = 0.9995             
lr_pred_pin = 0.003
lr_pred_an = 0.003       
lr_prey_pin = 0.003
lr_prey_an = 0.003 
lr_prey_pred_pin = 0.003

# Discriminator Update
lr_pred_dis =  5e-3
lr_prey_dis =  5e-3
alpha = 0.99
eps_dis = 1e-08
lambda_gp_pred = 5
lambda_gp_prey = 5

# Early Stopping
start_es_pred = 1000
start_es_prey = 1000
patience = 1000

In [3]:
# Create training folder
path = rf"..\data\2. Training\tuning\GAIL"
timestamp = datetime.datetime.now().strftime("%d.%m.%Y_%H.%M")
folder_name = f"GAIL Tuning - {timestamp}"
save_dir = os.path.join(path, folder_name)
os.makedirs(save_dir, exist_ok=True)

# Expert Data
traj_path = rf'..\data\1. Data Processing\processed\video\expert_tensors\yolo_detected'
hl_path = rf'..\data\1. Data Processing\processed\video\expert_tensors\hand_labeled'
ftw_path = rf'..\data\1. Data Processing\processed\video\3. full_track_windows'

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

pred_policy = PredatorPolicy().to(device)
pred_policy.set_parameters(init=True)

prey_policy = PreyPolicy().to(device)
prey_policy.set_parameters(init=True)

pred_discriminator = Discriminator().to(device)
pred_discriminator.set_parameters(init=True)
optim_dis_pred = torch.optim.RMSprop(pred_discriminator.parameters(), lr=lr_pred_dis, alpha=alpha, eps=eps_dis)

prey_discriminator = Discriminator().to(device)
prey_discriminator.set_parameters(init=True)
optim_dis_prey = torch.optim.RMSprop(prey_discriminator.parameters(), lr=lr_prey_dis, alpha=alpha, eps=eps_dis)

expert_buffer = Buffer(pred_max_length=23000, prey_max_length=73000, device=device)

len_gb_pred = gt_gen_episodes * gt_clip_length * 3
len_gb_prey = gt_gen_episodes * gt_clip_length * 3 * 32 # completly update reply buffer after three generations
generative_buffer = Buffer(pred_max_length=len_gb_pred, prey_max_length=len_gb_prey, device=device)

start_frame_pool = Pool(max_length=13000, device=device)
start_frame_pool.generate_startframes(ftw_path)

early_stopper_pred = EarlyStoppingWasserstein(patience=patience, start_es=start_es_pred)
early_stopper_prey = EarlyStoppingWasserstein(patience=patience, start_es=start_es_prey)

In [5]:
# Load Expert Data from local storage
print("Expert Buffer is empty, load data...")
expert_buffer.add_expert(traj_path)
expert_buffer.clear(p=50)               # Reduce ratio of non-attack data by 90%. now ~equal
expert_buffer.add_expert(hl_path)       # hand-labeled data | Pred: 1057 | Prey: 33824

len_exp_pred, len_exp_prey = expert_buffer.lengths()
print("Storage of Predator Expert Buffer: ", len_exp_pred)
print("Storage of Prey Expert Buffer: ", len_exp_prey, "\n")

Expert Buffer is empty, load data...
Storage of Predator Expert Buffer:  11767
Storage of Prey Expert Buffer:  70324 



In [6]:
# Generate Trajectories for Generative Buffer
print("Generative Buffer is empty, generating data...")
generate_trajectories(buffer=generative_buffer, start_frame_pool=start_frame_pool,
                      pred_policy=pred_policy, prey_policy=prey_policy, 
                      clip_length=gt_clip_length, num_generative_episodes=gt_gen_episodes,
                      use_walls=True)

len_gen_pred, len_gen_prey = generative_buffer.lengths()
print("Storage of Predator Generative Buffer: ", len_gen_pred)
print("Storage of Prey Generative Buffer: ", len_gen_prey)

Generative Buffer is empty, generating data...
Storage of Predator Generative Buffer:  1500
Storage of Prey Generative Buffer:  48000


In [7]:
def train_gail(pred_policy=pred_policy, prey_policy=prey_policy, 
               pred_discriminator=pred_discriminator, prey_discriminator=prey_discriminator,
               expert_buffer=expert_buffer, generative_buffer=generative_buffer, start_frame_pool=start_frame_pool,
               gt_gen_episodes=gt_gen_episodes, gt_clip_length=gt_clip_length, 
               gen_dis_ratio=gen_dis_ratio, gamma=gamma, 
               alpha=alpha, eps_dis=eps_dis,

                num_generations=num_generations,
                pred_batch_size=pred_batch_size,
                prey_batch_size=prey_batch_size,
                num_perturbations=num_perturbations,
                pert_clip_length=pert_clip_length,
                sigma=sigma,
                lr_pred_pin = lr_pred_pin,
                lr_pred_an = lr_pred_an,
                lr_prey_pin = lr_prey_pin,
                lr_prey_an = lr_prey_an,
                lr_prey_pred_pin = lr_prey_pred_pin,
                lr_pred_dis=lr_pred_dis,
                lr_prey_dis=lr_prey_dis,
                lambda_gp_pred=lambda_gp_pred,
                lambda_gp_prey=lambda_gp_prey):
    
    optim_dis_pred = torch.optim.RMSprop(pred_discriminator.parameters(), lr=lr_pred_dis, alpha=alpha, eps=eps_dis)
    optim_dis_prey = torch.optim.RMSprop(prey_discriminator.parameters(), lr=lr_prey_dis, alpha=alpha, eps=eps_dis)

    dis_metrics_pred = []
    dis_metrics_prey = []

    es_metrics_pred = []
    es_metrics_prey = []

    for generation in range(num_generations):
        start_time = time.time()
        
        # Sample traj from expert and generative buffer
        expert_pred_batch, expert_prey_batch = expert_buffer.sample(pred_batch_size, prey_batch_size)
        policy_pred_batch, policy_prey_batch = generative_buffer.sample(pred_batch_size, prey_batch_size)

        # Predator discriminator update
        dis_metric_pred = pred_discriminator.update(expert_pred_batch, policy_pred_batch, optim_dis_pred, lambda_gp_pred)
        dis_metrics_pred.append(dis_metric_pred)
                                        
        # Prey discriminator update
        dis_metric_prey = prey_discriminator.update(expert_prey_batch, policy_prey_batch, optim_dis_prey, lambda_gp_prey)
        dis_metrics_prey.append(dis_metric_prey)

        for i in range(gen_dis_ratio):
            pred_stats = pred_policy.update("predator", "pairwise",
                                            pred_policy, prey_policy,
                                            pred_discriminator, prey_discriminator,
                                            num_perturbations, generation, lr_pred_pin,
                                            sigma, clip_length=pert_clip_length,
                                            use_walls=True, start_frame_pool=start_frame_pool)
            

            pred_stats += pred_policy.update("predator", "attention",
                                            pred_policy, prey_policy,
                                            pred_discriminator, prey_discriminator,
                                            num_perturbations, generation, lr_pred_an,
                                            sigma, clip_length=pert_clip_length,
                                            use_walls=True, start_frame_pool=start_frame_pool)
            es_metrics_pred.append(pred_stats)


            prey_stats = prey_policy.update("prey", "prey_pairwise",
                                            pred_policy, prey_policy,
                                            pred_discriminator, prey_discriminator,
                                            num_perturbations, generation, lr_prey_pin,
                                            sigma, clip_length=pert_clip_length,
                                            use_walls=True, start_frame_pool=start_frame_pool)
            
            prey_stats += prey_policy.update("prey", "prey_attention",
                                            pred_policy, prey_policy,
                                            pred_discriminator, prey_discriminator,
                                            num_perturbations, generation, lr_prey_an,
                                            sigma, clip_length=pert_clip_length,
                                            use_walls=True, start_frame_pool=start_frame_pool)
            es_metrics_prey.append(prey_stats)

            prey_stats += prey_policy.update("prey", "pred_pairwise",
                                            pred_policy, prey_policy,
                                            pred_discriminator, prey_discriminator,
                                            num_perturbations, generation, lr_prey_pred_pin,
                                            sigma, clip_length=pert_clip_length,
                                            use_walls=True, start_frame_pool=start_frame_pool)
            es_metrics_prey.append(prey_stats)


            # Generate new trajectories with updated policies
            generate_trajectories(buffer=generative_buffer, start_frame_pool=start_frame_pool,
                                pred_policy=pred_policy, prey_policy=prey_policy, 
                                clip_length=gt_clip_length, num_generative_episodes=gt_gen_episodes,
                                use_walls=True)

        lr_pred_pin *= gamma
        lr_pred_an *= gamma
        lr_prey_pin *= gamma
        lr_prey_an *= gamma
        lr_prey_pred_pin *= gamma
        sigma *= gamma

        return dis_metrics_pred, dis_metrics_prey, es_metrics_pred, es_metrics_prey

In [8]:
def objective(trial):
    try:
        # Training
        num_generations   = trial.suggest_int("num_generations", 20, 100)
        pred_batch_size   = trial.suggest_categorical("pred_batch_size", [64, 128, 256, 512])
        prey_batch_size   = trial.suggest_categorical("prey_batch_size", [256, 512, 1024])

        # Generator Update
        num_perturbations = trial.suggest_categorical("num_perturbations", [16, 32, 64])
        pert_clip_length  = trial.suggest_int("pert_clip_length", 10, 40)
        sigma             = trial.suggest_float("sigma", 0.03, 0.10)
        lr_pred_pin       = trial.suggest_float("lr_pred_pairwise", 1e-4, 5e-3, log=True)
        lr_pred_an        = trial.suggest_float("lr_pred_attention", 5e-5, 2e-3, log=True)
        lr_prey_pin       = trial.suggest_float("lr_prey_pairwise", 1e-4, 5e-3, log=True)
        lr_prey_an        = trial.suggest_float("lr_prey_attention", 5e-5, 2e-3, log=True)
        lr_prey_pred_pin  = trial.suggest_float("lr_prey_pred_pin", 1e-4, 5e-3, log=True)

        # Discriminator Update
        lr_pred_dis       = trial.suggest_float("lr_pred_dis", 1e-5, 5e-4, log=True)
        lr_prey_dis       = trial.suggest_float("lr_prey_dis", 1e-5, 5e-4, log=True)
        lambda_gp_pred    = trial.suggest_int("lambda_gp_pred", 1, 10)
        lambda_gp_prey    = trial.suggest_int("lambda_gp_prey", 1, 10)


        dis_metrics_pred, dis_metrics_prey, es_metrics_pred, es_metrics_prey = train_gail(num_generations=num_generations,
                                                                                        pred_batch_size=pred_batch_size,
                                                                                        prey_batch_size=prey_batch_size,
                                                                                        gen_dis_ratio=gen_dis_ratio,
                                                                                        num_perturbations=num_perturbations,
                                                                                        pert_clip_length=pert_clip_length,
                                                                                        sigma=sigma,
                                                                                        lr_pred_pin = lr_pred_pin,
                                                                                        lr_pred_an = lr_pred_an,
                                                                                        lr_prey_pin = lr_prey_pin,
                                                                                        lr_prey_an = lr_prey_an,
                                                                                        lr_prey_pred_pin = lr_prey_pred_pin,
                                                                                        gamma=gamma,
                                                                                        lr_pred_dis=lr_pred_dis,
                                                                                        lr_prey_dis=lr_prey_dis,
                                                                                        lambda_gp_pred=lambda_gp_pred,
                                                                                        lambda_gp_prey=lambda_gp_prey)

        policy_score_pred = sum(score[3] for score in dis_metrics_pred) / len(dis_metrics_pred)
        policy_score_prey = sum(score[3] for score in dis_metrics_prey) / len(dis_metrics_prey)

        expert_score_pred = sum(score[2] for score in dis_metrics_pred) / len(dis_metrics_pred)
        expert_score_prey = sum(score[2] for score in dis_metrics_prey) / len(dis_metrics_prey)

        dist_pred = abs(policy_score_pred - expert_score_pred)
        dist_prey = abs(policy_score_prey - expert_score_prey)

        return dist_pred, dist_prey

    except optuna.TrialPruned:
        raise

    except Exception as e:
        print(f"[WARNING] Trial failed due to error: {e}")
        raise optuna.TrialPruned()

In [None]:
os.makedirs(os.path.join(save_dir, "tuning"), exist_ok=True)
tuning_path = os.path.join(save_dir, "tuning")
db_path = os.path.join(tuning_path, "optuna_results.db")
storage = f"sqlite:///{db_path}"

study = optuna.create_study(directions=["minimize", "minimize"], # Minimize Policy Scores
                            sampler=optuna.samplers.TPESampler(), # Bayesian Optimization
                            pruner=optuna.pruners.MedianPruner(), # Stopping trials that are not promising
                            storage=storage,
                            study_name="tuning results")

study.optimize(objective, n_trials=60, n_jobs=1)

file_path = os.path.join(tuning_path, "study.pkl")
with open(file_path, "wb") as f:
    pickle.dump(study, f)

[I 2025-12-04 19:48:22,669] A new study created in RDB with name: tuning results
[I 2025-12-04 20:15:10,997] Trial 0 finished with values: [0.013515546917915344, 0.014300642535090446] and parameters: {'num_generations': 77, 'pred_batch_size': 512, 'prey_batch_size': 1024, 'num_perturbations': 64, 'pert_clip_length': 23, 'sigma': 0.09973615422416492, 'lr_pred_pairwise': 0.0003907997813396514, 'lr_pred_attention': 0.00019852568251309247, 'lr_prey_pairwise': 0.0007101940796792549, 'lr_prey_attention': 6.370430759677056e-05, 'lr_prey_pred_pin': 0.0041223351076510395, 'lr_pred_dis': 0.00014789427828733372, 'lr_prey_dis': 0.00014284306174162245, 'lambda_gp_pred': 2, 'lambda_gp_prey': 5}.
[I 2025-12-04 20:45:12,794] Trial 1 finished with values: [0.040757521986961365, 0.039101822301745415] and parameters: {'num_generations': 51, 'pred_batch_size': 64, 'prey_batch_size': 512, 'num_perturbations': 64, 'pert_clip_length': 38, 'sigma': 0.06593345246128343, 'lr_pred_pairwise': 0.000293134374427557

tensor([[inf],
        [inf],
        [inf],
        [inf],
        [inf],
        [inf],
        [inf],
        [inf],
        [inf],
        [inf],
        [inf],
        [inf],
        [inf],
        [inf],
        [inf],
        [inf],
        [inf],
        [inf],
        [inf],
        [inf],
        [inf],
        [inf],
        [inf],
        [inf],
        [inf],
        [inf],
        [inf],
        [inf],
        [inf],
        [inf],
        [nan],
        [nan]], grad_fn=<SliceBackward0>)


[I 2025-12-04 21:03:22,205] Trial 4 pruned. 


tensor([[inf],
        [inf],
        [inf],
        [inf],
        [inf],
        [inf],
        [inf],
        [inf],
        [inf],
        [inf],
        [inf],
        [inf],
        [inf],
        [inf],
        [inf],
        [inf],
        [inf],
        [inf],
        [inf],
        [inf],
        [inf],
        [inf],
        [inf],
        [inf],
        [inf],
        [inf],
        [inf],
        [inf],
        [inf],
        [inf],
        [nan],
        [nan]], grad_fn=<SliceBackward0>)


[I 2025-12-04 21:04:44,787] Trial 5 pruned. 


tensor([[inf],
        [nan],
        [inf],
        [inf],
        [nan],
        [inf],
        [inf],
        [nan],
        [inf],
        [inf],
        [inf],
        [inf],
        [nan],
        [inf],
        [inf],
        [inf],
        [inf],
        [inf],
        [inf],
        [inf],
        [inf],
        [inf],
        [inf],
        [inf],
        [inf],
        [inf],
        [inf],
        [inf],
        [inf],
        [nan],
        [nan],
        [nan]], grad_fn=<SliceBackward0>)


In [None]:
with open(file_path, "rb") as f:
    study = pickle.load(f)

print("Best hyperparameters:", study.best_trials)

best_trials_path = os.path.join(tuning_path, "best_trials.txt")
os.makedirs(os.path.dirname(best_trials_path), exist_ok=True)
with open(best_trials_path, "w") as f:
    for trial in study.best_trials:
        f.write(str(trial) + "\n\n")

# PREDATOR
fig1 = vis.plot_optimization_history(study, target=lambda t: t.values[0], target_name="Predator Loss")
write_html(fig1, file=os.path.join(tuning_path, "opt_history_predator.html"))
fig1.show()

fig3 = vis.plot_param_importances(study, target=lambda t: t.values[0], target_name="Predator Loss")
write_html(fig3, file=os.path.join(tuning_path, "param_importance_predator.html"))
fig3.show()

fig5 = vis.plot_slice(study, target=lambda t: t.values[0], target_name="Predator Loss")
write_html(fig5, file=os.path.join(tuning_path, "slice_plot_predator.html"))
fig5.show()

fig7 = vis.plot_parallel_coordinate(study, target=lambda t: t.values[0], target_name="Predator Loss")
write_html(fig7, file=os.path.join(tuning_path, "parallel_coordinate_predator.html"))
fig7.show()

fig9 = vis.plot_contour(study, target=lambda t: t.values[0], target_name="Predator Loss")
write_html(fig9, file=os.path.join(tuning_path, "contour_predator.html"))
fig9.show()


# PREY
fig2 = vis.plot_optimization_history(study, target=lambda t: t.values[1], target_name="Prey Loss")
write_html(fig2, file=os.path.join(tuning_path, "opt_history_prey.html"))
fig2.show()

fig4 = vis.plot_param_importances(study, target=lambda t: t.values[1], target_name="Prey Loss")
write_html(fig4, file=os.path.join(tuning_path, "param_importance_prey.html"))
fig4.show()

fig6 = vis.plot_slice(study, target=lambda t: t.values[1], target_name="Prey Loss")
write_html(fig6, file=os.path.join(tuning_path, "slice_plot_prey.html"))
fig6.show()

fig8 = vis.plot_parallel_coordinate(study, target=lambda t: t.values[1], target_name="Prey Loss")
write_html(fig8, file=os.path.join(tuning_path, "parallel_coordinate_prey.html"))
fig8.show()

fig10 = vis.plot_contour(study, target=lambda t: t.values[1], target_name="Prey Loss")
write_html(fig10, file=os.path.join(tuning_path, "contour_prey.html"))
fig10.show()


Best hyperparameters: [FrozenTrial(number=0, state=TrialState.COMPLETE, values=[0.29102087020874023, 0.06575947068631649], datetime_start=datetime.datetime(2025, 8, 20, 13, 4, 9, 126821), datetime_complete=datetime.datetime(2025, 8, 20, 13, 45, 30, 837164), params={'num_generations': 242, 'pred_batch_size': 256, 'prey_batch_size': 64, 'gen_dis_ratio': 4, 'num_perturbations': 32, 'pert_clip_length': 17, 'sigma': 0.1780527546981858, 'lr_pred_policy': 0.007724243721786319, 'lr_prey_policy': 0.0013129670302420356, 'gamma': 0.9997628275960591, 'lr_pred_dis': 0.0013561814426744576, 'lr_prey_dis': 0.0010917400364471109, 'lambda_gp_pred': 5, 'lambda_gp_prey': 7}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'num_generations': IntDistribution(high=300, log=False, low=100, step=1), 'pred_batch_size': CategoricalDistribution(choices=(64, 128, 256, 512, 1024)), 'prey_batch_size': CategoricalDistribution(choices=(64, 128, 256, 512, 1024)), 'gen_dis_ratio': IntDistribution(