Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP]Add gradient checkpointing support for AutoencoderKLWan #11105

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

victolee0
Copy link
Contributor

What does this PR do?

Fixes #11071 (issue)

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@a-r-r-o-w

@victolee0 victolee0 changed the title Add gradient checkpointing support for AutoencoderKLWan [WIP]Add gradient checkpointing support for AutoencoderKLWan Mar 18, 2025
Copy link
Member

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, very nice!

@a-r-r-o-w
Copy link
Member

@bot /style

@victolee0
Copy link
Contributor Author

victolee0 commented Mar 19, 2025

Thanks, very nice!

My PR failed to pass the test code. Would you like to look at this link?

test failure case: test_effective_gradient_checkpointing
test: test_models_autoencoder_wan.py

@quickdahuk
Copy link

@victolee0 I implemented it the same way you did. It works fine for the forward pass, but during the backward pass, the cache index gets mixed up and goes out of bounds. I might need to use a dictionary for the cache mechanism instead of an index-based list.

@victolee0
Copy link
Contributor Author

victolee0 commented Mar 29, 2025

@victolee0 I implemented it the same way you did. It works fine for the forward pass, but during the backward pass, the cache index gets mixed up and goes out of bounds. I might need to use a dictionary for the cache mechanism instead of an index-based list.

@quickdahuk
When I add exception handling as shown below, I get a different error.

Code

def forward(self, x, feat_cache=None, feat_idx=[0]):
        # Apply shortcut connection
        h = self.conv_shortcut(x)

        # First normalization and activation
        x = self.norm1(x)
        x = self.nonlinearity(x)

        # exception handling
        if feat_cache is not None and len(feat_cache) < feat_idx[0]:
            idx = feat_idx[0]
            cache_x = x[:, :, -CACHE_T:, :, :].clone()
            if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
                cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)

            x = self.conv1(x, feat_cache[idx])
            feat_cache[idx] = cache_x
            feat_idx[0] += 1
        else:
            x = self.conv1(x)

        # Second normalization and activation
        x = self.norm2(x)
        x = self.nonlinearity(x)

        # Dropout
        x = self.dropout(x)

        # exception handling
        if feat_cache is not None and len(feat_cache) < feat_idx[0]:
            idx = feat_idx[0]
            cache_x = x[:, :, -CACHE_T:, :, :].clone()
            if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
                cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)

            x = self.conv2(x, feat_cache[idx])
            feat_cache[idx] = cache_x
            feat_idx[0] += 1
        else:
            x = self.conv2(x)

        # Add residual connection
        return x + h

test case error

E           torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint: A different number of tensors was saved during the original forward and recomputation.
E           Number of tensors saved during forward: 8
E           Number of tensors saved during recomputation: 6

@quickdahuk
Copy link

@victolee0 I've implemented gradient checkpointing for the decoder. I don't need it for the encoder now. The training is working fine for me. I implemented checkpointing for each frame but didn't put it inside decoder operations.
I've the code here.

@victolee0
Copy link
Contributor Author

@quickdahuk
After copying and pasting the code you provided, I still get errors when running the test code on the same test.
I'm including the error message here:

a = tensor([-4.6806e-05, -2.9469e-05,  1.2341e-04]), b = tensor([-3.7164e-06, -4.1561e-06,  9.7381e-06]), args = (), kwargs = {'atol': 5e-05}

    def torch_all_close(a, b, *args, **kwargs):
        if not is_torch_available():
            raise ValueError("PyTorch needs to be installed to use this function.")
        if not torch.allclose(a, b, *args, **kwargs):
>           assert False, f"Max diff is absolute {(a - b).abs().max()}. Diff tensor is {(a - b).abs()}."
E           AssertionError: Max diff is absolute 0.00011366714170435444. Diff tensor is tensor([4.3089e-05, 2.5313e-05, 1.1367e-04]).

src/diffusers/utils/testing_utils.py:111: AssertionError

@quickdahuk
Copy link

@victolee0 The gradient calculated (<= 1.1367e-04) differs slightly. This difference may not be significant, but it's better to have a much lower difference. @a-r-r-o-w, Do you see anything suspicious in our implementation?

@quickdahuk
Copy link

@victolee0 I just found that if I used use_reentry=False, then it matches perfectly. I've updated the code.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@a-r-r-o-w
Copy link
Member

The differing result is definitely suspicious. I would actually prefer having a refactor of the VAE so that we don't have to work with cache indexing in the way it's done here (and rather have it behave similar to what's done in CogVideoX and Mochi). If it's not possible to do it without indexing, we can consider removing the cache completely too..

The speed difference from using cache vs not is minimal from my past benchmarks with CogVideoX VAE and removing it, given it's complicated to do and probably doesn't save much time here, is a tradeoff that we could possibly make. cc @hlky @yiyixuxu Would you be able to take a look?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

AutoencoderKLWan - support grandient_checkpointing
4 participants