This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Description
The problem
There are two types of batching rules:
- decompositions for CompositeImplicitAutograd operations
- everything else
Registering a type-2 batching rule for a CompositeImplicitAutograd operator "op" is a trap. When we do something like vmap(grad(op))(x)
, then op
gets decomposed, and the batching rule never gets hit. It's possible that op
's decomposition decomposes into other ops that don't have a batching rule or straight up doesn't work with vmap due to doing .data_ptr accesses.
We should change it so that registering a batching rule for a CompositeImplicitAutograd operator throws an error or there is some test that checks this.
The solution
@albanD tells me we have a way to tell if an operator is CompositeImplicitAutograd print(torch._C._dispatch_has_kernel_for_dispatch_key("aten::reshape", "CompositeImplicitAutograd"))
I'm not completely sure this is sufficient, because of the order in which kernels get registered:
- Is it guaranteed that when we're registering a batching rule for aten::reshape that there is a CompositeImplicitAutograd kernel for aten::reshape that we can check and then error out on?
Thoughts? cc @albanD @ezyang @samdow @Chillee