Skip to content

Commit

Permalink
implemented batching rule for matrix_exp function
Browse files Browse the repository at this point in the history
  • Loading branch information
samhu1 committed Apr 27, 2024
1 parent 3f14759 commit cb7630e
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 0 deletions.
11 changes: 11 additions & 0 deletions aten/src/ATen/functorch/BatchRulesLinearAlgebra.cpp
Expand Up @@ -36,6 +36,15 @@ Tensor vdot_decomp(const Tensor& A, const Tensor& B) {
return at::dot(A.is_complex() ? A.conj() : A, B);
}


oneOutput matrix_exp_batch_rule(const Tensor& self, c10::optional<int64_t> self_bdim) {
TORCH_CHECK(rankWithoutBatchDim(self, self_bdim) >= 2,
"torch.matrix_exp: The input tensor must have at least 2 dimensions.");

auto self_ = moveBatchDimToFront(self, self_bdim);
return std::make_tuple(at::matrix_exp(self_), 0);
}

// NB: I wrote this like this because we *might* want its for a future matmul
// batch rule that isn't decomposed...
// "tv" = tensor @ vector
Expand Down Expand Up @@ -593,6 +602,8 @@ LINALG_CHECK_MATRIX_UNARY_FOUR_OUT(_linalg_slogdet, linalg.slogdet);
LINALG_CHECK_MATRIX_UNARY_THREE_OUT(_linalg_svd, linalg.svd);
// NOLINTEND(*array*)

TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {m.impl("matrix_exp", matrix_exp_batch_rule);}

TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
VMAP_SUPPORT(bmm, bmm_batch_rule);
m.impl("addmv", addmv_decomp);
Expand Down
21 changes: 21 additions & 0 deletions aten/src/ATen/test/basic.cpp
Expand Up @@ -365,6 +365,27 @@ void test(DeprecatedTypeProperties& type) {
TestIntArrayRefExpansion(type);
}

void TestMatrixExpBatching(DeprecatedTypeProperties& type) {
auto matrices = at::randn({10, 3, 3}, type.options());
std::vector<Tensor> expected_results;

for (const auto& mat : matrices) {
expected_results.push_back(at::matrix_exp(mat));
}
auto expected = at::stack(expected_results);

Tensor actual = at::matrix_exp(matrices);

// Compare actual results to expected results
ASSERT_TRUE(actual.allclose(expected, 1e-5, 1e-8));
}

TEST(MatrixExpTest, BatchedMatrixExp) {
manual_seed(42); // Set a manual seed for reproducibility
DeprecatedTypeProperties type = CPU(at::kDouble);
TestMatrixExpBatching(type);
}

TEST(BasicTest, BasicTestCPU) {
manual_seed(123);

Expand Down

0 comments on commit cb7630e

Please sign in to comment.