Skip to content

Commit

Permalink
[pt2] add SymInt support for linalg.matrix_power
Browse files Browse the repository at this point in the history
ghstack-source-id: 6c13bc9e6ce0faa2671cad5fab9e2d2d7857f1d6
Pull Request resolved: #101940
  • Loading branch information
nkaretnikov committed May 21, 2023
1 parent 0edc3d0 commit e5c530e
Show file tree
Hide file tree
Showing 5 changed files with 3 additions and 7 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/native/LinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -623,7 +623,7 @@ Tensor linalg_matrix_power_impl(
// Clone input to include result in the autograd graph
out = self.clone(at::MemoryFormat::Contiguous);
}
return out.copy_(at::eye(self.size(-2), self.options()));
return out.copy_(at::eye_symint(self.sym_size(-2), self.options()));
}
if (n == 1) {
return _out.has_value() ? out.copy_(self)
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/LinearAlgebraUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,10 @@ static inline void checkIsMatrix(const Tensor& A, const char* const f_name, cons
}
static inline void squareCheckInputs(const Tensor& self, const char* const f_name, const char* const arg_name = "A") {
checkIsMatrix(self, f_name, arg_name);
TORCH_CHECK(self.size(-1) == self.size(-2),
TORCH_CHECK(self.sym_size(-1) == self.sym_size(-2),
f_name,
": ", arg_name, " must be batches of square matrices, "
"but they are ", self.size(-2), " by ", self.size(-1), " matrices");
"but they are ", self.sym_size(-2), " by ", self.sym_size(-1), " matrices");
}

static inline void checkInputsSolver(const Tensor& A,
Expand Down
1 change: 0 additions & 1 deletion test/functorch/test_aotdispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2775,7 +2775,6 @@ def forward(self, x):
xfail('linalg.lu_factor', ''), # aten.linalg_lu_factor_ex.default - couldn't find symbolic meta function...
xfail('linalg.lu_factor_ex', ''), # aten.linalg_lu_factor_ex.default - couldn't find symbolic meta funct...
xfail('linalg.lu_solve', ''), # aten.linalg_lu_solve.default - couldn't find symbolic meta function/deco...
xfail('linalg.matrix_power', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('linalg.multi_dot', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('linalg.pinv', ''), # aten.linalg_pinv.atol_rtol_tensor - couldn't find symbolic meta function/dec...
xfail('linalg.pinv', 'hermitian'), # aten.linalg_pinv.atol_rtol_tensor - couldn't find symbolic meta fu...
Expand Down
2 changes: 0 additions & 2 deletions test/test_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,7 +662,6 @@ def run_meta_crossref(
torch.functional.tensordot : {bf16, i8, i64, u8, c128, f64, i16, f32, i32, c64},
torch.inner : {bf16, i8, i64, u8, c128, f64, i16, f32, i32, c64},
torch.linalg.matrix_norm : {c128, f32, c64, f64},
torch.linalg.matrix_power : {c128, c64},
torch.linalg.matrix_rank : {c128, c64},
torch.linalg.svd : {c128, c64},
torch.matmul : {bf16, c128, f64, f32, f16, c64},
Expand Down Expand Up @@ -736,7 +735,6 @@ def run_meta_crossref(
meta_function_device_skips['cuda'] = {
torch.functional.tensordot: {f16},
torch.inner: {f16},
torch.linalg.matrix_power: {f32, f64},
torch.linalg.matrix_rank: {f32, f64},
torch.linalg.svd: {f32, f64},
torch.nn.functional.cross_entropy: {f16},
Expand Down
1 change: 0 additions & 1 deletion test/test_proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1443,7 +1443,6 @@ def f(a, b, c, d, e):
xfail('isin', ''), # aten.isin.Tensor_Tensor - couldn't find symbolic meta function/decomposition
xfail('kron', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('kthvalue', ''), # aten.kthvalue.default - couldn't find symbolic meta function/decomposition
xfail('linalg.matrix_power'), # RuntimeError: Trying to call aten.size on a tensor with symbolic shape
xfail('linalg.multi_dot', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('linalg.pinv', ''), # aten.linalg_pinv.atol_rtol_tensor - couldn't find symbolic meta function/decomposition
xfail('linalg.pinv', 'singular'), # aten.linalg_cholesky_ex.default - couldn't find symbolic meta function/decomposition
Expand Down

0 comments on commit e5c530e

Please sign in to comment.