From 5bac6c010fbdf09b0c49e640a25d199229d1db27 Mon Sep 17 00:00:00 2001 From: Wei Yang Date: Mon, 18 Jun 2018 15:24:44 -0700 Subject: [PATCH 01/18] added torch.rot90() to ATen --- .../src/ATen/native/TensorTransformations.cpp | 36 ++++++++++++++++++ aten/src/ATen/native/native_functions.yaml | 2 + test/test_autograd.py | 2 + test/test_cuda.py | 5 +++ test/test_torch.py | 30 +++++++++++++++ tools/autograd/derivatives.yaml | 3 ++ torch/_torch_docs.py | 37 +++++++++++++++++++ 7 files changed, 115 insertions(+) diff --git a/aten/src/ATen/native/TensorTransformations.cpp b/aten/src/ATen/native/TensorTransformations.cpp index 8bce12cac2a69..4b6902609f37a 100644 --- a/aten/src/ATen/native/TensorTransformations.cpp +++ b/aten/src/ATen/native/TensorTransformations.cpp @@ -56,4 +56,40 @@ Tensor flip_cpu(const Tensor& self, IntList dims) { return out_tensor; } +Tensor rot90(const Tensor& self, int64_t k, IntList dims) { + const int64_t total_dims = self.dim(), total_rot_dims = dims.size(); + + AT_CHECK(total_rot_dims == 2, + "expected total rotation dims == 2, but got dims = ", total_rot_dims); + + AT_CHECK(dims[0] != dims[1], + "expected rotation dims to be different, but got both dims = ", dims[0]); + + // check range of dims + auto rot_dims_v = std::vector(dims); + std::sort(rot_dims_v.begin(), rot_dims_v.end()); + + AT_CHECK(rot_dims_v[0] >= 0, + "expected rotation dims >= 0, but got dims=", rot_dims_v[0]); + + AT_CHECK(rot_dims_v[1] <= total_dims - 1, + "expected rotation dims <= total_dims - 1, but got dims = ", rot_dims_v[1], + ", where total dims - 1 = ", total_dims - 1); + + // handle modulo with negative k + k = (4 + (k % 4)) % 4; + + switch(k) { + case 1: + return self.flip({dims[1]}).transpose_(dims[0], dims[1]); + case 2: + return self.flip({dims[0]}).flip({dims[1]}); + case 3: + return self.transpose(dims[0], dims[1]).flip({dims[1]}); + default: + return self.clone(); + } + +} + }} // namespace at::native diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 89f2771b8dadf..a701c18dc30c7 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -1445,6 +1445,8 @@ CPU: flip_cpu CUDA: flip_cuda +- func: rot90(Tensor self, int64_t k, IntList dims) -> Tensor + - func: _trilinear(Tensor i1, Tensor i2, Tensor i3, IntList expand1, IntList expand2, IntList expand3, IntList sumdim, int64_t unroll_dim=1) -> Tensor variants: function diff --git a/test/test_autograd.py b/test/test_autograd.py index 8559fc9284a3f..60c0feb831d9b 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -2648,6 +2648,8 @@ class dont_convert(tuple): ('flip', (S, S, S), ([0, 1, 2],), 'd012'), ('flip', (S, S, S), ([0, 2],), 'd02'), ('flip', (S, S, S), ([2, 0],), 'd20'), + ('rot90', (S, S, S), (1, [0, 1],), 'k1_d01'), + ('rot90', (S, S, S), (1, [1, 2],), 'k1_d12'), ('view_as', (S, S, S), (non_differentiable(torch.rand(S * S, S)),)), ('view_as', (), (non_differentiable(torch.tensor(5.5)),), 'scalar'), ('view_as', (), (non_differentiable(torch.rand(1, 1)),), 'scalar_to_dims'), diff --git a/test/test_cuda.py b/test/test_cuda.py index 9d56682f3cf06..090e288d89b22 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -420,6 +420,8 @@ def tmp(t): ('flip', small_3d, lambda t: [0, 1, 2], 'd012', types, True), ('flip', small_3d, lambda t: [0, 2], 'd02', types, True), ('flip', small_3d, lambda t: [2, 0], 'd20', types, True), + ('rot90', small_2d, lambda t: [1, [0, 1]], 'k1_d01', types, True), + ('rot90', small_3d, lambda t: [1, [1, 2]], 'k1_d12', types, True), ('rsqrt', lambda t: constant_tensor_add(1, small_3d(t)), lambda t: [], None, float_types), ('sinh', lambda t: tensor_clamp(small_3d(t), -1, 1), lambda t: [], None, float_types), ('tan', lambda t: tensor_clamp(small_3d(t), -1, 1), lambda t: [], None, float_types), @@ -1415,6 +1417,9 @@ def test_view(self): def test_flip(self): TestTorch._test_flip(self, use_cuda=True) + def test_rot90(self): + TestTorch._test_rot90(self, use_cuda=True) + def test_signal_window_functions(self): TestTorch._test_signal_window_functions(self, device=torch.device('cuda')) diff --git a/test/test_torch.py b/test/test_torch.py index 4d3d443db39bc..e27af9d622ef1 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -6735,6 +6735,36 @@ def test_reversed(self): val = torch.tensor(42) self.assertEqual(reversed(val), torch.tensor(42)) + @staticmethod + def _test_rot90(self, use_cuda=False): + device = torch.device("cuda" if use_cuda else "cpu") + data = torch.arange(1, 5, device=device).view(2, 2) + self.assertEqual(torch.tensor([1, 2, 3, 4]).view(2, 2), data.rot90(0, [0, 1])) + self.assertEqual(torch.tensor([2, 4, 1, 3]).view(2, 2), data.rot90(1, [0, 1])) + self.assertEqual(torch.tensor([4, 3, 2, 1]).view(2, 2), data.rot90(2, [0, 1])) + self.assertEqual(torch.tensor([3, 1, 4, 2]).view(2, 2), data.rot90(3, [0, 1])) + + # test for reversed order of dims + self.assertEqual(data.rot90(3, [0, 1]), data.rot90(1, [1, 0])) + + # test for modulo of k + self.assertEqual(data.rot90(5, [0, 1]), data.rot90(1, [0, 1])) + self.assertEqual(data.rot90(3, [0, 1]), data.rot90(-1, [0, 1])) + self.assertEqual(data.rot90(-5, [0, 1]), data.rot90(-1, [0, 1])) + + # test tensor with more than 2D + data = torch.arange(1, 9, device=device).view(2, 2, 2) + self.assertEqual(torch.tensor([2, 4, 1, 3, 6, 8, 5, 7]).view(2, 2, 2), data.rot90(1, [1, 2])) + + # test for errors + self.assertRaises(RuntimeError, lambda: data.rot90(1, [-1, 0])) + self.assertRaises(RuntimeError, lambda: data.rot90(1, [0, 3])) + self.assertRaises(RuntimeError, lambda: data.rot90(1, [1, 1])) + self.assertRaises(RuntimeError, lambda: data.rot90(1, [0, 1, 2])) + + def test_rot90(self): + self._test_rot90(self, use_cuda=False) + def test_storage(self): v = torch.randn(3, 5) self.assertEqual(v.storage()[0], v.data[0][0]) diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 3a396d84b66e4..25fb8ee5e2477 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -685,6 +685,9 @@ - name: flip(Tensor self, IntList dims) self: grad.flip(dims) +- name: rot90(Tensor self, int64_t k, IntList dims) + self: grad.rot90(-k, dims) + - name: take(Tensor self, Tensor index) self: zeros_like(self).put_(index, grad, true) diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 84a08a155ce97..663ded865d2c9 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -4498,6 +4498,43 @@ def parse_kwargs(desc): [ 0, 1]]]) """) +add_docstr(torch.rot90, + r""" +rot90(input, k, dims) -> Tensor + +Rotate a n-D tensor by 90 degrees in the plane specified by dims axis. +Rotation direction is from the first towards the second axis if k > 0, and from the second towards the first for k < 0. + +Args: + input (Tensor): the input tensor + k (int): number of times to rotate + dims (a list or tuple): axis to rotate + +Example:: + + >>> x = torch.arange(4).view(2, 2) + >>> x + tensor([[0, 1], + [2, 3]]) + >>> torch.flip(x, 1, [0, 1]) + tensor([[1, 3], + [0, 2]]) + + >>> x = torch.arange(8).view(2, 2, 2) + >>> x + tensor([[[0, 1], + [2, 3]], + + [[4, 5], + [6, 7]]]) + >>> torch.flip(x, 1, [1, 2]) + tensor([[[1, 3], + [0, 2]], + + [[5, 7], + [4, 6]]]) +""") + add_docstr(torch.take, r""" take(input, indices) -> Tensor From 5e61901b125ef303344eb26f395d343ea31280ae Mon Sep 17 00:00:00 2001 From: Wei Yang Date: Mon, 18 Jun 2018 15:55:46 -0700 Subject: [PATCH 02/18] nits --- torch/_torch_docs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 663ded865d2c9..42208cfc6725e 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -4516,7 +4516,7 @@ def parse_kwargs(desc): >>> x tensor([[0, 1], [2, 3]]) - >>> torch.flip(x, 1, [0, 1]) + >>> torch.rot90(x, 1, [0, 1]) tensor([[1, 3], [0, 2]]) @@ -4527,7 +4527,7 @@ def parse_kwargs(desc): [[4, 5], [6, 7]]]) - >>> torch.flip(x, 1, [1, 2]) + >>> torch.rot90(x, 1, [1, 2]) tensor([[[1, 3], [0, 2]], From 4d914d628b320619f5620c7d61da57a1257fe74d Mon Sep 17 00:00:00 2001 From: Wei Yang Date: Wed, 20 Jun 2018 00:33:10 -0700 Subject: [PATCH 03/18] support wrap_dims for negative dims in flip() and rot90() --- .../src/ATen/native/TensorTransformations.cpp | 23 +++++------ aten/src/ATen/native/TensorTransformations.h | 41 +++++++++++-------- .../ATen/native/cuda/TensorTransformations.cu | 12 +++--- test/test_autograd.py | 2 + test/test_cuda.py | 2 + test/test_torch.py | 6 +-- 6 files changed, 47 insertions(+), 39 deletions(-) diff --git a/aten/src/ATen/native/TensorTransformations.cpp b/aten/src/ATen/native/TensorTransformations.cpp index 4b6902609f37a..badf5e99afd48 100644 --- a/aten/src/ATen/native/TensorTransformations.cpp +++ b/aten/src/ATen/native/TensorTransformations.cpp @@ -11,9 +11,10 @@ namespace native { Tensor flip_cpu(const Tensor& self, IntList dims) { const int64_t total_dims = self.dim(), flip_dims_size = dims.size(); - check_errors(total_dims, flip_dims_size, dims); + flip_check_errors(total_dims, flip_dims_size, dims); auto flip_dims_v = std::vector(dims); + wrap_dims(flip_dims_v, total_dims); std::sort(flip_dims_v.begin(), flip_dims_v.end()); auto final_indices = std::vector(total_dims); @@ -62,19 +63,16 @@ Tensor rot90(const Tensor& self, int64_t k, IntList dims) { AT_CHECK(total_rot_dims == 2, "expected total rotation dims == 2, but got dims = ", total_rot_dims); - AT_CHECK(dims[0] != dims[1], - "expected rotation dims to be different, but got both dims = ", dims[0]); + AT_CHECK(dims[0] != dims[1] && std::abs(dims[0] - dims[1]) != total_dims, + "expected rotation dims to be different, but got dim0 = ", dims[0], + " and dim1 = ", dims[1]); // check range of dims - auto rot_dims_v = std::vector(dims); - std::sort(rot_dims_v.begin(), rot_dims_v.end()); + AT_CHECK(dims[0] < total_dims && dims[0] >= -total_dims, + "Rotation dim0 out of range, dim0 = ", dims[0]); - AT_CHECK(rot_dims_v[0] >= 0, - "expected rotation dims >= 0, but got dims=", rot_dims_v[0]); - - AT_CHECK(rot_dims_v[1] <= total_dims - 1, - "expected rotation dims <= total_dims - 1, but got dims = ", rot_dims_v[1], - ", where total dims - 1 = ", total_dims - 1); + AT_CHECK(dims[1] < total_dims && dims[1] >= -total_dims, + "Rotation dim1 out of range, dim1 = ", dims[1]); // handle modulo with negative k k = (4 + (k % 4)) % 4; @@ -83,13 +81,12 @@ Tensor rot90(const Tensor& self, int64_t k, IntList dims) { case 1: return self.flip({dims[1]}).transpose_(dims[0], dims[1]); case 2: - return self.flip({dims[0]}).flip({dims[1]}); + return self.flip(dims); case 3: return self.transpose(dims[0], dims[1]).flip({dims[1]}); default: return self.clone(); } - } }} // namespace at::native diff --git a/aten/src/ATen/native/TensorTransformations.h b/aten/src/ATen/native/TensorTransformations.h index 554a46f281e73..bb07766b77eb5 100644 --- a/aten/src/ATen/native/TensorTransformations.h +++ b/aten/src/ATen/native/TensorTransformations.h @@ -8,32 +8,37 @@ namespace at { namespace native { -static inline void check_errors(int64_t total_dims, int64_t flip_dims_size, IntList dims) { +// wrap negative dims +static inline void wrap_dims(std::vector& v, int64_t n) { + for (int64_t i = 0; i < v.size(); i++) { + if (v[i] < 0) { + v[i] = (n + (v[i] % n)) % n; + } + } +} + +static inline void flip_check_errors(int64_t total_dims, int64_t flip_dims_size, IntList dims) { // check if number of axis in dim is valid - AT_CHECK(flip_dims_size > 0, - "expected input tensor dims > 0, but got tensor dims size=", flip_dims_size); + AT_CHECK(flip_dims_size > 0 && flip_dims_size <= total_dims, + "flip dims size out of range, got flip dims size=", flip_dims_size); - // check duplicates in dims auto flip_dims_v = std::vector(dims); - flip_dims_v.erase(std::unique(flip_dims_v.begin(), flip_dims_v.end()), flip_dims_v.end()); - AT_CHECK((int64_t)flip_dims_v.size() == flip_dims_size, - "dims has duplicates, original flip dims size=", flip_dims_size, - ", but unique flip dims size=", flip_dims_v.size()); - - // check len of dims - AT_CHECK(flip_dims_size <= total_dims, - "expected flip dims size <= tensor total dims, but got flip dims size=", - flip_dims_size, " and tensor total dim=", total_dims); // check if dims axis within range auto min_max_d = std::minmax_element(flip_dims_v.begin(), flip_dims_v.end()); - AT_CHECK(*min_max_d.first >= 0, - "expected flip dims axis >= 0, but got min flip dims=", *min_max_d.first); + AT_CHECK(*min_max_d.first < total_dims && *min_max_d.first >= -total_dims, + "The min flip dims out of range, got min flip dims=", *min_max_d.first); + + AT_CHECK(*min_max_d.second < total_dims && *min_max_d.second >= -total_dims, + "The max flip dims out of range, got max flip dims=", *min_max_d.second); - AT_CHECK(*min_max_d.second < total_dims, - "expected flip dims axis < tensor total dims, but got max flip dims=", - *min_max_d.second, " and tensor total dim=", total_dims); + // check duplicates in dims + wrap_dims(flip_dims_v, total_dims); + flip_dims_v.erase(std::unique(flip_dims_v.begin(), flip_dims_v.end()), flip_dims_v.end()); + AT_CHECK((int64_t)flip_dims_v.size() == flip_dims_size, + "dims has duplicates, original flip dims size=", flip_dims_size, + ", but unique flip dims size=", flip_dims_v.size()); } }} // namespace at::native diff --git a/aten/src/ATen/native/cuda/TensorTransformations.cu b/aten/src/ATen/native/cuda/TensorTransformations.cu index ee4d030f775e7..800b5a07a662e 100644 --- a/aten/src/ATen/native/cuda/TensorTransformations.cu +++ b/aten/src/ATen/native/cuda/TensorTransformations.cu @@ -69,7 +69,7 @@ void flip_cuda_kernel(scalar_t* in_tensor, scalar_t* out_tensor, int64_t N, int6 Tensor flip_cuda(const Tensor& self, IntList dims) { auto in_tensor = self; const int64_t flip_dims_size = dims.size(), total_dims = in_tensor.dim(), N = in_tensor.numel(); - check_errors(total_dims, flip_dims_size, dims); + flip_check_errors(total_dims, flip_dims_size, dims); int64_t block_size = 512; dim3 dim_block(block_size); @@ -80,13 +80,16 @@ Tensor flip_cuda(const Tensor& self, IntList dims) { return out_tensor; } + auto flip_dims = std::vector(dims); + wrap_dims(flip_dims, total_dims); + // use kernel_pointwise_flip_apply2 only when to-flip dim is the 1st or last dim, where collapseDims can reduce the amount of work - if (flip_dims_size == 1 && in_tensor.is_contiguous() && (dims[0] == 0 || dims[0] == total_dims - 1)) { + if (flip_dims_size == 1 && in_tensor.is_contiguous() && (flip_dims[0] == 0 || flip_dims[0] == total_dims - 1)) { AT_DISPATCH_ALL_TYPES_AND_HALF(in_tensor.type(), "flip_cuda", [&] { auto in_tensor_info = cuda::detail::getTensorInfo(in_tensor); auto out_tensor_info = cuda::detail::getTensorInfo(out_tensor); - int flip_dim = in_tensor_info.collapseDims(dims[0]); - out_tensor_info.collapseDims(dims[0]); + int flip_dim = in_tensor_info.collapseDims(flip_dims[0]); + out_tensor_info.collapseDims(flip_dims[0]); kernel_pointwise_flip_apply2 <<>>( in_tensor_info, out_tensor_info, N, flip_dim, total_dims); @@ -94,7 +97,6 @@ Tensor flip_cuda(const Tensor& self, IntList dims) { return out_tensor; } - auto flip_dims = std::vector(dims); auto flip_dims_t = at::CPU(kLong).tensorFromBlob(flip_dims.data(), {static_cast(flip_dims.size())}); auto shape = std::vector(in_tensor.sizes()); diff --git a/test/test_autograd.py b/test/test_autograd.py index 60c0feb831d9b..be88f7be07fd9 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -2648,8 +2648,10 @@ class dont_convert(tuple): ('flip', (S, S, S), ([0, 1, 2],), 'd012'), ('flip', (S, S, S), ([0, 2],), 'd02'), ('flip', (S, S, S), ([2, 0],), 'd20'), + ('flip', (S, S, S), ([-1],), 'neg_d'), ('rot90', (S, S, S), (1, [0, 1],), 'k1_d01'), ('rot90', (S, S, S), (1, [1, 2],), 'k1_d12'), + ('rot90', (S, S, S), (1, [1, -1],), 'k1_neg_d'), ('view_as', (S, S, S), (non_differentiable(torch.rand(S * S, S)),)), ('view_as', (), (non_differentiable(torch.tensor(5.5)),), 'scalar'), ('view_as', (), (non_differentiable(torch.rand(1, 1)),), 'scalar_to_dims'), diff --git a/test/test_cuda.py b/test/test_cuda.py index 090e288d89b22..ebf0badf05af1 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -420,8 +420,10 @@ def tmp(t): ('flip', small_3d, lambda t: [0, 1, 2], 'd012', types, True), ('flip', small_3d, lambda t: [0, 2], 'd02', types, True), ('flip', small_3d, lambda t: [2, 0], 'd20', types, True), + ('flip', small_3d, lambda t: [-1], 'neg_d', types, True), ('rot90', small_2d, lambda t: [1, [0, 1]], 'k1_d01', types, True), ('rot90', small_3d, lambda t: [1, [1, 2]], 'k1_d12', types, True), + ('rot90', small_3d, lambda t: [1, [1, -1]], 'k1_neg_d', types, True), ('rsqrt', lambda t: constant_tensor_add(1, small_3d(t)), lambda t: [], None, float_types), ('sinh', lambda t: tensor_clamp(small_3d(t), -1, 1), lambda t: [], None, float_types), ('tan', lambda t: tensor_clamp(small_3d(t), -1, 1), lambda t: [], None, float_types), diff --git a/test/test_torch.py b/test/test_torch.py index e27af9d622ef1..4c910d2a71c1e 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -6686,6 +6686,8 @@ def _test_flip(self, use_cuda=False): self.assertEqual(torch.tensor([7, 8, 5, 6, 3, 4, 1, 2]).view(2, 2, 2), data.flip(0, 1)) self.assertEqual(torch.tensor([8, 7, 6, 5, 4, 3, 2, 1]).view(2, 2, 2), data.flip(0, 1, 2)) + # check for wrap dim + self.assertEqual(torch.tensor([2, 1, 4, 3, 6, 5, 8, 7]).view(2, 2, 2), data.flip(-1)) # check for permute self.assertEqual(torch.tensor([6, 5, 8, 7, 2, 1, 4, 3]).view(2, 2, 2), data.flip(0, 2)) self.assertEqual(torch.tensor([6, 5, 8, 7, 2, 1, 4, 3]).view(2, 2, 2), data.flip(2, 0)) @@ -6696,8 +6698,6 @@ def _test_flip(self, use_cuda=False): self.assertRaises(TypeError, lambda: data.flip()) # not allow size of flip dim > total dims self.assertRaises(RuntimeError, lambda: data.flip(0, 1, 2, 3)) - # not allow dim < 0 - self.assertRaises(RuntimeError, lambda: data.flip(-1)) # not allow dim > max dim self.assertRaises(RuntimeError, lambda: data.flip(3)) @@ -6755,9 +6755,9 @@ def _test_rot90(self, use_cuda=False): # test tensor with more than 2D data = torch.arange(1, 9, device=device).view(2, 2, 2) self.assertEqual(torch.tensor([2, 4, 1, 3, 6, 8, 5, 7]).view(2, 2, 2), data.rot90(1, [1, 2])) + self.assertEqual(data.rot90(1, [1, -1]), data.rot90(1, [1, 2])) # test for errors - self.assertRaises(RuntimeError, lambda: data.rot90(1, [-1, 0])) self.assertRaises(RuntimeError, lambda: data.rot90(1, [0, 3])) self.assertRaises(RuntimeError, lambda: data.rot90(1, [1, 1])) self.assertRaises(RuntimeError, lambda: data.rot90(1, [0, 1, 2])) From d6ce84d7fe400fdbc730fd98782ded4e477c058f Mon Sep 17 00:00:00 2001 From: Wei Yang Date: Wed, 20 Jun 2018 13:39:48 -0700 Subject: [PATCH 04/18] 1. moved wrap_dims() to WrapDimUtils.h; 2. re-arange ops order for rot90 k=3; 3. TODO: add default values to args k and dims --- aten/src/ATen/WrapDimUtils.h | 9 +++++++++ aten/src/ATen/native/TensorTransformations.cpp | 2 +- aten/src/ATen/native/TensorTransformations.h | 10 +--------- test/test_torch.py | 3 +++ 4 files changed, 14 insertions(+), 10 deletions(-) diff --git a/aten/src/ATen/WrapDimUtils.h b/aten/src/ATen/WrapDimUtils.h index a07efa24d21a2..8bd7239da776c 100644 --- a/aten/src/ATen/WrapDimUtils.h +++ b/aten/src/ATen/WrapDimUtils.h @@ -86,4 +86,13 @@ static inline int64_t legacy_cat_wrap_dim(int64_t dim, TensorList tensors) { return dim; } +// wrap negative dims in v, where total dim is n +static inline void wrap_dims(std::vector& v, int64_t n) { + for (auto i = 0; i < v.size(); i++) { + if (v[i] < 0) { + v[i] = (n + (v[i] % n)) % n; + } + } +} + } diff --git a/aten/src/ATen/native/TensorTransformations.cpp b/aten/src/ATen/native/TensorTransformations.cpp index badf5e99afd48..6431a45bcba1a 100644 --- a/aten/src/ATen/native/TensorTransformations.cpp +++ b/aten/src/ATen/native/TensorTransformations.cpp @@ -83,7 +83,7 @@ Tensor rot90(const Tensor& self, int64_t k, IntList dims) { case 2: return self.flip(dims); case 3: - return self.transpose(dims[0], dims[1]).flip({dims[1]}); + return self.flip({dims[0]}).transpose_(dims[0], dims[1]); default: return self.clone(); } diff --git a/aten/src/ATen/native/TensorTransformations.h b/aten/src/ATen/native/TensorTransformations.h index bb07766b77eb5..e000b13f6a379 100644 --- a/aten/src/ATen/native/TensorTransformations.h +++ b/aten/src/ATen/native/TensorTransformations.h @@ -1,6 +1,7 @@ #include "ATen/ATen.h" #include +#include #include #include @@ -8,15 +9,6 @@ namespace at { namespace native { -// wrap negative dims -static inline void wrap_dims(std::vector& v, int64_t n) { - for (int64_t i = 0; i < v.size(); i++) { - if (v[i] < 0) { - v[i] = (n + (v[i] % n)) % n; - } - } -} - static inline void flip_check_errors(int64_t total_dims, int64_t flip_dims_size, IntList dims) { // check if number of axis in dim is valid AT_CHECK(flip_dims_size > 0 && flip_dims_size <= total_dims, diff --git a/test/test_torch.py b/test/test_torch.py index 4c910d2a71c1e..180c2c5fc9589 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -6744,6 +6744,9 @@ def _test_rot90(self, use_cuda=False): self.assertEqual(torch.tensor([4, 3, 2, 1]).view(2, 2), data.rot90(2, [0, 1])) self.assertEqual(torch.tensor([3, 1, 4, 2]).view(2, 2), data.rot90(3, [0, 1])) + # test for default args k=1, dims=[0, 1] + # self.assertEqual(data.rot90(), data.rot90(1, [0, 1])) + # test for reversed order of dims self.assertEqual(data.rot90(3, [0, 1]), data.rot90(1, [1, 0])) From b95f07f90d97ba056367028a48cbafe438b90829 Mon Sep 17 00:00:00 2001 From: Wei Yang Date: Wed, 20 Jun 2018 23:56:59 -0700 Subject: [PATCH 05/18] 1. addressed comments; 2. hacked an IntList parser at python_arg_parser.cpp --- aten/src/ATen/WrapDimUtils.h | 2 +- aten/src/ATen/native/native_functions.yaml | 3 ++- aten/src/ATen/native_parse.py | 2 ++ test/test_autograd.py | 1 + test/test_cuda.py | 1 + test/test_torch.py | 2 +- torch/csrc/utils/python_arg_parser.cpp | 28 +++++++++++++++++++++- 7 files changed, 35 insertions(+), 4 deletions(-) diff --git a/aten/src/ATen/WrapDimUtils.h b/aten/src/ATen/WrapDimUtils.h index 8bd7239da776c..0f3179f97bf75 100644 --- a/aten/src/ATen/WrapDimUtils.h +++ b/aten/src/ATen/WrapDimUtils.h @@ -88,7 +88,7 @@ static inline int64_t legacy_cat_wrap_dim(int64_t dim, TensorList tensors) { // wrap negative dims in v, where total dim is n static inline void wrap_dims(std::vector& v, int64_t n) { - for (auto i = 0; i < v.size(); i++) { + for (size_t i = 0; i < v.size(); i++) { if (v[i] < 0) { v[i] = (n + (v[i] % n)) % n; } diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index a701c18dc30c7..a326d56f4aa79 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -1445,7 +1445,8 @@ CPU: flip_cpu CUDA: flip_cuda -- func: rot90(Tensor self, int64_t k, IntList dims) -> Tensor +# default IntList value {0,1} should not add space after comma, since native_parse.py uses ', ' to split args +- func: rot90(Tensor self, int64_t k=1, IntList dims={0,1}) -> Tensor - func: _trilinear(Tensor i1, Tensor i2, Tensor i3, IntList expand1, IntList expand2, IntList expand3, IntList sumdim, int64_t unroll_dim=1) -> Tensor variants: function diff --git a/aten/src/ATen/native_parse.py b/aten/src/ATen/native_parse.py index 13d852d2f1e14..c3a0068df7ea2 100644 --- a/aten/src/ATen/native_parse.py +++ b/aten/src/ATen/native_parse.py @@ -20,6 +20,8 @@ def parse_default(s): return s elif s == '{}': return '{}' + elif re.match(r'{.*}', s): + return s elif s == 'nullopt': return s try: diff --git a/test/test_autograd.py b/test/test_autograd.py index be88f7be07fd9..9849bcbcaf3c6 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -2652,6 +2652,7 @@ class dont_convert(tuple): ('rot90', (S, S, S), (1, [0, 1],), 'k1_d01'), ('rot90', (S, S, S), (1, [1, 2],), 'k1_d12'), ('rot90', (S, S, S), (1, [1, -1],), 'k1_neg_d'), + ('rot90', (S, S, S), (), 'default'), ('view_as', (S, S, S), (non_differentiable(torch.rand(S * S, S)),)), ('view_as', (), (non_differentiable(torch.tensor(5.5)),), 'scalar'), ('view_as', (), (non_differentiable(torch.rand(1, 1)),), 'scalar_to_dims'), diff --git a/test/test_cuda.py b/test/test_cuda.py index ebf0badf05af1..b391b73c451ef 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -424,6 +424,7 @@ def tmp(t): ('rot90', small_2d, lambda t: [1, [0, 1]], 'k1_d01', types, True), ('rot90', small_3d, lambda t: [1, [1, 2]], 'k1_d12', types, True), ('rot90', small_3d, lambda t: [1, [1, -1]], 'k1_neg_d', types, True), + ('rot90', small_3d, lambda t: [], 'default', types, True), ('rsqrt', lambda t: constant_tensor_add(1, small_3d(t)), lambda t: [], None, float_types), ('sinh', lambda t: tensor_clamp(small_3d(t), -1, 1), lambda t: [], None, float_types), ('tan', lambda t: tensor_clamp(small_3d(t), -1, 1), lambda t: [], None, float_types), diff --git a/test/test_torch.py b/test/test_torch.py index 180c2c5fc9589..5595035507c01 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -6745,7 +6745,7 @@ def _test_rot90(self, use_cuda=False): self.assertEqual(torch.tensor([3, 1, 4, 2]).view(2, 2), data.rot90(3, [0, 1])) # test for default args k=1, dims=[0, 1] - # self.assertEqual(data.rot90(), data.rot90(1, [0, 1])) + self.assertEqual(data.rot90(), data.rot90(1, [0, 1])) # test for reversed order of dims self.assertEqual(data.rot90(3, [0, 1]), data.rot90(1, [1, 0])) diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp index c6d2315a6fe4e..97f17f3b50415 100644 --- a/torch/csrc/utils/python_arg_parser.cpp +++ b/torch/csrc/utils/python_arg_parser.cpp @@ -161,6 +161,32 @@ static inline at::optional parse_as_integer(const std::string& s) { return (*str_end == 0) ? at::optional(ans) : at::nullopt; } +/* +Parse default value of IntList declared at native_functions.yaml +There are two kinds of default values: +1. IntList[2] x=1 (where size=2, value={1,1} +2. IntList x={1,2,3} (where size=3, value={1,2,3}, note that there cannot be space after comma since native_parse.py uses ', ' +to split args) +*/ +static inline std::vector parse_intlist_args(const std::string& s, int64_t size) { + if (s[0] != '{') { + return std::vector(size, std::stoi(s)); + } + + auto args = std::vector(); + int64_t x = 0; + for (size_t i = 0; i < s.size();) { + if ('0' <= s[i] && s[i] <= '9') { + while ('0' <= s[i] && s[i] <= '9') { + x = x * 10 + (s[i++] - '0'); + } + args.emplace_back(x); + x = 0; + } + i += 1; + } + return args; +} void FunctionParameter::set_default_str(const std::string& str) { if (str == "None") { @@ -189,7 +215,7 @@ void FunctionParameter::set_default_str(const std::string& str) { } } else if (type_ == ParameterType::INT_LIST) { if (str != "None") { - default_intlist.assign(size, std::stoi(str)); + default_intlist = parse_intlist_args(str, size); } } else if (type_ == ParameterType::SCALARTYPE) { if (str == "None") { From e79638a0ebdc621169467e5c7f1c19e4471ece5f Mon Sep 17 00:00:00 2001 From: Wei Yang Date: Mon, 25 Jun 2018 13:58:31 -0700 Subject: [PATCH 06/18] wip --- torch/csrc/utils/python_arg_parser.cpp | 41 +++++++++++++++++++++----- 1 file changed, 33 insertions(+), 8 deletions(-) diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp index 97f17f3b50415..f573e022a84b2 100644 --- a/torch/csrc/utils/python_arg_parser.cpp +++ b/torch/csrc/utils/python_arg_parser.cpp @@ -168,22 +168,47 @@ There are two kinds of default values: 2. IntList x={1,2,3} (where size=3, value={1,2,3}, note that there cannot be space after comma since native_parse.py uses ', ' to split args) */ +// static inline std::vector parse_intlist_args(const std::string& s, int64_t size) { +// if (s[0] != '{') { +// return std::vector(size, std::stoi(s)); +// } +// +// auto args = std::vector(); +// int64_t x = 0; +// for (size_t i = 0; i < s.size();) { +// if ('0' <= s[i] && s[i] <= '9') { +// while ('0' <= s[i] && s[i] <= '9') { +// x = x * 10 + (s[i++] - '0'); +// } +// args.emplace_back(x); +// x = 0; +// } +// i += 1; +// } +// return args; +// } + static inline std::vector parse_intlist_args(const std::string& s, int64_t size) { if (s[0] != '{') { return std::vector(size, std::stoi(s)); } auto args = std::vector(); - int64_t x = 0; - for (size_t i = 0; i < s.size();) { - if ('0' <= s[i] && s[i] <= '9') { - while ('0' <= s[i] && s[i] <= '9') { - x = x * 10 + (s[i++] - '0'); - } - args.emplace_back(x); + int64_t x = 0, sign = 1; + for (size_t i = 0; i < s.size(); i++) { + if (s[i] == '-') { + sign *= -1; + } + else if ('0' <= s[i] && s[i] <= '9') { + x = x * 10 + (s[i] - '0'); + } + else if (s[i] == ',' || s[i] == '}') { + args.emplace_back(sign * x); + sign = 1; x = 0; } - i += 1; + else { // '{' + } } return args; } From f0c861279b9d4bedda764cf0dba459204bcb7a0a Mon Sep 17 00:00:00 2001 From: Wei Yang Date: Mon, 25 Jun 2018 18:37:26 -0700 Subject: [PATCH 07/18] supports negative default IntList values, and addresses comments --- aten/src/ATen/WrapDimUtils.h | 8 ++++---- torch/csrc/utils/python_arg_parser.cpp | 24 ++---------------------- 2 files changed, 6 insertions(+), 26 deletions(-) diff --git a/aten/src/ATen/WrapDimUtils.h b/aten/src/ATen/WrapDimUtils.h index 0f3179f97bf75..207dc46466155 100644 --- a/aten/src/ATen/WrapDimUtils.h +++ b/aten/src/ATen/WrapDimUtils.h @@ -87,10 +87,10 @@ static inline int64_t legacy_cat_wrap_dim(int64_t dim, TensorList tensors) { } // wrap negative dims in v, where total dim is n -static inline void wrap_dims(std::vector& v, int64_t n) { - for (size_t i = 0; i < v.size(); i++) { - if (v[i] < 0) { - v[i] = (n + (v[i] % n)) % n; +static inline void wrap_dims(std::vector& to_transform_dims, int64_t tensor_total_dims) { + for (size_t i = 0; i < to_transform_dims.size(); i++) { + if (to_transform_dims[i] < 0) { + to_transform_dims[i] = (tensor_total_dims + (to_transform_dims[i] % tensor_total_dims)) % tensor_total_dims; } } } diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp index f573e022a84b2..33dc66eb5b12c 100644 --- a/torch/csrc/utils/python_arg_parser.cpp +++ b/torch/csrc/utils/python_arg_parser.cpp @@ -163,31 +163,11 @@ static inline at::optional parse_as_integer(const std::string& s) { /* Parse default value of IntList declared at native_functions.yaml + There are two kinds of default values: 1. IntList[2] x=1 (where size=2, value={1,1} -2. IntList x={1,2,3} (where size=3, value={1,2,3}, note that there cannot be space after comma since native_parse.py uses ', ' -to split args) +2. IntList x={1,2,3} (where size=3, value={1,2,3}, note that there cannot be space after comma since native_parse.py uses ', ' to split args) */ -// static inline std::vector parse_intlist_args(const std::string& s, int64_t size) { -// if (s[0] != '{') { -// return std::vector(size, std::stoi(s)); -// } -// -// auto args = std::vector(); -// int64_t x = 0; -// for (size_t i = 0; i < s.size();) { -// if ('0' <= s[i] && s[i] <= '9') { -// while ('0' <= s[i] && s[i] <= '9') { -// x = x * 10 + (s[i++] - '0'); -// } -// args.emplace_back(x); -// x = 0; -// } -// i += 1; -// } -// return args; -// } - static inline std::vector parse_intlist_args(const std::string& s, int64_t size) { if (s[0] != '{') { return std::vector(size, std::stoi(s)); From 57e309aa7ca41ffd609b1d8a3903049f221fcf4e Mon Sep 17 00:00:00 2001 From: Wei Yang Date: Mon, 25 Jun 2018 20:06:11 -0700 Subject: [PATCH 08/18] add more tests in parse_intlist_args --- aten/src/ATen/WrapDimUtils.h | 2 +- torch/csrc/utils/python_arg_parser.cpp | 12 +++++++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/aten/src/ATen/WrapDimUtils.h b/aten/src/ATen/WrapDimUtils.h index 207dc46466155..35feba7410627 100644 --- a/aten/src/ATen/WrapDimUtils.h +++ b/aten/src/ATen/WrapDimUtils.h @@ -86,7 +86,7 @@ static inline int64_t legacy_cat_wrap_dim(int64_t dim, TensorList tensors) { return dim; } -// wrap negative dims in v, where total dim is n +// wrap negative dims in to_transform_dims static inline void wrap_dims(std::vector& to_transform_dims, int64_t tensor_total_dims) { for (size_t i = 0; i < to_transform_dims.size(); i++) { if (to_transform_dims[i] < 0) { diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp index 33dc66eb5b12c..c6f3fe85fa4f5 100644 --- a/torch/csrc/utils/python_arg_parser.cpp +++ b/torch/csrc/utils/python_arg_parser.cpp @@ -173,23 +173,29 @@ static inline std::vector parse_intlist_args(const std::string& s, int6 return std::vector(size, std::stoi(s)); } + size_t n = s.size(); + AT_CHECK(s[0] == '{', "Default value of IntList is missing left brace '{', found ", s[0]); + AT_CHECK(s[n - 1] == '}', "Default value of IntList is missing right brace '}', found ", s[n - 1]); + auto args = std::vector(); int64_t x = 0, sign = 1; - for (size_t i = 0; i < s.size(); i++) { + for (size_t i = 1; i < n - 1; i++) { if (s[i] == '-') { sign *= -1; } else if ('0' <= s[i] && s[i] <= '9') { x = x * 10 + (s[i] - '0'); } - else if (s[i] == ',' || s[i] == '}') { + else if (s[i] == ',') { args.emplace_back(sign * x); sign = 1; x = 0; } - else { // '{' + else { + AT_ERROR("Illegal char in IntList default value: ", s[i]); } } + args.emplace_back(sign * x); return args; } From b3af2e9d75e85a4717cd5c43183cc25e38991d6a Mon Sep 17 00:00:00 2001 From: Wei Yang Date: Thu, 28 Jun 2018 19:40:27 -0700 Subject: [PATCH 09/18] addresses comments --- torch/csrc/utils/python_arg_parser.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp index c6f3fe85fa4f5..b1ae6804d029d 100644 --- a/torch/csrc/utils/python_arg_parser.cpp +++ b/torch/csrc/utils/python_arg_parser.cpp @@ -174,7 +174,7 @@ static inline std::vector parse_intlist_args(const std::string& s, int6 } size_t n = s.size(); - AT_CHECK(s[0] == '{', "Default value of IntList is missing left brace '{', found ", s[0]); + // since already checked left brace '{' above, here only checks right brace '}' AT_CHECK(s[n - 1] == '}', "Default value of IntList is missing right brace '}', found ", s[n - 1]); auto args = std::vector(); From 3fcb2c138bc66c5b641a1f9870477461301c2609 Mon Sep 17 00:00:00 2001 From: Wei Yang Date: Fri, 29 Jun 2018 09:38:45 -0700 Subject: [PATCH 10/18] 1. disallow dims to < -total_dim in wrap_dims(); 2. return empty vector for x={} default args --- aten/src/ATen/WrapDimUtils.h | 1 + torch/csrc/utils/python_arg_parser.cpp | 5 +++++ 2 files changed, 6 insertions(+) diff --git a/aten/src/ATen/WrapDimUtils.h b/aten/src/ATen/WrapDimUtils.h index 35feba7410627..9a3e5d3fbf9d3 100644 --- a/aten/src/ATen/WrapDimUtils.h +++ b/aten/src/ATen/WrapDimUtils.h @@ -90,6 +90,7 @@ static inline int64_t legacy_cat_wrap_dim(int64_t dim, TensorList tensors) { static inline void wrap_dims(std::vector& to_transform_dims, int64_t tensor_total_dims) { for (size_t i = 0; i < to_transform_dims.size(); i++) { if (to_transform_dims[i] < 0) { + AT_CHECK(to_transform_dims[i] >= -tensor_total_dims, "to_transform_dims = ", to_transform_dims[i], " less than lower bound ", -tensor_total_dims); to_transform_dims[i] = (tensor_total_dims + (to_transform_dims[i] % tensor_total_dims)) % tensor_total_dims; } } diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp index b1ae6804d029d..0f3d75238d9cc 100644 --- a/torch/csrc/utils/python_arg_parser.cpp +++ b/torch/csrc/utils/python_arg_parser.cpp @@ -178,6 +178,11 @@ static inline std::vector parse_intlist_args(const std::string& s, int6 AT_CHECK(s[n - 1] == '}', "Default value of IntList is missing right brace '}', found ", s[n - 1]); auto args = std::vector(); + // for case IntList x={}, return an empty vector + if (s.size() == 2) { + return args; + } + int64_t x = 0, sign = 1; for (size_t i = 1; i < n - 1; i++) { if (s[i] == '-') { From d7a9a6ee0e94212d261a7c615af5723963300d6a Mon Sep 17 00:00:00 2001 From: Wei Yang Date: Mon, 2 Jul 2018 10:12:43 -0700 Subject: [PATCH 11/18] [wip] use strtol in parse_intlist_args --- aten/src/ATen/WrapDimUtils.h | 9 +-- .../src/ATen/native/TensorTransformations.cpp | 9 ++- aten/src/ATen/native/TensorTransformations.h | 2 +- .../ATen/native/cuda/TensorTransformations.cu | 2 +- aten/src/ATen/native/native_functions.yaml | 2 +- torch/csrc/utils/python_arg_parser.cpp | 68 +++++++++++++++---- 6 files changed, 66 insertions(+), 26 deletions(-) diff --git a/aten/src/ATen/WrapDimUtils.h b/aten/src/ATen/WrapDimUtils.h index 9a3e5d3fbf9d3..aa4f7c158bd1d 100644 --- a/aten/src/ATen/WrapDimUtils.h +++ b/aten/src/ATen/WrapDimUtils.h @@ -87,12 +87,9 @@ static inline int64_t legacy_cat_wrap_dim(int64_t dim, TensorList tensors) { } // wrap negative dims in to_transform_dims -static inline void wrap_dims(std::vector& to_transform_dims, int64_t tensor_total_dims) { - for (size_t i = 0; i < to_transform_dims.size(); i++) { - if (to_transform_dims[i] < 0) { - AT_CHECK(to_transform_dims[i] >= -tensor_total_dims, "to_transform_dims = ", to_transform_dims[i], " less than lower bound ", -tensor_total_dims); - to_transform_dims[i] = (tensor_total_dims + (to_transform_dims[i] % tensor_total_dims)) % tensor_total_dims; - } +static inline void wrap_all_dims(std::vector& dims_to_wrap, int64_t tensor_total_dims) { + for (size_t i = 0; i < dims_to_wrap.size(); i++) { + dims_to_wrap[i] = maybe_wrap_dim(dims_to_wrap[i], tensor_total_dims) % tensor_total_dims; } } diff --git a/aten/src/ATen/native/TensorTransformations.cpp b/aten/src/ATen/native/TensorTransformations.cpp index 6431a45bcba1a..31359fdda65d9 100644 --- a/aten/src/ATen/native/TensorTransformations.cpp +++ b/aten/src/ATen/native/TensorTransformations.cpp @@ -14,7 +14,7 @@ Tensor flip_cpu(const Tensor& self, IntList dims) { flip_check_errors(total_dims, flip_dims_size, dims); auto flip_dims_v = std::vector(dims); - wrap_dims(flip_dims_v, total_dims); + wrap_all_dims(flip_dims_v, total_dims); std::sort(flip_dims_v.begin(), flip_dims_v.end()); auto final_indices = std::vector(total_dims); @@ -58,7 +58,12 @@ Tensor flip_cpu(const Tensor& self, IntList dims) { } Tensor rot90(const Tensor& self, int64_t k, IntList dims) { - const int64_t total_dims = self.dim(), total_rot_dims = dims.size(); + const int64_t total_dims = self.dim(); + int64_t total_rot_dims = dims.size(); + if (total_rot_dims == 0) { + dims = IntList({0,1}); + total_rot_dims = 2; + } AT_CHECK(total_rot_dims == 2, "expected total rotation dims == 2, but got dims = ", total_rot_dims); diff --git a/aten/src/ATen/native/TensorTransformations.h b/aten/src/ATen/native/TensorTransformations.h index e000b13f6a379..2504a2c3f201b 100644 --- a/aten/src/ATen/native/TensorTransformations.h +++ b/aten/src/ATen/native/TensorTransformations.h @@ -26,7 +26,7 @@ static inline void flip_check_errors(int64_t total_dims, int64_t flip_dims_size, "The max flip dims out of range, got max flip dims=", *min_max_d.second); // check duplicates in dims - wrap_dims(flip_dims_v, total_dims); + wrap_all_dims(flip_dims_v, total_dims); flip_dims_v.erase(std::unique(flip_dims_v.begin(), flip_dims_v.end()), flip_dims_v.end()); AT_CHECK((int64_t)flip_dims_v.size() == flip_dims_size, "dims has duplicates, original flip dims size=", flip_dims_size, diff --git a/aten/src/ATen/native/cuda/TensorTransformations.cu b/aten/src/ATen/native/cuda/TensorTransformations.cu index 800b5a07a662e..7fa1fe64f28d6 100644 --- a/aten/src/ATen/native/cuda/TensorTransformations.cu +++ b/aten/src/ATen/native/cuda/TensorTransformations.cu @@ -81,7 +81,7 @@ Tensor flip_cuda(const Tensor& self, IntList dims) { } auto flip_dims = std::vector(dims); - wrap_dims(flip_dims, total_dims); + wrap_all_dims(flip_dims, total_dims); // use kernel_pointwise_flip_apply2 only when to-flip dim is the 1st or last dim, where collapseDims can reduce the amount of work if (flip_dims_size == 1 && in_tensor.is_contiguous() && (flip_dims[0] == 0 || flip_dims[0] == total_dims - 1)) { diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index a326d56f4aa79..f8fbe0071c7df 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -1446,7 +1446,7 @@ CUDA: flip_cuda # default IntList value {0,1} should not add space after comma, since native_parse.py uses ', ' to split args -- func: rot90(Tensor self, int64_t k=1, IntList dims={0,1}) -> Tensor +- func: rot90(Tensor self, int64_t k=1, IntList dims={}) -> Tensor - func: _trilinear(Tensor i1, Tensor i2, Tensor i3, IntList expand1, IntList expand2, IntList expand3, IntList sumdim, int64_t unroll_dim=1) -> Tensor variants: function diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp index 0f3d75238d9cc..8bf67151b4f17 100644 --- a/torch/csrc/utils/python_arg_parser.cpp +++ b/torch/csrc/utils/python_arg_parser.cpp @@ -168,30 +168,68 @@ There are two kinds of default values: 1. IntList[2] x=1 (where size=2, value={1,1} 2. IntList x={1,2,3} (where size=3, value={1,2,3}, note that there cannot be space after comma since native_parse.py uses ', ' to split args) */ +// static inline std::vector parse_intlist_args(const std::string& s, int64_t size) { +// printf("s = %s\n", s.c_str()); +// if (s[0] != '{') { +// return std::vector(size, std::stol(s)); +// } +// +// size_t n = s.size(); +// // since already checked left brace '{' above, here only checks right brace '}' +// AT_CHECK(s[n - 1] == '}', "Default value of IntList is missing right brace '}', found ", s[n - 1]); +// +// auto args = std::vector(); +// // for case IntList x={}, return an empty vector +// if (s.size() == 2) { +// return args; +// } +// +// int64_t x = 0, sign = 1; +// for (size_t i = 1; i < n - 1; i++) { +// if (s[i] == '-') { +// sign *= -1; +// } +// else if ('0' <= s[i] && s[i] <= '9') { +// x = x * 10 + (s[i] - '0'); +// } +// else if (s[i] == ',') { +// args.emplace_back(sign * x); +// sign = 1; +// x = 0; +// } +// else { +// AT_ERROR("Illegal char in IntList default value: ", s[i]); +// } +// } +// args.emplace_back(sign * x); +// return args; +// } + static inline std::vector parse_intlist_args(const std::string& s, int64_t size) { + printf("s = %s\n", s.c_str()); + char *str_end; + int64_t x = 0; + size_t n = s.size(); + auto args = std::vector(); + + if (s.empty()) return args; + if (s[0] != '{') { - return std::vector(size, std::stoi(s)); + x = strtol(s.c_str(), &str_end, 10); + if (*str_end == 0) { + args = std::vector(size, x); + } + return args; } - size_t n = s.size(); // since already checked left brace '{' above, here only checks right brace '}' AT_CHECK(s[n - 1] == '}', "Default value of IntList is missing right brace '}', found ", s[n - 1]); - auto args = std::vector(); - // for case IntList x={}, return an empty vector - if (s.size() == 2) { - return args; - } - - int64_t x = 0, sign = 1; + int64_t s = 1, e = 1; for (size_t i = 1; i < n - 1; i++) { - if (s[i] == '-') { - sign *= -1; - } - else if ('0' <= s[i] && s[i] <= '9') { - x = x * 10 + (s[i] - '0'); - } else if (s[i] == ',') { + e = i; + auto x = std::atol(s.substr(s, e - s)); args.emplace_back(sign * x); sign = 1; x = 0; From 03a27cbd258e4d3d117730a6c420698d018a3e46 Mon Sep 17 00:00:00 2001 From: Wei Yang Date: Mon, 2 Jul 2018 20:00:28 -0700 Subject: [PATCH 12/18] addresses comments --- .../src/ATen/native/TensorTransformations.cpp | 7 +- aten/src/ATen/native/native_functions.yaml | 2 +- test/test_torch.py | 4 ++ torch/csrc/utils/python_arg_parser.cpp | 72 ++++--------------- 4 files changed, 19 insertions(+), 66 deletions(-) diff --git a/aten/src/ATen/native/TensorTransformations.cpp b/aten/src/ATen/native/TensorTransformations.cpp index 31359fdda65d9..9653abef26f63 100644 --- a/aten/src/ATen/native/TensorTransformations.cpp +++ b/aten/src/ATen/native/TensorTransformations.cpp @@ -58,12 +58,7 @@ Tensor flip_cpu(const Tensor& self, IntList dims) { } Tensor rot90(const Tensor& self, int64_t k, IntList dims) { - const int64_t total_dims = self.dim(); - int64_t total_rot_dims = dims.size(); - if (total_rot_dims == 0) { - dims = IntList({0,1}); - total_rot_dims = 2; - } + const int64_t total_dims = self.dim(), total_rot_dims = dims.size(); AT_CHECK(total_rot_dims == 2, "expected total rotation dims == 2, but got dims = ", total_rot_dims); diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index f8fbe0071c7df..a326d56f4aa79 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -1446,7 +1446,7 @@ CUDA: flip_cuda # default IntList value {0,1} should not add space after comma, since native_parse.py uses ', ' to split args -- func: rot90(Tensor self, int64_t k=1, IntList dims={}) -> Tensor +- func: rot90(Tensor self, int64_t k=1, IntList dims={0,1}) -> Tensor - func: _trilinear(Tensor i1, Tensor i2, Tensor i3, IntList expand1, IntList expand2, IntList expand3, IntList sumdim, int64_t unroll_dim=1) -> Tensor variants: function diff --git a/test/test_torch.py b/test/test_torch.py index 5595035507c01..2cf8633f5f9a9 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -6755,6 +6755,10 @@ def _test_rot90(self, use_cuda=False): self.assertEqual(data.rot90(3, [0, 1]), data.rot90(-1, [0, 1])) self.assertEqual(data.rot90(-5, [0, 1]), data.rot90(-1, [0, 1])) + # test for dims out-of-range error + self.assertRaises(RuntimeError, lambda: data.rot90(1, [0, -3])) + self.assertRaises(RuntimeError, lambda: data.rot90(1, [0, 2])) + # test tensor with more than 2D data = torch.arange(1, 9, device=device).view(2, 2, 2) self.assertEqual(torch.tensor([2, 4, 1, 3, 6, 8, 5, 7]).view(2, 2, 2), data.rot90(1, [1, 2])) diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp index 8bf67151b4f17..23dd65618c86e 100644 --- a/torch/csrc/utils/python_arg_parser.cpp +++ b/torch/csrc/utils/python_arg_parser.cpp @@ -168,77 +168,31 @@ There are two kinds of default values: 1. IntList[2] x=1 (where size=2, value={1,1} 2. IntList x={1,2,3} (where size=3, value={1,2,3}, note that there cannot be space after comma since native_parse.py uses ', ' to split args) */ -// static inline std::vector parse_intlist_args(const std::string& s, int64_t size) { -// printf("s = %s\n", s.c_str()); -// if (s[0] != '{') { -// return std::vector(size, std::stol(s)); -// } -// -// size_t n = s.size(); -// // since already checked left brace '{' above, here only checks right brace '}' -// AT_CHECK(s[n - 1] == '}', "Default value of IntList is missing right brace '}', found ", s[n - 1]); -// -// auto args = std::vector(); -// // for case IntList x={}, return an empty vector -// if (s.size() == 2) { -// return args; -// } -// -// int64_t x = 0, sign = 1; -// for (size_t i = 1; i < n - 1; i++) { -// if (s[i] == '-') { -// sign *= -1; -// } -// else if ('0' <= s[i] && s[i] <= '9') { -// x = x * 10 + (s[i] - '0'); -// } -// else if (s[i] == ',') { -// args.emplace_back(sign * x); -// sign = 1; -// x = 0; -// } -// else { -// AT_ERROR("Illegal char in IntList default value: ", s[i]); -// } -// } -// args.emplace_back(sign * x); -// return args; -// } - static inline std::vector parse_intlist_args(const std::string& s, int64_t size) { printf("s = %s\n", s.c_str()); - char *str_end; - int64_t x = 0; size_t n = s.size(); - auto args = std::vector(); - if (s.empty()) return args; + // case 1. s = "" + if (s.empty()) return std::vector(); + // case 2. s is an int (e.g., s = "12") if (s[0] != '{') { - x = strtol(s.c_str(), &str_end, 10); - if (*str_end == 0) { - args = std::vector(size, x); - } - return args; + return std::vector(size, std::stol(s)); } + // case 3. s is a list of dims (e.g., s = {1,2}) + // since already checked left brace '{' above, here only checks right brace '}' AT_CHECK(s[n - 1] == '}', "Default value of IntList is missing right brace '}', found ", s[n - 1]); - int64_t s = 1, e = 1; - for (size_t i = 1; i < n - 1; i++) { - else if (s[i] == ',') { - e = i; - auto x = std::atol(s.substr(s, e - s)); - args.emplace_back(sign * x); - sign = 1; - x = 0; - } - else { - AT_ERROR("Illegal char in IntList default value: ", s[i]); - } + auto args = std::vector(); + const char* del = ",{}"; + char* s_p = strdup(s.c_str()); + char *token = strtok(s_p, del); + while (token != NULL) { + args.emplace_back(std::atol(token)); + token = strtok(NULL, del); } - args.emplace_back(sign * x); return args; } From 7686d974825ff7c9f46730beabd3e638303e84cb Mon Sep 17 00:00:00 2001 From: Wei Yang Date: Mon, 2 Jul 2018 20:10:26 -0700 Subject: [PATCH 13/18] nits --- torch/csrc/utils/python_arg_parser.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp index 23dd65618c86e..cea5ff019fff0 100644 --- a/torch/csrc/utils/python_arg_parser.cpp +++ b/torch/csrc/utils/python_arg_parser.cpp @@ -169,18 +169,16 @@ There are two kinds of default values: 2. IntList x={1,2,3} (where size=3, value={1,2,3}, note that there cannot be space after comma since native_parse.py uses ', ' to split args) */ static inline std::vector parse_intlist_args(const std::string& s, int64_t size) { - printf("s = %s\n", s.c_str()); size_t n = s.size(); - // case 1. s = "" if (s.empty()) return std::vector(); - // case 2. s is an int (e.g., s = "12") + // case 1. s is an int (e.g., s=2) if (s[0] != '{') { return std::vector(size, std::stol(s)); } - // case 3. s is a list of dims (e.g., s = {1,2}) + // case 2. s is a list of dims (e.g., s={1,2}) // since already checked left brace '{' above, here only checks right brace '}' AT_CHECK(s[n - 1] == '}', "Default value of IntList is missing right brace '}', found ", s[n - 1]); From e5c46bf29360dc19a517951f8eda50bc64ed5982 Mon Sep 17 00:00:00 2001 From: Wei Yang Date: Wed, 18 Jul 2018 00:03:26 -0700 Subject: [PATCH 14/18] Using istringstream rather than strtok, addressed comments --- aten/src/ATen/WrapDimUtils.h | 4 ++-- torch/csrc/utils/python_arg_parser.cpp | 11 +++++------ 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/aten/src/ATen/WrapDimUtils.h b/aten/src/ATen/WrapDimUtils.h index aa4f7c158bd1d..09d4e2a53fae1 100644 --- a/aten/src/ATen/WrapDimUtils.h +++ b/aten/src/ATen/WrapDimUtils.h @@ -86,10 +86,10 @@ static inline int64_t legacy_cat_wrap_dim(int64_t dim, TensorList tensors) { return dim; } -// wrap negative dims in to_transform_dims +// wrap negative dims in a vector static inline void wrap_all_dims(std::vector& dims_to_wrap, int64_t tensor_total_dims) { for (size_t i = 0; i < dims_to_wrap.size(); i++) { - dims_to_wrap[i] = maybe_wrap_dim(dims_to_wrap[i], tensor_total_dims) % tensor_total_dims; + dims_to_wrap[i] = maybe_wrap_dim(dims_to_wrap[i], tensor_total_dims); } } diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp index cea5ff019fff0..019ea9fda5dbe 100644 --- a/torch/csrc/utils/python_arg_parser.cpp +++ b/torch/csrc/utils/python_arg_parser.cpp @@ -184,12 +184,11 @@ static inline std::vector parse_intlist_args(const std::string& s, int6 AT_CHECK(s[n - 1] == '}', "Default value of IntList is missing right brace '}', found ", s[n - 1]); auto args = std::vector(); - const char* del = ",{}"; - char* s_p = strdup(s.c_str()); - char *token = strtok(s_p, del); - while (token != NULL) { - args.emplace_back(std::atol(token)); - token = strtok(NULL, del); + std::istringstream ss(s.substr(1, s.length() - 2)); // exclude '{' and '}' + std::string tok; + + while(std::getline(ss, tok, ',')) { + args.emplace_back(std::stol(tok)); } return args; } From fd1eff5bf5bdd71612d7125b36112c37a014abdf Mon Sep 17 00:00:00 2001 From: Wei Yang Date: Mon, 23 Jul 2018 19:49:16 -0700 Subject: [PATCH 15/18] 1. modified jit default args parser; 2. added test for flip() on empty tensor --- test/test_torch.py | 4 ++++ tools/jit/gen_jit_dispatch.py | 4 +++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/test/test_torch.py b/test/test_torch.py index 2cf8633f5f9a9..6a3479c000920 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -6722,6 +6722,10 @@ def _test_flip(self, use_cuda=False): self.assertEqual(flip0_result, data.flip(0)) self.assertEqual(flip1_result, data.flip(1)) + # test empty tensor, should just return an empty tensor of the same shape + data = torch.tensor([]) + self.assertEqual(data, data.flip(0)) + def test_flip(self): self._test_flip(self, use_cuda=False) diff --git a/tools/jit/gen_jit_dispatch.py b/tools/jit/gen_jit_dispatch.py index 49948de7180a0..18c043a6c1061 100644 --- a/tools/jit/gen_jit_dispatch.py +++ b/tools/jit/gen_jit_dispatch.py @@ -419,7 +419,9 @@ def format_arg(arg): .replace('false', 'False') \ .replace('nullptr', 'None') \ .replace('Reduction::ElementwiseMean', 'ElementwiseMean') \ - .replace('{}', 'None' if is_tensor_arg(arg) else '[]') + .replace('{}', 'None' if is_tensor_arg(arg) else '[]') \ + .replace('{', '[') \ + .replace('}', ']') default = default_map.get(default, default) decl = '{}={}'.format(decl, default) From 64fd306237ac0132221d01d5840d5b70b82d5f25 Mon Sep 17 00:00:00 2001 From: Wei Yang Date: Tue, 24 Jul 2018 15:10:46 -0700 Subject: [PATCH 16/18] jit debug --- aten/src/ATen/native/TensorTransformations.cpp | 9 +++++++++ aten/src/ATen/native/native_functions.yaml | 2 +- test/test_jit.py | 2 ++ tools/jit/gen_jit_dispatch.py | 3 +++ 4 files changed, 15 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/native/TensorTransformations.cpp b/aten/src/ATen/native/TensorTransformations.cpp index 9653abef26f63..e084c03d0e9c7 100644 --- a/aten/src/ATen/native/TensorTransformations.cpp +++ b/aten/src/ATen/native/TensorTransformations.cpp @@ -60,9 +60,18 @@ Tensor flip_cpu(const Tensor& self, IntList dims) { Tensor rot90(const Tensor& self, int64_t k, IntList dims) { const int64_t total_dims = self.dim(), total_rot_dims = dims.size(); + printf("total_rot_dims %ld\n", total_rot_dims); + + for (int64_t i = 0; i < total_rot_dims; i++) { + printf("dim[%ld] = %ld\n", i, dims[i]); + } + AT_CHECK(total_rot_dims == 2, "expected total rotation dims == 2, but got dims = ", total_rot_dims); + AT_CHECK(total_dims >= 2, + "expected total dims >= 2, but got total dims = ", total_dims); + AT_CHECK(dims[0] != dims[1] && std::abs(dims[0] - dims[1]) != total_dims, "expected rotation dims to be different, but got dim0 = ", dims[0], " and dim1 = ", dims[1]); diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index a326d56f4aa79..1bd0dbf4a0b28 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -1446,7 +1446,7 @@ CUDA: flip_cuda # default IntList value {0,1} should not add space after comma, since native_parse.py uses ', ' to split args -- func: rot90(Tensor self, int64_t k=1, IntList dims={0,1}) -> Tensor +- func: rot90(Tensor self, int64_t k=1, IntList[2] dims=0) -> Tensor - func: _trilinear(Tensor i1, Tensor i2, Tensor i3, IntList expand1, IntList expand2, IntList expand3, IntList sumdim, int64_t unroll_dim=1) -> Tensor variants: function diff --git a/test/test_jit.py b/test/test_jit.py index bd04f63729b28..b5b26f69a28df 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -4861,7 +4861,9 @@ def script_fn(*args, **kwargs): else: call = '{}.{}({}{})'.format(actuals[0], method_name, ', '.join(actuals[1:]), kwargs_str) script = script_template.format(', '.join(formals), call) + # print("script: ", script); CU = torch.jit.CompilationUnit(script) + # print("create_script_fn: tensors = %s" % tensors) return output_process_fn(CU.the_method(*tensors)) return script_fn diff --git a/tools/jit/gen_jit_dispatch.py b/tools/jit/gen_jit_dispatch.py index 18c043a6c1061..b985855b008a0 100644 --- a/tools/jit/gen_jit_dispatch.py +++ b/tools/jit/gen_jit_dispatch.py @@ -425,6 +425,7 @@ def format_arg(arg): default = default_map.get(default, default) decl = '{}={}'.format(decl, default) + # print("decl", decl) return decl args = [] @@ -440,6 +441,8 @@ def format_arg(arg): ret_list = jit_type_of(decl['returns'][0]) else: ret_list = '({})'.format(', '.join(jit_type_of(r) for r in decl['returns'])) + + print("signature return: %s" % 'aten::{}({}) -> {}'.format(decl['name'], arg_list, ret_list)) return 'aten::{}({}) -> {}'.format(decl['name'], arg_list, ret_list) From 427d4b96043e6c36d7ba8896530a83bf5c865d3b Mon Sep 17 00:00:00 2001 From: Wei Yang Date: Tue, 24 Jul 2018 16:57:03 -0700 Subject: [PATCH 17/18] [wip] debugging jit --- aten/src/ATen/native/native_functions.yaml | 2 +- test/test_jit.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 1bd0dbf4a0b28..a326d56f4aa79 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -1446,7 +1446,7 @@ CUDA: flip_cuda # default IntList value {0,1} should not add space after comma, since native_parse.py uses ', ' to split args -- func: rot90(Tensor self, int64_t k=1, IntList[2] dims=0) -> Tensor +- func: rot90(Tensor self, int64_t k=1, IntList dims={0,1}) -> Tensor - func: _trilinear(Tensor i1, Tensor i2, Tensor i3, IntList expand1, IntList expand2, IntList expand3, IntList sumdim, int64_t unroll_dim=1) -> Tensor variants: function diff --git a/test/test_jit.py b/test/test_jit.py index b5b26f69a28df..25d262b6c8ad1 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -4863,7 +4863,7 @@ def script_fn(*args, **kwargs): script = script_template.format(', '.join(formals), call) # print("script: ", script); CU = torch.jit.CompilationUnit(script) - # print("create_script_fn: tensors = %s" % tensors) + print(CU.the_method.graph) return output_process_fn(CU.the_method(*tensors)) return script_fn From 0a5d9e04bf57a9a4e5c91ee9f5c0bb1b44802d28 Mon Sep 17 00:00:00 2001 From: Wei Yang Date: Tue, 24 Jul 2018 17:29:56 -0700 Subject: [PATCH 18/18] Zachary DeVito fixed the bug in jit --- aten/src/ATen/native/TensorTransformations.cpp | 6 ------ test/test_jit.py | 2 -- test/test_torch.py | 1 + tools/jit/gen_jit_dispatch.py | 3 --- torch/csrc/jit/script/compiler.cpp | 2 +- 5 files changed, 2 insertions(+), 12 deletions(-) diff --git a/aten/src/ATen/native/TensorTransformations.cpp b/aten/src/ATen/native/TensorTransformations.cpp index e084c03d0e9c7..84759874ef535 100644 --- a/aten/src/ATen/native/TensorTransformations.cpp +++ b/aten/src/ATen/native/TensorTransformations.cpp @@ -60,12 +60,6 @@ Tensor flip_cpu(const Tensor& self, IntList dims) { Tensor rot90(const Tensor& self, int64_t k, IntList dims) { const int64_t total_dims = self.dim(), total_rot_dims = dims.size(); - printf("total_rot_dims %ld\n", total_rot_dims); - - for (int64_t i = 0; i < total_rot_dims; i++) { - printf("dim[%ld] = %ld\n", i, dims[i]); - } - AT_CHECK(total_rot_dims == 2, "expected total rotation dims == 2, but got dims = ", total_rot_dims); diff --git a/test/test_jit.py b/test/test_jit.py index 25d262b6c8ad1..bd04f63729b28 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -4861,9 +4861,7 @@ def script_fn(*args, **kwargs): else: call = '{}.{}({}{})'.format(actuals[0], method_name, ', '.join(actuals[1:]), kwargs_str) script = script_template.format(', '.join(formals), call) - # print("script: ", script); CU = torch.jit.CompilationUnit(script) - print(CU.the_method.graph) return output_process_fn(CU.the_method(*tensors)) return script_fn diff --git a/test/test_torch.py b/test/test_torch.py index 6a3479c000920..98e48d474fb6f 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -6772,6 +6772,7 @@ def _test_rot90(self, use_cuda=False): self.assertRaises(RuntimeError, lambda: data.rot90(1, [0, 3])) self.assertRaises(RuntimeError, lambda: data.rot90(1, [1, 1])) self.assertRaises(RuntimeError, lambda: data.rot90(1, [0, 1, 2])) + self.assertRaises(RuntimeError, lambda: data.rot90(1, [0])) def test_rot90(self): self._test_rot90(self, use_cuda=False) diff --git a/tools/jit/gen_jit_dispatch.py b/tools/jit/gen_jit_dispatch.py index b985855b008a0..18c043a6c1061 100644 --- a/tools/jit/gen_jit_dispatch.py +++ b/tools/jit/gen_jit_dispatch.py @@ -425,7 +425,6 @@ def format_arg(arg): default = default_map.get(default, default) decl = '{}={}'.format(decl, default) - # print("decl", decl) return decl args = [] @@ -441,8 +440,6 @@ def format_arg(arg): ret_list = jit_type_of(decl['returns'][0]) else: ret_list = '({})'.format(', '.join(jit_type_of(r) for r in decl['returns'])) - - print("signature return: %s" % 'aten::{}({}) -> {}'.format(decl['name'], arg_list, ret_list)) return 'aten::{}({}) -> {}'.format(decl['name'], arg_list, ret_list) diff --git a/torch/csrc/jit/script/compiler.cpp b/torch/csrc/jit/script/compiler.cpp index b12bb7b58d016..c38eeadb33737 100644 --- a/torch/csrc/jit/script/compiler.cpp +++ b/torch/csrc/jit/script/compiler.cpp @@ -373,7 +373,7 @@ static bool isTensorSubtype(Value* v) { at::optional> getIntListAttribute(at::optional N, Value* input) { auto list = constant_as>(input); if(list) - return std::vector(*list); + return std::vector(list.value()->elements()); // broadcast IntList[3] with value 4 -> {4, 4, 4} if(!N) return at::nullopt;