-
Notifications
You must be signed in to change notification settings - Fork 24.8k
Description
Context
Gradient norm clipping is a popular technique for stabilizing training, which requires computing the total norm with respect to the model's gradients. This involves a norm reduction across all of the gradients down to a single scalar.
PyTorch's clip_grad_morm_
offers both single-tensor (foreach=False
) and multi-tensor (foreach=True
) implementations. However, even the multi-tensor foreach
implementation incurs high CPU overhead when computing the total norm over a large number of tensors.
pytorch/torch/nn/utils/clip_grad.py
Lines 81 to 96 in 3434a54
norms: List[Tensor] = [] | |
for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment] | |
if (foreach is None and _has_foreach_support(device_grads, device)) or ( | |
foreach and _device_has_foreach_support(device) | |
): | |
norms.extend(torch._foreach_norm(device_grads, norm_type)) | |
elif foreach: | |
raise RuntimeError( | |
f"foreach=True was passed, but can't use the foreach API on {device.type} tensors" | |
) | |
else: | |
norms.extend([torch.linalg.vector_norm(g, norm_type) for g in device_grads]) | |
total_norm = torch.linalg.vector_norm( | |
torch.stack([norm.to(first_device) for norm in norms]), norm_type | |
) |
The foreach implementation involves:
torch._foreach_norm(tensors)
to return a list of 0D scalars, representing the norm of each gradienttorch.stack
to cat the 0D scalars into a 1D tensortorch.linalg.vector_norm()
to compute the norm of the norms, representing the final total norm
Issue
The foreach implementation incurs much unnecessary CPU overhead, making the clipping heavily CPU bound when operating on a large number of tensors. For example, for ~2000 tensors, the total norm calculation takes >18 ms on CPU, while only 1.3 ms on GPU (for a particular real workload -- larger tensors would make this slower).
Assuming N
tensors, some inefficiencies in the existing implementation arise from:
_foreach_norm
must callN
aten::empty({})
to construct theN
0D scalar outputs (code).- The
N
0D scalars need to bestack
d for the final norm reduction.stack
requires 1D tensors, so each of theN
0D scalar gets unsqueezed again instack
.
Together, this leads to an extra 2N
dispatcher calls to handle the N
intermediate scalars. Ideally, we can avoid materializing these N
intermediates, especially as torch.Tensor
s, where one option is a fused kernel.
Today, torch.compile
cannot address this issue in a satisfying way. Default torch.compile
cannot achieve horizontal fusion, leading to slower performance than eager mode. torch.compile(mode="reduce-overhead")
does reduce overhead more than eager but results in 2x memory usage, likely due to copying the gradients into CUDA graph addresses. Note that we likely cannot mark the inputs as static because gradients are computed anew every iteration at possibly different addresses, and for my use case, we cannot compile the gradient allocation with torch.compile
(due to FSDP).
Here is an example script for getting profiler traces of various implementations: P1529496302
The ask is to provide some native way to make this total norm calculation not CPU-overhead bound, e.g. via a fused op.
cc @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki @crcrpar @mcarilli @janeyx99 @ezyang @chauhang @penguinwu
Metadata
Metadata
Assignees
Labels
Type
Projects
Status