Skip to content

Commit

Permalink
vmap support for torch.tril and torch.triu (#94287)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #94287

Add vmap support for torch.tril and torch.triu.

Issue: #91403

Test Plan: GitHub pipeline

Reviewed By: zou3519

Differential Revision: D43016624

fbshipit-source-id: 6e73ce6f1c83be8bafc70a039c06404b60180a87
  • Loading branch information
isdanni authored and facebook-github-bot committed Mar 15, 2023
1 parent 11e708d commit e962dc9
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 4 deletions.
24 changes: 22 additions & 2 deletions aten/src/ATen/functorch/BatchRulesViews.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -563,12 +563,32 @@ Tensor trace_decomp(const Tensor& tensor) {
return tensor.diagonal().sum();
}

std::tuple<Tensor,optional<int64_t>> tril_batch_rule(
const Tensor& self,
optional<int64_t> self_bdim,
int64_t diagonal = 0) {
TORCH_CHECK(self.dim() >= 2, "tril: The input tensor must have at least 2 dimensions.");
auto self_ = moveBatchDimToFront(self, self_bdim);
auto result = at::tril(self_, diagonal);
return std::make_tuple(std::move(result), 0);
}

std::tuple<Tensor,optional<int64_t>> triu_batch_rule(
const Tensor& self,
optional<int64_t> self_bdim,
int64_t diagonal = 0) {
TORCH_CHECK(self.dim() >= 2, "triu: The input tensor must have at least 2 dimensions.");
auto self_ = moveBatchDimToFront(self, self_bdim);
auto result = at::triu(self_, diagonal);
return std::make_tuple(std::move(result), 0);
}

TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
m.impl("flatten.using_ints", static_cast<decltype(&ATEN_FN2(flatten, using_ints))>(native::flatten));
VMAP_SUPPORT(flip, flip_batch_rule);
m.impl("trace", trace_decomp);
VMAP_SUPPORT(tril, VARIADIC_BDIMS_BATCH_RULE(ATEN_FN(tril)));
VMAP_SUPPORT(triu, VARIADIC_BDIMS_BATCH_RULE(ATEN_FN(triu)));
VMAP_SUPPORT(tril, tril_batch_rule);
VMAP_SUPPORT(triu, triu_batch_rule);
VMAP_SUPPORT(repeat, repeat_batch_rule);
VMAP_SUPPORT(_unsafe_view, _unsafe_view_batch_rule);
VMAP_SUPPORT(unsqueeze, unsqueeze_batch_rule);
Expand Down
2 changes: 0 additions & 2 deletions test/functorch/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3802,8 +3802,6 @@ def test_op_has_batch_rule(self, device, dtype, op):
'scatter',
'square',
'sub',
'tril',
'triu',
'trunc',
'xlogy',
)
Expand Down

0 comments on commit e962dc9

Please sign in to comment.