Skip to content

No way for low-overhead total norm in native PyTorch with large number of tensors #133586

@awgu

Description

@awgu

Context
Gradient norm clipping is a popular technique for stabilizing training, which requires computing the total norm with respect to the model's gradients. This involves a norm reduction across all of the gradients down to a single scalar.

PyTorch's clip_grad_morm_ offers both single-tensor (foreach=False) and multi-tensor (foreach=True) implementations. However, even the multi-tensor foreach implementation incurs high CPU overhead when computing the total norm over a large number of tensors.

norms: List[Tensor] = []
for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment]
if (foreach is None and _has_foreach_support(device_grads, device)) or (
foreach and _device_has_foreach_support(device)
):
norms.extend(torch._foreach_norm(device_grads, norm_type))
elif foreach:
raise RuntimeError(
f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
)
else:
norms.extend([torch.linalg.vector_norm(g, norm_type) for g in device_grads])
total_norm = torch.linalg.vector_norm(
torch.stack([norm.to(first_device) for norm in norms]), norm_type
)

The foreach implementation involves:

  1. torch._foreach_norm(tensors) to return a list of 0D scalars, representing the norm of each gradient
  2. torch.stack to cat the 0D scalars into a 1D tensor
  3. torch.linalg.vector_norm() to compute the norm of the norms, representing the final total norm

Issue
The foreach implementation incurs much unnecessary CPU overhead, making the clipping heavily CPU bound when operating on a large number of tensors. For example, for ~2000 tensors, the total norm calculation takes >18 ms on CPU, while only 1.3 ms on GPU (for a particular real workload -- larger tensors would make this slower).
Screenshot 2024-08-15 at 12 47 18 PM

Assuming N tensors, some inefficiencies in the existing implementation arise from:

  • _foreach_norm must call N aten::empty({}) to construct the N 0D scalar outputs (code).
  • The N 0D scalars need to be stackd for the final norm reduction. stack requires 1D tensors, so each of the N 0D scalar gets unsqueezed again in stack.

Together, this leads to an extra 2N dispatcher calls to handle the N intermediate scalars. Ideally, we can avoid materializing these N intermediates, especially as torch.Tensors, where one option is a fused kernel.

Today, torch.compile cannot address this issue in a satisfying way. Default torch.compile cannot achieve horizontal fusion, leading to slower performance than eager mode. torch.compile(mode="reduce-overhead") does reduce overhead more than eager but results in 2x memory usage, likely due to copying the gradients into CUDA graph addresses. Note that we likely cannot mark the inputs as static because gradients are computed anew every iteration at possibly different addresses, and for my use case, we cannot compile the gradient allocation with torch.compile (due to FSDP).

Here is an example script for getting profiler traces of various implementations: P1529496302

The ask is to provide some native way to make this total norm calculation not CPU-overhead bound, e.g. via a fused op.

cc @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki @crcrpar @mcarilli @janeyx99 @ezyang @chauhang @penguinwu

Metadata

Metadata

Assignees

Labels

module: mtaIssues related to multi-tensor apply kernels and foreach functionsmodule: nnRelated to torch.nnoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

Status

To pick up

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions