Skip to content

Commit

Permalink
[pt2] add SymInt support for column_stack
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
nkaretnikov committed Apr 16, 2023
1 parent e2923b5 commit 7d12ea9
Show file tree
Hide file tree
Showing 3 changed files with 1 addition and 3 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/native/TensorShape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2807,7 +2807,7 @@ static std::vector<Tensor> reshape_input_for_column_stack(TensorList tensors) {
auto transform_lambda = [](const Tensor& input) -> Tensor {
// reshape 0D or 1D tensor t into (t.numel(), 1)
if (input.dim() <= 1) {
return input.reshape({input.numel(), 1});
return input.reshape_symint({input.sym_numel(), 1});
}
return input;
};
Expand Down
1 change: 0 additions & 1 deletion test/functorch/test_aotdispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2509,7 +2509,6 @@ def forward(self, x):
xfail('cdist', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('cholesky_inverse', ''), # could not find kernel
xfail('cholesky_solve', ''), # could not find kernel
xfail('column_stack', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('combinations', ''), # aten.masked_select.default
xfail('cross', ''), # aten.linalg_cross.default - couldn't find symbolic meta function/decomposition
xfail('cumprod', ''), # aten.cumprod.default - couldn't find symbolic meta function/decomposition
Expand Down
1 change: 0 additions & 1 deletion test/test_proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1368,7 +1368,6 @@ def f(a, b, c, d, e):
xfail('linalg.eig'),
xfail('linalg.eigvals'),
xfail('cholesky_solve', ''), # Could not run 'aten::_cholesky_solve_helper' with arguments from the 'Meta' back...
xfail('column_stack', ''), # Tensors of type TensorImpl do not have numel
xfail('combinations', ''),
xfail('cross', ''), # aten.linalg_cross.default - couldn't find symbolic meta function/decomposition
xfail('cumulative_trapezoid', ''), # aten.slice.Tensor - couldn't find symbolic meta function/decomposition
Expand Down

0 comments on commit 7d12ea9

Please sign in to comment.