-
Notifications
You must be signed in to change notification settings - Fork 306
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Benchmark] Fix RB benchmarks #1760
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/1760
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (7 Unrelated Failures)As of commit 9e2ee73 with merge base 6d217c6 (): FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Hi, just a suggestion. The Somehow pytest do not recognize the class name anymore after This is because _TensorDictPrioritizedReplayBuffer = partial(TensorDictPrioritizedReplayBuffer, alpha=1, beta=0.9)
# preserve the name of the class even after partial
_TensorDictPrioritizedReplayBuffer.__name__ = TensorDictPrioritizedReplayBuffer.__name__ and then change the |
@harnvo edited following your suggestion, would mean a lot if you could submit your review! |
@@ -17,6 +18,12 @@ | |||
) | |||
from torchrl.data.replay_buffers import RandomSampler, SamplerWithoutReplacement | |||
|
|||
_TensorDictPrioritizedReplayBuffer = partial( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo: functools.partial
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
populated=True, | ||
size=size, | ||
)() | ||
benchmark(sample, rb) | ||
|
||
|
||
def infinite_iter(obj): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems there exists a infinite loop here and the test_rb_iterate
could not end the 4th test. (Possibly due to the circular sampling in SamplerWithoutReplacement
)
The old iterate
works fine and I think you should change this back.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, I wanted to avoid the cost of entering __iter__
but it's still better than an infinite loop!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this case, the following would work:
def iterate(rb_iter):
try:
next(rb_iter)
except StopIteration:
pass
and inside test_rb_iterate
:
benchmark(iterate, iter(rb))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The problem is that I was call iter on iter, not solving the problem that the first iter was empty. I will push a fix in a sec
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Works in my PC
@harnvo it should run now! |
cc @harnvo
I changed the names so that we can have a fresh look at the benchmarks on the benchmark tool.
I will try to remove the old traces that are wrong (need to look at how I can modify the xml files on gh-pages)