Skip to content

Commit

Permalink
remove flatten.using_ints, linalg_*, linear, log_softmax.int, logdet,…
Browse files Browse the repository at this point in the history
… special_* from xfail list

ghstack-source-id: b9a62a39de789dee7d92de3b808431a9c284b073
Pull Request resolved: #110985
  • Loading branch information
guilhermeleobas committed Oct 10, 2023
1 parent 733368a commit ebcb9f3
Show file tree
Hide file tree
Showing 6 changed files with 11 additions and 32 deletions.
1 change: 0 additions & 1 deletion aten/src/ATen/functorch/BatchRulesBinaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,6 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
BINARY_SCALAR_2(rsub, Tensor, Scalar);

BINARY_SCALAR_3_Tensor(special_xlog1py, other_scalar, self_scalar);
BINARY_SCALAR_3_Tensor(special_xlogy, other_scalar, self_scalar);
BINARY_SCALAR_3_Tensor(special_zeta, other_scalar, self_scalar);

VMAP_SUPPORT2(where, self, where_self_batch_rule);
Expand Down
11 changes: 11 additions & 0 deletions aten/src/ATen/functorch/BatchRulesDecompositions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) {
OP_DECOMPOSE(fix);
OP_DECOMPOSE(fliplr);
OP_DECOMPOSE(flipud);
OP_DECOMPOSE2(flatten, using_ints);
OP_DECOMPOSE2(float_power, Tensor_Tensor);
OP_DECOMPOSE2(float_power, Tensor_Scalar);
OP_DECOMPOSE2(float_power, Scalar);
Expand Down Expand Up @@ -159,28 +160,35 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) {
OP_DECOMPOSE2(ldexp, Tensor);
OP_DECOMPOSE2(less_equal, Tensor );
OP_DECOMPOSE2(less, Tensor );
OP_DECOMPOSE(linear);
OP_DECOMPOSE(linalg_cond);
OP_DECOMPOSE(linalg_cholesky);
OP_DECOMPOSE(linalg_det);
OP_DECOMPOSE(linalg_eigvalsh);
OP_DECOMPOSE(linalg_eigvals);
OP_DECOMPOSE(linalg_inv);
OP_DECOMPOSE(linalg_lu_factor);
OP_DECOMPOSE(linalg_matmul);
OP_DECOMPOSE(linalg_matrix_norm);
OP_DECOMPOSE2(linalg_matrix_norm, str_ord);
OP_DECOMPOSE(linalg_multi_dot);
OP_DECOMPOSE(linalg_norm);
OP_DECOMPOSE2(linalg_norm, ord_str);
OP_DECOMPOSE(linalg_eigh);
OP_DECOMPOSE(linalg_slogdet);
OP_DECOMPOSE(linalg_solve);
OP_DECOMPOSE(linalg_solve_ex);
OP_DECOMPOSE(linalg_svd);
OP_DECOMPOSE(linalg_svdvals);
OP_DECOMPOSE(linalg_pinv);
OP_DECOMPOSE(linalg_tensorinv);
OP_DECOMPOSE2(linalg_pinv, atol_rtol_float);
m.impl("linalg_vander", native::linalg_vander_symint);
OP_DECOMPOSE(cumprod_backward);
OP_DECOMPOSE(linalg_matrix_power);
OP_DECOMPOSE(linalg_vecdot);
OP_DECOMPOSE(logdet);
OP_DECOMPOSE2(log_softmax, int);
OP_DECOMPOSE(_lu_with_info);
OP_DECOMPOSE(matmul);
OP_DECOMPOSE(matrix_H);
Expand Down Expand Up @@ -250,6 +258,9 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) {
OP_DECOMPOSE(special_psi);
OP_DECOMPOSE(special_round);
OP_DECOMPOSE(special_sinc);
OP_DECOMPOSE(special_xlogy);
OP_DECOMPOSE2(special_xlogy, other_scalar);
OP_DECOMPOSE2(special_xlogy, self_scalar);


m.impl("split.sizes", native::split_symint);
Expand Down
7 changes: 0 additions & 7 deletions aten/src/ATen/functorch/BatchRulesLinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -582,18 +582,13 @@ LINALG_CHECK_MATRIX_UNARY_TWO_OUT(linalg_cholesky_ex, linalg.cholesky);
LINALG_CHECK_MATRIX_UNARY_TWO_OUT(linalg_eig, linalg.eig);
LINALG_CHECK_MATRIX_UNARY_TWO_OUT(linalg_inv_ex, linalg.inv_ex);
LINALG_CHECK_MATRIX_UNARY_THREE_OUT(linalg_ldl_factor_ex, torch.linalg.ldl_factor_ex);
LINALG_CHECK_MATRIX_UNARY_ONE_OUT(linalg_pinv, linalg.pinv);
LINALG_CHECK_MATRIX_UNARY_ONE_OUT2(linalg_pinv, atol_rtol_float, linalg.pinv);
LINALG_CHECK_MATRIX_UNARY_TWO_OUT(linalg_qr, linalg.qr);
LINALG_CHECK_MATRIX_UNARY_TWO_OUT(linalg_slogdet, linalg.slogdet);
LINALG_CHECK_MATRIX_BINARY_ONE_OUT(linalg_solve_triangular, linalg.solve_triangular);

LINALG_CHECK_MATRIX_UNARY_TWO_OUT(geqrf, geqrf);
LINALG_CHECK_MATRIX_UNARY_ONE_OUT(logdet, logdet);
LINALG_CHECK_MATRIX_BINARY_TWO_OUT(triangular_solve, triangular_solve);
LINALG_CHECK_MATRIX_UNARY_THREE_OUT(_linalg_det, linalg.det);
LINALG_CHECK_MATRIX_UNARY_TWO_OUT(_linalg_eigh, linalg.eigh);
LINALG_CHECK_MATRIX_UNARY_FOUR_OUT(_linalg_slogdet, linalg.slogdet);
LINALG_CHECK_MATRIX_UNARY_THREE_OUT(_linalg_svd, linalg.svd);

TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
Expand All @@ -605,13 +600,11 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
VMAP_SUPPORT(dot, dot_batch_rule);
VMAP_SUPPORT(mv, mv_batch_rule);
VMAP_SUPPORT(mm, mm_batch_rule);
m.impl("linear", linear_decomp);
VMAP_SUPPORT(linalg_lu_solve, linalg_lu_solve_batch_rule);
VMAP_SUPPORT(linalg_householder_product, householder_product_batch_rule);
VMAP_SUPPORT(cholesky_solve, cholesky_solve_batch_rule); // custom dim error
VMAP_SUPPORT(linalg_lstsq, linalg_lstsq_batch_rule); // custom errors and sometimes empty return
VMAP_SUPPORT(linalg_lu_factor_ex, linalg_lu_factor_ex_batch_rule);
VMAP_SUPPORT(linalg_lu_factor, linalg_lu_factor_batch_rule);
VMAP_SUPPORT(linalg_matrix_exp, matrix_exp_batch_rule);
VMAP_SUPPORT(_linalg_solve_ex, solve_ex_batch_rule);
VMAP_SUPPORT(linalg_cross, cross_batch_rule);
Expand Down
1 change: 0 additions & 1 deletion aten/src/ATen/functorch/BatchRulesReduceOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,6 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
m.impl("dist", dist_decomp);
REDUCTION_BOXED_ARGS(kthvalue, 2, KEEPDIM_CASE_VARIABLE, 3);
REDUCTION_BOXED_ARGS(linalg_vector_norm, 2, KEEPDIM_CASE_VARIABLE, 3);
REDUCTION_NO_KEEPDIM_ARG(log_softmax.int);
REDUCTION_NO_KEEPDIM_ARG(logcumsumexp);
REDUCTION_WITH_KEEPDIM_ARG(logsumexp);
m.impl("max", max_decomp);
Expand Down
1 change: 0 additions & 1 deletion aten/src/ATen/functorch/BatchRulesViews.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,6 @@ std::tuple<Tensor,optional<int64_t>> triu_batch_rule(
}

TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
m.impl("flatten.using_ints", static_cast<decltype(&ATEN_FN2(flatten, using_ints))>(native::flatten));
VMAP_SUPPORT(flip, flip_batch_rule);
m.impl("trace", trace_decomp);
VMAP_SUPPORT(tril, tril_batch_rule);
Expand Down
22 changes: 0 additions & 22 deletions test/functorch/test_vmap_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,26 +15,15 @@
)

xfail_functorch_batched = {
"aten::flatten.using_ints",
"aten::is_nonzero",
"aten::isfinite",
"aten::isreal",
"aten::item",
"aten::linalg_pinv",
"aten::linalg_pinv.atol_rtol_float",
"aten::linalg_slogdet",
"aten::linalg_lu_factor",
"aten::linear",
"aten::log_sigmoid",
"aten::log_softmax.int",
"aten::logdet",
"aten::masked_select_backward",
"aten::movedim.intlist",
"aten::one_hot",
"aten::silu_backward",
"aten::special_xlogy",
"aten::special_xlogy.other_scalar",
"aten::special_xlogy.self_scalar",
"aten::tensor_split.indices",
"aten::tensor_split.sections",
"aten::to.device",
Expand Down Expand Up @@ -95,7 +84,6 @@
"aten::fill_diagonal_",
"aten::fix_",
"aten::flatten.named_out_dim",
"aten::flatten.using_ints",
"aten::flatten.using_names",
"aten::flatten_dense_tensors",
"aten::float_power_.Scalar",
Expand Down Expand Up @@ -128,21 +116,14 @@
"aten::less_equal_.Tensor",
"aten::linalg_cond.p_str",
"aten::linalg_eigh.eigvals",
"aten::linalg_lu_factor",
"aten::linalg_matrix_rank",
"aten::linalg_matrix_rank.out_tol_tensor",
"aten::linalg_matrix_rank.tol_tensor",
"aten::linalg_pinv",
"aten::linalg_pinv.atol_rtol_float",
"aten::linalg_pinv.out_rcond_tensor",
"aten::linalg_pinv.rcond_tensor",
"aten::linalg_slogdet",
"aten::linalg_svd.U",
"aten::linalg_tensorsolve",
"aten::linear",
"aten::log_sigmoid",
"aten::log_softmax.int",
"aten::logdet",
"aten::logsumexp.names",
"aten::lstm.data",
"aten::lstm.input",
Expand Down Expand Up @@ -227,9 +208,6 @@
"aten::special_shifted_chebyshev_polynomial_v.x_scalar",
"aten::special_shifted_chebyshev_polynomial_w.n_scalar",
"aten::special_shifted_chebyshev_polynomial_w.x_scalar",
"aten::special_xlogy",
"aten::special_xlogy.other_scalar",
"aten::special_xlogy.self_scalar",
"aten::square_",
"aten::sspaddmm",
"aten::std.correction_names",
Expand Down

0 comments on commit ebcb9f3

Please sign in to comment.