Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Add mechanism to error out when registering a batching rule for a CompositeImplicitAutograd operation #1009

@zou3519

Description

@zou3519

The problem

There are two types of batching rules:

  1. decompositions for CompositeImplicitAutograd operations
  2. everything else

Registering a type-2 batching rule for a CompositeImplicitAutograd operator "op" is a trap. When we do something like vmap(grad(op))(x), then op gets decomposed, and the batching rule never gets hit. It's possible that op's decomposition decomposes into other ops that don't have a batching rule or straight up doesn't work with vmap due to doing .data_ptr accesses.

We should change it so that registering a batching rule for a CompositeImplicitAutograd operator throws an error or there is some test that checks this.

The solution

@albanD tells me we have a way to tell if an operator is CompositeImplicitAutograd print(torch._C._dispatch_has_kernel_for_dispatch_key("aten::reshape", "CompositeImplicitAutograd"))

I'm not completely sure this is sufficient, because of the order in which kernels get registered:

  • Is it guaranteed that when we're registering a batching rule for aten::reshape that there is a CompositeImplicitAutograd kernel for aten::reshape that we can check and then error out on?

Thoughts? cc @albanD @ezyang @samdow @Chillee

Metadata

Metadata

Assignees

No one assigned

    Labels

    actionableIt is clear what should be done for this issuesmall

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions