In [1]:
import sys, os
sys.path.insert(0, os.path.abspath('..'))

import pickle
import datetime
import custom_marl_aquarium
from utils.es_utils import *
from utils.env_utils import *
from utils.train_utils import *
from models.Buffer import Buffer, Pool
from models.PreyOnlyPolicy import PreyOnlyPolicy
from models.Discriminator import Discriminator

In [2]:
# Hyperparameters

# generated_trajectories
gt_gen_episodes = 3
gt_clip_length = 30

# Training
num_generations = 10
prey_batch_size = 512

# Polcy Update
num_perturbations = 30
pert_clip_length = 30
sigma = 0.1
gamma = 0.9997
lr_prey_pin = 0.1
lr_prey_an  = 0.1

# Discriminator Update
lr_prey_dis = 1.0e-4
alpha = 0.99
eps_dis = 1e-08
lambda_gp_prey = 9

In [3]:
# Create training folder
path = rf"..\data\2. Training\training\GAIL"
timestamp = datetime.datetime.now().strftime("%d.%m.%Y_%H.%M")
folder_name = f"GAIL - {timestamp} - Prey Only"
save_dir = os.path.join(path, folder_name)
os.makedirs(save_dir, exist_ok=True)

# Expert Data
mill_path = rf'..\data\1. Data Processing\processed\couzin\reversed'
traj_path = rf'..\data\1. Data Processing\processed\prey_only\expert_tensors\yolo_detected'
ftw_path = rf'..\data\1. Data Processing\processed\prey_only\3. full_track_windows'

In [4]:
device = torch.device("cpu") # Use CPU for GAIL training due to PoolThreading issues with GPU

prey_policy = PreyOnlyPolicy().to(device)
prey_policy.set_parameters(init=True)

prey_discriminator = Discriminator(neigh=32).to(device)
prey_discriminator.set_parameters(init=True)
optim_dis_prey = torch.optim.RMSprop(prey_discriminator.parameters(), lr=lr_prey_dis, alpha=alpha, eps=eps_dis)

expert_buffer = Buffer(prey_max_length=100000, device=device)

len_gb_prey = gt_gen_episodes * gt_clip_length * 3 * 32 # completly update reply buffer after three generations
generative_buffer = Buffer(prey_max_length=len_gb_prey, device=device)

start_frame_pool = Pool(max_length=13000, device=device)
start_frame_pool.generate_startframes(ftw_path)

In [5]:
# Load Expert Data from local storage
print("Expert Buffer is empty, load data...")
expert_buffer.add_expert(mill_path)

print("Buffer Size:", len(expert_buffer.prey_buffer))

Expert Buffer is empty, load data...
Buffer Size: 32000


In [6]:
# Generate Trajectories for Generative Buffer
print("Generative Buffer is empty, generating data...")
generate_trajectories(buffer=generative_buffer, start_frame_pool=start_frame_pool, prey_policy=prey_policy, 
                      clip_length=gt_clip_length, num_generative_episodes=gt_gen_episodes, use_walls=True)

Generative Buffer is empty, generating data...


In [8]:
dis_metrics_prey = []
es_metrics_prey = []

for generation in range(num_generations):
    start_time = time.time()
    
    # Sample traj from expert and generative buffer
    expert_prey_batch = expert_buffer.sample(prey_batch_size)
    policy_prey_batch = expert_buffer.sample(prey_batch_size)

    # Prey discriminator update
    dis_metric_prey = prey_discriminator.update(expert_prey_batch, policy_prey_batch, optim_dis_prey, lambda_gp_prey)
    dis_metrics_prey.append(dis_metric_prey)

    prey_stats = prey_policy.update("prey", "prey_pairwise",
                                    prey_policy, prey_discriminator,
                                    num_perturbations, generation, lr_prey_pin,
                                    sigma, clip_length=pert_clip_length,
                                    use_walls=True, start_frame_pool=start_frame_pool)
    print(f"[PREY] Pairwise Interaction Network updated!")
    
    prey_stats += prey_policy.update("prey", "prey_attention",
                                    prey_policy, prey_discriminator,
                                    num_perturbations, generation, lr_prey_an,
                                    sigma, clip_length=pert_clip_length,
                                    use_walls=True, start_frame_pool=start_frame_pool)
    es_metrics_prey.append(prey_stats)
    print(f"[PREY] Attention Network updated!\n")

    # Generate new trajectories with updated policies
    generate_trajectories(buffer=generative_buffer, start_frame_pool=start_frame_pool, prey_policy=prey_policy, 
                            clip_length=gt_clip_length, num_generative_episodes=gt_gen_episodes, use_walls=True)
        
    lr_prey_pin *= gamma
    lr_prey_an *= gamma
    sigma *= gamma

    last_epoch_duration = time.time() - start_time
    estimated_time, epoch_time = remaining_time(num_generations, last_epoch_duration, current_generation=generation)

    avg_es_prey = np.mean([m['reward_mean'] for m in prey_stats])
    avg_grad_norm_prey = np.mean([m['grad_norm'] for m in prey_stats])
        
    print(f"[Generation {generation+1}/{num_generations}] - Time: {epoch_time} - Estimated Finish: {estimated_time}" )
    print(f"Prey     | ES-Reward: {avg_es_prey:.3f} | Wasserstein Loss: {dis_metric_prey[0]:.3f} | GP: {dis_metric_prey[1]:.3f} | Grad: {avg_grad_norm_prey:.3f} | Expert Scores: {dis_metric_prey[2]:.3f} | Policy Scores: {dis_metric_prey[3]:.3f}\n")

[PREY] Pairwise Interaction Network updated!
[PREY] Attention Network updated!

[Generation 1/10] - Time: 5:18 - Estimated Finish: 16.12.2025 13:20:07
Prey     | ES-Reward: -0.000 | Wasserstein Loss: 7.597 | GP: 0.844 | Grad: 265548.094 | Expert Scores: -0.142 | Policy Scores: -0.141

[PREY] Pairwise Interaction Network updated!
[PREY] Attention Network updated!

[Generation 2/10] - Time: 5:24 - Estimated Finish: 16.12.2025 13:20:57
Prey     | ES-Reward: 0.001 | Wasserstein Loss: 7.472 | GP: 0.830 | Grad: 70291.650 | Expert Scores: -0.142 | Policy Scores: -0.140

[PREY] Pairwise Interaction Network updated!
[PREY] Attention Network updated!

[Generation 3/10] - Time: 5:15 - Estimated Finish: 16.12.2025 13:19:46
Prey     | ES-Reward: -0.002 | Wasserstein Loss: 7.354 | GP: 0.817 | Grad: 69472.441 | Expert Scores: -0.140 | Policy Scores: -0.139

[PREY] Pairwise Interaction Network updated!
[PREY] Attention Network updated!

[Generation 4/10] - Time: 5:18 - Estimated Finish: 16.12.2025 13:

In [9]:
# Save models
save_models(save_dir,
            prey_policy,
            prey_discriminator,
            optim_dis_prey,
            expert_buffer, 
            generative_buffer,
            dis_metrics_prey,
            es_metrics_prey)

Models successfully saved!
Training done!
