From e12008d110acae507573b4f5feb02f94125a4463 Mon Sep 17 00:00:00 2001 From: Zafar Date: Thu, 7 Jan 2021 13:53:51 -0800 Subject: [PATCH 01/54] [quant] Mapping for the `_LinearWithBias` (#49964) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49964 `torch.nn.modules.linear._LinearWithBias` is only used in the transformers, and is completely identical to the `torch.nn.Linear`. This PR creates a mapping so that this module would be treated the same as the Linear. Test Plan: ``` python test/test_quantization.py TestDynamicQuantizedModule TestStaticQuantizedModule ``` Differential Revision: D25731589 Reviewed By: jerryzh168 Pulled By: z-a-f fbshipit-source-id: 1b2697014e250e97d3010cdb542f9d130b71fbc3 --- test/quantization/test_quantized_module.py | 50 +++++++++++--------- torch/nn/quantized/dynamic/modules/linear.py | 6 ++- torch/nn/quantized/modules/linear.py | 13 +++-- torch/quantization/quantization_mappings.py | 3 ++ 4 files changed, 45 insertions(+), 27 deletions(-) diff --git a/test/quantization/test_quantized_module.py b/test/quantization/test_quantized_module.py index 66210685b562..28867e8260b6 100644 --- a/test/quantization/test_quantized_module.py +++ b/test/quantization/test_quantized_module.py @@ -194,20 +194,24 @@ def _test_linear_api_impl(self, batch_size, in_features, out_features, use_bias, # Test JIT self.checkScriptable(qlinear, [[X_q]], check_save_load=True) - # Test from_float. - float_linear = torch.nn.Linear(in_features, out_features).float() - float_linear.qconfig = torch.quantization.default_qconfig - torch.quantization.prepare(float_linear, inplace=True) - float_linear(X.float()) - # Sequential allows swapping using "convert". - quantized_float_linear = torch.nn.Sequential(float_linear) - quantized_float_linear = torch.quantization.convert(quantized_float_linear, inplace=True) + # Make sure `from_float` works for all linear variants + modules_under_test = [torch.nn.Linear, torch.nn.modules.linear._LinearWithBias] - # Smoke test to make sure the module actually runs - quantized_float_linear(X_q) + for mut in modules_under_test: + # Test from_float. + float_linear = mut(in_features, out_features).float() + float_linear.qconfig = torch.quantization.default_qconfig + torch.quantization.prepare(float_linear, inplace=True) + float_linear(X.float()) + # Sequential allows swapping using "convert". + quantized_float_linear = torch.nn.Sequential(float_linear) + quantized_float_linear = torch.quantization.convert(quantized_float_linear, inplace=True) - # Smoke test extra_repr - self.assertTrue('QuantizedLinear' in str(quantized_float_linear)) + # Smoke test to make sure the module actually runs + quantized_float_linear(X_q) + + # Smoke test extra_repr + self.assertTrue('QuantizedLinear' in str(quantized_float_linear)) def test_quant_dequant_api(self): r = torch.tensor([[1., -1.], [1., -1.]], dtype=torch.float) @@ -928,16 +932,18 @@ def test_linear_api(self, batch_size, in_features, out_features, use_bias, use_d # Test JIT self.checkScriptable(qlinear, [[X]], check_save_load=True) - # Test from_float - float_linear = torch.nn.Linear(in_features, out_features).float() - if use_default_observer: - float_linear.qconfig = torch.quantization.default_dynamic_qconfig - prepare_dynamic(float_linear) - float_linear(X.float()) - quantized_float_linear = nnqd.Linear.from_float(float_linear) - - # Smoke test to make sure the module actually runs - quantized_float_linear(X) + modules_under_test = [torch.nn.Linear, torch.nn.modules.linear._LinearWithBias] + for mut in modules_under_test: + # Test from_float + float_linear = mut(in_features, out_features).float() + if use_default_observer: + float_linear.qconfig = torch.quantization.default_dynamic_qconfig + prepare_dynamic(float_linear) + float_linear(X.float()) + quantized_float_linear = nnqd.Linear.from_float(float_linear) + + # Smoke test to make sure the module actually runs + quantized_float_linear(X) # Smoke test extra_repr self.assertTrue('QuantizedLinear' in str(quantized_float_linear)) diff --git a/torch/nn/quantized/dynamic/modules/linear.py b/torch/nn/quantized/dynamic/modules/linear.py index f220aed075c1..527ee76fdc76 100644 --- a/torch/nn/quantized/dynamic/modules/linear.py +++ b/torch/nn/quantized/dynamic/modules/linear.py @@ -1,5 +1,4 @@ import torch -from ....modules.linear import Linear as NNLinear import torch.nn.quantized as nnq from torch.nn.quantized.modules.utils import _quantize_weight @@ -80,7 +79,10 @@ def from_float(cls, mod): mod (Module): a float module, either produced by torch.quantization utilities or provided by the user """ - assert type(mod) == NNLinear, 'nn.quantized.dynamic.Linear.from_float only works for nn.Linear' + float_modules = [torch.nn.Linear, torch.nn.modules.linear._LinearWithBias] + assert type(mod) in float_modules, \ + 'nn.quantized.dynamic.Linear.from_float only works for one of' + \ + str([float_mod.__name__ for float_mod in float_modules]) assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined' if mod.qconfig is not None and mod.qconfig.weight is not None: weight_observer = mod.qconfig.weight() diff --git a/torch/nn/quantized/modules/linear.py b/torch/nn/quantized/modules/linear.py index c5a26c777776..d7f86ccb6216 100644 --- a/torch/nn/quantized/modules/linear.py +++ b/torch/nn/quantized/modules/linear.py @@ -1,3 +1,4 @@ +from collections.abc import Iterable import torch import torch.nn as nn @@ -124,7 +125,7 @@ class Linear(torch.nn.Module): torch.Size([128, 30]) """ _version = 3 - _FLOAT_MODULE = nn.Linear + _FLOAT_MODULE = (nn.Linear, nn.modules.linear._LinearWithBias) def __init__(self, in_features, out_features, bias_=True, dtype=torch.qint8): @@ -252,8 +253,14 @@ def from_float(cls, mod): weight_post_process = mod.weight_fake_quant activation_post_process = mod.activation_post_process else: - assert type(mod) == cls._FLOAT_MODULE, ' nnq.' + cls.__name__ + '.from_float only works for ' + \ - cls._FLOAT_MODULE.__name__ + # This function does not participate in JIT, so it is OK to ignore + # the type mismatch in assignment. Also, mypy has an issue with + # iterables not being implemented, so we are ignoring those too. + if not isinstance(cls._FLOAT_MODULE, Iterable): + cls._FLOAT_MODULE = [cls._FLOAT_MODULE] # type: ignore + supported_modules = ', '.join([float_mod.__name__ for float_mod in cls._FLOAT_MODULE]) # type: ignore + error_msg = 'nnq.{}.from_float only works for {}'.format(cls.__name__, supported_modules) + assert type(mod) in cls._FLOAT_MODULE, error_msg.format() # type: ignore assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined' activation_post_process = mod.activation_post_process if type(mod) == nni.LinearReLU: diff --git a/torch/quantization/quantization_mappings.py b/torch/quantization/quantization_mappings.py index 6c8fc9defc00..d9800552d8a7 100644 --- a/torch/quantization/quantization_mappings.py +++ b/torch/quantization/quantization_mappings.py @@ -39,6 +39,7 @@ nn.InstanceNorm3d: nnq.InstanceNorm3d, nn.LayerNorm: nnq.LayerNorm, nn.LeakyReLU: nnq.LeakyReLU, + nn.modules.linear._LinearWithBias: nnq.Linear, nn.Linear: nnq.Linear, nn.ReLU6: nnq.ReLU6, # Wrapper Modules: @@ -65,6 +66,7 @@ DEFAULT_QAT_MODULE_MAPPINGS : Dict[Callable, Any] = { nn.Conv2d: nnqat.Conv2d, nn.Linear: nnqat.Linear, + nn.modules.linear._LinearWithBias: nnqat.Linear, # Intrinsic modules: nni.ConvBn1d: nniqat.ConvBn1d, nni.ConvBn2d: nniqat.ConvBn2d, @@ -78,6 +80,7 @@ DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS : Dict[Callable, Any] = { nn.GRUCell: nnqd.GRUCell, nn.Linear: nnqd.Linear, + nn.modules.linear._LinearWithBias: nnqd.Linear, nn.LSTM: nnqd.LSTM, nn.GRU: nnqd.GRU, nn.LSTMCell: nnqd.LSTMCell, From 0c3bae6a89303284d83e5eac0f1ab947a7213f16 Mon Sep 17 00:00:00 2001 From: Felix Abecassis Date: Thu, 7 Jan 2021 14:11:01 -0800 Subject: [PATCH 02/54] docker: add environment variable PYTORCH_VERSION (#50154) Summary: The aim is being able to inspect a container image and determine immediately which version of pytorch it contains. Closes https://github.com/pytorch/pytorch/issues/48324 Signed-off-by: Felix Abecassis seemethere PTAL. As you requested in https://github.com/pytorch/pytorch/issues/48324#issuecomment-754237156, I'm submitting the patch. But I could only do limited testing as I'm not sure these Makefile/Dockerfile are used for pushing the Docker Hub images (since the Makefile tags the image with a `v` prefix for the version, as in: `pytorch:v1.7.1`, but Docker Hub images don't have this prefix). Also on the master branch we currently have the following: ``` $ git describe --tags v1.4.0a0-11171-g68a6e46379 ``` So it's a little off, but it behaves as expected on the `release/1.7` branch. Pull Request resolved: https://github.com/pytorch/pytorch/pull/50154 Reviewed By: walterddr Differential Revision: D25828491 Pulled By: seemethere fbshipit-source-id: 500ec96cb5f5da1321610002d5e3678f4b0b94b5 --- Dockerfile | 2 ++ docker.Makefile | 10 ++++++---- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/Dockerfile b/Dockerfile index 70da5dbaf424..cbaa85597ad9 100644 --- a/Dockerfile +++ b/Dockerfile @@ -59,6 +59,7 @@ RUN /opt/conda/bin/conda install -c "${INSTALL_CHANNEL}" -y python=${PYTHON_VERS RUN /opt/conda/bin/pip install torchelastic FROM ${BASE_IMAGE} as official +ARG PYTORCH_VERSION LABEL com.nvidia.volumes.needed="nvidia_driver" RUN --mount=type=cache,id=apt-final,target=/var/cache/apt \ apt-get update && apt-get install -y --no-install-recommends \ @@ -71,6 +72,7 @@ ENV PATH /opt/conda/bin:$PATH ENV NVIDIA_VISIBLE_DEVICES all ENV NVIDIA_DRIVER_CAPABILITIES compute,utility ENV LD_LIBRARY_PATH /usr/local/nvidia/lib:/usr/local/nvidia/lib64 +ENV PYTORCH_VERSION ${PYTORCH_VERSION} WORKDIR /workspace FROM official as dev diff --git a/docker.Makefile b/docker.Makefile index 3af77ab9c7d1..6b843fa9c1b3 100644 --- a/docker.Makefile +++ b/docker.Makefile @@ -17,12 +17,14 @@ BASE_DEVEL = nvidia/cuda:$(CUDA_VERSION)-cudnn$(CUDNN_VERSION)-de INSTALL_CHANNEL = pytorch PYTHON_VERSION = 3.7 +PYTORCH_VERSION = $(shell git describe --tags) # Can be either official / dev BUILD_TYPE = dev BUILD_PROGRESS = auto BUILD_ARGS = --build-arg BASE_IMAGE=$(BASE_IMAGE) \ --build-arg PYTHON_VERSION=$(PYTHON_VERSION) \ --build-arg CUDA_VERSION=$(CUDA_VERSION) \ + --build-arg PYTORCH_VERSION=$(PYTORCH_VERSION) \ --build-arg INSTALL_CHANNEL=$(INSTALL_CHANNEL) EXTRA_DOCKER_BUILD_FLAGS ?= DOCKER_BUILD = DOCKER_BUILDKIT=1 \ @@ -39,26 +41,26 @@ all: devel-image .PHONY: devel-image devel-image: BASE_IMAGE := $(BASE_DEVEL) -devel-image: DOCKER_TAG := $(shell git describe --tags)-devel +devel-image: DOCKER_TAG := $(PYTORCH_VERSION)-devel devel-image: $(DOCKER_BUILD) .PHONY: devel-image devel-push: BASE_IMAGE := $(BASE_DEVEL) -devel-push: DOCKER_TAG := $(shell git describe --tags)-devel +devel-push: DOCKER_TAG := $(PYTORCH_VERSION)-devel devel-push: $(DOCKER_PUSH) .PHONY: runtime-image runtime-image: BASE_IMAGE := $(BASE_RUNTIME) -runtime-image: DOCKER_TAG := $(shell git describe --tags)-runtime +runtime-image: DOCKER_TAG := $(PYTORCH_VERSION)-runtime runtime-image: $(DOCKER_BUILD) docker tag $(DOCKER_FULL_NAME):$(DOCKER_TAG) $(DOCKER_FULL_NAME):latest .PHONY: runtime-image runtime-push: BASE_IMAGE := $(BASE_RUNTIME) -runtime-push: DOCKER_TAG := $(shell git describe --tags)-runtime +runtime-push: DOCKER_TAG := $(PYTORCH_VERSION)-runtime runtime-push: $(DOCKER_PUSH) From 7ce8f7e4884c4ab9d01941f4fcb8326471c30e8b Mon Sep 17 00:00:00 2001 From: Zafar Date: Thu, 7 Jan 2021 14:55:26 -0800 Subject: [PATCH 03/54] [quant] Backend string for the quantized types (#49965) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49965 Without this checking the type of the quantized tensor using `type` would throw an error. After this PR running the `type(qx)`, where `qx` is a quantized tensor would show something like `torch.quantized.QUInt8`. Test Plan: Not needed -- this is just a string description for the quantized tensors Differential Revision: D25731594 Reviewed By: ezyang Pulled By: z-a-f fbshipit-source-id: 942fdf89a1c50895249989c7203f2e7cc00df4c6 --- torch/csrc/utils/tensor_types.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/torch/csrc/utils/tensor_types.cpp b/torch/csrc/utils/tensor_types.cpp index e6b851a3a74c..5d60f8b07c64 100644 --- a/torch/csrc/utils/tensor_types.cpp +++ b/torch/csrc/utils/tensor_types.cpp @@ -21,6 +21,7 @@ static const char* backend_to_string(const at::Backend& backend) { case at::Backend::CUDA: return "torch.cuda"; case at::Backend::SparseCPU: return "torch.sparse"; case at::Backend::SparseCUDA: return "torch.cuda.sparse"; + case at::Backend::QuantizedCPU: return "torch.quantized"; default: AT_ERROR("Unimplemented backend ", backend); } } From 23cadb5d7b4ceecad3e06eb33343c775a86540f2 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Thu, 7 Jan 2021 15:16:00 -0800 Subject: [PATCH 04/54] [PyTorch] Specialize `list_element_from` for `IValue` to avoid extra move/copy (#50124) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/50124 This patch makes `list_element_from` avoid extra `IValue` move/copies for `List` by just forwarding the reference argument. We take advantage of this in `listConstruct` by using `push_back` (which hits the `ListElementFrom` path) instead of ` ghstack-source-id: 119478962 Test Plan: Inspected generated assembly for vararg_functions.cpp in optimized build. Rather than a call to `vector::emplace_back` and an extra move, `vector::push_back` gets inlined. Reviewed By: ezyang Differential Revision: D25794277 fbshipit-source-id: 2354d8c08e0a0d6be2db3f0d0d6c90c3a455d8bd --- aten/src/ATen/core/List_inl.h | 42 +++++++++++++-------- torch/csrc/jit/runtime/vararg_functions.cpp | 2 +- 2 files changed, 27 insertions(+), 17 deletions(-) diff --git a/aten/src/ATen/core/List_inl.h b/aten/src/ATen/core/List_inl.h index ab3ddae55770..c517dac20542 100644 --- a/aten/src/ATen/core/List_inl.h +++ b/aten/src/ATen/core/List_inl.h @@ -101,13 +101,23 @@ namespace detail { return std::move(element).template to(); } template - IValue list_element_from(const T& element) { - return element; - } - template - IValue list_element_from(T&& element) { - return std::move(element); - } + struct ListElementFrom { + static IValue from(const T& element) { + return element; + } + static IValue from(T&& element) { + return std::move(element); + } + }; + template<> + struct ListElementFrom { + static const IValue& from(const IValue& element) { + return element; + } + static IValue&& from(IValue&& element) { + return std::move(element); + } + }; } namespace impl { @@ -119,13 +129,13 @@ ListElementReference::operator T() const { template ListElementReference& ListElementReference::operator=(T&& new_value) && { - *iterator_ = c10::detail::list_element_from(std::move(new_value)); + *iterator_ = c10::detail::ListElementFrom::from(std::move(new_value)); return *this; } template ListElementReference& ListElementReference::operator=(const T& new_value) && { - *iterator_ = c10::detail::list_element_from(std::move(new_value)); + *iterator_ = c10::detail::ListElementFrom::from(std::move(new_value)); return *this; } @@ -154,12 +164,12 @@ inline bool operator==(const T& lhs, const ListElementReference& rh template void List::set(size_type pos, const value_type& value) const { - impl_->list.at(pos) = c10::detail::list_element_from(value); + impl_->list.at(pos) = c10::detail::ListElementFrom::from(value); } template void List::set(size_type pos, value_type&& value) const { - impl_->list.at(pos) = c10::detail::list_element_from(std::move(value)); + impl_->list.at(pos) = c10::detail::ListElementFrom::from(std::move(value)); } template @@ -178,7 +188,7 @@ typename List::value_type List::extract(size_type pos) const { auto& elem = impl_->list.at(pos); auto result = c10::detail::list_element_to(std::move(elem)); // Reset the list element to a T() instead of None to keep it correctly typed - elem = c10::detail::list_element_from(T{}); + elem = c10::detail::ListElementFrom::from(T{}); return result; } @@ -214,12 +224,12 @@ void List::clear() const { template typename List::iterator List::insert(iterator pos, const T& value) const { - return iterator { impl_->list.insert(pos.iterator_, c10::detail::list_element_from(value)) }; + return iterator { impl_->list.insert(pos.iterator_, c10::detail::ListElementFrom::from(value)) }; } template typename List::iterator List::insert(iterator pos, T&& value) const { - return iterator { impl_->list.insert(pos.iterator_, c10::detail::list_element_from(std::move(value))) }; + return iterator { impl_->list.insert(pos.iterator_, c10::detail::ListElementFrom::from(std::move(value))) }; } template @@ -231,12 +241,12 @@ typename List::iterator List::emplace(iterator pos, Args&&... value) const template void List::push_back(const T& value) const { - impl_->list.push_back(c10::detail::list_element_from(value)); + impl_->list.push_back(c10::detail::ListElementFrom::from(value)); } template void List::push_back(T&& value) const { - impl_->list.push_back(c10::detail::list_element_from(std::move(value))); + impl_->list.push_back(c10::detail::ListElementFrom::from(std::move(value))); } template diff --git a/torch/csrc/jit/runtime/vararg_functions.cpp b/torch/csrc/jit/runtime/vararg_functions.cpp index 220a5e67f723..930ea5eb8f98 100644 --- a/torch/csrc/jit/runtime/vararg_functions.cpp +++ b/torch/csrc/jit/runtime/vararg_functions.cpp @@ -213,7 +213,7 @@ void listConstruct(Stack& stack, const at::ListType& type, size_t num_inputs) { c10::List vals(type.getElementType()); vals.reserve(num_inputs); for (size_t i = stack.size() - num_inputs; i < stack.size(); ++i) { - vals.emplace_back(std::move(stack[i])); + vals.push_back(std::move(stack[i])); } drop(stack, num_inputs); return vals; From 3b56e9d0efafb8ddebda6d9fe511323df593fbbe Mon Sep 17 00:00:00 2001 From: Karthik Prasad Date: Thu, 7 Jan 2021 15:19:57 -0800 Subject: [PATCH 05/54] [pytorch] prune based on custom importance scores (#48378) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/48378 This commit adds support for accepting custom importance scores to use for pruning mask computation, rather than only using the parameter. This is useful if one wants to prune based on scores from different technique such as activations, gradients, weighted scoring of parameters, etc. An alternative to the above approach would be pass the custom mask to the already available interface. However, the ability to accept importance scores is easier it can leverage the mask computation logic that has already been baked in. In addition, the commit also makes some minor lint fixes. Test Plan: * Unit tests * Circle CI Differential Revision: D24997355 fbshipit-source-id: 30797897977b57d3e3bc197987da20e88febb1fa --- test/test_nn.py | 136 +++++++++++++++++++++++++++-- torch/nn/utils/prune.py | 184 ++++++++++++++++++++++++++-------------- 2 files changed, 251 insertions(+), 69 deletions(-) diff --git a/test/test_nn.py b/test/test_nn.py index 386ba369dca6..83a98e360857 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -2405,7 +2405,6 @@ def test_pruning_container_compute_mask(self): # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 self.assertEqualIgnoreType(expected_mask, computed_mask) - def test_l1_unstructured_pruning(self): r"""Test that l1 unstructured pruning actually removes the lowest entries by l1 norm (by hand). It also checks that applying l1 @@ -2430,6 +2429,35 @@ def test_l1_unstructured_pruning(self): # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 self.assertEqualIgnoreType(expected_weight, m.weight) + def test_l1_unstructured_pruning_with_importance_scores(self): + r"""Test that l1 unstructured pruning actually removes the lowest + entries of importance scores and not the parameter by l1 norm (by hand). + It also checks that applying l1 unstructured pruning more than once + respects the previous mask. + """ + m = nn.Linear(4, 2) + # modify its weight matrix by hand + m.weight = torch.nn.Parameter( + torch.tensor( + [[1, 2, 3, 4], [-4, -3, -2, -1]], dtype=torch.float32 + ) + ) + importance_scores = torch.tensor( + [[4, 2, 1, 3], [-3, -1, -2, -4]], dtype=torch.float32 + ) + + prune.l1_unstructured(m, 'weight', amount=2, importance_scores=importance_scores) + expected_weight = torch.tensor([[1, 2, 0, 4], [-4, 0, -2, -1]]) + # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 + self.assertEqualIgnoreType(expected_weight, m.weight) + + # check that pruning again removes two entries of m.weight that are colocated with + # the next two smallest absolute values of importance scores. + prune.l1_unstructured(m, 'weight', amount=2, importance_scores=importance_scores) + expected_weight = torch.tensor([[1, 0, 0, 4], [-4, 0, 0, -1]]) + # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 + self.assertEqualIgnoreType(expected_weight, m.weight) + def test_unstructured_pruning_same_magnitude(self): r"""Since it may happen that the tensor to prune has entries with the same exact magnitude, it is important to check that pruning happens @@ -2447,7 +2475,6 @@ def test_unstructured_pruning_same_magnitude(self): self.assertEqual(nparams_toprune, nparams_pruned) def test_random_structured_pruning_amount(self): - AMOUNT = 0.6 AXIS = 2 p = prune.RandomStructured(amount=AMOUNT, dim=AXIS) @@ -2464,7 +2491,6 @@ def test_random_structured_pruning_amount(self): ) assert per_column_sums == [0, 20] - def test_ln_structured_pruning(self): r"""Check Ln structured pruning by hand. """ @@ -2488,6 +2514,33 @@ def test_ln_structured_pruning(self): prune.ln_structured(m, 'weight', amount=1, n=1, dim=-1) self.assertEqual(expected_mask_axis3, m.weight_mask) + def test_ln_structured_pruning_importance_scores(self): + r"""Check Ln structured pruning by hand. + """ + m = nn.Conv2d(3, 1, 2) + m.weight.data = torch.Tensor( + [[[[1., 2.], [1., 2.5]], + [[0.5, 1.], [0.1, 0.1]], + [[-3., -5.], [0.1, -1.]]]] + ) + importance_scores = torch.Tensor( + [[[[10., 1.], [10., 1.]], + [[30., 3.], [30., 3.]], + [[-20., -2.], [-20., -2.]]]] + ) + # expected effect of pruning 1 of the 3 channels by L2-norm + expected_mask_axis1 = torch.ones_like(m.weight) + expected_mask_axis1[:, 0] = 0. + + prune.ln_structured(m, 'weight', amount=1, n=2, dim=1, importance_scores=importance_scores) + self.assertEqual(expected_mask_axis1, m.weight_mask) + + # expected effect of pruning 1 of the 2 columns along axis -1 by L1-norm + expected_mask_axis3 = expected_mask_axis1 + expected_mask_axis3[:, :, :, 1] = 0. + + prune.ln_structured(m, 'weight', amount=1, n=1, dim=-1, importance_scores=importance_scores) + self.assertEqual(expected_mask_axis3, m.weight_mask) def test_remove_pruning(self): r"""`prune.remove` removes the hook and the reparametrization @@ -2567,6 +2620,49 @@ def test_global_pruning(self): expected_nweight = torch.tensor([[0, 0, -2]]).to(dtype=n.weight.dtype) self.assertEqual(expected_nweight, n.weight) + def test_global_pruning_importance_scores(self): + r"""Test that global l1 unstructured pruning over 2 parameters removes + the `amount=4` smallest global weights across the 2 parameters. + """ + m = nn.Linear(4, 2) + n = nn.Linear(3, 1) + # modify the weight matrices by hand + m.weight = torch.nn.Parameter( + torch.tensor([[1, 2, 3, 4], [-4, -3, -2, -1]]).to( + dtype=torch.float32) + ) + m_importance_scores = torch.tensor( + [[4, 2, 1, 3], [-3, -1, -2, -4]], dtype=torch.float32 + ) + n.weight = torch.nn.Parameter( + torch.tensor([[0, 0.1, -2]]).to( + dtype=torch.float32) + ) + n_importance_scores = torch.tensor([[0, 10., -0.2]]).to(dtype=torch.float32) + + params_to_prune = ( + (m, 'weight'), + (n, 'weight'), + ) + importance_scores = { + (m, 'weight'): m_importance_scores, + (n, 'weight'): n_importance_scores, + } + + # prune the 4 smallest weights globally by L1 magnitude + prune.global_unstructured( + params_to_prune, + pruning_method=prune.L1Unstructured, + amount=4, + importance_scores=importance_scores, + ) + + expected_m_weight = torch.tensor([[1, 2, 0, 4], [-4, 0, -2, -1]]) + # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 + self.assertEqualIgnoreType(expected_m_weight, m.weight) + + expected_n_weight = torch.tensor([[0, 0.1, 0]]).to(dtype=n.weight.dtype) + self.assertEqual(expected_n_weight, n.weight) def test_custom_from_mask_pruning(self): r"""Test that the CustomFromMask is capable of receiving @@ -2656,7 +2752,6 @@ def test_pruning_serialization_model(self): self.assertEqual(pruned_weight, new_model[0].weight) - def test_pruning_serialization_state_dict(self): # create a model model = torch.nn.Sequential( @@ -2707,7 +2802,6 @@ def test_pruning_serialization_state_dict(self): self.assertEqual(pruned_weight, new_model[0].weight) - def test_prune(self): # create a new pruning method p = prune.L1Unstructured(amount=2) @@ -2721,6 +2815,37 @@ def test_prune(self): pruned_tensor = p.prune(t, default_mask) self.assertEqual(t * expected_mask, pruned_tensor) + def test_prune_importance_scores(self): + # create a new pruning method + p = prune.L1Unstructured(amount=2) + # create tensor to be pruned + t = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]).to(dtype=torch.float32) + importance_scores = torch.tensor( + [[1, 2, 3, 4], [1.5, 1.6, 1.7, 1.8]] + ).to(dtype=torch.float32) + # create prior mask by hand + default_mask = torch.tensor([[1, 1, 1, 0], [1, 1, 0, 1]]) + # since we are pruning the two lowest magnitude units, the outcome of + # the calculation should be this: + expected_mask = torch.tensor([[0, 1, 1, 0], [0, 1, 0, 1]]) + pruned_tensor = p.prune(t, default_mask, importance_scores=importance_scores) + self.assertEqual(t * expected_mask, pruned_tensor) + + def test_prune_importance_scores_mimic_default(self): + # create a new pruning method + p = prune.L1Unstructured(amount=2) + # create tensor to be pruned + t = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]).to(dtype=torch.float32) + # create prior mask by hand + default_mask = torch.tensor([[1, 1, 1, 0], [1, 1, 0, 1]]) + # since we are pruning the two lowest magnitude units, the outcome of + # the calculation should be this: + expected_mask = torch.tensor([[0, 0, 1, 0], [1, 1, 0, 1]]) + pruned_tensor_without_importance_scores = p.prune(t, default_mask) + pruned_tensor_with_importance_scores = p.prune(t, default_mask, importance_scores=t) + self.assertEqual(pruned_tensor_without_importance_scores, pruned_tensor_with_importance_scores) + self.assertEqual(t * expected_mask, pruned_tensor_without_importance_scores) + def test_rnn_pruning(self): l = torch.nn.LSTM(32, 32) # This Module has 4 parameters called: @@ -2752,6 +2877,7 @@ def test_rnn_pruning(self): assert dict(l.named_parameters())['weight_ih_l0'] is not None assert 'weight_ih_l0_orig' not in dict(l.named_parameters()) + def test_rnn_weight_norm(self): def check_weight_norm(l, name, num_params): # This Module has 4 or 5 parameters called: diff --git a/torch/nn/utils/prune.py b/torch/nn/utils/prune.py index 851a551da0d8..c21940689a77 100644 --- a/torch/nn/utils/prune.py +++ b/torch/nn/utils/prune.py @@ -1,13 +1,14 @@ r""" Pruning methods """ -from abc import abstractmethod import numbers -import torch -from abc import ABC +from abc import ABC, abstractmethod from collections.abc import Iterable from typing import Tuple +import torch + + class BasePruningMethod(ABC): r"""Abstract base class for creation of new pruning techniques. @@ -40,7 +41,8 @@ def compute_mask(self, t, default_mask): method recipe. Args: - t (torch.Tensor): tensor representing the parameter to prune + t (torch.Tensor): tensor representing the importance scores of the + parameter to prune. default_mask (torch.Tensor): Base mask from previous pruning iterations, that need to be respected after the new mask is applied. Same dims as ``t``. @@ -64,9 +66,7 @@ def apply_mask(self, module): """ # to carry out the multiplication, the mask needs to have been computed, # so the pruning method must know what tensor it's operating on - assert ( - self._tensor_name is not None - ), "Module {} has to be pruned".format( + assert self._tensor_name is not None, "Module {} has to be pruned".format( module ) # this gets set in apply() mask = getattr(module, self._tensor_name + "_mask") @@ -75,7 +75,7 @@ def apply_mask(self, module): return pruned_tensor @classmethod - def apply(cls, module, name, *args, **kwargs): + def apply(cls, module, name, *args, importance_scores=None, **kwargs): r"""Adds the forward pre-hook that enables pruning on the fly and the reparametrization of a tensor in terms of the original tensor and the pruning mask. @@ -86,6 +86,11 @@ def apply(cls, module, name, *args, **kwargs): will act. args: arguments passed on to a subclass of :class:`BasePruningMethod` + importance_scores (torch.Tensor): tensor of importance scores (of + same shape as module parameter) used to compute mask for pruning. + The values in this tensor indicate the importance of the + corresponding elements in the parameter being pruned. + If unspecified or None, the parameter will be used in its place. kwargs: keyword arguments passed on to a subclass of a :class:`BasePruningMethod` """ @@ -101,10 +106,7 @@ def _get_composite_method(cls, module, name, *args, **kwargs): for k, hook in module._forward_pre_hooks.items(): # if it exists, take existing thing, remove hook, then # go through normal thing - if ( - isinstance(hook, BasePruningMethod) - and hook._tensor_name == name - ): + if isinstance(hook, BasePruningMethod) and hook._tensor_name == name: old_method = hook hooks_to_remove.append(k) found += 1 @@ -150,8 +152,18 @@ def _get_composite_method(cls, module, name, *args, **kwargs): # Pruning is to be applied to the module's tensor named `name`, # starting from the state it is found in prior to this iteration of - # pruning + # pruning. The pruning mask is calculated based on importances scores. + orig = getattr(module, name) + if importance_scores is not None: + assert ( + importance_scores.shape == orig.shape + ), "importance_scores should have the same shape as parameter \ + {} of {}".format( + name, module + ) + else: + importance_scores = orig # If this is the first time pruning is applied, take care of moving # the original tensor to a new parameter called name + '_orig' and @@ -166,13 +178,17 @@ def _get_composite_method(cls, module, name, *args, **kwargs): # has been done before in a previous pruning iteration, so we're good # to go else: - default_mask = getattr(module, name + "_mask").detach().clone(memory_format=torch.contiguous_format) + default_mask = ( + getattr(module, name + "_mask") + .detach() + .clone(memory_format=torch.contiguous_format) + ) # Use try/except because if anything goes wrong with the mask # computation etc., you'd want to roll back. try: # get the final mask, computed according to the specific method - mask = method.compute_mask(orig, default_mask=default_mask) + mask = method.compute_mask(importance_scores, default_mask=default_mask) # reparametrize by saving mask to `module[name + '_mask']`... module.register_buffer(name + "_mask", mask) # ... and the new pruned tensor to `module[name]` @@ -190,13 +206,18 @@ def _get_composite_method(cls, module, name, *args, **kwargs): return method - def prune(self, t, default_mask=None): + def prune(self, t, default_mask=None, importance_scores=None): r"""Computes and returns a pruned version of input tensor ``t`` according to the pruning rule specified in :meth:`compute_mask`. Args: t (torch.Tensor): tensor to prune (of same dimensions as ``default_mask``). + importance_scores (torch.Tensor): tensor of importance scores (of + same shape as ``t``) used to compute mask for pruning ``t``. + The values in this tensor indicate the importance of the + corresponding elements in the ``t`` that is being pruned. + If unspecified or None, the tensor ``t`` will be used in its place. default_mask (torch.Tensor, optional): mask from previous pruning iteration, if any. To be considered when determining what portion of the tensor that pruning should act on. If None, @@ -205,9 +226,14 @@ def prune(self, t, default_mask=None): Returns: pruned version of tensor ``t``. """ - if default_mask is None: - default_mask = torch.ones_like(t) - return t * self.compute_mask(t, default_mask=default_mask) + if importance_scores is not None: + assert ( + importance_scores.shape == t.shape + ), "importance_scores should have the same shape as tensor t" + else: + importance_scores = t + default_mask = default_mask if default_mask is not None else torch.ones_like(t) + return t * self.compute_mask(importance_scores, default_mask=default_mask) def remove(self, module): r"""Removes the pruning reparameterization from a module. The pruned @@ -249,7 +275,7 @@ class PruningContainer(BasePruningMethod): """ def __init__(self, *args): - self._pruning_methods: Tuple['BasePruningMethod', ...] = tuple() + self._pruning_methods: Tuple["BasePruningMethod", ...] = tuple() if not isinstance(args, Iterable): # only 1 item self._tensor_name = args._tensor_name self.add_pruning_method(args) @@ -319,6 +345,7 @@ def compute_mask(self, t, default_mask): pruning ``method`` (of same dimensions as ``default_mask`` and ``t``). """ + def _combine_masks(method, t, mask): r""" Args: @@ -360,13 +387,12 @@ def _combine_masks(method, t, mask): # if dim is still negative after subtracting it from n_dims if dim < 0: raise IndexError( - 'Index is out of bounds for tensor with dimensions {}' - .format(n_dims) + "Index is out of bounds for tensor with dimensions {}".format( + n_dims + ) ) # find channels along dim = dim that aren't already tots 0ed out - keep_channel = ( - mask.sum(dim=[d for d in range(n_dims) if d != dim]) != 0 - ) + keep_channel = mask.sum(dim=[d for d in range(n_dims) if d != dim]) != 0 # create slice to identify what to prune slc = [slice(None)] * n_dims slc[dim] = keep_channel @@ -470,9 +496,7 @@ def apply(cls, module, name, amount): fraction of parameters to prune. If ``int``, it represents the absolute number of parameters to prune. """ - return super(RandomUnstructured, cls).apply( - module, name, amount=amount - ) + return super(RandomUnstructured, cls).apply(module, name, amount=amount) class L1Unstructured(BasePruningMethod): @@ -509,16 +533,14 @@ def compute_mask(self, t, default_mask): if nparams_toprune != 0: # k=0 not supported by torch.kthvalue # largest=True --> top k; largest=False --> bottom k # Prune the smallest k - topk = torch.topk( - torch.abs(t).view(-1), k=nparams_toprune, largest=False - ) + topk = torch.topk(torch.abs(t).view(-1), k=nparams_toprune, largest=False) # topk will have .indices and .values mask.view(-1)[topk.indices] = 0 return mask @classmethod - def apply(cls, module, name, amount): + def apply(cls, module, name, amount, importance_scores=None): r"""Adds the forward pre-hook that enables pruning on the fly and the reparametrization of a tensor in terms of the original tensor and the pruning mask. @@ -531,8 +553,15 @@ def apply(cls, module, name, amount): If ``float``, should be between 0.0 and 1.0 and represent the fraction of parameters to prune. If ``int``, it represents the absolute number of parameters to prune. + importance_scores (torch.Tensor): tensor of importance scores (of same + shape as module parameter) used to compute mask for pruning. + The values in this tensor indicate the importance of the corresponding + elements in the parameter being pruned. + If unspecified or None, the module parameter will be used in its place. """ - return super(L1Unstructured, cls).apply(module, name, amount=amount) + return super(L1Unstructured, cls).apply( + module, name, amount=amount, importance_scores=importance_scores + ) class RandomStructured(BasePruningMethod): @@ -634,9 +663,7 @@ def apply(cls, module, name, amount, dim=-1): dim (int, optional): index of the dim along which we define channels to prune. Default: -1. """ - return super(RandomStructured, cls).apply( - module, name, amount=amount, dim=dim - ) + return super(RandomStructured, cls).apply(module, name, amount=amount, dim=dim) class LnStructured(BasePruningMethod): @@ -705,11 +732,7 @@ def compute_mask(self, t, default_mask): norm = _compute_norm(t, self.n, self.dim) # largest=True --> top k; largest=False --> bottom k # Keep the largest k channels along dim=self.dim - topk = torch.topk( - norm, - k=nparams_tokeep, - largest=True, - ) + topk = torch.topk(norm, k=nparams_tokeep, largest=True) # topk will have .indices and .values # Compute binary mask by initializing it to all 0s and then filling in @@ -737,7 +760,7 @@ def make_mask(t, dim, indices): return mask @classmethod - def apply(cls, module, name, amount, n, dim): + def apply(cls, module, name, amount, n, dim, importance_scores=None): r"""Adds the forward pre-hook that enables pruning on the fly and the reparametrization of a tensor in terms of the original tensor and the pruning mask. @@ -754,9 +777,19 @@ def apply(cls, module, name, amount, n, dim): entries for argument ``p`` in :func:`torch.norm`. dim (int): index of the dim along which we define channels to prune. + importance_scores (torch.Tensor): tensor of importance scores (of same + shape as module parameter) used to compute mask for pruning. + The values in this tensor indicate the importance of the corresponding + elements in the parameter being pruned. + If unspecified or None, the module parameter will be used in its place. """ return super(LnStructured, cls).apply( - module, name, amount=amount, n=n, dim=dim + module, + name, + amount=amount, + n=n, + dim=dim, + importance_scores=importance_scores, ) @@ -783,9 +816,7 @@ def apply(cls, module, name, mask): name (str): parameter name within ``module`` on which pruning will act. """ - return super(CustomFromMask, cls).apply( - module, name, mask - ) + return super(CustomFromMask, cls).apply(module, name, mask=mask) def identity(module, name): @@ -852,7 +883,7 @@ def random_unstructured(module, name, amount): return module -def l1_unstructured(module, name, amount): +def l1_unstructured(module, name, amount, importance_scores=None): r"""Prunes tensor corresponding to parameter called ``name`` in ``module`` by removing the specified `amount` of (currently unpruned) units with the lowest L1-norm. @@ -872,6 +903,11 @@ def l1_unstructured(module, name, amount): If ``float``, should be between 0.0 and 1.0 and represent the fraction of parameters to prune. If ``int``, it represents the absolute number of parameters to prune. + importance_scores (torch.Tensor): tensor of importance scores (of same + shape as module parameter) used to compute mask for pruning. + The values in this tensor indicate the importance of the corresponding + elements in the parameter being pruned. + If unspecified or None, the module parameter will be used in its place. Returns: module (nn.Module): modified (i.e. pruned) version of the input module @@ -881,7 +917,9 @@ def l1_unstructured(module, name, amount): >>> m.state_dict().keys() odict_keys(['bias', 'weight_orig', 'weight_mask']) """ - L1Unstructured.apply(module, name, amount) + L1Unstructured.apply( + module, name, amount=amount, importance_scores=importance_scores + ) return module @@ -922,7 +960,7 @@ def random_structured(module, name, amount, dim): return module -def ln_structured(module, name, amount, n, dim): +def ln_structured(module, name, amount, n, dim, importance_scores=None): r"""Prunes tensor corresponding to parameter called ``name`` in ``module`` by removing the specified ``amount`` of (currently unpruned) channels along the specified ``dim`` with the lowest L``n``-norm. @@ -945,6 +983,11 @@ def ln_structured(module, name, amount, n, dim): n (int, float, inf, -inf, 'fro', 'nuc'): See documentation of valid entries for argument ``p`` in :func:`torch.norm`. dim (int): index of the dim along which we define channels to prune. + importance_scores (torch.Tensor): tensor of importance scores (of same + shape as module parameter) used to compute mask for pruning. + The values in this tensor indicate the importance of the corresponding + elements in the parameter being pruned. + If unspecified or None, the module parameter will be used in its place. Returns: module (nn.Module): modified (i.e. pruned) version of the input module @@ -954,11 +997,13 @@ def ln_structured(module, name, amount, n, dim): nn.Conv2d(5, 3, 2), 'weight', amount=0.3, dim=1, n=float('-inf') ) """ - LnStructured.apply(module, name, amount, n, dim) + LnStructured.apply( + module, name, amount, n, dim, importance_scores=importance_scores + ) return module -def global_unstructured(parameters, pruning_method, **kwargs): +def global_unstructured(parameters, pruning_method, importance_scores=None, **kwargs): r""" Globally prunes tensors corresponding to all parameters in ``parameters`` by applying the specified ``pruning_method``. @@ -977,6 +1022,12 @@ def global_unstructured(parameters, pruning_method, **kwargs): pruning_method (function): a valid pruning function from this module, or a custom one implemented by the user that satisfies the implementation guidelines and has ``PRUNING_TYPE='unstructured'``. + importance_scores (dict): a dictionary mapping (module, name) tuples to + the corresponding parameter's importance scores tensor. The tensor + should be the same shape as the parameter, and is used for computing + mask for pruning. + If unspecified or None, the parameter will be used in place of its + importance scores. kwargs: other keyword arguments such as: amount (int or float): quantity of parameters to prune across the specified parameters. @@ -1011,17 +1062,25 @@ def global_unstructured(parameters, pruning_method, **kwargs): """ # ensure parameters is a list or generator of tuples - assert isinstance(parameters, Iterable) + if not isinstance(parameters, Iterable): + raise TypeError("global_unstructured(): parameters is not an Iterable") - # flatten parameter values to consider them all at once in global pruning - t = torch.nn.utils.parameters_to_vector([getattr(*p) for p in parameters]) + importance_scores = importance_scores if importance_scores is not None else {} + if not isinstance(importance_scores, dict): + raise TypeError("global_unstructured(): importance_scores must be of type dict") + + # flatten importance scores to consider them all at once in global pruning + relevant_importance_scores = torch.nn.utils.parameters_to_vector( + [ + importance_scores.get((module, name), getattr(module, name)) + for (module, name) in parameters + ] + ) # similarly, flatten the masks (if they exist), or use a flattened vector # of 1s of the same dimensions as t default_mask = torch.nn.utils.parameters_to_vector( [ - getattr( - module, name + "_mask", torch.ones_like(getattr(module, name)) - ) + getattr(module, name + "_mask", torch.ones_like(getattr(module, name))) for (module, name) in parameters ] ) @@ -1044,7 +1103,7 @@ def global_unstructured(parameters, pruning_method, **kwargs): # use the `compute_mask` method from `PruningContainer` to combine the # mask computed by the new method with the pre-existing mask - final_mask = container.compute_mask(t, default_mask) + final_mask = container.compute_mask(relevant_importance_scores, default_mask) # Pointer for slicing the mask to match the shape of each parameter pointer = 0 @@ -1057,7 +1116,7 @@ def global_unstructured(parameters, pruning_method, **kwargs): param_mask = final_mask[pointer : pointer + num_param].view_as(param) # Assign the correct pre-computed mask to each parameter and add it # to the forward_pre_hooks like any other pruning method - custom_from_mask(module, name, param_mask) + custom_from_mask(module, name, mask=param_mask) # Increment the pointer to continue slicing the final_mask pointer += num_param @@ -1172,8 +1231,7 @@ def _validate_pruning_amount_init(amount): """ if not isinstance(amount, numbers.Real): raise TypeError( - "Invalid type for amount: {}. Must be int or float." - "".format(amount) + "Invalid type for amount: {}. Must be int or float." "".format(amount) ) if (isinstance(amount, numbers.Integral) and amount < 0) or ( @@ -1260,9 +1318,7 @@ def _validate_pruning_dim(t, dim): dim (int): index of the dim along which we define channels to prune """ if dim >= t.dim(): - raise IndexError( - "Invalid index {} for tensor of size {}".format(dim, t.shape) - ) + raise IndexError("Invalid index {} for tensor of size {}".format(dim, t.shape)) def _compute_norm(t, n, dim): From 2bceee785fcfed60b67d252238d6be79f758968a Mon Sep 17 00:00:00 2001 From: Richard Barnes Date: Thu, 7 Jan 2021 15:32:01 -0800 Subject: [PATCH 06/54] Clean up simple type annotations in nn/functional.py (#50106) Summary: Also reformats code to pass linters. Pull Request resolved: https://github.com/pytorch/pytorch/pull/50106 Test Plan: Sandcastle tests Reviewed By: xush6528 Differential Revision: D25787566 fbshipit-source-id: 39c86b4021e279f92f8ccf30252a6cfae1063c3c --- torch/nn/functional.py | 2110 +++++++++++++++++++++++++--------------- 1 file changed, 1319 insertions(+), 791 deletions(-) diff --git a/torch/nn/functional.py b/torch/nn/functional.py index fc71e4a0c449..162c22ea236a 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -1,22 +1,25 @@ r"""Functional interface""" -import warnings +from typing import Callable, List, Optional, Tuple import math +import warnings import torch +from torch import _VF from torch._C import _infer_size, _add_docstr +from torch._torch_docs import reproducibility_notes, tf32_notes + +from .._jit_internal import boolean_dispatch, _overload +from ..overrides import has_torch_function, handle_torch_function from . import _reduction as _Reduction +from . import grad # noqa: F401 from .modules import utils from .modules.utils import _single, _pair, _triple, _list_with_default -from . import grad # noqa: F401 -from torch import _VF -from .._jit_internal import boolean_dispatch, List, Optional, _overload, Tuple -from ..overrides import has_torch_function, handle_torch_function -from torch._torch_docs import reproducibility_notes, tf32_notes - Tensor = torch.Tensor -conv1d = _add_docstr(torch.conv1d, r""" +conv1d = _add_docstr( + torch.conv1d, + r""" conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor Applies a 1D convolution over an input signal composed of several input @@ -28,7 +31,10 @@ Note: {cudnn_reproducibility_note} -""".format(**reproducibility_notes, **tf32_notes) + r""" +""".format( + **reproducibility_notes, **tf32_notes + ) + + r""" Args: input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)` @@ -48,9 +54,12 @@ >>> filters = torch.randn(33, 16, 3) >>> inputs = torch.randn(20, 16, 50) >>> F.conv1d(inputs, filters) -""") +""", +) -conv2d = _add_docstr(torch.conv2d, r""" +conv2d = _add_docstr( + torch.conv2d, + r""" conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor Applies a 2D convolution over an input image composed of several input @@ -62,7 +71,10 @@ Note: {cudnn_reproducibility_note} -""".format(**reproducibility_notes, **tf32_notes) + r""" +""".format( + **reproducibility_notes, **tf32_notes + ) + + r""" Args: input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)` weight: filters of shape :math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , kH , kW)` @@ -82,9 +94,12 @@ >>> filters = torch.randn(8,4,3,3) >>> inputs = torch.randn(1,4,5,5) >>> F.conv2d(inputs, filters, padding=1) -""") # noqa: E501 +""", +) # noqa: E501 -conv3d = _add_docstr(torch.conv3d, r""" +conv3d = _add_docstr( + torch.conv3d, + r""" conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor Applies a 3D convolution over an input image composed of several input @@ -96,7 +111,10 @@ Note: {cudnn_reproducibility_note} -""".format(**reproducibility_notes, **tf32_notes) + r""" +""".format( + **reproducibility_notes, **tf32_notes + ) + + r""" Args: input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iT , iH , iW)` @@ -116,9 +134,12 @@ >>> filters = torch.randn(33, 16, 3, 3, 3) >>> inputs = torch.randn(20, 16, 50, 10, 20) >>> F.conv3d(inputs, filters) -""") # noqa: E501 +""", +) # noqa: E501 -conv_transpose1d = _add_docstr(torch.conv_transpose1d, r""" +conv_transpose1d = _add_docstr( + torch.conv_transpose1d, + r""" conv_transpose1d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1) -> Tensor Applies a 1D transposed convolution operator over an input signal @@ -130,7 +151,10 @@ Note: {cudnn_reproducibility_note} -""".format(**reproducibility_notes, **tf32_notes) + r""" +""".format( + **reproducibility_notes, **tf32_notes + ) + + r""" Args: input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)` @@ -153,9 +177,12 @@ >>> inputs = torch.randn(20, 16, 50) >>> weights = torch.randn(16, 33, 5) >>> F.conv_transpose1d(inputs, weights) -""") +""", +) -conv_transpose2d = _add_docstr(torch.conv_transpose2d, r""" +conv_transpose2d = _add_docstr( + torch.conv_transpose2d, + r""" conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1) -> Tensor Applies a 2D transposed convolution operator over an input image @@ -167,7 +194,10 @@ Note: {cudnn_reproducibility_note} -""".format(**reproducibility_notes, **tf32_notes) + r""" +""".format( + **reproducibility_notes, **tf32_notes + ) + + r""" Args: input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)` @@ -192,9 +222,12 @@ >>> inputs = torch.randn(1, 4, 5, 5) >>> weights = torch.randn(4, 8, 3, 3) >>> F.conv_transpose2d(inputs, weights, padding=1) -""") # noqa: E501 +""", +) # noqa: E501 -conv_transpose3d = _add_docstr(torch.conv_transpose3d, r""" +conv_transpose3d = _add_docstr( + torch.conv_transpose3d, + r""" conv_transpose3d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1) -> Tensor Applies a 3D transposed convolution operator over an input image @@ -206,7 +239,10 @@ Note: {cudnn_reproducibility_note} -""".format(**reproducibility_notes, **tf32_notes) + r""" +""".format( + **reproducibility_notes, **tf32_notes + ) + + r""" Args: input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iT , iH , iW)` @@ -230,9 +266,12 @@ >>> inputs = torch.randn(20, 16, 50, 10, 20) >>> weights = torch.randn(16, 33, 3, 3, 3) >>> F.conv_transpose3d(inputs, weights) -""") # noqa: E501 +""", +) # noqa: E501 -conv_tbc = _add_docstr(torch.conv_tbc, r""" +conv_tbc = _add_docstr( + torch.conv_tbc, + r""" Applies a 1-dimensional sequence convolution over an input sequence. Input and output dimensions are (Time, Batch, Channels) - hence TBC. @@ -241,11 +280,14 @@ weight: filter of shape (:math:`\text{kernel width} \times \text{in\_channels} \times \text{out\_channels}`) bias: bias of shape (:math:`\text{out\_channels}`) pad: number of timesteps to pad. Default: 0 -""") +""", +) # Pooling -avg_pool1d = _add_docstr(torch.avg_pool1d, r""" +avg_pool1d = _add_docstr( + torch.avg_pool1d, + r""" avg_pool1d(input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True) -> Tensor Applies a 1D average pooling over an input signal composed of several @@ -273,10 +315,13 @@ >>> F.avg_pool1d(input, kernel_size=3, stride=2) tensor([[[ 2., 4., 6.]]]) -""") +""", +) -avg_pool2d = _add_docstr(torch._C._nn.avg_pool2d, r""" +avg_pool2d = _add_docstr( + torch._C._nn.avg_pool2d, + r""" avg_pool2d(input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None) -> Tensor Applies 2D average-pooling operation in :math:`kH \times kW` regions by step size @@ -299,9 +344,12 @@ averaging calculation. Default: ``True`` divisor_override: if specified, it will be used as divisor, otherwise size of the pooling region will be used. Default: None -""") +""", +) -avg_pool3d = _add_docstr(torch._C._nn.avg_pool3d, r""" +avg_pool3d = _add_docstr( + torch._C._nn.avg_pool3d, + r""" avg_pool3d(input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None) -> Tensor Applies 3D average-pooling operation in :math:`kT \times kH \times kW` regions by step @@ -324,12 +372,13 @@ averaging calculation divisor_override: if specified, it will be used as divisor, otherwise size of the pooling region will be used. Default: None -""") +""", +) -def fractional_max_pool2d_with_indices(input, kernel_size, output_size=None, - output_ratio=None, return_indices=False, - _random_samples=None): +def fractional_max_pool2d_with_indices( + input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None +): # type: (Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], Optional[BroadcastingList2[float]], bool, Optional[Tensor]) -> Tuple[Tensor, Tensor] # noqa r"""Applies 2D fractional max pooling over an input signal composed of several input planes. @@ -363,50 +412,62 @@ def fractional_max_pool2d_with_indices(input, kernel_size, output_size=None, if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): return handle_torch_function( - fractional_max_pool2d_with_indices, (input,), input, kernel_size, - output_size=output_size, output_ratio=output_ratio, - return_indices=return_indices, _random_samples=_random_samples) + fractional_max_pool2d_with_indices, + (input,), + input, + kernel_size, + output_size=output_size, + output_ratio=output_ratio, + return_indices=return_indices, + _random_samples=_random_samples, + ) if output_size is None and output_ratio is None: - raise ValueError("fractional_max_pool2d requires specifying either " - "an output_size or an output_ratio") + raise ValueError("fractional_max_pool2d requires specifying either " "an output_size or an output_ratio") if output_size is None: assert output_ratio is not None _output_ratio = _pair(output_ratio) - output_size = [int(input.size(2) * _output_ratio[0]), - int(input.size(3) * _output_ratio[1])] + output_size = [int(input.size(2) * _output_ratio[0]), int(input.size(3) * _output_ratio[1])] if _random_samples is None: _random_samples = torch.rand(input.size(0), input.size(1), 2, dtype=input.dtype, device=input.device) return torch._C._nn.fractional_max_pool2d(input, kernel_size, output_size, _random_samples) -def _fractional_max_pool2d(input, kernel_size, output_size=None, - output_ratio=None, return_indices=False, - _random_samples=None): +def _fractional_max_pool2d( + input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None +): # type: (Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], Optional[BroadcastingList2[float]], bool, Optional[Tensor]) -> Tensor # noqa if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): return handle_torch_function( - fractional_max_pool2d, (input,), input, kernel_size, - output_size=output_size, output_ratio=output_ratio, - return_indices=return_indices, _random_samples=_random_samples) - return fractional_max_pool2d_with_indices(input, kernel_size, output_size, - output_ratio, return_indices, - _random_samples)[0] + fractional_max_pool2d, + (input,), + input, + kernel_size, + output_size=output_size, + output_ratio=output_ratio, + return_indices=return_indices, + _random_samples=_random_samples, + ) + return fractional_max_pool2d_with_indices( + input, kernel_size, output_size, output_ratio, return_indices, _random_samples + )[0] + fractional_max_pool2d = boolean_dispatch( - arg_name='return_indices', + arg_name="return_indices", arg_index=4, default=False, if_true=fractional_max_pool2d_with_indices, if_false=_fractional_max_pool2d, module_name=__name__, - func_name='fractional_max_pool2d') + func_name="fractional_max_pool2d", +) -def fractional_max_pool3d_with_indices(input, kernel_size, output_size=None, - output_ratio=None, return_indices=False, - _random_samples=None): +def fractional_max_pool3d_with_indices( + input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None +): # type: (Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], Optional[BroadcastingList3[float]], bool, Optional[Tensor]) -> Tuple[Tensor, Tensor] # noqa r"""Applies 3D fractional max pooling over an input signal composed of several input planes. @@ -441,50 +502,66 @@ def fractional_max_pool3d_with_indices(input, kernel_size, output_size=None, if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): return handle_torch_function( - fractional_max_pool3d_with_indices, (input,), input, kernel_size, - output_size=output_size, output_ratio=output_ratio, - return_indices=return_indices, _random_samples=_random_samples) + fractional_max_pool3d_with_indices, + (input,), + input, + kernel_size, + output_size=output_size, + output_ratio=output_ratio, + return_indices=return_indices, + _random_samples=_random_samples, + ) if output_size is None and output_ratio is None: - raise ValueError("fractional_max_pool3d requires specifying either " - "an output_size or an output_ratio") + raise ValueError("fractional_max_pool3d requires specifying either " "an output_size or an output_ratio") if output_size is None: assert output_ratio is not None _output_ratio = _triple(output_ratio) - output_size = [int(input.size(2) * _output_ratio[0]), - int(input.size(3) * _output_ratio[1]), - int(input.size(4) * _output_ratio[2])] + output_size = [ + int(input.size(2) * _output_ratio[0]), + int(input.size(3) * _output_ratio[1]), + int(input.size(4) * _output_ratio[2]), + ] if _random_samples is None: _random_samples = torch.rand(input.size(0), input.size(1), 3, dtype=input.dtype, device=input.device) return torch._C._nn.fractional_max_pool3d(input, kernel_size, output_size, _random_samples) -def _fractional_max_pool3d(input, kernel_size, output_size=None, - output_ratio=None, return_indices=False, - _random_samples=None): +def _fractional_max_pool3d( + input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None +): # type: (Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], Optional[BroadcastingList3[float]], bool, Optional[Tensor]) -> Tensor # noqa if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): return handle_torch_function( - fractional_max_pool3d, (input,), input, kernel_size, - output_size=output_size, output_ratio=output_ratio, - return_indices=return_indices, _random_samples=_random_samples) - return fractional_max_pool3d_with_indices(input, kernel_size, output_size, - output_ratio, return_indices, - _random_samples)[0] + fractional_max_pool3d, + (input,), + input, + kernel_size, + output_size=output_size, + output_ratio=output_ratio, + return_indices=return_indices, + _random_samples=_random_samples, + ) + return fractional_max_pool3d_with_indices( + input, kernel_size, output_size, output_ratio, return_indices, _random_samples + )[0] + fractional_max_pool3d = boolean_dispatch( - arg_name='return_indices', + arg_name="return_indices", arg_index=4, default=False, if_true=fractional_max_pool3d_with_indices, if_false=_fractional_max_pool3d, module_name=__name__, - func_name='fractional_max_pool3d') + func_name="fractional_max_pool3d", +) -def max_pool1d_with_indices(input, kernel_size, stride=None, padding=0, - dilation=1, ceil_mode=False, return_indices=False): +def max_pool1d_with_indices( + input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False +): # type: (Tensor, BroadcastingList1[int], Optional[BroadcastingList1[int]], BroadcastingList1[int], BroadcastingList1[int], bool, bool) -> Tuple[Tensor, Tensor] # noqa r"""Applies a 1D max pooling over an input signal composed of several input planes. @@ -494,41 +571,55 @@ def max_pool1d_with_indices(input, kernel_size, stride=None, padding=0, if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): return handle_torch_function( - max_pool1d_with_indices, (input,), input, kernel_size, - stride=stride, padding=padding, dilation=dilation, ceil_mode=ceil_mode, - return_indices=return_indices) + max_pool1d_with_indices, + (input,), + input, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + return_indices=return_indices, + ) if stride is None: stride = torch.jit.annotate(List[int], []) - return torch.max_pool1d_with_indices( - input, kernel_size, stride, padding, dilation, ceil_mode) + return torch.max_pool1d_with_indices(input, kernel_size, stride, padding, dilation, ceil_mode) -def _max_pool1d(input, kernel_size, stride=None, padding=0, dilation=1, - ceil_mode=False, return_indices=False): +def _max_pool1d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False): # type: (Tensor, BroadcastingList1[int], Optional[BroadcastingList1[int]], BroadcastingList1[int], BroadcastingList1[int], bool, bool) -> Tensor # noqa if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): return handle_torch_function( - max_pool1d, (input,), input, kernel_size, - stride=stride, padding=padding, dilation=dilation, ceil_mode=ceil_mode, - return_indices=return_indices) + max_pool1d, + (input,), + input, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + return_indices=return_indices, + ) if stride is None: stride = torch.jit.annotate(List[int], []) - return torch.max_pool1d( - input, kernel_size, stride, padding, dilation, ceil_mode) + return torch.max_pool1d(input, kernel_size, stride, padding, dilation, ceil_mode) + max_pool1d = boolean_dispatch( - arg_name='return_indices', + arg_name="return_indices", arg_index=6, default=False, if_true=max_pool1d_with_indices, if_false=_max_pool1d, module_name=__name__, - func_name='max_pool1d') + func_name="max_pool1d", +) -def max_pool2d_with_indices(input, kernel_size, stride=None, padding=0, dilation=1, - ceil_mode=False, return_indices=False): +def max_pool2d_with_indices( + input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False +): # type: (Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], BroadcastingList2[int], BroadcastingList2[int], bool, bool) -> Tuple[Tensor, Tensor] # noqa r"""Applies a 2D max pooling over an input signal composed of several input planes. @@ -538,40 +629,55 @@ def max_pool2d_with_indices(input, kernel_size, stride=None, padding=0, dilation if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): return handle_torch_function( - max_pool2d_with_indices, (input,), input, kernel_size, - stride=stride, padding=padding, dilation=dilation, ceil_mode=ceil_mode, - return_indices=return_indices) + max_pool2d_with_indices, + (input,), + input, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + return_indices=return_indices, + ) if stride is None: stride = torch.jit.annotate(List[int], []) return torch._C._nn.max_pool2d_with_indices(input, kernel_size, stride, padding, dilation, ceil_mode) -def _max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1, - ceil_mode=False, return_indices=False): +def _max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False): # type: (Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], BroadcastingList2[int], BroadcastingList2[int], bool, bool) -> Tensor # noqa if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): return handle_torch_function( - max_pool2d, (input,), input, kernel_size, - stride=stride, padding=padding, dilation=dilation, ceil_mode=ceil_mode, - return_indices=return_indices) + max_pool2d, + (input,), + input, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + return_indices=return_indices, + ) if stride is None: stride = torch.jit.annotate(List[int], []) - return torch.max_pool2d( - input, kernel_size, stride, padding, dilation, ceil_mode) + return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode) + max_pool2d = boolean_dispatch( - arg_name='return_indices', + arg_name="return_indices", arg_index=6, default=False, if_true=max_pool2d_with_indices, if_false=_max_pool2d, module_name=__name__, - func_name='max_pool2d') + func_name="max_pool2d", +) -def max_pool3d_with_indices(input, kernel_size, stride=None, padding=0, - dilation=1, ceil_mode=False, return_indices=False): +def max_pool3d_with_indices( + input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False +): # type: (Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], BroadcastingList3[int], BroadcastingList3[int], bool, bool) -> Tuple[Tensor, Tensor] # noqa r"""Applies a 3D max pooling over an input signal composed of several input planes. @@ -581,69 +687,86 @@ def max_pool3d_with_indices(input, kernel_size, stride=None, padding=0, if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): return handle_torch_function( - max_pool3d_with_indices, (input,), input, kernel_size, - stride=stride, padding=padding, dilation=dilation, ceil_mode=ceil_mode, - return_indices=return_indices) + max_pool3d_with_indices, + (input,), + input, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + return_indices=return_indices, + ) if stride is None: stride = torch.jit.annotate(List[int], []) - return torch._C._nn.max_pool3d_with_indices( - input, kernel_size, stride, padding, dilation, ceil_mode) + return torch._C._nn.max_pool3d_with_indices(input, kernel_size, stride, padding, dilation, ceil_mode) -def _max_pool3d(input, kernel_size, stride=None, padding=0, dilation=1, - ceil_mode=False, return_indices=False): +def _max_pool3d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False): # type: (Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], BroadcastingList3[int], BroadcastingList3[int], bool, bool) -> Tensor # noqa if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): return handle_torch_function( - max_pool3d, (input,), input, kernel_size, stride=stride, padding=padding, - dilation=dilation, ceil_mode=ceil_mode, return_indices=return_indices) + max_pool3d, + (input,), + input, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + return_indices=return_indices, + ) if stride is None: stride = torch.jit.annotate(List[int], []) - return torch.max_pool3d( - input, kernel_size, stride, padding, dilation, ceil_mode) + return torch.max_pool3d(input, kernel_size, stride, padding, dilation, ceil_mode) + max_pool3d = boolean_dispatch( - arg_name='return_indices', + arg_name="return_indices", arg_index=6, default=False, if_true=max_pool3d_with_indices, if_false=_max_pool3d, module_name=__name__, - func_name='max_pool3d') + func_name="max_pool3d", +) -def _unpool_output_size(input, kernel_size, stride, padding, output_size): - # type: (Tensor, List[int], List[int], List[int], Optional[List[int]]) -> List[int] +def _unpool_output_size( + input: Tensor, kernel_size: List[int], stride: List[int], padding: List[int], output_size: Optional[List[int]] +) -> List[int]: input_size = input.size() default_size = torch.jit.annotate(List[int], []) for d in range(len(kernel_size)): - default_size.append((input_size[d + 2] - 1) * stride[d] + - kernel_size[d] - 2 * padding[d]) + default_size.append((input_size[d + 2] - 1) * stride[d] + kernel_size[d] - 2 * padding[d]) if output_size is None: ret = default_size else: if len(output_size) == len(kernel_size) + 2: output_size = output_size[2:] if len(output_size) != len(kernel_size): - raise ValueError("output_size should be a sequence containing " - "{} or {} elements, but it has a length of '{}'" - .format(len(kernel_size), len(kernel_size) + 2, - len(output_size))) + raise ValueError( + "output_size should be a sequence containing " + "{} or {} elements, but it has a length of '{}'".format( + len(kernel_size), len(kernel_size) + 2, len(output_size) + ) + ) for d in range(len(kernel_size)): min_size = default_size[d] - stride[d] max_size = default_size[d] + stride[d] if not (min_size < output_size[d] < max_size): raise ValueError( - 'invalid output_size "{}" (dim {} must be between {} and {})' - .format(output_size, d, min_size, max_size)) + 'invalid output_size "{}" (dim {} must be between {} and {})'.format( + output_size, d, min_size, max_size + ) + ) ret = output_size return ret -def max_unpool1d(input, indices, kernel_size, stride=None, padding=0, - output_size=None): +def max_unpool1d(input, indices, kernel_size, stride=None, padding=0, output_size=None): # type: (Tensor, Tensor, BroadcastingList1[int], Optional[BroadcastingList1[int]], BroadcastingList1[int], Optional[BroadcastingList1[int]]) -> Tensor # noqa r"""Computes a partial inverse of :class:`MaxPool1d`. @@ -652,26 +775,30 @@ def max_unpool1d(input, indices, kernel_size, stride=None, padding=0, if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): return handle_torch_function( - max_unpool1d, (input,), input, indices, kernel_size, - stride=stride, padding=padding, output_size=output_size) + max_unpool1d, + (input,), + input, + indices, + kernel_size, + stride=stride, + padding=padding, + output_size=output_size, + ) kernel_size = _single(kernel_size) if stride is not None: _stride = _single(stride) else: _stride = kernel_size padding = _single(padding) - output_size = _unpool_output_size(input, kernel_size, _stride, padding, - output_size) + output_size = _unpool_output_size(input, kernel_size, _stride, padding, output_size) if isinstance(output_size, list): output_size = output_size + [1] else: output_size = output_size + (1,) - return torch._C._nn.max_unpool2d(input.unsqueeze(3), indices.unsqueeze(3), - output_size).squeeze(3) + return torch._C._nn.max_unpool2d(input.unsqueeze(3), indices.unsqueeze(3), output_size).squeeze(3) -def max_unpool2d(input, indices, kernel_size, stride=None, padding=0, - output_size=None): +def max_unpool2d(input, indices, kernel_size, stride=None, padding=0, output_size=None): # type: (Tensor, Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], BroadcastingList2[int], Optional[BroadcastingList2[int]]) -> Tensor # noqa r"""Computes a partial inverse of :class:`MaxPool2d`. @@ -680,21 +807,26 @@ def max_unpool2d(input, indices, kernel_size, stride=None, padding=0, if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): return handle_torch_function( - max_unpool2d, (input,), input, indices, kernel_size, - stride=stride, padding=padding, output_size=output_size) + max_unpool2d, + (input,), + input, + indices, + kernel_size, + stride=stride, + padding=padding, + output_size=output_size, + ) kernel_size = _pair(kernel_size) if stride is not None: _stride = _pair(stride) else: _stride = kernel_size padding = _pair(padding) - output_size = _unpool_output_size(input, kernel_size, _stride, padding, - output_size) + output_size = _unpool_output_size(input, kernel_size, _stride, padding, output_size) return torch._C._nn.max_unpool2d(input, indices, output_size) -def max_unpool3d(input, indices, kernel_size, stride=None, padding=0, - output_size=None): +def max_unpool3d(input, indices, kernel_size, stride=None, padding=0, output_size=None): # type: (Tensor, Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], BroadcastingList3[int], Optional[BroadcastingList3[int]]) -> Tensor # noqa r"""Computes a partial inverse of :class:`MaxPool3d`. @@ -703,18 +835,23 @@ def max_unpool3d(input, indices, kernel_size, stride=None, padding=0, if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): return handle_torch_function( - max_unpool3d, (input,), input, indices, kernel_size, - stride=stride, padding=padding, output_size=output_size) + max_unpool3d, + (input,), + input, + indices, + kernel_size, + stride=stride, + padding=padding, + output_size=output_size, + ) kernel_size = _triple(kernel_size) if stride is not None: _stride = _triple(stride) else: _stride = kernel_size padding = _triple(padding) - output_size = _unpool_output_size(input, kernel_size, _stride, padding, - output_size) - return torch._C._nn.max_unpool3d( - input, indices, output_size, _stride, padding) + output_size = _unpool_output_size(input, kernel_size, _stride, padding, output_size) + return torch._C._nn.max_unpool3d(input, indices, output_size, _stride, padding) def lp_pool2d(input, norm_type, kernel_size, stride=None, ceil_mode=False): @@ -728,15 +865,15 @@ def lp_pool2d(input, norm_type, kernel_size, stride=None, ceil_mode=False): if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): return handle_torch_function( - lp_pool2d, (input,), input, norm_type, kernel_size, stride=stride, - ceil_mode=ceil_mode) + lp_pool2d, (input,), input, norm_type, kernel_size, stride=stride, ceil_mode=ceil_mode + ) kw, kh = utils._pair(kernel_size) if stride is not None: out = avg_pool2d(input.pow(norm_type), kernel_size, stride, 0, ceil_mode) else: out = avg_pool2d(input.pow(norm_type), kernel_size, padding=0, ceil_mode=ceil_mode) - return (torch.sign(out) * relu(torch.abs(out))).mul(kw * kh).pow(1. / norm_type) + return (torch.sign(out) * relu(torch.abs(out))).mul(kw * kh).pow(1.0 / norm_type) def lp_pool1d(input, norm_type, kernel_size, stride=None, ceil_mode=False): @@ -750,14 +887,14 @@ def lp_pool1d(input, norm_type, kernel_size, stride=None, ceil_mode=False): if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): return handle_torch_function( - lp_pool1d, (input,), input, norm_type, kernel_size, stride=stride, - ceil_mode=ceil_mode) + lp_pool1d, (input,), input, norm_type, kernel_size, stride=stride, ceil_mode=ceil_mode + ) if stride is not None: out = avg_pool1d(input.pow(norm_type), kernel_size, stride, 0, ceil_mode) else: out = avg_pool1d(input.pow(norm_type), kernel_size, padding=0, ceil_mode=ceil_mode) - return (torch.sign(out) * relu(torch.abs(out))).mul(kernel_size).pow(1. / norm_type) + return (torch.sign(out) * relu(torch.abs(out))).mul(kernel_size).pow(1.0 / norm_type) def adaptive_max_pool1d_with_indices(input, output_size, return_indices=False): @@ -774,8 +911,8 @@ def adaptive_max_pool1d_with_indices(input, output_size, return_indices=False): if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): return handle_torch_function( - adaptive_max_pool1d_with_indices, (input,), input, output_size, - return_indices=return_indices) + adaptive_max_pool1d_with_indices, (input,), input, output_size, return_indices=return_indices + ) return torch.adaptive_max_pool1d(input, output_size) @@ -784,18 +921,20 @@ def _adaptive_max_pool1d(input, output_size, return_indices=False): if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): return handle_torch_function( - adaptive_max_pool1d, (input,), input, output_size, - return_indices=return_indices) + adaptive_max_pool1d, (input,), input, output_size, return_indices=return_indices + ) return adaptive_max_pool1d_with_indices(input, output_size)[0] + adaptive_max_pool1d = boolean_dispatch( - arg_name='return_indices', + arg_name="return_indices", arg_index=2, default=False, if_true=adaptive_max_pool1d_with_indices, if_false=_adaptive_max_pool1d, module_name=__name__, - func_name='adaptive_max_pool1d') + func_name="adaptive_max_pool1d", +) def adaptive_max_pool2d_with_indices(input, output_size, return_indices=False): @@ -813,8 +952,8 @@ def adaptive_max_pool2d_with_indices(input, output_size, return_indices=False): if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): return handle_torch_function( - adaptive_max_pool2d_with_indices, (input,), input, output_size, - return_indices=return_indices) + adaptive_max_pool2d_with_indices, (input,), input, output_size, return_indices=return_indices + ) output_size = _list_with_default(output_size, input.size()) return torch._C._nn.adaptive_max_pool2d(input, output_size) @@ -824,18 +963,20 @@ def _adaptive_max_pool2d(input, output_size, return_indices=False): if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): return handle_torch_function( - adaptive_max_pool2d, (input,), input, output_size, - return_indices=return_indices) + adaptive_max_pool2d, (input,), input, output_size, return_indices=return_indices + ) return adaptive_max_pool2d_with_indices(input, output_size)[0] + adaptive_max_pool2d = boolean_dispatch( - arg_name='return_indices', + arg_name="return_indices", arg_index=2, default=False, if_true=adaptive_max_pool2d_with_indices, if_false=_adaptive_max_pool2d, module_name=__name__, - func_name='adaptive_max_pool2d') + func_name="adaptive_max_pool2d", +) def adaptive_max_pool3d_with_indices(input, output_size, return_indices=False): @@ -853,8 +994,8 @@ def adaptive_max_pool3d_with_indices(input, output_size, return_indices=False): if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): return handle_torch_function( - adaptive_max_pool3d_with_indices, (input,), input, output_size, - return_indices=return_indices) + adaptive_max_pool3d_with_indices, (input,), input, output_size, return_indices=return_indices + ) output_size = _list_with_default(output_size, input.size()) return torch._C._nn.adaptive_max_pool3d(input, output_size) @@ -864,21 +1005,25 @@ def _adaptive_max_pool3d(input, output_size, return_indices=False): if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): return handle_torch_function( - adaptive_max_pool3d, (input,), input, output_size, - return_indices=return_indices) + adaptive_max_pool3d, (input,), input, output_size, return_indices=return_indices + ) return adaptive_max_pool3d_with_indices(input, output_size)[0] + adaptive_max_pool3d = boolean_dispatch( - arg_name='return_indices', + arg_name="return_indices", arg_index=2, default=False, if_true=adaptive_max_pool3d_with_indices, if_false=_adaptive_max_pool3d, module_name=__name__, - func_name='adaptive_max_pool3d') + func_name="adaptive_max_pool3d", +) -adaptive_avg_pool1d = _add_docstr(torch.adaptive_avg_pool1d, r""" +adaptive_avg_pool1d = _add_docstr( + torch.adaptive_avg_pool1d, + r""" adaptive_avg_pool1d(input, output_size) -> Tensor Applies a 1D adaptive average pooling over an input signal composed of @@ -888,7 +1033,8 @@ def _adaptive_max_pool3d(input, output_size, return_indices=False): Args: output_size: the target output size (single integer) -""") +""", +) def adaptive_avg_pool2d(input, output_size): @@ -905,8 +1051,7 @@ def adaptive_avg_pool2d(input, output_size): """ if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - adaptive_avg_pool2d, (input,), input, output_size) + return handle_torch_function(adaptive_avg_pool2d, (input,), input, output_size) _output_size = _list_with_default(output_size, input.size()) return torch._C._nn.adaptive_avg_pool2d(input, _output_size) @@ -925,15 +1070,13 @@ def adaptive_avg_pool3d(input, output_size): """ if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - adaptive_avg_pool3d, (input,), input, output_size) + return handle_torch_function(adaptive_avg_pool3d, (input,), input, output_size) _output_size = _list_with_default(output_size, input.size()) return torch._C._nn.adaptive_avg_pool3d(input, _output_size) # Activation functions -def dropout(input, p=0.5, training=True, inplace=False): - # type: (Tensor, float, bool, bool) -> Tensor +def dropout(input: Tensor, p: float = 0.5, training: bool = True, inplace: bool = False) -> Tensor: r""" During training, randomly zeroes some of the elements of the input tensor with probability :attr:`p` using samples from a Bernoulli @@ -948,36 +1091,26 @@ def dropout(input, p=0.5, training=True, inplace=False): """ if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - dropout, (input,), input, p=p, training=training, inplace=inplace) - if p < 0. or p > 1.: - raise ValueError("dropout probability has to be between 0 and 1, " - "but got {}".format(p)) - return (_VF.dropout_(input, p, training) - if inplace - else _VF.dropout(input, p, training)) + return handle_torch_function(dropout, (input,), input, p=p, training=training, inplace=inplace) + if p < 0.0 or p > 1.0: + raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p)) + return _VF.dropout_(input, p, training) if inplace else _VF.dropout(input, p, training) -def alpha_dropout(input, p=0.5, training=False, inplace=False): - # type: (Tensor, float, bool, bool) -> Tensor +def alpha_dropout(input: Tensor, p: float = 0.5, training: bool = False, inplace: bool = False) -> Tensor: r"""Applies alpha dropout to the input. See :class:`~torch.nn.AlphaDropout` for details. """ if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - alpha_dropout, (input,), input, p=p, training=training, inplace=inplace) - if p < 0. or p > 1.: - raise ValueError("dropout probability has to be between 0 and 1, " - "but got {}".format(p)) - return (_VF.alpha_dropout_(input, p, training) - if inplace - else _VF.alpha_dropout(input, p, training)) + return handle_torch_function(alpha_dropout, (input,), input, p=p, training=training, inplace=inplace) + if p < 0.0 or p > 1.0: + raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p)) + return _VF.alpha_dropout_(input, p, training) if inplace else _VF.alpha_dropout(input, p, training) -def dropout2d(input, p=0.5, training=True, inplace=False): - # type: (Tensor, float, bool, bool) -> Tensor +def dropout2d(input: Tensor, p: float = 0.5, training: bool = True, inplace: bool = False) -> Tensor: r""" Randomly zero out entire channels (a channel is a 2D feature map, e.g., the :math:`j`-th channel of the :math:`i`-th sample in the @@ -994,18 +1127,13 @@ def dropout2d(input, p=0.5, training=True, inplace=False): """ if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - dropout2d, (input,), input, p=p, training=training, inplace=inplace) - if p < 0. or p > 1.: - raise ValueError("dropout probability has to be between 0 and 1, " - "but got {}".format(p)) - return (_VF.feature_dropout_(input, p, training) - if inplace - else _VF.feature_dropout(input, p, training)) + return handle_torch_function(dropout2d, (input,), input, p=p, training=training, inplace=inplace) + if p < 0.0 or p > 1.0: + raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p)) + return _VF.feature_dropout_(input, p, training) if inplace else _VF.feature_dropout(input, p, training) -def dropout3d(input, p=0.5, training=True, inplace=False): - # type: (Tensor, float, bool, bool) -> Tensor +def dropout3d(input: Tensor, p: float = 0.5, training: bool = True, inplace: bool = False) -> Tensor: r""" Randomly zero out entire channels (a channel is a 3D feature map, e.g., the :math:`j`-th channel of the :math:`i`-th sample in the @@ -1024,18 +1152,13 @@ def dropout3d(input, p=0.5, training=True, inplace=False): # stack traces are not confusing. if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - dropout3d, (input,), input, p=p, training=training, inplace=inplace) - if p < 0. or p > 1.: - raise ValueError("dropout probability has to be between 0 and 1, " - "but got {}".format(p)) - return (_VF.feature_dropout_(input, p, training) - if inplace - else _VF.feature_dropout(input, p, training)) + return handle_torch_function(dropout3d, (input,), input, p=p, training=training, inplace=inplace) + if p < 0.0 or p > 1.0: + raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p)) + return _VF.feature_dropout_(input, p, training) if inplace else _VF.feature_dropout(input, p, training) -def feature_alpha_dropout(input, p=0.5, training=False, inplace=False): - # type: (Tensor, float, bool, bool) -> Tensor +def feature_alpha_dropout(input: Tensor, p: float = 0.5, training: bool = False, inplace: bool = False) -> Tensor: r""" Randomly masks out entire channels (a channel is a feature map, e.g. the :math:`j`-th channel of the :math:`i`-th sample in the batch input @@ -1058,42 +1181,41 @@ def feature_alpha_dropout(input, p=0.5, training=False, inplace=False): if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): return handle_torch_function( - feature_alpha_dropout, (input,), input, p=p, training=training, - inplace=inplace) - if p < 0. or p > 1.: - raise ValueError("dropout probability has to be between 0 and 1, " - "but got {}".format(p)) - return (_VF.feature_alpha_dropout_(input, p, training) - if inplace - else _VF.feature_alpha_dropout(input, p, training)) - - -def _threshold(input, threshold, value, inplace=False): - # type: (Tensor, float, float, bool) -> Tensor + feature_alpha_dropout, (input,), input, p=p, training=training, inplace=inplace + ) + if p < 0.0 or p > 1.0: + raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p)) + return _VF.feature_alpha_dropout_(input, p, training) if inplace else _VF.feature_alpha_dropout(input, p, training) + + +def _threshold(input: Tensor, threshold: float, value: float, inplace: bool = False) -> Tensor: r"""Thresholds each element of the input Tensor. See :class:`~torch.nn.Threshold` for more details. """ if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - _threshold, (input,), input, threshold, value, inplace=inplace) + return handle_torch_function(_threshold, (input,), input, threshold, value, inplace=inplace) if inplace: result = _VF.threshold_(input, threshold, value) else: result = _VF.threshold(input, threshold, value) return result + # We define this function as _threshold because it takes an argument # named threshold, which clobbers the recursive reference to the # function needed for __torch_function__ support threshold = _threshold -threshold_ = _add_docstr(_VF.threshold_, r""" +threshold_ = _add_docstr( + _VF.threshold_, + r""" threshold_(input, threshold, value) -> Tensor In-place version of :func:`~threshold`. -""") +""", +) def relu(input: Tensor, inplace: bool = False) -> Tensor: @@ -1112,11 +1234,14 @@ def relu(input: Tensor, inplace: bool = False) -> Tensor: return result -relu_ = _add_docstr(torch.relu_, r""" +relu_ = _add_docstr( + torch.relu_, + r""" relu_(input) -> Tensor In-place version of :func:`~relu`. -""") +""", +) def glu(input: Tensor, dim: int = -1) -> Tensor: @@ -1145,7 +1270,7 @@ def glu(input: Tensor, dim: int = -1) -> Tensor: return torch._C._nn.glu(input, dim) -def hardtanh(input: Tensor, min_val: float = -1., max_val: float = 1., inplace: bool = False) -> Tensor: +def hardtanh(input: Tensor, min_val: float = -1.0, max_val: float = 1.0, inplace: bool = False) -> Tensor: r""" hardtanh(input, min_val=-1., max_val=1., inplace=False) -> Tensor @@ -1154,9 +1279,7 @@ def hardtanh(input: Tensor, min_val: float = -1., max_val: float = 1., inplace: """ if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - hardtanh, (input,), input, min_val=min_val, max_val=max_val, - inplace=inplace) + return handle_torch_function(hardtanh, (input,), input, min_val=min_val, max_val=max_val, inplace=inplace) if inplace: result = torch._C._nn.hardtanh_(input, min_val, max_val) else: @@ -1164,15 +1287,17 @@ def hardtanh(input: Tensor, min_val: float = -1., max_val: float = 1., inplace: return result -hardtanh_ = _add_docstr(torch._C._nn.hardtanh_, r""" +hardtanh_ = _add_docstr( + torch._C._nn.hardtanh_, + r""" hardtanh_(input, min_val=-1., max_val=1.) -> Tensor In-place version of :func:`~hardtanh`. -""") +""", +) -def relu6(input, inplace=False): - # type: (Tensor, bool) -> Tensor +def relu6(input: Tensor, inplace: bool = False) -> Tensor: r"""relu6(input, inplace=False) -> Tensor Applies the element-wise function :math:`\text{ReLU6}(x) = \min(\max(0,x), 6)`. @@ -1182,11 +1307,10 @@ def relu6(input, inplace=False): if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): return handle_torch_function(relu6, (input,), input, inplace=inplace) - return hardtanh(input, 0., 6., inplace) + return hardtanh(input, 0.0, 6.0, inplace) -def elu(input, alpha=1., inplace=False): - # type: (Tensor, float, bool) -> Tensor +def elu(input: Tensor, alpha: float = 1.0, inplace: bool = False) -> Tensor: r"""Applies element-wise, :math:`\text{ELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x) - 1))`. @@ -1194,8 +1318,7 @@ def elu(input, alpha=1., inplace=False): """ if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function(elu, (input,), input, alpha=alpha, - inplace=inplace) + return handle_torch_function(elu, (input,), input, alpha=alpha, inplace=inplace) if inplace: result = torch._C._nn.elu_(input, alpha) else: @@ -1203,15 +1326,17 @@ def elu(input, alpha=1., inplace=False): return result -elu_ = _add_docstr(torch._C._nn.elu_, r""" +elu_ = _add_docstr( + torch._C._nn.elu_, + r""" elu_(input, alpha=1.) -> Tensor In-place version of :func:`~elu`. -""") +""", +) -def selu(input, inplace=False): - # type: (Tensor, bool) -> Tensor +def selu(input: Tensor, inplace: bool = False) -> Tensor: r"""selu(input, inplace=False) -> Tensor Applies element-wise, @@ -1231,15 +1356,17 @@ def selu(input, inplace=False): return result -selu_ = _add_docstr(torch.selu_, r""" +selu_ = _add_docstr( + torch.selu_, + r""" selu_(input) -> Tensor In-place version of :func:`~selu`. -""") +""", +) -def celu(input, alpha=1., inplace=False): - # type: (Tensor, float, bool) -> Tensor +def celu(input: Tensor, alpha: float = 1.0, inplace: bool = False) -> Tensor: r"""celu(input, alpha=1., inplace=False) -> Tensor Applies element-wise, @@ -1249,19 +1376,22 @@ def celu(input, alpha=1., inplace=False): """ if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function(celu, (input,), input, alpha=alpha, - inplace=inplace) + return handle_torch_function(celu, (input,), input, alpha=alpha, inplace=inplace) if inplace: result = torch.celu_(input, alpha) else: result = torch.celu(input, alpha) return result -celu_ = _add_docstr(torch.celu_, r""" + +celu_ = _add_docstr( + torch.celu_, + r""" celu_(input, alpha=1.) -> Tensor In-place version of :func:`~celu`. -""") +""", +) def leaky_relu(input: Tensor, negative_slope: float = 0.01, inplace: bool = False) -> Tensor: @@ -1275,9 +1405,7 @@ def leaky_relu(input: Tensor, negative_slope: float = 0.01, inplace: bool = Fals """ if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - leaky_relu, (input,), input, negative_slope=negative_slope, - inplace=inplace) + return handle_torch_function(leaky_relu, (input,), input, negative_slope=negative_slope, inplace=inplace) if inplace: result = torch._C._nn.leaky_relu_(input, negative_slope) else: @@ -1285,15 +1413,17 @@ def leaky_relu(input: Tensor, negative_slope: float = 0.01, inplace: bool = Fals return result -leaky_relu_ = _add_docstr(torch._C._nn.leaky_relu_, r""" +leaky_relu_ = _add_docstr( + torch._C._nn.leaky_relu_, + r""" leaky_relu_(input, negative_slope=0.01) -> Tensor In-place version of :func:`~leaky_relu`. -""") +""", +) -def prelu(input, weight): - # type: (Tensor, Tensor) -> Tensor +def prelu(input: Tensor, weight: Tensor) -> Tensor: r"""prelu(input, weight) -> Tensor Applies element-wise the function @@ -1308,8 +1438,9 @@ def prelu(input, weight): return torch.prelu(input, weight) -def rrelu(input, lower=1. / 8, upper=1. / 3, training=False, inplace=False): - # type: (Tensor, float, float, bool, bool) -> Tensor +def rrelu( + input: Tensor, lower: float = 1.0 / 8, upper: float = 1.0 / 3, training: bool = False, inplace: bool = False +) -> Tensor: r"""rrelu(input, lower=1./8, upper=1./3, training=False, inplace=False) -> Tensor Randomized leaky ReLU. @@ -1319,8 +1450,8 @@ def rrelu(input, lower=1. / 8, upper=1. / 3, training=False, inplace=False): if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): return handle_torch_function( - rrelu, (input,), input, lower=lower, upper=upper, - training=training, inplace=inplace) + rrelu, (input,), input, lower=lower, upper=upper, training=training, inplace=inplace + ) if inplace: result = torch.rrelu_(input, lower, upper, training) else: @@ -1328,19 +1459,26 @@ def rrelu(input, lower=1. / 8, upper=1. / 3, training=False, inplace=False): return result -rrelu_ = _add_docstr(torch.rrelu_, r""" +rrelu_ = _add_docstr( + torch.rrelu_, + r""" rrelu_(input, lower=1./8, upper=1./3, training=False) -> Tensor In-place version of :func:`~rrelu`. -""") +""", +) -logsigmoid = _add_docstr(torch._C._nn.log_sigmoid, r""" +logsigmoid = _add_docstr( + torch._C._nn.log_sigmoid, + r""" logsigmoid(input) -> Tensor Applies element-wise :math:`\text{LogSigmoid}(x_i) = \log \left(\frac{1}{1 + \exp(-x_i)}\right)` See :class:`~torch.nn.LogSigmoid` for more details. -""") +""", +) + def gelu(input): r"""gelu(input) -> Tensor @@ -1358,8 +1496,7 @@ def gelu(input): return torch._C._nn.gelu(input) -def hardshrink(input, lambd=0.5): - # type: (Tensor, float) -> Tensor +def hardshrink(input: Tensor, lambd: float = 0.5) -> Tensor: r""" hardshrink(input, lambd=0.5) -> Tensor @@ -1399,7 +1536,9 @@ def softsign(input): return input / (input.abs() + 1) -softplus = _add_docstr(torch._C._nn.softplus, r""" +softplus = _add_docstr( + torch._C._nn.softplus, + r""" softplus(input, beta=1, threshold=20) -> Tensor Applies element-wise, the function :math:`\text{Softplus}(x) = \frac{1}{\beta} * \log(1 + \exp(\beta * x))`. @@ -1408,13 +1547,16 @@ def softsign(input): when :math:`input \times \beta > threshold`. See :class:`~torch.nn.Softplus` for more details. -""") +""", +) -def _get_softmax_dim(name, ndim, stacklevel): - # type: (str, int, int) -> int - warnings.warn("Implicit dimension choice for {} has been deprecated. " - "Change the call to include dim=X as an argument.".format(name), stacklevel=stacklevel) +def _get_softmax_dim(name: str, ndim: int, stacklevel: int) -> int: + warnings.warn( + "Implicit dimension choice for {} has been deprecated. " + "Change the call to include dim=X as an argument.".format(name), + stacklevel=stacklevel, + ) if ndim == 0 or ndim == 1 or ndim == 3: ret = 0 else: @@ -1422,8 +1564,7 @@ def _get_softmax_dim(name, ndim, stacklevel): return ret -def softmin(input, dim=None, _stacklevel=3, dtype=None): - # type: (Tensor, Optional[int], int, Optional[int]) -> Tensor +def softmin(input: Tensor, dim: Optional[int] = None, _stacklevel: int = 3, dtype: Optional[int] = None) -> Tensor: r"""Applies a softmin function. Note that :math:`\text{Softmin}(x) = \text{Softmax}(-x)`. See softmax definition for mathematical formula. @@ -1440,10 +1581,9 @@ def softmin(input, dim=None, _stacklevel=3, dtype=None): """ if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - softmin, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype) + return handle_torch_function(softmin, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype) if dim is None: - dim = _get_softmax_dim('softmin', input.dim(), _stacklevel) + dim = _get_softmax_dim("softmin", input.dim(), _stacklevel) if dtype is None: ret = (-input).softmax(dim) else: @@ -1451,8 +1591,7 @@ def softmin(input, dim=None, _stacklevel=3, dtype=None): return ret -def softmax(input, dim=None, _stacklevel=3, dtype=None): - # type: (Tensor, Optional[int], int, Optional[int]) -> Tensor +def softmax(input: Tensor, dim: Optional[int] = None, _stacklevel: int = 3, dtype: Optional[int] = None) -> Tensor: r"""Applies a softmax function. Softmax is defined as: @@ -1479,10 +1618,9 @@ def softmax(input, dim=None, _stacklevel=3, dtype=None): """ if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - softmax, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype) + return handle_torch_function(softmax, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype) if dim is None: - dim = _get_softmax_dim('softmax', input.dim(), _stacklevel) + dim = _get_softmax_dim("softmax", input.dim(), _stacklevel) if dtype is None: ret = input.softmax(dim) else: @@ -1490,8 +1628,7 @@ def softmax(input, dim=None, _stacklevel=3, dtype=None): return ret -def gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1): - # type: (Tensor, float, bool, float, int) -> Tensor +def gumbel_softmax(logits: Tensor, tau: float = 1, hard: bool = False, eps: float = 1e-10, dim: int = -1) -> Tensor: r""" Samples from the Gumbel-Softmax distribution (`Link 1`_ `Link 2`_) and optionally discretizes. @@ -1533,12 +1670,13 @@ def gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1): """ if not torch.jit.is_scripting(): if type(logits) is not Tensor and has_torch_function((logits,)): - return handle_torch_function( - gumbel_softmax, (logits,), logits, tau=tau, hard=hard, eps=eps, dim=dim) + return handle_torch_function(gumbel_softmax, (logits,), logits, tau=tau, hard=hard, eps=eps, dim=dim) if eps != 1e-10: warnings.warn("`eps` parameter is deprecated and has no effect.") - gumbels = -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log() # ~Gumbel(0,1) + gumbels = ( + -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log() + ) # ~Gumbel(0,1) gumbels = (logits + gumbels) / tau # ~Gumbel(logits,tau) y_soft = gumbels.softmax(dim) @@ -1553,8 +1691,7 @@ def gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1): return ret -def log_softmax(input, dim=None, _stacklevel=3, dtype=None): - # type: (Tensor, Optional[int], int, Optional[int]) -> Tensor +def log_softmax(input: Tensor, dim: Optional[int] = None, _stacklevel: int = 3, dtype: Optional[int] = None) -> Tensor: r"""Applies a softmax followed by a logarithm. While mathematically equivalent to log(softmax(x)), doing these two @@ -1572,10 +1709,9 @@ def log_softmax(input, dim=None, _stacklevel=3, dtype=None): """ if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - log_softmax, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype) + return handle_torch_function(log_softmax, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype) if dim is None: - dim = _get_softmax_dim('log_softmax', input.dim(), _stacklevel) + dim = _get_softmax_dim("log_softmax", input.dim(), _stacklevel) if dtype is None: ret = input.log_softmax(dim) else: @@ -1583,13 +1719,16 @@ def log_softmax(input, dim=None, _stacklevel=3, dtype=None): return ret -softshrink = _add_docstr(torch._C._nn.softshrink, r""" +softshrink = _add_docstr( + torch._C._nn.softshrink, + r""" softshrink(input, lambd=0.5) -> Tensor Applies the soft shrinkage function elementwise See :class:`~torch.nn.Softshrink` for more details. -""") +""", +) def tanh(input): @@ -1615,8 +1754,7 @@ def sigmoid(input): return input.sigmoid() -def hardsigmoid(input, inplace=False): - # type: (Tensor, bool) -> Tensor +def hardsigmoid(input: Tensor, inplace: bool = False) -> Tensor: r"""hardsigmoid(input) -> Tensor Applies the element-wise function @@ -1641,8 +1779,7 @@ def hardsigmoid(input, inplace=False): return torch._C._nn.hardsigmoid(input) -def linear(input, weight, bias=None): - # type: (Tensor, Tensor, Optional[Tensor]) -> Tensor +def linear(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor: r""" Applies a linear transformation to the incoming data: :math:`y = xA^T + b`. @@ -1671,8 +1808,7 @@ def linear(input, weight, bias=None): return ret -def bilinear(input1, input2, weight, bias=None): - # type: (Tensor, Tensor, Tensor, Optional[Tensor]) -> Tensor +def bilinear(input1: Tensor, input2: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor: r""" Applies a bilinear transformation to the incoming data: :math:`y = x_1^T A x_2 + b` @@ -1691,8 +1827,8 @@ def bilinear(input1, input2, weight, bias=None): """ return torch.bilinear(input1, input2, weight, bias) -def silu(input, inplace=False): - # type: (Tensor, bool) -> Tensor + +def silu(input: Tensor, inplace: bool = False) -> Tensor: r"""Applies the silu function, element-wise. .. math:: @@ -1715,6 +1851,7 @@ def silu(input, inplace=False): return torch._C._nn.silu_(input) return torch._C._nn.silu(input) + def hardswish(input: Tensor, inplace: bool = False) -> Tensor: r"""Applies the hardswish function, element-wise, as described in the paper: @@ -1740,15 +1877,20 @@ def hardswish(input: Tensor, inplace: bool = False) -> Tensor: return torch._C._nn.hardswish(input) -def _no_grad_embedding_renorm_(weight, input, max_norm, norm_type): - # type: (Tensor, Tensor, float, float) -> Tensor +def _no_grad_embedding_renorm_(weight: Tensor, input: Tensor, max_norm: float, norm_type: float) -> Tensor: with torch.no_grad(): torch.embedding_renorm_(weight, input, max_norm, norm_type) -def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2., - scale_grad_by_freq=False, sparse=False): - # type: (Tensor, Tensor, Optional[int], Optional[float], float, bool, bool) -> Tensor +def embedding( + input: Tensor, + weight: Tensor, + padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + sparse: bool = False, +) -> Tensor: r"""A simple lookup table that looks up embeddings in a fixed dictionary and size. This module is often used to retrieve word embeddings using indices. @@ -1809,9 +1951,9 @@ def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2., if padding_idx is not None: if padding_idx > 0: - assert padding_idx < weight.size(0), 'Padding_idx must be within num_embeddings' + assert padding_idx < weight.size(0), "Padding_idx must be within num_embeddings" elif padding_idx < 0: - assert padding_idx >= -weight.size(0), 'Padding_idx must be within num_embeddings' + assert padding_idx >= -weight.size(0), "Padding_idx must be within num_embeddings" padding_idx = weight.size(0) + padding_idx else: padding_idx = -1 @@ -1828,10 +1970,18 @@ def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2., return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse) -def embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2, - scale_grad_by_freq=False, mode='mean', sparse=False, - per_sample_weights=None, include_last_offset=False): - # type: (Tensor, Tensor, Optional[Tensor], Optional[float], float, bool, str, bool, Optional[Tensor], bool) -> Tensor +def embedding_bag( + input: Tensor, + weight: Tensor, + offsets: Optional[Tensor] = None, + max_norm: Optional[float] = None, + norm_type: float = 2, + scale_grad_by_freq: bool = False, + mode: str = "mean", + sparse: bool = False, + per_sample_weights: Optional[Tensor] = None, + include_last_offset: bool = False, +) -> Tensor: r"""Computes sums, means or maxes of `bags` of embeddings, without instantiating the intermediate embeddings. @@ -1911,23 +2061,35 @@ def embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2, tens_ops = (input, weight) if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): return handle_torch_function( - embedding_bag, tens_ops, input, weight, offsets=offsets, max_norm=max_norm, - norm_type=norm_type, scale_grad_by_freq=scale_grad_by_freq, mode=mode, - sparse=sparse, per_sample_weights=per_sample_weights, - include_last_offset=include_last_offset) + embedding_bag, + tens_ops, + input, + weight, + offsets=offsets, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + mode=mode, + sparse=sparse, + per_sample_weights=per_sample_weights, + include_last_offset=include_last_offset, + ) # Check for backward compatibility. # Used to be embedding_bag(weight, input, ...) # Now is embedding_bag(input, weight, ...) if weight.dtype == torch.long and input.is_floating_point(): - warnings.warn("Argument order of nn.functional.embedding_bag was changed. " - "Usage `embedding_bag(weight, input, ...)` is deprecated, " - "and should now be `embedding_bag(input, weight, ...)`.") + warnings.warn( + "Argument order of nn.functional.embedding_bag was changed. " + "Usage `embedding_bag(weight, input, ...)` is deprecated, " + "and should now be `embedding_bag(input, weight, ...)`." + ) weight, input = input, weight if per_sample_weights is not None and input.size() != per_sample_weights.size(): - raise ValueError("embedding_bag: If per_sample_weights ({}) is not None, " - "then it must have the same shape as the input ({})" - .format(per_sample_weights.shape, input.shape)) + raise ValueError( + "embedding_bag: If per_sample_weights ({}) is not None, " + "then it must have the same shape as the input ({})".format(per_sample_weights.shape, input.shape) + ) if input.dim() == 2: if offsets is not None: @@ -1935,12 +2097,13 @@ def embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2, # TODO: Remove this once script supports type() calls if not torch.jit.is_scripting(): type_str = str(type(offsets)) - raise ValueError("if input is 2D, then offsets has to be None" - ", as input is treated is a mini-batch of" - " fixed length sequences. However, found " - "offsets of type {}".format(type_str)) - offsets = torch.arange(0, input.numel(), input.size(1), - dtype=input.dtype, device=input.device) + raise ValueError( + "if input is 2D, then offsets has to be None" + ", as input is treated is a mini-batch of" + " fixed length sequences. However, found " + "offsets of type {}".format(type_str) + ) + offsets = torch.arange(0, input.numel(), input.size(1), dtype=input.dtype, device=input.device) input = input.reshape(-1) if per_sample_weights is not None: @@ -1951,13 +2114,12 @@ def embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2, if offsets.dim() != 1: raise ValueError("offsets has to be a 1D Tensor") else: - raise ValueError("input has to be 1D or 2D Tensor," - " but got Tensor of dimension {}".format(input.dim())) - if mode == 'sum': + raise ValueError("input has to be 1D or 2D Tensor," " but got Tensor of dimension {}".format(input.dim())) + if mode == "sum": mode_enum = 0 - elif mode == 'mean': + elif mode == "mean": mode_enum = 1 - elif mode == 'max': + elif mode == "max": mode_enum = 2 if scale_grad_by_freq: @@ -1976,28 +2138,23 @@ def embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2, # remove once script supports set_grad_enabled _no_grad_embedding_renorm_(weight, input, max_norm, norm_type) - if per_sample_weights is not None and mode != 'sum': - raise NotImplementedError("embedding_bag: per_sample_weights was not None. " - "per_sample_weights is only supported for mode='sum' " - "(got mode='{}'). Please open a feature request on GitHub." - .format(mode)) + if per_sample_weights is not None and mode != "sum": + raise NotImplementedError( + "embedding_bag: per_sample_weights was not None. " + "per_sample_weights is only supported for mode='sum' " + "(got mode='{}'). Please open a feature request on GitHub.".format(mode) + ) ret, _, _, _ = torch.embedding_bag( - weight, - input, - offsets, - scale_grad_by_freq, - mode_enum, - sparse, - per_sample_weights, - include_last_offset) + weight, input, offsets, scale_grad_by_freq, mode_enum, sparse, per_sample_weights, include_last_offset + ) return ret + embedding_bag.__doc__ = embedding_bag.__doc__.format(**reproducibility_notes) -def _verify_batch_size(size): - # type: (List[int]) -> None +def _verify_batch_size(size: List[int]) -> None: # XXX: JIT script does not support the reduce from functools, and mul op is a # builtin, which cannot be used as a value to a func yet, so rewrite this size # check to a simple equivalent for loop @@ -2011,12 +2168,20 @@ def _verify_batch_size(size): for i in range(len(size) - 2): size_prods *= size[i + 2] if size_prods == 1: - raise ValueError('Expected more than 1 value per channel when training, got input size {}'.format(size)) - - -def batch_norm(input, running_mean, running_var, weight=None, bias=None, - training=False, momentum=0.1, eps=1e-5): - # type: (Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor], bool, float, float) -> Tensor # noqa + raise ValueError("Expected more than 1 value per channel when training, got input size {}".format(size)) + + +def batch_norm( + input: Tensor, + running_mean: Optional[Tensor], + running_var: Optional[Tensor], + weight: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + training: bool = False, + momentum: float = 0.1, + eps: float = 1e-5, +) -> Tensor: + # noqa r"""Applies Batch Normalization for each channel across a batch of data. See :class:`~torch.nn.BatchNorm1d`, :class:`~torch.nn.BatchNorm2d`, @@ -2025,20 +2190,36 @@ def batch_norm(input, running_mean, running_var, weight=None, bias=None, if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): return handle_torch_function( - batch_norm, (input,), input, running_mean, running_var, weight=weight, - bias=bias, training=training, momentum=momentum, eps=eps) + batch_norm, + (input,), + input, + running_mean, + running_var, + weight=weight, + bias=bias, + training=training, + momentum=momentum, + eps=eps, + ) if training: _verify_batch_size(input.size()) return torch.batch_norm( - input, weight, bias, running_mean, running_var, - training, momentum, eps, torch.backends.cudnn.enabled + input, weight, bias, running_mean, running_var, training, momentum, eps, torch.backends.cudnn.enabled ) -def instance_norm(input, running_mean=None, running_var=None, weight=None, - bias=None, use_input_stats=True, momentum=0.1, eps=1e-5): - # type: (Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor], bool, float, float) -> Tensor # noqa +def instance_norm( + input: Tensor, + running_mean: Optional[Tensor] = None, + running_var: Optional[Tensor] = None, + weight: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + use_input_stats: bool = True, + momentum: float = 0.1, + eps: float = 1e-5, +) -> Tensor: + # noqa r"""Applies Instance Normalization for each channel in each data sample in a batch. @@ -2048,18 +2229,30 @@ def instance_norm(input, running_mean=None, running_var=None, weight=None, if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): return handle_torch_function( - instance_norm, (input,), input, running_mean=running_mean, - running_var=running_var, weight=weight, bias=bias, - use_input_stats=use_input_stats, momentum=momentum, eps=eps) + instance_norm, + (input,), + input, + running_mean=running_mean, + running_var=running_var, + weight=weight, + bias=bias, + use_input_stats=use_input_stats, + momentum=momentum, + eps=eps, + ) _verify_batch_size(input.size()) return torch.instance_norm( - input, weight, bias, running_mean, running_var, - use_input_stats, momentum, eps, torch.backends.cudnn.enabled + input, weight, bias, running_mean, running_var, use_input_stats, momentum, eps, torch.backends.cudnn.enabled ) -def layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-5): - # type: (Tensor, List[int], Optional[Tensor], Optional[Tensor], float) -> Tensor +def layer_norm( + input: Tensor, + normalized_shape: List[int], + weight: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + eps: float = 1e-5, +) -> Tensor: r"""Applies Layer Normalization for last certain number of dimensions. See :class:`~torch.nn.LayerNorm` for details. @@ -2067,30 +2260,26 @@ def layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-5): if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): return handle_torch_function( - layer_norm, (input,), input, normalized_shape, weight=weight, bias=bias, eps=eps) - return torch.layer_norm(input, normalized_shape, weight, bias, eps, - torch.backends.cudnn.enabled) + layer_norm, (input,), input, normalized_shape, weight=weight, bias=bias, eps=eps + ) + return torch.layer_norm(input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled) -def group_norm(input, num_groups, weight=None, bias=None, eps=1e-5): - # type: (Tensor, int, Optional[Tensor], Optional[Tensor], float) -> Tensor +def group_norm( + input: Tensor, num_groups: int, weight: Optional[Tensor] = None, bias: Optional[Tensor] = None, eps: float = 1e-5 +) -> Tensor: r"""Applies Group Normalization for last certain number of dimensions. See :class:`~torch.nn.GroupNorm` for details. """ if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - group_norm, (input,), input, num_groups, weight=weight, bias=bias, eps=eps) - _verify_batch_size([ - input.size(0) * input.size(1) // num_groups, num_groups] - + list(input.size()[2:])) - return torch.group_norm(input, num_groups, weight, bias, eps, - torch.backends.cudnn.enabled) + return handle_torch_function(group_norm, (input,), input, num_groups, weight=weight, bias=bias, eps=eps) + _verify_batch_size([input.size(0) * input.size(1) // num_groups, num_groups] + list(input.size()[2:])) + return torch.group_norm(input, num_groups, weight, bias, eps, torch.backends.cudnn.enabled) -def local_response_norm(input, size, alpha=1e-4, beta=0.75, k=1.): - # type: (Tensor, int, float, float, float) -> Tensor +def local_response_norm(input: Tensor, size: int, alpha: float = 1e-4, beta: float = 0.75, k: float = 1.0) -> Tensor: r"""Applies local response normalization over an input signal composed of several input planes, where channels occupy the second dimension. Applies normalization across channels. @@ -2099,12 +2288,15 @@ def local_response_norm(input, size, alpha=1e-4, beta=0.75, k=1.): """ if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - local_response_norm, (input,), input, size, alpha=alpha, beta=beta, k=k) + return handle_torch_function(local_response_norm, (input,), input, size, alpha=alpha, beta=beta, k=k) dim = input.dim() if dim < 3: - raise ValueError('Expected 3D or higher dimensionality \ - input (got {} dimensions)'.format(dim)) + raise ValueError( + "Expected 3D or higher dimensionality \ + input (got {} dimensions)".format( + dim + ) + ) div = input.mul(input).unsqueeze(1) if dim == 3: div = pad(div, (0, 0, size // 2, (size - 1) // 2)) @@ -2121,9 +2313,16 @@ def local_response_norm(input, size, alpha=1e-4, beta=0.75, k=1.): # loss -def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0, - reduction='mean', zero_infinity=False): - # type: (Tensor, Tensor, Tensor, Tensor, int, str, bool) -> Tensor + +def ctc_loss( + log_probs: Tensor, + targets: Tensor, + input_lengths: Tensor, + target_lengths: Tensor, + blank: int = 0, + reduction: str = "mean", + zero_infinity: bool = False, +) -> Tensor: r"""The Connectionist Temporal Classification loss. See :class:`~torch.nn.CTCLoss` for details. @@ -2167,14 +2366,23 @@ def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0, >>> loss = F.ctc_loss(log_probs, targets, input_lengths, target_lengths) >>> loss.backward() """ - return torch.ctc_loss(log_probs, targets, input_lengths, target_lengths, blank, _Reduction.get_enum(reduction), - zero_infinity) + return torch.ctc_loss( + log_probs, targets, input_lengths, target_lengths, blank, _Reduction.get_enum(reduction), zero_infinity + ) + + ctc_loss.__doc__ = ctc_loss.__doc__.format(**reproducibility_notes) -def nll_loss(input, target, weight=None, size_average=None, ignore_index=-100, - reduce=None, reduction='mean'): - # type: (Tensor, Tensor, Optional[Tensor], Optional[bool], int, Optional[bool], str) -> Tensor +def nll_loss( + input: Tensor, + target: Tensor, + weight: Optional[Tensor] = None, + size_average: Optional[bool] = None, + ignore_index: int = -100, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: r"""The negative log likelihood loss. See :class:`~torch.nn.NLLLoss` for details. @@ -2220,17 +2428,26 @@ def nll_loss(input, target, weight=None, size_average=None, ignore_index=-100, tens_ops = (input, target) if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): return handle_torch_function( - nll_loss, tens_ops, input, target, weight=weight, size_average=size_average, - ignore_index=ignore_index, reduce=reduce, reduction=reduction) + nll_loss, + tens_ops, + input, + target, + weight=weight, + size_average=size_average, + ignore_index=ignore_index, + reduce=reduce, + reduction=reduction, + ) if size_average is not None or reduce is not None: reduction = _Reduction.legacy_get_string(size_average, reduce) dim = input.dim() if dim < 2: - raise ValueError('Expected 2 or more dimensions (got {})'.format(dim)) + raise ValueError("Expected 2 or more dimensions (got {})".format(dim)) if input.size(0) != target.size(0): - raise ValueError('Expected input batch_size ({}) to match target batch_size ({}).' - .format(input.size(0), target.size(0))) + raise ValueError( + "Expected input batch_size ({}) to match target batch_size ({}).".format(input.size(0), target.size(0)) + ) if dim == 2: ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index) elif dim == 4: @@ -2241,8 +2458,7 @@ def nll_loss(input, target, weight=None, size_average=None, ignore_index=-100, c = input.size(1) out_size = (n,) + input.size()[2:] if target.size()[1:] != input.size()[2:]: - raise ValueError('Expected target size {}, got {}'.format( - out_size, target.size())) + raise ValueError("Expected target size {}, got {}".format(out_size, target.size())) input = input.contiguous() target = target.contiguous() # support empty batches, see #15870 @@ -2255,19 +2471,24 @@ def nll_loss(input, target, weight=None, size_average=None, ignore_index=-100, else: target = target.view(n, 0, 0) reduction_enum = _Reduction.get_enum(reduction) - if reduction != 'none': - ret = torch._C._nn.nll_loss2d( - input, target, weight, reduction_enum, ignore_index) + if reduction != "none": + ret = torch._C._nn.nll_loss2d(input, target, weight, reduction_enum, ignore_index) else: - out = torch._C._nn.nll_loss2d( - input, target, weight, reduction_enum, ignore_index) + out = torch._C._nn.nll_loss2d(input, target, weight, reduction_enum, ignore_index) ret = out.view(out_size) return ret -def poisson_nll_loss(input, target, log_input=True, full=False, size_average=None, eps=1e-8, - reduce=None, reduction='mean'): - # type: (Tensor, Tensor, bool, bool, Optional[bool], float, Optional[bool], str) -> Tensor +def poisson_nll_loss( + input: Tensor, + target: Tensor, + log_input: bool = True, + full: bool = False, + size_average: Optional[bool] = None, + eps: float = 1e-8, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: r"""Poisson negative log likelihood loss. See :class:`~torch.nn.PoissonNLLLoss` for details. @@ -2304,11 +2525,20 @@ def poisson_nll_loss(input, target, log_input=True, full=False, size_average=Non tens_ops = (input, target) if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): return handle_torch_function( - poisson_nll_loss, tens_ops, input, target, log_input=log_input, full=full, - size_average=size_average, eps=eps, reduce=reduce, reduction=reduction) + poisson_nll_loss, + tens_ops, + input, + target, + log_input=log_input, + full=full, + size_average=size_average, + eps=eps, + reduce=reduce, + reduction=reduction, + ) if size_average is not None or reduce is not None: reduction = _Reduction.legacy_get_string(size_average, reduce) - if reduction != 'none' and reduction != 'mean' and reduction != 'sum': + if reduction != "none" and reduction != "mean" and reduction != "sum": ret = input raise ValueError(reduction + " is not valid") @@ -2316,8 +2546,14 @@ def poisson_nll_loss(input, target, log_input=True, full=False, size_average=Non return ret -def kl_div(input, target, size_average=None, reduce=None, reduction='mean', log_target=False): - # type: (Tensor, Tensor, Optional[bool], Optional[bool], str, bool) -> Tensor +def kl_div( + input: Tensor, + target: Tensor, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", + log_target: bool = False, +) -> Tensor: r"""The `Kullback-Leibler divergence Loss `__ @@ -2360,33 +2596,48 @@ def kl_div(input, target, size_average=None, reduce=None, reduction='mean', log_ tens_ops = (input, target) if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): return handle_torch_function( - kl_div, tens_ops, input, target, size_average=size_average, - reduce=reduce, reduction=reduction, log_target=log_target) + kl_div, + tens_ops, + input, + target, + size_average=size_average, + reduce=reduce, + reduction=reduction, + log_target=log_target, + ) if size_average is not None or reduce is not None: reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) else: - if reduction == 'mean': - warnings.warn("reduction: 'mean' divides the total loss by both the batch size and the support size." - "'batchmean' divides only by the batch size, and aligns with the KL div math definition." - "'mean' will be changed to behave the same as 'batchmean' in the next major release.") + if reduction == "mean": + warnings.warn( + "reduction: 'mean' divides the total loss by both the batch size and the support size." + "'batchmean' divides only by the batch size, and aligns with the KL div math definition." + "'mean' will be changed to behave the same as 'batchmean' in the next major release." + ) # special case for batchmean - if reduction == 'batchmean': - reduction_enum = _Reduction.get_enum('sum') + if reduction == "batchmean": + reduction_enum = _Reduction.get_enum("sum") else: reduction_enum = _Reduction.get_enum(reduction) reduced = torch.kl_div(input, target, reduction_enum, log_target=log_target) - if reduction == 'batchmean' and input.dim() != 0: + if reduction == "batchmean" and input.dim() != 0: reduced = reduced / input.size()[0] return reduced -def cross_entropy(input, target, weight=None, size_average=None, ignore_index=-100, - reduce=None, reduction='mean'): - # type: (Tensor, Tensor, Optional[Tensor], Optional[bool], int, Optional[bool], str) -> Tensor +def cross_entropy( + input: Tensor, + target: Tensor, + weight: Optional[Tensor] = None, + size_average: Optional[bool] = None, + ignore_index: int = -100, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: r"""This criterion combines `log_softmax` and `nll_loss` in a single function. @@ -2431,17 +2682,29 @@ def cross_entropy(input, target, weight=None, size_average=None, ignore_index=-1 tens_ops = (input, target) if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): return handle_torch_function( - cross_entropy, tens_ops, input, target, weight=weight, - size_average=size_average, ignore_index=ignore_index, reduce=reduce, - reduction=reduction) + cross_entropy, + tens_ops, + input, + target, + weight=weight, + size_average=size_average, + ignore_index=ignore_index, + reduce=reduce, + reduction=reduction, + ) if size_average is not None or reduce is not None: reduction = _Reduction.legacy_get_string(size_average, reduce) return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction) -def binary_cross_entropy(input, target, weight=None, size_average=None, - reduce=None, reduction='mean'): - # type: (Tensor, Tensor, Optional[Tensor], Optional[bool], Optional[bool], str) -> Tensor +def binary_cross_entropy( + input: Tensor, + target: Tensor, + weight: Optional[Tensor] = None, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: r"""Function that measures the Binary Cross Entropy between the target and the output. @@ -2479,27 +2742,41 @@ def binary_cross_entropy(input, target, weight=None, size_average=None, tens_ops = (input, target) if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): return handle_torch_function( - binary_cross_entropy, tens_ops, input, target, weight=weight, - size_average=size_average, reduce=reduce, reduction=reduction) + binary_cross_entropy, + tens_ops, + input, + target, + weight=weight, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) if size_average is not None or reduce is not None: reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) else: reduction_enum = _Reduction.get_enum(reduction) if target.size() != input.size(): - raise ValueError("Using a target size ({}) that is different to the input size ({}) is deprecated. " - "Please ensure they have the same size.".format(target.size(), input.size())) + raise ValueError( + "Using a target size ({}) that is different to the input size ({}) is deprecated. " + "Please ensure they have the same size.".format(target.size(), input.size()) + ) if weight is not None: new_size = _infer_size(target.size(), weight.size()) weight = weight.expand(new_size) - return torch._C._nn.binary_cross_entropy( - input, target, weight, reduction_enum) + return torch._C._nn.binary_cross_entropy(input, target, weight, reduction_enum) -def binary_cross_entropy_with_logits(input, target, weight=None, size_average=None, - reduce=None, reduction='mean', pos_weight=None): - # type: (Tensor, Tensor, Optional[Tensor], Optional[bool], Optional[bool], str, Optional[Tensor]) -> Tensor +def binary_cross_entropy_with_logits( + input: Tensor, + target: Tensor, + weight: Optional[Tensor] = None, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", + pos_weight: Optional[Tensor] = None, +) -> Tensor: r"""Function that measures Binary Cross Entropy between target and output logits. @@ -2539,9 +2816,16 @@ def binary_cross_entropy_with_logits(input, target, weight=None, size_average=No tens_ops = (input, target) if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): return handle_torch_function( - binary_cross_entropy_with_logits, tens_ops, input, target, weight=weight, - size_average=size_average, reduce=reduce, reduction=reduction, - pos_weight=pos_weight) + binary_cross_entropy_with_logits, + tens_ops, + input, + target, + weight=weight, + size_average=size_average, + reduce=reduce, + reduction=reduction, + pos_weight=pos_weight, + ) if size_average is not None or reduce is not None: reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) else: @@ -2553,8 +2837,14 @@ def binary_cross_entropy_with_logits(input, target, weight=None, size_average=No return torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction_enum) -def smooth_l1_loss(input, target, size_average=None, reduce=None, reduction='mean', beta=1.0): - # type: (Tensor, Tensor, Optional[bool], Optional[bool], str, float) -> Tensor +def smooth_l1_loss( + input: Tensor, + target: Tensor, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", + beta: float = 1.0, +) -> Tensor: r"""Function that uses a squared term if the absolute element-wise error falls below beta and an L1 term otherwise. @@ -2564,13 +2854,22 @@ def smooth_l1_loss(input, target, size_average=None, reduce=None, reduction='mea tens_ops = (input, target) if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): return handle_torch_function( - smooth_l1_loss, tens_ops, input, target, size_average=size_average, - reduce=reduce, reduction=reduction, beta=beta) + smooth_l1_loss, + tens_ops, + input, + target, + size_average=size_average, + reduce=reduce, + reduction=reduction, + beta=beta, + ) if not (target.size() == input.size()): - warnings.warn("Using a target size ({}) that is different to the input size ({}). " - "This will likely lead to incorrect results due to broadcasting. " - "Please ensure they have the same size.".format(target.size(), input.size()), - stacklevel=2) + warnings.warn( + "Using a target size ({}) that is different to the input size ({}). " + "This will likely lead to incorrect results due to broadcasting. " + "Please ensure they have the same size.".format(target.size(), input.size()), + stacklevel=2, + ) if size_average is not None or reduce is not None: reduction = _Reduction.legacy_get_string(size_average, reduce) @@ -2578,8 +2877,13 @@ def smooth_l1_loss(input, target, size_average=None, reduce=None, reduction='mea return torch._C._nn.smooth_l1_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction), beta) -def l1_loss(input, target, size_average=None, reduce=None, reduction='mean'): - # type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor +def l1_loss( + input: Tensor, + target: Tensor, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: r"""l1_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor Function that takes the mean element-wise absolute value difference. @@ -2590,23 +2894,29 @@ def l1_loss(input, target, size_average=None, reduce=None, reduction='mean'): tens_ops = (input, target) if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): return handle_torch_function( - l1_loss, tens_ops, input, target, size_average=size_average, reduce=reduce, - reduction=reduction) + l1_loss, tens_ops, input, target, size_average=size_average, reduce=reduce, reduction=reduction + ) if not (target.size() == input.size()): - warnings.warn("Using a target size ({}) that is different to the input size ({}). " - "This will likely lead to incorrect results due to broadcasting. " - "Please ensure they have the same size.".format(target.size(), input.size()), - stacklevel=2) + warnings.warn( + "Using a target size ({}) that is different to the input size ({}). " + "This will likely lead to incorrect results due to broadcasting. " + "Please ensure they have the same size.".format(target.size(), input.size()), + stacklevel=2, + ) if size_average is not None or reduce is not None: reduction = _Reduction.legacy_get_string(size_average, reduce) - expanded_input, expanded_target = torch.broadcast_tensors(input, target) return torch._C._nn.l1_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction)) -def mse_loss(input, target, size_average=None, reduce=None, reduction='mean'): - # type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor +def mse_loss( + input: Tensor, + target: Tensor, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: r"""mse_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor Measures the element-wise mean squared error. @@ -2617,13 +2927,15 @@ def mse_loss(input, target, size_average=None, reduce=None, reduction='mean'): tens_ops = (input, target) if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): return handle_torch_function( - mse_loss, tens_ops, input, target, size_average=size_average, reduce=reduce, - reduction=reduction) + mse_loss, tens_ops, input, target, size_average=size_average, reduce=reduce, reduction=reduction + ) if not (target.size() == input.size()): - warnings.warn("Using a target size ({}) that is different to the input size ({}). " - "This will likely lead to incorrect results due to broadcasting. " - "Please ensure they have the same size.".format(target.size(), input.size()), - stacklevel=2) + warnings.warn( + "Using a target size ({}) that is different to the input size ({}). " + "This will likely lead to incorrect results due to broadcasting. " + "Please ensure they have the same size.".format(target.size(), input.size()), + stacklevel=2, + ) if size_average is not None or reduce is not None: reduction = _Reduction.legacy_get_string(size_average, reduce) @@ -2631,9 +2943,15 @@ def mse_loss(input, target, size_average=None, reduce=None, reduction='mean'): return torch._C._nn.mse_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction)) -def margin_ranking_loss(input1, input2, target, margin=0, size_average=None, - reduce=None, reduction='mean'): - # type: (Tensor, Tensor, Tensor, float, Optional[bool], Optional[bool], str) -> Tensor +def margin_ranking_loss( + input1: Tensor, + input2: Tensor, + target: Tensor, + margin: float = 0, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: r"""margin_ranking_loss(input1, input2, target, margin=0, size_average=None, reduce=None, reduction='mean') -> Tensor See :class:`~torch.nn.MarginRankingLoss` for details. @@ -2642,21 +2960,38 @@ def margin_ranking_loss(input1, input2, target, margin=0, size_average=None, tens_ops = (input1, input2, target) if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): return handle_torch_function( - margin_ranking_loss, tens_ops, input1, input2, target, margin=margin, - size_average=size_average, reduce=reduce, reduction=reduction) + margin_ranking_loss, + tens_ops, + input1, + input2, + target, + margin=margin, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) if size_average is not None or reduce is not None: reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) else: reduction_enum = _Reduction.get_enum(reduction) if input1.dim() == 0 or input2.dim() == 0 or target.dim() == 0: - raise RuntimeError(("margin_ranking_loss does not support scalars, got sizes: " - "input1: {}, input2: {}, target: {} ".format(input1.size(), input2.size(), target.size()))) + raise RuntimeError( + ( + "margin_ranking_loss does not support scalars, got sizes: " + "input1: {}, input2: {}, target: {} ".format(input1.size(), input2.size(), target.size()) + ) + ) return torch.margin_ranking_loss(input1, input2, target, margin, reduction_enum) -def hinge_embedding_loss(input, target, margin=1.0, size_average=None, - reduce=None, reduction='mean'): - # type: (Tensor, Tensor, float, Optional[bool], Optional[bool], str) -> Tensor +def hinge_embedding_loss( + input: Tensor, + target: Tensor, + margin: float = 1.0, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: r"""hinge_embedding_loss(input, target, margin=1.0, size_average=None, reduce=None, reduction='mean') -> Tensor See :class:`~torch.nn.HingeEmbeddingLoss` for details. @@ -2665,8 +3000,15 @@ def hinge_embedding_loss(input, target, margin=1.0, size_average=None, tens_ops = (input, target) if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): return handle_torch_function( - hinge_embedding_loss, tens_ops, input, target, margin=margin, - size_average=size_average, reduce=reduce, reduction=reduction) + hinge_embedding_loss, + tens_ops, + input, + target, + margin=margin, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) if size_average is not None or reduce is not None: reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) else: @@ -2674,8 +3016,13 @@ def hinge_embedding_loss(input, target, margin=1.0, size_average=None, return torch.hinge_embedding_loss(input, target, margin, reduction_enum) -def multilabel_margin_loss(input, target, size_average=None, reduce=None, reduction='mean'): - # type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor +def multilabel_margin_loss( + input: Tensor, + target: Tensor, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: r"""multilabel_margin_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor See :class:`~torch.nn.MultiLabelMarginLoss` for details. @@ -2684,8 +3031,14 @@ def multilabel_margin_loss(input, target, size_average=None, reduce=None, reduct tens_ops = (input, target) if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): return handle_torch_function( - multilabel_margin_loss, tens_ops, input, target, size_average=size_average, - reduce=reduce, reduction=reduction) + multilabel_margin_loss, + tens_ops, + input, + target, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) if size_average is not None or reduce is not None: reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) else: @@ -2693,8 +3046,13 @@ def multilabel_margin_loss(input, target, size_average=None, reduce=None, reduct return torch._C._nn.multilabel_margin_loss(input, target, reduction_enum) -def soft_margin_loss(input, target, size_average=None, reduce=None, reduction='mean'): - # type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor +def soft_margin_loss( + input: Tensor, + target: Tensor, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: r"""soft_margin_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor See :class:`~torch.nn.SoftMarginLoss` for details. @@ -2703,8 +3061,8 @@ def soft_margin_loss(input, target, size_average=None, reduce=None, reduction='m tens_ops = (input, target) if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): return handle_torch_function( - soft_margin_loss, tens_ops, input, target, size_average=size_average, - reduce=reduce, reduction=reduction) + soft_margin_loss, tens_ops, input, target, size_average=size_average, reduce=reduce, reduction=reduction + ) if size_average is not None or reduce is not None: reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) else: @@ -2712,9 +3070,14 @@ def soft_margin_loss(input, target, size_average=None, reduce=None, reduction='m return torch._C._nn.soft_margin_loss(input, target, reduction_enum) -def multilabel_soft_margin_loss(input, target, weight=None, size_average=None, - reduce=None, reduction='mean'): - # type: (Tensor, Tensor, Optional[Tensor], Optional[bool], Optional[bool], str) -> Tensor +def multilabel_soft_margin_loss( + input: Tensor, + target: Tensor, + weight: Optional[Tensor] = None, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: r"""multilabel_soft_margin_loss(input, target, weight=None, size_average=None) -> Tensor See :class:`~torch.nn.MultiLabelSoftMarginLoss` for details. @@ -2723,8 +3086,15 @@ def multilabel_soft_margin_loss(input, target, weight=None, size_average=None, tens_ops = (input, target) if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): return handle_torch_function( - multilabel_soft_margin_loss, tens_ops, input, target, weight=weight, - size_average=size_average, reduce=reduce, reduction=reduction) + multilabel_soft_margin_loss, + tens_ops, + input, + target, + weight=weight, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) if size_average is not None or reduce is not None: reduction = _Reduction.legacy_get_string(size_average, reduce) @@ -2735,11 +3105,11 @@ def multilabel_soft_margin_loss(input, target, weight=None, size_average=None, loss = loss.sum(dim=1) / input.size(1) # only return N loss values - if reduction == 'none': + if reduction == "none": ret = loss - elif reduction == 'mean': + elif reduction == "mean": ret = loss.mean() - elif reduction == 'sum': + elif reduction == "sum": ret = loss.sum() else: ret = input @@ -2747,9 +3117,15 @@ def multilabel_soft_margin_loss(input, target, weight=None, size_average=None, return ret -def cosine_embedding_loss(input1, input2, target, margin=0, size_average=None, - reduce=None, reduction='mean'): - # type: (Tensor, Tensor, Tensor, float, Optional[bool], Optional[bool], str) -> Tensor +def cosine_embedding_loss( + input1: Tensor, + input2: Tensor, + target: Tensor, + margin: float = 0, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: r"""cosine_embedding_loss(input1, input2, target, margin=0, size_average=None, reduce=None, reduction='mean') -> Tensor See :class:`~torch.nn.CosineEmbeddingLoss` for details. @@ -2758,8 +3134,16 @@ def cosine_embedding_loss(input1, input2, target, margin=0, size_average=None, tens_ops = (input1, input2, target) if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): return handle_torch_function( - cosine_embedding_loss, tens_ops, input1, input2, target, margin=margin, - size_average=size_average, reduce=reduce, reduction=reduction) + cosine_embedding_loss, + tens_ops, + input1, + input2, + target, + margin=margin, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) if size_average is not None or reduce is not None: reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) else: @@ -2767,9 +3151,16 @@ def cosine_embedding_loss(input1, input2, target, margin=0, size_average=None, return torch.cosine_embedding_loss(input1, input2, target, margin, reduction_enum) -def multi_margin_loss(input, target, p=1, margin=1., weight=None, size_average=None, - reduce=None, reduction='mean'): - # type: (Tensor, Tensor, int, float, Optional[Tensor], Optional[bool], Optional[bool], str) -> Tensor +def multi_margin_loss( + input: Tensor, + target: Tensor, + p: int = 1, + margin: float = 1.0, + weight: Optional[Tensor] = None, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: r"""multi_margin_loss(input, target, p=1, margin=1, weight=None, size_average=None, reduce=None, reduction='mean') -> Tensor @@ -2779,23 +3170,33 @@ def multi_margin_loss(input, target, p=1, margin=1., weight=None, size_average=N tens_ops = (input, target) if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): return handle_torch_function( - multi_margin_loss, tens_ops, input, target, p=p, margin=margin, - weight=weight, size_average=size_average, reduce=reduce, - reduction=reduction) + multi_margin_loss, + tens_ops, + input, + target, + p=p, + margin=margin, + weight=weight, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) if size_average is not None or reduce is not None: reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) else: reduction_enum = _Reduction.get_enum(reduction) if p != 1 and p != 2: - raise ValueError('only p == 1 and p == 2 supported') + raise ValueError("only p == 1 and p == 2 supported") if weight is not None: if weight.dim() != 1: - raise ValueError('weight must be one-dimensional') + raise ValueError("weight must be one-dimensional") return torch._C._nn.multi_margin_loss(input, target, p, margin, weight, reduction_enum) -pixel_shuffle = _add_docstr(torch.pixel_shuffle, r""" +pixel_shuffle = _add_docstr( + torch.pixel_shuffle, + r""" pixel_shuffle(input, upscale_factor) -> Tensor Rearranges elements in a tensor of shape :math:`(*, C \times r^2, H, W)` to a @@ -2813,9 +3214,12 @@ def multi_margin_loss(input, target, p=1, margin=1., weight=None, size_average=N >>> output = torch.nn.functional.pixel_shuffle(input, 3) >>> print(output.size()) torch.Size([1, 1, 12, 12]) -""") +""", +) -pixel_unshuffle = _add_docstr(torch.pixel_unshuffle, r""" +pixel_unshuffle = _add_docstr( + torch.pixel_unshuffle, + r""" pixel_unshuffle(input, downscale_factor) -> Tensor Reverses the :class:`~torch.nn.PixelShuffle` operation by rearranging elements in a @@ -2834,9 +3238,12 @@ def multi_margin_loss(input, target, p=1, margin=1., weight=None, size_average=N >>> output = torch.nn.functional.pixel_unshuffle(input, 3) >>> print(output.size()) torch.Size([1, 9, 4, 4]) -""") +""", +) -channel_shuffle = _add_docstr(torch.channel_shuffle, r""" +channel_shuffle = _add_docstr( + torch.channel_shuffle, + r""" channel_shuffle(input, groups) -> Tensor Divide the channels in a tensor of shape :math:`(*, C , H, W)` @@ -2873,20 +3280,23 @@ def multi_margin_loss(input, target, p=1, margin=1., weight=None, size_average=N [[13, 14], [15, 16]], ]] -""") +""", +) + @_overload # noqa: F811 -def upsample(input, size=None, scale_factor=None, mode='nearest', align_corners=None): # noqa: F811 +def upsample(input, size=None, scale_factor=None, mode="nearest", align_corners=None): # noqa: F811 # type: (Tensor, Optional[int], Optional[float], str, Optional[bool]) -> Tensor pass + @_overload # noqa: F811 -def upsample(input, size=None, scale_factor=None, mode='nearest', align_corners=None): # noqa: F811 +def upsample(input, size=None, scale_factor=None, mode="nearest", align_corners=None): # noqa: F811 # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor pass -def upsample(input, size=None, scale_factor=None, mode='nearest', align_corners=None): # noqa: F811 +def upsample(input, size=None, scale_factor=None, mode="nearest", align_corners=None): # noqa: F811 r"""Upsamples the input to either the given :attr:`size` or the given :attr:`scale_factor` @@ -2945,26 +3355,38 @@ def upsample(input, size=None, scale_factor=None, mode='nearest', align_corners= """ warnings.warn("nn.functional.upsample is deprecated. Use nn.functional.interpolate instead.") return interpolate(input, size, scale_factor, mode, align_corners) + + upsample.__doc__ = upsample.__doc__.format(**reproducibility_notes) + @_overload # noqa: F811 def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None): # noqa: F811 # type: (Tensor, Optional[int], Optional[List[float]], str, Optional[bool], Optional[bool]) -> Tensor pass + @_overload # noqa: F811 def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None): # noqa: F811 # type: (Tensor, Optional[List[int]], Optional[List[float]], str, Optional[bool], Optional[bool]) -> Tensor pass + @_overload # noqa: F811 def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None): # noqa: F811 # type: (Tensor, Optional[int], Optional[float], str, Optional[bool], Optional[bool]) -> Tensor pass + @_overload # noqa: F811 -def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None): # noqa: F811 - # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool], Optional[bool]) -> Tensor +def interpolate( # noqa: F811 + input: Tensor, + size: Optional[List[int]] = None, + scale_factor: Optional[float] = None, + mode: str = "nearest", + align_corners: Optional[bool] = None, + recompute_scale_factor: Optional[bool] = None, +) -> Tensor: # noqa: F811 pass def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None): # noqa: F811 @@ -3040,20 +3462,30 @@ def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corne if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): return handle_torch_function( - interpolate, (input,), input, size=size, scale_factor=scale_factor, - mode=mode, align_corners=align_corners, - recompute_scale_factor=recompute_scale_factor) - - if mode in ('nearest', 'area'): + interpolate, + (input,), + input, + size=size, + scale_factor=scale_factor, + mode=mode, + align_corners=align_corners, + recompute_scale_factor=recompute_scale_factor, + ) + + if mode in ("nearest", "area"): if align_corners is not None: - raise ValueError("align_corners option can only be set with the " - "interpolating modes: linear | bilinear | bicubic | trilinear") + raise ValueError( + "align_corners option can only be set with the " + "interpolating modes: linear | bilinear | bicubic | trilinear" + ) else: if align_corners is None: - warnings.warn("Default upsampling behavior when mode={} is changed " - "to align_corners=False since 0.4.0. Please specify " - "align_corners=True if the old behavior is desired. " - "See the documentation of nn.Upsample for details.".format(mode)) + warnings.warn( + "Default upsampling behavior when mode={} is changed " + "to align_corners=False since 0.4.0. Please specify " + "align_corners=True if the old behavior is desired. " + "See the documentation of nn.Upsample for details.".format(mode) + ) align_corners = False dim = input.dim() - 2 # Number of spatial dimensions. @@ -3063,14 +3495,15 @@ def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corne # After this block, exactly one of output_size and scale_factors will # be non-None, and it will be a list (or tuple). if size is not None and scale_factor is not None: - raise ValueError('only one of size or scale_factor should be defined') + raise ValueError("only one of size or scale_factor should be defined") elif size is not None: assert scale_factor is None scale_factors = None if isinstance(size, (list, tuple)): if len(size) != dim: - raise ValueError('size shape must match input shape. ' - 'Input is {}D, size is {}'.format(dim, len(size))) + raise ValueError( + "size shape must match input shape. " "Input is {}D, size is {}".format(dim, len(size)) + ) output_size = size else: output_size = [size for _ in range(dim)] @@ -3079,13 +3512,15 @@ def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corne output_size = None if isinstance(scale_factor, (list, tuple)): if len(scale_factor) != dim: - raise ValueError('scale_factor shape must match input shape. ' - 'Input is {}D, scale_factor is {}'.format(dim, len(scale_factor))) + raise ValueError( + "scale_factor shape must match input shape. " + "Input is {}D, scale_factor is {}".format(dim, len(scale_factor)) + ) scale_factors = scale_factor else: scale_factors = [scale_factor for _ in range(dim)] else: - raise ValueError('either size or scale_factor should be defined') + raise ValueError("either size or scale_factor should be defined") if recompute_scale_factor is None: # only warn when the scales have floating values since @@ -3093,11 +3528,13 @@ def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corne if scale_factors is not None: for scale in scale_factors: if math.floor(scale) != scale: - warnings.warn("The default behavior for interpolate/upsample with float scale_factor changed " - "in 1.6.0 to align with other frameworks/libraries, and now uses scale_factor directly, " - "instead of relying on the computed output size. " - "If you wish to restore the old behavior, please set recompute_scale_factor=True. " - "See the documentation of nn.Upsample for details. ") + warnings.warn( + "The default behavior for interpolate/upsample with float scale_factor changed " + "in 1.6.0 to align with other frameworks/libraries, and now uses scale_factor directly, " + "instead of relying on the computed output size. " + "If you wish to restore the old behavior, please set recompute_scale_factor=True. " + "See the documentation of nn.Upsample for details. " + ) break elif recompute_scale_factor and size is not None: raise ValueError("recompute_scale_factor is not meaningful with an explicit size.") @@ -3112,71 +3549,80 @@ def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corne # The C++ code will recompute it based on the (integer) output size. if not torch.jit.is_scripting() and torch._C._get_tracing_state(): # make scale_factor a tensor in tracing so constant doesn't get baked in - output_size = [(torch.floor((input.size(i + 2).float() * torch.tensor(scale_factors[i], - dtype=torch.float32)).float())) for i in range(dim)] + output_size = [ + (torch.floor((input.size(i + 2).float() * torch.tensor(scale_factors[i], dtype=torch.float32)).float())) + for i in range(dim) + ] else: assert scale_factors is not None output_size = [int(math.floor(float(input.size(i + 2)) * scale_factors[i])) for i in range(dim)] scale_factors = None - if input.dim() == 3 and mode == 'nearest': + if input.dim() == 3 and mode == "nearest": return torch._C._nn.upsample_nearest1d(input, output_size, scale_factors) - if input.dim() == 4 and mode == 'nearest': + if input.dim() == 4 and mode == "nearest": return torch._C._nn.upsample_nearest2d(input, output_size, scale_factors) - if input.dim() == 5 and mode == 'nearest': + if input.dim() == 5 and mode == "nearest": return torch._C._nn.upsample_nearest3d(input, output_size, scale_factors) - if input.dim() == 3 and mode == 'area': + if input.dim() == 3 and mode == "area": assert output_size is not None return adaptive_avg_pool1d(input, output_size) - if input.dim() == 4 and mode == 'area': + if input.dim() == 4 and mode == "area": assert output_size is not None return adaptive_avg_pool2d(input, output_size) - if input.dim() == 5 and mode == 'area': + if input.dim() == 5 and mode == "area": assert output_size is not None return adaptive_avg_pool3d(input, output_size) - if input.dim() == 3 and mode == 'linear': + if input.dim() == 3 and mode == "linear": assert align_corners is not None return torch._C._nn.upsample_linear1d(input, output_size, align_corners, scale_factors) - if input.dim() == 4 and mode == 'bilinear': + if input.dim() == 4 and mode == "bilinear": assert align_corners is not None return torch._C._nn.upsample_bilinear2d(input, output_size, align_corners, scale_factors) - if input.dim() == 5 and mode == 'trilinear': + if input.dim() == 5 and mode == "trilinear": assert align_corners is not None return torch._C._nn.upsample_trilinear3d(input, output_size, align_corners, scale_factors) - if input.dim() == 4 and mode == 'bicubic': + if input.dim() == 4 and mode == "bicubic": assert align_corners is not None return torch._C._nn.upsample_bicubic2d(input, output_size, align_corners, scale_factors) - if input.dim() == 3 and mode == 'bilinear': + if input.dim() == 3 and mode == "bilinear": raise NotImplementedError("Got 3D input, but bilinear mode needs 4D input") - if input.dim() == 3 and mode == 'trilinear': + if input.dim() == 3 and mode == "trilinear": raise NotImplementedError("Got 3D input, but trilinear mode needs 5D input") - if input.dim() == 4 and mode == 'linear': + if input.dim() == 4 and mode == "linear": raise NotImplementedError("Got 4D input, but linear mode needs 3D input") - if input.dim() == 4 and mode == 'trilinear': + if input.dim() == 4 and mode == "trilinear": raise NotImplementedError("Got 4D input, but trilinear mode needs 5D input") - if input.dim() == 5 and mode == 'linear': + if input.dim() == 5 and mode == "linear": raise NotImplementedError("Got 5D input, but linear mode needs 3D input") - if input.dim() == 5 and mode == 'bilinear': + if input.dim() == 5 and mode == "bilinear": raise NotImplementedError("Got 5D input, but bilinear mode needs 4D input") - raise NotImplementedError("Input Error: Only 3D, 4D and 5D input Tensors supported" - " (got {}D) for the modes: nearest | linear | bilinear | bicubic | trilinear" - " (got {})".format(input.dim(), mode)) + raise NotImplementedError( + "Input Error: Only 3D, 4D and 5D input Tensors supported" + " (got {}D) for the modes: nearest | linear | bilinear | bicubic | trilinear" + " (got {})".format(input.dim(), mode) + ) + + interpolate.__doc__ = interpolate.__doc__.format(**reproducibility_notes) + @_overload # noqa: F811 def upsample_nearest(input, size=None, scale_factor=None): # noqa: F811 # type: (Tensor, Optional[int], Optional[float]) -> Tensor pass + @_overload # noqa: F811 def upsample_nearest(input, size=None, scale_factor=None): # noqa: F811 # type: (Tensor, Optional[List[int]], Optional[float]) -> Tensor pass + def upsample_nearest(input, size=None, scale_factor=None): # noqa: F811 r"""Upsamples the input, using nearest neighbours' pixel values. @@ -3198,29 +3644,40 @@ def upsample_nearest(input, size=None, scale_factor=None): # noqa: F811 """ # DeprecationWarning is ignored by default warnings.warn("nn.functional.upsample_nearest is deprecated. Use nn.functional.interpolate instead.") - return interpolate(input, size, scale_factor, mode='nearest') + return interpolate(input, size, scale_factor, mode="nearest") + + upsample_nearest.__doc__ = upsample_nearest.__doc__.format(**reproducibility_notes) + @_overload # noqa: F811 -def upsample_bilinear(input, size=None, scale_factor=None): # noqa: F811 - # type: (Tensor, Optional[int], Optional[float]) -> Tensor +def upsample_bilinear( + input: Tensor, size: Optional[int] = None, scale_factor: Optional[float] = None +) -> Tensor: # noqa: F811 pass + @_overload # noqa: F811 -def upsample_bilinear(input, size=None, scale_factor=None): # noqa: F811 - # type: (Tensor, Optional[List[int]], Optional[float]) -> Tensor +def upsample_bilinear( # noqa: F811 + input: Tensor, size: Optional[List[int]] = None, scale_factor: Optional[float] = None +) -> Tensor: # noqa: F811 pass + @_overload # noqa: F811 -def upsample_bilinear(input, size=None, scale_factor=None): # noqa: F811 - # type: (Tensor, Optional[int], Optional[List[float]]) -> Tensor +def upsample_bilinear( # noqa: F811 + input: Tensor, size: Optional[int] = None, scale_factor: Optional[List[float]] = None +) -> Tensor: # noqa: F811 pass + @_overload # noqa: F811 -def upsample_bilinear(input, size=None, scale_factor=None): # noqa: F811 - # type: (Tensor, Optional[List[int]], Optional[List[float]]) -> Tensor +def upsample_bilinear( # noqa: F811 + input: Tensor, size: Optional[List[int]] = None, scale_factor: Optional[List[float]] = None +) -> Tensor: # noqa: F811 pass + def upsample_bilinear(input, size=None, scale_factor=None): # noqa: F811 r"""Upsamples the input, using bilinear upsampling. @@ -3242,24 +3699,31 @@ def upsample_bilinear(input, size=None, scale_factor=None): # noqa: F811 """ # DeprecationWarning is ignored by default warnings.warn("nn.functional.upsample_bilinear is deprecated. Use nn.functional.interpolate instead.") - return interpolate(input, size, scale_factor, mode='bilinear', align_corners=True) + return interpolate(input, size, scale_factor, mode="bilinear", align_corners=True) + + upsample_bilinear.__doc__ = upsample_bilinear.__doc__.format(**reproducibility_notes) GRID_SAMPLE_INTERPOLATION_MODES = { - 'bilinear': 0, - 'nearest': 1, - 'bicubic': 2, + "bilinear": 0, + "nearest": 1, + "bicubic": 2, } GRID_SAMPLE_PADDING_MODES = { - 'zeros': 0, - 'border': 1, - 'reflection': 2, + "zeros": 0, + "border": 1, + "reflection": 2, } -def grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corners=None): - # type: (Tensor, Tensor, str, str, Optional[bool]) -> Tensor +def grid_sample( + input: Tensor, + grid: Tensor, + mode: str = "bilinear", + padding_mode: str = "zeros", + align_corners: Optional[bool] = None, +) -> Tensor: r"""Given an :attr:`input` and a flow-field :attr:`grid`, computes the ``output`` using :attr:`input` values and pixel locations from :attr:`grid`. @@ -3316,7 +3780,7 @@ def grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corner or :math:`(N, D_\text{out}, H_\text{out}, W_\text{out}, 3)` (5-D case) mode (str): interpolation mode to calculate output values ``'bilinear'`` | ``'nearest'`` | ``'bicubic'``. Default: ``'bilinear'`` - Note: ``mode='bicubic'`` supports only 4-D input. + Note: ``mode='bicubic'`` supports only 4-D input. When ``mode='bilinear'`` and the input is 5-D, the interpolation mode used internally will actually be trilinear. However, when the input is 4-D, the interpolation mode will legitimately be bilinear. @@ -3349,11 +3813,11 @@ def grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corner in order to bring it in line with the default for :func:`interpolate`. .. note:: - ``mode='bicubic'`` is implemented using the `cubic convolution algorithm`_ with :math:`\alpha=-0.75`. - The constant :math:`\alpha` might be different from packages to packages. - For example, `PIL`_ and `OpenCV`_ use -0.5 and -0.75 respectively. - This algorithm may "overshoot" the range of values it's interpolating. - For example, it may produce negative values or values greater than 255 when interpolating input in [0, 255]. + ``mode='bicubic'`` is implemented using the `cubic convolution algorithm`_ with :math:`\alpha=-0.75`. + The constant :math:`\alpha` might be different from packages to packages. + For example, `PIL`_ and `OpenCV`_ use -0.5 and -0.75 respectively. + This algorithm may "overshoot" the range of values it's interpolating. + For example, it may produce negative values or values greater than 255 when interpolating input in [0, 255]. Clamp the results with :func: `torch.clamp` to ensure they are within the valid range. .. _`cubic convolution algorithm`: https://en.wikipedia.org/wiki/Bicubic_interpolation .. _`PIL`: https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/src/libImaging/Resample.c#L51 @@ -3363,42 +3827,47 @@ def grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corner tens_ops = (input, grid) if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): return handle_torch_function( - grid_sample, tens_ops, input, grid, mode=mode, padding_mode=padding_mode, - align_corners=align_corners) - if mode != 'bilinear' and mode != 'nearest' and mode != 'bicubic': - raise ValueError("nn.functional.grid_sample(): expected mode to be " - "'bilinear', 'nearest' or 'bicubic', but got: '{}'".format(mode)) - if padding_mode != 'zeros' and padding_mode != 'border' and padding_mode != 'reflection': - raise ValueError("nn.functional.grid_sample(): expected padding_mode " - "to be 'zeros', 'border', or 'reflection', " - "but got: '{}'".format(padding_mode)) - - if mode == 'bilinear': + grid_sample, tens_ops, input, grid, mode=mode, padding_mode=padding_mode, align_corners=align_corners + ) + if mode != "bilinear" and mode != "nearest" and mode != "bicubic": + raise ValueError( + "nn.functional.grid_sample(): expected mode to be " + "'bilinear', 'nearest' or 'bicubic', but got: '{}'".format(mode) + ) + if padding_mode != "zeros" and padding_mode != "border" and padding_mode != "reflection": + raise ValueError( + "nn.functional.grid_sample(): expected padding_mode " + "to be 'zeros', 'border', or 'reflection', " + "but got: '{}'".format(padding_mode) + ) + + if mode == "bilinear": mode_enum = 0 - elif mode == 'nearest': + elif mode == "nearest": mode_enum = 1 else: # mode == 'bicubic' mode_enum = 2 - if padding_mode == 'zeros': + if padding_mode == "zeros": padding_mode_enum = 0 - elif padding_mode == 'border': + elif padding_mode == "border": padding_mode_enum = 1 else: # padding_mode == 'reflection' padding_mode_enum = 2 if align_corners is None: - warnings.warn("Default grid_sample and affine_grid behavior has changed " - "to align_corners=False since 1.3.0. Please specify " - "align_corners=True if the old behavior is desired. " - "See the documentation of grid_sample for details.") + warnings.warn( + "Default grid_sample and affine_grid behavior has changed " + "to align_corners=False since 1.3.0. Please specify " + "align_corners=True if the old behavior is desired. " + "See the documentation of grid_sample for details." + ) align_corners = False return torch.grid_sampler(input, grid, mode_enum, padding_mode_enum, align_corners) -def affine_grid(theta, size, align_corners=None): - # type: (Tensor, List[int], Optional[bool]) -> Tensor +def affine_grid(theta: Tensor, size: List[int], align_corners: Optional[bool] = None) -> Tensor: r"""Generates a 2D or 3D flow field (sampling grid), given a batch of affine matrices :attr:`theta`. @@ -3448,49 +3917,55 @@ def affine_grid(theta, size, align_corners=None): """ if not torch.jit.is_scripting(): if type(theta) is not Tensor and has_torch_function((theta,)): - return handle_torch_function( - affine_grid, (theta,), theta, size, align_corners=align_corners) + return handle_torch_function(affine_grid, (theta,), theta, size, align_corners=align_corners) if align_corners is None: - warnings.warn("Default grid_sample and affine_grid behavior has changed " - "to align_corners=False since 1.3.0. Please specify " - "align_corners=True if the old behavior is desired. " - "See the documentation of grid_sample for details.") + warnings.warn( + "Default grid_sample and affine_grid behavior has changed " + "to align_corners=False since 1.3.0. Please specify " + "align_corners=True if the old behavior is desired. " + "See the documentation of grid_sample for details." + ) align_corners = False # enforce floating point dtype on theta if not theta.is_floating_point(): - raise ValueError("Expected theta to have floating point type, but got {}" - .format(theta.dtype)) + raise ValueError("Expected theta to have floating point type, but got {}".format(theta.dtype)) # check that shapes and sizes match if len(size) == 4: if theta.dim() != 3 or theta.shape[-2] != 2 or theta.shape[-1] != 3: - raise ValueError("Expected a batch of 2D affine matrices of shape Nx2x3 " - "for size {}. Got {}.".format(size, theta.shape)) + raise ValueError( + "Expected a batch of 2D affine matrices of shape Nx2x3 " + "for size {}. Got {}.".format(size, theta.shape) + ) spatial_size = size[-2:] # spatial dimension sizes elif len(size) == 5: if theta.dim() != 3 or theta.shape[-2] != 3 or theta.shape[-1] != 4: - raise ValueError("Expected a batch of 3D affine matrices of shape Nx3x4 " - "for size {}. Got {}.".format(size, theta.shape)) + raise ValueError( + "Expected a batch of 3D affine matrices of shape Nx3x4 " + "for size {}. Got {}.".format(size, theta.shape) + ) spatial_size = size[-3:] # spatial dimension sizes else: - raise NotImplementedError("affine_grid only supports 4D and 5D sizes, " - "for 2D and 3D affine transforms, respectively. " - "Got size {}.".format(size)) + raise NotImplementedError( + "affine_grid only supports 4D and 5D sizes, " + "for 2D and 3D affine transforms, respectively. " + "Got size {}.".format(size) + ) # check for empty span if align_corners and min(spatial_size) == 1: - warnings.warn("Since version 1.3.0, affine_grid behavior has changed " - "for unit-size grids when align_corners=True. " - "This is not an intended use case of affine_grid. " - "See the documentation of affine_grid for details.") + warnings.warn( + "Since version 1.3.0, affine_grid behavior has changed " + "for unit-size grids when align_corners=True. " + "This is not an intended use case of affine_grid. " + "See the documentation of affine_grid for details." + ) elif min(size) <= 0: - raise ValueError("Expected non-zero, positive output size. Got {}" - .format(size)) + raise ValueError("Expected non-zero, positive output size. Got {}".format(size)) return torch.affine_grid_generator(theta, size, align_corners) -def _pad(input, pad, mode='constant', value=0): - # type: (Tensor, List[int], str, float) -> Tensor +def _pad(input: Tensor, pad: List[int], mode: str = "constant", value: float = 0) -> Tensor: r"""Pads tensor. Padding size: @@ -3551,49 +4026,49 @@ def _pad(input, pad, mode='constant', value=0): """ if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - _pad, (input,), input, pad, mode=mode, value=value) - assert len(pad) % 2 == 0, 'Padding length must be divisible by 2' - assert len(pad) // 2 <= input.dim(), 'Padding length too large' - if mode == 'constant': + return handle_torch_function(_pad, (input,), input, pad, mode=mode, value=value) + assert len(pad) % 2 == 0, "Padding length must be divisible by 2" + assert len(pad) // 2 <= input.dim(), "Padding length too large" + if mode == "constant": return _VF.constant_pad_nd(input, pad, value) else: assert value == 0, 'Padding mode "{}"" doesn\'t take in value argument'.format(mode) if input.dim() == 3: - assert len(pad) == 2, '3D tensors expect 2 values for padding' - if mode == 'reflect': + assert len(pad) == 2, "3D tensors expect 2 values for padding" + if mode == "reflect": return torch._C._nn.reflection_pad1d(input, pad) - elif mode == 'replicate': + elif mode == "replicate": return torch._C._nn.replication_pad1d(input, pad) - elif mode == 'circular': + elif mode == "circular": return _pad_circular(input, pad) else: raise NotImplementedError elif input.dim() == 4: - assert len(pad) == 4, '4D tensors expect 4 values for padding' - if mode == 'reflect': + assert len(pad) == 4, "4D tensors expect 4 values for padding" + if mode == "reflect": return torch._C._nn.reflection_pad2d(input, pad) - elif mode == 'replicate': + elif mode == "replicate": return torch._C._nn.replication_pad2d(input, pad) - elif mode == 'circular': + elif mode == "circular": return _pad_circular(input, pad) else: raise NotImplementedError elif input.dim() == 5: - assert len(pad) == 6, '5D tensors expect 6 values for padding' - if mode == 'reflect': + assert len(pad) == 6, "5D tensors expect 6 values for padding" + if mode == "reflect": raise NotImplementedError - elif mode == 'replicate': + elif mode == "replicate": return torch._C._nn.replication_pad3d(input, pad) - elif mode == 'circular': + elif mode == "circular": return _pad_circular(input, pad) else: raise NotImplementedError else: raise NotImplementedError("Only 3D, 4D, 5D padding with non-constant padding are supported for now") + # We define this function as _pad because it takes an argument # named pad, which clobbers the recursive reference to the pad # function needed for __torch_function__ support @@ -3602,15 +4077,16 @@ def _pad(input, pad, mode='constant', value=0): # distance -def pairwise_distance(x1, x2, p=2., eps=1e-6, keepdim=False): - # type: (Tensor, Tensor, float, float, bool) -> Tensor +def pairwise_distance(x1: Tensor, x2: Tensor, p: float = 2.0, eps: float = 1e-6, keepdim: bool = False) -> Tensor: r""" See :class:`torch.nn.PairwiseDistance` for details """ return torch.pairwise_distance(x1, x2, p, eps, keepdim) -pdist = _add_docstr(torch.pdist, r""" +pdist = _add_docstr( + torch.pdist, + r""" pdist(input, p=2) -> Tensor Computes the p-norm distance between every pair of row vectors in the input. @@ -3631,10 +4107,13 @@ def pairwise_distance(x1, x2, p=2., eps=1e-6, keepdim=False): input: input tensor of shape :math:`N \times M`. p: p value for the p-norm distance to calculate between each vector pair :math:`\in [0, \infty]`. -""") +""", +) -cosine_similarity = _add_docstr(torch.cosine_similarity, r""" +cosine_similarity = _add_docstr( + torch.cosine_similarity, + r""" cosine_similarity(x1, x2, dim=1, eps=1e-8) -> Tensor Returns cosine similarity between x1 and x2, computed along dim. @@ -3659,10 +4138,13 @@ def pairwise_distance(x1, x2, p=2., eps=1e-6, keepdim=False): >>> input2 = torch.randn(100, 128) >>> output = F.cosine_similarity(input1, input2) >>> print(output) -""") +""", +) -one_hot = _add_docstr(torch._C._nn.one_hot, r""" +one_hot = _add_docstr( + torch._C._nn.one_hot, + r""" one_hot(tensor, num_classes=-1) -> LongTensor Takes LongTensor with index values of shape ``(*)`` and returns a tensor @@ -3706,12 +4188,22 @@ def pairwise_distance(x1, x2, p=2., eps=1e-6, keepdim=False): [1, 0, 0]], [[0, 1, 0], [0, 0, 1]]]) -""") - - -def triplet_margin_loss(anchor, positive, negative, margin=1.0, p=2, eps=1e-6, swap=False, size_average=None, - reduce=None, reduction="mean"): - # type: (Tensor, Tensor, Tensor, float, float, float, bool, Optional[bool], Optional[bool], str) -> Tensor +""", +) + + +def triplet_margin_loss( + anchor: Tensor, + positive: Tensor, + negative: Tensor, + margin: float = 1.0, + p: float = 2, + eps: float = 1e-6, + swap: bool = False, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: r""" See :class:`~torch.nn.TripletMarginLoss` for details """ @@ -3719,32 +4211,58 @@ def triplet_margin_loss(anchor, positive, negative, margin=1.0, p=2, eps=1e-6, s tens_ops = (anchor, positive, negative) if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): return handle_torch_function( - triplet_margin_loss, tens_ops, anchor, positive, negative, margin=margin, - p=p, eps=eps, swap=swap, size_average=size_average, reduce=reduce, - reduction=reduction) + triplet_margin_loss, + tens_ops, + anchor, + positive, + negative, + margin=margin, + p=p, + eps=eps, + swap=swap, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) if size_average is not None or reduce is not None: reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) else: reduction_enum = _Reduction.get_enum(reduction) - return torch.triplet_margin_loss(anchor, positive, negative, margin, p, eps, - swap, reduction_enum) - - -def triplet_margin_with_distance_loss(anchor, positive, negative, *, distance_function=None, - margin=1.0, swap=False, reduction="mean"): - # type: (Tensor, Tensor, Tensor, Optional[Callable[[Tensor, Tensor], Tensor]], float, bool, str) -> Tensor + return torch.triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap, reduction_enum) + + +def triplet_margin_with_distance_loss( + anchor: Tensor, + positive: Tensor, + negative: Tensor, + *, + distance_function: Optional[Callable[[Tensor, Tensor], Tensor]] = None, + margin: float = 1.0, + swap: bool = False, + reduction: str = "mean" +) -> Tensor: r""" See :class:`~torch.nn.TripletMarginWithDistanceLoss` for details. """ if torch.jit.is_scripting(): - raise NotImplementedError("F.triplet_margin_with_distance_loss does not support JIT scripting: " - "functions requiring Callables cannot be scripted.") + raise NotImplementedError( + "F.triplet_margin_with_distance_loss does not support JIT scripting: " + "functions requiring Callables cannot be scripted." + ) tens_ops = (anchor, positive, negative) if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): return handle_torch_function( - triplet_margin_with_distance_loss, tens_ops, anchor, positive, negative, - distance_function=distance_function, margin=margin, swap=swap, reduction=reduction) + triplet_margin_with_distance_loss, + tens_ops, + anchor, + positive, + negative, + distance_function=distance_function, + margin=margin, + swap=swap, + reduction=reduction, + ) distance_function = distance_function if distance_function is not None else pairwise_distance @@ -3766,8 +4284,7 @@ def triplet_margin_with_distance_loss(anchor, positive, negative, *, distance_fu return output -def normalize(input, p=2, dim=1, eps=1e-12, out=None): - # type: (Tensor, float, int, float, Optional[Tensor]) -> Tensor +def normalize(input: Tensor, p: float = 2, dim: int = 1, eps: float = 1e-12, out: Optional[Tensor] = None) -> Tensor: r"""Performs :math:`L_p` normalization of inputs over specified dimension. For a tensor :attr:`input` of sizes :math:`(n_0, ..., n_{dim}, ..., n_k)`, each @@ -3788,8 +4305,7 @@ def normalize(input, p=2, dim=1, eps=1e-12, out=None): """ if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - normalize, (input,), input, p=p, dim=dim, eps=eps, out=out) + return handle_torch_function(normalize, (input,), input, p=p, dim=dim, eps=eps, out=out) if out is None: denom = input.norm(p, dim, keepdim=True).clamp_min(eps).expand_as(input) return input / denom @@ -3798,8 +4314,7 @@ def normalize(input, p=2, dim=1, eps=1e-12, out=None): return torch.div(input, denom, out=out) -def assert_int_or_pair(arg, arg_name, message): - # type: (List[int], str, str) -> None +def assert_int_or_pair(arg: List[int], arg_name: str, message: str) -> None: assert isinstance(arg, int) or len(arg) == 2, message.format(arg_name) @@ -3824,17 +4339,16 @@ def unfold(input, kernel_size, dilation=1, padding=0, stride=1): if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): return handle_torch_function( - unfold, (input,), input, kernel_size, dilation=dilation, - padding=padding, stride=stride) + unfold, (input,), input, kernel_size, dilation=dilation, padding=padding, stride=stride + ) if input.dim() == 4: - msg = '{} must be int or 2-tuple for 4D input' - assert_int_or_pair(kernel_size, 'kernel_size', msg) - assert_int_or_pair(dilation, 'dilation', msg) - assert_int_or_pair(padding, 'padding', msg) - assert_int_or_pair(stride, 'stride', msg) - - return torch._C._nn.im2col(input, _pair(kernel_size), - _pair(dilation), _pair(padding), _pair(stride)) + msg = "{} must be int or 2-tuple for 4D input" + assert_int_or_pair(kernel_size, "kernel_size", msg) + assert_int_or_pair(dilation, "dilation", msg) + assert_int_or_pair(padding, "padding", msg) + assert_int_or_pair(stride, "stride", msg) + + return torch._C._nn.im2col(input, _pair(kernel_size), _pair(dilation), _pair(padding), _pair(stride)) else: raise NotImplementedError("Input Error: Only 4D input Tensors are supported (got {}D)".format(input.dim())) @@ -3853,24 +4367,24 @@ def fold(input, output_size, kernel_size, dilation=1, padding=0, stride=1): if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): return handle_torch_function( - fold, (input,), input, output_size, kernel_size, dilation=dilation, - padding=padding, stride=stride) + fold, (input,), input, output_size, kernel_size, dilation=dilation, padding=padding, stride=stride + ) if input.dim() == 3: - msg = '{} must be int or 2-tuple for 3D input' - assert_int_or_pair(output_size, 'output_size', msg) - assert_int_or_pair(kernel_size, 'kernel_size', msg) - assert_int_or_pair(dilation, 'dilation', msg) - assert_int_or_pair(padding, 'padding', msg) - assert_int_or_pair(stride, 'stride', msg) - - return torch._C._nn.col2im(input, _pair(output_size), _pair(kernel_size), - _pair(dilation), _pair(padding), _pair(stride)) + msg = "{} must be int or 2-tuple for 3D input" + assert_int_or_pair(output_size, "output_size", msg) + assert_int_or_pair(kernel_size, "kernel_size", msg) + assert_int_or_pair(dilation, "dilation", msg) + assert_int_or_pair(padding, "padding", msg) + assert_int_or_pair(stride, "stride", msg) + + return torch._C._nn.col2im( + input, _pair(output_size), _pair(kernel_size), _pair(dilation), _pair(padding), _pair(stride) + ) else: raise NotImplementedError("Input Error: Only 3D input Tensors are supported (got {}D)".format(input.dim())) -def _pad_circular(input, padding): - # type: (Tensor, List[int]) -> Tensor +def _pad_circular(input: Tensor, padding: List[int]) -> Tensor: """Circularly pads tensor. Tensor values at the beginning are used to pad the end, and values at the @@ -3924,21 +4438,19 @@ def _pad_circular(input, padding): for idx, size in enumerate(paddable_shape): # Only supports wrapping around once - assert padding[-(idx * 2 + 1)] <= size, \ - "Padding value causes wrapping around more than once." - assert padding[-(idx * 2 + 2)] <= size, \ - "Padding value causes wrapping around more than once." + assert padding[-(idx * 2 + 1)] <= size, "Padding value causes wrapping around more than once." + assert padding[-(idx * 2 + 2)] <= size, "Padding value causes wrapping around more than once." # Negative padding should not result in negative sizes - assert padding[-(idx * 2 + 1)] + padding[-(idx * 2 + 2)] + size >= 0, \ - "Negative padding value is resulting in an empty dimension." + assert ( + padding[-(idx * 2 + 1)] + padding[-(idx * 2 + 2)] + size >= 0 + ), "Negative padding value is resulting in an empty dimension." # Get shape of padded tensor out_shape = in_shape[:2] for idx, size in enumerate(paddable_shape): out_shape += (size + padding[-(idx * 2 + 1)] + padding[-(idx * 2 + 2)],) - out = torch.empty(out_shape, dtype=input.dtype, layout=input.layout, - device=input.device) + out = torch.empty(out_shape, dtype=input.dtype, layout=input.layout, device=input.device) # Put original array in padded array if ndim == 1: @@ -3962,8 +4474,7 @@ def _pad_circular(input, padding): in_h0 = max(-padding[-4], 0) in_h1 = in_shape[3] - max(-padding[-3], 0) - out[..., out_d0:out_d1, out_h0:out_h1] = \ - input[..., in_d0:in_d1, in_h0:in_h1] + out[..., out_d0:out_d1, out_h0:out_h1] = input[..., in_d0:in_d1, in_h0:in_h1] elif ndim == 3: out_d0 = max(padding[-2], 0) out_d1 = out_shape[2] - max(padding[-1], 0) @@ -3983,8 +4494,7 @@ def _pad_circular(input, padding): in_w0 = max(-padding[-6], 0) in_w1 = in_shape[4] - max(-padding[-5], 0) - out[..., out_d0:out_d1, out_h0:out_h1, out_w0:out_w1] = \ - input[..., in_d0:in_d1, in_h0:in_h1, in_w0:in_w1] + out[..., out_d0:out_d1, out_h0:out_h1, out_w0:out_w1] = input[..., in_d0:in_d1, in_h0:in_h1, in_w0:in_w1] # The following steps first pad the beginning of the tensor (left side), # and then pad the end of the tensor (right side). @@ -4014,15 +4524,13 @@ def _pad_circular(input, padding): i1 = out_shape[3] - max(padding[-3], 0) o0 = 0 o1 = padding[-4] - out[:, :, :, o0:o1] = \ - out[:, :, :, i0:i1] + out[:, :, :, o0:o1] = out[:, :, :, i0:i1] if padding[-3] > 0: i0 = max(padding[-4], 0) i1 = max(padding[-4], 0) + padding[-3] o0 = out_shape[3] - padding[-3] o1 = out_shape[3] - out[:, :, :, o0:o1] = \ - out[:, :, :, i0:i1] + out[:, :, :, o0:o1] = out[:, :, :, i0:i1] # Pad third dimension (width) if len(padding) > 4: @@ -4031,43 +4539,42 @@ def _pad_circular(input, padding): i1 = out_shape[4] - max(padding[-5], 0) o0 = 0 o1 = padding[-6] - out[:, :, :, :, o0:o1] = \ - out[:, :, :, :, i0:i1] + out[:, :, :, :, o0:o1] = out[:, :, :, :, i0:i1] if padding[-5] > 0: i0 = max(padding[-6], 0) i1 = max(padding[-6], 0) + padding[-5] o0 = out_shape[4] - padding[-5] o1 = out_shape[4] - out[:, :, :, :, o0:o1] = \ - out[:, :, :, :, i0:i1] + out[:, :, :, :, o0:o1] = out[:, :, :, :, i0:i1] return out -def multi_head_attention_forward(query: Tensor, - key: Tensor, - value: Tensor, - embed_dim_to_check: int, - num_heads: int, - in_proj_weight: Tensor, - in_proj_bias: Tensor, - bias_k: Optional[Tensor], - bias_v: Optional[Tensor], - add_zero_attn: bool, - dropout_p: float, - out_proj_weight: Tensor, - out_proj_bias: Tensor, - training: bool = True, - key_padding_mask: Optional[Tensor] = None, - need_weights: bool = True, - attn_mask: Optional[Tensor] = None, - use_separate_proj_weight: bool = False, - q_proj_weight: Optional[Tensor] = None, - k_proj_weight: Optional[Tensor] = None, - v_proj_weight: Optional[Tensor] = None, - static_k: Optional[Tensor] = None, - static_v: Optional[Tensor] = None - ) -> Tuple[Tensor, Optional[Tensor]]: +def multi_head_attention_forward( + query: Tensor, + key: Tensor, + value: Tensor, + embed_dim_to_check: int, + num_heads: int, + in_proj_weight: Tensor, + in_proj_bias: Tensor, + bias_k: Optional[Tensor], + bias_v: Optional[Tensor], + add_zero_attn: bool, + dropout_p: float, + out_proj_weight: Tensor, + out_proj_bias: Tensor, + training: bool = True, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + use_separate_proj_weight: bool = False, + q_proj_weight: Optional[Tensor] = None, + k_proj_weight: Optional[Tensor] = None, + v_proj_weight: Optional[Tensor] = None, + static_k: Optional[Tensor] = None, + static_v: Optional[Tensor] = None, +) -> Tuple[Tensor, Optional[Tensor]]: r""" Args: query, key, value: map a query and a set of key-value pairs to an output. @@ -4125,18 +4632,35 @@ def multi_head_attention_forward(query: Tensor, L is the target sequence length, S is the source sequence length. """ if not torch.jit.is_scripting(): - tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, - out_proj_weight, out_proj_bias) + tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias) if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): return handle_torch_function( - multi_head_attention_forward, tens_ops, query, key, value, - embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, - bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight, - out_proj_bias, training=training, key_padding_mask=key_padding_mask, - need_weights=need_weights, attn_mask=attn_mask, + multi_head_attention_forward, + tens_ops, + query, + key, + value, + embed_dim_to_check, + num_heads, + in_proj_weight, + in_proj_bias, + bias_k, + bias_v, + add_zero_attn, + dropout_p, + out_proj_weight, + out_proj_bias, + training=training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, use_separate_proj_weight=use_separate_proj_weight, - q_proj_weight=q_proj_weight, k_proj_weight=k_proj_weight, - v_proj_weight=v_proj_weight, static_k=static_k, static_v=static_v) + q_proj_weight=q_proj_weight, + k_proj_weight=k_proj_weight, + v_proj_weight=v_proj_weight, + static_k=static_k, + static_v=static_v, + ) tgt_len, bsz, embed_dim = query.size() assert embed_dim == embed_dim_to_check # allow MHA to have different sizes for the feature dimension @@ -4151,7 +4675,7 @@ def multi_head_attention_forward(query: Tensor, # self-attention q, k, v = linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1) - elif (key is value or torch.equal(key, value)): + elif key is value or torch.equal(key, value): # encoder-decoder attention # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias @@ -4219,8 +4743,8 @@ def multi_head_attention_forward(query: Tensor, if in_proj_bias is not None: q = linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim]) - k = linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim:(embed_dim * 2)]) - v = linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2):]) + k = linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim : (embed_dim * 2)]) + v = linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2) :]) else: q = linear(query, q_proj_weight_non_opt, in_proj_bias) k = linear(key, k_proj_weight_non_opt, in_proj_bias) @@ -4228,9 +4752,13 @@ def multi_head_attention_forward(query: Tensor, q = q * scaling if attn_mask is not None: - assert attn_mask.dtype == torch.float32 or attn_mask.dtype == torch.float64 or \ - attn_mask.dtype == torch.float16 or attn_mask.dtype == torch.uint8 or attn_mask.dtype == torch.bool, \ - 'Only float, byte, and bool types are supported for attn_mask, not {}'.format(attn_mask.dtype) + assert ( + attn_mask.dtype == torch.float32 + or attn_mask.dtype == torch.float64 + or attn_mask.dtype == torch.float16 + or attn_mask.dtype == torch.uint8 + or attn_mask.dtype == torch.bool + ), "Only float, byte, and bool types are supported for attn_mask, not {}".format(attn_mask.dtype) if attn_mask.dtype == torch.uint8: warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.") attn_mask = attn_mask.to(torch.bool) @@ -4238,17 +4766,19 @@ def multi_head_attention_forward(query: Tensor, if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError('The size of the 2D attn_mask is not correct.') + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [bsz * num_heads, query.size(0), key.size(0)]: - raise RuntimeError('The size of the 3D attn_mask is not correct.') + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim())) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: - warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.") + warnings.warn( + "Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead." + ) key_padding_mask = key_padding_mask.to(torch.bool) if bias_k is not None and bias_v is not None: @@ -4302,21 +4832,19 @@ def multi_head_attention_forward(query: Tensor, if attn_mask is not None: if attn_mask.dtype == torch.bool: - attn_output_weights.masked_fill_(attn_mask, float('-inf')) + attn_output_weights.masked_fill_(attn_mask, float("-inf")) else: attn_output_weights += attn_mask - if key_padding_mask is not None: attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) attn_output_weights = attn_output_weights.masked_fill( key_padding_mask.unsqueeze(1).unsqueeze(2), - float('-inf'), + float("-inf"), ) attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len) - attn_output_weights = softmax( - attn_output_weights, dim=-1) + attn_output_weights = softmax(attn_output_weights, dim=-1) attn_output_weights = dropout(attn_output_weights, p=dropout_p, training=training) attn_output = torch.bmm(attn_output_weights, v) From f83d57f99e07ae14722e47f481995e37a975f172 Mon Sep 17 00:00:00 2001 From: Richard Barnes Date: Thu, 7 Jan 2021 15:34:53 -0800 Subject: [PATCH 07/54] [Don't review] Clean up type annotations in caffe2/torch/nn (#50079) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/50079 Test Plan: Sandcastle tests Reviewed By: xush6528 Differential Revision: D25718694 fbshipit-source-id: f535fb879bcd4cb4ea715adfd90bbffa3fcc1150 --- torch/nn/_reduction.py | 9 +++------ torch/nn/init.py | 24 ++++++++---------------- 2 files changed, 11 insertions(+), 22 deletions(-) diff --git a/torch/nn/_reduction.py b/torch/nn/_reduction.py index 025ef157b958..2b5a91276bd9 100644 --- a/torch/nn/_reduction.py +++ b/torch/nn/_reduction.py @@ -4,8 +4,7 @@ # NB: Keep this file in sync with enums in aten/src/ATen/core/Reduction.h -def get_enum(reduction): - # type: (str) -> int +def get_enum(reduction: str) -> int: if reduction == 'none': ret = 0 elif reduction == 'mean': @@ -25,8 +24,7 @@ def get_enum(reduction): # We use these functions in torch/legacy as well, in which case we'll silence the warning -def legacy_get_string(size_average, reduce, emit_warning=True): - # type: (Optional[bool], Optional[bool], bool) -> str +def legacy_get_string(size_average: Optional[bool], reduce: Optional[bool], emit_warning: bool = True) -> str: warning = "size_average and reduce args will be deprecated, please use reduction='{}' instead." if size_average is None: @@ -45,6 +43,5 @@ def legacy_get_string(size_average, reduce, emit_warning=True): return ret -def legacy_get_enum(size_average, reduce, emit_warning=True): - # type: (Optional[bool], Optional[bool], bool) -> int +def legacy_get_enum(size_average: Optional[bool], reduce: Optional[bool], emit_warning: bool = True) -> int: return get_enum(legacy_get_string(size_average, reduce, emit_warning)) diff --git a/torch/nn/init.py b/torch/nn/init.py index c11dba648c5a..3c4149ff8e81 100644 --- a/torch/nn/init.py +++ b/torch/nn/init.py @@ -106,8 +106,7 @@ def calculate_gain(nonlinearity, param=None): raise ValueError("Unsupported nonlinearity {}".format(nonlinearity)) -def uniform_(tensor, a=0., b=1.): - # type: (Tensor, float, float) -> Tensor +def uniform_(tensor: Tensor, a: float = 0., b: float = 1.) -> Tensor: r"""Fills the input Tensor with values drawn from the uniform distribution :math:`\mathcal{U}(a, b)`. @@ -123,8 +122,7 @@ def uniform_(tensor, a=0., b=1.): return _no_grad_uniform_(tensor, a, b) -def normal_(tensor, mean=0., std=1.): - # type: (Tensor, float, float) -> Tensor +def normal_(tensor: Tensor, mean: float = 0., std: float = 1.) -> Tensor: r"""Fills the input Tensor with values drawn from the normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`. @@ -139,8 +137,7 @@ def normal_(tensor, mean=0., std=1.): """ return _no_grad_normal_(tensor, mean, std) -def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): - # type: (Tensor, float, float, float, float) -> Tensor +def trunc_normal_(tensor: Tensor, mean: float = 0., std: float = 1., a: float = -2., b: float = 2.) -> Tensor: r"""Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn from the normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` @@ -162,8 +159,7 @@ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): return _no_grad_trunc_normal_(tensor, mean, std, a, b) -def constant_(tensor, val): - # type: (Tensor, float) -> Tensor +def constant_(tensor: Tensor, val: float) -> Tensor: r"""Fills the input Tensor with the value :math:`\text{val}`. Args: @@ -177,8 +173,7 @@ def constant_(tensor, val): return _no_grad_fill_(tensor, val) -def ones_(tensor): - # type: (Tensor) -> Tensor +def ones_(tensor: Tensor) -> Tensor: r"""Fills the input Tensor with the scalar value `1`. Args: @@ -191,8 +186,7 @@ def ones_(tensor): return _no_grad_fill_(tensor, 1.) -def zeros_(tensor): - # type: (Tensor) -> Tensor +def zeros_(tensor: Tensor) -> Tensor: r"""Fills the input Tensor with the scalar value `0`. Args: @@ -284,8 +278,7 @@ def _calculate_fan_in_and_fan_out(tensor): return fan_in, fan_out -def xavier_uniform_(tensor, gain=1.): - # type: (Tensor, float) -> Tensor +def xavier_uniform_(tensor: Tensor, gain: float = 1.) -> Tensor: r"""Fills the input `Tensor` with values according to the method described in `Understanding the difficulty of training deep feedforward neural networks` - Glorot, X. & Bengio, Y. (2010), using a uniform @@ -312,8 +305,7 @@ def xavier_uniform_(tensor, gain=1.): return _no_grad_uniform_(tensor, -a, a) -def xavier_normal_(tensor, gain=1.): - # type: (Tensor, float) -> Tensor +def xavier_normal_(tensor: Tensor, gain: float = 1.) -> Tensor: r"""Fills the input `Tensor` with values according to the method described in `Understanding the difficulty of training deep feedforward neural networks` - Glorot, X. & Bengio, Y. (2010), using a normal From 09eefec627d7b54a3c4906dd9f16f8927e8fdbba Mon Sep 17 00:00:00 2001 From: Richard Barnes Date: Thu, 7 Jan 2021 15:36:48 -0800 Subject: [PATCH 08/54] Clean up some type annotations in android (#49944) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49944 Upgrades type annotations from Python2 to Python3 Test Plan: Sandcastle tests Reviewed By: xush6528 Differential Revision: D25717539 fbshipit-source-id: c621e2712e87eaed08cda48eb0fb224f6b0570c9 --- .../generate_test_torchscripts.py | 60 +++++++------------ android/pytorch_android/test_asset.jit | 60 +++++++------------ 2 files changed, 40 insertions(+), 80 deletions(-) diff --git a/android/pytorch_android/generate_test_torchscripts.py b/android/pytorch_android/generate_test_torchscripts.py index 6384d588e9aa..8b41fefc246e 100644 --- a/android/pytorch_android/generate_test_torchscripts.py +++ b/android/pytorch_android/generate_test_torchscripts.py @@ -20,92 +20,77 @@ def forward(self, input): return None @torch.jit.script_method - def eqBool(self, input): - # type: (bool) -> bool + def eqBool(self, input: bool) -> bool: return input @torch.jit.script_method - def eqInt(self, input): - # type: (int) -> int + def eqInt(self, input: int) -> int: return input @torch.jit.script_method - def eqFloat(self, input): - # type: (float) -> float + def eqFloat(self, input: float) -> float: return input @torch.jit.script_method - def eqStr(self, input): - # type: (str) -> str + def eqStr(self, input: str) -> str: return input @torch.jit.script_method - def eqTensor(self, input): - # type: (Tensor) -> Tensor + def eqTensor(self, input: Tensor) -> Tensor: return input @torch.jit.script_method - def eqDictStrKeyIntValue(self, input): - # type: (Dict[str, int]) -> Dict[str, int] + def eqDictStrKeyIntValue(self, input: Dict[str, int]) -> Dict[str, int]: return input @torch.jit.script_method - def eqDictIntKeyIntValue(self, input): - # type: (Dict[int, int]) -> Dict[int, int] + def eqDictIntKeyIntValue(self, input: Dict[int, int]) -> Dict[int, int]: return input @torch.jit.script_method - def eqDictFloatKeyIntValue(self, input): - # type: (Dict[float, int]) -> Dict[float, int] + def eqDictFloatKeyIntValue(self, input: Dict[float, int]) -> Dict[float, int]: return input @torch.jit.script_method - def listIntSumReturnTuple(self, input): - # type: (List[int]) -> Tuple[List[int], int] + def listIntSumReturnTuple(self, input: List[int]) -> Tuple[List[int], int]: sum = 0 for x in input: sum += x return (input, sum) @torch.jit.script_method - def listBoolConjunction(self, input): - # type: (List[bool]) -> bool + def listBoolConjunction(self, input: List[bool]) -> bool: res = True for x in input: res = res and x return res @torch.jit.script_method - def listBoolDisjunction(self, input): - # type: (List[bool]) -> bool + def listBoolDisjunction(self, input: List[bool]) -> bool: res = False for x in input: res = res or x return res @torch.jit.script_method - def tupleIntSumReturnTuple(self, input): - # type: (Tuple[int, int, int]) -> Tuple[Tuple[int, int, int], int] + def tupleIntSumReturnTuple(self, input: Tuple[int, int, int]) -> Tuple[Tuple[int, int, int], int]: sum = 0 for x in input: sum += x return (input, sum) @torch.jit.script_method - def optionalIntIsNone(self, input): - # type: (Optional[int]) -> bool + def optionalIntIsNone(self, input: Optional[int]) -> bool: return input is None @torch.jit.script_method - def intEq0None(self, input): - # type: (int) -> Optional[int] + def intEq0None(self, input: int) -> Optional[int]: if input == 0: return None return input @torch.jit.script_method - def str3Concat(self, input): - # type: (str) -> str + def str3Concat(self, input: str) -> str: return input + input + input @torch.jit.script_method @@ -113,8 +98,7 @@ def newEmptyShapeWithItem(self, input): return torch.tensor([int(input.item())])[0] @torch.jit.script_method - def testAliasWithOffset(self): - # type: () -> List[Tensor] + def testAliasWithOffset(self) -> List[Tensor]: x = torch.tensor([100, 200]) a = [x[0], x[1]] return a @@ -128,8 +112,7 @@ def testNonContiguous(self): return x @torch.jit.script_method - def conv2d(self, x, w, toChannelsLast): - # type: (Tensor, Tensor, bool) -> Tensor + def conv2d(self, x: Tensor, w: Tensor, toChannelsLast: bool) -> Tensor: r = torch.nn.functional.conv2d(x, w) if (toChannelsLast): r = r.contiguous(memory_format=torch.channels_last) @@ -138,18 +121,15 @@ def conv2d(self, x, w, toChannelsLast): return r @torch.jit.script_method - def contiguous(self, x): - # type: (Tensor) -> Tensor + def contiguous(self, x: Tensor) -> Tensor: return x.contiguous() @torch.jit.script_method - def contiguousChannelsLast(self, x): - # type: (Tensor) -> Tensor + def contiguousChannelsLast(self, x: Tensor) -> Tensor: return x.contiguous(memory_format=torch.channels_last) @torch.jit.script_method - def contiguousChannelsLast3d(self, x): - # type: (Tensor) -> Tensor + def contiguousChannelsLast3d(self, x: Tensor) -> Tensor: return x.contiguous(memory_format=torch.channels_last_3d) scriptAndSave(Test(), "test.pt") diff --git a/android/pytorch_android/test_asset.jit b/android/pytorch_android/test_asset.jit index 49a41eff36a6..3bd9037da4ee 100644 --- a/android/pytorch_android/test_asset.jit +++ b/android/pytorch_android/test_asset.jit @@ -1,85 +1,69 @@ def forward(self, input): return None -def eqBool(self, input): - # type: (bool) -> bool +def eqBool(self, input: bool) -> bool: return input -def eqInt(self, input): - # type: (int) -> int +def eqInt(self, input: int) -> int: return input -def eqFloat(self, input): - # type: (float) -> float +def eqFloat(self, input: float) -> float: return input -def eqStr(self, input): - # type: (str) -> str +def eqStr(self, input: str) -> str: return input -def eqTensor(self, input): - # type: (Tensor) -> Tensor +def eqTensor(self, input: Tensor) -> Tensor: return input -def eqDictStrKeyIntValue(self, input): - # type: (Dict[str, int]) -> Dict[str, int] +def eqDictStrKeyIntValue(self, input: Dict[str, int]) -> Dict[str, int]: return input -def eqDictIntKeyIntValue(self, input): - # type: (Dict[int, int]) -> Dict[int, int] +def eqDictIntKeyIntValue(self, input: Dict[int, int]) -> Dict[int, int]: return input -def eqDictFloatKeyIntValue(self, input): - # type: (Dict[float, int]) -> Dict[float, int] +def eqDictFloatKeyIntValue(self, input: Dict[float, int]) -> Dict[float, int]: return input -def listIntSumReturnTuple(self, input): - # type: (List[int]) -> Tuple[List[int], int] +def listIntSumReturnTuple(self, input: List[int]) -> Tuple[List[int], int]: sum = 0 for x in input: sum += x return (input, sum) -def listBoolConjunction(self, input): - # type: (List[bool]) -> bool +def listBoolConjunction(self, input: List[bool]) -> bool: res = True for x in input: res = res and x return res -def listBoolDisjunction(self, input): - # type: (List[bool]) -> bool +def listBoolDisjunction(self, input: List[bool]) -> bool: res = False for x in input: res = res or x return res -def tupleIntSumReturnTuple(self, input): - # type: (Tuple[int, int, int]) -> Tuple[Tuple[int, int, int], int] +def tupleIntSumReturnTuple(self, input: Tuple[int, int, int]) -> Tuple[Tuple[int, int, int], int]: sum = 0 for x in input: sum += x return (input, sum) -def optionalIntIsNone(self, input): - # type: (Optional[int]) -> bool +def optionalIntIsNone(self, input: Optional[int]) -> bool: return input is None -def intEq0None(self, input): - # type: (int) -> Optional[int] +def intEq0None(self, input: int) -> Optional[int]: if input == 0: return None return input -def str3Concat(self, input): - # type: (str) -> str +def str3Concat(self, input: str) -> str: return input + input + input def newEmptyShapeWithItem(self, input): return torch.tensor([int(input.item())])[0] -def testAliasWithOffset(self): - # type: () -> List[Tensor] +def testAliasWithOffset(self) -> List[Tensor]: x = torch.tensor([100, 200]) a = [x[0], x[1]] return a @@ -91,8 +75,7 @@ def testNonContiguous(self): assert x[1] == 300 return x -def conv2d(self, x, w, toChannelsLast): - # type: (Tensor, Tensor, bool) -> Tensor +def conv2d(self, x: Tensor, w: Tensor, toChannelsLast: bool) -> Tensor: r = torch.conv2d(x, w) if (toChannelsLast): # memory_format=torch.channels_last @@ -101,16 +84,13 @@ def conv2d(self, x, w, toChannelsLast): r = r.contiguous() return r -def contiguous(self, x): - # type: (Tensor) -> Tensor +def contiguous(self, x: Tensor) -> Tensor: return x.contiguous() -def contiguousChannelsLast(self, x): - # type: (Tensor) -> Tensor +def contiguousChannelsLast(self, x: Tensor) -> Tensor: # memory_format=torch.channels_last return x.contiguous(memory_format=2) -def contiguousChannelsLast3d(self, x): - # type: (Tensor) -> Tensor +def contiguousChannelsLast3d(self, x: Tensor) -> Tensor: # memory_format=torch.channels_last_3d return x.contiguous(memory_format=3) From ce370398cc19ab587d4ad1971f361227ccd75a6b Mon Sep 17 00:00:00 2001 From: Yi Wang Date: Thu, 7 Jan 2021 15:54:25 -0800 Subject: [PATCH 09/54] [Gradient Compression] Remove the extra comma after "bucket" in PowerSGD hook signatures (#50197) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/50197 Remove the extra comma after "bucket". ghstack-source-id: 119513484 Test Plan: waitforbuildbot Reviewed By: rohan-varma Differential Revision: D25823117 fbshipit-source-id: acf048f7cb732c23cba3a81ccce1e70f6b9f4299 --- .../algorithms/ddp_comm_hooks/powerSGD_hook.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py b/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py index b1be0b52c9e6..7183aa1a82a3 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py @@ -82,10 +82,7 @@ def __init__( self.q_memory_dict = {} -def powerSGD_hook( - state: PowerSGDState, - bucket, -) -> torch.futures.Future: +def powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future: """ This DDP communication hook implements the original PowerSGD gradient compression algorithm described in https://arxiv.org/abs/1905.13727. @@ -322,10 +319,7 @@ def decompress(fut): ) -def batched_powerSGD_hook( - state: PowerSGDState, - bucket, -) -> torch.futures.Future: +def batched_powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future: """ This DDP communication hook implements a simplified PowerSGD gradient compression algorithm described in https://arxiv.org/abs/1905.13727. From 870ab04b64dcc88b716756d96ac37b7a6c0592fe Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Thu, 7 Jan 2021 16:18:01 -0800 Subject: [PATCH 10/54] add type annotations to torch._utils (#49705) Summary: closes gh-49704 Pull Request resolved: https://github.com/pytorch/pytorch/pull/49705 Reviewed By: mruberry Differential Revision: D25725352 Pulled By: malfet fbshipit-source-id: 05a7041c9caffde4a5c1eb8af0d13697075103af --- mypy.ini | 17 ++++++++++++++++- tools/pyi/gen_pyi.py | 3 +++ torch/_utils.py | 30 +++++++++++++++--------------- 3 files changed, 34 insertions(+), 16 deletions(-) diff --git a/mypy.ini b/mypy.ini index bab4ce5dfd42..6c579ee9399d 100644 --- a/mypy.ini +++ b/mypy.ini @@ -106,7 +106,22 @@ ignore_errors = True [mypy-torch._appdirs] ignore_errors = True -[mypy-torch._utils] +[mypy-torch._overrides] +ignore_errors = True + +[mypy-torch.utils.tensorboard._caffe2_graph] +ignore_errors = True + +[mypy-torch.contrib._tensorboard_vis] +ignore_errors = True + +[mypy-torch.nn.utils.prune] +ignore_errors = True + +[mypy-torch.utils.show_pickle] +ignore_errors = True + +[mypy-torch.utils.hipify.hipify_python] ignore_errors = True [mypy-torch.utils.benchmark.examples.*] diff --git a/tools/pyi/gen_pyi.py b/tools/pyi/gen_pyi.py index 7ad514d5d067..c5a9209d36b7 100644 --- a/tools/pyi/gen_pyi.py +++ b/tools/pyi/gen_pyi.py @@ -302,6 +302,9 @@ def gen_pyi(native_yaml_path: str, deprecated_yaml_path: str, out: str) -> None: 'sparse_coo_tensor': ['def sparse_coo_tensor(indices: Tensor, values: Union[Tensor,List],' ' size: Optional[_size]=None, *, dtype: Optional[_dtype]=None,' ' device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ...'], + '_sparse_coo_tensor_unsafe': ['def _sparse_coo_tensor_unsafe(indices: Tensor, values: Tensor, size: List[int],' + ' dtype: Optional[_dtype] = None, device: Optional[_device] = None,' + ' requires_grad: bool = False) -> Tensor: ...'], 'range': ['def range(start: Number, end: Number,' ' step: Number=1, *, out: Optional[Tensor]=None, {}) -> Tensor: ...' .format(FACTORY_PARAMS)], diff --git a/torch/_utils.py b/torch/_utils.py index 796e88a3cc2d..75eadd4a990e 100644 --- a/torch/_utils.py +++ b/torch/_utils.py @@ -1,6 +1,6 @@ import torch import torch._six -from typing import Optional +from typing import Optional, List, DefaultDict import warnings from collections import defaultdict import sys @@ -37,9 +37,9 @@ def _type(self, dtype=None, non_blocking=False, **kwargs): raise RuntimeError("Cannot cast sparse tensor to dense tensor") new_module_name = dtype.__module__.replace('.sparse', '') new_values_type_name = new_module_name + '.' + dtype.__name__ - new_values = torch._values(self).type(new_values_type_name, non_blocking) + new_values = torch.Tensor._values(self).type(new_values_type_name, non_blocking) new_indices_type_name = new_module_name + '.LongTensor' - new_indices = torch._indices(self).type(new_indices_type_name, non_blocking) + new_indices = torch.Tensor._indices(self).type(new_indices_type_name, non_blocking) return dtype(new_indices, new_values, self.size()) if dtype.is_sparse: raise RuntimeError("Cannot cast dense tensor to sparse tensor") @@ -72,8 +72,8 @@ def _cuda(self, device=None, non_blocking=False, **kwargs): with torch.cuda.device(device): if self.is_sparse: new_type = getattr(torch.cuda.sparse, self.__class__.__name__) - indices = torch._indices(self).cuda(device, non_blocking) - values = torch._values(self).cuda(device, non_blocking) + indices = torch.Tensor._indices(self).cuda(device, non_blocking) + values = torch.Tensor._values(self).cuda(device, non_blocking) return new_type(indices, values, self.size()) else: new_type = getattr(torch.cuda, self.__class__.__name__) @@ -144,7 +144,7 @@ def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, bac return tensor -_sparse_tensors_to_validate = [] +_sparse_tensors_to_validate: List["torch.Tensor"] = [] # In _legacy_load() in serialization.py we unpickle storages after the sparse # tensors have been already unpickled. Those storages contain data necessary for @@ -271,8 +271,8 @@ def _flatten_sparse_tensors(tensors): A tuple of two contiguous 1D buffers, one containing input tensors' indices and the other containing the values. """ - flat_indices = _flatten_dense_tensors([torch._indices(t) for t in tensors]) - flat_values = _flatten_dense_tensors([torch._values(t) for t in tensors]) + flat_indices = _flatten_dense_tensors([torch.Tensor._indices(t) for t in tensors]) + flat_values = _flatten_dense_tensors([torch.Tensor._values(t) for t in tensors]) return flat_indices, flat_values @@ -314,8 +314,8 @@ def _unflatten_sparse_tensors(flat, tensors): flat. """ flat_indices, flat_values = flat - indices = _unflatten_dense_tensors(flat_indices, [torch._indices(t) for t in tensors]) - values = _unflatten_dense_tensors(flat_values, [torch._values(t) for t in tensors]) + indices = _unflatten_dense_tensors(flat_indices, [torch.Tensor._indices(t) for t in tensors]) + values = _unflatten_dense_tensors(flat_values, [torch.Tensor._values(t) for t in tensors]) outputs = [] for t, i, v in zip(tensors, indices, values): outputs.append(t.new(i, v, t.size())) @@ -340,8 +340,8 @@ def _reorder_tensors_as(tensors, ordered_tensors): type_dict = defaultdict(list) for tensor in tensors: type_dict[tensor.type()].append(tensor) - type_dict = {t: iter(coll) for t, coll in type_dict.items()} - return tuple(next(type_dict[tensor.type()]) for tensor in ordered_tensors) + type_dict_ = {t: iter(coll) for t, coll in type_dict.items()} + return tuple(next(type_dict_[tensor.type()]) for tensor in ordered_tensors) def _take_tensors(tensors, size_limit): @@ -356,12 +356,12 @@ def _take_tensors(tensors, size_limit): Blocks of tensors of same type and within size_limit. The yielded tensors are only ordered as the original sequence within its types. """ - buf_dict = defaultdict(lambda: [[], 0]) + buf_dict: DefaultDict[str, List] = defaultdict(lambda: [[], 0]) for tensor in tensors: t = tensor.type() if tensor.is_sparse: - indices = torch._indices(tensor) - values = torch._values(tensor) + indices = torch.Tensor._indices(tensor) + values = torch.Tensor._values(tensor) size = indices.numel() * indices.element_size() + values.numel() * values.element_size() else: size = tensor.numel() * tensor.element_size() From bf4fcab68104b333c81e39d142dccc7ae7a6f1d7 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Thu, 7 Jan 2021 18:29:00 -0800 Subject: [PATCH 11/54] Fix SyncBatchNorm usage without stats tracking (#50126) Summary: In `batch_norm_gather_stats_with_counts_cuda` use `input.scalar_type()` if `running_mean` is not defined In `SyncBatchNorm` forward function create count tensor with `torch.float32` type if `running_mean` is None Fix a few typos Pull Request resolved: https://github.com/pytorch/pytorch/pull/50126 Test Plan: ``` python -c "import torch;print(torch.batch_norm_gather_stats_with_counts( torch.randn(1, 3, 3, 3, device='cuda'), mean = torch.ones(2, 3, device='cuda'), invstd = torch.ones(2, 3, device='cuda'), running_mean = None, running_var = None , momentum = .1, eps = 1e-5, counts = torch.ones(2, device='cuda')))" ``` Fixes https://github.com/pytorch/pytorch/issues/49730 Reviewed By: ngimel Differential Revision: D25797930 Pulled By: malfet fbshipit-source-id: 22a91e3969b5e9bbb7969d9cc70b45013a42fe83 --- aten/src/ATen/native/Normalization.cpp | 2 +- aten/src/ATen/native/cuda/Normalization.cu | 12 +++++++----- aten/src/ATen/native/cuda/Normalization.cuh | 2 +- test/test_cuda.py | 8 ++++++++ torch/nn/modules/_functions.py | 9 +++++---- torch/nn/modules/batchnorm.py | 2 +- 6 files changed, 23 insertions(+), 12 deletions(-) diff --git a/aten/src/ATen/native/Normalization.cpp b/aten/src/ATen/native/Normalization.cpp index 1ac4250a9d54..ea4a54c13196 100644 --- a/aten/src/ATen/native/Normalization.cpp +++ b/aten/src/ATen/native/Normalization.cpp @@ -118,7 +118,7 @@ void batch_norm_cpu_inference_channels_last(Tensor& output, const Tensor& input, // output(n, c, h, w) = input(n, c, h, w) * alpha(c) + beta(c) // No need to use parallel_for as this function is supposed to be // memory-limited. - // Keep the loop struture simple to make sure compiler vectorization kicks in. + // Keep the loop structure simple to make sure compiler vectorization kicks in. if (n_channel != 1) { for (int64_t n = 0; n < n_batch; ++n) { for (int64_t i = 0; i < image_size; ++i) { diff --git a/aten/src/ATen/native/cuda/Normalization.cu b/aten/src/ATen/native/cuda/Normalization.cu index 186099dfde50..6bf4e0f32f13 100644 --- a/aten/src/ATen/native/cuda/Normalization.cu +++ b/aten/src/ATen/native/cuda/Normalization.cu @@ -120,15 +120,17 @@ std::tuple batch_norm_gather_stats_cuda(const Tensor& self, cons const Tensor& running_var, double momentum, double epsilon, int64_t count) { std::vector counts(mean.size(0), count); Tensor counts_ = at::from_blob((void*)counts.data(), {(int64_t)counts.size()}, self.options().dtype(at::kLong).device(at::kCPU)); - counts_ = counts_.to(self.device()).to(running_mean.dtype()); + counts_ = counts_.to(self.device()).to(running_mean.defined() ? running_mean.dtype() : self.dtype()); return batch_norm_gather_stats_with_counts_cuda(self, mean, invstd, running_mean, running_var, momentum, epsilon, counts_); } -std::tuple batch_norm_gather_stats_with_counts_cuda(const Tensor& self, const Tensor& mean, const Tensor& invstd, const Tensor& running_mean, - const Tensor& running_var, double momentum, double epsilon, const Tensor& counts) { - - return AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, running_mean.scalar_type(), "batch_norm_update_stats_cuda", [&] { +std::tuple batch_norm_gather_stats_with_counts_cuda( + const Tensor& self, const Tensor& mean, const Tensor& invstd, const Tensor& running_mean /* optional */, + const Tensor& running_var /* optional */, double momentum, double epsilon, const Tensor& counts) { + + auto scalar_type = running_mean.defined() ? running_mean.scalar_type() : self.scalar_type(); + return AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, scalar_type, "batch_norm_update_stats_cuda", [&] { using accscalar_t = at::acc_type; if (cuda::detail::canUse32BitIndexMath(self)) { return batch_norm_gather_stats_cuda_template(mean, invstd, running_mean, running_var, momentum, epsilon, counts); diff --git a/aten/src/ATen/native/cuda/Normalization.cuh b/aten/src/ATen/native/cuda/Normalization.cuh index 8355ac004308..a0445f129192 100644 --- a/aten/src/ATen/native/cuda/Normalization.cuh +++ b/aten/src/ATen/native/cuda/Normalization.cuh @@ -104,7 +104,7 @@ static __device__ __forceinline__ Float2 warpSum(Float2 (2C + 1) combined = torch.cat([mean, invstd, count], dim=0) diff --git a/torch/nn/modules/batchnorm.py b/torch/nn/modules/batchnorm.py index 48e58d637ea6..64417069e2b7 100644 --- a/torch/nn/modules/batchnorm.py +++ b/torch/nn/modules/batchnorm.py @@ -54,7 +54,7 @@ def __init__( def reset_running_stats(self) -> None: if self.track_running_stats: - # running_mean/running_var/num_batches... are registerd at runtime depending + # running_mean/running_var/num_batches... are registered at runtime depending # if self.track_running_stats is on self.running_mean.zero_() # type: ignore[operator] self.running_var.fill_(1) # type: ignore[operator] From 2e7c6cc9df3bff4f22fc8a63bd97e99c2029b751 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Thu, 7 Jan 2021 18:58:29 -0800 Subject: [PATCH 12/54] [PyTorch] Devirtualize TensorImpl::numel() with macro (#49766) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49766 Devirtualizing this seems like a decent performance improvement on internal benchmarks. The *reason* this is a performance improvement is twofold: 1) virtual calls are a bit slower than regular calls 2) virtual functions in `TensorImpl` can't be inlined Test Plan: internal benchmark Reviewed By: hlu1 Differential Revision: D25602321 fbshipit-source-id: d61556456ccfd7f10c6ebdc3a52263b438a2aef1 --- c10/core/TensorImpl.h | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index e7f9c1260263..dbe3ed8944c6 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -244,6 +244,21 @@ struct C10_API VariableVersion { } }; +/** + * NOTE: Some TensorImpl methods are small and not overridden in the + * PyTorch codebase itself, but may theoretically need to be + * overridden by third-party TensorImpl subclasses. This macro allows + * users that need maximum performance and don't need these extension + * points to disable them with a build-time flag. (In particular, + * XLA's XLATensorImpl currently overrides these methods, so we can't + * enable this flag by default.) + */ +#ifdef C10_DISABLE_TENSORIMPL_EXTENSIBILITY +#define TENSORIMPL_MAYBE_VIRTUAL +#else +#define TENSORIMPL_MAYBE_VIRTUAL virtual +#endif + /** * The low-level representation of a tensor, which contains a pointer * to a storage (which contains the actual data) and metadata (e.g., sizes and @@ -412,7 +427,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { * is no longer true; numel always accurately reports the product * of sizes of a tensor. */ - virtual int64_t numel() const { + TENSORIMPL_MAYBE_VIRTUAL int64_t numel() const { #ifdef DEBUG TORCH_INTERNAL_ASSERT(compute_numel() == numel_); #endif From 1a1b6658279b06c2a2023da3a01c1ba92e95fb08 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Thu, 7 Jan 2021 18:58:29 -0800 Subject: [PATCH 13/54] [PyTorch] validate that SparseTensorImpl::dim needn't be overridden (#49767) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49767 I'm told that the base implementation should work fine. Let's validate that in an intermediate diff before removing it. ghstack-source-id: 119528066 Test Plan: CI Reviewed By: ezyang, bhosmer Differential Revision: D25686830 fbshipit-source-id: f931394d3de6df7f6c5c68fe8ab711d90d3b12fd --- aten/src/ATen/SparseTensorImpl.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/aten/src/ATen/SparseTensorImpl.cpp b/aten/src/ATen/SparseTensorImpl.cpp index 8d7d4b2ce0f8..98670db11e86 100644 --- a/aten/src/ATen/SparseTensorImpl.cpp +++ b/aten/src/ATen/SparseTensorImpl.cpp @@ -70,6 +70,7 @@ void SparseTensorImpl::set_storage_offset(int64_t storage_offset) { } int64_t SparseTensorImpl::dim() const { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(sparse_dim_ + dense_dim_ == TensorImpl::dim()); return sparse_dim_ + dense_dim_; } bool SparseTensorImpl::has_storage() const { From 4de6b279c8975409b554ecdb0fcf611839e55f9b Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Thu, 7 Jan 2021 18:58:29 -0800 Subject: [PATCH 14/54] [PyTorch] Devirtualize TensorImpl::dim() with macro (#49770) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49770 Seems like the performance cost of making this commonly-called method virtual isn't worth having use of undefined tensors crash a bit earlier (they'll still fail to dispatch). ghstack-source-id: 119528065 Test Plan: framework overhead benchmarks Reviewed By: ezyang Differential Revision: D25687465 fbshipit-source-id: 89aabce165a594be401979c04236114a6f527b59 --- aten/src/ATen/SparseTensorImpl.cpp | 4 ---- aten/src/ATen/SparseTensorImpl.h | 1 - aten/src/ATen/test/undefined_tensor_test.cpp | 2 +- c10/core/TensorImpl.cpp | 2 ++ c10/core/TensorImpl.h | 9 ++++++++- c10/core/UndefinedTensorImpl.cpp | 4 ---- c10/core/UndefinedTensorImpl.h | 1 - 7 files changed, 11 insertions(+), 12 deletions(-) diff --git a/aten/src/ATen/SparseTensorImpl.cpp b/aten/src/ATen/SparseTensorImpl.cpp index 98670db11e86..0e18dca131a2 100644 --- a/aten/src/ATen/SparseTensorImpl.cpp +++ b/aten/src/ATen/SparseTensorImpl.cpp @@ -69,10 +69,6 @@ void SparseTensorImpl::set_storage_offset(int64_t storage_offset) { AT_ERROR("sparse tensors do not have set_storage_offset"); } -int64_t SparseTensorImpl::dim() const { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(sparse_dim_ + dense_dim_ == TensorImpl::dim()); - return sparse_dim_ + dense_dim_; -} bool SparseTensorImpl::has_storage() const { return false; } diff --git a/aten/src/ATen/SparseTensorImpl.h b/aten/src/ATen/SparseTensorImpl.h index db97b2202a5f..c80f5de18824 100644 --- a/aten/src/ATen/SparseTensorImpl.h +++ b/aten/src/ATen/SparseTensorImpl.h @@ -47,7 +47,6 @@ struct TORCH_API SparseTensorImpl : public TensorImpl { void set_stride(int64_t dim, int64_t new_stride) override; void set_storage_offset(int64_t storage_offset) override; - int64_t dim() const override; bool has_storage() const override; const Storage& storage() const override; int64_t storage_offset() const override; diff --git a/aten/src/ATen/test/undefined_tensor_test.cpp b/aten/src/ATen/test/undefined_tensor_test.cpp index 7a6dd50f9163..0102a8cf4f49 100644 --- a/aten/src/ATen/test/undefined_tensor_test.cpp +++ b/aten/src/ATen/test/undefined_tensor_test.cpp @@ -19,7 +19,7 @@ TEST(TestUndefined, UndefinedTest) { ASSERT_EQ(std::string("UndefinedType"), und.toString()); ASSERT_ANY_THROW(und.strides()); - ASSERT_ANY_THROW(und.dim()); + ASSERT_EQ(und.dim(), 1); ASSERT_ANY_THROW([]() { return Tensor(); }() = Scalar(5)); ASSERT_ANY_THROW(und.add(und)); ASSERT_ANY_THROW(und.add(ft)); diff --git a/c10/core/TensorImpl.cpp b/c10/core/TensorImpl.cpp index 5b4a4f3b83e6..f17ba95a754b 100644 --- a/c10/core/TensorImpl.cpp +++ b/c10/core/TensorImpl.cpp @@ -216,9 +216,11 @@ void TensorImpl::release_resources() { } } +#ifndef C10_DISABLE_TENSORIMPL_EXTENSIBILITY int64_t TensorImpl::dim() const { return sizes_.size(); } +#endif int64_t TensorImpl::size(int64_t d) const { d = at::maybe_wrap_dim(d, dim(), false); diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index dbe3ed8944c6..d5c5208472d3 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -402,7 +402,14 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { * Return the number of dimensions of this tensor. Note that 0-dimension * represents a Tensor that is a Scalar, e.g., one that has a single element. */ - virtual int64_t dim() const; + TENSORIMPL_MAYBE_VIRTUAL int64_t dim() const +#ifdef C10_DISABLE_TENSORIMPL_EXTENSIBILITY + { + return sizes_.size(); + } +#else + ; +#endif /** * True if this tensor has storage. See storage() for details. diff --git a/c10/core/UndefinedTensorImpl.cpp b/c10/core/UndefinedTensorImpl.cpp index f79897c72e04..ab7f9d42c5c4 100644 --- a/c10/core/UndefinedTensorImpl.cpp +++ b/c10/core/UndefinedTensorImpl.cpp @@ -20,10 +20,6 @@ int64_t UndefinedTensorImpl::stride(int64_t d) const { AT_ERROR("stride(dim) called on an undefined Tensor"); } -int64_t UndefinedTensorImpl::dim() const { - AT_ERROR("dim() called on undefined Tensor"); -} - bool UndefinedTensorImpl::has_storage() const { AT_ERROR("has_storage() called on undefined Tensor"); } diff --git a/c10/core/UndefinedTensorImpl.h b/c10/core/UndefinedTensorImpl.h index 26122ed305e2..efe1ed6a463e 100644 --- a/c10/core/UndefinedTensorImpl.h +++ b/c10/core/UndefinedTensorImpl.h @@ -21,7 +21,6 @@ struct C10_API UndefinedTensorImpl final : public TensorImpl { IntArrayRef strides() const override; int64_t size(int64_t d) const override; int64_t stride(int64_t d) const override; - int64_t dim() const override; bool has_storage() const override; const Storage& storage() const override; int64_t storage_offset() const override; From 84e3237a530efe82e5e3625c4fbef2d0bd995c41 Mon Sep 17 00:00:00 2001 From: Shen Li Date: Thu, 7 Jan 2021 19:43:44 -0800 Subject: [PATCH 15/54] Let RpcAgent::send() return JitFuture (#49906) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49906 This commit modifies RPC Message to inherit from `torch::CustomClassHolder`, and wraps a Message in an IValue in `RpcAgent::send()`. Test Plan: Imported from OSS Reviewed By: lw Differential Revision: D25719518 Pulled By: mrshenli fbshipit-source-id: 694e40021e49e396da1620a2f81226522341550b --- .../autograd/functions/recvrpc_backward.cpp | 4 +- torch/csrc/distributed/autograd/utils.cpp | 6 ++- torch/csrc/distributed/rpc/message.cpp | 12 ++++++ torch/csrc/distributed/rpc/message.h | 2 +- .../distributed/rpc/process_group_agent.cpp | 8 ++-- .../distributed/rpc/process_group_agent.h | 2 +- torch/csrc/distributed/rpc/py_rref.cpp | 6 +-- torch/csrc/distributed/rpc/rpc_agent.cpp | 4 +- torch/csrc/distributed/rpc/rpc_agent.h | 42 ++++++++++++++++++- .../csrc/distributed/rpc/tensorpipe_agent.cpp | 6 +-- torch/csrc/distributed/rpc/tensorpipe_agent.h | 2 +- .../testing/faulty_process_group_agent.cpp | 4 +- .../rpc/testing/faulty_process_group_agent.h | 2 +- 13 files changed, 78 insertions(+), 22 deletions(-) diff --git a/torch/csrc/distributed/autograd/functions/recvrpc_backward.cpp b/torch/csrc/distributed/autograd/functions/recvrpc_backward.cpp index 99951f098a22..08f18da081a5 100644 --- a/torch/csrc/distributed/autograd/functions/recvrpc_backward.cpp +++ b/torch/csrc/distributed/autograd/functions/recvrpc_backward.cpp @@ -47,8 +47,8 @@ variable_list RecvRpcBackward::apply(variable_list&& grads) { // Send the gradients over to the appropriate node. auto rpcAgent = rpc::RpcAgent::getCurrentRpcAgent(); - auto futureMessage = rpcAgent->send( - rpcAgent->getWorkerInfo(fromWorkerId_), std::move(gradCall).toMessage()); + auto futureMessage = rpc::RpcAgent::toFutureMessage(rpcAgent->send( + rpcAgent->getWorkerInfo(fromWorkerId_), std::move(gradCall).toMessage())); // Record the future in the context. sharedContext->addOutstandingRpc(futureMessage); diff --git a/torch/csrc/distributed/autograd/utils.cpp b/torch/csrc/distributed/autograd/utils.cpp index 464d8248d8a4..4802a9047884 100644 --- a/torch/csrc/distributed/autograd/utils.cpp +++ b/torch/csrc/distributed/autograd/utils.cpp @@ -160,9 +160,11 @@ std::shared_ptr sendMessageWithAutograd( std::move(msg), rpc::MessageType::RUN_WITH_PROFILING_REQ, std::move(profilerConfig)); - fut = agent.send(dst, std::move(msgWithProfiling), rpcTimeoutSeconds); + fut = rpc::RpcAgent::toFutureMessage( + agent.send(dst, std::move(msgWithProfiling), rpcTimeoutSeconds)); } else { - fut = agent.send(dst, std::move(msg), rpcTimeoutSeconds); + fut = rpc::RpcAgent::toFutureMessage( + agent.send(dst, std::move(msg), rpcTimeoutSeconds)); } return fut; diff --git a/torch/csrc/distributed/rpc/message.cpp b/torch/csrc/distributed/rpc/message.cpp index 14c267fd0699..b35e9149d1e6 100644 --- a/torch/csrc/distributed/rpc/message.cpp +++ b/torch/csrc/distributed/rpc/message.cpp @@ -104,6 +104,18 @@ Message createExceptionResponse(const std::string& exceptionStr, int64_t id) { id); } +namespace { + +// NB: need to call torch::class_ to register Message in the map returned by +// c10::getCustomClassTypeMap(). Otherwise, Message cannot be wrapped within +// an IValue. +// NB: add this line here instead of in rpc/init.cpp because 1) we have C++ +// only tests that won't run rpc/init.cpp; 2) Message is not meant to be +// visible from Python. +static const auto message = torch::class_("rpc", "_Message"); + +} // namespace + } // namespace rpc } // namespace distributed } // namespace torch diff --git a/torch/csrc/distributed/rpc/message.h b/torch/csrc/distributed/rpc/message.h index 7e458810db69..bad9569cfff8 100644 --- a/torch/csrc/distributed/rpc/message.h +++ b/torch/csrc/distributed/rpc/message.h @@ -93,7 +93,7 @@ enum MessageType { // Layers above ``RpcAgent`` only converts ScriptCall, ScriptResp, PythonCall, // and PythonResp into a Message, and it is up to the RpcAgent // implementation to determine how to serialize a message. -class TORCH_API Message final { +class TORCH_API Message final : public torch::CustomClassHolder { public: Message(); diff --git a/torch/csrc/distributed/rpc/process_group_agent.cpp b/torch/csrc/distributed/rpc/process_group_agent.cpp index b106f1442d31..71691d9dfbb1 100644 --- a/torch/csrc/distributed/rpc/process_group_agent.cpp +++ b/torch/csrc/distributed/rpc/process_group_agent.cpp @@ -287,7 +287,7 @@ void ProcessGroupAgent::shutdownImpl() { threadPool_.waitWorkComplete(); } -std::shared_ptr ProcessGroupAgent::send( +std::shared_ptr ProcessGroupAgent::send( const WorkerInfo& to, Message&& message, const float rpcTimeoutSeconds) { @@ -369,7 +369,7 @@ std::shared_ptr ProcessGroupAgent::send( // to our receiving queue. if (to.id_ == (worker_id_t)pg_->getRank()) { sendToSelf(std::move(message)); - return future; + return toJitFuture(std::move(future)); } // NB: cannot directly pass ``to`` to the ``SendWork``, because it might no @@ -382,7 +382,9 @@ std::shared_ptr ProcessGroupAgent::send( // the C++ land. Hence, we have to explicitly use the ``WorkerInfo`` in the // C++ land. enqueueSend(SendWork(allWorkerInfo_[to.id_], std::move(message))); - return future; + + auto jitFuture = toJitFuture(std::move(future)); + return jitFuture; } void ProcessGroupAgent::handleSend(const SendWork& work) { diff --git a/torch/csrc/distributed/rpc/process_group_agent.h b/torch/csrc/distributed/rpc/process_group_agent.h index 61d17f03e623..3d4393475f31 100644 --- a/torch/csrc/distributed/rpc/process_group_agent.h +++ b/torch/csrc/distributed/rpc/process_group_agent.h @@ -88,7 +88,7 @@ class TORCH_API ProcessGroupAgent : public RpcAgent { // This method wraps the destination information and the message into a // SendWork object, and put the SendWork into a queue. Another thread will // consume SendWork from the queue and send it out. - std::shared_ptr send( + std::shared_ptr send( const WorkerInfo& to, Message&& message, const float rpcTimeoutSeconds = kUnsetRpcTimeout) override; diff --git a/torch/csrc/distributed/rpc/py_rref.cpp b/torch/csrc/distributed/rpc/py_rref.cpp index 1c955a6baefb..f1ff4f29c206 100644 --- a/torch/csrc/distributed/rpc/py_rref.cpp +++ b/torch/csrc/distributed/rpc/py_rref.cpp @@ -331,10 +331,10 @@ void PyRRef::backward( // Invoke distributed backward remotely. auto rpcAgent = rpc::RpcAgent::getCurrentRpcAgent(); - rpcAgent - ->send( + rpc::RpcAgent::toFutureMessage( + rpcAgent->send( rpcAgent->getWorkerInfo(rref->owner()), - std::move(rrefBackwardReq).toMessage()) + std::move(rrefBackwardReq).toMessage())) ->wait(); } } diff --git a/torch/csrc/distributed/rpc/rpc_agent.cpp b/torch/csrc/distributed/rpc/rpc_agent.cpp index 4d9f6db39220..549408387bc0 100644 --- a/torch/csrc/distributed/rpc/rpc_agent.cpp +++ b/torch/csrc/distributed/rpc/rpc_agent.cpp @@ -62,7 +62,7 @@ std::shared_ptr RpcAgent::sendWithRetries( computeNewRpcRetryTime(retryOptions, /* retryCount */ 0); // Making a copy of the message so it can be retried after the first send. Message msgCopy = message; - auto fm = send(to, std::move(message)); + auto fm = toFutureMessage(send(to, std::move(message))); auto firstRetryRpc = std::make_shared( to, std::move(msgCopy), @@ -133,7 +133,7 @@ void RpcAgent::retryExpiredRpcs() { // with an error, since this RPC never succeeded and can no longer be // retried. try { - fm = send(earliestRpc->to_, std::move(msgCopy)); + fm = toFutureMessage(send(earliestRpc->to_, std::move(msgCopy))); futures.emplace_back(fm, earliestRpc); } catch (std::exception& e) { // We must store the futures and exception messages here and only mark diff --git a/torch/csrc/distributed/rpc/rpc_agent.h b/torch/csrc/distributed/rpc/rpc_agent.h index 34b77a085510..64592e6d98f7 100644 --- a/torch/csrc/distributed/rpc/rpc_agent.h +++ b/torch/csrc/distributed/rpc/rpc_agent.h @@ -157,7 +157,7 @@ class TORCH_API RpcAgent { // If ``message.isRequest()`` is true, the ``FutureMessage`` will be // completed when the response arrives. For other message types, the Future // should be ignored by the caller. - virtual std::shared_ptr send( + virtual std::shared_ptr send( const WorkerInfo& to, Message&& message, const float rpcTimeoutSeconds = kUnsetRpcTimeout) = 0; @@ -259,6 +259,46 @@ class TORCH_API RpcAgent { // Get the type resolver std::shared_ptr getTypeResolver(); + static std::shared_ptr toJitFuture( + std::shared_ptr&& fm) { + auto jitFuture = std::make_shared(at::AnyClassType::get()); + + std::weak_ptr wp = fm; + fm->addCallback( + [jitFuture, wp]() mutable { + auto future = wp.lock(); + TORCH_INTERNAL_ASSERT(future); + if (future->hasError()) { + jitFuture->setError(std::make_exception_ptr(*(future->error()))); + } else { + jitFuture->markCompleted(IValue( + c10::make_intrusive(std::move(*future).moveValue()))); + } + } + ); + return jitFuture; + } + + static std::shared_ptr toFutureMessage( + std::shared_ptr&& jitFuture) { + auto fm = std::make_shared(); + + std::weak_ptr wp = jitFuture; + jitFuture->addCallback( + [fm, wp]() mutable { + auto future = wp.lock(); + TORCH_INTERNAL_ASSERT(future); + if (future->hasError()) { + fm->setError(future->tryRetrieveErrorMessage()); + } else { + fm->markCompleted( + std::move(*future->value().toCustomClass())); + } + } + ); + return fm; + } + protected: const WorkerInfo workerInfo_; const std::unique_ptr cb_; diff --git a/torch/csrc/distributed/rpc/tensorpipe_agent.cpp b/torch/csrc/distributed/rpc/tensorpipe_agent.cpp index 4a39979e6e1b..8e5eb15057b0 100644 --- a/torch/csrc/distributed/rpc/tensorpipe_agent.cpp +++ b/torch/csrc/distributed/rpc/tensorpipe_agent.cpp @@ -604,7 +604,7 @@ void TensorPipeAgent::respond(std::shared_ptr& pipe) { }); } -std::shared_ptr TensorPipeAgent::send( +std::shared_ptr TensorPipeAgent::send( const WorkerInfo& toWorkerInfo, Message&& requestMessage, const float rpcTimeoutSeconds) { @@ -778,8 +778,8 @@ std::shared_ptr TensorPipeAgent::send( }); }); - return std::shared_ptr( - futureResponseMessage, &futureResponseMessage->futMsg); + return toJitFuture(std::shared_ptr( + futureResponseMessage, &futureResponseMessage->futMsg)); } void TensorPipeAgent::pollTimeoutRpcs() { diff --git a/torch/csrc/distributed/rpc/tensorpipe_agent.h b/torch/csrc/distributed/rpc/tensorpipe_agent.h index 3eb7cdc6ec7e..83ade1cf9203 100644 --- a/torch/csrc/distributed/rpc/tensorpipe_agent.h +++ b/torch/csrc/distributed/rpc/tensorpipe_agent.h @@ -181,7 +181,7 @@ class TensorPipeAgent : public RpcAgent { TensorPipeAgent(const TensorPipeAgent&) = delete; TensorPipeAgent& operator=(const TensorPipeAgent&) = delete; - std::shared_ptr send( + std::shared_ptr send( const WorkerInfo& to, Message&& message, const float rpcTimeoutSeconds = kUnsetRpcTimeout) override; diff --git a/torch/csrc/distributed/rpc/testing/faulty_process_group_agent.cpp b/torch/csrc/distributed/rpc/testing/faulty_process_group_agent.cpp index dccf1abb6d3e..ba2229932145 100644 --- a/torch/csrc/distributed/rpc/testing/faulty_process_group_agent.cpp +++ b/torch/csrc/distributed/rpc/testing/faulty_process_group_agent.cpp @@ -56,7 +56,7 @@ std::unordered_map> FaultyProcessGroupAgent:: return delayMessages; } -std::shared_ptr FaultyProcessGroupAgent::send( +std::shared_ptr FaultyProcessGroupAgent::send( const WorkerInfo& to, Message&& message, const float rpcTimeoutSeconds) { @@ -82,7 +82,7 @@ std::shared_ptr FaultyProcessGroupAgent::send( fm->setError(makeRPCError( c10::str("Send attempt failed intentionally for ", key), RPCErrorType::INTENTIONAL_FAILURE)); - return fm; + return toJitFuture(std::move(fm)); } else { lock.unlock(); return ProcessGroupAgent::send(to, std::move(message), rpcTimeoutSeconds); diff --git a/torch/csrc/distributed/rpc/testing/faulty_process_group_agent.h b/torch/csrc/distributed/rpc/testing/faulty_process_group_agent.h index 25a162bbd559..8cbe4c9a137d 100644 --- a/torch/csrc/distributed/rpc/testing/faulty_process_group_agent.h +++ b/torch/csrc/distributed/rpc/testing/faulty_process_group_agent.h @@ -43,7 +43,7 @@ class FaultyProcessGroupAgent : public ProcessGroupAgent { int failNumSends = 0); // Faulty send function for this class. - std::shared_ptr send( + std::shared_ptr send( const WorkerInfo& to, Message&& message, const float rpcTimeoutSeconds = From 25ef605132bfc2c1e2a84b9b542ec7bca688f399 Mon Sep 17 00:00:00 2001 From: Shen Li Date: Thu, 7 Jan 2021 19:43:44 -0800 Subject: [PATCH 16/54] Replace FutureMessage with ivalue::Future in distributed/autograd/utils.* (#49927) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49927 Test Plan: Imported from OSS Reviewed By: lw Differential Revision: D25724241 Pulled By: mrshenli fbshipit-source-id: d608e448f5224e41fbb0b5be6b9ac51a587f25b4 --- test/cpp/rpc/e2e_test_base.h | 16 ++-- torch/csrc/distributed/autograd/utils.cpp | 12 ++- torch/csrc/distributed/autograd/utils.h | 2 +- torch/csrc/distributed/rpc/py_rref.cpp | 2 +- .../csrc/distributed/rpc/python_functions.cpp | 74 ++++++++++--------- torch/csrc/distributed/rpc/python_functions.h | 16 ++-- torch/csrc/distributed/rpc/rref_impl.cpp | 15 ++-- torch/csrc/distributed/rpc/rref_impl.h | 6 +- .../distributed/rpc/torchscript_functions.cpp | 24 +++--- 9 files changed, 88 insertions(+), 79 deletions(-) diff --git a/test/cpp/rpc/e2e_test_base.h b/test/cpp/rpc/e2e_test_base.h index 114284839858..f9a2fa20e910 100644 --- a/test/cpp/rpc/e2e_test_base.h +++ b/test/cpp/rpc/e2e_test_base.h @@ -64,13 +64,14 @@ class TestE2EBase : public ::testing::Test { ScriptRemoteCall scriptRemoteCall( op, {t1, t2, 1}, ownerRRef->rrefId(), ownerRRef->rrefId()); - auto fm = autograd::sendMessageWithAutograd( + auto jitFuture = autograd::sendMessageWithAutograd( *rpcAgent, rpcAgent->getWorkerInfo("worker"), std::move(scriptRemoteCall).toMessage(), false); - ownerRRef->registerOwnerCreationFuture(fm); + ownerRRef->registerOwnerCreationFuture(jitFuture); + auto fm = RpcAgent::toFutureMessage(std::move(jitFuture)); // Builtin operators does not return py::object, and hence does not require // GIL for destructing the potentially deleted OwerRRef. @@ -88,11 +89,12 @@ class TestE2EBase : public ::testing::Test { ScriptCall scriptCall(op, {t1, t2, /* alpha */ 1}); // Send the RPC and return result. - auto response = autograd::sendMessageWithAutograd( - *rpcAgent, - rpcAgent->getWorkerInfo("worker"), - std::move(scriptCall).toMessage()) - ->wait(); + auto response = RpcAgent::toFutureMessage( + autograd::sendMessageWithAutograd( + *rpcAgent, + rpcAgent->getWorkerInfo("worker"), + std::move(scriptCall).toMessage())) + ->wait(); MessageType messageType = MessageType::FORWARD_AUTOGRAD_RESP; auto wrappedResponse = deserializeResponse(response, messageType); return static_cast(*wrappedResponse).value().toTensor(); diff --git a/torch/csrc/distributed/autograd/utils.cpp b/torch/csrc/distributed/autograd/utils.cpp index 4802a9047884..08bb99471686 100644 --- a/torch/csrc/distributed/autograd/utils.cpp +++ b/torch/csrc/distributed/autograd/utils.cpp @@ -16,7 +16,7 @@ namespace autograd { using torch::distributed::autograd::AutogradMetadata; using torch::distributed::autograd::RpcWithAutograd; -using torch::distributed::rpc::FutureMessage; +using torch::distributed::rpc::JitFuture; using torch::distributed::rpc::Message; using torch::distributed::rpc::MessageType; using torch::distributed::rpc::RpcAgent; @@ -138,7 +138,7 @@ Message getMessageWithAutograd( return std::move(*rpcWithAutograd).toMessage(); } -std::shared_ptr sendMessageWithAutograd( +std::shared_ptr sendMessageWithAutograd( RpcAgent& agent, const WorkerInfo& dst, torch::distributed::rpc::Message&& wrappedRpcMsg, @@ -151,7 +151,7 @@ std::shared_ptr sendMessageWithAutograd( MessageType::FORWARD_AUTOGRAD_REQ, forceGradRecording); - std::shared_ptr fut; + std::shared_ptr fut; // If profiler is enabled, wrap this message with profiling metadata that will // tell the remote end to process this request with the profiler enabled. if (!forceDisableProfiling && torch::autograd::profiler::profilerEnabled()) { @@ -160,11 +160,9 @@ std::shared_ptr sendMessageWithAutograd( std::move(msg), rpc::MessageType::RUN_WITH_PROFILING_REQ, std::move(profilerConfig)); - fut = rpc::RpcAgent::toFutureMessage( - agent.send(dst, std::move(msgWithProfiling), rpcTimeoutSeconds)); + fut = agent.send(dst, std::move(msgWithProfiling), rpcTimeoutSeconds); } else { - fut = rpc::RpcAgent::toFutureMessage( - agent.send(dst, std::move(msg), rpcTimeoutSeconds)); + fut = agent.send(dst, std::move(msg), rpcTimeoutSeconds); } return fut; diff --git a/torch/csrc/distributed/autograd/utils.h b/torch/csrc/distributed/autograd/utils.h index 2a0a066e1a95..07ba45ed60d7 100644 --- a/torch/csrc/distributed/autograd/utils.h +++ b/torch/csrc/distributed/autograd/utils.h @@ -45,7 +45,7 @@ TORCH_API rpc::Message getMessageWithAutograd( bool forceGradRecording = false); // Send message after autograd checking -TORCH_API std::shared_ptr +TORCH_API std::shared_ptr sendMessageWithAutograd( rpc::RpcAgent& agent, const rpc::WorkerInfo& dst, diff --git a/torch/csrc/distributed/rpc/py_rref.cpp b/torch/csrc/distributed/rpc/py_rref.cpp index f1ff4f29c206..4830cd5424b1 100644 --- a/torch/csrc/distributed/rpc/py_rref.cpp +++ b/torch/csrc/distributed/rpc/py_rref.cpp @@ -137,7 +137,7 @@ c10::intrusive_ptr PyRRef::getFuture() const { // Marking hasValue to false, as this Future is only used for signaling // profiler to update profiling result and the profiler does not retrieve // any value from it. - return wrapFutureMessageInJitFuture( + return toPyJitFuture( rref_->getOwnerCreationFuture(), false /* hasValue */); } diff --git a/torch/csrc/distributed/rpc/python_functions.cpp b/torch/csrc/distributed/rpc/python_functions.cpp index 5e2e8304b7bd..08e90f8f4d91 100644 --- a/torch/csrc/distributed/rpc/python_functions.cpp +++ b/torch/csrc/distributed/rpc/python_functions.cpp @@ -24,7 +24,7 @@ namespace rpc { namespace { -IValue toIValue(const Message& message) { +IValue toPyIValue(const Message& message) { MessageType msgType = message.type(); auto response = deserializeResponse(message, msgType); switch (msgType) { @@ -109,7 +109,7 @@ std::shared_ptr matchBuiltinOp( return matchedOperator; } -std::shared_ptr sendPythonRemoteCall( +std::shared_ptr sendPythonRemoteCall( const WorkerInfo& dst, SerializedPyObj serializedPyObj, const IValue& rrefId, @@ -134,42 +134,40 @@ std::shared_ptr sendPythonRemoteCall( using namespace torch::distributed::autograd; -c10::intrusive_ptr wrapFutureMessageInJitFuture( - const std::shared_ptr& futureResponseMessage, +c10::intrusive_ptr toPyJitFuture( + const std::shared_ptr& messageJitFuture, bool hasValue) { if (hasValue) { - c10::intrusive_ptr jitFuture = + c10::intrusive_ptr pyJitFuture = c10::make_intrusive(PyObjectType::get()); - std::weak_ptr wp = futureResponseMessage; - futureResponseMessage->addCallback( - at::wrapPropagateTLSState([jitFuture, wp]() { - auto futureResponseMessage = wp.lock(); - if (futureResponseMessage->hasError()) { - jitFuture->setError( - std::make_exception_ptr(*futureResponseMessage->error())); + std::weak_ptr wp = messageJitFuture; + messageJitFuture->addCallback( + at::wrapPropagateTLSState([pyJitFuture, wp]() { + auto messageJitFuture = wp.lock(); + if (messageJitFuture->hasError()) { + pyJitFuture->setError(messageJitFuture->exception_ptr()); } else { - jitFuture->markCompleted( - toIValue(futureResponseMessage->constValue())); + pyJitFuture->markCompleted(toPyIValue( + *messageJitFuture->value().toCustomClass())); } })); - return jitFuture; + return pyJitFuture; } else { - c10::intrusive_ptr jitFuture = + c10::intrusive_ptr pyJitFuture = c10::make_intrusive(NoneType::get()); - std::weak_ptr wp = futureResponseMessage; - futureResponseMessage->addCallback( - at::wrapPropagateTLSState([wp, jitFuture]() { - auto futureResponseMessage = wp.lock(); - if (futureResponseMessage->hasError()) { - jitFuture->setError( - std::make_exception_ptr(*futureResponseMessage->error())); + std::weak_ptr wp = messageJitFuture; + messageJitFuture->addCallback( + at::wrapPropagateTLSState([wp, pyJitFuture]() { + auto messageJitFuture = wp.lock(); + if (messageJitFuture->hasError()) { + pyJitFuture->setError(messageJitFuture->exception_ptr()); } else { - jitFuture->markCompleted(IValue()); + pyJitFuture->markCompleted(IValue()); } })); - return jitFuture; + return pyJitFuture; } } @@ -186,7 +184,7 @@ c10::intrusive_ptr pyRpcBuiltin( py::gil_scoped_release release; auto scriptCall = std::make_unique(op, std::move(stack)); auto agent = RpcAgent::getCurrentRpcAgent(); - return wrapFutureMessageInJitFuture(sendMessageWithAutograd( + return toPyJitFuture(sendMessageWithAutograd( *agent, dst, std::move(*scriptCall).toMessage(), @@ -207,7 +205,7 @@ c10::intrusive_ptr pyRpcPythonUdf( std::move(serializedPyObj), isAsyncExecution); auto agent = RpcAgent::getCurrentRpcAgent(); - return wrapFutureMessageInJitFuture(sendMessageWithAutograd( + return toPyJitFuture(sendMessageWithAutograd( *agent, dst, std::move(*pythonCall).toMessage(), @@ -275,14 +273,16 @@ PyRRef pyRemoteBuiltin( auto scriptRemoteCall = std::make_unique( op, std::move(stack), userRRef->rrefId(), userRRef->forkId()); - auto fm = sendMessageWithAutograd( + auto jitFuture = sendMessageWithAutograd( *agent, dst, std::move(*scriptRemoteCall).toMessage(), /*forceGradRecord */ false, /* timeout */ rpcTimeoutSeconds); - userRRef->registerOwnerCreationFuture(fm); + userRRef->registerOwnerCreationFuture(jitFuture); + auto fm = RpcAgent::toFutureMessage(std::move(jitFuture)); + ctx.addPendingUser(userRRef->forkId(), userRRef); std::weak_ptr wp = fm; fm->addCallback( @@ -298,14 +298,15 @@ PyRRef pyRemoteBuiltin( auto scriptRemoteCall = std::make_unique( op, std::move(stack), ownerRRef->rrefId(), ownerRRef->rrefId()); - auto fm = sendMessageWithAutograd( + auto jitFuture = sendMessageWithAutograd( *agent, dst, std::move(*scriptRemoteCall).toMessage(), /* forceGradRecord */ false, /* timeout */ rpcTimeoutSeconds); - ownerRRef->registerOwnerCreationFuture(fm); + ownerRRef->registerOwnerCreationFuture(jitFuture); + auto fm = RpcAgent::toFutureMessage(std::move(jitFuture)); // Builtin operators does not return py::object, and hence does not require // GIL for destructing the potentially deleted OwerRRef. @@ -332,7 +333,7 @@ PyRRef pyRemotePythonUdf( if (ctx.getWorkerId() != dst.id_) { auto userRRef = ctx.createUserRRef(dst.id_, PyObjectType::get()); - auto fm = sendPythonRemoteCall( + auto jitFuture = sendPythonRemoteCall( dst, std::move(serializedPyObj), userRRef->rrefId().toIValue(), @@ -340,7 +341,8 @@ PyRRef pyRemotePythonUdf( rpcTimeoutSeconds, isAsyncExecution); - userRRef->registerOwnerCreationFuture(fm); + userRRef->registerOwnerCreationFuture(jitFuture); + auto fm = RpcAgent::toFutureMessage(std::move(jitFuture)); ctx.addPendingUser(userRRef->forkId(), userRRef); std::weak_ptr wp = fm; @@ -355,7 +357,7 @@ PyRRef pyRemotePythonUdf( auto ownerRRef = ctx.createOwnerRRef(PyObjectType::get()); // prevent this owner RRef being deleted due to other forks ctx.addSelfAsFork(ownerRRef); - auto fm = sendPythonRemoteCall( + auto jitFuture = sendPythonRemoteCall( dst, std::move(serializedPyObj), ownerRRef->rrefId().toIValue(), @@ -363,7 +365,9 @@ PyRRef pyRemotePythonUdf( rpcTimeoutSeconds, isAsyncExecution); - ownerRRef->registerOwnerCreationFuture(fm); + ownerRRef->registerOwnerCreationFuture(jitFuture); + auto fm = RpcAgent::toFutureMessage(std::move(jitFuture)); + std::weak_ptr wp = fm; fm->addCallback(at::wrapPropagateTLSState( [wp, ownerRRefId = ownerRRef->rrefId()]() { diff --git a/torch/csrc/distributed/rpc/python_functions.h b/torch/csrc/distributed/rpc/python_functions.h index 56c091096828..15bc0b2af8a0 100644 --- a/torch/csrc/distributed/rpc/python_functions.h +++ b/torch/csrc/distributed/rpc/python_functions.h @@ -9,16 +9,16 @@ namespace torch { namespace distributed { namespace rpc { -// Converts an internal FutureMessage type into a user-facing FutureIValue type -// by creating a new FutureIValue and call its markCompleted as a callback in -// the given FutureMessage. +// Converts an internal ivalue::Future of Message into a user-facing +// ivalue::Future of py::object type by creating a new ivalue::Future and call +// its markCompleted as a callback in the given ivalue::Future. // If hasValue is true, the Message will be converted into a py::object and then -// wrap it with an IValue. If hasValue is false, this FutureIValue is only used -// for signaling and launching callbacks. In this case, the message will be -// discarded and then set the FutureIValue using an empty IValue or the given +// wrap it with an IValue. If hasValue is false, this ivalue::Future is only +// used for signaling and launching callbacks. In this case, the message will be +// discarded and then set the ivalue::Future using an empty IValue or the given // FutureError if there is an error. -c10::intrusive_ptr wrapFutureMessageInJitFuture( - const std::shared_ptr& futureResponseMessage, +c10::intrusive_ptr toPyJitFuture( + const std::shared_ptr& messageJitFuture, bool hasValue = true); c10::intrusive_ptr pyRpcBuiltin( diff --git a/torch/csrc/distributed/rpc/rref_impl.cpp b/torch/csrc/distributed/rpc/rref_impl.cpp index 6c6a377a4652..72fed660e200 100644 --- a/torch/csrc/distributed/rpc/rref_impl.cpp +++ b/torch/csrc/distributed/rpc/rref_impl.cpp @@ -170,13 +170,14 @@ IValue UserRRef::toHere(const float timeoutSeconds) const { // toHere is profiled as a blocking call, and does not execute operations on // the remote node. Hence, don't wrap it with a profiling message since we // don't need the profiler to be enabled remotely. - auto futureResponse = autograd::sendMessageWithAutograd( - *agent, - agent->getWorkerInfo(ownerId_), - std::move(msgToSend), - true /* forceGradRecording */, - timeoutSeconds, - true /* forceDisableProfiling */); + auto futureResponse = RpcAgent::toFutureMessage( + autograd::sendMessageWithAutograd( + *agent, + agent->getWorkerInfo(ownerId_), + std::move(msgToSend), + true /* forceGradRecording */, + timeoutSeconds, + true /* forceDisableProfiling */)); // TODO: we should ideally be able to interrupt this blocking wait if we check // getTimedOut() and it is true diff --git a/torch/csrc/distributed/rpc/rref_impl.h b/torch/csrc/distributed/rpc/rref_impl.h index 29aa355908fa..e18c1fe3b9d7 100644 --- a/torch/csrc/distributed/rpc/rref_impl.h +++ b/torch/csrc/distributed/rpc/rref_impl.h @@ -228,12 +228,12 @@ class TORCH_API RRef : public RRefInterface { // node. Note that this is only set when processing requests invoked with // rpc.remote. This is only used to get the future corresponding to the rref // for profiling use cases. - inline void registerOwnerCreationFuture(std::shared_ptr fut) { + inline void registerOwnerCreationFuture(std::shared_ptr fut) { ownerCreationFuture_ = std::move(fut); } // Get the future corresponding to the creation of this rref. - inline std::shared_ptr getOwnerCreationFuture() const { + inline std::shared_ptr getOwnerCreationFuture() const { return ownerCreationFuture_; } @@ -272,7 +272,7 @@ class TORCH_API RRef : public RRefInterface { // it could be any TypePtr that JIT support, including PyObjectType const TypePtr type_; // Future corresponding to request to create RRef on remote node. - std::shared_ptr ownerCreationFuture_; + std::shared_ptr ownerCreationFuture_; }; // ``UserRRef`` represents a user of an RRef. Besides the ``RRefId``, each user diff --git a/torch/csrc/distributed/rpc/torchscript_functions.cpp b/torch/csrc/distributed/rpc/torchscript_functions.cpp index a9cd006439e8..22ce8cbd6237 100644 --- a/torch/csrc/distributed/rpc/torchscript_functions.cpp +++ b/torch/csrc/distributed/rpc/torchscript_functions.cpp @@ -43,12 +43,13 @@ c10::intrusive_ptr rpcTorchscript( auto scriptCall = std::make_unique( qualifiedName, std::move(stack), isAsyncExecution); auto rpcAgentPtr = RpcAgent::getCurrentRpcAgent(); - auto futMessage = autograd::sendMessageWithAutograd( - *rpcAgentPtr, - rpcAgentPtr->getWorkerInfo(dstWorkerName), - std::move(*scriptCall).toMessage(), - true /*forceGradRecording*/, - rpcTimeoutSeconds); + auto futMessage = RpcAgent::toFutureMessage( + autograd::sendMessageWithAutograd( + *rpcAgentPtr, + rpcAgentPtr->getWorkerInfo(dstWorkerName), + std::move(*scriptCall).toMessage(), + true /*forceGradRecording*/, + rpcTimeoutSeconds)); // Get function return type to construct c10::ivalue::Future. auto returns = functionSchema.returns(); @@ -112,14 +113,15 @@ c10::intrusive_ptr remoteTorchscript( userRRefPtr->forkId(), isAsyncExecution); - auto fm = torch::distributed::autograd::sendMessageWithAutograd( + auto jitFuture = torch::distributed::autograd::sendMessageWithAutograd( *rpcAgentPtr, dstWorkerInfo, std::move(*scriptRemoteCall).toMessage(), true /*forceGradRecording*/, rpcTimeoutSeconds /* timeout */); - userRRefPtr->registerOwnerCreationFuture(fm); + userRRefPtr->registerOwnerCreationFuture(jitFuture); + auto fm = RpcAgent::toFutureMessage(std::move(jitFuture)); ctx.addPendingUser(userRRefPtr->forkId(), userRRefPtr); std::weak_ptr wp = fm; @@ -142,14 +144,16 @@ c10::intrusive_ptr remoteTorchscript( ownerRRefPtr->rrefId(), isAsyncExecution); - auto fm = torch::distributed::autograd::sendMessageWithAutograd( + auto jitFuture = torch::distributed::autograd::sendMessageWithAutograd( *rpcAgentPtr, dstWorkerInfo, std::move(*scriptRemoteCall).toMessage(), true /*forceGradRecording*/, rpcTimeoutSeconds /* timeout */); - ownerRRefPtr->registerOwnerCreationFuture(fm); + ownerRRefPtr->registerOwnerCreationFuture(jitFuture); + auto fm = RpcAgent::toFutureMessage(std::move(jitFuture)); + std::weak_ptr wp = fm; fm->addCallback(at::wrapPropagateTLSState( [wp, ownerRRefId = ownerRRefPtr->rrefId()]() { From 008206decca0c40a125b0a9143d1c57a53ed0a1a Mon Sep 17 00:00:00 2001 From: Shen Li Date: Thu, 7 Jan 2021 19:43:44 -0800 Subject: [PATCH 17/54] Replace FutureMessage with ivalue::Future in RRefContext (#49960) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49960 Test Plan: Imported from OSS Reviewed By: lw Differential Revision: D25730530 Pulled By: mrshenli fbshipit-source-id: 5d54572c653592d79c40aed616266c87307a1ad8 --- aten/src/ATen/core/ivalue_inl.h | 6 +-- test/cpp/rpc/e2e_test_base.h | 9 ++-- .../csrc/distributed/rpc/python_functions.cpp | 52 +++++++------------ torch/csrc/distributed/rpc/rref_context.cpp | 34 +++++++----- torch/csrc/distributed/rpc/rref_context.h | 5 +- torch/csrc/distributed/rpc/rref_impl.cpp | 14 ++--- torch/csrc/distributed/rpc/rref_impl.h | 2 +- .../distributed/rpc/torchscript_functions.cpp | 18 +++---- torch/csrc/distributed/rpc/utils.cpp | 8 +-- torch/csrc/distributed/rpc/utils.h | 2 +- 10 files changed, 72 insertions(+), 78 deletions(-) diff --git a/aten/src/ATen/core/ivalue_inl.h b/aten/src/ATen/core/ivalue_inl.h index b96f4b834989..c14b86f995e5 100644 --- a/aten/src/ATen/core/ivalue_inl.h +++ b/aten/src/ATen/core/ivalue_inl.h @@ -406,7 +406,7 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target { // This accessor should only be used if we know that the future is // completed() with no error. - const IValue& constValue() { + const IValue& constValue() const { std::unique_lock lock(mutex_); AT_ASSERT(completed()); AT_ASSERT(!eptr_); @@ -451,7 +451,7 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target { } // Tries to retrieve the error message from std::exception_ptr. - std::string tryRetrieveErrorMessage() { + std::string tryRetrieveErrorMessage() const { TORCH_CHECK(hasError(), "No error present on the future."); std::unique_lock lock(mutex_); return tryRetrieveErrorMessageInternal(eptr_); @@ -543,7 +543,7 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target { } // Tries to retrieve the error message from std::exception_ptr. - std::string tryRetrieveErrorMessageInternal(std::exception_ptr eptr) { + std::string tryRetrieveErrorMessageInternal(std::exception_ptr eptr) const { try { std::rethrow_exception(eptr); } catch (const std::exception& e) { diff --git a/test/cpp/rpc/e2e_test_base.h b/test/cpp/rpc/e2e_test_base.h index f9a2fa20e910..61d9a73e1077 100644 --- a/test/cpp/rpc/e2e_test_base.h +++ b/test/cpp/rpc/e2e_test_base.h @@ -71,13 +71,14 @@ class TestE2EBase : public ::testing::Test { false); ownerRRef->registerOwnerCreationFuture(jitFuture); - auto fm = RpcAgent::toFutureMessage(std::move(jitFuture)); // Builtin operators does not return py::object, and hence does not require // GIL for destructing the potentially deleted OwerRRef. - fm->addCallback( - [ownerRRefId = ownerRRef->rrefId()](const FutureMessage& fm) { - callback::finishCreatingOwnerRRef(fm, ownerRRefId); + std::weak_ptr wp = jitFuture; + jitFuture->addCallback( + [wp, ownerRRefId = ownerRRef->rrefId()]() { + auto jitFuture = wp.lock(); + callback::finishCreatingOwnerRRef(*jitFuture, ownerRRefId); }); return ownerRRef; } diff --git a/torch/csrc/distributed/rpc/python_functions.cpp b/torch/csrc/distributed/rpc/python_functions.cpp index 08e90f8f4d91..383d6df9cee0 100644 --- a/torch/csrc/distributed/rpc/python_functions.cpp +++ b/torch/csrc/distributed/rpc/python_functions.cpp @@ -143,12 +143,12 @@ c10::intrusive_ptr toPyJitFuture( std::weak_ptr wp = messageJitFuture; messageJitFuture->addCallback( at::wrapPropagateTLSState([pyJitFuture, wp]() { - auto messageJitFuture = wp.lock(); - if (messageJitFuture->hasError()) { - pyJitFuture->setError(messageJitFuture->exception_ptr()); + auto future = wp.lock(); + if (future->hasError()) { + pyJitFuture->setError(future->exception_ptr()); } else { - pyJitFuture->markCompleted(toPyIValue( - *messageJitFuture->value().toCustomClass())); + pyJitFuture->markCompleted( + toPyIValue(*future->value().toCustomClass())); } })); @@ -159,9 +159,9 @@ c10::intrusive_ptr toPyJitFuture( std::weak_ptr wp = messageJitFuture; messageJitFuture->addCallback( at::wrapPropagateTLSState([wp, pyJitFuture]() { - auto messageJitFuture = wp.lock(); - if (messageJitFuture->hasError()) { - pyJitFuture->setError(messageJitFuture->exception_ptr()); + auto future = wp.lock(); + if (future->hasError()) { + pyJitFuture->setError(future->exception_ptr()); } else { pyJitFuture->markCompleted(IValue()); } @@ -281,14 +281,11 @@ PyRRef pyRemoteBuiltin( /* timeout */ rpcTimeoutSeconds); userRRef->registerOwnerCreationFuture(jitFuture); - auto fm = RpcAgent::toFutureMessage(std::move(jitFuture)); - ctx.addPendingUser(userRRef->forkId(), userRRef); - std::weak_ptr wp = fm; - fm->addCallback( + std::weak_ptr wp = jitFuture; + jitFuture->addCallback( at::wrapPropagateTLSState([wp, forkId{userRRef->forkId()}]() { - auto fm = wp.lock(); - callback::confirmPendingUser(*fm, forkId); + callback::confirmPendingUser(*wp.lock(), forkId); })); return PyRRef(userRRef); } else { @@ -306,15 +303,12 @@ PyRRef pyRemoteBuiltin( /* timeout */ rpcTimeoutSeconds); ownerRRef->registerOwnerCreationFuture(jitFuture); - auto fm = RpcAgent::toFutureMessage(std::move(jitFuture)); - // Builtin operators does not return py::object, and hence does not require // GIL for destructing the potentially deleted OwerRRef. - std::weak_ptr wp = fm; - fm->addCallback(at::wrapPropagateTLSState( + std::weak_ptr wp = jitFuture; + jitFuture->addCallback(at::wrapPropagateTLSState( [wp, ownerRRefId = ownerRRef->rrefId()]() { - auto fm = wp.lock(); - callback::finishCreatingOwnerRRef(*fm, ownerRRefId); + callback::finishCreatingOwnerRRef(*wp.lock(), ownerRRefId); })); return PyRRef(ownerRRef); } @@ -342,14 +336,11 @@ PyRRef pyRemotePythonUdf( isAsyncExecution); userRRef->registerOwnerCreationFuture(jitFuture); - auto fm = RpcAgent::toFutureMessage(std::move(jitFuture)); - ctx.addPendingUser(userRRef->forkId(), userRRef); - std::weak_ptr wp = fm; - fm->addCallback( + std::weak_ptr wp = jitFuture; + jitFuture->addCallback( at::wrapPropagateTLSState([wp, forkId{userRRef->forkId()}]() { - auto fm = wp.lock(); - callback::confirmPendingUser(*fm, forkId); + callback::confirmPendingUser(*wp.lock(), forkId); })); return PyRRef(userRRef); } else { @@ -366,14 +357,11 @@ PyRRef pyRemotePythonUdf( isAsyncExecution); ownerRRef->registerOwnerCreationFuture(jitFuture); - auto fm = RpcAgent::toFutureMessage(std::move(jitFuture)); - - std::weak_ptr wp = fm; - fm->addCallback(at::wrapPropagateTLSState( + std::weak_ptr wp = jitFuture; + jitFuture->addCallback(at::wrapPropagateTLSState( [wp, ownerRRefId = ownerRRef->rrefId()]() { - auto fm = wp.lock(); auto deletedRRef = - callback::finishCreatingOwnerRRef(*fm, ownerRRefId); + callback::finishCreatingOwnerRRef(*wp.lock(), ownerRRefId); if (deletedRRef && deletedRRef->isPyObj()) { py::gil_scoped_acquire ag; deletedRRef.reset(); diff --git a/torch/csrc/distributed/rpc/rref_context.cpp b/torch/csrc/distributed/rpc/rref_context.cpp index dd64ee5c9445..c608be534561 100644 --- a/torch/csrc/distributed/rpc/rref_context.cpp +++ b/torch/csrc/distributed/rpc/rref_context.cpp @@ -14,11 +14,12 @@ thread_local bool RRefContext::recording_ = false; namespace callback { void confirmPendingUser( - const FutureMessage& futureMessage, + const JitFuture& jitFuture, const ForkId& expectedForkId) { - if (!futureMessage.hasError()) { - auto msgType = futureMessage.constValue().type(); - auto rpc = deserializeResponse(futureMessage.constValue(), msgType); + if (!jitFuture.hasError()) { + auto msgPtr = jitFuture.constValue().toCustomClass(); + auto msgType = msgPtr->type(); + auto rpc = deserializeResponse(*msgPtr, msgType); auto rr = dynamic_cast(rpc.get()); TORCH_INTERNAL_ASSERT(rr->forkId() == expectedForkId); } else { @@ -34,30 +35,31 @@ void confirmPendingUser( // the user application will use the RRef before the errors are handled. In // this case, errors may not be raised as they have not yet been handled. auto rref_ptr = RRefContext::getInstance().getPendingUser(expectedForkId); - auto errorType = getRPCErrorType(futureMessage); - rref_ptr->handleError(errorType, futureMessage); + auto errorType = getRPCErrorType(jitFuture); + rref_ptr->handleError(errorType, jitFuture); } RRefContext::getInstance().delPendingUser(expectedForkId); } c10::intrusive_ptr finishCreatingOwnerRRef( - const FutureMessage& futureMessage, + const JitFuture& jitFuture, const RRefId& rrefId) { - if (futureMessage.hasError()) { + if (jitFuture.hasError()) { auto& ctx = RRefContext::getInstance(); // We expect to run this callback only after the OwnerRRef has been created, // since this is only invoked when sending to self. auto rref_ptr = ctx.getOwnerRRef(rrefId, /* ensure created */ true)->constValue(); - auto errorType = getRPCErrorType(futureMessage); - rref_ptr->handleError(errorType, futureMessage); + auto errorType = getRPCErrorType(jitFuture); + rref_ptr->handleError(errorType, jitFuture); // OwnerRRefs do not have a forkId, so don't need to assert here. auto deletedRRef = ctx.delForkOfOwner(rref_ptr->rrefId(), rref_ptr->rrefId()); return deletedRRef; } else { - auto msgType = futureMessage.constValue().type(); - auto rpc = deserializeResponse(futureMessage.constValue(), msgType); + auto msgPtr = jitFuture.constValue().toCustomClass(); + auto msgType = msgPtr->type(); + auto rpc = deserializeResponse(*msgPtr, msgType); auto rr = dynamic_cast(rpc.get()); TORCH_INTERNAL_ASSERT( rr->rrefId() == rr->forkId(), @@ -102,6 +104,14 @@ std::vector> RRefContext::destroyInstance( return deletedRRefs; } +void RRefContext::handleException(const JitFuture& jitFuture) { + if (jitFuture.hasError()) { + auto errMsg = jitFuture.tryRetrieveErrorMessage(); + VLOG(1) << "Got exception: " << errMsg; + throw std::runtime_error(errMsg); + } +} + void RRefContext::handleException(const FutureMessage& fm) { if (fm.hasError()) { VLOG(1) << "Got exception: " << fm.error()->what(); diff --git a/torch/csrc/distributed/rpc/rref_context.h b/torch/csrc/distributed/rpc/rref_context.h index cf89980e7f71..48985fd42798 100644 --- a/torch/csrc/distributed/rpc/rref_context.h +++ b/torch/csrc/distributed/rpc/rref_context.h @@ -17,14 +17,14 @@ namespace rpc { namespace callback { // It's the callback for RemoteCall. void TORCH_API confirmPendingUser( - const FutureMessage& futureMessage, + const JitFuture& jitFuture, const ForkId& expectedForkId); // It's the callback for finishing creating owner rref, it returned deletedRRef, // so that the deletedRRef can be handled under GIL in python_functions.cpp if // deletedRRef contains python object. c10::intrusive_ptr TORCH_API finishCreatingOwnerRRef( - const FutureMessage& futureMessage, + const JitFuture& jitFuture, const RRefId& rrefId); } // namespace callback @@ -42,6 +42,7 @@ class TORCH_API RRefContext { static std::vector> destroyInstance( bool ignoreRRefLeak = true); + static void handleException(const JitFuture& jitFuture); static void handleException(const FutureMessage& fm); RRefContext(const RRefContext&) = delete; diff --git a/torch/csrc/distributed/rpc/rref_impl.cpp b/torch/csrc/distributed/rpc/rref_impl.cpp index 72fed660e200..54f7320567a9 100644 --- a/torch/csrc/distributed/rpc/rref_impl.cpp +++ b/torch/csrc/distributed/rpc/rref_impl.cpp @@ -67,21 +67,21 @@ RRefForkData RRef::fork() const { void RRef::handleError( RPCErrorType errorType, - const FutureMessage& futMessage) { + const JitFuture& jitFuture) { static std::unordered_map< RPCErrorType, - std::function, + std::function, std::hash> errorHandlers = { {RPCErrorType::TIMEOUT, - [this](const FutureMessage& /* unused */) { setTimedOut(); }}, + [this](const JitFuture& /* unused */) { setTimedOut(); }}, {RPCErrorType::INTENTIONAL_FAILURE, - [this](const FutureMessage& /* unused */) { setTimedOut(); }}, - {RPCErrorType::UNKNOWN_ERROR, [](const FutureMessage& fm) { + [this](const JitFuture& /* unused */) { setTimedOut(); }}, + {RPCErrorType::UNKNOWN_ERROR, [](const JitFuture& jitFuture) { // Default error handler - RRefContext::handleException(fm); + RRefContext::handleException(jitFuture); }}}; - errorHandlers.find(errorType)->second(futMessage); + errorHandlers.find(errorType)->second(jitFuture); } ////////////////////////// UserRRef ///////////////////////////////////// diff --git a/torch/csrc/distributed/rpc/rref_impl.h b/torch/csrc/distributed/rpc/rref_impl.h index e18c1fe3b9d7..c7f812271468 100644 --- a/torch/csrc/distributed/rpc/rref_impl.h +++ b/torch/csrc/distributed/rpc/rref_impl.h @@ -243,7 +243,7 @@ class TORCH_API RRef : public RRefInterface { } // Dispatches an error to the correct handler based on its RPCErrorType. - void handleError(RPCErrorType errorType, const FutureMessage& futMessage); + void handleError(RPCErrorType errorType, const JitFuture& JitFuture); // Send delete UserRRef request to Owner, // if the request hasn't been sent yet. diff --git a/torch/csrc/distributed/rpc/torchscript_functions.cpp b/torch/csrc/distributed/rpc/torchscript_functions.cpp index 22ce8cbd6237..c6efd8eaba7d 100644 --- a/torch/csrc/distributed/rpc/torchscript_functions.cpp +++ b/torch/csrc/distributed/rpc/torchscript_functions.cpp @@ -121,14 +121,11 @@ c10::intrusive_ptr remoteTorchscript( rpcTimeoutSeconds /* timeout */); userRRefPtr->registerOwnerCreationFuture(jitFuture); - auto fm = RpcAgent::toFutureMessage(std::move(jitFuture)); - ctx.addPendingUser(userRRefPtr->forkId(), userRRefPtr); - std::weak_ptr wp = fm; - fm->addCallback( + std::weak_ptr wp = jitFuture; + jitFuture->addCallback( at::wrapPropagateTLSState([wp, forkId{userRRefPtr->forkId()}]() { - auto fm = wp.lock(); - callback::confirmPendingUser(*fm, forkId); + callback::confirmPendingUser(*wp.lock(), forkId); })); return userRRefPtr; @@ -152,13 +149,10 @@ c10::intrusive_ptr remoteTorchscript( rpcTimeoutSeconds /* timeout */); ownerRRefPtr->registerOwnerCreationFuture(jitFuture); - auto fm = RpcAgent::toFutureMessage(std::move(jitFuture)); - - std::weak_ptr wp = fm; - fm->addCallback(at::wrapPropagateTLSState( + std::weak_ptr wp = jitFuture; + jitFuture->addCallback(at::wrapPropagateTLSState( [wp, ownerRRefId = ownerRRefPtr->rrefId()]() { - auto fm = wp.lock(); - callback::finishCreatingOwnerRRef(*fm, ownerRRefId); + callback::finishCreatingOwnerRRef(*wp.lock(), ownerRRefId); })); return ownerRRefPtr; } diff --git a/torch/csrc/distributed/rpc/utils.cpp b/torch/csrc/distributed/rpc/utils.cpp index 79c197505bbb..8c39cc3e3e97 100644 --- a/torch/csrc/distributed/rpc/utils.cpp +++ b/torch/csrc/distributed/rpc/utils.cpp @@ -58,15 +58,15 @@ void processRemoteProfiledEvents( const std::string kRPCErrorPrefix = std::string("RPCErr"); -RPCErrorType getRPCErrorType(const FutureMessage& fm) { +RPCErrorType getRPCErrorType(const JitFuture& jitFuture) { TORCH_INTERNAL_ASSERT( - fm.hasError(), - "FutureMessage passed to getRPCErrorType does not have an error."); + jitFuture.hasError(), + "JitFuture of Message passed to getRPCErrorType does not have an error."); // Attempt to parse for error string given by makeRPCError, otherwise return // unknown error. // Note that this function expects errors formatted with makeRPCError(). - auto err = std::string(fm.error()->what()); + auto err = jitFuture.tryRetrieveErrorMessage(); size_t pos = err.find(kRPCErrorPrefix); if (pos != std::string::npos) { // Parse the RPCErrorType. diff --git a/torch/csrc/distributed/rpc/utils.h b/torch/csrc/distributed/rpc/utils.h index aa920d06cae8..4d27daff6ffe 100644 --- a/torch/csrc/distributed/rpc/utils.h +++ b/torch/csrc/distributed/rpc/utils.h @@ -15,7 +15,7 @@ namespace distributed { namespace rpc { // Parse error message and return RPCErrorType based on the message. -TORCH_API RPCErrorType getRPCErrorType(const FutureMessage& fm); +TORCH_API RPCErrorType getRPCErrorType(const JitFuture& jitFuture); // Create an error string given the error description and error type TORCH_API std::string makeRPCError( const std::string& rpcErrorStr, From d730c7e261acdfd3521338252e45abb7bdac53cc Mon Sep 17 00:00:00 2001 From: Shen Li Date: Thu, 7 Jan 2021 19:43:44 -0800 Subject: [PATCH 18/54] Replace FutureMessage with ivalue::Future in RpcAgent retry logic (#49995) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49995 Test Plan: Imported from OSS Reviewed By: lw Differential Revision: D25745301 Pulled By: mrshenli fbshipit-source-id: b5e3a7e0b377496924847d8d70d61de32e2d87f4 --- .../autograd/context/container.cpp | 9 ++-- torch/csrc/distributed/rpc/rpc_agent.cpp | 52 +++++++++---------- torch/csrc/distributed/rpc/rpc_agent.h | 8 +-- torch/csrc/distributed/rpc/rref_context.cpp | 36 ++++++------- torch/csrc/distributed/rpc/rref_context.h | 1 - 5 files changed, 53 insertions(+), 53 deletions(-) diff --git a/torch/csrc/distributed/autograd/context/container.cpp b/torch/csrc/distributed/autograd/context/container.cpp index ee3939010e8a..6948de958b84 100644 --- a/torch/csrc/distributed/autograd/context/container.cpp +++ b/torch/csrc/distributed/autograd/context/container.cpp @@ -245,14 +245,17 @@ void DistAutogradContainer::sendReleaseContextRpc( CleanupAutogradContextReq(context_id).toMessage(), options); + std::weak_ptr wp = cleanupFuture; cleanupFuture->addCallback( - [worker_id](const rpc::FutureMessage& cleanupFuture) { - if (cleanupFuture.hasError()) { + [worker_id, wp]() { + auto future = wp.lock(); + TORCH_INTERNAL_ASSERT(future); + if (future->hasError()) { std::string errorMsg = c10::str( "Could not release Dist Autograd Context on node ", worker_id, ": ", - cleanupFuture.error()->what()); + future->tryRetrieveErrorMessage()); LOG(ERROR) << errorMsg; return; } diff --git a/torch/csrc/distributed/rpc/rpc_agent.cpp b/torch/csrc/distributed/rpc/rpc_agent.cpp index 549408387bc0..afcb0226b411 100644 --- a/torch/csrc/distributed/rpc/rpc_agent.cpp +++ b/torch/csrc/distributed/rpc/rpc_agent.cpp @@ -45,7 +45,7 @@ void RpcAgent::shutdown() { shutdownImpl(); } -std::shared_ptr RpcAgent::sendWithRetries( +std::shared_ptr RpcAgent::sendWithRetries( const WorkerInfo& to, Message&& message, RpcRetryOptions retryOptions) { @@ -57,12 +57,12 @@ std::shared_ptr RpcAgent::sendWithRetries( retryOptions.rpcRetryDuration.count() >= 0, "rpcRetryDuration cannot be negative."); - auto originalFuture = std::make_shared(); + auto originalFuture = std::make_shared(at::AnyClassType::get()); steady_clock_time_point newTime = computeNewRpcRetryTime(retryOptions, /* retryCount */ 0); // Making a copy of the message so it can be retried after the first send. Message msgCopy = message; - auto fm = toFutureMessage(send(to, std::move(message))); + auto jitFuture = send(to, std::move(message)); auto firstRetryRpc = std::make_shared( to, std::move(msgCopy), @@ -70,13 +70,13 @@ std::shared_ptr RpcAgent::sendWithRetries( /* retryCount */ 0, retryOptions); // Use weak_ptr so that the value can be std::moved in rpcRetryCallback. - fm->addCallback([this, + jitFuture->addCallback([this, newTime, firstRetryRpc, - weak = std::weak_ptr(fm)]() { - auto fm = weak.lock(); - TORCH_INTERNAL_ASSERT(fm); - rpcRetryCallback(fm, newTime, firstRetryRpc); + wp = std::weak_ptr(jitFuture)]() { + auto future = wp.lock(); + TORCH_INTERNAL_ASSERT(future); + rpcRetryCallback(future, newTime, firstRetryRpc); }); return originalFuture; @@ -85,10 +85,10 @@ std::shared_ptr RpcAgent::sendWithRetries( void RpcAgent::retryExpiredRpcs() { // Stores the retried futures so callbacks can be added outside the lock. std::vector< - std::pair, std::shared_ptr>> + std::pair, std::shared_ptr>> futures; // Stores futures and exception messages for non-retriable error-ed futures. - std::vector, std::string>> + std::vector, std::string>> errorFutures; while (rpcAgentRunning_.load()) { @@ -126,15 +126,15 @@ void RpcAgent::retryExpiredRpcs() { auto& earliestRpc = *it; // Making a copy of the message so it can be retried in the future. Message msgCopy = earliestRpc->message_; - std::shared_ptr fm; + std::shared_ptr jitFuture; // send() will throw an exception if an RPC is retried while the agent is // shutdown. We must catch this exception and mark the original future // with an error, since this RPC never succeeded and can no longer be // retried. try { - fm = toFutureMessage(send(earliestRpc->to_, std::move(msgCopy))); - futures.emplace_back(fm, earliestRpc); + jitFuture = send(earliestRpc->to_, std::move(msgCopy)); + futures.emplace_back(jitFuture, earliestRpc); } catch (std::exception& e) { // We must store the futures and exception messages here and only mark // the futures with an error after releasing the lock. @@ -158,20 +158,20 @@ void RpcAgent::retryExpiredRpcs() { // We attach callbacks to the futures outside of the lock to prevent // potential deadlocks. for (const auto& it : futures) { - auto fm = it.first; + auto jitFuture = it.first; auto earliestRpc = it.second; steady_clock_time_point newTime = computeNewRpcRetryTime( earliestRpc->options_, earliestRpc->retryCount_); earliestRpc->retryCount_++; // Use weak_ptr so that the value can be std::moved in rpcRetryCallback. - fm->addCallback([this, + jitFuture->addCallback([this, newTime, earliestRpc, - weak = std::weak_ptr(fm)]() { - auto fm = weak.lock(); - TORCH_INTERNAL_ASSERT(fm); - rpcRetryCallback(fm, newTime, earliestRpc); + wp = std::weak_ptr(jitFuture)]() { + auto future = wp.lock(); + TORCH_INTERNAL_ASSERT(future); + rpcRetryCallback(future, newTime, earliestRpc); }); } futures.clear(); @@ -181,17 +181,17 @@ void RpcAgent::retryExpiredRpcs() { for (const auto& it : errorFutures) { auto errorFuture = it.first; auto errorMsg = it.second; - errorFuture->setError(errorMsg); + errorFuture->setError(std::make_exception_ptr(std::runtime_error(errorMsg))); } errorFutures.clear(); } } void RpcAgent::rpcRetryCallback( - const std::shared_ptr& futureMessage, + const std::shared_ptr& jitFuture, steady_clock_time_point newTime, std::shared_ptr earliestRpc) { - if (futureMessage->hasError()) { + if (jitFuture->hasError()) { // Adding one since we want to include the original send as well and not // just the retry count. LOG(INFO) << "Send try " << (earliestRpc->retryCount_ + 1) << " failed"; @@ -203,7 +203,7 @@ void RpcAgent::rpcRetryCallback( "RPC Agent is no longer running on Node ", RpcAgent::getWorkerInfo().id_, ". Cannot retry message."); - earliestRpc->originalFuture_->setError(*futureMessage->error()); + earliestRpc->originalFuture_->setError(jitFuture->exception_ptr()); } else if (earliestRpc->retryCount_ < earliestRpc->options_.maxRetries) { // If the previous future completed with an error and we haven't // completed maxRetries send attempts, we move the earliestRpc @@ -223,12 +223,12 @@ void RpcAgent::rpcRetryCallback( "The RPC has not succeeded after the specified number of max retries (", earliestRpc->options_.maxRetries, ")."); - earliestRpc->originalFuture_->setError(errorMessage); + earliestRpc->originalFuture_->setError( + std::make_exception_ptr(std::runtime_error(errorMessage))); } } else { // This try succeeded, so we can make the original future as complete. - earliestRpc->originalFuture_->markCompleted( - std::move(*futureMessage).moveValue()); + earliestRpc->originalFuture_->markCompleted(jitFuture->value()); } } diff --git a/torch/csrc/distributed/rpc/rpc_agent.h b/torch/csrc/distributed/rpc/rpc_agent.h index 64592e6d98f7..c8763c678922 100644 --- a/torch/csrc/distributed/rpc/rpc_agent.h +++ b/torch/csrc/distributed/rpc/rpc_agent.h @@ -105,7 +105,7 @@ struct TORCH_API RpcRetryInfo { RpcRetryInfo( const WorkerInfo& to, Message&& message, - std::shared_ptr originalFuture, + std::shared_ptr originalFuture, int retryCount, RpcRetryOptions options) : to_(to), @@ -117,7 +117,7 @@ struct TORCH_API RpcRetryInfo { const WorkerInfo& to_; Message message_; // Future that is returned to the caller of sendWithRetries(). - std::shared_ptr originalFuture_; + std::shared_ptr originalFuture_; // Number of send attempts completed so far. int retryCount_; RpcRetryOptions options_; @@ -174,7 +174,7 @@ class TORCH_API RpcAgent { // executing a method twice on the remote end (it does not guarantee // exactly-once semantics). Therefore, the user must ensure their requests // are idempotent. - std::shared_ptr sendWithRetries( + std::shared_ptr sendWithRetries( const WorkerInfo& to, Message&& message, RpcRetryOptions retryOptions = RpcRetryOptions()); @@ -339,7 +339,7 @@ class TORCH_API RpcAgent { // error and do not retry again. In case 3, we move the RpcRetryInfo struct // to another time point in the map to schedule the RPC for a future send. void rpcRetryCallback( - const std::shared_ptr& message, + const std::shared_ptr& message, steady_clock_time_point newTime, std::shared_ptr earliestRpc); diff --git a/torch/csrc/distributed/rpc/rref_context.cpp b/torch/csrc/distributed/rpc/rref_context.cpp index c608be534561..ce257c50a7a4 100644 --- a/torch/csrc/distributed/rpc/rref_context.cpp +++ b/torch/csrc/distributed/rpc/rref_context.cpp @@ -112,13 +112,6 @@ void RRefContext::handleException(const JitFuture& jitFuture) { } } -void RRefContext::handleException(const FutureMessage& fm) { - if (fm.hasError()) { - VLOG(1) << "Got exception: " << fm.error()->what(); - throw std::runtime_error(fm.error()->what()); - } -} - RRefContext::RRefContext(std::shared_ptr agent) : agent_(std::move(agent)), destroyed_(false) {} @@ -219,12 +212,13 @@ void RRefContext::delUser( // which is now idempotent. See the comment at RRefContext::delForkOfOwner // for more details. ++numPendingFutures_; - auto fm = agent_->sendWithRetries( + auto jitFuture = agent_->sendWithRetries( agent_->getWorkerInfo(owner), RRefUserDelete(rrefId, forkId).toMessage()); - fm->addCallback([this](const FutureMessage& fm) { - handleException(fm); + std::weak_ptr wp = jitFuture; + jitFuture->addCallback([this, wp]() { + handleException(*wp.lock()); --numPendingFutures_; }); } @@ -493,21 +487,24 @@ void RRefContext::notifyOwnerAndParentOfFork( // into forks_. Because, there will be no real `UserRRef` associated // with this fork ID. ++numPendingFutures_; - auto fm = agent_->sendWithRetries( + auto jitFuture = agent_->sendWithRetries( agent_->getWorkerInfo(parent), RRefChildAccept(forkId).toMessage()); - fm->addCallback([this](const FutureMessage& fm) { - handleException(fm); + std::weak_ptr wp = jitFuture; + jitFuture->addCallback([this, wp]() { + handleException(*wp.lock()); --numPendingFutures_; }); } else { ++numPendingFutures_; - auto fm = agent_->sendWithRetries( + auto jitFuture = agent_->sendWithRetries( agent_->getWorkerInfo(rref->owner()), RRefForkRequest(rref->rrefId(), forkId).toMessage()); addPendingUser(forkId, rref); - fm->addCallback([this, forkId, parent](const FutureMessage& fm) { - handleException(fm); + + std::weak_ptr wp = jitFuture; + jitFuture->addCallback([this, forkId, parent, wp]() { + handleException(*wp.lock()); this->finishForkRequest(forkId, parent); // Decrease after calling finishForkRequest because, as that creates a new // future, it might otherwise cause the count to briefly go to zero. @@ -686,11 +683,12 @@ void RRefContext::clearRecordedPendingRRefsOnError() { void RRefContext::finishForkRequest(const ForkId& forkId, worker_id_t parent) { delPendingUser(forkId); ++numPendingFutures_; - auto fm = agent_->sendWithRetries( + auto jitFuture = agent_->sendWithRetries( agent_->getWorkerInfo(parent), RRefChildAccept(forkId).toMessage()); - fm->addCallback([this](const FutureMessage& fm) { - handleException(fm); + std::weak_ptr wp = jitFuture; + jitFuture->addCallback([this, wp]() { + handleException(*wp.lock()); --numPendingFutures_; }); } diff --git a/torch/csrc/distributed/rpc/rref_context.h b/torch/csrc/distributed/rpc/rref_context.h index 48985fd42798..c8ef52f7379a 100644 --- a/torch/csrc/distributed/rpc/rref_context.h +++ b/torch/csrc/distributed/rpc/rref_context.h @@ -43,7 +43,6 @@ class TORCH_API RRefContext { bool ignoreRRefLeak = true); static void handleException(const JitFuture& jitFuture); - static void handleException(const FutureMessage& fm); RRefContext(const RRefContext&) = delete; RRefContext(RRefContext&& other) = delete; From 2d5f57cf3bc227bf8d5e19c809372004802a67ab Mon Sep 17 00:00:00 2001 From: Shen Li Date: Thu, 7 Jan 2021 19:43:44 -0800 Subject: [PATCH 19/54] Completely remove FutureMessage from RRef Implementations (#50004) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/50004 Test Plan: Imported from OSS Reviewed By: lw Differential Revision: D25750602 Pulled By: mrshenli fbshipit-source-id: 06854a77f4fb5cc4c34a1ede843301157ebf7309 --- torch/csrc/distributed/rpc/py_rref.cpp | 8 ++++---- torch/csrc/distributed/rpc/rref_impl.cpp | 22 +++++++++++----------- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/torch/csrc/distributed/rpc/py_rref.cpp b/torch/csrc/distributed/rpc/py_rref.cpp index 4830cd5424b1..7c39d1d812a7 100644 --- a/torch/csrc/distributed/rpc/py_rref.cpp +++ b/torch/csrc/distributed/rpc/py_rref.cpp @@ -331,11 +331,11 @@ void PyRRef::backward( // Invoke distributed backward remotely. auto rpcAgent = rpc::RpcAgent::getCurrentRpcAgent(); - rpc::RpcAgent::toFutureMessage( - rpcAgent->send( + rpcAgent + ->send( rpcAgent->getWorkerInfo(rref->owner()), - std::move(rrefBackwardReq).toMessage())) - ->wait(); + std::move(rrefBackwardReq).toMessage()) + ->waitAndThrow(); } } diff --git a/torch/csrc/distributed/rpc/rref_impl.cpp b/torch/csrc/distributed/rpc/rref_impl.cpp index 54f7320567a9..f5bcf39400fc 100644 --- a/torch/csrc/distributed/rpc/rref_impl.cpp +++ b/torch/csrc/distributed/rpc/rref_impl.cpp @@ -170,21 +170,21 @@ IValue UserRRef::toHere(const float timeoutSeconds) const { // toHere is profiled as a blocking call, and does not execute operations on // the remote node. Hence, don't wrap it with a profiling message since we // don't need the profiler to be enabled remotely. - auto futureResponse = RpcAgent::toFutureMessage( - autograd::sendMessageWithAutograd( - *agent, - agent->getWorkerInfo(ownerId_), - std::move(msgToSend), - true /* forceGradRecording */, - timeoutSeconds, - true /* forceDisableProfiling */)); + auto jitFuture = autograd::sendMessageWithAutograd( + *agent, + agent->getWorkerInfo(ownerId_), + std::move(msgToSend), + true /* forceGradRecording */, + timeoutSeconds, + true /* forceDisableProfiling */); // TODO: we should ideally be able to interrupt this blocking wait if we check // getTimedOut() and it is true // (https://github.com/pytorch/pytorch/issues/39411). - const Message& message = futureResponse->wait(); - MessageType msgType = message.type(); - auto response = deserializeResponse(message, msgType); + jitFuture->waitAndThrow(); + auto messagePtr = jitFuture->constValue().toCustomClass(); + MessageType msgType = messagePtr->type(); + auto response = deserializeResponse(*messagePtr, msgType); TORCH_INTERNAL_ASSERT( msgType == MessageType::SCRIPT_RREF_FETCH_RET || msgType == MessageType::PYTHON_RREF_FETCH_RET, From b2da0b5afed2d661b18cd42b5cd7f8c93129e29f Mon Sep 17 00:00:00 2001 From: Shen Li Date: Thu, 7 Jan 2021 19:43:44 -0800 Subject: [PATCH 20/54] Completely remove FutureMessage from RPC TorchScript implementations (#50005) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/50005 Test Plan: Imported from OSS Reviewed By: lw Differential Revision: D25750663 Pulled By: mrshenli fbshipit-source-id: 6d97156b61d82aa19dd0567ca72fe04bd7b5d1e7 --- .../distributed/rpc/torchscript_functions.cpp | 33 +++++++++---------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/torch/csrc/distributed/rpc/torchscript_functions.cpp b/torch/csrc/distributed/rpc/torchscript_functions.cpp index c6efd8eaba7d..eceeb3401376 100644 --- a/torch/csrc/distributed/rpc/torchscript_functions.cpp +++ b/torch/csrc/distributed/rpc/torchscript_functions.cpp @@ -14,7 +14,7 @@ namespace torch { namespace distributed { namespace rpc { -c10::intrusive_ptr rpcTorchscript( +c10::intrusive_ptr rpcTorchscript( const std::string& dstWorkerName, const c10::QualifiedName& qualifiedName, const c10::FunctionSchema& functionSchema, @@ -43,15 +43,14 @@ c10::intrusive_ptr rpcTorchscript( auto scriptCall = std::make_unique( qualifiedName, std::move(stack), isAsyncExecution); auto rpcAgentPtr = RpcAgent::getCurrentRpcAgent(); - auto futMessage = RpcAgent::toFutureMessage( - autograd::sendMessageWithAutograd( - *rpcAgentPtr, - rpcAgentPtr->getWorkerInfo(dstWorkerName), - std::move(*scriptCall).toMessage(), - true /*forceGradRecording*/, - rpcTimeoutSeconds)); + auto jitFuture = autograd::sendMessageWithAutograd( + *rpcAgentPtr, + rpcAgentPtr->getWorkerInfo(dstWorkerName), + std::move(*scriptCall).toMessage(), + true /*forceGradRecording*/, + rpcTimeoutSeconds); - // Get function return type to construct c10::ivalue::Future. + // Get function return type to construct JitFuture. auto returns = functionSchema.returns(); // Script call only allows single IValue returned. TORCH_INTERNAL_ASSERT( @@ -63,15 +62,15 @@ c10::intrusive_ptr rpcTorchscript( // Create a JIT future and pass it to futMessage's callback to set state // of the JIT future. - auto futPtr = c10::make_intrusive(returnType); - std::weak_ptr wp = futMessage; - futMessage->addCallback(at::wrapPropagateTLSState([futPtr, wp]() { - auto futMessage = wp.lock(); - if (futMessage->hasError()) { - c10::ivalue::Future::FutureError jitFutErr(futMessage->error()->what()); - futPtr->setError(std::make_exception_ptr(jitFutErr)); + auto futPtr = c10::make_intrusive(returnType); + std::weak_ptr wp = jitFuture; + jitFuture->addCallback(at::wrapPropagateTLSState([futPtr, wp]() { + auto future = wp.lock(); + if (future->hasError()) { + futPtr->setError(future->exception_ptr()); } else { - futPtr->markCompleted(deserializeRespToIValue(futMessage->constValue())); + futPtr->markCompleted(deserializeRespToIValue( + *future->constValue().toCustomClass())); } })); if (shouldProfile) { From 0c943931aad432f54fd995bb02ff092e126e8577 Mon Sep 17 00:00:00 2001 From: Shen Li Date: Thu, 7 Jan 2021 19:43:44 -0800 Subject: [PATCH 21/54] Completely remove FutureMessage from distributed autograd (#50020) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/50020 Test Plan: Imported from OSS Reviewed By: lw Differential Revision: D25752968 Pulled By: mrshenli fbshipit-source-id: 138d37e204b6f9a584633cfc79fd44c8c9c00f41 --- .../distributed/autograd/context/context.cpp | 23 +++++++++++-------- .../distributed/autograd/context/context.h | 4 ++-- .../autograd/functions/recvrpc_backward.cpp | 7 +++--- 3 files changed, 19 insertions(+), 15 deletions(-) diff --git a/torch/csrc/distributed/autograd/context/context.cpp b/torch/csrc/distributed/autograd/context/context.cpp index 6527fc25b92b..526ca053dd40 100644 --- a/torch/csrc/distributed/autograd/context/context.cpp +++ b/torch/csrc/distributed/autograd/context/context.cpp @@ -123,26 +123,27 @@ void DistAutogradContext::resetGraphTask() { } void DistAutogradContext::addOutstandingRpc( - const std::shared_ptr& futureMessage) { - futureMessage->addCallback([this](const rpc::FutureMessage& futureMessage) { - if (futureMessage.hasError()) { + const std::shared_ptr& jitFuture) { + std::weak_ptr wp = jitFuture; + jitFuture->addCallback([this, wp]() { + auto future = wp.lock(); + if (future->hasError()) { // If we have an error, let the local autograd engine know about it. std::unique_lock lock(lock_); if (graphTask_) { graphTask_->set_exception_without_signal(nullptr); lock.unlock(); if (!graphTask_->future_completed_.exchange(true)) { - graphTask_->future_result_->setErrorIfNeeded( - std::make_exception_ptr(*futureMessage.error())); + graphTask_->future_result_->setErrorIfNeeded(future->exception_ptr()); } } else { LOG(WARNING) << "Ignoring error since GraphTask is no longer valid: " - << (*futureMessage.error()).what(); + << future->tryRetrieveErrorMessage(); } } }); std::lock_guard guard(lock_); - outStandingRpcs_.push_back(futureMessage); + outStandingRpcs_.push_back(jitFuture); } void DistAutogradContext::clearOutstandingRpcs() { @@ -170,8 +171,10 @@ std::shared_ptr DistAutogradContext:: state->future->markCompleted(c10::IValue()); } else { for (auto& rpc : outStandingRpcs) { - rpc->addCallback([state](const rpc::FutureMessage& rpc) { - if (rpc.hasError()) { + std::weak_ptr wp = rpc; + rpc->addCallback([state, wp]() { + auto future = wp.lock(); + if (future->hasError()) { // If there's an error, we want to setError() on the future, // unless another error has already been sent - use a CAS to // guard. @@ -183,7 +186,7 @@ std::shared_ptr DistAutogradContext:: bool expectedAlreadySent = false; if (state->alreadySentError.compare_exchange_strong( expectedAlreadySent, true)) { - state->future->setError(std::make_exception_ptr(*rpc.error())); + state->future->setError(future->exception_ptr()); } return; } diff --git a/torch/csrc/distributed/autograd/context/context.h b/torch/csrc/distributed/autograd/context/context.h index e7d73962634b..b611040af448 100644 --- a/torch/csrc/distributed/autograd/context/context.h +++ b/torch/csrc/distributed/autograd/context/context.h @@ -52,7 +52,7 @@ class TORCH_API DistAutogradContext { // Adds a future message recording an outstanding RPC. void addOutstandingRpc( - const std::shared_ptr& futureMessage); + const std::shared_ptr& jitFuture); // Returns all gradients. const c10::Dict getGradients() const; @@ -134,7 +134,7 @@ class TORCH_API DistAutogradContext { // List of futures for RPCs initiated by this node to propagate gradients to // other nodes. The distributed autograd engine on this node can return // successfully only if all these futures are done and are successful. - std::vector> outStandingRpcs_; + std::vector> outStandingRpcs_; // Lock to protect concurrent modification of the context. mutable std::mutex lock_; diff --git a/torch/csrc/distributed/autograd/functions/recvrpc_backward.cpp b/torch/csrc/distributed/autograd/functions/recvrpc_backward.cpp index 08f18da081a5..509c5c6cbd08 100644 --- a/torch/csrc/distributed/autograd/functions/recvrpc_backward.cpp +++ b/torch/csrc/distributed/autograd/functions/recvrpc_backward.cpp @@ -47,11 +47,12 @@ variable_list RecvRpcBackward::apply(variable_list&& grads) { // Send the gradients over to the appropriate node. auto rpcAgent = rpc::RpcAgent::getCurrentRpcAgent(); - auto futureMessage = rpc::RpcAgent::toFutureMessage(rpcAgent->send( - rpcAgent->getWorkerInfo(fromWorkerId_), std::move(gradCall).toMessage())); + auto jitFuture = rpcAgent->send( + rpcAgent->getWorkerInfo(fromWorkerId_), + std::move(gradCall).toMessage()); // Record the future in the context. - sharedContext->addOutstandingRpc(futureMessage); + sharedContext->addOutstandingRpc(jitFuture); // 'recv' function sends the gradients over the wire using RPC, it doesn't // need to return anything for any downstream autograd function. From 1deb89507454c7e57ccd453f75e4396082d73980 Mon Sep 17 00:00:00 2001 From: Shen Li Date: Thu, 7 Jan 2021 19:43:44 -0800 Subject: [PATCH 22/54] Remove FutureMessage from sender ProcessGroupAgent (#50023) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/50023 Test Plan: Imported from OSS Reviewed By: lw Differential Revision: D25753217 Pulled By: mrshenli fbshipit-source-id: 5a98473c17535c8f92043abe143064e7fca4413b --- .../distributed/rpc/process_group_agent.cpp | 29 ++++++++++--------- .../distributed/rpc/process_group_agent.h | 6 ++-- 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/torch/csrc/distributed/rpc/process_group_agent.cpp b/torch/csrc/distributed/rpc/process_group_agent.cpp index 71691d9dfbb1..9f37535c485a 100644 --- a/torch/csrc/distributed/rpc/process_group_agent.cpp +++ b/torch/csrc/distributed/rpc/process_group_agent.cpp @@ -319,7 +319,7 @@ std::shared_ptr ProcessGroupAgent::send( pg_->getRank()); auto requestId = nextId(); - auto future = std::make_shared(); + auto future = std::make_shared(at::AnyClassType::get()); if (message.isRequest()) { // millisecond level precision of when request started. auto futureStartTime = std::chrono::steady_clock::now(); @@ -362,14 +362,14 @@ std::shared_ptr ProcessGroupAgent::send( message.setId(requestId); ++clientActiveCalls_; } else { - future->markCompleted(Message()); + future->markCompleted(IValue()); } // Sending to ourselves: bypass the send logic and enqueue directly // to our receiving queue. if (to.id_ == (worker_id_t)pg_->getRank()) { sendToSelf(std::move(message)); - return toJitFuture(std::move(future)); + return future; } // NB: cannot directly pass ``to`` to the ``SendWork``, because it might no @@ -383,8 +383,7 @@ std::shared_ptr ProcessGroupAgent::send( // C++ land. enqueueSend(SendWork(allWorkerInfo_[to.id_], std::move(message))); - auto jitFuture = toJitFuture(std::move(future)); - return jitFuture; + return future; } void ProcessGroupAgent::handleSend(const SendWork& work) { @@ -560,7 +559,7 @@ bool ProcessGroupAgent::handleRecv(RecvWork& work) { } } else if (message.isResponse()) { auto id = message.id(); - std::shared_ptr fm = nullptr; + std::shared_ptr jitFuture = nullptr; { std::lock_guard lock{futureMutex_}; const auto& futureInfo = futures_.find(id); @@ -572,7 +571,7 @@ bool ProcessGroupAgent::handleRecv(RecvWork& work) { return false; } // Use futureInfo before destructing it. - fm = futureInfo->second.future_; + jitFuture = futureInfo->second.future_; auto endTime = futureInfo->second.endTime_; futures_.erase(id); // look up the corresponding future by its time out and request @@ -591,10 +590,11 @@ bool ProcessGroupAgent::handleRecv(RecvWork& work) { futureCV_.notify_all(); --clientActiveCalls_; if (message.type() == MessageType::EXCEPTION) { - fm->setError( - std::string(message.payload().begin(), message.payload().end())); + jitFuture->setError(std::make_exception_ptr(std::runtime_error( + std::string(message.payload().begin(), message.payload().end())))); } else { - fm->markCompleted(std::move(message)); + jitFuture->markCompleted( + IValue(c10::make_intrusive(std::move(message)))); } } else { // TODO: pass the error back to the caller instead of crashing here. @@ -645,7 +645,7 @@ void ProcessGroupAgent::markFutureWithError(Message& message) { } void ProcessGroupAgent::markFutureWithError(int64_t id, std::string errorMsg) { - std::shared_ptr fm = nullptr; + std::shared_ptr jitFuture = nullptr; { std::lock_guard lock{futureMutex_}; const auto& futureInfo = futures_.find(id); @@ -655,7 +655,7 @@ void ProcessGroupAgent::markFutureWithError(int64_t id, std::string errorMsg) { // out and been processed accordingly. return; } - fm = futureInfo->second.future_; + jitFuture = futureInfo->second.future_; auto rpcEndTime = futureInfo->second.endTime_; futures_.erase(id); // look up the corresponding future by its time out and request ID, @@ -673,7 +673,7 @@ void ProcessGroupAgent::markFutureWithError(int64_t id, std::string errorMsg) { } --clientActiveCalls_; - fm->setError(std::move(errorMsg)); + jitFuture->setError(std::make_exception_ptr(std::runtime_error(errorMsg))); futureCV_.notify_all(); } @@ -805,7 +805,8 @@ void ProcessGroupAgent::pollTimedOutRPCs() { if (!timedOutFuture.future_->hasError()) { --clientActiveCalls_; - timedOutFuture.future_->setError(std::move(err)); + timedOutFuture.future_->setError( + std::make_exception_ptr(std::runtime_error(err))); // The future timed out and will not be processed by handleRecv(), even // if we eventually get a response. In order to keep track of all // send/recv pairs, we increment the count here. diff --git a/torch/csrc/distributed/rpc/process_group_agent.h b/torch/csrc/distributed/rpc/process_group_agent.h index 3d4393475f31..8d2471a7d113 100644 --- a/torch/csrc/distributed/rpc/process_group_agent.h +++ b/torch/csrc/distributed/rpc/process_group_agent.h @@ -130,16 +130,16 @@ class TORCH_API ProcessGroupAgent : public RpcAgent { // additional information to manage timeouts and destination information, // which is needed for termination detection. struct FutureInfo { - std::shared_ptr future_; + std::shared_ptr future_; steady_clock_time_point endTime_; int dstRank_; std::chrono::milliseconds timeout_; FutureInfo( - const std::shared_ptr& future, + std::shared_ptr future, const steady_clock_time_point& endTime, int dstRank, const std::chrono::milliseconds timeout) - : future_(future), + : future_(std::move(future)), endTime_(endTime), dstRank_(dstRank), timeout_(timeout) {} From 0684d074256bac5fe2ed0c4180625d5d78c65a75 Mon Sep 17 00:00:00 2001 From: Shen Li Date: Thu, 7 Jan 2021 19:43:44 -0800 Subject: [PATCH 23/54] Remove FutureMessage from sender TensorPipeAgent (#50024) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/50024 Test Plan: Imported from OSS Reviewed By: lw Differential Revision: D25753386 Pulled By: mrshenli fbshipit-source-id: fdca051b805762a2c88f965ceb3edf1c25d40a56 --- .../csrc/distributed/rpc/tensorpipe_agent.cpp | 29 ++++++++++--------- torch/csrc/distributed/rpc/tensorpipe_agent.h | 13 +++++---- 2 files changed, 22 insertions(+), 20 deletions(-) diff --git a/torch/csrc/distributed/rpc/tensorpipe_agent.cpp b/torch/csrc/distributed/rpc/tensorpipe_agent.cpp index 8e5eb15057b0..3050a40ec608 100644 --- a/torch/csrc/distributed/rpc/tensorpipe_agent.cpp +++ b/torch/csrc/distributed/rpc/tensorpipe_agent.cpp @@ -637,7 +637,7 @@ std::shared_ptr TensorPipeAgent::send( ClientPipe& clientPipe = it->second; auto& pendingResponseMessage = clientPipe.pendingResponseMessage_; - auto futureResponseMessage = std::make_shared(); + auto futureResponseMessage = std::make_shared(); uint64_t messageId = nextMessageID_++; requestMessage.setId(messageId); pendingResponseMessage[messageId] = futureResponseMessage; @@ -649,7 +649,7 @@ std::shared_ptr TensorPipeAgent::send( auto devices = getDevicesForTensors(clientPipe.pipe_->getRemoteName(), requestMessage); - futureResponseMessage->futMsg.addCallback([this]() { + futureResponseMessage->jitFuture->addCallback([this]() { TORCH_INTERNAL_ASSERT( this->threadPool_.inThreadPool(), "Future marked complete from outside the thread pool"); @@ -747,7 +747,7 @@ std::shared_ptr TensorPipeAgent::send( << " received response #" << messageId << " from " << clientPipe.pipe_->getRemoteName(); - std::shared_ptr futureResponseMessage; + std::shared_ptr futureResponseMessage; { std::lock_guard lock(mutex_); // A read error will lead all following callbacks to be @@ -778,8 +778,7 @@ std::shared_ptr TensorPipeAgent::send( }); }); - return toJitFuture(std::shared_ptr( - futureResponseMessage, &futureResponseMessage->futMsg)); + return futureResponseMessage->jitFuture; } void TensorPipeAgent::pollTimeoutRpcs() { @@ -808,7 +807,7 @@ void TensorPipeAgent::pollTimeoutRpcs() { // Move all these futures to a separate vector so we can process them // outside the lock. std::vector, + std::shared_ptr, std::chrono::milliseconds>> timedOutFutures = std::move(timeoutMap_.begin()->second); // We can safely remove this key from the timeoutMap_ since all these @@ -1026,16 +1025,17 @@ void TensorPipeAgent::decreaseCallCount(int32_t& count) { } void TensorPipeAgent::markFutureAsComplete( - std::shared_ptr futureMessage, + std::shared_ptr atomicFuture, Message message) { - if (!futureMessage->isComplete.test_and_set()) { + if (!atomicFuture->isComplete.test_and_set()) { // Completing the future will run its callbacks, which could execute // arbitrary user code. To prevent blocking or stalling the TensorPipe event // loops, we defer this to a worker thread. threadPool_.run([this, - futureMessage{std::move(futureMessage)}, + atomicFuture{std::move(atomicFuture)}, message{std::move(message)}]() mutable { - futureMessage->futMsg.markCompleted(std::move(message)); + atomicFuture->jitFuture->markCompleted( + IValue(c10::make_intrusive(std::move(message)))); // The future's callbacks may schedule further RPCs, increasing the count. // Thus we must decrease it after completing the future, otherwise it may // briefly dip to zero and trick join into thinking all work is done. @@ -1045,16 +1045,17 @@ void TensorPipeAgent::markFutureAsComplete( } void TensorPipeAgent::markFutureWithError( - std::shared_ptr futureMessage, + std::shared_ptr atomicFuture, std::string errorMsg) { - if (!futureMessage->isComplete.test_and_set()) { + if (!atomicFuture->isComplete.test_and_set()) { // Completing the future will run its callbacks, which could execute // arbitrary user code. To prevent blocking or stalling the TensorPipe event // loops, we defer this to a worker thread. threadPool_.run([this, - futureMessage{std::move(futureMessage)}, + atomicFuture{std::move(atomicFuture)}, errorMsg{std::move(errorMsg)}]() mutable { - futureMessage->futMsg.setError(std::move(errorMsg)); + atomicFuture->jitFuture->setError( + std::make_exception_ptr(std::runtime_error(errorMsg))); // The future's callbacks may schedule further RPCs, increasing the count. // Thus we must decrease it after completing the future, otherwise it may // briefly dip to zero and trick join into thinking all work is done. diff --git a/torch/csrc/distributed/rpc/tensorpipe_agent.h b/torch/csrc/distributed/rpc/tensorpipe_agent.h index 83ade1cf9203..022f92241044 100644 --- a/torch/csrc/distributed/rpc/tensorpipe_agent.h +++ b/torch/csrc/distributed/rpc/tensorpipe_agent.h @@ -273,8 +273,9 @@ class TensorPipeAgent : public RpcAgent { // only if it isn't yet. It does exist for errors (setErrorIfNeeded) but, even // then, it ends up printing a log message, which may worry the user. To solve // both issues we use a separate atomic flag to know the status of the future. - struct AtomicFutureMessage { - FutureMessage futMsg; + struct AtomicJitFuture { + std::shared_ptr jitFuture = + std::make_shared(at::AnyClassType::get()); std::atomic_flag isComplete = ATOMIC_FLAG_INIT; }; @@ -288,7 +289,7 @@ class TensorPipeAgent : public RpcAgent { std::shared_ptr pipe_; bool readError_{false}; // Map from Message Request ID's to corresponding futures. - std::unordered_map> + std::unordered_map> pendingResponseMessage_; }; @@ -321,7 +322,7 @@ class TensorPipeAgent : public RpcAgent { std::map< steady_clock_time_point, std::vector, + std::shared_ptr, std::chrono::milliseconds>>> timeoutMap_; @@ -394,10 +395,10 @@ class TensorPipeAgent : public RpcAgent { // Helpers to set the state of the requests. void markFutureAsComplete( - std::shared_ptr futureMessage, + std::shared_ptr atomicFuture, Message message); void markFutureWithError( - std::shared_ptr futureMessage, + std::shared_ptr atomicFuture, std::string errorMsg); }; From 2831af983799ed668e2a90d65050618fd1cdbfbc Mon Sep 17 00:00:00 2001 From: Shen Li Date: Thu, 7 Jan 2021 19:43:44 -0800 Subject: [PATCH 24/54] Completely remove FutureMessage from FaultyProcessGroupAgent (#50025) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/50025 Test Plan: Imported from OSS Reviewed By: lw Differential Revision: D25753587 Pulled By: mrshenli fbshipit-source-id: a5d4106a10d1b0d3e4c406751795f19af8afd120 --- .../rpc/testing/faulty_process_group_agent.cpp | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/torch/csrc/distributed/rpc/testing/faulty_process_group_agent.cpp b/torch/csrc/distributed/rpc/testing/faulty_process_group_agent.cpp index ba2229932145..7f6c3015f544 100644 --- a/torch/csrc/distributed/rpc/testing/faulty_process_group_agent.cpp +++ b/torch/csrc/distributed/rpc/testing/faulty_process_group_agent.cpp @@ -78,11 +78,12 @@ std::shared_ptr FaultyProcessGroupAgent::send( if (failMessageCountMap_[key] < failNumSends_) { failMessageCountMap_[key]++; lock.unlock(); - auto fm = std::make_shared(); - fm->setError(makeRPCError( - c10::str("Send attempt failed intentionally for ", key), - RPCErrorType::INTENTIONAL_FAILURE)); - return toJitFuture(std::move(fm)); + auto jitFuture = std::make_shared(at::AnyClassType::get()); + jitFuture->setError( + std::make_exception_ptr(std::runtime_error(makeRPCError( + c10::str("Send attempt failed intentionally for ", key), + RPCErrorType::INTENTIONAL_FAILURE)))); + return jitFuture; } else { lock.unlock(); return ProcessGroupAgent::send(to, std::move(message), rpcTimeoutSeconds); From 1f795e1a9b937e5ca83e589030f69092d33ce4dc Mon Sep 17 00:00:00 2001 From: Shen Li Date: Thu, 7 Jan 2021 19:43:44 -0800 Subject: [PATCH 25/54] Remove FutureMessage from RPC request callback logic (#50026) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/50026 Test Plan: Imported from OSS Reviewed By: lw Differential Revision: D25753588 Pulled By: mrshenli fbshipit-source-id: a6fcda7830901dd812fbf0489b001e6bd9673780 --- .../distributed/rpc/request_callback_impl.cpp | 72 +++++++++-------- .../distributed/rpc/request_callback_impl.h | 12 +-- .../rpc/request_callback_no_python.cpp | 77 +++++++++++-------- .../rpc/request_callback_no_python.h | 26 +++---- 4 files changed, 102 insertions(+), 85 deletions(-) diff --git a/torch/csrc/distributed/rpc/request_callback_impl.cpp b/torch/csrc/distributed/rpc/request_callback_impl.cpp index 2004178565ea..3fa6b63953f0 100644 --- a/torch/csrc/distributed/rpc/request_callback_impl.cpp +++ b/torch/csrc/distributed/rpc/request_callback_impl.cpp @@ -88,12 +88,12 @@ std::unique_ptr deserializePythonRpcCommandReference( void processAsyncExecution( const py::object& pyFn, const int64_t messageId, - const std::shared_ptr& responseFuture, + const std::shared_ptr& responseFuture, std::function&)> postProcessing) { + const std::shared_ptr&)> postProcessing) { std::shared_ptr pyFuture; auto& pythonRpcHandler = PythonRpcHandler::getInstance(); { @@ -151,7 +151,7 @@ void RequestCallbackImpl::processScriptCall( RpcCommandBase& rpc, const std::function& markComplete, const int64_t messageId, - const std::shared_ptr& responseFuture) const { + const std::shared_ptr& responseFuture) const { auto& scriptCall = static_cast(rpc); auto& stack = scriptCall.stackRef(); if (processScriptCallOp(scriptCall, markComplete, stack)) { @@ -176,13 +176,14 @@ void RequestCallbackImpl::processScriptCall( try { Message m = ScriptResp(valueJitFuture->value()).toMessage(); m.setId(messageId); - responseFuture->markCompleted(std::move(m)); - } catch (const std::exception& e) { - responseFuture->setError(e.what()); + responseFuture->markCompleted( + IValue(c10::make_intrusive(std::move(m)))); + } catch (const std::exception& /* unused */) { + responseFuture->setError(std::current_exception()); } }); - } catch (const std::exception& e) { - responseFuture->setError(e.what()); + } catch (const std::exception& /* unused */) { + responseFuture->setError(std::current_exception()); } }); } else { @@ -195,9 +196,10 @@ void RequestCallbackImpl::processScriptCall( try { Message m = ScriptResp(jitFuture->value()).toMessage(); m.setId(messageId); - responseFuture->markCompleted(std::move(m)); - } catch (const std::exception& e) { - responseFuture->setError(e.what()); + responseFuture->markCompleted( + IValue(c10::make_intrusive(std::move(m)))); + } catch (const std::exception& /* unused */) { + responseFuture->setError(std::current_exception()); } }); } @@ -207,7 +209,7 @@ void RequestCallbackImpl::processPythonCall( RpcCommandBase& rpc, const std::function& markComplete, const int64_t messageId, - const std::shared_ptr& responseFuture) const { + const std::shared_ptr& responseFuture) const { auto& upc = static_cast(rpc); if (upc.isAsyncExecution()) { try { @@ -218,17 +220,18 @@ void RequestCallbackImpl::processPythonCall( [](const py::object& result, const int64_t messageId, PythonRpcHandler& pythonRpcHandler, - const std::shared_ptr& responseFuture) { + const std::shared_ptr& responseFuture) { auto serializedPyObj = pythonRpcHandler.serialize(result); py::gil_scoped_release release; auto m = std::move(PythonResp(std::move(serializedPyObj))).toMessage(); m.setId(messageId); - responseFuture->markCompleted(std::move(m)); + responseFuture->markCompleted( + IValue(c10::make_intrusive(std::move(m)))); }); } catch (std::exception& e) { - responseFuture->markCompleted( - createExceptionResponse(e.what(), messageId)); + responseFuture->markCompleted(IValue(c10::make_intrusive( + createExceptionResponse(e.what(), messageId)))); } } else { auto& pythonRpcHandler = PythonRpcHandler::getInstance(); @@ -335,7 +338,7 @@ void RequestCallbackImpl::processPythonRemoteCall( RpcCommandBase& rpc, const std::function& markComplete, const int64_t messageId, - const std::shared_ptr& responseFuture) const { + const std::shared_ptr& responseFuture) const { auto& uprc = static_cast(rpc); const auto& rrefId = uprc.rrefId(); @@ -373,20 +376,22 @@ void RequestCallbackImpl::processPythonRemoteCall( const py::object& result, const int64_t messageId, PythonRpcHandler& /* unused */, - const std::shared_ptr& responseFuture) { + const std::shared_ptr& responseFuture) { IValue py_ivalue = jit::toIValue(result, PyObjectType::get()); py::gil_scoped_release release; ownerRRef->setValue(std::move(py_ivalue)); auto m = RemoteRet(rrefId, forkId).toMessage(); m.setId(messageId); - responseFuture->markCompleted(std::move(m)); + responseFuture->markCompleted( + IValue(c10::make_intrusive(std::move(m)))); }); } catch (std::exception& e) { ownerRRef->setError(std::current_exception()); auto m = RemoteRet(rrefId, forkId).toMessage(); m.setId(messageId); - responseFuture->markCompleted(std::move(m)); + responseFuture->markCompleted( + IValue(c10::make_intrusive(std::move(m)))); } } else { IValue py_ivalue; @@ -414,14 +419,14 @@ void RequestCallbackImpl::processPythonRemoteCall( void RequestCallbackImpl::processPythonRRefFetchCall( RpcCommandBase& rpc, const int64_t messageId, - const std::shared_ptr& responseFuture) const { + const std::shared_ptr& responseFuture) const { // Making this lambda mutable to allow move-capture it in callbacks auto postProcessing = [responseFuture]( const c10::intrusive_ptr& rref, int64_t messageId) mutable { auto whenValueSet = rref->getFuture(); if (whenValueSet->hasError()) { - responseFuture->setError(whenValueSet->tryRetrieveErrorMessage()); + responseFuture->setError(whenValueSet->exception_ptr()); return; } try { @@ -437,15 +442,17 @@ void RequestCallbackImpl::processPythonRRefFetchCall( Message m = PythonRRefFetchRet(std::move(*result).toIValues()).toMessage(); m.setId(messageId); - responseFuture->markCompleted(std::move(m)); + responseFuture->markCompleted( + IValue(c10::make_intrusive(std::move(m)))); } catch (py::error_already_set& e) { // py::error_already_set requires GIL to destruct, take special care. - responseFuture->setError(e.what()); + responseFuture->setError( + std::make_exception_ptr(std::runtime_error(e.what()))); py::gil_scoped_acquire acquire; e.restore(); PyErr_Clear(); - } catch (const std::exception& e) { - responseFuture->setError(e.what()); + } catch (const std::exception& /* unused */) { + responseFuture->setError(std::current_exception()); } }; @@ -487,7 +494,7 @@ void RequestCallbackImpl::processRpcWithErrors( RpcCommandBase& rpc, const MessageType& messageType, const int64_t messageId, - const std::shared_ptr& responseFuture) const { + const std::shared_ptr& responseFuture) const { try { processRpc(rpc, messageType, messageId, responseFuture); } catch (py::error_already_set& e) { @@ -516,7 +523,7 @@ bool RequestCallbackImpl::cudaAvailable() const { void RequestCallbackImpl::processRRefBackward( RpcCommandBase& rpc, const int64_t messageId, - const std::shared_ptr& responseFuture) const { + const std::shared_ptr& responseFuture) const { auto& rrefBackwardReq = static_cast(rpc); // Get all fields @@ -540,7 +547,7 @@ void RequestCallbackImpl::processRRefBackward( autogradContextId, retainGraph]() { if (whenValueSet->hasError()) { - responseFuture->setError(whenValueSet->tryRetrieveErrorMessage()); + responseFuture->setError(whenValueSet->exception_ptr()); return; } @@ -551,9 +558,10 @@ void RequestCallbackImpl::processRRefBackward( // Return the response. Message m = RRefBackwardResp().toMessage(); m.setId(messageId); - responseFuture->markCompleted(std::move(m)); - } catch (const std::exception& e) { - responseFuture->setError(e.what()); + responseFuture->markCompleted( + IValue(c10::make_intrusive(std::move(m)))); + } catch (const std::exception& /* unused */) { + responseFuture->setError(std::current_exception()); } }); }); diff --git a/torch/csrc/distributed/rpc/request_callback_impl.h b/torch/csrc/distributed/rpc/request_callback_impl.h index bf43ef867cff..2883359af303 100644 --- a/torch/csrc/distributed/rpc/request_callback_impl.h +++ b/torch/csrc/distributed/rpc/request_callback_impl.h @@ -18,13 +18,13 @@ class TORCH_API RequestCallbackImpl : public RequestCallbackNoPython { RpcCommandBase& rpc, const std::function& markComplete, const int64_t messageId, - const std::shared_ptr& responseFuture) const override; + const std::shared_ptr& responseFuture) const override; void processScriptCall( RpcCommandBase& rpc, const std::function& markComplete, const int64_t messageId, - const std::shared_ptr& responseFuture) const override; + const std::shared_ptr& responseFuture) const override; TypePtr getScriptRemoteCallType( ScriptRemoteCall& scriptRemoteCall) const override; @@ -39,12 +39,12 @@ class TORCH_API RequestCallbackImpl : public RequestCallbackNoPython { RpcCommandBase& rpc, const std::function& markComplete, const int64_t messageId, - const std::shared_ptr& responseFuture) const override; + const std::shared_ptr& responseFuture) const override; void processPythonRRefFetchCall( RpcCommandBase& rpc, const int64_t messageId, - const std::shared_ptr& responseFuture) const override; + const std::shared_ptr& responseFuture) const override; void handleRRefDelete(c10::intrusive_ptr& rref) const override; @@ -52,14 +52,14 @@ class TORCH_API RequestCallbackImpl : public RequestCallbackNoPython { RpcCommandBase& rpc, const MessageType& messageType, const int64_t messageId, - const std::shared_ptr& responseFuture) const override; + const std::shared_ptr& responseFuture) const override; bool cudaAvailable() const override; void processRRefBackward( RpcCommandBase& rpc, const int64_t messageId, - const std::shared_ptr& responseFuture) const override; + const std::shared_ptr& responseFuture) const override; }; } // namespace rpc diff --git a/torch/csrc/distributed/rpc/request_callback_no_python.cpp b/torch/csrc/distributed/rpc/request_callback_no_python.cpp index 7f8db89f55bf..5f2765b32a01 100644 --- a/torch/csrc/distributed/rpc/request_callback_no_python.cpp +++ b/torch/csrc/distributed/rpc/request_callback_no_python.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -53,7 +54,7 @@ std::shared_ptr RequestCallbackNoPython::processMessage( // RPC message: // 1) waiting for all RRefs in the arguments to become confirmed; // 2) waiting for processRpc to finish. - auto retFuture = std::make_shared(); + auto retFuture = std::make_shared(at::AnyClassType::get()); auto& rrefContext = RRefContext::getInstance(); try { rrefContext.recordThreadLocalPendingRRefs(); @@ -104,14 +105,14 @@ std::shared_ptr RequestCallbackNoPython::processMessage( retFuture->markCompleted(handleError(e, request.type(), request.id())); rrefContext.clearRecordedPendingRRefsOnError(); } - return retFuture; + return RpcAgent::toFutureMessage(std::move(retFuture)); } void RequestCallbackNoPython::processRpcWithErrors( RpcCommandBase& rpc, const MessageType& messageType, const int64_t messageId, - const std::shared_ptr& responseFuture) const { + const std::shared_ptr& responseFuture) const { try { processRpc(rpc, messageType, messageId, responseFuture); } catch (std::exception& e) { @@ -123,7 +124,7 @@ void RequestCallbackNoPython::processScriptCall( RpcCommandBase& rpc, const std::function& markComplete, const int64_t messageId, - const std::shared_ptr& responseFuture) const { + const std::shared_ptr& /* unused */) const { auto& scriptCall = static_cast(rpc); auto& stack = scriptCall.stackRef(); TORCH_CHECK( @@ -161,7 +162,7 @@ void RequestCallbackNoPython::processPythonCall( RpcCommandBase& rpc, const std::function& markComplete, const int64_t messageId, - const std::shared_ptr& responseFuture) const { + const std::shared_ptr& /* unused */) const { C10_THROW_ERROR(Error, "Python call not supported!"); } @@ -169,7 +170,7 @@ void RequestCallbackNoPython::processPythonRemoteCall( RpcCommandBase& rpc, const std::function& markComplete, const int64_t messageId, - const std::shared_ptr& responseFuture) const { + const std::shared_ptr& /* unused */) const { C10_THROW_ERROR(Error, "Python call not supported!"); } @@ -187,7 +188,7 @@ void RequestCallbackNoPython::processBaseScriptRemoteCall( RpcCommandBase& rpc, const std::function& markComplete, const int64_t messageId, - const std::shared_ptr& responseFuture) const { + const std::shared_ptr& responseFuture) const { auto& scriptRemoteCall = static_cast(rpc); auto rrefId = scriptRemoteCall.retRRefId(); auto forkId = scriptRemoteCall.retForkId(); @@ -208,7 +209,8 @@ void RequestCallbackNoPython::processBaseScriptRemoteCall( } Message m = RemoteRet(rrefId, forkId).toMessage(); m.setId(messageId); - responseFuture->markCompleted(std::move(m)); + responseFuture->markCompleted( + IValue(c10::make_intrusive(std::move(m)))); }; // scriptRemoteCall is only alive within this block, use reference to @@ -259,7 +261,7 @@ void RequestCallbackNoPython::processScriptRRefFetchCall( RpcCommandBase& rpc, const std::function& markComplete, const int64_t messageId, - const std::shared_ptr& responseFuture) const { + const std::shared_ptr& responseFuture) const { auto& srf = static_cast(rpc); auto& ctx = RRefContext::getInstance(); @@ -283,15 +285,16 @@ void RequestCallbackNoPython::processScriptRRefFetchCall( whenValueSet->addCallback( [responseFuture, messageId, rref, whenValueSet]() { if (whenValueSet->hasError()) { - responseFuture->setError(whenValueSet->tryRetrieveErrorMessage()); + responseFuture->setError(whenValueSet->exception_ptr()); return; } try { Message m = ScriptRRefFetchRet({rref->getValue()}).toMessage(); m.setId(messageId); - responseFuture->markCompleted(std::move(m)); - } catch (const std::exception& e) { - responseFuture->setError(e.what()); + responseFuture->markCompleted( + IValue(c10::make_intrusive(std::move(m)))); + } catch (const std::exception& /* unused */) { + responseFuture->setError(std::current_exception()); } }); }); @@ -300,7 +303,7 @@ void RequestCallbackNoPython::processScriptRRefFetchCall( void RequestCallbackNoPython::processPythonRRefFetchCall( RpcCommandBase& rpc, const int64_t messageId, - const std::shared_ptr& responseFuture) const { + const std::shared_ptr& /* unused */) const { C10_THROW_ERROR(Error, "Python call not supported!"); } @@ -340,7 +343,7 @@ void RequestCallbackNoPython::processRRefForkRequest( void RequestCallbackNoPython::processForwardAutogradReq( RpcCommandBase& rpc, const int64_t messageId, - const std::shared_ptr& responseFuture) const { + const std::shared_ptr& responseFuture) const { auto& rpcWithAutograd = static_cast(rpc); // Attach 'recv' autograd function. @@ -362,7 +365,8 @@ void RequestCallbackNoPython::processForwardAutogradReq( // Process the original RPC. auto wrappedMessageType = rpcWithAutograd.wrappedMessageType(); // Make an overall future for the wrapped response. - auto wrappedRpcResponseFuture = std::make_shared(); + auto wrappedRpcResponseFuture = + std::make_shared(at::AnyClassType::get()); // Kick off processing for the nested RPC command. // wrappedRpcResponseFuture will be a Future to the result. processRpc( @@ -379,7 +383,7 @@ void RequestCallbackNoPython::processForwardAutogradReq( responseFuture, messageId, fromWorkerId, - weak = std::weak_ptr(wrappedRpcResponseFuture), + weak = std::weak_ptr(wrappedRpcResponseFuture), ctxId = autogradContext->contextId() ]() { // As this callback can be invoked by a different thread, we have to @@ -398,14 +402,15 @@ void RequestCallbackNoPython::processForwardAutogradReq( TORCH_INTERNAL_ASSERT(wrappedRpcResponseFuture); if (wrappedRpcResponseFuture->hasError()) { // Propagate error to responseFuture if we had one. - responseFuture->setError(wrappedRpcResponseFuture->error()->what()); + responseFuture->setError(wrappedRpcResponseFuture->exception_ptr()); } else { auto msg = getMessageWithAutograd( fromWorkerId, - std::move(*wrappedRpcResponseFuture).moveValue(), + std::move(*wrappedRpcResponseFuture->value().toCustomClass()), MessageType::FORWARD_AUTOGRAD_RESP); msg.setId(messageId); - responseFuture->markCompleted(std::move(msg)); + responseFuture->markCompleted( + IValue(c10::make_intrusive(std::move(msg)))); } }); } @@ -413,7 +418,7 @@ void RequestCallbackNoPython::processForwardAutogradReq( void RequestCallbackNoPython::processBackwardAutogradReq( RpcCommandBase& rpc, const int64_t messageId, - const std::shared_ptr& responseFuture) const { + const std::shared_ptr& responseFuture) const { auto& gradientsCall = static_cast(rpc); const auto& autogradMetadata = gradientsCall.getAutogradMetadata(); @@ -437,9 +442,10 @@ void RequestCallbackNoPython::processBackwardAutogradReq( if (!execFuture->hasError()) { Message m = std::move(PropagateGradientsResp()).toMessage(); m.setId(messageId); - responseFuture->markCompleted(std::move(m)); + responseFuture->markCompleted( + IValue(c10::make_intrusive(std::move(m)))); } else { - responseFuture->setError(execFuture->tryRetrieveErrorMessage()); + responseFuture->setError(execFuture->exception_ptr()); } }); } @@ -461,7 +467,7 @@ void RequestCallbackNoPython::processCleanupAutogradContextReq( void RequestCallbackNoPython::processRunWithProfilingReq( RpcCommandBase& rpc, const int64_t messageId, - const std::shared_ptr& responseFuture) const { + const std::shared_ptr& responseFuture) const { auto& rpcWithProfilingReq = static_cast(rpc); auto wrappedMsgType = rpcWithProfilingReq.wrappedMessageType(); auto profilingConfig = rpcWithProfilingReq.getProfilingConfig(); @@ -483,7 +489,8 @@ void RequestCallbackNoPython::processRunWithProfilingReq( this->cudaAvailable(), "Profiler state set to CUDA but CUDA not available."); const auto profilingKeyId = rpcWithProfilingReq.getProfilingId(); - auto wrappedRpcResponseFuture = std::make_shared(); + auto wrappedRpcResponseFuture = + std::make_shared(at::AnyClassType::get()); // Enable the profiler with the config from the sender. // When enabling on the main thread, ensure profiler states are cleaned // up, but defer consolidation of all profiled events to the continuation @@ -525,17 +532,17 @@ void RequestCallbackNoPython::processRunWithProfilingReq( if (wrappedRpcResponseFuture->hasError()) { // Propagate error // No need to propagate remote events in the case of an error. - responseFuture->setError(wrappedRpcResponseFuture->error()->what()); + responseFuture->setError(wrappedRpcResponseFuture->exception_ptr()); } else { populateRemoteProfiledEvents( profiledEvents, profilingConfig, event_lists); auto rpcWithProfilingResp = std::make_unique( MessageType::RUN_WITH_PROFILING_RESP, - std::move(*wrappedRpcResponseFuture).moveValue(), + std::move(*wrappedRpcResponseFuture->value().toCustomClass()), profiledEvents, profilingKeyId); - responseFuture->markCompleted( - std::move(*rpcWithProfilingResp).toMessage()); + responseFuture->markCompleted(IValue(c10::make_intrusive( + std::move(*rpcWithProfilingResp).toMessage()))); } })); // Exiting the scope will disable the profiler on this thread with the @@ -546,7 +553,7 @@ void RequestCallbackNoPython::processRunWithProfilingReq( void RequestCallbackNoPython::processRRefBackward( RpcCommandBase& rpc, const int64_t messageId, - const std::shared_ptr& responseFuture) const { + const std::shared_ptr& /* unused */) const { C10_THROW_ERROR(Error, "Python call not supported!"); } @@ -554,10 +561,11 @@ void RequestCallbackNoPython::processRpc( RpcCommandBase& rpc, const MessageType& messageType, const int64_t messageId, - const std::shared_ptr& responseFuture) const { + const std::shared_ptr& responseFuture) const { auto markComplete = [messageId, &responseFuture](Message m) { m.setId(messageId); - responseFuture->markCompleted(std::move(m)); + responseFuture->markCompleted( + IValue(c10::make_intrusive(std::move(m)))); }; // TODO: RpcCommandBase should have an abstract execute() method that we can // call here instead of having another switch statement here. Even better we @@ -629,7 +637,7 @@ void RequestCallbackNoPython::processRpc( } } -Message RequestCallbackNoPython::handleError( +IValue RequestCallbackNoPython::handleError( const std::exception& e, const MessageType messageType, int64_t messageId) const { @@ -642,7 +650,8 @@ Message RequestCallbackNoPython::handleError( DistAutogradContainer::getInstance().getWorkerId(), ": ", e.what()); - return createExceptionResponse(errorMsg, messageId); + return IValue(c10::make_intrusive( + createExceptionResponse(errorMsg, messageId))); } bool RequestCallbackNoPython::cudaAvailable() const { diff --git a/torch/csrc/distributed/rpc/request_callback_no_python.h b/torch/csrc/distributed/rpc/request_callback_no_python.h index 7d9dac0a6635..269bd5ac33ff 100644 --- a/torch/csrc/distributed/rpc/request_callback_no_python.h +++ b/torch/csrc/distributed/rpc/request_callback_no_python.h @@ -26,7 +26,7 @@ class TORCH_API RequestCallbackNoPython : public RequestCallback { RpcCommandBase& rpc, const std::function& markComplete, const int64_t messageId, - const std::shared_ptr& responseFuture) const; + const std::shared_ptr& responseFuture) const; bool processScriptCallOp( ScriptCall& scriptCall, @@ -37,7 +37,7 @@ class TORCH_API RequestCallbackNoPython : public RequestCallback { RpcCommandBase& rpc, const std::function& markComplete, const int64_t messageId, - const std::shared_ptr& responseFuture) const; + const std::shared_ptr& responseFuture) const; virtual TypePtr getScriptRemoteCallType( ScriptRemoteCall& scriptRemoteCall) const; @@ -52,7 +52,7 @@ class TORCH_API RequestCallbackNoPython : public RequestCallback { RpcCommandBase& rpc, const std::function& markComplete, const int64_t messageId, - const std::shared_ptr& responseFuture) const; + const std::shared_ptr& responseFuture) const; bool processScriptRemoteCallOp( ScriptRemoteCall& scriptRemoteCall, @@ -64,18 +64,18 @@ class TORCH_API RequestCallbackNoPython : public RequestCallback { RpcCommandBase& rpc, const std::function& markComplete, const int64_t messageId, - const std::shared_ptr& responseFuture) const; + const std::shared_ptr& responseFuture) const; void processScriptRRefFetchCall( RpcCommandBase& rpc, const std::function& markComplete, const int64_t messageId, - const std::shared_ptr& responseFuture) const; + const std::shared_ptr& responseFuture) const; virtual void processPythonRRefFetchCall( RpcCommandBase& rpc, const int64_t messageId, - const std::shared_ptr& responseFuture) const; + const std::shared_ptr& responseFuture) const; void processRRefUserDelete( RpcCommandBase& rpc, @@ -92,12 +92,12 @@ class TORCH_API RequestCallbackNoPython : public RequestCallback { void processForwardAutogradReq( RpcCommandBase& rpc, const int64_t messageId, - const std::shared_ptr& responseFuture) const; + const std::shared_ptr& responseFuture) const; void processBackwardAutogradReq( RpcCommandBase& rpc, const int64_t messageId, - const std::shared_ptr& responseFuture) const; + const std::shared_ptr& responseFuture) const; void processCleanupAutogradContextReq( RpcCommandBase& rpc, @@ -106,7 +106,7 @@ class TORCH_API RequestCallbackNoPython : public RequestCallback { void processRunWithProfilingReq( RpcCommandBase& rpc, const int64_t messageId, - const std::shared_ptr& responseFuture) const; + const std::shared_ptr& responseFuture) const; virtual void handleRRefDelete(c10::intrusive_ptr& rref) const; @@ -114,15 +114,15 @@ class TORCH_API RequestCallbackNoPython : public RequestCallback { RpcCommandBase& rpc, const MessageType& messageType, const int64_t messageId, - const std::shared_ptr& responseFuture) const; + const std::shared_ptr& responseFuture) const; virtual void processRpcWithErrors( RpcCommandBase& rpc, const MessageType& messageType, const int64_t messageId, - const std::shared_ptr& responseFuture) const; + const std::shared_ptr& responseFuture) const; - Message handleError( + IValue handleError( const std::exception& e, const MessageType messageType, int64_t messageId) const; @@ -132,7 +132,7 @@ class TORCH_API RequestCallbackNoPython : public RequestCallback { virtual void processRRefBackward( RpcCommandBase& rpc, const int64_t messageId, - const std::shared_ptr& responseFuture) const; + const std::shared_ptr& responseFuture) const; }; } // namespace rpc From 098751016e93bf10c3dfcb8dadfdb15c7d15ad44 Mon Sep 17 00:00:00 2001 From: Shen Li Date: Thu, 7 Jan 2021 19:43:44 -0800 Subject: [PATCH 26/54] Completely Remove FutureMessage from RPC cpp tests (#50027) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/50027 Test Plan: Imported from OSS Reviewed By: lw Differential Revision: D25753815 Pulled By: mrshenli fbshipit-source-id: 85b9b03fec52b4175288ac3a401285607744b451 --- test/cpp/rpc/e2e_test_base.h | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/test/cpp/rpc/e2e_test_base.h b/test/cpp/rpc/e2e_test_base.h index 61d9a73e1077..cea5079b1a4e 100644 --- a/test/cpp/rpc/e2e_test_base.h +++ b/test/cpp/rpc/e2e_test_base.h @@ -90,14 +90,15 @@ class TestE2EBase : public ::testing::Test { ScriptCall scriptCall(op, {t1, t2, /* alpha */ 1}); // Send the RPC and return result. - auto response = RpcAgent::toFutureMessage( - autograd::sendMessageWithAutograd( - *rpcAgent, - rpcAgent->getWorkerInfo("worker"), - std::move(scriptCall).toMessage())) - ->wait(); + auto response = autograd::sendMessageWithAutograd( + *rpcAgent, + rpcAgent->getWorkerInfo("worker"), + std::move(scriptCall).toMessage()); + response->waitAndThrow(); + MessageType messageType = MessageType::FORWARD_AUTOGRAD_RESP; - auto wrappedResponse = deserializeResponse(response, messageType); + auto wrappedResponse = deserializeResponse( + std::move(*response->value().toCustomClass()), messageType); return static_cast(*wrappedResponse).value().toTensor(); } From 171648edaa50cea3ef82305576a0b881154b712f Mon Sep 17 00:00:00 2001 From: Shen Li Date: Thu, 7 Jan 2021 19:43:44 -0800 Subject: [PATCH 27/54] Completely Remove FutureMessage from RPC agents (#50028) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/50028 Test Plan: Imported from OSS Reviewed By: lw Differential Revision: D25753887 Pulled By: mrshenli fbshipit-source-id: 40718349c2def262a16aaa24c167c0b540cddcb1 --- .../distributed/rpc/process_group_agent.cpp | 20 ++++++++------ .../csrc/distributed/rpc/request_callback.cpp | 2 +- torch/csrc/distributed/rpc/request_callback.h | 4 +-- .../rpc/request_callback_no_python.cpp | 4 +-- .../rpc/request_callback_no_python.h | 2 +- .../csrc/distributed/rpc/tensorpipe_agent.cpp | 26 +++++++++---------- torch/csrc/distributed/rpc/tensorpipe_agent.h | 2 +- 7 files changed, 32 insertions(+), 28 deletions(-) diff --git a/torch/csrc/distributed/rpc/process_group_agent.cpp b/torch/csrc/distributed/rpc/process_group_agent.cpp index 9f37535c485a..be3b1803b8d6 100644 --- a/torch/csrc/distributed/rpc/process_group_agent.cpp +++ b/torch/csrc/distributed/rpc/process_group_agent.cpp @@ -514,22 +514,24 @@ bool ProcessGroupAgent::handleRecv(RecvWork& work) { std::move(data.first), std::move(data.second), work.type_, work.id_); if (message.isRequest()) { ++serverActiveCalls_; - std::shared_ptr futureResponse; + std::shared_ptr futureResponse; try { futureResponse = cb_->operator()(message); } catch (const std::exception& e) { - futureResponse = std::make_shared(); - futureResponse->setError(e.what()); + futureResponse = std::make_shared(at::AnyClassType::get()); + futureResponse->setError(std::current_exception()); } if (futureResponse->completed()) { --serverActiveCalls_; if (!futureResponse->hasError()) { - send(work.from_, std::move(*futureResponse).moveValue()); + send( + work.from_, + std::move(*futureResponse->value().toCustomClass())); } else { send( work.from_, createExceptionResponse( - futureResponse->error()->what(), message.id())); + futureResponse->tryRetrieveErrorMessage(), message.id())); } } else { ++serverActiveAsyncCalls_; @@ -541,19 +543,21 @@ bool ProcessGroupAgent::handleRecv(RecvWork& work) { futureResponse->addCallback([this, fromId, requestId, - weak = std::weak_ptr( + weak = std::weak_ptr( futureResponse)]() { auto futureResponse = weak.lock(); TORCH_INTERNAL_ASSERT(futureResponse); --serverActiveCalls_; --serverActiveAsyncCalls_; if (!futureResponse->hasError()) { - send(getWorkerInfo(fromId), std::move(*futureResponse).moveValue()); + send( + getWorkerInfo(fromId), + std::move(*futureResponse->value().toCustomClass())); } else { send( getWorkerInfo(fromId), createExceptionResponse( - futureResponse->error()->what(), requestId)); + futureResponse->tryRetrieveErrorMessage(), requestId)); } }); } diff --git a/torch/csrc/distributed/rpc/request_callback.cpp b/torch/csrc/distributed/rpc/request_callback.cpp index 44b7cb6eb2e5..e690e9ac90dd 100644 --- a/torch/csrc/distributed/rpc/request_callback.cpp +++ b/torch/csrc/distributed/rpc/request_callback.cpp @@ -9,7 +9,7 @@ namespace rpc { using namespace torch::distributed::autograd; -std::shared_ptr RequestCallback::operator()( +std::shared_ptr RequestCallback::operator()( Message& request) const { // NB: cannot clear autograd context id here because the processMessage method // might pause waiting for all RRefs in the arguments to be confirmed by their diff --git a/torch/csrc/distributed/rpc/request_callback.h b/torch/csrc/distributed/rpc/request_callback.h index 95847eb6153a..bf1231eefd7f 100644 --- a/torch/csrc/distributed/rpc/request_callback.h +++ b/torch/csrc/distributed/rpc/request_callback.h @@ -12,7 +12,7 @@ namespace rpc { class TORCH_API RequestCallback { public: // Invoke the callback. - std::shared_ptr operator()(Message& request) const; + std::shared_ptr operator()(Message& request) const; virtual ~RequestCallback() {} @@ -24,7 +24,7 @@ class TORCH_API RequestCallback { // message containing an exception. Different rpc agent implementations are // expected to ensure delivery of the response/exception based on their // implementation specific mechanisms. - virtual std::shared_ptr processMessage( + virtual std::shared_ptr processMessage( Message& request) const = 0; }; diff --git a/torch/csrc/distributed/rpc/request_callback_no_python.cpp b/torch/csrc/distributed/rpc/request_callback_no_python.cpp index 5f2765b32a01..db455e11710f 100644 --- a/torch/csrc/distributed/rpc/request_callback_no_python.cpp +++ b/torch/csrc/distributed/rpc/request_callback_no_python.cpp @@ -48,7 +48,7 @@ std::unique_ptr RequestCallbackNoPython:: return rpc; } -std::shared_ptr RequestCallbackNoPython::processMessage( +std::shared_ptr RequestCallbackNoPython::processMessage( Message& request) const { // We need two futures here because it could pause twice when processing a // RPC message: @@ -105,7 +105,7 @@ std::shared_ptr RequestCallbackNoPython::processMessage( retFuture->markCompleted(handleError(e, request.type(), request.id())); rrefContext.clearRecordedPendingRRefsOnError(); } - return RpcAgent::toFutureMessage(std::move(retFuture)); + return retFuture; } void RequestCallbackNoPython::processRpcWithErrors( diff --git a/torch/csrc/distributed/rpc/request_callback_no_python.h b/torch/csrc/distributed/rpc/request_callback_no_python.h index 269bd5ac33ff..09c2e48479ab 100644 --- a/torch/csrc/distributed/rpc/request_callback_no_python.h +++ b/torch/csrc/distributed/rpc/request_callback_no_python.h @@ -14,7 +14,7 @@ namespace rpc { // RequestCallback implementation with no Python dependencies. class TORCH_API RequestCallbackNoPython : public RequestCallback { public: - std::shared_ptr processMessage( + std::shared_ptr processMessage( Message& request) const override; protected: diff --git a/torch/csrc/distributed/rpc/tensorpipe_agent.cpp b/torch/csrc/distributed/rpc/tensorpipe_agent.cpp index 3050a40ec608..d05426f6160d 100644 --- a/torch/csrc/distributed/rpc/tensorpipe_agent.cpp +++ b/torch/csrc/distributed/rpc/tensorpipe_agent.cpp @@ -464,7 +464,7 @@ void TensorPipeAgent::pipeWrite( void TensorPipeAgent::sendCompletedResponseMessage( std::shared_ptr& pipe, - std::shared_ptr& futureResponseMessage, + std::shared_ptr& futureResponseMessage, uint64_t messageId) { if (!rpcAgentRunning_.load()) { LOG(WARNING) << "RPC agent for " << workerInfo_.name_ @@ -477,17 +477,15 @@ void TensorPipeAgent::sendCompletedResponseMessage( << " is sending response to request #" << messageId << " to " << pipe->getRemoteName(); - const c10::optional error = - futureResponseMessage->error(); - Message&& responseMessage = std::move(*futureResponseMessage).moveValue(); - responseMessage.setId(messageId); - if (!error) { + if (!futureResponseMessage->hasError()) { + Message&& responseMessage = + std::move(*futureResponseMessage->value().toCustomClass()); + responseMessage.setId(messageId); std::vector devices; - try { devices = getDevicesForTensors(pipe->getRemoteName(), responseMessage); } catch (const std::exception& e) { - responseMessage = createExceptionResponse(e.what(), responseMessage.id()); + responseMessage = createExceptionResponse(e.what(), messageId); } pipeWrite( @@ -511,7 +509,8 @@ void TensorPipeAgent::sendCompletedResponseMessage( } else { pipeWrite( pipe, - createExceptionResponse(error->what(), responseMessage.id()), + createExceptionResponse( + futureResponseMessage->tryRetrieveErrorMessage(), messageId), {}, [this, pipe, messageId](const tensorpipe::Error& error) { if (error) { @@ -572,12 +571,13 @@ void TensorPipeAgent::respond(std::shared_ptr& pipe) { << " is running request #" << messageId << " from " << pipe->getRemoteName() << " in thread pool"; - std::shared_ptr futureResponseMessage; + std::shared_ptr futureResponseMessage; try { futureResponseMessage = cb_->operator()(requestMessage); - } catch (const std::exception& e) { - futureResponseMessage = std::make_shared(); - futureResponseMessage->setError(e.what()); + } catch (const std::exception& /* unused */) { + futureResponseMessage = + std::make_shared(at::AnyClassType::get()); + futureResponseMessage->setError(std::current_exception()); } // Shortcut if immediately done diff --git a/torch/csrc/distributed/rpc/tensorpipe_agent.h b/torch/csrc/distributed/rpc/tensorpipe_agent.h index 022f92241044..ef72e5468a5d 100644 --- a/torch/csrc/distributed/rpc/tensorpipe_agent.h +++ b/torch/csrc/distributed/rpc/tensorpipe_agent.h @@ -249,7 +249,7 @@ class TensorPipeAgent : public RpcAgent { void sendCompletedResponseMessage( std::shared_ptr& pipe, - std::shared_ptr& futureResponseMessage, + std::shared_ptr& futureResponseMessage, uint64_t messageId); // Collects metrics from successful RPC calls From c480eebf958a306afdb4bcdf15afe89ce6a38731 Mon Sep 17 00:00:00 2001 From: Shen Li Date: Thu, 7 Jan 2021 19:43:44 -0800 Subject: [PATCH 28/54] Completely remove FutureMessage type (#50029) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/50029 Test Plan: buck run mode/opt -c=python.package_style=inplace //caffe2/torch/fb/training_toolkit/examples:ctr_mbl_feed_april_2020 -- local-preset --flow-entitlement pytorch_ftw_gpu --secure-group oncall_pytorch_distributed Before: ``` ... I0107 11:03:10.434000 3831111 print_publisher.py:23 master ] Publishing batch metrics: qps-qps|total_examples 14000.0 I0107 11:03:10.434000 3831111 print_publisher.py:23 master ] Publishing batch metrics: qps-qps|window_qps 74.60101318359375 I0107 11:03:10.434000 3831111 print_publisher.py:23 master ] Publishing batch metrics: qps-qps|lifetime_qps 74.60101318359375 ... I0107 11:05:12.132000 3831111 print_publisher.py:23 master ] Publishing batch metrics: qps-qps|total_examples 20000.0 I0107 11:05:12.132000 3831111 print_publisher.py:23 master ] Publishing batch metrics: qps-qps|window_qps 64.0 I0107 11:05:12.132000 3831111 print_publisher.py:23 master ] Publishing batch metrics: qps-qps|lifetime_qps 64.64917755126953 ... ``` After: ``` ... I0107 11:53:03.858000 53693 print_publisher.py:23 master ] Publishing batch metrics: qps-qps|total_examples 14000.0 I0107 11:53:03.858000 53693 print_publisher.py:23 master ] Publishing batch metrics: qps-qps|window_qps 72.56404876708984 I0107 11:53:03.858000 53693 print_publisher.py:23 master ] Publishing batch metrics: qps-qps|lifetime_qps 72.56404876708984 ... I0107 11:54:24.612000 53693 print_publisher.py:23 master ] Publishing batch metrics: qps-qps|total_examples 20000.0 I0107 11:54:24.612000 53693 print_publisher.py:23 master ] Publishing batch metrics: qps-qps|window_qps 73.07617950439453 I0107 11:54:24.612000 53693 print_publisher.py:23 master ] Publishing batch metrics: qps-qps|lifetime_qps 73.07617950439453 ... ``` Reviewed By: lw Differential Revision: D25774915 Pulled By: mrshenli fbshipit-source-id: 1128c3c2df9d76e36beaf171557da86e82043eb9 --- torch/csrc/distributed/rpc/message.h | 3 -- torch/csrc/distributed/rpc/rpc_agent.h | 46 ++------------------------ torch/csrc/jit/python/pybind_utils.h | 2 +- 3 files changed, 4 insertions(+), 47 deletions(-) diff --git a/torch/csrc/distributed/rpc/message.h b/torch/csrc/distributed/rpc/message.h index bad9569cfff8..7ff487b9b337 100644 --- a/torch/csrc/distributed/rpc/message.h +++ b/torch/csrc/distributed/rpc/message.h @@ -154,9 +154,6 @@ TORCH_API Message createExceptionResponse(const std::exception& e, int64_t id); TORCH_API Message createExceptionResponse(const std::string& exceptionStr, int64_t id); -// FutureMessage is an internal type used in the communication layer. All -// user-facing surface APIs should use JitFuture instead. -using FutureMessage = torch::utils::Future; using JitFuture = c10::ivalue::Future; } // namespace rpc diff --git a/torch/csrc/distributed/rpc/rpc_agent.h b/torch/csrc/distributed/rpc/rpc_agent.h index c8763c678922..bfc6c38c07a1 100644 --- a/torch/csrc/distributed/rpc/rpc_agent.h +++ b/torch/csrc/distributed/rpc/rpc_agent.h @@ -151,10 +151,10 @@ class TORCH_API RpcAgent { virtual ~RpcAgent(); // Send a message to the ``RpcAgent`` of id ``to`` and returns a - // ``FutureMessage`` ptr. The implementation must be asynchronous, i.e., it + // ``JitFuture`` ptr. The implementation must be asynchronous, i.e., it // cannot block until it receives the response. // - // If ``message.isRequest()`` is true, the ``FutureMessage`` will be + // If ``message.isRequest()`` is true, the ``JitFuture`` will be // completed when the response arrives. For other message types, the Future // should be ignored by the caller. virtual std::shared_ptr send( @@ -167,7 +167,7 @@ class TORCH_API RpcAgent { // time using an exponential backoff algorithm. // // Sends ``message`` to the ``RpcAgent`` of id ``to`` and returns a - // ``FutureMessage`` ptr, just like send(). Caller can specify the maximum + // ``JitFuture`` ptr, just like send(). Caller can specify the maximum // number of retries for this RPC (default is 5), initial duration between // sends (default is 1000ms), and backoff constant (default is 1.5) by // passing in the RpcRetryOptions struct. This API might end up @@ -259,46 +259,6 @@ class TORCH_API RpcAgent { // Get the type resolver std::shared_ptr getTypeResolver(); - static std::shared_ptr toJitFuture( - std::shared_ptr&& fm) { - auto jitFuture = std::make_shared(at::AnyClassType::get()); - - std::weak_ptr wp = fm; - fm->addCallback( - [jitFuture, wp]() mutable { - auto future = wp.lock(); - TORCH_INTERNAL_ASSERT(future); - if (future->hasError()) { - jitFuture->setError(std::make_exception_ptr(*(future->error()))); - } else { - jitFuture->markCompleted(IValue( - c10::make_intrusive(std::move(*future).moveValue()))); - } - } - ); - return jitFuture; - } - - static std::shared_ptr toFutureMessage( - std::shared_ptr&& jitFuture) { - auto fm = std::make_shared(); - - std::weak_ptr wp = jitFuture; - jitFuture->addCallback( - [fm, wp]() mutable { - auto future = wp.lock(); - TORCH_INTERNAL_ASSERT(future); - if (future->hasError()) { - fm->setError(future->tryRetrieveErrorMessage()); - } else { - fm->markCompleted( - std::move(*future->value().toCustomClass())); - } - } - ); - return fm; - } - protected: const WorkerInfo workerInfo_; const std::unique_ptr cb_; diff --git a/torch/csrc/jit/python/pybind_utils.h b/torch/csrc/jit/python/pybind_utils.h index bfe160b2647f..b14f4ddc37fd 100644 --- a/torch/csrc/jit/python/pybind_utils.h +++ b/torch/csrc/jit/python/pybind_utils.h @@ -122,7 +122,7 @@ struct VISIBILITY_HIDDEN PythonFutureWrapper // Capture a copy of the ivalue::Future instead of the `this` pointer // because the PythonFutureWrapper object could have been deleted // when the callbacks are fired. For example, RPC only captures the - // ivalue::Future instead of PythonFutureWrapper in FutureMessage's + // ivalue::Future instead of PythonFutureWrapper in JitFuture's // callback functions. Hence, if user code does not hold a reference to // this PythonFutureWrapper object, there is no guarantee that the // PythonFutureWrapper is still valid when running the callback. From 882ddb2f2dc1e96127d105e41c6f421742c3659e Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Thu, 7 Jan 2021 20:54:20 -0800 Subject: [PATCH 29/54] [PyTorch] Introduce packed SizesAndStrides abstraction (#47507) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47507 This introduces a new SizesAndStrides class as a helper for TensorImpl, in preparation for changing its representation. ghstack-source-id: 119313559 Test Plan: Added new automated tests as well. Run framework overhead benchmarks. Results seem to be neutral-ish. Reviewed By: ezyang Differential Revision: D24762557 fbshipit-source-id: 6cc0ede52d0a126549fb51eecef92af41c3e1a98 --- aten/src/ATen/BatchedTensorImpl.cpp | 11 +- aten/src/ATen/OpaqueTensorImpl.h | 6 +- aten/src/ATen/SparseTensorImpl.h | 11 +- c10/core/TensorImpl.cpp | 59 +-- c10/core/TensorImpl.h | 69 ++-- c10/core/impl/SizesAndStrides.h | 134 +++++++ c10/test/core/impl/SizesAndStrides_test.cpp | 385 ++++++++++++++++++++ 7 files changed, 594 insertions(+), 81 deletions(-) create mode 100644 c10/core/impl/SizesAndStrides.h create mode 100644 c10/test/core/impl/SizesAndStrides_test.cpp diff --git a/aten/src/ATen/BatchedTensorImpl.cpp b/aten/src/ATen/BatchedTensorImpl.cpp index f295d70c31fd..9dbf9ea78f4b 100644 --- a/aten/src/ATen/BatchedTensorImpl.cpp +++ b/aten/src/ATen/BatchedTensorImpl.cpp @@ -20,14 +20,11 @@ BatchedTensorImpl::BatchedTensorImpl(Tensor value, BatchDims bdims) const auto public_dims = value_.dim() - bdims_.size(); const auto value_sizes = value_.sizes(); const auto value_strides = value_.strides(); - sizes_.clear(); - sizes_.reserve(public_dims); - strides_.clear(); - strides_.reserve(public_dims); + sizes_and_strides_.resize(public_dims); for (int64_t dim = 0; dim < public_dims; dim++) { auto actual_dim = actualDim(dim, /*wrap_dim=*/false); - sizes_.push_back(value_sizes.at(actual_dim)); - strides_.push_back(value_strides.at(actual_dim)); + sizes_and_strides_.size_at_unchecked(dim) = value_sizes.at(actual_dim); + sizes_and_strides_.stride_at_unchecked(dim) = value_strides.at(actual_dim); } refresh_numel(); refresh_contiguous(); @@ -35,7 +32,7 @@ BatchedTensorImpl::BatchedTensorImpl(Tensor value, BatchDims bdims) int64_t BatchedTensorImpl::actualDim(int64_t dim, bool wrap_dim) const { if (wrap_dim) { - const auto ndim = sizes_.size(); + const auto ndim = sizes_and_strides_.size(); dim = maybe_wrap_dim(dim, ndim); } auto is_bdim = createBatchDimBitset(bdims_); diff --git a/aten/src/ATen/OpaqueTensorImpl.h b/aten/src/ATen/OpaqueTensorImpl.h index f23a097aabb8..2072f549d011 100644 --- a/aten/src/ATen/OpaqueTensorImpl.h +++ b/aten/src/ATen/OpaqueTensorImpl.h @@ -28,7 +28,7 @@ struct TORCH_API OpaqueTensorImpl : public TensorImpl { bool is_non_overlapping_and_dense = true) : TensorImpl(key_set, data_type, device), opaque_handle_(std::move(opaque_handle)) { - sizes_ = sizes.vec(); + sizes_and_strides_.set_sizes(sizes); refresh_numel(); is_non_overlapping_and_dense_ = is_non_overlapping_and_dense; } @@ -86,7 +86,7 @@ struct TORCH_API OpaqueTensorImpl : public TensorImpl { const c10::VariableVersion& version_counter, bool allow_tensor_metadata_change) const override { auto impl = c10::make_intrusive>( - key_set(), dtype(), device(), opaque_handle_, sizes_); + key_set(), dtype(), device(), opaque_handle_, sizes_and_strides_.sizes_arrayref()); copy_tensor_metadata( /*src_opaque_impl=*/this, /*dest_opaque_impl=*/impl.get(), @@ -106,7 +106,7 @@ struct TORCH_API OpaqueTensorImpl : public TensorImpl { c10::VariableVersion&& version_counter, bool allow_tensor_metadata_change) const override { auto impl = c10::make_intrusive>( - key_set(), dtype(), device(), opaque_handle_, sizes_); + key_set(), dtype(), device(), opaque_handle_, sizes_and_strides_.sizes_arrayref()); copy_tensor_metadata( /*src_opaque_impl=*/this, /*dest_opaque_impl=*/impl.get(), diff --git a/aten/src/ATen/SparseTensorImpl.h b/aten/src/ATen/SparseTensorImpl.h index c80f5de18824..9daf21c15e56 100644 --- a/aten/src/ATen/SparseTensorImpl.h +++ b/aten/src/ATen/SparseTensorImpl.h @@ -55,7 +55,7 @@ struct TORCH_API SparseTensorImpl : public TensorImpl { // respect to indices and values void raw_resize_(int64_t sparse_dim, int64_t dense_dim, IntArrayRef size) { TORCH_CHECK(allow_tensor_metadata_change(), "raw_resize_ ", err_msg_tensor_metadata_change_not_allowed); - sizes_ = size.vec(); + sizes_and_strides_.set_sizes(size); sparse_dim_ = sparse_dim; dense_dim_ = dense_dim; refresh_numel(); @@ -125,7 +125,8 @@ struct TORCH_API SparseTensorImpl : public TensorImpl { "shrinking the size of dense dimensions (from ", dense_size_original, " to ", dense_size_new, ") on a non-empty sparse tensor is not supported.\n", alt_options_msg); } - if ((!size.equals(sizes_)) || (sparse_dim != sparse_dim_) || (dense_dim != dense_dim_)) { + const bool size_equals_sizes = std::equal(size.begin(), size.end(), sizes_and_strides_.sizes_begin(), sizes_and_strides_.sizes_end()); + if ((!size_equals_sizes) || (sparse_dim != sparse_dim_) || (dense_dim != dense_dim_)) { auto nnz = values().size(0); std::vector values_size = {nnz}; auto dense_size = size.slice(sparse_dim); @@ -134,7 +135,9 @@ struct TORCH_API SparseTensorImpl : public TensorImpl { indices_.resize_({sparse_dim, nnz}); } - sizes_ = size.vec(); + if (!size_equals_sizes) { + sizes_and_strides_.set_sizes(size); + } sparse_dim_ = sparse_dim; dense_dim_ = dense_dim; refresh_numel(); @@ -145,7 +148,7 @@ struct TORCH_API SparseTensorImpl : public TensorImpl { TORCH_CHECK(allow_tensor_metadata_change(), "resize_and_clear_ ", err_msg_tensor_metadata_change_not_allowed); TORCH_CHECK(sparse_dim + dense_dim == static_cast(size.size()), "number of dimensions must be sparse_dim (", sparse_dim, ") + dense_dim (", dense_dim, "), but got ", size.size()); - sizes_ = size.vec(); + sizes_and_strides_.set_sizes(size); sparse_dim_ = sparse_dim; dense_dim_ = dense_dim; diff --git a/c10/core/TensorImpl.cpp b/c10/core/TensorImpl.cpp index f17ba95a754b..d79d38427d5a 100644 --- a/c10/core/TensorImpl.cpp +++ b/c10/core/TensorImpl.cpp @@ -67,7 +67,6 @@ TensorImpl::TensorImpl(DispatchKeySet key_set, const caffe2::TypeMeta data_type, TensorImpl::TensorImpl(Storage&& storage, DispatchKeySet key_set, const caffe2::TypeMeta data_type, c10::optional device_opt) : storage_(std::move(storage)), - sizes_{0}, storage_offset_(0), numel_(0), data_type_(data_type), @@ -91,15 +90,14 @@ TensorImpl::TensorImpl(Storage&& storage, DispatchKeySet key_set, const caffe2:: // we would also like to check that non-cpu devices have an index, but some Caffe2 operators create // Storages with default devices. - strides_.push_back(1); } IntArrayRef TensorImpl::sizes() const { - return sizes_; + return sizes_and_strides_.sizes_arrayref(); } IntArrayRef TensorImpl::strides() const { - return strides_; + return sizes_and_strides_.strides_arrayref(); } bool TensorImpl::compute_contiguous() const { @@ -108,9 +106,10 @@ bool TensorImpl::compute_contiguous() const { return is_contiguous; int64_t z = 1; for (int64_t d = dim() - 1; d >= 0; d--) { - if (sizes_[d] != 1) { - if (strides_[d] == z) { - z *= sizes_[d]; + const auto size_d = sizes_and_strides_.size_at_unchecked(d); + if (size_d != 1) { + if (sizes_and_strides_.stride_at_unchecked(d) == z) { + z *= size_d; } else { is_contiguous = false; break; @@ -123,16 +122,17 @@ bool TensorImpl::compute_contiguous() const { bool TensorImpl::compute_channels_last_contiguous_2d() const { // Please don't combine these code, constant array is used here to let // compiler fully unroll the loop to get better performance - switch (sizes_.size()) { + switch (sizes_and_strides_.size()) { case 4: { int64_t expected = 1; for (auto& d : {1, 3, 2, 0}) { - if (sizes_[d] != 1) { - if (strides_[d] != expected) { + const auto size_d = sizes_and_strides_.size_at_unchecked(d); + if (size_d != 1) { + if (sizes_and_strides_.stride_at_unchecked(d) != expected) { return false; } - expected *= sizes_[d]; + expected *= size_d; } } return true; @@ -148,16 +148,17 @@ bool TensorImpl::compute_channels_last_contiguous_2d() const { bool TensorImpl::compute_channels_last_contiguous_3d() const { // Please don't combine these code, constant array is used here to let // compiler fully unroll the loop to get better performance - switch (sizes_.size()) { + switch (sizes_and_strides_.size()) { case 5: { int64_t expected = 1; for (auto& d : {1, 4, 3, 2, 0}) { - if (sizes_[d] != 1) { - if (strides_[d] != expected) { + const auto size_d = sizes_and_strides_.size_at_unchecked(d); + if (size_d != 1) { + if (sizes_and_strides_.stride_at_unchecked(d) != expected) { return false; } - expected *= sizes_[d]; + expected *= size_d; } } return true; @@ -171,16 +172,16 @@ bool TensorImpl::compute_channels_last_contiguous_3d() const { } bool TensorImpl::compute_strides_like_channels_last_2d() const { - return is_channels_last_strides_2d(sizes_, strides_); + return is_channels_last_strides_2d(TensorImpl::sizes(), TensorImpl::strides()); } bool TensorImpl::compute_strides_like_channels_last_3d() const { - return is_channels_last_strides_3d(sizes_, strides_); + return is_channels_last_strides_3d(TensorImpl::sizes(), TensorImpl::strides()); } bool TensorImpl::compute_non_overlapping_and_dense() const { if (dim() == 1) { - return sizes_[0] < 2 || strides_[0] == 1; + return sizes_and_strides_.size_at_unchecked(0) < 2 || sizes_and_strides_.stride_at_unchecked(0) == 1; } SmallVector perm; perm.resize(dim()); @@ -189,22 +190,23 @@ bool TensorImpl::compute_non_overlapping_and_dense() const { } // Sort by strides, leaving 0 and 1 sized dims at the end of the array std::sort(perm.begin(), perm.end(), [&](int64_t a, int64_t b) { - if (sizes_[a] < 2) { + if (sizes_and_strides_.size_at_unchecked(a) < 2) { return false; - } else if (sizes_[b] < 2) { + } else if (sizes_and_strides_.size_at_unchecked(b) < 2) { return true; } - return strides_[a] < strides_[b]; + return sizes_and_strides_.stride_at_unchecked(a) < sizes_and_strides_.stride_at_unchecked(b); }); auto require_stride = 1; for (int64_t i = 0; i < dim(); i ++) { - if (sizes_[perm[i]] < 2) { + const auto size_perm_i = sizes_and_strides_.size_at_unchecked(perm[i]); + if (size_perm_i < 2) { return true; } - if (strides_[perm[i]] != require_stride) { + if (sizes_and_strides_.stride_at_unchecked(perm[i]) != require_stride) { return false; } - require_stride *= sizes_[perm[i]]; + require_stride *= size_perm_i; } return true; } @@ -218,18 +220,18 @@ void TensorImpl::release_resources() { #ifndef C10_DISABLE_TENSORIMPL_EXTENSIBILITY int64_t TensorImpl::dim() const { - return sizes_.size(); + return sizes_and_strides_.size(); } #endif int64_t TensorImpl::size(int64_t d) const { d = at::maybe_wrap_dim(d, dim(), false); - return sizes_[d]; + return sizes_and_strides_.size_at_unchecked(d); } int64_t TensorImpl::stride(int64_t d) const { d = at::maybe_wrap_dim(d, dim(), false); - return strides_[d]; + return sizes_and_strides_.stride_at_unchecked(d); } bool TensorImpl::has_storage() const { @@ -339,8 +341,7 @@ void TensorImpl::copy_tensor_metadata_except_version_counter( TensorImpl* dest_impl, bool allow_tensor_metadata_change) { dest_impl->storage_ = src_impl->storage_; - dest_impl->sizes_ = src_impl->sizes_; - dest_impl->strides_ = src_impl->strides_; + dest_impl->sizes_and_strides_ = src_impl->sizes_and_strides_; dest_impl->storage_offset_ = src_impl->storage_offset_; dest_impl->data_type_ = src_impl->data_type_; dest_impl->device_opt_ = src_impl->device_opt_; diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index d5c5208472d3..41d4ff4c7f08 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include @@ -10,8 +11,10 @@ #include #include #include +#include #include + #include #include #include @@ -769,7 +772,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { */ virtual void set_size(int64_t dim, int64_t new_size) { TORCH_CHECK(allow_tensor_metadata_change(), "set_size ", err_msg_tensor_metadata_change_not_allowed); - sizes_.at(dim) = new_size; + sizes_and_strides_.size_at(dim) = new_size; refresh_numel(); refresh_contiguous(); } @@ -782,7 +785,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { */ virtual void set_stride(int64_t dim, int64_t new_stride) { TORCH_CHECK(allow_tensor_metadata_change(), "set_stride ", err_msg_tensor_metadata_change_not_allowed); - strides_[dim] = new_stride; + sizes_and_strides_.stride_at_unchecked(dim) = new_stride; refresh_contiguous(); } @@ -807,12 +810,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { */ void set_sizes_contiguous(IntArrayRef new_size) { TORCH_CHECK(allow_tensor_metadata_change(), "set_sizes_contiguous ", err_msg_tensor_metadata_change_not_allowed); - auto new_dim = new_size.size(); - sizes_.resize(new_dim); - for (size_t dim = 0; dim < new_dim; ++dim) { - sizes_[dim] = new_size[dim]; - } + sizes_and_strides_.set_sizes(new_size); refresh_numel(); empty_tensor_restride(MemoryFormat::Contiguous); @@ -834,27 +833,25 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { ") must match dimensionality of strides (", new_stride.size(), ")"); - auto new_dim = new_size.size(); + const auto new_dim = new_size.size(); - sizes_.resize(new_dim); - for (size_t dim = 0; dim < new_dim; ++dim) { - sizes_[dim] = new_size[dim]; - } + sizes_and_strides_.set_sizes(new_size); - strides_.resize(new_dim); if (new_dim > 0) { for (size_t dim = new_dim - 1; ; dim--) { if (new_stride[dim] >= 0) { - strides_[dim] = new_stride[dim]; + sizes_and_strides_.stride_at_unchecked(dim) = new_stride[dim]; } else { // XXX: This behavior is surprising and may need to be removed to // support negative strides. Some pytorch functions rely on it: // for example, torch.cat (run TestTorch.test_cat_empty). if (dim == new_dim - 1) { - strides_[dim] = 1; + sizes_and_strides_.stride_at_unchecked(dim) = 1; } else { // Keep stride monotonically increasing to match NumPy. - strides_[dim] = std::max(sizes_[dim + 1], 1) * strides_[dim + 1]; + sizes_and_strides_.stride_at_unchecked(dim) = + std::max(sizes_and_strides_.size_at_unchecked(dim + 1), 1) * + sizes_and_strides_.stride_at_unchecked(dim + 1); } } if (dim == 0) break; @@ -1082,12 +1079,13 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { * This op is auto-asynchronous if the underlying device (CUDA) supports it. */ void Extend(int64_t num, float growthPct) { - TORCH_CHECK(sizes_.size() >= 1u); + TORCH_CHECK(sizes_and_strides_.size() >= 1u); TORCH_CHECK(num >= 0, "`num` must be non-negative for Extend"); TORCH_CHECK( is_contiguous_, "Right now Extend is only supported for contiguous Tensor."); - auto newDims = sizes_; + using SizesVector = SmallVector; + SizesVector newDims(sizes_and_strides_.sizes_begin(), sizes_and_strides_.sizes_end()); newDims[0] += num; if (!storage_.data()) { Resize(newDims); @@ -1099,16 +1097,15 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { static_cast(1), std::multiplies()); if (newNumel * data_type_.itemsize() <= storage_.nbytes()) { - sizes_ = newDims; + sizes_and_strides_.set_sizes(newDims); numel_ = newNumel; return; } - auto newCapacity = sizes_; + SizesVector newCapacity(sizes_and_strides_.sizes_begin(), sizes_and_strides_.sizes_end()); newCapacity[0] = std::max( - newDims[0], static_cast(std::ceil(sizes_[0] * (1 + growthPct / 100)))); + newDims[0], static_cast(std::ceil(sizes_and_strides_.size_at_unchecked(0) * (1 + growthPct / 100)))); auto oldData = std::move(storage_.data_ptr()); auto oldSize = numel_; - auto oldDims = sizes_; Resize(newCapacity); auto* newData = raw_mutable_data(data_type_); if (data_type_.copy()) { @@ -1135,7 +1132,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { true); // non-blocking } reserved_ = true; - sizes_ = newDims; + sizes_and_strides_.set_sizes(newDims); numel_ = newNumel; } @@ -1152,7 +1149,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { "Right now ReserveSpace is only supported for contiguous Tensor."); TORCH_CHECK( storage_.unique(), "Can't call ReserveSpace on shared storage."); - auto newCapacity = sizes_; + // TODO: eliminate newCapacity. + SmallVector newCapacity(sizes_and_strides_.sizes_begin(), sizes_and_strides_.sizes_end()); newCapacity[0] = outer_dim; auto newNumel = std::accumulate( newCapacity.begin(), @@ -1165,11 +1163,11 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { // Old data is discarded storage_.data_ptr().clear(); auto oldSize = numel_; - auto oldDims = sizes_; + SmallVector oldDims(sizes_and_strides_.sizes_begin(), sizes_and_strides_.sizes_end()); Resize(newCapacity); // Allocate new memory but don't copy over the data raw_mutable_data(data_type_); - sizes_ = oldDims; + sizes_and_strides_.set_sizes(oldDims); numel_ = oldSize; reserved_ = true; } @@ -1239,7 +1237,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { " The old caffe2 mixes Reshape and Resize but this behavior has " "been changed. If you find this error, most likely you will need " "to change corresponding code from Reshape to Resize."); - sizes_ = dims; + sizes_and_strides_.set_sizes(dims); empty_tensor_restride(MemoryFormat::Contiguous); } @@ -1454,12 +1452,12 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { case MemoryFormat::Contiguous: { // dim_ is a virtual call, don't repeat it const auto dim_ = dim(); - strides_.resize(dim_); + sizes_and_strides_.resize(dim_); if (dim_ > 0) { const auto last_idx = dim_ - 1; - strides_[last_idx] = 1; + sizes_and_strides_.stride_at_unchecked(last_idx) = 1; for (auto i = last_idx - 1; i >= 0; --i) { - strides_[i] = strides_[i + 1] * std::max(sizes_[i + 1], 1); + sizes_and_strides_.stride_at_unchecked(i) = sizes_and_strides_.stride_at_unchecked(i + 1) * std::max(sizes_and_strides_.size_at_unchecked(i + 1), 1); } } break; @@ -1517,11 +1515,11 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { typename = typename std::enable_if::value>::type> bool SetDimsTemplate(ArrayRef src) { auto old_numel = numel_; - sizes_.resize(src.size()); + sizes_and_strides_.resize(src.size()); int64_t new_numel = 1; for (size_t i = 0; i < src.size(); ++i) { new_numel *= src[i]; - sizes_[i] = src[i]; + sizes_and_strides_.size_at_unchecked(i) = src[i]; } numel_ = new_numel; empty_tensor_restride(MemoryFormat::Contiguous); @@ -1718,12 +1716,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { // occurs in THPVariable_clear in torch/csrc/autograd/python_variable.cpp PyObject* pyobj_ = nullptr; - // We could save a word or two by combining the SmallVector structs, - // since their size is redundant, and if we need to overflow the buffer space - // we could keep the two pointers together. However, that would require - // implementing another struct from scratch, so only do this if we're desperate. - SmallVector sizes_; - SmallVector strides_; + c10::impl::SizesAndStrides sizes_and_strides_; int64_t storage_offset_ = 0; // If sizes and strides are empty, the numel is 1!! However, most of the diff --git a/c10/core/impl/SizesAndStrides.h b/c10/core/impl/SizesAndStrides.h new file mode 100644 index 000000000000..b50af160d29c --- /dev/null +++ b/c10/core/impl/SizesAndStrides.h @@ -0,0 +1,134 @@ +#pragma once + +#include + +#include +#include +#include + +namespace c10 { +namespace impl { + +// Packed container for TensorImpl sizes and strides. +class C10_API SizesAndStrides { + public: + // TODO: different iterator types for sizes & strides to prevent + // mixing the two accidentally. + using sizes_iterator = int64_t*; + using sizes_const_iterator = const int64_t*; + using strides_iterator = int64_t*; + using strides_const_iterator = const int64_t*; + + SizesAndStrides() : sizes_{0}, strides_{1} {} + + size_t size() const { + return sizes_.size(); + } + + const int64_t* sizes_data() const { + return sizes_.data(); + } + + int64_t* sizes_data() { + return sizes_.data(); + } + + sizes_const_iterator sizes_begin() const { + return sizes_data(); + } + + sizes_iterator sizes_begin() { + return sizes_data(); + } + + sizes_const_iterator sizes_end() const { + return sizes_begin() + size(); + } + + sizes_iterator sizes_end() { + return sizes_begin() + size(); + } + + IntArrayRef sizes_arrayref() const { + return IntArrayRef{sizes_data(), size()}; + } + + void set_sizes(IntArrayRef newSizes) { + resize(newSizes.size()); + std::copy(newSizes.begin(), newSizes.end(), sizes_begin()); + } + + const int64_t* strides_data() const { + return strides_.data(); + } + + int64_t* strides_data() { + return strides_.data(); + } + + strides_const_iterator strides_begin() const { + return strides_data(); + } + + strides_iterator strides_begin() { + return strides_data(); + } + + strides_const_iterator strides_end() const { + return strides_begin() + size(); + } + + strides_iterator strides_end() { + return strides_begin() + size(); + } + + IntArrayRef strides_arrayref() const { + return IntArrayRef{strides_data(), size()}; + } + + // Size accessors. + int64_t size_at(size_t idx) const { + return sizes_.at(idx); + } + + int64_t& size_at(size_t idx) { + return sizes_.at(idx); + } + + int64_t size_at_unchecked(size_t idx) const { + return sizes_[idx]; + } + + int64_t& size_at_unchecked(size_t idx) { + return sizes_[idx]; + } + + // Size accessors. + int64_t stride_at(size_t idx) const { + return strides_.at(idx); + } + + int64_t& stride_at(size_t idx) { + return strides_.at(idx); + } + + int64_t stride_at_unchecked(size_t idx) const { + return strides_[idx]; + } + + int64_t& stride_at_unchecked(size_t idx) { + return strides_[idx]; + } + + void resize(size_t sz) { + sizes_.resize(sz); + strides_.resize(sz); + } + + private: + SmallVector sizes_; + SmallVector strides_; +}; + +} // namespace impl +} // namespace c10 diff --git a/c10/test/core/impl/SizesAndStrides_test.cpp b/c10/test/core/impl/SizesAndStrides_test.cpp new file mode 100644 index 000000000000..8863442e81cb --- /dev/null +++ b/c10/test/core/impl/SizesAndStrides_test.cpp @@ -0,0 +1,385 @@ +#include + +#include + +using namespace c10; +using namespace c10::impl; + +static void checkData(const SizesAndStrides& sz, IntArrayRef sizes, IntArrayRef strides) { + EXPECT_EQ(sizes.size(), strides.size()) << "bad test case: size() of sizes and strides don't match"; + EXPECT_EQ(sz.size(), sizes.size()); + + int idx = 0; + for (auto x: sizes) { + EXPECT_EQ(sz.size_at_unchecked(idx), x) << "index: " << idx; + EXPECT_EQ(sz.size_at(idx), x) << "index: " << idx; + EXPECT_EQ(sz.sizes_data()[idx], x) << "index: " << idx; + EXPECT_EQ(*(sz.sizes_begin() + idx), x) << "index: " << idx; + idx++; + } + EXPECT_EQ(sz.sizes_arrayref(), sizes); + + idx = 0; + for (auto x: strides) { + EXPECT_EQ(sz.stride_at_unchecked(idx), x) << "index: " << idx; + EXPECT_EQ(sz.stride_at(idx), x) << "index: " << idx; + EXPECT_EQ(sz.strides_data()[idx], x) << "index: " << idx; + EXPECT_EQ(*(sz.strides_begin() + idx), x) << "index: " << idx; + + idx++; + } + EXPECT_EQ(sz.strides_arrayref(), strides); +} + +TEST(SizesAndStridesTest, DefaultConstructor) { + SizesAndStrides sz; + checkData(sz, {0}, {1}); + // Can't test size_at() out of bounds because it just asserts for now. +} + +TEST(SizesAndStridesTest, SetSizes) { + SizesAndStrides sz; + sz.set_sizes({5, 6, 7, 8}); + checkData(sz, {5, 6, 7, 8}, {1, 0, 0, 0}); +} + +TEST(SizesAndStridesTest, Resize) { + SizesAndStrides sz; + + sz.resize(2); + + // Small to small growing. + checkData(sz, {0, 0}, {1, 0}); + + // Small to small growing, again. + sz.resize(5); + checkData(sz, {0, 0, 0, 0, 0}, {1, 0, 0, 0, 0}); + + for (int ii = 0; ii < sz.size(); ++ii) { + sz.size_at_unchecked(ii) = ii + 1; + sz.stride_at_unchecked(ii) = 2 * (ii + 1); + } + + checkData(sz, {1, 2, 3, 4, 5}, {2, 4, 6, 8, 10}); + + // Small to small, shrinking. + sz.resize(4); + checkData(sz, {1, 2, 3, 4}, {2, 4, 6, 8}); + + // Small to small with no size change. + sz.resize(4); + checkData(sz, {1, 2, 3, 4}, {2, 4, 6, 8}); + + // Small to big. + sz.resize(6); + + checkData(sz, {1, 2, 3, 4, 0, 0}, {2, 4, 6, 8, 0, 0}); + + sz.size_at_unchecked(5) = 6; + sz.stride_at_unchecked(5) = 12; + + checkData(sz, {1, 2, 3, 4, 0, 6}, {2, 4, 6, 8, 0, 12}); + + // Big to big, growing. + sz.resize(7); + + checkData(sz, {1, 2, 3, 4, 0, 6, 0}, {2, 4, 6, 8, 0, 12, 0}); + + // Big to big with no size change. + sz.resize(7); + + checkData(sz, {1, 2, 3, 4, 0, 6, 0}, {2, 4, 6, 8, 0, 12, 0}); + + // Big to big, shrinking. + sz.resize(6); + checkData(sz, {1, 2, 3, 4, 0, 6}, {2, 4, 6, 8, 0, 12}); + + // Finally, big to small. + + // Give it different data than it had when it was small to avoid + // getting it right by accident (i.e., because of leftover inline + // storage when going small to big). + for (int ii = 0; ii < sz.size(); ++ii) { + sz.size_at_unchecked(ii) = ii - 1; + sz.stride_at_unchecked(ii) = 2 * (ii - 1); + } + + checkData(sz, {-1, 0, 1, 2, 3, 4}, {-2, 0, 2, 4, 6, 8}); + + sz.resize(5); + checkData(sz, {-1, 0, 1, 2, 3}, {-2, 0, 2, 4, 6}); +} + +TEST(SizesAndStridesTest, SetAtIndex) { + SizesAndStrides sz; + + sz.resize(5); + sz.size_at(4) = 42; + sz.stride_at(4) = 23; + + checkData(sz, {0, 0, 0, 0, 42}, {1, 0, 0, 0, 23}); + + sz.resize(6); + sz.size_at(5) = 43; + sz.stride_at(5) = 24; + + checkData(sz, {0, 0, 0, 0, 42, 43}, {1, 0, 0, 0, 23, 24}); +} + +TEST(SizesAndStridesTest, SetAtIterator) { + SizesAndStrides sz; + + sz.resize(5); + *(sz.sizes_begin() + 4) = 42; + *(sz.strides_begin() + 4) = 23; + + checkData(sz, {0, 0, 0, 0, 42}, {1, 0, 0, 0, 23}); + + sz.resize(6); + *(sz.sizes_begin() + 5) = 43; + *(sz.strides_begin() + 5) = 24; + + checkData(sz, {0, 0, 0, 0, 42, 43}, {1, 0, 0, 0, 23, 24}); +} + +TEST(SizesAndStridesTest, SetViaData) { + SizesAndStrides sz; + + sz.resize(5); + *(sz.sizes_data() + 4) = 42; + *(sz.strides_data() + 4) = 23; + + checkData(sz, {0, 0, 0, 0, 42}, {1, 0, 0, 0, 23}); + + sz.resize(6); + *(sz.sizes_data() + 5) = 43; + *(sz.strides_data() + 5) = 24; + + checkData(sz, {0, 0, 0, 0, 42, 43}, {1, 0, 0, 0, 23, 24}); +} + +static SizesAndStrides makeSmall(int offset = 0) { + SizesAndStrides small; + small.resize(3); + for (int ii = 0; ii < small.size(); ++ii) { + small.size_at_unchecked(ii) = ii + 1 + offset; + small.stride_at_unchecked(ii) = 2 * (ii + 1 + offset); + } + + return small; +} + +static SizesAndStrides makeBig(int offset = 0) { + SizesAndStrides big; + big.resize(8); + for (int ii = 0; ii < big.size(); ++ii) { + big.size_at_unchecked(ii) = ii - 1 + offset; + big.stride_at_unchecked(ii) = 2 * (ii - 1 + offset); + } + + return big; +} + +static void checkSmall(const SizesAndStrides& sm, int offset = 0) { + std::vector sizes(3), strides(3); + for (int ii = 0; ii < 3; ++ii) { + sizes[ii] = ii + 1 + offset; + strides[ii] = 2 * (ii + 1 + offset); + } + checkData(sm, sizes, strides); +} + +static void checkBig(const SizesAndStrides& big, int offset = 0) { + std::vector sizes(8), strides(8); + for (int ii = 0; ii < 8; ++ii) { + sizes[ii] = ii - 1 + offset; + strides[ii] = 2 * (ii - 1 + offset); + } + checkData(big, sizes, strides); +} + +TEST(SizesAndStridesTest, MoveConstructor) { + SizesAndStrides empty; + + SizesAndStrides movedEmpty(std::move(empty)); + + EXPECT_EQ(empty.size(), 0); + EXPECT_EQ(movedEmpty.size(), 1); + checkData(movedEmpty, {0}, {1}); + + SizesAndStrides small = makeSmall(); + checkSmall(small); + + SizesAndStrides movedSmall(std::move(small)); + checkSmall(movedSmall); + EXPECT_EQ(small.size(), 0); + + SizesAndStrides big = makeBig(); + checkBig(big); + + SizesAndStrides movedBig(std::move(big)); + checkBig(movedBig); + EXPECT_EQ(big.size(), 0); +} + +TEST(SizesAndStridesTest, CopyConstructor) { + SizesAndStrides empty; + + SizesAndStrides copiedEmpty(empty); + + EXPECT_EQ(empty.size(), 1); + EXPECT_EQ(copiedEmpty.size(), 1); + checkData(empty, {0}, {1}); + checkData(copiedEmpty, {0}, {1}); + + SizesAndStrides small = makeSmall(); + checkSmall(small); + + SizesAndStrides copiedSmall(small); + checkSmall(copiedSmall); + checkSmall(small); + + SizesAndStrides big = makeBig(); + checkBig(big); + + SizesAndStrides copiedBig(big); + checkBig(big); + checkBig(copiedBig); +} + +TEST(SizesAndStridesTest, CopyAssignmentSmallToSmall) { + SizesAndStrides smallTarget = makeSmall(); + SizesAndStrides smallCopyFrom = makeSmall(1); + + checkSmall(smallTarget); + checkSmall(smallCopyFrom, 1); + + smallTarget = smallCopyFrom; + + checkSmall(smallTarget, 1); + checkSmall(smallCopyFrom, 1); +} + +TEST(SizesAndStridesTest, MoveAssignmentSmallToSmall) { + SizesAndStrides smallTarget = makeSmall(); + SizesAndStrides smallMoveFrom = makeSmall(1); + + checkSmall(smallTarget); + checkSmall(smallMoveFrom, 1); + + smallTarget = std::move(smallMoveFrom); + + checkSmall(smallTarget, 1); + EXPECT_EQ(smallMoveFrom.size(), 0); +} + +TEST(SizesAndStridesTest, CopyAssignmentSmallToBig) { + SizesAndStrides bigTarget = makeBig(); + SizesAndStrides smallCopyFrom = makeSmall(); + + checkBig(bigTarget); + checkSmall(smallCopyFrom); + + bigTarget = smallCopyFrom; + + checkSmall(bigTarget); + checkSmall(smallCopyFrom); +} + +TEST(SizesAndStridesTest, MoveAssignmentSmallToBig) { + SizesAndStrides bigTarget = makeBig(); + SizesAndStrides smallMoveFrom = makeSmall(); + + checkBig(bigTarget); + checkSmall(smallMoveFrom); + + bigTarget = std::move(smallMoveFrom); + + checkSmall(bigTarget); + EXPECT_EQ(smallMoveFrom.size(), 0); +} + +TEST(SizesAndStridesTest, CopyAssignmentBigToBig) { + SizesAndStrides bigTarget = makeBig(); + SizesAndStrides bigCopyFrom = makeBig(1); + + checkBig(bigTarget); + checkBig(bigCopyFrom, 1); + + bigTarget = bigCopyFrom; + + checkBig(bigTarget, 1); + checkBig(bigCopyFrom, 1); +} + +TEST(SizesAndStridesTest, MoveAssignmentBigToBig) { + SizesAndStrides bigTarget = makeBig(); + SizesAndStrides bigMoveFrom = makeBig(1); + + checkBig(bigTarget); + checkBig(bigMoveFrom, 1); + + bigTarget = std::move(bigMoveFrom); + + checkBig(bigTarget, 1); + EXPECT_EQ(bigMoveFrom.size(), 0); +} + +TEST(SizesAndStridesTest, CopyAssignmentBigToSmall) { + SizesAndStrides smallTarget = makeSmall(); + SizesAndStrides bigCopyFrom = makeBig(); + + checkSmall(smallTarget); + checkBig(bigCopyFrom); + + smallTarget = bigCopyFrom; + + checkBig(smallTarget); + checkBig(bigCopyFrom); +} + +TEST(SizesAndStridesTest, MoveAssignmentBigToSmall) { + SizesAndStrides smallTarget = makeSmall(); + SizesAndStrides bigMoveFrom = makeBig(); + + checkSmall(smallTarget); + checkBig(bigMoveFrom); + + smallTarget = std::move(bigMoveFrom); + + checkBig(smallTarget); + EXPECT_EQ(bigMoveFrom.size(), 0); +} + +TEST(SizesAndStridesTest, CopyAssignmentSelf) { + SizesAndStrides small = makeSmall(); + SizesAndStrides big = makeBig(); + + checkSmall(small); + checkBig(big); + + small = small; + checkSmall(small); + + big = big; + checkBig(big); +} + +// Avoid failures due to -Wall -Wself-move. +static void selfMove(SizesAndStrides& x, SizesAndStrides& y) { + x = std::move(y); +} + +TEST(SizesAndStridesTest, MoveAssignmentSelf) { + SizesAndStrides small = makeSmall(); + SizesAndStrides big = makeBig(); + + checkSmall(small); + checkBig(big); + + selfMove(small, small); + checkSmall(small); + + selfMove(big, big); + checkBig(big); +} From b73c01859863892476b65aef071064633a97ed5d Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Thu, 7 Jan 2021 20:54:20 -0800 Subject: [PATCH 30/54] [PyTorch] Change representation of SizesAndStrides (#47508) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47508 This moves SizesAndStrides to a specialized representation that is 5 words smaller in the common case of tensor rank 5 or less. ghstack-source-id: 119313560 Test Plan: SizesAndStridesTest added in previous diff passes under ASAN + UBSAN. Run framework overhead benchmarks. Looks more or less neutral. Reviewed By: ezyang Differential Revision: D24772023 fbshipit-source-id: 0a75fd6c2daabb0769e2f803e80e2d6831871316 --- c10/core/TensorImpl.h | 29 +-- c10/core/impl/SizesAndStrides.cpp | 66 ++++++ c10/core/impl/SizesAndStrides.h | 247 ++++++++++++++++---- c10/test/core/impl/SizesAndStrides_test.cpp | 16 +- 4 files changed, 296 insertions(+), 62 deletions(-) create mode 100644 c10/core/impl/SizesAndStrides.cpp diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index 41d4ff4c7f08..16f7b1d90ec4 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -1848,22 +1848,17 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { // autograd metadata pointer // version counter pointer // PyObject pointer -// sizes SmallVector (begin) -// sizes SmallVector (end) -// sizes SmallVector (capacity) -// sizes SmallVector (pre-allocated 0) -// sizes SmallVector (pre-allocated 1) -// sizes SmallVector (pre-allocated 2) -// sizes SmallVector (pre-allocated 3) -// sizes SmallVector (pre-allocated 4) -// strides SmallVector (begin) -// strides SmallVector (end) -// strides SmallVector (capacity) -// strides SmallVector (pre-allocated 0) -// strides SmallVector (pre-allocated 1) -// strides SmallVector (pre-allocated 2) -// strides SmallVector (pre-allocated 3) -// strides SmallVector (pre-allocated 4) +// SizesAndStrides size/pointer +// SizesAndStrides sizes (pre-allocated 0) +// SizesAndStrides sizes (pre-allocated 1) +// SizesAndStrides sizes (pre-allocated 2) +// SizesAndStrides sizes (pre-allocated 3) +// SizesAndStrides sizes (pre-allocated 4) +// SizesAndStrides strides (pre-allocated 0) +// SizesAndStrides strides (pre-allocated 1) +// SizesAndStrides strides (pre-allocated 2) +// SizesAndStrides strides (pre-allocated 3) +// SizesAndStrides strides (pre-allocated 4) // storage offset // numel // data type @@ -1872,7 +1867,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { // miscellaneous bitfield // static_assert(sizeof(void*) != sizeof(int64_t) || // if 64-bit... - sizeof(TensorImpl) == sizeof(int64_t) * 29, + sizeof(TensorImpl) == sizeof(int64_t) * 24, "You changed the size of TensorImpl on 64-bit arch." "See Note [TensorImpl size constraints] on how to proceed."); } // namespace c10 diff --git a/c10/core/impl/SizesAndStrides.cpp b/c10/core/impl/SizesAndStrides.cpp new file mode 100644 index 000000000000..bf7ec3ff887d --- /dev/null +++ b/c10/core/impl/SizesAndStrides.cpp @@ -0,0 +1,66 @@ +#include + +namespace c10 { +namespace impl { + +void SizesAndStrides::resizeSlowPath(const size_t newSize, const size_t oldSize) { + if (newSize <= C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!isInline(), "resizeSlowPath called when fast path should have been hit!"); + int64_t* tempStorage = outOfLineStorage_; + memcpy( + &inlineStorage_[0], + &tempStorage[0], + C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE * sizeof(inlineStorage_[0])); + memcpy( + &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE], + &tempStorage[oldSize], + C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE * sizeof(inlineStorage_[0])); + // CANNOT USE freeOutOfLineStorage() HERE! outOfLineStorage_ + // HAS BEEN OVERWRITTEN! + free(tempStorage); + } else { + if (isInline()) { + // CANNOT USE allocateOutOfLineStorage(newSize) HERE! WOULD + // OVERWRITE inlineStorage_! + int64_t* tempStorage = static_cast(malloc(storageBytes(newSize))); + TORCH_CHECK(tempStorage, "Could not allocate memory to change Tensor SizesAndStrides!"); + const auto bytesToCopy = oldSize * sizeof(inlineStorage_[0]); + const auto bytesToZero = (newSize > oldSize) ? (newSize - oldSize) * sizeof(tempStorage[0]) : 0; + memcpy(&tempStorage[0], &inlineStorage_[0], bytesToCopy); + if (bytesToZero) { + memset(&tempStorage[oldSize], 0, bytesToZero); + } + memcpy(&tempStorage[newSize], &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE], bytesToCopy); + if (bytesToZero) { + memset(&tempStorage[newSize + oldSize], 0, bytesToZero); + } + outOfLineStorage_ = tempStorage; + } else { + const bool isGrowing = oldSize < newSize; + if (isGrowing) { + // Resize before shifting so that we have room. + resizeOutOfLineStorage(newSize); + } + // Shift the old strides to their new starting point. Note + // that this does not occur in the inline path above because + // the stride starting point is not moving. + memmove( + outOfLineStorage_ + newSize, + outOfLineStorage_ + oldSize, + std::min(oldSize, newSize) * sizeof(outOfLineStorage_[0])); + if (!isGrowing) { + // Resize after shifting so that we don't lose data. + resizeOutOfLineStorage(newSize); + } else { + // Zero the end of the sizes portion. + const auto bytesToZero = (newSize - oldSize) * sizeof(outOfLineStorage_[0]); + memset(&outOfLineStorage_[oldSize], 0, bytesToZero); + memset(&outOfLineStorage_[newSize + oldSize], 0, bytesToZero); + } + } + } + size_ = newSize; +} + +} // namespace impl +} // namespace c10 diff --git a/c10/core/impl/SizesAndStrides.h b/c10/core/impl/SizesAndStrides.h index b50af160d29c..4f7e19330aca 100644 --- a/c10/core/impl/SizesAndStrides.h +++ b/c10/core/impl/SizesAndStrides.h @@ -1,15 +1,25 @@ #pragma once #include +#include #include #include #include +#define C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE 5 + namespace c10 { namespace impl { // Packed container for TensorImpl sizes and strides. +// This design improves on the previous approach of using a pair of +// c10::SmallVector by specializing for the operations we +// actually use and enforcing that the number of sizes is the same as +// the number of strides. The memory layout is as follows: +// +// 1 size_t for the size +// 5 eightbytes of inline sizes and 5 eightbytes of inline strides, OR pointer to out-of-line array class C10_API SizesAndStrides { public: // TODO: different iterator types for sizes & strides to prevent @@ -19,37 +29,120 @@ class C10_API SizesAndStrides { using strides_iterator = int64_t*; using strides_const_iterator = const int64_t*; - SizesAndStrides() : sizes_{0}, strides_{1} {} + SizesAndStrides() : size_(1) { + size_at_unchecked(0) = 0; + stride_at_unchecked(0) = 1; + } + + ~SizesAndStrides() { + if (C10_UNLIKELY(!isInline())) { + free(outOfLineStorage_); + } + } + + SizesAndStrides(const SizesAndStrides& rhs) : size_(rhs.size_) { + if (C10_LIKELY(rhs.isInline())) { + copyDataInline(rhs); + } else { + allocateOutOfLineStorage(size_); + copyDataOutline(rhs); + } + } + + SizesAndStrides& operator=(const SizesAndStrides& rhs) { + if (this == &rhs) { + return *this; + } + if (C10_LIKELY(rhs.isInline())) { + if (C10_UNLIKELY(!isInline())) { + free(outOfLineStorage_); + } + copyDataInline(rhs); + } else { + if (isInline()) { + allocateOutOfLineStorage(rhs.size_); + } else { + resizeOutOfLineStorage(rhs.size_); + } + copyDataOutline(rhs); + } + size_ = rhs.size_; + return *this; + } + + // Move from rhs. rhs.size() == 0 afterwards. + SizesAndStrides(SizesAndStrides&& rhs) noexcept : size_(rhs.size_) { + if (C10_LIKELY(isInline())) { + memcpy(inlineStorage_, rhs.inlineStorage_, sizeof(inlineStorage_)); + } else { + outOfLineStorage_ = rhs.outOfLineStorage_; + rhs.outOfLineStorage_ = nullptr; + } + + rhs.size_ = 0; + } - size_t size() const { - return sizes_.size(); + // Move from rhs. rhs.size() == 0 afterwards. + SizesAndStrides& operator=(SizesAndStrides&& rhs) noexcept { + if (this == &rhs) { + return *this; + } + if (C10_LIKELY(rhs.isInline())) { + if (C10_UNLIKELY(!isInline())) { + free(outOfLineStorage_); + } + copyDataInline(rhs); + } else { + // They're outline. We're going to steal their vector. + if (!isInline()) { + free(outOfLineStorage_); + } + outOfLineStorage_ = rhs.outOfLineStorage_; + rhs.outOfLineStorage_ = nullptr; + } + size_ = rhs.size_; + rhs.size_ = 0; + + return *this; } - const int64_t* sizes_data() const { - return sizes_.data(); + size_t size() const noexcept { + return size_; } - int64_t* sizes_data() { - return sizes_.data(); + const int64_t* sizes_data() const noexcept { + if (C10_LIKELY(isInline())) { + return &inlineStorage_[0]; + } else { + return &outOfLineStorage_[0]; + } } - sizes_const_iterator sizes_begin() const { + int64_t* sizes_data() noexcept { + if (C10_LIKELY(isInline())) { + return &inlineStorage_[0]; + } else { + return &outOfLineStorage_[0]; + } + } + + sizes_const_iterator sizes_begin() const noexcept { return sizes_data(); } - sizes_iterator sizes_begin() { + sizes_iterator sizes_begin() noexcept { return sizes_data(); } - sizes_const_iterator sizes_end() const { + sizes_const_iterator sizes_end() const noexcept { return sizes_begin() + size(); } - sizes_iterator sizes_end() { + sizes_iterator sizes_end() noexcept { return sizes_begin() + size(); } - IntArrayRef sizes_arrayref() const { + IntArrayRef sizes_arrayref() const noexcept { return IntArrayRef{sizes_data(), size()}; } @@ -58,76 +151,142 @@ class C10_API SizesAndStrides { std::copy(newSizes.begin(), newSizes.end(), sizes_begin()); } - const int64_t* strides_data() const { - return strides_.data(); + const int64_t* strides_data() const noexcept { + if (C10_LIKELY(isInline())) { + return &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE]; + } else { + return &outOfLineStorage_[size()]; + } } - int64_t* strides_data() { - return strides_.data(); + int64_t* strides_data() noexcept { + if (C10_LIKELY(isInline())) { + return &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE]; + } else { + return &outOfLineStorage_[size()]; + } } - strides_const_iterator strides_begin() const { - return strides_data(); + strides_const_iterator strides_begin() const noexcept { + if (C10_LIKELY(isInline())) { + return &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE]; + } else { + return &outOfLineStorage_[size()]; + } } - strides_iterator strides_begin() { - return strides_data(); + strides_iterator strides_begin() noexcept { + if (C10_LIKELY(isInline())) { + return &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE]; + } else { + return &outOfLineStorage_[size()]; + } } - strides_const_iterator strides_end() const { + strides_const_iterator strides_end() const noexcept { return strides_begin() + size(); } - strides_iterator strides_end() { + strides_iterator strides_end() noexcept { return strides_begin() + size(); } - IntArrayRef strides_arrayref() const { + IntArrayRef strides_arrayref() const noexcept { return IntArrayRef{strides_data(), size()}; } // Size accessors. - int64_t size_at(size_t idx) const { - return sizes_.at(idx); + int64_t size_at(size_t idx) const noexcept { + assert(idx < size()); + return sizes_data()[idx]; } - int64_t& size_at(size_t idx) { - return sizes_.at(idx); + int64_t& size_at(size_t idx) noexcept { + assert(idx < size()); + return sizes_data()[idx]; } - int64_t size_at_unchecked(size_t idx) const { - return sizes_[idx]; + int64_t size_at_unchecked(size_t idx) const noexcept { + return sizes_data()[idx]; } - int64_t& size_at_unchecked(size_t idx) { - return sizes_[idx]; + int64_t& size_at_unchecked(size_t idx) noexcept { + return sizes_data()[idx]; } // Size accessors. - int64_t stride_at(size_t idx) const { - return strides_.at(idx); + int64_t stride_at(size_t idx) const noexcept { + assert(idx < size()); + return strides_data()[idx]; } - int64_t& stride_at(size_t idx) { - return strides_.at(idx); + int64_t& stride_at(size_t idx) noexcept { + assert(idx < size()); + return strides_data()[idx]; } - int64_t stride_at_unchecked(size_t idx) const { - return strides_[idx]; + int64_t stride_at_unchecked(size_t idx) const noexcept { + return strides_data()[idx]; } - int64_t& stride_at_unchecked(size_t idx) { - return strides_[idx]; + int64_t& stride_at_unchecked(size_t idx) noexcept { + return strides_data()[idx]; } - void resize(size_t sz) { - sizes_.resize(sz); - strides_.resize(sz); + void resize(size_t newSize) { + const auto oldSize = size(); + if (newSize == oldSize) { + return; + } + if (C10_LIKELY(newSize <= C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE && isInline())) { + if (oldSize < newSize) { + const auto bytesToZero = (newSize - oldSize) * sizeof(inlineStorage_[0]); + memset(&inlineStorage_[oldSize], 0, bytesToZero); + memset(&inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE + oldSize], 0, bytesToZero); + } + size_ = newSize; + } else { + resizeSlowPath(newSize, oldSize); + } } + void resizeSlowPath(size_t newSize, size_t oldSize); + private: - SmallVector sizes_; - SmallVector strides_; + bool isInline() const noexcept { + return size_ <= C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE; + } + + void copyDataInline(const SizesAndStrides& rhs) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.isInline()); + memcpy(inlineStorage_, rhs.inlineStorage_, sizeof(inlineStorage_)); + } + + static size_t storageBytes(size_t size) noexcept { + return size * 2 * sizeof(int64_t); + } + + void allocateOutOfLineStorage(size_t size) { + outOfLineStorage_ = static_cast(malloc(storageBytes(size))); + TORCH_CHECK(outOfLineStorage_, "Could not allocate memory for Tensor SizesAndStrides!"); + } + + void resizeOutOfLineStorage(size_t newSize) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!isInline()); + outOfLineStorage_ = static_cast(realloc(outOfLineStorage_, storageBytes(newSize))); + TORCH_CHECK(outOfLineStorage_, "Could not allocate memory for Tensor SizesAndStrides!"); + } + + void copyDataOutline(const SizesAndStrides& rhs) noexcept { + memcpy(outOfLineStorage_, rhs.outOfLineStorage_, storageBytes(rhs.size_)); + } + + size_t size_; + union { + int64_t *outOfLineStorage_; + int64_t inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE * 2]{}; + }; + }; } // namespace impl diff --git a/c10/test/core/impl/SizesAndStrides_test.cpp b/c10/test/core/impl/SizesAndStrides_test.cpp index 8863442e81cb..94e90c42feff 100644 --- a/c10/test/core/impl/SizesAndStrides_test.cpp +++ b/c10/test/core/impl/SizesAndStrides_test.cpp @@ -70,6 +70,11 @@ TEST(SizesAndStridesTest, Resize) { sz.resize(4); checkData(sz, {1, 2, 3, 4}, {2, 4, 6, 8}); + // Small to small, growing back so that we can confirm that our "new" + // data really does get zeroed. + sz.resize(5); + checkData(sz, {1, 2, 3, 4, 0}, {2, 4, 6, 8, 0}); + // Small to big. sz.resize(6); @@ -90,10 +95,19 @@ TEST(SizesAndStridesTest, Resize) { checkData(sz, {1, 2, 3, 4, 0, 6, 0}, {2, 4, 6, 8, 0, 12, 0}); + sz.size_at_unchecked(6) = 11; + sz.stride_at_unchecked(6) = 22; + + checkData(sz, {1, 2, 3, 4, 0, 6, 11}, {2, 4, 6, 8, 0, 12, 22}); + // Big to big, shrinking. sz.resize(6); checkData(sz, {1, 2, 3, 4, 0, 6}, {2, 4, 6, 8, 0, 12}); + // Grow back to make sure "new" elements get zeroed in big mode too. + sz.resize(7); + checkData(sz, {1, 2, 3, 4, 0, 6, 0}, {2, 4, 6, 8, 0, 12, 0}); + // Finally, big to small. // Give it different data than it had when it was small to avoid @@ -104,7 +118,7 @@ TEST(SizesAndStridesTest, Resize) { sz.stride_at_unchecked(ii) = 2 * (ii - 1); } - checkData(sz, {-1, 0, 1, 2, 3, 4}, {-2, 0, 2, 4, 6, 8}); + checkData(sz, {-1, 0, 1, 2, 3, 4, 5}, {-2, 0, 2, 4, 6, 8, 10}); sz.resize(5); checkData(sz, {-1, 0, 1, 2, 3}, {-2, 0, 2, 4, 6}); From 5a63c452e638dad8e077887ad8d2c94ff0e23917 Mon Sep 17 00:00:00 2001 From: Xiao Wang <24860335+xwang233@users.noreply.github.com> Date: Thu, 7 Jan 2021 21:15:59 -0800 Subject: [PATCH 31/54] Disable cuDNN persistent RNN on sm_86 devices (#49534) Summary: Excludes sm_86 GPU devices from using cuDNN persistent RNN. This is because there are some hard-to-detect edge cases that will throw exceptions with cudnn 8.0.5 on Nvidia A40 GPU. Pull Request resolved: https://github.com/pytorch/pytorch/pull/49534 Reviewed By: mruberry Differential Revision: D25632378 Pulled By: mrshenli fbshipit-source-id: cbe78236d85d4d0c2e4ca63a3fc2c4e2de662d9e --- aten/src/ATen/native/cudnn/RNN.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/aten/src/ATen/native/cudnn/RNN.cpp b/aten/src/ATen/native/cudnn/RNN.cpp index 1dc9d5cba945..8e1f254da9f8 100644 --- a/aten/src/ATen/native/cudnn/RNN.cpp +++ b/aten/src/ATen/native/cudnn/RNN.cpp @@ -722,6 +722,11 @@ namespace { (tensors.seq_length >=10 && bsize <=32)); } } else if (prop->major >= 8) { + if (prop->minor == 6) { + // Excludes sm_86 GPU devices from using persistent rnn. + // This is because there are some edge cases that will throw exceptions with cudnn 8.0.5 on Nvidia A40 GPU. + return false; + } // Based on tests by Vasily Volkov and xwang233. Vasily only tried bsize <= 128, // so conservatively enable persistence for bsize <= 128 only. // TODO: Run more tests for bsize > 128. From 294b7867eb7f015de6219e03ad201583c1efb42a Mon Sep 17 00:00:00 2001 From: Rohan Varma Date: Thu, 7 Jan 2021 21:25:40 -0800 Subject: [PATCH 32/54] Address clang-tidy warnings in ProcessGroupNCCL (#50131) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/50131 Noticed that in the internal diff for https://github.com/pytorch/pytorch/pull/49069 there was a clang-tidy warning to use emplace instead of push_back. This can save us a copy as it eliminates the unnecessary in-place construction ghstack-source-id: 119560979 Test Plan: CI Reviewed By: pritamdamania87 Differential Revision: D25800134 fbshipit-source-id: 243e57318f5d6e43de524d4e5409893febe6164c --- torch/lib/c10d/ProcessGroupNCCL.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch/lib/c10d/ProcessGroupNCCL.cpp b/torch/lib/c10d/ProcessGroupNCCL.cpp index b9ac5aa77150..f7c6658d715e 100644 --- a/torch/lib/c10d/ProcessGroupNCCL.cpp +++ b/torch/lib/c10d/ProcessGroupNCCL.cpp @@ -1413,7 +1413,7 @@ c10::intrusive_ptr ProcessGroupNCCL::barrier( // Use user defined GPU device ids if provided if (!opts.device_ids.empty()) { for (auto device : opts.device_ids) { - devices.push_back(at::Device(at::DeviceType::CUDA, device)); + devices.emplace_back(at::DeviceType::CUDA, device); } } else if (usedDeviceIdxs_.empty()) { // This means there is not yet a NCCL collective being called @@ -1423,10 +1423,10 @@ c10::intrusive_ptr ProcessGroupNCCL::barrier( // ensure that each process is on a different GPU auto numGPUs = at::cuda::getNumGPUs(); int16_t deviceIdx = static_cast(rank_ % numGPUs); - devices.push_back(at::Device(at::DeviceType::CUDA, deviceIdx)); + devices.emplace_back(at::DeviceType::CUDA, deviceIdx); } else { for (auto usedDeviceIdx : usedDeviceIdxs_) { - devices.push_back(at::Device(at::DeviceType::CUDA, usedDeviceIdx)); + devices.emplace_back(at::DeviceType::CUDA, usedDeviceIdx); } } From c215ffb6a2de1c85a62a146a9e282908335ba3ed Mon Sep 17 00:00:00 2001 From: Lucian Grijincu Date: Thu, 7 Jan 2021 22:05:30 -0800 Subject: [PATCH 33/54] Revert D25687465: [PyTorch] Devirtualize TensorImpl::dim() with macro Test Plan: revert-hammer Differential Revision: D25687465 (https://github.com/pytorch/pytorch/commit/4de6b279c8975409b554ecdb0fcf611839e55f9b) Original commit changeset: 89aabce165a5 fbshipit-source-id: fa5def17209d1691e68b1245fa0873fd03e88eaa --- aten/src/ATen/SparseTensorImpl.cpp | 4 ++++ aten/src/ATen/SparseTensorImpl.h | 1 + aten/src/ATen/test/undefined_tensor_test.cpp | 2 +- c10/core/TensorImpl.cpp | 2 -- c10/core/TensorImpl.h | 9 +-------- c10/core/UndefinedTensorImpl.cpp | 4 ++++ c10/core/UndefinedTensorImpl.h | 1 + 7 files changed, 12 insertions(+), 11 deletions(-) diff --git a/aten/src/ATen/SparseTensorImpl.cpp b/aten/src/ATen/SparseTensorImpl.cpp index 0e18dca131a2..98670db11e86 100644 --- a/aten/src/ATen/SparseTensorImpl.cpp +++ b/aten/src/ATen/SparseTensorImpl.cpp @@ -69,6 +69,10 @@ void SparseTensorImpl::set_storage_offset(int64_t storage_offset) { AT_ERROR("sparse tensors do not have set_storage_offset"); } +int64_t SparseTensorImpl::dim() const { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(sparse_dim_ + dense_dim_ == TensorImpl::dim()); + return sparse_dim_ + dense_dim_; +} bool SparseTensorImpl::has_storage() const { return false; } diff --git a/aten/src/ATen/SparseTensorImpl.h b/aten/src/ATen/SparseTensorImpl.h index 9daf21c15e56..5f502a6eaa54 100644 --- a/aten/src/ATen/SparseTensorImpl.h +++ b/aten/src/ATen/SparseTensorImpl.h @@ -47,6 +47,7 @@ struct TORCH_API SparseTensorImpl : public TensorImpl { void set_stride(int64_t dim, int64_t new_stride) override; void set_storage_offset(int64_t storage_offset) override; + int64_t dim() const override; bool has_storage() const override; const Storage& storage() const override; int64_t storage_offset() const override; diff --git a/aten/src/ATen/test/undefined_tensor_test.cpp b/aten/src/ATen/test/undefined_tensor_test.cpp index 0102a8cf4f49..7a6dd50f9163 100644 --- a/aten/src/ATen/test/undefined_tensor_test.cpp +++ b/aten/src/ATen/test/undefined_tensor_test.cpp @@ -19,7 +19,7 @@ TEST(TestUndefined, UndefinedTest) { ASSERT_EQ(std::string("UndefinedType"), und.toString()); ASSERT_ANY_THROW(und.strides()); - ASSERT_EQ(und.dim(), 1); + ASSERT_ANY_THROW(und.dim()); ASSERT_ANY_THROW([]() { return Tensor(); }() = Scalar(5)); ASSERT_ANY_THROW(und.add(und)); ASSERT_ANY_THROW(und.add(ft)); diff --git a/c10/core/TensorImpl.cpp b/c10/core/TensorImpl.cpp index d79d38427d5a..98a2f74e9399 100644 --- a/c10/core/TensorImpl.cpp +++ b/c10/core/TensorImpl.cpp @@ -218,11 +218,9 @@ void TensorImpl::release_resources() { } } -#ifndef C10_DISABLE_TENSORIMPL_EXTENSIBILITY int64_t TensorImpl::dim() const { return sizes_and_strides_.size(); } -#endif int64_t TensorImpl::size(int64_t d) const { d = at::maybe_wrap_dim(d, dim(), false); diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index 16f7b1d90ec4..47e1c865f999 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -405,14 +405,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { * Return the number of dimensions of this tensor. Note that 0-dimension * represents a Tensor that is a Scalar, e.g., one that has a single element. */ - TENSORIMPL_MAYBE_VIRTUAL int64_t dim() const -#ifdef C10_DISABLE_TENSORIMPL_EXTENSIBILITY - { - return sizes_.size(); - } -#else - ; -#endif + virtual int64_t dim() const; /** * True if this tensor has storage. See storage() for details. diff --git a/c10/core/UndefinedTensorImpl.cpp b/c10/core/UndefinedTensorImpl.cpp index ab7f9d42c5c4..f79897c72e04 100644 --- a/c10/core/UndefinedTensorImpl.cpp +++ b/c10/core/UndefinedTensorImpl.cpp @@ -20,6 +20,10 @@ int64_t UndefinedTensorImpl::stride(int64_t d) const { AT_ERROR("stride(dim) called on an undefined Tensor"); } +int64_t UndefinedTensorImpl::dim() const { + AT_ERROR("dim() called on undefined Tensor"); +} + bool UndefinedTensorImpl::has_storage() const { AT_ERROR("has_storage() called on undefined Tensor"); } diff --git a/c10/core/UndefinedTensorImpl.h b/c10/core/UndefinedTensorImpl.h index efe1ed6a463e..26122ed305e2 100644 --- a/c10/core/UndefinedTensorImpl.h +++ b/c10/core/UndefinedTensorImpl.h @@ -21,6 +21,7 @@ struct C10_API UndefinedTensorImpl final : public TensorImpl { IntArrayRef strides() const override; int64_t size(int64_t d) const override; int64_t stride(int64_t d) const override; + int64_t dim() const override; bool has_storage() const override; const Storage& storage() const override; int64_t storage_offset() const override; From fc2ead094471dad946f9f44e903b02aea5324021 Mon Sep 17 00:00:00 2001 From: Alban Desmaison Date: Fri, 8 Jan 2021 05:28:32 -0800 Subject: [PATCH 34/54] Autograd engine, only enqueue task when it is fully initialized (#50164) Summary: This solves a race condition where the worker thread might see a partially initialized graph_task Fixes https://github.com/pytorch/pytorch/issues/49652 I don't know how to reliably trigger the race so I didn't add any test. But the rocm build flakyness (it just happens to race more often on rocm builds) should disappear after this PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/50164 Reviewed By: zou3519 Differential Revision: D25824954 Pulled By: albanD fbshipit-source-id: 6a3391753cb2afd2ab415d3fb2071a837cc565bb --- torch/csrc/autograd/engine.cpp | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp index f906efb187ef..ad217c2924ad 100644 --- a/torch/csrc/autograd/engine.cpp +++ b/torch/csrc/autograd/engine.cpp @@ -916,7 +916,6 @@ std::shared_ptr Engine::execute_with_graph_task( std::unique_lock lock(graph_task->mutex_); auto queue = ready_queue(graph_task->cpu_ready_queue_, input_buffer.device()); - queue->push(NodeTask(graph_task, std::move(graph_root), std::move(input_buffer))); // worker_device == NO_DEVICE it's a CPU thread and it's trying to drive the // autograd engine with corresponding GraphTask, and its NOT a re-entrant call @@ -929,8 +928,12 @@ std::shared_ptr Engine::execute_with_graph_task( // set the graph_task owner to the current device graph_task->owner_ = worker_device; - // The owning thread start to drive the engine execution with the GraphTask - // that has already been pushed to the current CPU thread's ready_queue + // Now that all the non-thread safe fields of the graph_task have been populated, + // we can enqueue it. + queue->push(NodeTask(graph_task, std::move(graph_root), std::move(input_buffer))); + + // The owning thread start to drive the engine execution for any CPU task that + // was just pushed or will be added later from other worker threads lock.unlock(); thread_main(graph_task); TORCH_INTERNAL_ASSERT(graph_task->future_result_->completed()); @@ -943,6 +946,11 @@ std::shared_ptr Engine::execute_with_graph_task( // If worker_device is any devices (i.e. CPU, CUDA): this is a re-entrant // backward call from that device. graph_task->owner_ = worker_device; + + // Now that all the non-thread safe fields of the graph_task have been populated, + // we can enqueue it. + queue->push(NodeTask(graph_task, std::move(graph_root), std::move(input_buffer))); + if (current_depth >= max_recursion_depth_) { // See Note [Reentrant backwards] // If reached the max depth, switch to a different thread From 9f832c8d3ea31666d21c87b5fe4792a68a545036 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Fri, 8 Jan 2021 06:28:05 -0800 Subject: [PATCH 35/54] [numpy] torch.exp: promote integer inputs to float (#50093) Summary: Reference: https://github.com/pytorch/pytorch/issues/42515 Pull Request resolved: https://github.com/pytorch/pytorch/pull/50093 Reviewed By: H-Huang Differential Revision: D25803549 Pulled By: mruberry fbshipit-source-id: e6f245b5e728f2dca6072f8c359f03dff63aa14d --- aten/src/ATen/native/UnaryOps.cpp | 4 ++-- aten/src/ATen/native/cuda/UnaryOpsKernel.cu | 2 +- test/test_torch.py | 3 --- test/test_unary_ufuncs.py | 1 - torch/csrc/jit/tensorexpr/kernel.cpp | 5 +++-- .../_internal/common_methods_invocations.py | 16 ++++++++++++++-- 6 files changed, 20 insertions(+), 11 deletions(-) diff --git a/aten/src/ATen/native/UnaryOps.cpp b/aten/src/ATen/native/UnaryOps.cpp index 0f6da7e4292a..c636d5d94a2f 100644 --- a/aten/src/ATen/native/UnaryOps.cpp +++ b/aten/src/ATen/native/UnaryOps.cpp @@ -256,8 +256,8 @@ Tensor& ceil_out(Tensor& result, const Tensor& self) { Tensor ceil(const Tensor& self) { return unary_op_impl(self, at::ceil_out); } Tensor& ceil_(Tensor& self) { return unary_op_impl_(self, at::ceil_out); } -Tensor& exp_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, exp_stub); } -Tensor exp(const Tensor& self) { return unary_op_impl(self, at::exp_out); } +Tensor& exp_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, exp_stub); } +Tensor exp(const Tensor& self) { return unary_op_impl_float(self, exp_stub); } Tensor& exp_(Tensor& self) { return unary_op_impl_(self, at::exp_out); } Tensor& exp2_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, exp2_stub); } diff --git a/aten/src/ATen/native/cuda/UnaryOpsKernel.cu b/aten/src/ATen/native/cuda/UnaryOpsKernel.cu index f5e1a4e85a04..e727335aaf17 100644 --- a/aten/src/ATen/native/cuda/UnaryOpsKernel.cu +++ b/aten/src/ATen/native/cuda/UnaryOpsKernel.cu @@ -33,7 +33,7 @@ void bitwise_not_kernel_cuda(TensorIterator& iter) { } void exp_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "exp_cuda", [&]() { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.common_dtype(), "exp_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { return ::exp(a); }); diff --git a/test/test_torch.py b/test/test_torch.py index 72fa853e2e7c..874a8a6ac9f6 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -6909,9 +6909,6 @@ def inner(self, device, dtype): ('atanh', '', _small_3d, lambda t, d: [], 1e-3, 1e-2, 1e-5, torch.testing.get_all_fp_dtypes()), ('erf', '', _small_3d, lambda t, d: [], 1e-3, 1e-2, 1e-5, torch.testing.get_all_fp_dtypes(), [torch.bfloat16]), ('erfc', '', _small_3d, lambda t, d: [], 1e-3, 1e-2, 1e-5, _float_types, [torch.bfloat16]), - ('exp', '', _small_3d, lambda t, d: [], 1e-2, 5e-2, 1e-5, torch.testing.get_all_fp_dtypes()), - ('exp', 'small', lambda t, d: _small_3d(t, d).clamp(-1, 1), - lambda t, d: [], 1e-2, 5e-2, 1e-5, torch.testing.get_all_fp_dtypes(), [torch.bfloat16]), ('rad2deg', '', _small_3d, lambda t, d: [], 1e-1, 1e-0, 1e-5, torch.testing.get_all_fp_dtypes(), [torch.bfloat16]), ('deg2rad', '', _small_3d, lambda t, d: [], 1e-1, 1e-1, 1e-5, torch.testing.get_all_fp_dtypes(), [torch.bfloat16]), ('reciprocal', '', _small_3d, lambda t, d: [], 1e-1, 1e-1, 1e-5, torch.testing.get_all_fp_dtypes(), [torch.bfloat16]), diff --git a/test/test_unary_ufuncs.py b/test/test_unary_ufuncs.py index 960991a4820b..3c6a4f0e7b0a 100644 --- a/test/test_unary_ufuncs.py +++ b/test/test_unary_ufuncs.py @@ -1702,7 +1702,6 @@ def _medium_2d(dtype, device): # TODO: all these should be replaced with OpInfos torch_op_tests = [ - _TorchMathTestMeta('exp'), _TorchMathTestMeta('floor'), _TorchMathTestMeta('ceil'), _TorchMathTestMeta('rad2deg'), diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp index e6e31ba4d96c..92cec735b3bb 100644 --- a/torch/csrc/jit/tensorexpr/kernel.cpp +++ b/torch/csrc/jit/tensorexpr/kernel.cpp @@ -1081,8 +1081,9 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) { } break; case aten::exp: { - return computeOneOperand( - "aten_exp", v, [](const ExprHandle& a) { return exp(a); }); + return computeOneOperand("aten_exp", v, [](const ExprHandle& a) { + return exp(promoteIntegerToDefaultType(a)); + }); } break; case aten::expm1: { diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 119750396f1e..d1a62cafe25c 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -806,6 +806,20 @@ def sample_inputs_fliplr_flipud(op_info, device, dtype, requires_grad): SkipInfo('TestCommon', 'test_variant_consistency_jit', device_type='cuda', dtypes=[torch.float16]), )), + UnaryUfuncInfo('exp', + ref=np_unary_ufunc_integer_promotion_wrapper(np.exp), + dtypes=all_types_and_complex_and(torch.bool, torch.half), + dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + skips=( + # Reference: https://github.com/pytorch/pytorch/pull/50093#pullrequestreview-561791547 + SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', dtypes=[torch.bfloat16]), + # Reference: https://github.com/pytorch/pytorch/issues/48010 + SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + ), + assert_autodiffed=True, + promotes_integers_to_float=True), SpectralFuncInfo('fft.fft', aten_name='fft_fft', ref=np.fft.fft, @@ -1602,8 +1616,6 @@ def method_tests(): ('expand', (), (dont_convert(()),), 'scalar_to_scalar'), ('expand', (), (1, 3, 2), 'scalar_to_dims', (False,)), ('expand_as', (S, 1, 1), (torch.rand(S, S, S),), '', (False,)), - ('exp', (S, S, S), NO_ARGS, '', (True,)), - ('exp', (), NO_ARGS, 'scalar', (True,)), ('logit', torch.randn(S, S, S).clamp(0.1, 0.9).requires_grad_(True), NO_ARGS, ''), ('logit', torch.randn(S, S, S).clamp(0.1, 0.9).requires_grad_(True), (0.2,), 'eps'), ('logit', uniform_scalar().clamp(0.1, 0.9).requires_grad_(True), NO_ARGS, 'scalar'), From 006cfebf3dcc4bc0cba344c25243baebaeeefbb7 Mon Sep 17 00:00:00 2001 From: Alban Desmaison Date: Fri, 8 Jan 2021 06:36:01 -0800 Subject: [PATCH 36/54] Update autograd related comments (#50166) Summary: Remove outdated comment and update to use new paths. Pull Request resolved: https://github.com/pytorch/pytorch/pull/50166 Reviewed By: zou3519 Differential Revision: D25824942 Pulled By: albanD fbshipit-source-id: 7dc694891409e80e1804eddcdcc50cc21b60f822 --- c10/core/TensorImpl.h | 12 ------------ tools/autograd/derivatives.yaml | 2 +- 2 files changed, 1 insertion(+), 13 deletions(-) diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index 47e1c865f999..d8b803f906f7 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -582,9 +582,6 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { /** * Set whether or not a tensor requires gradient. - * - * It is only valid to call this method on a Variable. - * See Note [Tensor versus Variable in C++]. */ void set_requires_grad(bool requires_grad); @@ -594,27 +591,18 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { * we can automatically differentiate back to them. A tensor that * requires gradient and has no history is a "leaf" tensor, which we * accumulate gradients into. - * - * It is only valid to call this method on a Variable. - * See Note [Tensor versus Variable in C++]. */ bool requires_grad() const; /** * Return a mutable reference to the gradient. This is conventionally * used as `t.grad() = x` to set a gradient to a completely new tensor. - * - * It is only valid to call this method on a Variable. - * See Note [Tensor versus Variable in C++]. */ at::Tensor& mutable_grad(); /** * Return the accumulated gradient of a tensor. This gradient is written * into when performing backwards, when this tensor is a leaf tensor. - * - * It is only valid to call this method on a Variable. - * See Note [Tensor versus Variable in C++]. */ const at::Tensor& grad() const; diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 9bf266da394d..5bd0451545c9 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -86,7 +86,7 @@ # e.g., it is used by _cudnn_rnn # # If you need a complex expression, e.g., with local variables, -# write a _backward function in tools/autograd/templates/Functions.cpp +# write a _backward function in torch/csrc/autograd/FunctionsManual.cpp # and invoke it from here. By the way, go read # https://github.com/zdevito/ATen/issues/163; this describes an # important hazard that occurs when porting backwards from Python to C++ From 5c5abd591d0b452ab8d4d37115fbc37cc8c984c7 Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Fri, 8 Jan 2021 06:40:29 -0800 Subject: [PATCH 37/54] Implement torch.linalg.svd (#45562) Summary: This is related to https://github.com/pytorch/pytorch/issues/42666 . I am opening this PR to have the opportunity to discuss things. First, we need to consider the differences between `torch.svd` and `numpy.linalg.svd`: 1. `torch.svd` takes `some=True`, while `numpy.linalg.svd` takes `full_matrices=True`, which is effectively the opposite (and with the opposite default, too!) 2. `torch.svd` returns `(U, S, V)`, while `numpy.linalg.svd` returns `(U, S, VT)` (i.e., V transposed). 3. `torch.svd` always returns a 3-tuple; `numpy.linalg.svd` returns only `S` in case `compute_uv==False` 4. `numpy.linalg.svd` also takes an optional `hermitian=False` argument. I think that the plan is to eventually deprecate `torch.svd` in favor of `torch.linalg.svd`, so this PR does the following: 1. Rename/adapt the old `svd` C++ functions into `linalg_svd`: in particular, now `linalg_svd` takes `full_matrices` and returns `VT` 2. Re-implement the old C++ interface on top of the new (by negating `full_matrices` and transposing `VT`). 3. The C++ version of `linalg_svd` *always* returns a 3-tuple (we can't do anything else). So, there is a python wrapper which manually calls `torch._C._linalg.linalg_svd` to tweak the return value in case `compute_uv==False`. Currently, `linalg_svd_backward` is broken because it has not been adapted yet after the `V ==> VT` change, but before continuing and spending more time on it I wanted to make sure that the general approach is fine. Pull Request resolved: https://github.com/pytorch/pytorch/pull/45562 Reviewed By: H-Huang Differential Revision: D25803557 Pulled By: mruberry fbshipit-source-id: 4966f314a0ba2ee391bab5cda4563e16275ce91f --- aten/src/ATen/native/BatchLinearAlgebra.cpp | 61 ++- aten/src/ATen/native/LinearAlgebraUtils.h | 9 +- .../ATen/native/cuda/BatchLinearAlgebra.cu | 4 +- aten/src/ATen/native/native_functions.yaml | 15 +- docs/source/linalg.rst | 1 + .../check_backward_compatibility.py | 1 + test/test_linalg.py | 465 ++++++++++-------- test/test_namedtuple_return_api.py | 47 +- tools/autograd/derivatives.yaml | 2 +- tools/autograd/gen_variable_type.py | 2 +- torch/_torch_docs.py | 72 ++- torch/linalg/__init__.py | 87 ++++ torch/overrides.py | 1 + .../_internal/common_methods_invocations.py | 46 +- 14 files changed, 558 insertions(+), 255 deletions(-) diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index bd59fe7d28b9..f0b36d0fdbac 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -1411,7 +1411,7 @@ std::tuple _svd_helper_cpu(const Tensor& self, bool some if (compute_uv) { if (some) { - VT_working_copy = VT_working_copy.narrow(-1, 0, k); + VT_working_copy = VT_working_copy.narrow(-2, 0, k); } } else { VT_working_copy.zero_(); @@ -1421,24 +1421,71 @@ std::tuple _svd_helper_cpu(const Tensor& self, bool some U_working_copy.zero_(); VT_working_copy.zero_(); } + // so far we have computed VT, but torch.svd returns V instead. Adjust accordingly. + VT_working_copy.transpose_(-2, -1); return std::make_tuple(U_working_copy, S_working_copy, VT_working_copy); } std::tuple svd(const Tensor& self, bool some, bool compute_uv) { TORCH_CHECK(self.dim() >= 2, - "self should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); + "svd input should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); return at::_svd_helper(self, some, compute_uv); } -std::tuple svd_out(Tensor& U, Tensor& S, Tensor& VT, +std::tuple svd_out(Tensor& U, Tensor& S, Tensor& V, const Tensor& self, bool some, bool compute_uv) { TORCH_CHECK(self.dim() >= 2, - "self should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); - Tensor U_tmp, S_tmp, VT_tmp; - std::tie(U_tmp, S_tmp, VT_tmp) = at::_svd_helper(self, some, compute_uv); + "svd input should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); + Tensor U_tmp, S_tmp, V_tmp; + std::tie(U_tmp, S_tmp, V_tmp) = at::_svd_helper(self, some, compute_uv); U.resize_as_(U_tmp).copy_(U_tmp); S.resize_as_(S_tmp).copy_(S_tmp); - VT.resize_as_(VT_tmp).copy_(VT_tmp); + V.resize_as_(V_tmp).copy_(V_tmp); + return std::tuple(U, S, V); +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ linalg_svd ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/* torch.linalg.svd, implemented in terms of torch.svd. There are two main + differences: + + 1. the 2nd parameter is bool some=True, which if effectively the opposite + of full_matrices=True + + 2. svd returns V, while linalg.svd returns VT. To accommodate the + difference, we transpose() V upon return +*/ + +std::tuple linalg_svd(const Tensor& self, bool full_matrices, bool compute_uv) { + TORCH_CHECK(self.dim() >= 2, + "svd input should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); + + bool some = !full_matrices; + Tensor U, S, V; + std::tie(U, S, V) = at::_svd_helper(self, some, compute_uv); + if (compute_uv) { + Tensor VT = V.transpose(-2, -1); + return std::make_tuple(U, S, VT); + } else { + Tensor empty_U = at::empty({0}, self.options()); + Tensor empty_VT = at::empty({0}, self.options()); + return std::make_tuple(empty_U, S, empty_VT); + } +} + +static void svd_resize_and_copy(const char *name, const Tensor& src, Tensor &dst) { + TORCH_CHECK(src.device() == dst.device(), "svd output tensor ", name, " is on the wrong device: expected ", src.device(), " got ", dst.device()); + at::native::resize_output(dst, src.sizes()); + dst.copy_(src); +} + +std::tuple linalg_svd_out(Tensor& U, Tensor& S, Tensor& VT, + const Tensor& self, bool full_matrices, bool compute_uv) { + Tensor U_tmp, S_tmp, VT_tmp; + std::tie(U_tmp, S_tmp, VT_tmp) = at::linalg_svd(self, full_matrices, compute_uv); + svd_resize_and_copy("U", U_tmp, U); + svd_resize_and_copy("S", S_tmp, S); + svd_resize_and_copy("V", VT_tmp, VT); return std::tuple(U, S, VT); } diff --git a/aten/src/ATen/native/LinearAlgebraUtils.h b/aten/src/ATen/native/LinearAlgebraUtils.h index 607e201ebe8d..e97637d4c5e4 100644 --- a/aten/src/ATen/native/LinearAlgebraUtils.h +++ b/aten/src/ATen/native/LinearAlgebraUtils.h @@ -261,18 +261,21 @@ static inline std::tuple _create_U_S_VT(const Tensor& in U_empty = at::empty_strided(sizes, strides, input.options().device(at::kCPU)); } + // VT should be a column-major or a batch of column-major matrices sizes[input.dim() - 2] = n; sizes[input.dim() - 1] = n; - // VT should be a row-major or a batch of row-major matrices + strides = at::detail::defaultStrides(sizes); + strides[input.dim() - 1] = n; + strides[input.dim() - 2] = 1; Tensor VT_empty; if (!input.is_cuda()) { - VT_empty = at::empty(sizes, input.options()); + VT_empty = at::empty_strided(sizes, strides, input.options()); } else { // NB: VT_empty is an empty tensor created on the CPU intentionally, because magma_(d/s)gesdd // (which is the driver routine for the divide and conquer SVD operation) // takes in arrays on the CPU as input. This routine is a hybrid CPU-GPU routine that // moves the inputs between devices internally. - VT_empty = at::empty(sizes, input.options().device(at::kCPU)); + VT_empty = at::empty_strided(sizes, strides, input.options().device(at::kCPU)); } sizes.pop_back(); diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu index 3fbd693d17b1..379847d76ff4 100644 --- a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu +++ b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu @@ -2194,7 +2194,7 @@ std::tuple _svd_helper_cuda(const Tensor& self, bool som if (compute_uv) { if (some) { - VT_working_copy = VT_working_copy.narrow(-1, 0, k); + VT_working_copy = VT_working_copy.narrow(-2, 0, k); } } else { VT_working_copy.zero_(); @@ -2205,6 +2205,8 @@ std::tuple _svd_helper_cuda(const Tensor& self, bool som S_working_copy = same_stride_to(S_working_copy, S_working_copy.options().device(self.device())); VT_working_copy = same_stride_to(VT_working_copy, self.options()).zero_(); } + // so far we have computed VT, but torch.svd returns V instead. Adjust accordingly. + VT_working_copy.transpose_(-2, -1); return std::make_tuple(U_working_copy, S_working_copy, VT_working_copy); } diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index b474d435398c..e8e3efa307f8 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -5820,14 +5820,14 @@ - func: svd.U(Tensor self, bool some=True, bool compute_uv=True, *, Tensor(a!) U, Tensor(b!) S, Tensor(c!) V) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) V) use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: - DefaultBackend: svd_out + Math: svd_out - func: svd(Tensor self, bool some=True, bool compute_uv=True) -> (Tensor U, Tensor S, Tensor V) variants: method, function dispatch: - DefaultBackend: svd + Math: svd -- func: _svd_helper(Tensor self, bool some, bool compute_uv) -> (Tensor, Tensor, Tensor) +- func: _svd_helper(Tensor self, bool some, bool compute_uv) -> (Tensor U, Tensor S, Tensor V) variants: function dispatch: CPU: _svd_helper_cpu @@ -8962,6 +8962,15 @@ python_module: linalg variants: function +- func: linalg_svd.U(Tensor self, bool full_matrices=True, bool compute_uv=True, *, Tensor(a!) U, Tensor(b!) S, Tensor(c!) V) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) V) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + python_module: linalg + +- func: linalg_svd(Tensor self, bool full_matrices=True, bool compute_uv=True) -> (Tensor U, Tensor S, Tensor V) + python_module: linalg + use_c10_dispatcher: full + variants: function + - func: linalg_cond(Tensor self, Scalar? p=None) -> Tensor python_module: linalg variants: function diff --git a/docs/source/linalg.rst b/docs/source/linalg.rst index d6de2373ad57..991205688df3 100644 --- a/docs/source/linalg.rst +++ b/docs/source/linalg.rst @@ -19,6 +19,7 @@ Functions .. autofunction:: eigvalsh .. autofunction:: matrix_rank .. autofunction:: norm +.. autofunction:: svd .. autofunction:: solve .. autofunction:: tensorinv .. autofunction:: tensorsolve diff --git a/test/backward_compatibility/check_backward_compatibility.py b/test/backward_compatibility/check_backward_compatibility.py index 2d5d50096c81..4332916fef6b 100644 --- a/test/backward_compatibility/check_backward_compatibility.py +++ b/test/backward_compatibility/check_backward_compatibility.py @@ -37,6 +37,7 @@ ("aten::ifft", datetime.date(2021, 1, 31)), ("aten::irfft", datetime.date(2021, 1, 31)), ("aten::rfft", datetime.date(2021, 1, 31)), + ("aten::_svd_helper", datetime.date(2021, 1, 31)), ("aten::_cudnn_rnn_flatten_weight", datetime.date(2020, 12, 31)), ("aten::_cudnn_rnn", datetime.date(2020, 12, 31)), ("aten::_cudnn_rnn_backward", datetime.date(2020, 12, 31)), diff --git a/test/test_linalg.py b/test/test_linalg.py index bab73f987905..5f3a12844ef8 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -11,6 +11,7 @@ from math import inf, nan, isnan import random from random import randrange +from itertools import product from functools import reduce from torch.testing._internal.common_utils import \ @@ -1864,6 +1865,277 @@ def test_nuclear_norm_exceptions_old(self, device): self.assertRaisesRegex(RuntimeError, "duplicate or invalid", torch.norm, x, "nuc", (0, 0)) self.assertRaisesRegex(IndexError, "Dimension out of range", torch.norm, x, "nuc", (0, 2)) + # ~~~ tests for torch.svd ~~~ + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.double) + def test_svd(self, device, dtype): + def run_test(dims, some, compute_uv): + x = torch.randn(*dims, dtype=dtype, device=device) + outu = torch.empty(0, dtype=dtype, device=device) + outs = torch.empty(0, dtype=dtype, device=device) + outv = torch.empty(0, dtype=dtype, device=device) + torch.svd(x, some=some, compute_uv=compute_uv, out=(outu, outs, outv)) + + if compute_uv: + if some: + x_recon = torch.matmul(outu, torch.matmul(outs.diag_embed(), outv.transpose(-2, -1))) + self.assertEqual(x, x_recon, atol=1e-8, rtol=0, msg='Incorrect reconstruction using U @ diag(S) @ V.T') + else: + narrow_u = outu[..., :min(*dims[-2:])] + narrow_v = outv[..., :min(*dims[-2:])] + x_recon = torch.matmul(narrow_u, torch.matmul(outs.diag_embed(), narrow_v.transpose(-2, -1))) + self.assertEqual(x, x_recon, atol=1e-8, rtol=0, msg='Incorrect reconstruction using U @ diag(S) @ V.T') + else: + _, singvals, _ = torch.svd(x, compute_uv=True) + self.assertEqual(singvals, outs, msg='Singular values mismatch') + self.assertEqual(outu, torch.zeros_like(outu), msg='U not zero') + self.assertEqual(outv, torch.zeros_like(outv), msg='V not zero') + + resu, ress, resv = torch.svd(x, some=some, compute_uv=compute_uv) + self.assertEqual(resu, outu, msg='outputs of svd and svd with out differ') + self.assertEqual(ress, outs, msg='outputs of svd and svd with out differ') + self.assertEqual(resv, outv, msg='outputs of svd and svd with out differ') + + # test non-contiguous + x = torch.randn(*dims, dtype=dtype, device=device) + n_dim = len(dims) + # Reverse the batch dimensions and the matrix dimensions and then concat them + x = x.permute(tuple(range(n_dim - 3, -1, -1)) + (n_dim - 1, n_dim - 2)) + assert not x.is_contiguous(), "x is intentionally non-contiguous" + resu, ress, resv = torch.svd(x, some=some, compute_uv=compute_uv) + if compute_uv: + if some: + x_recon = torch.matmul(resu, torch.matmul(ress.diag_embed(), resv.transpose(-2, -1))) + self.assertEqual(x, x_recon, atol=1e-8, rtol=0, msg='Incorrect reconstruction using U @ diag(S) @ V.T') + else: + narrow_u = resu[..., :min(*dims[-2:])] + narrow_v = resv[..., :min(*dims[-2:])] + x_recon = torch.matmul(narrow_u, torch.matmul(ress.diag_embed(), narrow_v.transpose(-2, -1))) + self.assertEqual(x, x_recon, atol=1e-8, rtol=0, msg='Incorrect reconstruction using U @ diag(S) @ V.T') + else: + _, singvals, _ = torch.svd(x, compute_uv=True) + self.assertEqual(singvals, ress, msg='Singular values mismatch') + self.assertEqual(resu, torch.zeros_like(resu), msg='U not zero') + self.assertEqual(resv, torch.zeros_like(resv), msg='V not zero') + + shapes = [(3, 3), (5, 3, 3), (7, 5, 3, 3), # square matrices + (7, 3), (5, 7, 3), (7, 5, 7, 3), # fat matrices + (3, 7), (5, 3, 7), (7, 5, 3, 7)] # thin matrices + for dims, some, compute_uv in product(shapes, [True, False], [True, False]): + run_test(dims, some, compute_uv) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float) + def test_svd_no_singularvectors(self, device, dtype): + for size in [(5, 5), (5, 20), (20, 5)]: + a = torch.randn(*size, device=device, dtype=dtype) + u, s_expect, v = torch.svd(a) + u, s_actual, v = torch.svd(a, compute_uv=False) + self.assertEqual(s_expect, s_actual, msg="Singular values don't match") + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.double) + def test_svd_lowrank(self, device, dtype): + from torch.testing._internal.common_utils import random_lowrank_matrix, random_sparse_matrix + + def run_subtest(actual_rank, matrix_size, batches, device, svd_lowrank, **options): + density = options.pop('density', 1) + if isinstance(matrix_size, int): + rows = columns = matrix_size + else: + rows, columns = matrix_size + if density == 1: + a_input = random_lowrank_matrix(actual_rank, rows, columns, *batches, device=device, dtype=dtype) + a = a_input + else: + assert batches == () + a_input = random_sparse_matrix(rows, columns, density, device=device, dtype=dtype) + a = a_input.to_dense() + + q = min(*size) + u, s, v = svd_lowrank(a_input, q=q, **options) + + # check if u, s, v is a SVD + u, s, v = u[..., :q], s[..., :q], v[..., :q] + A = u.matmul(s.diag_embed()).matmul(v.transpose(-2, -1)) + self.assertEqual(A, a) + + # check if svd_lowrank produces same singular values as torch.svd + U, S, V = torch.svd(a) + self.assertEqual(s.shape, S.shape) + self.assertEqual(u.shape, U.shape) + self.assertEqual(v.shape, V.shape) + self.assertEqual(s, S) + + if density == 1: + # actual_rank is known only for dense inputs + # + # check if pairs (u, U) and (v, V) span the same + # subspaces, respectively + u, s, v = u[..., :actual_rank], s[..., :actual_rank], v[..., :actual_rank] + U, S, V = U[..., :actual_rank], S[..., :actual_rank], V[..., :actual_rank] + self.assertEqual(u.transpose(-2, -1).matmul(U).det().abs(), torch.ones(batches, device=device, dtype=dtype)) + self.assertEqual(v.transpose(-2, -1).matmul(V).det().abs(), torch.ones(batches, device=device, dtype=dtype)) + + all_batches = [(), (1,), (3,), (2, 3)] + for actual_rank, size, all_batches in [ + (2, (17, 4), all_batches), + (4, (17, 4), all_batches), + (4, (17, 17), all_batches), + (10, (100, 40), all_batches), + (7, (1000, 1000), [()]), + ]: + # dense input + for batches in all_batches: + run_subtest(actual_rank, size, batches, device, torch.svd_lowrank) + if size != size[::-1]: + run_subtest(actual_rank, size[::-1], batches, device, torch.svd_lowrank) + + # sparse input + for size in [(17, 4), (4, 17), (17, 17), (100, 40), (40, 100), (1000, 1000)]: + for density in [0.005, 0.1]: + run_subtest(None, size, (), device, torch.svd_lowrank, density=density) + + # jitting support + jitted = torch.jit.script(torch.svd_lowrank) + actual_rank, size, batches = 2, (17, 4), () + run_subtest(actual_rank, size, batches, device, jitted) + + @onlyCPU + @skipCPUIfNoLapack + @dtypes(torch.cfloat) + def test_svd_complex(self, device, dtype): + t = torch.randn((10, 10), dtype=dtype, device=device) + U, S, V = torch.svd(t, some=False) + # note: from the math point of view, it is weird that we need to use + # V.T instead of V.T.conj(): torch.svd has a buggy behavior for + # complex numbers and it's deprecated. You should use torch.linalg.svd + # instead. + t2 = U @ torch.diag(S).type(dtype) @ V.T + self.assertEqual(t, t2) + + def _test_svd_helper(self, shape, some, col_maj, device, dtype): + cpu_tensor = torch.randn(shape, device='cpu').to(dtype) + device_tensor = cpu_tensor.to(device=device) + if col_maj: + cpu_tensor = cpu_tensor.t() + device_tensor = device_tensor.t() + cpu_result = torch.svd(cpu_tensor, some=some) + device_result = torch.svd(device_tensor, some=some) + m = min(cpu_tensor.shape[-2:]) + # torch.svd returns torch.return_types.svd which is a tuple of (U, V, S). + # - When some==False, U[..., m:] can be arbitrary. + # - When some==True, U shape: [..., m], V shape: [m, m] + # - Signs are not deterministic. If the sign of a column of U is changed + # then the corresponding column of the V has to be changed. + # Thus here we only compare result[..., :m].abs() from CPU and device. + for x, y in zip(cpu_result, device_result): + self.assertEqual(x[..., :m].abs(), y[..., :m].abs(), atol=1e-5, rtol=0) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(*floating_and_complex_types()) + def test_svd_square(self, device, dtype): + self._test_svd_helper((10, 10), True, False, device, dtype) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(*floating_types()) + def test_svd_square_col_maj(self, device, dtype): + self._test_svd_helper((10, 10), True, True, device, dtype) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(*floating_types()) + def test_svd_tall_some(self, device, dtype): + self._test_svd_helper((20, 5), True, False, device, dtype) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(*floating_types()) + def test_svd_tall_all(self, device, dtype): + self._test_svd_helper((20, 5), False, False, device, dtype) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(*floating_types()) + def test_svd_tall_some_col_maj(self, device, dtype): + self._test_svd_helper((5, 20), True, True, device, dtype) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(*floating_types()) + def test_svd_tall_all_col_maj(self, device, dtype): + self._test_svd_helper((5, 20), False, True, device, dtype) + + # ~~~ tests for torch.linalg.svd ~~~ + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) + def test_linalg_svd_compute_uv(self, device, dtype): + """ + Test the default case, compute_uv=True. Here we have the very same behavior as + numpy + """ + t = torch.randn((10, 11), device=device, dtype=dtype) + np_t = t.cpu().numpy() + for full_matrices in (True, False): + # check linalg.svd vs numpy + expected = np.linalg.svd(np_t, full_matrices, compute_uv=True) + actual = torch.linalg.svd(t, full_matrices, compute_uv=True) + self.assertEqual(actual, expected) + # check linalg.svd vs linalg.svd(out=...) + out = (torch.empty_like(actual[0]), + torch.empty_like(actual[1]), + torch.empty_like(actual[2])) + out2 = torch.linalg.svd(t, full_matrices, compute_uv=True, out=out) + self.assertEqual(actual, out) + self.assertEqual(actual, out2) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) + def test_linalg_svd_no_compute_uv(self, device, dtype): + """ + Test the compute_uv=False case. Here we have a different return type than + numpy: numpy returns S, we return (empty, S, empty) + """ + t = torch.randn((10, 11), device=device, dtype=dtype) + np_t = t.cpu().numpy() + + def is_empty(x): + return x.numel() == 0 and x.dtype == t.dtype and x.device == t.device + + for full_matrices in (True, False): + # check linalg.svd vs numpy + np_s = np.linalg.svd(np_t, full_matrices, compute_uv=False) + USV = torch.linalg.svd(t, full_matrices, compute_uv=False) + assert is_empty(USV.U) + self.assertEqual(USV.S, np_s) + assert is_empty(USV.V) + # check linalg.svd vs linalg.svd(out=...) + out = (torch.empty_like(USV.U), torch.empty_like(USV.S), torch.empty_like(USV.V)) + USV = torch.linalg.svd(t, full_matrices, compute_uv=False, out=out) + assert USV.U is out[0] + assert USV.S is out[1] + assert USV.V is out[2] + self.assertEqual(USV.S, np_s) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @onlyCUDA + @dtypes(torch.float) + def test_linalg_svd_out_different_device(self, device, dtype): + t = torch.randn(5, 7, device=device, dtype=dtype) # this is on cuda + u = torch.empty((5, 5), device='cpu', dtype=dtype) + s = torch.empty((5,), device='cpu', dtype=dtype) + v = torch.empty((7, 7), device='cpu', dtype=dtype) + with self.assertRaisesRegex(RuntimeError, 'svd output tensor U is on the wrong device: expected cuda:.* got cpu'): + torch.linalg.svd(t, out=(u, s, v)) + def cholesky_solve_test_helper(self, A_dims, b_dims, upper, device, dtype): from torch.testing._internal.common_utils import random_hermitian_pd_matrix @@ -4601,60 +4873,6 @@ def test_solve_methods_arg_device(self, device): "Expected LU_pivots and LU_data to be on the same device"): torch.lu_solve(b, A, torch.rand(A.shape[:-1], device=b_device).int()) - def _test_svd_helper(self, shape, some, col_maj, device, dtype): - cpu_tensor = torch.randn(shape, device='cpu').to(dtype) - device_tensor = cpu_tensor.to(device=device) - if col_maj: - cpu_tensor = cpu_tensor.t() - device_tensor = device_tensor.t() - cpu_result = torch.svd(cpu_tensor, some=some) - device_result = torch.svd(device_tensor, some=some) - m = min(cpu_tensor.shape[-2:]) - # torch.svd returns torch.return_types.svd which is a tuple of (U, V, S). - # - When some==False, U[..., m:] can be arbitrary. - # - When some==True, U shape: [..., m], V shape: [m, m] - # - Signs are not deterministic. If the sign of a column of U is changed - # then the corresponding column of the V has to be changed. - # Thus here we only compare result[..., :m].abs() from CPU and device. - for x, y in zip(cpu_result, device_result): - self.assertEqual(x[..., :m].abs(), y[..., :m].abs(), atol=1e-5, rtol=0) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(*floating_and_complex_types()) - def test_svd_square(self, device, dtype): - self._test_svd_helper((10, 10), True, False, device, dtype) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(*floating_types()) - def test_svd_square_col_maj(self, device, dtype): - self._test_svd_helper((10, 10), True, True, device, dtype) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(*floating_types()) - def test_svd_tall_some(self, device, dtype): - self._test_svd_helper((20, 5), True, False, device, dtype) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(*floating_types()) - def test_svd_tall_all(self, device, dtype): - self._test_svd_helper((20, 5), False, False, device, dtype) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(*floating_types()) - def test_svd_tall_some_col_maj(self, device, dtype): - self._test_svd_helper((5, 20), True, True, device, dtype) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(*floating_types()) - def test_svd_tall_all_col_maj(self, device, dtype): - self._test_svd_helper((5, 20), False, True, device, dtype) - @precisionOverride({torch.float32: 5e-3, torch.complex64: 1e-3}) @skipCUDAIfNoMagma @skipCPUIfNoLapack @@ -5596,145 +5814,6 @@ def run_test(dims, eigenvectors, upper): for batch_dims, eigenvectors, upper in itertools.product(batch_dims_set, (True, False), (True, False)): run_test((5,) + batch_dims, eigenvectors, upper) - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(torch.double) - def test_svd(self, device, dtype): - def run_test(dims, some, compute_uv): - x = torch.randn(*dims, dtype=dtype, device=device) - outu = torch.tensor((), dtype=dtype, device=device) - outs = torch.tensor((), dtype=dtype, device=device) - outv = torch.tensor((), dtype=dtype, device=device) - torch.svd(x, some=some, compute_uv=compute_uv, out=(outu, outs, outv)) - - if compute_uv: - if some: - x_recon = torch.matmul(outu, torch.matmul(outs.diag_embed(), outv.transpose(-2, -1))) - self.assertEqual(x, x_recon, atol=1e-8, rtol=0, msg='Incorrect reconstruction using U @ diag(S) @ V.T') - else: - narrow_u = outu[..., :min(*dims[-2:])] - narrow_v = outv[..., :min(*dims[-2:])] - x_recon = torch.matmul(narrow_u, torch.matmul(outs.diag_embed(), narrow_v.transpose(-2, -1))) - self.assertEqual(x, x_recon, atol=1e-8, rtol=0, msg='Incorrect reconstruction using U @ diag(S) @ V.T') - else: - _, singvals, _ = torch.svd(x, compute_uv=True) - self.assertEqual(singvals, outs, msg='Singular values mismatch') - self.assertEqual(outu, torch.zeros_like(outu), msg='U not zero') - self.assertEqual(outv, torch.zeros_like(outv), msg='V not zero') - - resu, ress, resv = torch.svd(x, some=some, compute_uv=compute_uv) - self.assertEqual(resu, outu, msg='outputs of svd and svd with out differ') - self.assertEqual(ress, outs, msg='outputs of svd and svd with out differ') - self.assertEqual(resv, outv, msg='outputs of svd and svd with out differ') - - # test non-contiguous - x = torch.randn(*dims, dtype=dtype, device=device) - n_dim = len(dims) - # Reverse the batch dimensions and the matrix dimensions and then concat them - x = x.permute(tuple(range(n_dim - 3, -1, -1)) + (n_dim - 1, n_dim - 2)) - assert not x.is_contiguous(), "x is intentionally non-contiguous" - resu, ress, resv = torch.svd(x, some=some, compute_uv=compute_uv) - if compute_uv: - if some: - x_recon = torch.matmul(resu, torch.matmul(ress.diag_embed(), resv.transpose(-2, -1))) - self.assertEqual(x, x_recon, atol=1e-8, rtol=0, msg='Incorrect reconstruction using U @ diag(S) @ V.T') - else: - narrow_u = resu[..., :min(*dims[-2:])] - narrow_v = resv[..., :min(*dims[-2:])] - x_recon = torch.matmul(narrow_u, torch.matmul(ress.diag_embed(), narrow_v.transpose(-2, -1))) - self.assertEqual(x, x_recon, atol=1e-8, rtol=0, msg='Incorrect reconstruction using U @ diag(S) @ V.T') - else: - _, singvals, _ = torch.svd(x, compute_uv=True) - self.assertEqual(singvals, ress, msg='Singular values mismatch') - self.assertEqual(resu, torch.zeros_like(resu), msg='U not zero') - self.assertEqual(resv, torch.zeros_like(resv), msg='V not zero') - - shapes = [(3, 3), (5, 3, 3), (7, 5, 3, 3), # square matrices - (7, 3), (5, 7, 3), (7, 5, 7, 3), # fat matrices - (3, 7), (5, 3, 7), (7, 5, 3, 7)] # thin matrices - for dims, some, compute_uv in itertools.product(shapes, [True, False], [True, False]): - run_test(dims, some, compute_uv) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - def test_svd_no_singularvectors(self, device): - for size in [(5, 5), (5, 20), (20, 5)]: - a = torch.randn(*size, device=device) - u, s_expect, v = torch.svd(a) - u, s_actual, v = torch.svd(a, compute_uv=False) - self.assertEqual(s_expect, s_actual, msg="Singular values don't match") - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - def test_svd_lowrank(self, device): - import torch - from torch.testing._internal.common_utils import random_lowrank_matrix, random_sparse_matrix - - dtype = torch.double - - def run_subtest(actual_rank, matrix_size, batches, device, svd_lowrank, **options): - density = options.pop('density', 1) - if isinstance(matrix_size, int): - rows = columns = matrix_size - else: - rows, columns = matrix_size - if density == 1: - a_input = random_lowrank_matrix(actual_rank, rows, columns, *batches, device=device, dtype=dtype) - a = a_input - else: - assert batches == () - a_input = random_sparse_matrix(rows, columns, density, device=device, dtype=dtype) - a = a_input.to_dense() - - q = min(*size) - u, s, v = svd_lowrank(a_input, q=q, **options) - - # check if u, s, v is a SVD - u, s, v = u[..., :q], s[..., :q], v[..., :q] - A = u.matmul(s.diag_embed()).matmul(v.transpose(-2, -1)) - self.assertEqual(A, a) - - # check if svd_lowrank produces same singular values as torch.svd - U, S, V = torch.svd(a) - self.assertEqual(s.shape, S.shape) - self.assertEqual(u.shape, U.shape) - self.assertEqual(v.shape, V.shape) - self.assertEqual(s, S) - - if density == 1: - # actual_rank is known only for dense inputs - # - # check if pairs (u, U) and (v, V) span the same - # subspaces, respectively - u, s, v = u[..., :actual_rank], s[..., :actual_rank], v[..., :actual_rank] - U, S, V = U[..., :actual_rank], S[..., :actual_rank], V[..., :actual_rank] - self.assertEqual(u.transpose(-2, -1).matmul(U).det().abs(), torch.ones(batches, device=device, dtype=dtype)) - self.assertEqual(v.transpose(-2, -1).matmul(V).det().abs(), torch.ones(batches, device=device, dtype=dtype)) - - all_batches = [(), (1,), (3,), (2, 3)] - for actual_rank, size, all_batches in [ - (2, (17, 4), all_batches), - (4, (17, 4), all_batches), - (4, (17, 17), all_batches), - (10, (100, 40), all_batches), - (7, (1000, 1000), [()]), - ]: - # dense input - for batches in all_batches: - run_subtest(actual_rank, size, batches, device, torch.svd_lowrank) - if size != size[::-1]: - run_subtest(actual_rank, size[::-1], batches, device, torch.svd_lowrank) - - # sparse input - for size in [(17, 4), (4, 17), (17, 17), (100, 40), (40, 100), (1000, 1000)]: - for density in [0.005, 0.1]: - run_subtest(None, size, (), device, torch.svd_lowrank, density=density) - - # jitting support - jitted = torch.jit.script(torch.svd_lowrank) - actual_rank, size, batches = 2, (17, 4), () - run_subtest(actual_rank, size, batches, device, jitted) - @skipCUDAIfNoMagma @skipCPUIfNoLapack def test_pca_lowrank(self, device): diff --git a/test/test_namedtuple_return_api.py b/test/test_namedtuple_return_api.py index 1906b83ca8d6..a5d8b8179207 100644 --- a/test/test_namedtuple_return_api.py +++ b/test/test_namedtuple_return_api.py @@ -13,6 +13,7 @@ 'max', 'min', 'median', 'nanmedian', 'mode', 'kthvalue', 'svd', 'symeig', 'eig', 'qr', 'geqrf', 'solve', 'slogdet', 'sort', 'topk', 'lstsq', 'triangular_solve', 'cummax', 'cummin', 'linalg_eigh', "unpack_dual", 'linalg_qr', + '_svd_helper', 'linalg_svd', } @@ -56,7 +57,7 @@ def test_namedtuple_return(self): names=('values', 'indices'), hasout=True), op(operators=['kthvalue'], input=(1, 0), names=('values', 'indices'), hasout=True), - op(operators=['svd'], input=(), names=('U', 'S', 'V'), hasout=True), + op(operators=['svd', '_svd_helper', 'linalg_svd'], input=(), names=('U', 'S', 'V'), hasout=True), op(operators=['slogdet'], input=(), names=('sign', 'logabsdet'), hasout=False), op(operators=['qr', 'linalg_qr'], input=(), names=('Q', 'R'), hasout=True), op(operators=['solve'], input=(a,), names=('solution', 'LU'), hasout=True), @@ -65,26 +66,38 @@ def test_namedtuple_return(self): op(operators=['triangular_solve'], input=(a,), names=('solution', 'cloned_coefficient'), hasout=True), op(operators=['lstsq'], input=(a,), names=('solution', 'QR'), hasout=True), op(operators=['linalg_eigh'], input=("L",), names=('eigenvalues', 'eigenvectors'), hasout=True), - op(operators=['unpack_dual'], input=(a, 0), names=('primal', 'tangent'), hasout=False), + op(operators=['unpack_dual'], input=(0,), names=('primal', 'tangent'), hasout=False), ] + def get_func(f): + "Return either torch.f or torch.linalg.f, where 'f' is a string" + if f.startswith('linalg_'): + return getattr(torch.linalg, f[7:]) + return getattr(torch, f, None) + + def check_namedtuple(tup, names): + "Check that the namedtuple 'tup' has the given names" + for i, name in enumerate(names): + self.assertIs(getattr(tup, name), tup[i]) + for op in operators: for f in op.operators: - if 'linalg_' in f: - ret = getattr(torch.linalg, f[7:])(a, *op.input) - ret1 = getattr(torch.linalg, f[7:])(a, *op.input, out=tuple(ret)) - for i, name in enumerate(op.names): - self.assertIs(getattr(ret, name), ret[i]) - else: - # Handle op that are not methods - func = getattr(a, f) if hasattr(a, f) else getattr(torch, f) - ret = func(*op.input) - for i, name in enumerate(op.names): - self.assertIs(getattr(ret, name), ret[i]) - if op.hasout: - ret1 = getattr(torch, f)(a, *op.input, out=tuple(ret)) - for i, name in enumerate(op.names): - self.assertIs(getattr(ret, name), ret[i]) + # 1. check the namedtuple returned by calling torch.f + func = get_func(f) + if func: + ret1 = func(a, *op.input) + check_namedtuple(ret1, op.names) + # + # 2. check the out= variant, if it exists + if func and op.hasout: + ret2 = func(a, *op.input, out=tuple(ret1)) + check_namedtuple(ret2, op.names) + # + # 3. check the Tensor.f method, if it exists + meth = getattr(a, f, None) + if meth: + ret3 = meth(*op.input) + check_namedtuple(ret3, op.names) all_covered_operators = set([x for y in operators for x in y.operators]) diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 5bd0451545c9..fbeec2a5eaba 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -1055,7 +1055,7 @@ - name: nansum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor self: nansum_backward(grad.to(self.scalar_type()), self, dim, keepdim) -- name: svd(Tensor self, bool some=True, bool compute_uv=True) -> (Tensor U, Tensor S, Tensor V) +- name: _svd_helper(Tensor self, bool some, bool compute_uv) -> (Tensor U, Tensor S, Tensor V) self: svd_backward(grads, self, some, compute_uv, U, S, V) - name: symeig(Tensor self, bool eigenvectors=False, bool upper=True) -> (Tensor eigenvalues, Tensor eigenvectors) diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index c78e1e5f66cc..e4337e9de855 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -82,7 +82,7 @@ 'bmm', 'diagonal', 'alias', 'atan', 'log', 'log10', 'log1p', 'log2', 'reciprocal', 'tan', 'pow', 'rsqrt', 'tanh', 'tanh_backward', 'asinh', 'acosh', 'take', 'fill_', 'exp', 'nonzero', 'mean', 'inverse', 'solve', 'linalg_cholesky', 'addcmul', 'addcdiv', - 'matrix_exp', 'linalg_eigh', 'cholesky_solve', 'linalg_qr', 'svd', '_fft_c2c', '_fft_r2c', + 'matrix_exp', 'linalg_eigh', 'cholesky_solve', 'linalg_qr', '_svd_helper', '_fft_c2c', '_fft_r2c', 'linalg_solve', 'sqrt', 'stack', 'gather', 'index_select', 'index_add_' } diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 4a1c36df7497..d204afdb286e 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -8142,18 +8142,45 @@ def merge_dicts(*dicts): r""" svd(input, some=True, compute_uv=True, *, out=None) -> (Tensor, Tensor, Tensor) -This function returns a namedtuple ``(U, S, V)`` which is the singular value -decomposition of a input matrix or batches of matrices :attr:`input` such that -:math:`input = U \times diag(S) \times V^T`. - -If :attr:`some` is ``True`` (default), the method returns the reduced -singular value decomposition i.e., if the last two dimensions of -:attr:`input` are ``m`` and ``n``, then the returned `U` matrix will -contain only :math:`min(n, m)` orthonormal columns and the size of `V` -will be :math:`(*, n, n)`. - -If :attr:`compute_uv` is ``False``, the returned `U` and `V` matrices will be zero matrices -of shape :math:`(m \times m)` and :math:`(n \times n)` respectively. :attr:`some` will be ignored here. +Computes the singular value decomposition of either a matrix or batch of +matrices :attr:`input`." The singular value decomposition is represented as a +namedtuple ``(U, S, V)``, such that :math:`input = U \mathbin{@} diag(S) \times +V^T`, where :math:`V^T` is the transpose of ``V``. If :attr:`input` is a batch +of tensors, then ``U``, ``S``, and ``V`` are also batched with the same batch +dimensions as :attr:`input`. + +If :attr:`some` is ``True`` (default), the method returns the reduced singular +value decomposition i.e., if the last two dimensions of :attr:`input` are +``m`` and ``n``, then the returned `U` and `V` matrices will contain only +:math:`min(n, m)` orthonormal columns. + +If :attr:`compute_uv` is ``False``, the returned `U` and `V` will be +zero-filled matrices of shape :math:`(m \times m)` and :math:`(n \times n)` +respectively, and the same device as :attr:`input`. The :attr:`some` +argument has no effect when :attr:`compute_uv` is False. + +The dtypes of ``U`` and ``V`` are the same as :attr:`input`'s. ``S`` will +always be real-valued, even if :attr:`input` is complex. + +.. warning:: ``torch.svd`` is deprecated. Please use ``torch.linalg.`` + :func:`~torch.linalg.svd` instead, which is similar to NumPy's + ``numpy.linalg.svd``. + +.. note:: **Differences with** ``torch.linalg.`` :func:`~torch.linalg.svd`: + + * :attr:`some` is the opposite of ``torch.linalg.`` + :func:`~torch.linalg.svd`'s :attr:`full_matricies`. Note that + default value for both is ``True``, so the default behavior is + effectively the opposite. + + * it returns ``V``, whereas ``torch.linalg.`` + :func:`~torch.linalg.svd` returns ``Vh``. The result is that + when using ``svd`` you need to manually transpose + ``V`` in order to reconstruct the original matrix. + + * If :attr:`compute_uv=False`, it returns zero-filled tensors for + ``U`` and ``Vh``, whereas :meth:`~torch.linalg.svd` returns + empty tensors. Supports real-valued and complex-valued input. @@ -8164,22 +8191,18 @@ def merge_dicts(*dicts): algorithm) instead of `?gesvd` for speed. Analogously, the SVD on GPU uses the MAGMA routine `gesdd` as well. -.. note:: Irrespective of the original strides, the returned matrix `U` - will be transposed, i.e. with strides :code:`U.contiguous().transpose(-2, -1).stride()` +.. note:: The returned matrix `U` will be transposed, i.e. with strides + :code:`U.contiguous().transpose(-2, -1).stride()`. -.. note:: Extra care needs to be taken when backward through `U` and `V` - outputs. Such operation is really only stable when :attr:`input` is - full rank with all distinct singular values. Otherwise, ``NaN`` can - appear as the gradients are not properly defined. Also, notice that - double backward will usually do an additional backward through `U` and - `V` even if the original backward is only on `S`. +.. note:: Gradients computed using `U` and `V` may be unstable if + :attr:`input` is not full rank or has non-unique singular values. .. note:: When :attr:`some` = ``False``, the gradients on :code:`U[..., :, min(m, n):]` and :code:`V[..., :, min(m, n):]` will be ignored in backward as those vectors can be arbitrary bases of the subspaces. -.. note:: When :attr:`compute_uv` = ``False``, backward cannot be performed since `U` and `V` - from the forward pass is required for the backward operation. +.. note:: The `S` tensor can only be used to compute gradients if :attr:`compute_uv` is True. + .. note:: With the complex-valued input the backward operation works correctly only for gauge invariant loss functions. Please look at `Gauge problem in AD`_ for more details. @@ -8187,8 +8210,9 @@ def merge_dicts(*dicts): Args: input (Tensor): the input tensor of size :math:`(*, m, n)` where `*` is zero or more batch dimensions consisting of :math:`m \times n` matrices. - some (bool, optional): controls the shape of returned `U` and `V` - compute_uv (bool, optional): option whether to compute `U` and `V` or not + some (bool, optional): controls whether to compute the reduced or full decomposition, and + consequently the shape of returned ``U`` and ``V``. Defaults to True. + compute_uv (bool, optional): option whether to compute `U` and `V` or not. Defaults to True. Keyword args: out (tuple, optional): the output tuple of tensors diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index 0f99def6c7fe..4c724b0b7e4c 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -403,6 +403,93 @@ (tensor(3.7417), tensor(11.2250)) """) +svd = _add_docstr(_linalg.linalg_svd, r""" +linalg.svd(input, full_matrices=True, compute_uv=True, *, out=None) -> (Tensor, Tensor, Tensor) + +Computes the singular value decomposition of either a matrix or batch of +matrices :attr:`input`." The singular value decomposition is represented as a +namedtuple ``(U, S, Vh)``, such that :math:`input = U \mathbin{@} diag(S) \times +Vh`. If :attr:`input` is a batch of tensors, then ``U``, ``S``, and ``Vh`` are +also batched with the same batch dimensions as :attr:`input`. + +If :attr:`full_matrices` is ``False`` (default), the method returns the reduced singular +value decomposition i.e., if the last two dimensions of :attr:`input` are +``m`` and ``n``, then the returned `U` and `V` matrices will contain only +:math:`min(n, m)` orthonormal columns. + +If :attr:`compute_uv` is ``False``, the returned `U` and `Vh` will be empy +tensors with no elements and the same device as :attr:`input`. The +:attr:`full_matrices` argument has no effect when :attr:`compute_uv` is False. + +The dtypes of ``U`` and ``V`` are the same as :attr:`input`'s. ``S`` will +always be real-valued, even if :attr:`input` is complex. + +.. note:: Unlike NumPy's ``linalg.svd``, this always returns a namedtuple of + three tensors, even when :attr:`compute_uv=False`. + +.. note:: The singular values are returned in descending order. If :attr:`input` is a batch of matrices, + then the singular values of each matrix in the batch is returned in descending order. + +.. note:: The implementation of SVD on CPU uses the LAPACK routine `?gesdd` (a divide-and-conquer + algorithm) instead of `?gesvd` for speed. Analogously, the SVD on GPU uses the MAGMA routine + `gesdd` as well. + +.. note:: The returned matrix `U` will be transposed, i.e. with strides + :code:`U.contiguous().transpose(-2, -1).stride()`. + +.. note:: Gradients computed using `U` and `Vh` may be unstable if + :attr:`input` is not full rank or has non-unique singular values. + +.. note:: When :attr:`full_matrices` = ``True``, the gradients on :code:`U[..., :, min(m, n):]` + and :code:`V[..., :, min(m, n):]` will be ignored in backward as those vectors + can be arbitrary bases of the subspaces. + +.. note:: The `S` tensor can only be used to compute gradients if :attr:`compute_uv` is True. + + +Args: + input (Tensor): the input tensor of size :math:`(*, m, n)` where `*` is zero or more + batch dimensions consisting of :math:`m \times n` matrices. + full_matrices (bool, optional): controls whether to compute the full or reduced decomposition, and + consequently the shape of returned ``U`` and ``V``. Defaults to True. + compute_uv (bool, optional): whether to compute `U` and `V` or not. Defaults to True. + out (tuple, optional): a tuple of three tensors to use for the outputs. If compute_uv=False, + the 1st and 3rd arguments must be tensors, but they are ignored. E.g. you can + pass `(torch.Tensor(), out_S, torch.Tensor())` + +Example:: + + >>> import torch + >>> a = torch.randn(5, 3) + >>> a + tensor([[-0.3357, -0.2987, -1.1096], + [ 1.4894, 1.0016, -0.4572], + [-1.9401, 0.7437, 2.0968], + [ 0.1515, 1.3812, 1.5491], + [-1.8489, -0.5907, -2.5673]]) + >>> + >>> # reconstruction in the full_matrices=False case + >>> u, s, vh = torch.linalg.svd(a, full_matrices=False) + >>> u.shape, s.shape, vh.shape + (torch.Size([5, 3]), torch.Size([3]), torch.Size([3, 3])) + >>> torch.dist(a, u @ torch.diag(s) @ vh) + tensor(1.0486e-06) + >>> + >>> # reconstruction in the full_matrices=True case + >>> u, s, vh = torch.linalg.svd(a) + >>> u.shape, s.shape, vh.shape + (torch.Size([5, 5]), torch.Size([3]), torch.Size([3, 3])) + >>> torch.dist(a, u[:, :3] @ torch.diag(s) @ vh) + >>> torch.dist(a, u[:, :3] @ torch.diag(s) @ vh) + tensor(1.0486e-06) + >>> + >>> # extra dimensions + >>> a_big = torch.randn(7, 5, 3) + >>> u, s, vh = torch.linalg.svd(a_big, full_matrices=False) + >>> torch.dist(a_big, u @ torch.diag_embed(s) @ vh) + tensor(3.0957e-06) +""") + cond = _add_docstr(_linalg.linalg_cond, r""" linalg.cond(input, p=None, *, out=None) -> Tensor diff --git a/torch/overrides.py b/torch/overrides.py index f8d9f2e152f6..a6ec0de5ffb5 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -803,6 +803,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.nansum: lambda input, dim=None: -1, torch.svd: lambda input, some=True, compute_uv=True, out=None: -1, torch.svd_lowrank: lambda input, q=6, niter=2, M=None: -1, + torch.linalg.svd: lambda input, full_matrices=True, compute_uv=True, out=None: -1, torch.symeig: lambda input, eigenvectors=False, upper=True, out=None: -1, torch.swapaxes: lambda input, dim0, dim1: -1, torch.swapdims: lambda input, axis0, axis1: -1, diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index d1a62cafe25c..b01deacc8ad1 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -533,7 +533,7 @@ def sample_inputs_linalg_solve(op_info, device, dtype, requires_grad=False): return out -def sample_inputs_svd(op_info, device, dtype, requires_grad=False): +def _sample_inputs_svd(op_info, device, dtype, requires_grad=False, is_linalg_svd=False): """ This function generates input for torch.svd with distinct singular values so that autograd is always stable. Matrices of different size: @@ -546,6 +546,16 @@ def sample_inputs_svd(op_info, device, dtype, requires_grad=False): """ from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value + # svd and linalg.svd returns V and V.T, respectively. So we need to slice + # along different dimensions when needed (this is used by + # test_cases2:wide_all and wide_all_batched below) + if is_linalg_svd: + def slice_V(v): + return v[..., :(S - 2), :] + else: + def slice_V(v): + return v[..., :, :(S - 2)] + test_cases1 = ( # some=True (default) # loss functions for complex-valued svd have to be "gauge invariant", # i.e. loss functions shouldn't change when sigh of the singular vectors change. @@ -575,11 +585,11 @@ def sample_inputs_svd(op_info, device, dtype, requires_grad=False): ) test_cases2 = ( # some=False (random_fullrank_matrix_distinct_singular_value(S, dtype=dtype).to(device)[:(S - 2)], - lambda usv: (abs(usv[0]), usv[1], abs(usv[2][:, :(S - 2)]))), # 'wide_all' + lambda usv: (abs(usv[0]), usv[1], abs(slice_V(usv[2])))), # 'wide_all' (random_fullrank_matrix_distinct_singular_value(S, dtype=dtype).to(device)[:, :(S - 2)], lambda usv: (abs(usv[0][:, :(S - 2)]), usv[1], abs(usv[2]))), # 'tall_all' (random_fullrank_matrix_distinct_singular_value(S, 2, dtype=dtype).to(device)[..., :(S - 2), :], - lambda usv: (abs(usv[0]), usv[1], abs(usv[2][..., :, :(S - 2)]))), # 'wide_all_batched' + lambda usv: (abs(usv[0]), usv[1], abs(slice_V(usv[2])))), # 'wide_all_batched' (random_fullrank_matrix_distinct_singular_value(S, 2, dtype=dtype).to(device)[..., :, :(S - 2)], lambda usv: (abs(usv[0][..., :, :(S - 2)]), usv[1], abs(usv[2]))), # 'tall_all_batched' ) @@ -587,15 +597,27 @@ def sample_inputs_svd(op_info, device, dtype, requires_grad=False): out = [] for a, out_fn in test_cases1: a.requires_grad = requires_grad - out.append(SampleInput(a, output_process_fn_grad=out_fn)) + if is_linalg_svd: + kwargs = {'full_matrices': False} + else: + kwargs = {'some': True} + out.append(SampleInput(a, kwargs=kwargs, output_process_fn_grad=out_fn)) for a, out_fn in test_cases2: a.requires_grad = requires_grad - kwargs = {'some': False} + if is_linalg_svd: + kwargs = {'full_matrices': True} + else: + kwargs = {'some': False} out.append(SampleInput(a, kwargs=kwargs, output_process_fn_grad=out_fn)) return out +def sample_inputs_svd(op_info, device, dtype, requires_grad=False): + return _sample_inputs_svd(op_info, device, dtype, requires_grad, is_linalg_svd=False) + +def sample_inputs_linalg_svd(op_info, device, dtype, requires_grad=False): + return _sample_inputs_svd(op_info, device, dtype, requires_grad, is_linalg_svd=True) def sample_inputs_pinverse(op_info, device, dtype, requires_grad=False): """ @@ -1189,6 +1211,20 @@ def sample_inputs_fliplr_flipud(op_info, device, dtype, requires_grad): # cuda gradchecks are very slow # see discussion https://github.com/pytorch/pytorch/pull/47761#issuecomment-747316775 SkipInfo('TestGradients', 'test_fn_gradgrad', device_type='cuda'))), + OpInfo('linalg.svd', + op=torch.linalg.svd, + aten_name='linalg_svd', + dtypes=floating_and_complex_types(), + test_inplace_grad=False, + supports_tensor_out=False, + sample_inputs_func=sample_inputs_linalg_svd, + decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack], + skips=( + # gradgrad checks are slow + SkipInfo('TestGradients', 'test_fn_gradgrad', active_if=(not TEST_WITH_SLOW)), + # cuda gradchecks are very slow + # see discussion https://github.com/pytorch/pytorch/pull/47761#issuecomment-747316775 + SkipInfo('TestGradients', 'test_fn_gradgrad', device_type='cuda'))), OpInfo('pinverse', op=torch.pinverse, dtypes=floating_and_complex_types(), From d00acebd14d5ca8d79092f52662f79d0d219909c Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Fri, 8 Jan 2021 06:53:24 -0800 Subject: [PATCH 38/54] Add tensor.view(dtype) (#47951) Summary: Fixes https://github.com/pytorch/pytorch/issues/42571 Note that this functionality is a subset of [`numpy.ndarray.view`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.view.html): - this only supports viewing a tensor as a dtype with the same number of bytes - this does not support viewing a tensor as a subclass of `torch.Tensor` Pull Request resolved: https://github.com/pytorch/pytorch/pull/47951 Reviewed By: ngimel Differential Revision: D25062301 Pulled By: mruberry fbshipit-source-id: 9fefaaef77f15d5b863ccd12d836932983794475 --- aten/src/ATen/native/TensorConversions.cpp | 16 +++++++ aten/src/ATen/native/native_functions.yaml | 13 ++++++ test/test_view_ops.py | 50 +++++++++++++++++++++ tools/autograd/derivatives.yaml | 3 ++ torch/_tensor_docs.py | 45 +++++++++++++++++++ torch/csrc/jit/frontend/schema_matching.cpp | 17 +++++++ torch/csrc/jit/ir/ir.cpp | 2 +- torch/csrc/jit/python/python_ir.cpp | 6 ++- 8 files changed, 149 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/native/TensorConversions.cpp b/aten/src/ATen/native/TensorConversions.cpp index 355fcad9010d..d773de927efb 100644 --- a/aten/src/ATen/native/TensorConversions.cpp +++ b/aten/src/ATen/native/TensorConversions.cpp @@ -117,4 +117,20 @@ Tensor to_mkldnn_backward(const Tensor& grad, const Tensor& input_) { return grad.to_dense(input_.scalar_type()); } +Tensor view_dtype(const Tensor& self, ScalarType dtype) { + if (self.scalar_type() == dtype) { + return self; + } + auto type_meta = c10::scalarTypeToTypeMeta(dtype); + TORCH_CHECK(self.element_size() == type_meta.itemsize(), + "Viewing a tensor as a new dtype with a different number of bytes per element is not supported."); + Storage storage = self.storage(); + auto new_tensor = detail::make_tensor( + std::move(storage), self.key_set(), type_meta); + auto* impl = new_tensor.unsafeGetTensorImpl(); + impl->set_storage_offset(self.storage_offset()); + impl->set_sizes_and_strides(self.sizes(), self.strides()); + return new_tensor; +} + }} // namespace at::native diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index e8e3efa307f8..52bee793b4c8 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -4899,6 +4899,19 @@ CPU, CUDA, QuantizedCPU, QuantizedCUDA: view MkldnnCPU: mkldnn_view +# Warning: If you want to change the name or overload name of this +# operator, you might also want to change the `isBlockListedSchema` +# function in `torch/csrc/jit/frontend/schema_catching.cpp`. +# The name and overload name of this operator is hardcoded in that +# function in order to workaround a bug: +# https://github.com/pytorch/pytorch/issues/47964 +- func: view.dtype(Tensor(a) self, ScalarType dtype) -> Tensor(a) + use_c10_dispatcher: full + variants: method + device_guard: False + dispatch: + DefaultBackend: view_dtype + - func: put_(Tensor(a!) self, Tensor index, Tensor source, bool accumulate=False) -> Tensor(a!) variants: method dispatch: diff --git a/test/test_view_ops.py b/test/test_view_ops.py index be33aa1ab44a..17d04c35d8ca 100644 --- a/test/test_view_ops.py +++ b/test/test_view_ops.py @@ -119,6 +119,56 @@ def test_conj_self(self, device, dtype): s = t.conj() self.assertTrue(s is t) + @onlyOnCPUAndCUDA + @dtypes(*torch.testing.get_all_fp_dtypes(include_bfloat16=False), torch.complex64) + def test_view_dtype(self, device, dtype): + int_dtype = { + torch.half: torch.int16, + torch.bfloat16: torch.int16, + torch.float: torch.int, + torch.double: torch.long, + torch.complex64: torch.long, + }[dtype] + numpy_dtype = { + torch.half: np.int16, + torch.bfloat16: np.int16, + torch.float: np.int32, + torch.double: np.int64, + torch.complex64: np.int64, + }[dtype] + + def generate_inputs(): + yield make_tensor((5, 5, 5), device, dtype, low=-5, high=5) + yield make_tensor((5, 5, 5), device, dtype, low=-5, high=5).permute(2, 0, 1) + yield make_tensor((1, 5, 1), device, dtype, low=-5, high=5).expand(5, 5, 5) + yield make_tensor((10, 5, 10), device, dtype, low=-5, high=5)[::2, :, ::2] + yield make_tensor((0, 5, 10), device, dtype, low=-5, high=5) + yield make_tensor((), device, dtype, low=-5, high=5) + + def run_test(fp_tensor): + self.assertRaises(RuntimeError, lambda: fp_tensor.view(torch.complex128)) + self.assertRaises(RuntimeError, lambda: fp_tensor.view(torch.int8)) + + int_tensor = fp_tensor.view(int_dtype) + self.assertEqual(int_tensor.dtype, int_dtype) + self.assertEqual(int_tensor.shape, fp_tensor.shape) + self.assertEqual(int_tensor.stride(), fp_tensor.stride()) + + self.assertEqual(fp_tensor, int_tensor.view(dtype), rtol=0, atol=0) + self.assertEqual(fp_tensor.cpu().numpy().view(numpy_dtype), int_tensor, rtol=0, atol=0) + + fp_tensor.zero_() + self.assertEqual(fp_tensor, torch.zeros_like(fp_tensor), rtol=0, atol=0) + + for fp_tensor in generate_inputs(): + run_test(fp_tensor) + + # Test that requires_grad is dropped, because view(dtype) does not support backward + if dtype is torch.double: + t = make_tensor((5, 5, 5), device, torch.double, low=-5, high=5, requires_grad=True) + self.assertFalse(t.view(torch.complex64).requires_grad) + + @onlyOnCPUAndCUDA def test_view_as_complex(self, device): def fn(contiguous_input=True, dim0=0, dim1=1): diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index fbeec2a5eaba..9c9719be1ef0 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -1173,6 +1173,9 @@ - name: view(Tensor(a) self, int[] size) -> Tensor(a) self: grad.reshape(self.sizes()) +- name: view.dtype(Tensor(a) self, ScalarType dtype) -> Tensor(a) + output_differentiability: [False] + - name: view_as_real(Tensor(a) self) -> Tensor(a) self: at::view_as_complex(grad.contiguous()) # gx0 + 1j * gx1 diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index 588c59ef98a6..fd879427b56e 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -4254,6 +4254,51 @@ def callable(a, b) -> number >>> torch.equal(b, c) False + +.. function:: view(dtype) -> Tensor + +Returns a new tensor with the same data as the :attr:`self` tensor but of a +different :attr:`dtype`. :attr:`dtype` must have the same number of bytes per +element as :attr:`self`'s dtype. + +.. warning:: + + This overload is not supported by TorchScript, and using it in a Torchscript + program will cause undefined behavior. + + +Args: + dtype (:class:`torch.dtype`): the desired dtype + +Example:: + + >>> x = torch.randn(4, 4) + >>> x + tensor([[ 0.9482, -0.0310, 1.4999, -0.5316], + [-0.1520, 0.7472, 0.5617, -0.8649], + [-2.4724, -0.0334, -0.2976, -0.8499], + [-0.2109, 1.9913, -0.9607, -0.6123]]) + >>> x.dtype + torch.float32 + + >>> y = x.view(torch.int32) + >>> y + tensor([[ 1064483442, -1124191867, 1069546515, -1089989247], + [-1105482831, 1061112040, 1057999968, -1084397505], + [-1071760287, -1123489973, -1097310419, -1084649136], + [-1101533110, 1073668768, -1082790149, -1088634448]], + dtype=torch.int32) + >>> y[0, 0] = 1000000000 + >>> x + tensor([[ 0.0047, -0.0310, 1.4999, -0.5316], + [-0.1520, 0.7472, 0.5617, -0.8649], + [-2.4724, -0.0334, -0.2976, -0.8499], + [-0.2109, 1.9913, -0.9607, -0.6123]]) + + >>> x.view(torch.int16) + Traceback (most recent call last): + File "", line 1, in + RuntimeError: Viewing a tensor as a new dtype with a different number of bytes per element is not supported. """) add_docstr_all('view_as', diff --git a/torch/csrc/jit/frontend/schema_matching.cpp b/torch/csrc/jit/frontend/schema_matching.cpp index d544df0d5535..f431dd3d4a0f 100644 --- a/torch/csrc/jit/frontend/schema_matching.cpp +++ b/torch/csrc/jit/frontend/schema_matching.cpp @@ -284,6 +284,19 @@ static bool varargsCanBeUsedAsList( !typevar_list; } +// Note (@zasdfgbnm): +// This is a workaround for https://github.com/pytorch/pytorch/issues/47964 +// Currently JIT does not distinguish ScalarType vs int, so there is really +// no way to distinguish x.view(1) vs x.view(torch.int8). So we have to hardcode +// the aten::view.dtype here to block this overload. This blocklist should be +// removed when JIT fully suports ScalarType as its own type. +bool isBlockListedSchema(const FunctionSchema& schema) { + if (schema.name() == "aten::view" && schema.overload_name() == "dtype") { + return true; + } + return false; +} + static c10::optional tryMatchSchema( const FunctionSchema& schema, const SourceRange& loc, @@ -293,6 +306,10 @@ static c10::optional tryMatchSchema( c10::optional self, std::ostream* failure_messages, bool allow_conversions) { + if (isBlockListedSchema(schema)) { + return c10::nullopt; + } + auto err = [&]() -> std::ostream& { *failure_messages << "\n" << schema << ":\n"; return *failure_messages; diff --git a/torch/csrc/jit/ir/ir.cpp b/torch/csrc/jit/ir/ir.cpp index eb75928e5952..7f58aca25271 100644 --- a/torch/csrc/jit/ir/ir.cpp +++ b/torch/csrc/jit/ir/ir.cpp @@ -104,7 +104,7 @@ std::ostream& operator<<( static void printAttribute(std::ostream& out, const at::Tensor& tensor) { // 1-elem tensors are usually boxed scalars, so print them like it if (tensor.numel() == 1) { - auto scalar_tensor = tensor.view({}).item(); + auto scalar_tensor = tensor.view(std::vector{}).item(); out << "{"; if (scalar_tensor.isFloatingPoint()) { out << scalar_tensor.toDouble(); diff --git a/torch/csrc/jit/python/python_ir.cpp b/torch/csrc/jit/python/python_ir.cpp index 6e68fe9ebec3..b6f6169ede6c 100644 --- a/torch/csrc/jit/python/python_ir.cpp +++ b/torch/csrc/jit/python/python_ir.cpp @@ -625,7 +625,8 @@ void initPythonIRBindings(PyObject* module_) { [](Node& n, const char* name, const at::Tensor& v) { return n.t_( Symbol::attr(name), - autograd::Variable(v.view({})).set_requires_grad(false)); + autograd::Variable(v.view(std::vector{})) + .set_requires_grad(false)); }) .def( "z", @@ -634,7 +635,8 @@ void initPythonIRBindings(PyObject* module_) { "zs_", [](Node& n, const char* name, TensorsAttr::ValueType v) { for (auto& i : v) { - i = autograd::Variable(i.view({})).set_requires_grad(false); + i = autograd::Variable(i.view(std::vector{})) + .set_requires_grad(false); } return n.ts_(Symbol::attr(name), std::move(v)); }) From 54ce171f16c8859f829dde09f87c364c8a6b4130 Mon Sep 17 00:00:00 2001 From: Tongzhou Wang Date: Fri, 8 Jan 2021 06:56:10 -0800 Subject: [PATCH 39/54] Fix persistent_workers + pin_memory (#48543) Summary: Fixes https://github.com/pytorch/pytorch/issues/48370 https://github.com/pytorch/pytorch/issues/47445 cc emcastillo who authored the original functionality. Pull Request resolved: https://github.com/pytorch/pytorch/pull/48543 Reviewed By: bdhirsh Differential Revision: D25277474 Pulled By: ejguan fbshipit-source-id: 1967002124fb0fff57caca8982bc7df359a059a2 --- test/test_dataloader.py | 22 +++++++++++++--------- torch/utils/data/_utils/worker.py | 2 +- torch/utils/data/dataloader.py | 5 +++-- 3 files changed, 17 insertions(+), 12 deletions(-) diff --git a/test/test_dataloader.py b/test/test_dataloader.py index c257dd8a2fd7..edc31b75485e 100644 --- a/test/test_dataloader.py +++ b/test/test_dataloader.py @@ -2031,15 +2031,19 @@ def __next__(self): def test_dataset_not_reset(self): dataset = DummyDataset() - dataloader = self._get_data_loader(dataset, num_workers=2) - dataset.start = 0 - for i in range(10): - for x in dataloader: - pass - # Changing the start value here doesn't have any effect in the dataset - # cached by the workers. since they are not recreated between epochs - # and can cache values safely - dataset.start = i + pin_memory_configs = [False] + if TEST_CUDA: + pin_memory_configs.append(True) + for pin_memory in pin_memory_configs: + dataloader = self._get_data_loader(dataset, num_workers=2, pin_memory=pin_memory) + dataset.start = 0 + for i in range(10): + for x in dataloader: + pass + # Changing the start value here doesn't have any effect in the dataset + # cached by the workers. since they are not recreated between epochs + # and can cache values safely + dataset.start = i diff --git a/torch/utils/data/_utils/worker.py b/torch/utils/data/_utils/worker.py index 7a53d61feae5..a8ca66057b6b 100644 --- a/torch/utils/data/_utils/worker.py +++ b/torch/utils/data/_utils/worker.py @@ -177,7 +177,7 @@ def _worker_loop(dataset_kind, dataset, index_queue, data_queue, done_event, continue if isinstance(r, _ResumeIteration): # Acknowledge the main process - data_queue.put(r) + data_queue.put((r, None)) iteration_end = False # Recreate the fetcher for worker-reuse policy fetcher = _DatasetKind.create_fetcher( diff --git a/torch/utils/data/dataloader.py b/torch/utils/data/dataloader.py index d4ef1a99a2df..a5eeeec671e3 100644 --- a/torch/utils/data/dataloader.py +++ b/torch/utils/data/dataloader.py @@ -962,8 +962,9 @@ def _reset(self, loader, first_iter=False): self._index_queues[idx].put(_utils.worker._ResumeIteration()) resume_iteration_cnt = self._num_workers while resume_iteration_cnt > 0: - data = self._get_data() - if isinstance(data, _utils.worker._ResumeIteration): + return_idx, return_data = self._get_data() + if isinstance(return_idx, _utils.worker._ResumeIteration): + assert return_data is None resume_iteration_cnt -= 1 # prime the prefetch loop for _ in range(self._prefetch_factor * self._num_workers): From 55919a47588630d27333df37a94753b8d1f14eb8 Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Fri, 8 Jan 2021 07:29:43 -0800 Subject: [PATCH 40/54] add type annotations to torch.nn.quantized.modules.conv (#49702) Summary: closes gh-49700 No mypy issues were found in the first three entries deleted from `mypy.ini`: ``` [mypy-torch.nn.qat.modules.activations] ignore_errors = True [mypy-torch.nn.qat.modules.conv] ignore_errors = True [mypy-torch.nn.quantized.dynamic.modules.linear] ignore_errors = True ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/49702 Reviewed By: walterddr, zou3519 Differential Revision: D25767119 Pulled By: ezyang fbshipit-source-id: cb83e53549a299538e1b154cf8b79e3280f7392a --- mypy.ini | 11 +----- torch/nn/quantized/modules/conv.py | 59 ++++++++++++++++++++++-------- 2 files changed, 45 insertions(+), 25 deletions(-) diff --git a/mypy.ini b/mypy.ini index 6c579ee9399d..0c99a9c62d18 100644 --- a/mypy.ini +++ b/mypy.ini @@ -91,16 +91,7 @@ ignore_errors = True [mypy-torch.nn.modules.pooling] ignore_errors = True -[mypy-torch.nn.qat.modules.activations] -ignore_errors = True - -[mypy-torch.nn.qat.modules.conv] -ignore_errors = True - -[mypy-torch.nn.quantized.dynamic.modules.linear] -ignore_errors = True - -[mypy-torch.nn.quantized.modules.conv] +[mypy-torch.nn.parallel._functions] ignore_errors = True [mypy-torch._appdirs] diff --git a/torch/nn/quantized/modules/conv.py b/torch/nn/quantized/modules/conv.py index 00ceba7ab367..b3bc78ff6941 100644 --- a/torch/nn/quantized/modules/conv.py +++ b/torch/nn/quantized/modules/conv.py @@ -1,7 +1,7 @@ # coding=utf-8 r"""Quantized convolution modules.""" -from typing import Optional, List +from typing import Optional, List, TypeVar import torch import torch.nn as nn @@ -16,11 +16,17 @@ class _ConvNd(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, stride, - padding, dilation, - transposed, output_padding, - groups, bias, + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'): + # All subclasses have this signature - See PR #49702s + raise NotImplementedError + + def _init(self, in_channels, out_channels, kernel_size, stride, + padding, dilation, + transposed, output_padding, + groups, bias, + padding_mode='zeros'): super(_ConvNd, self).__init__() if padding_mode != 'zeros': raise NotImplementedError( @@ -54,6 +60,15 @@ def __init__(self, in_channels, out_channels, kernel_size, stride, self.scale = 1.0 self.zero_point = 0 + def set_weight_bias(self, qweight, bias_float): + raise NotImplementedError + + def bias(self): + raise NotImplementedError + + def _weight_bias(self): + raise NotImplementedError + def extra_repr(self): s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}' ', stride={stride}, scale={scale}, zero_point={zero_point}') @@ -155,7 +170,8 @@ def get_qconv(cls, mod, activation_post_process, weight_post_process=None): assert weight_post_process.dtype == torch.qint8, \ 'Weight observer must have a dtype of qint8' qweight = _quantize_weight(mod.weight.float(), weight_post_process) - qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size, + # the __init__ call used is the one from derived classes and not the one from _ConvNd + qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size, # type: ignore[call-arg] mod.stride, mod.padding, mod.dilation, mod.groups, mod.bias is not None, mod.padding_mode) qconv.set_weight_bias(qweight, mod.bias) @@ -233,7 +249,9 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding = _pair_from_first(padding) dilation = _pair_from_first(dilation) - super(Conv1d, self).__init__( + # Subclasses of _ConvNd needs to call _init rather than __init__. See + # discussion on PR #49702 + super(Conv1d, self)._init( in_channels, out_channels, kernel_size, stride, padding, dilation, False, _single(0), groups, bias, padding_mode) @@ -319,7 +337,9 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, stride = _pair(stride) padding = _pair(padding) dilation = _pair(dilation) - super(Conv2d, self).__init__( + # Subclasses of _ConvNd need to call _init rather than __init__. See + # discussion on PR #49702 + super(Conv2d, self)._init( in_channels, out_channels, kernel_size, stride, padding, dilation, False, _pair(0), groups, bias, padding_mode) @@ -403,7 +423,9 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, stride = _triple(stride) padding = _triple(padding) dilation = _triple(dilation) - super(Conv3d, self).__init__( + # Subclasses of _ConvNd need to call _init rather than __init__. See + # discussion on PR #49702 + super(Conv3d, self)._init( in_channels, out_channels, kernel_size, stride, padding, dilation, False, _triple(0), groups, bias, padding_mode) @@ -450,15 +472,20 @@ def from_float(cls, mod): return cls.get_qconv(mod, activation_post_process) # === Transposed Convolutions === +MOD = TypeVar('MOD', bound=nn.modules.conv._ConvNd) class _ConvTransposeNd(_ConvNd): + + _FLOAT_MODULE = MOD + def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, groups, bias, padding_mode): if padding_mode != 'zeros': raise ValueError('Only "zeros" padding mode is supported for {}'.format(self.__class__.__name__)) - - super(_ConvTransposeNd, self).__init__( + # Subclasses of _ConvNd need to call _init rather than __init__. See + # discussion on PR #49702 + super(_ConvTransposeNd, self)._init( in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, groups, bias, padding_mode) @@ -477,9 +504,10 @@ def from_float(cls, mod): mod (Module): a float module, either produced by torch.quantization utilities or provided by the user """ - assert type(mod) == cls._FLOAT_MODULE, \ - ' nnq.' + cls.__name__ + '.from_float only works for ' + \ - cls._FLOAT_MODULE.__name__ + # derived classes override cls._FLOAT_MODULE attribute + msg = ' nnq.' + cls.__name__ + '.from_float only works for ' + \ + cls._FLOAT_MODULE.__name__ + assert type(mod) == cls._FLOAT_MODULE, msg assert hasattr(mod, 'qconfig'), \ 'Input float module must have qconfig defined.' weight_post_process = mod.qconfig.weight() @@ -488,7 +516,8 @@ def from_float(cls, mod): assert weight_post_process.dtype == torch.qint8, \ 'Weight observer must have a dtype of qint8' qweight = _quantize_weight(mod.weight.float(), weight_post_process) - qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size, + # the __init__ call used is the one from derived classes and not the one from _ConvTransposeNd + qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size, # type: ignore[call-arg] mod.stride, mod.padding, mod.output_padding, mod.groups, mod.bias is not None, mod.dilation, mod.padding_mode) qconv.set_weight_bias(qweight, mod.bias) From 88bd69b4881d1f9cde078d610b588172b32d4a65 Mon Sep 17 00:00:00 2001 From: Gregory Chanan Date: Fri, 8 Jan 2021 09:42:19 -0800 Subject: [PATCH 41/54] Stop using c10::scalar_to_tensor in float_power. (#50105) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/50105 There should be no functional change here. A couple of reasons here: 1) This function is generally an anti-pattern (https://github.com/pytorch/pytorch/issues/49758) and it is good to minimize its usage in the code base. 2) pow itself has a fair amount of smarts like not broadcasting scalar/tensor combinations and we should defer to it. Test Plan: Imported from OSS Reviewed By: mruberry Differential Revision: D25786172 Pulled By: gchanan fbshipit-source-id: 89de03aa0b900ce011a62911224a5441f15e331a --- aten/src/ATen/native/Pow.cpp | 50 +++++++++++++++++++++++++++--------- test/test_binary_ufuncs.py | 8 +++--- 2 files changed, 42 insertions(+), 16 deletions(-) diff --git a/aten/src/ATen/native/Pow.cpp b/aten/src/ATen/native/Pow.cpp index 4d1601d3e6a0..1ed105fd0175 100644 --- a/aten/src/ATen/native/Pow.cpp +++ b/aten/src/ATen/native/Pow.cpp @@ -28,7 +28,7 @@ Tensor& pow_out(Tensor& result, const Tensor& base, Scalar exp) { auto common_dtype = at::result_type(base, exp); TORCH_CHECK(at::can_cast(common_dtype, result.scalar_type()), - "result type ", common_dtype, "can't be cast to the desired output type ", + "result type ", common_dtype, " can't be cast to the desired output type ", result.scalar_type()); if (exp.equal(0.0)) { @@ -83,42 +83,68 @@ Tensor& float_power_out(Tensor& result, const Tensor& base, const Tensor& exp) { auto dtype = (at::isComplexType(base.scalar_type()) || at::isComplexType(exp.scalar_type())) ? at::kComplexDouble : at::kDouble; TORCH_CHECK(result.scalar_type() == dtype, - "output type ", result.scalar_type(), "is not the desired output type ", dtype); + "the output given to float_power has dtype ", result.scalar_type(), + " but the operation's result requires dtype ", dtype); return at::pow_out(result, base.to(dtype), exp.to(dtype)); } Tensor& float_power_out(Tensor& result, const Tensor& base, Scalar exp) { - return at::float_power_out(result, base, c10::scalar_to_tensor(exp, base.device())); + auto dtype = (at::isComplexType(base.scalar_type()) || exp.isComplex()) ? at::kComplexDouble : at::kDouble; + TORCH_CHECK(result.scalar_type() == dtype, + "the output given to float_power has dtype ", result.scalar_type(), + " but the operation's result requires dtype ", dtype); + + // Note: need the casts inside the ternary because conversion functions return e.g. c10::complex, + // which causes a complex scalar to always be returned. + exp = (dtype == at::kComplexDouble) ? Scalar(exp.toComplexDouble()) : Scalar(exp.toDouble()); + return at::pow_out(result, base.to(dtype), exp); } Tensor& float_power_out(Tensor& result, Scalar base, const Tensor& exp) { - return at::float_power_out(result, c10::scalar_to_tensor(base, exp.device()), exp); -} + auto dtype = (at::isComplexType(exp.scalar_type()) || base.isComplex()) ? at::kComplexDouble : at::kDouble; + TORCH_CHECK(result.scalar_type() == dtype, + "the output given to float_power has dtype ", result.scalar_type(), + " but the operation's result requires dtype ", dtype); -Tensor float_power(const Tensor& base, const Tensor& exp) { - auto dtype = (at::isComplexType(base.scalar_type()) || at::isComplexType(exp.scalar_type())) ? at::kComplexDouble : at::kDouble; - return at::pow(base.to(dtype), exp.to(dtype)); + base = (dtype == at::kComplexDouble) ? Scalar(base.toComplexDouble()) : Scalar(base.toDouble()); + return at::pow_out(result, base, exp.to(dtype)); } Tensor float_power(const Tensor& base, Scalar exp) { - return at::float_power(base, c10::scalar_to_tensor(exp, base.device())); + auto dtype = (at::isComplexType(base.scalar_type()) || exp.isComplex()) ? at::kComplexDouble : at::kDouble; + exp = (dtype == at::kComplexDouble) ? Scalar(exp.toComplexDouble()) : Scalar(exp.toDouble()); + return at::pow(base.to(dtype), exp); } Tensor float_power(Scalar base, const Tensor& exp) { - return at::float_power(c10::scalar_to_tensor(base, exp.device()), exp); + auto dtype = (at::isComplexType(exp.scalar_type()) || base.isComplex()) ? at::kComplexDouble : at::kDouble; + base = (dtype == at::kComplexDouble) ? Scalar(base.toComplexDouble()) : Scalar(base.toDouble()); + return at::pow(base, exp.to(dtype)); +} + +Tensor float_power(const Tensor& base, const Tensor& exp) { + auto dtype = (at::isComplexType(base.scalar_type()) || at::isComplexType(exp.scalar_type())) ? at::kComplexDouble : at::kDouble; + return at::pow(base.to(dtype), exp.to(dtype)); } Tensor& float_power_(Tensor& base, const Tensor& exp) { auto dtype = (at::isComplexType(base.scalar_type()) || at::isComplexType(exp.scalar_type())) ? at::kComplexDouble : at::kDouble; TORCH_CHECK(base.scalar_type() == dtype, - "self tensor type ", base.scalar_type(), "is not the desired type ", dtype); + "the base given to float_power_ has dtype ", base.scalar_type(), + " but the operation's result requires dtype ", dtype); return base.pow_(exp.to(dtype)); } Tensor& float_power_(Tensor& base, Scalar exp) { - return base.float_power_(c10::scalar_to_tensor(exp, base.device())); + auto dtype = (at::isComplexType(base.scalar_type()) || exp.isComplex()) ? at::kComplexDouble : at::kDouble; + TORCH_CHECK(base.scalar_type() == dtype, + "the base given to float_power_ has dtype ", base.scalar_type(), + " but the operation's result requires dtype ", dtype); + + exp = (dtype == at::kComplexDouble) ? Scalar(exp.toComplexDouble()) : Scalar(exp.toDouble()); + return base.pow_(exp); } } // namespace native diff --git a/test/test_binary_ufuncs.py b/test/test_binary_ufuncs.py index 2ff12396701e..785f6999f570 100644 --- a/test/test_binary_ufuncs.py +++ b/test/test_binary_ufuncs.py @@ -2424,7 +2424,7 @@ def to_np(value): # Case of Tensor x Tensor if op is torch.Tensor.float_power_ and base_dtype != out_dtype: - with self.assertRaisesRegex(RuntimeError, "is not the desired type"): + with self.assertRaisesRegex(RuntimeError, "operation's result requires dtype"): op(base.clone(), exp) else: result = op(base.clone(), exp) @@ -2441,7 +2441,7 @@ def to_np(value): expected_scalar_exp = torch.from_numpy(np.float_power(to_np(base), i)) if op is torch.Tensor.float_power_ and base_dtype != out_dtype_scalar_exp: - with self.assertRaisesRegex(RuntimeError, "is not the desired type"): + with self.assertRaisesRegex(RuntimeError, "operation's result requires dtype"): op(base.clone(), i) else: result = op(base.clone(), i) @@ -2483,13 +2483,13 @@ def _promo_helper(x, y): if out.dtype == required_dtype: torch.float_power(base, exp, out=out) else: - with self.assertRaisesRegex(RuntimeError, "is not the desired output type"): + with self.assertRaisesRegex(RuntimeError, "operation's result requires dtype"): torch.float_power(base, exp, out=out) if base.dtype == required_dtype: torch.Tensor.float_power_(base.clone(), exp) else: - with self.assertRaisesRegex(RuntimeError, "is not the desired type"): + with self.assertRaisesRegex(RuntimeError, "operation's result requires dtype"): torch.Tensor.float_power_(base.clone(), exp) @skipIf(not TEST_SCIPY, "Scipy required for the test.") From b5ab0a7f781411878b30b0d83cda47f9c364dc03 Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Fri, 8 Jan 2021 09:45:49 -0800 Subject: [PATCH 42/54] Improve torch.linalg.qr (#50046) Summary: This is a follow up of PR https://github.com/pytorch/pytorch/issues/47764 to fix the remaining details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/50046 Reviewed By: zou3519 Differential Revision: D25825557 Pulled By: mruberry fbshipit-source-id: b8e335e02265e73484a99b0189e4cc042828e0a9 --- aten/src/ATen/native/BatchLinearAlgebra.cpp | 4 +- aten/src/ATen/native/LinearAlgebraUtils.h | 3 +- test/test_autograd.py | 11 ---- test/test_linalg.py | 67 ++++++++++++++++++--- torch/_torch_docs.py | 23 ++++--- torch/csrc/autograd/FunctionsManual.cpp | 4 +- torch/linalg/__init__.py | 24 ++++---- 7 files changed, 88 insertions(+), 48 deletions(-) diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index f0b36d0fdbac..146275e1faa6 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -1011,13 +1011,13 @@ std::tuple _linalg_qr_helper_cpu(const Tensor& self, std::string std::tuple linalg_qr(const Tensor& self, std::string mode) { TORCH_CHECK(self.dim() >= 2, - "self should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); + "qr input should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); return at::_linalg_qr_helper(self, mode); } std::tuple linalg_qr_out(Tensor& Q, Tensor& R, const Tensor& self, std::string mode) { TORCH_CHECK(self.dim() >= 2, - "self should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); + "qr input should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); Tensor Q_tmp, R_tmp; std::tie(Q_tmp, R_tmp) = at::_linalg_qr_helper(self, mode); at::native::resize_output(Q, Q_tmp.sizes()); diff --git a/aten/src/ATen/native/LinearAlgebraUtils.h b/aten/src/ATen/native/LinearAlgebraUtils.h index e97637d4c5e4..4322c4c79222 100644 --- a/aten/src/ATen/native/LinearAlgebraUtils.h +++ b/aten/src/ATen/native/LinearAlgebraUtils.h @@ -206,7 +206,8 @@ static inline std::tuple _parse_qr_mode(std::string mode) { compute_q = false; reduced = true; // this is actually irrelevant in this mode } else { - TORCH_CHECK(false, "Unrecognized mode '", mode, "'"); + TORCH_CHECK(false, "qr received unrecognized mode '", mode, + "' but expected one of 'reduced' (default), 'r', or 'complete'"); } return std::make_tuple(compute_q, reduced); } diff --git a/test/test_autograd.py b/test/test_autograd.py index 9f5925212757..fb4394b3a62b 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -4941,17 +4941,6 @@ def assert_only_first_requires_grad(res): return_counts=return_counts) assert_only_first_requires_grad(res) - def test_linalg_qr_r(self): - # torch.linalg.qr(mode='r') returns only 'r' and discards 'q', but - # without 'q' you cannot compute the backward pass. Check that - # linalg_qr_backward complains cleanly in that case. - inp = torch.randn((5, 7), requires_grad=True) - q, r = torch.linalg.qr(inp, mode='r') - assert q.shape == (0,) # empty tensor - b = torch.sum(r) - with self.assertRaisesRegex(RuntimeError, - "linalg_qr_backward: cannot compute backward"): - b.backward() def index_perm_variable(shape, max_indices): if not isinstance(shape, tuple): diff --git a/test/test_linalg.py b/test/test_linalg.py index 5f3a12844ef8..b39d64e4bca1 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -3059,12 +3059,34 @@ def test_qr_vs_numpy(self, device, dtype): exp_r = np.linalg.qr(np_t, mode='r') q, r = torch.linalg.qr(t, mode='r') # check that q is empty - assert q.shape == (0,) - assert q.dtype == t.dtype - assert q.device == t.device + self.assertEqual(q.shape, (0,)) + self.assertEqual(q.dtype, t.dtype) + self.assertEqual(q.device, t.device) # check r self.assertEqual(r, exp_r) + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float) + def test_linalg_qr_autograd_errors(self, device, dtype): + # torch.linalg.qr(mode='r') returns only 'r' and discards 'q', but + # without 'q' you cannot compute the backward pass. Check that + # linalg_qr_backward complains cleanly in that case. + inp = torch.randn((5, 7), device=device, dtype=dtype, requires_grad=True) + q, r = torch.linalg.qr(inp, mode='r') + self.assertEqual(q.shape, (0,)) # empty tensor + b = torch.sum(r) + with self.assertRaisesRegex(RuntimeError, + "The derivative of qr is not implemented when mode='r'"): + b.backward() + # + inp = torch.randn((7, 5), device=device, dtype=dtype, requires_grad=True) + q, r = torch.linalg.qr(inp, mode='complete') + b = torch.sum(r) + with self.assertRaisesRegex(RuntimeError, + "The derivative of qr is not implemented when mode='complete' and nrows > ncols"): + b.backward() + @skipCUDAIfNoMagma @skipCPUIfNoLapack @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) @@ -3078,10 +3100,17 @@ def np_qr_batched(a, mode): all_q = [] all_r = [] for matrix in a: - q, r = np.linalg.qr(matrix, mode=mode) - all_q.append(q) - all_r.append(r) - return np.array(all_q), np.array(all_r) + result = np.linalg.qr(matrix, mode=mode) + if mode == 'r': + all_r.append(result) + else: + q, r = result + all_q.append(q) + all_r.append(r) + if mode == 'r': + return np.array(all_r) + else: + return np.array(all_q), np.array(all_r) t = torch.randn((3, 7, 5), device=device, dtype=dtype) np_t = t.cpu().numpy() @@ -3090,6 +3119,15 @@ def np_qr_batched(a, mode): q, r = torch.linalg.qr(t, mode=mode) self.assertEqual(q, exp_q) self.assertEqual(r, exp_r) + # for mode='r' we need a special logic because numpy returns only r + exp_r = np_qr_batched(np_t, mode='r') + q, r = torch.linalg.qr(t, mode='r') + # check that q is empty + self.assertEqual(q.shape, (0,)) + self.assertEqual(q.dtype, t.dtype) + self.assertEqual(q.device, t.device) + # check r + self.assertEqual(r, exp_r) @skipCUDAIfNoMagma @skipCPUIfNoLapack @@ -3112,11 +3150,22 @@ def test_qr_out(self, device, dtype): out = (torch.empty((0), dtype=dtype, device=device), torch.empty((0), dtype=dtype, device=device)) q2, r2 = torch.linalg.qr(t, mode=mode, out=out) - assert q2 is out[0] - assert r2 is out[1] + self.assertIs(q2, out[0]) + self.assertIs(r2, out[1]) self.assertEqual(q2, q) self.assertEqual(r2, r) + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float) + def test_qr_error_cases(self, device, dtype): + t1 = torch.randn(5, device=device, dtype=dtype) + with self.assertRaisesRegex(RuntimeError, 'qr input should have at least 2 dimensions, but has 1 dimensions instead'): + torch.linalg.qr(t1) + t2 = torch.randn((5, 7), device=device, dtype=dtype) + with self.assertRaisesRegex(RuntimeError, "qr received unrecognized mode 'hello'"): + torch.linalg.qr(t2, mode='hello') + @dtypes(torch.double, torch.cdouble) def test_einsum(self, device, dtype): def check(equation, *operands): diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index d204afdb286e..b08db84d9459 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -6676,11 +6676,10 @@ def merge_dicts(*dicts): If :attr:`some` is ``True``, then this function returns the thin (reduced) QR factorization. Otherwise, if :attr:`some` is ``False``, this function returns the complete QR factorization. -.. warning:: ``torch.qr`` is deprecated. Please use ``torch.linalg.`` :meth:`~torch.linalg.qr` - instead, which provides a better compatibility with - ``numpy.linalg.qr``. +.. warning:: ``torch.qr`` is deprecated. Please use ``torch.linalg.`` :func:`~torch.linalg.qr` + instead. - **Differences with** ``torch.linalg.`` :meth:`~torch.linalg.qr`: + **Differences with** ``torch.linalg.qr``: * ``torch.linalg.qr`` takes a string parameter ``mode`` instead of ``some``: @@ -6698,21 +6697,21 @@ def merge_dicts(*dicts): .. note:: This function uses LAPACK for CPU inputs and MAGMA for CUDA inputs, and may produce different (valid) decompositions on different device types - and different platforms, depending on the precise version of the - underlying library. + or different platforms. Args: input (Tensor): the input tensor of size :math:`(*, m, n)` where `*` is zero or more batch dimensions consisting of matrices of dimension :math:`m \times n`. some (bool, optional): Set to ``True`` for reduced QR decomposition and ``False`` for - complete QR decomposition. + complete QR decomposition. If `k = min(m, n)` then: + + * ``some=True`` : returns `(Q, R)` with dimensions (m, k), (k, n) (default) + + * ``'some=False'``: returns `(Q, R)` with dimensions (m, m), (m, n) Keyword args: - out (tuple, optional): tuple of `Q` and `R` tensors - satisfying :code:`input = torch.matmul(Q, R)`. - The dimensions of `Q` and `R` are :math:`(*, m, k)` and :math:`(*, k, n)` - respectively, where :math:`k = \min(m, n)` if :attr:`some:` is ``True`` and - :math:`k = m` otherwise. + out (tuple, optional): tuple of `Q` and `R` tensors. + The dimensions of `Q` and `R` are detailed in the description of :attr:`some` above. Example:: diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 6558295d58cb..79d195d73a74 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -2078,7 +2078,7 @@ Tensor linalg_qr_backward(const std::vector &grads, c std::string mode, const Tensor& q, const Tensor& r){ bool compute_q, reduced; std::tie(compute_q, reduced) = at::native::_parse_qr_mode(mode); - TORCH_CHECK(compute_q, "linalg_qr_backward: cannot compute backward if mode='r'. " + TORCH_CHECK(compute_q, "The derivative of qr is not implemented when mode='r'. " "Please use torch.linalg.qr(..., mode='reduced')"); auto square_deep_case_backward = [](const Tensor& grad_Q, @@ -2145,7 +2145,7 @@ Tensor linalg_qr_backward(const std::vector &grads, c TORCH_CHECK( ((m <= n && (!reduced)) || reduced), - "The derivative is not implemented when nrows > ncols and complete QR. "); + "The derivative of qr is not implemented when mode='complete' and nrows > ncols."); auto grad_Q = grads[0]; auto grad_R = grads[1]; diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index 4c724b0b7e4c..de5fcb54ddb7 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -731,15 +731,15 @@ .. note:: Backpropagation is not supported for ``mode='r'``. Use ``mode='reduced'`` instead. - If you plan to backpropagate through QR, note that the current backward implementation - is only well-defined when the first :math:`\min(input.size(-1), input.size(-2))` - columns of :attr:`input` are linearly independent. - This behavior may change in the future. + Backpropagation is also not supported if the first + :math:`\min(input.size(-1), input.size(-2))` columns of any matrix + in :attr:`input` are not linearly independent. While no error will + be thrown when this occurs the values of the "gradient" produced may + be anything. This behavior may change in the future. .. note:: This function uses LAPACK for CPU inputs and MAGMA for CUDA inputs, and may produce different (valid) decompositions on different device types - and different platforms, depending on the precise version of the - underlying library. + or different platforms. Args: input (Tensor): the input tensor of size :math:`(*, m, n)` where `*` is zero or more @@ -753,11 +753,8 @@ * ``'r'``: computes only `R`; returns `(Q, R)` where `Q` is empty and `R` has dimensions (k, n) Keyword args: - out (tuple, optional): tuple of `Q` and `R` tensors - satisfying :code:`input = torch.matmul(Q, R)`. - The dimensions of `Q` and `R` are :math:`(*, m, k)` and :math:`(*, k, n)` - respectively, where :math:`k = \min(m, n)` if :attr:`mode` is `'reduced'` and - :math:`k = m` if :attr:`mode` is `'complete'`. + out (tuple, optional): tuple of `Q` and `R` tensors. + The dimensions of `Q` and `R` are detailed in the description of :attr:`mode` above. Example:: @@ -779,6 +776,11 @@ tensor([[ 1., 0., 0.], [ 0., 1., -0.], [ 0., -0., 1.]]) + >>> q2, r2 = torch.linalg.qr(a, mode='r') + >>> q2 + tensor([]) + >>> torch.equal(r, r2) + True >>> a = torch.randn(3, 4, 5) >>> q, r = torch.linalg.qr(a, mode='complete') >>> torch.allclose(torch.matmul(q, r), a) From 81778e28119271001cc75dc8ca74cffc6031ced8 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Fri, 8 Jan 2021 10:05:24 -0800 Subject: [PATCH 43/54] [onnx] Do not deref nullptr in scalar type analysis (#50237) Summary: Apply a little bit of defensive programming: `type->cast()` returns an optional pointer so dereferencing it can lead to a hard crash. Fixes SIGSEGV reported in https://github.com/pytorch/pytorch/issues/49959 Pull Request resolved: https://github.com/pytorch/pytorch/pull/50237 Reviewed By: walterddr Differential Revision: D25839675 Pulled By: malfet fbshipit-source-id: 403d6df5e2392dd6adc308b1de48057f2f9d77ab --- .../jit/passes/onnx/scalar_type_analysis.cpp | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp b/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp index 1d38e1e1f4cd..1a63edf0fc15 100644 --- a/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp +++ b/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp @@ -71,6 +71,7 @@ static bool IsComparisonOp(const NodeKind& nkind) { static TensorTypePtr CreateProfiledTensorTypeWithScalarType( const TensorTypePtr& typePtr, const c10::ScalarType& scalar_type) { + AT_ASSERT(typePtr != nullptr); return typePtr->withScalarType({scalar_type}); } @@ -132,6 +133,15 @@ static c10::optional PromoteScalarTypesWithCategory( static c10::optional InferExpectedScalarType(const Node* n) { std::vector typesFromTensors; std::vector typesFromScalars; + + auto get_scalar_type = + [](const Value* input) -> c10::optional { + if (auto tensor_type = input->type()->cast()) { + return tensor_type->scalarType(); + } + return c10::nullopt; + }; + std::for_each( n->inputs().begin(), n->inputs().end(), [&](const Value* input) { auto nkind = input->node()->kind(); @@ -180,16 +190,13 @@ static c10::optional InferExpectedScalarType(const Node* n) { } else { typesFromTensors.emplace_back(scalar_type); } - } else if ( - auto scalar_type = - input->type()->cast()->scalarType()) { + } else if (auto scalar_type = get_scalar_type(input)) { typesFromTensors.emplace_back(*scalar_type); } }); c10::optional st = c10::nullopt; - const c10::optional output_st = - n->output()->type()->cast()->scalarType(); + const auto output_st = get_scalar_type(n->output()); if (IsComparisonOp(n->kind())) { // For comparison ops, always promote scalar type to highest among inputs, @@ -236,7 +243,8 @@ static void UpdateScalarTypeForInputs( for (auto input : n->inputs()) { auto input_tensor_type = input->type()->cast(); - auto input_scalar_type = input_tensor_type->scalarType(); + auto input_scalar_type = + input_tensor_type ? input_tensor_type->scalarType() : c10::nullopt; if ((input->node()->kind() == onnx::Constant) || (input_scalar_type && (*input_scalar_type != scalar_type))) { From a4f30d48d85b934253bef011d01af7dda3468750 Mon Sep 17 00:00:00 2001 From: Richard Barnes Date: Fri, 8 Jan 2021 10:54:58 -0800 Subject: [PATCH 44/54] Clean up some type annotations in test/jit (#50158) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/50158 Upgrades type annotations from Python2 to Python3 Test Plan: Sandcastle tests Reviewed By: xush6528 Differential Revision: D25717504 fbshipit-source-id: 9a83c44db02ec79f353862255732873f6d7f885e --- test/jit/test_async.py | 6 +- test/jit/test_builtins.py | 14 +- test/jit/test_list_dict.py | 252 ++++++++++-------------------- test/jit/test_module_interface.py | 178 +++++++-------------- test/jit/test_recursive_script.py | 12 +- test/jit/test_save_load.py | 12 +- test/jit/test_tracer.py | 14 +- test/jit/test_with.py | 98 ++++-------- 8 files changed, 197 insertions(+), 389 deletions(-) diff --git a/test/jit/test_async.py b/test/jit/test_async.py index 7a70a4c5a655..37fabcfc438a 100644 --- a/test/jit/test_async.py +++ b/test/jit/test_async.py @@ -41,8 +41,7 @@ def foo(inp): def test_async_parsing(self): @torch.jit.script - def foo(x): - # type: (Tensor) -> List[Tensor] + def foo(x: Tensor) -> List[Tensor]: return [torch.neg(x), x.t()] @torch.jit.script @@ -257,8 +256,7 @@ def __init__(self): self.traced = torch.jit.trace(Traced(), (x), _force_outplace=True) @torch.jit.script_method - def forward(self, x): - # type: (Tensor) -> Tuple[List[Tensor], Tuple[Tensor, Tensor], Tensor] + def forward(self, x: Tensor) -> Tuple[List[Tensor], Tuple[Tensor, Tensor], Tensor]: future1 = torch.jit._fork(self.traced, x) future2 = torch.jit._fork(torch.neg, x) diff --git a/test/jit/test_builtins.py b/test/jit/test_builtins.py index 04991f72c352..b5a0dd8599a6 100644 --- a/test/jit/test_builtins.py +++ b/test/jit/test_builtins.py @@ -2,7 +2,7 @@ import sys import inspect import unittest -from typing import List +from typing import Dict, List import torch @@ -78,8 +78,7 @@ def forward(self, name): torch.jit.script(Mod()) def test_del(self): - def fn(x): - # type: (List[int]) -> List[int] + def fn(x: List[int]) -> List[int]: a = x * 2 del a return x @@ -109,16 +108,14 @@ def fn(x): return a def test_del_multiple_operands(self): - def fn(x): - # type: (List[int]) -> List[int] + def fn(x: List[int]) -> List[int]: a, b, c = x[0], x[1], x[2] del a, b, c return x self.checkScript(fn, ([1, 2, 3],)) - def del_list_multiple_operands(x): - # type: (List[int]) -> List[int] + def del_list_multiple_operands(x: List[int]) -> List[int]: del x[0], x[1] return x @@ -126,8 +123,7 @@ def del_list_multiple_operands(x): jit_out = torch.jit.script(del_list_multiple_operands)([0, 1, 2]) self.assertEquals(py_out, jit_out) - def del_dict_multiple_operands(x): - # type: (Dict[str, int]) -> Dict[str, int] + def del_dict_multiple_operands(x: Dict[str, int]) -> Dict[str, int]: del x['hi'], x['there'] return x diff --git a/test/jit/test_list_dict.py b/test/jit/test_list_dict.py index 2bc24a57751d..63d3aa3938e9 100644 --- a/test/jit/test_list_dict.py +++ b/test/jit/test_list_dict.py @@ -20,22 +20,19 @@ class TestList(JitTestCase): def test_in_check(self): - def int_in(x): - # type: (List[int]) -> bool + def int_in(x: List[int]) -> bool: return 2 in x self.checkScript(int_in, ([1, 2, 3],)) self.checkScript(int_in, ([1, 3, 3],)) - def float_in(x): - # type: (List[float]) -> bool + def float_in(x: List[float]) -> bool: return 2. in x self.checkScript(float_in, ([1., 2., 3.],)) self.checkScript(float_in, ([1., 3., 3.],)) - def str_in(x): - # type: (List[str]) -> bool + def str_in(x: List[str]) -> bool: return 'hi' in x self.checkScript(str_in, (['not', 'here'],)) @@ -100,8 +97,7 @@ def test_del(self): def inputs(): return [1, 2, 3, 4] - def fn(x): - # type: (List[int]) -> List[int] + def fn(x: List[int]) -> List[int]: del x[1] return x @@ -114,8 +110,7 @@ def fn(x): self.assertEqual(torch.jit.script(fn)(inputs()), python_out) @torch.jit.script - def fn2(x): - # type: (List[int]) -> List[int] + def fn2(x: List[int]) -> List[int]: del x[100] return x @@ -124,8 +119,7 @@ def fn2(x): with self.assertRaisesRegex(RuntimeError, "deletion at a single index"): @torch.jit.script - def fn(x): - # type: (List[int]) -> List[int] + def fn(x: List[int]) -> List[int]: del x[1:3] return x @@ -149,23 +143,19 @@ def foo3(): FileCheck().check_count("aten::list", 2, exactly=True).run(torch.jit.script(foo3).graph) def test_min_bool_list(self): - def jit_min_list(a, b): - # type: (List[bool], List[bool]) -> List[bool] + def jit_min_list(a: List[bool], b: List[bool]) -> List[bool]: return min(a, b) self.checkScript(jit_min_list, ([True, False], [False, True])) def test_min_max_list(self): - def jit_min_list(a, b): - # type: (List[int], List[int]) -> List[int] + def jit_min_list(a: List[int], b: List[int]) -> List[int]: return min(a, b) - def jit_min_list_float(a, b): - # type: (List[float], List[float]) -> List[float] + def jit_min_list_float(a: List[float], b: List[float]) -> List[float]: return min(a, b) - def jit_min_list_bool(a, b): - # type: (List[bool], List[bool]) -> List[bool] + def jit_min_list_bool(a: List[bool], b: List[bool]) -> List[bool]: return min(a, b) def run_tests(func, a, b): @@ -186,16 +176,13 @@ def run_tests(func, a, b): [False, True], [False, False, True], [False, False, False]] run_tests(jit_min_list_bool, args_left_bool, args_right_bool) - def jit_max_list(a, b): - # type: (List[int], List[int]) -> List[int] + def jit_max_list(a: List[int], b: List[int]) -> List[int]: return max(a, b) - def jit_max_list_float(a, b): - # type: (List[float], List[float]) -> List[float] + def jit_max_list_float(a: List[float], b: List[float]) -> List[float]: return max(a, b) - def jit_max_list_bool(a, b): - # type: (List[bool], List[bool]) -> List[bool] + def jit_max_list_bool(a: List[bool], b: List[bool]) -> List[bool]: return max(a, b) args_left_int = [[1, 8, 8], [8, 1, 1], [], [1], [], [1, 2]] @@ -365,8 +352,7 @@ def func(): t2 = scope['func']() self.assertEqual(t1, t2) - def test_fail(x): - # type: (List[Tensor]) -> List[Tensor] + def test_fail(x: List[Tensor]) -> List[Tensor]: x.sort() return x @@ -472,8 +458,7 @@ def test_append(): self.checkScript(test_append, ()) def test_comprehensions_basic(self): - def comp(l): - # type: (List[int]) -> List[int] + def comp(l: List[int]) -> List[int]: n = [x * 3 for x in l] return n @@ -482,8 +467,7 @@ def comp(l): self.checkScript(comp, ([1, 2, 3],)) def test_comprehensions_basic_float(self): - def comp(l): - # type: (List[float]) -> List[float] + def comp(l: List[float]) -> List[float]: n = [x * 3 for x in l] return n @@ -492,8 +476,7 @@ def comp(l): def test_comprehensions_two_comps(self): @torch.jit.script - def comp(l1, l2): - # type: (List[int], List[int]) -> List[int] + def comp(l1: List[int], l2: List[int]) -> List[int]: n = [x * 3 for x in l1] n2 = [x + 2 for x in l2] @@ -502,8 +485,7 @@ def comp(l1, l2): self.assertEqual(comp([1, 2, 3], [4, 5]), [3, 6, 9, 6, 7]) def test_comprehension_out_type_not_in_type(self): - def list_cast(): - # type: () -> int + def list_cast() -> int: li = [int(i) for i in [torch.tensor(0), torch.tensor(1), torch.tensor(2)]] return li[0] + li[1] + li[2] @@ -513,15 +495,13 @@ def test_comprehension_iterable(self): def test_func(fn, inputs): self.assertEqual(fn(*inputs), torch.jit.script(fn)(*inputs)) - def foo(names, results): - # type: (List[int], List[int]) -> List[Tuple[int, int]] + def foo(names: List[int], results: List[int]) -> List[Tuple[int, int]]: return [(k + 5, v - 2) for k, v in zip(names, results)] test_func(foo, ([1, 2, 4], [4, 7, 9])) test_func(foo, ([5], [4, 7, 9])) - def fn(x): - # type: (int) -> List[int] + def fn(x: int) -> List[int]: return [i for i in range(x)] # noqa: C416 test_func(fn, (9,)) @@ -601,8 +581,7 @@ def test_nested_loop(): def test_mutable_list_function_inline(self): @torch.jit.script - def bar(y): - # type: (List[int]) -> None + def bar(y: List[int]) -> None: y.append(4) @torch.jit.script @@ -888,8 +867,7 @@ def test_list_remove2(): def test_extend_list_mutable(self): @torch.jit.script - def extend_list(a, b): - # type: (List[Tensor], List[Tensor]) -> List[Tensor] + def extend_list(a: List[Tensor], b: List[Tensor]) -> List[Tensor]: a.extend(b) return a @@ -900,8 +878,7 @@ def extend_list(a, b): def test_extend_list_immutable(self): @torch.jit.script - def extend_list(a, b): - # type: (List[int], List[int]) -> List[int] + def extend_list(a: List[int], b: List[int]) -> List[int]: a.extend(b) return a @@ -912,8 +889,7 @@ def extend_list(a, b): def test_copy_list_mutable(self): @torch.jit.script - def copy_list(a): - # type: (List[Tensor]) -> List[Tensor] + def copy_list(a: List[Tensor]) -> List[Tensor]: return a.copy() for l in [[], [torch.rand(2)], [torch.rand(2), torch.rand(2), torch.rand(2)]]: @@ -921,36 +897,29 @@ def copy_list(a): def test_copy_list_immutable(self): @torch.jit.script - def copy_list(a): - # type: (List[int]) -> List[int] + def copy_list(a: List[int]) -> List[int]: return a.copy() for l in [[], [1], [1, 2, 3]]: self.assertEqual(copy_list(l), l) def test_min_max_single_list(self): - def min_intlist(li): - # type: (List[int]) -> int + def min_intlist(li: List[int]) -> int: return min(li) - def max_intlist(li): - # type: (List[int]) -> int + def max_intlist(li: List[int]) -> int: return max(li) - def min_boollist(li): - # type: (List[bool]) -> bool + def min_boollist(li: List[bool]) -> bool: return min(li) - def max_boollist(li): - # type: (List[bool]) -> bool + def max_boollist(li: List[bool]) -> bool: return max(li) - def min_floatlist(li): - # type: (List[float]) -> float + def min_floatlist(li: List[float]) -> float: return min(li) - def max_floatlist(li): - # type: (List[float]) -> float + def max_floatlist(li: List[float]) -> float: return max(li) @@ -980,23 +949,19 @@ def test_to_list(self): """ Boolean dtype unit tests. """ - def to_list_bool_0D(x): - # type: (torch.Tensor) -> bool + def to_list_bool_0D(x: torch.Tensor) -> bool: li = torch.jit.annotate(bool, x.tolist()) return li - def to_list_bool_1D(x): - # type: (torch.Tensor) -> List[bool] + def to_list_bool_1D(x: torch.Tensor) -> List[bool]: li = torch.jit.annotate(List[bool], x.tolist()) return li - def to_list_bool_2D(x): - # type: (torch.Tensor) -> List[List[bool]] + def to_list_bool_2D(x: torch.Tensor) -> List[List[bool]]: li = torch.jit.annotate(List[List[bool]], x.tolist()) return li - def to_list_bool_3D(x): - # type: (torch.Tensor) -> List[List[List[bool]]] + def to_list_bool_3D(x: torch.Tensor) -> List[List[List[bool]]]: li = torch.jit.annotate(List[List[List[bool]]], x.tolist()) return li @@ -1021,23 +986,19 @@ def to_list_bool_3D(x): """ Int dtype unit tests. """ - def to_list_int_0D(x): - # type: (torch.Tensor) -> int + def to_list_int_0D(x: torch.Tensor) -> int: li = torch.jit.annotate(int, x.tolist()) return li - def to_list_int_1D(x): - # type: (torch.Tensor) -> List[int] + def to_list_int_1D(x: torch.Tensor) -> List[int]: li = torch.jit.annotate(List[int], x.tolist()) return li - def to_list_int_2D(x): - # type: (torch.Tensor) -> List[List[int]] + def to_list_int_2D(x: torch.Tensor) -> List[List[int]]: li = torch.jit.annotate(List[List[int]], x.tolist()) return li - def to_list_int_3D(x): - # type: (torch.Tensor) -> List[List[List[int]]] + def to_list_int_3D(x: torch.Tensor) -> List[List[List[int]]]: li = torch.jit.annotate(List[List[List[int]]], x.tolist()) return li @@ -1058,23 +1019,19 @@ def to_list_int_3D(x): """ Float dtype unit tests. """ - def to_list_float_0D(x): - # type: (torch.Tensor) -> float + def to_list_float_0D(x: torch.Tensor) -> float: li = torch.jit.annotate(float, x.tolist()) return li - def to_list_float_1D(x): - # type: (torch.Tensor) -> List[float] + def to_list_float_1D(x: torch.Tensor) -> List[float]: li = torch.jit.annotate(List[float], x.tolist()) return li - def to_list_float_2D(x): - # type: (torch.Tensor) -> List[List[float]] + def to_list_float_2D(x: torch.Tensor) -> List[List[float]]: li = torch.jit.annotate(List[List[float]], x.tolist()) return li - def to_list_float_3D(x): - # type: (torch.Tensor) -> List[List[List[float]]] + def to_list_float_3D(x: torch.Tensor) -> List[List[List[float]]]: li = torch.jit.annotate(List[List[List[float]]], x.tolist()) return li @@ -1099,28 +1056,23 @@ def to_list_float_3D(x): - type annotation with the wrong dimension - type annotation with scalar type that doesn't match the input scalar type """ - def to_list_missing_type_annotation(x): - # type: (torch.Tensor) -> List[float] + def to_list_missing_type_annotation(x: torch.Tensor) -> List[float]: li = x.tolist() return li - def to_list_incorrect_type_annotation(x): - # type: (torch.Tensor) -> List[float] + def to_list_incorrect_type_annotation(x: torch.Tensor) -> List[float]: li = torch.jit.annotate(float, x.tolist()) return li - def to_list_unsupported_type_annotation(x): - # type: (torch.Tensor) -> List[float] + def to_list_unsupported_type_annotation(x: torch.Tensor) -> List[float]: li = torch.jit.annotate(List[str], x.tolist()) return li - def to_list_type_annotation_wrong_dim(x): - # type: (torch.Tensor) -> List[List[float]] + def to_list_type_annotation_wrong_dim(x: torch.Tensor) -> List[List[float]]: li = torch.jit.annotate(List[List[float]], x.tolist()) return li - def to_list_type_annotation_incorrect_scalar_type(x): - # type: (torch.Tensor) -> List[float] + def to_list_type_annotation_incorrect_scalar_type(x: torch.Tensor) -> List[float]: li = torch.jit.annotate(List[float], x.tolist()) return li @@ -1164,18 +1116,15 @@ def test_to_list_gpu(self): if not torch.cuda.is_available() or torch.cuda.device_count() == 0: self.skipTest("CUDA is not available") - def to_list_bool_1D(x): - # type: (torch.Tensor) -> List[bool] + def to_list_bool_1D(x: torch.Tensor) -> List[bool]: li = torch.jit.annotate(List[bool], x.tolist()) return li - def to_list_int_1D(x): - # type: (torch.Tensor) -> List[int] + def to_list_int_1D(x: torch.Tensor) -> List[int]: li = torch.jit.annotate(List[int], x.tolist()) return li - def to_list_float_1D(x): - # type: (torch.Tensor) -> List[float] + def to_list_float_1D(x: torch.Tensor) -> List[float]: li = torch.jit.annotate(List[float], x.tolist()) return li @@ -1187,8 +1136,7 @@ def to_list_float_1D(x): 5, dtype=torch.double).cuda(),)) def test_no_element_type_annotation(self): - def fn_with_comment(x): - # type: (torch.Tensor) -> List + def fn_with_comment(x: torch.Tensor) -> List: a: List = x.tolist() return a @@ -1230,8 +1178,7 @@ def test_del(self): def inputs(): return {'hi': 2, 'bye': 3} - def fn(x): - # type: (Dict[str, int]) -> Dict[str, int] + def fn(x: Dict[str, int]) -> Dict[str, int]: del x['hi'] return x @@ -1247,8 +1194,7 @@ def fn(x): def test_keys(self): @torch.jit.script - def keys(x): - # type: (Dict[str, Tensor]) -> List[str] + def keys(x: Dict[str, Tensor]) -> List[str]: return list(x.keys()) self.assertEqual(set(keys(self.dict())), set(self.dict().keys())) @@ -1263,30 +1209,26 @@ def specialized_list(): def test_values(self): @torch.jit.script - def values(x): - # type: (Dict[str, Tensor]) -> List[Tensor] + def values(x: Dict[str, Tensor]) -> List[Tensor]: return list(x.values()) the_dict = self.dict() self.assertEqual(set(values(the_dict)), set(the_dict.values())) def test_len(self): - def length(x): - # type: (Dict[str, Tensor]) -> int + def length(x: Dict[str, Tensor]) -> int: return len(x) self.checkScript(length, (self.dict(),)) def test_copy(self): - def func(x): - # type: (Dict[str, Tensor]) -> Dict[str, Tensor] + def func(x: Dict[str, Tensor]) -> Dict[str, Tensor]: return x.copy() self.checkScript(func, (self.dict(),)) def test_items(self): - def func(x): - # type: (Dict[str, Tensor]) -> List[Tuple[str, Tensor]] + def func(x: Dict[str, Tensor]) -> List[Tuple[str, Tensor]]: return x.items() # The value returned by Python is in arbitrary order, so we can't use @@ -1301,8 +1243,7 @@ def func(x): self.assertTrue(item in script_out) def test_pop(self): - def pop(x, key): - # type: (Dict[str, Tensor], str) -> Tuple[Tensor, Dict[str, Tensor]] + def pop(x: Dict[str, Tensor], key: str) -> Tuple[Tensor, Dict[str, Tensor]]: return x.pop(key), x # checkScript doesn't copy the inputs, so we can't use it since this mutates @@ -1318,16 +1259,14 @@ def tester(fn, *args): torch.jit.script(pop)(self.dict(), 'x') - def default_pop(x, key, default): - # type: (Dict[str, Tensor], str, Tensor) -> Tuple[Tensor, Dict[str, Tensor]] + def default_pop(x: Dict[str, Tensor], key: str, default: Tensor) -> Tuple[Tensor, Dict[str, Tensor]]: return x.pop(key, default), x tester(default_pop, 'a', torch.randn(2, 2)) tester(default_pop, 'x', torch.randn(2, 2)) def test_setdefault(self): - def setdefault(x, key, default): - # type: (Dict[str, Tensor], str, Tensor) -> Dict[str, Tensor] + def setdefault(x: Dict[str, Tensor], key: str, default: Tensor) -> Dict[str, Tensor]: x.setdefault(key, default) return x @@ -1335,8 +1274,7 @@ def setdefault(x, key, default): self.checkScript(setdefault, (self.dict(), 'nonexistant', torch.randn(2, 2))) def test_update(self): - def update(a, b): - # type: (Dict[str, Tensor], Dict[str, Tensor]) -> Tuple[Dict[str, Tensor], Dict[str, Tensor]] + def update(a: Dict[str, Tensor], b: Dict[str, Tensor]) -> Tuple[Dict[str, Tensor], Dict[str, Tensor]]: a.update(b) return a, b @@ -1353,8 +1291,7 @@ def foo() -> Dict[str, int]: self.checkScript(foo, ()) def test_aug_assign(self): - def aug_assign_dict_tensor(a): - # type: (Dict[str, Tensor]) -> Dict[str, Tensor] + def aug_assign_dict_tensor(a: Dict[str, Tensor]) -> Dict[str, Tensor]: a['a'] += 1 a['b'] -= 12 a['c'] *= 122 @@ -1362,8 +1299,7 @@ def aug_assign_dict_tensor(a): a['c'] %= 2 return a - def aug_assign_dict_prim(a): - # type: (Dict[str, float]) -> Dict[str, float] + def aug_assign_dict_prim(a: Dict[str, float]) -> Dict[str, float]: a['a'] += 3.4 a['b'] -= 2.4 a['c'] *= 3.0 @@ -1376,8 +1312,7 @@ def aug_assign_dict_prim(a): def test_popitem(self): @torch.jit.script - def popitem(x): - # type: (Dict[str, Tensor]) -> Tuple[Tuple[str, Tensor], Dict[str, Tensor]] + def popitem(x: Dict[str, Tensor]) -> Tuple[Tuple[str, Tensor], Dict[str, Tensor]]: item = x.popitem() return item, x @@ -1395,65 +1330,56 @@ def popitem(x): self.assertTrue(isinstance(script_out[0][1], torch.Tensor)) def test_clear(self): - def clear(x): - # type: (Dict[str, Tensor]) -> Dict[str, Tensor] + def clear(x: Dict[str, Tensor]) -> Dict[str, Tensor]: x.clear() return x self.checkScript(clear, (self.dict(),)) def test_get(self): - def get(x, key): - # type: (Dict[str, Tensor], str) -> Optional[Tensor] + def get(x: Dict[str, Tensor], key: str) -> Optional[Tensor]: return x.get(key) self.checkScript(get, (self.dict(), 'a')) self.checkScript(get, (self.dict(), "doesn't exist")) - def get_default(x, key): - # type: (Dict[str, Tensor], str) -> Optional[Tensor] + def get_default(x: Dict[str, Tensor], key: str) -> Optional[Tensor]: return x.get(key, torch.randn(2, 2)) self.checkScript(get, (self.dict(), 'a')) self.checkScript(get, (self.dict(), "doesn't exist")) def test_get_boolkey(self): - def get(x, key): - # type: (Dict[bool, int], bool) -> Optional[int] + def get(x: Dict[bool, int], key: bool) -> Optional[int]: return x.get(key) self.checkScript(get, (self.dict_bool(), True)) self.checkScript(get, (self.dict_bool(), False)) - def get_default(x, key): - # type: (Dict[bool, int], bool) -> int + def get_default(x: Dict[bool, int], key: bool) -> int: return x.get(key, 42) self.checkScript(get_default, (self.dict_bool(), True)) self.checkScript(get_default, (self.dict_bool(), False)) def test_basic(self): - def simple(x): - # type: (Dict[str, int]) -> Dict[str, int] + def simple(x: Dict[str, int]) -> Dict[str, int]: return x self.checkScript(simple, ({'item': 20, 'other_item': 120},)) - def index(x): - # type: (Dict[str, int]) -> int + def index(x: Dict[str, int]) -> int: return x['item'] self.checkScript(index, ({'item': 20, 'other_item': 120},)) - def type_default(): - # type: () -> Dict[str, Tensor] + def type_default() -> Dict[str, Tensor]: return {} self.checkScript(type_default, ()) @torch.jit.script - def missing_index(x): - # type: (Dict[str, int]) -> int + def missing_index(x: Dict[str, int]) -> int: return x['dne'] with self.assertRaisesRegex(RuntimeError, "KeyError"): @@ -1475,16 +1401,14 @@ def literal3(): ''')) self.assertEqual({10: 1.2, 11: 1.3}, cu.literal3()) - def list_of_dicts(): - # type: () -> List[Dict[str, Tensor]] + def list_of_dicts() -> List[Dict[str, Tensor]]: return [{'word': torch.ones(2) + 3}, {'other word': torch.ones(1) + 2}] self.checkScript(list_of_dicts, ()) def test_mutability(self): @torch.jit.script - def fn(): - # type: () -> Dict[str, int] + def fn() -> Dict[str, int]: a = torch.jit.annotate(Dict[str, int], {}) a['ok'] = 10 return a @@ -1494,14 +1418,12 @@ def fn(): def test_key_type(self): with self.assertRaisesRegex(RuntimeError, "but instead found type"): @torch.jit.script - def fn(a): - # type: (Dict[str, int]) -> int + def fn(a: Dict[str, int]) -> int: return a[None] def test_loop(self): @torch.jit.script - def fn(x): - # type: (int) -> Dict[str, int] + def fn(x: int) -> Dict[str, int]: a = torch.jit.annotate(Dict[str, int], {}) for i in range(x): a['ok'] = i @@ -1520,16 +1442,14 @@ def fn(x, y): self.checkScript(fn, (torch.rand(2, 3), torch.rand(2, 3))) def test_membership(self): - def fn(x, y): - # type: (Dict[int, int], int) -> int + def fn(x: Dict[int, int], y: int) -> int: return x.get(y, 3) d = {1: 2, 3: 4} self.checkScript(fn, (d, 3)) self.checkScript(fn, (d, 2)) - def optional(x, y): - # type: (Dict[int, int], int) -> bool + def optional(x: Dict[int, int], y: int) -> bool: res = x.get(y) return res is None @@ -1538,18 +1458,15 @@ def optional(x, y): with self.assertRaisesRegex(RuntimeError, "is actually of type Optional"): @torch.jit.script - def bad_types(x, y): - # type: (Dict[int, int], int) -> int + def bad_types(x: Dict[int, int], y: int) -> int: return x.get(y) # noqa: T484 def test_dict_to_python(self): @torch.jit.ignore - def python_lookup(my_dict, keys): - # type: (Dict[str, int], List[str]) -> List[int] + def python_lookup(my_dict: Dict[str, int], keys: List[str]) -> List[int]: return [my_dict[k] for k in keys] - def fn(my_dict, keys): - # type: (Dict[str, int], List[str]) -> List[int] + def fn(my_dict: Dict[str, int], keys: List[str]) -> List[int]: return python_lookup(my_dict, keys) a_dict = {'a': torch.ones(1), 'b': torch.ones(1) + 1, 'c': torch.ones(1) + 2} @@ -1601,8 +1518,7 @@ def test_type_annotation_missing_contained_type(self): key and value types produces an error. """ # This function uses a type comment. - def fn_with_comment(input): - # type: (Dict) -> Any + def fn_with_comment(input: Dict) -> Any: return input # This function uses Python3 style type annotations. diff --git a/test/jit/test_module_interface.py b/test/jit/test_module_interface.py index 70071cc56028..d32537eb2ac5 100644 --- a/test/jit/test_module_interface.py +++ b/test/jit/test_module_interface.py @@ -6,6 +6,7 @@ import torch.nn as nn import os import sys +from torch import Tensor from torch.testing._internal.jit_utils import JitTestCase # Make the helper files in test/ importable @@ -22,36 +23,30 @@ class OrigModule(nn.Module): def __init__(self): super(OrigModule, self).__init__() - def one(self, inp1, inp2): - # type: (Tensor, Tensor) -> Tensor + def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: return inp1 + inp2 + 1 - def two(self, input): - # type: (Tensor) -> Tensor + def two(self, input: Tensor) -> Tensor: return input + 2 - def forward(self, input): - # type: (Tensor) -> Tensor + def forward(self, input: Tensor) -> Tensor: return input + self.one(input, input) + 1 class NewModule(nn.Module): def __init__(self): super(NewModule, self).__init__() - def one(self, inp1, inp2): - # type: (Tensor, Tensor) -> Tensor + def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: return inp1 * inp2 + 1 - def forward(self, input): - # type: (Tensor) -> Tensor + def forward(self, input: Tensor) -> Tensor: return self.one(input, input + 1) class TestModuleInterface(JitTestCase): def test_not_submodule_interface_call(self): @torch.jit.interface class ModuleInterface(nn.Module): - def one(self, inp1, inp2): - # type: (Tensor, Tensor) -> Tensor + def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: pass class TestNotModuleInterfaceCall(nn.Module): @@ -61,8 +56,7 @@ def __init__(self): super(TestNotModuleInterfaceCall, self).__init__() self.proxy_mod = OrigModule() - def forward(self, input): - # type: (Tensor) -> Tensor + def forward(self, input: Tensor) -> Tensor: return self.proxy_mod.two(input) with self.assertRaisesRegex(RuntimeError, "Tried to access nonexistent attribute"): @@ -72,64 +66,51 @@ def test_module_interface(self): global OneTwoModule, OneTwoClass @torch.jit.interface class OneTwoModule(nn.Module): - def one(self, x, y): - # type: (Tensor, Tensor) -> Tensor + def one(self, x: Tensor, y: Tensor) -> Tensor: pass - def two(self, x): - # type: (Tensor) -> Tensor + def two(self, x: Tensor) -> Tensor: pass - def forward(self, x): - # type: (Tensor) -> Tensor + def forward(self, x: Tensor) -> Tensor: pass @torch.jit.interface class OneTwoClass(object): - def one(self, x, y): - # type: (Tensor, Tensor) -> Tensor + def one(self, x: Tensor, y: Tensor) -> Tensor: pass - def two(self, x): - # type: (Tensor) -> Tensor + def two(self, x: Tensor) -> Tensor: pass class FooMod(nn.Module): - def one(self, x, y): - # type: (Tensor, Tensor) -> Tensor + def one(self, x: Tensor, y: Tensor) -> Tensor: return x + y - def two(self, x): - # type: (Tensor) -> Tensor + def two(self, x: Tensor) -> Tensor: return 2 * x - def forward(self, x): - # type: (Tensor) -> Tensor + def forward(self, x: Tensor) -> Tensor: return self.one(self.two(x), x) class BarMod(nn.Module): - def one(self, x, y): - # type: (Tensor, Tensor) -> Tensor + def one(self, x: Tensor, y: Tensor) -> Tensor: return x * y - def two(self, x): - # type: (Tensor) -> Tensor + def two(self, x: Tensor) -> Tensor: return 2 / x - def forward(self, x): - # type: (Tensor) -> Tensor + def forward(self, x: Tensor) -> Tensor: return self.two(self.one(x, x)) @torch.jit.export - def forward2(self, x): - # type: (Tensor) -> Tensor + def forward2(self, x: Tensor) -> Tensor: return self.two(self.one(x, x)) + 1 def use_module_interface(mod_list: List[OneTwoModule], x: torch.Tensor): return mod_list[0].forward(x) + mod_list[1].forward(x) - def use_class_interface(mod_list, x): - # type: (List[OneTwoClass], Tensor) -> Tensor + def use_class_interface(mod_list: List[OneTwoClass], x: Tensor) -> Tensor: return mod_list[0].two(x) + mod_list[1].one(x, x) scripted_foo_mod = torch.jit.script(FooMod()) @@ -139,8 +120,7 @@ def use_class_interface(mod_list, x): self.checkScript(use_class_interface, ([scripted_foo_mod, scripted_bar_mod], torch.rand(3, 4),)) - def call_module_interface_on_other_method(mod_interface, x): - # type: (OneTwoModule, Tensor) -> Tensor + def call_module_interface_on_other_method(mod_interface: OneTwoModule, x: Tensor) -> Tensor: return mod_interface.forward2(x) # ensure error out when we call the module on the method other than the interface specified. @@ -152,35 +132,28 @@ def test_module_interface_subtype(self): global OneTwoModule @torch.jit.interface class OneTwoModule(nn.Module): - def one(self, x, y): - # type: (Tensor, Tensor) -> Tensor + def one(self, x: Tensor, y: Tensor) -> Tensor: pass - def two(self, x): - # type: (Tensor) -> Tensor + def two(self, x: Tensor) -> Tensor: pass - def forward(self, x): - # type: (Tensor) -> Tensor + def forward(self, x: Tensor) -> Tensor: pass @torch.jit.script - def as_module_interface(x): - # type: (OneTwoModule) -> OneTwoModule + def as_module_interface(x: OneTwoModule) -> OneTwoModule: return x @torch.jit.script class Foo(object): - def one(self, x, y): - # type: (Tensor, Tensor) -> Tensor + def one(self, x: Tensor, y: Tensor) -> Tensor: return x + y - def two(self, x): - # type: (Tensor) -> Tensor + def two(self, x: Tensor) -> Tensor: return 2 * x - def forward(self, x): - # type: (Tensor) -> Tensor + def forward(self, x: Tensor) -> Tensor: return self.one(self.two(x), x) # check class object is not a subtype of module interface @@ -188,12 +161,10 @@ def forward(self, x): as_module_interface(Foo()) class WrongMod(nn.Module): - def two(self, x): - # type: (int) -> int + def two(self, x: int) -> int: return 2 * x - def forward(self, x): - # type: (Tensor) -> Tensor + def forward(self, x: Tensor) -> Tensor: return x + torch.randn(3, self.two(3)) scripted_wrong_mod = torch.jit.script(WrongMod()) @@ -244,19 +215,16 @@ def test_module_interface_inheritance(self): with self.assertRaisesRegex(RuntimeError, "does not support inheritance yet. Please directly"): @torch.jit.interface class InheritMod(nn.ReLU): - def three(self, x): - # type: (Tensor) -> Tensor + def three(self, x: Tensor) -> Tensor: return 3 * x def test_module_swap(self): @torch.jit.interface class ModuleInterface(nn.Module): - def one(self, inp1, inp2): - # type: (Tensor, Tensor) -> Tensor + def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: pass - def forward(self, input): - # type: (Tensor) -> Tensor + def forward(self, input: Tensor) -> Tensor: pass class TestModule(nn.Module): @@ -266,8 +234,7 @@ def __init__(self): super(TestModule, self).__init__() self.proxy_mod = OrigModule() - def forward(self, input): - # type: (Tensor) -> Tensor + def forward(self, input: Tensor) -> Tensor: return self.proxy_mod.forward(input) scripted_mod = torch.jit.script(TestModule()) @@ -285,20 +252,17 @@ def forward(self, input): def test_module_swap_wrong_module(self): @torch.jit.interface class ModuleInterface(nn.Module): - def one(self, inp1, inp2): - # type: (Tensor, Tensor) -> Tensor + def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: pass - def forward(self, input): - # type: (Tensor) -> Tensor + def forward(self, input: Tensor) -> Tensor: pass class NewModuleWrong(nn.Module): def __init__(self): super(NewModuleWrong, self).__init__() - def forward(self, input): - # type: (int) -> int + def forward(self, input: int) -> int: return input + 1 class TestModule(nn.Module): @@ -308,8 +272,7 @@ def __init__(self): super(TestModule, self).__init__() self.proxy_mod = OrigModule() - def forward(self, input): - # type: (Tensor) -> Tensor + def forward(self, input: Tensor) -> Tensor: return self.proxy_mod.forward(input) scripted_mod = torch.jit.script(TestModule()) @@ -320,12 +283,10 @@ def forward(self, input): def test_module_swap_no_lazy_compile(self): @torch.jit.interface class ModuleInterface(nn.Module): - def one(self, inp1, inp2): - # type: (Tensor, Tensor) -> Tensor + def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: pass - def forward(self, input): - # type: (Tensor) -> Tensor + def forward(self, input: Tensor) -> Tensor: pass class TestModule(nn.Module): @@ -335,20 +296,17 @@ def __init__(self): super(TestModule, self).__init__() self.proxy_mod = OrigModule() - def forward(self, input): - # type: (Tensor) -> Tensor + def forward(self, input: Tensor) -> Tensor: return self.proxy_mod.forward(input) class NewModuleMethodNotLazyCompile(nn.Module): def __init__(self): super(NewModuleMethodNotLazyCompile, self).__init__() - def one(self, inp1, inp2): - # type: (Tensor, Tensor) -> Tensor + def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: return inp1 * inp2 + 1 - def forward(self, input): - # type: (Tensor) -> Tensor + def forward(self, input: Tensor) -> Tensor: return input + 1 scripted_mod = torch.jit.script(TestModule()) @@ -362,12 +320,10 @@ def __init__(self): super(NewModuleMethodManualExport, self).__init__() @torch.jit.export - def one(self, inp1, inp2): - # type: (Tensor, Tensor) -> Tensor + def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: return inp1 * inp2 + 1 - def forward(self, input): - # type: (Tensor) -> Tensor + def forward(self, input: Tensor) -> Tensor: return input + 1 scripted_mod.proxy_mod = torch.jit.script(NewModuleMethodManualExport()) @@ -381,8 +337,7 @@ def __init__(self): super(TestNoModuleInterface, self).__init__() self.proxy_mod = OrigModule() - def forward(self, input): - # type: (Tensor) -> Tensor + def forward(self, input: Tensor) -> Tensor: return self.proxy_mod(input) scripted_no_module_interface = torch.jit.script(TestNoModuleInterface()) @@ -397,12 +352,10 @@ def forward(self, input): def test_script_module_as_interface_swap(self): @torch.jit.interface class ModuleInterface(nn.Module): - def one(self, inp1, inp2): - # type: (Tensor, Tensor) -> Tensor + def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: pass - def forward(self, input): - # type: (Tensor) -> Tensor + def forward(self, input: Tensor) -> Tensor: pass class OrigScriptModule(torch.jit.ScriptModule): @@ -410,13 +363,11 @@ def __init__(self): super(OrigScriptModule, self).__init__() @torch.jit.script_method - def one(self, inp1, inp2): - # type: (Tensor, Tensor) -> Tensor + def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: return inp1 + inp2 + 1 @torch.jit.script_method - def forward(self, input): - # type: (Tensor) -> Tensor + def forward(self, input: Tensor) -> Tensor: return input + self.one(input, input) + 1 class NewScriptModule(torch.jit.ScriptModule): @@ -424,13 +375,11 @@ def __init__(self): super(NewScriptModule, self).__init__() @torch.jit.script_method - def one(self, inp1, inp2): - # type: (Tensor, Tensor) -> Tensor + def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: return inp1 * inp2 + 1 @torch.jit.script_method - def forward(self, input): - # type: (Tensor) -> Tensor + def forward(self, input: Tensor) -> Tensor: return self.one(input, input + 1) class TestNNModuleWithScriptModule(nn.Module): @@ -440,8 +389,7 @@ def __init__(self): super(TestNNModuleWithScriptModule, self).__init__() self.proxy_mod = OrigScriptModule() - def forward(self, input): - # type: (Tensor) -> Tensor + def forward(self, input: Tensor) -> Tensor: return self.proxy_mod.forward(input) input = torch.randn(3, 4) @@ -472,8 +420,7 @@ def forward(self, x): @torch.jit.interface class ModInterface(torch.nn.Module): - def forward(self, x): - # type: (Tensor) -> int + def forward(self, x: Tensor) -> int: pass class TestModule(torch.nn.Module): @@ -520,8 +467,7 @@ def forward(self, x): @torch.jit.interface class ModInterface(torch.nn.Module): - def forward(self, x): - # type: (Tensor) -> int + def forward(self, x: Tensor) -> int: pass class TestModule(torch.nn.Module): @@ -564,8 +510,7 @@ def forward(self, x): @torch.jit.interface class ModInterface(torch.nn.Module): - def forward(self, x): - # type: (Tensor) -> Tensor + def forward(self, x: Tensor) -> Tensor: pass class TestModule(torch.nn.Module): @@ -610,8 +555,7 @@ def forward(self, x): @torch.jit.interface class ModInterface(torch.nn.Module): - def forward(self, x): - # type: (Tensor) -> Tensor + def forward(self, x: Tensor) -> Tensor: pass class TestModule(torch.nn.Module): @@ -653,8 +597,7 @@ def forward(self, x): @torch.jit.interface class ModInterface(torch.nn.Module): - def forward(self, x): - # type: (Tensor) -> Tensor + def forward(self, x: Tensor) -> Tensor: pass class TestModule(torch.nn.Module): @@ -688,8 +631,7 @@ def forward(self, x): def test_module_apis_interface(self): @torch.jit.interface class ModuleInterface(nn.Module): - def one(self, inp1, inp2): - # type: (Tensor, Tensor) -> Tensor + def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: pass class TestModule(nn.Module): diff --git a/test/jit/test_recursive_script.py b/test/jit/test_recursive_script.py index a84d9d7256b8..3532e62bfe04 100644 --- a/test/jit/test_recursive_script.py +++ b/test/jit/test_recursive_script.py @@ -284,8 +284,7 @@ def forward(self, x): test_module_dir(nn.ModuleDict(OrderedDict([("conv", conv), ("linear", linear)]))) def test_class_compile(self): - def other_fn(a, b): - # type: (int, Tensor) -> Tensor + def other_fn(a: int, b: Tensor) -> Tensor: return a * b class B(object): @@ -307,8 +306,7 @@ def forward(self, x): self.checkModule(N(), (torch.randn(2, 2),)) def test_error_stack(self): - def d(x): - # type: (int) -> int + def d(x: int) -> int: return x + 10 def c(x): @@ -331,8 +329,7 @@ def a(x): checker.run(str(e)) def test_error_stack_module(self): - def d(x): - # type: (int) -> int + def d(x: int) -> int: return x + 10 def c(x): @@ -565,8 +562,7 @@ def __init__(self): self.a = 4 self.inner = Inner2() - def __setstate__(self, obj): - # type: (Tuple[int, Inner2]) -> None + def __setstate__(self, obj: Tuple[int, Inner2]) -> None: a, inner = obj self.a = a self.inner = inner diff --git a/test/jit/test_save_load.py b/test/jit/test_save_load.py index eb14ca8350af..12d42a5150cd 100644 --- a/test/jit/test_save_load.py +++ b/test/jit/test_save_load.py @@ -680,8 +680,7 @@ def test_different_interfaces(self): """ @torch.jit.interface class MyInterface(object): - def bar(self, x): - # type: (Tensor) -> Tensor + def bar(self, x: Tensor) -> Tensor: pass @torch.jit.script @@ -711,8 +710,7 @@ def forward(self, x): @torch.jit.interface class MyInterface(object): - def not_bar(self, x): - # type: (Tensor) -> Tensor + def not_bar(self, x: Tensor) -> Tensor: pass @torch.jit.script # noqa: F811 @@ -767,8 +765,7 @@ class MyCoolNamedTuple(NamedTuple): @torch.jit.interface class MyInterface(object): - def bar(self, x): - # type: (Tensor) -> Tensor + def bar(self, x: Tensor) -> Tensor: pass @torch.jit.script @@ -809,8 +806,7 @@ def forward(self, x): @torch.jit.interface class MyInterface(object): - def not_bar(self, x): - # type: (Tensor) -> Tensor + def not_bar(self, x: Tensor) -> Tensor: pass @torch.jit.script # noqa F811 diff --git a/test/jit/test_tracer.py b/test/jit/test_tracer.py index 366ca1af69e6..841682666964 100644 --- a/test/jit/test_tracer.py +++ b/test/jit/test_tracer.py @@ -25,7 +25,7 @@ # Standard library from collections import namedtuple from itertools import chain -from typing import Dict +from typing import Dict, Optional import warnings if __name__ == '__main__': @@ -1862,13 +1862,11 @@ def forward(self, inputs): class TestMixTracingScripting(JitTestCase): def test_trace_script(self): @torch.jit.script - def func1(x): - # type: (Tuple[Tensor, Tensor]) -> Tensor + def func1(x: Tuple[Tensor, Tensor]) -> Tensor: return x[0] + x[1] @torch.jit.script - def func2(x): - # type: (List[Tensor]) -> Tensor + def func2(x: List[Tensor]) -> Tensor: return x[0] + x[1] a = torch.randn(5) @@ -1878,8 +1876,7 @@ def func2(x): self.checkTrace(func2, ((a, b),)) @torch.jit.script - def func3(x, method='bilinear', align_corners=True): - # type: (Tensor, str, bool) -> Tensor + def func3(x: Tensor, method: str = 'bilinear', align_corners: bool = True) -> Tensor: hw = x.shape[2:4] return F.interpolate(x, hw, mode=method, align_corners=align_corners) @@ -1887,8 +1884,7 @@ def func3(x, method='bilinear', align_corners=True): self.checkTrace(func3, (inp,)) @torch.jit.script - def func4(x, a): - # type: (Tensor, List[Optional[str]]) -> Tensor + def func4(x: Tensor, a: List[Optional[str]]) -> Tensor: if len(a) == 2: return x + 2 else: diff --git a/test/jit/test_with.py b/test/jit/test_with.py index ffd0631639f6..f958dc46c39a 100644 --- a/test/jit/test_with.py +++ b/test/jit/test_with.py @@ -1,7 +1,7 @@ import os import sys -from typing import Any +from typing import Any, List import torch from torch.testing._internal.jit_utils import JitTestCase @@ -50,8 +50,7 @@ def __enter__(self): def __exit__(self, type: Any, value: Any, tb: Any): self.count.sub_(0.3) - def test_basic(x): - # type: (Tensor) -> Tensor + def test_basic(x: torch.Tensor) -> torch.Tensor: """Basic test with one with-statement.""" c = Context(1) @@ -62,8 +61,7 @@ def test_basic(x): y *= c.count return y - def test_pass(x): - # type: (Tensor) -> Tensor + def test_pass(x: torch.Tensor) -> torch.Tensor: """ Test with a pass statement inside a with-statement. Although the body of the with is empty, __enter__ and __exit__ should @@ -77,8 +75,7 @@ def test_pass(x): x *= c.count return x - def test_early_return(x, c): - # type: (Tensor, Context) -> Tensor + def test_early_return(x: torch.Tensor, c: Context) -> torch.Tensor: """ Test that returning early from inside a with-statement works as expected. @@ -90,8 +87,7 @@ def test_early_return(x, c): x = y + y return x - def test_conditional_early_return(x, c): - # type: (Tensor, Context) -> Tensor + def test_conditional_early_return(x: torch.Tensor, c: Context) -> torch.Tensor: """ Test that conditionally returning early from inside a with-statement works as expected. @@ -104,8 +100,7 @@ def test_conditional_early_return(x, c): x = y + y return x - def test_break(x, c, l): - # type: (Tensor, Context, List[int]) -> Tensor + def test_break(x: torch.Tensor, c: Context, l: List[int]) -> torch.Tensor: """ Test that breaking early from inside a with-statement works as expected. @@ -118,8 +113,7 @@ def test_break(x, c, l): return x - def test_continue(x, c, l): - # type: (Tensor, Context, List[int]) -> Tensor + def test_continue(x: torch.Tensor, c: Context, l: List[int]) -> torch.Tensor: """ Test that using continue inside a with-statement works as expected. @@ -132,8 +126,7 @@ def test_continue(x, c, l): return x - def test_serial(x): - # type: (Tensor) -> Tensor + def test_serial(x: torch.Tensor) -> torch.Tensor: """ Test two with-statements in a row. """ @@ -147,8 +140,7 @@ def test_serial(x): return y - def test_nested(x): - # type: (Tensor) -> Tensor + def test_nested(x: torch.Tensor) -> torch.Tensor: """ Test nested with-statements. """ @@ -162,8 +154,7 @@ def test_nested(x): return y - def test_combined(x): - # type: (Tensor) -> Tensor + def test_combined(x: torch.Tensor) -> torch.Tensor: """ Test a with-statement with multiple with items. """ @@ -215,8 +206,7 @@ def __enter__(self): def __exit__(self, type: Any, value: Any, tb: Any): self.count.sub_(0.3) - def test_basic(x): - # type: (Tensor) -> Tensor + def test_basic(x: torch.Tensor) -> torch.Tensor: """Basic test with one with-statement.""" c = Context(1) @@ -227,8 +217,7 @@ def test_basic(x): y *= c.count return y - def test_pass(x): - # type: (Tensor) -> Tensor + def test_pass(x: torch.Tensor) -> torch.Tensor: """ Test with a pass statement inside a with-statement. Although the body of the with is empty, __enter__ and __exit__ should @@ -242,8 +231,7 @@ def test_pass(x): x *= c.count return x - def test_early_return(x, c): - # type: (Tensor, Context) -> Tensor + def test_early_return(x: torch.Tensor, c: Context) -> torch.Tensor: """ Test that returning early from inside a with-statement works as expected. @@ -255,8 +243,7 @@ def test_early_return(x, c): x = y + y return x - def test_conditional_early_return(x, c): - # type: (Tensor, Context) -> Tensor + def test_conditional_early_return(x: torch.Tensor, c: Context) -> torch.Tensor: """ Test that conditionally returning early from inside a with-statement works as expected. @@ -269,8 +256,7 @@ def test_conditional_early_return(x, c): x = y + y return x - def test_break(x, c, l): - # type: (Tensor, Context, List[int]) -> Tensor + def test_break(x: torch.Tensor, c: Context, l: List[int]) -> torch.Tensor: """ Test that breaking early from inside a with-statement works as expected. @@ -283,8 +269,7 @@ def test_break(x, c, l): return x - def test_continue(x, c, l): - # type: (Tensor, Context, List[int]) -> Tensor + def test_continue(x: torch.Tensor, c: Context, l: List[int]) -> torch.Tensor: """ Test that using continue inside a with-statement works as expected. @@ -297,8 +282,7 @@ def test_continue(x, c, l): return x - def test_serial(x): - # type: (Tensor) -> Tensor + def test_serial(x: torch.Tensor) -> torch.Tensor: """ Test two with-statements in a row. """ @@ -312,8 +296,7 @@ def test_serial(x): return y - def test_nested(x): - # type: (Tensor) -> Tensor + def test_nested(x: torch.Tensor) -> torch.Tensor: """ Test nested with-statements. """ @@ -327,8 +310,7 @@ def test_nested(x): return y - def test_combined(x): - # type: (Tensor) -> Tensor + def test_combined(x: torch.Tensor) -> torch.Tensor: """ Test a with-statement with multiple with items. """ @@ -381,13 +363,11 @@ def __exit__(self, type: Any, value: Any, tb: Any): self.count.sub_(0.3) @torch.jit.script - def method_that_raises(): - # type: () -> Tensor + def method_that_raises() -> torch.Tensor: raise Exception("raised exception") @torch.jit.script - def test_exception(x, c): - # type: (Tensor, Context) -> Tensor + def test_exception(x: torch.Tensor, c: Context) -> torch.Tensor: """ Test the case in which an exception is thrown while executing the body of a with-statement. """ @@ -397,8 +377,7 @@ def test_exception(x, c): return x @torch.jit.script - def test_exception_nested(x, c): - # type: (Tensor, Context) -> Tensor + def test_exception_nested(x: torch.Tensor, c: Context) -> torch.Tensor: """ Test the case in which an exception is thrown while executing the body of a nested with-statement. """ @@ -409,8 +388,7 @@ def test_exception_nested(x, c): return x @torch.jit.script - def with_that_raises(c): - # type: (Context) -> Tensor + def with_that_raises(c: Context) -> torch.Tensor: a = torch.tensor([1]) with c as _: @@ -419,8 +397,7 @@ def with_that_raises(c): return a @torch.jit.script - def test_exception_fn_call(x, c): - # type: (Tensor, Context) -> Tensor + def test_exception_fn_call(x: torch.Tensor, c: Context) -> torch.Tensor: """ Test the case in which an exception is thrown while there are active with-statements in two different frames. @@ -506,29 +483,25 @@ def __enter__(self): def __exit__(self, type: Any, value: int, tb: int): pass - def test_no_enter_no_exit(x, c): - # type: (Tensor, NoEnterNoExit) -> Tensor + def test_no_enter_no_exit(x: torch.Tensor, c: NoEnterNoExit) -> torch.Tensor: with c as _: pass return x - def test_bad_enter(x, c): - # type: (Tensor, BadEnter) -> Tensor + def test_bad_enter(x: torch.Tensor, c: BadEnter) -> torch.Tensor: with c as _: pass return x - def test_bad_exit(x, c): - # type: (Tensor, BadExit) -> Tensor + def test_bad_exit(x: torch.Tensor, c: BadExit) -> torch.Tensor: with c as _: pass return x - def test_exit_incorrect_types(x, c): - # type: (Tensor, ExitIncorrectTypes) -> Tensor + def test_exit_incorrect_types(x: torch.Tensor, c: ExitIncorrectTypes) -> torch.Tensor: with c as _: pass @@ -565,8 +538,7 @@ def test_with_no_grad(self): """ # Basic no_grad test. - def test_no_grad(x, y): - # type: (Tensor, Tensor) -> Tensor + def test_no_grad(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: with torch.no_grad(): w = x + y @@ -583,8 +555,7 @@ def test_no_grad(x, y): # Test assignment of a grad-less Tensor to a Tensor with gradients # in a no_grad block. - def test_no_grad_assignment(x, y): - # type: (Tensor, Tensor) -> Tensor + def test_no_grad_assignment(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: with torch.no_grad(): x[0] = y @@ -603,13 +574,11 @@ def __init__(self): super().__init__() @torch.jit.ignore - def adder(self, x, y): - # type: (Tensor, Tensor) -> Tensor + def adder(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: w = x + y return w - def forward(self, x, y): - # type: (Tensor, Tensor) -> Tensor + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: with torch.no_grad(): w = self.adder(x, y) @@ -625,8 +594,7 @@ def test_with_record_function(self): Check that torch.autograd.profiler.record_function context manager is torchscriptable. """ - def with_rf(x, y): - # type: (Tensor, Tensor) -> Tensor + def with_rf(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: with torch.autograd.profiler.record_function("foo"): # Nested record_function. with torch.autograd.profiler.record_function("nested"): From 5d45140d6874be04c22c8abba55e4438c25d2fdb Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Fri, 8 Jan 2021 11:01:52 -0800 Subject: [PATCH 45/54] [numpy] torch.{all/any} : output dtype is always bool (#47878) Summary: BC-breaking note: This PR changes the behavior of the any and all functions to always return a bool tensor. Previously these functions were only defined on bool and uint8 tensors, and when called on uint8 tensors they would also return a uint8 tensor. (When called on a bool tensor they would return a bool tensor.) PR summary: https://github.com/pytorch/pytorch/pull/44790#issuecomment-725596687 Fixes 2 and 3 Also Fixes https://github.com/pytorch/pytorch/issues/48352 Changes * Output dtype is always `bool` (consistent with numpy) **BC Breaking (Previously used to match the input dtype**) * Uses vectorized version for all dtypes on CPU * Enables test for complex * Update doc for `torch.all` and `torch.any` TODO * [x] Update docs * [x] Benchmark * [x] Raise issue on XLA Pull Request resolved: https://github.com/pytorch/pytorch/pull/47878 Reviewed By: albanD Differential Revision: D25714324 Pulled By: mruberry fbshipit-source-id: a87345f725297524242d69402dfe53060521ea5d --- aten/src/ATen/native/ReduceOps.cpp | 108 ++++++++-- aten/src/ATen/native/SharedReduceOps.h | 50 ----- aten/src/ATen/native/cpu/ReduceOpsKernel.cpp | 50 +++-- .../src/ATen/native/cuda/ReduceLogicKernel.cu | 35 ++-- docs/source/tensors.rst | 9 +- docs/source/torch.rst | 2 + test/test_reductions.py | 191 +++++++++--------- torch/_tensor_docs.py | 80 +------- torch/_torch_docs.py | 107 ++++++++++ 9 files changed, 361 insertions(+), 271 deletions(-) diff --git a/aten/src/ATen/native/ReduceOps.cpp b/aten/src/ATen/native/ReduceOps.cpp index e4b0a1cb19b7..fd27b3e7efe5 100644 --- a/aten/src/ATen/native/ReduceOps.cpp +++ b/aten/src/ATen/native/ReduceOps.cpp @@ -740,6 +740,12 @@ Tensor norm(const Tensor& self, Scalar p) { return at::native::_norm(self, p); } +// Note [all, any : uint8 compatibility]: +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// For NumPy comptability, `all` and `any` return +// Tensor of dtype `bool`. However for compatibility reason, +// for `uint8`, they return Tensor of same dtype `uint8`. +// Reference: https://github.com/pytorch/pytorch/pull/47878#issuecomment-747108561 inline Tensor & _all(Tensor & result, TensorIterator & iter) { if (iter.numel() == 0) { result.fill_(1); @@ -756,14 +762,40 @@ Tensor all(const Tensor& self) { TORCH_CHECK(self.layout() == Layout::Strided, "all only supports strided layout, got: ", self.layout()); - Tensor result = at::empty({0}, self.options()); - auto iter = make_reduction( - "all", result, self, {}, false, self.scalar_type()); + // Refer [all, any : uint8 compatibility] + Tensor result; + ScalarType out_dtype; + if (self.scalar_type() == ScalarType::Byte){ + result = at::empty({0}, self.options()); + out_dtype = self.scalar_type(); + } else { + result = at::empty({0}, self.options().dtype(kBool)); + out_dtype = ScalarType::Bool; + } + + if (self.is_cuda()) { + // As CUDA supports dynamic type casting, we use this overload of + // `make_reduction`, which doesn't cast input to the result type i.e. kBool., + // otherwise we use the overload below which casts the input to kBool (which is + // an extra operation). + auto iter = make_reduction( + "all", result, self, {}, false, self.scalar_type(), out_dtype); + return _all(result, iter); + } + auto iter = + make_reduction("all", result, self, {}, false, /*out_dtype=*/out_dtype); return _all(result, iter); } Tensor all(const Tensor& self, int64_t dim, bool keepdim) { - Tensor result = at::empty({0}, self.options()); + // Refer [all, any : uint8 compatibility] + Tensor result; + if (self.scalar_type() == ScalarType::Byte){ + result = at::empty({0}, self.options()); + } else { + result = at::empty({0}, self.options().dtype(kBool)); + } + return at::native::all_out(result, self, dim, keepdim); } @@ -772,13 +804,26 @@ Tensor &all_out(Tensor &result, const Tensor &self, int64_t dim, bool keepdim) { "all only supports CPU AND CUDA device type, got: ", self.device().type()); TORCH_CHECK(self.layout() == Layout::Strided, "all only supports strided layout, got: ", self.layout()); + // Refer [all, any : uint8 compatibility] + TORCH_CHECK(result.scalar_type() == ScalarType::Bool || result.scalar_type() == ScalarType::Byte, + "all only supports bool tensor for result, got: ", result.scalar_type()); + auto out_dtype = result.scalar_type(); dim = maybe_wrap_dim(dim, self.dim()); if (_dimreduce_return_trivial(result, self, 1, dim, keepdim)) { return result; } else { - auto iter = make_reduction( - "all", result, self, dim, keepdim, self.scalar_type()); + if (self.is_cuda()) { + // As CUDA supports dynamic type casting, we use this overload of + // `make_reduction`, which doesn't cast input to the result type i.e. kBool., + // otherwise we use the overload below which casts the input to kBool (which is + // an extra operation). + auto iter = make_reduction( + "all", result, self, dim, keepdim, self.scalar_type(), out_dtype); + return _all(result, iter); + } + auto iter = + make_reduction("all", result, self, dim, keepdim, /*out_dtype=*/out_dtype); return _all(result, iter); } } @@ -798,15 +843,41 @@ Tensor any(const Tensor& self) { "any only supports CPU AND CUDA device type, got: ", self.device().type()); TORCH_CHECK(self.layout() == Layout::Strided || self.layout() == Layout::Sparse, "any only supports strided AND sparse layout, got: ", self.layout()); + + // Refer [all, any : uint8 compatibility] + Tensor result; + ScalarType out_dtype; + if (self.scalar_type() == ScalarType::Byte){ + result = at::empty({0}, self.options()); + out_dtype = self.scalar_type(); + } else { + result = at::empty({0}, self.options().dtype(kBool)); + out_dtype = ScalarType::Bool; + } - Tensor result = at::empty({0}, self.options()); - auto iter = make_reduction( - "any", result, self, {}, false, self.scalar_type()); + if (self.is_cuda()) { + // As CUDA supports dynamic type casting, we use this overload of + // `make_reduction`, which doesn't cast input to the result type i.e. kBool., + // otherwise we use the overload below which casts the input to kBool (which is + // an extra operation). + auto iter = make_reduction( + "any", result, self, {}, false, self.scalar_type(), out_dtype); + return _any(result, iter); + } + auto iter = + make_reduction("any", result, self, {}, false, /*out_dtype=*/out_dtype); return _any(result, iter); } Tensor any(const Tensor& self, int64_t dim, bool keepdim) { - Tensor result = at::empty({0}, self.options()); + // Refer [all, any : uint8 compatibility] + Tensor result; + if (self.scalar_type() == ScalarType::Byte){ + result = at::empty({0}, self.options()); + } else { + result = at::empty({0}, self.options().dtype(kBool)); + } + return at::native::any_out(result, self, dim, keepdim); } @@ -815,13 +886,26 @@ Tensor &any_out(Tensor &result, const Tensor &self, int64_t dim, bool keepdim) { "any only supports CPU AND CUDA device type, got: ", self.device().type()); TORCH_CHECK(self.layout() == Layout::Strided, "any only supports strided layout, got: ", self.layout()); + // Refer [all, any : uint8 compatibility] + TORCH_CHECK(result.scalar_type() == ScalarType::Bool || result.scalar_type() == ScalarType::Byte, + "any only supports bool tensor for result, got: ", result.scalar_type()); + auto out_dtype = result.scalar_type(); dim = maybe_wrap_dim(dim, self.dim()); if (_dimreduce_return_trivial(result, self, 0, dim, keepdim)) { return result; } else { - auto iter = make_reduction( - "any", result, self, dim, keepdim, self.scalar_type()); + if (self.is_cuda()) { + // As CUDA supports dynamic type casting, we use this overload of + // `make_reduction`, which doesn't cast input to the result type i.e. kBool., + // otherwise we use the overload below which casts the input to kBool (which is + // an extra operation). + auto iter = make_reduction( + "any", result, self, dim, keepdim, self.scalar_type(), out_dtype); + return _any(result, iter); + } + auto iter = + make_reduction("any", result, self, dim, keepdim, /*out_dtype=*/out_dtype); return _any(result, iter); } } diff --git a/aten/src/ATen/native/SharedReduceOps.h b/aten/src/ATen/native/SharedReduceOps.h index 4106a90c0729..e25b943d13a8 100644 --- a/aten/src/ATen/native/SharedReduceOps.h +++ b/aten/src/ATen/native/SharedReduceOps.h @@ -386,56 +386,6 @@ struct NanSumOps { #endif }; -template -struct AndOps { - inline C10_DEVICE acc_t reduce(acc_t a, acc_t b, int64_t /*idx*/) const { - return static_cast(a) && static_cast(b); - } - - inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const { - return static_cast(a) && static_cast(b); - } - - inline C10_DEVICE acc_t project(acc_t a) const { - return a; - } - - static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) { - return acc; - } - -#if defined(__CUDACC__) || defined(__HIPCC__) - inline C10_DEVICE acc_t warp_shfl_down(acc_t data, int offset) const { - return WARP_SHFL_DOWN(data, offset); - } -#endif -}; - -template -struct OrOps { - inline C10_DEVICE acc_t reduce(acc_t a, acc_t b, int64_t /*idx*/) const { - return static_cast(a) || static_cast(b); - } - - inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const { - return static_cast(a) || static_cast(b); - } - - inline C10_DEVICE acc_t project(acc_t a) const { - return a; - } - - static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) { - return acc; - } - -#if defined(__CUDACC__) || defined(__HIPCC__) - inline C10_DEVICE acc_t warp_shfl_down(acc_t data, int offset) const { - return WARP_SHFL_DOWN(data, offset); - } -#endif -}; - namespace detail { template diff --git a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp index 32033abcd4e2..14f3d4a1fc21 100644 --- a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp @@ -256,11 +256,25 @@ static void norm_kernel_tensor_iterator_impl( } static void and_kernel_impl(TensorIterator& iter) { - if (c10::isIntegralType(iter.dtype(), /*includeBool=*/true)) { + if (iter.dtype() == ScalarType::Byte) { + // Refer [all, any : uint8 compatibility] binary_kernel_reduce_vec( iter, [=](uint8_t a, uint8_t b) -> uint8_t { return (a && b) ? 1 : 0; }, [=](Vec256 a, Vec256 b) { + Vec256 c = Vec256(); + + for (decltype(c.size()) i = 0; i != Vec256::size(); i++) { + c[i] = (a[i] && b[i]) ? 1 : 0; + } + return c; + }, + /*ident=*/true); + } else { + binary_kernel_reduce_vec( + iter, + [=](bool a, bool b) -> bool { return a && b; }, + [=](Vec256 a, Vec256 b) { // Adding the implementation here instead of in vec256_base to avoid // return value inconsistency. Other comparison operators in // vec256_base return -1/0 (all bit 1 / all bit 0) as true/false to @@ -271,39 +285,45 @@ static void and_kernel_impl(TensorIterator& iter) { // // In this method, users would expect, e.g., all(), to return 1/0 as // true/false. - Vec256 c = Vec256(); - for (int i = 0; i != Vec256::size(); i++) { - c[i] = (a[i] && b[i]) ? 1 : 0; + Vec256 c = Vec256(); + + for (decltype(c.size()) i = 0; i != Vec256::size(); i++) { + c[i] = a[i] && b[i]; } return c; }, /*ident=*/true); - } else { - AT_DISPATCH_FLOATING_TYPES_AND(kHalf, iter.dtype(), "and_kernel", [&]() { - binary_kernel_reduce( - iter, AndOps(), static_cast(true)); - }); } } static void or_kernel_impl(TensorIterator& iter) { - if (c10::isIntegralType(iter.dtype(), /*includeBool=*/true)) { + if (iter.dtype() == ScalarType::Byte) { + // Refer [all, any : uint8 compatibility] binary_kernel_reduce_vec( iter, [=](uint8_t a, uint8_t b) -> uint8_t { return (a || b) ? 1 : 0; }, [=](Vec256 a, Vec256 b) { Vec256 c = Vec256(); - for (int i = 0; i != Vec256::size(); i++) { + + for (decltype(c.size()) i = 0; i != Vec256::size(); i++) { c[i] = (a[i] || b[i]) ? 1 : 0; } return c; }, /*ident=*/false); } else { - AT_DISPATCH_FLOATING_TYPES_AND(kHalf, iter.dtype(), "or_kernel", [&]() { - binary_kernel_reduce( - iter, OrOps(), static_cast(false)); - }); + binary_kernel_reduce_vec( + iter, + [=](bool a, bool b) -> bool { return a || b; }, + [=](Vec256 a, Vec256 b) { + Vec256 c = Vec256(); + + for (decltype(c.size()) i = 0; i != Vec256::size(); i++) { + c[i] = a[i] || b[i]; + } + return c; + }, + /*ident=*/false); } } diff --git a/aten/src/ATen/native/cuda/ReduceLogicKernel.cu b/aten/src/ATen/native/cuda/ReduceLogicKernel.cu index a29a926ef257..fcf60678929e 100644 --- a/aten/src/ATen/native/cuda/ReduceLogicKernel.cu +++ b/aten/src/ATen/native/cuda/ReduceLogicKernel.cu @@ -3,30 +3,33 @@ #include #include #include +#include namespace at { namespace native { void and_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, iter.dtype(), "and_kernel", [&]() { - gpu_reduce_kernel( - iter, - func_wrapper([] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { - return static_cast(static_cast(a) && static_cast(b)); - }), - static_cast(true)); - }); + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( + kHalf, kBFloat16, kBool, iter.common_dtype(), "and_cuda", [&]() { + gpu_reduce_kernel( + iter, + func_wrapper([] GPU_LAMBDA(scalar_t a, scalar_t b) -> bool { + return (static_cast(a) && static_cast(b)); + }), + true); + }); } void or_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, iter.dtype(), "or_kernel", [&]() { - gpu_reduce_kernel( - iter, - func_wrapper([] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { - return static_cast(static_cast(a) || static_cast(b)); - }), - static_cast(false)); - }); + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( + kHalf, kBFloat16, kBool, iter.common_dtype(), "or_cuda", [&]() { + gpu_reduce_kernel( + iter, + func_wrapper([] GPU_LAMBDA(scalar_t a, scalar_t b) -> bool { + return (static_cast(a) || static_cast(b)); + }), + false); + }); } REGISTER_DISPATCH(and_stub, &and_kernel_cuda); diff --git a/docs/source/tensors.rst b/docs/source/tensors.rst index 315cc9dc5309..1baf34dd955e 100644 --- a/docs/source/tensors.rst +++ b/docs/source/tensors.rst @@ -214,6 +214,8 @@ view of a storage and defines numeric operations on it. .. automethod:: arctan_ .. automethod:: atan2 .. automethod:: atan2_ + .. automethod:: all + .. automethod:: any .. automethod:: backward :noindex: .. automethod:: baddbmm @@ -648,10 +650,3 @@ view of a storage and defines numeric operations on it. .. automethod:: xlogy .. automethod:: xlogy_ .. automethod:: zero_ - -.. class:: BoolTensor() - - The following methods are unique to :class:`torch.BoolTensor`. - - .. automethod:: all - .. automethod:: any diff --git a/docs/source/torch.rst b/docs/source/torch.rst index 46960ecdb1b4..922e1434bae1 100644 --- a/docs/source/torch.rst +++ b/docs/source/torch.rst @@ -364,6 +364,8 @@ Reduction Ops argmin amax amin + all + any max min dist diff --git a/test/test_reductions.py b/test/test_reductions.py index 917d469c5ee6..b08cebf7947b 100644 --- a/test/test_reductions.py +++ b/test/test_reductions.py @@ -1133,8 +1133,13 @@ def verify_against_numpy(t): verify_against_numpy(t) @dtypes(*(torch.testing.get_all_dtypes(include_half=True, include_bfloat16=False, - include_bool=True, include_complex=False))) + include_bool=True, include_complex=True))) def test_all_any_vs_numpy(self, device, dtype): + # Note [all, any uint8 compatibility]: However for compatibility reason, + # for `uint8`, they return Tensor of same dtype `uint8`. + # Reference: https://github.com/pytorch/pytorch/pull/47878#issuecomment-747108561 + exact_dtype = True if dtype != torch.uint8 else False + def _test_all_any(x): self.compare_with_numpy(torch.all, np.all, x) self.compare_with_numpy(torch.any, np.any, x) @@ -1142,38 +1147,102 @@ def _test_all_any(x): def _test_all_any_with_dim(x, dim): torch_fn = partial(torch.all, dim=dim) np_fn = partial(np.all, axis=dim) - self.compare_with_numpy(torch_fn, np_fn, x, exact_dtype=False) + self.compare_with_numpy(torch_fn, np_fn, x, exact_dtype=exact_dtype) torch_fn = partial(torch.any, dim=dim) np_fn = partial(np.any, axis=dim) - self.compare_with_numpy(torch_fn, np_fn, x, exact_dtype=False) + self.compare_with_numpy(torch_fn, np_fn, x, exact_dtype=exact_dtype) + + def _test_out_variant(x, dim): + out = torch.empty_like(x) + if dtype == torch.bool or dtype == torch.uint8: + expected = torch.all(x, dim) + torch.all(x, dim, out=out) + self.assertEqual(expected, out) + + expected = torch.any(x, dim) + torch.any(x, dim, out=out) + self.assertEqual(expected, out) + else: + with self.assertRaisesRegex(RuntimeError, "all only supports bool tensor for result, got"): + torch.all(x, dim, out=out) + + with self.assertRaisesRegex(RuntimeError, "any only supports bool tensor for result, got"): + torch.any(x, dim, out=out) + + def _test_all_any_with_dim_keepdim(x, dim, keepdim): + torch_fn = partial(torch.all, dim=dim, keepdim=keepdim) + np_fn = partial(np.all, axis=dim, keepdims=keepdim) + self.compare_with_numpy(torch_fn, np_fn, x, exact_dtype=exact_dtype) + + torch_fn = partial(torch.any, dim=dim, keepdim=keepdim) + np_fn = partial(np.any, axis=dim, keepdims=keepdim) + self.compare_with_numpy(torch_fn, np_fn, x, exact_dtype=exact_dtype) + + def _test_output_dtype(x): + # This test will fail once the functions return bool output + # for uint8 input. + expected_dtype = torch.uint8 if dtype == torch.uint8 else torch.bool + self.assertEqual(torch.all(x).dtype, expected_dtype) + self.assertEqual(torch.any(x).dtype, expected_dtype) + + self.assertEqual(torch.all(x, dim=0).dtype, expected_dtype) + self.assertEqual(torch.any(x, dim=0).dtype, expected_dtype) for ndim in range(5): shape = _rand_shape(ndim, 1, 5) x = _generate_input(shape, dtype, device, with_extremal=False) _test_all_any(x) + _test_all_any(x.T) + _test_all_any(x[..., ::2]) x = _generate_input(shape, dtype, device, with_extremal=True) _test_all_any(x) + _test_all_any(x.T) + _test_all_any(x[..., ::2]) x = torch.zeros_like(x) _test_all_any(x) + _test_all_any(x.T) + _test_all_any(x[..., ::2]) x = torch.ones_like(x) _test_all_any(x) - + _test_all_any(x.T) + _test_all_any(x[..., ::2]) + _test_output_dtype(x) for dim in range(ndim): x = _generate_input(shape, dtype, device, with_extremal=False) _test_all_any_with_dim(x, dim) + _test_all_any_with_dim(x.T, dim) + _test_all_any_with_dim(x[..., ::2], dim) + _test_out_variant(x, dim) + _test_all_any_with_dim_keepdim(x, dim, keepdim=True) + _test_all_any_with_dim_keepdim(x, dim, keepdim=False) x = _generate_input(shape, dtype, device, with_extremal=True) _test_all_any_with_dim(x, dim) + _test_all_any_with_dim(x.T, dim) + _test_all_any_with_dim(x[..., ::2], dim) + _test_out_variant(x, dim) + _test_all_any_with_dim_keepdim(x, dim, keepdim=True) + _test_all_any_with_dim_keepdim(x, dim, keepdim=False) x = torch.zeros_like(x) _test_all_any_with_dim(x, dim) + _test_all_any_with_dim(x.T, dim) + _test_all_any_with_dim(x[..., ::2], dim) + _test_out_variant(x, dim) + _test_all_any_with_dim_keepdim(x, dim, keepdim=True) + _test_all_any_with_dim_keepdim(x, dim, keepdim=False) x = torch.ones_like(x) _test_all_any_with_dim(x, dim) + _test_all_any_with_dim(x.T, dim) + _test_all_any_with_dim(x[..., ::2], dim) + _test_out_variant(x, dim) + _test_all_any_with_dim_keepdim(x, dim, keepdim=True) + _test_all_any_with_dim_keepdim(x, dim, keepdim=False) # TODO: part of this test covers torch.norm, with should be covered by test_linalg @onlyOnCPUAndCUDA @@ -1851,82 +1920,6 @@ def check(a, q, args, kwargs, message): RuntimeError, r'quantile\(\) out tensor must be on the same device as the input tensor'): torch.quantile(torch.randn(1, device=device), 0.5, out=torch.scalar_tensor(1)) - def test_logical_any(self, device): - x = torch.zeros([2, 3, 400], dtype=torch.uint8, device=device) - - self.assertEqual( - torch.tensor(0, dtype=torch.uint8, device=device), - x.any()) - - self.assertEqual( - torch.zeros([1, 3, 400], dtype=torch.uint8, device=device), - x.any(0, keepdim=True)) - - self.assertEqual( - torch.zeros([2, 1, 400], dtype=torch.uint8, device=device), - x.any(1, keepdim=True)) - - self.assertEqual( - torch.zeros([2, 3, 1], dtype=torch.uint8, device=device), - x.any(2, keepdim=True)) - - # set the last element to 0 - x[-1][-1][-1] = 1 - - self.assertEqual( - torch.tensor(1, dtype=torch.uint8, device=device), - x.any()) - - y = torch.zeros([1, 3, 400], dtype=torch.uint8, device=device) - y[-1][-1][-1] = 1 - self.assertEqual(y, x.any(0, keepdim=True)) - - y = torch.zeros([2, 1, 400], dtype=torch.uint8, device=device) - y[-1][-1][-1] = 1 - self.assertEqual(y, x.any(1, keepdim=True)) - - y = torch.zeros([2, 3, 1], dtype=torch.uint8, device=device) - y[-1][-1][-1] = 1 - self.assertEqual(y, x.any(2, keepdim=True)) - - def test_logical_all(self, device): - x = torch.ones([2, 3, 400], dtype=torch.uint8, device=device) - - self.assertEqual( - torch.tensor(1, dtype=torch.uint8, device=device), - x.all()) - - self.assertEqual( - torch.ones([1, 3, 400], dtype=torch.uint8, device=device), - x.all(0, keepdim=True)) - - self.assertEqual( - torch.ones([2, 1, 400], dtype=torch.uint8, device=device), - x.all(1, keepdim=True)) - - self.assertEqual( - torch.ones([2, 3, 1], dtype=torch.uint8, device=device), - x.all(2, keepdim=True)) - - # set the last element to 0 - x[-1][-1][-1] = 0 - - self.assertEqual( - torch.tensor(0, dtype=torch.uint8, device=device), - x.all()) - - y = torch.ones([1, 3, 400], dtype=torch.uint8, device=device) - y[-1][-1][-1] = 0 - self.assertEqual(y, x.all(0, keepdim=True)) - - y = torch.ones([2, 1, 400], dtype=torch.uint8, device=device) - y[-1][-1][-1] = 0 - self.assertEqual(y, x.all(1, keepdim=True)) - - y = torch.ones([2, 3, 1], dtype=torch.uint8, device=device) - y[-1][-1][-1] = 0 - self.assertEqual(y, x.all(2, keepdim=True)) - def test_std_mean(self, device): x = torch.rand(100, 50, 20, device=device) for dim in range(x.dim()): @@ -2251,21 +2244,29 @@ def test_reduction_empty(self, device): # ignore if there is no allreduce. self.assertTrue('dim' in str(err)) - # any - xb = x.to(torch.uint8) - yb = x.to(torch.uint8) - self.assertEqual((2, 0), xb.any(2).shape) - self.assertEqual((2, 0, 1), xb.any(2, keepdim=True).shape) - self.assertEqual(torch.zeros((2, 4), device=device, dtype=torch.uint8), xb.any(1)) - self.assertEqual(torch.zeros((2, 1, 4), device=device, dtype=torch.uint8), xb.any(1, keepdim=True)) - self.assertEqual(torch.zeros((), device=device, dtype=torch.uint8), xb.any()) - - # all - self.assertEqual((2, 0), xb.all(2).shape) - self.assertEqual((2, 0, 1), xb.all(2, keepdim=True).shape) - self.assertEqual(torch.ones((2, 4), device=device, dtype=torch.uint8), xb.all(1)) - self.assertEqual(torch.ones((2, 1, 4), device=device, dtype=torch.uint8), xb.all(1, keepdim=True)) - self.assertEqual(torch.ones((), device=device, dtype=torch.uint8), xb.all()) + for dtype in torch.testing.get_all_dtypes(include_half=True, include_bfloat16=False, + include_bool=True, include_complex=True): + # Refer: [all, any uint8 compatibility] + if dtype == torch.uint8: + out_dtype = torch.uint8 + else: + out_dtype = torch.bool # output of all/any is bool irrespective of input dtype + + # any + xb = x.to(dtype) + yb = x.to(dtype) + self.assertEqual((2, 0), xb.any(2).shape) + self.assertEqual((2, 0, 1), xb.any(2, keepdim=True).shape) + self.assertEqual(torch.zeros((2, 4), device=device, dtype=out_dtype), xb.any(1)) + self.assertEqual(torch.zeros((2, 1, 4), device=device, dtype=out_dtype), xb.any(1, keepdim=True)) + self.assertEqual(torch.zeros((), device=device, dtype=out_dtype), xb.any()) + + # all + self.assertEqual((2, 0), xb.all(2).shape) + self.assertEqual((2, 0, 1), xb.all(2, keepdim=True).shape) + self.assertEqual(torch.ones((2, 4), device=device, dtype=out_dtype), xb.all(1)) + self.assertEqual(torch.ones((2, 1, 4), device=device, dtype=out_dtype), xb.all(1, keepdim=True)) + self.assertEqual(torch.ones((), device=device, dtype=out_dtype), xb.all()) instantiate_device_type_tests(TestReductions, globals()) diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index fd879427b56e..83bc04113672 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -432,45 +432,9 @@ def add_docstr_all(method, docstr): add_docstr_all('all', r""" -.. function:: all() -> bool +all(dim=None, keepdim=False) -> Tensor -Returns True if all elements in the tensor are True, False otherwise. - -Example:: - - >>> a = torch.rand(1, 2).bool() - >>> a - tensor([[False, True]], dtype=torch.bool) - >>> a.all() - tensor(False, dtype=torch.bool) - -.. function:: all(dim, keepdim=False, out=None) -> Tensor - -Returns True if all elements in each row of the tensor in the given -dimension :attr:`dim` are True, False otherwise. - -If :attr:`keepdim` is ``True``, the output tensor is of the same size as -:attr:`input` except in the dimension :attr:`dim` where it is of size 1. -Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting -in the output tensor having 1 fewer dimension than :attr:`input`. - -Args: - dim (int): the dimension to reduce - keepdim (bool): whether the output tensor has :attr:`dim` retained or not - out (Tensor, optional): the output tensor - -Example:: - - >>> a = torch.rand(4, 2).bool() - >>> a - tensor([[True, True], - [True, False], - [True, True], - [True, True]], dtype=torch.bool) - >>> a.all(dim=1) - tensor([ True, False, True, True], dtype=torch.bool) - >>> a.all(dim=0) - tensor([ True, False], dtype=torch.bool) +See :func:`torch.all` """) add_docstr_all('allclose', @@ -489,45 +453,9 @@ def add_docstr_all(method, docstr): add_docstr_all('any', r""" -.. function:: any() -> bool - -Returns True if any elements in the tensor are True, False otherwise. - -Example:: - - >>> a = torch.rand(1, 2).bool() - >>> a - tensor([[False, True]], dtype=torch.bool) - >>> a.any() - tensor(True, dtype=torch.bool) - -.. function:: any(dim, keepdim=False, out=None) -> Tensor - -Returns True if any elements in each row of the tensor in the given -dimension :attr:`dim` are True, False otherwise. - -If :attr:`keepdim` is ``True``, the output tensor is of the same size as -:attr:`input` except in the dimension :attr:`dim` where it is of size 1. -Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting -in the output tensor having 1 fewer dimension than :attr:`input`. - -Args: - dim (int): the dimension to reduce - keepdim (bool): whether the output tensor has :attr:`dim` retained or not - out (Tensor, optional): the output tensor - -Example:: +any(dim=None, keepdim=False) -> Tensor - >>> a = torch.randn(4, 2) < 0 - >>> a - tensor([[ True, True], - [False, True], - [ True, True], - [False, False]]) - >>> a.any(1) - tensor([ True, True, True, False]) - >>> a.any(0) - tensor([True, True]) +See :func:`torch.any` """) add_docstr_all('apply_', diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index b08db84d9459..d4b377a23750 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -610,6 +610,113 @@ def merge_dicts(*dicts): True """) +add_docstr(torch.all, + r""" +all(input) -> Tensor + +Tests if all elements in :attr:`input` evaluate to `True`. + +.. note:: This function matches the behaviour of NumPy in returning + output of dtype `bool` for all supported dtypes except `uint8`. + For `uint8` the dtype of output is `uint8` itself. + +Example:: + + >>> a = torch.rand(1, 2).bool() + >>> a + tensor([[False, True]], dtype=torch.bool) + >>> torch.all(a) + tensor(False, dtype=torch.bool) + >>> a = torch.arange(0, 3) + >>> a + tensor([0, 1, 2]) + >>> torch.all(a) + tensor(False) + +.. function:: all(input, dim, keepdim=False, *, out=None) -> Tensor + +For each row of :attr:`input` in the given dimension :attr:`dim`, +returns `True` if all elements in the row evaluate to `True` and `False` otherwise. + +{keepdim_details} + +Args: + {input} + {dim} + {keepdim} + +Keyword args: + {out} + +Example:: + + >>> a = torch.rand(4, 2).bool() + >>> a + tensor([[True, True], + [True, False], + [True, True], + [True, True]], dtype=torch.bool) + >>> torch.all(a, dim=1) + tensor([ True, False, True, True], dtype=torch.bool) + >>> torch.all(a, dim=0) + tensor([ True, False], dtype=torch.bool) +""".format(**single_dim_common)) + +add_docstr(torch.any, + r""" +any(input) -> Tensor + +Args: + {input} + +Tests if any element in :attr:`input` evaluates to `True`. + +.. note:: This function matches the behaviour of NumPy in returning + output of dtype `bool` for all supported dtypes except `uint8`. + For `uint8` the dtype of output is `uint8` itself. + +Example:: + + >>> a = torch.rand(1, 2).bool() + >>> a + tensor([[False, True]], dtype=torch.bool) + >>> torch.any(a) + tensor(True, dtype=torch.bool) + >>> a = torch.arange(0, 3) + >>> a + tensor([0, 1, 2]) + >>> torch.any(a) + tensor(True) + +.. function:: any(input, dim, keepdim=False, *, out=None) -> Tensor + +For each row of :attr:`input` in the given dimension :attr:`dim`, +returns `True` if any element in the row evaluate to `True` and `False` otherwise. + +{keepdim_details} + +Args: + {input} + {dim} + {keepdim} + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(4, 2) < 0 + >>> a + tensor([[ True, True], + [False, True], + [ True, True], + [False, False]]) + >>> torch.any(a, 1) + tensor([ True, True, True, False]) + >>> torch.any(a, 0) + tensor([True, True]) +""".format(**single_dim_common)) + add_docstr(torch.angle, r""" angle(input, *, out=None) -> Tensor From d78b638a3110dbd6cac8dd7688542b4cc2ac53f6 Mon Sep 17 00:00:00 2001 From: Thomas Zhang Date: Fri, 8 Jan 2021 11:13:20 -0800 Subject: [PATCH 46/54] Convert string => raw strings so char classes can be represented in Python regex (#50239) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/50239 Convert regex strings that have character classes (e.g. \d, \s, \w, \b, etc) into raw strings so they won't be interpreted as escape characters. References: Python RegEx - https://www.w3schools.com/python/python_regex.asp Python Escape Chars - https://www.w3schools.com/python/gloss_python_escape_characters.asp Python Raw String - https://www.journaldev.com/23598/python-raw-string Python RegEx Docs - https://docs.python.org/3/library/re.html Python String Tester - https://www.w3schools.com/python/trypython.asp?filename=demo_string_escape Python Regex Tester - https://regex101.com/ Test Plan: To find occurrences of regex strings with the above issue in VS Code, search using the regex \bre\.[a-z]+\(['"], and under 'files to include', use /data/users/your_username/fbsource/fbcode/caffe2. Reviewed By: r-barnes Differential Revision: D25813302 fbshipit-source-id: df9e23c0a84c49175eaef399ca6d091bfbeed936 --- scripts/release_notes/commitlist.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/release_notes/commitlist.py b/scripts/release_notes/commitlist.py index 0a76f896f217..552641f54674 100644 --- a/scripts/release_notes/commitlist.py +++ b/scripts/release_notes/commitlist.py @@ -97,7 +97,7 @@ def filter(self, *, category=None, topic=None): if topic is not None: commits = [commit for commit in commits if commit.topic == topic] return commits - + def update_to(self, new_version): last_hash = self.commits[-1].commit_hash new_commits = CommitList.get_commits_between(last_hash, new_version) @@ -121,7 +121,7 @@ def update_existing(path, new_version): def to_markdown(commit_list, category): def cleanup_title(commit): - match = re.match('(.*) \(#\d+\)', commit.title) + match = re.match(r'(.*) \(#\d+\)', commit.title) if match is None: return commit.title return match.group(1) From 0bb341daaadee7b93d1290db0305843508702128 Mon Sep 17 00:00:00 2001 From: Ailing Zhang Date: Fri, 8 Jan 2021 11:30:15 -0800 Subject: [PATCH 47/54] Dump state when hitting ambiguous_autogradother_kernel. (#50246) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/50246 Test Plan: Imported from OSS Reviewed By: bhosmer Differential Revision: D25843205 Pulled By: ailzhang fbshipit-source-id: 66916ae477a4ae97e1695227fc6af78c4f328ea3 --- aten/src/ATen/core/boxing/KernelFunction.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aten/src/ATen/core/boxing/KernelFunction.cpp b/aten/src/ATen/core/boxing/KernelFunction.cpp index 58c35557018c..260343ac3180 100644 --- a/aten/src/ATen/core/boxing/KernelFunction.cpp +++ b/aten/src/ATen/core/boxing/KernelFunction.cpp @@ -24,7 +24,7 @@ void ambiguous_autogradother_kernel(OperatorKernel*, const OperatorHandle& op, S op.operator_name(), " has kernels registered to both Math and a backend mapped to AutogradOther. " "This makes the backend kernel unreachable (see Note [Ambiguity in AutogradOther kernel]). " "If it's intended to override Math kernel behavior, please open an issue to request a dedicated " - "Autograd dispatch key for the backend."); + "Autograd dispatch key for the backend.", "\nCanonical state\n~~~~~~~~~~~\n", op.dumpState(), "\n\n"); } void named_not_supported_kernel(OperatorKernel*, const OperatorHandle& op, Stack*) { From f9f758e3497b357668c2a259ea070bb33c591a0b Mon Sep 17 00:00:00 2001 From: Shen Li Date: Fri, 8 Jan 2021 11:45:56 -0800 Subject: [PATCH 48/54] Apply clang-format to rpc cpp files (#50236) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/50236 Test Plan: Imported from OSS Reviewed By: lw Differential Revision: D25847892 Pulled By: mrshenli fbshipit-source-id: b4af1221acfcaba8903c629869943abbf877e04e --- torch/csrc/distributed/rpc/init.cpp | 9 +- torch/csrc/distributed/rpc/message.h | 38 +++-- .../distributed/rpc/process_group_agent.cpp | 40 ++--- torch/csrc/distributed/rpc/py_rref.cpp | 3 +- torch/csrc/distributed/rpc/py_rref.h | 5 +- .../csrc/distributed/rpc/request_callback.cpp | 3 +- torch/csrc/distributed/rpc/request_callback.h | 3 +- .../distributed/rpc/request_callback_impl.cpp | 2 +- .../rpc/request_callback_no_python.cpp | 161 +++++++++--------- .../rpc/request_callback_no_python.h | 3 +- torch/csrc/distributed/rpc/rpc_agent.cpp | 18 +- torch/csrc/distributed/rpc/rref_context.h | 10 +- torch/csrc/distributed/rpc/rref_impl.cpp | 4 +- .../csrc/distributed/rpc/tensorpipe_agent.cpp | 23 +-- torch/csrc/distributed/rpc/tensorpipe_agent.h | 2 - 15 files changed, 169 insertions(+), 155 deletions(-) diff --git a/torch/csrc/distributed/rpc/init.cpp b/torch/csrc/distributed/rpc/init.cpp index 6204f6e343ed..f1f8e39cd7e3 100644 --- a/torch/csrc/distributed/rpc/init.cpp +++ b/torch/csrc/distributed/rpc/init.cpp @@ -44,7 +44,8 @@ PyObject* rpc_init(PyObject* _unused, PyObject* noargs) { } auto torch_C_m = py::handle(torch_C_module).cast(); - auto m = torch_C_m.def_submodule("_distributed_rpc", "distributed rpc bindings"); + auto m = + torch_C_m.def_submodule("_distributed_rpc", "distributed rpc bindings"); auto module = py::handle(m).cast(); @@ -129,12 +130,12 @@ PyObject* rpc_init(PyObject* _unused, PyObject* noargs) { .def( "get_worker_info", (const WorkerInfo& (RpcAgent::*)(void)const) & - RpcAgent::getWorkerInfo, + RpcAgent::getWorkerInfo, py::call_guard()) - .def( + .def( "get_worker_info", (const WorkerInfo& (RpcAgent::*)(const std::string&)const) & - RpcAgent::getWorkerInfo, + RpcAgent::getWorkerInfo, py::call_guard()) .def( "get_worker_infos", diff --git a/torch/csrc/distributed/rpc/message.h b/torch/csrc/distributed/rpc/message.h index 7ff487b9b337..3d2d623e821f 100644 --- a/torch/csrc/distributed/rpc/message.h +++ b/torch/csrc/distributed/rpc/message.h @@ -35,19 +35,35 @@ enum MessageType { PYTHON_RET = 3 | MessageTypeFlags::RESPONSE_TYPE, // messages for dist.remote on builtin operators and Python UDF - SCRIPT_REMOTE_CALL = 4 | MessageTypeFlags::REQUEST_TYPE, // A remote call on a builtin operator - PYTHON_REMOTE_CALL = 5 | MessageTypeFlags::REQUEST_TYPE, // A remote call on a Python UDF - REMOTE_RET = 6 | MessageTypeFlags::RESPONSE_TYPE, // Response for remote calls for UDF, builtin, or script + SCRIPT_REMOTE_CALL = + 4 | MessageTypeFlags::REQUEST_TYPE, // A remote call on a builtin operator + PYTHON_REMOTE_CALL = + 5 | MessageTypeFlags::REQUEST_TYPE, // A remote call on a Python UDF + REMOTE_RET = + 6 | MessageTypeFlags::RESPONSE_TYPE, // Response for remote calls for UDF, + // builtin, or script // RRef related internal messages - SCRIPT_RREF_FETCH_CALL = 7 | MessageTypeFlags::REQUEST_TYPE, // A UserRRef fetches value from owner - PYTHON_RREF_FETCH_CALL = 8 | MessageTypeFlags::REQUEST_TYPE, // A UserRRef fetches value from owner - SCRIPT_RREF_FETCH_RET = 9 | MessageTypeFlags::RESPONSE_TYPE, // An OwnerRRef sends ivalue to user - PYTHON_RREF_FETCH_RET = 10 | MessageTypeFlags::RESPONSE_TYPE, // An OwnerRRef sends py::object to user - RREF_USER_DELETE = 11 | MessageTypeFlags::REQUEST_TYPE, // A UserRRef tells the owner to deref - RREF_FORK_REQUEST = 12 | MessageTypeFlags::REQUEST_TYPE, // A child UserRRef tells the owner about itself - RREF_CHILD_ACCEPT = 13 | MessageTypeFlags::REQUEST_TYPE, // A child UserRRef tells parent that owner knows it - RREF_ACK = 14 | MessageTypeFlags::RESPONSE_TYPE, // ACK to internal RRef messages + SCRIPT_RREF_FETCH_CALL = + 7 | MessageTypeFlags::REQUEST_TYPE, // A UserRRef fetches value + // from owner + PYTHON_RREF_FETCH_CALL = + 8 | MessageTypeFlags::REQUEST_TYPE, // A UserRRef fetches + // value from owner + SCRIPT_RREF_FETCH_RET = + 9 | MessageTypeFlags::RESPONSE_TYPE, // An OwnerRRef sends ivalue to user + PYTHON_RREF_FETCH_RET = 10 | + MessageTypeFlags::RESPONSE_TYPE, // An OwnerRRef sends py::object to user + RREF_USER_DELETE = 11 | + MessageTypeFlags::REQUEST_TYPE, // A UserRRef tells the owner to deref + RREF_FORK_REQUEST = + 12 | MessageTypeFlags::REQUEST_TYPE, // A child UserRRef tells the owner + // about itself + RREF_CHILD_ACCEPT = + 13 | MessageTypeFlags::REQUEST_TYPE, // A child UserRRef tells parent that + // owner knows it + RREF_ACK = + 14 | MessageTypeFlags::RESPONSE_TYPE, // ACK to internal RRef messages // Messages with autograd info FORWARD_AUTOGRAD_REQ = 15 | MessageTypeFlags::REQUEST_TYPE, diff --git a/torch/csrc/distributed/rpc/process_group_agent.cpp b/torch/csrc/distributed/rpc/process_group_agent.cpp index be3b1803b8d6..9c1a703cfa6d 100644 --- a/torch/csrc/distributed/rpc/process_group_agent.cpp +++ b/torch/csrc/distributed/rpc/process_group_agent.cpp @@ -540,26 +540,26 @@ bool ProcessGroupAgent::handleRecv(RecvWork& work) { // Use a weak_ptr, so we can std::move the future's value. auto fromId = work.from_.id_; auto requestId = work.id_; - futureResponse->addCallback([this, - fromId, - requestId, - weak = std::weak_ptr( - futureResponse)]() { - auto futureResponse = weak.lock(); - TORCH_INTERNAL_ASSERT(futureResponse); - --serverActiveCalls_; - --serverActiveAsyncCalls_; - if (!futureResponse->hasError()) { - send( - getWorkerInfo(fromId), - std::move(*futureResponse->value().toCustomClass())); - } else { - send( - getWorkerInfo(fromId), - createExceptionResponse( - futureResponse->tryRetrieveErrorMessage(), requestId)); - } - }); + futureResponse->addCallback( + [this, + fromId, + requestId, + weak = std::weak_ptr(futureResponse)]() { + auto futureResponse = weak.lock(); + TORCH_INTERNAL_ASSERT(futureResponse); + --serverActiveCalls_; + --serverActiveAsyncCalls_; + if (!futureResponse->hasError()) { + send( + getWorkerInfo(fromId), + std::move(*futureResponse->value().toCustomClass())); + } else { + send( + getWorkerInfo(fromId), + createExceptionResponse( + futureResponse->tryRetrieveErrorMessage(), requestId)); + } + }); } } else if (message.isResponse()) { auto id = message.id(); diff --git a/torch/csrc/distributed/rpc/py_rref.cpp b/torch/csrc/distributed/rpc/py_rref.cpp index 7c39d1d812a7..cface0b88551 100644 --- a/torch/csrc/distributed/rpc/py_rref.cpp +++ b/torch/csrc/distributed/rpc/py_rref.cpp @@ -137,8 +137,7 @@ c10::intrusive_ptr PyRRef::getFuture() const { // Marking hasValue to false, as this Future is only used for signaling // profiler to update profiling result and the profiler does not retrieve // any value from it. - return toPyJitFuture( - rref_->getOwnerCreationFuture(), false /* hasValue */); + return toPyJitFuture(rref_->getOwnerCreationFuture(), false /* hasValue */); } c10::intrusive_ptr PyRRef::getProfilingFuture() const { diff --git a/torch/csrc/distributed/rpc/py_rref.h b/torch/csrc/distributed/rpc/py_rref.h index 1b61c8c2fbf0..e3ae22fb8320 100644 --- a/torch/csrc/distributed/rpc/py_rref.h +++ b/torch/csrc/distributed/rpc/py_rref.h @@ -55,7 +55,10 @@ class PYBIND11_EXPORT PyRRef { void backward(int64_t autogradContextId, bool retainGraph); // Helper static function to run backward on a given rref. - static void backward(int64_t autogradContextId, bool retainGraph, const c10::intrusive_ptr& rref); + static void backward( + int64_t autogradContextId, + bool retainGraph, + const c10::intrusive_ptr& rref); private: c10::intrusive_ptr rref_; diff --git a/torch/csrc/distributed/rpc/request_callback.cpp b/torch/csrc/distributed/rpc/request_callback.cpp index e690e9ac90dd..33703c923523 100644 --- a/torch/csrc/distributed/rpc/request_callback.cpp +++ b/torch/csrc/distributed/rpc/request_callback.cpp @@ -9,8 +9,7 @@ namespace rpc { using namespace torch::distributed::autograd; -std::shared_ptr RequestCallback::operator()( - Message& request) const { +std::shared_ptr RequestCallback::operator()(Message& request) const { // NB: cannot clear autograd context id here because the processMessage method // might pause waiting for all RRefs in the arguments to be confirmed by their // owners and resumne processing in a different thread. Hence, the diff --git a/torch/csrc/distributed/rpc/request_callback.h b/torch/csrc/distributed/rpc/request_callback.h index bf1231eefd7f..128cf9590034 100644 --- a/torch/csrc/distributed/rpc/request_callback.h +++ b/torch/csrc/distributed/rpc/request_callback.h @@ -24,8 +24,7 @@ class TORCH_API RequestCallback { // message containing an exception. Different rpc agent implementations are // expected to ensure delivery of the response/exception based on their // implementation specific mechanisms. - virtual std::shared_ptr processMessage( - Message& request) const = 0; + virtual std::shared_ptr processMessage(Message& request) const = 0; }; } // namespace rpc diff --git a/torch/csrc/distributed/rpc/request_callback_impl.cpp b/torch/csrc/distributed/rpc/request_callback_impl.cpp index 3fa6b63953f0..684ca5576a56 100644 --- a/torch/csrc/distributed/rpc/request_callback_impl.cpp +++ b/torch/csrc/distributed/rpc/request_callback_impl.cpp @@ -559,7 +559,7 @@ void RequestCallbackImpl::processRRefBackward( Message m = RRefBackwardResp().toMessage(); m.setId(messageId); responseFuture->markCompleted( - IValue(c10::make_intrusive(std::move(m)))); + IValue(c10::make_intrusive(std::move(m)))); } catch (const std::exception& /* unused */) { responseFuture->setError(std::current_exception()); } diff --git a/torch/csrc/distributed/rpc/request_callback_no_python.cpp b/torch/csrc/distributed/rpc/request_callback_no_python.cpp index db455e11710f..597edb68862d 100644 --- a/torch/csrc/distributed/rpc/request_callback_no_python.cpp +++ b/torch/csrc/distributed/rpc/request_callback_no_python.cpp @@ -1,3 +1,4 @@ +#include #include #include #include @@ -7,7 +8,6 @@ #include #include #include -#include #include #include #include @@ -63,44 +63,43 @@ std::shared_ptr RequestCallbackNoPython::processMessage( deserializeRequest(request), request.type()); auto rrefsReadyFuture = rrefContext.waitForThreadLocalPendingRRefs(); - rrefsReadyFuture->addCallback([ - this, - retFuture, - // std::function must be copyable, hence hae to cast the unique_ptr to - // a shared_ptr here. - rpc = (std::shared_ptr)std::move(rpc), - messageType = request.type(), - id = request.id() - ]() { - // The cost of pre-request check is minimal thanks to - // std::shared_lock. The cost is in magnitude - // of 10us. - auto serverProcessGlobalProfilerStateStackEntryPtr = - profiler::processglobal::StateStackEntry::current(); - // If server global profiler is enabled, we futher pay the - // cost of thread local profiler state initialization. - if (serverProcessGlobalProfilerStateStackEntryPtr) { - // Initialize thread-local profiler state from process-global - // profiler state. - ::torch::autograd::profiler::enableProfilerLegacy( - serverProcessGlobalProfilerStateStackEntryPtr->statePtr() - ->config()); - } - - processRpcWithErrors(*rpc, messageType, id, retFuture); - - // Response message has been sent at this moment, this post-response - // work doesn't affect RPC trip time. - if (serverProcessGlobalProfilerStateStackEntryPtr) { - // Restore thread-local profiler state. - ::torch::autograd::profiler::thread_event_lists event_lists = - ::torch::autograd::profiler::disableProfilerLegacy(); - // Put thread_local event_lists into the process-global profiler - // state. - profiler::processglobal::pushResultRecursive( - serverProcessGlobalProfilerStateStackEntryPtr, event_lists); - } - }); + rrefsReadyFuture->addCallback( + [this, + retFuture, + // std::function must be copyable, hence hae to cast the unique_ptr to + // a shared_ptr here. + rpc = (std::shared_ptr)std::move(rpc), + messageType = request.type(), + id = request.id()]() { + // The cost of pre-request check is minimal thanks to + // std::shared_lock. The cost is in magnitude + // of 10us. + auto serverProcessGlobalProfilerStateStackEntryPtr = + profiler::processglobal::StateStackEntry::current(); + // If server global profiler is enabled, we futher pay the + // cost of thread local profiler state initialization. + if (serverProcessGlobalProfilerStateStackEntryPtr) { + // Initialize thread-local profiler state from process-global + // profiler state. + ::torch::autograd::profiler::enableProfilerLegacy( + serverProcessGlobalProfilerStateStackEntryPtr->statePtr() + ->config()); + } + + processRpcWithErrors(*rpc, messageType, id, retFuture); + + // Response message has been sent at this moment, this post-response + // work doesn't affect RPC trip time. + if (serverProcessGlobalProfilerStateStackEntryPtr) { + // Restore thread-local profiler state. + ::torch::autograd::profiler::thread_event_lists event_lists = + ::torch::autograd::profiler::disableProfilerLegacy(); + // Put thread_local event_lists into the process-global profiler + // state. + profiler::processglobal::pushResultRecursive( + serverProcessGlobalProfilerStateStackEntryPtr, event_lists); + } + }); } catch (std::exception& e) { retFuture->markCompleted(handleError(e, request.type(), request.id())); rrefContext.clearRecordedPendingRRefsOnError(); @@ -379,40 +378,40 @@ void RequestCallbackNoPython::processForwardAutogradReq( // The original future needs to be marked as completed when the wrapped // one completes, with the autograd context information wrapped. // Uses weak_ptr so we can std::move the value. - wrappedRpcResponseFuture->addCallback([ - responseFuture, - messageId, - fromWorkerId, - weak = std::weak_ptr(wrappedRpcResponseFuture), - ctxId = autogradContext->contextId() - ]() { - // As this callback can be invoked by a different thread, we have to - // make sure that the thread_local states in the previous thread is - // correctly propagated. - // NB: The execution of TorchScript functions can also run on a - // different thread, which is addressed by - // https://github.com/pytorch/pytorch/pull/36395 - // NB: when adding async UDF support, we should also propagate - // thread_local states there. - // TODO: Land on a general solution for RPC ThreadLocalState. See - // https://github.com/pytorch/pytorch/issues/38510 - DistAutogradContextGuard cbCtxGuard(ctxId); - - auto wrappedRpcResponseFuture = weak.lock(); - TORCH_INTERNAL_ASSERT(wrappedRpcResponseFuture); - if (wrappedRpcResponseFuture->hasError()) { - // Propagate error to responseFuture if we had one. - responseFuture->setError(wrappedRpcResponseFuture->exception_ptr()); - } else { - auto msg = getMessageWithAutograd( - fromWorkerId, - std::move(*wrappedRpcResponseFuture->value().toCustomClass()), - MessageType::FORWARD_AUTOGRAD_RESP); - msg.setId(messageId); - responseFuture->markCompleted( - IValue(c10::make_intrusive(std::move(msg)))); - } - }); + wrappedRpcResponseFuture->addCallback( + [responseFuture, + messageId, + fromWorkerId, + weak = std::weak_ptr(wrappedRpcResponseFuture), + ctxId = autogradContext->contextId()]() { + // As this callback can be invoked by a different thread, we have to + // make sure that the thread_local states in the previous thread is + // correctly propagated. + // NB: The execution of TorchScript functions can also run on a + // different thread, which is addressed by + // https://github.com/pytorch/pytorch/pull/36395 + // NB: when adding async UDF support, we should also propagate + // thread_local states there. + // TODO: Land on a general solution for RPC ThreadLocalState. See + // https://github.com/pytorch/pytorch/issues/38510 + DistAutogradContextGuard cbCtxGuard(ctxId); + + auto wrappedRpcResponseFuture = weak.lock(); + TORCH_INTERNAL_ASSERT(wrappedRpcResponseFuture); + if (wrappedRpcResponseFuture->hasError()) { + // Propagate error to responseFuture if we had one. + responseFuture->setError(wrappedRpcResponseFuture->exception_ptr()); + } else { + auto msg = getMessageWithAutograd( + fromWorkerId, + std::move( + *wrappedRpcResponseFuture->value().toCustomClass()), + MessageType::FORWARD_AUTOGRAD_RESP); + msg.setId(messageId); + responseFuture->markCompleted( + IValue(c10::make_intrusive(std::move(msg)))); + } + }); } void RequestCallbackNoPython::processBackwardAutogradReq( @@ -511,11 +510,11 @@ void RequestCallbackNoPython::processRunWithProfilingReq( messageId, wrappedRpcResponseFuture); - wrappedRpcResponseFuture->addCallback(at::wrapPropagateTLSState( - [wrappedRpcResponseFuture, - responseFuture, - profilingKeyId, - profilingConfig] { + wrappedRpcResponseFuture->addCallback( + at::wrapPropagateTLSState([wrappedRpcResponseFuture, + responseFuture, + profilingKeyId, + profilingConfig] { std::vector profiledEvents; // Defer consolidation of profiler events until async work has // completed (such as async UDF) @@ -528,7 +527,8 @@ void RequestCallbackNoPython::processRunWithProfilingReq( // they will be cleaned up by main thread, and consolidate all // events so we obtain asynchronously run events. torch::autograd::profiler::ProfilerDisableOptions opts(false, true); - auto event_lists = torch::autograd::profiler::disableProfilerLegacy(opts); + auto event_lists = + torch::autograd::profiler::disableProfilerLegacy(opts); if (wrappedRpcResponseFuture->hasError()) { // Propagate error // No need to propagate remote events in the case of an error. @@ -538,7 +538,8 @@ void RequestCallbackNoPython::processRunWithProfilingReq( profiledEvents, profilingConfig, event_lists); auto rpcWithProfilingResp = std::make_unique( MessageType::RUN_WITH_PROFILING_RESP, - std::move(*wrappedRpcResponseFuture->value().toCustomClass()), + std::move(*wrappedRpcResponseFuture->value() + .toCustomClass()), profiledEvents, profilingKeyId); responseFuture->markCompleted(IValue(c10::make_intrusive( diff --git a/torch/csrc/distributed/rpc/request_callback_no_python.h b/torch/csrc/distributed/rpc/request_callback_no_python.h index 09c2e48479ab..9932c4744900 100644 --- a/torch/csrc/distributed/rpc/request_callback_no_python.h +++ b/torch/csrc/distributed/rpc/request_callback_no_python.h @@ -14,8 +14,7 @@ namespace rpc { // RequestCallback implementation with no Python dependencies. class TORCH_API RequestCallbackNoPython : public RequestCallback { public: - std::shared_ptr processMessage( - Message& request) const override; + std::shared_ptr processMessage(Message& request) const override; protected: virtual std::unique_ptr deserializePythonRpcCommand( diff --git a/torch/csrc/distributed/rpc/rpc_agent.cpp b/torch/csrc/distributed/rpc/rpc_agent.cpp index afcb0226b411..2033b2b771e2 100644 --- a/torch/csrc/distributed/rpc/rpc_agent.cpp +++ b/torch/csrc/distributed/rpc/rpc_agent.cpp @@ -71,9 +71,9 @@ std::shared_ptr RpcAgent::sendWithRetries( retryOptions); // Use weak_ptr so that the value can be std::moved in rpcRetryCallback. jitFuture->addCallback([this, - newTime, - firstRetryRpc, - wp = std::weak_ptr(jitFuture)]() { + newTime, + firstRetryRpc, + wp = std::weak_ptr(jitFuture)]() { auto future = wp.lock(); TORCH_INTERNAL_ASSERT(future); rpcRetryCallback(future, newTime, firstRetryRpc); @@ -88,8 +88,7 @@ void RpcAgent::retryExpiredRpcs() { std::pair, std::shared_ptr>> futures; // Stores futures and exception messages for non-retriable error-ed futures. - std::vector, std::string>> - errorFutures; + std::vector, std::string>> errorFutures; while (rpcAgentRunning_.load()) { std::unique_lock lock(rpcRetryMutex_); @@ -166,9 +165,9 @@ void RpcAgent::retryExpiredRpcs() { // Use weak_ptr so that the value can be std::moved in rpcRetryCallback. jitFuture->addCallback([this, - newTime, - earliestRpc, - wp = std::weak_ptr(jitFuture)]() { + newTime, + earliestRpc, + wp = std::weak_ptr(jitFuture)]() { auto future = wp.lock(); TORCH_INTERNAL_ASSERT(future); rpcRetryCallback(future, newTime, earliestRpc); @@ -181,7 +180,8 @@ void RpcAgent::retryExpiredRpcs() { for (const auto& it : errorFutures) { auto errorFuture = it.first; auto errorMsg = it.second; - errorFuture->setError(std::make_exception_ptr(std::runtime_error(errorMsg))); + errorFuture->setError( + std::make_exception_ptr(std::runtime_error(errorMsg))); } errorFutures.clear(); } diff --git a/torch/csrc/distributed/rpc/rref_context.h b/torch/csrc/distributed/rpc/rref_context.h index c8ef52f7379a..1e3537a6dfd3 100644 --- a/torch/csrc/distributed/rpc/rref_context.h +++ b/torch/csrc/distributed/rpc/rref_context.h @@ -16,16 +16,14 @@ namespace rpc { namespace callback { // It's the callback for RemoteCall. -void TORCH_API confirmPendingUser( - const JitFuture& jitFuture, - const ForkId& expectedForkId); +void TORCH_API +confirmPendingUser(const JitFuture& jitFuture, const ForkId& expectedForkId); // It's the callback for finishing creating owner rref, it returned deletedRRef, // so that the deletedRRef can be handled under GIL in python_functions.cpp if // deletedRRef contains python object. -c10::intrusive_ptr TORCH_API finishCreatingOwnerRRef( - const JitFuture& jitFuture, - const RRefId& rrefId); +c10::intrusive_ptr TORCH_API +finishCreatingOwnerRRef(const JitFuture& jitFuture, const RRefId& rrefId); } // namespace callback using torch::utils::Future; diff --git a/torch/csrc/distributed/rpc/rref_impl.cpp b/torch/csrc/distributed/rpc/rref_impl.cpp index f5bcf39400fc..085b65bfc0fb 100644 --- a/torch/csrc/distributed/rpc/rref_impl.cpp +++ b/torch/csrc/distributed/rpc/rref_impl.cpp @@ -65,9 +65,7 @@ RRefForkData RRef::fork() const { getTypeStr(type_)); } -void RRef::handleError( - RPCErrorType errorType, - const JitFuture& jitFuture) { +void RRef::handleError(RPCErrorType errorType, const JitFuture& jitFuture) { static std::unordered_map< RPCErrorType, std::function, diff --git a/torch/csrc/distributed/rpc/tensorpipe_agent.cpp b/torch/csrc/distributed/rpc/tensorpipe_agent.cpp index d05426f6160d..00c7567c6d43 100644 --- a/torch/csrc/distributed/rpc/tensorpipe_agent.cpp +++ b/torch/csrc/distributed/rpc/tensorpipe_agent.cpp @@ -166,8 +166,8 @@ std::unique_ptr makeMultiplexedUvChannel() { } auto context = std::make_shared( std::move(contexts), std::move(listeners)); - return std::make_unique( - CpuChannelRegistration{std::move(context), kMultiplexedUvChannelPriority}); + return std::make_unique(CpuChannelRegistration{ + std::move(context), kMultiplexedUvChannelPriority}); } // The multiplexed UV channel encapsulates multiple UV transports (each with its @@ -192,7 +192,10 @@ std::unique_ptr makeCudaIpcChannel() { // The cuda_ipc channels use cudaMemcpy to transmit CUDA tensor across processes // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) -C10_REGISTER_CREATOR(TensorPipeCudaChannelRegistry, cuda_ipc, makeCudaIpcChannel); +C10_REGISTER_CREATOR( + TensorPipeCudaChannelRegistry, + cuda_ipc, + makeCudaIpcChannel); #endif @@ -324,9 +327,10 @@ void TensorPipeAgent::startImpl() { if (iter == opts_.channels->end()) { continue; } - // Assign priorities in reverse order of occurrence in the vector, so that - // a channel that comes before another receives a higher priority. - priority = opts_.channels->size() - 1 - (iter - opts_.channels->begin()); + // Assign priorities in reverse order of occurrence in the vector, so + // that a channel that comes before another receives a higher priority. + priority = + opts_.channels->size() - 1 - (iter - opts_.channels->begin()); } // The reg var is either a std::unique_ptr or a @@ -337,7 +341,7 @@ void TensorPipeAgent::startImpl() { priority = reg->priority; } context_->registerChannel( - priority, std::move(key), std::move(reg->channel)); + priority, std::move(key), std::move(reg->channel)); } }; @@ -806,9 +810,8 @@ void TensorPipeAgent::pollTimeoutRpcs() { // Move all these futures to a separate vector so we can process them // outside the lock. - std::vector, - std::chrono::milliseconds>> + std::vector< + std::pair, std::chrono::milliseconds>> timedOutFutures = std::move(timeoutMap_.begin()->second); // We can safely remove this key from the timeoutMap_ since all these // futures will be processed. diff --git a/torch/csrc/distributed/rpc/tensorpipe_agent.h b/torch/csrc/distributed/rpc/tensorpipe_agent.h index ef72e5468a5d..9f75eaf4e0af 100644 --- a/torch/csrc/distributed/rpc/tensorpipe_agent.h +++ b/torch/csrc/distributed/rpc/tensorpipe_agent.h @@ -11,7 +11,6 @@ #include #include - // Forward-declare the TensorPipe classes we need, to avoid including its // headers in PyTorch's ones and thus have it become a public dependency. @@ -89,7 +88,6 @@ struct CudaChannelRegistration { // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) C10_DECLARE_REGISTRY(TensorPipeCudaChannelRegistry, CudaChannelRegistration); - constexpr auto kDefaultNumWorkerThreads = 16; struct TensorPipeRpcBackendOptions : public RpcBackendOptions { From 1bb7d8ff93c58ed89bbd0adbc12d1c7bae0d2ca6 Mon Sep 17 00:00:00 2001 From: Heitor Schueroff Date: Fri, 8 Jan 2021 12:13:00 -0800 Subject: [PATCH 49/54] Revert D25717504: Clean up some type annotations in test/jit Test Plan: revert-hammer Differential Revision: D25717504 (https://github.com/pytorch/pytorch/commit/a4f30d48d85b934253bef011d01af7dda3468750) Original commit changeset: 9a83c44db02e fbshipit-source-id: e6e3a83bed22701d8125f5a293dfcd5093c1a2cd --- test/jit/test_async.py | 6 +- test/jit/test_builtins.py | 14 +- test/jit/test_list_dict.py | 252 ++++++++++++++++++++---------- test/jit/test_module_interface.py | 178 ++++++++++++++------- test/jit/test_recursive_script.py | 12 +- test/jit/test_save_load.py | 12 +- test/jit/test_tracer.py | 14 +- test/jit/test_with.py | 98 ++++++++---- 8 files changed, 389 insertions(+), 197 deletions(-) diff --git a/test/jit/test_async.py b/test/jit/test_async.py index 37fabcfc438a..7a70a4c5a655 100644 --- a/test/jit/test_async.py +++ b/test/jit/test_async.py @@ -41,7 +41,8 @@ def foo(inp): def test_async_parsing(self): @torch.jit.script - def foo(x: Tensor) -> List[Tensor]: + def foo(x): + # type: (Tensor) -> List[Tensor] return [torch.neg(x), x.t()] @torch.jit.script @@ -256,7 +257,8 @@ def __init__(self): self.traced = torch.jit.trace(Traced(), (x), _force_outplace=True) @torch.jit.script_method - def forward(self, x: Tensor) -> Tuple[List[Tensor], Tuple[Tensor, Tensor], Tensor]: + def forward(self, x): + # type: (Tensor) -> Tuple[List[Tensor], Tuple[Tensor, Tensor], Tensor] future1 = torch.jit._fork(self.traced, x) future2 = torch.jit._fork(torch.neg, x) diff --git a/test/jit/test_builtins.py b/test/jit/test_builtins.py index b5a0dd8599a6..04991f72c352 100644 --- a/test/jit/test_builtins.py +++ b/test/jit/test_builtins.py @@ -2,7 +2,7 @@ import sys import inspect import unittest -from typing import Dict, List +from typing import List import torch @@ -78,7 +78,8 @@ def forward(self, name): torch.jit.script(Mod()) def test_del(self): - def fn(x: List[int]) -> List[int]: + def fn(x): + # type: (List[int]) -> List[int] a = x * 2 del a return x @@ -108,14 +109,16 @@ def fn(x): return a def test_del_multiple_operands(self): - def fn(x: List[int]) -> List[int]: + def fn(x): + # type: (List[int]) -> List[int] a, b, c = x[0], x[1], x[2] del a, b, c return x self.checkScript(fn, ([1, 2, 3],)) - def del_list_multiple_operands(x: List[int]) -> List[int]: + def del_list_multiple_operands(x): + # type: (List[int]) -> List[int] del x[0], x[1] return x @@ -123,7 +126,8 @@ def del_list_multiple_operands(x: List[int]) -> List[int]: jit_out = torch.jit.script(del_list_multiple_operands)([0, 1, 2]) self.assertEquals(py_out, jit_out) - def del_dict_multiple_operands(x: Dict[str, int]) -> Dict[str, int]: + def del_dict_multiple_operands(x): + # type: (Dict[str, int]) -> Dict[str, int] del x['hi'], x['there'] return x diff --git a/test/jit/test_list_dict.py b/test/jit/test_list_dict.py index 63d3aa3938e9..2bc24a57751d 100644 --- a/test/jit/test_list_dict.py +++ b/test/jit/test_list_dict.py @@ -20,19 +20,22 @@ class TestList(JitTestCase): def test_in_check(self): - def int_in(x: List[int]) -> bool: + def int_in(x): + # type: (List[int]) -> bool return 2 in x self.checkScript(int_in, ([1, 2, 3],)) self.checkScript(int_in, ([1, 3, 3],)) - def float_in(x: List[float]) -> bool: + def float_in(x): + # type: (List[float]) -> bool return 2. in x self.checkScript(float_in, ([1., 2., 3.],)) self.checkScript(float_in, ([1., 3., 3.],)) - def str_in(x: List[str]) -> bool: + def str_in(x): + # type: (List[str]) -> bool return 'hi' in x self.checkScript(str_in, (['not', 'here'],)) @@ -97,7 +100,8 @@ def test_del(self): def inputs(): return [1, 2, 3, 4] - def fn(x: List[int]) -> List[int]: + def fn(x): + # type: (List[int]) -> List[int] del x[1] return x @@ -110,7 +114,8 @@ def fn(x: List[int]) -> List[int]: self.assertEqual(torch.jit.script(fn)(inputs()), python_out) @torch.jit.script - def fn2(x: List[int]) -> List[int]: + def fn2(x): + # type: (List[int]) -> List[int] del x[100] return x @@ -119,7 +124,8 @@ def fn2(x: List[int]) -> List[int]: with self.assertRaisesRegex(RuntimeError, "deletion at a single index"): @torch.jit.script - def fn(x: List[int]) -> List[int]: + def fn(x): + # type: (List[int]) -> List[int] del x[1:3] return x @@ -143,19 +149,23 @@ def foo3(): FileCheck().check_count("aten::list", 2, exactly=True).run(torch.jit.script(foo3).graph) def test_min_bool_list(self): - def jit_min_list(a: List[bool], b: List[bool]) -> List[bool]: + def jit_min_list(a, b): + # type: (List[bool], List[bool]) -> List[bool] return min(a, b) self.checkScript(jit_min_list, ([True, False], [False, True])) def test_min_max_list(self): - def jit_min_list(a: List[int], b: List[int]) -> List[int]: + def jit_min_list(a, b): + # type: (List[int], List[int]) -> List[int] return min(a, b) - def jit_min_list_float(a: List[float], b: List[float]) -> List[float]: + def jit_min_list_float(a, b): + # type: (List[float], List[float]) -> List[float] return min(a, b) - def jit_min_list_bool(a: List[bool], b: List[bool]) -> List[bool]: + def jit_min_list_bool(a, b): + # type: (List[bool], List[bool]) -> List[bool] return min(a, b) def run_tests(func, a, b): @@ -176,13 +186,16 @@ def run_tests(func, a, b): [False, True], [False, False, True], [False, False, False]] run_tests(jit_min_list_bool, args_left_bool, args_right_bool) - def jit_max_list(a: List[int], b: List[int]) -> List[int]: + def jit_max_list(a, b): + # type: (List[int], List[int]) -> List[int] return max(a, b) - def jit_max_list_float(a: List[float], b: List[float]) -> List[float]: + def jit_max_list_float(a, b): + # type: (List[float], List[float]) -> List[float] return max(a, b) - def jit_max_list_bool(a: List[bool], b: List[bool]) -> List[bool]: + def jit_max_list_bool(a, b): + # type: (List[bool], List[bool]) -> List[bool] return max(a, b) args_left_int = [[1, 8, 8], [8, 1, 1], [], [1], [], [1, 2]] @@ -352,7 +365,8 @@ def func(): t2 = scope['func']() self.assertEqual(t1, t2) - def test_fail(x: List[Tensor]) -> List[Tensor]: + def test_fail(x): + # type: (List[Tensor]) -> List[Tensor] x.sort() return x @@ -458,7 +472,8 @@ def test_append(): self.checkScript(test_append, ()) def test_comprehensions_basic(self): - def comp(l: List[int]) -> List[int]: + def comp(l): + # type: (List[int]) -> List[int] n = [x * 3 for x in l] return n @@ -467,7 +482,8 @@ def comp(l: List[int]) -> List[int]: self.checkScript(comp, ([1, 2, 3],)) def test_comprehensions_basic_float(self): - def comp(l: List[float]) -> List[float]: + def comp(l): + # type: (List[float]) -> List[float] n = [x * 3 for x in l] return n @@ -476,7 +492,8 @@ def comp(l: List[float]) -> List[float]: def test_comprehensions_two_comps(self): @torch.jit.script - def comp(l1: List[int], l2: List[int]) -> List[int]: + def comp(l1, l2): + # type: (List[int], List[int]) -> List[int] n = [x * 3 for x in l1] n2 = [x + 2 for x in l2] @@ -485,7 +502,8 @@ def comp(l1: List[int], l2: List[int]) -> List[int]: self.assertEqual(comp([1, 2, 3], [4, 5]), [3, 6, 9, 6, 7]) def test_comprehension_out_type_not_in_type(self): - def list_cast() -> int: + def list_cast(): + # type: () -> int li = [int(i) for i in [torch.tensor(0), torch.tensor(1), torch.tensor(2)]] return li[0] + li[1] + li[2] @@ -495,13 +513,15 @@ def test_comprehension_iterable(self): def test_func(fn, inputs): self.assertEqual(fn(*inputs), torch.jit.script(fn)(*inputs)) - def foo(names: List[int], results: List[int]) -> List[Tuple[int, int]]: + def foo(names, results): + # type: (List[int], List[int]) -> List[Tuple[int, int]] return [(k + 5, v - 2) for k, v in zip(names, results)] test_func(foo, ([1, 2, 4], [4, 7, 9])) test_func(foo, ([5], [4, 7, 9])) - def fn(x: int) -> List[int]: + def fn(x): + # type: (int) -> List[int] return [i for i in range(x)] # noqa: C416 test_func(fn, (9,)) @@ -581,7 +601,8 @@ def test_nested_loop(): def test_mutable_list_function_inline(self): @torch.jit.script - def bar(y: List[int]) -> None: + def bar(y): + # type: (List[int]) -> None y.append(4) @torch.jit.script @@ -867,7 +888,8 @@ def test_list_remove2(): def test_extend_list_mutable(self): @torch.jit.script - def extend_list(a: List[Tensor], b: List[Tensor]) -> List[Tensor]: + def extend_list(a, b): + # type: (List[Tensor], List[Tensor]) -> List[Tensor] a.extend(b) return a @@ -878,7 +900,8 @@ def extend_list(a: List[Tensor], b: List[Tensor]) -> List[Tensor]: def test_extend_list_immutable(self): @torch.jit.script - def extend_list(a: List[int], b: List[int]) -> List[int]: + def extend_list(a, b): + # type: (List[int], List[int]) -> List[int] a.extend(b) return a @@ -889,7 +912,8 @@ def extend_list(a: List[int], b: List[int]) -> List[int]: def test_copy_list_mutable(self): @torch.jit.script - def copy_list(a: List[Tensor]) -> List[Tensor]: + def copy_list(a): + # type: (List[Tensor]) -> List[Tensor] return a.copy() for l in [[], [torch.rand(2)], [torch.rand(2), torch.rand(2), torch.rand(2)]]: @@ -897,29 +921,36 @@ def copy_list(a: List[Tensor]) -> List[Tensor]: def test_copy_list_immutable(self): @torch.jit.script - def copy_list(a: List[int]) -> List[int]: + def copy_list(a): + # type: (List[int]) -> List[int] return a.copy() for l in [[], [1], [1, 2, 3]]: self.assertEqual(copy_list(l), l) def test_min_max_single_list(self): - def min_intlist(li: List[int]) -> int: + def min_intlist(li): + # type: (List[int]) -> int return min(li) - def max_intlist(li: List[int]) -> int: + def max_intlist(li): + # type: (List[int]) -> int return max(li) - def min_boollist(li: List[bool]) -> bool: + def min_boollist(li): + # type: (List[bool]) -> bool return min(li) - def max_boollist(li: List[bool]) -> bool: + def max_boollist(li): + # type: (List[bool]) -> bool return max(li) - def min_floatlist(li: List[float]) -> float: + def min_floatlist(li): + # type: (List[float]) -> float return min(li) - def max_floatlist(li: List[float]) -> float: + def max_floatlist(li): + # type: (List[float]) -> float return max(li) @@ -949,19 +980,23 @@ def test_to_list(self): """ Boolean dtype unit tests. """ - def to_list_bool_0D(x: torch.Tensor) -> bool: + def to_list_bool_0D(x): + # type: (torch.Tensor) -> bool li = torch.jit.annotate(bool, x.tolist()) return li - def to_list_bool_1D(x: torch.Tensor) -> List[bool]: + def to_list_bool_1D(x): + # type: (torch.Tensor) -> List[bool] li = torch.jit.annotate(List[bool], x.tolist()) return li - def to_list_bool_2D(x: torch.Tensor) -> List[List[bool]]: + def to_list_bool_2D(x): + # type: (torch.Tensor) -> List[List[bool]] li = torch.jit.annotate(List[List[bool]], x.tolist()) return li - def to_list_bool_3D(x: torch.Tensor) -> List[List[List[bool]]]: + def to_list_bool_3D(x): + # type: (torch.Tensor) -> List[List[List[bool]]] li = torch.jit.annotate(List[List[List[bool]]], x.tolist()) return li @@ -986,19 +1021,23 @@ def to_list_bool_3D(x: torch.Tensor) -> List[List[List[bool]]]: """ Int dtype unit tests. """ - def to_list_int_0D(x: torch.Tensor) -> int: + def to_list_int_0D(x): + # type: (torch.Tensor) -> int li = torch.jit.annotate(int, x.tolist()) return li - def to_list_int_1D(x: torch.Tensor) -> List[int]: + def to_list_int_1D(x): + # type: (torch.Tensor) -> List[int] li = torch.jit.annotate(List[int], x.tolist()) return li - def to_list_int_2D(x: torch.Tensor) -> List[List[int]]: + def to_list_int_2D(x): + # type: (torch.Tensor) -> List[List[int]] li = torch.jit.annotate(List[List[int]], x.tolist()) return li - def to_list_int_3D(x: torch.Tensor) -> List[List[List[int]]]: + def to_list_int_3D(x): + # type: (torch.Tensor) -> List[List[List[int]]] li = torch.jit.annotate(List[List[List[int]]], x.tolist()) return li @@ -1019,19 +1058,23 @@ def to_list_int_3D(x: torch.Tensor) -> List[List[List[int]]]: """ Float dtype unit tests. """ - def to_list_float_0D(x: torch.Tensor) -> float: + def to_list_float_0D(x): + # type: (torch.Tensor) -> float li = torch.jit.annotate(float, x.tolist()) return li - def to_list_float_1D(x: torch.Tensor) -> List[float]: + def to_list_float_1D(x): + # type: (torch.Tensor) -> List[float] li = torch.jit.annotate(List[float], x.tolist()) return li - def to_list_float_2D(x: torch.Tensor) -> List[List[float]]: + def to_list_float_2D(x): + # type: (torch.Tensor) -> List[List[float]] li = torch.jit.annotate(List[List[float]], x.tolist()) return li - def to_list_float_3D(x: torch.Tensor) -> List[List[List[float]]]: + def to_list_float_3D(x): + # type: (torch.Tensor) -> List[List[List[float]]] li = torch.jit.annotate(List[List[List[float]]], x.tolist()) return li @@ -1056,23 +1099,28 @@ def to_list_float_3D(x: torch.Tensor) -> List[List[List[float]]]: - type annotation with the wrong dimension - type annotation with scalar type that doesn't match the input scalar type """ - def to_list_missing_type_annotation(x: torch.Tensor) -> List[float]: + def to_list_missing_type_annotation(x): + # type: (torch.Tensor) -> List[float] li = x.tolist() return li - def to_list_incorrect_type_annotation(x: torch.Tensor) -> List[float]: + def to_list_incorrect_type_annotation(x): + # type: (torch.Tensor) -> List[float] li = torch.jit.annotate(float, x.tolist()) return li - def to_list_unsupported_type_annotation(x: torch.Tensor) -> List[float]: + def to_list_unsupported_type_annotation(x): + # type: (torch.Tensor) -> List[float] li = torch.jit.annotate(List[str], x.tolist()) return li - def to_list_type_annotation_wrong_dim(x: torch.Tensor) -> List[List[float]]: + def to_list_type_annotation_wrong_dim(x): + # type: (torch.Tensor) -> List[List[float]] li = torch.jit.annotate(List[List[float]], x.tolist()) return li - def to_list_type_annotation_incorrect_scalar_type(x: torch.Tensor) -> List[float]: + def to_list_type_annotation_incorrect_scalar_type(x): + # type: (torch.Tensor) -> List[float] li = torch.jit.annotate(List[float], x.tolist()) return li @@ -1116,15 +1164,18 @@ def test_to_list_gpu(self): if not torch.cuda.is_available() or torch.cuda.device_count() == 0: self.skipTest("CUDA is not available") - def to_list_bool_1D(x: torch.Tensor) -> List[bool]: + def to_list_bool_1D(x): + # type: (torch.Tensor) -> List[bool] li = torch.jit.annotate(List[bool], x.tolist()) return li - def to_list_int_1D(x: torch.Tensor) -> List[int]: + def to_list_int_1D(x): + # type: (torch.Tensor) -> List[int] li = torch.jit.annotate(List[int], x.tolist()) return li - def to_list_float_1D(x: torch.Tensor) -> List[float]: + def to_list_float_1D(x): + # type: (torch.Tensor) -> List[float] li = torch.jit.annotate(List[float], x.tolist()) return li @@ -1136,7 +1187,8 @@ def to_list_float_1D(x: torch.Tensor) -> List[float]: 5, dtype=torch.double).cuda(),)) def test_no_element_type_annotation(self): - def fn_with_comment(x: torch.Tensor) -> List: + def fn_with_comment(x): + # type: (torch.Tensor) -> List a: List = x.tolist() return a @@ -1178,7 +1230,8 @@ def test_del(self): def inputs(): return {'hi': 2, 'bye': 3} - def fn(x: Dict[str, int]) -> Dict[str, int]: + def fn(x): + # type: (Dict[str, int]) -> Dict[str, int] del x['hi'] return x @@ -1194,7 +1247,8 @@ def fn(x: Dict[str, int]) -> Dict[str, int]: def test_keys(self): @torch.jit.script - def keys(x: Dict[str, Tensor]) -> List[str]: + def keys(x): + # type: (Dict[str, Tensor]) -> List[str] return list(x.keys()) self.assertEqual(set(keys(self.dict())), set(self.dict().keys())) @@ -1209,26 +1263,30 @@ def specialized_list(): def test_values(self): @torch.jit.script - def values(x: Dict[str, Tensor]) -> List[Tensor]: + def values(x): + # type: (Dict[str, Tensor]) -> List[Tensor] return list(x.values()) the_dict = self.dict() self.assertEqual(set(values(the_dict)), set(the_dict.values())) def test_len(self): - def length(x: Dict[str, Tensor]) -> int: + def length(x): + # type: (Dict[str, Tensor]) -> int return len(x) self.checkScript(length, (self.dict(),)) def test_copy(self): - def func(x: Dict[str, Tensor]) -> Dict[str, Tensor]: + def func(x): + # type: (Dict[str, Tensor]) -> Dict[str, Tensor] return x.copy() self.checkScript(func, (self.dict(),)) def test_items(self): - def func(x: Dict[str, Tensor]) -> List[Tuple[str, Tensor]]: + def func(x): + # type: (Dict[str, Tensor]) -> List[Tuple[str, Tensor]] return x.items() # The value returned by Python is in arbitrary order, so we can't use @@ -1243,7 +1301,8 @@ def func(x: Dict[str, Tensor]) -> List[Tuple[str, Tensor]]: self.assertTrue(item in script_out) def test_pop(self): - def pop(x: Dict[str, Tensor], key: str) -> Tuple[Tensor, Dict[str, Tensor]]: + def pop(x, key): + # type: (Dict[str, Tensor], str) -> Tuple[Tensor, Dict[str, Tensor]] return x.pop(key), x # checkScript doesn't copy the inputs, so we can't use it since this mutates @@ -1259,14 +1318,16 @@ def tester(fn, *args): torch.jit.script(pop)(self.dict(), 'x') - def default_pop(x: Dict[str, Tensor], key: str, default: Tensor) -> Tuple[Tensor, Dict[str, Tensor]]: + def default_pop(x, key, default): + # type: (Dict[str, Tensor], str, Tensor) -> Tuple[Tensor, Dict[str, Tensor]] return x.pop(key, default), x tester(default_pop, 'a', torch.randn(2, 2)) tester(default_pop, 'x', torch.randn(2, 2)) def test_setdefault(self): - def setdefault(x: Dict[str, Tensor], key: str, default: Tensor) -> Dict[str, Tensor]: + def setdefault(x, key, default): + # type: (Dict[str, Tensor], str, Tensor) -> Dict[str, Tensor] x.setdefault(key, default) return x @@ -1274,7 +1335,8 @@ def setdefault(x: Dict[str, Tensor], key: str, default: Tensor) -> Dict[str, Ten self.checkScript(setdefault, (self.dict(), 'nonexistant', torch.randn(2, 2))) def test_update(self): - def update(a: Dict[str, Tensor], b: Dict[str, Tensor]) -> Tuple[Dict[str, Tensor], Dict[str, Tensor]]: + def update(a, b): + # type: (Dict[str, Tensor], Dict[str, Tensor]) -> Tuple[Dict[str, Tensor], Dict[str, Tensor]] a.update(b) return a, b @@ -1291,7 +1353,8 @@ def foo() -> Dict[str, int]: self.checkScript(foo, ()) def test_aug_assign(self): - def aug_assign_dict_tensor(a: Dict[str, Tensor]) -> Dict[str, Tensor]: + def aug_assign_dict_tensor(a): + # type: (Dict[str, Tensor]) -> Dict[str, Tensor] a['a'] += 1 a['b'] -= 12 a['c'] *= 122 @@ -1299,7 +1362,8 @@ def aug_assign_dict_tensor(a: Dict[str, Tensor]) -> Dict[str, Tensor]: a['c'] %= 2 return a - def aug_assign_dict_prim(a: Dict[str, float]) -> Dict[str, float]: + def aug_assign_dict_prim(a): + # type: (Dict[str, float]) -> Dict[str, float] a['a'] += 3.4 a['b'] -= 2.4 a['c'] *= 3.0 @@ -1312,7 +1376,8 @@ def aug_assign_dict_prim(a: Dict[str, float]) -> Dict[str, float]: def test_popitem(self): @torch.jit.script - def popitem(x: Dict[str, Tensor]) -> Tuple[Tuple[str, Tensor], Dict[str, Tensor]]: + def popitem(x): + # type: (Dict[str, Tensor]) -> Tuple[Tuple[str, Tensor], Dict[str, Tensor]] item = x.popitem() return item, x @@ -1330,56 +1395,65 @@ def popitem(x: Dict[str, Tensor]) -> Tuple[Tuple[str, Tensor], Dict[str, Tensor] self.assertTrue(isinstance(script_out[0][1], torch.Tensor)) def test_clear(self): - def clear(x: Dict[str, Tensor]) -> Dict[str, Tensor]: + def clear(x): + # type: (Dict[str, Tensor]) -> Dict[str, Tensor] x.clear() return x self.checkScript(clear, (self.dict(),)) def test_get(self): - def get(x: Dict[str, Tensor], key: str) -> Optional[Tensor]: + def get(x, key): + # type: (Dict[str, Tensor], str) -> Optional[Tensor] return x.get(key) self.checkScript(get, (self.dict(), 'a')) self.checkScript(get, (self.dict(), "doesn't exist")) - def get_default(x: Dict[str, Tensor], key: str) -> Optional[Tensor]: + def get_default(x, key): + # type: (Dict[str, Tensor], str) -> Optional[Tensor] return x.get(key, torch.randn(2, 2)) self.checkScript(get, (self.dict(), 'a')) self.checkScript(get, (self.dict(), "doesn't exist")) def test_get_boolkey(self): - def get(x: Dict[bool, int], key: bool) -> Optional[int]: + def get(x, key): + # type: (Dict[bool, int], bool) -> Optional[int] return x.get(key) self.checkScript(get, (self.dict_bool(), True)) self.checkScript(get, (self.dict_bool(), False)) - def get_default(x: Dict[bool, int], key: bool) -> int: + def get_default(x, key): + # type: (Dict[bool, int], bool) -> int return x.get(key, 42) self.checkScript(get_default, (self.dict_bool(), True)) self.checkScript(get_default, (self.dict_bool(), False)) def test_basic(self): - def simple(x: Dict[str, int]) -> Dict[str, int]: + def simple(x): + # type: (Dict[str, int]) -> Dict[str, int] return x self.checkScript(simple, ({'item': 20, 'other_item': 120},)) - def index(x: Dict[str, int]) -> int: + def index(x): + # type: (Dict[str, int]) -> int return x['item'] self.checkScript(index, ({'item': 20, 'other_item': 120},)) - def type_default() -> Dict[str, Tensor]: + def type_default(): + # type: () -> Dict[str, Tensor] return {} self.checkScript(type_default, ()) @torch.jit.script - def missing_index(x: Dict[str, int]) -> int: + def missing_index(x): + # type: (Dict[str, int]) -> int return x['dne'] with self.assertRaisesRegex(RuntimeError, "KeyError"): @@ -1401,14 +1475,16 @@ def literal3(): ''')) self.assertEqual({10: 1.2, 11: 1.3}, cu.literal3()) - def list_of_dicts() -> List[Dict[str, Tensor]]: + def list_of_dicts(): + # type: () -> List[Dict[str, Tensor]] return [{'word': torch.ones(2) + 3}, {'other word': torch.ones(1) + 2}] self.checkScript(list_of_dicts, ()) def test_mutability(self): @torch.jit.script - def fn() -> Dict[str, int]: + def fn(): + # type: () -> Dict[str, int] a = torch.jit.annotate(Dict[str, int], {}) a['ok'] = 10 return a @@ -1418,12 +1494,14 @@ def fn() -> Dict[str, int]: def test_key_type(self): with self.assertRaisesRegex(RuntimeError, "but instead found type"): @torch.jit.script - def fn(a: Dict[str, int]) -> int: + def fn(a): + # type: (Dict[str, int]) -> int return a[None] def test_loop(self): @torch.jit.script - def fn(x: int) -> Dict[str, int]: + def fn(x): + # type: (int) -> Dict[str, int] a = torch.jit.annotate(Dict[str, int], {}) for i in range(x): a['ok'] = i @@ -1442,14 +1520,16 @@ def fn(x, y): self.checkScript(fn, (torch.rand(2, 3), torch.rand(2, 3))) def test_membership(self): - def fn(x: Dict[int, int], y: int) -> int: + def fn(x, y): + # type: (Dict[int, int], int) -> int return x.get(y, 3) d = {1: 2, 3: 4} self.checkScript(fn, (d, 3)) self.checkScript(fn, (d, 2)) - def optional(x: Dict[int, int], y: int) -> bool: + def optional(x, y): + # type: (Dict[int, int], int) -> bool res = x.get(y) return res is None @@ -1458,15 +1538,18 @@ def optional(x: Dict[int, int], y: int) -> bool: with self.assertRaisesRegex(RuntimeError, "is actually of type Optional"): @torch.jit.script - def bad_types(x: Dict[int, int], y: int) -> int: + def bad_types(x, y): + # type: (Dict[int, int], int) -> int return x.get(y) # noqa: T484 def test_dict_to_python(self): @torch.jit.ignore - def python_lookup(my_dict: Dict[str, int], keys: List[str]) -> List[int]: + def python_lookup(my_dict, keys): + # type: (Dict[str, int], List[str]) -> List[int] return [my_dict[k] for k in keys] - def fn(my_dict: Dict[str, int], keys: List[str]) -> List[int]: + def fn(my_dict, keys): + # type: (Dict[str, int], List[str]) -> List[int] return python_lookup(my_dict, keys) a_dict = {'a': torch.ones(1), 'b': torch.ones(1) + 1, 'c': torch.ones(1) + 2} @@ -1518,7 +1601,8 @@ def test_type_annotation_missing_contained_type(self): key and value types produces an error. """ # This function uses a type comment. - def fn_with_comment(input: Dict) -> Any: + def fn_with_comment(input): + # type: (Dict) -> Any return input # This function uses Python3 style type annotations. diff --git a/test/jit/test_module_interface.py b/test/jit/test_module_interface.py index d32537eb2ac5..70071cc56028 100644 --- a/test/jit/test_module_interface.py +++ b/test/jit/test_module_interface.py @@ -6,7 +6,6 @@ import torch.nn as nn import os import sys -from torch import Tensor from torch.testing._internal.jit_utils import JitTestCase # Make the helper files in test/ importable @@ -23,30 +22,36 @@ class OrigModule(nn.Module): def __init__(self): super(OrigModule, self).__init__() - def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: + def one(self, inp1, inp2): + # type: (Tensor, Tensor) -> Tensor return inp1 + inp2 + 1 - def two(self, input: Tensor) -> Tensor: + def two(self, input): + # type: (Tensor) -> Tensor return input + 2 - def forward(self, input: Tensor) -> Tensor: + def forward(self, input): + # type: (Tensor) -> Tensor return input + self.one(input, input) + 1 class NewModule(nn.Module): def __init__(self): super(NewModule, self).__init__() - def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: + def one(self, inp1, inp2): + # type: (Tensor, Tensor) -> Tensor return inp1 * inp2 + 1 - def forward(self, input: Tensor) -> Tensor: + def forward(self, input): + # type: (Tensor) -> Tensor return self.one(input, input + 1) class TestModuleInterface(JitTestCase): def test_not_submodule_interface_call(self): @torch.jit.interface class ModuleInterface(nn.Module): - def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: + def one(self, inp1, inp2): + # type: (Tensor, Tensor) -> Tensor pass class TestNotModuleInterfaceCall(nn.Module): @@ -56,7 +61,8 @@ def __init__(self): super(TestNotModuleInterfaceCall, self).__init__() self.proxy_mod = OrigModule() - def forward(self, input: Tensor) -> Tensor: + def forward(self, input): + # type: (Tensor) -> Tensor return self.proxy_mod.two(input) with self.assertRaisesRegex(RuntimeError, "Tried to access nonexistent attribute"): @@ -66,51 +72,64 @@ def test_module_interface(self): global OneTwoModule, OneTwoClass @torch.jit.interface class OneTwoModule(nn.Module): - def one(self, x: Tensor, y: Tensor) -> Tensor: + def one(self, x, y): + # type: (Tensor, Tensor) -> Tensor pass - def two(self, x: Tensor) -> Tensor: + def two(self, x): + # type: (Tensor) -> Tensor pass - def forward(self, x: Tensor) -> Tensor: + def forward(self, x): + # type: (Tensor) -> Tensor pass @torch.jit.interface class OneTwoClass(object): - def one(self, x: Tensor, y: Tensor) -> Tensor: + def one(self, x, y): + # type: (Tensor, Tensor) -> Tensor pass - def two(self, x: Tensor) -> Tensor: + def two(self, x): + # type: (Tensor) -> Tensor pass class FooMod(nn.Module): - def one(self, x: Tensor, y: Tensor) -> Tensor: + def one(self, x, y): + # type: (Tensor, Tensor) -> Tensor return x + y - def two(self, x: Tensor) -> Tensor: + def two(self, x): + # type: (Tensor) -> Tensor return 2 * x - def forward(self, x: Tensor) -> Tensor: + def forward(self, x): + # type: (Tensor) -> Tensor return self.one(self.two(x), x) class BarMod(nn.Module): - def one(self, x: Tensor, y: Tensor) -> Tensor: + def one(self, x, y): + # type: (Tensor, Tensor) -> Tensor return x * y - def two(self, x: Tensor) -> Tensor: + def two(self, x): + # type: (Tensor) -> Tensor return 2 / x - def forward(self, x: Tensor) -> Tensor: + def forward(self, x): + # type: (Tensor) -> Tensor return self.two(self.one(x, x)) @torch.jit.export - def forward2(self, x: Tensor) -> Tensor: + def forward2(self, x): + # type: (Tensor) -> Tensor return self.two(self.one(x, x)) + 1 def use_module_interface(mod_list: List[OneTwoModule], x: torch.Tensor): return mod_list[0].forward(x) + mod_list[1].forward(x) - def use_class_interface(mod_list: List[OneTwoClass], x: Tensor) -> Tensor: + def use_class_interface(mod_list, x): + # type: (List[OneTwoClass], Tensor) -> Tensor return mod_list[0].two(x) + mod_list[1].one(x, x) scripted_foo_mod = torch.jit.script(FooMod()) @@ -120,7 +139,8 @@ def use_class_interface(mod_list: List[OneTwoClass], x: Tensor) -> Tensor: self.checkScript(use_class_interface, ([scripted_foo_mod, scripted_bar_mod], torch.rand(3, 4),)) - def call_module_interface_on_other_method(mod_interface: OneTwoModule, x: Tensor) -> Tensor: + def call_module_interface_on_other_method(mod_interface, x): + # type: (OneTwoModule, Tensor) -> Tensor return mod_interface.forward2(x) # ensure error out when we call the module on the method other than the interface specified. @@ -132,28 +152,35 @@ def test_module_interface_subtype(self): global OneTwoModule @torch.jit.interface class OneTwoModule(nn.Module): - def one(self, x: Tensor, y: Tensor) -> Tensor: + def one(self, x, y): + # type: (Tensor, Tensor) -> Tensor pass - def two(self, x: Tensor) -> Tensor: + def two(self, x): + # type: (Tensor) -> Tensor pass - def forward(self, x: Tensor) -> Tensor: + def forward(self, x): + # type: (Tensor) -> Tensor pass @torch.jit.script - def as_module_interface(x: OneTwoModule) -> OneTwoModule: + def as_module_interface(x): + # type: (OneTwoModule) -> OneTwoModule return x @torch.jit.script class Foo(object): - def one(self, x: Tensor, y: Tensor) -> Tensor: + def one(self, x, y): + # type: (Tensor, Tensor) -> Tensor return x + y - def two(self, x: Tensor) -> Tensor: + def two(self, x): + # type: (Tensor) -> Tensor return 2 * x - def forward(self, x: Tensor) -> Tensor: + def forward(self, x): + # type: (Tensor) -> Tensor return self.one(self.two(x), x) # check class object is not a subtype of module interface @@ -161,10 +188,12 @@ def forward(self, x: Tensor) -> Tensor: as_module_interface(Foo()) class WrongMod(nn.Module): - def two(self, x: int) -> int: + def two(self, x): + # type: (int) -> int return 2 * x - def forward(self, x: Tensor) -> Tensor: + def forward(self, x): + # type: (Tensor) -> Tensor return x + torch.randn(3, self.two(3)) scripted_wrong_mod = torch.jit.script(WrongMod()) @@ -215,16 +244,19 @@ def test_module_interface_inheritance(self): with self.assertRaisesRegex(RuntimeError, "does not support inheritance yet. Please directly"): @torch.jit.interface class InheritMod(nn.ReLU): - def three(self, x: Tensor) -> Tensor: + def three(self, x): + # type: (Tensor) -> Tensor return 3 * x def test_module_swap(self): @torch.jit.interface class ModuleInterface(nn.Module): - def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: + def one(self, inp1, inp2): + # type: (Tensor, Tensor) -> Tensor pass - def forward(self, input: Tensor) -> Tensor: + def forward(self, input): + # type: (Tensor) -> Tensor pass class TestModule(nn.Module): @@ -234,7 +266,8 @@ def __init__(self): super(TestModule, self).__init__() self.proxy_mod = OrigModule() - def forward(self, input: Tensor) -> Tensor: + def forward(self, input): + # type: (Tensor) -> Tensor return self.proxy_mod.forward(input) scripted_mod = torch.jit.script(TestModule()) @@ -252,17 +285,20 @@ def forward(self, input: Tensor) -> Tensor: def test_module_swap_wrong_module(self): @torch.jit.interface class ModuleInterface(nn.Module): - def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: + def one(self, inp1, inp2): + # type: (Tensor, Tensor) -> Tensor pass - def forward(self, input: Tensor) -> Tensor: + def forward(self, input): + # type: (Tensor) -> Tensor pass class NewModuleWrong(nn.Module): def __init__(self): super(NewModuleWrong, self).__init__() - def forward(self, input: int) -> int: + def forward(self, input): + # type: (int) -> int return input + 1 class TestModule(nn.Module): @@ -272,7 +308,8 @@ def __init__(self): super(TestModule, self).__init__() self.proxy_mod = OrigModule() - def forward(self, input: Tensor) -> Tensor: + def forward(self, input): + # type: (Tensor) -> Tensor return self.proxy_mod.forward(input) scripted_mod = torch.jit.script(TestModule()) @@ -283,10 +320,12 @@ def forward(self, input: Tensor) -> Tensor: def test_module_swap_no_lazy_compile(self): @torch.jit.interface class ModuleInterface(nn.Module): - def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: + def one(self, inp1, inp2): + # type: (Tensor, Tensor) -> Tensor pass - def forward(self, input: Tensor) -> Tensor: + def forward(self, input): + # type: (Tensor) -> Tensor pass class TestModule(nn.Module): @@ -296,17 +335,20 @@ def __init__(self): super(TestModule, self).__init__() self.proxy_mod = OrigModule() - def forward(self, input: Tensor) -> Tensor: + def forward(self, input): + # type: (Tensor) -> Tensor return self.proxy_mod.forward(input) class NewModuleMethodNotLazyCompile(nn.Module): def __init__(self): super(NewModuleMethodNotLazyCompile, self).__init__() - def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: + def one(self, inp1, inp2): + # type: (Tensor, Tensor) -> Tensor return inp1 * inp2 + 1 - def forward(self, input: Tensor) -> Tensor: + def forward(self, input): + # type: (Tensor) -> Tensor return input + 1 scripted_mod = torch.jit.script(TestModule()) @@ -320,10 +362,12 @@ def __init__(self): super(NewModuleMethodManualExport, self).__init__() @torch.jit.export - def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: + def one(self, inp1, inp2): + # type: (Tensor, Tensor) -> Tensor return inp1 * inp2 + 1 - def forward(self, input: Tensor) -> Tensor: + def forward(self, input): + # type: (Tensor) -> Tensor return input + 1 scripted_mod.proxy_mod = torch.jit.script(NewModuleMethodManualExport()) @@ -337,7 +381,8 @@ def __init__(self): super(TestNoModuleInterface, self).__init__() self.proxy_mod = OrigModule() - def forward(self, input: Tensor) -> Tensor: + def forward(self, input): + # type: (Tensor) -> Tensor return self.proxy_mod(input) scripted_no_module_interface = torch.jit.script(TestNoModuleInterface()) @@ -352,10 +397,12 @@ def forward(self, input: Tensor) -> Tensor: def test_script_module_as_interface_swap(self): @torch.jit.interface class ModuleInterface(nn.Module): - def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: + def one(self, inp1, inp2): + # type: (Tensor, Tensor) -> Tensor pass - def forward(self, input: Tensor) -> Tensor: + def forward(self, input): + # type: (Tensor) -> Tensor pass class OrigScriptModule(torch.jit.ScriptModule): @@ -363,11 +410,13 @@ def __init__(self): super(OrigScriptModule, self).__init__() @torch.jit.script_method - def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: + def one(self, inp1, inp2): + # type: (Tensor, Tensor) -> Tensor return inp1 + inp2 + 1 @torch.jit.script_method - def forward(self, input: Tensor) -> Tensor: + def forward(self, input): + # type: (Tensor) -> Tensor return input + self.one(input, input) + 1 class NewScriptModule(torch.jit.ScriptModule): @@ -375,11 +424,13 @@ def __init__(self): super(NewScriptModule, self).__init__() @torch.jit.script_method - def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: + def one(self, inp1, inp2): + # type: (Tensor, Tensor) -> Tensor return inp1 * inp2 + 1 @torch.jit.script_method - def forward(self, input: Tensor) -> Tensor: + def forward(self, input): + # type: (Tensor) -> Tensor return self.one(input, input + 1) class TestNNModuleWithScriptModule(nn.Module): @@ -389,7 +440,8 @@ def __init__(self): super(TestNNModuleWithScriptModule, self).__init__() self.proxy_mod = OrigScriptModule() - def forward(self, input: Tensor) -> Tensor: + def forward(self, input): + # type: (Tensor) -> Tensor return self.proxy_mod.forward(input) input = torch.randn(3, 4) @@ -420,7 +472,8 @@ def forward(self, x): @torch.jit.interface class ModInterface(torch.nn.Module): - def forward(self, x: Tensor) -> int: + def forward(self, x): + # type: (Tensor) -> int pass class TestModule(torch.nn.Module): @@ -467,7 +520,8 @@ def forward(self, x): @torch.jit.interface class ModInterface(torch.nn.Module): - def forward(self, x: Tensor) -> int: + def forward(self, x): + # type: (Tensor) -> int pass class TestModule(torch.nn.Module): @@ -510,7 +564,8 @@ def forward(self, x): @torch.jit.interface class ModInterface(torch.nn.Module): - def forward(self, x: Tensor) -> Tensor: + def forward(self, x): + # type: (Tensor) -> Tensor pass class TestModule(torch.nn.Module): @@ -555,7 +610,8 @@ def forward(self, x): @torch.jit.interface class ModInterface(torch.nn.Module): - def forward(self, x: Tensor) -> Tensor: + def forward(self, x): + # type: (Tensor) -> Tensor pass class TestModule(torch.nn.Module): @@ -597,7 +653,8 @@ def forward(self, x): @torch.jit.interface class ModInterface(torch.nn.Module): - def forward(self, x: Tensor) -> Tensor: + def forward(self, x): + # type: (Tensor) -> Tensor pass class TestModule(torch.nn.Module): @@ -631,7 +688,8 @@ def forward(self, x): def test_module_apis_interface(self): @torch.jit.interface class ModuleInterface(nn.Module): - def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: + def one(self, inp1, inp2): + # type: (Tensor, Tensor) -> Tensor pass class TestModule(nn.Module): diff --git a/test/jit/test_recursive_script.py b/test/jit/test_recursive_script.py index 3532e62bfe04..a84d9d7256b8 100644 --- a/test/jit/test_recursive_script.py +++ b/test/jit/test_recursive_script.py @@ -284,7 +284,8 @@ def forward(self, x): test_module_dir(nn.ModuleDict(OrderedDict([("conv", conv), ("linear", linear)]))) def test_class_compile(self): - def other_fn(a: int, b: Tensor) -> Tensor: + def other_fn(a, b): + # type: (int, Tensor) -> Tensor return a * b class B(object): @@ -306,7 +307,8 @@ def forward(self, x): self.checkModule(N(), (torch.randn(2, 2),)) def test_error_stack(self): - def d(x: int) -> int: + def d(x): + # type: (int) -> int return x + 10 def c(x): @@ -329,7 +331,8 @@ def a(x): checker.run(str(e)) def test_error_stack_module(self): - def d(x: int) -> int: + def d(x): + # type: (int) -> int return x + 10 def c(x): @@ -562,7 +565,8 @@ def __init__(self): self.a = 4 self.inner = Inner2() - def __setstate__(self, obj: Tuple[int, Inner2]) -> None: + def __setstate__(self, obj): + # type: (Tuple[int, Inner2]) -> None a, inner = obj self.a = a self.inner = inner diff --git a/test/jit/test_save_load.py b/test/jit/test_save_load.py index 12d42a5150cd..eb14ca8350af 100644 --- a/test/jit/test_save_load.py +++ b/test/jit/test_save_load.py @@ -680,7 +680,8 @@ def test_different_interfaces(self): """ @torch.jit.interface class MyInterface(object): - def bar(self, x: Tensor) -> Tensor: + def bar(self, x): + # type: (Tensor) -> Tensor pass @torch.jit.script @@ -710,7 +711,8 @@ def forward(self, x): @torch.jit.interface class MyInterface(object): - def not_bar(self, x: Tensor) -> Tensor: + def not_bar(self, x): + # type: (Tensor) -> Tensor pass @torch.jit.script # noqa: F811 @@ -765,7 +767,8 @@ class MyCoolNamedTuple(NamedTuple): @torch.jit.interface class MyInterface(object): - def bar(self, x: Tensor) -> Tensor: + def bar(self, x): + # type: (Tensor) -> Tensor pass @torch.jit.script @@ -806,7 +809,8 @@ def forward(self, x): @torch.jit.interface class MyInterface(object): - def not_bar(self, x: Tensor) -> Tensor: + def not_bar(self, x): + # type: (Tensor) -> Tensor pass @torch.jit.script # noqa F811 diff --git a/test/jit/test_tracer.py b/test/jit/test_tracer.py index 841682666964..366ca1af69e6 100644 --- a/test/jit/test_tracer.py +++ b/test/jit/test_tracer.py @@ -25,7 +25,7 @@ # Standard library from collections import namedtuple from itertools import chain -from typing import Dict, Optional +from typing import Dict import warnings if __name__ == '__main__': @@ -1862,11 +1862,13 @@ def forward(self, inputs): class TestMixTracingScripting(JitTestCase): def test_trace_script(self): @torch.jit.script - def func1(x: Tuple[Tensor, Tensor]) -> Tensor: + def func1(x): + # type: (Tuple[Tensor, Tensor]) -> Tensor return x[0] + x[1] @torch.jit.script - def func2(x: List[Tensor]) -> Tensor: + def func2(x): + # type: (List[Tensor]) -> Tensor return x[0] + x[1] a = torch.randn(5) @@ -1876,7 +1878,8 @@ def func2(x: List[Tensor]) -> Tensor: self.checkTrace(func2, ((a, b),)) @torch.jit.script - def func3(x: Tensor, method: str = 'bilinear', align_corners: bool = True) -> Tensor: + def func3(x, method='bilinear', align_corners=True): + # type: (Tensor, str, bool) -> Tensor hw = x.shape[2:4] return F.interpolate(x, hw, mode=method, align_corners=align_corners) @@ -1884,7 +1887,8 @@ def func3(x: Tensor, method: str = 'bilinear', align_corners: bool = True) -> Te self.checkTrace(func3, (inp,)) @torch.jit.script - def func4(x: Tensor, a: List[Optional[str]]) -> Tensor: + def func4(x, a): + # type: (Tensor, List[Optional[str]]) -> Tensor if len(a) == 2: return x + 2 else: diff --git a/test/jit/test_with.py b/test/jit/test_with.py index f958dc46c39a..ffd0631639f6 100644 --- a/test/jit/test_with.py +++ b/test/jit/test_with.py @@ -1,7 +1,7 @@ import os import sys -from typing import Any, List +from typing import Any import torch from torch.testing._internal.jit_utils import JitTestCase @@ -50,7 +50,8 @@ def __enter__(self): def __exit__(self, type: Any, value: Any, tb: Any): self.count.sub_(0.3) - def test_basic(x: torch.Tensor) -> torch.Tensor: + def test_basic(x): + # type: (Tensor) -> Tensor """Basic test with one with-statement.""" c = Context(1) @@ -61,7 +62,8 @@ def test_basic(x: torch.Tensor) -> torch.Tensor: y *= c.count return y - def test_pass(x: torch.Tensor) -> torch.Tensor: + def test_pass(x): + # type: (Tensor) -> Tensor """ Test with a pass statement inside a with-statement. Although the body of the with is empty, __enter__ and __exit__ should @@ -75,7 +77,8 @@ def test_pass(x: torch.Tensor) -> torch.Tensor: x *= c.count return x - def test_early_return(x: torch.Tensor, c: Context) -> torch.Tensor: + def test_early_return(x, c): + # type: (Tensor, Context) -> Tensor """ Test that returning early from inside a with-statement works as expected. @@ -87,7 +90,8 @@ def test_early_return(x: torch.Tensor, c: Context) -> torch.Tensor: x = y + y return x - def test_conditional_early_return(x: torch.Tensor, c: Context) -> torch.Tensor: + def test_conditional_early_return(x, c): + # type: (Tensor, Context) -> Tensor """ Test that conditionally returning early from inside a with-statement works as expected. @@ -100,7 +104,8 @@ def test_conditional_early_return(x: torch.Tensor, c: Context) -> torch.Tensor: x = y + y return x - def test_break(x: torch.Tensor, c: Context, l: List[int]) -> torch.Tensor: + def test_break(x, c, l): + # type: (Tensor, Context, List[int]) -> Tensor """ Test that breaking early from inside a with-statement works as expected. @@ -113,7 +118,8 @@ def test_break(x: torch.Tensor, c: Context, l: List[int]) -> torch.Tensor: return x - def test_continue(x: torch.Tensor, c: Context, l: List[int]) -> torch.Tensor: + def test_continue(x, c, l): + # type: (Tensor, Context, List[int]) -> Tensor """ Test that using continue inside a with-statement works as expected. @@ -126,7 +132,8 @@ def test_continue(x: torch.Tensor, c: Context, l: List[int]) -> torch.Tensor: return x - def test_serial(x: torch.Tensor) -> torch.Tensor: + def test_serial(x): + # type: (Tensor) -> Tensor """ Test two with-statements in a row. """ @@ -140,7 +147,8 @@ def test_serial(x: torch.Tensor) -> torch.Tensor: return y - def test_nested(x: torch.Tensor) -> torch.Tensor: + def test_nested(x): + # type: (Tensor) -> Tensor """ Test nested with-statements. """ @@ -154,7 +162,8 @@ def test_nested(x: torch.Tensor) -> torch.Tensor: return y - def test_combined(x: torch.Tensor) -> torch.Tensor: + def test_combined(x): + # type: (Tensor) -> Tensor """ Test a with-statement with multiple with items. """ @@ -206,7 +215,8 @@ def __enter__(self): def __exit__(self, type: Any, value: Any, tb: Any): self.count.sub_(0.3) - def test_basic(x: torch.Tensor) -> torch.Tensor: + def test_basic(x): + # type: (Tensor) -> Tensor """Basic test with one with-statement.""" c = Context(1) @@ -217,7 +227,8 @@ def test_basic(x: torch.Tensor) -> torch.Tensor: y *= c.count return y - def test_pass(x: torch.Tensor) -> torch.Tensor: + def test_pass(x): + # type: (Tensor) -> Tensor """ Test with a pass statement inside a with-statement. Although the body of the with is empty, __enter__ and __exit__ should @@ -231,7 +242,8 @@ def test_pass(x: torch.Tensor) -> torch.Tensor: x *= c.count return x - def test_early_return(x: torch.Tensor, c: Context) -> torch.Tensor: + def test_early_return(x, c): + # type: (Tensor, Context) -> Tensor """ Test that returning early from inside a with-statement works as expected. @@ -243,7 +255,8 @@ def test_early_return(x: torch.Tensor, c: Context) -> torch.Tensor: x = y + y return x - def test_conditional_early_return(x: torch.Tensor, c: Context) -> torch.Tensor: + def test_conditional_early_return(x, c): + # type: (Tensor, Context) -> Tensor """ Test that conditionally returning early from inside a with-statement works as expected. @@ -256,7 +269,8 @@ def test_conditional_early_return(x: torch.Tensor, c: Context) -> torch.Tensor: x = y + y return x - def test_break(x: torch.Tensor, c: Context, l: List[int]) -> torch.Tensor: + def test_break(x, c, l): + # type: (Tensor, Context, List[int]) -> Tensor """ Test that breaking early from inside a with-statement works as expected. @@ -269,7 +283,8 @@ def test_break(x: torch.Tensor, c: Context, l: List[int]) -> torch.Tensor: return x - def test_continue(x: torch.Tensor, c: Context, l: List[int]) -> torch.Tensor: + def test_continue(x, c, l): + # type: (Tensor, Context, List[int]) -> Tensor """ Test that using continue inside a with-statement works as expected. @@ -282,7 +297,8 @@ def test_continue(x: torch.Tensor, c: Context, l: List[int]) -> torch.Tensor: return x - def test_serial(x: torch.Tensor) -> torch.Tensor: + def test_serial(x): + # type: (Tensor) -> Tensor """ Test two with-statements in a row. """ @@ -296,7 +312,8 @@ def test_serial(x: torch.Tensor) -> torch.Tensor: return y - def test_nested(x: torch.Tensor) -> torch.Tensor: + def test_nested(x): + # type: (Tensor) -> Tensor """ Test nested with-statements. """ @@ -310,7 +327,8 @@ def test_nested(x: torch.Tensor) -> torch.Tensor: return y - def test_combined(x: torch.Tensor) -> torch.Tensor: + def test_combined(x): + # type: (Tensor) -> Tensor """ Test a with-statement with multiple with items. """ @@ -363,11 +381,13 @@ def __exit__(self, type: Any, value: Any, tb: Any): self.count.sub_(0.3) @torch.jit.script - def method_that_raises() -> torch.Tensor: + def method_that_raises(): + # type: () -> Tensor raise Exception("raised exception") @torch.jit.script - def test_exception(x: torch.Tensor, c: Context) -> torch.Tensor: + def test_exception(x, c): + # type: (Tensor, Context) -> Tensor """ Test the case in which an exception is thrown while executing the body of a with-statement. """ @@ -377,7 +397,8 @@ def test_exception(x: torch.Tensor, c: Context) -> torch.Tensor: return x @torch.jit.script - def test_exception_nested(x: torch.Tensor, c: Context) -> torch.Tensor: + def test_exception_nested(x, c): + # type: (Tensor, Context) -> Tensor """ Test the case in which an exception is thrown while executing the body of a nested with-statement. """ @@ -388,7 +409,8 @@ def test_exception_nested(x: torch.Tensor, c: Context) -> torch.Tensor: return x @torch.jit.script - def with_that_raises(c: Context) -> torch.Tensor: + def with_that_raises(c): + # type: (Context) -> Tensor a = torch.tensor([1]) with c as _: @@ -397,7 +419,8 @@ def with_that_raises(c: Context) -> torch.Tensor: return a @torch.jit.script - def test_exception_fn_call(x: torch.Tensor, c: Context) -> torch.Tensor: + def test_exception_fn_call(x, c): + # type: (Tensor, Context) -> Tensor """ Test the case in which an exception is thrown while there are active with-statements in two different frames. @@ -483,25 +506,29 @@ def __enter__(self): def __exit__(self, type: Any, value: int, tb: int): pass - def test_no_enter_no_exit(x: torch.Tensor, c: NoEnterNoExit) -> torch.Tensor: + def test_no_enter_no_exit(x, c): + # type: (Tensor, NoEnterNoExit) -> Tensor with c as _: pass return x - def test_bad_enter(x: torch.Tensor, c: BadEnter) -> torch.Tensor: + def test_bad_enter(x, c): + # type: (Tensor, BadEnter) -> Tensor with c as _: pass return x - def test_bad_exit(x: torch.Tensor, c: BadExit) -> torch.Tensor: + def test_bad_exit(x, c): + # type: (Tensor, BadExit) -> Tensor with c as _: pass return x - def test_exit_incorrect_types(x: torch.Tensor, c: ExitIncorrectTypes) -> torch.Tensor: + def test_exit_incorrect_types(x, c): + # type: (Tensor, ExitIncorrectTypes) -> Tensor with c as _: pass @@ -538,7 +565,8 @@ def test_with_no_grad(self): """ # Basic no_grad test. - def test_no_grad(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + def test_no_grad(x, y): + # type: (Tensor, Tensor) -> Tensor with torch.no_grad(): w = x + y @@ -555,7 +583,8 @@ def test_no_grad(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # Test assignment of a grad-less Tensor to a Tensor with gradients # in a no_grad block. - def test_no_grad_assignment(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + def test_no_grad_assignment(x, y): + # type: (Tensor, Tensor) -> Tensor with torch.no_grad(): x[0] = y @@ -574,11 +603,13 @@ def __init__(self): super().__init__() @torch.jit.ignore - def adder(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + def adder(self, x, y): + # type: (Tensor, Tensor) -> Tensor w = x + y return w - def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + def forward(self, x, y): + # type: (Tensor, Tensor) -> Tensor with torch.no_grad(): w = self.adder(x, y) @@ -594,7 +625,8 @@ def test_with_record_function(self): Check that torch.autograd.profiler.record_function context manager is torchscriptable. """ - def with_rf(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + def with_rf(x, y): + # type: (Tensor, Tensor) -> Tensor with torch.autograd.profiler.record_function("foo"): # Nested record_function. with torch.autograd.profiler.record_function("nested"): From 8f31621f78789e3179007dfde5087bda85c3bb29 Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Fri, 8 Jan 2021 13:14:33 -0800 Subject: [PATCH 50/54] Fix MKL builds on Ubuntu (#50212) Summary: This fixes https://github.com/pytorch/pytorch/issues/50211 Pull Request resolved: https://github.com/pytorch/pytorch/pull/50212 Reviewed By: janeyx99 Differential Revision: D25850876 Pulled By: walterddr fbshipit-source-id: be138db3ae370c45f5fbf3af486cf8b32518df87 --- CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 3df73f8a3041..415e2676a14a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -318,7 +318,7 @@ set(OP_DEPENDENCY "" CACHE STRING # symbol lookup error: miniconda3/envs/pytorch-py3.7/lib/libmkl_intel_lp64.so: undefined symbol: mkl_blas_dsyrk # https://software.intel.com/en-us/articles/symbol-lookup-error-when-linking-intel-mkl-with-gcc-on-ubuntu if(LINUX) - set(CMAKE_SHARED_LINKER_FLAGS "-Wl,--no-as-needed ${CMAKE_SHARED_LINKER_FLAGS}") + set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -Wl,--no-as-needed") endif() if(MSVC) From 2c4b6ec45723ebc184d958c12f56961ee2fa6ca4 Mon Sep 17 00:00:00 2001 From: Alex Henrie Date: Fri, 8 Jan 2021 13:29:39 -0800 Subject: [PATCH 51/54] Unused exception variables (#50181) Summary: These unused variables were identified by [pyflakes](https://pypi.org/project/pyflakes/). They can be safely removed to simplify the code. Pull Request resolved: https://github.com/pytorch/pytorch/pull/50181 Reviewed By: gchanan Differential Revision: D25844270 fbshipit-source-id: 0e648ffe8c6db6daf56788a13ba89806923cbb76 --- torch/_jit_internal.py | 2 +- torch/distributed/launch.py | 2 +- torch/distributed/rpc/api.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/torch/_jit_internal.py b/torch/_jit_internal.py index fca1c45377e7..4ccffc2c8362 100644 --- a/torch/_jit_internal.py +++ b/torch/_jit_internal.py @@ -80,7 +80,7 @@ def parseExpr(expr, module): value, len_parsed = parseNestedExpr(expr, module) assert len_parsed == len(expr), "whole expression was not parsed, falling back to c++ parser" return value - except Exception as e: + except Exception: """ The python resolver fails in several cases in known unit tests, and is intended to fall back gracefully to the c++ resolver in general. For example, python 2 style diff --git a/torch/distributed/launch.py b/torch/distributed/launch.py index 58fc42b33dbf..a6a5b26e6d40 100644 --- a/torch/distributed/launch.py +++ b/torch/distributed/launch.py @@ -295,7 +295,7 @@ def sigkill_handler(signum, frame): print(f"Killing subprocess {process.pid}") try: process.kill() - except Exception as e: + except Exception: pass if last_return_code is not None: raise subprocess.CalledProcessError(returncode=last_return_code, cmd=cmd) diff --git a/torch/distributed/rpc/api.py b/torch/distributed/rpc/api.py index 0a99df67269e..ce200dba392f 100644 --- a/torch/distributed/rpc/api.py +++ b/torch/distributed/rpc/api.py @@ -381,7 +381,7 @@ def _rref_typeof_on_user(rref): # Combine the implementation class and the type class. class RRef(PyRRef, Generic[T]): pass -except TypeError as exc: +except TypeError: # TypeError: metaclass conflict: the metaclass of a derived class # must be a (non-strict) subclass of the metaclasses of all its bases # Mypy doesn't understand __class__ (mypy bug #4177) From aa18d174553e1c2ab4d9e09ae45c6ac323f04af5 Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Fri, 8 Jan 2021 13:47:41 -0800 Subject: [PATCH 52/54] add type annotations to torch.nn.modules.fold (#49479) Summary: closes gh-49478 Fixes https://github.com/pytorch/pytorch/issues/49478 Pull Request resolved: https://github.com/pytorch/pytorch/pull/49479 Reviewed By: mruberry Differential Revision: D25723838 Pulled By: walterddr fbshipit-source-id: 45c4cbd6f147b6dc4a5f5419c17578c49c201022 --- mypy.ini | 3 --- torch/nn/functional.pyi.in | 10 +++++----- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/mypy.ini b/mypy.ini index 0c99a9c62d18..8639ad1b44ad 100644 --- a/mypy.ini +++ b/mypy.ini @@ -76,9 +76,6 @@ ignore_errors = True [mypy-torch.nn.modules.conv] ignore_errors = True -[mypy-torch.nn.modules.fold] -ignore_errors = True - [mypy-torch.nn.modules.module] ignore_errors = True diff --git a/torch/nn/functional.pyi.in b/torch/nn/functional.pyi.in index 208dc7c2df40..11eb4c404dc6 100644 --- a/torch/nn/functional.pyi.in +++ b/torch/nn/functional.pyi.in @@ -1,7 +1,7 @@ from torch import Tensor from torch.types import _size from typing import Any, Optional, Tuple, Dict, List, Callable, Sequence, Union -from .common_types import _ratio_any_t, _size_1_t, _size_2_t, _size_3_t, _size_2_opt_t, _size_3_opt_t +from .common_types import _ratio_any_t, _size_any_t, _size_1_t, _size_2_t, _size_3_t, _size_2_opt_t, _size_3_opt_t # 'TypedDict' is a new accepted type that represents a dictionary with a fixed set of allowed keys. # It is standards-track but not in `typing` yet. We leave this hear to be uncommented once the feature @@ -335,12 +335,12 @@ def normalize(input: Tensor, p: float = ..., dim: int = ..., eps: float = ..., def assert_int_or_pair(arg: Any, arg_name: Any, message: Any) -> None: ... -def unfold(input: Tensor, kernel_size: _size, dilation: _size = ..., padding: _size = ..., - stride: _size = ...) -> Tensor: ... +def unfold(input: Tensor, kernel_size: _size_any_t, dilation: _size_any_t = ..., padding: _size_any_t = ..., + stride: _size_any_t = ...) -> Tensor: ... -def fold(input: Tensor, output_size: _size, kernel_size: _size, dilation: _size = ..., padding: _size = ..., - stride: _size = ...) -> Tensor: ... +def fold(input: Tensor, output_size: _size_any_t, kernel_size: _size_any_t, dilation: _size_any_t = ..., padding: _size_any_t = ..., + stride: _size_any_t = ...) -> Tensor: ... def multi_head_attention_forward(query: Tensor, From 1c12cbea90db347cdbd5b82f8e80e6eafba5019b Mon Sep 17 00:00:00 2001 From: Ashkan Aliabadi Date: Fri, 8 Jan 2021 16:36:28 -0800 Subject: [PATCH 53/54] Optimize Vulkan command buffer submission rate. (#49112) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49112 Differential Revision: D25729889 Test Plan: Imported from OSS Reviewed By: SS-JIA Pulled By: AshkanAliabadi fbshipit-source-id: c4ab470fdcf3f83745971986f3a44a3dff69287f --- CMakeLists.txt | 2 +- aten/src/ATen/native/vulkan/api/Cache.h | 4 + aten/src/ATen/native/vulkan/api/Command.cpp | 273 ++++++--- aten/src/ATen/native/vulkan/api/Command.h | 65 +- aten/src/ATen/native/vulkan/api/Common.h | 11 +- aten/src/ATen/native/vulkan/api/Context.cpp | 37 +- .../src/ATen/native/vulkan/api/Descriptor.cpp | 76 ++- aten/src/ATen/native/vulkan/api/Descriptor.h | 33 +- aten/src/ATen/native/vulkan/api/Pipeline.cpp | 4 + aten/src/ATen/native/vulkan/api/Pipeline.h | 18 +- aten/src/ATen/native/vulkan/api/Resource.cpp | 55 +- aten/src/ATen/native/vulkan/api/Resource.h | 41 +- aten/src/ATen/native/vulkan/api/Runtime.cpp | 2 + aten/src/ATen/native/vulkan/api/Runtime.h | 14 +- aten/src/ATen/native/vulkan/api/Shader.cpp | 4 + aten/src/ATen/native/vulkan/api/Shader.h | 30 +- aten/src/ATen/native/vulkan/ops/Add.cpp | 44 +- aten/src/ATen/native/vulkan/ops/Clamp.cpp | 22 +- aten/src/ATen/native/vulkan/ops/Common.h | 8 - .../ATen/native/vulkan/ops/Convolution.cpp | 578 ++++++++++-------- aten/src/ATen/native/vulkan/ops/Copy.cpp | 252 ++++---- aten/src/ATen/native/vulkan/ops/Mean.cpp | 84 +-- aten/src/ATen/native/vulkan/ops/Mm.cpp | 51 +- aten/src/ATen/native/vulkan/ops/Mul.cpp | 22 +- aten/src/ATen/native/vulkan/ops/Pool.cpp | 24 +- aten/src/ATen/native/vulkan/ops/Shape.cpp | 7 +- aten/src/ATen/native/vulkan/ops/Tensor.cpp | 200 ++---- aten/src/ATen/native/vulkan/ops/Tensor.h | 41 +- aten/src/ATen/native/vulkan/ops/Upsample.cpp | 11 +- aten/src/ATen/native/vulkan/ops/Utils.h | 2 +- 30 files changed, 1057 insertions(+), 958 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 415e2676a14a..abef53630b7a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -207,7 +207,7 @@ cmake_dependent_option( USE_VALGRIND "Use Valgrind. Only available on Linux." ON "LINUX" OFF) option(USE_VULKAN "Use Vulkan GPU backend" OFF) -option(USE_VULKAN_FP16_INFERENCE "Vulkan - Use fp16 inference even on fp32 tensors" ON) +option(USE_VULKAN_FP16_INFERENCE "Vulkan - Use fp16 inference even on fp32 tensors" OFF) option(USE_VULKAN_RELAXED_PRECISION "Vulkan - Use relaxed precision math in the kernels (mediump)" OFF) option(USE_VULKAN_SHADERC_RUNTIME "Vulkan - Use runtime shader compilation (needs libshaderc)" OFF) option(USE_VULKAN_WRAPPER "Vulkan - Dynamically load Vulkan functions" ON) diff --git a/aten/src/ATen/native/vulkan/api/Cache.h b/aten/src/ATen/native/vulkan/api/Cache.h index b224adbbeeda..a93385088277 100644 --- a/aten/src/ATen/native/vulkan/api/Cache.h +++ b/aten/src/ATen/native/vulkan/api/Cache.h @@ -62,6 +62,10 @@ class Cache final { Factory factory_; }; +// +// Impl +// + template inline Cache::Cache(Factory factory) : factory_(std::move(factory)) { diff --git a/aten/src/ATen/native/vulkan/api/Command.cpp b/aten/src/ATen/native/vulkan/api/Command.cpp index 5aa3586d4683..247b51fa5395 100644 --- a/aten/src/ATen/native/vulkan/api/Command.cpp +++ b/aten/src/ATen/native/vulkan/api/Command.cpp @@ -76,6 +76,25 @@ Command::Buffer::Buffer(const VkCommandBuffer command_buffer) "Invalid Vulkan command buffer!"); } +Command::Buffer::Buffer(Buffer&& buffer) + : command_buffer_(std::move(buffer.command_buffer_)), + bound_(std::move(buffer.bound_)), + barriers_(std::move(buffer.barriers_)) { + buffer.invalidate(); +} + +Command::Buffer& Command::Buffer::operator=(Buffer&& buffer) { + if (&buffer != this) { + command_buffer_ = std::move(buffer.command_buffer_); + bound_ = std::move(buffer.bound_); + barriers_ = std::move(buffer.barriers_); + + buffer.invalidate(); + }; + + return *this; +} + void Command::Buffer::Buffer::begin() { TORCH_INTERNAL_ASSERT_DEBUG_ONLY( command_buffer_, @@ -107,69 +126,6 @@ void Command::Buffer::Buffer::end() { VK_CHECK(vkEndCommandBuffer(command_buffer_)); } -void Command::Buffer::barrier() { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - command_buffer_, - "This command buffer is in an invalid state! " - "Potential reason: This command buffer is moved from."); - - if (barriers_.stage) { - c10::SmallVector buffer_memory_barriers; - - for (const Resource::Buffer::Barrier& barrier : barriers_.buffers) { - buffer_memory_barriers.push_back({ - VK_STRUCTURE_TYPE_BUFFER_MEMORY_BARRIER, - nullptr, - barrier.memory.src, - barrier.memory.dst, - VK_QUEUE_FAMILY_IGNORED, - VK_QUEUE_FAMILY_IGNORED, - barrier.object.handle, - barrier.object.offset, - barrier.object.range, - }); - } - - c10::SmallVector image_memory_barriers; - - for (const Resource::Image::Barrier& barrier : barriers_.images) { - image_memory_barriers.push_back({ - VK_STRUCTURE_TYPE_IMAGE_MEMORY_BARRIER, - nullptr, - barrier.memory.src, - barrier.memory.dst, - barrier.layout.src, - barrier.layout.dst, - VK_QUEUE_FAMILY_IGNORED, - VK_QUEUE_FAMILY_IGNORED, - barrier.object.handle, - { - VK_IMAGE_ASPECT_COLOR_BIT, - 0u, - VK_REMAINING_MIP_LEVELS, - 0u, - VK_REMAINING_ARRAY_LAYERS, - }, - }); - } - - vkCmdPipelineBarrier( - command_buffer_, - barriers_.stage.src, - barriers_.stage.dst, - 0u, - 0u, - nullptr, - buffer_memory_barriers.size(), - buffer_memory_barriers.data(), - image_memory_barriers.size(), - image_memory_barriers.data()); - } - - // Reset - barriers_.reset(); -} - void Command::Buffer::barrier(const Pipeline::Barrier& barrier) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY( command_buffer_, @@ -291,31 +247,86 @@ void Command::Buffer::dispatch( bound_.pipeline.local_work_group.data[2u])); } -void Command::Buffer::submit( - const VkQueue queue, - const Resource::Fence fence) { +void Command::Buffer::barrier() { TORCH_INTERNAL_ASSERT_DEBUG_ONLY( command_buffer_, "This command buffer is in an invalid state! " "Potential reason: This command buffer is moved from."); - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - queue, - "Invalid Vulkan queue!"); + if (barriers_.stage) { + c10::SmallVector buffer_memory_barriers; - const VkSubmitInfo submit_info{ - VK_STRUCTURE_TYPE_SUBMIT_INFO, - nullptr, - 0u, - nullptr, - nullptr, - 1u, - &command_buffer_, - 0u, - nullptr, - }; + for (const Resource::Buffer::Barrier& barrier : barriers_.buffers) { + buffer_memory_barriers.push_back({ + VK_STRUCTURE_TYPE_BUFFER_MEMORY_BARRIER, + nullptr, + barrier.memory.src, + barrier.memory.dst, + VK_QUEUE_FAMILY_IGNORED, + VK_QUEUE_FAMILY_IGNORED, + barrier.object.handle, + barrier.object.offset, + barrier.object.range, + }); + } + + c10::SmallVector image_memory_barriers; - VK_CHECK(vkQueueSubmit(queue, 1u, &submit_info, fence.handle())); + for (const Resource::Image::Barrier& barrier : barriers_.images) { + image_memory_barriers.push_back({ + VK_STRUCTURE_TYPE_IMAGE_MEMORY_BARRIER, + nullptr, + barrier.memory.src, + barrier.memory.dst, + barrier.layout.src, + barrier.layout.dst, + VK_QUEUE_FAMILY_IGNORED, + VK_QUEUE_FAMILY_IGNORED, + barrier.object.handle, + { + VK_IMAGE_ASPECT_COLOR_BIT, + 0u, + VK_REMAINING_MIP_LEVELS, + 0u, + VK_REMAINING_ARRAY_LAYERS, + }, + }); + } + + vkCmdPipelineBarrier( + command_buffer_, + barriers_.stage.src, + barriers_.stage.dst, + 0u, + 0u, + nullptr, + buffer_memory_barriers.size(), + buffer_memory_barriers.data(), + image_memory_barriers.size(), + image_memory_barriers.data()); + } + + // Reset + barriers_.reset(); +} + +void Command::Buffer::invalidate() { + command_buffer_ = VK_NULL_HANDLE; +} + +inline void Command::Buffer::Bound::reset() { + pipeline = {}; + descriptor_set = VK_NULL_HANDLE; +} + +inline Command::Buffer::Barrier::Stage::operator bool() const { + return (0u != src) || (0u != dst); +} + +inline void Command::Buffer::Barrier::reset() { + stage = {}; + buffers.clear(); + images.clear(); } Command::Pool::Pool(const GPU& gpu) @@ -338,8 +349,9 @@ Command::Pool::Pool(const GPU& gpu) Command::Pool::Pool(Pool&& pool) : device_(std::move(pool.device_)), command_pool_(std::move(pool.command_pool_)), - buffer_(std::move(pool.buffer_)) { - pool.device_ = VK_NULL_HANDLE; + buffer_(std::move(pool.buffer_)), + stream_(std::move(pool.stream_)) { + pool.invalidate(); } Command::Pool& Command::Pool::operator=(Pool&& pool) { @@ -347,8 +359,9 @@ Command::Pool& Command::Pool::operator=(Pool&& pool) { device_ = std::move(pool.device_); command_pool_ = std::move(pool.command_pool_); buffer_ = std::move(pool.buffer_); + stream_ = std::move(pool.stream_); - pool.device_ = VK_NULL_HANDLE; + pool.invalidate(); }; return *this; @@ -383,25 +396,109 @@ Command::Buffer Command::Pool::allocate() { Configuration::kQuantum); allocate_command_buffers( - device_, - command_pool_.get(), - buffer_.pool.data() + buffer_.in_use, - Configuration::kQuantum); + device_, + command_pool_.get(), + buffer_.pool.data() + buffer_.in_use, + Configuration::kQuantum); } return Buffer(buffer_.pool[buffer_.in_use++]); } +Command::Buffer& Command::Pool::stream() { + if (!stream_.buffer) { + stream_.buffer = allocate(); + stream_.buffer.begin(); + stream_.counter = 0u; + } + + return stream_.buffer; +} + void Command::Pool::purge() { TORCH_INTERNAL_ASSERT_DEBUG_ONLY( device_ && command_pool_, "This command pool is in an invalid state! " "Potential reason: This command pool is moved from."); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + !stream_.buffer, + "Pending command buffer detected. Make sure all command buffers are " + "submitted to the queue for execution prior to reclaiming pool memory."); + buffer_.in_use = 0u; VK_CHECK(vkResetCommandPool(device_, command_pool_.get(), 0u)); } +void Command::Pool::submit( + const VkQueue queue, + const c10::ArrayRef buffers, + const Resource::Fence fence) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + device_ && command_pool_, + "This command pool is in an invalid state! " + "Potential reason: This command pool is moved from."); + + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + queue, + "Invalid Vulkan queue!"); + + c10::SmallVector command_buffers; + command_buffers.reserve(buffers.size()); + + for (const Buffer& buffer : buffers) { + VkCommandBuffer command_buffer = buffer.handle(); + + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + command_buffer, + "Invalid Vulkan command buffer!"); + + // Are we submitting our one and only command stream, or a regular command + // buffer whose scope is manually maintained by the user? Automatically + // maintain state and submission rate if the former. + + if (stream_.buffer.handle() == command_buffer) { + // Hand the stream off to the driver if: + // - The user has implictly signaled interest in the results via a fence. + // - We are over the submission cutoff. We don't want to starve the GPU. + + if (fence || (stream_.counter++ > Configuration::kSubmit)) { + stream_.buffer.end(); + stream_.buffer.invalidate(); + } + // Skip - Accumulate more calls prior to submission. + else { + command_buffer = VK_NULL_HANDLE; + } + } + + if (command_buffer) { + command_buffers.push_back(command_buffer); + } + } + + if (!command_buffers.empty()) { + const VkSubmitInfo submit_info{ + VK_STRUCTURE_TYPE_SUBMIT_INFO, + nullptr, + 0u, + nullptr, + nullptr, + command_buffers.size(), + command_buffers.data(), + 0u, + nullptr, + }; + + VK_CHECK(vkQueueSubmit(queue, 1u, &submit_info, fence.handle())); + } +} + +void Command::Pool::invalidate() { + device_ = VK_NULL_HANDLE; + command_pool_.reset(); +} + } // namespace api } // namespace vulkan } // namespace native diff --git a/aten/src/ATen/native/vulkan/api/Command.h b/aten/src/ATen/native/vulkan/api/Command.h index 42f073674be5..8b60d8bcd8f5 100644 --- a/aten/src/ATen/native/vulkan/api/Command.h +++ b/aten/src/ATen/native/vulkan/api/Command.h @@ -7,6 +7,7 @@ #include #include #include +#include namespace at { namespace native { @@ -14,13 +15,15 @@ namespace vulkan { namespace api { struct Command final { + class Pool; + // // Buffer // class Buffer final { public: - Buffer(VkCommandBuffer command_buffer = VK_NULL_HANDLE); + explicit Buffer(VkCommandBuffer command_buffer = VK_NULL_HANDLE); Buffer(const Buffer&) = delete; Buffer& operator=(const Buffer&) = delete; Buffer(Buffer&&); @@ -28,18 +31,22 @@ struct Command final { ~Buffer() = default; operator bool() const; + VkCommandBuffer handle() const; void begin(); void end(); + void barrier(const Pipeline::Barrier& barrier); void bind(const Pipeline::Object& pipeline); void bind(const Descriptor::Set& set); void copy(Resource::Buffer::Object source, Resource::Buffer::Object destination); void dispatch(const Shader::WorkGroup& global_work_group); - void submit(VkQueue queue, Resource::Fence fence = {}); private: + friend class Pool; + void barrier(); + void invalidate(); private: VkCommandBuffer command_buffer_; @@ -80,12 +87,22 @@ struct Command final { ~Pool(); Buffer allocate(); + Buffer& stream(); void purge(); + void submit( + VkQueue queue, + c10::ArrayRef buffers, + Resource::Fence fence = {}); + + private: + void invalidate(); + private: struct Configuration final { - static constexpr uint32_t kQuantum = 64u; - static constexpr uint32_t kReserve = 1024u; + static constexpr uint32_t kQuantum = 4u; + static constexpr uint32_t kReserve = 16u; + static constexpr uint32_t kSubmit = 10u; }; VkDevice device_; @@ -95,6 +112,11 @@ struct Command final { std::vector pool; size_t in_use; } buffer_; + + struct { + Buffer buffer; + uint32_t counter; + } stream_; } pool /* [thread_count] */; explicit Command(const GPU& gpu) @@ -106,43 +128,12 @@ struct Command final { // Impl // -inline Command::Buffer::Buffer(Buffer&& buffer) - : command_buffer_(std::move(buffer.command_buffer_)), - bound_(std::move(buffer.bound_)), - barriers_(std::move(buffer.barriers_)) { - buffer.command_buffer_ = VK_NULL_HANDLE; -} - -inline Command::Buffer& Command::Buffer::operator=(Buffer&& buffer) { - if (&buffer != this) { - command_buffer_ = std::move(buffer.command_buffer_); - bound_ = std::move(buffer.bound_); - barriers_ = std::move(buffer.barriers_); - - buffer.command_buffer_ = VK_NULL_HANDLE; - }; - - return *this; -} - inline Command::Buffer::operator bool() const { return VK_NULL_HANDLE != command_buffer_; } -inline void Command::Buffer::Bound::reset() { - pipeline = {}; - descriptor_set = VK_NULL_HANDLE; -} - -inline Command::Buffer::Barrier::Stage::operator bool() const { - return (0u != src) || - (0u != dst); -} - -inline void Command::Buffer::Barrier::reset() { - stage = {}; - buffers.clear(); - images.clear(); +inline VkCommandBuffer Command::Buffer::handle() const { + return command_buffer_; } } // namespace api diff --git a/aten/src/ATen/native/vulkan/api/Common.h b/aten/src/ATen/native/vulkan/api/Common.h index d606f1d859a9..49f9ffa21a22 100644 --- a/aten/src/ATen/native/vulkan/api/Common.h +++ b/aten/src/ATen/native/vulkan/api/Common.h @@ -6,10 +6,17 @@ #ifdef USE_VULKAN_SHADERC_RUNTIME #include -#define VK_KERNEL(name) { name##_glsl, } +#define VK_KERNEL(name) \ + ::at::native::vulkan::api::Shader::Descriptor{ \ + name##_glsl, \ + } #else #include -#define VK_KERNEL(name) { name##_spv, name##_spv_len, } +#define VK_KERNEL(name) \ + ::at::native::vulkan::api::Shader::Descriptor{ \ + name##_spv, \ + name##_spv_len, \ + } #endif /* USE_VULKAN_SHADERC_RUNTIME */ #ifdef USE_VULKAN_WRAPPER diff --git a/aten/src/ATen/native/vulkan/api/Context.cpp b/aten/src/ATen/native/vulkan/api/Context.cpp index 09dfa8fc1d77..0a9a6e130f4f 100644 --- a/aten/src/ATen/native/vulkan/api/Context.cpp +++ b/aten/src/ATen/native/vulkan/api/Context.cpp @@ -43,6 +43,40 @@ VkDevice create_device( &queue_priorities, }; + uint32_t device_extension_properties_count = 0; + VK_CHECK(vkEnumerateDeviceExtensionProperties( + physical_device, + nullptr, + &device_extension_properties_count, + nullptr)); + + std::vector device_extension_properties( + device_extension_properties_count); + + VK_CHECK(vkEnumerateDeviceExtensionProperties( + physical_device, + nullptr, + &device_extension_properties_count, + device_extension_properties.data())); + + constexpr const char* const requested_device_extensions[]{ + #ifdef VK_KHR_portability_subset + // https://vulkan.lunarg.com/doc/view/1.2.162.0/mac/1.2-extensions/vkspec.html#VUID-VkDeviceCreateInfo-pProperties-04451 + VK_KHR_PORTABILITY_SUBSET_EXTENSION_NAME, + #endif + }; + + std::vector enabled_device_extensions; + + for (const auto& requested_device_extension : requested_device_extensions) { + for (const auto& extension : device_extension_properties) { + if (strcmp(requested_device_extension, extension.extensionName) == 0) { + enabled_device_extensions.push_back(requested_device_extension); + break; + } + } + } + const VkDeviceCreateInfo device_create_info{ VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO, nullptr, @@ -51,7 +85,8 @@ VkDevice create_device( &device_queue_create_info, 0u, nullptr, - 0u, + static_cast(enabled_device_extensions.size()), + enabled_device_extensions.data(), nullptr, }; diff --git a/aten/src/ATen/native/vulkan/api/Descriptor.cpp b/aten/src/ATen/native/vulkan/api/Descriptor.cpp index 317536248987..5bdcb0b7fd02 100644 --- a/aten/src/ATen/native/vulkan/api/Descriptor.cpp +++ b/aten/src/ATen/native/vulkan/api/Descriptor.cpp @@ -128,27 +128,25 @@ Descriptor::Set::Set( "Invalid Vulkan descriptor set!"); } -void Descriptor::Set::update(const Item& item) { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - device_ && descriptor_set_, - "This descriptor set is in an invalid state! " - "Potential reason: This descriptor set is moved from."); +Descriptor::Set::Set(Set&& set) + : device_(std::move(set.device_)), + descriptor_set_(std::move(set.descriptor_set_)), + shader_layout_signature_(std::move(set.shader_layout_signature_)), + bindings_(std::move(set.bindings_)) { + set.invalidate(); +} - const auto items_itr = std::find_if( - bindings_.items.begin(), - bindings_.items.end(), - [binding = item.binding](const Item& other) { - return other.binding == binding; - }); +Descriptor::Set& Descriptor::Set::operator=(Set&& set) { + if (&set != this) { + device_ = std::move(set.device_); + descriptor_set_ = std::move(set.descriptor_set_); + shader_layout_signature_ = std::move(set.shader_layout_signature_); + bindings_ = std::move(set.bindings_); - if (bindings_.items.end() == items_itr) { - bindings_.items.emplace_back(item); - } - else { - *items_itr = item; - } + set.invalidate(); + }; - bindings_.dirty = true; + return *this; } Descriptor::Set& Descriptor::Set::bind( @@ -276,12 +274,39 @@ VkDescriptorSet Descriptor::Set::handle() const { return descriptor_set_; } +void Descriptor::Set::invalidate() { + device_ = VK_NULL_HANDLE; + descriptor_set_ = VK_NULL_HANDLE; +} + +void Descriptor::Set::update(const Item& item) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + device_ && descriptor_set_, + "This descriptor set is in an invalid state! " + "Potential reason: This descriptor set is moved from."); + + const auto items_itr = std::find_if( + bindings_.items.begin(), + bindings_.items.end(), + [binding = item.binding](const Item& other) { + return other.binding == binding; + }); + + if (bindings_.items.end() == items_itr) { + bindings_.items.emplace_back(item); + } + else { + *items_itr = item; + } + + bindings_.dirty = true; +} + Descriptor::Pool::Pool(const GPU& gpu) : device_(gpu.device), descriptor_pool_( create_descriptor_pool(gpu.device), - VK_DELETER(DescriptorPool)(device_)), - set_{} { + VK_DELETER(DescriptorPool)(device_)) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY( device_, "Invalid Vulkan device!"); @@ -295,7 +320,7 @@ Descriptor::Pool::Pool(Pool&& pool) : device_(std::move(pool.device_)), descriptor_pool_(std::move(pool.descriptor_pool_)), set_(std::move(pool.set_)) { - pool.device_ = VK_NULL_HANDLE; + pool.invalidate(); } Descriptor::Pool& Descriptor::Pool::operator=(Pool&& pool) { @@ -304,7 +329,7 @@ Descriptor::Pool& Descriptor::Pool::operator=(Pool&& pool) { descriptor_pool_ = std::move(pool.descriptor_pool_); set_ = std::move(pool.set_); - pool.device_ = VK_NULL_HANDLE; + pool.invalidate(); }; return *this; @@ -371,8 +396,13 @@ void Descriptor::Pool::purge() { "This descriptor pool is in an invalid state! " "Potential reason: This descriptor pool is moved from."); - set_.layouts.clear(); VK_CHECK(vkResetDescriptorPool(device_, descriptor_pool_.get(), 0u)); + set_.layouts.clear(); +} + +void Descriptor::Pool::invalidate() { + device_ = VK_NULL_HANDLE; + descriptor_pool_.reset(); } } // namespace api diff --git a/aten/src/ATen/native/vulkan/api/Descriptor.h b/aten/src/ATen/native/vulkan/api/Descriptor.h index 440bb9aa4097..6c50a350d7f3 100644 --- a/aten/src/ATen/native/vulkan/api/Descriptor.h +++ b/aten/src/ATen/native/vulkan/api/Descriptor.h @@ -73,6 +73,9 @@ struct Descriptor final { VkDescriptorSet handle() const; + private: + void invalidate(); + private: struct Item final { uint32_t binding; @@ -113,6 +116,9 @@ struct Descriptor final { Set allocate(const Shader::Layout::Object& shader_layout); void purge(); + private: + void invalidate(); + private: struct Configuration final { static constexpr uint32_t kQuantum = 16u; @@ -137,33 +143,6 @@ struct Descriptor final { } }; -// -// Impl -// - -inline Descriptor::Set::Set(Set&& set) - : device_(std::move(set.device_)), - descriptor_set_(std::move(set.descriptor_set_)), - shader_layout_signature_(std::move(set.shader_layout_signature_)), - bindings_(std::move(set.bindings_)) { - set.device_ = VK_NULL_HANDLE; - set.descriptor_set_ = VK_NULL_HANDLE; -} - -inline Descriptor::Set& Descriptor::Set::operator=(Set&& set) { - if (&set != this) { - device_ = std::move(set.device_); - descriptor_set_ = std::move(set.descriptor_set_); - shader_layout_signature_ = std::move(set.shader_layout_signature_); - bindings_ = std::move(set.bindings_); - - set.device_ = VK_NULL_HANDLE; - set.descriptor_set_ = VK_NULL_HANDLE; - }; - - return *this; -} - } // namespace api } // namespace vulkan } // namespace native diff --git a/aten/src/ATen/native/vulkan/api/Pipeline.cpp b/aten/src/ATen/native/vulkan/api/Pipeline.cpp index 4b15203892ed..89e85892ee0c 100644 --- a/aten/src/ATen/native/vulkan/api/Pipeline.cpp +++ b/aten/src/ATen/native/vulkan/api/Pipeline.cpp @@ -169,6 +169,10 @@ Pipeline::Cache::Cache(Factory factory) : cache_(std::move(factory)) { } +void Pipeline::Cache::purge() { + cache_.purge(); +} + } // namespace api } // namespace vulkan } // namespace native diff --git a/aten/src/ATen/native/vulkan/api/Pipeline.h b/aten/src/ATen/native/vulkan/api/Pipeline.h index 1d1966790dbf..794193d8a161 100644 --- a/aten/src/ATen/native/vulkan/api/Pipeline.h +++ b/aten/src/ATen/native/vulkan/api/Pipeline.h @@ -196,7 +196,11 @@ inline Pipeline::Barrier::operator bool() const { inline bool operator==( const Pipeline::Layout::Descriptor& _1, const Pipeline::Layout::Descriptor& _2) { - return (_1.descriptor_set_layout == _2.descriptor_set_layout); + static_assert( + std::is_trivially_copyable::value, + "This implementation is no longer valid!"); + + return (0 == memcmp(&_1, &_2, sizeof(Pipeline::Layout::Descriptor))); } inline size_t Pipeline::Layout::Factory::Hasher::operator()( @@ -207,9 +211,11 @@ inline size_t Pipeline::Layout::Factory::Hasher::operator()( inline bool operator==( const Pipeline::Descriptor& _1, const Pipeline::Descriptor& _2) { - return (_1.pipeline_layout == _2.pipeline_layout) && - (_1.shader_module == _2.shader_module) && - (_1.local_work_group == _2.local_work_group); + static_assert( + std::is_trivially_copyable::value, + "This implementation is no longer valid!"); + + return (0 == memcmp(&_1, &_2, sizeof(Pipeline::Descriptor))); } inline size_t Pipeline::Factory::Hasher::operator()( @@ -236,10 +242,6 @@ inline Pipeline::Object Pipeline::Cache::retrieve( }; } -inline void Pipeline::Cache::purge() { - cache_.purge(); -} - } // namespace api } // namespace vulkan } // namespace native diff --git a/aten/src/ATen/native/vulkan/api/Resource.cpp b/aten/src/ATen/native/vulkan/api/Resource.cpp index a491ed4dd6e0..adda610fb90c 100644 --- a/aten/src/ATen/native/vulkan/api/Resource.cpp +++ b/aten/src/ATen/native/vulkan/api/Resource.cpp @@ -192,6 +192,11 @@ VkFence Resource::Fence::handle(const bool add_to_waitlist) const { "Invalid Vulkan fence!"); const VkFence fence = pool->fence_.pool[id].get(); + + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + fence, + "Invalid Vulkan fence!"); + if (add_to_waitlist) { pool->fence_.waitlist.push_back(fence); } @@ -360,14 +365,13 @@ Resource::Pool::Pool( : device_(gpu.device), allocator_( create_allocator( - gpu.adapter->runtime->instance(), - gpu.adapter->handle, - device_), + gpu.adapter->runtime->instance(), + gpu.adapter->handle, + device_), vmaDestroyAllocator), memory_{ std::move(policy), }, - buffer_{}, image_{ .sampler = Image::Sampler{gpu}, }, @@ -377,23 +381,6 @@ Resource::Pool::Pool( fence_.pool.reserve(Configuration::kReserve); } -Resource::Pool::~Pool() { - try { - if (device_ && allocator_) { - purge(); - } - } - catch (const std::exception& e) { - LOG(WARNING) - << "Vulkan: Resource pool destructor raised an exception! Error: " - << e.what(); - } - catch (...) { - LOG(WARNING) - << "Vulkan: Resource pool destructor raised an unknown exception!"; - } -} - Resource::Pool::Pool(Pool&& pool) : device_(std::move(pool.device_)), allocator_(std::move(pool.allocator_)), @@ -401,7 +388,7 @@ Resource::Pool::Pool(Pool&& pool) buffer_(std::move(pool.buffer_)), image_(std::move(pool.image_)), fence_(std::move(pool.fence_)) { - pool.device_ = VK_NULL_HANDLE; + pool.invalidate(); } Resource::Pool& Resource::Pool::operator=(Pool&& pool) { @@ -413,12 +400,29 @@ Resource::Pool& Resource::Pool::operator=(Pool&& pool) { image_ = std::move(pool.image_); fence_ = std::move(pool.fence_); - pool.device_ = VK_NULL_HANDLE; + pool.invalidate(); }; return *this; } +Resource::Pool::~Pool() { + try { + if (device_ && allocator_) { + purge(); + } + } + catch (const std::exception& e) { + LOG(WARNING) + << "Vulkan: Resource pool destructor raised an exception! Error: " + << e.what(); + } + catch (...) { + LOG(WARNING) + << "Vulkan: Resource pool destructor raised an unknown exception!"; + } +} + Resource::Buffer Resource::Pool::buffer( const Buffer::Descriptor& descriptor) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY( @@ -678,6 +682,11 @@ void Resource::Pool::purge() { buffer_.pool.clear(); } +void Resource::Pool::invalidate() { + device_ = VK_NULL_HANDLE; + allocator_.reset(); +} + } // namespace api } // namespace vulkan } // namespace native diff --git a/aten/src/ATen/native/vulkan/api/Resource.h b/aten/src/ATen/native/vulkan/api/Resource.h index 61a2094dc692..19a7df3d04d2 100644 --- a/aten/src/ATen/native/vulkan/api/Resource.h +++ b/aten/src/ATen/native/vulkan/api/Resource.h @@ -20,6 +20,16 @@ struct Resource final { // struct Memory final { + /* + Descriptor + */ + + struct Descriptor final { + VmaMemoryUsage usage; + VkMemoryPropertyFlags /* optional */ required; + VkMemoryPropertyFlags /* optional */ preferred; + }; + /* Barrier */ @@ -30,18 +40,9 @@ struct Resource final { }; /* - Descriptor + Access */ - struct Descriptor final { - VmaMemoryUsage usage; - VkMemoryPropertyFlags /* optional */ required; - VkMemoryPropertyFlags /* optional */ preferred; - }; - - VmaAllocator allocator; - VmaAllocation allocation; - struct Access final { typedef uint8_t Flags; @@ -74,6 +75,9 @@ struct Resource final { typename Pointer = Access::Pointer> Handle map() &; + VmaAllocator allocator; + VmaAllocation allocation; + private: // Intentionally disabed to ensure memory access is always properly // encapsualted in a scoped map-unmap region. Allowing below overloads @@ -299,6 +303,8 @@ struct Resource final { private: friend struct Fence; + void invalidate(); + private: struct Configuration final { static constexpr uint32_t kReserve = 256u; @@ -353,7 +359,8 @@ class Resource::Memory::Scope final { template inline Resource::Memory::Handle Resource::Memory::map() const & { - void* map(const Memory& memory, Access::Flags); + // Forward declaration + void* map(const Memory&, Access::Flags); return Handle{ reinterpret_cast(map(*this, Access::Read)), @@ -363,7 +370,8 @@ inline Resource::Memory::Handle Resource::Memory::map() const & { template inline Resource::Memory::Handle Resource::Memory::map() & { - void* map(const Memory& memory, Access::Flags); + // Forward declaration + void* map(const Memory&, Access::Flags); static_assert( (kAccess == Access::Read) || @@ -388,10 +396,11 @@ inline Resource::Buffer::operator bool() const { inline bool operator==( const Resource::Image::Sampler::Descriptor& _1, const Resource::Image::Sampler::Descriptor& _2) { - return (_1.filter == _2.filter) && - (_1.mipmap_mode == _2.mipmap_mode) && - (_1.address_mode == _2.address_mode) && - (_1.border == _2.border); + static_assert( + std::is_trivially_copyable::value, + "This implementation is no longer valid!"); + + return (0 == memcmp(&_1, &_2, sizeof(Resource::Image::Sampler::Descriptor))); } inline size_t Resource::Image::Sampler::Factory::Hasher::operator()( diff --git a/aten/src/ATen/native/vulkan/api/Runtime.cpp b/aten/src/ATen/native/vulkan/api/Runtime.cpp index b63ded444887..c3ad6ebddb45 100644 --- a/aten/src/ATen/native/vulkan/api/Runtime.cpp +++ b/aten/src/ATen/native/vulkan/api/Runtime.cpp @@ -86,7 +86,9 @@ VkInstance create_instance(const Runtime::Type type) { nullptr, &instance_extension_count, instance_extension_properties.data())); constexpr const char* const requested_instance_extensions[]{ + #ifdef VK_EXT_debug_report VK_EXT_DEBUG_REPORT_EXTENSION_NAME, + #endif }; for (const auto& requested_instance_extension : requested_instance_extensions) { diff --git a/aten/src/ATen/native/vulkan/api/Runtime.h b/aten/src/ATen/native/vulkan/api/Runtime.h index 675f35cd0789..55eae70f8723 100644 --- a/aten/src/ATen/native/vulkan/api/Runtime.h +++ b/aten/src/ATen/native/vulkan/api/Runtime.h @@ -33,10 +33,7 @@ class Runtime final { Runtime& operator=(Runtime&&) = default; ~Runtime() = default; - inline VkInstance instance() const { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(instance_); - return instance_.get(); - } + VkInstance instance() const; typedef std::function Selector; Adapter select(const Selector& selector); @@ -59,6 +56,15 @@ class Runtime final { Runtime* runtime(); +// +// Impl +// + +inline VkInstance Runtime::instance() const { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(instance_); + return instance_.get(); +} + } // namespace api } // namespace vulkan } // namespace native diff --git a/aten/src/ATen/native/vulkan/api/Shader.cpp b/aten/src/ATen/native/vulkan/api/Shader.cpp index 43d1a62ac201..7995dd160c35 100644 --- a/aten/src/ATen/native/vulkan/api/Shader.cpp +++ b/aten/src/ATen/native/vulkan/api/Shader.cpp @@ -60,6 +60,10 @@ Shader::Layout::Cache::Cache(Factory factory) : cache_(std::move(factory)) { } +void Shader::Layout::Cache::purge() { + cache_.purge(); +} + #ifdef USE_VULKAN_SHADERC_RUNTIME struct Shader::Factory::Compiler final { diff --git a/aten/src/ATen/native/vulkan/api/Shader.h b/aten/src/ATen/native/vulkan/api/Shader.h index 718504e69bd4..f005eb1c11e9 100644 --- a/aten/src/ATen/native/vulkan/api/Shader.h +++ b/aten/src/ATen/native/vulkan/api/Shader.h @@ -218,16 +218,14 @@ inline Shader::Layout::Object Shader::Layout::Cache::retrieve( }; } -inline void Shader::Layout::Cache::purge() { - cache_.purge(); -} - inline bool operator==( const Shader::WorkGroup& _1, const Shader::WorkGroup& _2) { - return (_1.data[0u] == _2.data[0u]) && - (_1.data[1u] == _2.data[1u]) && - (_1.data[2u] == _2.data[2u]); + static_assert( + std::is_trivially_copyable::value, + "This implementation is no longer valid!"); + + return (0 == memcmp(&_1, &_2, sizeof(Shader::WorkGroup))); } inline Shader::Descriptor::Descriptor(const char* const glsl) @@ -258,12 +256,10 @@ inline bool operator==( const Shader::Descriptor& _1, const Shader::Descriptor& _2) { static_assert( - sizeof(Shader::Descriptor::shader.source) == sizeof(Shader::Descriptor::shader.binary), - "This implementation requires sizeof(Source) to be equal to sizeof(Binary)."); + std::is_trivially_copyable::value, + "This implementation is no longer valid!"); - return (_1.type == _2.type) && - (_1.shader.binary.spirv == _2.shader.binary.spirv) && - (_1.shader.binary.size == _2.shader.binary.size); + return (0 == memcmp(&_1, &_2, sizeof(Shader::Descriptor))); } inline size_t Shader::Factory::Hasher::operator()( @@ -286,11 +282,11 @@ inline size_t Shader::Factory::Hasher::operator()( inline bool operator==( const VkDescriptorSetLayoutBinding& _1, const VkDescriptorSetLayoutBinding& _2) { - return (_1.binding == _2.binding) && - (_1.descriptorType == _2.descriptorType) && - (_1.descriptorCount == _2.descriptorCount) && - (_1.stageFlags == _2.stageFlags) && - (_1.pImmutableSamplers == _2.pImmutableSamplers); + static_assert( + std::is_trivially_copyable::value, + "This implementation is no longer valid!"); + + return (0 == memcmp(&_1, &_2, sizeof(VkDescriptorSetLayoutBinding))); } #endif /* USE_VULKAN_API */ diff --git a/aten/src/ATen/native/vulkan/ops/Add.cpp b/aten/src/ATen/native/vulkan/ops/Add.cpp index 270a1d5f8168..95b7fd67c095 100644 --- a/aten/src/ATen/native/vulkan/ops/Add.cpp +++ b/aten/src/ATen/native/vulkan/ops/Add.cpp @@ -24,11 +24,11 @@ Tensor add_scalar( v_self.options(), }; - api::Command::Buffer command_buffer = context->command().pool.allocate(); - command_buffer.begin(); + api::Command::Pool& command_pool = context->command().pool; + api::Command::Buffer& command_buffer = command_pool.stream(); { - if (v_output.has_image() && v_self.has_image()) { - const struct { + if C10_LIKELY(v_output.has_image() && v_self.has_image()) { + const struct Block final { uvec3 extents; float other; } block { @@ -64,8 +64,7 @@ Tensor add_scalar( TORCH_CHECK(false, "Not implemented!"); } } - command_buffer.end(); - command_buffer.submit(context->gpu().queue); + command_pool.submit(context->gpu().queue, command_buffer); return convert(v_output); } @@ -82,11 +81,11 @@ Tensor& add_scalar_( vTensor& v_self = convert(self); - api::Command::Buffer command_buffer = context->command().pool.allocate(); - command_buffer.begin(); + api::Command::Pool& command_pool = context->command().pool; + api::Command::Buffer& command_buffer = command_pool.stream(); { - if (v_self.has_image()) { - const struct { + if C10_LIKELY(v_self.has_image()) { + const struct Block final { uvec3 extents; float other; } block { @@ -116,8 +115,7 @@ Tensor& add_scalar_( TORCH_CHECK(false, "Not implemented!"); } } - command_buffer.end(); - command_buffer.submit(context->gpu().queue); + command_pool.submit(context->gpu().queue, command_buffer); return self; } @@ -140,11 +138,11 @@ Tensor add_tensor( v_self.options(), }; - api::Command::Buffer command_buffer = context->command().pool.allocate(); - command_buffer.begin(); + api::Command::Pool& command_pool = context->command().pool; + api::Command::Buffer& command_buffer = command_pool.stream(); { - if (v_self.has_image() && v_other.has_image()) { - const struct { + if C10_LIKELY(v_self.has_image() && v_other.has_image()) { + const struct Block final { uvec3 extents; float alpha; } block { @@ -186,8 +184,7 @@ Tensor add_tensor( TORCH_CHECK(false, "Not implemented!"); } } - command_buffer.end(); - command_buffer.submit(context->gpu().queue); + command_pool.submit(context->gpu().queue, command_buffer); return convert(v_output); } @@ -207,11 +204,11 @@ Tensor& add_tensor_( const Tensor other = other_arg.is_vulkan() ? other_arg : other_arg.vulkan(); const vTensor& v_other = convert(other); - api::Command::Buffer command_buffer = context->command().pool.allocate(); - command_buffer.begin(); + api::Command::Pool& command_pool = context->command().pool; + api::Command::Buffer& command_buffer = command_pool.stream(); { - if (v_self.has_image() && v_other.has_image() && !self.is_same(other)) { - const struct { + if C10_LIKELY(v_self.has_image() && v_other.has_image() && !self.is_same(other)) { + const struct Block final { uvec3 extents; float alpha; } block { @@ -247,8 +244,7 @@ Tensor& add_tensor_( TORCH_CHECK(false, "Not implemented!"); } } - command_buffer.end(); - command_buffer.submit(context->gpu().queue); + command_pool.submit(context->gpu().queue, command_buffer); return self; } diff --git a/aten/src/ATen/native/vulkan/ops/Clamp.cpp b/aten/src/ATen/native/vulkan/ops/Clamp.cpp index 9f25d89bca9b..75e9a1bb0fff 100644 --- a/aten/src/ATen/native/vulkan/ops/Clamp.cpp +++ b/aten/src/ATen/native/vulkan/ops/Clamp.cpp @@ -28,11 +28,11 @@ Tensor clamp( v_self.options(), }; - api::Command::Buffer command_buffer = context->command().pool.allocate(); - command_buffer.begin(); + api::Command::Pool& command_pool = context->command().pool; + api::Command::Buffer& command_buffer = command_pool.stream(); { - if (v_output.has_image() && v_self.has_image()) { - const struct { + if C10_LIKELY(v_output.has_image() && v_self.has_image()) { + const struct Block final { uvec3 extents; uint32_t _; vec2 clamp; @@ -73,8 +73,7 @@ Tensor clamp( TORCH_CHECK(false, "Not implemented!"); } } - command_buffer.end(); - command_buffer.submit(context->gpu().queue); + command_pool.submit(context->gpu().queue, command_buffer); return convert(v_output); } @@ -95,11 +94,11 @@ Tensor& clamp_( vTensor& v_self = convert(self); - api::Command::Buffer command_buffer = context->command().pool.allocate(); - command_buffer.begin(); + api::Command::Pool& command_pool = context->command().pool; + api::Command::Buffer& command_buffer = command_pool.stream(); { - if (v_self.has_image()) { - const struct { + if C10_LIKELY(v_self.has_image()) { + const struct Block final { uvec3 extents; uint32_t _; vec2 clamp; @@ -134,8 +133,7 @@ Tensor& clamp_( TORCH_CHECK(false, "Not implemented!"); } } - command_buffer.end(); - command_buffer.submit(context->gpu().queue); + command_pool.submit(context->gpu().queue, command_buffer); return self; } diff --git a/aten/src/ATen/native/vulkan/ops/Common.h b/aten/src/ATen/native/vulkan/ops/Common.h index b0bbeeaf34f1..3c9b2e8b3b9f 100644 --- a/aten/src/ATen/native/vulkan/ops/Common.h +++ b/aten/src/ATen/native/vulkan/ops/Common.h @@ -35,14 +35,6 @@ struct Layout final { }; }; -struct Experimentation { - static constexpr bool kUseConv2dOldApi = false; -}; - -struct ConvPrepackLimits final { - static constexpr int64_t maxStackDepth = 2048*4; -}; - } // namespace ops } // namespace vulkan } // namespace native diff --git a/aten/src/ATen/native/vulkan/ops/Convolution.cpp b/aten/src/ATen/native/vulkan/ops/Convolution.cpp index d88545e3a25a..8991475b7d15 100644 --- a/aten/src/ATen/native/vulkan/ops/Convolution.cpp +++ b/aten/src/ATen/native/vulkan/ops/Convolution.cpp @@ -1,8 +1,8 @@ #include -#include #include #include #include +#include namespace at { namespace native { @@ -12,6 +12,10 @@ namespace { using namespace api::utils; +struct Experimentation final { + static constexpr bool kUseConv2dOldApi = false; +}; + inline bool is_depthwise( const IntArrayRef filter, const int64_t groups) { @@ -26,47 +30,103 @@ inline bool is_pointwise(const IntArrayRef filter) { } vTensor pack_weights_dw( + api::Context* const context, + api::Command::Buffer& command_buffer, api::Resource::Pool& pool, - const Tensor& weight_arg, - const int64_t groups) { - if (weight_arg.is_vulkan()) { - return convert(weight_arg); - } - + const Tensor& weight) { /* Source */ - - const Tensor weight = weight_arg.contiguous(); const IntArrayRef src_filter = weight.sizes(); const float* const src_weight_ptr = weight.data_ptr(); const int64_t src_kw_sz = src_filter[Layout::Filter::width]; const int64_t src_kh_sz = src_filter[Layout::Filter::height]; + const int64_t src_kernel_sz = src_kw_sz * src_kh_sz; + const int64_t src_block_sz = src_kernel_sz * src_filter[Layout::Filter::input]; const int64_t num_stacks = div_up(src_filter[Layout::Filter::output], INT64_C(4)); + + /* Destination */ + const int64_t dst_kw_sz = src_kernel_sz; + const int64_t dst_kh_sz = num_stacks; + const int64_t dst_kernel_sz = dst_kw_sz * dst_kh_sz; + vTensor v_weight{ - api::context(), + context, &pool, { 4, - num_stacks, - src_kw_sz * src_kh_sz, + dst_kh_sz, + dst_kw_sz, }, weight.options(), }; using Future = vTensor::Future; - Future v_weight_future = v_weight.host(); + Future v_weight_future = v_weight.host(command_buffer); Future::Payload v_weight_payload = v_weight_future.wait(); + float* const dst_weight_ptr = v_weight_payload.get(); + memset(dst_weight_ptr, 0, v_weight.nbytes()); + + for (int64_t src_oc = 0; src_oc < src_filter[Layout::Filter::output]; ++src_oc) { + /* Source */ + const float* const src_weight_oc_ptr = src_weight_ptr + src_oc * src_block_sz; + + /* Destination */ + const int64_t dst_oh = src_oc / 4; + const int64_t dst_c = src_oc % 4; + + float* const dst_weight_c_ptr = dst_weight_ptr + + dst_c * dst_kernel_sz + + dst_oh * dst_kw_sz; + + for (int64_t src_ih = 0; src_ih < src_filter[Layout::Filter::height]; ++src_ih) { + memcpy( + dst_weight_c_ptr + src_ih * src_kw_sz, + src_weight_oc_ptr + src_ih * src_kw_sz, + sizeof(float) * src_kw_sz); + } + } + + return v_weight; +} + +vTensor pack_weights_2d( + api::Context* const context, + api::Command::Buffer& command_buffer, + api::Resource::Pool& pool, + const Tensor& weight) { /* Source */ + const IntArrayRef src_filter = weight.sizes(); + const float* const src_weight_ptr = weight.data_ptr(); + + const int64_t src_kw_sz = src_filter[Layout::Filter::width]; + const int64_t src_kh_sz = src_filter[Layout::Filter::height]; const int64_t src_kernel_sz = src_kw_sz * src_kh_sz; - const int64_t src_block_sz = - src_kernel_sz * src_filter[Layout::Filter::input]; + const int64_t src_block_sz = src_kernel_sz * src_filter[Layout::Filter::input]; + + const int64_t num_stacks = div_up(src_filter[Layout::Filter::output], INT64_C(4)); + const int64_t stack_depth = api::utils::align_up(src_filter[Layout::Filter::input], INT64_C(4)); /* Destination */ - const int64_t dst_kw_sz = src_kw_sz * src_kh_sz; - const int64_t dst_kh_sz = num_stacks; + const int64_t dst_kw_sz = src_kw_sz * stack_depth; + const int64_t dst_kh_sz = src_kh_sz * num_stacks; const int64_t dst_kernel_sz = dst_kw_sz * dst_kh_sz; + vTensor v_weight{ + context, + &pool, + { + 4, + dst_kh_sz, + dst_kw_sz, + }, + weight.options(), + }; + + using Future = vTensor::Future; + Future v_weight_future = v_weight.host(command_buffer); + Future::Payload v_weight_payload = v_weight_future.wait(); + float* const dst_weight_ptr = v_weight_payload.get(); memset(dst_weight_ptr, 0, v_weight.nbytes()); @@ -80,26 +140,29 @@ vTensor pack_weights_dw( float* const dst_weight_c_ptr = dst_weight_ptr + dst_c * dst_kernel_sz; - for (int64_t src_ih = 0; src_ih < src_filter[Layout::Filter::height]; ++src_ih) { - memcpy( - dst_weight_c_ptr + dst_oh * dst_kw_sz + src_ih * src_kw_sz, - src_weight_oc_ptr + src_ih * src_kw_sz, - sizeof(float) * src_kw_sz); + for (int64_t src_ic = 0; src_ic < src_filter[Layout::Filter::input]; ++src_ic) { + const int64_t dst_ic4 = src_ic / 4; + + for (int64_t src_ih = 0; src_ih < src_kh_sz; ++src_ih) { + for (int64_t src_iw = 0; src_iw < src_kw_sz; ++src_iw) { + memcpy( + dst_weight_c_ptr + (dst_oh * src_kh_sz + src_ih) * dst_kw_sz + + dst_ic4 * src_kw_sz * 4 + src_iw * 4 + src_ic % 4, + src_weight_oc_ptr + src_ic * src_kernel_sz + src_ih * src_kw_sz + src_iw, + sizeof(float)); + } + } } } return v_weight; } -vTensor pack_weights_old( +vTensor pack_weights_2d_old( + api::Context* const context, + api::Command::Buffer& command_buffer, api::Resource::Pool& pool, - const Tensor& weight_arg, - const int64_t groups) { - if (weight_arg.is_vulkan()) { - return convert(weight_arg); - } - - const Tensor weight = weight_arg.contiguous(); + const Tensor& weight) { const IntArrayRef src_filter = weight.sizes(); const float* const src_weight_ptr = weight.data_ptr(); @@ -111,7 +174,7 @@ vTensor pack_weights_old( const uint32_t KW = src_filter[Layout::Filter::width]; vTensor v_weight{ - api::context(), + context, &pool, { 1, @@ -123,13 +186,13 @@ vTensor pack_weights_old( }; using Future = vTensor::Future; - Future v_weight_future = v_weight.host(); + Future v_weight_future = v_weight.host(command_buffer); Future::Payload v_weight_payload = v_weight_future.wait(); float* const dst_weight_ptr = v_weight_payload.get(); memset(dst_weight_ptr, 0, v_weight.nbytes()); - const float* src = src_weight_ptr; + const float* const src = src_weight_ptr; float* const dst = dst_weight_ptr; { @@ -162,7 +225,7 @@ vTensor pack_weights_old( dim0_ = dim0; dim1_ = dim1; dim2_ = dim2; - data_ = new float[dim0 * dim1 * dim2 * 4]; + data_ = new float[dim0 * dim1 * dim2 * 4]; // TODO: memory leak memset(data_, 0.f, dim0 * dim1 * dim2 * 4 * sizeof(float)); } @@ -211,7 +274,7 @@ vTensor pack_weights_old( return v_weight; } -vTensor pack_weights_2d( +vTensor pack_weights( api::Resource::Pool& pool, const Tensor& weight_arg, const int64_t groups) { @@ -219,81 +282,32 @@ vTensor pack_weights_2d( return convert(weight_arg); } - const Tensor weight = weight_arg.contiguous(); - const IntArrayRef src_filter = weight.sizes(); - const float* const src_weight_ptr = weight.data_ptr(); - - const int64_t src_kw_sz = src_filter[Layout::Filter::width]; - const int64_t src_kh_sz = src_filter[Layout::Filter::height]; - const int64_t num_stacks = div_up(src_filter[Layout::Filter::output], INT64_C(4)); - const int64_t stack_depth = api::utils::align_up(src_filter[Layout::Filter::input], INT64_C(4)); - vTensor v_weight{ - api::context(), - &pool, - { - 4, - src_kh_sz * num_stacks, - src_kw_sz * stack_depth, - }, - weight.options(), - }; - - using Future = vTensor::Future; - Future v_weight_future = v_weight.host(); - Future::Payload v_weight_payload = v_weight_future.wait(); - - /* Source */ - const int64_t src_kernel_sz = src_kw_sz * src_kh_sz; - const int64_t src_block_sz = - src_kernel_sz * src_filter[Layout::Filter::input]; - - /* Destination */ - const int64_t dst_kw_sz = src_kw_sz * stack_depth; - const int64_t dst_kh_sz = src_kh_sz * num_stacks; - const int64_t dst_kernel_sz = dst_kw_sz * dst_kh_sz; - - float* const dst_weight_ptr = v_weight_payload.get(); - memset(dst_weight_ptr, 0, v_weight.nbytes()); - - for (int64_t src_oc = 0; src_oc < src_filter[Layout::Filter::output]; ++src_oc) { - /* Source */ - const float* const src_weight_oc_ptr = src_weight_ptr + src_oc * src_block_sz; - - /* Destination */ - const int64_t dst_oh = src_oc / 4; - const int64_t dst_c = src_oc % 4; - - float* const dst_weight_c_ptr = dst_weight_ptr + dst_c * dst_kernel_sz; - - for (int64_t src_ic = 0; src_ic < src_filter[Layout::Filter::input]; ++src_ic) { - const int64_t dst_ic4 = src_ic/4; - for (int64_t src_ih = 0; src_ih < src_kh_sz; ++src_ih) { - for (int64_t src_iw = 0; src_iw < src_kw_sz; ++src_iw) { - memcpy( - dst_weight_c_ptr + (dst_oh * src_kh_sz + src_ih) * dst_kw_sz + - dst_ic4 * src_kw_sz * 4 + src_iw * 4 + src_ic % 4, - src_weight_oc_ptr + src_ic * src_kernel_sz + src_ih * src_kw_sz + src_iw, - sizeof(float)); - } - } - } - } + api::Context* const context = api::context(); + api::Command::Buffer& command_buffer = context->command().pool.stream(); - return v_weight; -} + const Tensor weight = weight_arg.contiguous(); -vTensor pack_weights( - api::Resource::Pool& pool, - const Tensor& weight_arg, - const int64_t groups) { - if (is_depthwise(weight_arg.sizes(), groups)) { - return pack_weights_dw(pool, weight_arg, groups); + if (is_depthwise(weight.sizes(), groups)) { + return pack_weights_dw( + context, + command_buffer, + pool, + weight); } if (Experimentation::kUseConv2dOldApi) { - return pack_weights_old(pool, weight_arg, groups); + return pack_weights_2d_old( + context, + command_buffer, + pool, + weight); } - return pack_weights_2d(pool, weight_arg, groups); + + return pack_weights_2d( + context, + command_buffer, + pool, + weight); } vTensor pack_biases( @@ -304,8 +318,11 @@ vTensor pack_biases( return convert(*bias); } + api::Context* const context = api::context(); + api::Command::Buffer& command_buffer = context->command().pool.stream(); + vTensor v_bias{ - api::context(), + context, &pool, { // 1D @@ -316,7 +333,7 @@ vTensor pack_biases( { using Future = vTensor::Future; - Future v_bias_future = v_bias.host(); + Future v_bias_future = v_bias.host(command_buffer); Future::Payload v_bias_payload = v_bias_future.wait(); if (bias) { @@ -394,7 +411,8 @@ bool available( (c10::DeviceType::Vulkan == bias->device().type())) && (kFloat == bias->scalar_type()) && (transposed ? false /* to be addded in the future */ - : (weight.size(Layout::Filter::output) == bias->size(Layout::Filter::output)))) + : (weight.size(Layout::Filter::output) == + bias->size(Layout::Filter::output)))) : true) && // Stride (stride[Layout::Parameter::height] > 0) && @@ -432,7 +450,7 @@ bool usable(const Tensor& input) { true; } -void conv2d_depthwise( +void conv2d_dw( api::Context* const context, api::Command::Buffer& command_buffer, vTensor& v_output, @@ -446,27 +464,39 @@ void conv2d_depthwise( const IntArrayRef dilation, const float output_min, const float output_max) { - if (v_output.has_image() && v_input.has_image() && v_weight.has_image()) { - const struct { - int32_t kernel_x, kernel_y; - int32_t stride_x, stride_y; - int32_t padding_x, padding_y; - int32_t dilate_x, dilate_y; - float clamp_x, clamp_y; - int32_t src_filter_w, src_filter_h; + if C10_LIKELY(v_output.has_image() && v_input.has_image() && v_weight.has_image()) { + const struct Block final { + ivec2 kernel; + ivec2 stride; + ivec2 padding; + ivec2 dilate; + vec2 clamp; + ivec2 src_filter; } block { - safe_downcast(filter[Layout::Filter::width]), - safe_downcast(filter[Layout::Filter::height]), - safe_downcast(stride[Layout::Parameter::width]), - safe_downcast(stride[Layout::Parameter::height]), - safe_downcast(padding[Layout::Parameter::width]), - safe_downcast(padding[Layout::Parameter::height]), - safe_downcast(dilation[Layout::Parameter::width]), - safe_downcast(dilation[Layout::Parameter::height]), - output_min, - output_max, - safe_downcast(src_filter[Layout::Filter::width]), - safe_downcast(src_filter[Layout::Filter::height]), + { + safe_downcast(filter[Layout::Filter::width]), + safe_downcast(filter[Layout::Filter::height]), + }, + { + safe_downcast(stride[Layout::Parameter::width]), + safe_downcast(stride[Layout::Parameter::height]), + }, + { + safe_downcast(padding[Layout::Parameter::width]), + safe_downcast(padding[Layout::Parameter::height]), + }, + { + safe_downcast(dilation[Layout::Parameter::width]), + safe_downcast(dilation[Layout::Parameter::height]), + }, + { + output_min, + output_max, + }, + { + safe_downcast(src_filter[Layout::Filter::width]), + safe_downcast(src_filter[Layout::Filter::height]), + }, }; context->dispatch( @@ -510,7 +540,7 @@ void conv2d_depthwise( } } -void conv2d_pointwise( +void conv2d_pw( api::Context* const context, api::Command::Buffer& command_buffer, vTensor& v_output, @@ -522,22 +552,29 @@ void conv2d_pointwise( const IntArrayRef padding, const float output_min, const float output_max) { - if (v_output.has_image() && v_input.has_image() && v_weight.has_image()) { - - const struct { - int32_t kernel_ic, kernel_oc; - int32_t stride_x, stride_y; - int32_t padding_x, padding_y; - float clamp_x, clamp_y; + if C10_LIKELY(v_output.has_image() && v_input.has_image() && v_weight.has_image()) { + const struct Block final { + ivec2 kernel; + ivec2 stride; + ivec2 padding; + vec2 clamp; } block { - safe_downcast(filter[Layout::Filter::input]), - safe_downcast(filter[Layout::Filter::output]), - safe_downcast(stride[Layout::Parameter::width]), - safe_downcast(stride[Layout::Parameter::height]), - safe_downcast(padding[Layout::Parameter::width]), - safe_downcast(padding[Layout::Parameter::height]), - output_min, - output_max, + { + safe_downcast(filter[Layout::Filter::input]), + safe_downcast(filter[Layout::Filter::output]), + }, + { + safe_downcast(stride[Layout::Parameter::width]), + safe_downcast(stride[Layout::Parameter::height]), + }, + { + safe_downcast(padding[Layout::Parameter::width]), + safe_downcast(padding[Layout::Parameter::height]), + }, + { + output_min, + output_max, + }, }; context->dispatch( @@ -595,30 +632,134 @@ void conv2d( const IntArrayRef dilation, const float output_min, const float output_max) { + if C10_LIKELY(v_output.has_image() && v_input.has_image() && v_weight.has_image()) { + const struct Block final { + ivec4 kernel; + ivec2 stride; + ivec2 padding; + ivec2 dilate; + vec2 clamp; + ivec4 src_filter; + } block { + { + safe_downcast(filter[Layout::Filter::width]), + safe_downcast(filter[Layout::Filter::height]), + safe_downcast(filter[Layout::Filter::input]), + safe_downcast(filter[Layout::Filter::output]), + }, + { + safe_downcast(stride[Layout::Parameter::width]), + safe_downcast(stride[Layout::Parameter::height]), + }, + { + safe_downcast(padding[Layout::Parameter::width]), + safe_downcast(padding[Layout::Parameter::height]), + }, + { + safe_downcast(dilation[Layout::Parameter::width]), + safe_downcast(dilation[Layout::Parameter::height]), + }, + { + output_min, + output_max, + }, + { + safe_downcast(src_filter[Layout::Filter::width]), + safe_downcast(src_filter[Layout::Filter::height]), + safe_downcast(src_filter[Layout::Filter::width] * 4), + 0, + }, + }; + + context->dispatch( + command_buffer, + { + VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, + VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, + }, + VK_KERNEL(conv2d), + v_output.extents(), + // Write-only access bypasses synchronization but inserts appropriate + // barriers if necessary. + v_output.image( + command_buffer, + vTensor::Stage::Compute, + vTensor::Access::Write), + // Read-only access is implied on const tensors and triggers an async + // synchronization if necessary. + v_input.image( + command_buffer, + vTensor::Stage::Compute), + // Read-only access is implied on const tensors and triggers an async + // synchronization if necessary. + v_weight.image( + command_buffer, + vTensor::Stage::Compute), + // Read-only access is implied on const tensors and triggers an async + // synchronization if necessary. + v_bias.buffer( + command_buffer, + vTensor::Stage::Compute), + // Object lifetime is managed by the resource pool. + // It is OK not to keep track of the handle. + context->resource().pool.uniform(block).object); + } + else { + TORCH_CHECK(false, "Not implemented!"); + } +} + +void conv2d_old( + api::Context* const context, + api::Command::Buffer& command_buffer, + vTensor& v_output, + const vTensor& v_input, + const vTensor& v_weight, + const vTensor& v_bias, + const IntArrayRef filter, + const IntArrayRef stride, + const IntArrayRef padding, + const IntArrayRef dilation, + const float output_min, + const float output_max) { + using namespace api::utils; + if (v_output.has_image() && v_input.has_image() && v_weight.has_image()) { - const struct { - int32_t kernel_x, kernel_y, kernel_ic, kernel_oc; - int32_t stride_x, stride_y; + const int32_t W = v_input.extents().data[0]; + const int32_t H = v_input.extents().data[1]; + const int32_t C_4 = v_input.extents().data[2]; + const int32_t C = 4 * C_4; + + const int32_t OW = v_output.extents().data[0]; + const int32_t OH = v_output.extents().data[1]; + const int32_t OC_4 = v_output.extents().data[2]; + const int32_t OC = 4 * OC_4; + + const struct Block final { int32_t padding_x, padding_y; + int32_t kernel_x, kernel_y; + int32_t stride_x, stride_y; int32_t dilate_x, dilate_y; - float clamp_x, clamp_y; - int32_t src_filter_w, src_filter_h, src_filter_w4; + int32_t outputSize[4]; + int32_t inputSize[4]; + float outputMin; + float outputMax; } block { + safe_downcast(padding[Layout::Parameter::width]), + safe_downcast(padding[Layout::Parameter::height]), safe_downcast(filter[Layout::Filter::width]), safe_downcast(filter[Layout::Filter::height]), - safe_downcast(filter[Layout::Filter::input]), - safe_downcast(filter[Layout::Filter::output]), safe_downcast(stride[Layout::Parameter::width]), safe_downcast(stride[Layout::Parameter::height]), - safe_downcast(padding[Layout::Parameter::width]), - safe_downcast(padding[Layout::Parameter::height]), safe_downcast(dilation[Layout::Parameter::width]), safe_downcast(dilation[Layout::Parameter::height]), + { OW, OH, OC_4, OC }, + { W, H, C_4, C }, output_min, output_max, - safe_downcast(src_filter[Layout::Filter::width]), - safe_downcast(src_filter[Layout::Filter::height]), - safe_downcast(src_filter[Layout::Filter::width]*4), }; context->dispatch( @@ -630,29 +771,30 @@ void conv2d( VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, }, - VK_KERNEL(conv2d), + VK_KERNEL(conv2d_nogroup_clamp), + //VK_KERNEL(conv2d_nogroup_clamp_1x), v_output.extents(), // Write-only access bypasses synchronization but inserts appropriate // barriers if necessary. v_output.image( - command_buffer, - vTensor::Stage::Compute, - vTensor::Access::Write), + command_buffer, + vTensor::Stage::Compute, + vTensor::Access::Write), // Read-only access is implied on const tensors and triggers an async // synchronization if necessary. v_input.image( - command_buffer, - vTensor::Stage::Compute), + command_buffer, + vTensor::Stage::Compute), // Read-only access is implied on const tensors and triggers an async // synchronization if necessary. v_weight.image( - command_buffer, - vTensor::Stage::Compute), + command_buffer, + vTensor::Stage::Compute), // Read-only access is implied on const tensors and triggers an async // synchronization if necessary. v_bias.buffer( - command_buffer, - vTensor::Stage::Compute), + command_buffer, + vTensor::Stage::Compute), // Object lifetime is managed by the resource pool. // It is OK not to keep track of the handle. context->resource().pool.uniform(block).object); @@ -781,99 +923,6 @@ Conv2dOpContext Conv2dOpContext::create( }; } -void conv2d_old( - api::Context* const context, - api::Command::Buffer& command_buffer, - vTensor& v_output, - const vTensor& v_input, - const vTensor& v_weight, - const vTensor& v_bias, - const IntArrayRef filter, - const IntArrayRef stride, - const IntArrayRef padding, - const IntArrayRef dilation, - const float output_min, - const float output_max) { - - using namespace api::utils; - - if (v_output.has_image() && v_input.has_image() && v_weight.has_image()) { - const int32_t W = v_input.extents().data[0]; - const int32_t H = v_input.extents().data[1]; - const int32_t C_4 = v_input.extents().data[2]; - const int32_t C = 4 * C_4; - - const int32_t OW = v_output.extents().data[0]; - const int32_t OH = v_output.extents().data[1]; - const int32_t OC_4 = v_output.extents().data[2]; - const int32_t OC = 4 * OC_4; - - const struct { - int32_t padding_x, padding_y; - int32_t kernel_x, kernel_y; - int32_t stride_x, stride_y; - int32_t dilate_x, dilate_y; - int32_t outputSize[4]; - int32_t inputSize[4]; - float outputMin; - float outputMax; - } block { - safe_downcast(padding[Layout::Parameter::width]), - safe_downcast(padding[Layout::Parameter::height]), - safe_downcast(filter[Layout::Filter::width]), - safe_downcast(filter[Layout::Filter::height]), - safe_downcast(stride[Layout::Parameter::width]), - safe_downcast(stride[Layout::Parameter::height]), - safe_downcast(dilation[Layout::Parameter::width]), - safe_downcast(dilation[Layout::Parameter::height]), - { OW, OH, OC_4, OC }, - { W, H, C_4, C }, - output_min, - output_max, - }; - - context->dispatch( - command_buffer, - { - VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, - VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, - VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, - VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, - VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, - }, - VK_KERNEL(conv2d_nogroup_clamp), - //VK_KERNEL(conv2d_nogroup_clamp_1x), - v_output.extents(), - // Write-only access bypasses synchronization but inserts appropriate - // barriers if necessary. - v_output.image( - command_buffer, - vTensor::Stage::Compute, - vTensor::Access::Write), - // Read-only access is implied on const tensors and triggers an async - // synchronization if necessary. - v_input.image( - command_buffer, - vTensor::Stage::Compute), - // Read-only access is implied on const tensors and triggers an async - // synchronization if necessary. - v_weight.image( - command_buffer, - vTensor::Stage::Compute), - // Read-only access is implied on const tensors and triggers an async - // synchronization if necessary. - v_bias.buffer( - command_buffer, - vTensor::Stage::Compute), - // Object lifetime is managed by the resource pool. - // It is OK not to keep track of the handle. - context->resource().pool.uniform(block).object); - } - else { - TORCH_CHECK(false, "Not implemented!"); - } -} - Tensor Conv2dOpContext::run(const Tensor& input_arg) const { api::Context* const context = api::context(); @@ -896,11 +945,11 @@ Tensor Conv2dOpContext::run(const Tensor& input_arg) const { input.options(), }; - api::Command::Buffer command_buffer = context->command().pool.allocate(); - command_buffer.begin(); + api::Command::Pool& command_pool = context->command().pool; + api::Command::Buffer& command_buffer = command_pool.stream(); { if (is_depthwise(unpacked_.filter, unpacked_.groups)) { - conv2d_depthwise( + conv2d_dw( context, command_buffer, v_output, @@ -932,7 +981,7 @@ Tensor Conv2dOpContext::run(const Tensor& input_arg) const { packed_.output_max); } else { if (is_pointwise(unpacked_.filter)) { - conv2d_pointwise( + conv2d_pw( context, command_buffer, v_output, @@ -964,8 +1013,7 @@ Tensor Conv2dOpContext::run(const Tensor& input_arg) const { } } } - command_buffer.end(); - command_buffer.submit(context->gpu().queue); + command_pool.submit(context->gpu().queue, command_buffer); return convert(v_output); } diff --git a/aten/src/ATen/native/vulkan/ops/Copy.cpp b/aten/src/ATen/native/vulkan/ops/Copy.cpp index bbd326b42ace..1cf6b1ad6aa9 100644 --- a/aten/src/ATen/native/vulkan/ops/Copy.cpp +++ b/aten/src/ATen/native/vulkan/ops/Copy.cpp @@ -6,87 +6,96 @@ namespace vulkan { namespace ops { Tensor& copy_(Tensor& self, const Tensor& src) { - // X -> Vulkan - if (at::kVulkan == self.device().type()) { - vTensor& v_self = convert(self); - - // CPU -> Vulkan - if (at::kCPU == src.device().type()) { - // Requesting write-only host access to the tensor never triggers a sync - // as the contents will be overwritten regardless. Having said that, - // appropriate barriers are inserted automatically if WAR or WAW hazards - // are detected. Examples of such scenario for instance are if any of - // these async operations are on going in the background on 'self': - // - On discrete systems: - // * buffer-to-staging transfers - // * staging-to-buffer transfers - // - On UMA buffer is an alias for staging and accessible both on host - // and device. Consequently: - // * buffer-to-image NHWC -> NC4HW packing - // * image-to-buffer NC4HW -> NHWC unpacking - - using Future = vTensor::Future; - Future v_self_future = v_self.host(); - - // This wait() will be a no-op if no hazards are detected, including the - // obvious, yet important, special case of 'self' being an empty tensor. - - Future::Payload v_self_payload = v_self_future.wait(); - - memcpy( - v_self_payload.get(), - src.contiguous().data_ptr(), - std::min(src.nbytes(), self.nbytes())); + api::Context* const context = api::context(); + + api::Command::Pool& command_pool = context->command().pool; + api::Command::Buffer& command_buffer = command_pool.stream(); + { + // X -> Vulkan + if (at::kVulkan == self.device().type()) { + vTensor& v_self = convert(self); + + // Vulkan -> Vulkan + if (at::kVulkan == src.device().type()) { + command_buffer.copy( + // - Read-only access is implied on const tensors. Memory barriers + // are automatically inserted if a RAW hazard is detected. + // - Recording any potential pending sync operations into the same + // command buffer prevents an expensive queue submission. + convert(src).buffer( + command_buffer, + vTensor::Stage::Transfer), + // - Write-only access never triggers a sync as the contents will be + // overwritten regardless. Having said that, appropriate barriers + // are inserted automatically if WAR or WAW hazards are detected. + // - Recording pending sync operations into the same command buffer + // prevents an expensive queue submission. + v_self.buffer( + command_buffer, + vTensor::Stage::Transfer, + vTensor::Access::Write)); + + command_pool.submit(context->gpu().queue, command_buffer); + } + // CPU -> Vulkan + else { + const Tensor cpu_src = src.device().is_cpu() ? src : src.cpu(); + + // Requesting write-only host access to the tensor never triggers a sync + // as the contents will be overwritten regardless. Having said that, + // appropriate barriers are inserted automatically if WAR or WAW hazards + // are detected. Examples of such scenario for instance are if any of + // these async operations are on going in the background on 'self': + // - On discrete systems: + // * buffer-to-staging transfers + // * staging-to-buffer transfers + // - On UMA buffer is an alias for staging and accessible both on host + // and device. Consequently: + // * buffer-to-image NHWC -> NC4HW packing + // * image-to-buffer NC4HW -> NHWC unpacking + + using Future = vTensor::Future; + Future v_self_future = v_self.host(command_buffer); + + // Ideally we would have been able to put as much distance between + // requesting the data - a call to host() - and accessing the data + // - a call to wait() - but a local view of the computation graph + // in eager mode makes that optimization non-trivial. + + // This wait() will be a no-op if no hazards are detected, including the + // obvious, yet important, special case of 'self' being an empty tensor. + + Future::Payload v_self_payload = v_self_future.wait(); + + memcpy( + v_self_payload.get(), + cpu_src.contiguous().data_ptr(), + std::min(src.nbytes(), self.nbytes())); + } } - // Vulkan -> Vulkan + // Vulkan -> X else if (at::kVulkan == src.device().type()) { - api::Command::Buffer command_buffer = api::context()->command().pool.allocate(); - command_buffer.begin(); - - command_buffer.copy( - // - Read-only access is implied on const tensors. Memory barriers - // are automatically inserted if a RAW hazard is detected. - // - Recording any potential pending sync operations into the same - // command buffer prevents an expensive queue submission. - convert(src).buffer( - command_buffer, - vTensor::Stage::Transfer), - // - Write-only access never triggers a sync as the contents will be - // overwritten regardless. Having said that, appropriate barriers - // are inserted automatically if WAR or WAW hazards are detected. - // - Recording pending sync operations into the same command buffer - // prevents an expensive queue submission. - v_self.buffer( - command_buffer, - vTensor::Stage::Transfer, - vTensor::Access::Write)); - - command_buffer.end(); - command_buffer.submit(api::context()->gpu().queue); - } - else { - TORCH_INTERNAL_ASSERT(false, "Unsupported!"); - } - } - // Vulkan -> X - else if (at::kVulkan == src.device().type()) { - const vTensor& v_src = convert(src); - - { - // Similar notes as above applies, with the additional consideration of - // potential syncs on read accesses. Namely, - // - on discrete systems, if the (staging, buffer, image) trio, or - // - on UMA, if the (buffer, image) duo - // have gone out of sync as a result of one processor writing to one - // resource which is then either accessed as an another resource type on - // the same or another processor. Same considerations regarding hazard - // avoidance as above applies. - - using Future = vTensor::Future; - const Future v_src_future = v_src.host(); + const vTensor& v_src = convert(src); // Vulkan -> CPU - if (at::kCPU == self.device().type()) { + if (self.device().is_cpu()) { + // Similar notes as above applies, with the additional consideration of + // potential syncs on read accesses. Namely, + // - on discrete systems, if the (staging, buffer, image) trio, or + // - on UMA, if the (buffer, image) duo + // have gone out of sync as a result of one processor writing to one + // resource which is then either accessed as an another resource type on + // the same or another processor. Same considerations regarding hazard + // avoidance as above applies. + + using Future = vTensor::Future; + const Future v_src_future = v_src.host(command_buffer); + + // Ideally we would have been able to put as much distance between + // requesting the data - a call to host() - and accessing the data + // - a call to wait() - but a local view of the computation graph + // in eager mode makes that optimization non-trivial. + // This wait() is a no-op if data is not out of sync. More often than // not though, waits here are expected as the GPU catches up with // compute submitted from CPU. @@ -99,51 +108,56 @@ Tensor& copy_(Tensor& self, const Tensor& src) { std::min(src.nbytes(), self.nbytes())); } else { - TORCH_INTERNAL_ASSERT(false, "Unsupported!"); + TORCH_CHECK(false, "Unsupported!"); } - } - // - // WARNING - // - - // This is not great. We almost never want to flush the GPU pipeline as - // that has far reaching consequences, especially if PyTorch is not the only - // process accessing the GPU. If we have done our job properly, above - // synchronization mechanisms should be enough to ensure correctness at a more - // modest cost, as there is no need to flush the entirety of jobs in flight - // if one is only interested on waiting on computation affecting one single - // tensor to finish. - // - // Having said that, we still do need to release all pool resources at one - // point per inference run or we will run out of memory otherwise. There is - // no perfect answer to this problem that checks all boxes, which leaves us - // with one of several design decisions: - // - // 1) Use graph mode to gain an understanding of the computation graph, - // itself allowing us to place pool purges intelligently. Best option - // for performance and memory consumption. Not without its downsides if - // flexibility is a top priority. - // 2) If on eager mode, and hence are seeing operations one at a time, expose - // this release of resources to the user as a Python / C++ function. This - // makes for suboptimal user experience but is efficient in terms of - // performance. - // 3) If on eager mode, and interested in keeping this bookkeeping transparent - // to the user, release all resources somewhere ... like here. This is - // not ideal since it requires a pipeline flush to make sure these objects - // are not already in use by a workload in flight. Cannot do much better - // within the constraints of this approach. Good for user experience, - // suboptimal for performance. - // 4) If on eager mode, and interested in keeping this bookkeeping transparent - // to the user, and performance does not matter, make CPU and GPU run in - // lockstep. Obviously this is just bad. Mentioned for the sake of - // completeness. - - api::context()->flush(); - } - else { - TORCH_INTERNAL_ASSERT(false, "Unsupported!"); + // + // WARNING + // + + // This is not great. We almost never want to flush the GPU pipeline as + // that has far reaching consequences, especially if PyTorch is not the only + // process accessing the GPU. If we have done our job properly, above + // synchronization mechanisms should be enough to ensure correctness at a more + // modest cost, as there is no need to flush the entirety of jobs in flight + // if one is only interested on waiting on computation affecting one single + // tensor to finish. + // + // Having said that, we still do need to release all pool resources at one + // point per inference run or we will run out of memory otherwise. There is + // no perfect answer to this problem that checks all boxes, which leaves us + // with one of several design decisions: + // + // 1) Use graph mode to gain an understanding of the computation graph, + // itself allowing us to place pool purges intelligently. Best option + // for performance and memory consumption. Not without its downsides if + // flexibility is a top priority. + // 2) If on eager mode, and hence are seeing operations one at a time, expose + // this release of resources to the user as a Python / C++ function. This + // makes for suboptimal user experience but is efficient in terms of + // performance. + // 3) If on eager mode, and interested in keeping this bookkeeping transparent + // to the user, release all resources somewhere ... like here. This is + // not ideal since it requires a pipeline flush to make sure these objects + // are not already in use by a workload in flight. Cannot do much better + // within the constraints of this approach. Good for user experience, + // suboptimal for performance. + // 4) If on eager mode, and interested in keeping this bookkeeping transparent + // to the user, and performance does not matter, make CPU and GPU run in + // lockstep. Obviously this is just bad. Mentioned for the sake of + // completeness. + + context->flush(); + } + else { + TORCH_INTERNAL_ASSERT( + false, + "Invalid code path taken! Either the source or the destination tensor " + "was expected to be Vulkan a tensor! Incorrect dispatch?"); + } } + // No queue submission here. All queue submissions must have been handled + // above either explicitly or as a result of calling tensor.host(). return self; } diff --git a/aten/src/ATen/native/vulkan/ops/Mean.cpp b/aten/src/ATen/native/vulkan/ops/Mean.cpp index f6d63c14f381..6a413f55ded5 100644 --- a/aten/src/ATen/native/vulkan/ops/Mean.cpp +++ b/aten/src/ATen/native/vulkan/ops/Mean.cpp @@ -52,11 +52,11 @@ Tensor mean( v_input.options(), }; - api::Command::Buffer command_buffer = context->command().pool.allocate(); - command_buffer.begin(); + api::Command::Pool& command_pool = context->command().pool; + api::Command::Buffer& command_buffer = command_pool.stream(); { - if (v_input.has_image()) { - const struct { + if C10_LIKELY(v_input.has_image()) { + const struct Block final { uvec3 extents; int32_t range; ivec2 iextents; @@ -71,63 +71,35 @@ Tensor mean( }, }; - if (keepdim) { - context->dispatch( - command_buffer, - { - VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, - VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, - VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, - }, - VK_KERNEL(mean), - v_output.extents(), - // Write-only access bypasses synchronization but inserts appropriate - // barriers if necessary. - v_output.image( - command_buffer, - vTensor::Stage::Compute, - vTensor::Access::Write), - // Read-only access is implied on const tensors and triggers an async - // synchronization if necessary. - v_input.image( - command_buffer, - vTensor::Stage::Compute), - // Object lifetime is managed by the resource pool. - // It is OK not to keep track of the handle. - context->resource().pool.uniform(block).object); - } - else { - context->dispatch( - command_buffer, - { - VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, - VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, - VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, - }, - VK_KERNEL(mean2d), - v_output.extents(), - // Write-only access bypasses synchronization but inserts appropriate - // barriers if necessary. - v_output.image( - command_buffer, - vTensor::Stage::Compute, - vTensor::Access::Write), - // Read-only access is implied on const tensors and triggers an async - // synchronization if necessary. - v_input.image( - command_buffer, - vTensor::Stage::Compute), - // Object lifetime is managed by the resource pool. - // It is OK not to keep track of the handle. - context->resource().pool.uniform(block).object); - } + context->dispatch( + command_buffer, + { + VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, + }, + keepdim ? VK_KERNEL(mean) : VK_KERNEL(mean2d), + v_output.extents(), + // Write-only access bypasses synchronization but inserts appropriate + // barriers if necessary. + v_output.image( + command_buffer, + vTensor::Stage::Compute, + vTensor::Access::Write), + // Read-only access is implied on const tensors and triggers an async + // synchronization if necessary. + v_input.image( + command_buffer, + vTensor::Stage::Compute), + // Object lifetime is managed by the resource pool. + // It is OK not to keep track of the handle. + context->resource().pool.uniform(block).object); } else { TORCH_CHECK(false, "Not implemented!"); } } - command_buffer.end(); - command_buffer.submit(context->gpu().queue); + command_pool.submit(context->gpu().queue, command_buffer); return convert(v_output); } diff --git a/aten/src/ATen/native/vulkan/ops/Mm.cpp b/aten/src/ATen/native/vulkan/ops/Mm.cpp index fa1f7d45048c..e50f9724e665 100644 --- a/aten/src/ATen/native/vulkan/ops/Mm.cpp +++ b/aten/src/ATen/native/vulkan/ops/Mm.cpp @@ -10,18 +10,21 @@ namespace { using namespace api::utils; vTensor pack_weights( - api::Resource::Pool& pool, - const Tensor& weight_arg) { + api::Resource::Pool& pool, + const Tensor& weight_arg) { if (weight_arg.is_vulkan()) { return convert(weight_arg); } + api::Context* const context = api::context(); + api::Command::Buffer& command_buffer = context->command().pool.stream(); + const Tensor weight = weight_arg.contiguous(); const IntArrayRef w_sizes = weight.sizes(); const float* const src_weight_ptr = weight.data_ptr(); vTensor v_weight{ - api::context(), + context, &pool, w_sizes, weight.options(), @@ -29,7 +32,7 @@ vTensor pack_weights( { using Future = vTensor::Future; - Future v_weight_future = v_weight.host(); + Future v_weight_future = v_weight.host(command_buffer); Future::Payload v_weight_payload = v_weight_future.wait(); memcpy( @@ -49,16 +52,21 @@ vTensor pack_biases( return convert(*bias_arg); } + api::Context* const context = api::context(); + api::Command::Buffer& command_buffer = context->command().pool.stream(); + vTensor v_bias{ - api::context(), + context, &pool, - {weight_arg.sizes()[Layout::Parameter::width]}, + { + weight_arg.size(Layout::Parameter::width), + }, weight_arg.options(), }; { using Future = vTensor::Future; - Future v_bias_future = v_bias.host(); + Future v_bias_future = v_bias.host(command_buffer); Future::Payload v_bias_payload = v_bias_future.wait(); if (bias_arg) { @@ -66,7 +74,8 @@ vTensor pack_biases( v_bias_payload.get(), bias_arg->contiguous().data_ptr(), std::min(bias_arg->nbytes(), v_bias.nbytes())); - } else { + } + else { memset( v_bias_payload.get(), // 2's complement integers and IEEE-754 floating point numbers both @@ -162,11 +171,11 @@ Tensor mm( mat1.options(), }; - api::Command::Buffer command_buffer = context->command().pool.allocate(); - command_buffer.begin(); + api::Command::Pool& command_pool = context->command().pool; + api::Command::Buffer& command_buffer = command_pool.stream(); { - if (v_mat1.has_image() && v_mat2.has_image()) { - const struct { + if C10_LIKELY(v_mat1.has_image() && v_mat2.has_image()) { + const struct Block final { uvec3 size; int32_t K; } block { @@ -203,12 +212,12 @@ Tensor mm( // Object lifetime is managed by the resource pool. // It is OK not to keep track of the handle. context->resource().pool.uniform(block).object); - } else { + } + else { TORCH_CHECK(false, "Not implemented!"); } } - command_buffer.end(); - command_buffer.submit(context->gpu().queue); + command_pool.submit(context->gpu().queue, command_buffer); return convert(v_output); } @@ -281,14 +290,15 @@ Tensor LinearOpContext::run( input.options(), }; - api::Command::Buffer command_buffer = context->command().pool.allocate(); - command_buffer.begin(); + api::Command::Pool& command_pool = context->command().pool; + api::Command::Buffer& command_buffer = command_pool.stream(); { - if (v_output.has_image() && + if C10_LIKELY( + v_output.has_image() && v_input.has_image() && packed_.v_weight.has_image() && packed_.v_bias.has_image()) { - const struct { + const struct Block final { uvec3 size; int32_t K; vec2 multiplier; @@ -341,8 +351,7 @@ Tensor LinearOpContext::run( TORCH_CHECK(false, "Not implemented!"); } } - command_buffer.end(); - command_buffer.submit(context->gpu().queue); + command_pool.submit(context->gpu().queue, command_buffer); return convert(v_output); } diff --git a/aten/src/ATen/native/vulkan/ops/Mul.cpp b/aten/src/ATen/native/vulkan/ops/Mul.cpp index 84226135929a..1e494287a5ae 100644 --- a/aten/src/ATen/native/vulkan/ops/Mul.cpp +++ b/aten/src/ATen/native/vulkan/ops/Mul.cpp @@ -23,11 +23,11 @@ Tensor mul_scalar( v_self.options(), }; - api::Command::Buffer command_buffer = context->command().pool.allocate(); - command_buffer.begin(); + api::Command::Pool& command_pool = context->command().pool; + api::Command::Buffer& command_buffer = command_pool.stream(); { - if (v_output.has_image() && v_self.has_image()) { - const struct { + if C10_LIKELY(v_output.has_image() && v_self.has_image()) { + const struct Block final { uvec3 extents; float other; } block { @@ -63,8 +63,7 @@ Tensor mul_scalar( TORCH_CHECK(false, "Not implemented!"); } } - command_buffer.end(); - command_buffer.submit(context->gpu().queue); + command_pool.submit(context->gpu().queue, command_buffer); return convert(v_output); } @@ -80,11 +79,11 @@ Tensor& mul_scalar_( vTensor& v_self = convert(self); - api::Command::Buffer command_buffer = context->command().pool.allocate(); - command_buffer.begin(); + api::Command::Pool& command_pool = context->command().pool; + api::Command::Buffer& command_buffer = command_pool.stream(); { - if (v_self.has_image()) { - const struct { + if C10_LIKELY(v_self.has_image()) { + const struct Block final { uvec3 extents; float other; } block { @@ -114,8 +113,7 @@ Tensor& mul_scalar_( TORCH_CHECK(false, "Not implemented!"); } } - command_buffer.end(); - command_buffer.submit(context->gpu().queue); + command_pool.submit(context->gpu().queue, command_buffer); return self; } diff --git a/aten/src/ATen/native/vulkan/ops/Pool.cpp b/aten/src/ATen/native/vulkan/ops/Pool.cpp index 0bc97d6741bc..5eaaf9d04171 100644 --- a/aten/src/ATen/native/vulkan/ops/Pool.cpp +++ b/aten/src/ATen/native/vulkan/ops/Pool.cpp @@ -33,10 +33,10 @@ Tensor adaptive_avg_pool2d( v_self.options(), }; - api::Command::Buffer command_buffer = context->command().pool.allocate(); - command_buffer.begin(); + api::Command::Pool& command_pool = context->command().pool; + api::Command::Buffer& command_buffer = command_pool.stream(); { - if (v_self.has_image()) { + if C10_LIKELY(v_self.has_image()) { const uvec3 v_output_size = v_output.extents(); const uvec3 v_self_size = v_self.extents(); @@ -45,7 +45,7 @@ Tensor adaptive_avg_pool2d( static_cast(v_self_size.data[1u]) / v_output_size.data[1u], }; - const struct { + const struct Block final { uvec3 size; uint32_t _; vec2 stride; @@ -88,8 +88,7 @@ Tensor adaptive_avg_pool2d( TORCH_CHECK(false, "Not implemented!"); } } - command_buffer.end(); - command_buffer.submit(context->gpu().queue); + command_pool.submit(context->gpu().queue, command_buffer); return convert(v_output); } @@ -171,13 +170,11 @@ Tensor avg_pool2d( v_self.options(), }; - api::Command::Buffer command_buffer = context->command().pool.allocate(); - command_buffer.begin(); + api::Command::Pool& command_pool = context->command().pool; + api::Command::Buffer& command_buffer = command_pool.stream(); { - using namespace utils; - - if (v_self.has_image()) { - const struct { + if C10_LIKELY(v_self.has_image()) { + const struct Block final { uvec3 extents; int32_t range; ivec2 iextents; @@ -235,8 +232,7 @@ Tensor avg_pool2d( TORCH_CHECK(false, "Not implemented!"); } } - command_buffer.end(); - command_buffer.submit(context->gpu().queue); + command_pool.submit(context->gpu().queue, command_buffer); return convert(v_output); } diff --git a/aten/src/ATen/native/vulkan/ops/Shape.cpp b/aten/src/ATen/native/vulkan/ops/Shape.cpp index 8edfda60b76f..9d2a248f0707 100644 --- a/aten/src/ATen/native/vulkan/ops/Shape.cpp +++ b/aten/src/ATen/native/vulkan/ops/Shape.cpp @@ -21,8 +21,8 @@ Tensor view( self.options(), }; - api::Command::Buffer command_buffer = context->command().pool.allocate(); - command_buffer.begin(); + api::Command::Pool& command_pool = context->command().pool; + api::Command::Buffer& command_buffer = command_pool.stream(); { command_buffer.copy( // Read-only access is implied on const tensors and triggers an async @@ -37,8 +37,7 @@ Tensor view( vTensor::Stage::Transfer, vTensor::Access::Write)); } - command_buffer.end(); - command_buffer.submit(context->gpu().queue); + command_pool.submit(context->gpu().queue, command_buffer); return convert(v_output); } diff --git a/aten/src/ATen/native/vulkan/ops/Tensor.cpp b/aten/src/ATen/native/vulkan/ops/Tensor.cpp index 3570834b8dd4..0bf7acbe7dee 100644 --- a/aten/src/ATen/native/vulkan/ops/Tensor.cpp +++ b/aten/src/ATen/native/vulkan/ops/Tensor.cpp @@ -419,31 +419,19 @@ vTensor::vTensor( }) { } -const vTensor* vTensor::host() const { - view_->staging(Stage::Host, Access::Read); +const vTensor* vTensor::host( + api::Command::Buffer& command_buffer) const { + view_->staging(command_buffer, Stage::Host, Access::Read); return this; } -vTensor* vTensor::host(const Access::Flags access) { - view_->staging(Stage::Host, access); +vTensor* vTensor::host( + api::Command::Buffer& command_buffer, + const Access::Flags access) { + view_->staging(command_buffer, Stage::Host, access); return this; } -vTensor::Buffer::Object vTensor::buffer( - const Stage::Flags stage) const & { - return view_->buffer( - stage, - Access::Read).object; -} - -vTensor::Buffer::Object vTensor::buffer( - const Stage::Flags stage, - const Access::Flags access) & { - return view_->buffer( - stage, - access).object; -} - vTensor::Buffer::Object vTensor::buffer( api::Command::Buffer& command_buffer, const Stage::Flags stage) const & { @@ -463,21 +451,6 @@ vTensor::Buffer::Object vTensor::buffer( access).object; } -vTensor::Image::Object vTensor::image( - const Stage::Flags stage) const & { - return view_->image( - stage, - Access::Read).object; -} - -vTensor::Image::Object vTensor::image( - const Stage::Flags stage, - const Access::Flags access) & { - return view_->image( - stage, - access).object; -} - vTensor::Image::Object vTensor::image( api::Command::Buffer& command_buffer, const Stage::Flags stage) const & { @@ -535,16 +508,8 @@ vTensor::View::View( ops::verify(options); } -// We typically do not know whether we need a command buffer to service a request -// until we have perfomed a bunch of checks in nested logic, and even then we -// may end up with the always issued state transition optimized away under -// certain conditions, which makes a policy of always allocating a command buffer -// up front, only to end up using it at times, a wasteful approach. This class -// answers that need. - class vTensor::View::CMD final { public: - explicit CMD(const View&); CMD(const View&, api::Command::Buffer&); CMD(const CMD&) = delete; CMD& operator=(const CMD&) = delete; @@ -578,60 +543,18 @@ class vTensor::View::CMD final { const Image::Object& image, Buffer::Object& buffer); - void submit(Fence fence = {}); - - private: - api::Command::Buffer& command_buffer(); + void submit(Fence fence); private: const View& view_; - - enum class Type { - Internal, - External, - } type; - - union _ final { - api::Command::Buffer internal; - api::Command::Buffer* external; - ~_() {} - } command_buffer_; + api::Command::Buffer& command_buffer_; }; -vTensor::View::CMD::CMD( - const View& view) - : view_(view), - type(Type::Internal), - command_buffer_{} { -} - vTensor::View::CMD::CMD( const View& view, - api::Command::Buffer& external) + api::Command::Buffer& command_buffer) : view_(view), - type(Type::External), - command_buffer_{ - .external = &external, - } { -} - -api::Command::Buffer& vTensor::View::CMD::command_buffer() { - switch (type) { - case Type::Internal: - if (!command_buffer_.internal) { - command_buffer_.internal = view_.context_->command().pool.allocate(); - command_buffer_.internal.begin(); - } - - return command_buffer_.internal; - - case Type::External: - return *(command_buffer_.external); - - default: - TORCH_INTERNAL_ASSERT(false, "Unknown command buffer type!"); - break; - } + command_buffer_(command_buffer) { } void vTensor::View::CMD::barrier(State::Transition transition) { @@ -761,7 +684,7 @@ void vTensor::View::CMD::barrier(State::Transition transition) { barrier.stage.src = VK_PIPELINE_STAGE_BOTTOM_OF_PIPE_BIT; } - command_buffer().barrier(barrier); + command_buffer_.barrier(barrier); } } @@ -789,7 +712,7 @@ void vTensor::View::CMD::copy_buffer_to_staging( {}, })); - command_buffer().copy(buffer, staging); + command_buffer_.copy(buffer, staging); } void vTensor::View::CMD::copy_staging_to_buffer( @@ -816,7 +739,7 @@ void vTensor::View::CMD::copy_staging_to_buffer( {}, })); - command_buffer().copy(staging, buffer); + command_buffer_.copy(staging, buffer); } void vTensor::View::CMD::copy_buffer_to_image( @@ -847,7 +770,7 @@ void vTensor::View::CMD::copy_buffer_to_image( const uvec3 extents = view_.extents(); const uint32_t plane = extents.data[0u] * extents.data[1u]; - const struct { + const struct Block final { uvec3 extents; uint32_t block; uvec4 offset; @@ -863,7 +786,7 @@ void vTensor::View::CMD::copy_buffer_to_image( }; view_.context_->dispatch( - command_buffer(), + command_buffer_, { VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, @@ -904,7 +827,7 @@ void vTensor::View::CMD::copy_image_to_buffer( const uvec3 extents = view_.extents(); const uint32_t plane = extents.data[0u] * extents.data[1u]; - const struct { + const struct Block final { uvec3 extents; uint32_t block; uvec4 offset; @@ -920,7 +843,7 @@ void vTensor::View::CMD::copy_image_to_buffer( }; view_.context_->dispatch( - command_buffer(), + command_buffer_, { VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, @@ -934,10 +857,10 @@ void vTensor::View::CMD::copy_image_to_buffer( } void vTensor::View::CMD::submit(const api::Resource::Fence fence) { - if ((Type::Internal == type) && command_buffer_.internal) { - command_buffer_.internal.end(); - command_buffer_.internal.submit(view_.context_->gpu().queue, fence); - } + view_.context_->command().pool.submit( + view_.context_->gpu().queue, + command_buffer_, + fence); } vTensor::Buffer& vTensor::View::buffer() const { @@ -953,38 +876,28 @@ vTensor::Buffer& vTensor::View::buffer() const { } vTensor::Buffer& vTensor::View::buffer( + api::Command::Buffer& command_buffer, const Stage::Flags stage, const Access::Flags access) const { - CMD command_buffer(*this); - Buffer& buffer = this->buffer(command_buffer, stage, access); - command_buffer.submit(); - - return buffer; -} - -vTensor::Buffer& vTensor::View::buffer( - api::Command::Buffer& command_buffer_, - const Stage::Flags stage, - const Access::Flags access) const { - CMD command_buffer(*this, command_buffer_); - return buffer(command_buffer, stage, access); + CMD cmd(*this, command_buffer); + return buffer(cmd, stage, access); } vTensor::Buffer& vTensor::View::buffer( - CMD& command_buffer, + CMD& cmd, const Stage::Flags stage, const Access::Flags access) const { if ((access & Access::Read) && state_.is_dirty(Component::Buffer)) { if (state_.is_clean(Component::Staging)) { - command_buffer.copy_staging_to_buffer( + cmd.copy_staging_to_buffer( state_, - staging(command_buffer, Stage::Transfer, Access::Read).object, + staging(cmd, Stage::Transfer, Access::Read).object, buffer().object); } else if (state_.is_clean(Component::Image)) { - command_buffer.copy_image_to_buffer( + cmd.copy_image_to_buffer( state_, - image(command_buffer, Stage::Compute, Access::Read).object, + image(cmd, Stage::Compute, Access::Read).object, buffer().object); } else { @@ -994,7 +907,7 @@ vTensor::Buffer& vTensor::View::buffer( } } - command_buffer.barrier( + cmd.barrier( state_.transition({ // Staging {}, @@ -1028,35 +941,25 @@ vTensor::Image& vTensor::View::image() const { } vTensor::Image& vTensor::View::image( + api::Command::Buffer& command_buffer, const Stage::Flags stage, const Access::Flags access) const { - CMD command_buffer(*this); - Image& image = this->image(command_buffer, stage, access); - command_buffer.submit(); - - return image; -} - -vTensor::Image& vTensor::View::image( - api::Command::Buffer& command_buffer_, - const Stage::Flags stage, - const Access::Flags access) const { - CMD command_buffer(*this, command_buffer_); - return image(command_buffer, stage, access); + CMD cmd(*this, command_buffer); + return image(cmd, stage, access); } vTensor::Image& vTensor::View::image( - CMD& command_buffer, + CMD& cmd, const Stage::Flags stage, const Access::Flags access) const { if ((access & Access::Read) && state_.is_dirty(Component::Image)) { - command_buffer.copy_buffer_to_image( + cmd.copy_buffer_to_image( state_, - buffer(command_buffer, stage, Access::Read).object, + buffer(cmd, stage, Access::Read).object, image().object); } - command_buffer.barrier( + cmd.barrier( state_.transition({ // Staging {}, @@ -1096,27 +999,28 @@ vTensor::Buffer& vTensor::View::staging() const { } vTensor::Buffer& vTensor::View::staging( + api::Command::Buffer& command_buffer, const Stage::Flags stage, const Access::Flags access) const { - CMD command_buffer(*this); - Buffer& staging = this->staging(command_buffer, stage, access); - command_buffer.submit(fence()); + CMD cmd(*this, command_buffer); + Buffer& staging = this->staging(cmd, stage, access); + cmd.submit(fence(access)); return staging; } vTensor::Buffer& vTensor::View::staging( - CMD& command_buffer, + CMD& cmd, const Stage::Flags stage, const Access::Flags access) const { if ((access & Access::Read) && state_.is_dirty(Component::Staging)) { - command_buffer.copy_buffer_to_staging( + cmd.copy_buffer_to_staging( state_, - buffer(command_buffer, Stage::Transfer, Access::Read).object, + buffer(cmd, Stage::Transfer, Access::Read).object, staging().object); } - command_buffer.barrier( + cmd.barrier( state_.transition({ // Staging { @@ -1138,6 +1042,14 @@ vTensor::Buffer& vTensor::View::staging( return staging(); } +vTensor::Fence& vTensor::View::fence(const Access::Flags access) const { + if (access & Access::Read) { + fence_ = allocate_fence(&context_->resource().pool); + } + + return fence_; +} + vTensor::Memory& vTensor::View::wait() const { if (fence_) { fence_.wait(); @@ -1146,10 +1058,6 @@ vTensor::Memory& vTensor::View::wait() const { return staging().memory; } -vTensor::Fence& vTensor::View::fence() const { - return (fence_ = allocate_fence(pool_)); -} - void vTensor::View::verify() const { TORCH_INTERNAL_ASSERT(!image_ || state_.is_available(Component::Image)); TORCH_INTERNAL_ASSERT(!staging_ || state_.is_discrete()); diff --git a/aten/src/ATen/native/vulkan/ops/Tensor.h b/aten/src/ATen/native/vulkan/ops/Tensor.h index 48d4cca84dd4..f404988b420b 100644 --- a/aten/src/ATen/native/vulkan/ops/Tensor.h +++ b/aten/src/ATen/native/vulkan/ops/Tensor.h @@ -157,10 +157,10 @@ class vTensor final { */ template - Future host() const &; + Future host(api::Command::Buffer&) const &; template - Future host() &; + Future host(api::Command::Buffer&) &; /* Device access - these functions will be expensive if they trigger a buffer @@ -178,14 +178,10 @@ class vTensor final { predictability of usage and efficiency. */ - Buffer::Object buffer(Stage::Flags) const &; - Buffer::Object buffer(Stage::Flags, Access::Flags) &; Buffer::Object buffer(api::Command::Buffer&, Stage::Flags) const &; Buffer::Object buffer(api::Command::Buffer&, Stage::Flags, Access::Flags) &; bool has_image() const; - Image::Object image(Stage::Flags) const &; - Image::Object image(Stage::Flags, Access::Flags) &; Image::Object image(api::Command::Buffer&, Stage::Flags) const &; Image::Object image(api::Command::Buffer&, Stage::Flags, Access::Flags) &; @@ -210,26 +206,22 @@ class vTensor final { Host */ - const vTensor* host() const; - vTensor* host(Access::Flags access); + const vTensor* host(api::Command::Buffer&) const; + vTensor* host(api::Command::Buffer&, Access::Flags); template - Future host() const && = delete; + Future host(api::Command::Buffer&) const && = delete; template - Future host() && = delete; + Future host(api::Command::Buffer&) && = delete; /* Device */ - Buffer::Object buffer(Stage::Flags) const && = delete; - Buffer::Object buffer(Stage::Flags, Access::Flags) && = delete; Buffer::Object buffer(api::Command::Buffer&, Stage::Flags) const && = delete; Buffer::Object buffer(api::Command::Buffer&, Stage::Flags, Access::Flags) && = delete; - Image::Object image(Stage::Flags) const && = delete; - Image::Object image(Stage::Flags, Access::Flags) && = delete; Image::Object image(api::Command::Buffer&, Stage::Flags) const && = delete; Image::Object image(api::Command::Buffer&, Stage::Flags, Access::Flags) && = delete; @@ -249,21 +241,22 @@ class vTensor final { ~View() = default; /* - Device + Buffer */ - Buffer& buffer(Stage::Flags, Access::Flags) const; Buffer& buffer(api::Command::Buffer&, Stage::Flags, Access::Flags) const; + /* + Image + */ + bool has_image() const; - Image& image(Stage::Flags, Access::Flags) const; Image& image(api::Command::Buffer&, Stage::Flags, Access::Flags) const; /* Host */ - Buffer& staging(Stage::Flags, Access::Flags) const; Buffer& staging(api::Command::Buffer&, Stage::Flags, Access::Flags) const; vTensor::Memory& wait() const; @@ -343,7 +336,7 @@ class vTensor final { Image& image(CMD&, Stage::Flags, Access::Flags) const; Buffer& staging() const; Buffer& staging(CMD&, Stage::Flags, Access::Flags) const; - Fence& fence() const; + Fence& fence(Access::Flags) const; // Validation void verify() const; @@ -485,13 +478,15 @@ vTensor::Future::wait() const & { } template -inline vTensor::Future vTensor::host() const & { - return Future(host()); +inline vTensor::Future +vTensor::host(api::Command::Buffer& command_buffer) const & { + return Future(host(command_buffer)); } template -inline vTensor::Future vTensor::host() & { - return Future(host(kAccess)); +inline vTensor::Future +vTensor::host(api::Command::Buffer& command_buffer) & { + return Future(host(command_buffer, kAccess)); } inline bool vTensor::has_image() const { diff --git a/aten/src/ATen/native/vulkan/ops/Upsample.cpp b/aten/src/ATen/native/vulkan/ops/Upsample.cpp index 32508c01eec1..00cefc1bdf53 100644 --- a/aten/src/ATen/native/vulkan/ops/Upsample.cpp +++ b/aten/src/ATen/native/vulkan/ops/Upsample.cpp @@ -36,11 +36,11 @@ Tensor upsample_nearest2d( input.options(), }; - api::Command::Buffer command_buffer = context->command().pool.allocate(); - command_buffer.begin(); + api::Command::Pool& command_pool = context->command().pool; + api::Command::Buffer& command_buffer = command_pool.stream(); { - if (v_input.has_image()) { - const struct { + if C10_LIKELY(v_input.has_image()) { + const struct Block final { uvec3 extents; uint32_t _; ivec2 iextents; @@ -92,8 +92,7 @@ Tensor upsample_nearest2d( TORCH_CHECK(false, "Not implemented!"); } } - command_buffer.end(); - command_buffer.submit(context->gpu().queue); + command_pool.submit(context->gpu().queue, command_buffer); return convert(v_output); } diff --git a/aten/src/ATen/native/vulkan/ops/Utils.h b/aten/src/ATen/native/vulkan/ops/Utils.h index ffdc2b6e94eb..de218cfc472a 100644 --- a/aten/src/ATen/native/vulkan/ops/Utils.h +++ b/aten/src/ATen/native/vulkan/ops/Utils.h @@ -10,7 +10,7 @@ namespace vulkan { namespace ops { namespace utils { -int64_t normalize( +inline int64_t normalize( const int64_t dimension, const int64_t n) { return (dimension % n + n) % n; From 49bb0a30e81b91647a42382973c4826408019f87 Mon Sep 17 00:00:00 2001 From: Yuxin Wu Date: Fri, 8 Jan 2021 16:53:11 -0800 Subject: [PATCH 54/54] Support scripting classmethod called with object instances (#49967) Summary: Currentlt classmethods are compiled the same way as methods - the first argument is self. Adding a fake statement to assign the first argument to the class. This is kind of hacky, but that's all it takes. Pull Request resolved: https://github.com/pytorch/pytorch/pull/49967 Reviewed By: gchanan Differential Revision: D25841378 Pulled By: ppwwyyxx fbshipit-source-id: 0f3657b4c9d5d2181d658f9bade9bafc72de33d8 --- test/jit/test_class_type.py | 26 ++++++++++++++++++++++++++ torch/jit/frontend.py | 15 +++++++++++++-- 2 files changed, 39 insertions(+), 2 deletions(-) diff --git a/test/jit/test_class_type.py b/test/jit/test_class_type.py index a80670f0d22b..1c0eed071dfd 100644 --- a/test/jit/test_class_type.py +++ b/test/jit/test_class_type.py @@ -1212,6 +1212,32 @@ def test_function(a: int, b: int) -> 'ClassWithStaticMethod': self.checkScript(test_function, (1, 2)) + def test_classmethod(self): + """ + Test classmethods on class types. + """ + global ClassWithClassMethod + + @torch.jit.script + class ClassWithClassMethod: + def __init__(self, a: int): + self.a: int = a + + def __eq__(self, other: 'ClassWithClassMethod'): + return self.a == other.a + + @classmethod + def create(cls, a: int) -> 'ClassWithClassMethod': + return cls(a) + + def test_function(a: int) -> 'ClassWithClassMethod': + x = ClassWithClassMethod(a) + # Support calling classmethod with an instance + # Calling with the class is not supported. + return x.create(a) + + self.checkScript(test_function, (1,)) + def test_properties(self): """ Test that a scripted class can make use of the @property decorator. diff --git a/torch/jit/frontend.py b/torch/jit/frontend.py index da3b3c2301a6..c6e5bd9a7870 100644 --- a/torch/jit/frontend.py +++ b/torch/jit/frontend.py @@ -166,9 +166,14 @@ def get_jit_class_def(cls, self_name): and not is_static_fn(cls, m.__name__) and m.__name__ in cls.__dict__ ) + + def is_classmethod(fn): + return inspect.ismethod(fn) and getattr(fn, "__self__", None) == cls + methods = [get_jit_def(method[1], method[0], - self_name=self_name) for method in methods] + self_name=self_name, + is_classmethod=is_classmethod(method[1])) for method in methods] properties = get_class_properties(cls, self_name) @@ -217,7 +222,7 @@ def remove_prefix(text, prefix): return aligned_prefix + aligned_suffix -def get_jit_def(fn, def_name, self_name=None): +def get_jit_def(fn, def_name, self_name=None, is_classmethod=False): """ Build a JIT AST (TreeView) from the given function. @@ -244,6 +249,12 @@ def _forward(self): ctx = SourceContext(source, filename, file_lineno, leading_whitespace_len, True) fn_def = py_ast.body[0] + if is_classmethod: + arg_name = fn_def.args.args[0].arg + # Insert a statement that assigns the first argument to the class + assign_stmt = ast.parse(f"{arg_name} = {self_name}").body[0] + fn_def.body.insert(0, assign_stmt) + # Swap out the function signature and body if it is unused if should_drop(fn): unused_fn_def = ast.parse("def unused_fn(self: Any):\n\traise RuntimeError(\"Cannot call @unused methods\")")