You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
It seems that the rendering and recording procedure as laid out here here doesn't work when the environment is a MultiAgentEnv. I tried with a custom environment, and then made a simple version based on the example, and the mp4's simply don't appear. It works normally when the environment is single agent and inherits from gym.Env.
One thing I have tried (and doesn't work) is making the env inherit both from gym.Env and MultiAgentEnv
Ray version and other system information (Python version, TensorFlow version, OS): Python 3.8, Ray 1.3.0, Torch 1.8.1, MacOS 11.2.1
Reproduction (REQUIRED)
importargparseimportnumpyasnpimportrayfromgym.spacesimportBox, Discretefromrayimporttunefromray.rllibimportMultiAgentEnvparser=argparse.ArgumentParser()
parser.add_argument(
"--framework",
choices=["tf", "tf2", "tfe", "torch"],
default="tf",
help="The DL framework specifier.")
parser.add_argument("--stop-iters", type=int, default=10)
parser.add_argument("--stop-timesteps", type=int, default=10000)
parser.add_argument("--stop-reward", type=float, default=9.0)
classCustomRenderedEnv(MultiAgentEnv):
"""Example of a custom env, for which you can specify rendering behavior. """metadata= {
"render.modes": ["rgb_array"],
}
def__init__(self, config):
self.end_pos=config.get("corridor_length", 10)
self.max_steps=config.get("max_steps", 100)
self.cur_pos=0self.steps=0self.action_space=Discrete(2)
self.observation_space=Box(0.0, 999.0, shape=(1, ), dtype=np.float32)
defreset(self):
self.cur_pos=0.0self.steps=0obs_dict= {"agent": [self.cur_pos]}
returnobs_dictdefstep(self, actions):
action=actions["agent"]
self.steps+=1assertactionin [0, 1], actionifaction==0andself.cur_pos>0:
self.cur_pos-=1.0elifaction==1:
self.cur_pos+=1.0done=self.cur_pos>=self.end_posor \
self.steps>=self.max_stepsobs_dict= {"agent": [self.cur_pos]}
done_dict= {"agent": done, "__all__": done}
reward_dict= {"agent": 10.0ifdoneelse-0.1}
returnobs_dict, reward_dict, done_dict, {}
defrender(self, mode="rgb"):
returnnp.random.randint(0, 256, size=(300, 400, 3), dtype=np.uint8)
if__name__=="__main__":
# Note: Recording and rendering in this example# should work for both local_mode=True|False.ray.init(num_cpus=4)
args=parser.parse_args()
obs_space=Box(0.0, 999.0, shape=(1, ), dtype=np.float32)
act_space=Discrete(2)
policies= {"shared_policy": (None, obs_space, act_space, {})}
policy_ids=list(policies.keys())
# Example config causingconfig= {
# Also try common gym envs like: "CartPole-v0" or "Pendulum-v0"."env": CustomRenderedEnv,
"env_config": {
"corridor_length": 10,
"max_steps": 100,
},
"multiagent": {
"policies": policies,
"policy_mapping_fn": (lambdaagent_id: "shared_policy"),
},
# Evaluate once per training iteration."evaluation_interval": 1,
# Run evaluation on (at least) two episodes"evaluation_num_episodes":2,
# ... using one evaluation worker (setting this to 0 will cause# evaluation to run on the local evaluation worker, blocking# training until evaluation is done)."evaluation_num_workers": 1,
# Special evaluation config. Keys specified here will override# the same keys in the main config, but only for evaluation."evaluation_config": {
# Store videos in this relative directory here inside# the default output dir (~/ray_results/...).# Alternatively, you can specify an absolute path.# Set to True for using the default output dir (~/ray_results/...).# Set to False for not recording anything."record_env": "videos",
# "record_env": "videos",# "record_env": "/Users/xyz/my_videos/",# Render the env while evaluating.# Note that this will always only render the 1st RolloutWorker's# env and only the 1st sub-env in a vectorized env."render_env": True,
},
"num_workers": 1,
# Use a vectorized env with 2 sub-envs."num_envs_per_worker": 2,
"framework": args.framework,
}
stop= {
"training_iteration": args.stop_iters,
"timesteps_total": args.stop_timesteps,
"episode_reward_mean": args.stop_reward,
}
results=tune.run("PPO", config=config, stop=stop)
I have verified my script runs in a clean environment and reproduces the issue.
I have verified the issue also occurs with the latest wheels.
The text was updated successfully, but these errors were encountered:
RedTachyon
added
bug
Something that is supposed to be working; but isn't
triage
Needs triage (eg: priority, bug/not-bug, and owning component)
labels
Jun 7, 2021
I believe I am also facing the same problem and have opened an issue here
richardliaw
added
P2
Important issue, but not time-critical
rllib
RLlib related issues
and removed
triage
Needs triage (eg: priority, bug/not-bug, and owning component)
labels
Jun 14, 2021
Hey @rfali and @RedTachyon , could you take a look at this PR here and let me know, whether this would fix your issue?
It was a little tricky fixing this. The reason was that RLlib uses the gym Monitor wrapper, which only works on gym.Env objects. MultiAgentEnv is not a gym.Env (even though it almost looks like one), so I had to create a new child wrapper that is able to handle MultiAgentEnv (which e.g. returns dict rewards, not floats).
At a glance it seems to be it. The code doesn't seem to work just yet (it is wip after all), but I'll be happy to check the functionality when it's done.
What is the problem?
It seems that the rendering and recording procedure as laid out here here doesn't work when the environment is a MultiAgentEnv. I tried with a custom environment, and then made a simple version based on the example, and the mp4's simply don't appear. It works normally when the environment is single agent and inherits from
gym.Env
.One thing I have tried (and doesn't work) is making the env inherit both from
gym.Env
andMultiAgentEnv
Ray version and other system information (Python version, TensorFlow version, OS): Python 3.8, Ray 1.3.0, Torch 1.8.1, MacOS 11.2.1
Reproduction (REQUIRED)
The text was updated successfully, but these errors were encountered: