-
Couldn't load subscription status.
- Fork 25.7k
Description
🐛 Describe the bug
Currently, auxiliary values in torch.func.grad and torch.func.value_and_grad must be tensors or named tuples. This is a severe limitation compared to JAX's implementation:
>>> import jax
>>> jax.grad(lambda arr: (arr**2, 'aux'), has_aux=True)(2.)
(Array(4., dtype=float32, weak_type=True), 'aux')
>>> import torch
>>> torch.func.grad(lambda arr: (arr**2, 'aux'), has_aux=True)(torch.tensor(2.))
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[5], [line 3](vscode-notebook-cell:?execution_count=5&line=3)
1 import torch
----> [3](vscode-notebook-cell:?execution_count=5&line=3) torch.func.grad(lambda arr: (arr**2, 'aux'), has_aux=True)(torch.tensor(2.))
File ~/Documents/code/python/phaser/venv/lib/python3.12/site-packages/torch/_functorch/apis.py:398, in grad.<locals>.wrapper(*args, **kwargs)
397 def wrapper(*args, **kwargs):
--> [398](https://file+.vscode-resource.vscode-cdn.net/Users/colin/Documents/code/python/phaser/notebooks/~/Documents/code/python/phaser/venv/lib/python3.12/site-packages/torch/_functorch/apis.py:398) return eager_transforms.grad_impl(func, argnums, has_aux, args, kwargs)
File ~/Documents/code/python/phaser/venv/lib/python3.12/site-packages/torch/_functorch/eager_transforms.py:1406, in grad_impl(func, argnums, has_aux, args, kwargs)
1405 def grad_impl(func: Callable, argnums: argnums_t, has_aux: bool, args, kwargs):
-> [1406](https://file+.vscode-resource.vscode-cdn.net/Users/colin/Documents/code/python/phaser/notebooks/~/Documents/code/python/phaser/venv/lib/python3.12/site-packages/torch/_functorch/eager_transforms.py:1406) results = grad_and_value_impl(func, argnums, has_aux, args, kwargs)
1407 if has_aux:
1408 grad, (_, aux) = results
File ~/Documents/code/python/phaser/venv/lib/python3.12/site-packages/torch/_functorch/vmap.py:48, in doesnt_support_saved_tensors_hooks.<locals>.fn(*args, **kwargs)
45 @functools.wraps(f)
46 def fn(*args, **kwargs):
47 with torch.autograd.graph.disable_saved_tensors_hooks(message):
---> [48](https://file+.vscode-resource.vscode-cdn.net/Users/colin/Documents/code/python/phaser/notebooks/~/Documents/code/python/phaser/venv/lib/python3.12/site-packages/torch/_functorch/vmap.py:48) return f(*args, **kwargs)
File ~/Documents/code/python/phaser/venv/lib/python3.12/site-packages/torch/_functorch/eager_transforms.py:1398, in grad_and_value_impl(func, argnums, has_aux, args, kwargs)
1396 output = _undo_create_differentiable(output, level)
1397 if has_aux:
-> [1398](https://file+.vscode-resource.vscode-cdn.net/Users/colin/Documents/code/python/phaser/notebooks/~/Documents/code/python/phaser/venv/lib/python3.12/site-packages/torch/_functorch/eager_transforms.py:1398) aux = _undo_create_differentiable(aux, level)
1400 if has_aux:
1401 return grad_input, (output, aux)
File ~/Documents/code/python/phaser/venv/lib/python3.12/site-packages/torch/_functorch/eager_transforms.py:93, in _undo_create_differentiable(inps, level)
89 return tree_map(unwrap_tensors, tuple(x))
91 raise RuntimeError(f"Expected tensors, got unsupported type {type(x)}")
---> [93](https://file+.vscode-resource.vscode-cdn.net/Users/colin/Documents/code/python/phaser/notebooks/~/Documents/code/python/phaser/venv/lib/python3.12/site-packages/torch/_functorch/eager_transforms.py:93) return tree_map(unwrap_tensors, inps)
File ~/Documents/code/python/phaser/venv/lib/python3.12/site-packages/torch/utils/_pytree.py:1145, in tree_map(func, tree, is_leaf, *rests)
1143 leaves, treespec = tree_flatten(tree, is_leaf=is_leaf)
1144 flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
-> [1145](https://file+.vscode-resource.vscode-cdn.net/Users/colin/Documents/code/python/phaser/notebooks/~/Documents/code/python/phaser/venv/lib/python3.12/site-packages/torch/utils/_pytree.py:1145) return treespec.unflatten(map(func, *flat_args))
File ~/Documents/code/python/phaser/venv/lib/python3.12/site-packages/torch/utils/_pytree.py:982, in TreeSpec.unflatten(self, leaves)
980 def unflatten(self, leaves: Iterable[Any]) -> PyTree:
981 if not isinstance(leaves, (list, tuple)):
--> [982](https://file+.vscode-resource.vscode-cdn.net/Users/colin/Documents/code/python/phaser/notebooks/~/Documents/code/python/phaser/venv/lib/python3.12/site-packages/torch/utils/_pytree.py:982) leaves = list(leaves)
983 if len(leaves) != self.num_leaves:
984 raise ValueError(
985 f"treespec.unflatten(leaves): `leaves` has length {len(leaves)} "
986 f"but the spec refers to a pytree that holds {self.num_leaves} "
987 f"items ({self}).",
988 )
File ~/Documents/code/python/phaser/venv/lib/python3.12/site-packages/torch/_functorch/eager_transforms.py:91, in _undo_create_differentiable.<locals>.unwrap_tensors(x)
88 if isinstance(x, tuple):
89 return tree_map(unwrap_tensors, tuple(x))
---> [91](https://file+.vscode-resource.vscode-cdn.net/Users/colin/Documents/code/python/phaser/notebooks/~/Documents/code/python/phaser/venv/lib/python3.12/site-packages/torch/_functorch/eager_transforms.py:91) raise RuntimeError(f"Expected tensors, got unsupported type {type(x)}")
RuntimeError: Expected tensors, got unsupported type <class 'str'>It looks like this was discussed before in functorch#423, but later reverted. Is there anything structurally blocking this feature?
Versions
Collecting environment information...
PyTorch version: 2.7.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 15.5 (arm64)
GCC version: Could not collect
Clang version: 17.0.0 (clang-1700.0.13.5)
CMake version: version 3.31.7
Libc version: N/A
Python version: 3.12.11 (main, Jun 6 2025, 23:18:08) [Clang 16.0.0 (clang-1600.0.26.6)] (64-bit runtime)
Python platform: macOS-15.5-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Apple M1 Pro
Versions of relevant libraries:
[pip3] mypy-extensions==1.0.0
[pip3] numpy==2.3.2
[pip3] optree==0.17.0
[pip3] torch==2.7.0
[pip3] torchvision==0.22.0
[conda] Could not collect
cc @ezyang @gqchen @nikitaved @soulitzer @Varal7 @xmfan @jbschlosser @mruberry @zou3519 @Chillee @samdow @kshitij12345