Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add flattening logic so ddpg can handle image observations #23

Merged
merged 3 commits into from
Jul 1, 2016

Conversation

singulaire
Copy link
Contributor

Logic to flatten observations that are not a flat vector, allowing the ddpg algorithm to work with environments that return image data in their observations. cf. #22.

@@ -14,7 +14,7 @@ def rollout(env, agent, max_path_length=np.inf, animated=False, speedup=1):
if animated:
env.render()
while path_length < max_path_length:
a, agent_info = agent.get_action(o)
a, agent_info = agent.get_action(env.observation_space.flatten(o))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would break quite a few things since the get_action method is supposed to receive just the raw actions. Why is this change needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

without this change I get the following error:

Traceback (most recent call last):
File "/home/tom/rllab/scripts/run_experiment_lite.py", line 103, in
run_experiment(sys.argv)
File "/home/tom/rllab/scripts/run_experiment_lite.py", line 90, in run_experiment
maybe_iter = concretize(data)
File "/home/tom/rllab/rllab/misc/instrument.py", line 898, in concretize
return method(_args, *_kwargs)
File "/home/tom/rllab/rllab/algos/ddpg.py", line 263, in train
self.evaluate(epoch, pool)
File "/home/tom/rllab/rllab/algos/ddpg.py", line 381, in evaluate
max_path_length=self.max_path_length,
File "/home/tom/rllab/rllab/sampler/parallel_sampler.py", line 114, in sample_paths
show_prog_bar=True
File "/home/tom/rllab/rllab/sampler/stateful_pool.py", line 142, in run_collect
result, inc = collect_once(self.G, *args)
File "/home/tom/rllab/rllab/sampler/parallel_sampler.py", line 89, in _worker_collect_one_path
path = rollout(G.env, G.policy, max_path_length)
File "/home/tom/rllab/rllab/sampler/utils.py", line 17, in rollout
a, agent_info = agent.get_action(o)
File "/home/tom/rllab/rllab/policies/deterministic_mlp_policy.py", line 66, in get_action
action = self._f_actions([observation])[0]
File "/home/tom/anaconda2/envs/rllab/lib/python2.7/site-packages/theano/compile/function_module.py", line 784, in call
allow_downcast=s.allow_downcast)
File "/home/tom/anaconda2/envs/rllab/lib/python2.7/site-packages/theano/tensor/type.py", line 178, in filter
data.shape))
TypeError: ('Bad input argument to theano function with name "/home/tom/rllab/rllab/misc/ext.py:135" at index 0 (0-based)', 'Wrong number of dimensions: expected 2, got 4 with shape (1, 96, 96, 3).')

Or in short, the policy network is expecting a flat input but getting a non-flat input. An alternative would be to handle flattening inside the get_action function of DeterministicMLPPolicy.

@@ -224,7 +224,7 @@ def train(self):
self.es_path_returns.append(path_return)
path_length = 0
path_return = 0
action = self.es.get_action(itr, observation, policy=sample_policy) # qf=qf)
action = self.es.get_action(itr, self.env.observation_space.flatten(observation), policy=sample_policy) # qf=qf)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to be a little consistent, could this be changed to passing the raw observations to exploration strategy? (shouldn't require any other changes, as I believe that none of the exploration strategies implemented make use of the observations...)

@dementrock dementrock merged commit 327446a into rll:master Jul 1, 2016
@dementrock
Copy link
Member

Thanks, merged!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants