forked from ray-project/ray
-
Notifications
You must be signed in to change notification settings - Fork 0
/
catalog_guide.py
129 lines (101 loc) · 4.63 KB
/
catalog_guide.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# flake8: noqa
"""
This file holds several examples for the Catalogs API that are used in the catalog
guide.
"""
# 1) Basic interaction with Catalogs in RLlib.
# __sphinx_doc_basic_interaction_begin__
import gymnasium as gym
from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog
env = gym.make("CartPole-v1")
catalog = PPOCatalog(env.observation_space, env.action_space, model_config_dict={})
# Build an encoder that fits CartPole's observation space.
encoder = catalog.build_actor_critic_encoder(framework="torch")
policy_head = catalog.build_pi_head(framework="torch")
# We expect a categorical distribution for CartPole.
action_dist_class = catalog.get_action_dist_cls(framework="torch")
# __sphinx_doc_basic_interaction_end__
# 2) Basic workflow that includes the Catalog base class and
# RLlib's ModelConfigs to build models and an action distribution to step through an
# environment.
# __sphinx_doc_modelsworkflow_begin__
import gymnasium as gym
import torch
# ENCODER_OUT is a constant we use to enumerate Encoder I/O.
from ray.rllib.core.models.base import ENCODER_OUT
from ray.rllib.core.models.catalog import Catalog
from ray.rllib.policy.sample_batch import SampleBatch
env = gym.make("CartPole-v1")
catalog = Catalog(env.observation_space, env.action_space, model_config_dict={})
# We expect a categorical distribution for CartPole.
action_dist_class = catalog.get_action_dist_cls(framework="torch")
# Build an encoder that fits CartPole's observation space.
encoder = catalog.build_encoder(framework="torch")
# Build a suitable head model for the action distribution.
# We need `env.action_space.n` action distribution inputs.
head = torch.nn.Linear(catalog.latent_dims[0], env.action_space.n)
# Now we are ready to interact with the environment
obs, info = env.reset()
# Encoders check for state and sequence lengths for recurrent models.
# We don't need either in this case because default encoders are not recurrent.
input_batch = {SampleBatch.OBS: torch.Tensor([obs])}
# Pass the batch through our models and the action distribution.
encoding = encoder(input_batch)[ENCODER_OUT]
action_dist_inputs = head(encoding)
action_dist = action_dist_class.from_logits(action_dist_inputs)
actions = action_dist.sample().numpy()
env.step(actions[0])
# __sphinx_doc_modelsworkflow_end__
# 3) Demonstrates a basic workflow that includes the PPOCatalog to build models
# and an action distribution to step through an environment.
# __sphinx_doc_ppo_models_begin__
import gymnasium as gym
import torch
from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog
# STATE_IN, STATE_OUT and ENCODER_OUT are constants we use to enumerate Encoder I/O.
from ray.rllib.core.models.base import STATE_IN, ENCODER_OUT, ACTOR
from ray.rllib.policy.sample_batch import SampleBatch
env = gym.make("CartPole-v1")
catalog = PPOCatalog(env.observation_space, env.action_space, model_config_dict={})
# Build an encoder that fits CartPole's observation space.
encoder = catalog.build_actor_critic_encoder(framework="torch")
policy_head = catalog.build_pi_head(framework="torch")
# We expect a categorical distribution for CartPole.
action_dist_class = catalog.get_action_dist_cls(framework="torch")
# Now we are ready to interact with the environment
obs, info = env.reset()
# Encoders check for state and sequence lengths for recurrent models.
# We don't need either in this case because default encoders are not recurrent.
input_batch = {SampleBatch.OBS: torch.Tensor([obs])}
# Pass the batch through our models and the action distribution.
encoding = encoder(input_batch)[ENCODER_OUT][ACTOR]
action_dist_inputs = policy_head(encoding)
action_dist = action_dist_class.from_logits(action_dist_inputs)
actions = action_dist.sample().numpy()
env.step(actions[0])
# __sphinx_doc_ppo_models_end__
# 4) Demonstrates how to specify a Catalog for an RLModule to use through
# AlgorithmConfig.
# __sphinx_doc_algo_configs_begin__
from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
class MyPPOCatalog(PPOCatalog):
def __init__(self, *args, **kwargs):
print("Hi from within PPORLModule!")
super().__init__(*args, **kwargs)
config = (
PPOConfig()
.environment("CartPole-v1")
.framework("torch")
.rl_module(_enable_rl_module_api=True)
.training(_enable_learner_api=True)
)
# Specify the catalog to use for the PPORLModule.
config = config.rl_module(
rl_module_spec=SingleAgentRLModuleSpec(catalog_class=MyPPOCatalog)
)
# This is how RLlib constructs a PPORLModule
# It will say "Hi from within PPORLModule!".
ppo = config.build()
# __sphinx_doc_algo_configs_end__