Skip to content

Commit

Permalink
[pt2] add SymInt support for tensordot and inner (#100356)
Browse files Browse the repository at this point in the history
Pull Request resolved: #100356
Approved by: https://github.com/ezyang
  • Loading branch information
nkaretnikov authored and pytorchmergebot committed May 2, 2023
1 parent 4582ceb commit 4136153
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 20 deletions.
27 changes: 14 additions & 13 deletions aten/src/ATen/native/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -727,12 +727,12 @@ Tensor bilinear(const Tensor& input1, const Tensor& input2, const Tensor& weight
Tensor tensordot(const Tensor& input1, const Tensor& input2, IntArrayRef dims1, IntArrayRef dims2) {
TORCH_CHECK(dims1.size() == dims2.size(), "both dimension lists should have same length");
TORCH_CHECK(input1.scalar_type() == input2.scalar_type(), "both inputs should have same dtype");
int64_t csize = 1; // total size of the contracted dimensions
SymInt csize = 1; // total size of the contracted dimensions
Tensor t1 = input1;
Tensor t2 = input2;
for (const auto i : c10::irange(dims1.size())) {
int s1 = input1.size(dims1[i]);
int s2 = input2.size(dims2[i]);
SymInt s1 = input1.sym_size(dims1[i]);
SymInt s2 = input2.sym_size(dims2[i]);
if (s2 == 1) { // broadcasted dimensions can be summed right away
t1 = t1.sum(dims1[i], true);
} else if (s1 == 1) {
Expand All @@ -746,19 +746,20 @@ Tensor tensordot(const Tensor& input1, const Tensor& input2, IntArrayRef dims1,

auto cdims1 = at::dim_list_to_bitset(dims1, input1.dim());
auto cdims2 = at::dim_list_to_bitset(dims2, input2.dim());
std::vector<int64_t> p1, p2, rsizes; // p1, p2: input permutations, rsizes: sizes of the result
std::vector<int64_t> p1, p2; // p1, p2: input permutations
std::vector<SymInt> rsizes; // rsizes: sizes of the result
p1.reserve(input1.dim());
p2.reserve(input2.dim());
rsizes.reserve(input1.dim() + input2.dim() - (int64_t) dims1.size());
int64_t size1 = 1; // number of non-contracted elements in input1
int64_t size2 = 1; // number of non-contracted elements in input2
SymInt size1 = 1; // number of non-contracted elements in input1
SymInt size2 = 1; // number of non-contracted elements in input2

// fill the permutations and compute sizes
for (const auto i : c10::irange(input1.dim())) {
if (! cdims1[i]) {
p1.emplace_back(i);
size1 *= t1.size(i);
rsizes.emplace_back(t1.size(i));
size1 *= t1.sym_size(i);
rsizes.emplace_back(t1.sym_size(i));
}
}
for (const auto x : dims1) {
Expand All @@ -770,15 +771,15 @@ Tensor tensordot(const Tensor& input1, const Tensor& input2, IntArrayRef dims1,
for (const auto i : c10::irange(input2.dim())) {
if (! cdims2[i]) {
p2.emplace_back(i);
size2 *= t2.size(i);
rsizes.emplace_back(t2.size(i));
size2 *= t2.sym_size(i);
rsizes.emplace_back(t2.sym_size(i));
}
}
// permut and reshape for matrix multiplication
t1 = t1.permute(p1).reshape({size1, csize});
t2 = t2.permute(p2).reshape({csize, size2});
t1 = t1.permute(p1).reshape_symint({size1, csize});
t2 = t2.permute(p2).reshape_symint({csize, size2});
// multiply and reshape to target size
return at::mm(t1, t2).reshape(rsizes);
return at::mm(t1, t2).reshape_symint(rsizes);
}

Tensor &tensordot_out(const Tensor& input1, const Tensor& input2, IntArrayRef dims1, IntArrayRef dims2, Tensor& result) {
Expand Down
6 changes: 3 additions & 3 deletions aten/src/ATen/native/LinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1284,11 +1284,11 @@ Tensor inner(const Tensor& self, const Tensor& other) {

// Last dimension should match (tensordot does not enforce this)
TORCH_CHECK(
self.size(-1) == other.size(-1),
self.sym_size(-1) == other.sym_size(-1),
"inner() the last dimension must match on both input tensors but got shapes ",
self.sizes(),
self.sym_sizes(),
" and ",
other.sizes());
other.sym_sizes());

return at::tensordot(self, other, -1, -1);
}
Expand Down
2 changes: 0 additions & 2 deletions test/functorch/test_aotdispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2502,7 +2502,6 @@ def forward(self, x):
xfail('hsplit', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('i0', ''), # aten.i0.default - couldn't find symbolic meta function/decomposition
xfail('index_fill', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('inner', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('kron', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('kthvalue', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('linalg.cholesky_ex', ''), # could not find kernel for aten.linalg_solve_triangular.default
Expand Down Expand Up @@ -2612,7 +2611,6 @@ def forward(self, x):
xfail('svd', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('svd_lowrank', ''), # could not find kernel
xfail('take_along_dim', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('tensordot', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('trace', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('triangular_solve', ''), # aten.triangular_solve.default - couldn't find symbolic meta function/de...
xfail('_upsample_bilinear2d_aa'), # RuntimeError: isIntList() INTERNAL ASSERT FAILED Expected IntList but got GenericList
Expand Down
2 changes: 0 additions & 2 deletions test/test_proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1455,7 +1455,6 @@ def f(a, b, c, d, e):
xfail('histogramdd', ''), # aten._histogramdd_bin_edges.default - couldn't find symbolic meta function/decomposition
xfail('hsplit', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('index_reduce', ''), # Float
xfail('inner', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('isin', ''), # aten.isin.Tensor_Tensor - couldn't find symbolic meta function/decomposition
xfail('kron', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('kthvalue', ''), # aten.kthvalue.default - couldn't find symbolic meta function/decomposition
Expand Down Expand Up @@ -1559,7 +1558,6 @@ def f(a, b, c, d, e):
xfail('stft', ''), # argument 'size' must be tuple of ints, but found element of type torch._C.SymIntNode at...
xfail('svd_lowrank', ''), # aten.mm.default - couldn't find symbolic meta function/decomposition
xfail('take_along_dim', ''), # dtype of indices should be Long but got Float
xfail('tensordot', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('triangular_solve', ''), # aten.triangular_solve.default - couldn't find symbolic meta function/decomposition
xfail('vsplit', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('unique_consecutive', ''), # aten.unique_consecutive.default - couldn't find symbolic meta function/decomposition
Expand Down

0 comments on commit 4136153

Please sign in to comment.