-
Notifications
You must be signed in to change notification settings - Fork 306
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Refactor] Faster envs (2) #1457
Conversation
|
Name | Max | Mean | Ops | Ops on Repo HEAD
|
Change |
---|---|---|---|---|---|
test_single | 0.1393s | 0.1384s | 7.2279 Ops/s | 6.7754 Ops/s | |
test_sync | 0.1479s | 77.3141ms | 12.9342 Ops/s | 12.0750 Ops/s | |
test_async | 0.1938s | 73.0239ms | 13.6942 Ops/s | 13.0088 Ops/s | |
test_simple | 0.6867s | 0.6170s | 1.6207 Ops/s | 1.6093 Ops/s | |
test_transformed | 1.6663s | 1.6043s | 0.6233 Ops/s | 0.5920 Ops/s | |
test_serial | 1.7591s | 1.7011s | 0.5879 Ops/s | 0.5093 Ops/s | |
test_parallel | 1.5634s | 1.4942s | 0.6693 Ops/s | 0.6204 Ops/s | |
test_step_mdp_speed[True-True-True-True-True] | 0.1606ms | 44.6101μs | 22.4164 KOps/s | 22.6276 KOps/s | |
test_step_mdp_speed[True-True-True-True-False] | 93.2020μs | 25.3293μs | 39.4799 KOps/s | 39.8819 KOps/s | |
test_step_mdp_speed[True-True-True-False-True] | 0.1092ms | 31.1811μs | 32.0708 KOps/s | 32.1355 KOps/s | |
test_step_mdp_speed[True-True-True-False-False] | 45.2010μs | 17.4488μs | 57.3104 KOps/s | 58.2652 KOps/s | |
test_step_mdp_speed[True-True-False-True-True] | 0.1340ms | 45.9581μs | 21.7590 KOps/s | 21.7292 KOps/s | |
test_step_mdp_speed[True-True-False-True-False] | 63.6010μs | 27.0889μs | 36.9154 KOps/s | 37.3445 KOps/s | |
test_step_mdp_speed[True-True-False-False-True] | 0.1086ms | 33.4016μs | 29.9387 KOps/s | 29.9290 KOps/s | |
test_step_mdp_speed[True-True-False-False-False] | 44.2000μs | 19.3657μs | 51.6378 KOps/s | 52.4679 KOps/s | |
test_step_mdp_speed[True-False-True-True-True] | 0.1252ms | 48.3980μs | 20.6620 KOps/s | 20.9121 KOps/s | |
test_step_mdp_speed[True-False-True-True-False] | 0.1022ms | 28.8974μs | 34.6052 KOps/s | 34.5226 KOps/s | |
test_step_mdp_speed[True-False-True-False-True] | 0.1592ms | 33.8083μs | 29.5785 KOps/s | 30.1931 KOps/s | |
test_step_mdp_speed[True-False-True-False-False] | 97.5020μs | 19.2759μs | 51.8783 KOps/s | 52.0894 KOps/s | |
test_step_mdp_speed[True-False-False-True-True] | 97.9020μs | 50.0849μs | 19.9661 KOps/s | 20.2904 KOps/s | |
test_step_mdp_speed[True-False-False-True-False] | 0.1024ms | 30.3109μs | 32.9914 KOps/s | 32.9354 KOps/s | |
test_step_mdp_speed[True-False-False-False-True] | 0.1616ms | 35.3969μs | 28.2511 KOps/s | 28.9924 KOps/s | |
test_step_mdp_speed[True-False-False-False-False] | 48.3000μs | 20.8032μs | 48.0696 KOps/s | 48.2027 KOps/s | |
test_step_mdp_speed[False-True-True-True-True] | 0.1137ms | 48.1004μs | 20.7898 KOps/s | 20.9114 KOps/s | |
test_step_mdp_speed[False-True-True-True-False] | 50.5010μs | 28.7913μs | 34.7327 KOps/s | 34.0017 KOps/s | |
test_step_mdp_speed[False-True-True-False-True] | 0.1201ms | 37.4881μs | 26.6752 KOps/s | 25.6100 KOps/s | |
test_step_mdp_speed[False-True-True-False-False] | 3.2953ms | 21.6042μs | 46.2874 KOps/s | 45.4794 KOps/s | |
test_step_mdp_speed[False-True-False-True-True] | 0.1248ms | 49.8007μs | 20.0800 KOps/s | 20.3685 KOps/s | |
test_step_mdp_speed[False-True-False-True-False] | 97.4020μs | 30.5748μs | 32.7067 KOps/s | 32.7694 KOps/s | |
test_step_mdp_speed[False-True-False-False-True] | 0.1186ms | 38.8381μs | 25.7479 KOps/s | 24.6716 KOps/s | |
test_step_mdp_speed[False-True-False-False-False] | 97.2020μs | 23.3945μs | 42.7452 KOps/s | 42.1579 KOps/s | |
test_step_mdp_speed[False-False-True-True-True] | 78.2020μs | 51.0446μs | 19.5907 KOps/s | 19.5684 KOps/s | |
test_step_mdp_speed[False-False-True-True-False] | 0.1019ms | 32.6807μs | 30.5991 KOps/s | 30.9838 KOps/s | |
test_step_mdp_speed[False-False-True-False-True] | 0.1152ms | 38.7226μs | 25.8247 KOps/s | 24.7414 KOps/s | |
test_step_mdp_speed[False-False-True-False-False] | 56.5010μs | 23.3576μs | 42.8126 KOps/s | 42.8149 KOps/s | |
test_step_mdp_speed[False-False-False-True-True] | 0.1438ms | 52.1779μs | 19.1652 KOps/s | 19.1596 KOps/s | |
test_step_mdp_speed[False-False-False-True-False] | 89.2020μs | 33.8109μs | 29.5762 KOps/s | 29.6812 KOps/s | |
test_step_mdp_speed[False-False-False-False-True] | 70.9010μs | 39.8541μs | 25.0915 KOps/s | 23.1099 KOps/s | |
test_step_mdp_speed[False-False-False-False-False] | 96.5010μs | 24.6764μs | 40.5246 KOps/s | 40.2648 KOps/s | |
test_values[generalized_advantage_estimate-True-True] | 15.7383ms | 13.5918ms | 73.5737 Ops/s | 73.3193 Ops/s | |
test_values[vec_generalized_advantage_estimate-True-True] | 45.9083ms | 40.8992ms | 24.4503 Ops/s | 23.9171 Ops/s | |
test_values[td0_return_estimate-False-False] | 0.3003ms | 0.1860ms | 5.3768 KOps/s | 4.7050 KOps/s | |
test_values[td1_return_estimate-False-False] | 13.5835ms | 13.2667ms | 75.3767 Ops/s | 76.3167 Ops/s | |
test_values[vec_td1_return_estimate-False-False] | 46.7013ms | 41.3187ms | 24.2021 Ops/s | 24.5061 Ops/s | |
test_values[td_lambda_return_estimate-True-False] | 32.9919ms | 31.5605ms | 31.6852 Ops/s | 31.5362 Ops/s | |
test_values[vec_td_lambda_return_estimate-True-False] | 46.5164ms | 41.4118ms | 24.1477 Ops/s | 24.3243 Ops/s | |
test_gae_speed[generalized_advantage_estimate-False-1-512] | 13.1433ms | 11.9557ms | 83.6420 Ops/s | 84.0546 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] | 8.4312ms | 3.3242ms | 300.8235 Ops/s | 288.3139 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] | 0.5901ms | 0.4541ms | 2.2021 KOps/s | 2.1204 KOps/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] | 62.5153ms | 56.2641ms | 17.7733 Ops/s | 17.8985 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] | 7.3074ms | 2.7554ms | 362.9238 Ops/s | 359.5459 Ops/s | |
test_dqn_speed | 8.8302ms | 1.8392ms | 543.7140 Ops/s | 547.8821 Ops/s | |
test_ddpg_speed | 9.6086ms | 2.6534ms | 376.8743 Ops/s | 367.0492 Ops/s | |
test_sac_speed | 17.2164ms | 7.8280ms | 127.7463 Ops/s | 123.3549 Ops/s | |
test_redq_speed | 22.7394ms | 15.9568ms | 62.6694 Ops/s | 62.4717 Ops/s | |
test_redq_deprec_speed | 16.8215ms | 12.2293ms | 81.7706 Ops/s | 76.8589 Ops/s | |
test_td3_speed | 10.8532ms | 9.8453ms | 101.5710 Ops/s | 96.0183 Ops/s | |
test_cql_speed | 42.3939ms | 36.6898ms | 27.2555 Ops/s | 38.0057 Ops/s | |
test_a2c_speed | 11.3330ms | 5.2170ms | 191.6802 Ops/s | 189.3897 Ops/s | |
test_ppo_speed | 41.4040ms | 5.7131ms | 175.0376 Ops/s | 182.8101 Ops/s | |
test_reinforce_speed | 11.0398ms | 4.0813ms | 245.0214 Ops/s | 245.0349 Ops/s | |
test_iql_speed | 32.7330ms | 23.3990ms | 42.7369 Ops/s | 46.0300 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 3.3918ms | 2.6365ms | 379.2877 Ops/s | 378.9104 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 5.1394ms | 2.8096ms | 355.9191 Ops/s | 360.8460 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 4.7708ms | 2.8414ms | 351.9367 Ops/s | 354.6221 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 3.6404ms | 2.6794ms | 373.2245 Ops/s | 378.2628 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 5.3296ms | 2.8291ms | 353.4745 Ops/s | 355.9805 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 4.3288ms | 2.8099ms | 355.8840 Ops/s | 358.6406 Ops/s | |
test_sample_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 3.6365ms | 2.6348ms | 379.5307 Ops/s | 387.2584 Ops/s | |
test_sample_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 0.1417s | 3.2116ms | 311.3740 Ops/s | 359.8914 Ops/s | |
test_sample_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 5.2880ms | 2.8446ms | 351.5433 Ops/s | 353.8041 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 3.5652ms | 2.7108ms | 368.8987 Ops/s | 380.0087 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 5.1808ms | 2.9410ms | 340.0224 Ops/s | 358.1901 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 5.3371ms | 2.9584ms | 338.0169 Ops/s | 353.2290 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 3.4765ms | 2.6998ms | 370.3962 Ops/s | 378.8107 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 4.7352ms | 2.8686ms | 348.5967 Ops/s | 359.3995 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 4.9174ms | 2.8607ms | 349.5603 Ops/s | 355.2383 Ops/s | |
test_iterate_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 3.5962ms | 2.6742ms | 373.9450 Ops/s | 377.4031 Ops/s | |
test_iterate_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 4.7858ms | 2.8520ms | 350.6280 Ops/s | 359.8586 Ops/s | |
test_iterate_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 5.9294ms | 2.9027ms | 344.5106 Ops/s | 356.3901 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-400] | 0.2903s | 30.6421ms | 32.6348 Ops/s | 31.9249 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] | 0.1471s | 29.6074ms | 33.7753 Ops/s | 34.8381 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] | 0.1423s | 27.0494ms | 36.9695 Ops/s | 37.5991 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] | 0.1435s | 29.5092ms | 33.8878 Ops/s | 33.4265 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] | 0.1433s | 27.1496ms | 36.8330 Ops/s | 37.9529 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] | 0.1416s | 26.8441ms | 37.2521 Ops/s | 34.1288 Ops/s | |
test_populate_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] | 0.1420s | 27.1229ms | 36.8692 Ops/s | 37.9302 Ops/s | |
test_populate_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] | 0.1462s | 29.7785ms | 33.5813 Ops/s | 34.4476 Ops/s | |
test_populate_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] | 0.1445s | 26.9509ms | 37.1046 Ops/s | 37.9909 Ops/s |
# Conflicts: # torchrl/collectors/collectors.py # torchrl/envs/common.py # torchrl/envs/transforms/transforms.py # torchrl/envs/utils.py # torchrl/envs/vec_env.py
torchrl/envs/vec_env.py
Outdated
# output keys after reset | ||
self._selected_reset_keys = { | ||
_unravel_key_to_tuple(key) | ||
for key in self._env_obs_keys + self.done_keys + ["_reset"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reset keys should not include "_reset" right?
if so there will be multiple _reset keys (one per done)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need the "_reset" keys in the reset_keys as these need to be synced across processes.
I edited the logic a bit, you can have a look
for key in self._selected_reset_keys: | ||
_set_single_key(self.shared_tensordict_parent, out, key, clone=True) | ||
if _unravel_key_to_tuple(key)[-1] != "_reset": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
again, no need to check here if you just exclude the reset key from selected reset keys
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but if you wanna check, there could be multiple reset keys
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same response as above: we need to sync them across processes (line 627)
for key in self._selected_reset_keys: | ||
_set_single_key(self.shared_tensordict_parent, out, key, clone=True) | ||
if _unravel_key_to_tuple(key)[-1] != "_reset": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above
self._tensordict[traj_done_or_terminated] = td_reset[ | ||
traj_done_or_terminated | ||
] | ||
self._tensordict = torch.where( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this compatible with lazy stacks?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Let's remember about the multiple reset keys across vec_env.
And it seems like some "next": {}
might still be there in some steps of mocking envs (for example MultiKeyConutngEnv)
torchrl/envs/utils.py
Outdated
@@ -562,6 +568,101 @@ def make_composite_from_td(data): | |||
return composite | |||
|
|||
|
|||
def _fuse_tensordicts(*tds, excluded, selected=None, total=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this used anywhere?
What do you mean? |
i think you updated the pr, i was just saying: let’s remeber there are multiple reset keys |
# Deprecated reset_when_done | ||
# @pytest.mark.parametrize("num_env", [1, 2]) | ||
# @pytest.mark.parametrize("env_name", ["vec"]) | ||
# def test_collector_done_persist(num_env, env_name, seed=5): | ||
# if num_env == 1: | ||
# | ||
# def env_fn(seed): | ||
# env = MockSerialEnv(device="cpu") | ||
# env.set_seed(seed) | ||
# return env | ||
# | ||
# else: | ||
# | ||
# def env_fn(seed): | ||
# def make_env(seed): | ||
# env = MockSerialEnv(device="cpu") | ||
# env.set_seed(seed) | ||
# return env | ||
# | ||
# env = ParallelEnv( | ||
# num_workers=num_env, | ||
# create_env_fn=make_env, | ||
# create_env_kwargs=[{"seed": i} for i in range(seed, seed + num_env)], | ||
# ) | ||
# env.set_seed(seed) | ||
# return env | ||
# | ||
# policy = make_policy(env_name) | ||
# | ||
# collector = SyncDataCollector( | ||
# create_env_fn=env_fn, | ||
# create_env_kwargs={"seed": seed}, | ||
# policy=policy, | ||
# frames_per_batch=200 * num_env, | ||
# max_frames_per_traj=2000, | ||
# total_frames=20000, | ||
# device="cpu", | ||
# reset_when_done=False, | ||
# ) | ||
# for _, d in enumerate(collector): # noqa | ||
# break | ||
# | ||
# assert (d["done"].sum(-2) >= 1).all() | ||
# assert torch.unique(d["collector", "traj_ids"], dim=-1).shape[-1] == 1 | ||
# | ||
# del collector |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left out comments?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We deprecated that feature but I wanted to make sure that we were not going to integrate it back so I kept the commented code.
It'll be removed from the code base before the next release.
Thanks for spotting this!
A minimalist version of #1448