Skip to content

Commit

Permalink
Add batching rule for torch.matrix_exp.
Browse files Browse the repository at this point in the history
  • Loading branch information
Xiao215 committed Jan 19, 2024
1 parent 8524fa5 commit ea41f87
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
7 changes: 5 additions & 2 deletions aten/src/ATen/functorch/BatchRulesLinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -331,9 +331,11 @@ threeOutputs linalg_lu_factor_ex_batch_rule(
}

oneOutput matrix_exp_batch_rule(const Tensor& self, c10::optional<int64_t> self_bdim) {
TORCH_CHECK(rankWithoutBatchDim(self, self_bdim) >= 2, "torch.matrix_exp: The input tensor A must have at least 2 dimensions.");
TORCH_CHECK(
rankWithoutBatchDim(self, self_bdim) >= 2,
"torch.linalg.matrix_exp: The input tensor A must have at least 2 dimensions.");
const auto self_ = moveBatchDimToFront(self, self_bdim).contiguous(); // seems to be a bug
return std::make_tuple(at::matrix_exp(self_), 0);
return std::make_tuple(at::linalg_matrix_exp(self_), 0);
}

fourOutputs solve_ex_batch_rule(
Expand Down Expand Up @@ -585,6 +587,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
VMAP_SUPPORT(linalg_lstsq, linalg_lstsq_batch_rule); // custom errors and sometimes empty return
VMAP_SUPPORT(linalg_lu_factor_ex, linalg_lu_factor_ex_batch_rule);
VMAP_SUPPORT(linalg_matrix_exp, matrix_exp_batch_rule);
VMAP_SUPPORT(matrix_exp, matrix_exp_batch_rule);
VMAP_SUPPORT(_linalg_solve_ex, solve_ex_batch_rule);
VMAP_SUPPORT(linalg_cross, cross_batch_rule);
VMAP_SUPPORT2(linalg_pinv, atol_rtol_tensor, pinv_batch_rule);
Expand Down
1 change: 0 additions & 1 deletion test/functorch/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3701,7 +3701,6 @@ def test_vmap_exhaustive(self, device, dtype, op):
xfail('torch.ops.aten._efficient_attention_forward'), # outputs ints
xfail('resize_'),
xfail('view_as_complex'),
xfail('matrix_exp'),
xfail('fft.ihfft2'),
xfail('fft.ihfftn'),
xfail('allclose'),
Expand Down

0 comments on commit ea41f87

Please sign in to comment.