vmap support for torch.tril, torch.triu #91403
Labels
good first issue
module: functorch
Pertaining to torch.func or pytorch/functorch
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Tasks
Add vmap support for the following PyTorch operations. That is, each one needs a batching rule.
These already have batching rules implemented, but the batching rules are incorrect.
Expected behavior
Currently, when one does vmap over them using a 2D input, it succeeds:
This should actually raise an exception. The mental model here is that vmap should be equivalent to performing a for-loop over the arguments. The following expression does raise an exception, so vmap should as well:
Read this!
See this note for more context https://github.com/pytorch/pytorch/blob/master/functorch/writing_batching_rules.md
If you're new to developing PyTorch and/or function transforms, we would recommend reading through https://github.com/pytorch/pytorch/wiki/Core-Frontend-Onboarding#unit-8-function-transforms-optional
For this issue, you'll likely need to implement something similar to the vmap support for torch.clone:
pytorch/aten/src/ATen/functorch/BatchRulesUnaryOps.cpp
Line 95 in 06bdd49
Testing
pytest test/functorch/test_vmap.py -v -k "test_op_has_batch_rule_tril"
should pass (currently, it is an expected failure). Ditto for triu.cc @Chillee @samdow @soumith
The text was updated successfully, but these errors were encountered: