Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Option to let DistributedDataParallel know in advance unused parameters at each forward pass #90171

Open
netw0rkf10w opened this issue Dec 5, 2022 · 5 comments
Labels
enhancement Not as big of a feature, but technically not a bug. Should be easy to fix oncall: distributed Add this issue/PR to distributed oncall triage queue

Comments

@netw0rkf10w
Copy link

netw0rkf10w commented Dec 5, 2022

🚀 The feature, motivation and pitch

Motivation: In models with stochastic depth and the like, at each forward pass some layers (or parts of them) are skipped and thus one needs to set find_unused_parameters=True, which makes training much slower in general. Yet, one can implement these models in such a way that the unused parameters at each step are known in advance (e.g., the layer sampling is done before the model forward pass and not on-the-fly). It would then be great if we could feed this information to the DDP model so that it doesn't need to find the unused parameters.

The usage could be something like the following:

model = DistributedDataParallel(model, find_unused_parameters=False)
for x in dataloader:
  # random_layer_sampling should return the same layers across GPUs
  layers_to_skip = random_layer_sampling(model)
  p = get_parameters(layers_to_skip)
  output = model(x, unused_parameters=p)
  ...

Alternatives

Currently there is no option other than setting find_unused_parameters=True.

Additional context

No response

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu

@jbschlosser jbschlosser added oncall: distributed Add this issue/PR to distributed oncall triage queue enhancement Not as big of a feature, but technically not a bug. Should be easy to fix labels Dec 6, 2022
@aazzolini
Copy link
Contributor

aazzolini commented Dec 10, 2022

Thanks for the suggestion, @netw0rkf10w ! This makes sense.
I wonder if there would be a way to implement it with existing DDP. Maybe you can force all parameters to be used by, e.g. implementing an autograd.Function that takes parameters as input and always return zeros in backward?

Something like this (looks a bit ugly but could be made a little prettier with a wrapper etc.):
Basically , we pass un-used parameters to "UseParameters" which will create zero-gradients for them at every iteration.


import torch
from torch import nn
from torch.nn.parallel import DistributedDataParallel as DDP


class UseParameters(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, *params):
        """
        Simply returns the input, but save parameter shapes
        """
        ctx.shapes = [param.shape for param in params]
        return input

    @staticmethod
    def backward(ctx, grad_output):
        """
        - Just forward the gradient of the input.
        - Return zeros for the gradient of the parameters.
        """
        return (grad_output, ) + tuple(torch.zeros(shape) for shape in ctx.shapes)


class MyModel(nn.Module):
    def __init__(self):
        nn.Module.__init__(self)
        self.linear = nn.Linear(2,4)
        self.another = nn.Linear(2,4)

    def forward(self, x, cond):
        if cond:
            return UseParameters.apply(self.linear(x), *self.another.parameters())
        else:
            return UseParameters.apply(self.another(x), *self.linear.parameters())


if __name__ == '__main__':
    torch.distributed.init_process_group(backend='gloo')
    x = torch.nn.parallel.DistributedDataParallel(MyModel(), find_unused_parameters=False)
    x(torch.randn(3,2), True).sum().backward()

    x(torch.randn(3,2), False).sum().backward()

    x(torch.randn(3,2), True).sum().backward()
    print([p.grad for p in x.parameters()])

    x(torch.randn(3,2), False).sum().backward()
    print([p.grad for p in x.parameters()])

@netw0rkf10w
Copy link
Author

netw0rkf10w commented Jan 27, 2023

@aazzolini I'm really sorry for my late response. Unfortunately I was unable to reply earlier :(

I have to say that what you have proposed is very clever! However it seems to me that it's still suboptimal because it still allocates GPU memory for unused parameters, and gradient reduction still happens for those.

Let me summarize below the different solutions in terms of running time and memory footprint. Please correct me if I'm wrong because I don't know very well how DDP works.

1. Optimal solution: DDP knows in advance which parameters (or layers) are unused

If DDP knows in advance which layers are unused at each forward pass, then it should be able to ignores them during both forward and backward passes.

✅ Running time: good, no forward and no reduction for unused layers.
✅ Memory: good, no memory allocation for unused gradients.

2. First naive solution: Setting find_unused_parameters=True

❌ Running time: bad, the time it takes to find the unused parameters.
✅ Memory: good, no memory allocation for unused gradients.

3. Second naive solution: Multipling by zero to ignore unused layers

class MyModel(nn.Module):
    def __init__(self):
        nn.Module.__init__(self)
        self.linear = nn.Linear(2,4)
        self.another = nn.Linear(2,4)

    def forward(self, x, cond):
        w = 1 if cond else 0
        return w * self.linear(x) + (1-w) * self.another(x)

❌ Running time: bad but in some cases still better than the previous
❌ Memory: bad.

4. aazzolini's solution: Use a custom autograd function

✅ Running time: good, no forward for unused layers (but reduction still happens though).
❌ Memory: bad.

Please let me know what you think. Thanks a lot!

@netw0rkf10w
Copy link
Author

cc as well @mrshenli who used to answer my questions on DDP

@netw0rkf10w
Copy link
Author

@aazzolini I finally had the time to implement and benchmark your solution. Unfortunately it's even slower than simply set find_unused_parameters=True (second solution above) :(

@netw0rkf10w
Copy link
Author

I've just realised that this was discussed in the PyTorch DDP paper by @mrshenli et al:

Screenshot 2023-07-06 at 13 15 26

I believe that the solution that I proposed above is much simpler than any of the ones proposed in the paper. What do you think about this, @mrshenli?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement Not as big of a feature, but technically not a bug. Should be easy to fix oncall: distributed Add this issue/PR to distributed oncall triage queue
Projects
None yet
Development

No branches or pull requests

3 participants