Skip to content

Non-tensor auxiliary types don't work in torch.func.{grad, value_and_grad} #159667

@hexane360

Description

@hexane360

🐛 Describe the bug

Currently, auxiliary values in torch.func.grad and torch.func.value_and_grad must be tensors or named tuples. This is a severe limitation compared to JAX's implementation:

>>> import jax
>>> jax.grad(lambda arr: (arr**2, 'aux'), has_aux=True)(2.)
(Array(4., dtype=float32, weak_type=True), 'aux')

>>> import torch
>>> torch.func.grad(lambda arr: (arr**2, 'aux'), has_aux=True)(torch.tensor(2.))

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[5], [line 3](vscode-notebook-cell:?execution_count=5&line=3)
      1 import torch
----> [3](vscode-notebook-cell:?execution_count=5&line=3) torch.func.grad(lambda arr: (arr**2, 'aux'), has_aux=True)(torch.tensor(2.))

File ~/Documents/code/python/phaser/venv/lib/python3.12/site-packages/torch/_functorch/apis.py:398, in grad.<locals>.wrapper(*args, **kwargs)
    397 def wrapper(*args, **kwargs):
--> [398](https://file+.vscode-resource.vscode-cdn.net/Users/colin/Documents/code/python/phaser/notebooks/~/Documents/code/python/phaser/venv/lib/python3.12/site-packages/torch/_functorch/apis.py:398)     return eager_transforms.grad_impl(func, argnums, has_aux, args, kwargs)

File ~/Documents/code/python/phaser/venv/lib/python3.12/site-packages/torch/_functorch/eager_transforms.py:1406, in grad_impl(func, argnums, has_aux, args, kwargs)
   1405 def grad_impl(func: Callable, argnums: argnums_t, has_aux: bool, args, kwargs):
-> [1406](https://file+.vscode-resource.vscode-cdn.net/Users/colin/Documents/code/python/phaser/notebooks/~/Documents/code/python/phaser/venv/lib/python3.12/site-packages/torch/_functorch/eager_transforms.py:1406)     results = grad_and_value_impl(func, argnums, has_aux, args, kwargs)
   1407     if has_aux:
   1408         grad, (_, aux) = results

File ~/Documents/code/python/phaser/venv/lib/python3.12/site-packages/torch/_functorch/vmap.py:48, in doesnt_support_saved_tensors_hooks.<locals>.fn(*args, **kwargs)
     45 @functools.wraps(f)
     46 def fn(*args, **kwargs):
     47     with torch.autograd.graph.disable_saved_tensors_hooks(message):
---> [48](https://file+.vscode-resource.vscode-cdn.net/Users/colin/Documents/code/python/phaser/notebooks/~/Documents/code/python/phaser/venv/lib/python3.12/site-packages/torch/_functorch/vmap.py:48)         return f(*args, **kwargs)

File ~/Documents/code/python/phaser/venv/lib/python3.12/site-packages/torch/_functorch/eager_transforms.py:1398, in grad_and_value_impl(func, argnums, has_aux, args, kwargs)
   1396     output = _undo_create_differentiable(output, level)
   1397     if has_aux:
-> [1398](https://file+.vscode-resource.vscode-cdn.net/Users/colin/Documents/code/python/phaser/notebooks/~/Documents/code/python/phaser/venv/lib/python3.12/site-packages/torch/_functorch/eager_transforms.py:1398)         aux = _undo_create_differentiable(aux, level)
   1400 if has_aux:
   1401     return grad_input, (output, aux)

File ~/Documents/code/python/phaser/venv/lib/python3.12/site-packages/torch/_functorch/eager_transforms.py:93, in _undo_create_differentiable(inps, level)
     89         return tree_map(unwrap_tensors, tuple(x))
     91     raise RuntimeError(f"Expected tensors, got unsupported type {type(x)}")
---> [93](https://file+.vscode-resource.vscode-cdn.net/Users/colin/Documents/code/python/phaser/notebooks/~/Documents/code/python/phaser/venv/lib/python3.12/site-packages/torch/_functorch/eager_transforms.py:93) return tree_map(unwrap_tensors, inps)

File ~/Documents/code/python/phaser/venv/lib/python3.12/site-packages/torch/utils/_pytree.py:1145, in tree_map(func, tree, is_leaf, *rests)
   1143 leaves, treespec = tree_flatten(tree, is_leaf=is_leaf)
   1144 flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
-> [1145](https://file+.vscode-resource.vscode-cdn.net/Users/colin/Documents/code/python/phaser/notebooks/~/Documents/code/python/phaser/venv/lib/python3.12/site-packages/torch/utils/_pytree.py:1145) return treespec.unflatten(map(func, *flat_args))

File ~/Documents/code/python/phaser/venv/lib/python3.12/site-packages/torch/utils/_pytree.py:982, in TreeSpec.unflatten(self, leaves)
    980 def unflatten(self, leaves: Iterable[Any]) -> PyTree:
    981     if not isinstance(leaves, (list, tuple)):
--> [982](https://file+.vscode-resource.vscode-cdn.net/Users/colin/Documents/code/python/phaser/notebooks/~/Documents/code/python/phaser/venv/lib/python3.12/site-packages/torch/utils/_pytree.py:982)         leaves = list(leaves)
    983     if len(leaves) != self.num_leaves:
    984         raise ValueError(
    985             f"treespec.unflatten(leaves): `leaves` has length {len(leaves)} "
    986             f"but the spec refers to a pytree that holds {self.num_leaves} "
    987             f"items ({self}).",
    988         )

File ~/Documents/code/python/phaser/venv/lib/python3.12/site-packages/torch/_functorch/eager_transforms.py:91, in _undo_create_differentiable.<locals>.unwrap_tensors(x)
     88 if isinstance(x, tuple):
     89     return tree_map(unwrap_tensors, tuple(x))
---> [91](https://file+.vscode-resource.vscode-cdn.net/Users/colin/Documents/code/python/phaser/notebooks/~/Documents/code/python/phaser/venv/lib/python3.12/site-packages/torch/_functorch/eager_transforms.py:91) raise RuntimeError(f"Expected tensors, got unsupported type {type(x)}")

RuntimeError: Expected tensors, got unsupported type <class 'str'>

It looks like this was discussed before in functorch#423, but later reverted. Is there anything structurally blocking this feature?

Versions

Collecting environment information...
PyTorch version: 2.7.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 15.5 (arm64)
GCC version: Could not collect
Clang version: 17.0.0 (clang-1700.0.13.5)
CMake version: version 3.31.7
Libc version: N/A

Python version: 3.12.11 (main, Jun 6 2025, 23:18:08) [Clang 16.0.0 (clang-1600.0.26.6)] (64-bit runtime)
Python platform: macOS-15.5-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M1 Pro

Versions of relevant libraries:
[pip3] mypy-extensions==1.0.0
[pip3] numpy==2.3.2
[pip3] optree==0.17.0
[pip3] torch==2.7.0
[pip3] torchvision==0.22.0
[conda] Could not collect

cc @ezyang @gqchen @nikitaved @soulitzer @Varal7 @xmfan @jbschlosser @mruberry @zou3519 @Chillee @samdow @kshitij12345

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: autogradRelated to torch.autograd, and the autograd engine in generalmodule: functional UXmodule: functorchPertaining to torch.func or pytorch/functorchtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions