Skip to content
Branch: master
Find file Copy path
Find file Copy path
1 contributor

Users who have contributed to this file

97 lines (77 sloc) 2.95 KB
"""Example of a custom gym environment and model. Run this for a demo.
This example shows:
- using a custom environment
- using a custom model
- using Tune for grid search
You can visualize experiment results in ~/ray_results using TensorBoard.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import gym
from ray.rllib.models import ModelCatalog
from import TFModelV2
from import FullyConnectedNetwork
from gym.spaces import Discrete, Box
import ray
from ray import tune
from ray.rllib.utils import try_import_tf
from ray.tune import grid_search
tf = try_import_tf()
class SimpleCorridor(gym.Env):
"""Example of a custom env in which you have to walk down a corridor.
You can configure the length of the corridor via the env config."""
def __init__(self, config):
self.end_pos = config["corridor_length"]
self.cur_pos = 0
self.action_space = Discrete(2)
self.observation_space = Box(
0.0, self.end_pos, shape=(1, ), dtype=np.float32)
def reset(self):
self.cur_pos = 0
return [self.cur_pos]
def step(self, action):
assert action in [0, 1], action
if action == 0 and self.cur_pos > 0:
self.cur_pos -= 1
elif action == 1:
self.cur_pos += 1
done = self.cur_pos >= self.end_pos
return [self.cur_pos], 1 if done else 0, done, {}
class CustomModel(TFModelV2):
"""Example of a custom model that just delegates to a fc-net."""
def __init__(self, obs_space, action_space, num_outputs, model_config,
super(CustomModel, self).__init__(obs_space, action_space, num_outputs,
model_config, name)
self.model = FullyConnectedNetwork(obs_space, action_space,
num_outputs, model_config, name)
def forward(self, input_dict, state, seq_lens):
return self.model.forward(input_dict, state, seq_lens)
def value_function(self):
return self.model.value_function()
if __name__ == "__main__":
# Can also register the env creator function explicitly with:
# register_env("corridor", lambda config: SimpleCorridor(config))
ModelCatalog.register_custom_model("my_model", CustomModel)
"timesteps_total": 10000,
"env": SimpleCorridor, # or "corridor" if registered above
"model": {
"custom_model": "my_model",
"vf_share_layers": True,
"lr": grid_search([1e-2, 1e-4, 1e-6]), # try different lrs
"num_workers": 1, # parallelism
"env_config": {
"corridor_length": 5,
You can’t perform that action at this time.