Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sampler should not flatten observations and actions #967

Merged
merged 3 commits into from
Nov 2, 2019

Conversation

ahtsan
Copy link
Contributor

@ahtsan ahtsan commented Oct 29, 2019

To make API consistent among all samplers, samplers should not flatten the observation nor the actions, which is done by algorithms.

@ahtsan ahtsan requested a review from a team as a code owner October 29, 2019 05:04
@codecov
Copy link

codecov bot commented Oct 29, 2019

Codecov Report

Merging #967 into master will increase coverage by 0.09%.
The diff coverage is 100%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #967      +/-   ##
==========================================
+ Coverage   84.49%   84.58%   +0.09%     
==========================================
  Files         156      156              
  Lines        7467     7468       +1     
  Branches      938      938              
==========================================
+ Hits         6309     6317       +8     
+ Misses        966      964       -2     
+ Partials      192      187       -5
Impacted Files Coverage Δ
src/garage/sampler/utils.py 82.35% <100%> (+10.35%) ⬆️
src/garage/np/policies/base.py 85% <100%> (ø) ⬆️
...rc/garage/sampler/off_policy_vectorized_sampler.py 100% <0%> (+1.17%) ⬆️
.../exploration_strategies/epsilon_greedy_strategy.py 100% <0%> (+3.7%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 393e9c2...8789f0b. Read the comment docs.

@@ -111,6 +115,6 @@ def truncate_paths(paths, max_samples):
truncated_last_path[k] = tensor_utils.truncate_tensor_dict(
v, truncated_len)
else:
raise NotImplementedError
raise NotImplementedError()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you raise a ValueError instead? and include an error message which specifies the valid keys versus the ones found?

@naeioi
Copy link
Member

naeioi commented Oct 29, 2019

Not directly related to this PR, but there is also inconsistency in flattening in primitives. For example, continuous_mlp_policy_with_model.py flattens observation

flat_obs = self.observation_space.flatten_n(observations)

while categorical_cnn_policy.py does not
probs = self._f_prob(observations)

To make things even worse, there is a flatten_input argument in batch_polopt.py

if self.flatten_input:

We should make all these into a single switch for flattening.

@ryanjulian
Copy link
Member

@naeioi thanks for pointing this out.

i think that's a bug in ContinuousMLPPolicy then -- for now we have decided not to flatten in the primitives.

@ahtsan what do you think?

@ahtsan
Copy link
Contributor Author

ahtsan commented Oct 30, 2019

@naeioi @ryanjulian Yes I think right now flattening is duplicate in both primitive and algorithm - we should probably only do that in one place (which I think we chose to be algorithm). That said, we should remove the flattening in all primitives and add flattening support in all other algorithms (except batch_polopt) as well.

Regarding categorical_cnn_policy.py, we don't want to flatten the observations since we want to keep spatial information for CNN primitives. In that case, I think we could achieve the same goal by pass flatten_input=False to algorithm when using CNN primitives. Does that sound right?

@ryanjulian
Copy link
Member

@ahtsan that sounds right to me. for now let's just make it consistent and we can consider more automatic designs later.

@ryanjulian
Copy link
Member

I added some tests to try to merge this for the release, but it turns out something is still quite broken (or my tests are just wrong).

In any case, this won't make it and we will have to backport the fix I think.

@ahtsan @krzentner should be interested.

@ryanjulian ryanjulian added the backport-to-2019.10 Backport this PR to release-2019.10 label Nov 1, 2019
@ahtsan
Copy link
Contributor Author

ahtsan commented Nov 1, 2019

I added some tests to try to merge this for the release, but it turns out something is still quite broken (or my tests are just wrong).

In any case, this won't make it and we will have to backport the fix I think.

@ahtsan @krzentner should be interested.

I think the issue is that the dummy policy doesn't have dist_info_keys to extract agent_info, and the test assumes "env_info is not empty (which is not true in reality) and agent_info is not empty". We should probably fix the dummy policy and make them inherit the policy base class to make sure they comply with the APIs. We might also want to introduce env.env_info_keys so we can extract that, similarly to dist_info_keys.

I am working on a fix.

@ryanjulian
Copy link
Member

If you look at the broken test, I think the shape of path['observations'] is wrong for a non-flat observation space.

@ahtsan
Copy link
Contributor Author

ahtsan commented Nov 1, 2019

If you look at the broken test, I think the shape of path['observations'] is wrong for a non-flat observation space.

Do you mean this test doesn't pass?

    def test_does_not_flatten(self):
        path = utils.rollout(self.env, self.policy, max_path_length=5)
        assert path['observations'][0].shape == (4, 4)
        assert path['actions'][0].shape == (2, 2)

@ryanjulian
Copy link
Member

Yes

@ryanjulian ryanjulian removed the backport-to-2019.10 Backport this PR to release-2019.10 label Nov 1, 2019
@ahtsan ahtsan merged commit 93b1a48 into master Nov 2, 2019
@ahtsan ahtsan deleted the fix_batch_sampler branch November 2, 2019 04:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants