primTorch: support refs and decompositions when ATen and Python disagree #83931
Labels
module: primTorch
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
馃悰 Describe the bug
Policy (related): We need to agree on how to handle refs/decomps like
binary_cross_entropy
where the ATen op and the Python frontend have different number/type of arguments. Right now,binary_cross_entropy
just defines a decomp. Do we want to split the ref/decomp implementations for these (with a shared core/conversion logic) or handle in some other way?Bug: if
register_decomposition
is used with ops like these to define one Python function that's both registered as a ref and a decomp, there are no type signature checks inregister_decomposition
to catch this. So it'll just run and we might not notice unless it breaks for some other reason.I have a demo branch with this issue here, see the commit message and notes in the topmost commit:
https://github.com/nkaretnikov/pytorch/commits/primtorch-l1-loss-decomp-ref-compat-issue
Versions
master (e0f2eba)
cc @ezyang @mruberry @ngimel
The text was updated successfully, but these errors were encountered: