-
Notifications
You must be signed in to change notification settings - Fork 799
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add flattening logic so ddpg can handle image observations #23
Conversation
@@ -14,7 +14,7 @@ def rollout(env, agent, max_path_length=np.inf, animated=False, speedup=1): | |||
if animated: | |||
env.render() | |||
while path_length < max_path_length: | |||
a, agent_info = agent.get_action(o) | |||
a, agent_info = agent.get_action(env.observation_space.flatten(o)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would break quite a few things since the get_action
method is supposed to receive just the raw actions. Why is this change needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
without this change I get the following error:
Traceback (most recent call last):
File "/home/tom/rllab/scripts/run_experiment_lite.py", line 103, in
run_experiment(sys.argv)
File "/home/tom/rllab/scripts/run_experiment_lite.py", line 90, in run_experiment
maybe_iter = concretize(data)
File "/home/tom/rllab/rllab/misc/instrument.py", line 898, in concretize
return method(_args, *_kwargs)
File "/home/tom/rllab/rllab/algos/ddpg.py", line 263, in train
self.evaluate(epoch, pool)
File "/home/tom/rllab/rllab/algos/ddpg.py", line 381, in evaluate
max_path_length=self.max_path_length,
File "/home/tom/rllab/rllab/sampler/parallel_sampler.py", line 114, in sample_paths
show_prog_bar=True
File "/home/tom/rllab/rllab/sampler/stateful_pool.py", line 142, in run_collect
result, inc = collect_once(self.G, *args)
File "/home/tom/rllab/rllab/sampler/parallel_sampler.py", line 89, in _worker_collect_one_path
path = rollout(G.env, G.policy, max_path_length)
File "/home/tom/rllab/rllab/sampler/utils.py", line 17, in rollout
a, agent_info = agent.get_action(o)
File "/home/tom/rllab/rllab/policies/deterministic_mlp_policy.py", line 66, in get_action
action = self._f_actions([observation])[0]
File "/home/tom/anaconda2/envs/rllab/lib/python2.7/site-packages/theano/compile/function_module.py", line 784, in call
allow_downcast=s.allow_downcast)
File "/home/tom/anaconda2/envs/rllab/lib/python2.7/site-packages/theano/tensor/type.py", line 178, in filter
data.shape))
TypeError: ('Bad input argument to theano function with name "/home/tom/rllab/rllab/misc/ext.py:135" at index 0 (0-based)', 'Wrong number of dimensions: expected 2, got 4 with shape (1, 96, 96, 3).')
Or in short, the policy network is expecting a flat input but getting a non-flat input. An alternative would be to handle flattening inside the get_action function of DeterministicMLPPolicy.
@@ -224,7 +224,7 @@ def train(self): | |||
self.es_path_returns.append(path_return) | |||
path_length = 0 | |||
path_return = 0 | |||
action = self.es.get_action(itr, observation, policy=sample_policy) # qf=qf) | |||
action = self.es.get_action(itr, self.env.observation_space.flatten(observation), policy=sample_policy) # qf=qf) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to be a little consistent, could this be changed to passing the raw observations to exploration strategy? (shouldn't require any other changes, as I believe that none of the exploration strategies implemented make use of the observations...)
Thanks, merged! |
Logic to flatten observations that are not a flat vector, allowing the ddpg algorithm to work with environments that return image data in their observations. cf. #22.