Skip to content

Commit

Permalink
Add SymInt support to torch.take_along_dim
Browse files Browse the repository at this point in the history
ghstack-source-id: c85e4cd0309c248d89d41be098e919c388f2be62
Pull Request resolved: #108879
  • Loading branch information
guilhermeleobas committed Sep 8, 2023
1 parent cadd97f commit 64800dc
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 10 deletions.
20 changes: 11 additions & 9 deletions aten/src/ATen/native/TensorAdvancedIndexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2065,19 +2065,21 @@ inline std::tuple<Tensor, Tensor, int64_t> _take_along_dim_helper(

dim = at::maybe_wrap_dim(dim, self.dim());

DimVector self_sizes{self.sizes()};
SymDimVector self_sizes{self.sym_sizes()};
// update number of elements at dim as per indices
self_sizes[dim] = indices.size(dim);
auto broadcast_shape = infer_size(self_sizes, indices.sizes());
auto indices_broadcasted = at::broadcast_to(indices, broadcast_shape);
self_sizes[dim] = indices.sym_size(dim);
auto broadcast_shape = infer_size_symint(self_sizes, indices.sym_sizes());
auto indices_broadcasted = at::broadcast_to_symint(indices, broadcast_shape);

DimVector indices_sizes{indices.sizes()};
SymDimVector indices_sizes{indices.sym_sizes()};
// update number of elements at dim as per self
indices_sizes[dim] = self.size(dim);
broadcast_shape = infer_size(indices_sizes, self.sizes());
auto self_broadcasted = at::broadcast_to(self, broadcast_shape);
indices_sizes[dim] = self.sym_size(dim);
broadcast_shape = infer_size_symint(indices_sizes, self.sym_sizes());
auto self_broadcasted = at::broadcast_to_symint(self, broadcast_shape);

return std::make_tuple(self_broadcasted, indices_broadcasted, dim);
return std::make_tuple(std::move(self_broadcasted),
std::move(indices_broadcasted),
std::move(dim));
}

static inline void checkDevice(CheckedFrom c, const Tensor& t, Device device) {
Expand Down
1 change: 0 additions & 1 deletion test/test_proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1574,7 +1574,6 @@ def f(t):
xfail('resize_', ''), # aten.clone.default - couldn't find symbolic meta function/decomposition
xfail('resize_as_', ''), # aten.clone.default - couldn't find symbolic meta function/decomposition
xfail('_segment_reduce', 'offsets'), # aten.segment_reduce.default - couldn't find symbolic meta function/decomposition
xfail('take_along_dim', ''), # dtype of indices should be Long but got Float
xfail('unique_consecutive', ''), # aten.unique_consecutive.default - couldn't find symbolic meta function/decomposition
xfail('unique', ''), # aten._unique2.default - couldn't find symbolic meta function/decomposition

Expand Down

0 comments on commit 64800dc

Please sign in to comment.