Skip to content

Commit

Permalink
Add batching rule for torch.matrix_exp.
Browse files Browse the repository at this point in the history
  • Loading branch information
Xiao215 committed Jan 19, 2024
1 parent ea41f87 commit 612114a
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion aten/src/ATen/functorch/BatchRulesLinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ threeOutputs linalg_lu_factor_ex_batch_rule(
oneOutput matrix_exp_batch_rule(const Tensor& self, c10::optional<int64_t> self_bdim) {
TORCH_CHECK(
rankWithoutBatchDim(self, self_bdim) >= 2,
"torch.linalg.matrix_exp: The input tensor A must have at least 2 dimensions.");
"torch.matrix_exp: The input tensor A must have at least 2 dimensions.");
const auto self_ = moveBatchDimToFront(self, self_bdim).contiguous(); // seems to be a bug
return std::make_tuple(at::linalg_matrix_exp(self_), 0);
}
Expand Down

0 comments on commit 612114a

Please sign in to comment.