-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
optim.apply_optimizer_in_backward does not account for gradient accumulation #124523
Comments
We could do something like this (don't persecute me for the code, I just asked llama3 to solve this GitHub issue) which just creates a new tensor, but the issue with this is that we are now increasing the memory footprint. For AdamW, this is still a net win as we went from 2 copies of the parameters to one (for the accumulated gradient).
|
@rohan-varma thoughts? |
I would suggest using the implementation from https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html that will be more reliable.
Or have a counter that counts to detect when you reach the last micro batch. |
I think I asked @rohan-varma about the above implementation and he said it's not tested with FSDP, whereas the one from the dist api is? I agree that this implementation would make things easier, but do you know whether you can use this with FSDP? |
@rohan-varma any ideas on how to make this API compatible with gradient clipping? This API is actually really good -the implementation that Huggingface uses for this functionality is incredibly scuffed by comparison. I think all we would need to do is in fsdp_runtime_utils add some kind of conditional logging if we detect optimizers_in_backward? But messing around with the dist.set_debug_level() has me kind of confused on how to use this. |
I was curious how logging is related to gradient clipping? My understanding is that applying optimizer in backward is algorithmically incompatible with gradient clipping, where there is a dependency on the total norm computed over all gradients before running the optimizer step. |
I would also be curious how we can tell if the user intends to use gradient accumulation or not. |
So for Gradient clipping, anyways for big models we're not going to do the allgather anyways and we can just apply per parameter Gradient clipping. This is fine and can already be implemented by the user with the provided API. However I'm not sure how to provide the Gradient norm (a common metric) back from the API. As far as I can tell FSDP runtime utils is doing the actual update step. |
@kiddyboots216 Could you explain what per-parameter gradient clipping is? Does it mean that the total norm is only computed per gradient rather than overall all gradients of the model? I think that PyTorch is moving toward |
Yeah so we iterate over each parameter in the model and clip that to
something. If my desired global gradient clipping norm is 10.0 then I could
just clip each parameter to 10.0/sqrt(len(parameters)).
Yeah I did see that register_post_accumulate_grad_hook exists but saw that
FSDP is explicitly looking for this attr
https://github.com/pytorch/pytorch/blob/e3b9b71684ae4f81d6128854b12771f83064c5ce/torch/distributed/fsdp/_runtime_utils.py#L186
so thought maybe in this function we can do some kind of logging there -?
In particular when I say "conditional logging" I mean like a if rank==0
then do some logging. But if the solution would end up being
register_post_accumulate_grad_hook then that's not the way to go perhaps.
…On Mon, Apr 29, 2024 at 11:15 AM Andrew Gu ***@***.***> wrote:
@kiddyboots216 <https://github.com/kiddyboots216> Could you explain what
per-parameter gradient clipping is? Does it mean that the total norm is
only computed per gradient rather than overall all gradients of the model?
I think that PyTorch is moving toward register_post_accumulate_grad_hook
as the way to implement optimizer in backward. For that hook signature,
there is no return value. If you want the return value, then maybe you need
to write it to some global data structure.
—
Reply to this email directly, view it on GitHub
<#124523 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AFM2FYKV53P4FM66DX5YO33Y7ZP2DAVCNFSM6AAAAABGP2WQDOVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDAOBTGAYTCOBVHE>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
Assuming that you meant clipping each gradient instead of each parameter, then I think this makes sense to me now!
It might be hard to add that kind of logging directly into the PyTorch code since it may not be what all users want. I think the right direction is to move toward supporting By the way, what are you looking to log? |
Yeah, that's reasonable. So how would that be supported in FSDP? I guess that function just has to append to the specific parameter attribute that FSDP is looking for?
Well typically if we do optimizer.step() we would want to log the gradient norm (I think that's basically a universal usecase). In this case, just like torch.stack([param grads], dim=0).norm(2, dim=0) would be our overall grad norm. |
We are not actively developing on this FSDP anymore (i.e. FSDP1) and have instead been focusing on FSDP2 (#114299). The current optimizer-in-backward support in FSDP1 is kind of hard-coded to just support optimizer-in-backward as is, making these kinds of customizations hard. Like mentioned above, the ideal path is to support the post-accumulate-grad hook instead, and users can use that to implement optimizer in backward.
If we were to support the post-accumulate-grad hook, then you should be able to maintain your own online |
cc: @janeyx99 |
Would agree with @awgu that the using the hook would be the way to go for the future. That said support for the hook and FSDP2 is planned and not yet complete, so not sure if there is a workaround for the time being. |
You can just make an Optimizer locally that does whatever you want it to do, and define a local "apply_mask_in_backward" that instantiates that Optimizer and appends it to param._in_backward_optimizers, and pass model.parameters() to this function. That's what I'm doing for the time being and it works fine. |
🐛 Describe the bug
https://github.com/pytorch/pytorch/blob/main/torch/distributed/optim/apply_optimizer_in_backward.py
This fires the optimizer every time the gradient gets accumulated. However, in practice we might do gradient accumulation. So if we have a logical batch size of 64 and a physical batch size of 8, "apply_optimizer_in_backward" will not distinguish between the microbatches.
My suggested fix: just raise a warning for the user if we detect that gradient accumulation is being used with apply_optimizer_in_backward. I don't think it's possible to actually make this function work with gradient accumulation.
Versions
Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] torch==2.2.2
[pip3] triton==2.2.0
[conda] numpy 1.26.4 pypi_0 pypi
[conda] torch 2.2.2 pypi_0 pypi
[conda] triton 2.2.0 pypi_0 pypi
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k @vincentqb @jbschlosser @albanD @janeyx99 @crcrpar
The text was updated successfully, but these errors were encountered: