In [1]:
# import argparse
import difflib
import importlib
import os
import time
import uuid

import gymnasium as gym
import numpy as np
import stable_baselines3 as sb3
import torch as th
from stable_baselines3.common.utils import set_random_seed

# Register custom envs
# import rl_zoo3.import_envs  # noqa: F401
from rl_zoo3.exp_manager import ExperimentManager
# from rl_zoo3.utils import ALGOS, StoreDict

from dataclasses import dataclass

In [2]:
@dataclass
class LotteryTicketArgs:
    prune_ratio: float = 0.2
    prune_type: str = 'global' # global, layerwise
    random_reinitialization: bool = False
    one_shot_pruning: bool = False
    iterative_pruning: bool = False

In [3]:
@dataclass
class RlZooArgs:
    algo: str = 'ppo' # RL Algorithm
    env: str = 'CartPole-v1' # environment ID
    tensorboard_log: str = '' # Tensorboard log dir
    trained_agent: str = '' # Path to a pretrained agent to continue training
    truncate_last_trajectory: bool = True # When using HER with online sampling the last trajectory in the replay buffer will be truncated after reloading the replay buffer.
    n_timesteps: int = -1 # Overwrite the number of timesteps
    num_threads: int = -1 # Number of threads for PyTorch (-1 to use default)
    log_interval: int = -1 # Override log interval ( default: -1, no change)
    eval_freq: int = 25000 # Evaluate the agent every n steps (if negative, no evaluation). During hyperparameter optimization n-evaluations is used instead
    optimization_log_path: str = '' # Path to save the evaluation log and optimal policy for each hyperparameter tried during optimization. Disabled if no argument is passed.
    eval_episodes: int = 5 # Number of episodes to use for evaluation
    n_eval_envs: int = 1 # Number of environments for evaluation
    save_freq: int = -1 # Save the model every n steps (if negative, no checkpoint)
    save_replay_buffer: bool = False # Save the replay buffer too (when applicable)
    log_folder: str = 'logs' # Log folder
    seed: int = -1 # Random generator seed
    vec_env: str = 'dummy' # VecEnv type
    device: str = 'auto' # PyTorch device to be use (ex: cpu, cuda...)
    n_trials: int = 500 # Number of trials for optimizing hyperparameters. This applies to each optimization runner, not the entire optimization process.
    max_total_trials: int = 0 # Number of (potentially pruned) trials for optimizing hyperparameters. This applies to the entire optimization process and takes precedence over --n-trials if set.
    optimize_hyperparameters: bool = False # Run hyperparameters search
    no_optim_plots: bool = False # Disable hyperparameter optimization plots
    n_jobs: int = 1 # Number of parallel jobs when optimizing hyperparameters
    sampler: str = 'tpe' # Sampler to use when optimizing hyperparameters
    pruner: str = 'median' # Pruner to use when optimizing hyperparameters
    n_startup_trials: int = 10 # Number of trials before using optuna sampler
    n_evaluations: int = 0 # Training policies are evaluated every n-timesteps // n_evaluations steps when doing hyperparameter optimization.Default is 1 evaluation per 100k timesteps.
    storage: str = '' # Database storage path if distributed optimization should be used
    study_name: str = '' # Study name for distributed optimization
    verbose: int = 1 # Verbose mode (0: no output, 1: INFO)
    gym_packages: str = '' # Additional external Gym environment package modules to import
    env_kwargs: str = '' # Optional keyword argument to pass to the env constructor
    eval_env_kwargs: str = '' # Optional keyword argument to pass to the env constructor for evaluation
    hyperparams: str = '' # Overwrite hyperparameter (e.g. learning_rate:0.01 train_freq:10)
    conf_file: str = '' # Custom yaml file or python package from which the hyperparameters will be loaded.We expect that python packages contain a dictionary called 'hyperparams' which contains a key for each environment.
    uuid: bool = False # Ensure that the run has a unique ID
    track: bool = False # if toggled, this experiment will be tracked with Weights and Biases
    wandb_project_name: str = 'sb3' # the wandb's project name
    wandb_entity: str = '' # the entity (team) of wandb's project
    progress: bool = False # if toggled, display a progress bar using tqdm and rich
    wandb_tags: str = '' # Tags for wandb run, e.g.: -tags optimized pr-123

In [4]:
# args = RlZooArgs()
algo_name = 'a2c'
env_name = 'CartPole-v1'
seed = 44113
args = RlZooArgs(algo=algo_name, env=env_name, seed=seed, verbose=0)

In [5]:
# Going through custom gym packages to let them register in the global registry
for env_module in args.gym_packages:
    importlib.import_module(env_module)

env_id = args.env
registered_envs = set(gym.envs.registry.keys())

# If the environment is not found, suggest the closest match
if env_id not in registered_envs:
    try:
        closest_match = difflib.get_close_matches(env_id, registered_envs, n=1)[0]
    except IndexError:
        closest_match = "'no close match found...'"
    raise ValueError(f"{env_id} not found in gym registry, you maybe meant {closest_match}?")

# Unique id to ensure there is no race condition for the folder creation
uuid_str = f"_{uuid.uuid4()}" if args.uuid else ""
if args.seed < 0:
    # Seed but with a random one
    args.seed = np.random.randint(2**32 - 1, dtype="int64").item()  # type: ignore[attr-defined]

In [6]:
set_random_seed(args.seed)
print(f"Seed: {args.seed}")

Seed: 44113


In [7]:
# Setting num threads to 1 makes things run faster on cpu
if args.num_threads > 0:
    if args.verbose > 1:
        print(f"Setting torch.num_threads to {args.num_threads}")
    th.set_num_threads(args.num_threads)

if args.trained_agent != "":
    assert args.trained_agent.endswith(".zip") and os.path.isfile(
        args.trained_agent
    ), "The trained_agent must be a valid path to a .zip file"

print("=" * 10, env_id, "=" * 10)




In [8]:
if args.track:
    try:
        import wandb
    except ImportError as e:
        raise ImportError(
            "if you want to use Weights & Biases to track experiment, please install W&B via `pip install wandb`"
        ) from e

    run_name = f"{args.env}__{args.algo}__{args.seed}__{int(time.time())}"
    tags = [*args.wandb_tags, f"v{sb3.__version__}"]
    run = wandb.init(
        name=run_name,
        project=args.wandb_project_name,
        entity=args.wandb_entity,
        tags=tags,
        config=vars(args),
        sync_tensorboard=True,  # auto-upload sb3's tensorboard metrics
        monitor_gym=True,  # auto-upload the videos of agents playing the game
        save_code=True,  # optional
    )
    args.tensorboard_log = f"runs/{run_name}"

In [9]:
exp_manager = ExperimentManager(
    args,
    args.algo,
    env_id,
    args.log_folder,
    args.tensorboard_log,
    args.n_timesteps,
    args.eval_freq,
    args.eval_episodes,
    args.save_freq,
    args.hyperparams,
    args.env_kwargs,
    args.eval_env_kwargs,
    args.trained_agent,
    args.optimize_hyperparameters,
    args.storage,
    args.study_name,
    args.n_trials,
    args.max_total_trials,
    args.n_jobs,
    args.sampler,
    args.pruner,
    args.optimization_log_path,
    n_startup_trials=args.n_startup_trials,
    n_evaluations=args.n_evaluations,
    truncate_last_trajectory=args.truncate_last_trajectory,
    uuid_str=uuid_str,
    seed=args.seed,
    log_interval=args.log_interval,
    save_replay_buffer=args.save_replay_buffer,
    verbose=args.verbose,
    vec_env_type=args.vec_env,
    n_eval_envs=args.n_eval_envs,
    no_optim_plots=args.no_optim_plots,
    device=args.device,
    config=args.conf_file,
    show_progress=args.progress,
)

In [10]:
# Prepare experiment and launch hyperparameter optimization if needed
results = exp_manager.setup_experiment()
if results is not None:
    model, saved_hyperparams = results
    if args.track:
        # we need to save the loaded hyperparameters
        args.saved_hyperparams = saved_hyperparams
        assert run is not None  # make mypy happy
        run.config.setdefaults(vars(args))

    # Normal training
    if model is not None:
        # model.save("boo")
        print(model.get_parameters())
        exp_manager.learn(model)
        exp_manager.save_trained_model(model)
else:
    exp_manager.hyperparameters_optimization()

Loading hyperparameters from: /home/startung/code/rl-baselines3-zoo-with-pruning/hyperparams/a2c.yml
Default hyperparameters for environment (ones being tuned will be overridden):
OrderedDict([('ent_coef', 0.0),
             ('n_envs', 8),
             ('n_timesteps', 500000.0),
             ('policy', 'MlpPolicy')])
Log path: logs/a2c/CartPole-v1_5
{'policy': OrderedDict([('mlp_extractor.policy_net.0.weight', tensor([[ 0.2180, -0.0485, -0.1344,  0.2340],
        [-0.1108, -0.1717, -0.3098, -0.0630],
        [ 0.1240,  0.0232,  0.4054, -0.1411],
        [-0.3963,  0.1493, -0.1833,  0.0957],
        [-0.0559, -0.1295, -0.0832, -0.0530],
        [ 0.1620,  0.0709,  0.0997, -0.2296],
        [ 0.3864, -0.1370, -0.1669,  0.3586],
        [-0.1112,  0.1236, -0.3404, -0.3605],
        [ 0.0078, -0.0128,  0.0159,  0.1040],
        [-0.0577,  0.0594, -0.0091,  0.2051],
        [ 0.0449,  0.3293, -0.4155, -0.2879],
        [ 0.0757,  0.2835, -0.2232,  0.1891],
        [-0.1683,  0.1304, -0.2382

In [11]:
model.get_parameters()

{'policy': OrderedDict([('mlp_extractor.policy_net.0.weight',
               tensor([[ 0.2708,  0.2043,  0.3010,  0.3366],
                       [-0.1561, -0.4346, -0.9361, -0.2228],
                       [ 0.1481,  0.2417,  0.8993, -0.0233],
                       [-0.3334,  0.3595,  0.2630,  0.2797],
                       [-0.0929, -0.3852, -0.6941, -0.2021],
                       [ 0.0956, -0.1461, -0.2858, -0.3727],
                       [ 0.4348,  0.1067,  0.2996,  0.5109],
                       [-0.1604, -0.1264, -0.9827, -0.5385],
                       [ 0.0604,  0.2485,  0.6208,  0.2662],
                       [ 0.0144,  0.3472,  0.5791,  0.3120],
                       [-0.0033,  0.1075, -0.8948, -0.4553],
                       [ 0.0801,  0.4850,  0.0907,  0.1744],
                       [-0.0603,  0.3749,  0.0040,  0.1180],
                       [ 0.0800,  0.5435,  0.0762,  0.0281],
                       [ 0.2117, -0.2252, -0.5676, -0.4255],
                       