In [47]:
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import SubprocVecEnv
from stable_baselines3.common.callbacks import CheckpointCallback

from lib.environment.attacker.environment import Environment

import os
from datetime import datetime

In [48]:
task_training_name = "attacker"
algorithm_name = "PPO"

In [49]:
def create_env():
    def _init():
        return Environment()
    return _init

In [50]:
def create_folder_if_not_exists(folder_path):
    if not os.path.exists(folder_path):
        os.makedirs(folder_path)

In [51]:
def get_task_models_path():
    current_datetime = datetime.now()

    year = current_datetime.year
    month = current_datetime.month
    day = current_datetime.day
    hour = current_datetime.hour
    minute = current_datetime.minute
    second = current_datetime.second

    datetime_name = f"{year}_{month}_{day}_{hour}_{minute}_{second}"

    return f"models/{task_training_name}/{algorithm_name}/{datetime_name}"

In [52]:
save_path = get_task_models_path()

create_folder_if_not_exists(save_path)

In [53]:
num_threads = 60
model_filename_prefix = f"{algorithm_name}_model"

In [54]:
env = SubprocVecEnv([create_env() for i in range(num_threads)])

n_actions = env.action_space.shape[-1]
n_actions

2

In [55]:
gae_lambda = 0.95
gamma = 0.99
learning_rate = 0.0004
clip_range = 0.2
policy = "MlpPolicy"
batch_size = 128

In [56]:
load_model = True
loaded_model_path = "models/attacker/PPO/2024_6_3_16_46_10/PPO_model_90000000_steps"

In [57]:
model = PPO(
    policy=policy,
    env=env,
    gamma=gamma,
    gae_lambda=gae_lambda,
    clip_range=clip_range,
    batch_size=batch_size)

if load_model:
    model.set_parameters(loaded_model_path)

In [58]:
total_timesteps = 200_000_000

In [59]:
saved_model_number = 20
save_freq = total_timesteps // (saved_model_number * num_threads)
log_interval = total_timesteps // 10

In [60]:
checkpoint_callback = CheckpointCallback(
    save_freq=save_freq,
    save_path=save_path,
    name_prefix=model_filename_prefix)

In [None]:
model.learn(
    total_timesteps=total_timesteps,
    log_interval=log_interval,
    callback=checkpoint_callback)


ODE Message 3: LCP internal error, s <= 0 (s=0.0000e+00)

ODE Message 3: LCP internal error, s <= 0 (s=0.0000e+00)

ODE Message 3: LCP internal error, s <= 0 (s=-0.0000e+00)

ODE Message 3: LCP internal error, s <= 0 (s=0.0000e+00)

ODE Message 3: LCP internal error, s <= 0 (s=0.0000e+00)

ODE Message 3: LCP internal error, s <= 0 (s=-0.0000e+00)

ODE Message 3: LCP internal error, s <= 0 (s=0.0000e+00)
Bad pipe message: %s [b'\xb4*l\x08\x8a\xb0W\xef\x16\xc4\xe6N\t\xa4\x83;\xe3\xc9 \x06v\xcc\xdaN\xa9\xa4\x1f9\xae\xee\xeb\xbe\xb0\xeb\xf8\xb5\x02\xe0e\x83^o\xb1{\x87"\xbew\x96\xcf\x11\x00', b'\x02\x13\x03\x13\x01\x00\xff', b'']
Bad pipe message: %s [b"Ng$U\xbf\xe2R\xca\xcd\xbf\xf2O\xe4E\xabi\x0e\x03\x00\x00\xa6\xc0,\xc00\x00\xa3\x00\x9f\xcc\xa9\xcc\xa8\xcc\xaa\xc0\xaf\xc0\xad\xc0\xa3\xc0\x9f\xc0]\xc0a\xc0W\xc0S\xc0+\xc0/\x00\xa2\x00\x9e\xc0\xae\xc0\xac\xc0\xa2\xc0\x9e\xc0\\\xc0`\xc0V\xc0R\xc0$\xc0(\x00k\x00j\xc0s\xc0w\x00\xc4\x00\xc3\xc0#\xc0'\x00g\x00@\xc0r\xc0v\x00\xbe\x00\xbd\xc0\n\xc

In [None]:
model.save(f"{save_path}/{model_filename_prefix}")