In [1]:
import sys, os
sys.path.insert(0, os.path.abspath('..'))

import pickle
import datetime
from utils.es_utils import *
from utils.env_utils import *
from utils.train_utils import *
from marl_aquarium import aquarium_v0
from models.Buffer import Buffer, Pool
from models.PreyPolicy import PreyPolicy
from models.PredatorPolicy import PredatorPolicy
from models.Discriminator import Discriminator

In [2]:
"""
ToDo's
- Discriminator Output is raw nicht, sigmoid wie in WGAN - Was eig. für Wasserstein Loss relevant ist.
- Umgang mit Mode Collapse?
- Design von plots standardisieren - MPI Farben

- Discriminator konvergiert werden Output 0.5, heißt er macht Coin-Flip bei jedem Case.
- Speed Metriken und Velocity Metric bei Scaling und als Input-Vector prüfen!
"""

"\nToDo's\n- Discriminator Output is raw nicht, sigmoid wie in WGAN - Was eig. für Wasserstein Loss relevant ist.\n- Umgang mit Mode Collapse?\n- Design von plots standardisieren - MPI Farben\n\n- Discriminator konvergiert werden Output 0.5, heißt er macht Coin-Flip bei jedem Case.\n- Speed Metriken und Velocity Metric bei Scaling und als Input-Vector prüfen!\n"

In [3]:
# Hyperparameters

#Environment
pred_count = 1
prey_count = 32 
action_count = 360
use_walls = True

# generated_trajectories
gt_gen_episodes = 4
gt_clip_length = 5

# Buffer
pred_buffer_size = 23000
prey_buffer_size = 73000

# Training
num_generations = 250
pred_batch_size = 16
prey_batch_size = 16
gen_dis_ratio = 4

# Early Stopping
start_es_pred = 70
start_es_prey = 50
patience = 20

# ES-Pertrubation
num_perturbations = 32
pert_clip_length = 16
sigma = 0.17
gamma = 0.9998
lr_pred_policy = 0.007
lr_prey_policy = 0.001

# RMSprop
lr_pred_dis =  0.014
lr_prey_dis = 0.001
alpha=0.99
eps_dis=1e-08
lambda_gp_pred = 6
lambda_gp_prey = 7

In [None]:
# Create training folder
path = rf"..\data\2. Training\training\GAIL"
timestamp = datetime.datetime.now().strftime("%d.%m.%Y_%H.%M")
folder_name = f"GAIL Training - {timestamp}"
save_dir = os.path.join(path, folder_name)
os.makedirs(save_dir, exist_ok=True)

# Expert Data
traj_path = rf'..\data\1. Data Processing\processed\video\expert_tensors\yolo_detected'
hl_path = rf'..\data\1. Data Processing\processed\video\expert_tensors\hand_labeled'
ftw_path = rf'..\data\1. Data Processing\processed\video\3. full_track_windows'

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

pred_policy = PredatorPolicy().to(device)
pred_policy.set_parameters(init=True)

prey_policy = PreyPolicy().to(device)
prey_policy.set_parameters(init=True)

pred_discriminator = Discriminator().to(device)
pred_discriminator.set_parameters(init=True)
optim_dis_pred = torch.optim.RMSprop(pred_discriminator.parameters(), lr=lr_pred_dis, alpha=alpha, eps=eps_dis)

prey_discriminator = Discriminator().to(device)
prey_discriminator.set_parameters(init=True)
optim_dis_prey = torch.optim.RMSprop(prey_discriminator.parameters(), lr=lr_prey_dis, alpha=alpha, eps=eps_dis)

expert_buffer = Buffer(pred_max_length=pred_buffer_size, prey_max_length=prey_buffer_size, device=device)
generative_buffer = Buffer(pred_max_length=pred_buffer_size, prey_max_length=prey_buffer_size, device=device)

start_frame_pool = Pool(max_length=13000, device=device)
start_frame_pool.generate_startframes(ftw_path)

early_stopper_pred = EarlyStoppingWasserstein(patience=patience, start_es=start_es_pred)
early_stopper_prey = EarlyStoppingWasserstein(patience=patience, start_es=start_es_prey)

In [6]:
# Load Expert Data from local storage
print("Expert Buffer is empty, load data...")
expert_buffer.add_expert(traj_path)
expert_buffer.clear(p=50)               # Reduce ratio of non-attack data by 90%. now ~equal
expert_buffer.add_expert(hl_path)       # hand-labeled data | Pred: 1057 | Prey: 33824

len_exp_pred, len_exp_prey = expert_buffer.lengths()
print("Storage of Predator Expert Buffer: ", len_exp_pred)
print("Storage of Prey Expert Buffer: ", len_exp_prey, "\n")

Expert Buffer is empty, load data...
Storage of Predator Expert Buffer:  11767
Storage of Prey Expert Buffer:  70324 



In [7]:
# Generate Trajectories for Generative Buffer
print("Generative Buffer is empty, generating data...")
generate_trajectories(buffer=generative_buffer, start_frame_pool=start_frame_pool,
                      pred_count=pred_count, prey_count=prey_count, action_count=action_count, 
                      pred_policy=pred_policy, prey_policy=prey_policy, 
                      clip_length=gt_clip_length, num_generative_episodes=gt_gen_episodes,
                      use_walls=True)

len_gen_pred, len_gen_prey = generative_buffer.lengths()
print("Storage of Predator Generative Buffer: ", len_gen_pred)
print("Storage of Prey Generative Buffer: ", len_gen_prey)

Generative Buffer is empty, generating data...
[DEBUG] Prey prey_0 action=242
[DEBUG] Prey prey_1 action=118
[DEBUG] Prey prey_2 action=257
[DEBUG] Prey prey_3 action=164
[DEBUG] Prey prey_4 action=148
[DEBUG] Prey prey_5 action=151
[DEBUG] Prey prey_6 action=102
[DEBUG] Prey prey_7 action=163
[DEBUG] Prey prey_8 action=189
[DEBUG] Prey prey_9 action=188
[DEBUG] Prey prey_10 action=117
[DEBUG] Prey prey_11 action=222
[DEBUG] Prey prey_12 action=131
[DEBUG] Prey prey_13 action=260
[DEBUG] Prey prey_14 action=234
[DEBUG] Prey prey_15 action=190
[DEBUG] Prey prey_16 action=196
[DEBUG] Prey prey_17 action=213
[DEBUG] Prey prey_18 action=163
[DEBUG] Prey prey_19 action=168
[DEBUG] Prey prey_20 action=113
[DEBUG] Prey prey_21 action=260
[DEBUG] Prey prey_22 action=229
[DEBUG] Prey prey_23 action=216
[DEBUG] Prey prey_24 action=240
[DEBUG] Prey prey_25 action=101
[DEBUG] Prey prey_26 action=118
[DEBUG] Prey prey_27 action=129
[DEBUG] Prey prey_28 action=195
[DEBUG] Prey prey_29 action=179
[DE

In [8]:
dis_metrics_pred = []
dis_metrics_prey = []

es_metrics_pred = []
es_metrics_prey = []

for generation in range(num_generations):
    start_time = time.time()
    
    # Sample traj from expert and generative buffer
    expert_pred_batch, expert_prey_batch = expert_buffer.sample(pred_batch_size, prey_batch_size)
    policy_pred_batch, policy_prey_batch = generative_buffer.sample(pred_batch_size, prey_batch_size)

    # Predator discriminator update
    dis_metric_pred = pred_discriminator.update(expert_pred_batch, policy_pred_batch, optim_dis_pred, lambda_gp_pred)
    dis_metrics_pred.append(dis_metric_pred)
                                     
    # Prey discriminator update
    dis_metric_prey = prey_discriminator.update(expert_prey_batch, policy_prey_batch, optim_dis_prey, lambda_gp_prey)
    dis_metrics_prey.append(dis_metric_prey)

    for i in range(gen_dis_ratio):
        pred_stats = pred_policy.update("predator", "pairwise",
                                        pred_count, prey_count, action_count,
                                        pred_policy, prey_policy,
                                        pred_discriminator, prey_discriminator,
                                        num_perturbations, generation,
                                        lr_pred_policy, lr_prey_policy,
                                        sigma, gamma, clip_length=pert_clip_length,
                                        use_walls=use_walls, start_frame_pool=start_frame_pool)
        

        pred_stats += pred_policy.update("predator", "attention",
                                        pred_count, prey_count, action_count,
                                        pred_policy, prey_policy,
                                        pred_discriminator, prey_discriminator,
                                        num_perturbations, generation,
                                        lr_pred_policy, lr_prey_policy,
                                        sigma, gamma, clip_length=pert_clip_length,
                                        use_walls=use_walls, start_frame_pool=start_frame_pool)
        es_metrics_pred.append(pred_stats)


        prey_stats = prey_policy.update("prey", "pairwise",
                                        pred_count, prey_count, action_count,
                                        pred_policy, prey_policy,
                                        pred_discriminator, prey_discriminator,
                                        num_perturbations, generation,
                                        lr_pred_policy, lr_prey_policy,
                                        sigma, gamma, clip_length=pert_clip_length,
                                        use_walls=use_walls, start_frame_pool=start_frame_pool)
        
        prey_stats += prey_policy.update("prey", "attention",
                                        pred_count, prey_count, action_count,
                                        pred_policy, prey_policy,
                                        pred_discriminator, prey_discriminator,
                                        num_perturbations, generation,
                                        lr_pred_policy, lr_prey_policy,
                                        sigma, gamma, clip_length=pert_clip_length,
                                        use_walls=use_walls, start_frame_pool=start_frame_pool)
        es_metrics_prey.append(prey_stats)

        # Generate new trajectories with updated policies
        generate_trajectories(buffer=generative_buffer, start_frame_pool=start_frame_pool,
                              pred_count=pred_count, prey_count=prey_count, action_count=action_count, 
                              pred_policy=pred_policy, prey_policy=prey_policy, 
                              clip_length=gt_clip_length, num_generative_episodes=gt_gen_episodes,
                              use_walls=use_walls)

    sigma *= gamma
    lr_pred_policy *= gamma
    lr_prey_policy *= gamma

    last_epoch_duration = time.time() - start_time
    estimated_time, epoch_time = remaining_time(num_generations, last_epoch_duration, current_generation=generation)

    avg_es_pred = np.mean([m['avg_reward_diff'] for m in pred_stats])
    avg_es_prey = np.mean([m['avg_reward_diff'] for m in prey_stats])
        
    print(f"[Generation {generation+1}/{num_generations}] - Time: {epoch_time} - Estimated Finish: {estimated_time}" )
    print(f"Predator | Avg. ES-Reward: {avg_es_pred:.4f} | Wasserstein Loss: {dis_metric_pred[0]:.4f} | Expert Scores: {dis_metric_pred[2]:.4f} | Policy Scores: {dis_metric_pred[3]:.4f}")
    print(f"Prey     | Avg. ES-Reward: {avg_es_prey:.4f} | Wasserstein Loss: {dis_metric_prey[0]:.4f} | Expert Scores: {dis_metric_prey[2]:.4f} | Policy Scores: {dis_metric_prey[3]:.4f}\n")

    if early_stopper_pred(dis_metric_pred[0], generation, "predator") or early_stopper_prey(dis_metric_prey[0], generation, "prey"):
        break

    if generation % 25 == 0:
        save_checkpoint(save_dir, generation,
                        pred_policy, prey_policy,
                        pred_discriminator, prey_discriminator,
                        optim_dis_pred, optim_dis_prey,
                        expert_buffer, generative_buffer,
                        dis_metrics_pred, dis_metrics_prey,
                        es_metrics_pred, es_metrics_prey)


# Save models
save_models(save_dir,
            pred_policy, prey_policy,
            pred_discriminator, prey_discriminator,
            optim_dis_pred, optim_dis_prey,
            expert_buffer, generative_buffer,
            dis_metrics_pred, dis_metrics_prey,
            es_metrics_pred, es_metrics_prey)

TypeError: parallel_get_rollouts() got multiple values for argument 'clip_length'