From 01da0fe5fffefe9e76e0975b899db226cf92e855 Mon Sep 17 00:00:00 2001 From: Jane Xu Date: Wed, 4 Nov 2020 09:34:49 -0800 Subject: [PATCH 1/6] Including generator param in randperm documentation (#47231) Summary: The `randperm` documentation is outdated and did not use to include the optional `generator` parameter. This PR just adds that along with the `pin_memory` parameter. This PR was brought up in [PR 47022](https://github.com/pytorch/pytorch/pull/47022), but is now rebased onto master. New docs look like: ![image](https://user-images.githubusercontent.com/31798555/97923963-e6084400-1d2c-11eb-9d46-573ba3189ad6.png) Pull Request resolved: https://github.com/pytorch/pytorch/pull/47231 Reviewed By: mruberry Differential Revision: D24711960 Pulled By: janeyx99 fbshipit-source-id: 3ff8be62ec33e34ef87d017ea97bb950621a3064 --- torch/_torch_docs.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 0f8292de3ad3..f11ceade91b1 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -6566,7 +6566,8 @@ def merge_dicts(*dicts): add_docstr(torch.randperm, r""" -randperm(n, *, out=None, dtype=torch.int64, layout=torch.strided, device=None, requires_grad=False) -> LongTensor +randperm(n, \*, generator=None, out=None, dtype=torch.int64, layout=torch.strided, device=None, requires_grad=False, + pin_memory=False) -> LongTensor Returns a random permutation of integers from ``0`` to ``n - 1``. @@ -6574,12 +6575,14 @@ def merge_dicts(*dicts): n (int): the upper bound (exclusive) Keyword args: + {generator} {out} dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. Default: ``torch.int64``. {layout} {device} {requires_grad} + {pin_memory} Example:: From 4189c3ca76b1e60f921d08034928af48fbd67e81 Mon Sep 17 00:00:00 2001 From: Jane Xu Date: Wed, 4 Nov 2020 10:29:12 -0800 Subject: [PATCH 2/6] Fix onnx test-reports path in CI (#47315) Summary: Currently, no test reports are uploaded to CI because the paths for the `onnx` runs are incorrect. This PR attempts to change that. Pull Request resolved: https://github.com/pytorch/pytorch/pull/47315 Reviewed By: malfet Differential Revision: D24727607 Pulled By: janeyx99 fbshipit-source-id: f6d91698fdb15a39e01ef812032d4cd30621f864 --- .jenkins/caffe2/common.sh | 6 +++--- scripts/onnx/test.sh | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.jenkins/caffe2/common.sh b/.jenkins/caffe2/common.sh index 49f115419e3c..026cb8349d3d 100644 --- a/.jenkins/caffe2/common.sh +++ b/.jenkins/caffe2/common.sh @@ -2,9 +2,9 @@ set -ex LOCAL_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd) ROOT_DIR=$(cd "$LOCAL_DIR"/../.. && pwd) -TEST_DIR="$ROOT_DIR/caffe2_tests" -gtest_reports_dir="${TEST_DIR}/cpp" -pytest_reports_dir="${TEST_DIR}/python" +TEST_DIR="$ROOT_DIR/test" +gtest_reports_dir="${TEST_DIR}/test-reports/cpp" +pytest_reports_dir="${TEST_DIR}/test-reports/python" # Figure out which Python to use PYTHON="$(which python)" diff --git a/scripts/onnx/test.sh b/scripts/onnx/test.sh index 8b6fc6c4cf63..1f8ae54e1b4a 100755 --- a/scripts/onnx/test.sh +++ b/scripts/onnx/test.sh @@ -23,7 +23,7 @@ do done set -- "${UNKNOWN[@]}" # leave UNKNOWN -pip install pytest scipy hypothesis +pip install pytest scipy hypothesis # these are all already satisfied in CI if [[ $PARALLEL == 1 ]]; then pip install pytest-xdist From bba5a31176148dc2e2d1a71da53e5651fe79667b Mon Sep 17 00:00:00 2001 From: Erjia Guan Date: Wed, 4 Nov 2020 10:42:40 -0800 Subject: [PATCH 3/6] Revert D24481801: Optimize backward for torch.repeat Test Plan: revert-hammer Differential Revision: D24481801 (https://github.com/pytorch/pytorch/commit/4e6f2440d84a77bd251b7d7e573b0b33e4f3fe5c) Original commit changeset: 95c155e0de83 fbshipit-source-id: 0fb0afde760b0f5e17bd75df950a5d76aee5370b --- torch/csrc/autograd/FunctionsManual.cpp | 77 ++++++------------------- torch/csrc/autograd/FunctionsManual.h | 1 + 2 files changed, 19 insertions(+), 59 deletions(-) diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index bc4c78cf28ad..aed5dd52999d 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -9,7 +9,6 @@ #include #include #include -#include #include #include @@ -648,6 +647,17 @@ Tensor renorm_backward(const Tensor & grad, const Tensor & self, Scalar p, int64 return at::where(mask, grad, grad_norm); } +Tensor sum_tensorlist(TensorList tl) { + if (tl.size() == 0) { + throw std::runtime_error("Can't sum tensorlist of size 0"); + } + Tensor sum = tl[0]; + for(size_t i = 1; i < tl.size(); ++i) { + sum = sum + tl[i]; + } + return sum; +} + Tensor repeat_backward(Tensor grad, IntArrayRef repeats, IntArrayRef input_shape) { auto find_iter = std::find(repeats.cbegin(), repeats.cend(), 0); if (find_iter != repeats.cend()) { @@ -658,65 +668,14 @@ Tensor repeat_backward(Tensor grad, IntArrayRef repeats, IntArrayRef input_shape for (int64_t i = 0; i < num_unsqueezed; ++i) { grad = grad.sum(0, false); } - - at::DimVector grad_size, sum_dims; - for (size_t dim = 0; dim < input_dims; ++dim) { - int64_t repeat = repeats[dim + num_unsqueezed]; - // Reshape gradient (repeat > 1) - // Index: [..., dim , ...] [..., dim , dim+1 , ...] - // Shape: From [..., dimsize, ...] to [..., repeat, dimsize/repeat, ...] - // The gradient tensor at 'dim' is reshaped to 'repeat' times of input tensor. - // Then, sum up gradients over repeated tensors along 'dim', and reduce shape - // from 'repeat * dimsize/repeat' to 'dimsize/repeat' ('input_dimsize'). - // Example: - // Size(3, 2) Size(6, 2) - // [[v1_0, v1_1], - // [v1_2, v1_3], - // [[v0, v1], repeat(2, 1) [v1_4, v1_5], - // [v2, v3], -------------> [v2_0, v2_1], - // [v4, v5]] [v2_2, v2_3], - // [v2_4, v2_5]] - // - // input grad (3, 2) reshape (2, 3, 2) output grad (6, 2) - // [[[g1_0, g1_1], [[g1_0, g1_1], - // [g1_2, g1_3], [g1_2, g1_3], - // [[g1_0+g2_0, g1_1+g2_1], [g1_4, g1_5]], [g1_4, g1_5], - // [g1_0+g2_0, g1_1+g2_1], [g2_0, g2_1], - // [g1_0+g2_0, g1_1+g2_1]] [[g2_0, g2_1], [g2_2, g2_3], - // [g2_2, g2_3], [g2_4, g2_5]] - // [g2_4, g2_5]]] - // If gradient tensor is reshaped to [..., dimsize/repeat, repeat, ...] and then - // sum over 'dim+1'. The gradient for input is not correctly aligned with input. - // Example: - // input grad (3, 2) reshape (3, 2, 2) output grad (6, 2) - // [[[g1_0, g1_1], - // [g1_2, g1_3]], [[g1_0, g1_1], - // [g1_2, g1_3], - // [[g1_0+g1_2, g1_1+g1_3], [[g1_4, g1_5], [g1_4, g1_5], - // [g1_4+g2_0, g1_5+g2_1], [g2_0, g2_1]], [g2_0, g2_1], - // [g2_2+g2_4, g2_3+g2_5]] [g2_2, g2_3], - // [[g2_2, g2_3], [g2_4, g2_5]] - // [g2_4, g2_5]]] - if (repeat != 1) { - grad_size.push_back(repeat); - sum_dims.push_back(grad_size.size() - 1); + for (size_t j = num_unsqueezed; j < repeats.size(); ++j) { + int64_t repeat = repeats[j]; + if (repeat == 1) { + continue; } - // Don't need to reshape gradient into (repeat, input_shape[dim]) (repeat == 1) - grad_size.push_back(input_shape[dim]); - } - // One-time Reshape & Sum - // Reshape gradient to grad_size: - // 1. If repeat equals to 1, append input size at that dimension, - // 2. If repeat is larger than 1, append both repeat and input size at that dimension. - // Sum over all "repeat" dimensions from sum_dims: - // Example: - // Input Size (2, 3, 4, 5) - // repeat [4, 1, 9, 3] - // output/grad Size (8, 3, 36, 15) - // grad_size [4, 2, 3, 9, 4, 3, 5] - // sum_dims [0, 3, 5] - grad = grad.reshape(grad_size); - grad = grad.sum(sum_dims); + int64_t dim = j - num_unsqueezed; + grad = sum_tensorlist(grad.chunk(repeat, dim)); + } return grad; } diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h index 0fba31bdd894..6fd6d6bc418b 100644 --- a/torch/csrc/autograd/FunctionsManual.h +++ b/torch/csrc/autograd/FunctionsManual.h @@ -76,6 +76,7 @@ at::Tensor mm_mat1_backward(const Tensor & grad, const Tensor & mat2, at::IntArr at::Tensor mm_mat2_backward(const at::Tensor & grad, const at::Tensor & mat1, at::IntArrayRef sizes, at::IntArrayRef strides, const at::Scalar & alpha); at::Tensor _sparse_addmm_sparse_backward(const at::Tensor& grad, const at::Tensor& sparse_, const at::Tensor& dense, const at::Scalar& alpha); at::Tensor renorm_backward(const at::Tensor & grad, const at::Tensor & self, at::Scalar p, int64_t dim, at::Scalar maxnorm); +at::Tensor sum_tensorlist(at::TensorList tl); at::Tensor repeat_backward(at::Tensor grad, at::IntArrayRef repeats, at::IntArrayRef input_shape); at::Tensor _fused_dropout_backward(at::Tensor grad, at::Tensor mask, double p1m); at::Tensor evenly_distribute_backward(at::Tensor grad, const at::Tensor & input, const at::Tensor & value); From f588ad6a35c3f52da8e8180c7b51de954fce5fd1 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 4 Nov 2020 11:09:57 -0800 Subject: [PATCH 4/6] [quant][graphmode][fx] Test to make sure dequantize node are placed properly (#47332) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47332 Test Plan: Imported from OSS Reviewed By: supriyar Differential Revision: D24719736 fbshipit-source-id: 51b1f14b479edbc5d7f28d85920faf5fee8dd5ea --- test/quantization/test_quantize_fx.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/test/quantization/test_quantize_fx.py b/test/quantization/test_quantize_fx.py index adc31ac7d6ce..73afec11ee38 100644 --- a/test/quantization/test_quantize_fx.py +++ b/test/quantization/test_quantize_fx.py @@ -1035,6 +1035,30 @@ def forward(self, x): # quantize, should run with no errors quantized = convert_fx(prepared_copy) + def test_dequantize(self): + r""" Test to make sure dequantize node are placed before + non-quantizable node + """ + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(1, 1, 1) + self.act = torch.nn.GELU() + + def forward(self, x): + x = self.conv(x) + return self.act(x) + + data = torch.rand(5, 1, 3, 3, dtype=torch.float) + for quant_type in self.static_quant_types: + node_list = [ + ns.call_module(nnq.Conv2d), + ns.call_method("dequantize"), + ns.call_module(nn.GELU), + ] + self.checkGraphModeFxOp( + M().eval(), (data,), quant_type, expected_node_list=node_list) + @skipIfNoFBGEMM class TestQuantizeFxOps(QuantizationTestCase): """Unit tests for individual ops From 2c554266108f1b556dd49f7c3c06c08f2bbd3cbe Mon Sep 17 00:00:00 2001 From: Iurii Zdebskyi Date: Wed, 4 Nov 2020 11:59:26 -0800 Subject: [PATCH 5/6] Renamed a TensorListMetaData property. Cleaned up a test (#46662) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/46662 Test Plan: Imported from OSS Reviewed By: ansley Differential Revision: D24453346 Pulled By: izdeby fbshipit-source-id: f88ac21708befa2e8f3edeffe5805b69a4634d12 --- aten/src/ATen/native/cuda/ForeachFunctors.cuh | 14 +- .../src/ATen/native/cuda/MultiTensorApply.cuh | 12 +- test/test_foreach.py | 128 +++++++----------- 3 files changed, 64 insertions(+), 90 deletions(-) diff --git a/aten/src/ATen/native/cuda/ForeachFunctors.cuh b/aten/src/ATen/native/cuda/ForeachFunctors.cuh index e58188f846f6..ffe4a5865835 100644 --- a/aten/src/ATen/native/cuda/ForeachFunctors.cuh +++ b/aten/src/ATen/native/cuda/ForeachFunctors.cuh @@ -163,7 +163,7 @@ struct BinaryOpScalarFunctor { opmath_t scalar) { int tensor_loc = tl.block_to_tensor[blockIdx.x]; int chunk_idx = tl.block_to_chunk[blockIdx.x]; - int n = tl.sizes[tensor_loc]; + int n = tl.numel_for_tensor[tensor_loc]; T* args[depth]; bool all_aligned = init_args(args, tl, chunk_idx, chunk_size, tensor_loc); @@ -183,7 +183,7 @@ struct BinaryOpScalarListFunctor { Op op) { int tensor_loc = tl.block_to_tensor[blockIdx.x]; int chunk_idx = tl.block_to_chunk[blockIdx.x]; - int n = tl.sizes[tensor_loc]; + int n = tl.numel_for_tensor[tensor_loc]; T* args[depth]; bool all_aligned = init_args(args, tl, chunk_idx, chunk_size, tensor_loc); @@ -205,7 +205,7 @@ struct BinaryOpListAlphaFunctor { opmath_t alpha) { int tensor_loc = tl.block_to_tensor[blockIdx.x]; int chunk_idx = tl.block_to_chunk[blockIdx.x]; - int n = tl.sizes[tensor_loc]; + int n = tl.numel_for_tensor[tensor_loc]; T* args[depth]; bool all_aligned = init_args(args, tl, chunk_idx, chunk_size, tensor_loc); @@ -254,7 +254,7 @@ struct UnaryOpFunctor { Op op) { int tensor_loc = tl.block_to_tensor[blockIdx.x]; int chunk_idx = tl.block_to_chunk[blockIdx.x]; - int n = tl.sizes[tensor_loc]; + int n = tl.numel_for_tensor[tensor_loc]; T* args[depth]; bool all_aligned = init_args(args, tl, chunk_idx, chunk_size, tensor_loc); @@ -301,7 +301,7 @@ struct PointwiseOpScalarFunctor { opmath_t scalar) { int tensor_loc = tl.block_to_tensor[blockIdx.x]; int chunk_idx = tl.block_to_chunk[blockIdx.x]; - int n = tl.sizes[tensor_loc]; + int n = tl.numel_for_tensor[tensor_loc]; T* args[depth]; bool all_aligned = init_args(args, tl, chunk_idx, chunk_size, tensor_loc); @@ -321,7 +321,7 @@ struct PointwiseOpScalarListFunctor { Op op) { int tensor_loc = tl.block_to_tensor[blockIdx.x]; int chunk_idx = tl.block_to_chunk[blockIdx.x]; - int n = tl.sizes[tensor_loc]; + int n = tl.numel_for_tensor[tensor_loc]; T* args[depth]; bool all_aligned = init_args(args, tl, chunk_idx, chunk_size, tensor_loc); @@ -342,7 +342,7 @@ struct PointwiseOpListFunctor { Op op) { int tensor_loc = tl.block_to_tensor[blockIdx.x]; int chunk_idx = tl.block_to_chunk[blockIdx.x]; - int n = tl.sizes[tensor_loc]; + int n = tl.numel_for_tensor[tensor_loc]; T* args[depth]; bool all_aligned = init_args(args, tl, chunk_idx, chunk_size, tensor_loc); diff --git a/aten/src/ATen/native/cuda/MultiTensorApply.cuh b/aten/src/ATen/native/cuda/MultiTensorApply.cuh index f0f8f97fabb1..67758aab2991 100644 --- a/aten/src/ATen/native/cuda/MultiTensorApply.cuh +++ b/aten/src/ATen/native/cuda/MultiTensorApply.cuh @@ -31,7 +31,7 @@ static constexpr int depth_to_max_tensors_scalarlist[5] = {96, 64, 48, 36, 30}; template struct TensorListMetadata { void* addresses[n][depth_to_max_tensors[n-1]]; - int sizes[depth_to_max_tensors[n-1]]; + int numel_for_tensor[depth_to_max_tensors[n-1]]; unsigned char block_to_tensor[depth_to_max_blocks[n-1]]; int block_to_chunk[depth_to_max_blocks[n-1]]; }; @@ -39,7 +39,7 @@ template struct TensorListMetadata template struct TensorListScalarListMetadata { void* addresses[n][depth_to_max_tensors_scalarlist[n-1]]; - int sizes[depth_to_max_tensors_scalarlist[n-1]]; + int numel_for_tensor[depth_to_max_tensors_scalarlist[n-1]]; scalar_vals_t scalar_vals[depth_to_max_tensors_scalarlist[n-1]]; unsigned char block_to_tensor[depth_to_max_blocks[n-1]]; int block_to_chunk[depth_to_max_blocks[n-1]]; @@ -74,7 +74,7 @@ void multi_tensor_apply( tensorListMeta.scalar_vals[loc_tensor_info] = scalars[t]; - tensorListMeta.sizes[loc_tensor_info] = tensor_lists[0][t].numel(); + tensorListMeta.numel_for_tensor[loc_tensor_info] = tensor_lists[0][t].numel(); for (int d = 0; d < depth; d++) { tensorListMeta.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr(); } @@ -105,7 +105,7 @@ void multi_tensor_apply( loc_tensor_info = 0; } else { - tensorListMeta.sizes[0] = tensorListMeta.sizes[loc_tensor_info-1]; + tensorListMeta.numel_for_tensor[0] = tensorListMeta.numel_for_tensor[loc_tensor_info-1]; tensorListMeta.scalar_vals[0] = tensorListMeta.scalar_vals[loc_tensor_info-1]; for(int d = 0; d < depth; d++) { tensorListMeta.addresses[d][0] = tensorListMeta.addresses[d][loc_tensor_info-1]; @@ -131,7 +131,7 @@ void multi_tensor_apply( int loc_block_info = 0; int loc_tensor_info = 0; for(size_t t = 0; t < n_tensors; t++) { - tensorListMeta.sizes[loc_tensor_info] = tensor_lists[0][t].numel(); + tensorListMeta.numel_for_tensor[loc_tensor_info] = tensor_lists[0][t].numel(); for (int d = 0; d < depth; d++) { tensorListMeta.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr(); } @@ -162,7 +162,7 @@ void multi_tensor_apply( loc_tensor_info = 0; } else { - tensorListMeta.sizes[0] = tensorListMeta.sizes[loc_tensor_info-1]; + tensorListMeta.numel_for_tensor[0] = tensorListMeta.numel_for_tensor[loc_tensor_info-1]; for(int d = 0; d < depth; d++) { tensorListMeta.addresses[d][0] = tensorListMeta.addresses[d][loc_tensor_info-1]; } diff --git a/test/test_foreach.py b/test/test_foreach.py index 28c493579201..a723efa04684 100644 --- a/test/test_foreach.py +++ b/test/test_foreach.py @@ -703,56 +703,57 @@ def test_bin_op_scalar_with_different_tensor_dtypes(self, device): # # Ops with list # - def test_add_list_error_cases(self, device): - tensors1 = [] - tensors2 = [] - - # Empty lists - with self.assertRaisesRegex(RuntimeError, "There were no tensor arguments to this function"): - torch._foreach_add(tensors1, tensors2) - with self.assertRaisesRegex(RuntimeError, "There were no tensor arguments to this function"): - torch._foreach_add_(tensors1, tensors2) - - # One empty list - tensors1.append(torch.tensor([1], device=device)) - with self.assertRaisesRegex(RuntimeError, "Tensor list must have same number of elements as scalar list."): - torch._foreach_add(tensors1, tensors2) - with self.assertRaisesRegex(RuntimeError, "Tensor list must have same number of elements as scalar list."): - torch._foreach_add_(tensors1, tensors2) - - # Lists have different amount of tensors - tensors2.append(torch.tensor([1], device=device)) - tensors2.append(torch.tensor([1], device=device)) - with self.assertRaisesRegex(RuntimeError, "Tensor lists must have the same number of tensors, got 1 and 2"): - torch._foreach_add(tensors1, tensors2) - with self.assertRaisesRegex(RuntimeError, "Tensor lists must have the same number of tensors, got 1 and 2"): - torch._foreach_add_(tensors1, tensors2) - - # Different dtypes - tensors1 = [torch.zeros(10, 10, device=device, dtype=torch.float) for _ in range(10)] - tensors2 = [torch.ones(10, 10, device=device, dtype=torch.int) for _ in range(10)] - - with self.assertRaisesRegex(RuntimeError, "All tensors in the tensor list must have the same dtype."): - torch._foreach_add(tensors1, tensors2) - with self.assertRaisesRegex(RuntimeError, "All tensors in the tensor list must have the same dtype."): - torch._foreach_add_(tensors1, tensors2) - - # different devices - if torch.cuda.is_available() and torch.cuda.device_count() > 1: - tensor1 = torch.zeros(10, 10, device="cuda:0") - tensor2 = torch.ones(10, 10, device="cuda:1") - with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"): - torch._foreach_add([tensor1], [tensor2]) - with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"): - torch._foreach_add_([tensor1], [tensor2]) - - # Coresponding tensors with different sizes - tensors1 = [torch.zeros(10, 10, device=device) for _ in range(10)] - tensors2 = [torch.ones(11, 11, device=device) for _ in range(10)] - with self.assertRaisesRegex(RuntimeError, "Corresponding tensors in lists must have the same size"): - torch._foreach_add(tensors1, tensors2) - with self.assertRaisesRegex(RuntimeError, r", got \[10, 10\] and \[11, 11\]"): - torch._foreach_add_(tensors1, tensors2) + def test_bin_op_list_error_cases(self, device): + for bin_op, bin_op_ in zip(self.foreach_bin_ops, self.foreach_bin_ops_): + tensors1 = [] + tensors2 = [] + + # Empty lists + with self.assertRaisesRegex(RuntimeError, "There were no tensor arguments to this function"): + bin_op(tensors1, tensors2) + with self.assertRaisesRegex(RuntimeError, "There were no tensor arguments to this function"): + bin_op_(tensors1, tensors2) + + # One empty list + tensors1.append(torch.tensor([1], device=device)) + with self.assertRaisesRegex(RuntimeError, "Tensor list must have same number of elements as scalar list."): + bin_op(tensors1, tensors2) + with self.assertRaisesRegex(RuntimeError, "Tensor list must have same number of elements as scalar list."): + bin_op_(tensors1, tensors2) + + # Lists have different amount of tensors + tensors2.append(torch.tensor([1], device=device)) + tensors2.append(torch.tensor([1], device=device)) + with self.assertRaisesRegex(RuntimeError, "Tensor lists must have the same number of tensors, got 1 and 2"): + bin_op(tensors1, tensors2) + with self.assertRaisesRegex(RuntimeError, "Tensor lists must have the same number of tensors, got 1 and 2"): + bin_op_(tensors1, tensors2) + + # Different dtypes + tensors1 = [torch.zeros(10, 10, device=device, dtype=torch.float) for _ in range(10)] + tensors2 = [torch.ones(10, 10, device=device, dtype=torch.int) for _ in range(10)] + + with self.assertRaisesRegex(RuntimeError, "All tensors in the tensor list must have the same dtype."): + bin_op(tensors1, tensors2) + with self.assertRaisesRegex(RuntimeError, "All tensors in the tensor list must have the same dtype."): + bin_op_(tensors1, tensors2) + + # different devices + if torch.cuda.is_available() and torch.cuda.device_count() > 1: + tensor1 = torch.zeros(10, 10, device="cuda:0") + tensor2 = torch.ones(10, 10, device="cuda:1") + with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"): + bin_op([tensor1], [tensor2]) + with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"): + bin_op_([tensor1], [tensor2]) + + # Corresponding tensors with different sizes + tensors1 = [torch.zeros(10, 10, device=device) for _ in range(10)] + tensors2 = [torch.ones(11, 11, device=device) for _ in range(10)] + with self.assertRaisesRegex(RuntimeError, "Corresponding tensors in lists must have the same size"): + bin_op(tensors1, tensors2) + with self.assertRaisesRegex(RuntimeError, r", got \[10, 10\] and \[11, 11\]"): + bin_op_(tensors1, tensors2) @dtypes(*torch.testing.get_all_dtypes()) def test_add_list(self, device, dtype): @@ -799,33 +800,6 @@ def test_div_list(self, device, dtype): self.assertEqual(res, tensors1) self.assertEqual(tensors1, res) - def test_bin_op_list_error_cases(self, device): - tensors1 = [] - tensors2 = [] - - for bin_op in self.foreach_bin_ops + self.foreach_bin_ops_: - # Empty lists - with self.assertRaises(RuntimeError): - bin_op(tensors1, tensors2) - - # One empty list - tensors1.append(torch.tensor([1], device=device)) - with self.assertRaises(RuntimeError): - bin_op(tensors1, tensors2) - - # Lists have different amount of tensors - tensors2.append(torch.tensor([1], device=device)) - tensors2.append(torch.tensor([1], device=device)) - with self.assertRaises(RuntimeError): - bin_op(tensors1, tensors2) - - # Different dtypes - tensors1 = [torch.zeros(2, 2, device=device, dtype=torch.float) for _ in range(2)] - tensors2 = [torch.ones(2, 2, device=device, dtype=torch.int) for _ in range(2)] - - with self.assertRaises(RuntimeError): - bin_op(tensors1, tensors2) - @dtypes(*torch.testing.get_all_dtypes()) def test_add_list_different_sizes(self, device, dtype): tensors1 = [torch.zeros(10 + n, 10 + n, device=device, dtype=dtype) for n in range(10)] From 6b3802a7113b646e04c5d23394af89f0ec4bcda9 Mon Sep 17 00:00:00 2001 From: Yi Wang Date: Wed, 4 Nov 2020 12:32:42 -0800 Subject: [PATCH 6/6] [Gradient Compression] Export sizes, along with length and offset of each variable to GradBucket for PowerSGD (#47203) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47203 1. Create a new field in BucketReplica to store sizes info for each variable. 2. Export sizes list, along with lengths and offsets to GradBuceket. These fields are needed for PowerSGD. Original PR issue: Investigate Applying PowerSGD to Communication Hook for Gradient Compression #47202 ghstack-source-id: 115875194 Test Plan: Checked the field values from log. Reviewed By: rohan-varma Differential Revision: D24644137 fbshipit-source-id: bcec0daf0d02cbf25389bfd9be90df1e6fd8fc56 --- torch/csrc/distributed/c10d/init.cpp | 25 ++++++++++++++-- torch/lib/c10d/comm.hpp | 43 +++++++++++++++++++++++----- torch/lib/c10d/reducer.cpp | 18 +++++++++++- torch/lib/c10d/reducer.hpp | 5 +++- 4 files changed, 80 insertions(+), 11 deletions(-) diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index b04b5495c9ec..14c5096d36ef 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -168,7 +168,16 @@ PyObject* c10d_init(PyObject* _unused, PyObject* noargs) { py::arg("comm_hook_type")); shared_ptr_class_<::c10d::GradBucket>(module, "_GradBucket") - .def(py::init&>(), py::arg("tensors")) + .def( + py::init< + const std::vector&, + const std::vector&, + const std::vector&, + const std::vector&>(), + py::arg("tensors"), + py::arg("offsets"), + py::arg("lengths"), + py::arg("sizes_list")) .def( "get_tensors", &::c10d::GradBucket::getTensors, @@ -179,7 +188,19 @@ PyObject* c10d_init(PyObject* _unused, PyObject* noargs) { replicas only in the case of single process multiple device mode. In the single process single device mode, this list would consist of only a single tensor. - )"); + )") + .def( + "get_offsets", + &::c10d::GradBucket::getOffsets, + py::call_guard()) + .def( + "get_lengths", + &::c10d::GradBucket::getLengths, + py::call_guard()) + .def( + "get_sizes_list", + &::c10d::GradBucket::getSizesVec, + py::call_guard()); py::enum_<::c10d::BuiltinCommHookType>(module, "BuiltinCommHookType", R"( An enum-like class for built-in communication hooks: ``ALLREDUCE`` and ``FP16_COMPRESS``.)") diff --git a/torch/lib/c10d/comm.hpp b/torch/lib/c10d/comm.hpp index 58f40f81ccb3..0dd5815a5b8b 100644 --- a/torch/lib/c10d/comm.hpp +++ b/torch/lib/c10d/comm.hpp @@ -19,8 +19,16 @@ void broadcast_coalesced( // mappings as well. class GradBucket { public: - explicit GradBucket(const std::vector& tensors) - : tensors_(tensors) {} + explicit GradBucket( + const std::vector& tensors, + const std::vector& offsets = {}, + const std::vector& lengths = {}, + const std::vector& sizes_vec = {}) + : tensors_(tensors), + offsets_(offsets), + lengths_(lengths), + sizes_vec_(sizes_vec) {} + // Each tensor in the list that getTensors returns refers to the replica on // each device. There will be multiple replicas only in the case of single // process multiple device mode. In the single process single device mode, @@ -33,14 +41,35 @@ class GradBucket { return tensors_; } + // Returns the start index of each variable in tensors_[0]. + const std::vector& getOffsets() const { + return offsets_; + } + + // Returns the total (i.e., flattened) length of each variable in + // tensors_[0]. + const std::vector& getLengths() const { + return lengths_; + } + + // Returns the multi-dimensional sizes/shape of each variable in tensors_[0]. + const std::vector& getSizesVec() const { + return sizes_vec_; + } + private: std::vector tensors_; + + // Per-variable info in tensors_[0]. + std::vector offsets_; + std::vector lengths_; + std::vector sizes_vec_; }; // Base class of both `PythonCommHook` and `CppCommHook`. // Requires implementing 1) `runHook` method that communicates gradients -// asynchronously, and 2) `parseHookResult` method that converts the hook result -// into a tensor vector. +// asynchronously, and 2) `parseHookResult` method that converts the hook +// result into a tensor vector. class TORCH_PYTHON_API CommHookInterface { public: virtual ~CommHookInterface() {} @@ -51,9 +80,9 @@ class TORCH_PYTHON_API CommHookInterface { virtual c10::intrusive_ptr runHook( GradBucket& bucket) = 0; - // Returns the resulting tensors once the communication hook result is ready. - // The resulting tensors will then be copied to the grads of individual - // parameters. + // Returns the resulting tensors once the communication hook result is + // ready. The resulting tensors will then be copied to the grads of + // individual parameters. virtual std::vector parseHookResult( const c10::IValue& result) = 0; }; diff --git a/torch/lib/c10d/reducer.cpp b/torch/lib/c10d/reducer.cpp index 8bce7a3102fc..c05ce685bb7d 100644 --- a/torch/lib/c10d/reducer.cpp +++ b/torch/lib/c10d/reducer.cpp @@ -712,7 +712,13 @@ void Reducer::mark_bucket_ready(size_t bucket_index) { if (comm_hook_ == nullptr) { bucket.work = process_group_->allreduce(tensors); } else { - GradBucket grad_bucket(tensors); + GradBucket grad_bucket( + tensors, + // Since currently we do not support single-process multiple-device + // mode, we can assume only one replica in the bucket. + bucket.replicas[0].offsets, + bucket.replicas[0].lengths, + bucket.replicas[0].sizes_vec); bucket.future_work = comm_hook_->runHook(grad_bucket); } } @@ -783,8 +789,17 @@ void Reducer::initialize_buckets( replica.variables = {variable}; } else { at::TensorOptions options; + // The start index of the variable in the flattened tensor. size_t offset = 0; + // Reserve enough space for the per-variable fields stored in bucket + // replica for efficiency. + const size_t num_variables = bucket_indices[bucket_index].size(); + replica.variables.reserve(num_variables); + replica.offsets.reserve(num_variables); + replica.lengths.reserve(num_variables); + replica.sizes_vec.reserve(num_variables); + // Iterate over bucket variables. for (const auto variable_index : bucket_indices[bucket_index]) { TORCH_CHECK( @@ -810,6 +825,7 @@ void Reducer::initialize_buckets( replica.variables.push_back(variable); replica.offsets.push_back(offset); replica.lengths.push_back(length); + replica.sizes_vec.push_back(variable.sizes()); offset += length; } diff --git a/torch/lib/c10d/reducer.hpp b/torch/lib/c10d/reducer.hpp index d667899e05d0..4874f0dd8703 100644 --- a/torch/lib/c10d/reducer.hpp +++ b/torch/lib/c10d/reducer.hpp @@ -225,10 +225,13 @@ class Reducer { // participating variables after reduction has completed. std::vector variables; - // Per-variable offset/length into the flat bucket contents tensor. + // Per-variable offset/length into the flat bucket contents tensor and grad bucket. std::vector offsets; std::vector lengths; + // Per-variable sizes into the grad bucekt. + std::vector sizes_vec; + // Number of tensors to be added before this bucket is complete. // This is reset to `variables.size()` every iteration. size_t pending;