In [None]:
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import SubprocVecEnv
from stable_baselines3.common.callbacks import CheckpointCallback

from lib.environment.attacker.environment import Environment

import numpy as np
import os
from datetime import datetime

from lib.helpers.configuration_helper import ConfigurationHelper

In [None]:
task_training_name = "attacker"
algorithm_name = "ppo"

In [None]:
def create_env():
    def _init():
        return Environment()
    return _init

In [None]:
def create_folder_if_not_exists(folder_path):
    if not os.path.exists(folder_path):
        os.makedirs(folder_path)

In [None]:
def get_task_models_path():
    current_datetime = datetime.now()

    year = current_datetime.year
    month = current_datetime.month
    day = current_datetime.day
    hour = current_datetime.hour
    minute = current_datetime.minute
    second = current_datetime.second

    datetime_name = f"{year}_{month}_{day}_{hour}_{minute}_{second}"

    return f"models/{task_training_name}/{algorithm_name}/{datetime_name}"

In [None]:
save_path = get_task_models_path()

create_folder_if_not_exists(save_path)

In [None]:
num_threads = 1
model_filename_prefix = "PPO_model"

In [None]:
env = SubprocVecEnv([create_env() for i in range(num_threads)])

n_actions = env.action_space.shape[-1]
n_actions

In [None]:
gae_lambda = 0.95
gamma = 0.99
learning_rate = 0.0004
clip_range = 0.2
policy = "MlpPolicy"
batch_size = 128

In [None]:
load_model = False
loaded_model_path = ""

In [None]:
model = PPO(
    policy=policy,
    env=env,
    gamma=gamma,
    gae_lambda=gae_lambda,
    clip_range=clip_range,
    batch_size=batch_size)

if load_model:
    model.set_parameters(loaded_model_path)

In [None]:
total_timesteps = 2000000

In [None]:
saved_model_number = 10
save_freq = total_timesteps // (saved_model_number * num_threads)
log_interval = total_timesteps // 10

In [None]:
checkpoint_callback = CheckpointCallback(
    save_freq=save_freq,
    save_path=save_path,
    name_prefix=model_filename_prefix)

In [None]:
model.learn(
    total_timesteps=total_timesteps,
    log_interval=log_interval,
    callback=checkpoint_callback,
    progress_bar=True)

In [None]:
model.save(f"{save_path}/{model_filename_prefix}")