Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optim.apply_optimizer_in_backward does not account for gradient accumulation #124523

Open
kiddyboots216 opened this issue Apr 19, 2024 · 16 comments
Labels
module: optimizer Related to torch.optim oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@kiddyboots216
Copy link
Contributor

kiddyboots216 commented Apr 19, 2024

🐛 Describe the bug

https://github.com/pytorch/pytorch/blob/main/torch/distributed/optim/apply_optimizer_in_backward.py

This fires the optimizer every time the gradient gets accumulated. However, in practice we might do gradient accumulation. So if we have a logical batch size of 64 and a physical batch size of 8, "apply_optimizer_in_backward" will not distinguish between the microbatches.

My suggested fix: just raise a warning for the user if we detect that gradient accumulation is being used with apply_optimizer_in_backward. I don't think it's possible to actually make this function work with gradient accumulation.

Versions

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] torch==2.2.2
[pip3] triton==2.2.0
[conda] numpy 1.26.4 pypi_0 pypi
[conda] torch 2.2.2 pypi_0 pypi
[conda] triton 2.2.0 pypi_0 pypi

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k @vincentqb @jbschlosser @albanD @janeyx99 @crcrpar

@kiddyboots216
Copy link
Contributor Author

We could do something like this (don't persecute me for the code, I just asked llama3 to solve this GitHub issue) which just creates a new tensor, but the issue with this is that we are now increasing the memory footprint. For AdamW, this is still a net win as we went from 2 copies of the parameters to one (for the accumulated gradient).

def _apply_optimizer_in_backward_to_param(param: torch.nn.Parameter, accumulation_steps: int) -> None:
    # view_as creates a node in autograd graph that allows us access to the
    # parameter's AccumulateGrad autograd function object. We register a
    # hook on this object to fire the optimizer when the gradient for
    # this parameter is ready (has been accumulated into .grad field)

    # Don't create a new acc_grad if we already have one
    # i.e. for shared parameters or attaching multiple optimizers to a param.
    if param not in param_to_acc_grad_map:
        param_to_acc_grad_map[param] = param.view_as(param).grad_fn.next_functions[0][0]

    optimizer = optimizer_class([param], **optimizer_kwargs)

    if not hasattr(param, "_in_backward_optimizers"):
        param._in_backward_optimizers = []  # type: ignore[attr-defined]
        # TODO: Remove these attributes once we have a better way of accessing
        # optimizer classes and kwargs for a parameter.
        param._optimizer_classes = []  # type: ignore[attr-defined]
        param._optimizer_kwargs = []  # type: ignore[attr-defined]

    param._in_backward_optimizers.append(optimizer)  # type: ignore[attr-defined]
    param._optimizer_classes.append(optimizer_class)  # type: ignore[attr-defined]
    param._optimizer_kwargs.append(optimizer_kwargs)  # type: ignore[attr-defined]

    if not register_hook:
        return

    grad_accumulator = torch.zeros_like(param.grad)  # Initialize grad accumulator

    def optimizer_hook(*_unused) -> None:
        nonlocal grad_accumulator
        grad_accumulator += param.grad
        param.grad = None

        if grad_accumulator.numel() >= accumulation_steps:
            param.grad = grad_accumulator
            grad_accumulator = torch.zeros_like(param.grad)
            for opt in param._in_backward_optimizers:  # type: ignore[attr-defined]
                opt.step()
            param.grad = None

    handle = param_to_acc_grad_map[param].register_hook(optimizer_hook)  # type: ignore[attr-defined]
    if param not in param_to_optim_hook_handle_map:
        param_to_optim_hook_handle_map[param] = []
    param_to_optim_hook_handle_map[param].append(handle)

@kiddyboots216
Copy link
Contributor Author

@rohan-varma thoughts?

@msaroufim msaroufim added module: optimizer Related to torch.optim triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Apr 20, 2024
@albanD
Copy link
Collaborator

albanD commented Apr 22, 2024

I would suggest using the implementation from https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html that will be more reliable.
Also it should make it significantly easier to have in the hook:

def optimizer_hook(parameter) -> None:
  if update_step:
    optimizer_dict[parameter].step()
    optimizer_dict[parameter].zero_grad()

Or have a counter that counts to detect when you reach the last micro batch.

@albanD albanD added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Apr 22, 2024
@kiddyboots216
Copy link
Contributor Author

I think I asked @rohan-varma about the above implementation and he said it's not tested with FSDP, whereas the one from the dist api is? I agree that this implementation would make things easier, but do you know whether you can use this with FSDP?

@kiddyboots216
Copy link
Contributor Author

@rohan-varma any ideas on how to make this API compatible with gradient clipping? This API is actually really good -the implementation that Huggingface uses for this functionality is incredibly scuffed by comparison. I think all we would need to do is in fsdp_runtime_utils add some kind of conditional logging if we detect optimizers_in_backward? But messing around with the dist.set_debug_level() has me kind of confused on how to use this.

@awgu
Copy link
Contributor

awgu commented Apr 29, 2024

@rohan-varma any ideas on how to make this API compatible with gradient clipping? This API is actually really good -the implementation that Huggingface uses for this functionality is incredibly scuffed by comparison. I think all we would need to do is in fsdp_runtime_utils add some kind of conditional logging if we detect optimizers_in_backward? But messing around with the dist.set_debug_level() has me kind of confused on how to use this.

I was curious how logging is related to gradient clipping?

My understanding is that applying optimizer in backward is algorithmically incompatible with gradient clipping, where there is a dependency on the total norm computed over all gradients before running the optimizer step.

@awgu
Copy link
Contributor

awgu commented Apr 29, 2024

My suggested fix: just raise a warning for the user if we detect that gradient accumulation is being used with apply_optimizer_in_backward. I don't think it's possible to actually make this function work with gradient accumulation.

I would also be curious how we can tell if the user intends to use gradient accumulation or not.

@kiddyboots216
Copy link
Contributor Author

@rohan-varma any ideas on how to make this API compatible with gradient clipping? This API is actually really good -the implementation that Huggingface uses for this functionality is incredibly scuffed by comparison. I think all we would need to do is in fsdp_runtime_utils add some kind of conditional logging if we detect optimizers_in_backward? But messing around with the dist.set_debug_level() has me kind of confused on how to use this.

I was curious how logging is related to gradient clipping?

My understanding is that applying optimizer in backward is algorithmically incompatible with gradient clipping, where there is a dependency on the total norm computed over all gradients before running the optimizer step.

So for Gradient clipping, anyways for big models we're not going to do the allgather anyways and we can just apply per parameter Gradient clipping. This is fine and can already be implemented by the user with the provided API. However I'm not sure how to provide the Gradient norm (a common metric) back from the API. As far as I can tell FSDP runtime utils is doing the actual update step.

@awgu
Copy link
Contributor

awgu commented Apr 29, 2024

@kiddyboots216 Could you explain what per-parameter gradient clipping is? Does it mean that the total norm is only computed per gradient rather than overall all gradients of the model?

I think that PyTorch is moving toward register_post_accumulate_grad_hook as the way to implement optimizer in backward. For that hook signature, there is no return value. If you want the return value, then maybe you need to write it to some global data structure.

@kiddyboots216
Copy link
Contributor Author

kiddyboots216 commented Apr 29, 2024 via email

@awgu
Copy link
Contributor

awgu commented Apr 29, 2024

Yeah so we iterate over each parameter in the model and clip that to
something. If my desired global gradient clipping norm is 10.0 then I could
just clip each parameter to 10.0/sqrt(len(parameters)).

Assuming that you meant clipping each gradient instead of each parameter, then I think this makes sense to me now!

so thought maybe in this function we can do some kind of logging there -?
In particular when I say "conditional logging" I mean like a if rank==0
then do some logging.

It might be hard to add that kind of logging directly into the PyTorch code since it may not be what all users want. I think the right direction is to move toward supporting register_post_accumulate_grad_hook so that the user can customize as desired.

By the way, what are you looking to log?

@kiddyboots216
Copy link
Contributor Author

It might be hard to add that kind of logging directly into the PyTorch code since it may not be what all users want. I think the right direction is to move toward supporting register_post_accumulate_grad_hook so that the user can customize as desired.

Yeah, that's reasonable. So how would that be supported in FSDP? I guess that function just has to append to the specific parameter attribute that FSDP is looking for?

By the way, what are you looking to log?

Well typically if we do optimizer.step() we would want to log the gradient norm (I think that's basically a universal usecase). In this case, just like torch.stack([param grads], dim=0).norm(2, dim=0) would be our overall grad norm.

@awgu
Copy link
Contributor

awgu commented Apr 29, 2024

So how would that be supported in FSDP? I guess that function just has to append to the specific parameter attribute that FSDP is looking for?

We are not actively developing on this FSDP anymore (i.e. FSDP1) and have instead been focusing on FSDP2 (#114299). The current optimizer-in-backward support in FSDP1 is kind of hard-coded to just support optimizer-in-backward as is, making these kinds of customizations hard. Like mentioned above, the ideal path is to support the post-accumulate-grad hook instead, and users can use that to implement optimizer in backward.

In this case, just like torch.stack([param grads], dim=0).norm(2, dim=0) would be our overall grad norm.

If we were to support the post-accumulate-grad hook, then you should be able to maintain your own online $\ell_2$-norm squared as you process gradients through backward and then compute a square root at the end of backward to get your overall gradient norm.

@awgu
Copy link
Contributor

awgu commented Apr 29, 2024

cc: @janeyx99

@janeyx99
Copy link
Contributor

Would agree with @awgu that the using the hook would be the way to go for the future. That said support for the hook and FSDP2 is planned and not yet complete, so not sure if there is a workaround for the time being.

@kiddyboots216
Copy link
Contributor Author

You can just make an Optimizer locally that does whatever you want it to do, and define a local "apply_mask_in_backward" that instantiates that Optimizer and appends it to param._in_backward_optimizers, and pass model.parameters() to this function. That's what I'm doing for the time being and it works fine.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: optimizer Related to torch.optim oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants