Skip to content

Commit

Permalink
Bugfix/parallel launcher for linux (#1141)
Browse files Browse the repository at this point in the history
Closes: #1121
  • Loading branch information
MischaPanch committed May 8, 2024
2 parents d58ae16 + bf3859a commit f0b7abe
Show file tree
Hide file tree
Showing 9 changed files with 156 additions and 49 deletions.
10 changes: 7 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
## Release 1.1.0

### Api Extensions
- `evaluation`: New package for repeating the same experiment with multiple seeds and aggregating the results. #1074 #1141
- The module `evaluation.launchers` for parallelization is currently in alpha state.
- `data`:
- `Batch`:
- Add methods `to_dict` and `to_list_of_dicts`. #1063 #1098
Expand All @@ -24,7 +26,7 @@
- `SamplingConfig`:
- Add support for `batch_size=None`. #1077
- Add `training_seed` for explicit seeding of training and test environments, the `test_seed` is inferred from `training_seed`. #1074
- `highlevel.experiment`:
- `experiment`:
- `Experiment` now has a `name` attribute, which can be set using `ExperimentBuilder.with_name` and
which determines the default run name and therefore the persistence subdirectory.
It can still be overridden in `Experiment.run()`, the new parameter name being `run_name` rather than
Expand All @@ -34,8 +36,8 @@
- Add method `build_seeded_collection` for the sound creation of multiple
experiments with varying random seeds #1131
- Add method `copy` to facilitate the creation of multiple experiments from a single builder #1131
- `evaluation`: New package for repeating the same experiment with multiple seeds and aggregating the results. #1074
- The module `evaluation.launchers` for parallelization is currently in alpha state.
- `env`:
- Added new `VectorEnvType` called `SUBPROC_SHARED_MEM_AUTO` and used in for Atari and Mujoco venv creation. #1141
- Loggers can now restore the logged data into python by using the new `restore_logged_data` method. #1074
- `utils`:
- `net.continuous.Critic`:
Expand All @@ -45,6 +47,8 @@
to reuse the actor's preprocessing network #1128
- `torch_utils` (new module)
- Added context managers `torch_train_mode` and `policy_within_training_step` #1123
- `print`
- `DataclassPPrintMixin` now supports outputting a string, not just printing the pretty repr. #1141

### Fixes
- `CriticFactoryReuseActor`: Enable the Critic flag `apply_preprocess_net_to_obs_only` for continuous critics,
Expand Down
3 changes: 2 additions & 1 deletion examples/atari/atari_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,7 @@ def __init__(
frame_stack: int,
scale: bool = False,
use_envpool_if_available: bool = True,
venv_type: VectorEnvType = VectorEnvType.SUBPROC_SHARED_MEM_AUTO,
) -> None:
assert "NoFrameskip" in task
self.frame_stack = frame_stack
Expand All @@ -412,7 +413,7 @@ def __init__(
task=task,
train_seed=train_seed,
test_seed=test_seed,
venv_type=VectorEnvType.SUBPROC_SHARED_MEM,
venv_type=venv_type,
envpool_factory=envpool_factory,
)

Expand Down
2 changes: 1 addition & 1 deletion examples/mujoco/mujoco_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(
train_seed: int,
test_seed: int,
obs_norm: bool = True,
venv_type: VectorEnvType = VectorEnvType.SUBPROC_SHARED_MEM,
venv_type: VectorEnvType = VectorEnvType.SUBPROC_SHARED_MEM_AUTO,
) -> None:
super().__init__(
task=task,
Expand Down
42 changes: 24 additions & 18 deletions examples/mujoco/mujoco_ppo_hl_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,20 @@
are non-intersecting with the seeds of the other experiments.
Each experiment's results are stored in a separate subdirectory.
The final results are aggregated and turned into useful statistics with the rliable API.
The final results are aggregated and turned into useful statistics with the rliable package.
The call to `eval_experiments` will load the results from the log directory and
create an interp-quantile mean plot for the returns as well as a performance profile plot.
These plots are saved in the log directory and displayed in the console.
"""

import os
import sys

import torch

from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.evaluation.launcher import RegisteredExpLauncher
from tianshou.evaluation.rliable_evaluation_hl import RLiableExperimentResult
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.env import VectorEnvType
from tianshou.highlevel.experiment import (
ExperimentConfig,
PPOExperimentBuilder,
Expand All @@ -39,14 +37,14 @@


def main(
num_experiments: int = 2,
run_experiments_sequentially: bool = True,
num_experiments: int = 5,
run_experiments_sequentially: bool = False,
) -> RLiableExperimentResult:
""":param run_experiments_sequentially: if True, the experiments are run sequentially, otherwise in parallel.
LIMITATIONS: currently, the parallel execution does not seem to work properly on linux.
It might generally be undesired to run multiple experiments in parallel on the same machine,
as a single experiment already uses all available CPU cores by default.
:return: the directory where the results are stored
""":param num_experiments: the number of experiments to run. The experiments differ exclusively in the seeds.
:param run_experiments_sequentially: if True, the experiments are run sequentially, otherwise in parallel.
If a single experiment is set to use all available CPU cores,
it might be undesired to run multiple experiments in parallel on the same machine,
:return: an object containing rliable-based evaluation results
"""
task = "Ant-v4"
persistence_dir = os.path.abspath(os.path.join("log", task, "ppo", datetime_tag()))
Expand All @@ -57,22 +55,19 @@ def main(
num_epochs=1,
step_per_epoch=5000,
batch_size=64,
num_train_envs=10,
num_test_envs=10,
num_test_episodes=10,
num_train_envs=5,
num_test_envs=5,
num_test_episodes=5,
buffer_size=4096,
step_per_collect=2048,
repeat_per_collect=10,
repeat_per_collect=1,
)

env_factory = MujocoEnvFactory(
task,
train_seed=sampling_config.train_seed,
test_seed=sampling_config.test_seed,
obs_norm=True,
venv_type=VectorEnvType.SUBPROC_SHARED_MEM_FORK_CONTEXT
if sys.platform == "darwin"
else VectorEnvType.SUBPROC_SHARED_MEM,
)

hidden_sizes = (64, 64)
Expand Down Expand Up @@ -108,7 +103,18 @@ def main(
launcher = RegisteredExpLauncher.sequential.create_launcher()
else:
launcher = RegisteredExpLauncher.joblib.create_launcher()
experiment_collection.run(launcher)
successful_experiment_stats = experiment_collection.run(launcher)
log.info(f"Successfully completed {len(successful_experiment_stats)} experiments.")

num_successful_experiments = len(successful_experiment_stats)
for i, info_stats in enumerate(successful_experiment_stats, start=1):
if info_stats is not None:
log.info(f"Training stats for successful experiment {i}/{num_successful_experiments}:")
log.info(info_stats.pprints_asdict())
else:
log.info(
f"No training stats available for successful experiment {i}/{num_successful_experiments}.",
)

rliable_result = RLiableExperimentResult.load_from_disk(persistence_dir)
rliable_result.eval_results(show_plots=True, save_plots=True)
Expand Down
104 changes: 89 additions & 15 deletions tianshou/evaluation/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@

import logging
from abc import ABC, abstractmethod
from collections.abc import Sequence
from collections.abc import Callable, Sequence
from copy import copy
from dataclasses import asdict, dataclass
from enum import Enum
from typing import Literal

from joblib import Parallel, delayed

from tianshou.data import InfoStats
from tianshou.highlevel.experiment import Experiment

log = logging.getLogger(__name__)
Expand All @@ -26,19 +27,89 @@ class JoblibConfig:


class ExpLauncher(ABC):
def __init__(
self,
experiment_runner: Callable[
[Experiment],
InfoStats | None,
] = lambda exp: exp.run().trainer_result,
):
""":param experiment_runner: can be used to override the default way in which an experiment is executed.
Can be useful e.g., if one wants to use the high-level interfaces to setup an experiment (or an experiment
collection) and tinker with it prior to execution. This need often arises when prototyping with mechanisms
that are not yet supported by the high-level interfaces.
Passing this allows arbitrary things to happen during experiment execution, so use it with caution!
"""
self.experiment_runner = experiment_runner

@abstractmethod
def launch(self, experiments: Sequence[Experiment]) -> None:
pass
def _launch(self, experiments: Sequence[Experiment]) -> list[InfoStats | None]:
"""Should call `self.experiment_runner` for each experiment in experiments and aggregate the results."""

def _safe_execute(self, exp: Experiment) -> InfoStats | None | Literal["failed"]:
try:
return self.experiment_runner(exp)
except BaseException as e:
log.error(f"Failed to run experiment {exp}.", exc_info=e)
return "failed"

@staticmethod
def _return_from_successful_and_failed_exps(
successful_exp_stats: list[InfoStats | None],
failed_exps: list[Experiment],
) -> list[InfoStats | None]:
if not successful_exp_stats:
raise RuntimeError("All experiments failed, see error logs for more details.")
if failed_exps:
log.error(
f"Failed to run the following "
f"{len(failed_exps)}/{len(successful_exp_stats) + len(failed_exps)} experiments: {failed_exps}. "
f"See the logs for more details. "
f"Returning the results of {len(successful_exp_stats)} successful experiments.",
)
return successful_exp_stats

def launch(self, experiments: Sequence[Experiment]) -> list[InfoStats | None]:
"""Will return the results of successfully executed experiments.
If a single experiment is passed, will not use parallelism and run it in the main process.
Failed experiments will be logged, and a RuntimeError is only raised if all experiments have failed.
"""
if len(experiments) == 1:
log.info(
"A single experiment is being run, will not use parallelism and run it in the main process.",
)
return [self.experiment_runner(experiments[0])]
return self._launch(experiments)


class SequentialExpLauncher(ExpLauncher):
def launch(self, experiments: Sequence[Experiment]) -> None:
"""Convenience wrapper around a simple for loop to run experiments sequentially."""

def _launch(self, experiments: Sequence[Experiment]) -> list[InfoStats | None]:
successful_exp_stats = []
failed_exps = []
for exp in experiments:
exp.run()
for exp in experiments:
exp_stats = self._safe_execute(exp)
if exp_stats == "failed":
failed_exps.append(exp)
else:
successful_exp_stats.append(exp_stats)
# noinspection PyTypeChecker
return self._return_from_successful_and_failed_exps(successful_exp_stats, failed_exps)


class JoblibExpLauncher(ExpLauncher):
def __init__(self, joblib_cfg: JoblibConfig | None = None) -> None:
def __init__(
self,
joblib_cfg: JoblibConfig | None = None,
experiment_runner: Callable[
[Experiment],
InfoStats | None,
] = lambda exp: exp.run().trainer_result,
) -> None:
super().__init__(experiment_runner=experiment_runner)
self.joblib_cfg = copy(joblib_cfg) if joblib_cfg is not None else JoblibConfig()
# Joblib's backend is hard-coded to loky since the threading backend produces different results
if self.joblib_cfg.backend != "loky":
Expand All @@ -48,15 +119,18 @@ def __init__(self, joblib_cfg: JoblibConfig | None = None) -> None:
)
self.joblib_cfg.backend = "loky"

def launch(self, experiments: Sequence[Experiment]) -> None:
Parallel(**asdict(self.joblib_cfg))(delayed(self._safe_execute)(exp) for exp in experiments)

@staticmethod
def _safe_execute(exp: Experiment) -> None:
try:
exp.run()
except BaseException as e:
log.error(e)
def _launch(self, experiments: Sequence[Experiment]) -> list[InfoStats | None]:
results = Parallel(**asdict(self.joblib_cfg))(
delayed(self._safe_execute)(exp) for exp in experiments
)
successful_exps = []
failed_exps = []
for exp, result in zip(experiments, results, strict=True):
if result == "failed":
failed_exps.append(exp)
else:
successful_exps.append(result)
return self._return_from_successful_and_failed_exps(successful_exps, failed_exps)


class RegisteredExpLauncher(Enum):
Expand Down
2 changes: 1 addition & 1 deletion tianshou/highlevel/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class SamplingConfig(ToStringMixin):
"""
controls, within one gradient update step of an on-policy algorithm, the number of times an
actual gradient update is applied using the full collected dataset, i.e. if the parameter is
`n`, then the collected data shall be used five times to update the policy within the same
5, then the collected data shall be used five times to update the policy within the same
training step.
The parameter is ignored and may be set to None for off-policy and offline algorithms.
Expand Down
13 changes: 11 additions & 2 deletions tianshou/highlevel/env.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import platform
from abc import ABC, abstractmethod
from collections.abc import Callable, Sequence
from enum import Enum
Expand Down Expand Up @@ -66,13 +67,15 @@ class VectorEnvType(Enum):
"""Vectorized environment without parallelization; environments are processed sequentially"""
SUBPROC = "subproc"
"""Parallelization based on `subprocess`"""
SUBPROC_SHARED_MEM = "shmem"
SUBPROC_SHARED_MEM_DEFAULT_CONTEXT = "shmem"
"""Parallelization based on `subprocess` with shared memory"""
SUBPROC_SHARED_MEM_FORK_CONTEXT = "shmem_fork"
"""Parallelization based on `subprocess` with shared memory and fork context (relevant for macOS, which uses `spawn`
by default https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods)"""
RAY = "ray"
"""Parallelization based on the `ray` library"""
SUBPROC_SHARED_MEM_AUTO = "subproc_shared_mem_auto"
"""Parallelization based on `subprocess` with shared memory, using default context on windows and fork context otherwise"""

def create_venv(
self,
Expand All @@ -83,10 +86,16 @@ def create_venv(
return DummyVectorEnv(factories)
case VectorEnvType.SUBPROC:
return SubprocVectorEnv(factories)
case VectorEnvType.SUBPROC_SHARED_MEM:
case VectorEnvType.SUBPROC_SHARED_MEM_DEFAULT_CONTEXT:
return SubprocVectorEnv(factories, share_memory=True)
case VectorEnvType.SUBPROC_SHARED_MEM_FORK_CONTEXT:
return SubprocVectorEnv(factories, share_memory=True, context="fork")
case VectorEnvType.SUBPROC_SHARED_MEM_AUTO:
if platform.system().lower() == "windows":
selected_venv_type = VectorEnvType.SUBPROC_SHARED_MEM_DEFAULT_CONTEXT
else:
selected_venv_type = VectorEnvType.SUBPROC_SHARED_MEM_FORK_CONTEXT
return selected_venv_type.create_venv(factories)
case VectorEnvType.RAY:
return RayVectorEnv(factories)
case _:
Expand Down
10 changes: 7 additions & 3 deletions tianshou/highlevel/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@
if TYPE_CHECKING:
from tianshou.evaluation.launcher import ExpLauncher, RegisteredExpLauncher


log = logging.getLogger(__name__)


Expand Down Expand Up @@ -356,15 +355,20 @@ def _watch_agent(


class ExperimentCollection:
"""Shallow wrapper around a list of experiments providing a simple interface for running them with a launcher."""

def __init__(self, experiments: list[Experiment]):
self.experiments = experiments

def run(self, launcher: Union["ExpLauncher", "RegisteredExpLauncher"]) -> None:
def run(
self,
launcher: Union["ExpLauncher", "RegisteredExpLauncher"],
) -> list[InfoStats | None]:
from tianshou.evaluation.launcher import RegisteredExpLauncher

if isinstance(launcher, RegisteredExpLauncher):
launcher = launcher.create_launcher()
launcher.launch(experiments=self.experiments)
return launcher.launch(experiments=self.experiments)


class ExperimentBuilder:
Expand Down

0 comments on commit f0b7abe

Please sign in to comment.