Skip to content

Commit

Permalink
Add itertools.{prod, combinations, combinations_with_replacement} l…
Browse files Browse the repository at this point in the history
…ike op to pytorch (#9393)

Summary:
closes #7580
Pull Request resolved: #9393

Differential Revision: D13659628

Pulled By: zou3519

fbshipit-source-id: 3a233befa785709395a793ba8833413be394a6fd
  • Loading branch information
zasdfgbnm authored and facebook-github-bot committed Jan 15, 2019
1 parent 964732f commit 1065e7c
Show file tree
Hide file tree
Showing 6 changed files with 202 additions and 1 deletion.
2 changes: 2 additions & 0 deletions aten/src/ATen/core/aten_interned_strings.h
Expand Up @@ -236,6 +236,7 @@ _(aten, broadcast_tensors) \
_(aten, btrifact) \
_(aten, btrifact_with_info) \
_(aten, btrisolve) \
_(aten, cartesian_prod) \
_(aten, cat) \
_(aten, cauchy) \
_(aten, ceil) \
Expand All @@ -249,6 +250,7 @@ _(aten, clamp_max) \
_(aten, clamp_min) \
_(aten, clone) \
_(aten, coalesce) \
_(aten, combinations) \
_(aten, constant_pad_nd) \
_(aten, contiguous) \
_(aten, conv1d) \
Expand Down
60 changes: 60 additions & 0 deletions aten/src/ATen/native/Itertools.cpp
@@ -0,0 +1,60 @@
#include "ATen/ATen.h"
#include "ATen/Dispatch.h"

#include <vector>

namespace {

using namespace at;

Tensor _triu_mask(int64_t n, int64_t dims, bool diagonal, TensorOptions opt) {
// get a mask that has value 1 whose indices satisfies i < j < k < ...
// or i <= j <= k <= ... (depending on diagonal)
Tensor range = at::arange(n, opt.dtype(kLong));
std::vector<Tensor> index_grids = at::meshgrid(std::vector<Tensor>(dims, range));
Tensor mask = at::ones(index_grids[0].sizes(), opt.dtype(kByte));
if(diagonal) {
for(int64_t i = 0; i < dims - 1; i++) {
mask *= index_grids[i] <= index_grids[i+1];
}
} else {
for(int64_t i = 0; i < dims - 1; i++) {
mask *= index_grids[i] < index_grids[i+1];
}
}
return mask;
}

} // namespace

namespace at {
namespace native{

Tensor cartesian_prod(TensorList tensors) {
for(const Tensor &t : tensors) {
AT_CHECK(t.dim() == 1, "Expect a 1D vector, but got shape ", t.sizes());
}
if (tensors.size() == 1) {
return tensors[0];
}
std::vector<Tensor> grids = at::meshgrid(tensors);
for(Tensor &t : grids) {
t = t.flatten();
}
return at::stack(grids, 1);
}

Tensor combinations(const Tensor& self, int64_t r, bool with_replacement) {
AT_CHECK(self.dim() == 1, "Expect a 1D vector, but got shape ", self.sizes());
AT_CHECK(r > 0, "Expect a positive number, but got ", r);
int64_t num_elements = self.numel();
std::vector<Tensor> grids = at::meshgrid(std::vector<Tensor>(r, self));
Tensor mask = _triu_mask(num_elements, r, with_replacement, self.options());
for(Tensor &t : grids) {
t = t.masked_select(mask);
}
return at::stack(grids, 1);
}

} // namespace native
} // namespace at
6 changes: 6 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Expand Up @@ -2340,6 +2340,12 @@

- func: meshgrid(TensorList tensors) -> TensorList

- func: cartesian_prod(TensorList tensors) -> Tensor
variants: function

- func: combinations(Tensor self, int64_t r=2, bool with_replacement=false) -> Tensor
variants: function

- func: item(Tensor self) -> Scalar
variants: method

Expand Down
59 changes: 58 additions & 1 deletion test/test_torch.py
Expand Up @@ -19,7 +19,7 @@
from torch.utils.dlpack import from_dlpack, to_dlpack
from torch._utils import _rebuild_tensor
from torch._six import inf, nan, string_classes
from itertools import product, combinations
from itertools import product, combinations, combinations_with_replacement
from functools import reduce
from torch import multiprocessing as mp
from common_methods_invocations import tri_tests_args, run_additional_tri_tests, \
Expand Down Expand Up @@ -9688,6 +9688,63 @@ def test_cast_binary_op(self):
self.assertEqual(b.type(), b_copy.type())
self.assertEqual(b.data.type(), b_copy.type())

def test_cartesian_prod(self):
a = torch.tensor([1])
b = torch.tensor([1, 2, 3])
c = torch.tensor([1, 2])
prod = torch.cartesian_prod(a, b, c)
expected = torch.tensor(list(product([a], b, c)))
self.assertEqual(expected, prod)

# test 0 size input
d = torch.empty(0, dtype=b.dtype)
prod = torch.cartesian_prod(a, b, c, d)
expected = torch.empty(0, 4, dtype=b.dtype)
self.assertEqual(expected, prod)

# test single input
prod = torch.cartesian_prod(b)
self.assertEqual(b, prod)

def test_combinations(self):
a = torch.tensor([1, 2, 3])

c = torch.combinations(a, r=1)
expected = torch.tensor(list(combinations(a, r=1)))
self.assertEqual(c, expected)

c = torch.combinations(a, r=1, with_replacement=True)
expected = torch.tensor(list(combinations_with_replacement(a, r=1)))
self.assertEqual(c, expected)

c = torch.combinations(a)
expected = torch.tensor(list(combinations(a, r=2)))
self.assertEqual(c, expected)

c = torch.combinations(a, with_replacement=True)
expected = torch.tensor(list(combinations_with_replacement(a, r=2)))
self.assertEqual(c, expected)

c = torch.combinations(a, r=3)
expected = torch.tensor(list(combinations(a, r=3)))
self.assertEqual(c, expected)

c = torch.combinations(a, r=4)
expected = torch.empty(0, 4, dtype=a.dtype)
self.assertEqual(c, expected)

c = torch.combinations(a, r=5)
expected = torch.empty(0, 5, dtype=a.dtype)
self.assertEqual(c, expected)

# test empty imput
a = torch.empty(0)
c1 = torch.combinations(a)
c2 = torch.combinations(a, with_replacement=True)
expected = torch.empty(0, 2, dtype=a.dtype)
self.assertEqual(c1, expected)
self.assertEqual(c2, expected)

@unittest.skipIf(torch.cuda.device_count() < 2, 'only one GPU detected')
def test_reverse_binary_ops_multiple_device(self):
self.assertEqual(2 + torch.tensor(3), 2 + torch.tensor(3).to("cuda:1")) # __radd__
Expand Down
44 changes: 44 additions & 0 deletions torch/_torch_docs.py
Expand Up @@ -6267,3 +6267,47 @@ def parse_kwargs(desc):
>>> [7, 8, 9]]))
(tensor([1, 2, 3]), tensor([4, 5, 6]), tensor([7, 8, 9]))
""")


add_docstr(torch.combinations,
r"""
combinations(tensor, r=2, with_replacement=False) -> seq
Compute combinations of length :math:`r` of the given tensor. The behavior is similar to
python's `itertools.combinations` when `with_replacement` is set to `False`, and
`itertools.combinations_with_replacement` when `with_replacement` is set to `True`.
Arguments:
tensor (Tensor): 1D vector.
r (int, optional): number of elements to combine
with_replacement (boolean, optional): whether to allow duplication in combination
Returns:
Tensor: A tensor equivalent to converting all the input tensors into lists, do
`itertools.combinations` or `itertools.combinations_with_replacement` on these
lists, and finally convert the resulting list into tensor.
Example::
>>> a = [1, 2, 3]
>>> list(itertools.combinations(a, r=2))
[(1, 2), (1, 3), (2, 3)]
>>> list(itertools.combinations(a, r=3))
[(1, 2, 3)]
>>> list(itertools.combinations_with_replacement(a, r=2))
[(1, 1), (1, 2), (1, 3), (2, 2), (2, 3), (3, 3)]
>>> tensor_a = torch.tensor(a)
>>> torch.combinations(tensor_a)
tensor([[1, 2],
[1, 3],
[2, 3]])
>>> torch.combinations(tensor_a, r=3)
tensor([[1, 2, 3]])
>>> torch.combinations(tensor_a, with_replacement=True)
tensor([[1, 1],
[1, 2],
[1, 3],
[2, 2],
[2, 3],
[3, 3]])
""")
32 changes: 32 additions & 0 deletions torch/functional.py
Expand Up @@ -27,6 +27,7 @@
'stft',
'tensordot',
'unique',
'cartesian_prod',
]


Expand Down Expand Up @@ -601,6 +602,37 @@ def argsort(input, dim=None, descending=False):
return torch.sort(input, dim, descending)[1]


def cartesian_prod(*tensors):
"""Do cartesian product of the given sequence of tensors. The behavior is similar to
python's `itertools.product`.
Arguments:
*tensors: any number of 1 dimensional tensors.
Returns:
Tensor: A tensor equivalent to converting all the input tensors into lists,
do `itertools.product` on these lists, and finally convert the resulting list
into tensor.
Example::
>>> a = [1, 2, 3]
>>> b = [4, 5]
>>> list(itertools.product(a, b))
[(1, 4), (1, 5), (2, 4), (2, 5), (3, 4), (3, 5)]
>>> tensor_a = torch.tensor(a)
>>> tensor_b = torch.tensor(b)
>>> torch.cartesian_prod(tensor_a, tensor_b)
tensor([[1, 4],
[1, 5],
[2, 4],
[2, 5],
[3, 4],
[3, 5]])
"""
return torch._C._VariableFunctions.cartesian_prod(tensors)


def norm(input, p="fro", dim=None, keepdim=False, out=None):
r"""Returns the matrix norm or vector norm of a given tensor.
Expand Down

0 comments on commit 1065e7c

Please sign in to comment.