-
Notifications
You must be signed in to change notification settings - Fork 306
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Fix DType casting lazy init #1589
Merged
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
facebook-github-bot
added
the
CLA Signed
This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
label
Oct 2, 2023
|
Name | Max | Mean | Ops | Ops on Repo HEAD
|
Change |
---|---|---|---|---|---|
test_single | 0.1020s | 0.1012s | 9.8772 Ops/s | 9.7235 Ops/s | |
test_sync | 71.3060ms | 55.2376ms | 18.1036 Ops/s | 18.5475 Ops/s | |
test_async | 0.1099s | 53.0560ms | 18.8480 Ops/s | 18.9476 Ops/s | |
test_simple | 0.9034s | 0.8375s | 1.1940 Ops/s | 1.2088 Ops/s | |
test_transformed | 1.1081s | 1.0355s | 0.9657 Ops/s | 0.9396 Ops/s | |
test_serial | 2.2578s | 2.1880s | 0.4570 Ops/s | 0.4498 Ops/s | |
test_parallel | 1.9228s | 1.8388s | 0.5438 Ops/s | 0.5403 Ops/s | |
test_step_mdp_speed[True-True-True-True-True] | 0.1172ms | 44.1052μs | 22.6731 KOps/s | 22.3370 KOps/s | |
test_step_mdp_speed[True-True-True-True-False] | 50.8010μs | 25.0407μs | 39.9350 KOps/s | 39.1018 KOps/s | |
test_step_mdp_speed[True-True-True-False-True] | 99.1010μs | 30.9689μs | 32.2905 KOps/s | 31.5152 KOps/s | |
test_step_mdp_speed[True-True-True-False-False] | 45.7010μs | 17.4843μs | 57.1942 KOps/s | 55.3396 KOps/s | |
test_step_mdp_speed[True-True-False-True-True] | 72.6000μs | 45.4201μs | 22.0167 KOps/s | 21.2773 KOps/s | |
test_step_mdp_speed[True-True-False-True-False] | 0.2172ms | 26.8009μs | 37.3121 KOps/s | 36.5678 KOps/s | |
test_step_mdp_speed[True-True-False-False-True] | 0.1509ms | 33.4134μs | 29.9281 KOps/s | 29.4044 KOps/s | |
test_step_mdp_speed[True-True-False-False-False] | 41.9010μs | 19.3499μs | 51.6799 KOps/s | 50.6732 KOps/s | |
test_step_mdp_speed[True-False-True-True-True] | 94.6010μs | 47.5697μs | 21.0218 KOps/s | 20.4034 KOps/s | |
test_step_mdp_speed[True-False-True-True-False] | 50.4010μs | 28.5705μs | 35.0011 KOps/s | 34.4857 KOps/s | |
test_step_mdp_speed[True-False-True-False-True] | 82.5010μs | 33.0465μs | 30.2604 KOps/s | 28.9168 KOps/s | |
test_step_mdp_speed[True-False-True-False-False] | 69.2010μs | 19.3080μs | 51.7921 KOps/s | 50.6332 KOps/s | |
test_step_mdp_speed[True-False-False-True-True] | 72.3010μs | 48.9085μs | 20.4463 KOps/s | 19.6030 KOps/s | |
test_step_mdp_speed[True-False-False-True-False] | 96.7020μs | 30.3080μs | 32.9946 KOps/s | 32.5247 KOps/s | |
test_step_mdp_speed[True-False-False-False-True] | 60.0000μs | 34.5472μs | 28.9459 KOps/s | 28.0584 KOps/s | |
test_step_mdp_speed[True-False-False-False-False] | 48.5000μs | 20.8510μs | 47.9592 KOps/s | 47.1766 KOps/s | |
test_step_mdp_speed[False-True-True-True-True] | 76.5010μs | 47.4274μs | 21.0848 KOps/s | 20.5057 KOps/s | |
test_step_mdp_speed[False-True-True-True-False] | 97.7010μs | 28.9193μs | 34.5789 KOps/s | 34.4202 KOps/s | |
test_step_mdp_speed[False-True-True-False-True] | 62.6000μs | 36.8010μs | 27.1732 KOps/s | 26.6355 KOps/s | |
test_step_mdp_speed[False-True-True-False-False] | 95.2010μs | 21.5680μs | 46.3650 KOps/s | 45.9046 KOps/s | |
test_step_mdp_speed[False-True-False-True-True] | 71.0010μs | 49.1076μs | 20.3635 KOps/s | 20.0593 KOps/s | |
test_step_mdp_speed[False-True-False-True-False] | 0.1030ms | 30.3878μs | 32.9079 KOps/s | 32.4323 KOps/s | |
test_step_mdp_speed[False-True-False-False-True] | 68.0010μs | 38.3092μs | 26.1034 KOps/s | 25.6120 KOps/s | |
test_step_mdp_speed[False-True-False-False-False] | 78.0000μs | 23.1456μs | 43.2047 KOps/s | 42.1841 KOps/s | |
test_step_mdp_speed[False-False-True-True-True] | 75.9010μs | 50.8511μs | 19.6653 KOps/s | 19.3066 KOps/s | |
test_step_mdp_speed[False-False-True-True-False] | 58.0010μs | 32.1092μs | 31.1437 KOps/s | 30.4425 KOps/s | |
test_step_mdp_speed[False-False-True-False-True] | 60.3000μs | 38.2437μs | 26.1481 KOps/s | 25.4646 KOps/s | |
test_step_mdp_speed[False-False-True-False-False] | 85.7010μs | 22.9867μs | 43.5034 KOps/s | 41.6587 KOps/s | |
test_step_mdp_speed[False-False-False-True-True] | 81.8010μs | 52.1988μs | 19.1575 KOps/s | 18.8277 KOps/s | |
test_step_mdp_speed[False-False-False-True-False] | 57.1000μs | 33.5501μs | 29.8062 KOps/s | 29.0832 KOps/s | |
test_step_mdp_speed[False-False-False-False-True] | 89.5010μs | 39.7300μs | 25.1699 KOps/s | 24.9614 KOps/s | |
test_step_mdp_speed[False-False-False-False-False] | 48.5000μs | 24.2076μs | 41.3094 KOps/s | 40.3569 KOps/s | |
test_values[generalized_advantage_estimate-True-True] | 17.9546ms | 13.5865ms | 73.6025 Ops/s | 72.8320 Ops/s | |
test_values[vec_generalized_advantage_estimate-True-True] | 55.7943ms | 47.1087ms | 21.2275 Ops/s | 22.0185 Ops/s | |
test_values[td0_return_estimate-False-False] | 1.4800ms | 0.5030ms | 1.9879 KOps/s | 2.4628 KOps/s | |
test_values[td1_return_estimate-False-False] | 13.5768ms | 12.8526ms | 77.8051 Ops/s | 77.4109 Ops/s | |
test_values[vec_td1_return_estimate-False-False] | 57.4910ms | 46.7409ms | 21.3945 Ops/s | 22.1987 Ops/s | |
test_values[td_lambda_return_estimate-True-False] | 32.7305ms | 31.1147ms | 32.1392 Ops/s | 31.9401 Ops/s | |
test_values[vec_td_lambda_return_estimate-True-False] | 57.7460ms | 45.8122ms | 21.8283 Ops/s | 22.1009 Ops/s | |
test_gae_speed[generalized_advantage_estimate-False-1-512] | 11.4268ms | 11.3130ms | 88.3938 Ops/s | 87.1885 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] | 11.8742ms | 3.9360ms | 254.0669 Ops/s | 262.2090 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] | 2.2153ms | 0.5596ms | 1.7869 KOps/s | 1.7872 KOps/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] | 67.3170ms | 60.7516ms | 16.4605 Ops/s | 16.2907 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] | 12.0237ms | 3.3281ms | 300.4745 Ops/s | 280.8198 Ops/s | |
test_dqn_speed | 5.7761ms | 2.4214ms | 412.9878 Ops/s | 413.5649 Ops/s | |
test_ddpg_speed | 12.4854ms | 4.4146ms | 226.5200 Ops/s | 226.3251 Ops/s | |
test_sac_speed | 17.9294ms | 11.8382ms | 84.4720 Ops/s | 82.9609 Ops/s | |
test_redq_speed | 29.0052ms | 20.1602ms | 49.6027 Ops/s | 49.0197 Ops/s | |
test_redq_deprec_speed | 25.6553ms | 17.7443ms | 56.3560 Ops/s | 56.5359 Ops/s | |
test_td3_speed | 13.3291ms | 12.5128ms | 79.9183 Ops/s | 79.3599 Ops/s | |
test_cql_speed | 38.1271ms | 34.2003ms | 29.2395 Ops/s | 26.5101 Ops/s | |
test_a2c_speed | 10.8097ms | 7.4903ms | 133.5064 Ops/s | 126.9756 Ops/s | |
test_ppo_speed | 15.0940ms | 8.2247ms | 121.5854 Ops/s | 129.6301 Ops/s | |
test_reinforce_speed | 10.2680ms | 6.3614ms | 157.1978 Ops/s | 171.5532 Ops/s | |
test_iql_speed | 41.7187ms | 29.5321ms | 33.8615 Ops/s | 32.7840 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 2.8977ms | 2.6027ms | 384.2174 Ops/s | 377.4746 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 4.9115ms | 2.8116ms | 355.6651 Ops/s | 349.9659 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 4.6001ms | 2.8349ms | 352.7477 Ops/s | 350.2921 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 3.6175ms | 2.6648ms | 375.2649 Ops/s | 380.4885 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 5.6863ms | 2.8518ms | 350.6601 Ops/s | 359.9382 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 4.5353ms | 2.8291ms | 353.4714 Ops/s | 350.9852 Ops/s | |
test_sample_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 3.6150ms | 2.7522ms | 363.3513 Ops/s | 382.8796 Ops/s | |
test_sample_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 0.1656s | 3.3561ms | 297.9649 Ops/s | 356.8357 Ops/s | |
test_sample_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.1565s | 3.3736ms | 296.4170 Ops/s | 350.8716 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 3.2399ms | 2.6401ms | 378.7740 Ops/s | 380.2489 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 4.5655ms | 2.8829ms | 346.8774 Ops/s | 355.4661 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 5.6847ms | 2.8860ms | 346.5031 Ops/s | 356.4826 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 3.4813ms | 2.6328ms | 379.8291 Ops/s | 381.9431 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 5.5208ms | 2.8810ms | 347.1064 Ops/s | 351.5149 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.1851s | 3.4173ms | 292.6320 Ops/s | 347.3600 Ops/s | |
test_iterate_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 3.4054ms | 2.6425ms | 378.4364 Ops/s | 378.8293 Ops/s | |
test_iterate_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 4.9741ms | 2.8455ms | 351.4339 Ops/s | 350.4280 Ops/s | |
test_iterate_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 4.5815ms | 2.8350ms | 352.7297 Ops/s | 303.7465 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-400] | 0.3014s | 31.4351ms | 31.8116 Ops/s | 31.4770 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] | 0.1638s | 31.3760ms | 31.8715 Ops/s | 35.4739 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] | 0.1568s | 27.9721ms | 35.7499 Ops/s | 32.4081 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] | 0.1566s | 30.7057ms | 32.5672 Ops/s | 35.7892 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] | 0.1582s | 27.9205ms | 35.8160 Ops/s | 32.4310 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] | 0.1670s | 30.9185ms | 32.3431 Ops/s | 32.3124 Ops/s | |
test_populate_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] | 0.1538s | 27.7947ms | 35.9781 Ops/s | 35.5205 Ops/s | |
test_populate_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] | 0.1562s | 30.6340ms | 32.6435 Ops/s | 32.2745 Ops/s | |
test_populate_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] | 0.1578s | 30.7295ms | 32.5421 Ops/s | 35.6273 Ops/s |
matteobettini
approved these changes
Oct 2, 2023
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Labels
bug
Something isn't working
CLA Signed
This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
The init of DType casting is broken when the first access to the transform is made through reset. This PR fixes that.
This PR also uncouples
in_keys
andout_keys
in the transform base class. This is needed to avoid having to patch transforms with a logic that is, actually, not as generic as expected. From a user perspective nothing changes so this isn't bc-breaking.The checks on dtypes for the DTypeTransform are not a bit more stringent.