Skip to content

Latest commit

 

History

History
302 lines (227 loc) · 13.6 KB

README.rst

File metadata and controls

302 lines (227 loc) · 13.6 KB

RLlib: Industry-Grade Reinforcement Learning with TF and Torch

RLlib is an open-source library for reinforcement learning (RL), offering support for production-level, highly distributed RL workloads, while maintaining unified and simple APIs for a large variety of industry applications.

Whether you would like to train your agents in multi-agent setups, purely from offline (historic) datasets, or using externally connected simulators, RLlib offers simple solutions for your decision making needs.

If you either have your problem coded (in python) as an RL environment or own lots of pre-recorded, historic behavioral data to learn from, you will be up and running in only a few days.

RLlib is already used in production by industry leaders in many different verticals, such as climate control, industrial control, manufacturing and logistics, finance, gaming, automobile, robotics, boat design, and many others.

You can also read about RLlib Key Concepts.

Installation and Setup

Install RLlib and run your first experiment on your laptop in seconds:

TensorFlow:

$ conda create -n rllib python=3.8
$ conda activate rllib
$ pip install "ray[rllib]" tensorflow "gymnasium[atari]" "gymnasium[accept-rom-license]" atari_py
$ # Run a test job:
$ rllib train --run APPO --env CartPole-v0

PyTorch:

$ conda create -n rllib python=3.8
$ conda activate rllib
$ pip install "ray[rllib]" torch "gymnasium[atari]" "gymnasium[accept-rom-license]" atari_py
$ # Run a test job:
$ rllib train --run APPO --env CartPole-v0 --torch

Algorithms Supported

Model-free On-policy RL:

Model-free Off-policy RL:

Model-based RL:

Offline RL:

Multi-agent:

Others:

A list of all the algorithms can be found here .

Quick First Experiment

.. testcode::

    import gymnasium as gym
    from ray.rllib.algorithms.ppo import PPOConfig


    # Define your problem using python and Farama-Foundation's gymnasium API:
    class ParrotEnv(gym.Env):
        """Environment in which an agent must learn to repeat the seen observations.

        Observations are float numbers indicating the to-be-repeated values,
        e.g. -1.0, 5.1, or 3.2.

        The action space is always the same as the observation space.

        Rewards are r=-abs(observation - action), for all steps.
        """

        def __init__(self, config):
            # Make the space (for actions and observations) configurable.
            self.action_space = config.get(
                "parrot_shriek_range", gym.spaces.Box(-1.0, 1.0, shape=(1, )))
            # Since actions should repeat observations, their spaces must be the
            # same.
            self.observation_space = self.action_space
            self.cur_obs = None
            self.episode_len = 0

        def reset(self, *, seed=None, options=None):
            """Resets the episode and returns the initial observation of the new one.
            """
            # Reset the episode len.
            self.episode_len = 0
            # Sample a random number from our observation space.
            self.cur_obs = self.observation_space.sample()
            # Return initial observation.
            return self.cur_obs, {}

        def step(self, action):
            """Takes a single step in the episode given `action`

            Returns:
                New observation, reward, done-flag, info-dict (empty).
            """
            # Set `truncated` flag after 10 steps.
            self.episode_len += 1
            terminated = False
            truncated = self.episode_len >= 10
            # r = -abs(obs - action)
            reward = -sum(abs(self.cur_obs - action))
            # Set a new observation (random sample).
            self.cur_obs = self.observation_space.sample()
            return self.cur_obs, reward, terminated, truncated, {}


    # Create an RLlib Algorithm instance from a PPOConfig to learn how to
    # act in the above environment.
    config = (
        PPOConfig()
        .environment(
            # Env class to use (here: our gym.Env sub-class from above).
            env=ParrotEnv,
            # Config dict to be passed to our custom env's constructor.
            env_config={
                "parrot_shriek_range": gym.spaces.Box(-5.0, 5.0, (1, ))
            },
        )
        # Parallelize environment rollouts.
        .rollouts(num_rollout_workers=3)
    )
    # Use the config's `build()` method to construct a PPO object.
    algo = config.build()

    # Train for n iterations and report results (mean episode rewards).
    # Since we have to guess 10 times and the optimal reward is 0.0
    # (exact match between observation and action value),
    # we can expect to reach an optimal episode reward of 0.0.
    for i in range(1):
        results = algo.train()
        print(f"Iter: {i}; avg. reward={results['episode_reward_mean']}")

.. testoutput::
    :options: +MOCK

    Iter: 0; avg. reward=-41.88662799871655


After training, you may want to perform action computations (inference) in your environment. Below is a minimal example on how to do this. Also check out our more detailed examples here (in particular for normal models, LSTMs, and attention nets).

.. testcode::

    # Perform inference (action computations) based on given env observations.
    # Note that we are using a slightly simpler env here (-3.0 to 3.0, instead
    # of -5.0 to 5.0!), however, this should still work as the agent has
    # (hopefully) learned to "just always repeat the observation!".
    env = ParrotEnv({"parrot_shriek_range": gym.spaces.Box(-3.0, 3.0, (1, ))})
    # Get the initial observation (some value between -10.0 and 10.0).
    obs, info = env.reset()
    terminated = truncated = False
    total_reward = 0.0
    # Play one episode.
    while not terminated and not truncated:
        # Compute a single action, given the current observation
        # from the environment.
        action = algo.compute_single_action(obs)
        # Apply the computed action in the environment.
        obs, reward, terminated, truncated, info = env.step(action)
        # Sum up rewards for reporting purposes.
        total_reward += reward
    # Report results.
    print(f"Shreaked for 1 episode; total-reward={total_reward}")

.. testoutput::
    :options: +MOCK

    Shreaked for 1 episode; total-reward=-0.001


For a more detailed "60 second" example, head to our main documentation.

Highlighted Features

The following is a summary of RLlib's most striking features (for an in-depth overview, check out our documentation):

The most popular deep-learning frameworks: PyTorch and TensorFlow (tf1.x/2.x static-graph/eager/traced).

Highly distributed learning: Our RLlib algorithms (such as our "PPO" or "IMPALA") allow you to set the num_workers config parameter, such that your workloads can run on 100s of CPUs/nodes thus parallelizing and speeding up learning.

Vectorized (batched) and remote (parallel) environments: RLlib auto-vectorizes your gym.Envs via the num_envs_per_worker config. Environment workers can then batch and thus significantly speedup the action computing forward pass. On top of that, RLlib offers the remote_worker_envs config to create single environments (within a vectorized one) as ray Actors, thus parallelizing even the env stepping process.

Multi-agent RL (MARL): Convert your (custom) gym.Envs into a multi-agent one via a few simple steps and start training your agents in any of the following fashions:
1) Cooperative with shared or separate policies and/or value functions.
2) Adversarial scenarios using self-play and league-based training.
3) Independent learning of neutral/co-existing agents.

External simulators: Don't have your simulation running as a gym.Env in python? No problem! RLlib supports an external environment API and comes with a pluggable, off-the-shelve client/ server setup that allows you to run 100s of independent simulators on the "outside" (e.g. a Windows cloud) connecting to a central RLlib Policy-Server that learns and serves actions. Alternatively, actions can be computed on the client side to save on network traffic.

Offline RL and imitation learning/behavior cloning: You don't have a simulator for your particular problem, but tons of historic data recorded by a legacy (maybe non-RL/ML) system? This branch of reinforcement learning is for you! RLlib's comes with several offline RL algorithms (CQL, MARWIL, and DQfD), allowing you to either purely behavior-clone your existing system or learn how to further improve over it.

In-Depth Documentation

For an in-depth overview of RLlib and everything it has to offer, including hand-on tutorials of important industry use cases and workflows, head over to our documentation pages.

Cite our Paper

If you've found RLlib useful for your research, please cite our paper as follows:

@inproceedings{liang2018rllib,
    Author = {Eric Liang and
              Richard Liaw and
              Robert Nishihara and
              Philipp Moritz and
              Roy Fox and
              Ken Goldberg and
              Joseph E. Gonzalez and
              Michael I. Jordan and
              Ion Stoica},
    Title = {{RLlib}: Abstractions for Distributed Reinforcement Learning},
    Booktitle = {International Conference on Machine Learning ({ICML})},
    Year = {2018}
}