-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP]Add gradient checkpointing support for AutoencoderKLWan #11105
base: main
Are you sure you want to change the base?
[WIP]Add gradient checkpointing support for AutoencoderKLWan #11105
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, very nice!
@bot /style |
My PR failed to pass the test code. Would you like to look at this link? test failure case: test_effective_gradient_checkpointing |
@victolee0 I implemented it the same way you did. It works fine for the forward pass, but during the backward pass, the cache index gets mixed up and goes out of bounds. I might need to use a dictionary for the cache mechanism instead of an index-based list. |
@quickdahuk Code
test case error
|
@victolee0 I've implemented gradient checkpointing for the decoder. I don't need it for the encoder now. The training is working fine for me. I implemented checkpointing for each frame but didn't put it inside decoder operations. |
@quickdahuk
|
@victolee0 The gradient calculated (<= 1.1367e-04) differs slightly. This difference may not be significant, but it's better to have a much lower difference. @a-r-r-o-w, Do you see anything suspicious in our implementation? |
@victolee0 I just found that if I used use_reentry=False, then it matches perfectly. I've updated the code. |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
The differing result is definitely suspicious. I would actually prefer having a refactor of the VAE so that we don't have to work with cache indexing in the way it's done here (and rather have it behave similar to what's done in CogVideoX and Mochi). If it's not possible to do it without indexing, we can consider removing the cache completely too.. The speed difference from using cache vs not is minimal from my past benchmarks with CogVideoX VAE and removing it, given it's complicated to do and probably doesn't save much time here, is a tradeoff that we could possibly make. cc @hlky @yiyixuxu Would you be able to take a look? |
What does this PR do?
Fixes #11071 (issue)
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@a-r-r-o-w