In [32]:
# import argparse
import difflib
import importlib
import os
import time
import uuid

import gymnasium as gym
import numpy as np
import stable_baselines3 as sb3
import torch as th
from stable_baselines3.common.utils import set_random_seed

# Register custom envs
# import rl_zoo3.import_envs  # noqa: F401
from rl_zoo3.exp_manager import ExperimentManager
# from rl_zoo3.utils import ALGOS, StoreDict

from dataclasses import dataclass

In [33]:
@dataclass
class RlZooArgs:
    algo: str = 'ppo' # RL Algorithm
    env: str = 'CartPole-v1' # environment ID
    tensorboard_log: str = '' # Tensorboard log dir
    trained_agent: str = '' # Path to a pretrained agent to continue training
    truncate_last_trajectory: bool = True # When using HER with online sampling the last trajectory in the replay buffer will be truncated after reloading the replay buffer.
    n_timesteps: int = -1 # Overwrite the number of timesteps
    num_threads: int = -1 # Number of threads for PyTorch (-1 to use default)
    log_interval: int = -1 # Override log interval ( default: -1, no change)
    eval_freq: int = 25000 # Evaluate the agent every n steps (if negative, no evaluation). During hyperparameter optimization n-evaluations is used instead
    optimization_log_path: str = '' # Path to save the evaluation log and optimal policy for each hyperparameter tried during optimization. Disabled if no argument is passed.
    eval_episodes: int = 5 # Number of episodes to use for evaluation
    n_eval_envs: int = 1 # Number of environments for evaluation
    save_freq: int = -1 # Save the model every n steps (if negative, no checkpoint)
    save_replay_buffer: bool = False # Save the replay buffer too (when applicable)
    log_folder: str = 'logs' # Log folder
    seed: int = 44113 # -1 # Random generator seed
    vec_env: str = 'dummy' # VecEnv type
    device: str = 'auto' # PyTorch device to be use (ex: cpu, cuda...)
    n_trials: int = 500 # Number of trials for optimizing hyperparameters. This applies to each optimization runner, not the entire optimization process.
    max_total_trials: int = 0 # Number of (potentially pruned) trials for optimizing hyperparameters. This applies to the entire optimization process and takes precedence over --n-trials if set.
    optimize_hyperparameters: bool = False # Run hyperparameters search
    no_optim_plots: bool = False # Disable hyperparameter optimization plots
    n_jobs: int = 1 # Number of parallel jobs when optimizing hyperparameters
    sampler: str = 'tpe' # Sampler to use when optimizing hyperparameters
    pruner: str = 'median' # Pruner to use when optimizing hyperparameters
    n_startup_trials: int = 10 # Number of trials before using optuna sampler
    n_evaluations: int = 0 # Training policies are evaluated every n-timesteps // n_evaluations steps when doing hyperparameter optimization.Default is 1 evaluation per 100k timesteps.
    storage: str = '' # Database storage path if distributed optimization should be used
    study_name: str = '' # Study name for distributed optimization
    verbose: int = 0 # Verbose mode (0: no output, 1: INFO)
    gym_packages: str = '' # Additional external Gym environment package modules to import
    env_kwargs: str = '' # Optional keyword argument to pass to the env constructor
    eval_env_kwargs: str = '' # Optional keyword argument to pass to the env constructor for evaluation
    hyperparams: str = '' # Overwrite hyperparameter (e.g. learning_rate:0.01 train_freq:10)
    conf_file: str = '' # Custom yaml file or python package from which the hyperparameters will be loaded.We expect that python packages contain a dictionary called 'hyperparams' which contains a key for each environment.
    uuid: bool = False # Ensure that the run has a unique ID
    track: bool = False # if toggled, this experiment will be tracked with Weights and Biases
    wandb_project_name: str = 'sb3' # the wandb's project name
    wandb_entity: str = '' # the entity (team) of wandb's project
    progress: bool = False # if toggled, display a progress bar using tqdm and rich
    wandb_tags: str = '' # Tags for wandb run, e.g.: -tags optimized pr-123

|  RL Algo |  BipedalWalker-v3 | LunarLander-v2 | LunarLanderContinuous-v2 |  BipedalWalkerHardcore-v3 | CarRacing-v0 |
|----------|--------------|----------------|------------|--------------|--------------------------|
| ARS      |  | :heavy_check_mark: | | :heavy_check_mark: | |
| A2C      | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | |
| PPO      | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | |
| DQN      | N/A | :heavy_check_mark: | N/A | N/A | N/A |
| QR-DQN   | N/A | :heavy_check_mark: | N/A | N/A | N/A |
| DDPG     | :heavy_check_mark: | N/A | :heavy_check_mark: | | |
| SAC      | :heavy_check_mark: | N/A | :heavy_check_mark: | :heavy_check_mark: | |
| TD3      | :heavy_check_mark: | N/A | :heavy_check_mark: | :heavy_check_mark: | |
| TQC      | :heavy_check_mark: | N/A | :heavy_check_mark: | :heavy_check_mark: | |
| TRPO     | | :heavy_check_mark: | :heavy_check_mark: | | |


In [49]:
algos = {
    "ars": ['CartPole-v1', 'MountainCar-v0', 'Acrobot-v1', 'Pendulum-v1', 'MountainCarContinuous-v0'],
    "a2c": ['CartPole-v1', 'MountainCar-v0', 'Acrobot-v1', 'Pendulum-v1', 'MountainCarContinuous-v0'],
    "ppo": ['CartPole-v1', 'MountainCar-v0', 'Acrobot-v1', 'Pendulum-v1', 'MountainCarContinuous-v0'],
    "dqn": ['CartPole-v1', 'MountainCar-v0', 'Acrobot-v1'],
    "qrdqn": ['CartPole-v1', 'MountainCar-v0', 'Acrobot-v1'],
    "ddpg": ['Pendulum-v1', 'MountainCarContinuous-v0'],
    "sac": ['Pendulum-v1', 'MountainCarContinuous-v0'],
    "td3": ['Pendulum-v1', 'MountainCarContinuous-v0'],
    "tqc": ['Pendulum-v1', 'MountainCarContinuous-v0'],
    "trpo": ['CartPole-v1', 'MountainCar-v0', 'Acrobot-v1', 'Pendulum-v1', 'MountainCarContinuous-v0'],
}



In [50]:
output_str = ""

for key_name in algos:
    for env_name in algos[key_name]:
        print(f"Running {key_name} on {env_name}")
        output_str += f"Running {key_name} on {env_name}\n"
        args = RlZooArgs(algo=key_name, env=env_name)

        # Going through custom gym packages to let them register in the global registry
        for env_module in args.gym_packages:
            importlib.import_module(env_module)

        env_id = args.env
        registered_envs = set(gym.envs.registry.keys())

        # If the environment is not found, suggest the closest match
        if env_id not in registered_envs:
            try:
                closest_match = difflib.get_close_matches(env_id, registered_envs, n=1)[0]
            except IndexError:
                closest_match = "'no close match found...'"
            raise ValueError(f"{env_id} not found in gym registry, you maybe meant {closest_match}?")

        # Unique id to ensure there is no race condition for the folder creation
        uuid_str = f"_{uuid.uuid4()}" if args.uuid else ""
        if args.seed < 0:
            # Seed but with a random one
            args.seed = np.random.randint(2**32 - 1, dtype="int64").item()  # type: ignore[attr-defined]

        set_random_seed(args.seed)
        # print(f"Seed: {args.seed}")

        exp_manager = ExperimentManager(
            args,
            args.algo,
            env_id,
            args.log_folder,
            args.tensorboard_log,
            args.n_timesteps,
            args.eval_freq,
            args.eval_episodes,
            args.save_freq,
            args.hyperparams,
            args.env_kwargs,
            args.eval_env_kwargs,
            args.trained_agent,
            args.optimize_hyperparameters,
            args.storage,
            args.study_name,
            args.n_trials,
            args.max_total_trials,
            args.n_jobs,
            args.sampler,
            args.pruner,
            args.optimization_log_path,
            n_startup_trials=args.n_startup_trials,
            n_evaluations=args.n_evaluations,
            truncate_last_trajectory=args.truncate_last_trajectory,
            uuid_str=uuid_str,
            seed=args.seed,
            log_interval=args.log_interval,
            save_replay_buffer=args.save_replay_buffer,
            verbose=args.verbose,
            vec_env_type=args.vec_env,
            n_eval_envs=args.n_eval_envs,
            no_optim_plots=args.no_optim_plots,
            device=args.device,
            config=args.conf_file,
            show_progress=args.progress,
        )

        # Prepare experiment and launch hyperparameter optimization if needed
        results = exp_manager.setup_experiment()
        if results is not None:
            model, saved_hyperparams = results

            # Normal training
            if model is not None:
                for key, values in model.get_parameters().items():
                    print(key, end=":\n")
                    output_str += f"{key}:\n"
                    for inner_key, value in values.items():
                        # print(key, value.shape)
                        if isinstance(value, th.Tensor):
                            # print('\t', inner_key, value.shape)
                            output_str += f"\t {inner_key} {value.shape}\n"
                        else:
                            # print('\t', inner_key, value)
                            output_str += f"\t {inner_key} {value}\n\n-----------------------------\n\n"

Running ars on CartPole-v1
Loading hyperparameters from: /home/startung/code/rl-baselines3-zoo-with-pruning/hyperparams/ars.yml
Default hyperparameters for environment (ones being tuned will be overridden):
OrderedDict([('n_delta', 2),
             ('n_envs', 1),
             ('n_timesteps', 50000.0),
             ('policy', 'LinearPolicy')])
Log path: logs/ars/CartPole-v1_4
policy:
Running ars on MountainCar-v0
Loading hyperparameters from: /home/startung/code/rl-baselines3-zoo-with-pruning/hyperparams/ars.yml
Default hyperparameters for environment (ones being tuned will be overridden):
OrderedDict([('delta_std', 0.1),
             ('learning_rate', 0.018),
             ('n_delta', 8),
             ('n_envs', 1),
             ('n_timesteps', 500000.0),
             ('n_top', 1),
             ('normalize', 'dict(norm_obs=True, norm_reward=False)'),
             ('policy', 'MlpPolicy'),
             ('policy_kwargs', 'dict(net_arch=[16])'),
             ('zero_policy', False)])
Log pat

In [51]:
print(output_str)

Running ars on CartPole-v1
policy:
	 action_net.0.weight torch.Size([2, 4])
Running ars on MountainCar-v0
policy:
	 action_net.0.weight torch.Size([16, 2])
	 action_net.0.bias torch.Size([16])
	 action_net.2.weight torch.Size([3, 16])
	 action_net.2.bias torch.Size([3])
Running ars on Acrobot-v1
policy:
	 action_net.0.weight torch.Size([16, 6])
	 action_net.0.bias torch.Size([16])
	 action_net.2.weight torch.Size([3, 16])
	 action_net.2.bias torch.Size([3])
Running ars on Pendulum-v1
policy:
	 action_net.0.weight torch.Size([16, 3])
	 action_net.0.bias torch.Size([16])
	 action_net.2.weight torch.Size([1, 16])
	 action_net.2.bias torch.Size([1])
Running ars on MountainCarContinuous-v0
policy:
	 action_net.0.weight torch.Size([16, 2])
	 action_net.0.bias torch.Size([16])
	 action_net.2.weight torch.Size([1, 16])
	 action_net.2.bias torch.Size([1])
Running a2c on CartPole-v1
policy:
	 mlp_extractor.policy_net.0.weight torch.Size([64, 4])
	 mlp_extractor.policy_net.0.bias torch.Size([64])