-
Notifications
You must be signed in to change notification settings - Fork 306
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Algorithm] Online Decision transformer #1149
Conversation
…to decision_transformer
…to decision_transformer
…to decision_transformer
# Conflicts: # .circleci/unittest/linux_examples/scripts/environment.yml # test/test_modules.py # torchrl/modules/__init__.py
) | ||
|
||
|
||
class ModifiedGPT2Model(GPT2Model): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wrapper class to remove wpe layer of the GPT2Model from transformers. Maybe we can compress this even more?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- This should run even if transformers isn't installed.
- Do we have dedicated tests?
- Is it integrated in the doc?
- The docstring is a bit cryptic for someone who doesn't know what it is all about.
I wish transformers had a more modular code... What is the signature of wpe? In some cases we can simply replace the layer by nn.Identity()...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried to use the identity but got some shape issues. But I found out that with all the fixes I did it now even converges with the wpe layer. For comparison, I also ran a test where I exchanged the wpe layer with a custom ZeroPosEmbeddingLayer returning only zeros. In the graph, you can see with wpe and with zero wpe.
Let me know what you think. For now, I took the ZeroPosEmbeddingLayer off as it does converge but I can add it as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh but at this point let's get rid of that class altogether no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I removed it all already. If you can have a final look I think it should be ready now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! A couple of last edits and we can ship this!! 🚀💪🏻
) | ||
|
||
|
||
class ModifiedGPT2Model(GPT2Model): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- This should run even if transformers isn't installed.
- Do we have dedicated tests?
- Is it integrated in the doc?
- The docstring is a bit cryptic for someone who doesn't know what it is all about.
I wish transformers had a more modular code... What is the signature of wpe? In some cases we can simply replace the layer by nn.Identity()...
A million thanks for this feature @BY571! |
Co-authored-by: vmoens <vincentmoens@gmail.com> Co-authored-by: Mateusz Guzek <matguzek@meta.com>
Description
Implements the Online Decision Transformer Paper
Motivation and Context
Why is this change required? What problem does it solve?
If it fixes an open issue, please link to the issue here.
You can use the syntax
close #15213
if this solves the issue #15213Types of changes
What types of changes does your code introduce? Remove all that do not apply:
Checklist
Go over all the following points, and put an
x
in all the boxes that apply.If you are unsure about any of these, don't hesitate to ask. We are here to help!