In [None]:
# %pip install stable_baselines3
# %pip install stable_baselines3[extra]
# %pip install ../../../rSoccer/
# %pip install ipywidgets
# %pip install tqdm
# %pip install shapely

In [None]:
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import SubprocVecEnv
from lib.domain.behavior_callback import BehaviorCallback

from lib.environment.attacker.environment import Environment

import os
from datetime import datetime

In [3]:
task_training_name = "attacker"
algorithm_name = "PPO"

In [4]:
def create_env():
    def _init():
        return Environment()
    return _init

In [5]:
def create_folder_if_not_exists(folder_path):
    if not os.path.exists(folder_path):
        os.makedirs(folder_path)

In [6]:
def get_task_models_path():
    current_datetime = datetime.now()

    year = current_datetime.year
    month = current_datetime.month
    day = current_datetime.day
    hour = current_datetime.hour
    minute = current_datetime.minute
    second = current_datetime.second

    datetime_name = f"{year}_{month}_{day}_{hour}_{minute}_{second}"

    return f"models/{task_training_name}/{algorithm_name}/{datetime_name}"

In [7]:
save_path = get_task_models_path()

create_folder_if_not_exists(save_path)

In [8]:
num_threads = 2

In [None]:
env = SubprocVecEnv([create_env() for i in range(num_threads)])

n_actions = env.action_space.shape[-1]
n_actions

In [10]:
gae_lambda = 0.95
gamma = 0.99
learning_rate = 0.0004
clip_range = 0.2
policy = "MlpPolicy"
batch_size = 128

In [11]:
load_model = False
loaded_model_path = "models/attacker/PPO/2024_6_11_13_34_28/PPO_model"

In [12]:
model = PPO(
    policy=policy,
    env=env,
    gamma=gamma,
    gae_lambda=gae_lambda,
    clip_range=clip_range,
    batch_size=batch_size)

if load_model:
    model.set_parameters(loaded_model_path)

In [13]:
total_timesteps = 200_000_000

In [14]:
saved_model_number = 20
save_freq = total_timesteps // (saved_model_number * num_threads)
log_interval = total_timesteps // 10

In [15]:
checkpoint_callback = BehaviorCallback(
    check_freq=save_freq,
    save_path=save_path,
    model_name=algorithm_name,
    number_robot_blue=3,
    number_robot_yellow=3)

In [None]:
model.learn(
    total_timesteps=total_timesteps,
    log_interval=log_interval,
    callback=checkpoint_callback,
    progress_bar=True)