# Serving reinforcement learning policy models
In this example, we train a reinforcement learning model and serve it
using Ray Serve.

We then instantiate an environment and step through it by querying the served model
for actions via HTTP.

Let's start with installing our dependencies:

In [1]:
!pip install -qU "ray[rllib,serve]" gymnasium

Now we can run some imports:

In [2]:
import gymnasium as gym
import numpy as np
import requests

from ray.air.checkpoint import Checkpoint
from ray.air.config import RunConfig
from ray.train.rl.rl_trainer import RLTrainer
from ray.air.config import ScalingConfig
from ray.train.rl.rl_predictor import RLPredictor
from ray.air.result import Result
from ray.serve import PredictorDeployment
from ray import serve
from ray.tune.tuner import Tuner

Since we'll be serving a reinforcement learning policy, we need to train one first. Thus we define a simple training function which will kick off online reinforcement learning of a PPO agent on the `CartPole-v1` environment.

In [3]:
def train_rl_ppo_online(num_workers: int, use_gpu: bool = False) -> Result:
    print("Starting online training")
    trainer = RLTrainer(
        run_config=RunConfig(stop={"training_iteration": 5}),
        scaling_config=ScalingConfig(num_workers=num_workers, use_gpu=use_gpu),
        algorithm="PPO",
        config={
            "env": "CartPole-v1",
            "framework": "tf",
        },
    )
    # Todo (krfricke/xwjiang): Enable checkpoint config in RunConfig
    # result = trainer.fit()
    tuner = Tuner(
        trainer,
        _tuner_kwargs={"checkpoint_at_end": True},
    )
    result = tuner.fit()[0]
    return result

Once we obtained a trained checkpoint, we will want to serve it using Ray Serve:

In [4]:
def serve_rl_model(checkpoint: Checkpoint, name="RLModel") -> str:
    """Serve a RL model and return deployment URI.

    This function will start Ray Serve and deploy a model wrapper
    that loads the RL checkpoint into a RLPredictor.
    """
    serve.run(
        PredictorDeployment.options(name=name).bind(
            RLPredictor, checkpoint
        )
    )
    return f"http://localhost:8000/"

And to make sure everything works well, we can kick off an evaluation run on a fresh environment. This will query the served policy model to obtain actions using HTTP.

In [5]:
def evaluate_served_policy(endpoint_uri: str, num_episodes: int = 3) -> list:
    """Evaluate a served RL policy on a local environment.

    This function will create an RL environment and step through it.
    To obtain the actions, it will query the deployed RL model.
    """
    env = gym.make("CartPole-v1")

    rewards = []
    for i in range(num_episodes):
        obs, _ = env.reset()
        reward = 0.0
        terminated = truncated = False
        while not terminated and not truncated:
            action = query_action(endpoint_uri, obs)
            obs, r, terminated, truncated, _ = env.step(action)
            reward += r
        rewards.append(reward)

    return rewards


def query_action(endpoint_uri: str, obs: np.ndarray):
    """Perform inference on a served RL model.

    This will send a HTTP request to the Ray Serve endpoint of the served
    RL policy model and return the result.
    """
    action_dict = requests.post(endpoint_uri, json={"array": obs.tolist()}).json()
    return action_dict

Let's put it all together. First, we train the model:

In [6]:
num_workers = 2
use_gpu = False

result = train_rl_ppo_online(num_workers=num_workers, use_gpu=use_gpu)



Starting online training


2022-05-19 14:19:35,724	INFO services.py:1483 -- View the Ray dashboard at [1m[32mhttp://127.0.0.1:8269[39m[22m


Trial name,status,loc,iter,total time (s),ts,reward,episode_reward_max,episode_reward_min,episode_len_mean
AIRPPOTrainer_55884_00000,TERMINATED,127.0.0.1:15610,5,16.4897,20000,131.8,200,16,131.8


[2m[33m(raylet)[0m 2022-05-19 14:19:39,542	INFO context.py:70 -- Exec'ing worker with command: exec /Users/kai/.pyenv/versions/3.7.7/bin/python3.7 /Users/kai/coding/ray/python/ray/workers/default_worker.py --node-ip-address=127.0.0.1 --node-manager-port=51686 --object-store-name=/tmp/ray/session_2022-05-19_14-19-32_884042_15394/sockets/plasma_store --raylet-name=/tmp/ray/session_2022-05-19_14-19-32_884042_15394/sockets/raylet --redis-address=None --storage=None --temp-dir=/tmp/ray --metrics-agent-port=52347 --logging-rotate-bytes=536870912 --logging-rotate-backup-count=5 --gcs-address=127.0.0.1:65218 --redis-password=5241590000000000 --startup-token=16 --runtime-env-hash=-2010331134
[2m[36m(AIRPPOTrainer pid=15610)[0m 2022-05-19 14:19:47,485	INFO trainer.py:1728 -- Your framework setting is 'tf', meaning you are using static-graph mode. Set framework='tf2' to enable eager execution with tf2.x. You may also then want to set eager_tracing=True in order to reach similar execution sp

Result for AIRPPOTrainer_55884_00000:
  agent_timesteps_total: 4000
  counters:
    num_agent_steps_sampled: 4000
    num_agent_steps_trained: 4000
    num_env_steps_sampled: 4000
    num_env_steps_trained: 4000
  custom_metrics: {}
  date: 2022-05-19_14-20-01
  done: false
  episode_len_mean: 20.4020618556701
  episode_media: {}
  episode_reward_max: 91.0
  episode_reward_mean: 20.4020618556701
  episode_reward_min: 9.0
  episodes_this_iter: 194
  episodes_total: 194
  experiment_id: 91a6faca48864f6aa47a7847d8741683
  hostname: Kais-MacBook-Pro.local
  info:
    learner:
      default_policy:
        custom_metrics: {}
        learner_stats:
          cur_kl_coeff: 0.20000000298023224
          cur_lr: 4.999999873689376e-05
          entropy: 0.6655290722846985
          entropy_coeff: 0.0
          kl: 0.028071347624063492
          model: {}
          policy_loss: -0.04146554693579674
          total_loss: 8.68990707397461
          vf_explained_var: 0.010860291309654713
          v

2022-05-19 14:20:14,687	INFO tune.py:753 -- Total run time: 36.43 seconds (35.98 seconds for the tuning loop).


Then, we serve it using Ray Serve:

In [7]:
endpoint_uri = serve_rl_model(result.checkpoint)

[2m[36m(ServeController pid=15625)[0m INFO 2022-05-19 14:20:16,749 controller 15625 checkpoint_path.py:17 - Using RayInternalKVStore for controller checkpoint and recovery.
[2m[36m(ServeController pid=15625)[0m INFO 2022-05-19 14:20:16,751 controller 15625 http_state.py:115 - Starting HTTP proxy with name 'SERVE_CONTROLLER_ACTOR:SERVE_PROXY_ACTOR-node:127.0.0.1-0' on node 'node:127.0.0.1-0' listening on '127.0.0.1:8000'
[2m[36m(HTTPProxyActor pid=15630)[0m INFO:     Started server process [15630]
[2m[36m(ServeController pid=15625)[0m INFO 2022-05-19 14:20:26,056 controller 15625 deployment_state.py:1217 - Adding 1 replicas to deployment 'RLModel'.
[2m[36m(RLModel pid=15633)[0m 2022-05-19 14:20:34,700	INFO trainer.py:1728 -- Your framework setting is 'tf', meaning you are using static-graph mode. Set framework='tf2' to enable eager execution with tf2.x. You may also then want to set eager_tracing=True in order to reach similar execution speed as with static-graph mode.
[

And then we evaluate the served model on a fresh environment:

In [8]:
rewards = evaluate_served_policy(endpoint_uri=endpoint_uri)

[2m[36m(HTTPProxyActor pid=15630)[0m INFO 2022-05-19 14:20:45,215 http_proxy 127.0.0.1 http_proxy.py:320 - POST /RLModel 307 3.4ms
[2m[36m(RLModel pid=15633)[0m INFO 2022-05-19 14:20:45,214 RLModel RLModel#OeYEbL replica.py:483 - HANDLE __call__ OK 0.3ms
[2m[36m(HTTPProxyActor pid=15630)[0m INFO 2022-05-19 14:20:45,253 http_proxy 127.0.0.1 http_proxy.py:320 - POST /RLModel 200 33.6ms
[2m[36m(HTTPProxyActor pid=15630)[0m INFO 2022-05-19 14:20:45,260 http_proxy 127.0.0.1 http_proxy.py:320 - POST /RLModel 307 2.5ms
[2m[36m(HTTPProxyActor pid=15630)[0m INFO 2022-05-19 14:20:45,267 http_proxy 127.0.0.1 http_proxy.py:320 - POST /RLModel 200 3.8ms
[2m[36m(HTTPProxyActor pid=15630)[0m INFO 2022-05-19 14:20:45,273 http_proxy 127.0.0.1 http_proxy.py:320 - POST /RLModel 307 2.3ms
[2m[36m(HTTPProxyActor pid=15630)[0m INFO 2022-05-19 14:20:45,280 http_proxy 127.0.0.1 http_proxy.py:320 - POST /RLModel 200 4.1ms
[2m[36m(HTTPProxyActor pid=15630)[0m INFO 2022-05-19 14:20:45,285

In [9]:
print("Episode rewards:", rewards)

Episode rewards: [200.0, 200.0, 200.0]


After we're done, we can shutdown Ray Serve.

In [10]:
serve.shutdown()

[2m[36m(HTTPProxyActor pid=15630)[0m INFO 2022-05-19 14:20:52,369 http_proxy 127.0.0.1 http_proxy.py:320 - POST /RLModel 200 4.0ms
[2m[36m(HTTPProxyActor pid=15630)[0m INFO 2022-05-19 14:20:52,375 http_proxy 127.0.0.1 http_proxy.py:320 - POST /RLModel 307 2.6ms
[2m[36m(HTTPProxyActor pid=15630)[0m INFO 2022-05-19 14:20:52,381 http_proxy 127.0.0.1 http_proxy.py:320 - POST /RLModel 200 4.2ms
[2m[36m(HTTPProxyActor pid=15630)[0m INFO 2022-05-19 14:20:52,387 http_proxy 127.0.0.1 http_proxy.py:320 - POST /RLModel 307 2.1ms
[2m[36m(HTTPProxyActor pid=15630)[0m INFO 2022-05-19 14:20:52,393 http_proxy 127.0.0.1 http_proxy.py:320 - POST /RLModel 200 3.8ms
[2m[36m(HTTPProxyActor pid=15630)[0m INFO 2022-05-19 14:20:52,398 http_proxy 127.0.0.1 http_proxy.py:320 - POST /RLModel 307 2.3ms
[2m[36m(HTTPProxyActor pid=15630)[0m INFO 2022-05-19 14:20:52,404 http_proxy 127.0.0.1 http_proxy.py:320 - POST /RLModel 200 3.7ms
[2m[36m(HTTPProxyActor pid=15630)[0m INFO 2022-05-19 14:20: