diff --git a/CMakeLists.txt b/CMakeLists.txt index 3df73f8a3041..abef53630b7a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -207,7 +207,7 @@ cmake_dependent_option( USE_VALGRIND "Use Valgrind. Only available on Linux." ON "LINUX" OFF) option(USE_VULKAN "Use Vulkan GPU backend" OFF) -option(USE_VULKAN_FP16_INFERENCE "Vulkan - Use fp16 inference even on fp32 tensors" ON) +option(USE_VULKAN_FP16_INFERENCE "Vulkan - Use fp16 inference even on fp32 tensors" OFF) option(USE_VULKAN_RELAXED_PRECISION "Vulkan - Use relaxed precision math in the kernels (mediump)" OFF) option(USE_VULKAN_SHADERC_RUNTIME "Vulkan - Use runtime shader compilation (needs libshaderc)" OFF) option(USE_VULKAN_WRAPPER "Vulkan - Dynamically load Vulkan functions" ON) @@ -318,7 +318,7 @@ set(OP_DEPENDENCY "" CACHE STRING # symbol lookup error: miniconda3/envs/pytorch-py3.7/lib/libmkl_intel_lp64.so: undefined symbol: mkl_blas_dsyrk # https://software.intel.com/en-us/articles/symbol-lookup-error-when-linking-intel-mkl-with-gcc-on-ubuntu if(LINUX) - set(CMAKE_SHARED_LINKER_FLAGS "-Wl,--no-as-needed ${CMAKE_SHARED_LINKER_FLAGS}") + set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -Wl,--no-as-needed") endif() if(MSVC) diff --git a/Dockerfile b/Dockerfile index 70da5dbaf424..cbaa85597ad9 100644 --- a/Dockerfile +++ b/Dockerfile @@ -59,6 +59,7 @@ RUN /opt/conda/bin/conda install -c "${INSTALL_CHANNEL}" -y python=${PYTHON_VERS RUN /opt/conda/bin/pip install torchelastic FROM ${BASE_IMAGE} as official +ARG PYTORCH_VERSION LABEL com.nvidia.volumes.needed="nvidia_driver" RUN --mount=type=cache,id=apt-final,target=/var/cache/apt \ apt-get update && apt-get install -y --no-install-recommends \ @@ -71,6 +72,7 @@ ENV PATH /opt/conda/bin:$PATH ENV NVIDIA_VISIBLE_DEVICES all ENV NVIDIA_DRIVER_CAPABILITIES compute,utility ENV LD_LIBRARY_PATH /usr/local/nvidia/lib:/usr/local/nvidia/lib64 +ENV PYTORCH_VERSION ${PYTORCH_VERSION} WORKDIR /workspace FROM official as dev diff --git a/android/pytorch_android/generate_test_torchscripts.py b/android/pytorch_android/generate_test_torchscripts.py index 6384d588e9aa..8b41fefc246e 100644 --- a/android/pytorch_android/generate_test_torchscripts.py +++ b/android/pytorch_android/generate_test_torchscripts.py @@ -20,92 +20,77 @@ def forward(self, input): return None @torch.jit.script_method - def eqBool(self, input): - # type: (bool) -> bool + def eqBool(self, input: bool) -> bool: return input @torch.jit.script_method - def eqInt(self, input): - # type: (int) -> int + def eqInt(self, input: int) -> int: return input @torch.jit.script_method - def eqFloat(self, input): - # type: (float) -> float + def eqFloat(self, input: float) -> float: return input @torch.jit.script_method - def eqStr(self, input): - # type: (str) -> str + def eqStr(self, input: str) -> str: return input @torch.jit.script_method - def eqTensor(self, input): - # type: (Tensor) -> Tensor + def eqTensor(self, input: Tensor) -> Tensor: return input @torch.jit.script_method - def eqDictStrKeyIntValue(self, input): - # type: (Dict[str, int]) -> Dict[str, int] + def eqDictStrKeyIntValue(self, input: Dict[str, int]) -> Dict[str, int]: return input @torch.jit.script_method - def eqDictIntKeyIntValue(self, input): - # type: (Dict[int, int]) -> Dict[int, int] + def eqDictIntKeyIntValue(self, input: Dict[int, int]) -> Dict[int, int]: return input @torch.jit.script_method - def eqDictFloatKeyIntValue(self, input): - # type: (Dict[float, int]) -> Dict[float, int] + def eqDictFloatKeyIntValue(self, input: Dict[float, int]) -> Dict[float, int]: return input @torch.jit.script_method - def listIntSumReturnTuple(self, input): - # type: (List[int]) -> Tuple[List[int], int] + def listIntSumReturnTuple(self, input: List[int]) -> Tuple[List[int], int]: sum = 0 for x in input: sum += x return (input, sum) @torch.jit.script_method - def listBoolConjunction(self, input): - # type: (List[bool]) -> bool + def listBoolConjunction(self, input: List[bool]) -> bool: res = True for x in input: res = res and x return res @torch.jit.script_method - def listBoolDisjunction(self, input): - # type: (List[bool]) -> bool + def listBoolDisjunction(self, input: List[bool]) -> bool: res = False for x in input: res = res or x return res @torch.jit.script_method - def tupleIntSumReturnTuple(self, input): - # type: (Tuple[int, int, int]) -> Tuple[Tuple[int, int, int], int] + def tupleIntSumReturnTuple(self, input: Tuple[int, int, int]) -> Tuple[Tuple[int, int, int], int]: sum = 0 for x in input: sum += x return (input, sum) @torch.jit.script_method - def optionalIntIsNone(self, input): - # type: (Optional[int]) -> bool + def optionalIntIsNone(self, input: Optional[int]) -> bool: return input is None @torch.jit.script_method - def intEq0None(self, input): - # type: (int) -> Optional[int] + def intEq0None(self, input: int) -> Optional[int]: if input == 0: return None return input @torch.jit.script_method - def str3Concat(self, input): - # type: (str) -> str + def str3Concat(self, input: str) -> str: return input + input + input @torch.jit.script_method @@ -113,8 +98,7 @@ def newEmptyShapeWithItem(self, input): return torch.tensor([int(input.item())])[0] @torch.jit.script_method - def testAliasWithOffset(self): - # type: () -> List[Tensor] + def testAliasWithOffset(self) -> List[Tensor]: x = torch.tensor([100, 200]) a = [x[0], x[1]] return a @@ -128,8 +112,7 @@ def testNonContiguous(self): return x @torch.jit.script_method - def conv2d(self, x, w, toChannelsLast): - # type: (Tensor, Tensor, bool) -> Tensor + def conv2d(self, x: Tensor, w: Tensor, toChannelsLast: bool) -> Tensor: r = torch.nn.functional.conv2d(x, w) if (toChannelsLast): r = r.contiguous(memory_format=torch.channels_last) @@ -138,18 +121,15 @@ def conv2d(self, x, w, toChannelsLast): return r @torch.jit.script_method - def contiguous(self, x): - # type: (Tensor) -> Tensor + def contiguous(self, x: Tensor) -> Tensor: return x.contiguous() @torch.jit.script_method - def contiguousChannelsLast(self, x): - # type: (Tensor) -> Tensor + def contiguousChannelsLast(self, x: Tensor) -> Tensor: return x.contiguous(memory_format=torch.channels_last) @torch.jit.script_method - def contiguousChannelsLast3d(self, x): - # type: (Tensor) -> Tensor + def contiguousChannelsLast3d(self, x: Tensor) -> Tensor: return x.contiguous(memory_format=torch.channels_last_3d) scriptAndSave(Test(), "test.pt") diff --git a/android/pytorch_android/test_asset.jit b/android/pytorch_android/test_asset.jit index 49a41eff36a6..3bd9037da4ee 100644 --- a/android/pytorch_android/test_asset.jit +++ b/android/pytorch_android/test_asset.jit @@ -1,85 +1,69 @@ def forward(self, input): return None -def eqBool(self, input): - # type: (bool) -> bool +def eqBool(self, input: bool) -> bool: return input -def eqInt(self, input): - # type: (int) -> int +def eqInt(self, input: int) -> int: return input -def eqFloat(self, input): - # type: (float) -> float +def eqFloat(self, input: float) -> float: return input -def eqStr(self, input): - # type: (str) -> str +def eqStr(self, input: str) -> str: return input -def eqTensor(self, input): - # type: (Tensor) -> Tensor +def eqTensor(self, input: Tensor) -> Tensor: return input -def eqDictStrKeyIntValue(self, input): - # type: (Dict[str, int]) -> Dict[str, int] +def eqDictStrKeyIntValue(self, input: Dict[str, int]) -> Dict[str, int]: return input -def eqDictIntKeyIntValue(self, input): - # type: (Dict[int, int]) -> Dict[int, int] +def eqDictIntKeyIntValue(self, input: Dict[int, int]) -> Dict[int, int]: return input -def eqDictFloatKeyIntValue(self, input): - # type: (Dict[float, int]) -> Dict[float, int] +def eqDictFloatKeyIntValue(self, input: Dict[float, int]) -> Dict[float, int]: return input -def listIntSumReturnTuple(self, input): - # type: (List[int]) -> Tuple[List[int], int] +def listIntSumReturnTuple(self, input: List[int]) -> Tuple[List[int], int]: sum = 0 for x in input: sum += x return (input, sum) -def listBoolConjunction(self, input): - # type: (List[bool]) -> bool +def listBoolConjunction(self, input: List[bool]) -> bool: res = True for x in input: res = res and x return res -def listBoolDisjunction(self, input): - # type: (List[bool]) -> bool +def listBoolDisjunction(self, input: List[bool]) -> bool: res = False for x in input: res = res or x return res -def tupleIntSumReturnTuple(self, input): - # type: (Tuple[int, int, int]) -> Tuple[Tuple[int, int, int], int] +def tupleIntSumReturnTuple(self, input: Tuple[int, int, int]) -> Tuple[Tuple[int, int, int], int]: sum = 0 for x in input: sum += x return (input, sum) -def optionalIntIsNone(self, input): - # type: (Optional[int]) -> bool +def optionalIntIsNone(self, input: Optional[int]) -> bool: return input is None -def intEq0None(self, input): - # type: (int) -> Optional[int] +def intEq0None(self, input: int) -> Optional[int]: if input == 0: return None return input -def str3Concat(self, input): - # type: (str) -> str +def str3Concat(self, input: str) -> str: return input + input + input def newEmptyShapeWithItem(self, input): return torch.tensor([int(input.item())])[0] -def testAliasWithOffset(self): - # type: () -> List[Tensor] +def testAliasWithOffset(self) -> List[Tensor]: x = torch.tensor([100, 200]) a = [x[0], x[1]] return a @@ -91,8 +75,7 @@ def testNonContiguous(self): assert x[1] == 300 return x -def conv2d(self, x, w, toChannelsLast): - # type: (Tensor, Tensor, bool) -> Tensor +def conv2d(self, x: Tensor, w: Tensor, toChannelsLast: bool) -> Tensor: r = torch.conv2d(x, w) if (toChannelsLast): # memory_format=torch.channels_last @@ -101,16 +84,13 @@ def conv2d(self, x, w, toChannelsLast): r = r.contiguous() return r -def contiguous(self, x): - # type: (Tensor) -> Tensor +def contiguous(self, x: Tensor) -> Tensor: return x.contiguous() -def contiguousChannelsLast(self, x): - # type: (Tensor) -> Tensor +def contiguousChannelsLast(self, x: Tensor) -> Tensor: # memory_format=torch.channels_last return x.contiguous(memory_format=2) -def contiguousChannelsLast3d(self, x): - # type: (Tensor) -> Tensor +def contiguousChannelsLast3d(self, x: Tensor) -> Tensor: # memory_format=torch.channels_last_3d return x.contiguous(memory_format=3) diff --git a/aten/src/ATen/BatchedTensorImpl.cpp b/aten/src/ATen/BatchedTensorImpl.cpp index f295d70c31fd..9dbf9ea78f4b 100644 --- a/aten/src/ATen/BatchedTensorImpl.cpp +++ b/aten/src/ATen/BatchedTensorImpl.cpp @@ -20,14 +20,11 @@ BatchedTensorImpl::BatchedTensorImpl(Tensor value, BatchDims bdims) const auto public_dims = value_.dim() - bdims_.size(); const auto value_sizes = value_.sizes(); const auto value_strides = value_.strides(); - sizes_.clear(); - sizes_.reserve(public_dims); - strides_.clear(); - strides_.reserve(public_dims); + sizes_and_strides_.resize(public_dims); for (int64_t dim = 0; dim < public_dims; dim++) { auto actual_dim = actualDim(dim, /*wrap_dim=*/false); - sizes_.push_back(value_sizes.at(actual_dim)); - strides_.push_back(value_strides.at(actual_dim)); + sizes_and_strides_.size_at_unchecked(dim) = value_sizes.at(actual_dim); + sizes_and_strides_.stride_at_unchecked(dim) = value_strides.at(actual_dim); } refresh_numel(); refresh_contiguous(); @@ -35,7 +32,7 @@ BatchedTensorImpl::BatchedTensorImpl(Tensor value, BatchDims bdims) int64_t BatchedTensorImpl::actualDim(int64_t dim, bool wrap_dim) const { if (wrap_dim) { - const auto ndim = sizes_.size(); + const auto ndim = sizes_and_strides_.size(); dim = maybe_wrap_dim(dim, ndim); } auto is_bdim = createBatchDimBitset(bdims_); diff --git a/aten/src/ATen/OpaqueTensorImpl.h b/aten/src/ATen/OpaqueTensorImpl.h index f23a097aabb8..2072f549d011 100644 --- a/aten/src/ATen/OpaqueTensorImpl.h +++ b/aten/src/ATen/OpaqueTensorImpl.h @@ -28,7 +28,7 @@ struct TORCH_API OpaqueTensorImpl : public TensorImpl { bool is_non_overlapping_and_dense = true) : TensorImpl(key_set, data_type, device), opaque_handle_(std::move(opaque_handle)) { - sizes_ = sizes.vec(); + sizes_and_strides_.set_sizes(sizes); refresh_numel(); is_non_overlapping_and_dense_ = is_non_overlapping_and_dense; } @@ -86,7 +86,7 @@ struct TORCH_API OpaqueTensorImpl : public TensorImpl { const c10::VariableVersion& version_counter, bool allow_tensor_metadata_change) const override { auto impl = c10::make_intrusive>( - key_set(), dtype(), device(), opaque_handle_, sizes_); + key_set(), dtype(), device(), opaque_handle_, sizes_and_strides_.sizes_arrayref()); copy_tensor_metadata( /*src_opaque_impl=*/this, /*dest_opaque_impl=*/impl.get(), @@ -106,7 +106,7 @@ struct TORCH_API OpaqueTensorImpl : public TensorImpl { c10::VariableVersion&& version_counter, bool allow_tensor_metadata_change) const override { auto impl = c10::make_intrusive>( - key_set(), dtype(), device(), opaque_handle_, sizes_); + key_set(), dtype(), device(), opaque_handle_, sizes_and_strides_.sizes_arrayref()); copy_tensor_metadata( /*src_opaque_impl=*/this, /*dest_opaque_impl=*/impl.get(), diff --git a/aten/src/ATen/SparseTensorImpl.cpp b/aten/src/ATen/SparseTensorImpl.cpp index 8d7d4b2ce0f8..98670db11e86 100644 --- a/aten/src/ATen/SparseTensorImpl.cpp +++ b/aten/src/ATen/SparseTensorImpl.cpp @@ -70,6 +70,7 @@ void SparseTensorImpl::set_storage_offset(int64_t storage_offset) { } int64_t SparseTensorImpl::dim() const { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(sparse_dim_ + dense_dim_ == TensorImpl::dim()); return sparse_dim_ + dense_dim_; } bool SparseTensorImpl::has_storage() const { diff --git a/aten/src/ATen/SparseTensorImpl.h b/aten/src/ATen/SparseTensorImpl.h index db97b2202a5f..5f502a6eaa54 100644 --- a/aten/src/ATen/SparseTensorImpl.h +++ b/aten/src/ATen/SparseTensorImpl.h @@ -56,7 +56,7 @@ struct TORCH_API SparseTensorImpl : public TensorImpl { // respect to indices and values void raw_resize_(int64_t sparse_dim, int64_t dense_dim, IntArrayRef size) { TORCH_CHECK(allow_tensor_metadata_change(), "raw_resize_ ", err_msg_tensor_metadata_change_not_allowed); - sizes_ = size.vec(); + sizes_and_strides_.set_sizes(size); sparse_dim_ = sparse_dim; dense_dim_ = dense_dim; refresh_numel(); @@ -126,7 +126,8 @@ struct TORCH_API SparseTensorImpl : public TensorImpl { "shrinking the size of dense dimensions (from ", dense_size_original, " to ", dense_size_new, ") on a non-empty sparse tensor is not supported.\n", alt_options_msg); } - if ((!size.equals(sizes_)) || (sparse_dim != sparse_dim_) || (dense_dim != dense_dim_)) { + const bool size_equals_sizes = std::equal(size.begin(), size.end(), sizes_and_strides_.sizes_begin(), sizes_and_strides_.sizes_end()); + if ((!size_equals_sizes) || (sparse_dim != sparse_dim_) || (dense_dim != dense_dim_)) { auto nnz = values().size(0); std::vector values_size = {nnz}; auto dense_size = size.slice(sparse_dim); @@ -135,7 +136,9 @@ struct TORCH_API SparseTensorImpl : public TensorImpl { indices_.resize_({sparse_dim, nnz}); } - sizes_ = size.vec(); + if (!size_equals_sizes) { + sizes_and_strides_.set_sizes(size); + } sparse_dim_ = sparse_dim; dense_dim_ = dense_dim; refresh_numel(); @@ -146,7 +149,7 @@ struct TORCH_API SparseTensorImpl : public TensorImpl { TORCH_CHECK(allow_tensor_metadata_change(), "resize_and_clear_ ", err_msg_tensor_metadata_change_not_allowed); TORCH_CHECK(sparse_dim + dense_dim == static_cast(size.size()), "number of dimensions must be sparse_dim (", sparse_dim, ") + dense_dim (", dense_dim, "), but got ", size.size()); - sizes_ = size.vec(); + sizes_and_strides_.set_sizes(size); sparse_dim_ = sparse_dim; dense_dim_ = dense_dim; diff --git a/aten/src/ATen/core/List_inl.h b/aten/src/ATen/core/List_inl.h index ab3ddae55770..c517dac20542 100644 --- a/aten/src/ATen/core/List_inl.h +++ b/aten/src/ATen/core/List_inl.h @@ -101,13 +101,23 @@ namespace detail { return std::move(element).template to(); } template - IValue list_element_from(const T& element) { - return element; - } - template - IValue list_element_from(T&& element) { - return std::move(element); - } + struct ListElementFrom { + static IValue from(const T& element) { + return element; + } + static IValue from(T&& element) { + return std::move(element); + } + }; + template<> + struct ListElementFrom { + static const IValue& from(const IValue& element) { + return element; + } + static IValue&& from(IValue&& element) { + return std::move(element); + } + }; } namespace impl { @@ -119,13 +129,13 @@ ListElementReference::operator T() const { template ListElementReference& ListElementReference::operator=(T&& new_value) && { - *iterator_ = c10::detail::list_element_from(std::move(new_value)); + *iterator_ = c10::detail::ListElementFrom::from(std::move(new_value)); return *this; } template ListElementReference& ListElementReference::operator=(const T& new_value) && { - *iterator_ = c10::detail::list_element_from(std::move(new_value)); + *iterator_ = c10::detail::ListElementFrom::from(std::move(new_value)); return *this; } @@ -154,12 +164,12 @@ inline bool operator==(const T& lhs, const ListElementReference& rh template void List::set(size_type pos, const value_type& value) const { - impl_->list.at(pos) = c10::detail::list_element_from(value); + impl_->list.at(pos) = c10::detail::ListElementFrom::from(value); } template void List::set(size_type pos, value_type&& value) const { - impl_->list.at(pos) = c10::detail::list_element_from(std::move(value)); + impl_->list.at(pos) = c10::detail::ListElementFrom::from(std::move(value)); } template @@ -178,7 +188,7 @@ typename List::value_type List::extract(size_type pos) const { auto& elem = impl_->list.at(pos); auto result = c10::detail::list_element_to(std::move(elem)); // Reset the list element to a T() instead of None to keep it correctly typed - elem = c10::detail::list_element_from(T{}); + elem = c10::detail::ListElementFrom::from(T{}); return result; } @@ -214,12 +224,12 @@ void List::clear() const { template typename List::iterator List::insert(iterator pos, const T& value) const { - return iterator { impl_->list.insert(pos.iterator_, c10::detail::list_element_from(value)) }; + return iterator { impl_->list.insert(pos.iterator_, c10::detail::ListElementFrom::from(value)) }; } template typename List::iterator List::insert(iterator pos, T&& value) const { - return iterator { impl_->list.insert(pos.iterator_, c10::detail::list_element_from(std::move(value))) }; + return iterator { impl_->list.insert(pos.iterator_, c10::detail::ListElementFrom::from(std::move(value))) }; } template @@ -231,12 +241,12 @@ typename List::iterator List::emplace(iterator pos, Args&&... value) const template void List::push_back(const T& value) const { - impl_->list.push_back(c10::detail::list_element_from(value)); + impl_->list.push_back(c10::detail::ListElementFrom::from(value)); } template void List::push_back(T&& value) const { - impl_->list.push_back(c10::detail::list_element_from(std::move(value))); + impl_->list.push_back(c10::detail::ListElementFrom::from(std::move(value))); } template diff --git a/aten/src/ATen/core/boxing/KernelFunction.cpp b/aten/src/ATen/core/boxing/KernelFunction.cpp index 58c35557018c..260343ac3180 100644 --- a/aten/src/ATen/core/boxing/KernelFunction.cpp +++ b/aten/src/ATen/core/boxing/KernelFunction.cpp @@ -24,7 +24,7 @@ void ambiguous_autogradother_kernel(OperatorKernel*, const OperatorHandle& op, S op.operator_name(), " has kernels registered to both Math and a backend mapped to AutogradOther. " "This makes the backend kernel unreachable (see Note [Ambiguity in AutogradOther kernel]). " "If it's intended to override Math kernel behavior, please open an issue to request a dedicated " - "Autograd dispatch key for the backend."); + "Autograd dispatch key for the backend.", "\nCanonical state\n~~~~~~~~~~~\n", op.dumpState(), "\n\n"); } void named_not_supported_kernel(OperatorKernel*, const OperatorHandle& op, Stack*) { diff --git a/aten/src/ATen/core/ivalue_inl.h b/aten/src/ATen/core/ivalue_inl.h index b96f4b834989..c14b86f995e5 100644 --- a/aten/src/ATen/core/ivalue_inl.h +++ b/aten/src/ATen/core/ivalue_inl.h @@ -406,7 +406,7 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target { // This accessor should only be used if we know that the future is // completed() with no error. - const IValue& constValue() { + const IValue& constValue() const { std::unique_lock lock(mutex_); AT_ASSERT(completed()); AT_ASSERT(!eptr_); @@ -451,7 +451,7 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target { } // Tries to retrieve the error message from std::exception_ptr. - std::string tryRetrieveErrorMessage() { + std::string tryRetrieveErrorMessage() const { TORCH_CHECK(hasError(), "No error present on the future."); std::unique_lock lock(mutex_); return tryRetrieveErrorMessageInternal(eptr_); @@ -543,7 +543,7 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target { } // Tries to retrieve the error message from std::exception_ptr. - std::string tryRetrieveErrorMessageInternal(std::exception_ptr eptr) { + std::string tryRetrieveErrorMessageInternal(std::exception_ptr eptr) const { try { std::rethrow_exception(eptr); } catch (const std::exception& e) { diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index bd59fe7d28b9..146275e1faa6 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -1011,13 +1011,13 @@ std::tuple _linalg_qr_helper_cpu(const Tensor& self, std::string std::tuple linalg_qr(const Tensor& self, std::string mode) { TORCH_CHECK(self.dim() >= 2, - "self should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); + "qr input should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); return at::_linalg_qr_helper(self, mode); } std::tuple linalg_qr_out(Tensor& Q, Tensor& R, const Tensor& self, std::string mode) { TORCH_CHECK(self.dim() >= 2, - "self should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); + "qr input should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); Tensor Q_tmp, R_tmp; std::tie(Q_tmp, R_tmp) = at::_linalg_qr_helper(self, mode); at::native::resize_output(Q, Q_tmp.sizes()); @@ -1411,7 +1411,7 @@ std::tuple _svd_helper_cpu(const Tensor& self, bool some if (compute_uv) { if (some) { - VT_working_copy = VT_working_copy.narrow(-1, 0, k); + VT_working_copy = VT_working_copy.narrow(-2, 0, k); } } else { VT_working_copy.zero_(); @@ -1421,24 +1421,71 @@ std::tuple _svd_helper_cpu(const Tensor& self, bool some U_working_copy.zero_(); VT_working_copy.zero_(); } + // so far we have computed VT, but torch.svd returns V instead. Adjust accordingly. + VT_working_copy.transpose_(-2, -1); return std::make_tuple(U_working_copy, S_working_copy, VT_working_copy); } std::tuple svd(const Tensor& self, bool some, bool compute_uv) { TORCH_CHECK(self.dim() >= 2, - "self should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); + "svd input should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); return at::_svd_helper(self, some, compute_uv); } -std::tuple svd_out(Tensor& U, Tensor& S, Tensor& VT, +std::tuple svd_out(Tensor& U, Tensor& S, Tensor& V, const Tensor& self, bool some, bool compute_uv) { TORCH_CHECK(self.dim() >= 2, - "self should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); - Tensor U_tmp, S_tmp, VT_tmp; - std::tie(U_tmp, S_tmp, VT_tmp) = at::_svd_helper(self, some, compute_uv); + "svd input should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); + Tensor U_tmp, S_tmp, V_tmp; + std::tie(U_tmp, S_tmp, V_tmp) = at::_svd_helper(self, some, compute_uv); U.resize_as_(U_tmp).copy_(U_tmp); S.resize_as_(S_tmp).copy_(S_tmp); - VT.resize_as_(VT_tmp).copy_(VT_tmp); + V.resize_as_(V_tmp).copy_(V_tmp); + return std::tuple(U, S, V); +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ linalg_svd ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/* torch.linalg.svd, implemented in terms of torch.svd. There are two main + differences: + + 1. the 2nd parameter is bool some=True, which if effectively the opposite + of full_matrices=True + + 2. svd returns V, while linalg.svd returns VT. To accommodate the + difference, we transpose() V upon return +*/ + +std::tuple linalg_svd(const Tensor& self, bool full_matrices, bool compute_uv) { + TORCH_CHECK(self.dim() >= 2, + "svd input should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); + + bool some = !full_matrices; + Tensor U, S, V; + std::tie(U, S, V) = at::_svd_helper(self, some, compute_uv); + if (compute_uv) { + Tensor VT = V.transpose(-2, -1); + return std::make_tuple(U, S, VT); + } else { + Tensor empty_U = at::empty({0}, self.options()); + Tensor empty_VT = at::empty({0}, self.options()); + return std::make_tuple(empty_U, S, empty_VT); + } +} + +static void svd_resize_and_copy(const char *name, const Tensor& src, Tensor &dst) { + TORCH_CHECK(src.device() == dst.device(), "svd output tensor ", name, " is on the wrong device: expected ", src.device(), " got ", dst.device()); + at::native::resize_output(dst, src.sizes()); + dst.copy_(src); +} + +std::tuple linalg_svd_out(Tensor& U, Tensor& S, Tensor& VT, + const Tensor& self, bool full_matrices, bool compute_uv) { + Tensor U_tmp, S_tmp, VT_tmp; + std::tie(U_tmp, S_tmp, VT_tmp) = at::linalg_svd(self, full_matrices, compute_uv); + svd_resize_and_copy("U", U_tmp, U); + svd_resize_and_copy("S", S_tmp, S); + svd_resize_and_copy("V", VT_tmp, VT); return std::tuple(U, S, VT); } diff --git a/aten/src/ATen/native/LinearAlgebraUtils.h b/aten/src/ATen/native/LinearAlgebraUtils.h index 607e201ebe8d..4322c4c79222 100644 --- a/aten/src/ATen/native/LinearAlgebraUtils.h +++ b/aten/src/ATen/native/LinearAlgebraUtils.h @@ -206,7 +206,8 @@ static inline std::tuple _parse_qr_mode(std::string mode) { compute_q = false; reduced = true; // this is actually irrelevant in this mode } else { - TORCH_CHECK(false, "Unrecognized mode '", mode, "'"); + TORCH_CHECK(false, "qr received unrecognized mode '", mode, + "' but expected one of 'reduced' (default), 'r', or 'complete'"); } return std::make_tuple(compute_q, reduced); } @@ -261,18 +262,21 @@ static inline std::tuple _create_U_S_VT(const Tensor& in U_empty = at::empty_strided(sizes, strides, input.options().device(at::kCPU)); } + // VT should be a column-major or a batch of column-major matrices sizes[input.dim() - 2] = n; sizes[input.dim() - 1] = n; - // VT should be a row-major or a batch of row-major matrices + strides = at::detail::defaultStrides(sizes); + strides[input.dim() - 1] = n; + strides[input.dim() - 2] = 1; Tensor VT_empty; if (!input.is_cuda()) { - VT_empty = at::empty(sizes, input.options()); + VT_empty = at::empty_strided(sizes, strides, input.options()); } else { // NB: VT_empty is an empty tensor created on the CPU intentionally, because magma_(d/s)gesdd // (which is the driver routine for the divide and conquer SVD operation) // takes in arrays on the CPU as input. This routine is a hybrid CPU-GPU routine that // moves the inputs between devices internally. - VT_empty = at::empty(sizes, input.options().device(at::kCPU)); + VT_empty = at::empty_strided(sizes, strides, input.options().device(at::kCPU)); } sizes.pop_back(); diff --git a/aten/src/ATen/native/Normalization.cpp b/aten/src/ATen/native/Normalization.cpp index 1ac4250a9d54..ea4a54c13196 100644 --- a/aten/src/ATen/native/Normalization.cpp +++ b/aten/src/ATen/native/Normalization.cpp @@ -118,7 +118,7 @@ void batch_norm_cpu_inference_channels_last(Tensor& output, const Tensor& input, // output(n, c, h, w) = input(n, c, h, w) * alpha(c) + beta(c) // No need to use parallel_for as this function is supposed to be // memory-limited. - // Keep the loop struture simple to make sure compiler vectorization kicks in. + // Keep the loop structure simple to make sure compiler vectorization kicks in. if (n_channel != 1) { for (int64_t n = 0; n < n_batch; ++n) { for (int64_t i = 0; i < image_size; ++i) { diff --git a/aten/src/ATen/native/Pow.cpp b/aten/src/ATen/native/Pow.cpp index 4d1601d3e6a0..1ed105fd0175 100644 --- a/aten/src/ATen/native/Pow.cpp +++ b/aten/src/ATen/native/Pow.cpp @@ -28,7 +28,7 @@ Tensor& pow_out(Tensor& result, const Tensor& base, Scalar exp) { auto common_dtype = at::result_type(base, exp); TORCH_CHECK(at::can_cast(common_dtype, result.scalar_type()), - "result type ", common_dtype, "can't be cast to the desired output type ", + "result type ", common_dtype, " can't be cast to the desired output type ", result.scalar_type()); if (exp.equal(0.0)) { @@ -83,42 +83,68 @@ Tensor& float_power_out(Tensor& result, const Tensor& base, const Tensor& exp) { auto dtype = (at::isComplexType(base.scalar_type()) || at::isComplexType(exp.scalar_type())) ? at::kComplexDouble : at::kDouble; TORCH_CHECK(result.scalar_type() == dtype, - "output type ", result.scalar_type(), "is not the desired output type ", dtype); + "the output given to float_power has dtype ", result.scalar_type(), + " but the operation's result requires dtype ", dtype); return at::pow_out(result, base.to(dtype), exp.to(dtype)); } Tensor& float_power_out(Tensor& result, const Tensor& base, Scalar exp) { - return at::float_power_out(result, base, c10::scalar_to_tensor(exp, base.device())); + auto dtype = (at::isComplexType(base.scalar_type()) || exp.isComplex()) ? at::kComplexDouble : at::kDouble; + TORCH_CHECK(result.scalar_type() == dtype, + "the output given to float_power has dtype ", result.scalar_type(), + " but the operation's result requires dtype ", dtype); + + // Note: need the casts inside the ternary because conversion functions return e.g. c10::complex, + // which causes a complex scalar to always be returned. + exp = (dtype == at::kComplexDouble) ? Scalar(exp.toComplexDouble()) : Scalar(exp.toDouble()); + return at::pow_out(result, base.to(dtype), exp); } Tensor& float_power_out(Tensor& result, Scalar base, const Tensor& exp) { - return at::float_power_out(result, c10::scalar_to_tensor(base, exp.device()), exp); -} + auto dtype = (at::isComplexType(exp.scalar_type()) || base.isComplex()) ? at::kComplexDouble : at::kDouble; + TORCH_CHECK(result.scalar_type() == dtype, + "the output given to float_power has dtype ", result.scalar_type(), + " but the operation's result requires dtype ", dtype); -Tensor float_power(const Tensor& base, const Tensor& exp) { - auto dtype = (at::isComplexType(base.scalar_type()) || at::isComplexType(exp.scalar_type())) ? at::kComplexDouble : at::kDouble; - return at::pow(base.to(dtype), exp.to(dtype)); + base = (dtype == at::kComplexDouble) ? Scalar(base.toComplexDouble()) : Scalar(base.toDouble()); + return at::pow_out(result, base, exp.to(dtype)); } Tensor float_power(const Tensor& base, Scalar exp) { - return at::float_power(base, c10::scalar_to_tensor(exp, base.device())); + auto dtype = (at::isComplexType(base.scalar_type()) || exp.isComplex()) ? at::kComplexDouble : at::kDouble; + exp = (dtype == at::kComplexDouble) ? Scalar(exp.toComplexDouble()) : Scalar(exp.toDouble()); + return at::pow(base.to(dtype), exp); } Tensor float_power(Scalar base, const Tensor& exp) { - return at::float_power(c10::scalar_to_tensor(base, exp.device()), exp); + auto dtype = (at::isComplexType(exp.scalar_type()) || base.isComplex()) ? at::kComplexDouble : at::kDouble; + base = (dtype == at::kComplexDouble) ? Scalar(base.toComplexDouble()) : Scalar(base.toDouble()); + return at::pow(base, exp.to(dtype)); +} + +Tensor float_power(const Tensor& base, const Tensor& exp) { + auto dtype = (at::isComplexType(base.scalar_type()) || at::isComplexType(exp.scalar_type())) ? at::kComplexDouble : at::kDouble; + return at::pow(base.to(dtype), exp.to(dtype)); } Tensor& float_power_(Tensor& base, const Tensor& exp) { auto dtype = (at::isComplexType(base.scalar_type()) || at::isComplexType(exp.scalar_type())) ? at::kComplexDouble : at::kDouble; TORCH_CHECK(base.scalar_type() == dtype, - "self tensor type ", base.scalar_type(), "is not the desired type ", dtype); + "the base given to float_power_ has dtype ", base.scalar_type(), + " but the operation's result requires dtype ", dtype); return base.pow_(exp.to(dtype)); } Tensor& float_power_(Tensor& base, Scalar exp) { - return base.float_power_(c10::scalar_to_tensor(exp, base.device())); + auto dtype = (at::isComplexType(base.scalar_type()) || exp.isComplex()) ? at::kComplexDouble : at::kDouble; + TORCH_CHECK(base.scalar_type() == dtype, + "the base given to float_power_ has dtype ", base.scalar_type(), + " but the operation's result requires dtype ", dtype); + + exp = (dtype == at::kComplexDouble) ? Scalar(exp.toComplexDouble()) : Scalar(exp.toDouble()); + return base.pow_(exp); } } // namespace native diff --git a/aten/src/ATen/native/ReduceOps.cpp b/aten/src/ATen/native/ReduceOps.cpp index e4b0a1cb19b7..fd27b3e7efe5 100644 --- a/aten/src/ATen/native/ReduceOps.cpp +++ b/aten/src/ATen/native/ReduceOps.cpp @@ -740,6 +740,12 @@ Tensor norm(const Tensor& self, Scalar p) { return at::native::_norm(self, p); } +// Note [all, any : uint8 compatibility]: +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// For NumPy comptability, `all` and `any` return +// Tensor of dtype `bool`. However for compatibility reason, +// for `uint8`, they return Tensor of same dtype `uint8`. +// Reference: https://github.com/pytorch/pytorch/pull/47878#issuecomment-747108561 inline Tensor & _all(Tensor & result, TensorIterator & iter) { if (iter.numel() == 0) { result.fill_(1); @@ -756,14 +762,40 @@ Tensor all(const Tensor& self) { TORCH_CHECK(self.layout() == Layout::Strided, "all only supports strided layout, got: ", self.layout()); - Tensor result = at::empty({0}, self.options()); - auto iter = make_reduction( - "all", result, self, {}, false, self.scalar_type()); + // Refer [all, any : uint8 compatibility] + Tensor result; + ScalarType out_dtype; + if (self.scalar_type() == ScalarType::Byte){ + result = at::empty({0}, self.options()); + out_dtype = self.scalar_type(); + } else { + result = at::empty({0}, self.options().dtype(kBool)); + out_dtype = ScalarType::Bool; + } + + if (self.is_cuda()) { + // As CUDA supports dynamic type casting, we use this overload of + // `make_reduction`, which doesn't cast input to the result type i.e. kBool., + // otherwise we use the overload below which casts the input to kBool (which is + // an extra operation). + auto iter = make_reduction( + "all", result, self, {}, false, self.scalar_type(), out_dtype); + return _all(result, iter); + } + auto iter = + make_reduction("all", result, self, {}, false, /*out_dtype=*/out_dtype); return _all(result, iter); } Tensor all(const Tensor& self, int64_t dim, bool keepdim) { - Tensor result = at::empty({0}, self.options()); + // Refer [all, any : uint8 compatibility] + Tensor result; + if (self.scalar_type() == ScalarType::Byte){ + result = at::empty({0}, self.options()); + } else { + result = at::empty({0}, self.options().dtype(kBool)); + } + return at::native::all_out(result, self, dim, keepdim); } @@ -772,13 +804,26 @@ Tensor &all_out(Tensor &result, const Tensor &self, int64_t dim, bool keepdim) { "all only supports CPU AND CUDA device type, got: ", self.device().type()); TORCH_CHECK(self.layout() == Layout::Strided, "all only supports strided layout, got: ", self.layout()); + // Refer [all, any : uint8 compatibility] + TORCH_CHECK(result.scalar_type() == ScalarType::Bool || result.scalar_type() == ScalarType::Byte, + "all only supports bool tensor for result, got: ", result.scalar_type()); + auto out_dtype = result.scalar_type(); dim = maybe_wrap_dim(dim, self.dim()); if (_dimreduce_return_trivial(result, self, 1, dim, keepdim)) { return result; } else { - auto iter = make_reduction( - "all", result, self, dim, keepdim, self.scalar_type()); + if (self.is_cuda()) { + // As CUDA supports dynamic type casting, we use this overload of + // `make_reduction`, which doesn't cast input to the result type i.e. kBool., + // otherwise we use the overload below which casts the input to kBool (which is + // an extra operation). + auto iter = make_reduction( + "all", result, self, dim, keepdim, self.scalar_type(), out_dtype); + return _all(result, iter); + } + auto iter = + make_reduction("all", result, self, dim, keepdim, /*out_dtype=*/out_dtype); return _all(result, iter); } } @@ -798,15 +843,41 @@ Tensor any(const Tensor& self) { "any only supports CPU AND CUDA device type, got: ", self.device().type()); TORCH_CHECK(self.layout() == Layout::Strided || self.layout() == Layout::Sparse, "any only supports strided AND sparse layout, got: ", self.layout()); + + // Refer [all, any : uint8 compatibility] + Tensor result; + ScalarType out_dtype; + if (self.scalar_type() == ScalarType::Byte){ + result = at::empty({0}, self.options()); + out_dtype = self.scalar_type(); + } else { + result = at::empty({0}, self.options().dtype(kBool)); + out_dtype = ScalarType::Bool; + } - Tensor result = at::empty({0}, self.options()); - auto iter = make_reduction( - "any", result, self, {}, false, self.scalar_type()); + if (self.is_cuda()) { + // As CUDA supports dynamic type casting, we use this overload of + // `make_reduction`, which doesn't cast input to the result type i.e. kBool., + // otherwise we use the overload below which casts the input to kBool (which is + // an extra operation). + auto iter = make_reduction( + "any", result, self, {}, false, self.scalar_type(), out_dtype); + return _any(result, iter); + } + auto iter = + make_reduction("any", result, self, {}, false, /*out_dtype=*/out_dtype); return _any(result, iter); } Tensor any(const Tensor& self, int64_t dim, bool keepdim) { - Tensor result = at::empty({0}, self.options()); + // Refer [all, any : uint8 compatibility] + Tensor result; + if (self.scalar_type() == ScalarType::Byte){ + result = at::empty({0}, self.options()); + } else { + result = at::empty({0}, self.options().dtype(kBool)); + } + return at::native::any_out(result, self, dim, keepdim); } @@ -815,13 +886,26 @@ Tensor &any_out(Tensor &result, const Tensor &self, int64_t dim, bool keepdim) { "any only supports CPU AND CUDA device type, got: ", self.device().type()); TORCH_CHECK(self.layout() == Layout::Strided, "any only supports strided layout, got: ", self.layout()); + // Refer [all, any : uint8 compatibility] + TORCH_CHECK(result.scalar_type() == ScalarType::Bool || result.scalar_type() == ScalarType::Byte, + "any only supports bool tensor for result, got: ", result.scalar_type()); + auto out_dtype = result.scalar_type(); dim = maybe_wrap_dim(dim, self.dim()); if (_dimreduce_return_trivial(result, self, 0, dim, keepdim)) { return result; } else { - auto iter = make_reduction( - "any", result, self, dim, keepdim, self.scalar_type()); + if (self.is_cuda()) { + // As CUDA supports dynamic type casting, we use this overload of + // `make_reduction`, which doesn't cast input to the result type i.e. kBool., + // otherwise we use the overload below which casts the input to kBool (which is + // an extra operation). + auto iter = make_reduction( + "any", result, self, dim, keepdim, self.scalar_type(), out_dtype); + return _any(result, iter); + } + auto iter = + make_reduction("any", result, self, dim, keepdim, /*out_dtype=*/out_dtype); return _any(result, iter); } } diff --git a/aten/src/ATen/native/SharedReduceOps.h b/aten/src/ATen/native/SharedReduceOps.h index 4106a90c0729..e25b943d13a8 100644 --- a/aten/src/ATen/native/SharedReduceOps.h +++ b/aten/src/ATen/native/SharedReduceOps.h @@ -386,56 +386,6 @@ struct NanSumOps { #endif }; -template -struct AndOps { - inline C10_DEVICE acc_t reduce(acc_t a, acc_t b, int64_t /*idx*/) const { - return static_cast(a) && static_cast(b); - } - - inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const { - return static_cast(a) && static_cast(b); - } - - inline C10_DEVICE acc_t project(acc_t a) const { - return a; - } - - static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) { - return acc; - } - -#if defined(__CUDACC__) || defined(__HIPCC__) - inline C10_DEVICE acc_t warp_shfl_down(acc_t data, int offset) const { - return WARP_SHFL_DOWN(data, offset); - } -#endif -}; - -template -struct OrOps { - inline C10_DEVICE acc_t reduce(acc_t a, acc_t b, int64_t /*idx*/) const { - return static_cast(a) || static_cast(b); - } - - inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const { - return static_cast(a) || static_cast(b); - } - - inline C10_DEVICE acc_t project(acc_t a) const { - return a; - } - - static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) { - return acc; - } - -#if defined(__CUDACC__) || defined(__HIPCC__) - inline C10_DEVICE acc_t warp_shfl_down(acc_t data, int offset) const { - return WARP_SHFL_DOWN(data, offset); - } -#endif -}; - namespace detail { template diff --git a/aten/src/ATen/native/TensorConversions.cpp b/aten/src/ATen/native/TensorConversions.cpp index 355fcad9010d..d773de927efb 100644 --- a/aten/src/ATen/native/TensorConversions.cpp +++ b/aten/src/ATen/native/TensorConversions.cpp @@ -117,4 +117,20 @@ Tensor to_mkldnn_backward(const Tensor& grad, const Tensor& input_) { return grad.to_dense(input_.scalar_type()); } +Tensor view_dtype(const Tensor& self, ScalarType dtype) { + if (self.scalar_type() == dtype) { + return self; + } + auto type_meta = c10::scalarTypeToTypeMeta(dtype); + TORCH_CHECK(self.element_size() == type_meta.itemsize(), + "Viewing a tensor as a new dtype with a different number of bytes per element is not supported."); + Storage storage = self.storage(); + auto new_tensor = detail::make_tensor( + std::move(storage), self.key_set(), type_meta); + auto* impl = new_tensor.unsafeGetTensorImpl(); + impl->set_storage_offset(self.storage_offset()); + impl->set_sizes_and_strides(self.sizes(), self.strides()); + return new_tensor; +} + }} // namespace at::native diff --git a/aten/src/ATen/native/UnaryOps.cpp b/aten/src/ATen/native/UnaryOps.cpp index 0f6da7e4292a..c636d5d94a2f 100644 --- a/aten/src/ATen/native/UnaryOps.cpp +++ b/aten/src/ATen/native/UnaryOps.cpp @@ -256,8 +256,8 @@ Tensor& ceil_out(Tensor& result, const Tensor& self) { Tensor ceil(const Tensor& self) { return unary_op_impl(self, at::ceil_out); } Tensor& ceil_(Tensor& self) { return unary_op_impl_(self, at::ceil_out); } -Tensor& exp_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, exp_stub); } -Tensor exp(const Tensor& self) { return unary_op_impl(self, at::exp_out); } +Tensor& exp_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, exp_stub); } +Tensor exp(const Tensor& self) { return unary_op_impl_float(self, exp_stub); } Tensor& exp_(Tensor& self) { return unary_op_impl_(self, at::exp_out); } Tensor& exp2_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, exp2_stub); } diff --git a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp index 32033abcd4e2..14f3d4a1fc21 100644 --- a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp @@ -256,11 +256,25 @@ static void norm_kernel_tensor_iterator_impl( } static void and_kernel_impl(TensorIterator& iter) { - if (c10::isIntegralType(iter.dtype(), /*includeBool=*/true)) { + if (iter.dtype() == ScalarType::Byte) { + // Refer [all, any : uint8 compatibility] binary_kernel_reduce_vec( iter, [=](uint8_t a, uint8_t b) -> uint8_t { return (a && b) ? 1 : 0; }, [=](Vec256 a, Vec256 b) { + Vec256 c = Vec256(); + + for (decltype(c.size()) i = 0; i != Vec256::size(); i++) { + c[i] = (a[i] && b[i]) ? 1 : 0; + } + return c; + }, + /*ident=*/true); + } else { + binary_kernel_reduce_vec( + iter, + [=](bool a, bool b) -> bool { return a && b; }, + [=](Vec256 a, Vec256 b) { // Adding the implementation here instead of in vec256_base to avoid // return value inconsistency. Other comparison operators in // vec256_base return -1/0 (all bit 1 / all bit 0) as true/false to @@ -271,39 +285,45 @@ static void and_kernel_impl(TensorIterator& iter) { // // In this method, users would expect, e.g., all(), to return 1/0 as // true/false. - Vec256 c = Vec256(); - for (int i = 0; i != Vec256::size(); i++) { - c[i] = (a[i] && b[i]) ? 1 : 0; + Vec256 c = Vec256(); + + for (decltype(c.size()) i = 0; i != Vec256::size(); i++) { + c[i] = a[i] && b[i]; } return c; }, /*ident=*/true); - } else { - AT_DISPATCH_FLOATING_TYPES_AND(kHalf, iter.dtype(), "and_kernel", [&]() { - binary_kernel_reduce( - iter, AndOps(), static_cast(true)); - }); } } static void or_kernel_impl(TensorIterator& iter) { - if (c10::isIntegralType(iter.dtype(), /*includeBool=*/true)) { + if (iter.dtype() == ScalarType::Byte) { + // Refer [all, any : uint8 compatibility] binary_kernel_reduce_vec( iter, [=](uint8_t a, uint8_t b) -> uint8_t { return (a || b) ? 1 : 0; }, [=](Vec256 a, Vec256 b) { Vec256 c = Vec256(); - for (int i = 0; i != Vec256::size(); i++) { + + for (decltype(c.size()) i = 0; i != Vec256::size(); i++) { c[i] = (a[i] || b[i]) ? 1 : 0; } return c; }, /*ident=*/false); } else { - AT_DISPATCH_FLOATING_TYPES_AND(kHalf, iter.dtype(), "or_kernel", [&]() { - binary_kernel_reduce( - iter, OrOps(), static_cast(false)); - }); + binary_kernel_reduce_vec( + iter, + [=](bool a, bool b) -> bool { return a || b; }, + [=](Vec256 a, Vec256 b) { + Vec256 c = Vec256(); + + for (decltype(c.size()) i = 0; i != Vec256::size(); i++) { + c[i] = a[i] || b[i]; + } + return c; + }, + /*ident=*/false); } } diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu index 3fbd693d17b1..379847d76ff4 100644 --- a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu +++ b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu @@ -2194,7 +2194,7 @@ std::tuple _svd_helper_cuda(const Tensor& self, bool som if (compute_uv) { if (some) { - VT_working_copy = VT_working_copy.narrow(-1, 0, k); + VT_working_copy = VT_working_copy.narrow(-2, 0, k); } } else { VT_working_copy.zero_(); @@ -2205,6 +2205,8 @@ std::tuple _svd_helper_cuda(const Tensor& self, bool som S_working_copy = same_stride_to(S_working_copy, S_working_copy.options().device(self.device())); VT_working_copy = same_stride_to(VT_working_copy, self.options()).zero_(); } + // so far we have computed VT, but torch.svd returns V instead. Adjust accordingly. + VT_working_copy.transpose_(-2, -1); return std::make_tuple(U_working_copy, S_working_copy, VT_working_copy); } diff --git a/aten/src/ATen/native/cuda/Normalization.cu b/aten/src/ATen/native/cuda/Normalization.cu index 186099dfde50..6bf4e0f32f13 100644 --- a/aten/src/ATen/native/cuda/Normalization.cu +++ b/aten/src/ATen/native/cuda/Normalization.cu @@ -120,15 +120,17 @@ std::tuple batch_norm_gather_stats_cuda(const Tensor& self, cons const Tensor& running_var, double momentum, double epsilon, int64_t count) { std::vector counts(mean.size(0), count); Tensor counts_ = at::from_blob((void*)counts.data(), {(int64_t)counts.size()}, self.options().dtype(at::kLong).device(at::kCPU)); - counts_ = counts_.to(self.device()).to(running_mean.dtype()); + counts_ = counts_.to(self.device()).to(running_mean.defined() ? running_mean.dtype() : self.dtype()); return batch_norm_gather_stats_with_counts_cuda(self, mean, invstd, running_mean, running_var, momentum, epsilon, counts_); } -std::tuple batch_norm_gather_stats_with_counts_cuda(const Tensor& self, const Tensor& mean, const Tensor& invstd, const Tensor& running_mean, - const Tensor& running_var, double momentum, double epsilon, const Tensor& counts) { - - return AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, running_mean.scalar_type(), "batch_norm_update_stats_cuda", [&] { +std::tuple batch_norm_gather_stats_with_counts_cuda( + const Tensor& self, const Tensor& mean, const Tensor& invstd, const Tensor& running_mean /* optional */, + const Tensor& running_var /* optional */, double momentum, double epsilon, const Tensor& counts) { + + auto scalar_type = running_mean.defined() ? running_mean.scalar_type() : self.scalar_type(); + return AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, scalar_type, "batch_norm_update_stats_cuda", [&] { using accscalar_t = at::acc_type; if (cuda::detail::canUse32BitIndexMath(self)) { return batch_norm_gather_stats_cuda_template(mean, invstd, running_mean, running_var, momentum, epsilon, counts); diff --git a/aten/src/ATen/native/cuda/Normalization.cuh b/aten/src/ATen/native/cuda/Normalization.cuh index 8355ac004308..a0445f129192 100644 --- a/aten/src/ATen/native/cuda/Normalization.cuh +++ b/aten/src/ATen/native/cuda/Normalization.cuh @@ -104,7 +104,7 @@ static __device__ __forceinline__ Float2 warpSum(Float2 #include #include +#include namespace at { namespace native { void and_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, iter.dtype(), "and_kernel", [&]() { - gpu_reduce_kernel( - iter, - func_wrapper([] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { - return static_cast(static_cast(a) && static_cast(b)); - }), - static_cast(true)); - }); + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( + kHalf, kBFloat16, kBool, iter.common_dtype(), "and_cuda", [&]() { + gpu_reduce_kernel( + iter, + func_wrapper([] GPU_LAMBDA(scalar_t a, scalar_t b) -> bool { + return (static_cast(a) && static_cast(b)); + }), + true); + }); } void or_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, iter.dtype(), "or_kernel", [&]() { - gpu_reduce_kernel( - iter, - func_wrapper([] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { - return static_cast(static_cast(a) || static_cast(b)); - }), - static_cast(false)); - }); + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( + kHalf, kBFloat16, kBool, iter.common_dtype(), "or_cuda", [&]() { + gpu_reduce_kernel( + iter, + func_wrapper([] GPU_LAMBDA(scalar_t a, scalar_t b) -> bool { + return (static_cast(a) || static_cast(b)); + }), + false); + }); } REGISTER_DISPATCH(and_stub, &and_kernel_cuda); diff --git a/aten/src/ATen/native/cuda/UnaryOpsKernel.cu b/aten/src/ATen/native/cuda/UnaryOpsKernel.cu index f5e1a4e85a04..e727335aaf17 100644 --- a/aten/src/ATen/native/cuda/UnaryOpsKernel.cu +++ b/aten/src/ATen/native/cuda/UnaryOpsKernel.cu @@ -33,7 +33,7 @@ void bitwise_not_kernel_cuda(TensorIterator& iter) { } void exp_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "exp_cuda", [&]() { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.common_dtype(), "exp_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { return ::exp(a); }); diff --git a/aten/src/ATen/native/cudnn/RNN.cpp b/aten/src/ATen/native/cudnn/RNN.cpp index 1dc9d5cba945..8e1f254da9f8 100644 --- a/aten/src/ATen/native/cudnn/RNN.cpp +++ b/aten/src/ATen/native/cudnn/RNN.cpp @@ -722,6 +722,11 @@ namespace { (tensors.seq_length >=10 && bsize <=32)); } } else if (prop->major >= 8) { + if (prop->minor == 6) { + // Excludes sm_86 GPU devices from using persistent rnn. + // This is because there are some edge cases that will throw exceptions with cudnn 8.0.5 on Nvidia A40 GPU. + return false; + } // Based on tests by Vasily Volkov and xwang233. Vasily only tried bsize <= 128, // so conservatively enable persistence for bsize <= 128 only. // TODO: Run more tests for bsize > 128. diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index b474d435398c..52bee793b4c8 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -4899,6 +4899,19 @@ CPU, CUDA, QuantizedCPU, QuantizedCUDA: view MkldnnCPU: mkldnn_view +# Warning: If you want to change the name or overload name of this +# operator, you might also want to change the `isBlockListedSchema` +# function in `torch/csrc/jit/frontend/schema_catching.cpp`. +# The name and overload name of this operator is hardcoded in that +# function in order to workaround a bug: +# https://github.com/pytorch/pytorch/issues/47964 +- func: view.dtype(Tensor(a) self, ScalarType dtype) -> Tensor(a) + use_c10_dispatcher: full + variants: method + device_guard: False + dispatch: + DefaultBackend: view_dtype + - func: put_(Tensor(a!) self, Tensor index, Tensor source, bool accumulate=False) -> Tensor(a!) variants: method dispatch: @@ -5820,14 +5833,14 @@ - func: svd.U(Tensor self, bool some=True, bool compute_uv=True, *, Tensor(a!) U, Tensor(b!) S, Tensor(c!) V) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) V) use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: - DefaultBackend: svd_out + Math: svd_out - func: svd(Tensor self, bool some=True, bool compute_uv=True) -> (Tensor U, Tensor S, Tensor V) variants: method, function dispatch: - DefaultBackend: svd + Math: svd -- func: _svd_helper(Tensor self, bool some, bool compute_uv) -> (Tensor, Tensor, Tensor) +- func: _svd_helper(Tensor self, bool some, bool compute_uv) -> (Tensor U, Tensor S, Tensor V) variants: function dispatch: CPU: _svd_helper_cpu @@ -8962,6 +8975,15 @@ python_module: linalg variants: function +- func: linalg_svd.U(Tensor self, bool full_matrices=True, bool compute_uv=True, *, Tensor(a!) U, Tensor(b!) S, Tensor(c!) V) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) V) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + python_module: linalg + +- func: linalg_svd(Tensor self, bool full_matrices=True, bool compute_uv=True) -> (Tensor U, Tensor S, Tensor V) + python_module: linalg + use_c10_dispatcher: full + variants: function + - func: linalg_cond(Tensor self, Scalar? p=None) -> Tensor python_module: linalg variants: function diff --git a/aten/src/ATen/native/vulkan/api/Cache.h b/aten/src/ATen/native/vulkan/api/Cache.h index b224adbbeeda..a93385088277 100644 --- a/aten/src/ATen/native/vulkan/api/Cache.h +++ b/aten/src/ATen/native/vulkan/api/Cache.h @@ -62,6 +62,10 @@ class Cache final { Factory factory_; }; +// +// Impl +// + template inline Cache::Cache(Factory factory) : factory_(std::move(factory)) { diff --git a/aten/src/ATen/native/vulkan/api/Command.cpp b/aten/src/ATen/native/vulkan/api/Command.cpp index 5aa3586d4683..247b51fa5395 100644 --- a/aten/src/ATen/native/vulkan/api/Command.cpp +++ b/aten/src/ATen/native/vulkan/api/Command.cpp @@ -76,6 +76,25 @@ Command::Buffer::Buffer(const VkCommandBuffer command_buffer) "Invalid Vulkan command buffer!"); } +Command::Buffer::Buffer(Buffer&& buffer) + : command_buffer_(std::move(buffer.command_buffer_)), + bound_(std::move(buffer.bound_)), + barriers_(std::move(buffer.barriers_)) { + buffer.invalidate(); +} + +Command::Buffer& Command::Buffer::operator=(Buffer&& buffer) { + if (&buffer != this) { + command_buffer_ = std::move(buffer.command_buffer_); + bound_ = std::move(buffer.bound_); + barriers_ = std::move(buffer.barriers_); + + buffer.invalidate(); + }; + + return *this; +} + void Command::Buffer::Buffer::begin() { TORCH_INTERNAL_ASSERT_DEBUG_ONLY( command_buffer_, @@ -107,69 +126,6 @@ void Command::Buffer::Buffer::end() { VK_CHECK(vkEndCommandBuffer(command_buffer_)); } -void Command::Buffer::barrier() { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - command_buffer_, - "This command buffer is in an invalid state! " - "Potential reason: This command buffer is moved from."); - - if (barriers_.stage) { - c10::SmallVector buffer_memory_barriers; - - for (const Resource::Buffer::Barrier& barrier : barriers_.buffers) { - buffer_memory_barriers.push_back({ - VK_STRUCTURE_TYPE_BUFFER_MEMORY_BARRIER, - nullptr, - barrier.memory.src, - barrier.memory.dst, - VK_QUEUE_FAMILY_IGNORED, - VK_QUEUE_FAMILY_IGNORED, - barrier.object.handle, - barrier.object.offset, - barrier.object.range, - }); - } - - c10::SmallVector image_memory_barriers; - - for (const Resource::Image::Barrier& barrier : barriers_.images) { - image_memory_barriers.push_back({ - VK_STRUCTURE_TYPE_IMAGE_MEMORY_BARRIER, - nullptr, - barrier.memory.src, - barrier.memory.dst, - barrier.layout.src, - barrier.layout.dst, - VK_QUEUE_FAMILY_IGNORED, - VK_QUEUE_FAMILY_IGNORED, - barrier.object.handle, - { - VK_IMAGE_ASPECT_COLOR_BIT, - 0u, - VK_REMAINING_MIP_LEVELS, - 0u, - VK_REMAINING_ARRAY_LAYERS, - }, - }); - } - - vkCmdPipelineBarrier( - command_buffer_, - barriers_.stage.src, - barriers_.stage.dst, - 0u, - 0u, - nullptr, - buffer_memory_barriers.size(), - buffer_memory_barriers.data(), - image_memory_barriers.size(), - image_memory_barriers.data()); - } - - // Reset - barriers_.reset(); -} - void Command::Buffer::barrier(const Pipeline::Barrier& barrier) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY( command_buffer_, @@ -291,31 +247,86 @@ void Command::Buffer::dispatch( bound_.pipeline.local_work_group.data[2u])); } -void Command::Buffer::submit( - const VkQueue queue, - const Resource::Fence fence) { +void Command::Buffer::barrier() { TORCH_INTERNAL_ASSERT_DEBUG_ONLY( command_buffer_, "This command buffer is in an invalid state! " "Potential reason: This command buffer is moved from."); - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - queue, - "Invalid Vulkan queue!"); + if (barriers_.stage) { + c10::SmallVector buffer_memory_barriers; - const VkSubmitInfo submit_info{ - VK_STRUCTURE_TYPE_SUBMIT_INFO, - nullptr, - 0u, - nullptr, - nullptr, - 1u, - &command_buffer_, - 0u, - nullptr, - }; + for (const Resource::Buffer::Barrier& barrier : barriers_.buffers) { + buffer_memory_barriers.push_back({ + VK_STRUCTURE_TYPE_BUFFER_MEMORY_BARRIER, + nullptr, + barrier.memory.src, + barrier.memory.dst, + VK_QUEUE_FAMILY_IGNORED, + VK_QUEUE_FAMILY_IGNORED, + barrier.object.handle, + barrier.object.offset, + barrier.object.range, + }); + } + + c10::SmallVector image_memory_barriers; - VK_CHECK(vkQueueSubmit(queue, 1u, &submit_info, fence.handle())); + for (const Resource::Image::Barrier& barrier : barriers_.images) { + image_memory_barriers.push_back({ + VK_STRUCTURE_TYPE_IMAGE_MEMORY_BARRIER, + nullptr, + barrier.memory.src, + barrier.memory.dst, + barrier.layout.src, + barrier.layout.dst, + VK_QUEUE_FAMILY_IGNORED, + VK_QUEUE_FAMILY_IGNORED, + barrier.object.handle, + { + VK_IMAGE_ASPECT_COLOR_BIT, + 0u, + VK_REMAINING_MIP_LEVELS, + 0u, + VK_REMAINING_ARRAY_LAYERS, + }, + }); + } + + vkCmdPipelineBarrier( + command_buffer_, + barriers_.stage.src, + barriers_.stage.dst, + 0u, + 0u, + nullptr, + buffer_memory_barriers.size(), + buffer_memory_barriers.data(), + image_memory_barriers.size(), + image_memory_barriers.data()); + } + + // Reset + barriers_.reset(); +} + +void Command::Buffer::invalidate() { + command_buffer_ = VK_NULL_HANDLE; +} + +inline void Command::Buffer::Bound::reset() { + pipeline = {}; + descriptor_set = VK_NULL_HANDLE; +} + +inline Command::Buffer::Barrier::Stage::operator bool() const { + return (0u != src) || (0u != dst); +} + +inline void Command::Buffer::Barrier::reset() { + stage = {}; + buffers.clear(); + images.clear(); } Command::Pool::Pool(const GPU& gpu) @@ -338,8 +349,9 @@ Command::Pool::Pool(const GPU& gpu) Command::Pool::Pool(Pool&& pool) : device_(std::move(pool.device_)), command_pool_(std::move(pool.command_pool_)), - buffer_(std::move(pool.buffer_)) { - pool.device_ = VK_NULL_HANDLE; + buffer_(std::move(pool.buffer_)), + stream_(std::move(pool.stream_)) { + pool.invalidate(); } Command::Pool& Command::Pool::operator=(Pool&& pool) { @@ -347,8 +359,9 @@ Command::Pool& Command::Pool::operator=(Pool&& pool) { device_ = std::move(pool.device_); command_pool_ = std::move(pool.command_pool_); buffer_ = std::move(pool.buffer_); + stream_ = std::move(pool.stream_); - pool.device_ = VK_NULL_HANDLE; + pool.invalidate(); }; return *this; @@ -383,25 +396,109 @@ Command::Buffer Command::Pool::allocate() { Configuration::kQuantum); allocate_command_buffers( - device_, - command_pool_.get(), - buffer_.pool.data() + buffer_.in_use, - Configuration::kQuantum); + device_, + command_pool_.get(), + buffer_.pool.data() + buffer_.in_use, + Configuration::kQuantum); } return Buffer(buffer_.pool[buffer_.in_use++]); } +Command::Buffer& Command::Pool::stream() { + if (!stream_.buffer) { + stream_.buffer = allocate(); + stream_.buffer.begin(); + stream_.counter = 0u; + } + + return stream_.buffer; +} + void Command::Pool::purge() { TORCH_INTERNAL_ASSERT_DEBUG_ONLY( device_ && command_pool_, "This command pool is in an invalid state! " "Potential reason: This command pool is moved from."); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + !stream_.buffer, + "Pending command buffer detected. Make sure all command buffers are " + "submitted to the queue for execution prior to reclaiming pool memory."); + buffer_.in_use = 0u; VK_CHECK(vkResetCommandPool(device_, command_pool_.get(), 0u)); } +void Command::Pool::submit( + const VkQueue queue, + const c10::ArrayRef buffers, + const Resource::Fence fence) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + device_ && command_pool_, + "This command pool is in an invalid state! " + "Potential reason: This command pool is moved from."); + + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + queue, + "Invalid Vulkan queue!"); + + c10::SmallVector command_buffers; + command_buffers.reserve(buffers.size()); + + for (const Buffer& buffer : buffers) { + VkCommandBuffer command_buffer = buffer.handle(); + + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + command_buffer, + "Invalid Vulkan command buffer!"); + + // Are we submitting our one and only command stream, or a regular command + // buffer whose scope is manually maintained by the user? Automatically + // maintain state and submission rate if the former. + + if (stream_.buffer.handle() == command_buffer) { + // Hand the stream off to the driver if: + // - The user has implictly signaled interest in the results via a fence. + // - We are over the submission cutoff. We don't want to starve the GPU. + + if (fence || (stream_.counter++ > Configuration::kSubmit)) { + stream_.buffer.end(); + stream_.buffer.invalidate(); + } + // Skip - Accumulate more calls prior to submission. + else { + command_buffer = VK_NULL_HANDLE; + } + } + + if (command_buffer) { + command_buffers.push_back(command_buffer); + } + } + + if (!command_buffers.empty()) { + const VkSubmitInfo submit_info{ + VK_STRUCTURE_TYPE_SUBMIT_INFO, + nullptr, + 0u, + nullptr, + nullptr, + command_buffers.size(), + command_buffers.data(), + 0u, + nullptr, + }; + + VK_CHECK(vkQueueSubmit(queue, 1u, &submit_info, fence.handle())); + } +} + +void Command::Pool::invalidate() { + device_ = VK_NULL_HANDLE; + command_pool_.reset(); +} + } // namespace api } // namespace vulkan } // namespace native diff --git a/aten/src/ATen/native/vulkan/api/Command.h b/aten/src/ATen/native/vulkan/api/Command.h index 42f073674be5..8b60d8bcd8f5 100644 --- a/aten/src/ATen/native/vulkan/api/Command.h +++ b/aten/src/ATen/native/vulkan/api/Command.h @@ -7,6 +7,7 @@ #include #include #include +#include namespace at { namespace native { @@ -14,13 +15,15 @@ namespace vulkan { namespace api { struct Command final { + class Pool; + // // Buffer // class Buffer final { public: - Buffer(VkCommandBuffer command_buffer = VK_NULL_HANDLE); + explicit Buffer(VkCommandBuffer command_buffer = VK_NULL_HANDLE); Buffer(const Buffer&) = delete; Buffer& operator=(const Buffer&) = delete; Buffer(Buffer&&); @@ -28,18 +31,22 @@ struct Command final { ~Buffer() = default; operator bool() const; + VkCommandBuffer handle() const; void begin(); void end(); + void barrier(const Pipeline::Barrier& barrier); void bind(const Pipeline::Object& pipeline); void bind(const Descriptor::Set& set); void copy(Resource::Buffer::Object source, Resource::Buffer::Object destination); void dispatch(const Shader::WorkGroup& global_work_group); - void submit(VkQueue queue, Resource::Fence fence = {}); private: + friend class Pool; + void barrier(); + void invalidate(); private: VkCommandBuffer command_buffer_; @@ -80,12 +87,22 @@ struct Command final { ~Pool(); Buffer allocate(); + Buffer& stream(); void purge(); + void submit( + VkQueue queue, + c10::ArrayRef buffers, + Resource::Fence fence = {}); + + private: + void invalidate(); + private: struct Configuration final { - static constexpr uint32_t kQuantum = 64u; - static constexpr uint32_t kReserve = 1024u; + static constexpr uint32_t kQuantum = 4u; + static constexpr uint32_t kReserve = 16u; + static constexpr uint32_t kSubmit = 10u; }; VkDevice device_; @@ -95,6 +112,11 @@ struct Command final { std::vector pool; size_t in_use; } buffer_; + + struct { + Buffer buffer; + uint32_t counter; + } stream_; } pool /* [thread_count] */; explicit Command(const GPU& gpu) @@ -106,43 +128,12 @@ struct Command final { // Impl // -inline Command::Buffer::Buffer(Buffer&& buffer) - : command_buffer_(std::move(buffer.command_buffer_)), - bound_(std::move(buffer.bound_)), - barriers_(std::move(buffer.barriers_)) { - buffer.command_buffer_ = VK_NULL_HANDLE; -} - -inline Command::Buffer& Command::Buffer::operator=(Buffer&& buffer) { - if (&buffer != this) { - command_buffer_ = std::move(buffer.command_buffer_); - bound_ = std::move(buffer.bound_); - barriers_ = std::move(buffer.barriers_); - - buffer.command_buffer_ = VK_NULL_HANDLE; - }; - - return *this; -} - inline Command::Buffer::operator bool() const { return VK_NULL_HANDLE != command_buffer_; } -inline void Command::Buffer::Bound::reset() { - pipeline = {}; - descriptor_set = VK_NULL_HANDLE; -} - -inline Command::Buffer::Barrier::Stage::operator bool() const { - return (0u != src) || - (0u != dst); -} - -inline void Command::Buffer::Barrier::reset() { - stage = {}; - buffers.clear(); - images.clear(); +inline VkCommandBuffer Command::Buffer::handle() const { + return command_buffer_; } } // namespace api diff --git a/aten/src/ATen/native/vulkan/api/Common.h b/aten/src/ATen/native/vulkan/api/Common.h index d606f1d859a9..49f9ffa21a22 100644 --- a/aten/src/ATen/native/vulkan/api/Common.h +++ b/aten/src/ATen/native/vulkan/api/Common.h @@ -6,10 +6,17 @@ #ifdef USE_VULKAN_SHADERC_RUNTIME #include -#define VK_KERNEL(name) { name##_glsl, } +#define VK_KERNEL(name) \ + ::at::native::vulkan::api::Shader::Descriptor{ \ + name##_glsl, \ + } #else #include -#define VK_KERNEL(name) { name##_spv, name##_spv_len, } +#define VK_KERNEL(name) \ + ::at::native::vulkan::api::Shader::Descriptor{ \ + name##_spv, \ + name##_spv_len, \ + } #endif /* USE_VULKAN_SHADERC_RUNTIME */ #ifdef USE_VULKAN_WRAPPER diff --git a/aten/src/ATen/native/vulkan/api/Context.cpp b/aten/src/ATen/native/vulkan/api/Context.cpp index 09dfa8fc1d77..0a9a6e130f4f 100644 --- a/aten/src/ATen/native/vulkan/api/Context.cpp +++ b/aten/src/ATen/native/vulkan/api/Context.cpp @@ -43,6 +43,40 @@ VkDevice create_device( &queue_priorities, }; + uint32_t device_extension_properties_count = 0; + VK_CHECK(vkEnumerateDeviceExtensionProperties( + physical_device, + nullptr, + &device_extension_properties_count, + nullptr)); + + std::vector device_extension_properties( + device_extension_properties_count); + + VK_CHECK(vkEnumerateDeviceExtensionProperties( + physical_device, + nullptr, + &device_extension_properties_count, + device_extension_properties.data())); + + constexpr const char* const requested_device_extensions[]{ + #ifdef VK_KHR_portability_subset + // https://vulkan.lunarg.com/doc/view/1.2.162.0/mac/1.2-extensions/vkspec.html#VUID-VkDeviceCreateInfo-pProperties-04451 + VK_KHR_PORTABILITY_SUBSET_EXTENSION_NAME, + #endif + }; + + std::vector enabled_device_extensions; + + for (const auto& requested_device_extension : requested_device_extensions) { + for (const auto& extension : device_extension_properties) { + if (strcmp(requested_device_extension, extension.extensionName) == 0) { + enabled_device_extensions.push_back(requested_device_extension); + break; + } + } + } + const VkDeviceCreateInfo device_create_info{ VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO, nullptr, @@ -51,7 +85,8 @@ VkDevice create_device( &device_queue_create_info, 0u, nullptr, - 0u, + static_cast(enabled_device_extensions.size()), + enabled_device_extensions.data(), nullptr, }; diff --git a/aten/src/ATen/native/vulkan/api/Descriptor.cpp b/aten/src/ATen/native/vulkan/api/Descriptor.cpp index 317536248987..5bdcb0b7fd02 100644 --- a/aten/src/ATen/native/vulkan/api/Descriptor.cpp +++ b/aten/src/ATen/native/vulkan/api/Descriptor.cpp @@ -128,27 +128,25 @@ Descriptor::Set::Set( "Invalid Vulkan descriptor set!"); } -void Descriptor::Set::update(const Item& item) { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - device_ && descriptor_set_, - "This descriptor set is in an invalid state! " - "Potential reason: This descriptor set is moved from."); +Descriptor::Set::Set(Set&& set) + : device_(std::move(set.device_)), + descriptor_set_(std::move(set.descriptor_set_)), + shader_layout_signature_(std::move(set.shader_layout_signature_)), + bindings_(std::move(set.bindings_)) { + set.invalidate(); +} - const auto items_itr = std::find_if( - bindings_.items.begin(), - bindings_.items.end(), - [binding = item.binding](const Item& other) { - return other.binding == binding; - }); +Descriptor::Set& Descriptor::Set::operator=(Set&& set) { + if (&set != this) { + device_ = std::move(set.device_); + descriptor_set_ = std::move(set.descriptor_set_); + shader_layout_signature_ = std::move(set.shader_layout_signature_); + bindings_ = std::move(set.bindings_); - if (bindings_.items.end() == items_itr) { - bindings_.items.emplace_back(item); - } - else { - *items_itr = item; - } + set.invalidate(); + }; - bindings_.dirty = true; + return *this; } Descriptor::Set& Descriptor::Set::bind( @@ -276,12 +274,39 @@ VkDescriptorSet Descriptor::Set::handle() const { return descriptor_set_; } +void Descriptor::Set::invalidate() { + device_ = VK_NULL_HANDLE; + descriptor_set_ = VK_NULL_HANDLE; +} + +void Descriptor::Set::update(const Item& item) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + device_ && descriptor_set_, + "This descriptor set is in an invalid state! " + "Potential reason: This descriptor set is moved from."); + + const auto items_itr = std::find_if( + bindings_.items.begin(), + bindings_.items.end(), + [binding = item.binding](const Item& other) { + return other.binding == binding; + }); + + if (bindings_.items.end() == items_itr) { + bindings_.items.emplace_back(item); + } + else { + *items_itr = item; + } + + bindings_.dirty = true; +} + Descriptor::Pool::Pool(const GPU& gpu) : device_(gpu.device), descriptor_pool_( create_descriptor_pool(gpu.device), - VK_DELETER(DescriptorPool)(device_)), - set_{} { + VK_DELETER(DescriptorPool)(device_)) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY( device_, "Invalid Vulkan device!"); @@ -295,7 +320,7 @@ Descriptor::Pool::Pool(Pool&& pool) : device_(std::move(pool.device_)), descriptor_pool_(std::move(pool.descriptor_pool_)), set_(std::move(pool.set_)) { - pool.device_ = VK_NULL_HANDLE; + pool.invalidate(); } Descriptor::Pool& Descriptor::Pool::operator=(Pool&& pool) { @@ -304,7 +329,7 @@ Descriptor::Pool& Descriptor::Pool::operator=(Pool&& pool) { descriptor_pool_ = std::move(pool.descriptor_pool_); set_ = std::move(pool.set_); - pool.device_ = VK_NULL_HANDLE; + pool.invalidate(); }; return *this; @@ -371,8 +396,13 @@ void Descriptor::Pool::purge() { "This descriptor pool is in an invalid state! " "Potential reason: This descriptor pool is moved from."); - set_.layouts.clear(); VK_CHECK(vkResetDescriptorPool(device_, descriptor_pool_.get(), 0u)); + set_.layouts.clear(); +} + +void Descriptor::Pool::invalidate() { + device_ = VK_NULL_HANDLE; + descriptor_pool_.reset(); } } // namespace api diff --git a/aten/src/ATen/native/vulkan/api/Descriptor.h b/aten/src/ATen/native/vulkan/api/Descriptor.h index 440bb9aa4097..6c50a350d7f3 100644 --- a/aten/src/ATen/native/vulkan/api/Descriptor.h +++ b/aten/src/ATen/native/vulkan/api/Descriptor.h @@ -73,6 +73,9 @@ struct Descriptor final { VkDescriptorSet handle() const; + private: + void invalidate(); + private: struct Item final { uint32_t binding; @@ -113,6 +116,9 @@ struct Descriptor final { Set allocate(const Shader::Layout::Object& shader_layout); void purge(); + private: + void invalidate(); + private: struct Configuration final { static constexpr uint32_t kQuantum = 16u; @@ -137,33 +143,6 @@ struct Descriptor final { } }; -// -// Impl -// - -inline Descriptor::Set::Set(Set&& set) - : device_(std::move(set.device_)), - descriptor_set_(std::move(set.descriptor_set_)), - shader_layout_signature_(std::move(set.shader_layout_signature_)), - bindings_(std::move(set.bindings_)) { - set.device_ = VK_NULL_HANDLE; - set.descriptor_set_ = VK_NULL_HANDLE; -} - -inline Descriptor::Set& Descriptor::Set::operator=(Set&& set) { - if (&set != this) { - device_ = std::move(set.device_); - descriptor_set_ = std::move(set.descriptor_set_); - shader_layout_signature_ = std::move(set.shader_layout_signature_); - bindings_ = std::move(set.bindings_); - - set.device_ = VK_NULL_HANDLE; - set.descriptor_set_ = VK_NULL_HANDLE; - }; - - return *this; -} - } // namespace api } // namespace vulkan } // namespace native diff --git a/aten/src/ATen/native/vulkan/api/Pipeline.cpp b/aten/src/ATen/native/vulkan/api/Pipeline.cpp index 4b15203892ed..89e85892ee0c 100644 --- a/aten/src/ATen/native/vulkan/api/Pipeline.cpp +++ b/aten/src/ATen/native/vulkan/api/Pipeline.cpp @@ -169,6 +169,10 @@ Pipeline::Cache::Cache(Factory factory) : cache_(std::move(factory)) { } +void Pipeline::Cache::purge() { + cache_.purge(); +} + } // namespace api } // namespace vulkan } // namespace native diff --git a/aten/src/ATen/native/vulkan/api/Pipeline.h b/aten/src/ATen/native/vulkan/api/Pipeline.h index 1d1966790dbf..794193d8a161 100644 --- a/aten/src/ATen/native/vulkan/api/Pipeline.h +++ b/aten/src/ATen/native/vulkan/api/Pipeline.h @@ -196,7 +196,11 @@ inline Pipeline::Barrier::operator bool() const { inline bool operator==( const Pipeline::Layout::Descriptor& _1, const Pipeline::Layout::Descriptor& _2) { - return (_1.descriptor_set_layout == _2.descriptor_set_layout); + static_assert( + std::is_trivially_copyable::value, + "This implementation is no longer valid!"); + + return (0 == memcmp(&_1, &_2, sizeof(Pipeline::Layout::Descriptor))); } inline size_t Pipeline::Layout::Factory::Hasher::operator()( @@ -207,9 +211,11 @@ inline size_t Pipeline::Layout::Factory::Hasher::operator()( inline bool operator==( const Pipeline::Descriptor& _1, const Pipeline::Descriptor& _2) { - return (_1.pipeline_layout == _2.pipeline_layout) && - (_1.shader_module == _2.shader_module) && - (_1.local_work_group == _2.local_work_group); + static_assert( + std::is_trivially_copyable::value, + "This implementation is no longer valid!"); + + return (0 == memcmp(&_1, &_2, sizeof(Pipeline::Descriptor))); } inline size_t Pipeline::Factory::Hasher::operator()( @@ -236,10 +242,6 @@ inline Pipeline::Object Pipeline::Cache::retrieve( }; } -inline void Pipeline::Cache::purge() { - cache_.purge(); -} - } // namespace api } // namespace vulkan } // namespace native diff --git a/aten/src/ATen/native/vulkan/api/Resource.cpp b/aten/src/ATen/native/vulkan/api/Resource.cpp index a491ed4dd6e0..adda610fb90c 100644 --- a/aten/src/ATen/native/vulkan/api/Resource.cpp +++ b/aten/src/ATen/native/vulkan/api/Resource.cpp @@ -192,6 +192,11 @@ VkFence Resource::Fence::handle(const bool add_to_waitlist) const { "Invalid Vulkan fence!"); const VkFence fence = pool->fence_.pool[id].get(); + + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + fence, + "Invalid Vulkan fence!"); + if (add_to_waitlist) { pool->fence_.waitlist.push_back(fence); } @@ -360,14 +365,13 @@ Resource::Pool::Pool( : device_(gpu.device), allocator_( create_allocator( - gpu.adapter->runtime->instance(), - gpu.adapter->handle, - device_), + gpu.adapter->runtime->instance(), + gpu.adapter->handle, + device_), vmaDestroyAllocator), memory_{ std::move(policy), }, - buffer_{}, image_{ .sampler = Image::Sampler{gpu}, }, @@ -377,23 +381,6 @@ Resource::Pool::Pool( fence_.pool.reserve(Configuration::kReserve); } -Resource::Pool::~Pool() { - try { - if (device_ && allocator_) { - purge(); - } - } - catch (const std::exception& e) { - LOG(WARNING) - << "Vulkan: Resource pool destructor raised an exception! Error: " - << e.what(); - } - catch (...) { - LOG(WARNING) - << "Vulkan: Resource pool destructor raised an unknown exception!"; - } -} - Resource::Pool::Pool(Pool&& pool) : device_(std::move(pool.device_)), allocator_(std::move(pool.allocator_)), @@ -401,7 +388,7 @@ Resource::Pool::Pool(Pool&& pool) buffer_(std::move(pool.buffer_)), image_(std::move(pool.image_)), fence_(std::move(pool.fence_)) { - pool.device_ = VK_NULL_HANDLE; + pool.invalidate(); } Resource::Pool& Resource::Pool::operator=(Pool&& pool) { @@ -413,12 +400,29 @@ Resource::Pool& Resource::Pool::operator=(Pool&& pool) { image_ = std::move(pool.image_); fence_ = std::move(pool.fence_); - pool.device_ = VK_NULL_HANDLE; + pool.invalidate(); }; return *this; } +Resource::Pool::~Pool() { + try { + if (device_ && allocator_) { + purge(); + } + } + catch (const std::exception& e) { + LOG(WARNING) + << "Vulkan: Resource pool destructor raised an exception! Error: " + << e.what(); + } + catch (...) { + LOG(WARNING) + << "Vulkan: Resource pool destructor raised an unknown exception!"; + } +} + Resource::Buffer Resource::Pool::buffer( const Buffer::Descriptor& descriptor) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY( @@ -678,6 +682,11 @@ void Resource::Pool::purge() { buffer_.pool.clear(); } +void Resource::Pool::invalidate() { + device_ = VK_NULL_HANDLE; + allocator_.reset(); +} + } // namespace api } // namespace vulkan } // namespace native diff --git a/aten/src/ATen/native/vulkan/api/Resource.h b/aten/src/ATen/native/vulkan/api/Resource.h index 61a2094dc692..19a7df3d04d2 100644 --- a/aten/src/ATen/native/vulkan/api/Resource.h +++ b/aten/src/ATen/native/vulkan/api/Resource.h @@ -20,6 +20,16 @@ struct Resource final { // struct Memory final { + /* + Descriptor + */ + + struct Descriptor final { + VmaMemoryUsage usage; + VkMemoryPropertyFlags /* optional */ required; + VkMemoryPropertyFlags /* optional */ preferred; + }; + /* Barrier */ @@ -30,18 +40,9 @@ struct Resource final { }; /* - Descriptor + Access */ - struct Descriptor final { - VmaMemoryUsage usage; - VkMemoryPropertyFlags /* optional */ required; - VkMemoryPropertyFlags /* optional */ preferred; - }; - - VmaAllocator allocator; - VmaAllocation allocation; - struct Access final { typedef uint8_t Flags; @@ -74,6 +75,9 @@ struct Resource final { typename Pointer = Access::Pointer> Handle map() &; + VmaAllocator allocator; + VmaAllocation allocation; + private: // Intentionally disabed to ensure memory access is always properly // encapsualted in a scoped map-unmap region. Allowing below overloads @@ -299,6 +303,8 @@ struct Resource final { private: friend struct Fence; + void invalidate(); + private: struct Configuration final { static constexpr uint32_t kReserve = 256u; @@ -353,7 +359,8 @@ class Resource::Memory::Scope final { template inline Resource::Memory::Handle Resource::Memory::map() const & { - void* map(const Memory& memory, Access::Flags); + // Forward declaration + void* map(const Memory&, Access::Flags); return Handle{ reinterpret_cast(map(*this, Access::Read)), @@ -363,7 +370,8 @@ inline Resource::Memory::Handle Resource::Memory::map() const & { template inline Resource::Memory::Handle Resource::Memory::map() & { - void* map(const Memory& memory, Access::Flags); + // Forward declaration + void* map(const Memory&, Access::Flags); static_assert( (kAccess == Access::Read) || @@ -388,10 +396,11 @@ inline Resource::Buffer::operator bool() const { inline bool operator==( const Resource::Image::Sampler::Descriptor& _1, const Resource::Image::Sampler::Descriptor& _2) { - return (_1.filter == _2.filter) && - (_1.mipmap_mode == _2.mipmap_mode) && - (_1.address_mode == _2.address_mode) && - (_1.border == _2.border); + static_assert( + std::is_trivially_copyable::value, + "This implementation is no longer valid!"); + + return (0 == memcmp(&_1, &_2, sizeof(Resource::Image::Sampler::Descriptor))); } inline size_t Resource::Image::Sampler::Factory::Hasher::operator()( diff --git a/aten/src/ATen/native/vulkan/api/Runtime.cpp b/aten/src/ATen/native/vulkan/api/Runtime.cpp index b63ded444887..c3ad6ebddb45 100644 --- a/aten/src/ATen/native/vulkan/api/Runtime.cpp +++ b/aten/src/ATen/native/vulkan/api/Runtime.cpp @@ -86,7 +86,9 @@ VkInstance create_instance(const Runtime::Type type) { nullptr, &instance_extension_count, instance_extension_properties.data())); constexpr const char* const requested_instance_extensions[]{ + #ifdef VK_EXT_debug_report VK_EXT_DEBUG_REPORT_EXTENSION_NAME, + #endif }; for (const auto& requested_instance_extension : requested_instance_extensions) { diff --git a/aten/src/ATen/native/vulkan/api/Runtime.h b/aten/src/ATen/native/vulkan/api/Runtime.h index 675f35cd0789..55eae70f8723 100644 --- a/aten/src/ATen/native/vulkan/api/Runtime.h +++ b/aten/src/ATen/native/vulkan/api/Runtime.h @@ -33,10 +33,7 @@ class Runtime final { Runtime& operator=(Runtime&&) = default; ~Runtime() = default; - inline VkInstance instance() const { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(instance_); - return instance_.get(); - } + VkInstance instance() const; typedef std::function Selector; Adapter select(const Selector& selector); @@ -59,6 +56,15 @@ class Runtime final { Runtime* runtime(); +// +// Impl +// + +inline VkInstance Runtime::instance() const { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(instance_); + return instance_.get(); +} + } // namespace api } // namespace vulkan } // namespace native diff --git a/aten/src/ATen/native/vulkan/api/Shader.cpp b/aten/src/ATen/native/vulkan/api/Shader.cpp index 43d1a62ac201..7995dd160c35 100644 --- a/aten/src/ATen/native/vulkan/api/Shader.cpp +++ b/aten/src/ATen/native/vulkan/api/Shader.cpp @@ -60,6 +60,10 @@ Shader::Layout::Cache::Cache(Factory factory) : cache_(std::move(factory)) { } +void Shader::Layout::Cache::purge() { + cache_.purge(); +} + #ifdef USE_VULKAN_SHADERC_RUNTIME struct Shader::Factory::Compiler final { diff --git a/aten/src/ATen/native/vulkan/api/Shader.h b/aten/src/ATen/native/vulkan/api/Shader.h index 718504e69bd4..f005eb1c11e9 100644 --- a/aten/src/ATen/native/vulkan/api/Shader.h +++ b/aten/src/ATen/native/vulkan/api/Shader.h @@ -218,16 +218,14 @@ inline Shader::Layout::Object Shader::Layout::Cache::retrieve( }; } -inline void Shader::Layout::Cache::purge() { - cache_.purge(); -} - inline bool operator==( const Shader::WorkGroup& _1, const Shader::WorkGroup& _2) { - return (_1.data[0u] == _2.data[0u]) && - (_1.data[1u] == _2.data[1u]) && - (_1.data[2u] == _2.data[2u]); + static_assert( + std::is_trivially_copyable::value, + "This implementation is no longer valid!"); + + return (0 == memcmp(&_1, &_2, sizeof(Shader::WorkGroup))); } inline Shader::Descriptor::Descriptor(const char* const glsl) @@ -258,12 +256,10 @@ inline bool operator==( const Shader::Descriptor& _1, const Shader::Descriptor& _2) { static_assert( - sizeof(Shader::Descriptor::shader.source) == sizeof(Shader::Descriptor::shader.binary), - "This implementation requires sizeof(Source) to be equal to sizeof(Binary)."); + std::is_trivially_copyable::value, + "This implementation is no longer valid!"); - return (_1.type == _2.type) && - (_1.shader.binary.spirv == _2.shader.binary.spirv) && - (_1.shader.binary.size == _2.shader.binary.size); + return (0 == memcmp(&_1, &_2, sizeof(Shader::Descriptor))); } inline size_t Shader::Factory::Hasher::operator()( @@ -286,11 +282,11 @@ inline size_t Shader::Factory::Hasher::operator()( inline bool operator==( const VkDescriptorSetLayoutBinding& _1, const VkDescriptorSetLayoutBinding& _2) { - return (_1.binding == _2.binding) && - (_1.descriptorType == _2.descriptorType) && - (_1.descriptorCount == _2.descriptorCount) && - (_1.stageFlags == _2.stageFlags) && - (_1.pImmutableSamplers == _2.pImmutableSamplers); + static_assert( + std::is_trivially_copyable::value, + "This implementation is no longer valid!"); + + return (0 == memcmp(&_1, &_2, sizeof(VkDescriptorSetLayoutBinding))); } #endif /* USE_VULKAN_API */ diff --git a/aten/src/ATen/native/vulkan/ops/Add.cpp b/aten/src/ATen/native/vulkan/ops/Add.cpp index 270a1d5f8168..95b7fd67c095 100644 --- a/aten/src/ATen/native/vulkan/ops/Add.cpp +++ b/aten/src/ATen/native/vulkan/ops/Add.cpp @@ -24,11 +24,11 @@ Tensor add_scalar( v_self.options(), }; - api::Command::Buffer command_buffer = context->command().pool.allocate(); - command_buffer.begin(); + api::Command::Pool& command_pool = context->command().pool; + api::Command::Buffer& command_buffer = command_pool.stream(); { - if (v_output.has_image() && v_self.has_image()) { - const struct { + if C10_LIKELY(v_output.has_image() && v_self.has_image()) { + const struct Block final { uvec3 extents; float other; } block { @@ -64,8 +64,7 @@ Tensor add_scalar( TORCH_CHECK(false, "Not implemented!"); } } - command_buffer.end(); - command_buffer.submit(context->gpu().queue); + command_pool.submit(context->gpu().queue, command_buffer); return convert(v_output); } @@ -82,11 +81,11 @@ Tensor& add_scalar_( vTensor& v_self = convert(self); - api::Command::Buffer command_buffer = context->command().pool.allocate(); - command_buffer.begin(); + api::Command::Pool& command_pool = context->command().pool; + api::Command::Buffer& command_buffer = command_pool.stream(); { - if (v_self.has_image()) { - const struct { + if C10_LIKELY(v_self.has_image()) { + const struct Block final { uvec3 extents; float other; } block { @@ -116,8 +115,7 @@ Tensor& add_scalar_( TORCH_CHECK(false, "Not implemented!"); } } - command_buffer.end(); - command_buffer.submit(context->gpu().queue); + command_pool.submit(context->gpu().queue, command_buffer); return self; } @@ -140,11 +138,11 @@ Tensor add_tensor( v_self.options(), }; - api::Command::Buffer command_buffer = context->command().pool.allocate(); - command_buffer.begin(); + api::Command::Pool& command_pool = context->command().pool; + api::Command::Buffer& command_buffer = command_pool.stream(); { - if (v_self.has_image() && v_other.has_image()) { - const struct { + if C10_LIKELY(v_self.has_image() && v_other.has_image()) { + const struct Block final { uvec3 extents; float alpha; } block { @@ -186,8 +184,7 @@ Tensor add_tensor( TORCH_CHECK(false, "Not implemented!"); } } - command_buffer.end(); - command_buffer.submit(context->gpu().queue); + command_pool.submit(context->gpu().queue, command_buffer); return convert(v_output); } @@ -207,11 +204,11 @@ Tensor& add_tensor_( const Tensor other = other_arg.is_vulkan() ? other_arg : other_arg.vulkan(); const vTensor& v_other = convert(other); - api::Command::Buffer command_buffer = context->command().pool.allocate(); - command_buffer.begin(); + api::Command::Pool& command_pool = context->command().pool; + api::Command::Buffer& command_buffer = command_pool.stream(); { - if (v_self.has_image() && v_other.has_image() && !self.is_same(other)) { - const struct { + if C10_LIKELY(v_self.has_image() && v_other.has_image() && !self.is_same(other)) { + const struct Block final { uvec3 extents; float alpha; } block { @@ -247,8 +244,7 @@ Tensor& add_tensor_( TORCH_CHECK(false, "Not implemented!"); } } - command_buffer.end(); - command_buffer.submit(context->gpu().queue); + command_pool.submit(context->gpu().queue, command_buffer); return self; } diff --git a/aten/src/ATen/native/vulkan/ops/Clamp.cpp b/aten/src/ATen/native/vulkan/ops/Clamp.cpp index 9f25d89bca9b..75e9a1bb0fff 100644 --- a/aten/src/ATen/native/vulkan/ops/Clamp.cpp +++ b/aten/src/ATen/native/vulkan/ops/Clamp.cpp @@ -28,11 +28,11 @@ Tensor clamp( v_self.options(), }; - api::Command::Buffer command_buffer = context->command().pool.allocate(); - command_buffer.begin(); + api::Command::Pool& command_pool = context->command().pool; + api::Command::Buffer& command_buffer = command_pool.stream(); { - if (v_output.has_image() && v_self.has_image()) { - const struct { + if C10_LIKELY(v_output.has_image() && v_self.has_image()) { + const struct Block final { uvec3 extents; uint32_t _; vec2 clamp; @@ -73,8 +73,7 @@ Tensor clamp( TORCH_CHECK(false, "Not implemented!"); } } - command_buffer.end(); - command_buffer.submit(context->gpu().queue); + command_pool.submit(context->gpu().queue, command_buffer); return convert(v_output); } @@ -95,11 +94,11 @@ Tensor& clamp_( vTensor& v_self = convert(self); - api::Command::Buffer command_buffer = context->command().pool.allocate(); - command_buffer.begin(); + api::Command::Pool& command_pool = context->command().pool; + api::Command::Buffer& command_buffer = command_pool.stream(); { - if (v_self.has_image()) { - const struct { + if C10_LIKELY(v_self.has_image()) { + const struct Block final { uvec3 extents; uint32_t _; vec2 clamp; @@ -134,8 +133,7 @@ Tensor& clamp_( TORCH_CHECK(false, "Not implemented!"); } } - command_buffer.end(); - command_buffer.submit(context->gpu().queue); + command_pool.submit(context->gpu().queue, command_buffer); return self; } diff --git a/aten/src/ATen/native/vulkan/ops/Common.h b/aten/src/ATen/native/vulkan/ops/Common.h index b0bbeeaf34f1..3c9b2e8b3b9f 100644 --- a/aten/src/ATen/native/vulkan/ops/Common.h +++ b/aten/src/ATen/native/vulkan/ops/Common.h @@ -35,14 +35,6 @@ struct Layout final { }; }; -struct Experimentation { - static constexpr bool kUseConv2dOldApi = false; -}; - -struct ConvPrepackLimits final { - static constexpr int64_t maxStackDepth = 2048*4; -}; - } // namespace ops } // namespace vulkan } // namespace native diff --git a/aten/src/ATen/native/vulkan/ops/Convolution.cpp b/aten/src/ATen/native/vulkan/ops/Convolution.cpp index d88545e3a25a..8991475b7d15 100644 --- a/aten/src/ATen/native/vulkan/ops/Convolution.cpp +++ b/aten/src/ATen/native/vulkan/ops/Convolution.cpp @@ -1,8 +1,8 @@ #include -#include #include #include #include +#include namespace at { namespace native { @@ -12,6 +12,10 @@ namespace { using namespace api::utils; +struct Experimentation final { + static constexpr bool kUseConv2dOldApi = false; +}; + inline bool is_depthwise( const IntArrayRef filter, const int64_t groups) { @@ -26,47 +30,103 @@ inline bool is_pointwise(const IntArrayRef filter) { } vTensor pack_weights_dw( + api::Context* const context, + api::Command::Buffer& command_buffer, api::Resource::Pool& pool, - const Tensor& weight_arg, - const int64_t groups) { - if (weight_arg.is_vulkan()) { - return convert(weight_arg); - } - + const Tensor& weight) { /* Source */ - - const Tensor weight = weight_arg.contiguous(); const IntArrayRef src_filter = weight.sizes(); const float* const src_weight_ptr = weight.data_ptr(); const int64_t src_kw_sz = src_filter[Layout::Filter::width]; const int64_t src_kh_sz = src_filter[Layout::Filter::height]; + const int64_t src_kernel_sz = src_kw_sz * src_kh_sz; + const int64_t src_block_sz = src_kernel_sz * src_filter[Layout::Filter::input]; const int64_t num_stacks = div_up(src_filter[Layout::Filter::output], INT64_C(4)); + + /* Destination */ + const int64_t dst_kw_sz = src_kernel_sz; + const int64_t dst_kh_sz = num_stacks; + const int64_t dst_kernel_sz = dst_kw_sz * dst_kh_sz; + vTensor v_weight{ - api::context(), + context, &pool, { 4, - num_stacks, - src_kw_sz * src_kh_sz, + dst_kh_sz, + dst_kw_sz, }, weight.options(), }; using Future = vTensor::Future; - Future v_weight_future = v_weight.host(); + Future v_weight_future = v_weight.host(command_buffer); Future::Payload v_weight_payload = v_weight_future.wait(); + float* const dst_weight_ptr = v_weight_payload.get(); + memset(dst_weight_ptr, 0, v_weight.nbytes()); + + for (int64_t src_oc = 0; src_oc < src_filter[Layout::Filter::output]; ++src_oc) { + /* Source */ + const float* const src_weight_oc_ptr = src_weight_ptr + src_oc * src_block_sz; + + /* Destination */ + const int64_t dst_oh = src_oc / 4; + const int64_t dst_c = src_oc % 4; + + float* const dst_weight_c_ptr = dst_weight_ptr + + dst_c * dst_kernel_sz + + dst_oh * dst_kw_sz; + + for (int64_t src_ih = 0; src_ih < src_filter[Layout::Filter::height]; ++src_ih) { + memcpy( + dst_weight_c_ptr + src_ih * src_kw_sz, + src_weight_oc_ptr + src_ih * src_kw_sz, + sizeof(float) * src_kw_sz); + } + } + + return v_weight; +} + +vTensor pack_weights_2d( + api::Context* const context, + api::Command::Buffer& command_buffer, + api::Resource::Pool& pool, + const Tensor& weight) { /* Source */ + const IntArrayRef src_filter = weight.sizes(); + const float* const src_weight_ptr = weight.data_ptr(); + + const int64_t src_kw_sz = src_filter[Layout::Filter::width]; + const int64_t src_kh_sz = src_filter[Layout::Filter::height]; const int64_t src_kernel_sz = src_kw_sz * src_kh_sz; - const int64_t src_block_sz = - src_kernel_sz * src_filter[Layout::Filter::input]; + const int64_t src_block_sz = src_kernel_sz * src_filter[Layout::Filter::input]; + + const int64_t num_stacks = div_up(src_filter[Layout::Filter::output], INT64_C(4)); + const int64_t stack_depth = api::utils::align_up(src_filter[Layout::Filter::input], INT64_C(4)); /* Destination */ - const int64_t dst_kw_sz = src_kw_sz * src_kh_sz; - const int64_t dst_kh_sz = num_stacks; + const int64_t dst_kw_sz = src_kw_sz * stack_depth; + const int64_t dst_kh_sz = src_kh_sz * num_stacks; const int64_t dst_kernel_sz = dst_kw_sz * dst_kh_sz; + vTensor v_weight{ + context, + &pool, + { + 4, + dst_kh_sz, + dst_kw_sz, + }, + weight.options(), + }; + + using Future = vTensor::Future; + Future v_weight_future = v_weight.host(command_buffer); + Future::Payload v_weight_payload = v_weight_future.wait(); + float* const dst_weight_ptr = v_weight_payload.get(); memset(dst_weight_ptr, 0, v_weight.nbytes()); @@ -80,26 +140,29 @@ vTensor pack_weights_dw( float* const dst_weight_c_ptr = dst_weight_ptr + dst_c * dst_kernel_sz; - for (int64_t src_ih = 0; src_ih < src_filter[Layout::Filter::height]; ++src_ih) { - memcpy( - dst_weight_c_ptr + dst_oh * dst_kw_sz + src_ih * src_kw_sz, - src_weight_oc_ptr + src_ih * src_kw_sz, - sizeof(float) * src_kw_sz); + for (int64_t src_ic = 0; src_ic < src_filter[Layout::Filter::input]; ++src_ic) { + const int64_t dst_ic4 = src_ic / 4; + + for (int64_t src_ih = 0; src_ih < src_kh_sz; ++src_ih) { + for (int64_t src_iw = 0; src_iw < src_kw_sz; ++src_iw) { + memcpy( + dst_weight_c_ptr + (dst_oh * src_kh_sz + src_ih) * dst_kw_sz + + dst_ic4 * src_kw_sz * 4 + src_iw * 4 + src_ic % 4, + src_weight_oc_ptr + src_ic * src_kernel_sz + src_ih * src_kw_sz + src_iw, + sizeof(float)); + } + } } } return v_weight; } -vTensor pack_weights_old( +vTensor pack_weights_2d_old( + api::Context* const context, + api::Command::Buffer& command_buffer, api::Resource::Pool& pool, - const Tensor& weight_arg, - const int64_t groups) { - if (weight_arg.is_vulkan()) { - return convert(weight_arg); - } - - const Tensor weight = weight_arg.contiguous(); + const Tensor& weight) { const IntArrayRef src_filter = weight.sizes(); const float* const src_weight_ptr = weight.data_ptr(); @@ -111,7 +174,7 @@ vTensor pack_weights_old( const uint32_t KW = src_filter[Layout::Filter::width]; vTensor v_weight{ - api::context(), + context, &pool, { 1, @@ -123,13 +186,13 @@ vTensor pack_weights_old( }; using Future = vTensor::Future; - Future v_weight_future = v_weight.host(); + Future v_weight_future = v_weight.host(command_buffer); Future::Payload v_weight_payload = v_weight_future.wait(); float* const dst_weight_ptr = v_weight_payload.get(); memset(dst_weight_ptr, 0, v_weight.nbytes()); - const float* src = src_weight_ptr; + const float* const src = src_weight_ptr; float* const dst = dst_weight_ptr; { @@ -162,7 +225,7 @@ vTensor pack_weights_old( dim0_ = dim0; dim1_ = dim1; dim2_ = dim2; - data_ = new float[dim0 * dim1 * dim2 * 4]; + data_ = new float[dim0 * dim1 * dim2 * 4]; // TODO: memory leak memset(data_, 0.f, dim0 * dim1 * dim2 * 4 * sizeof(float)); } @@ -211,7 +274,7 @@ vTensor pack_weights_old( return v_weight; } -vTensor pack_weights_2d( +vTensor pack_weights( api::Resource::Pool& pool, const Tensor& weight_arg, const int64_t groups) { @@ -219,81 +282,32 @@ vTensor pack_weights_2d( return convert(weight_arg); } - const Tensor weight = weight_arg.contiguous(); - const IntArrayRef src_filter = weight.sizes(); - const float* const src_weight_ptr = weight.data_ptr(); - - const int64_t src_kw_sz = src_filter[Layout::Filter::width]; - const int64_t src_kh_sz = src_filter[Layout::Filter::height]; - const int64_t num_stacks = div_up(src_filter[Layout::Filter::output], INT64_C(4)); - const int64_t stack_depth = api::utils::align_up(src_filter[Layout::Filter::input], INT64_C(4)); - vTensor v_weight{ - api::context(), - &pool, - { - 4, - src_kh_sz * num_stacks, - src_kw_sz * stack_depth, - }, - weight.options(), - }; - - using Future = vTensor::Future; - Future v_weight_future = v_weight.host(); - Future::Payload v_weight_payload = v_weight_future.wait(); - - /* Source */ - const int64_t src_kernel_sz = src_kw_sz * src_kh_sz; - const int64_t src_block_sz = - src_kernel_sz * src_filter[Layout::Filter::input]; - - /* Destination */ - const int64_t dst_kw_sz = src_kw_sz * stack_depth; - const int64_t dst_kh_sz = src_kh_sz * num_stacks; - const int64_t dst_kernel_sz = dst_kw_sz * dst_kh_sz; - - float* const dst_weight_ptr = v_weight_payload.get(); - memset(dst_weight_ptr, 0, v_weight.nbytes()); - - for (int64_t src_oc = 0; src_oc < src_filter[Layout::Filter::output]; ++src_oc) { - /* Source */ - const float* const src_weight_oc_ptr = src_weight_ptr + src_oc * src_block_sz; - - /* Destination */ - const int64_t dst_oh = src_oc / 4; - const int64_t dst_c = src_oc % 4; - - float* const dst_weight_c_ptr = dst_weight_ptr + dst_c * dst_kernel_sz; - - for (int64_t src_ic = 0; src_ic < src_filter[Layout::Filter::input]; ++src_ic) { - const int64_t dst_ic4 = src_ic/4; - for (int64_t src_ih = 0; src_ih < src_kh_sz; ++src_ih) { - for (int64_t src_iw = 0; src_iw < src_kw_sz; ++src_iw) { - memcpy( - dst_weight_c_ptr + (dst_oh * src_kh_sz + src_ih) * dst_kw_sz + - dst_ic4 * src_kw_sz * 4 + src_iw * 4 + src_ic % 4, - src_weight_oc_ptr + src_ic * src_kernel_sz + src_ih * src_kw_sz + src_iw, - sizeof(float)); - } - } - } - } + api::Context* const context = api::context(); + api::Command::Buffer& command_buffer = context->command().pool.stream(); - return v_weight; -} + const Tensor weight = weight_arg.contiguous(); -vTensor pack_weights( - api::Resource::Pool& pool, - const Tensor& weight_arg, - const int64_t groups) { - if (is_depthwise(weight_arg.sizes(), groups)) { - return pack_weights_dw(pool, weight_arg, groups); + if (is_depthwise(weight.sizes(), groups)) { + return pack_weights_dw( + context, + command_buffer, + pool, + weight); } if (Experimentation::kUseConv2dOldApi) { - return pack_weights_old(pool, weight_arg, groups); + return pack_weights_2d_old( + context, + command_buffer, + pool, + weight); } - return pack_weights_2d(pool, weight_arg, groups); + + return pack_weights_2d( + context, + command_buffer, + pool, + weight); } vTensor pack_biases( @@ -304,8 +318,11 @@ vTensor pack_biases( return convert(*bias); } + api::Context* const context = api::context(); + api::Command::Buffer& command_buffer = context->command().pool.stream(); + vTensor v_bias{ - api::context(), + context, &pool, { // 1D @@ -316,7 +333,7 @@ vTensor pack_biases( { using Future = vTensor::Future; - Future v_bias_future = v_bias.host(); + Future v_bias_future = v_bias.host(command_buffer); Future::Payload v_bias_payload = v_bias_future.wait(); if (bias) { @@ -394,7 +411,8 @@ bool available( (c10::DeviceType::Vulkan == bias->device().type())) && (kFloat == bias->scalar_type()) && (transposed ? false /* to be addded in the future */ - : (weight.size(Layout::Filter::output) == bias->size(Layout::Filter::output)))) + : (weight.size(Layout::Filter::output) == + bias->size(Layout::Filter::output)))) : true) && // Stride (stride[Layout::Parameter::height] > 0) && @@ -432,7 +450,7 @@ bool usable(const Tensor& input) { true; } -void conv2d_depthwise( +void conv2d_dw( api::Context* const context, api::Command::Buffer& command_buffer, vTensor& v_output, @@ -446,27 +464,39 @@ void conv2d_depthwise( const IntArrayRef dilation, const float output_min, const float output_max) { - if (v_output.has_image() && v_input.has_image() && v_weight.has_image()) { - const struct { - int32_t kernel_x, kernel_y; - int32_t stride_x, stride_y; - int32_t padding_x, padding_y; - int32_t dilate_x, dilate_y; - float clamp_x, clamp_y; - int32_t src_filter_w, src_filter_h; + if C10_LIKELY(v_output.has_image() && v_input.has_image() && v_weight.has_image()) { + const struct Block final { + ivec2 kernel; + ivec2 stride; + ivec2 padding; + ivec2 dilate; + vec2 clamp; + ivec2 src_filter; } block { - safe_downcast(filter[Layout::Filter::width]), - safe_downcast(filter[Layout::Filter::height]), - safe_downcast(stride[Layout::Parameter::width]), - safe_downcast(stride[Layout::Parameter::height]), - safe_downcast(padding[Layout::Parameter::width]), - safe_downcast(padding[Layout::Parameter::height]), - safe_downcast(dilation[Layout::Parameter::width]), - safe_downcast(dilation[Layout::Parameter::height]), - output_min, - output_max, - safe_downcast(src_filter[Layout::Filter::width]), - safe_downcast(src_filter[Layout::Filter::height]), + { + safe_downcast(filter[Layout::Filter::width]), + safe_downcast(filter[Layout::Filter::height]), + }, + { + safe_downcast(stride[Layout::Parameter::width]), + safe_downcast(stride[Layout::Parameter::height]), + }, + { + safe_downcast(padding[Layout::Parameter::width]), + safe_downcast(padding[Layout::Parameter::height]), + }, + { + safe_downcast(dilation[Layout::Parameter::width]), + safe_downcast(dilation[Layout::Parameter::height]), + }, + { + output_min, + output_max, + }, + { + safe_downcast(src_filter[Layout::Filter::width]), + safe_downcast(src_filter[Layout::Filter::height]), + }, }; context->dispatch( @@ -510,7 +540,7 @@ void conv2d_depthwise( } } -void conv2d_pointwise( +void conv2d_pw( api::Context* const context, api::Command::Buffer& command_buffer, vTensor& v_output, @@ -522,22 +552,29 @@ void conv2d_pointwise( const IntArrayRef padding, const float output_min, const float output_max) { - if (v_output.has_image() && v_input.has_image() && v_weight.has_image()) { - - const struct { - int32_t kernel_ic, kernel_oc; - int32_t stride_x, stride_y; - int32_t padding_x, padding_y; - float clamp_x, clamp_y; + if C10_LIKELY(v_output.has_image() && v_input.has_image() && v_weight.has_image()) { + const struct Block final { + ivec2 kernel; + ivec2 stride; + ivec2 padding; + vec2 clamp; } block { - safe_downcast(filter[Layout::Filter::input]), - safe_downcast(filter[Layout::Filter::output]), - safe_downcast(stride[Layout::Parameter::width]), - safe_downcast(stride[Layout::Parameter::height]), - safe_downcast(padding[Layout::Parameter::width]), - safe_downcast(padding[Layout::Parameter::height]), - output_min, - output_max, + { + safe_downcast(filter[Layout::Filter::input]), + safe_downcast(filter[Layout::Filter::output]), + }, + { + safe_downcast(stride[Layout::Parameter::width]), + safe_downcast(stride[Layout::Parameter::height]), + }, + { + safe_downcast(padding[Layout::Parameter::width]), + safe_downcast(padding[Layout::Parameter::height]), + }, + { + output_min, + output_max, + }, }; context->dispatch( @@ -595,30 +632,134 @@ void conv2d( const IntArrayRef dilation, const float output_min, const float output_max) { + if C10_LIKELY(v_output.has_image() && v_input.has_image() && v_weight.has_image()) { + const struct Block final { + ivec4 kernel; + ivec2 stride; + ivec2 padding; + ivec2 dilate; + vec2 clamp; + ivec4 src_filter; + } block { + { + safe_downcast(filter[Layout::Filter::width]), + safe_downcast(filter[Layout::Filter::height]), + safe_downcast(filter[Layout::Filter::input]), + safe_downcast(filter[Layout::Filter::output]), + }, + { + safe_downcast(stride[Layout::Parameter::width]), + safe_downcast(stride[Layout::Parameter::height]), + }, + { + safe_downcast(padding[Layout::Parameter::width]), + safe_downcast(padding[Layout::Parameter::height]), + }, + { + safe_downcast(dilation[Layout::Parameter::width]), + safe_downcast(dilation[Layout::Parameter::height]), + }, + { + output_min, + output_max, + }, + { + safe_downcast(src_filter[Layout::Filter::width]), + safe_downcast(src_filter[Layout::Filter::height]), + safe_downcast(src_filter[Layout::Filter::width] * 4), + 0, + }, + }; + + context->dispatch( + command_buffer, + { + VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, + VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, + }, + VK_KERNEL(conv2d), + v_output.extents(), + // Write-only access bypasses synchronization but inserts appropriate + // barriers if necessary. + v_output.image( + command_buffer, + vTensor::Stage::Compute, + vTensor::Access::Write), + // Read-only access is implied on const tensors and triggers an async + // synchronization if necessary. + v_input.image( + command_buffer, + vTensor::Stage::Compute), + // Read-only access is implied on const tensors and triggers an async + // synchronization if necessary. + v_weight.image( + command_buffer, + vTensor::Stage::Compute), + // Read-only access is implied on const tensors and triggers an async + // synchronization if necessary. + v_bias.buffer( + command_buffer, + vTensor::Stage::Compute), + // Object lifetime is managed by the resource pool. + // It is OK not to keep track of the handle. + context->resource().pool.uniform(block).object); + } + else { + TORCH_CHECK(false, "Not implemented!"); + } +} + +void conv2d_old( + api::Context* const context, + api::Command::Buffer& command_buffer, + vTensor& v_output, + const vTensor& v_input, + const vTensor& v_weight, + const vTensor& v_bias, + const IntArrayRef filter, + const IntArrayRef stride, + const IntArrayRef padding, + const IntArrayRef dilation, + const float output_min, + const float output_max) { + using namespace api::utils; + if (v_output.has_image() && v_input.has_image() && v_weight.has_image()) { - const struct { - int32_t kernel_x, kernel_y, kernel_ic, kernel_oc; - int32_t stride_x, stride_y; + const int32_t W = v_input.extents().data[0]; + const int32_t H = v_input.extents().data[1]; + const int32_t C_4 = v_input.extents().data[2]; + const int32_t C = 4 * C_4; + + const int32_t OW = v_output.extents().data[0]; + const int32_t OH = v_output.extents().data[1]; + const int32_t OC_4 = v_output.extents().data[2]; + const int32_t OC = 4 * OC_4; + + const struct Block final { int32_t padding_x, padding_y; + int32_t kernel_x, kernel_y; + int32_t stride_x, stride_y; int32_t dilate_x, dilate_y; - float clamp_x, clamp_y; - int32_t src_filter_w, src_filter_h, src_filter_w4; + int32_t outputSize[4]; + int32_t inputSize[4]; + float outputMin; + float outputMax; } block { + safe_downcast(padding[Layout::Parameter::width]), + safe_downcast(padding[Layout::Parameter::height]), safe_downcast(filter[Layout::Filter::width]), safe_downcast(filter[Layout::Filter::height]), - safe_downcast(filter[Layout::Filter::input]), - safe_downcast(filter[Layout::Filter::output]), safe_downcast(stride[Layout::Parameter::width]), safe_downcast(stride[Layout::Parameter::height]), - safe_downcast(padding[Layout::Parameter::width]), - safe_downcast(padding[Layout::Parameter::height]), safe_downcast(dilation[Layout::Parameter::width]), safe_downcast(dilation[Layout::Parameter::height]), + { OW, OH, OC_4, OC }, + { W, H, C_4, C }, output_min, output_max, - safe_downcast(src_filter[Layout::Filter::width]), - safe_downcast(src_filter[Layout::Filter::height]), - safe_downcast(src_filter[Layout::Filter::width]*4), }; context->dispatch( @@ -630,29 +771,30 @@ void conv2d( VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, }, - VK_KERNEL(conv2d), + VK_KERNEL(conv2d_nogroup_clamp), + //VK_KERNEL(conv2d_nogroup_clamp_1x), v_output.extents(), // Write-only access bypasses synchronization but inserts appropriate // barriers if necessary. v_output.image( - command_buffer, - vTensor::Stage::Compute, - vTensor::Access::Write), + command_buffer, + vTensor::Stage::Compute, + vTensor::Access::Write), // Read-only access is implied on const tensors and triggers an async // synchronization if necessary. v_input.image( - command_buffer, - vTensor::Stage::Compute), + command_buffer, + vTensor::Stage::Compute), // Read-only access is implied on const tensors and triggers an async // synchronization if necessary. v_weight.image( - command_buffer, - vTensor::Stage::Compute), + command_buffer, + vTensor::Stage::Compute), // Read-only access is implied on const tensors and triggers an async // synchronization if necessary. v_bias.buffer( - command_buffer, - vTensor::Stage::Compute), + command_buffer, + vTensor::Stage::Compute), // Object lifetime is managed by the resource pool. // It is OK not to keep track of the handle. context->resource().pool.uniform(block).object); @@ -781,99 +923,6 @@ Conv2dOpContext Conv2dOpContext::create( }; } -void conv2d_old( - api::Context* const context, - api::Command::Buffer& command_buffer, - vTensor& v_output, - const vTensor& v_input, - const vTensor& v_weight, - const vTensor& v_bias, - const IntArrayRef filter, - const IntArrayRef stride, - const IntArrayRef padding, - const IntArrayRef dilation, - const float output_min, - const float output_max) { - - using namespace api::utils; - - if (v_output.has_image() && v_input.has_image() && v_weight.has_image()) { - const int32_t W = v_input.extents().data[0]; - const int32_t H = v_input.extents().data[1]; - const int32_t C_4 = v_input.extents().data[2]; - const int32_t C = 4 * C_4; - - const int32_t OW = v_output.extents().data[0]; - const int32_t OH = v_output.extents().data[1]; - const int32_t OC_4 = v_output.extents().data[2]; - const int32_t OC = 4 * OC_4; - - const struct { - int32_t padding_x, padding_y; - int32_t kernel_x, kernel_y; - int32_t stride_x, stride_y; - int32_t dilate_x, dilate_y; - int32_t outputSize[4]; - int32_t inputSize[4]; - float outputMin; - float outputMax; - } block { - safe_downcast(padding[Layout::Parameter::width]), - safe_downcast(padding[Layout::Parameter::height]), - safe_downcast(filter[Layout::Filter::width]), - safe_downcast(filter[Layout::Filter::height]), - safe_downcast(stride[Layout::Parameter::width]), - safe_downcast(stride[Layout::Parameter::height]), - safe_downcast(dilation[Layout::Parameter::width]), - safe_downcast(dilation[Layout::Parameter::height]), - { OW, OH, OC_4, OC }, - { W, H, C_4, C }, - output_min, - output_max, - }; - - context->dispatch( - command_buffer, - { - VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, - VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, - VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, - VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, - VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, - }, - VK_KERNEL(conv2d_nogroup_clamp), - //VK_KERNEL(conv2d_nogroup_clamp_1x), - v_output.extents(), - // Write-only access bypasses synchronization but inserts appropriate - // barriers if necessary. - v_output.image( - command_buffer, - vTensor::Stage::Compute, - vTensor::Access::Write), - // Read-only access is implied on const tensors and triggers an async - // synchronization if necessary. - v_input.image( - command_buffer, - vTensor::Stage::Compute), - // Read-only access is implied on const tensors and triggers an async - // synchronization if necessary. - v_weight.image( - command_buffer, - vTensor::Stage::Compute), - // Read-only access is implied on const tensors and triggers an async - // synchronization if necessary. - v_bias.buffer( - command_buffer, - vTensor::Stage::Compute), - // Object lifetime is managed by the resource pool. - // It is OK not to keep track of the handle. - context->resource().pool.uniform(block).object); - } - else { - TORCH_CHECK(false, "Not implemented!"); - } -} - Tensor Conv2dOpContext::run(const Tensor& input_arg) const { api::Context* const context = api::context(); @@ -896,11 +945,11 @@ Tensor Conv2dOpContext::run(const Tensor& input_arg) const { input.options(), }; - api::Command::Buffer command_buffer = context->command().pool.allocate(); - command_buffer.begin(); + api::Command::Pool& command_pool = context->command().pool; + api::Command::Buffer& command_buffer = command_pool.stream(); { if (is_depthwise(unpacked_.filter, unpacked_.groups)) { - conv2d_depthwise( + conv2d_dw( context, command_buffer, v_output, @@ -932,7 +981,7 @@ Tensor Conv2dOpContext::run(const Tensor& input_arg) const { packed_.output_max); } else { if (is_pointwise(unpacked_.filter)) { - conv2d_pointwise( + conv2d_pw( context, command_buffer, v_output, @@ -964,8 +1013,7 @@ Tensor Conv2dOpContext::run(const Tensor& input_arg) const { } } } - command_buffer.end(); - command_buffer.submit(context->gpu().queue); + command_pool.submit(context->gpu().queue, command_buffer); return convert(v_output); } diff --git a/aten/src/ATen/native/vulkan/ops/Copy.cpp b/aten/src/ATen/native/vulkan/ops/Copy.cpp index bbd326b42ace..1cf6b1ad6aa9 100644 --- a/aten/src/ATen/native/vulkan/ops/Copy.cpp +++ b/aten/src/ATen/native/vulkan/ops/Copy.cpp @@ -6,87 +6,96 @@ namespace vulkan { namespace ops { Tensor& copy_(Tensor& self, const Tensor& src) { - // X -> Vulkan - if (at::kVulkan == self.device().type()) { - vTensor& v_self = convert(self); - - // CPU -> Vulkan - if (at::kCPU == src.device().type()) { - // Requesting write-only host access to the tensor never triggers a sync - // as the contents will be overwritten regardless. Having said that, - // appropriate barriers are inserted automatically if WAR or WAW hazards - // are detected. Examples of such scenario for instance are if any of - // these async operations are on going in the background on 'self': - // - On discrete systems: - // * buffer-to-staging transfers - // * staging-to-buffer transfers - // - On UMA buffer is an alias for staging and accessible both on host - // and device. Consequently: - // * buffer-to-image NHWC -> NC4HW packing - // * image-to-buffer NC4HW -> NHWC unpacking - - using Future = vTensor::Future; - Future v_self_future = v_self.host(); - - // This wait() will be a no-op if no hazards are detected, including the - // obvious, yet important, special case of 'self' being an empty tensor. - - Future::Payload v_self_payload = v_self_future.wait(); - - memcpy( - v_self_payload.get(), - src.contiguous().data_ptr(), - std::min(src.nbytes(), self.nbytes())); + api::Context* const context = api::context(); + + api::Command::Pool& command_pool = context->command().pool; + api::Command::Buffer& command_buffer = command_pool.stream(); + { + // X -> Vulkan + if (at::kVulkan == self.device().type()) { + vTensor& v_self = convert(self); + + // Vulkan -> Vulkan + if (at::kVulkan == src.device().type()) { + command_buffer.copy( + // - Read-only access is implied on const tensors. Memory barriers + // are automatically inserted if a RAW hazard is detected. + // - Recording any potential pending sync operations into the same + // command buffer prevents an expensive queue submission. + convert(src).buffer( + command_buffer, + vTensor::Stage::Transfer), + // - Write-only access never triggers a sync as the contents will be + // overwritten regardless. Having said that, appropriate barriers + // are inserted automatically if WAR or WAW hazards are detected. + // - Recording pending sync operations into the same command buffer + // prevents an expensive queue submission. + v_self.buffer( + command_buffer, + vTensor::Stage::Transfer, + vTensor::Access::Write)); + + command_pool.submit(context->gpu().queue, command_buffer); + } + // CPU -> Vulkan + else { + const Tensor cpu_src = src.device().is_cpu() ? src : src.cpu(); + + // Requesting write-only host access to the tensor never triggers a sync + // as the contents will be overwritten regardless. Having said that, + // appropriate barriers are inserted automatically if WAR or WAW hazards + // are detected. Examples of such scenario for instance are if any of + // these async operations are on going in the background on 'self': + // - On discrete systems: + // * buffer-to-staging transfers + // * staging-to-buffer transfers + // - On UMA buffer is an alias for staging and accessible both on host + // and device. Consequently: + // * buffer-to-image NHWC -> NC4HW packing + // * image-to-buffer NC4HW -> NHWC unpacking + + using Future = vTensor::Future; + Future v_self_future = v_self.host(command_buffer); + + // Ideally we would have been able to put as much distance between + // requesting the data - a call to host() - and accessing the data + // - a call to wait() - but a local view of the computation graph + // in eager mode makes that optimization non-trivial. + + // This wait() will be a no-op if no hazards are detected, including the + // obvious, yet important, special case of 'self' being an empty tensor. + + Future::Payload v_self_payload = v_self_future.wait(); + + memcpy( + v_self_payload.get(), + cpu_src.contiguous().data_ptr(), + std::min(src.nbytes(), self.nbytes())); + } } - // Vulkan -> Vulkan + // Vulkan -> X else if (at::kVulkan == src.device().type()) { - api::Command::Buffer command_buffer = api::context()->command().pool.allocate(); - command_buffer.begin(); - - command_buffer.copy( - // - Read-only access is implied on const tensors. Memory barriers - // are automatically inserted if a RAW hazard is detected. - // - Recording any potential pending sync operations into the same - // command buffer prevents an expensive queue submission. - convert(src).buffer( - command_buffer, - vTensor::Stage::Transfer), - // - Write-only access never triggers a sync as the contents will be - // overwritten regardless. Having said that, appropriate barriers - // are inserted automatically if WAR or WAW hazards are detected. - // - Recording pending sync operations into the same command buffer - // prevents an expensive queue submission. - v_self.buffer( - command_buffer, - vTensor::Stage::Transfer, - vTensor::Access::Write)); - - command_buffer.end(); - command_buffer.submit(api::context()->gpu().queue); - } - else { - TORCH_INTERNAL_ASSERT(false, "Unsupported!"); - } - } - // Vulkan -> X - else if (at::kVulkan == src.device().type()) { - const vTensor& v_src = convert(src); - - { - // Similar notes as above applies, with the additional consideration of - // potential syncs on read accesses. Namely, - // - on discrete systems, if the (staging, buffer, image) trio, or - // - on UMA, if the (buffer, image) duo - // have gone out of sync as a result of one processor writing to one - // resource which is then either accessed as an another resource type on - // the same or another processor. Same considerations regarding hazard - // avoidance as above applies. - - using Future = vTensor::Future; - const Future v_src_future = v_src.host(); + const vTensor& v_src = convert(src); // Vulkan -> CPU - if (at::kCPU == self.device().type()) { + if (self.device().is_cpu()) { + // Similar notes as above applies, with the additional consideration of + // potential syncs on read accesses. Namely, + // - on discrete systems, if the (staging, buffer, image) trio, or + // - on UMA, if the (buffer, image) duo + // have gone out of sync as a result of one processor writing to one + // resource which is then either accessed as an another resource type on + // the same or another processor. Same considerations regarding hazard + // avoidance as above applies. + + using Future = vTensor::Future; + const Future v_src_future = v_src.host(command_buffer); + + // Ideally we would have been able to put as much distance between + // requesting the data - a call to host() - and accessing the data + // - a call to wait() - but a local view of the computation graph + // in eager mode makes that optimization non-trivial. + // This wait() is a no-op if data is not out of sync. More often than // not though, waits here are expected as the GPU catches up with // compute submitted from CPU. @@ -99,51 +108,56 @@ Tensor& copy_(Tensor& self, const Tensor& src) { std::min(src.nbytes(), self.nbytes())); } else { - TORCH_INTERNAL_ASSERT(false, "Unsupported!"); + TORCH_CHECK(false, "Unsupported!"); } - } - // - // WARNING - // - - // This is not great. We almost never want to flush the GPU pipeline as - // that has far reaching consequences, especially if PyTorch is not the only - // process accessing the GPU. If we have done our job properly, above - // synchronization mechanisms should be enough to ensure correctness at a more - // modest cost, as there is no need to flush the entirety of jobs in flight - // if one is only interested on waiting on computation affecting one single - // tensor to finish. - // - // Having said that, we still do need to release all pool resources at one - // point per inference run or we will run out of memory otherwise. There is - // no perfect answer to this problem that checks all boxes, which leaves us - // with one of several design decisions: - // - // 1) Use graph mode to gain an understanding of the computation graph, - // itself allowing us to place pool purges intelligently. Best option - // for performance and memory consumption. Not without its downsides if - // flexibility is a top priority. - // 2) If on eager mode, and hence are seeing operations one at a time, expose - // this release of resources to the user as a Python / C++ function. This - // makes for suboptimal user experience but is efficient in terms of - // performance. - // 3) If on eager mode, and interested in keeping this bookkeeping transparent - // to the user, release all resources somewhere ... like here. This is - // not ideal since it requires a pipeline flush to make sure these objects - // are not already in use by a workload in flight. Cannot do much better - // within the constraints of this approach. Good for user experience, - // suboptimal for performance. - // 4) If on eager mode, and interested in keeping this bookkeeping transparent - // to the user, and performance does not matter, make CPU and GPU run in - // lockstep. Obviously this is just bad. Mentioned for the sake of - // completeness. - - api::context()->flush(); - } - else { - TORCH_INTERNAL_ASSERT(false, "Unsupported!"); + // + // WARNING + // + + // This is not great. We almost never want to flush the GPU pipeline as + // that has far reaching consequences, especially if PyTorch is not the only + // process accessing the GPU. If we have done our job properly, above + // synchronization mechanisms should be enough to ensure correctness at a more + // modest cost, as there is no need to flush the entirety of jobs in flight + // if one is only interested on waiting on computation affecting one single + // tensor to finish. + // + // Having said that, we still do need to release all pool resources at one + // point per inference run or we will run out of memory otherwise. There is + // no perfect answer to this problem that checks all boxes, which leaves us + // with one of several design decisions: + // + // 1) Use graph mode to gain an understanding of the computation graph, + // itself allowing us to place pool purges intelligently. Best option + // for performance and memory consumption. Not without its downsides if + // flexibility is a top priority. + // 2) If on eager mode, and hence are seeing operations one at a time, expose + // this release of resources to the user as a Python / C++ function. This + // makes for suboptimal user experience but is efficient in terms of + // performance. + // 3) If on eager mode, and interested in keeping this bookkeeping transparent + // to the user, release all resources somewhere ... like here. This is + // not ideal since it requires a pipeline flush to make sure these objects + // are not already in use by a workload in flight. Cannot do much better + // within the constraints of this approach. Good for user experience, + // suboptimal for performance. + // 4) If on eager mode, and interested in keeping this bookkeeping transparent + // to the user, and performance does not matter, make CPU and GPU run in + // lockstep. Obviously this is just bad. Mentioned for the sake of + // completeness. + + context->flush(); + } + else { + TORCH_INTERNAL_ASSERT( + false, + "Invalid code path taken! Either the source or the destination tensor " + "was expected to be Vulkan a tensor! Incorrect dispatch?"); + } } + // No queue submission here. All queue submissions must have been handled + // above either explicitly or as a result of calling tensor.host(). return self; } diff --git a/aten/src/ATen/native/vulkan/ops/Mean.cpp b/aten/src/ATen/native/vulkan/ops/Mean.cpp index f6d63c14f381..6a413f55ded5 100644 --- a/aten/src/ATen/native/vulkan/ops/Mean.cpp +++ b/aten/src/ATen/native/vulkan/ops/Mean.cpp @@ -52,11 +52,11 @@ Tensor mean( v_input.options(), }; - api::Command::Buffer command_buffer = context->command().pool.allocate(); - command_buffer.begin(); + api::Command::Pool& command_pool = context->command().pool; + api::Command::Buffer& command_buffer = command_pool.stream(); { - if (v_input.has_image()) { - const struct { + if C10_LIKELY(v_input.has_image()) { + const struct Block final { uvec3 extents; int32_t range; ivec2 iextents; @@ -71,63 +71,35 @@ Tensor mean( }, }; - if (keepdim) { - context->dispatch( - command_buffer, - { - VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, - VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, - VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, - }, - VK_KERNEL(mean), - v_output.extents(), - // Write-only access bypasses synchronization but inserts appropriate - // barriers if necessary. - v_output.image( - command_buffer, - vTensor::Stage::Compute, - vTensor::Access::Write), - // Read-only access is implied on const tensors and triggers an async - // synchronization if necessary. - v_input.image( - command_buffer, - vTensor::Stage::Compute), - // Object lifetime is managed by the resource pool. - // It is OK not to keep track of the handle. - context->resource().pool.uniform(block).object); - } - else { - context->dispatch( - command_buffer, - { - VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, - VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, - VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, - }, - VK_KERNEL(mean2d), - v_output.extents(), - // Write-only access bypasses synchronization but inserts appropriate - // barriers if necessary. - v_output.image( - command_buffer, - vTensor::Stage::Compute, - vTensor::Access::Write), - // Read-only access is implied on const tensors and triggers an async - // synchronization if necessary. - v_input.image( - command_buffer, - vTensor::Stage::Compute), - // Object lifetime is managed by the resource pool. - // It is OK not to keep track of the handle. - context->resource().pool.uniform(block).object); - } + context->dispatch( + command_buffer, + { + VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, + }, + keepdim ? VK_KERNEL(mean) : VK_KERNEL(mean2d), + v_output.extents(), + // Write-only access bypasses synchronization but inserts appropriate + // barriers if necessary. + v_output.image( + command_buffer, + vTensor::Stage::Compute, + vTensor::Access::Write), + // Read-only access is implied on const tensors and triggers an async + // synchronization if necessary. + v_input.image( + command_buffer, + vTensor::Stage::Compute), + // Object lifetime is managed by the resource pool. + // It is OK not to keep track of the handle. + context->resource().pool.uniform(block).object); } else { TORCH_CHECK(false, "Not implemented!"); } } - command_buffer.end(); - command_buffer.submit(context->gpu().queue); + command_pool.submit(context->gpu().queue, command_buffer); return convert(v_output); } diff --git a/aten/src/ATen/native/vulkan/ops/Mm.cpp b/aten/src/ATen/native/vulkan/ops/Mm.cpp index fa1f7d45048c..e50f9724e665 100644 --- a/aten/src/ATen/native/vulkan/ops/Mm.cpp +++ b/aten/src/ATen/native/vulkan/ops/Mm.cpp @@ -10,18 +10,21 @@ namespace { using namespace api::utils; vTensor pack_weights( - api::Resource::Pool& pool, - const Tensor& weight_arg) { + api::Resource::Pool& pool, + const Tensor& weight_arg) { if (weight_arg.is_vulkan()) { return convert(weight_arg); } + api::Context* const context = api::context(); + api::Command::Buffer& command_buffer = context->command().pool.stream(); + const Tensor weight = weight_arg.contiguous(); const IntArrayRef w_sizes = weight.sizes(); const float* const src_weight_ptr = weight.data_ptr(); vTensor v_weight{ - api::context(), + context, &pool, w_sizes, weight.options(), @@ -29,7 +32,7 @@ vTensor pack_weights( { using Future = vTensor::Future; - Future v_weight_future = v_weight.host(); + Future v_weight_future = v_weight.host(command_buffer); Future::Payload v_weight_payload = v_weight_future.wait(); memcpy( @@ -49,16 +52,21 @@ vTensor pack_biases( return convert(*bias_arg); } + api::Context* const context = api::context(); + api::Command::Buffer& command_buffer = context->command().pool.stream(); + vTensor v_bias{ - api::context(), + context, &pool, - {weight_arg.sizes()[Layout::Parameter::width]}, + { + weight_arg.size(Layout::Parameter::width), + }, weight_arg.options(), }; { using Future = vTensor::Future; - Future v_bias_future = v_bias.host(); + Future v_bias_future = v_bias.host(command_buffer); Future::Payload v_bias_payload = v_bias_future.wait(); if (bias_arg) { @@ -66,7 +74,8 @@ vTensor pack_biases( v_bias_payload.get(), bias_arg->contiguous().data_ptr(), std::min(bias_arg->nbytes(), v_bias.nbytes())); - } else { + } + else { memset( v_bias_payload.get(), // 2's complement integers and IEEE-754 floating point numbers both @@ -162,11 +171,11 @@ Tensor mm( mat1.options(), }; - api::Command::Buffer command_buffer = context->command().pool.allocate(); - command_buffer.begin(); + api::Command::Pool& command_pool = context->command().pool; + api::Command::Buffer& command_buffer = command_pool.stream(); { - if (v_mat1.has_image() && v_mat2.has_image()) { - const struct { + if C10_LIKELY(v_mat1.has_image() && v_mat2.has_image()) { + const struct Block final { uvec3 size; int32_t K; } block { @@ -203,12 +212,12 @@ Tensor mm( // Object lifetime is managed by the resource pool. // It is OK not to keep track of the handle. context->resource().pool.uniform(block).object); - } else { + } + else { TORCH_CHECK(false, "Not implemented!"); } } - command_buffer.end(); - command_buffer.submit(context->gpu().queue); + command_pool.submit(context->gpu().queue, command_buffer); return convert(v_output); } @@ -281,14 +290,15 @@ Tensor LinearOpContext::run( input.options(), }; - api::Command::Buffer command_buffer = context->command().pool.allocate(); - command_buffer.begin(); + api::Command::Pool& command_pool = context->command().pool; + api::Command::Buffer& command_buffer = command_pool.stream(); { - if (v_output.has_image() && + if C10_LIKELY( + v_output.has_image() && v_input.has_image() && packed_.v_weight.has_image() && packed_.v_bias.has_image()) { - const struct { + const struct Block final { uvec3 size; int32_t K; vec2 multiplier; @@ -341,8 +351,7 @@ Tensor LinearOpContext::run( TORCH_CHECK(false, "Not implemented!"); } } - command_buffer.end(); - command_buffer.submit(context->gpu().queue); + command_pool.submit(context->gpu().queue, command_buffer); return convert(v_output); } diff --git a/aten/src/ATen/native/vulkan/ops/Mul.cpp b/aten/src/ATen/native/vulkan/ops/Mul.cpp index 84226135929a..1e494287a5ae 100644 --- a/aten/src/ATen/native/vulkan/ops/Mul.cpp +++ b/aten/src/ATen/native/vulkan/ops/Mul.cpp @@ -23,11 +23,11 @@ Tensor mul_scalar( v_self.options(), }; - api::Command::Buffer command_buffer = context->command().pool.allocate(); - command_buffer.begin(); + api::Command::Pool& command_pool = context->command().pool; + api::Command::Buffer& command_buffer = command_pool.stream(); { - if (v_output.has_image() && v_self.has_image()) { - const struct { + if C10_LIKELY(v_output.has_image() && v_self.has_image()) { + const struct Block final { uvec3 extents; float other; } block { @@ -63,8 +63,7 @@ Tensor mul_scalar( TORCH_CHECK(false, "Not implemented!"); } } - command_buffer.end(); - command_buffer.submit(context->gpu().queue); + command_pool.submit(context->gpu().queue, command_buffer); return convert(v_output); } @@ -80,11 +79,11 @@ Tensor& mul_scalar_( vTensor& v_self = convert(self); - api::Command::Buffer command_buffer = context->command().pool.allocate(); - command_buffer.begin(); + api::Command::Pool& command_pool = context->command().pool; + api::Command::Buffer& command_buffer = command_pool.stream(); { - if (v_self.has_image()) { - const struct { + if C10_LIKELY(v_self.has_image()) { + const struct Block final { uvec3 extents; float other; } block { @@ -114,8 +113,7 @@ Tensor& mul_scalar_( TORCH_CHECK(false, "Not implemented!"); } } - command_buffer.end(); - command_buffer.submit(context->gpu().queue); + command_pool.submit(context->gpu().queue, command_buffer); return self; } diff --git a/aten/src/ATen/native/vulkan/ops/Pool.cpp b/aten/src/ATen/native/vulkan/ops/Pool.cpp index 0bc97d6741bc..5eaaf9d04171 100644 --- a/aten/src/ATen/native/vulkan/ops/Pool.cpp +++ b/aten/src/ATen/native/vulkan/ops/Pool.cpp @@ -33,10 +33,10 @@ Tensor adaptive_avg_pool2d( v_self.options(), }; - api::Command::Buffer command_buffer = context->command().pool.allocate(); - command_buffer.begin(); + api::Command::Pool& command_pool = context->command().pool; + api::Command::Buffer& command_buffer = command_pool.stream(); { - if (v_self.has_image()) { + if C10_LIKELY(v_self.has_image()) { const uvec3 v_output_size = v_output.extents(); const uvec3 v_self_size = v_self.extents(); @@ -45,7 +45,7 @@ Tensor adaptive_avg_pool2d( static_cast(v_self_size.data[1u]) / v_output_size.data[1u], }; - const struct { + const struct Block final { uvec3 size; uint32_t _; vec2 stride; @@ -88,8 +88,7 @@ Tensor adaptive_avg_pool2d( TORCH_CHECK(false, "Not implemented!"); } } - command_buffer.end(); - command_buffer.submit(context->gpu().queue); + command_pool.submit(context->gpu().queue, command_buffer); return convert(v_output); } @@ -171,13 +170,11 @@ Tensor avg_pool2d( v_self.options(), }; - api::Command::Buffer command_buffer = context->command().pool.allocate(); - command_buffer.begin(); + api::Command::Pool& command_pool = context->command().pool; + api::Command::Buffer& command_buffer = command_pool.stream(); { - using namespace utils; - - if (v_self.has_image()) { - const struct { + if C10_LIKELY(v_self.has_image()) { + const struct Block final { uvec3 extents; int32_t range; ivec2 iextents; @@ -235,8 +232,7 @@ Tensor avg_pool2d( TORCH_CHECK(false, "Not implemented!"); } } - command_buffer.end(); - command_buffer.submit(context->gpu().queue); + command_pool.submit(context->gpu().queue, command_buffer); return convert(v_output); } diff --git a/aten/src/ATen/native/vulkan/ops/Shape.cpp b/aten/src/ATen/native/vulkan/ops/Shape.cpp index 8edfda60b76f..9d2a248f0707 100644 --- a/aten/src/ATen/native/vulkan/ops/Shape.cpp +++ b/aten/src/ATen/native/vulkan/ops/Shape.cpp @@ -21,8 +21,8 @@ Tensor view( self.options(), }; - api::Command::Buffer command_buffer = context->command().pool.allocate(); - command_buffer.begin(); + api::Command::Pool& command_pool = context->command().pool; + api::Command::Buffer& command_buffer = command_pool.stream(); { command_buffer.copy( // Read-only access is implied on const tensors and triggers an async @@ -37,8 +37,7 @@ Tensor view( vTensor::Stage::Transfer, vTensor::Access::Write)); } - command_buffer.end(); - command_buffer.submit(context->gpu().queue); + command_pool.submit(context->gpu().queue, command_buffer); return convert(v_output); } diff --git a/aten/src/ATen/native/vulkan/ops/Tensor.cpp b/aten/src/ATen/native/vulkan/ops/Tensor.cpp index 3570834b8dd4..0bf7acbe7dee 100644 --- a/aten/src/ATen/native/vulkan/ops/Tensor.cpp +++ b/aten/src/ATen/native/vulkan/ops/Tensor.cpp @@ -419,31 +419,19 @@ vTensor::vTensor( }) { } -const vTensor* vTensor::host() const { - view_->staging(Stage::Host, Access::Read); +const vTensor* vTensor::host( + api::Command::Buffer& command_buffer) const { + view_->staging(command_buffer, Stage::Host, Access::Read); return this; } -vTensor* vTensor::host(const Access::Flags access) { - view_->staging(Stage::Host, access); +vTensor* vTensor::host( + api::Command::Buffer& command_buffer, + const Access::Flags access) { + view_->staging(command_buffer, Stage::Host, access); return this; } -vTensor::Buffer::Object vTensor::buffer( - const Stage::Flags stage) const & { - return view_->buffer( - stage, - Access::Read).object; -} - -vTensor::Buffer::Object vTensor::buffer( - const Stage::Flags stage, - const Access::Flags access) & { - return view_->buffer( - stage, - access).object; -} - vTensor::Buffer::Object vTensor::buffer( api::Command::Buffer& command_buffer, const Stage::Flags stage) const & { @@ -463,21 +451,6 @@ vTensor::Buffer::Object vTensor::buffer( access).object; } -vTensor::Image::Object vTensor::image( - const Stage::Flags stage) const & { - return view_->image( - stage, - Access::Read).object; -} - -vTensor::Image::Object vTensor::image( - const Stage::Flags stage, - const Access::Flags access) & { - return view_->image( - stage, - access).object; -} - vTensor::Image::Object vTensor::image( api::Command::Buffer& command_buffer, const Stage::Flags stage) const & { @@ -535,16 +508,8 @@ vTensor::View::View( ops::verify(options); } -// We typically do not know whether we need a command buffer to service a request -// until we have perfomed a bunch of checks in nested logic, and even then we -// may end up with the always issued state transition optimized away under -// certain conditions, which makes a policy of always allocating a command buffer -// up front, only to end up using it at times, a wasteful approach. This class -// answers that need. - class vTensor::View::CMD final { public: - explicit CMD(const View&); CMD(const View&, api::Command::Buffer&); CMD(const CMD&) = delete; CMD& operator=(const CMD&) = delete; @@ -578,60 +543,18 @@ class vTensor::View::CMD final { const Image::Object& image, Buffer::Object& buffer); - void submit(Fence fence = {}); - - private: - api::Command::Buffer& command_buffer(); + void submit(Fence fence); private: const View& view_; - - enum class Type { - Internal, - External, - } type; - - union _ final { - api::Command::Buffer internal; - api::Command::Buffer* external; - ~_() {} - } command_buffer_; + api::Command::Buffer& command_buffer_; }; -vTensor::View::CMD::CMD( - const View& view) - : view_(view), - type(Type::Internal), - command_buffer_{} { -} - vTensor::View::CMD::CMD( const View& view, - api::Command::Buffer& external) + api::Command::Buffer& command_buffer) : view_(view), - type(Type::External), - command_buffer_{ - .external = &external, - } { -} - -api::Command::Buffer& vTensor::View::CMD::command_buffer() { - switch (type) { - case Type::Internal: - if (!command_buffer_.internal) { - command_buffer_.internal = view_.context_->command().pool.allocate(); - command_buffer_.internal.begin(); - } - - return command_buffer_.internal; - - case Type::External: - return *(command_buffer_.external); - - default: - TORCH_INTERNAL_ASSERT(false, "Unknown command buffer type!"); - break; - } + command_buffer_(command_buffer) { } void vTensor::View::CMD::barrier(State::Transition transition) { @@ -761,7 +684,7 @@ void vTensor::View::CMD::barrier(State::Transition transition) { barrier.stage.src = VK_PIPELINE_STAGE_BOTTOM_OF_PIPE_BIT; } - command_buffer().barrier(barrier); + command_buffer_.barrier(barrier); } } @@ -789,7 +712,7 @@ void vTensor::View::CMD::copy_buffer_to_staging( {}, })); - command_buffer().copy(buffer, staging); + command_buffer_.copy(buffer, staging); } void vTensor::View::CMD::copy_staging_to_buffer( @@ -816,7 +739,7 @@ void vTensor::View::CMD::copy_staging_to_buffer( {}, })); - command_buffer().copy(staging, buffer); + command_buffer_.copy(staging, buffer); } void vTensor::View::CMD::copy_buffer_to_image( @@ -847,7 +770,7 @@ void vTensor::View::CMD::copy_buffer_to_image( const uvec3 extents = view_.extents(); const uint32_t plane = extents.data[0u] * extents.data[1u]; - const struct { + const struct Block final { uvec3 extents; uint32_t block; uvec4 offset; @@ -863,7 +786,7 @@ void vTensor::View::CMD::copy_buffer_to_image( }; view_.context_->dispatch( - command_buffer(), + command_buffer_, { VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, @@ -904,7 +827,7 @@ void vTensor::View::CMD::copy_image_to_buffer( const uvec3 extents = view_.extents(); const uint32_t plane = extents.data[0u] * extents.data[1u]; - const struct { + const struct Block final { uvec3 extents; uint32_t block; uvec4 offset; @@ -920,7 +843,7 @@ void vTensor::View::CMD::copy_image_to_buffer( }; view_.context_->dispatch( - command_buffer(), + command_buffer_, { VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, @@ -934,10 +857,10 @@ void vTensor::View::CMD::copy_image_to_buffer( } void vTensor::View::CMD::submit(const api::Resource::Fence fence) { - if ((Type::Internal == type) && command_buffer_.internal) { - command_buffer_.internal.end(); - command_buffer_.internal.submit(view_.context_->gpu().queue, fence); - } + view_.context_->command().pool.submit( + view_.context_->gpu().queue, + command_buffer_, + fence); } vTensor::Buffer& vTensor::View::buffer() const { @@ -953,38 +876,28 @@ vTensor::Buffer& vTensor::View::buffer() const { } vTensor::Buffer& vTensor::View::buffer( + api::Command::Buffer& command_buffer, const Stage::Flags stage, const Access::Flags access) const { - CMD command_buffer(*this); - Buffer& buffer = this->buffer(command_buffer, stage, access); - command_buffer.submit(); - - return buffer; -} - -vTensor::Buffer& vTensor::View::buffer( - api::Command::Buffer& command_buffer_, - const Stage::Flags stage, - const Access::Flags access) const { - CMD command_buffer(*this, command_buffer_); - return buffer(command_buffer, stage, access); + CMD cmd(*this, command_buffer); + return buffer(cmd, stage, access); } vTensor::Buffer& vTensor::View::buffer( - CMD& command_buffer, + CMD& cmd, const Stage::Flags stage, const Access::Flags access) const { if ((access & Access::Read) && state_.is_dirty(Component::Buffer)) { if (state_.is_clean(Component::Staging)) { - command_buffer.copy_staging_to_buffer( + cmd.copy_staging_to_buffer( state_, - staging(command_buffer, Stage::Transfer, Access::Read).object, + staging(cmd, Stage::Transfer, Access::Read).object, buffer().object); } else if (state_.is_clean(Component::Image)) { - command_buffer.copy_image_to_buffer( + cmd.copy_image_to_buffer( state_, - image(command_buffer, Stage::Compute, Access::Read).object, + image(cmd, Stage::Compute, Access::Read).object, buffer().object); } else { @@ -994,7 +907,7 @@ vTensor::Buffer& vTensor::View::buffer( } } - command_buffer.barrier( + cmd.barrier( state_.transition({ // Staging {}, @@ -1028,35 +941,25 @@ vTensor::Image& vTensor::View::image() const { } vTensor::Image& vTensor::View::image( + api::Command::Buffer& command_buffer, const Stage::Flags stage, const Access::Flags access) const { - CMD command_buffer(*this); - Image& image = this->image(command_buffer, stage, access); - command_buffer.submit(); - - return image; -} - -vTensor::Image& vTensor::View::image( - api::Command::Buffer& command_buffer_, - const Stage::Flags stage, - const Access::Flags access) const { - CMD command_buffer(*this, command_buffer_); - return image(command_buffer, stage, access); + CMD cmd(*this, command_buffer); + return image(cmd, stage, access); } vTensor::Image& vTensor::View::image( - CMD& command_buffer, + CMD& cmd, const Stage::Flags stage, const Access::Flags access) const { if ((access & Access::Read) && state_.is_dirty(Component::Image)) { - command_buffer.copy_buffer_to_image( + cmd.copy_buffer_to_image( state_, - buffer(command_buffer, stage, Access::Read).object, + buffer(cmd, stage, Access::Read).object, image().object); } - command_buffer.barrier( + cmd.barrier( state_.transition({ // Staging {}, @@ -1096,27 +999,28 @@ vTensor::Buffer& vTensor::View::staging() const { } vTensor::Buffer& vTensor::View::staging( + api::Command::Buffer& command_buffer, const Stage::Flags stage, const Access::Flags access) const { - CMD command_buffer(*this); - Buffer& staging = this->staging(command_buffer, stage, access); - command_buffer.submit(fence()); + CMD cmd(*this, command_buffer); + Buffer& staging = this->staging(cmd, stage, access); + cmd.submit(fence(access)); return staging; } vTensor::Buffer& vTensor::View::staging( - CMD& command_buffer, + CMD& cmd, const Stage::Flags stage, const Access::Flags access) const { if ((access & Access::Read) && state_.is_dirty(Component::Staging)) { - command_buffer.copy_buffer_to_staging( + cmd.copy_buffer_to_staging( state_, - buffer(command_buffer, Stage::Transfer, Access::Read).object, + buffer(cmd, Stage::Transfer, Access::Read).object, staging().object); } - command_buffer.barrier( + cmd.barrier( state_.transition({ // Staging { @@ -1138,6 +1042,14 @@ vTensor::Buffer& vTensor::View::staging( return staging(); } +vTensor::Fence& vTensor::View::fence(const Access::Flags access) const { + if (access & Access::Read) { + fence_ = allocate_fence(&context_->resource().pool); + } + + return fence_; +} + vTensor::Memory& vTensor::View::wait() const { if (fence_) { fence_.wait(); @@ -1146,10 +1058,6 @@ vTensor::Memory& vTensor::View::wait() const { return staging().memory; } -vTensor::Fence& vTensor::View::fence() const { - return (fence_ = allocate_fence(pool_)); -} - void vTensor::View::verify() const { TORCH_INTERNAL_ASSERT(!image_ || state_.is_available(Component::Image)); TORCH_INTERNAL_ASSERT(!staging_ || state_.is_discrete()); diff --git a/aten/src/ATen/native/vulkan/ops/Tensor.h b/aten/src/ATen/native/vulkan/ops/Tensor.h index 48d4cca84dd4..f404988b420b 100644 --- a/aten/src/ATen/native/vulkan/ops/Tensor.h +++ b/aten/src/ATen/native/vulkan/ops/Tensor.h @@ -157,10 +157,10 @@ class vTensor final { */ template - Future host() const &; + Future host(api::Command::Buffer&) const &; template - Future host() &; + Future host(api::Command::Buffer&) &; /* Device access - these functions will be expensive if they trigger a buffer @@ -178,14 +178,10 @@ class vTensor final { predictability of usage and efficiency. */ - Buffer::Object buffer(Stage::Flags) const &; - Buffer::Object buffer(Stage::Flags, Access::Flags) &; Buffer::Object buffer(api::Command::Buffer&, Stage::Flags) const &; Buffer::Object buffer(api::Command::Buffer&, Stage::Flags, Access::Flags) &; bool has_image() const; - Image::Object image(Stage::Flags) const &; - Image::Object image(Stage::Flags, Access::Flags) &; Image::Object image(api::Command::Buffer&, Stage::Flags) const &; Image::Object image(api::Command::Buffer&, Stage::Flags, Access::Flags) &; @@ -210,26 +206,22 @@ class vTensor final { Host */ - const vTensor* host() const; - vTensor* host(Access::Flags access); + const vTensor* host(api::Command::Buffer&) const; + vTensor* host(api::Command::Buffer&, Access::Flags); template - Future host() const && = delete; + Future host(api::Command::Buffer&) const && = delete; template - Future host() && = delete; + Future host(api::Command::Buffer&) && = delete; /* Device */ - Buffer::Object buffer(Stage::Flags) const && = delete; - Buffer::Object buffer(Stage::Flags, Access::Flags) && = delete; Buffer::Object buffer(api::Command::Buffer&, Stage::Flags) const && = delete; Buffer::Object buffer(api::Command::Buffer&, Stage::Flags, Access::Flags) && = delete; - Image::Object image(Stage::Flags) const && = delete; - Image::Object image(Stage::Flags, Access::Flags) && = delete; Image::Object image(api::Command::Buffer&, Stage::Flags) const && = delete; Image::Object image(api::Command::Buffer&, Stage::Flags, Access::Flags) && = delete; @@ -249,21 +241,22 @@ class vTensor final { ~View() = default; /* - Device + Buffer */ - Buffer& buffer(Stage::Flags, Access::Flags) const; Buffer& buffer(api::Command::Buffer&, Stage::Flags, Access::Flags) const; + /* + Image + */ + bool has_image() const; - Image& image(Stage::Flags, Access::Flags) const; Image& image(api::Command::Buffer&, Stage::Flags, Access::Flags) const; /* Host */ - Buffer& staging(Stage::Flags, Access::Flags) const; Buffer& staging(api::Command::Buffer&, Stage::Flags, Access::Flags) const; vTensor::Memory& wait() const; @@ -343,7 +336,7 @@ class vTensor final { Image& image(CMD&, Stage::Flags, Access::Flags) const; Buffer& staging() const; Buffer& staging(CMD&, Stage::Flags, Access::Flags) const; - Fence& fence() const; + Fence& fence(Access::Flags) const; // Validation void verify() const; @@ -485,13 +478,15 @@ vTensor::Future::wait() const & { } template -inline vTensor::Future vTensor::host() const & { - return Future(host()); +inline vTensor::Future +vTensor::host(api::Command::Buffer& command_buffer) const & { + return Future(host(command_buffer)); } template -inline vTensor::Future vTensor::host() & { - return Future(host(kAccess)); +inline vTensor::Future +vTensor::host(api::Command::Buffer& command_buffer) & { + return Future(host(command_buffer, kAccess)); } inline bool vTensor::has_image() const { diff --git a/aten/src/ATen/native/vulkan/ops/Upsample.cpp b/aten/src/ATen/native/vulkan/ops/Upsample.cpp index 32508c01eec1..00cefc1bdf53 100644 --- a/aten/src/ATen/native/vulkan/ops/Upsample.cpp +++ b/aten/src/ATen/native/vulkan/ops/Upsample.cpp @@ -36,11 +36,11 @@ Tensor upsample_nearest2d( input.options(), }; - api::Command::Buffer command_buffer = context->command().pool.allocate(); - command_buffer.begin(); + api::Command::Pool& command_pool = context->command().pool; + api::Command::Buffer& command_buffer = command_pool.stream(); { - if (v_input.has_image()) { - const struct { + if C10_LIKELY(v_input.has_image()) { + const struct Block final { uvec3 extents; uint32_t _; ivec2 iextents; @@ -92,8 +92,7 @@ Tensor upsample_nearest2d( TORCH_CHECK(false, "Not implemented!"); } } - command_buffer.end(); - command_buffer.submit(context->gpu().queue); + command_pool.submit(context->gpu().queue, command_buffer); return convert(v_output); } diff --git a/aten/src/ATen/native/vulkan/ops/Utils.h b/aten/src/ATen/native/vulkan/ops/Utils.h index ffdc2b6e94eb..de218cfc472a 100644 --- a/aten/src/ATen/native/vulkan/ops/Utils.h +++ b/aten/src/ATen/native/vulkan/ops/Utils.h @@ -10,7 +10,7 @@ namespace vulkan { namespace ops { namespace utils { -int64_t normalize( +inline int64_t normalize( const int64_t dimension, const int64_t n) { return (dimension % n + n) % n; diff --git a/c10/core/TensorImpl.cpp b/c10/core/TensorImpl.cpp index 5b4a4f3b83e6..98a2f74e9399 100644 --- a/c10/core/TensorImpl.cpp +++ b/c10/core/TensorImpl.cpp @@ -67,7 +67,6 @@ TensorImpl::TensorImpl(DispatchKeySet key_set, const caffe2::TypeMeta data_type, TensorImpl::TensorImpl(Storage&& storage, DispatchKeySet key_set, const caffe2::TypeMeta data_type, c10::optional device_opt) : storage_(std::move(storage)), - sizes_{0}, storage_offset_(0), numel_(0), data_type_(data_type), @@ -91,15 +90,14 @@ TensorImpl::TensorImpl(Storage&& storage, DispatchKeySet key_set, const caffe2:: // we would also like to check that non-cpu devices have an index, but some Caffe2 operators create // Storages with default devices. - strides_.push_back(1); } IntArrayRef TensorImpl::sizes() const { - return sizes_; + return sizes_and_strides_.sizes_arrayref(); } IntArrayRef TensorImpl::strides() const { - return strides_; + return sizes_and_strides_.strides_arrayref(); } bool TensorImpl::compute_contiguous() const { @@ -108,9 +106,10 @@ bool TensorImpl::compute_contiguous() const { return is_contiguous; int64_t z = 1; for (int64_t d = dim() - 1; d >= 0; d--) { - if (sizes_[d] != 1) { - if (strides_[d] == z) { - z *= sizes_[d]; + const auto size_d = sizes_and_strides_.size_at_unchecked(d); + if (size_d != 1) { + if (sizes_and_strides_.stride_at_unchecked(d) == z) { + z *= size_d; } else { is_contiguous = false; break; @@ -123,16 +122,17 @@ bool TensorImpl::compute_contiguous() const { bool TensorImpl::compute_channels_last_contiguous_2d() const { // Please don't combine these code, constant array is used here to let // compiler fully unroll the loop to get better performance - switch (sizes_.size()) { + switch (sizes_and_strides_.size()) { case 4: { int64_t expected = 1; for (auto& d : {1, 3, 2, 0}) { - if (sizes_[d] != 1) { - if (strides_[d] != expected) { + const auto size_d = sizes_and_strides_.size_at_unchecked(d); + if (size_d != 1) { + if (sizes_and_strides_.stride_at_unchecked(d) != expected) { return false; } - expected *= sizes_[d]; + expected *= size_d; } } return true; @@ -148,16 +148,17 @@ bool TensorImpl::compute_channels_last_contiguous_2d() const { bool TensorImpl::compute_channels_last_contiguous_3d() const { // Please don't combine these code, constant array is used here to let // compiler fully unroll the loop to get better performance - switch (sizes_.size()) { + switch (sizes_and_strides_.size()) { case 5: { int64_t expected = 1; for (auto& d : {1, 4, 3, 2, 0}) { - if (sizes_[d] != 1) { - if (strides_[d] != expected) { + const auto size_d = sizes_and_strides_.size_at_unchecked(d); + if (size_d != 1) { + if (sizes_and_strides_.stride_at_unchecked(d) != expected) { return false; } - expected *= sizes_[d]; + expected *= size_d; } } return true; @@ -171,16 +172,16 @@ bool TensorImpl::compute_channels_last_contiguous_3d() const { } bool TensorImpl::compute_strides_like_channels_last_2d() const { - return is_channels_last_strides_2d(sizes_, strides_); + return is_channels_last_strides_2d(TensorImpl::sizes(), TensorImpl::strides()); } bool TensorImpl::compute_strides_like_channels_last_3d() const { - return is_channels_last_strides_3d(sizes_, strides_); + return is_channels_last_strides_3d(TensorImpl::sizes(), TensorImpl::strides()); } bool TensorImpl::compute_non_overlapping_and_dense() const { if (dim() == 1) { - return sizes_[0] < 2 || strides_[0] == 1; + return sizes_and_strides_.size_at_unchecked(0) < 2 || sizes_and_strides_.stride_at_unchecked(0) == 1; } SmallVector perm; perm.resize(dim()); @@ -189,22 +190,23 @@ bool TensorImpl::compute_non_overlapping_and_dense() const { } // Sort by strides, leaving 0 and 1 sized dims at the end of the array std::sort(perm.begin(), perm.end(), [&](int64_t a, int64_t b) { - if (sizes_[a] < 2) { + if (sizes_and_strides_.size_at_unchecked(a) < 2) { return false; - } else if (sizes_[b] < 2) { + } else if (sizes_and_strides_.size_at_unchecked(b) < 2) { return true; } - return strides_[a] < strides_[b]; + return sizes_and_strides_.stride_at_unchecked(a) < sizes_and_strides_.stride_at_unchecked(b); }); auto require_stride = 1; for (int64_t i = 0; i < dim(); i ++) { - if (sizes_[perm[i]] < 2) { + const auto size_perm_i = sizes_and_strides_.size_at_unchecked(perm[i]); + if (size_perm_i < 2) { return true; } - if (strides_[perm[i]] != require_stride) { + if (sizes_and_strides_.stride_at_unchecked(perm[i]) != require_stride) { return false; } - require_stride *= sizes_[perm[i]]; + require_stride *= size_perm_i; } return true; } @@ -217,17 +219,17 @@ void TensorImpl::release_resources() { } int64_t TensorImpl::dim() const { - return sizes_.size(); + return sizes_and_strides_.size(); } int64_t TensorImpl::size(int64_t d) const { d = at::maybe_wrap_dim(d, dim(), false); - return sizes_[d]; + return sizes_and_strides_.size_at_unchecked(d); } int64_t TensorImpl::stride(int64_t d) const { d = at::maybe_wrap_dim(d, dim(), false); - return strides_[d]; + return sizes_and_strides_.stride_at_unchecked(d); } bool TensorImpl::has_storage() const { @@ -337,8 +339,7 @@ void TensorImpl::copy_tensor_metadata_except_version_counter( TensorImpl* dest_impl, bool allow_tensor_metadata_change) { dest_impl->storage_ = src_impl->storage_; - dest_impl->sizes_ = src_impl->sizes_; - dest_impl->strides_ = src_impl->strides_; + dest_impl->sizes_and_strides_ = src_impl->sizes_and_strides_; dest_impl->storage_offset_ = src_impl->storage_offset_; dest_impl->data_type_ = src_impl->data_type_; dest_impl->device_opt_ = src_impl->device_opt_; diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index e7f9c1260263..d8b803f906f7 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include @@ -10,8 +11,10 @@ #include #include #include +#include #include + #include #include #include @@ -244,6 +247,21 @@ struct C10_API VariableVersion { } }; +/** + * NOTE: Some TensorImpl methods are small and not overridden in the + * PyTorch codebase itself, but may theoretically need to be + * overridden by third-party TensorImpl subclasses. This macro allows + * users that need maximum performance and don't need these extension + * points to disable them with a build-time flag. (In particular, + * XLA's XLATensorImpl currently overrides these methods, so we can't + * enable this flag by default.) + */ +#ifdef C10_DISABLE_TENSORIMPL_EXTENSIBILITY +#define TENSORIMPL_MAYBE_VIRTUAL +#else +#define TENSORIMPL_MAYBE_VIRTUAL virtual +#endif + /** * The low-level representation of a tensor, which contains a pointer * to a storage (which contains the actual data) and metadata (e.g., sizes and @@ -412,7 +430,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { * is no longer true; numel always accurately reports the product * of sizes of a tensor. */ - virtual int64_t numel() const { + TENSORIMPL_MAYBE_VIRTUAL int64_t numel() const { #ifdef DEBUG TORCH_INTERNAL_ASSERT(compute_numel() == numel_); #endif @@ -564,9 +582,6 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { /** * Set whether or not a tensor requires gradient. - * - * It is only valid to call this method on a Variable. - * See Note [Tensor versus Variable in C++]. */ void set_requires_grad(bool requires_grad); @@ -576,27 +591,18 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { * we can automatically differentiate back to them. A tensor that * requires gradient and has no history is a "leaf" tensor, which we * accumulate gradients into. - * - * It is only valid to call this method on a Variable. - * See Note [Tensor versus Variable in C++]. */ bool requires_grad() const; /** * Return a mutable reference to the gradient. This is conventionally * used as `t.grad() = x` to set a gradient to a completely new tensor. - * - * It is only valid to call this method on a Variable. - * See Note [Tensor versus Variable in C++]. */ at::Tensor& mutable_grad(); /** * Return the accumulated gradient of a tensor. This gradient is written * into when performing backwards, when this tensor is a leaf tensor. - * - * It is only valid to call this method on a Variable. - * See Note [Tensor versus Variable in C++]. */ const at::Tensor& grad() const; @@ -747,7 +753,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { */ virtual void set_size(int64_t dim, int64_t new_size) { TORCH_CHECK(allow_tensor_metadata_change(), "set_size ", err_msg_tensor_metadata_change_not_allowed); - sizes_.at(dim) = new_size; + sizes_and_strides_.size_at(dim) = new_size; refresh_numel(); refresh_contiguous(); } @@ -760,7 +766,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { */ virtual void set_stride(int64_t dim, int64_t new_stride) { TORCH_CHECK(allow_tensor_metadata_change(), "set_stride ", err_msg_tensor_metadata_change_not_allowed); - strides_[dim] = new_stride; + sizes_and_strides_.stride_at_unchecked(dim) = new_stride; refresh_contiguous(); } @@ -785,12 +791,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { */ void set_sizes_contiguous(IntArrayRef new_size) { TORCH_CHECK(allow_tensor_metadata_change(), "set_sizes_contiguous ", err_msg_tensor_metadata_change_not_allowed); - auto new_dim = new_size.size(); - sizes_.resize(new_dim); - for (size_t dim = 0; dim < new_dim; ++dim) { - sizes_[dim] = new_size[dim]; - } + sizes_and_strides_.set_sizes(new_size); refresh_numel(); empty_tensor_restride(MemoryFormat::Contiguous); @@ -812,27 +814,25 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { ") must match dimensionality of strides (", new_stride.size(), ")"); - auto new_dim = new_size.size(); + const auto new_dim = new_size.size(); - sizes_.resize(new_dim); - for (size_t dim = 0; dim < new_dim; ++dim) { - sizes_[dim] = new_size[dim]; - } + sizes_and_strides_.set_sizes(new_size); - strides_.resize(new_dim); if (new_dim > 0) { for (size_t dim = new_dim - 1; ; dim--) { if (new_stride[dim] >= 0) { - strides_[dim] = new_stride[dim]; + sizes_and_strides_.stride_at_unchecked(dim) = new_stride[dim]; } else { // XXX: This behavior is surprising and may need to be removed to // support negative strides. Some pytorch functions rely on it: // for example, torch.cat (run TestTorch.test_cat_empty). if (dim == new_dim - 1) { - strides_[dim] = 1; + sizes_and_strides_.stride_at_unchecked(dim) = 1; } else { // Keep stride monotonically increasing to match NumPy. - strides_[dim] = std::max(sizes_[dim + 1], 1) * strides_[dim + 1]; + sizes_and_strides_.stride_at_unchecked(dim) = + std::max(sizes_and_strides_.size_at_unchecked(dim + 1), 1) * + sizes_and_strides_.stride_at_unchecked(dim + 1); } } if (dim == 0) break; @@ -1060,12 +1060,13 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { * This op is auto-asynchronous if the underlying device (CUDA) supports it. */ void Extend(int64_t num, float growthPct) { - TORCH_CHECK(sizes_.size() >= 1u); + TORCH_CHECK(sizes_and_strides_.size() >= 1u); TORCH_CHECK(num >= 0, "`num` must be non-negative for Extend"); TORCH_CHECK( is_contiguous_, "Right now Extend is only supported for contiguous Tensor."); - auto newDims = sizes_; + using SizesVector = SmallVector; + SizesVector newDims(sizes_and_strides_.sizes_begin(), sizes_and_strides_.sizes_end()); newDims[0] += num; if (!storage_.data()) { Resize(newDims); @@ -1077,16 +1078,15 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { static_cast(1), std::multiplies()); if (newNumel * data_type_.itemsize() <= storage_.nbytes()) { - sizes_ = newDims; + sizes_and_strides_.set_sizes(newDims); numel_ = newNumel; return; } - auto newCapacity = sizes_; + SizesVector newCapacity(sizes_and_strides_.sizes_begin(), sizes_and_strides_.sizes_end()); newCapacity[0] = std::max( - newDims[0], static_cast(std::ceil(sizes_[0] * (1 + growthPct / 100)))); + newDims[0], static_cast(std::ceil(sizes_and_strides_.size_at_unchecked(0) * (1 + growthPct / 100)))); auto oldData = std::move(storage_.data_ptr()); auto oldSize = numel_; - auto oldDims = sizes_; Resize(newCapacity); auto* newData = raw_mutable_data(data_type_); if (data_type_.copy()) { @@ -1113,7 +1113,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { true); // non-blocking } reserved_ = true; - sizes_ = newDims; + sizes_and_strides_.set_sizes(newDims); numel_ = newNumel; } @@ -1130,7 +1130,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { "Right now ReserveSpace is only supported for contiguous Tensor."); TORCH_CHECK( storage_.unique(), "Can't call ReserveSpace on shared storage."); - auto newCapacity = sizes_; + // TODO: eliminate newCapacity. + SmallVector newCapacity(sizes_and_strides_.sizes_begin(), sizes_and_strides_.sizes_end()); newCapacity[0] = outer_dim; auto newNumel = std::accumulate( newCapacity.begin(), @@ -1143,11 +1144,11 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { // Old data is discarded storage_.data_ptr().clear(); auto oldSize = numel_; - auto oldDims = sizes_; + SmallVector oldDims(sizes_and_strides_.sizes_begin(), sizes_and_strides_.sizes_end()); Resize(newCapacity); // Allocate new memory but don't copy over the data raw_mutable_data(data_type_); - sizes_ = oldDims; + sizes_and_strides_.set_sizes(oldDims); numel_ = oldSize; reserved_ = true; } @@ -1217,7 +1218,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { " The old caffe2 mixes Reshape and Resize but this behavior has " "been changed. If you find this error, most likely you will need " "to change corresponding code from Reshape to Resize."); - sizes_ = dims; + sizes_and_strides_.set_sizes(dims); empty_tensor_restride(MemoryFormat::Contiguous); } @@ -1432,12 +1433,12 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { case MemoryFormat::Contiguous: { // dim_ is a virtual call, don't repeat it const auto dim_ = dim(); - strides_.resize(dim_); + sizes_and_strides_.resize(dim_); if (dim_ > 0) { const auto last_idx = dim_ - 1; - strides_[last_idx] = 1; + sizes_and_strides_.stride_at_unchecked(last_idx) = 1; for (auto i = last_idx - 1; i >= 0; --i) { - strides_[i] = strides_[i + 1] * std::max(sizes_[i + 1], 1); + sizes_and_strides_.stride_at_unchecked(i) = sizes_and_strides_.stride_at_unchecked(i + 1) * std::max(sizes_and_strides_.size_at_unchecked(i + 1), 1); } } break; @@ -1495,11 +1496,11 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { typename = typename std::enable_if::value>::type> bool SetDimsTemplate(ArrayRef src) { auto old_numel = numel_; - sizes_.resize(src.size()); + sizes_and_strides_.resize(src.size()); int64_t new_numel = 1; for (size_t i = 0; i < src.size(); ++i) { new_numel *= src[i]; - sizes_[i] = src[i]; + sizes_and_strides_.size_at_unchecked(i) = src[i]; } numel_ = new_numel; empty_tensor_restride(MemoryFormat::Contiguous); @@ -1696,12 +1697,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { // occurs in THPVariable_clear in torch/csrc/autograd/python_variable.cpp PyObject* pyobj_ = nullptr; - // We could save a word or two by combining the SmallVector structs, - // since their size is redundant, and if we need to overflow the buffer space - // we could keep the two pointers together. However, that would require - // implementing another struct from scratch, so only do this if we're desperate. - SmallVector sizes_; - SmallVector strides_; + c10::impl::SizesAndStrides sizes_and_strides_; int64_t storage_offset_ = 0; // If sizes and strides are empty, the numel is 1!! However, most of the @@ -1833,22 +1829,17 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { // autograd metadata pointer // version counter pointer // PyObject pointer -// sizes SmallVector (begin) -// sizes SmallVector (end) -// sizes SmallVector (capacity) -// sizes SmallVector (pre-allocated 0) -// sizes SmallVector (pre-allocated 1) -// sizes SmallVector (pre-allocated 2) -// sizes SmallVector (pre-allocated 3) -// sizes SmallVector (pre-allocated 4) -// strides SmallVector (begin) -// strides SmallVector (end) -// strides SmallVector (capacity) -// strides SmallVector (pre-allocated 0) -// strides SmallVector (pre-allocated 1) -// strides SmallVector (pre-allocated 2) -// strides SmallVector (pre-allocated 3) -// strides SmallVector (pre-allocated 4) +// SizesAndStrides size/pointer +// SizesAndStrides sizes (pre-allocated 0) +// SizesAndStrides sizes (pre-allocated 1) +// SizesAndStrides sizes (pre-allocated 2) +// SizesAndStrides sizes (pre-allocated 3) +// SizesAndStrides sizes (pre-allocated 4) +// SizesAndStrides strides (pre-allocated 0) +// SizesAndStrides strides (pre-allocated 1) +// SizesAndStrides strides (pre-allocated 2) +// SizesAndStrides strides (pre-allocated 3) +// SizesAndStrides strides (pre-allocated 4) // storage offset // numel // data type @@ -1857,7 +1848,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { // miscellaneous bitfield // static_assert(sizeof(void*) != sizeof(int64_t) || // if 64-bit... - sizeof(TensorImpl) == sizeof(int64_t) * 29, + sizeof(TensorImpl) == sizeof(int64_t) * 24, "You changed the size of TensorImpl on 64-bit arch." "See Note [TensorImpl size constraints] on how to proceed."); } // namespace c10 diff --git a/c10/core/impl/SizesAndStrides.cpp b/c10/core/impl/SizesAndStrides.cpp new file mode 100644 index 000000000000..bf7ec3ff887d --- /dev/null +++ b/c10/core/impl/SizesAndStrides.cpp @@ -0,0 +1,66 @@ +#include + +namespace c10 { +namespace impl { + +void SizesAndStrides::resizeSlowPath(const size_t newSize, const size_t oldSize) { + if (newSize <= C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!isInline(), "resizeSlowPath called when fast path should have been hit!"); + int64_t* tempStorage = outOfLineStorage_; + memcpy( + &inlineStorage_[0], + &tempStorage[0], + C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE * sizeof(inlineStorage_[0])); + memcpy( + &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE], + &tempStorage[oldSize], + C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE * sizeof(inlineStorage_[0])); + // CANNOT USE freeOutOfLineStorage() HERE! outOfLineStorage_ + // HAS BEEN OVERWRITTEN! + free(tempStorage); + } else { + if (isInline()) { + // CANNOT USE allocateOutOfLineStorage(newSize) HERE! WOULD + // OVERWRITE inlineStorage_! + int64_t* tempStorage = static_cast(malloc(storageBytes(newSize))); + TORCH_CHECK(tempStorage, "Could not allocate memory to change Tensor SizesAndStrides!"); + const auto bytesToCopy = oldSize * sizeof(inlineStorage_[0]); + const auto bytesToZero = (newSize > oldSize) ? (newSize - oldSize) * sizeof(tempStorage[0]) : 0; + memcpy(&tempStorage[0], &inlineStorage_[0], bytesToCopy); + if (bytesToZero) { + memset(&tempStorage[oldSize], 0, bytesToZero); + } + memcpy(&tempStorage[newSize], &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE], bytesToCopy); + if (bytesToZero) { + memset(&tempStorage[newSize + oldSize], 0, bytesToZero); + } + outOfLineStorage_ = tempStorage; + } else { + const bool isGrowing = oldSize < newSize; + if (isGrowing) { + // Resize before shifting so that we have room. + resizeOutOfLineStorage(newSize); + } + // Shift the old strides to their new starting point. Note + // that this does not occur in the inline path above because + // the stride starting point is not moving. + memmove( + outOfLineStorage_ + newSize, + outOfLineStorage_ + oldSize, + std::min(oldSize, newSize) * sizeof(outOfLineStorage_[0])); + if (!isGrowing) { + // Resize after shifting so that we don't lose data. + resizeOutOfLineStorage(newSize); + } else { + // Zero the end of the sizes portion. + const auto bytesToZero = (newSize - oldSize) * sizeof(outOfLineStorage_[0]); + memset(&outOfLineStorage_[oldSize], 0, bytesToZero); + memset(&outOfLineStorage_[newSize + oldSize], 0, bytesToZero); + } + } + } + size_ = newSize; +} + +} // namespace impl +} // namespace c10 diff --git a/c10/core/impl/SizesAndStrides.h b/c10/core/impl/SizesAndStrides.h new file mode 100644 index 000000000000..4f7e19330aca --- /dev/null +++ b/c10/core/impl/SizesAndStrides.h @@ -0,0 +1,293 @@ +#pragma once + +#include +#include + +#include +#include +#include + +#define C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE 5 + +namespace c10 { +namespace impl { + +// Packed container for TensorImpl sizes and strides. +// This design improves on the previous approach of using a pair of +// c10::SmallVector by specializing for the operations we +// actually use and enforcing that the number of sizes is the same as +// the number of strides. The memory layout is as follows: +// +// 1 size_t for the size +// 5 eightbytes of inline sizes and 5 eightbytes of inline strides, OR pointer to out-of-line array +class C10_API SizesAndStrides { + public: + // TODO: different iterator types for sizes & strides to prevent + // mixing the two accidentally. + using sizes_iterator = int64_t*; + using sizes_const_iterator = const int64_t*; + using strides_iterator = int64_t*; + using strides_const_iterator = const int64_t*; + + SizesAndStrides() : size_(1) { + size_at_unchecked(0) = 0; + stride_at_unchecked(0) = 1; + } + + ~SizesAndStrides() { + if (C10_UNLIKELY(!isInline())) { + free(outOfLineStorage_); + } + } + + SizesAndStrides(const SizesAndStrides& rhs) : size_(rhs.size_) { + if (C10_LIKELY(rhs.isInline())) { + copyDataInline(rhs); + } else { + allocateOutOfLineStorage(size_); + copyDataOutline(rhs); + } + } + + SizesAndStrides& operator=(const SizesAndStrides& rhs) { + if (this == &rhs) { + return *this; + } + if (C10_LIKELY(rhs.isInline())) { + if (C10_UNLIKELY(!isInline())) { + free(outOfLineStorage_); + } + copyDataInline(rhs); + } else { + if (isInline()) { + allocateOutOfLineStorage(rhs.size_); + } else { + resizeOutOfLineStorage(rhs.size_); + } + copyDataOutline(rhs); + } + size_ = rhs.size_; + return *this; + } + + // Move from rhs. rhs.size() == 0 afterwards. + SizesAndStrides(SizesAndStrides&& rhs) noexcept : size_(rhs.size_) { + if (C10_LIKELY(isInline())) { + memcpy(inlineStorage_, rhs.inlineStorage_, sizeof(inlineStorage_)); + } else { + outOfLineStorage_ = rhs.outOfLineStorage_; + rhs.outOfLineStorage_ = nullptr; + } + + rhs.size_ = 0; + } + + // Move from rhs. rhs.size() == 0 afterwards. + SizesAndStrides& operator=(SizesAndStrides&& rhs) noexcept { + if (this == &rhs) { + return *this; + } + if (C10_LIKELY(rhs.isInline())) { + if (C10_UNLIKELY(!isInline())) { + free(outOfLineStorage_); + } + copyDataInline(rhs); + } else { + // They're outline. We're going to steal their vector. + if (!isInline()) { + free(outOfLineStorage_); + } + outOfLineStorage_ = rhs.outOfLineStorage_; + rhs.outOfLineStorage_ = nullptr; + } + size_ = rhs.size_; + rhs.size_ = 0; + + return *this; + } + + size_t size() const noexcept { + return size_; + } + + const int64_t* sizes_data() const noexcept { + if (C10_LIKELY(isInline())) { + return &inlineStorage_[0]; + } else { + return &outOfLineStorage_[0]; + } + } + + int64_t* sizes_data() noexcept { + if (C10_LIKELY(isInline())) { + return &inlineStorage_[0]; + } else { + return &outOfLineStorage_[0]; + } + } + + sizes_const_iterator sizes_begin() const noexcept { + return sizes_data(); + } + + sizes_iterator sizes_begin() noexcept { + return sizes_data(); + } + + sizes_const_iterator sizes_end() const noexcept { + return sizes_begin() + size(); + } + + sizes_iterator sizes_end() noexcept { + return sizes_begin() + size(); + } + + IntArrayRef sizes_arrayref() const noexcept { + return IntArrayRef{sizes_data(), size()}; + } + + void set_sizes(IntArrayRef newSizes) { + resize(newSizes.size()); + std::copy(newSizes.begin(), newSizes.end(), sizes_begin()); + } + + const int64_t* strides_data() const noexcept { + if (C10_LIKELY(isInline())) { + return &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE]; + } else { + return &outOfLineStorage_[size()]; + } + } + + int64_t* strides_data() noexcept { + if (C10_LIKELY(isInline())) { + return &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE]; + } else { + return &outOfLineStorage_[size()]; + } + } + + strides_const_iterator strides_begin() const noexcept { + if (C10_LIKELY(isInline())) { + return &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE]; + } else { + return &outOfLineStorage_[size()]; + } + } + + strides_iterator strides_begin() noexcept { + if (C10_LIKELY(isInline())) { + return &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE]; + } else { + return &outOfLineStorage_[size()]; + } + } + + strides_const_iterator strides_end() const noexcept { + return strides_begin() + size(); + } + + strides_iterator strides_end() noexcept { + return strides_begin() + size(); + } + + IntArrayRef strides_arrayref() const noexcept { + return IntArrayRef{strides_data(), size()}; + } + + // Size accessors. + int64_t size_at(size_t idx) const noexcept { + assert(idx < size()); + return sizes_data()[idx]; + } + + int64_t& size_at(size_t idx) noexcept { + assert(idx < size()); + return sizes_data()[idx]; + } + + int64_t size_at_unchecked(size_t idx) const noexcept { + return sizes_data()[idx]; + } + + int64_t& size_at_unchecked(size_t idx) noexcept { + return sizes_data()[idx]; + } + + // Size accessors. + int64_t stride_at(size_t idx) const noexcept { + assert(idx < size()); + return strides_data()[idx]; + } + + int64_t& stride_at(size_t idx) noexcept { + assert(idx < size()); + return strides_data()[idx]; + } + + int64_t stride_at_unchecked(size_t idx) const noexcept { + return strides_data()[idx]; + } + + int64_t& stride_at_unchecked(size_t idx) noexcept { + return strides_data()[idx]; + } + + void resize(size_t newSize) { + const auto oldSize = size(); + if (newSize == oldSize) { + return; + } + if (C10_LIKELY(newSize <= C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE && isInline())) { + if (oldSize < newSize) { + const auto bytesToZero = (newSize - oldSize) * sizeof(inlineStorage_[0]); + memset(&inlineStorage_[oldSize], 0, bytesToZero); + memset(&inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE + oldSize], 0, bytesToZero); + } + size_ = newSize; + } else { + resizeSlowPath(newSize, oldSize); + } + } + + void resizeSlowPath(size_t newSize, size_t oldSize); + + private: + bool isInline() const noexcept { + return size_ <= C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE; + } + + void copyDataInline(const SizesAndStrides& rhs) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.isInline()); + memcpy(inlineStorage_, rhs.inlineStorage_, sizeof(inlineStorage_)); + } + + static size_t storageBytes(size_t size) noexcept { + return size * 2 * sizeof(int64_t); + } + + void allocateOutOfLineStorage(size_t size) { + outOfLineStorage_ = static_cast(malloc(storageBytes(size))); + TORCH_CHECK(outOfLineStorage_, "Could not allocate memory for Tensor SizesAndStrides!"); + } + + void resizeOutOfLineStorage(size_t newSize) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!isInline()); + outOfLineStorage_ = static_cast(realloc(outOfLineStorage_, storageBytes(newSize))); + TORCH_CHECK(outOfLineStorage_, "Could not allocate memory for Tensor SizesAndStrides!"); + } + + void copyDataOutline(const SizesAndStrides& rhs) noexcept { + memcpy(outOfLineStorage_, rhs.outOfLineStorage_, storageBytes(rhs.size_)); + } + + size_t size_; + union { + int64_t *outOfLineStorage_; + int64_t inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE * 2]{}; + }; + +}; + +} // namespace impl +} // namespace c10 diff --git a/c10/test/core/impl/SizesAndStrides_test.cpp b/c10/test/core/impl/SizesAndStrides_test.cpp new file mode 100644 index 000000000000..94e90c42feff --- /dev/null +++ b/c10/test/core/impl/SizesAndStrides_test.cpp @@ -0,0 +1,399 @@ +#include + +#include + +using namespace c10; +using namespace c10::impl; + +static void checkData(const SizesAndStrides& sz, IntArrayRef sizes, IntArrayRef strides) { + EXPECT_EQ(sizes.size(), strides.size()) << "bad test case: size() of sizes and strides don't match"; + EXPECT_EQ(sz.size(), sizes.size()); + + int idx = 0; + for (auto x: sizes) { + EXPECT_EQ(sz.size_at_unchecked(idx), x) << "index: " << idx; + EXPECT_EQ(sz.size_at(idx), x) << "index: " << idx; + EXPECT_EQ(sz.sizes_data()[idx], x) << "index: " << idx; + EXPECT_EQ(*(sz.sizes_begin() + idx), x) << "index: " << idx; + idx++; + } + EXPECT_EQ(sz.sizes_arrayref(), sizes); + + idx = 0; + for (auto x: strides) { + EXPECT_EQ(sz.stride_at_unchecked(idx), x) << "index: " << idx; + EXPECT_EQ(sz.stride_at(idx), x) << "index: " << idx; + EXPECT_EQ(sz.strides_data()[idx], x) << "index: " << idx; + EXPECT_EQ(*(sz.strides_begin() + idx), x) << "index: " << idx; + + idx++; + } + EXPECT_EQ(sz.strides_arrayref(), strides); +} + +TEST(SizesAndStridesTest, DefaultConstructor) { + SizesAndStrides sz; + checkData(sz, {0}, {1}); + // Can't test size_at() out of bounds because it just asserts for now. +} + +TEST(SizesAndStridesTest, SetSizes) { + SizesAndStrides sz; + sz.set_sizes({5, 6, 7, 8}); + checkData(sz, {5, 6, 7, 8}, {1, 0, 0, 0}); +} + +TEST(SizesAndStridesTest, Resize) { + SizesAndStrides sz; + + sz.resize(2); + + // Small to small growing. + checkData(sz, {0, 0}, {1, 0}); + + // Small to small growing, again. + sz.resize(5); + checkData(sz, {0, 0, 0, 0, 0}, {1, 0, 0, 0, 0}); + + for (int ii = 0; ii < sz.size(); ++ii) { + sz.size_at_unchecked(ii) = ii + 1; + sz.stride_at_unchecked(ii) = 2 * (ii + 1); + } + + checkData(sz, {1, 2, 3, 4, 5}, {2, 4, 6, 8, 10}); + + // Small to small, shrinking. + sz.resize(4); + checkData(sz, {1, 2, 3, 4}, {2, 4, 6, 8}); + + // Small to small with no size change. + sz.resize(4); + checkData(sz, {1, 2, 3, 4}, {2, 4, 6, 8}); + + // Small to small, growing back so that we can confirm that our "new" + // data really does get zeroed. + sz.resize(5); + checkData(sz, {1, 2, 3, 4, 0}, {2, 4, 6, 8, 0}); + + // Small to big. + sz.resize(6); + + checkData(sz, {1, 2, 3, 4, 0, 0}, {2, 4, 6, 8, 0, 0}); + + sz.size_at_unchecked(5) = 6; + sz.stride_at_unchecked(5) = 12; + + checkData(sz, {1, 2, 3, 4, 0, 6}, {2, 4, 6, 8, 0, 12}); + + // Big to big, growing. + sz.resize(7); + + checkData(sz, {1, 2, 3, 4, 0, 6, 0}, {2, 4, 6, 8, 0, 12, 0}); + + // Big to big with no size change. + sz.resize(7); + + checkData(sz, {1, 2, 3, 4, 0, 6, 0}, {2, 4, 6, 8, 0, 12, 0}); + + sz.size_at_unchecked(6) = 11; + sz.stride_at_unchecked(6) = 22; + + checkData(sz, {1, 2, 3, 4, 0, 6, 11}, {2, 4, 6, 8, 0, 12, 22}); + + // Big to big, shrinking. + sz.resize(6); + checkData(sz, {1, 2, 3, 4, 0, 6}, {2, 4, 6, 8, 0, 12}); + + // Grow back to make sure "new" elements get zeroed in big mode too. + sz.resize(7); + checkData(sz, {1, 2, 3, 4, 0, 6, 0}, {2, 4, 6, 8, 0, 12, 0}); + + // Finally, big to small. + + // Give it different data than it had when it was small to avoid + // getting it right by accident (i.e., because of leftover inline + // storage when going small to big). + for (int ii = 0; ii < sz.size(); ++ii) { + sz.size_at_unchecked(ii) = ii - 1; + sz.stride_at_unchecked(ii) = 2 * (ii - 1); + } + + checkData(sz, {-1, 0, 1, 2, 3, 4, 5}, {-2, 0, 2, 4, 6, 8, 10}); + + sz.resize(5); + checkData(sz, {-1, 0, 1, 2, 3}, {-2, 0, 2, 4, 6}); +} + +TEST(SizesAndStridesTest, SetAtIndex) { + SizesAndStrides sz; + + sz.resize(5); + sz.size_at(4) = 42; + sz.stride_at(4) = 23; + + checkData(sz, {0, 0, 0, 0, 42}, {1, 0, 0, 0, 23}); + + sz.resize(6); + sz.size_at(5) = 43; + sz.stride_at(5) = 24; + + checkData(sz, {0, 0, 0, 0, 42, 43}, {1, 0, 0, 0, 23, 24}); +} + +TEST(SizesAndStridesTest, SetAtIterator) { + SizesAndStrides sz; + + sz.resize(5); + *(sz.sizes_begin() + 4) = 42; + *(sz.strides_begin() + 4) = 23; + + checkData(sz, {0, 0, 0, 0, 42}, {1, 0, 0, 0, 23}); + + sz.resize(6); + *(sz.sizes_begin() + 5) = 43; + *(sz.strides_begin() + 5) = 24; + + checkData(sz, {0, 0, 0, 0, 42, 43}, {1, 0, 0, 0, 23, 24}); +} + +TEST(SizesAndStridesTest, SetViaData) { + SizesAndStrides sz; + + sz.resize(5); + *(sz.sizes_data() + 4) = 42; + *(sz.strides_data() + 4) = 23; + + checkData(sz, {0, 0, 0, 0, 42}, {1, 0, 0, 0, 23}); + + sz.resize(6); + *(sz.sizes_data() + 5) = 43; + *(sz.strides_data() + 5) = 24; + + checkData(sz, {0, 0, 0, 0, 42, 43}, {1, 0, 0, 0, 23, 24}); +} + +static SizesAndStrides makeSmall(int offset = 0) { + SizesAndStrides small; + small.resize(3); + for (int ii = 0; ii < small.size(); ++ii) { + small.size_at_unchecked(ii) = ii + 1 + offset; + small.stride_at_unchecked(ii) = 2 * (ii + 1 + offset); + } + + return small; +} + +static SizesAndStrides makeBig(int offset = 0) { + SizesAndStrides big; + big.resize(8); + for (int ii = 0; ii < big.size(); ++ii) { + big.size_at_unchecked(ii) = ii - 1 + offset; + big.stride_at_unchecked(ii) = 2 * (ii - 1 + offset); + } + + return big; +} + +static void checkSmall(const SizesAndStrides& sm, int offset = 0) { + std::vector sizes(3), strides(3); + for (int ii = 0; ii < 3; ++ii) { + sizes[ii] = ii + 1 + offset; + strides[ii] = 2 * (ii + 1 + offset); + } + checkData(sm, sizes, strides); +} + +static void checkBig(const SizesAndStrides& big, int offset = 0) { + std::vector sizes(8), strides(8); + for (int ii = 0; ii < 8; ++ii) { + sizes[ii] = ii - 1 + offset; + strides[ii] = 2 * (ii - 1 + offset); + } + checkData(big, sizes, strides); +} + +TEST(SizesAndStridesTest, MoveConstructor) { + SizesAndStrides empty; + + SizesAndStrides movedEmpty(std::move(empty)); + + EXPECT_EQ(empty.size(), 0); + EXPECT_EQ(movedEmpty.size(), 1); + checkData(movedEmpty, {0}, {1}); + + SizesAndStrides small = makeSmall(); + checkSmall(small); + + SizesAndStrides movedSmall(std::move(small)); + checkSmall(movedSmall); + EXPECT_EQ(small.size(), 0); + + SizesAndStrides big = makeBig(); + checkBig(big); + + SizesAndStrides movedBig(std::move(big)); + checkBig(movedBig); + EXPECT_EQ(big.size(), 0); +} + +TEST(SizesAndStridesTest, CopyConstructor) { + SizesAndStrides empty; + + SizesAndStrides copiedEmpty(empty); + + EXPECT_EQ(empty.size(), 1); + EXPECT_EQ(copiedEmpty.size(), 1); + checkData(empty, {0}, {1}); + checkData(copiedEmpty, {0}, {1}); + + SizesAndStrides small = makeSmall(); + checkSmall(small); + + SizesAndStrides copiedSmall(small); + checkSmall(copiedSmall); + checkSmall(small); + + SizesAndStrides big = makeBig(); + checkBig(big); + + SizesAndStrides copiedBig(big); + checkBig(big); + checkBig(copiedBig); +} + +TEST(SizesAndStridesTest, CopyAssignmentSmallToSmall) { + SizesAndStrides smallTarget = makeSmall(); + SizesAndStrides smallCopyFrom = makeSmall(1); + + checkSmall(smallTarget); + checkSmall(smallCopyFrom, 1); + + smallTarget = smallCopyFrom; + + checkSmall(smallTarget, 1); + checkSmall(smallCopyFrom, 1); +} + +TEST(SizesAndStridesTest, MoveAssignmentSmallToSmall) { + SizesAndStrides smallTarget = makeSmall(); + SizesAndStrides smallMoveFrom = makeSmall(1); + + checkSmall(smallTarget); + checkSmall(smallMoveFrom, 1); + + smallTarget = std::move(smallMoveFrom); + + checkSmall(smallTarget, 1); + EXPECT_EQ(smallMoveFrom.size(), 0); +} + +TEST(SizesAndStridesTest, CopyAssignmentSmallToBig) { + SizesAndStrides bigTarget = makeBig(); + SizesAndStrides smallCopyFrom = makeSmall(); + + checkBig(bigTarget); + checkSmall(smallCopyFrom); + + bigTarget = smallCopyFrom; + + checkSmall(bigTarget); + checkSmall(smallCopyFrom); +} + +TEST(SizesAndStridesTest, MoveAssignmentSmallToBig) { + SizesAndStrides bigTarget = makeBig(); + SizesAndStrides smallMoveFrom = makeSmall(); + + checkBig(bigTarget); + checkSmall(smallMoveFrom); + + bigTarget = std::move(smallMoveFrom); + + checkSmall(bigTarget); + EXPECT_EQ(smallMoveFrom.size(), 0); +} + +TEST(SizesAndStridesTest, CopyAssignmentBigToBig) { + SizesAndStrides bigTarget = makeBig(); + SizesAndStrides bigCopyFrom = makeBig(1); + + checkBig(bigTarget); + checkBig(bigCopyFrom, 1); + + bigTarget = bigCopyFrom; + + checkBig(bigTarget, 1); + checkBig(bigCopyFrom, 1); +} + +TEST(SizesAndStridesTest, MoveAssignmentBigToBig) { + SizesAndStrides bigTarget = makeBig(); + SizesAndStrides bigMoveFrom = makeBig(1); + + checkBig(bigTarget); + checkBig(bigMoveFrom, 1); + + bigTarget = std::move(bigMoveFrom); + + checkBig(bigTarget, 1); + EXPECT_EQ(bigMoveFrom.size(), 0); +} + +TEST(SizesAndStridesTest, CopyAssignmentBigToSmall) { + SizesAndStrides smallTarget = makeSmall(); + SizesAndStrides bigCopyFrom = makeBig(); + + checkSmall(smallTarget); + checkBig(bigCopyFrom); + + smallTarget = bigCopyFrom; + + checkBig(smallTarget); + checkBig(bigCopyFrom); +} + +TEST(SizesAndStridesTest, MoveAssignmentBigToSmall) { + SizesAndStrides smallTarget = makeSmall(); + SizesAndStrides bigMoveFrom = makeBig(); + + checkSmall(smallTarget); + checkBig(bigMoveFrom); + + smallTarget = std::move(bigMoveFrom); + + checkBig(smallTarget); + EXPECT_EQ(bigMoveFrom.size(), 0); +} + +TEST(SizesAndStridesTest, CopyAssignmentSelf) { + SizesAndStrides small = makeSmall(); + SizesAndStrides big = makeBig(); + + checkSmall(small); + checkBig(big); + + small = small; + checkSmall(small); + + big = big; + checkBig(big); +} + +// Avoid failures due to -Wall -Wself-move. +static void selfMove(SizesAndStrides& x, SizesAndStrides& y) { + x = std::move(y); +} + +TEST(SizesAndStridesTest, MoveAssignmentSelf) { + SizesAndStrides small = makeSmall(); + SizesAndStrides big = makeBig(); + + checkSmall(small); + checkBig(big); + + selfMove(small, small); + checkSmall(small); + + selfMove(big, big); + checkBig(big); +} diff --git a/docker.Makefile b/docker.Makefile index 3af77ab9c7d1..6b843fa9c1b3 100644 --- a/docker.Makefile +++ b/docker.Makefile @@ -17,12 +17,14 @@ BASE_DEVEL = nvidia/cuda:$(CUDA_VERSION)-cudnn$(CUDNN_VERSION)-de INSTALL_CHANNEL = pytorch PYTHON_VERSION = 3.7 +PYTORCH_VERSION = $(shell git describe --tags) # Can be either official / dev BUILD_TYPE = dev BUILD_PROGRESS = auto BUILD_ARGS = --build-arg BASE_IMAGE=$(BASE_IMAGE) \ --build-arg PYTHON_VERSION=$(PYTHON_VERSION) \ --build-arg CUDA_VERSION=$(CUDA_VERSION) \ + --build-arg PYTORCH_VERSION=$(PYTORCH_VERSION) \ --build-arg INSTALL_CHANNEL=$(INSTALL_CHANNEL) EXTRA_DOCKER_BUILD_FLAGS ?= DOCKER_BUILD = DOCKER_BUILDKIT=1 \ @@ -39,26 +41,26 @@ all: devel-image .PHONY: devel-image devel-image: BASE_IMAGE := $(BASE_DEVEL) -devel-image: DOCKER_TAG := $(shell git describe --tags)-devel +devel-image: DOCKER_TAG := $(PYTORCH_VERSION)-devel devel-image: $(DOCKER_BUILD) .PHONY: devel-image devel-push: BASE_IMAGE := $(BASE_DEVEL) -devel-push: DOCKER_TAG := $(shell git describe --tags)-devel +devel-push: DOCKER_TAG := $(PYTORCH_VERSION)-devel devel-push: $(DOCKER_PUSH) .PHONY: runtime-image runtime-image: BASE_IMAGE := $(BASE_RUNTIME) -runtime-image: DOCKER_TAG := $(shell git describe --tags)-runtime +runtime-image: DOCKER_TAG := $(PYTORCH_VERSION)-runtime runtime-image: $(DOCKER_BUILD) docker tag $(DOCKER_FULL_NAME):$(DOCKER_TAG) $(DOCKER_FULL_NAME):latest .PHONY: runtime-image runtime-push: BASE_IMAGE := $(BASE_RUNTIME) -runtime-push: DOCKER_TAG := $(shell git describe --tags)-runtime +runtime-push: DOCKER_TAG := $(PYTORCH_VERSION)-runtime runtime-push: $(DOCKER_PUSH) diff --git a/docs/source/linalg.rst b/docs/source/linalg.rst index d6de2373ad57..991205688df3 100644 --- a/docs/source/linalg.rst +++ b/docs/source/linalg.rst @@ -19,6 +19,7 @@ Functions .. autofunction:: eigvalsh .. autofunction:: matrix_rank .. autofunction:: norm +.. autofunction:: svd .. autofunction:: solve .. autofunction:: tensorinv .. autofunction:: tensorsolve diff --git a/docs/source/tensors.rst b/docs/source/tensors.rst index 315cc9dc5309..1baf34dd955e 100644 --- a/docs/source/tensors.rst +++ b/docs/source/tensors.rst @@ -214,6 +214,8 @@ view of a storage and defines numeric operations on it. .. automethod:: arctan_ .. automethod:: atan2 .. automethod:: atan2_ + .. automethod:: all + .. automethod:: any .. automethod:: backward :noindex: .. automethod:: baddbmm @@ -648,10 +650,3 @@ view of a storage and defines numeric operations on it. .. automethod:: xlogy .. automethod:: xlogy_ .. automethod:: zero_ - -.. class:: BoolTensor() - - The following methods are unique to :class:`torch.BoolTensor`. - - .. automethod:: all - .. automethod:: any diff --git a/docs/source/torch.rst b/docs/source/torch.rst index 46960ecdb1b4..922e1434bae1 100644 --- a/docs/source/torch.rst +++ b/docs/source/torch.rst @@ -364,6 +364,8 @@ Reduction Ops argmin amax amin + all + any max min dist diff --git a/mypy.ini b/mypy.ini index bab4ce5dfd42..8639ad1b44ad 100644 --- a/mypy.ini +++ b/mypy.ini @@ -76,9 +76,6 @@ ignore_errors = True [mypy-torch.nn.modules.conv] ignore_errors = True -[mypy-torch.nn.modules.fold] -ignore_errors = True - [mypy-torch.nn.modules.module] ignore_errors = True @@ -91,22 +88,28 @@ ignore_errors = True [mypy-torch.nn.modules.pooling] ignore_errors = True -[mypy-torch.nn.qat.modules.activations] +[mypy-torch.nn.parallel._functions] ignore_errors = True -[mypy-torch.nn.qat.modules.conv] +[mypy-torch._appdirs] ignore_errors = True -[mypy-torch.nn.quantized.dynamic.modules.linear] +[mypy-torch._overrides] ignore_errors = True -[mypy-torch.nn.quantized.modules.conv] +[mypy-torch.utils.tensorboard._caffe2_graph] ignore_errors = True -[mypy-torch._appdirs] +[mypy-torch.contrib._tensorboard_vis] +ignore_errors = True + +[mypy-torch.nn.utils.prune] +ignore_errors = True + +[mypy-torch.utils.show_pickle] ignore_errors = True -[mypy-torch._utils] +[mypy-torch.utils.hipify.hipify_python] ignore_errors = True [mypy-torch.utils.benchmark.examples.*] diff --git a/scripts/release_notes/commitlist.py b/scripts/release_notes/commitlist.py index 0a76f896f217..552641f54674 100644 --- a/scripts/release_notes/commitlist.py +++ b/scripts/release_notes/commitlist.py @@ -97,7 +97,7 @@ def filter(self, *, category=None, topic=None): if topic is not None: commits = [commit for commit in commits if commit.topic == topic] return commits - + def update_to(self, new_version): last_hash = self.commits[-1].commit_hash new_commits = CommitList.get_commits_between(last_hash, new_version) @@ -121,7 +121,7 @@ def update_existing(path, new_version): def to_markdown(commit_list, category): def cleanup_title(commit): - match = re.match('(.*) \(#\d+\)', commit.title) + match = re.match(r'(.*) \(#\d+\)', commit.title) if match is None: return commit.title return match.group(1) diff --git a/test/backward_compatibility/check_backward_compatibility.py b/test/backward_compatibility/check_backward_compatibility.py index 2d5d50096c81..4332916fef6b 100644 --- a/test/backward_compatibility/check_backward_compatibility.py +++ b/test/backward_compatibility/check_backward_compatibility.py @@ -37,6 +37,7 @@ ("aten::ifft", datetime.date(2021, 1, 31)), ("aten::irfft", datetime.date(2021, 1, 31)), ("aten::rfft", datetime.date(2021, 1, 31)), + ("aten::_svd_helper", datetime.date(2021, 1, 31)), ("aten::_cudnn_rnn_flatten_weight", datetime.date(2020, 12, 31)), ("aten::_cudnn_rnn", datetime.date(2020, 12, 31)), ("aten::_cudnn_rnn_backward", datetime.date(2020, 12, 31)), diff --git a/test/cpp/rpc/e2e_test_base.h b/test/cpp/rpc/e2e_test_base.h index 114284839858..cea5079b1a4e 100644 --- a/test/cpp/rpc/e2e_test_base.h +++ b/test/cpp/rpc/e2e_test_base.h @@ -64,19 +64,21 @@ class TestE2EBase : public ::testing::Test { ScriptRemoteCall scriptRemoteCall( op, {t1, t2, 1}, ownerRRef->rrefId(), ownerRRef->rrefId()); - auto fm = autograd::sendMessageWithAutograd( + auto jitFuture = autograd::sendMessageWithAutograd( *rpcAgent, rpcAgent->getWorkerInfo("worker"), std::move(scriptRemoteCall).toMessage(), false); - ownerRRef->registerOwnerCreationFuture(fm); + ownerRRef->registerOwnerCreationFuture(jitFuture); // Builtin operators does not return py::object, and hence does not require // GIL for destructing the potentially deleted OwerRRef. - fm->addCallback( - [ownerRRefId = ownerRRef->rrefId()](const FutureMessage& fm) { - callback::finishCreatingOwnerRRef(fm, ownerRRefId); + std::weak_ptr wp = jitFuture; + jitFuture->addCallback( + [wp, ownerRRefId = ownerRRef->rrefId()]() { + auto jitFuture = wp.lock(); + callback::finishCreatingOwnerRRef(*jitFuture, ownerRRefId); }); return ownerRRef; } @@ -89,12 +91,14 @@ class TestE2EBase : public ::testing::Test { // Send the RPC and return result. auto response = autograd::sendMessageWithAutograd( - *rpcAgent, - rpcAgent->getWorkerInfo("worker"), - std::move(scriptCall).toMessage()) - ->wait(); + *rpcAgent, + rpcAgent->getWorkerInfo("worker"), + std::move(scriptCall).toMessage()); + response->waitAndThrow(); + MessageType messageType = MessageType::FORWARD_AUTOGRAD_RESP; - auto wrappedResponse = deserializeResponse(response, messageType); + auto wrappedResponse = deserializeResponse( + std::move(*response->value().toCustomClass()), messageType); return static_cast(*wrappedResponse).value().toTensor(); } diff --git a/test/jit/test_class_type.py b/test/jit/test_class_type.py index a80670f0d22b..1c0eed071dfd 100644 --- a/test/jit/test_class_type.py +++ b/test/jit/test_class_type.py @@ -1212,6 +1212,32 @@ def test_function(a: int, b: int) -> 'ClassWithStaticMethod': self.checkScript(test_function, (1, 2)) + def test_classmethod(self): + """ + Test classmethods on class types. + """ + global ClassWithClassMethod + + @torch.jit.script + class ClassWithClassMethod: + def __init__(self, a: int): + self.a: int = a + + def __eq__(self, other: 'ClassWithClassMethod'): + return self.a == other.a + + @classmethod + def create(cls, a: int) -> 'ClassWithClassMethod': + return cls(a) + + def test_function(a: int) -> 'ClassWithClassMethod': + x = ClassWithClassMethod(a) + # Support calling classmethod with an instance + # Calling with the class is not supported. + return x.create(a) + + self.checkScript(test_function, (1,)) + def test_properties(self): """ Test that a scripted class can make use of the @property decorator. diff --git a/test/quantization/test_quantize_fx.py b/test/quantization/test_quantize_fx.py index a46cbcf738b6..dd13eb120a2e 100644 --- a/test/quantization/test_quantize_fx.py +++ b/test/quantization/test_quantize_fx.py @@ -958,7 +958,7 @@ def forward(self, x): ] qconfig_dict2 = {"": None, "module_name": [("sub", default_qconfig)]} - # Only nodes in Sub module instance is quantized + # Only nodes in Sub module instance are quantized # the first transpose is not quantized because the input is not quantized node_list2 = [ ns.call_module(nn.Conv2d), diff --git a/test/quantization/test_quantized_module.py b/test/quantization/test_quantized_module.py index 66210685b562..28867e8260b6 100644 --- a/test/quantization/test_quantized_module.py +++ b/test/quantization/test_quantized_module.py @@ -194,20 +194,24 @@ def _test_linear_api_impl(self, batch_size, in_features, out_features, use_bias, # Test JIT self.checkScriptable(qlinear, [[X_q]], check_save_load=True) - # Test from_float. - float_linear = torch.nn.Linear(in_features, out_features).float() - float_linear.qconfig = torch.quantization.default_qconfig - torch.quantization.prepare(float_linear, inplace=True) - float_linear(X.float()) - # Sequential allows swapping using "convert". - quantized_float_linear = torch.nn.Sequential(float_linear) - quantized_float_linear = torch.quantization.convert(quantized_float_linear, inplace=True) + # Make sure `from_float` works for all linear variants + modules_under_test = [torch.nn.Linear, torch.nn.modules.linear._LinearWithBias] - # Smoke test to make sure the module actually runs - quantized_float_linear(X_q) + for mut in modules_under_test: + # Test from_float. + float_linear = mut(in_features, out_features).float() + float_linear.qconfig = torch.quantization.default_qconfig + torch.quantization.prepare(float_linear, inplace=True) + float_linear(X.float()) + # Sequential allows swapping using "convert". + quantized_float_linear = torch.nn.Sequential(float_linear) + quantized_float_linear = torch.quantization.convert(quantized_float_linear, inplace=True) - # Smoke test extra_repr - self.assertTrue('QuantizedLinear' in str(quantized_float_linear)) + # Smoke test to make sure the module actually runs + quantized_float_linear(X_q) + + # Smoke test extra_repr + self.assertTrue('QuantizedLinear' in str(quantized_float_linear)) def test_quant_dequant_api(self): r = torch.tensor([[1., -1.], [1., -1.]], dtype=torch.float) @@ -928,16 +932,18 @@ def test_linear_api(self, batch_size, in_features, out_features, use_bias, use_d # Test JIT self.checkScriptable(qlinear, [[X]], check_save_load=True) - # Test from_float - float_linear = torch.nn.Linear(in_features, out_features).float() - if use_default_observer: - float_linear.qconfig = torch.quantization.default_dynamic_qconfig - prepare_dynamic(float_linear) - float_linear(X.float()) - quantized_float_linear = nnqd.Linear.from_float(float_linear) - - # Smoke test to make sure the module actually runs - quantized_float_linear(X) + modules_under_test = [torch.nn.Linear, torch.nn.modules.linear._LinearWithBias] + for mut in modules_under_test: + # Test from_float + float_linear = mut(in_features, out_features).float() + if use_default_observer: + float_linear.qconfig = torch.quantization.default_dynamic_qconfig + prepare_dynamic(float_linear) + float_linear(X.float()) + quantized_float_linear = nnqd.Linear.from_float(float_linear) + + # Smoke test to make sure the module actually runs + quantized_float_linear(X) # Smoke test extra_repr self.assertTrue('QuantizedLinear' in str(quantized_float_linear)) diff --git a/test/test_autograd.py b/test/test_autograd.py index 9f5925212757..fb4394b3a62b 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -4941,17 +4941,6 @@ def assert_only_first_requires_grad(res): return_counts=return_counts) assert_only_first_requires_grad(res) - def test_linalg_qr_r(self): - # torch.linalg.qr(mode='r') returns only 'r' and discards 'q', but - # without 'q' you cannot compute the backward pass. Check that - # linalg_qr_backward complains cleanly in that case. - inp = torch.randn((5, 7), requires_grad=True) - q, r = torch.linalg.qr(inp, mode='r') - assert q.shape == (0,) # empty tensor - b = torch.sum(r) - with self.assertRaisesRegex(RuntimeError, - "linalg_qr_backward: cannot compute backward"): - b.backward() def index_perm_variable(shape, max_indices): if not isinstance(shape, tuple): diff --git a/test/test_binary_ufuncs.py b/test/test_binary_ufuncs.py index 2ff12396701e..785f6999f570 100644 --- a/test/test_binary_ufuncs.py +++ b/test/test_binary_ufuncs.py @@ -2424,7 +2424,7 @@ def to_np(value): # Case of Tensor x Tensor if op is torch.Tensor.float_power_ and base_dtype != out_dtype: - with self.assertRaisesRegex(RuntimeError, "is not the desired type"): + with self.assertRaisesRegex(RuntimeError, "operation's result requires dtype"): op(base.clone(), exp) else: result = op(base.clone(), exp) @@ -2441,7 +2441,7 @@ def to_np(value): expected_scalar_exp = torch.from_numpy(np.float_power(to_np(base), i)) if op is torch.Tensor.float_power_ and base_dtype != out_dtype_scalar_exp: - with self.assertRaisesRegex(RuntimeError, "is not the desired type"): + with self.assertRaisesRegex(RuntimeError, "operation's result requires dtype"): op(base.clone(), i) else: result = op(base.clone(), i) @@ -2483,13 +2483,13 @@ def _promo_helper(x, y): if out.dtype == required_dtype: torch.float_power(base, exp, out=out) else: - with self.assertRaisesRegex(RuntimeError, "is not the desired output type"): + with self.assertRaisesRegex(RuntimeError, "operation's result requires dtype"): torch.float_power(base, exp, out=out) if base.dtype == required_dtype: torch.Tensor.float_power_(base.clone(), exp) else: - with self.assertRaisesRegex(RuntimeError, "is not the desired type"): + with self.assertRaisesRegex(RuntimeError, "operation's result requires dtype"): torch.Tensor.float_power_(base.clone(), exp) @skipIf(not TEST_SCIPY, "Scipy required for the test.") diff --git a/test/test_cuda.py b/test/test_cuda.py index cef6d689343a..f932a5158267 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -3110,6 +3110,14 @@ def run(module, op, args, kwargs): # Adds an empty dict for kwargs, which none of the Tensor methods use run("Tensor", *(meth_with_args + ({},))) + def test_batch_norm_gather_stats(self): + input = torch.randn(1, 3, 3, 3, device='cuda') + mean, invstd = torch.batch_norm_gather_stats( + input, mean=torch.ones(2, 3, device='cuda'), invstd=torch.ones(2, 3, device='cuda'), + running_mean=None, running_var=None , momentum=.1, eps=1e-5, count=2 + ) + self.assertEqual(mean, torch.ones(3, device='cuda')) + self.assertEqual(invstd, torch.ones(3, device='cuda')) class TestCudaComm(TestCase): def _test_broadcast(self, input): diff --git a/test/test_dataloader.py b/test/test_dataloader.py index c257dd8a2fd7..edc31b75485e 100644 --- a/test/test_dataloader.py +++ b/test/test_dataloader.py @@ -2031,15 +2031,19 @@ def __next__(self): def test_dataset_not_reset(self): dataset = DummyDataset() - dataloader = self._get_data_loader(dataset, num_workers=2) - dataset.start = 0 - for i in range(10): - for x in dataloader: - pass - # Changing the start value here doesn't have any effect in the dataset - # cached by the workers. since they are not recreated between epochs - # and can cache values safely - dataset.start = i + pin_memory_configs = [False] + if TEST_CUDA: + pin_memory_configs.append(True) + for pin_memory in pin_memory_configs: + dataloader = self._get_data_loader(dataset, num_workers=2, pin_memory=pin_memory) + dataset.start = 0 + for i in range(10): + for x in dataloader: + pass + # Changing the start value here doesn't have any effect in the dataset + # cached by the workers. since they are not recreated between epochs + # and can cache values safely + dataset.start = i diff --git a/test/test_linalg.py b/test/test_linalg.py index bab73f987905..b39d64e4bca1 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -11,6 +11,7 @@ from math import inf, nan, isnan import random from random import randrange +from itertools import product from functools import reduce from torch.testing._internal.common_utils import \ @@ -1864,6 +1865,277 @@ def test_nuclear_norm_exceptions_old(self, device): self.assertRaisesRegex(RuntimeError, "duplicate or invalid", torch.norm, x, "nuc", (0, 0)) self.assertRaisesRegex(IndexError, "Dimension out of range", torch.norm, x, "nuc", (0, 2)) + # ~~~ tests for torch.svd ~~~ + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.double) + def test_svd(self, device, dtype): + def run_test(dims, some, compute_uv): + x = torch.randn(*dims, dtype=dtype, device=device) + outu = torch.empty(0, dtype=dtype, device=device) + outs = torch.empty(0, dtype=dtype, device=device) + outv = torch.empty(0, dtype=dtype, device=device) + torch.svd(x, some=some, compute_uv=compute_uv, out=(outu, outs, outv)) + + if compute_uv: + if some: + x_recon = torch.matmul(outu, torch.matmul(outs.diag_embed(), outv.transpose(-2, -1))) + self.assertEqual(x, x_recon, atol=1e-8, rtol=0, msg='Incorrect reconstruction using U @ diag(S) @ V.T') + else: + narrow_u = outu[..., :min(*dims[-2:])] + narrow_v = outv[..., :min(*dims[-2:])] + x_recon = torch.matmul(narrow_u, torch.matmul(outs.diag_embed(), narrow_v.transpose(-2, -1))) + self.assertEqual(x, x_recon, atol=1e-8, rtol=0, msg='Incorrect reconstruction using U @ diag(S) @ V.T') + else: + _, singvals, _ = torch.svd(x, compute_uv=True) + self.assertEqual(singvals, outs, msg='Singular values mismatch') + self.assertEqual(outu, torch.zeros_like(outu), msg='U not zero') + self.assertEqual(outv, torch.zeros_like(outv), msg='V not zero') + + resu, ress, resv = torch.svd(x, some=some, compute_uv=compute_uv) + self.assertEqual(resu, outu, msg='outputs of svd and svd with out differ') + self.assertEqual(ress, outs, msg='outputs of svd and svd with out differ') + self.assertEqual(resv, outv, msg='outputs of svd and svd with out differ') + + # test non-contiguous + x = torch.randn(*dims, dtype=dtype, device=device) + n_dim = len(dims) + # Reverse the batch dimensions and the matrix dimensions and then concat them + x = x.permute(tuple(range(n_dim - 3, -1, -1)) + (n_dim - 1, n_dim - 2)) + assert not x.is_contiguous(), "x is intentionally non-contiguous" + resu, ress, resv = torch.svd(x, some=some, compute_uv=compute_uv) + if compute_uv: + if some: + x_recon = torch.matmul(resu, torch.matmul(ress.diag_embed(), resv.transpose(-2, -1))) + self.assertEqual(x, x_recon, atol=1e-8, rtol=0, msg='Incorrect reconstruction using U @ diag(S) @ V.T') + else: + narrow_u = resu[..., :min(*dims[-2:])] + narrow_v = resv[..., :min(*dims[-2:])] + x_recon = torch.matmul(narrow_u, torch.matmul(ress.diag_embed(), narrow_v.transpose(-2, -1))) + self.assertEqual(x, x_recon, atol=1e-8, rtol=0, msg='Incorrect reconstruction using U @ diag(S) @ V.T') + else: + _, singvals, _ = torch.svd(x, compute_uv=True) + self.assertEqual(singvals, ress, msg='Singular values mismatch') + self.assertEqual(resu, torch.zeros_like(resu), msg='U not zero') + self.assertEqual(resv, torch.zeros_like(resv), msg='V not zero') + + shapes = [(3, 3), (5, 3, 3), (7, 5, 3, 3), # square matrices + (7, 3), (5, 7, 3), (7, 5, 7, 3), # fat matrices + (3, 7), (5, 3, 7), (7, 5, 3, 7)] # thin matrices + for dims, some, compute_uv in product(shapes, [True, False], [True, False]): + run_test(dims, some, compute_uv) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float) + def test_svd_no_singularvectors(self, device, dtype): + for size in [(5, 5), (5, 20), (20, 5)]: + a = torch.randn(*size, device=device, dtype=dtype) + u, s_expect, v = torch.svd(a) + u, s_actual, v = torch.svd(a, compute_uv=False) + self.assertEqual(s_expect, s_actual, msg="Singular values don't match") + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.double) + def test_svd_lowrank(self, device, dtype): + from torch.testing._internal.common_utils import random_lowrank_matrix, random_sparse_matrix + + def run_subtest(actual_rank, matrix_size, batches, device, svd_lowrank, **options): + density = options.pop('density', 1) + if isinstance(matrix_size, int): + rows = columns = matrix_size + else: + rows, columns = matrix_size + if density == 1: + a_input = random_lowrank_matrix(actual_rank, rows, columns, *batches, device=device, dtype=dtype) + a = a_input + else: + assert batches == () + a_input = random_sparse_matrix(rows, columns, density, device=device, dtype=dtype) + a = a_input.to_dense() + + q = min(*size) + u, s, v = svd_lowrank(a_input, q=q, **options) + + # check if u, s, v is a SVD + u, s, v = u[..., :q], s[..., :q], v[..., :q] + A = u.matmul(s.diag_embed()).matmul(v.transpose(-2, -1)) + self.assertEqual(A, a) + + # check if svd_lowrank produces same singular values as torch.svd + U, S, V = torch.svd(a) + self.assertEqual(s.shape, S.shape) + self.assertEqual(u.shape, U.shape) + self.assertEqual(v.shape, V.shape) + self.assertEqual(s, S) + + if density == 1: + # actual_rank is known only for dense inputs + # + # check if pairs (u, U) and (v, V) span the same + # subspaces, respectively + u, s, v = u[..., :actual_rank], s[..., :actual_rank], v[..., :actual_rank] + U, S, V = U[..., :actual_rank], S[..., :actual_rank], V[..., :actual_rank] + self.assertEqual(u.transpose(-2, -1).matmul(U).det().abs(), torch.ones(batches, device=device, dtype=dtype)) + self.assertEqual(v.transpose(-2, -1).matmul(V).det().abs(), torch.ones(batches, device=device, dtype=dtype)) + + all_batches = [(), (1,), (3,), (2, 3)] + for actual_rank, size, all_batches in [ + (2, (17, 4), all_batches), + (4, (17, 4), all_batches), + (4, (17, 17), all_batches), + (10, (100, 40), all_batches), + (7, (1000, 1000), [()]), + ]: + # dense input + for batches in all_batches: + run_subtest(actual_rank, size, batches, device, torch.svd_lowrank) + if size != size[::-1]: + run_subtest(actual_rank, size[::-1], batches, device, torch.svd_lowrank) + + # sparse input + for size in [(17, 4), (4, 17), (17, 17), (100, 40), (40, 100), (1000, 1000)]: + for density in [0.005, 0.1]: + run_subtest(None, size, (), device, torch.svd_lowrank, density=density) + + # jitting support + jitted = torch.jit.script(torch.svd_lowrank) + actual_rank, size, batches = 2, (17, 4), () + run_subtest(actual_rank, size, batches, device, jitted) + + @onlyCPU + @skipCPUIfNoLapack + @dtypes(torch.cfloat) + def test_svd_complex(self, device, dtype): + t = torch.randn((10, 10), dtype=dtype, device=device) + U, S, V = torch.svd(t, some=False) + # note: from the math point of view, it is weird that we need to use + # V.T instead of V.T.conj(): torch.svd has a buggy behavior for + # complex numbers and it's deprecated. You should use torch.linalg.svd + # instead. + t2 = U @ torch.diag(S).type(dtype) @ V.T + self.assertEqual(t, t2) + + def _test_svd_helper(self, shape, some, col_maj, device, dtype): + cpu_tensor = torch.randn(shape, device='cpu').to(dtype) + device_tensor = cpu_tensor.to(device=device) + if col_maj: + cpu_tensor = cpu_tensor.t() + device_tensor = device_tensor.t() + cpu_result = torch.svd(cpu_tensor, some=some) + device_result = torch.svd(device_tensor, some=some) + m = min(cpu_tensor.shape[-2:]) + # torch.svd returns torch.return_types.svd which is a tuple of (U, V, S). + # - When some==False, U[..., m:] can be arbitrary. + # - When some==True, U shape: [..., m], V shape: [m, m] + # - Signs are not deterministic. If the sign of a column of U is changed + # then the corresponding column of the V has to be changed. + # Thus here we only compare result[..., :m].abs() from CPU and device. + for x, y in zip(cpu_result, device_result): + self.assertEqual(x[..., :m].abs(), y[..., :m].abs(), atol=1e-5, rtol=0) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(*floating_and_complex_types()) + def test_svd_square(self, device, dtype): + self._test_svd_helper((10, 10), True, False, device, dtype) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(*floating_types()) + def test_svd_square_col_maj(self, device, dtype): + self._test_svd_helper((10, 10), True, True, device, dtype) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(*floating_types()) + def test_svd_tall_some(self, device, dtype): + self._test_svd_helper((20, 5), True, False, device, dtype) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(*floating_types()) + def test_svd_tall_all(self, device, dtype): + self._test_svd_helper((20, 5), False, False, device, dtype) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(*floating_types()) + def test_svd_tall_some_col_maj(self, device, dtype): + self._test_svd_helper((5, 20), True, True, device, dtype) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(*floating_types()) + def test_svd_tall_all_col_maj(self, device, dtype): + self._test_svd_helper((5, 20), False, True, device, dtype) + + # ~~~ tests for torch.linalg.svd ~~~ + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) + def test_linalg_svd_compute_uv(self, device, dtype): + """ + Test the default case, compute_uv=True. Here we have the very same behavior as + numpy + """ + t = torch.randn((10, 11), device=device, dtype=dtype) + np_t = t.cpu().numpy() + for full_matrices in (True, False): + # check linalg.svd vs numpy + expected = np.linalg.svd(np_t, full_matrices, compute_uv=True) + actual = torch.linalg.svd(t, full_matrices, compute_uv=True) + self.assertEqual(actual, expected) + # check linalg.svd vs linalg.svd(out=...) + out = (torch.empty_like(actual[0]), + torch.empty_like(actual[1]), + torch.empty_like(actual[2])) + out2 = torch.linalg.svd(t, full_matrices, compute_uv=True, out=out) + self.assertEqual(actual, out) + self.assertEqual(actual, out2) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) + def test_linalg_svd_no_compute_uv(self, device, dtype): + """ + Test the compute_uv=False case. Here we have a different return type than + numpy: numpy returns S, we return (empty, S, empty) + """ + t = torch.randn((10, 11), device=device, dtype=dtype) + np_t = t.cpu().numpy() + + def is_empty(x): + return x.numel() == 0 and x.dtype == t.dtype and x.device == t.device + + for full_matrices in (True, False): + # check linalg.svd vs numpy + np_s = np.linalg.svd(np_t, full_matrices, compute_uv=False) + USV = torch.linalg.svd(t, full_matrices, compute_uv=False) + assert is_empty(USV.U) + self.assertEqual(USV.S, np_s) + assert is_empty(USV.V) + # check linalg.svd vs linalg.svd(out=...) + out = (torch.empty_like(USV.U), torch.empty_like(USV.S), torch.empty_like(USV.V)) + USV = torch.linalg.svd(t, full_matrices, compute_uv=False, out=out) + assert USV.U is out[0] + assert USV.S is out[1] + assert USV.V is out[2] + self.assertEqual(USV.S, np_s) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @onlyCUDA + @dtypes(torch.float) + def test_linalg_svd_out_different_device(self, device, dtype): + t = torch.randn(5, 7, device=device, dtype=dtype) # this is on cuda + u = torch.empty((5, 5), device='cpu', dtype=dtype) + s = torch.empty((5,), device='cpu', dtype=dtype) + v = torch.empty((7, 7), device='cpu', dtype=dtype) + with self.assertRaisesRegex(RuntimeError, 'svd output tensor U is on the wrong device: expected cuda:.* got cpu'): + torch.linalg.svd(t, out=(u, s, v)) + def cholesky_solve_test_helper(self, A_dims, b_dims, upper, device, dtype): from torch.testing._internal.common_utils import random_hermitian_pd_matrix @@ -2787,12 +3059,34 @@ def test_qr_vs_numpy(self, device, dtype): exp_r = np.linalg.qr(np_t, mode='r') q, r = torch.linalg.qr(t, mode='r') # check that q is empty - assert q.shape == (0,) - assert q.dtype == t.dtype - assert q.device == t.device + self.assertEqual(q.shape, (0,)) + self.assertEqual(q.dtype, t.dtype) + self.assertEqual(q.device, t.device) # check r self.assertEqual(r, exp_r) + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float) + def test_linalg_qr_autograd_errors(self, device, dtype): + # torch.linalg.qr(mode='r') returns only 'r' and discards 'q', but + # without 'q' you cannot compute the backward pass. Check that + # linalg_qr_backward complains cleanly in that case. + inp = torch.randn((5, 7), device=device, dtype=dtype, requires_grad=True) + q, r = torch.linalg.qr(inp, mode='r') + self.assertEqual(q.shape, (0,)) # empty tensor + b = torch.sum(r) + with self.assertRaisesRegex(RuntimeError, + "The derivative of qr is not implemented when mode='r'"): + b.backward() + # + inp = torch.randn((7, 5), device=device, dtype=dtype, requires_grad=True) + q, r = torch.linalg.qr(inp, mode='complete') + b = torch.sum(r) + with self.assertRaisesRegex(RuntimeError, + "The derivative of qr is not implemented when mode='complete' and nrows > ncols"): + b.backward() + @skipCUDAIfNoMagma @skipCPUIfNoLapack @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) @@ -2806,10 +3100,17 @@ def np_qr_batched(a, mode): all_q = [] all_r = [] for matrix in a: - q, r = np.linalg.qr(matrix, mode=mode) - all_q.append(q) - all_r.append(r) - return np.array(all_q), np.array(all_r) + result = np.linalg.qr(matrix, mode=mode) + if mode == 'r': + all_r.append(result) + else: + q, r = result + all_q.append(q) + all_r.append(r) + if mode == 'r': + return np.array(all_r) + else: + return np.array(all_q), np.array(all_r) t = torch.randn((3, 7, 5), device=device, dtype=dtype) np_t = t.cpu().numpy() @@ -2818,6 +3119,15 @@ def np_qr_batched(a, mode): q, r = torch.linalg.qr(t, mode=mode) self.assertEqual(q, exp_q) self.assertEqual(r, exp_r) + # for mode='r' we need a special logic because numpy returns only r + exp_r = np_qr_batched(np_t, mode='r') + q, r = torch.linalg.qr(t, mode='r') + # check that q is empty + self.assertEqual(q.shape, (0,)) + self.assertEqual(q.dtype, t.dtype) + self.assertEqual(q.device, t.device) + # check r + self.assertEqual(r, exp_r) @skipCUDAIfNoMagma @skipCPUIfNoLapack @@ -2840,11 +3150,22 @@ def test_qr_out(self, device, dtype): out = (torch.empty((0), dtype=dtype, device=device), torch.empty((0), dtype=dtype, device=device)) q2, r2 = torch.linalg.qr(t, mode=mode, out=out) - assert q2 is out[0] - assert r2 is out[1] + self.assertIs(q2, out[0]) + self.assertIs(r2, out[1]) self.assertEqual(q2, q) self.assertEqual(r2, r) + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float) + def test_qr_error_cases(self, device, dtype): + t1 = torch.randn(5, device=device, dtype=dtype) + with self.assertRaisesRegex(RuntimeError, 'qr input should have at least 2 dimensions, but has 1 dimensions instead'): + torch.linalg.qr(t1) + t2 = torch.randn((5, 7), device=device, dtype=dtype) + with self.assertRaisesRegex(RuntimeError, "qr received unrecognized mode 'hello'"): + torch.linalg.qr(t2, mode='hello') + @dtypes(torch.double, torch.cdouble) def test_einsum(self, device, dtype): def check(equation, *operands): @@ -4601,60 +4922,6 @@ def test_solve_methods_arg_device(self, device): "Expected LU_pivots and LU_data to be on the same device"): torch.lu_solve(b, A, torch.rand(A.shape[:-1], device=b_device).int()) - def _test_svd_helper(self, shape, some, col_maj, device, dtype): - cpu_tensor = torch.randn(shape, device='cpu').to(dtype) - device_tensor = cpu_tensor.to(device=device) - if col_maj: - cpu_tensor = cpu_tensor.t() - device_tensor = device_tensor.t() - cpu_result = torch.svd(cpu_tensor, some=some) - device_result = torch.svd(device_tensor, some=some) - m = min(cpu_tensor.shape[-2:]) - # torch.svd returns torch.return_types.svd which is a tuple of (U, V, S). - # - When some==False, U[..., m:] can be arbitrary. - # - When some==True, U shape: [..., m], V shape: [m, m] - # - Signs are not deterministic. If the sign of a column of U is changed - # then the corresponding column of the V has to be changed. - # Thus here we only compare result[..., :m].abs() from CPU and device. - for x, y in zip(cpu_result, device_result): - self.assertEqual(x[..., :m].abs(), y[..., :m].abs(), atol=1e-5, rtol=0) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(*floating_and_complex_types()) - def test_svd_square(self, device, dtype): - self._test_svd_helper((10, 10), True, False, device, dtype) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(*floating_types()) - def test_svd_square_col_maj(self, device, dtype): - self._test_svd_helper((10, 10), True, True, device, dtype) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(*floating_types()) - def test_svd_tall_some(self, device, dtype): - self._test_svd_helper((20, 5), True, False, device, dtype) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(*floating_types()) - def test_svd_tall_all(self, device, dtype): - self._test_svd_helper((20, 5), False, False, device, dtype) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(*floating_types()) - def test_svd_tall_some_col_maj(self, device, dtype): - self._test_svd_helper((5, 20), True, True, device, dtype) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(*floating_types()) - def test_svd_tall_all_col_maj(self, device, dtype): - self._test_svd_helper((5, 20), False, True, device, dtype) - @precisionOverride({torch.float32: 5e-3, torch.complex64: 1e-3}) @skipCUDAIfNoMagma @skipCPUIfNoLapack @@ -5596,145 +5863,6 @@ def run_test(dims, eigenvectors, upper): for batch_dims, eigenvectors, upper in itertools.product(batch_dims_set, (True, False), (True, False)): run_test((5,) + batch_dims, eigenvectors, upper) - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(torch.double) - def test_svd(self, device, dtype): - def run_test(dims, some, compute_uv): - x = torch.randn(*dims, dtype=dtype, device=device) - outu = torch.tensor((), dtype=dtype, device=device) - outs = torch.tensor((), dtype=dtype, device=device) - outv = torch.tensor((), dtype=dtype, device=device) - torch.svd(x, some=some, compute_uv=compute_uv, out=(outu, outs, outv)) - - if compute_uv: - if some: - x_recon = torch.matmul(outu, torch.matmul(outs.diag_embed(), outv.transpose(-2, -1))) - self.assertEqual(x, x_recon, atol=1e-8, rtol=0, msg='Incorrect reconstruction using U @ diag(S) @ V.T') - else: - narrow_u = outu[..., :min(*dims[-2:])] - narrow_v = outv[..., :min(*dims[-2:])] - x_recon = torch.matmul(narrow_u, torch.matmul(outs.diag_embed(), narrow_v.transpose(-2, -1))) - self.assertEqual(x, x_recon, atol=1e-8, rtol=0, msg='Incorrect reconstruction using U @ diag(S) @ V.T') - else: - _, singvals, _ = torch.svd(x, compute_uv=True) - self.assertEqual(singvals, outs, msg='Singular values mismatch') - self.assertEqual(outu, torch.zeros_like(outu), msg='U not zero') - self.assertEqual(outv, torch.zeros_like(outv), msg='V not zero') - - resu, ress, resv = torch.svd(x, some=some, compute_uv=compute_uv) - self.assertEqual(resu, outu, msg='outputs of svd and svd with out differ') - self.assertEqual(ress, outs, msg='outputs of svd and svd with out differ') - self.assertEqual(resv, outv, msg='outputs of svd and svd with out differ') - - # test non-contiguous - x = torch.randn(*dims, dtype=dtype, device=device) - n_dim = len(dims) - # Reverse the batch dimensions and the matrix dimensions and then concat them - x = x.permute(tuple(range(n_dim - 3, -1, -1)) + (n_dim - 1, n_dim - 2)) - assert not x.is_contiguous(), "x is intentionally non-contiguous" - resu, ress, resv = torch.svd(x, some=some, compute_uv=compute_uv) - if compute_uv: - if some: - x_recon = torch.matmul(resu, torch.matmul(ress.diag_embed(), resv.transpose(-2, -1))) - self.assertEqual(x, x_recon, atol=1e-8, rtol=0, msg='Incorrect reconstruction using U @ diag(S) @ V.T') - else: - narrow_u = resu[..., :min(*dims[-2:])] - narrow_v = resv[..., :min(*dims[-2:])] - x_recon = torch.matmul(narrow_u, torch.matmul(ress.diag_embed(), narrow_v.transpose(-2, -1))) - self.assertEqual(x, x_recon, atol=1e-8, rtol=0, msg='Incorrect reconstruction using U @ diag(S) @ V.T') - else: - _, singvals, _ = torch.svd(x, compute_uv=True) - self.assertEqual(singvals, ress, msg='Singular values mismatch') - self.assertEqual(resu, torch.zeros_like(resu), msg='U not zero') - self.assertEqual(resv, torch.zeros_like(resv), msg='V not zero') - - shapes = [(3, 3), (5, 3, 3), (7, 5, 3, 3), # square matrices - (7, 3), (5, 7, 3), (7, 5, 7, 3), # fat matrices - (3, 7), (5, 3, 7), (7, 5, 3, 7)] # thin matrices - for dims, some, compute_uv in itertools.product(shapes, [True, False], [True, False]): - run_test(dims, some, compute_uv) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - def test_svd_no_singularvectors(self, device): - for size in [(5, 5), (5, 20), (20, 5)]: - a = torch.randn(*size, device=device) - u, s_expect, v = torch.svd(a) - u, s_actual, v = torch.svd(a, compute_uv=False) - self.assertEqual(s_expect, s_actual, msg="Singular values don't match") - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - def test_svd_lowrank(self, device): - import torch - from torch.testing._internal.common_utils import random_lowrank_matrix, random_sparse_matrix - - dtype = torch.double - - def run_subtest(actual_rank, matrix_size, batches, device, svd_lowrank, **options): - density = options.pop('density', 1) - if isinstance(matrix_size, int): - rows = columns = matrix_size - else: - rows, columns = matrix_size - if density == 1: - a_input = random_lowrank_matrix(actual_rank, rows, columns, *batches, device=device, dtype=dtype) - a = a_input - else: - assert batches == () - a_input = random_sparse_matrix(rows, columns, density, device=device, dtype=dtype) - a = a_input.to_dense() - - q = min(*size) - u, s, v = svd_lowrank(a_input, q=q, **options) - - # check if u, s, v is a SVD - u, s, v = u[..., :q], s[..., :q], v[..., :q] - A = u.matmul(s.diag_embed()).matmul(v.transpose(-2, -1)) - self.assertEqual(A, a) - - # check if svd_lowrank produces same singular values as torch.svd - U, S, V = torch.svd(a) - self.assertEqual(s.shape, S.shape) - self.assertEqual(u.shape, U.shape) - self.assertEqual(v.shape, V.shape) - self.assertEqual(s, S) - - if density == 1: - # actual_rank is known only for dense inputs - # - # check if pairs (u, U) and (v, V) span the same - # subspaces, respectively - u, s, v = u[..., :actual_rank], s[..., :actual_rank], v[..., :actual_rank] - U, S, V = U[..., :actual_rank], S[..., :actual_rank], V[..., :actual_rank] - self.assertEqual(u.transpose(-2, -1).matmul(U).det().abs(), torch.ones(batches, device=device, dtype=dtype)) - self.assertEqual(v.transpose(-2, -1).matmul(V).det().abs(), torch.ones(batches, device=device, dtype=dtype)) - - all_batches = [(), (1,), (3,), (2, 3)] - for actual_rank, size, all_batches in [ - (2, (17, 4), all_batches), - (4, (17, 4), all_batches), - (4, (17, 17), all_batches), - (10, (100, 40), all_batches), - (7, (1000, 1000), [()]), - ]: - # dense input - for batches in all_batches: - run_subtest(actual_rank, size, batches, device, torch.svd_lowrank) - if size != size[::-1]: - run_subtest(actual_rank, size[::-1], batches, device, torch.svd_lowrank) - - # sparse input - for size in [(17, 4), (4, 17), (17, 17), (100, 40), (40, 100), (1000, 1000)]: - for density in [0.005, 0.1]: - run_subtest(None, size, (), device, torch.svd_lowrank, density=density) - - # jitting support - jitted = torch.jit.script(torch.svd_lowrank) - actual_rank, size, batches = 2, (17, 4), () - run_subtest(actual_rank, size, batches, device, jitted) - @skipCUDAIfNoMagma @skipCPUIfNoLapack def test_pca_lowrank(self, device): diff --git a/test/test_namedtuple_return_api.py b/test/test_namedtuple_return_api.py index 1906b83ca8d6..a5d8b8179207 100644 --- a/test/test_namedtuple_return_api.py +++ b/test/test_namedtuple_return_api.py @@ -13,6 +13,7 @@ 'max', 'min', 'median', 'nanmedian', 'mode', 'kthvalue', 'svd', 'symeig', 'eig', 'qr', 'geqrf', 'solve', 'slogdet', 'sort', 'topk', 'lstsq', 'triangular_solve', 'cummax', 'cummin', 'linalg_eigh', "unpack_dual", 'linalg_qr', + '_svd_helper', 'linalg_svd', } @@ -56,7 +57,7 @@ def test_namedtuple_return(self): names=('values', 'indices'), hasout=True), op(operators=['kthvalue'], input=(1, 0), names=('values', 'indices'), hasout=True), - op(operators=['svd'], input=(), names=('U', 'S', 'V'), hasout=True), + op(operators=['svd', '_svd_helper', 'linalg_svd'], input=(), names=('U', 'S', 'V'), hasout=True), op(operators=['slogdet'], input=(), names=('sign', 'logabsdet'), hasout=False), op(operators=['qr', 'linalg_qr'], input=(), names=('Q', 'R'), hasout=True), op(operators=['solve'], input=(a,), names=('solution', 'LU'), hasout=True), @@ -65,26 +66,38 @@ def test_namedtuple_return(self): op(operators=['triangular_solve'], input=(a,), names=('solution', 'cloned_coefficient'), hasout=True), op(operators=['lstsq'], input=(a,), names=('solution', 'QR'), hasout=True), op(operators=['linalg_eigh'], input=("L",), names=('eigenvalues', 'eigenvectors'), hasout=True), - op(operators=['unpack_dual'], input=(a, 0), names=('primal', 'tangent'), hasout=False), + op(operators=['unpack_dual'], input=(0,), names=('primal', 'tangent'), hasout=False), ] + def get_func(f): + "Return either torch.f or torch.linalg.f, where 'f' is a string" + if f.startswith('linalg_'): + return getattr(torch.linalg, f[7:]) + return getattr(torch, f, None) + + def check_namedtuple(tup, names): + "Check that the namedtuple 'tup' has the given names" + for i, name in enumerate(names): + self.assertIs(getattr(tup, name), tup[i]) + for op in operators: for f in op.operators: - if 'linalg_' in f: - ret = getattr(torch.linalg, f[7:])(a, *op.input) - ret1 = getattr(torch.linalg, f[7:])(a, *op.input, out=tuple(ret)) - for i, name in enumerate(op.names): - self.assertIs(getattr(ret, name), ret[i]) - else: - # Handle op that are not methods - func = getattr(a, f) if hasattr(a, f) else getattr(torch, f) - ret = func(*op.input) - for i, name in enumerate(op.names): - self.assertIs(getattr(ret, name), ret[i]) - if op.hasout: - ret1 = getattr(torch, f)(a, *op.input, out=tuple(ret)) - for i, name in enumerate(op.names): - self.assertIs(getattr(ret, name), ret[i]) + # 1. check the namedtuple returned by calling torch.f + func = get_func(f) + if func: + ret1 = func(a, *op.input) + check_namedtuple(ret1, op.names) + # + # 2. check the out= variant, if it exists + if func and op.hasout: + ret2 = func(a, *op.input, out=tuple(ret1)) + check_namedtuple(ret2, op.names) + # + # 3. check the Tensor.f method, if it exists + meth = getattr(a, f, None) + if meth: + ret3 = meth(*op.input) + check_namedtuple(ret3, op.names) all_covered_operators = set([x for y in operators for x in y.operators]) diff --git a/test/test_nn.py b/test/test_nn.py index 386ba369dca6..83a98e360857 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -2405,7 +2405,6 @@ def test_pruning_container_compute_mask(self): # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 self.assertEqualIgnoreType(expected_mask, computed_mask) - def test_l1_unstructured_pruning(self): r"""Test that l1 unstructured pruning actually removes the lowest entries by l1 norm (by hand). It also checks that applying l1 @@ -2430,6 +2429,35 @@ def test_l1_unstructured_pruning(self): # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 self.assertEqualIgnoreType(expected_weight, m.weight) + def test_l1_unstructured_pruning_with_importance_scores(self): + r"""Test that l1 unstructured pruning actually removes the lowest + entries of importance scores and not the parameter by l1 norm (by hand). + It also checks that applying l1 unstructured pruning more than once + respects the previous mask. + """ + m = nn.Linear(4, 2) + # modify its weight matrix by hand + m.weight = torch.nn.Parameter( + torch.tensor( + [[1, 2, 3, 4], [-4, -3, -2, -1]], dtype=torch.float32 + ) + ) + importance_scores = torch.tensor( + [[4, 2, 1, 3], [-3, -1, -2, -4]], dtype=torch.float32 + ) + + prune.l1_unstructured(m, 'weight', amount=2, importance_scores=importance_scores) + expected_weight = torch.tensor([[1, 2, 0, 4], [-4, 0, -2, -1]]) + # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 + self.assertEqualIgnoreType(expected_weight, m.weight) + + # check that pruning again removes two entries of m.weight that are colocated with + # the next two smallest absolute values of importance scores. + prune.l1_unstructured(m, 'weight', amount=2, importance_scores=importance_scores) + expected_weight = torch.tensor([[1, 0, 0, 4], [-4, 0, 0, -1]]) + # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 + self.assertEqualIgnoreType(expected_weight, m.weight) + def test_unstructured_pruning_same_magnitude(self): r"""Since it may happen that the tensor to prune has entries with the same exact magnitude, it is important to check that pruning happens @@ -2447,7 +2475,6 @@ def test_unstructured_pruning_same_magnitude(self): self.assertEqual(nparams_toprune, nparams_pruned) def test_random_structured_pruning_amount(self): - AMOUNT = 0.6 AXIS = 2 p = prune.RandomStructured(amount=AMOUNT, dim=AXIS) @@ -2464,7 +2491,6 @@ def test_random_structured_pruning_amount(self): ) assert per_column_sums == [0, 20] - def test_ln_structured_pruning(self): r"""Check Ln structured pruning by hand. """ @@ -2488,6 +2514,33 @@ def test_ln_structured_pruning(self): prune.ln_structured(m, 'weight', amount=1, n=1, dim=-1) self.assertEqual(expected_mask_axis3, m.weight_mask) + def test_ln_structured_pruning_importance_scores(self): + r"""Check Ln structured pruning by hand. + """ + m = nn.Conv2d(3, 1, 2) + m.weight.data = torch.Tensor( + [[[[1., 2.], [1., 2.5]], + [[0.5, 1.], [0.1, 0.1]], + [[-3., -5.], [0.1, -1.]]]] + ) + importance_scores = torch.Tensor( + [[[[10., 1.], [10., 1.]], + [[30., 3.], [30., 3.]], + [[-20., -2.], [-20., -2.]]]] + ) + # expected effect of pruning 1 of the 3 channels by L2-norm + expected_mask_axis1 = torch.ones_like(m.weight) + expected_mask_axis1[:, 0] = 0. + + prune.ln_structured(m, 'weight', amount=1, n=2, dim=1, importance_scores=importance_scores) + self.assertEqual(expected_mask_axis1, m.weight_mask) + + # expected effect of pruning 1 of the 2 columns along axis -1 by L1-norm + expected_mask_axis3 = expected_mask_axis1 + expected_mask_axis3[:, :, :, 1] = 0. + + prune.ln_structured(m, 'weight', amount=1, n=1, dim=-1, importance_scores=importance_scores) + self.assertEqual(expected_mask_axis3, m.weight_mask) def test_remove_pruning(self): r"""`prune.remove` removes the hook and the reparametrization @@ -2567,6 +2620,49 @@ def test_global_pruning(self): expected_nweight = torch.tensor([[0, 0, -2]]).to(dtype=n.weight.dtype) self.assertEqual(expected_nweight, n.weight) + def test_global_pruning_importance_scores(self): + r"""Test that global l1 unstructured pruning over 2 parameters removes + the `amount=4` smallest global weights across the 2 parameters. + """ + m = nn.Linear(4, 2) + n = nn.Linear(3, 1) + # modify the weight matrices by hand + m.weight = torch.nn.Parameter( + torch.tensor([[1, 2, 3, 4], [-4, -3, -2, -1]]).to( + dtype=torch.float32) + ) + m_importance_scores = torch.tensor( + [[4, 2, 1, 3], [-3, -1, -2, -4]], dtype=torch.float32 + ) + n.weight = torch.nn.Parameter( + torch.tensor([[0, 0.1, -2]]).to( + dtype=torch.float32) + ) + n_importance_scores = torch.tensor([[0, 10., -0.2]]).to(dtype=torch.float32) + + params_to_prune = ( + (m, 'weight'), + (n, 'weight'), + ) + importance_scores = { + (m, 'weight'): m_importance_scores, + (n, 'weight'): n_importance_scores, + } + + # prune the 4 smallest weights globally by L1 magnitude + prune.global_unstructured( + params_to_prune, + pruning_method=prune.L1Unstructured, + amount=4, + importance_scores=importance_scores, + ) + + expected_m_weight = torch.tensor([[1, 2, 0, 4], [-4, 0, -2, -1]]) + # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 + self.assertEqualIgnoreType(expected_m_weight, m.weight) + + expected_n_weight = torch.tensor([[0, 0.1, 0]]).to(dtype=n.weight.dtype) + self.assertEqual(expected_n_weight, n.weight) def test_custom_from_mask_pruning(self): r"""Test that the CustomFromMask is capable of receiving @@ -2656,7 +2752,6 @@ def test_pruning_serialization_model(self): self.assertEqual(pruned_weight, new_model[0].weight) - def test_pruning_serialization_state_dict(self): # create a model model = torch.nn.Sequential( @@ -2707,7 +2802,6 @@ def test_pruning_serialization_state_dict(self): self.assertEqual(pruned_weight, new_model[0].weight) - def test_prune(self): # create a new pruning method p = prune.L1Unstructured(amount=2) @@ -2721,6 +2815,37 @@ def test_prune(self): pruned_tensor = p.prune(t, default_mask) self.assertEqual(t * expected_mask, pruned_tensor) + def test_prune_importance_scores(self): + # create a new pruning method + p = prune.L1Unstructured(amount=2) + # create tensor to be pruned + t = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]).to(dtype=torch.float32) + importance_scores = torch.tensor( + [[1, 2, 3, 4], [1.5, 1.6, 1.7, 1.8]] + ).to(dtype=torch.float32) + # create prior mask by hand + default_mask = torch.tensor([[1, 1, 1, 0], [1, 1, 0, 1]]) + # since we are pruning the two lowest magnitude units, the outcome of + # the calculation should be this: + expected_mask = torch.tensor([[0, 1, 1, 0], [0, 1, 0, 1]]) + pruned_tensor = p.prune(t, default_mask, importance_scores=importance_scores) + self.assertEqual(t * expected_mask, pruned_tensor) + + def test_prune_importance_scores_mimic_default(self): + # create a new pruning method + p = prune.L1Unstructured(amount=2) + # create tensor to be pruned + t = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]).to(dtype=torch.float32) + # create prior mask by hand + default_mask = torch.tensor([[1, 1, 1, 0], [1, 1, 0, 1]]) + # since we are pruning the two lowest magnitude units, the outcome of + # the calculation should be this: + expected_mask = torch.tensor([[0, 0, 1, 0], [1, 1, 0, 1]]) + pruned_tensor_without_importance_scores = p.prune(t, default_mask) + pruned_tensor_with_importance_scores = p.prune(t, default_mask, importance_scores=t) + self.assertEqual(pruned_tensor_without_importance_scores, pruned_tensor_with_importance_scores) + self.assertEqual(t * expected_mask, pruned_tensor_without_importance_scores) + def test_rnn_pruning(self): l = torch.nn.LSTM(32, 32) # This Module has 4 parameters called: @@ -2752,6 +2877,7 @@ def test_rnn_pruning(self): assert dict(l.named_parameters())['weight_ih_l0'] is not None assert 'weight_ih_l0_orig' not in dict(l.named_parameters()) + def test_rnn_weight_norm(self): def check_weight_norm(l, name, num_params): # This Module has 4 or 5 parameters called: diff --git a/test/test_reductions.py b/test/test_reductions.py index 917d469c5ee6..b08cebf7947b 100644 --- a/test/test_reductions.py +++ b/test/test_reductions.py @@ -1133,8 +1133,13 @@ def verify_against_numpy(t): verify_against_numpy(t) @dtypes(*(torch.testing.get_all_dtypes(include_half=True, include_bfloat16=False, - include_bool=True, include_complex=False))) + include_bool=True, include_complex=True))) def test_all_any_vs_numpy(self, device, dtype): + # Note [all, any uint8 compatibility]: However for compatibility reason, + # for `uint8`, they return Tensor of same dtype `uint8`. + # Reference: https://github.com/pytorch/pytorch/pull/47878#issuecomment-747108561 + exact_dtype = True if dtype != torch.uint8 else False + def _test_all_any(x): self.compare_with_numpy(torch.all, np.all, x) self.compare_with_numpy(torch.any, np.any, x) @@ -1142,38 +1147,102 @@ def _test_all_any(x): def _test_all_any_with_dim(x, dim): torch_fn = partial(torch.all, dim=dim) np_fn = partial(np.all, axis=dim) - self.compare_with_numpy(torch_fn, np_fn, x, exact_dtype=False) + self.compare_with_numpy(torch_fn, np_fn, x, exact_dtype=exact_dtype) torch_fn = partial(torch.any, dim=dim) np_fn = partial(np.any, axis=dim) - self.compare_with_numpy(torch_fn, np_fn, x, exact_dtype=False) + self.compare_with_numpy(torch_fn, np_fn, x, exact_dtype=exact_dtype) + + def _test_out_variant(x, dim): + out = torch.empty_like(x) + if dtype == torch.bool or dtype == torch.uint8: + expected = torch.all(x, dim) + torch.all(x, dim, out=out) + self.assertEqual(expected, out) + + expected = torch.any(x, dim) + torch.any(x, dim, out=out) + self.assertEqual(expected, out) + else: + with self.assertRaisesRegex(RuntimeError, "all only supports bool tensor for result, got"): + torch.all(x, dim, out=out) + + with self.assertRaisesRegex(RuntimeError, "any only supports bool tensor for result, got"): + torch.any(x, dim, out=out) + + def _test_all_any_with_dim_keepdim(x, dim, keepdim): + torch_fn = partial(torch.all, dim=dim, keepdim=keepdim) + np_fn = partial(np.all, axis=dim, keepdims=keepdim) + self.compare_with_numpy(torch_fn, np_fn, x, exact_dtype=exact_dtype) + + torch_fn = partial(torch.any, dim=dim, keepdim=keepdim) + np_fn = partial(np.any, axis=dim, keepdims=keepdim) + self.compare_with_numpy(torch_fn, np_fn, x, exact_dtype=exact_dtype) + + def _test_output_dtype(x): + # This test will fail once the functions return bool output + # for uint8 input. + expected_dtype = torch.uint8 if dtype == torch.uint8 else torch.bool + self.assertEqual(torch.all(x).dtype, expected_dtype) + self.assertEqual(torch.any(x).dtype, expected_dtype) + + self.assertEqual(torch.all(x, dim=0).dtype, expected_dtype) + self.assertEqual(torch.any(x, dim=0).dtype, expected_dtype) for ndim in range(5): shape = _rand_shape(ndim, 1, 5) x = _generate_input(shape, dtype, device, with_extremal=False) _test_all_any(x) + _test_all_any(x.T) + _test_all_any(x[..., ::2]) x = _generate_input(shape, dtype, device, with_extremal=True) _test_all_any(x) + _test_all_any(x.T) + _test_all_any(x[..., ::2]) x = torch.zeros_like(x) _test_all_any(x) + _test_all_any(x.T) + _test_all_any(x[..., ::2]) x = torch.ones_like(x) _test_all_any(x) - + _test_all_any(x.T) + _test_all_any(x[..., ::2]) + _test_output_dtype(x) for dim in range(ndim): x = _generate_input(shape, dtype, device, with_extremal=False) _test_all_any_with_dim(x, dim) + _test_all_any_with_dim(x.T, dim) + _test_all_any_with_dim(x[..., ::2], dim) + _test_out_variant(x, dim) + _test_all_any_with_dim_keepdim(x, dim, keepdim=True) + _test_all_any_with_dim_keepdim(x, dim, keepdim=False) x = _generate_input(shape, dtype, device, with_extremal=True) _test_all_any_with_dim(x, dim) + _test_all_any_with_dim(x.T, dim) + _test_all_any_with_dim(x[..., ::2], dim) + _test_out_variant(x, dim) + _test_all_any_with_dim_keepdim(x, dim, keepdim=True) + _test_all_any_with_dim_keepdim(x, dim, keepdim=False) x = torch.zeros_like(x) _test_all_any_with_dim(x, dim) + _test_all_any_with_dim(x.T, dim) + _test_all_any_with_dim(x[..., ::2], dim) + _test_out_variant(x, dim) + _test_all_any_with_dim_keepdim(x, dim, keepdim=True) + _test_all_any_with_dim_keepdim(x, dim, keepdim=False) x = torch.ones_like(x) _test_all_any_with_dim(x, dim) + _test_all_any_with_dim(x.T, dim) + _test_all_any_with_dim(x[..., ::2], dim) + _test_out_variant(x, dim) + _test_all_any_with_dim_keepdim(x, dim, keepdim=True) + _test_all_any_with_dim_keepdim(x, dim, keepdim=False) # TODO: part of this test covers torch.norm, with should be covered by test_linalg @onlyOnCPUAndCUDA @@ -1851,82 +1920,6 @@ def check(a, q, args, kwargs, message): RuntimeError, r'quantile\(\) out tensor must be on the same device as the input tensor'): torch.quantile(torch.randn(1, device=device), 0.5, out=torch.scalar_tensor(1)) - def test_logical_any(self, device): - x = torch.zeros([2, 3, 400], dtype=torch.uint8, device=device) - - self.assertEqual( - torch.tensor(0, dtype=torch.uint8, device=device), - x.any()) - - self.assertEqual( - torch.zeros([1, 3, 400], dtype=torch.uint8, device=device), - x.any(0, keepdim=True)) - - self.assertEqual( - torch.zeros([2, 1, 400], dtype=torch.uint8, device=device), - x.any(1, keepdim=True)) - - self.assertEqual( - torch.zeros([2, 3, 1], dtype=torch.uint8, device=device), - x.any(2, keepdim=True)) - - # set the last element to 0 - x[-1][-1][-1] = 1 - - self.assertEqual( - torch.tensor(1, dtype=torch.uint8, device=device), - x.any()) - - y = torch.zeros([1, 3, 400], dtype=torch.uint8, device=device) - y[-1][-1][-1] = 1 - self.assertEqual(y, x.any(0, keepdim=True)) - - y = torch.zeros([2, 1, 400], dtype=torch.uint8, device=device) - y[-1][-1][-1] = 1 - self.assertEqual(y, x.any(1, keepdim=True)) - - y = torch.zeros([2, 3, 1], dtype=torch.uint8, device=device) - y[-1][-1][-1] = 1 - self.assertEqual(y, x.any(2, keepdim=True)) - - def test_logical_all(self, device): - x = torch.ones([2, 3, 400], dtype=torch.uint8, device=device) - - self.assertEqual( - torch.tensor(1, dtype=torch.uint8, device=device), - x.all()) - - self.assertEqual( - torch.ones([1, 3, 400], dtype=torch.uint8, device=device), - x.all(0, keepdim=True)) - - self.assertEqual( - torch.ones([2, 1, 400], dtype=torch.uint8, device=device), - x.all(1, keepdim=True)) - - self.assertEqual( - torch.ones([2, 3, 1], dtype=torch.uint8, device=device), - x.all(2, keepdim=True)) - - # set the last element to 0 - x[-1][-1][-1] = 0 - - self.assertEqual( - torch.tensor(0, dtype=torch.uint8, device=device), - x.all()) - - y = torch.ones([1, 3, 400], dtype=torch.uint8, device=device) - y[-1][-1][-1] = 0 - self.assertEqual(y, x.all(0, keepdim=True)) - - y = torch.ones([2, 1, 400], dtype=torch.uint8, device=device) - y[-1][-1][-1] = 0 - self.assertEqual(y, x.all(1, keepdim=True)) - - y = torch.ones([2, 3, 1], dtype=torch.uint8, device=device) - y[-1][-1][-1] = 0 - self.assertEqual(y, x.all(2, keepdim=True)) - def test_std_mean(self, device): x = torch.rand(100, 50, 20, device=device) for dim in range(x.dim()): @@ -2251,21 +2244,29 @@ def test_reduction_empty(self, device): # ignore if there is no allreduce. self.assertTrue('dim' in str(err)) - # any - xb = x.to(torch.uint8) - yb = x.to(torch.uint8) - self.assertEqual((2, 0), xb.any(2).shape) - self.assertEqual((2, 0, 1), xb.any(2, keepdim=True).shape) - self.assertEqual(torch.zeros((2, 4), device=device, dtype=torch.uint8), xb.any(1)) - self.assertEqual(torch.zeros((2, 1, 4), device=device, dtype=torch.uint8), xb.any(1, keepdim=True)) - self.assertEqual(torch.zeros((), device=device, dtype=torch.uint8), xb.any()) - - # all - self.assertEqual((2, 0), xb.all(2).shape) - self.assertEqual((2, 0, 1), xb.all(2, keepdim=True).shape) - self.assertEqual(torch.ones((2, 4), device=device, dtype=torch.uint8), xb.all(1)) - self.assertEqual(torch.ones((2, 1, 4), device=device, dtype=torch.uint8), xb.all(1, keepdim=True)) - self.assertEqual(torch.ones((), device=device, dtype=torch.uint8), xb.all()) + for dtype in torch.testing.get_all_dtypes(include_half=True, include_bfloat16=False, + include_bool=True, include_complex=True): + # Refer: [all, any uint8 compatibility] + if dtype == torch.uint8: + out_dtype = torch.uint8 + else: + out_dtype = torch.bool # output of all/any is bool irrespective of input dtype + + # any + xb = x.to(dtype) + yb = x.to(dtype) + self.assertEqual((2, 0), xb.any(2).shape) + self.assertEqual((2, 0, 1), xb.any(2, keepdim=True).shape) + self.assertEqual(torch.zeros((2, 4), device=device, dtype=out_dtype), xb.any(1)) + self.assertEqual(torch.zeros((2, 1, 4), device=device, dtype=out_dtype), xb.any(1, keepdim=True)) + self.assertEqual(torch.zeros((), device=device, dtype=out_dtype), xb.any()) + + # all + self.assertEqual((2, 0), xb.all(2).shape) + self.assertEqual((2, 0, 1), xb.all(2, keepdim=True).shape) + self.assertEqual(torch.ones((2, 4), device=device, dtype=out_dtype), xb.all(1)) + self.assertEqual(torch.ones((2, 1, 4), device=device, dtype=out_dtype), xb.all(1, keepdim=True)) + self.assertEqual(torch.ones((), device=device, dtype=out_dtype), xb.all()) instantiate_device_type_tests(TestReductions, globals()) diff --git a/test/test_torch.py b/test/test_torch.py index 72fa853e2e7c..874a8a6ac9f6 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -6909,9 +6909,6 @@ def inner(self, device, dtype): ('atanh', '', _small_3d, lambda t, d: [], 1e-3, 1e-2, 1e-5, torch.testing.get_all_fp_dtypes()), ('erf', '', _small_3d, lambda t, d: [], 1e-3, 1e-2, 1e-5, torch.testing.get_all_fp_dtypes(), [torch.bfloat16]), ('erfc', '', _small_3d, lambda t, d: [], 1e-3, 1e-2, 1e-5, _float_types, [torch.bfloat16]), - ('exp', '', _small_3d, lambda t, d: [], 1e-2, 5e-2, 1e-5, torch.testing.get_all_fp_dtypes()), - ('exp', 'small', lambda t, d: _small_3d(t, d).clamp(-1, 1), - lambda t, d: [], 1e-2, 5e-2, 1e-5, torch.testing.get_all_fp_dtypes(), [torch.bfloat16]), ('rad2deg', '', _small_3d, lambda t, d: [], 1e-1, 1e-0, 1e-5, torch.testing.get_all_fp_dtypes(), [torch.bfloat16]), ('deg2rad', '', _small_3d, lambda t, d: [], 1e-1, 1e-1, 1e-5, torch.testing.get_all_fp_dtypes(), [torch.bfloat16]), ('reciprocal', '', _small_3d, lambda t, d: [], 1e-1, 1e-1, 1e-5, torch.testing.get_all_fp_dtypes(), [torch.bfloat16]), diff --git a/test/test_unary_ufuncs.py b/test/test_unary_ufuncs.py index 960991a4820b..3c6a4f0e7b0a 100644 --- a/test/test_unary_ufuncs.py +++ b/test/test_unary_ufuncs.py @@ -1702,7 +1702,6 @@ def _medium_2d(dtype, device): # TODO: all these should be replaced with OpInfos torch_op_tests = [ - _TorchMathTestMeta('exp'), _TorchMathTestMeta('floor'), _TorchMathTestMeta('ceil'), _TorchMathTestMeta('rad2deg'), diff --git a/test/test_view_ops.py b/test/test_view_ops.py index be33aa1ab44a..17d04c35d8ca 100644 --- a/test/test_view_ops.py +++ b/test/test_view_ops.py @@ -119,6 +119,56 @@ def test_conj_self(self, device, dtype): s = t.conj() self.assertTrue(s is t) + @onlyOnCPUAndCUDA + @dtypes(*torch.testing.get_all_fp_dtypes(include_bfloat16=False), torch.complex64) + def test_view_dtype(self, device, dtype): + int_dtype = { + torch.half: torch.int16, + torch.bfloat16: torch.int16, + torch.float: torch.int, + torch.double: torch.long, + torch.complex64: torch.long, + }[dtype] + numpy_dtype = { + torch.half: np.int16, + torch.bfloat16: np.int16, + torch.float: np.int32, + torch.double: np.int64, + torch.complex64: np.int64, + }[dtype] + + def generate_inputs(): + yield make_tensor((5, 5, 5), device, dtype, low=-5, high=5) + yield make_tensor((5, 5, 5), device, dtype, low=-5, high=5).permute(2, 0, 1) + yield make_tensor((1, 5, 1), device, dtype, low=-5, high=5).expand(5, 5, 5) + yield make_tensor((10, 5, 10), device, dtype, low=-5, high=5)[::2, :, ::2] + yield make_tensor((0, 5, 10), device, dtype, low=-5, high=5) + yield make_tensor((), device, dtype, low=-5, high=5) + + def run_test(fp_tensor): + self.assertRaises(RuntimeError, lambda: fp_tensor.view(torch.complex128)) + self.assertRaises(RuntimeError, lambda: fp_tensor.view(torch.int8)) + + int_tensor = fp_tensor.view(int_dtype) + self.assertEqual(int_tensor.dtype, int_dtype) + self.assertEqual(int_tensor.shape, fp_tensor.shape) + self.assertEqual(int_tensor.stride(), fp_tensor.stride()) + + self.assertEqual(fp_tensor, int_tensor.view(dtype), rtol=0, atol=0) + self.assertEqual(fp_tensor.cpu().numpy().view(numpy_dtype), int_tensor, rtol=0, atol=0) + + fp_tensor.zero_() + self.assertEqual(fp_tensor, torch.zeros_like(fp_tensor), rtol=0, atol=0) + + for fp_tensor in generate_inputs(): + run_test(fp_tensor) + + # Test that requires_grad is dropped, because view(dtype) does not support backward + if dtype is torch.double: + t = make_tensor((5, 5, 5), device, torch.double, low=-5, high=5, requires_grad=True) + self.assertFalse(t.view(torch.complex64).requires_grad) + + @onlyOnCPUAndCUDA def test_view_as_complex(self, device): def fn(contiguous_input=True, dim0=0, dim1=1): diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 9bf266da394d..9c9719be1ef0 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -86,7 +86,7 @@ # e.g., it is used by _cudnn_rnn # # If you need a complex expression, e.g., with local variables, -# write a _backward function in tools/autograd/templates/Functions.cpp +# write a _backward function in torch/csrc/autograd/FunctionsManual.cpp # and invoke it from here. By the way, go read # https://github.com/zdevito/ATen/issues/163; this describes an # important hazard that occurs when porting backwards from Python to C++ @@ -1055,7 +1055,7 @@ - name: nansum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor self: nansum_backward(grad.to(self.scalar_type()), self, dim, keepdim) -- name: svd(Tensor self, bool some=True, bool compute_uv=True) -> (Tensor U, Tensor S, Tensor V) +- name: _svd_helper(Tensor self, bool some, bool compute_uv) -> (Tensor U, Tensor S, Tensor V) self: svd_backward(grads, self, some, compute_uv, U, S, V) - name: symeig(Tensor self, bool eigenvectors=False, bool upper=True) -> (Tensor eigenvalues, Tensor eigenvectors) @@ -1173,6 +1173,9 @@ - name: view(Tensor(a) self, int[] size) -> Tensor(a) self: grad.reshape(self.sizes()) +- name: view.dtype(Tensor(a) self, ScalarType dtype) -> Tensor(a) + output_differentiability: [False] + - name: view_as_real(Tensor(a) self) -> Tensor(a) self: at::view_as_complex(grad.contiguous()) # gx0 + 1j * gx1 diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index c78e1e5f66cc..e4337e9de855 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -82,7 +82,7 @@ 'bmm', 'diagonal', 'alias', 'atan', 'log', 'log10', 'log1p', 'log2', 'reciprocal', 'tan', 'pow', 'rsqrt', 'tanh', 'tanh_backward', 'asinh', 'acosh', 'take', 'fill_', 'exp', 'nonzero', 'mean', 'inverse', 'solve', 'linalg_cholesky', 'addcmul', 'addcdiv', - 'matrix_exp', 'linalg_eigh', 'cholesky_solve', 'linalg_qr', 'svd', '_fft_c2c', '_fft_r2c', + 'matrix_exp', 'linalg_eigh', 'cholesky_solve', 'linalg_qr', '_svd_helper', '_fft_c2c', '_fft_r2c', 'linalg_solve', 'sqrt', 'stack', 'gather', 'index_select', 'index_add_' } diff --git a/tools/pyi/gen_pyi.py b/tools/pyi/gen_pyi.py index 7ad514d5d067..c5a9209d36b7 100644 --- a/tools/pyi/gen_pyi.py +++ b/tools/pyi/gen_pyi.py @@ -302,6 +302,9 @@ def gen_pyi(native_yaml_path: str, deprecated_yaml_path: str, out: str) -> None: 'sparse_coo_tensor': ['def sparse_coo_tensor(indices: Tensor, values: Union[Tensor,List],' ' size: Optional[_size]=None, *, dtype: Optional[_dtype]=None,' ' device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ...'], + '_sparse_coo_tensor_unsafe': ['def _sparse_coo_tensor_unsafe(indices: Tensor, values: Tensor, size: List[int],' + ' dtype: Optional[_dtype] = None, device: Optional[_device] = None,' + ' requires_grad: bool = False) -> Tensor: ...'], 'range': ['def range(start: Number, end: Number,' ' step: Number=1, *, out: Optional[Tensor]=None, {}) -> Tensor: ...' .format(FACTORY_PARAMS)], diff --git a/torch/_jit_internal.py b/torch/_jit_internal.py index fca1c45377e7..4ccffc2c8362 100644 --- a/torch/_jit_internal.py +++ b/torch/_jit_internal.py @@ -80,7 +80,7 @@ def parseExpr(expr, module): value, len_parsed = parseNestedExpr(expr, module) assert len_parsed == len(expr), "whole expression was not parsed, falling back to c++ parser" return value - except Exception as e: + except Exception: """ The python resolver fails in several cases in known unit tests, and is intended to fall back gracefully to the c++ resolver in general. For example, python 2 style diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index 588c59ef98a6..83bc04113672 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -432,45 +432,9 @@ def add_docstr_all(method, docstr): add_docstr_all('all', r""" -.. function:: all() -> bool +all(dim=None, keepdim=False) -> Tensor -Returns True if all elements in the tensor are True, False otherwise. - -Example:: - - >>> a = torch.rand(1, 2).bool() - >>> a - tensor([[False, True]], dtype=torch.bool) - >>> a.all() - tensor(False, dtype=torch.bool) - -.. function:: all(dim, keepdim=False, out=None) -> Tensor - -Returns True if all elements in each row of the tensor in the given -dimension :attr:`dim` are True, False otherwise. - -If :attr:`keepdim` is ``True``, the output tensor is of the same size as -:attr:`input` except in the dimension :attr:`dim` where it is of size 1. -Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting -in the output tensor having 1 fewer dimension than :attr:`input`. - -Args: - dim (int): the dimension to reduce - keepdim (bool): whether the output tensor has :attr:`dim` retained or not - out (Tensor, optional): the output tensor - -Example:: - - >>> a = torch.rand(4, 2).bool() - >>> a - tensor([[True, True], - [True, False], - [True, True], - [True, True]], dtype=torch.bool) - >>> a.all(dim=1) - tensor([ True, False, True, True], dtype=torch.bool) - >>> a.all(dim=0) - tensor([ True, False], dtype=torch.bool) +See :func:`torch.all` """) add_docstr_all('allclose', @@ -489,45 +453,9 @@ def add_docstr_all(method, docstr): add_docstr_all('any', r""" -.. function:: any() -> bool - -Returns True if any elements in the tensor are True, False otherwise. - -Example:: - - >>> a = torch.rand(1, 2).bool() - >>> a - tensor([[False, True]], dtype=torch.bool) - >>> a.any() - tensor(True, dtype=torch.bool) +any(dim=None, keepdim=False) -> Tensor -.. function:: any(dim, keepdim=False, out=None) -> Tensor - -Returns True if any elements in each row of the tensor in the given -dimension :attr:`dim` are True, False otherwise. - -If :attr:`keepdim` is ``True``, the output tensor is of the same size as -:attr:`input` except in the dimension :attr:`dim` where it is of size 1. -Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting -in the output tensor having 1 fewer dimension than :attr:`input`. - -Args: - dim (int): the dimension to reduce - keepdim (bool): whether the output tensor has :attr:`dim` retained or not - out (Tensor, optional): the output tensor - -Example:: - - >>> a = torch.randn(4, 2) < 0 - >>> a - tensor([[ True, True], - [False, True], - [ True, True], - [False, False]]) - >>> a.any(1) - tensor([ True, True, True, False]) - >>> a.any(0) - tensor([True, True]) +See :func:`torch.any` """) add_docstr_all('apply_', @@ -4254,6 +4182,51 @@ def callable(a, b) -> number >>> torch.equal(b, c) False + +.. function:: view(dtype) -> Tensor + +Returns a new tensor with the same data as the :attr:`self` tensor but of a +different :attr:`dtype`. :attr:`dtype` must have the same number of bytes per +element as :attr:`self`'s dtype. + +.. warning:: + + This overload is not supported by TorchScript, and using it in a Torchscript + program will cause undefined behavior. + + +Args: + dtype (:class:`torch.dtype`): the desired dtype + +Example:: + + >>> x = torch.randn(4, 4) + >>> x + tensor([[ 0.9482, -0.0310, 1.4999, -0.5316], + [-0.1520, 0.7472, 0.5617, -0.8649], + [-2.4724, -0.0334, -0.2976, -0.8499], + [-0.2109, 1.9913, -0.9607, -0.6123]]) + >>> x.dtype + torch.float32 + + >>> y = x.view(torch.int32) + >>> y + tensor([[ 1064483442, -1124191867, 1069546515, -1089989247], + [-1105482831, 1061112040, 1057999968, -1084397505], + [-1071760287, -1123489973, -1097310419, -1084649136], + [-1101533110, 1073668768, -1082790149, -1088634448]], + dtype=torch.int32) + >>> y[0, 0] = 1000000000 + >>> x + tensor([[ 0.0047, -0.0310, 1.4999, -0.5316], + [-0.1520, 0.7472, 0.5617, -0.8649], + [-2.4724, -0.0334, -0.2976, -0.8499], + [-0.2109, 1.9913, -0.9607, -0.6123]]) + + >>> x.view(torch.int16) + Traceback (most recent call last): + File "", line 1, in + RuntimeError: Viewing a tensor as a new dtype with a different number of bytes per element is not supported. """) add_docstr_all('view_as', diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 4a1c36df7497..d4b377a23750 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -610,6 +610,113 @@ def merge_dicts(*dicts): True """) +add_docstr(torch.all, + r""" +all(input) -> Tensor + +Tests if all elements in :attr:`input` evaluate to `True`. + +.. note:: This function matches the behaviour of NumPy in returning + output of dtype `bool` for all supported dtypes except `uint8`. + For `uint8` the dtype of output is `uint8` itself. + +Example:: + + >>> a = torch.rand(1, 2).bool() + >>> a + tensor([[False, True]], dtype=torch.bool) + >>> torch.all(a) + tensor(False, dtype=torch.bool) + >>> a = torch.arange(0, 3) + >>> a + tensor([0, 1, 2]) + >>> torch.all(a) + tensor(False) + +.. function:: all(input, dim, keepdim=False, *, out=None) -> Tensor + +For each row of :attr:`input` in the given dimension :attr:`dim`, +returns `True` if all elements in the row evaluate to `True` and `False` otherwise. + +{keepdim_details} + +Args: + {input} + {dim} + {keepdim} + +Keyword args: + {out} + +Example:: + + >>> a = torch.rand(4, 2).bool() + >>> a + tensor([[True, True], + [True, False], + [True, True], + [True, True]], dtype=torch.bool) + >>> torch.all(a, dim=1) + tensor([ True, False, True, True], dtype=torch.bool) + >>> torch.all(a, dim=0) + tensor([ True, False], dtype=torch.bool) +""".format(**single_dim_common)) + +add_docstr(torch.any, + r""" +any(input) -> Tensor + +Args: + {input} + +Tests if any element in :attr:`input` evaluates to `True`. + +.. note:: This function matches the behaviour of NumPy in returning + output of dtype `bool` for all supported dtypes except `uint8`. + For `uint8` the dtype of output is `uint8` itself. + +Example:: + + >>> a = torch.rand(1, 2).bool() + >>> a + tensor([[False, True]], dtype=torch.bool) + >>> torch.any(a) + tensor(True, dtype=torch.bool) + >>> a = torch.arange(0, 3) + >>> a + tensor([0, 1, 2]) + >>> torch.any(a) + tensor(True) + +.. function:: any(input, dim, keepdim=False, *, out=None) -> Tensor + +For each row of :attr:`input` in the given dimension :attr:`dim`, +returns `True` if any element in the row evaluate to `True` and `False` otherwise. + +{keepdim_details} + +Args: + {input} + {dim} + {keepdim} + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(4, 2) < 0 + >>> a + tensor([[ True, True], + [False, True], + [ True, True], + [False, False]]) + >>> torch.any(a, 1) + tensor([ True, True, True, False]) + >>> torch.any(a, 0) + tensor([True, True]) +""".format(**single_dim_common)) + add_docstr(torch.angle, r""" angle(input, *, out=None) -> Tensor @@ -6676,11 +6783,10 @@ def merge_dicts(*dicts): If :attr:`some` is ``True``, then this function returns the thin (reduced) QR factorization. Otherwise, if :attr:`some` is ``False``, this function returns the complete QR factorization. -.. warning:: ``torch.qr`` is deprecated. Please use ``torch.linalg.`` :meth:`~torch.linalg.qr` - instead, which provides a better compatibility with - ``numpy.linalg.qr``. +.. warning:: ``torch.qr`` is deprecated. Please use ``torch.linalg.`` :func:`~torch.linalg.qr` + instead. - **Differences with** ``torch.linalg.`` :meth:`~torch.linalg.qr`: + **Differences with** ``torch.linalg.qr``: * ``torch.linalg.qr`` takes a string parameter ``mode`` instead of ``some``: @@ -6698,21 +6804,21 @@ def merge_dicts(*dicts): .. note:: This function uses LAPACK for CPU inputs and MAGMA for CUDA inputs, and may produce different (valid) decompositions on different device types - and different platforms, depending on the precise version of the - underlying library. + or different platforms. Args: input (Tensor): the input tensor of size :math:`(*, m, n)` where `*` is zero or more batch dimensions consisting of matrices of dimension :math:`m \times n`. some (bool, optional): Set to ``True`` for reduced QR decomposition and ``False`` for - complete QR decomposition. + complete QR decomposition. If `k = min(m, n)` then: + + * ``some=True`` : returns `(Q, R)` with dimensions (m, k), (k, n) (default) + + * ``'some=False'``: returns `(Q, R)` with dimensions (m, m), (m, n) Keyword args: - out (tuple, optional): tuple of `Q` and `R` tensors - satisfying :code:`input = torch.matmul(Q, R)`. - The dimensions of `Q` and `R` are :math:`(*, m, k)` and :math:`(*, k, n)` - respectively, where :math:`k = \min(m, n)` if :attr:`some:` is ``True`` and - :math:`k = m` otherwise. + out (tuple, optional): tuple of `Q` and `R` tensors. + The dimensions of `Q` and `R` are detailed in the description of :attr:`some` above. Example:: @@ -8142,18 +8248,45 @@ def merge_dicts(*dicts): r""" svd(input, some=True, compute_uv=True, *, out=None) -> (Tensor, Tensor, Tensor) -This function returns a namedtuple ``(U, S, V)`` which is the singular value -decomposition of a input matrix or batches of matrices :attr:`input` such that -:math:`input = U \times diag(S) \times V^T`. +Computes the singular value decomposition of either a matrix or batch of +matrices :attr:`input`." The singular value decomposition is represented as a +namedtuple ``(U, S, V)``, such that :math:`input = U \mathbin{@} diag(S) \times +V^T`, where :math:`V^T` is the transpose of ``V``. If :attr:`input` is a batch +of tensors, then ``U``, ``S``, and ``V`` are also batched with the same batch +dimensions as :attr:`input`. + +If :attr:`some` is ``True`` (default), the method returns the reduced singular +value decomposition i.e., if the last two dimensions of :attr:`input` are +``m`` and ``n``, then the returned `U` and `V` matrices will contain only +:math:`min(n, m)` orthonormal columns. + +If :attr:`compute_uv` is ``False``, the returned `U` and `V` will be +zero-filled matrices of shape :math:`(m \times m)` and :math:`(n \times n)` +respectively, and the same device as :attr:`input`. The :attr:`some` +argument has no effect when :attr:`compute_uv` is False. + +The dtypes of ``U`` and ``V`` are the same as :attr:`input`'s. ``S`` will +always be real-valued, even if :attr:`input` is complex. -If :attr:`some` is ``True`` (default), the method returns the reduced -singular value decomposition i.e., if the last two dimensions of -:attr:`input` are ``m`` and ``n``, then the returned `U` matrix will -contain only :math:`min(n, m)` orthonormal columns and the size of `V` -will be :math:`(*, n, n)`. +.. warning:: ``torch.svd`` is deprecated. Please use ``torch.linalg.`` + :func:`~torch.linalg.svd` instead, which is similar to NumPy's + ``numpy.linalg.svd``. -If :attr:`compute_uv` is ``False``, the returned `U` and `V` matrices will be zero matrices -of shape :math:`(m \times m)` and :math:`(n \times n)` respectively. :attr:`some` will be ignored here. +.. note:: **Differences with** ``torch.linalg.`` :func:`~torch.linalg.svd`: + + * :attr:`some` is the opposite of ``torch.linalg.`` + :func:`~torch.linalg.svd`'s :attr:`full_matricies`. Note that + default value for both is ``True``, so the default behavior is + effectively the opposite. + + * it returns ``V``, whereas ``torch.linalg.`` + :func:`~torch.linalg.svd` returns ``Vh``. The result is that + when using ``svd`` you need to manually transpose + ``V`` in order to reconstruct the original matrix. + + * If :attr:`compute_uv=False`, it returns zero-filled tensors for + ``U`` and ``Vh``, whereas :meth:`~torch.linalg.svd` returns + empty tensors. Supports real-valued and complex-valued input. @@ -8164,22 +8297,18 @@ def merge_dicts(*dicts): algorithm) instead of `?gesvd` for speed. Analogously, the SVD on GPU uses the MAGMA routine `gesdd` as well. -.. note:: Irrespective of the original strides, the returned matrix `U` - will be transposed, i.e. with strides :code:`U.contiguous().transpose(-2, -1).stride()` +.. note:: The returned matrix `U` will be transposed, i.e. with strides + :code:`U.contiguous().transpose(-2, -1).stride()`. -.. note:: Extra care needs to be taken when backward through `U` and `V` - outputs. Such operation is really only stable when :attr:`input` is - full rank with all distinct singular values. Otherwise, ``NaN`` can - appear as the gradients are not properly defined. Also, notice that - double backward will usually do an additional backward through `U` and - `V` even if the original backward is only on `S`. +.. note:: Gradients computed using `U` and `V` may be unstable if + :attr:`input` is not full rank or has non-unique singular values. .. note:: When :attr:`some` = ``False``, the gradients on :code:`U[..., :, min(m, n):]` and :code:`V[..., :, min(m, n):]` will be ignored in backward as those vectors can be arbitrary bases of the subspaces. -.. note:: When :attr:`compute_uv` = ``False``, backward cannot be performed since `U` and `V` - from the forward pass is required for the backward operation. +.. note:: The `S` tensor can only be used to compute gradients if :attr:`compute_uv` is True. + .. note:: With the complex-valued input the backward operation works correctly only for gauge invariant loss functions. Please look at `Gauge problem in AD`_ for more details. @@ -8187,8 +8316,9 @@ def merge_dicts(*dicts): Args: input (Tensor): the input tensor of size :math:`(*, m, n)` where `*` is zero or more batch dimensions consisting of :math:`m \times n` matrices. - some (bool, optional): controls the shape of returned `U` and `V` - compute_uv (bool, optional): option whether to compute `U` and `V` or not + some (bool, optional): controls whether to compute the reduced or full decomposition, and + consequently the shape of returned ``U`` and ``V``. Defaults to True. + compute_uv (bool, optional): option whether to compute `U` and `V` or not. Defaults to True. Keyword args: out (tuple, optional): the output tuple of tensors diff --git a/torch/_utils.py b/torch/_utils.py index 796e88a3cc2d..75eadd4a990e 100644 --- a/torch/_utils.py +++ b/torch/_utils.py @@ -1,6 +1,6 @@ import torch import torch._six -from typing import Optional +from typing import Optional, List, DefaultDict import warnings from collections import defaultdict import sys @@ -37,9 +37,9 @@ def _type(self, dtype=None, non_blocking=False, **kwargs): raise RuntimeError("Cannot cast sparse tensor to dense tensor") new_module_name = dtype.__module__.replace('.sparse', '') new_values_type_name = new_module_name + '.' + dtype.__name__ - new_values = torch._values(self).type(new_values_type_name, non_blocking) + new_values = torch.Tensor._values(self).type(new_values_type_name, non_blocking) new_indices_type_name = new_module_name + '.LongTensor' - new_indices = torch._indices(self).type(new_indices_type_name, non_blocking) + new_indices = torch.Tensor._indices(self).type(new_indices_type_name, non_blocking) return dtype(new_indices, new_values, self.size()) if dtype.is_sparse: raise RuntimeError("Cannot cast dense tensor to sparse tensor") @@ -72,8 +72,8 @@ def _cuda(self, device=None, non_blocking=False, **kwargs): with torch.cuda.device(device): if self.is_sparse: new_type = getattr(torch.cuda.sparse, self.__class__.__name__) - indices = torch._indices(self).cuda(device, non_blocking) - values = torch._values(self).cuda(device, non_blocking) + indices = torch.Tensor._indices(self).cuda(device, non_blocking) + values = torch.Tensor._values(self).cuda(device, non_blocking) return new_type(indices, values, self.size()) else: new_type = getattr(torch.cuda, self.__class__.__name__) @@ -144,7 +144,7 @@ def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, bac return tensor -_sparse_tensors_to_validate = [] +_sparse_tensors_to_validate: List["torch.Tensor"] = [] # In _legacy_load() in serialization.py we unpickle storages after the sparse # tensors have been already unpickled. Those storages contain data necessary for @@ -271,8 +271,8 @@ def _flatten_sparse_tensors(tensors): A tuple of two contiguous 1D buffers, one containing input tensors' indices and the other containing the values. """ - flat_indices = _flatten_dense_tensors([torch._indices(t) for t in tensors]) - flat_values = _flatten_dense_tensors([torch._values(t) for t in tensors]) + flat_indices = _flatten_dense_tensors([torch.Tensor._indices(t) for t in tensors]) + flat_values = _flatten_dense_tensors([torch.Tensor._values(t) for t in tensors]) return flat_indices, flat_values @@ -314,8 +314,8 @@ def _unflatten_sparse_tensors(flat, tensors): flat. """ flat_indices, flat_values = flat - indices = _unflatten_dense_tensors(flat_indices, [torch._indices(t) for t in tensors]) - values = _unflatten_dense_tensors(flat_values, [torch._values(t) for t in tensors]) + indices = _unflatten_dense_tensors(flat_indices, [torch.Tensor._indices(t) for t in tensors]) + values = _unflatten_dense_tensors(flat_values, [torch.Tensor._values(t) for t in tensors]) outputs = [] for t, i, v in zip(tensors, indices, values): outputs.append(t.new(i, v, t.size())) @@ -340,8 +340,8 @@ def _reorder_tensors_as(tensors, ordered_tensors): type_dict = defaultdict(list) for tensor in tensors: type_dict[tensor.type()].append(tensor) - type_dict = {t: iter(coll) for t, coll in type_dict.items()} - return tuple(next(type_dict[tensor.type()]) for tensor in ordered_tensors) + type_dict_ = {t: iter(coll) for t, coll in type_dict.items()} + return tuple(next(type_dict_[tensor.type()]) for tensor in ordered_tensors) def _take_tensors(tensors, size_limit): @@ -356,12 +356,12 @@ def _take_tensors(tensors, size_limit): Blocks of tensors of same type and within size_limit. The yielded tensors are only ordered as the original sequence within its types. """ - buf_dict = defaultdict(lambda: [[], 0]) + buf_dict: DefaultDict[str, List] = defaultdict(lambda: [[], 0]) for tensor in tensors: t = tensor.type() if tensor.is_sparse: - indices = torch._indices(tensor) - values = torch._values(tensor) + indices = torch.Tensor._indices(tensor) + values = torch.Tensor._values(tensor) size = indices.numel() * indices.element_size() + values.numel() * values.element_size() else: size = tensor.numel() * tensor.element_size() diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 6558295d58cb..79d195d73a74 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -2078,7 +2078,7 @@ Tensor linalg_qr_backward(const std::vector &grads, c std::string mode, const Tensor& q, const Tensor& r){ bool compute_q, reduced; std::tie(compute_q, reduced) = at::native::_parse_qr_mode(mode); - TORCH_CHECK(compute_q, "linalg_qr_backward: cannot compute backward if mode='r'. " + TORCH_CHECK(compute_q, "The derivative of qr is not implemented when mode='r'. " "Please use torch.linalg.qr(..., mode='reduced')"); auto square_deep_case_backward = [](const Tensor& grad_Q, @@ -2145,7 +2145,7 @@ Tensor linalg_qr_backward(const std::vector &grads, c TORCH_CHECK( ((m <= n && (!reduced)) || reduced), - "The derivative is not implemented when nrows > ncols and complete QR. "); + "The derivative of qr is not implemented when mode='complete' and nrows > ncols."); auto grad_Q = grads[0]; auto grad_R = grads[1]; diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp index f906efb187ef..ad217c2924ad 100644 --- a/torch/csrc/autograd/engine.cpp +++ b/torch/csrc/autograd/engine.cpp @@ -916,7 +916,6 @@ std::shared_ptr Engine::execute_with_graph_task( std::unique_lock lock(graph_task->mutex_); auto queue = ready_queue(graph_task->cpu_ready_queue_, input_buffer.device()); - queue->push(NodeTask(graph_task, std::move(graph_root), std::move(input_buffer))); // worker_device == NO_DEVICE it's a CPU thread and it's trying to drive the // autograd engine with corresponding GraphTask, and its NOT a re-entrant call @@ -929,8 +928,12 @@ std::shared_ptr Engine::execute_with_graph_task( // set the graph_task owner to the current device graph_task->owner_ = worker_device; - // The owning thread start to drive the engine execution with the GraphTask - // that has already been pushed to the current CPU thread's ready_queue + // Now that all the non-thread safe fields of the graph_task have been populated, + // we can enqueue it. + queue->push(NodeTask(graph_task, std::move(graph_root), std::move(input_buffer))); + + // The owning thread start to drive the engine execution for any CPU task that + // was just pushed or will be added later from other worker threads lock.unlock(); thread_main(graph_task); TORCH_INTERNAL_ASSERT(graph_task->future_result_->completed()); @@ -943,6 +946,11 @@ std::shared_ptr Engine::execute_with_graph_task( // If worker_device is any devices (i.e. CPU, CUDA): this is a re-entrant // backward call from that device. graph_task->owner_ = worker_device; + + // Now that all the non-thread safe fields of the graph_task have been populated, + // we can enqueue it. + queue->push(NodeTask(graph_task, std::move(graph_root), std::move(input_buffer))); + if (current_depth >= max_recursion_depth_) { // See Note [Reentrant backwards] // If reached the max depth, switch to a different thread diff --git a/torch/csrc/distributed/autograd/context/container.cpp b/torch/csrc/distributed/autograd/context/container.cpp index ee3939010e8a..6948de958b84 100644 --- a/torch/csrc/distributed/autograd/context/container.cpp +++ b/torch/csrc/distributed/autograd/context/container.cpp @@ -245,14 +245,17 @@ void DistAutogradContainer::sendReleaseContextRpc( CleanupAutogradContextReq(context_id).toMessage(), options); + std::weak_ptr wp = cleanupFuture; cleanupFuture->addCallback( - [worker_id](const rpc::FutureMessage& cleanupFuture) { - if (cleanupFuture.hasError()) { + [worker_id, wp]() { + auto future = wp.lock(); + TORCH_INTERNAL_ASSERT(future); + if (future->hasError()) { std::string errorMsg = c10::str( "Could not release Dist Autograd Context on node ", worker_id, ": ", - cleanupFuture.error()->what()); + future->tryRetrieveErrorMessage()); LOG(ERROR) << errorMsg; return; } diff --git a/torch/csrc/distributed/autograd/context/context.cpp b/torch/csrc/distributed/autograd/context/context.cpp index 6527fc25b92b..526ca053dd40 100644 --- a/torch/csrc/distributed/autograd/context/context.cpp +++ b/torch/csrc/distributed/autograd/context/context.cpp @@ -123,26 +123,27 @@ void DistAutogradContext::resetGraphTask() { } void DistAutogradContext::addOutstandingRpc( - const std::shared_ptr& futureMessage) { - futureMessage->addCallback([this](const rpc::FutureMessage& futureMessage) { - if (futureMessage.hasError()) { + const std::shared_ptr& jitFuture) { + std::weak_ptr wp = jitFuture; + jitFuture->addCallback([this, wp]() { + auto future = wp.lock(); + if (future->hasError()) { // If we have an error, let the local autograd engine know about it. std::unique_lock lock(lock_); if (graphTask_) { graphTask_->set_exception_without_signal(nullptr); lock.unlock(); if (!graphTask_->future_completed_.exchange(true)) { - graphTask_->future_result_->setErrorIfNeeded( - std::make_exception_ptr(*futureMessage.error())); + graphTask_->future_result_->setErrorIfNeeded(future->exception_ptr()); } } else { LOG(WARNING) << "Ignoring error since GraphTask is no longer valid: " - << (*futureMessage.error()).what(); + << future->tryRetrieveErrorMessage(); } } }); std::lock_guard guard(lock_); - outStandingRpcs_.push_back(futureMessage); + outStandingRpcs_.push_back(jitFuture); } void DistAutogradContext::clearOutstandingRpcs() { @@ -170,8 +171,10 @@ std::shared_ptr DistAutogradContext:: state->future->markCompleted(c10::IValue()); } else { for (auto& rpc : outStandingRpcs) { - rpc->addCallback([state](const rpc::FutureMessage& rpc) { - if (rpc.hasError()) { + std::weak_ptr wp = rpc; + rpc->addCallback([state, wp]() { + auto future = wp.lock(); + if (future->hasError()) { // If there's an error, we want to setError() on the future, // unless another error has already been sent - use a CAS to // guard. @@ -183,7 +186,7 @@ std::shared_ptr DistAutogradContext:: bool expectedAlreadySent = false; if (state->alreadySentError.compare_exchange_strong( expectedAlreadySent, true)) { - state->future->setError(std::make_exception_ptr(*rpc.error())); + state->future->setError(future->exception_ptr()); } return; } diff --git a/torch/csrc/distributed/autograd/context/context.h b/torch/csrc/distributed/autograd/context/context.h index e7d73962634b..b611040af448 100644 --- a/torch/csrc/distributed/autograd/context/context.h +++ b/torch/csrc/distributed/autograd/context/context.h @@ -52,7 +52,7 @@ class TORCH_API DistAutogradContext { // Adds a future message recording an outstanding RPC. void addOutstandingRpc( - const std::shared_ptr& futureMessage); + const std::shared_ptr& jitFuture); // Returns all gradients. const c10::Dict getGradients() const; @@ -134,7 +134,7 @@ class TORCH_API DistAutogradContext { // List of futures for RPCs initiated by this node to propagate gradients to // other nodes. The distributed autograd engine on this node can return // successfully only if all these futures are done and are successful. - std::vector> outStandingRpcs_; + std::vector> outStandingRpcs_; // Lock to protect concurrent modification of the context. mutable std::mutex lock_; diff --git a/torch/csrc/distributed/autograd/functions/recvrpc_backward.cpp b/torch/csrc/distributed/autograd/functions/recvrpc_backward.cpp index 99951f098a22..509c5c6cbd08 100644 --- a/torch/csrc/distributed/autograd/functions/recvrpc_backward.cpp +++ b/torch/csrc/distributed/autograd/functions/recvrpc_backward.cpp @@ -47,11 +47,12 @@ variable_list RecvRpcBackward::apply(variable_list&& grads) { // Send the gradients over to the appropriate node. auto rpcAgent = rpc::RpcAgent::getCurrentRpcAgent(); - auto futureMessage = rpcAgent->send( - rpcAgent->getWorkerInfo(fromWorkerId_), std::move(gradCall).toMessage()); + auto jitFuture = rpcAgent->send( + rpcAgent->getWorkerInfo(fromWorkerId_), + std::move(gradCall).toMessage()); // Record the future in the context. - sharedContext->addOutstandingRpc(futureMessage); + sharedContext->addOutstandingRpc(jitFuture); // 'recv' function sends the gradients over the wire using RPC, it doesn't // need to return anything for any downstream autograd function. diff --git a/torch/csrc/distributed/autograd/utils.cpp b/torch/csrc/distributed/autograd/utils.cpp index 464d8248d8a4..08bb99471686 100644 --- a/torch/csrc/distributed/autograd/utils.cpp +++ b/torch/csrc/distributed/autograd/utils.cpp @@ -16,7 +16,7 @@ namespace autograd { using torch::distributed::autograd::AutogradMetadata; using torch::distributed::autograd::RpcWithAutograd; -using torch::distributed::rpc::FutureMessage; +using torch::distributed::rpc::JitFuture; using torch::distributed::rpc::Message; using torch::distributed::rpc::MessageType; using torch::distributed::rpc::RpcAgent; @@ -138,7 +138,7 @@ Message getMessageWithAutograd( return std::move(*rpcWithAutograd).toMessage(); } -std::shared_ptr sendMessageWithAutograd( +std::shared_ptr sendMessageWithAutograd( RpcAgent& agent, const WorkerInfo& dst, torch::distributed::rpc::Message&& wrappedRpcMsg, @@ -151,7 +151,7 @@ std::shared_ptr sendMessageWithAutograd( MessageType::FORWARD_AUTOGRAD_REQ, forceGradRecording); - std::shared_ptr fut; + std::shared_ptr fut; // If profiler is enabled, wrap this message with profiling metadata that will // tell the remote end to process this request with the profiler enabled. if (!forceDisableProfiling && torch::autograd::profiler::profilerEnabled()) { diff --git a/torch/csrc/distributed/autograd/utils.h b/torch/csrc/distributed/autograd/utils.h index 2a0a066e1a95..07ba45ed60d7 100644 --- a/torch/csrc/distributed/autograd/utils.h +++ b/torch/csrc/distributed/autograd/utils.h @@ -45,7 +45,7 @@ TORCH_API rpc::Message getMessageWithAutograd( bool forceGradRecording = false); // Send message after autograd checking -TORCH_API std::shared_ptr +TORCH_API std::shared_ptr sendMessageWithAutograd( rpc::RpcAgent& agent, const rpc::WorkerInfo& dst, diff --git a/torch/csrc/distributed/rpc/init.cpp b/torch/csrc/distributed/rpc/init.cpp index 6204f6e343ed..f1f8e39cd7e3 100644 --- a/torch/csrc/distributed/rpc/init.cpp +++ b/torch/csrc/distributed/rpc/init.cpp @@ -44,7 +44,8 @@ PyObject* rpc_init(PyObject* _unused, PyObject* noargs) { } auto torch_C_m = py::handle(torch_C_module).cast(); - auto m = torch_C_m.def_submodule("_distributed_rpc", "distributed rpc bindings"); + auto m = + torch_C_m.def_submodule("_distributed_rpc", "distributed rpc bindings"); auto module = py::handle(m).cast(); @@ -129,12 +130,12 @@ PyObject* rpc_init(PyObject* _unused, PyObject* noargs) { .def( "get_worker_info", (const WorkerInfo& (RpcAgent::*)(void)const) & - RpcAgent::getWorkerInfo, + RpcAgent::getWorkerInfo, py::call_guard()) - .def( + .def( "get_worker_info", (const WorkerInfo& (RpcAgent::*)(const std::string&)const) & - RpcAgent::getWorkerInfo, + RpcAgent::getWorkerInfo, py::call_guard()) .def( "get_worker_infos", diff --git a/torch/csrc/distributed/rpc/message.cpp b/torch/csrc/distributed/rpc/message.cpp index 14c267fd0699..b35e9149d1e6 100644 --- a/torch/csrc/distributed/rpc/message.cpp +++ b/torch/csrc/distributed/rpc/message.cpp @@ -104,6 +104,18 @@ Message createExceptionResponse(const std::string& exceptionStr, int64_t id) { id); } +namespace { + +// NB: need to call torch::class_ to register Message in the map returned by +// c10::getCustomClassTypeMap(). Otherwise, Message cannot be wrapped within +// an IValue. +// NB: add this line here instead of in rpc/init.cpp because 1) we have C++ +// only tests that won't run rpc/init.cpp; 2) Message is not meant to be +// visible from Python. +static const auto message = torch::class_("rpc", "_Message"); + +} // namespace + } // namespace rpc } // namespace distributed } // namespace torch diff --git a/torch/csrc/distributed/rpc/message.h b/torch/csrc/distributed/rpc/message.h index 7e458810db69..3d2d623e821f 100644 --- a/torch/csrc/distributed/rpc/message.h +++ b/torch/csrc/distributed/rpc/message.h @@ -35,19 +35,35 @@ enum MessageType { PYTHON_RET = 3 | MessageTypeFlags::RESPONSE_TYPE, // messages for dist.remote on builtin operators and Python UDF - SCRIPT_REMOTE_CALL = 4 | MessageTypeFlags::REQUEST_TYPE, // A remote call on a builtin operator - PYTHON_REMOTE_CALL = 5 | MessageTypeFlags::REQUEST_TYPE, // A remote call on a Python UDF - REMOTE_RET = 6 | MessageTypeFlags::RESPONSE_TYPE, // Response for remote calls for UDF, builtin, or script + SCRIPT_REMOTE_CALL = + 4 | MessageTypeFlags::REQUEST_TYPE, // A remote call on a builtin operator + PYTHON_REMOTE_CALL = + 5 | MessageTypeFlags::REQUEST_TYPE, // A remote call on a Python UDF + REMOTE_RET = + 6 | MessageTypeFlags::RESPONSE_TYPE, // Response for remote calls for UDF, + // builtin, or script // RRef related internal messages - SCRIPT_RREF_FETCH_CALL = 7 | MessageTypeFlags::REQUEST_TYPE, // A UserRRef fetches value from owner - PYTHON_RREF_FETCH_CALL = 8 | MessageTypeFlags::REQUEST_TYPE, // A UserRRef fetches value from owner - SCRIPT_RREF_FETCH_RET = 9 | MessageTypeFlags::RESPONSE_TYPE, // An OwnerRRef sends ivalue to user - PYTHON_RREF_FETCH_RET = 10 | MessageTypeFlags::RESPONSE_TYPE, // An OwnerRRef sends py::object to user - RREF_USER_DELETE = 11 | MessageTypeFlags::REQUEST_TYPE, // A UserRRef tells the owner to deref - RREF_FORK_REQUEST = 12 | MessageTypeFlags::REQUEST_TYPE, // A child UserRRef tells the owner about itself - RREF_CHILD_ACCEPT = 13 | MessageTypeFlags::REQUEST_TYPE, // A child UserRRef tells parent that owner knows it - RREF_ACK = 14 | MessageTypeFlags::RESPONSE_TYPE, // ACK to internal RRef messages + SCRIPT_RREF_FETCH_CALL = + 7 | MessageTypeFlags::REQUEST_TYPE, // A UserRRef fetches value + // from owner + PYTHON_RREF_FETCH_CALL = + 8 | MessageTypeFlags::REQUEST_TYPE, // A UserRRef fetches + // value from owner + SCRIPT_RREF_FETCH_RET = + 9 | MessageTypeFlags::RESPONSE_TYPE, // An OwnerRRef sends ivalue to user + PYTHON_RREF_FETCH_RET = 10 | + MessageTypeFlags::RESPONSE_TYPE, // An OwnerRRef sends py::object to user + RREF_USER_DELETE = 11 | + MessageTypeFlags::REQUEST_TYPE, // A UserRRef tells the owner to deref + RREF_FORK_REQUEST = + 12 | MessageTypeFlags::REQUEST_TYPE, // A child UserRRef tells the owner + // about itself + RREF_CHILD_ACCEPT = + 13 | MessageTypeFlags::REQUEST_TYPE, // A child UserRRef tells parent that + // owner knows it + RREF_ACK = + 14 | MessageTypeFlags::RESPONSE_TYPE, // ACK to internal RRef messages // Messages with autograd info FORWARD_AUTOGRAD_REQ = 15 | MessageTypeFlags::REQUEST_TYPE, @@ -93,7 +109,7 @@ enum MessageType { // Layers above ``RpcAgent`` only converts ScriptCall, ScriptResp, PythonCall, // and PythonResp into a Message, and it is up to the RpcAgent // implementation to determine how to serialize a message. -class TORCH_API Message final { +class TORCH_API Message final : public torch::CustomClassHolder { public: Message(); @@ -154,9 +170,6 @@ TORCH_API Message createExceptionResponse(const std::exception& e, int64_t id); TORCH_API Message createExceptionResponse(const std::string& exceptionStr, int64_t id); -// FutureMessage is an internal type used in the communication layer. All -// user-facing surface APIs should use JitFuture instead. -using FutureMessage = torch::utils::Future; using JitFuture = c10::ivalue::Future; } // namespace rpc diff --git a/torch/csrc/distributed/rpc/process_group_agent.cpp b/torch/csrc/distributed/rpc/process_group_agent.cpp index b106f1442d31..9c1a703cfa6d 100644 --- a/torch/csrc/distributed/rpc/process_group_agent.cpp +++ b/torch/csrc/distributed/rpc/process_group_agent.cpp @@ -287,7 +287,7 @@ void ProcessGroupAgent::shutdownImpl() { threadPool_.waitWorkComplete(); } -std::shared_ptr ProcessGroupAgent::send( +std::shared_ptr ProcessGroupAgent::send( const WorkerInfo& to, Message&& message, const float rpcTimeoutSeconds) { @@ -319,7 +319,7 @@ std::shared_ptr ProcessGroupAgent::send( pg_->getRank()); auto requestId = nextId(); - auto future = std::make_shared(); + auto future = std::make_shared(at::AnyClassType::get()); if (message.isRequest()) { // millisecond level precision of when request started. auto futureStartTime = std::chrono::steady_clock::now(); @@ -362,7 +362,7 @@ std::shared_ptr ProcessGroupAgent::send( message.setId(requestId); ++clientActiveCalls_; } else { - future->markCompleted(Message()); + future->markCompleted(IValue()); } // Sending to ourselves: bypass the send logic and enqueue directly @@ -382,6 +382,7 @@ std::shared_ptr ProcessGroupAgent::send( // the C++ land. Hence, we have to explicitly use the ``WorkerInfo`` in the // C++ land. enqueueSend(SendWork(allWorkerInfo_[to.id_], std::move(message))); + return future; } @@ -513,22 +514,24 @@ bool ProcessGroupAgent::handleRecv(RecvWork& work) { std::move(data.first), std::move(data.second), work.type_, work.id_); if (message.isRequest()) { ++serverActiveCalls_; - std::shared_ptr futureResponse; + std::shared_ptr futureResponse; try { futureResponse = cb_->operator()(message); } catch (const std::exception& e) { - futureResponse = std::make_shared(); - futureResponse->setError(e.what()); + futureResponse = std::make_shared(at::AnyClassType::get()); + futureResponse->setError(std::current_exception()); } if (futureResponse->completed()) { --serverActiveCalls_; if (!futureResponse->hasError()) { - send(work.from_, std::move(*futureResponse).moveValue()); + send( + work.from_, + std::move(*futureResponse->value().toCustomClass())); } else { send( work.from_, createExceptionResponse( - futureResponse->error()->what(), message.id())); + futureResponse->tryRetrieveErrorMessage(), message.id())); } } else { ++serverActiveAsyncCalls_; @@ -537,28 +540,30 @@ bool ProcessGroupAgent::handleRecv(RecvWork& work) { // Use a weak_ptr, so we can std::move the future's value. auto fromId = work.from_.id_; auto requestId = work.id_; - futureResponse->addCallback([this, - fromId, - requestId, - weak = std::weak_ptr( - futureResponse)]() { - auto futureResponse = weak.lock(); - TORCH_INTERNAL_ASSERT(futureResponse); - --serverActiveCalls_; - --serverActiveAsyncCalls_; - if (!futureResponse->hasError()) { - send(getWorkerInfo(fromId), std::move(*futureResponse).moveValue()); - } else { - send( - getWorkerInfo(fromId), - createExceptionResponse( - futureResponse->error()->what(), requestId)); - } - }); + futureResponse->addCallback( + [this, + fromId, + requestId, + weak = std::weak_ptr(futureResponse)]() { + auto futureResponse = weak.lock(); + TORCH_INTERNAL_ASSERT(futureResponse); + --serverActiveCalls_; + --serverActiveAsyncCalls_; + if (!futureResponse->hasError()) { + send( + getWorkerInfo(fromId), + std::move(*futureResponse->value().toCustomClass())); + } else { + send( + getWorkerInfo(fromId), + createExceptionResponse( + futureResponse->tryRetrieveErrorMessage(), requestId)); + } + }); } } else if (message.isResponse()) { auto id = message.id(); - std::shared_ptr fm = nullptr; + std::shared_ptr jitFuture = nullptr; { std::lock_guard lock{futureMutex_}; const auto& futureInfo = futures_.find(id); @@ -570,7 +575,7 @@ bool ProcessGroupAgent::handleRecv(RecvWork& work) { return false; } // Use futureInfo before destructing it. - fm = futureInfo->second.future_; + jitFuture = futureInfo->second.future_; auto endTime = futureInfo->second.endTime_; futures_.erase(id); // look up the corresponding future by its time out and request @@ -589,10 +594,11 @@ bool ProcessGroupAgent::handleRecv(RecvWork& work) { futureCV_.notify_all(); --clientActiveCalls_; if (message.type() == MessageType::EXCEPTION) { - fm->setError( - std::string(message.payload().begin(), message.payload().end())); + jitFuture->setError(std::make_exception_ptr(std::runtime_error( + std::string(message.payload().begin(), message.payload().end())))); } else { - fm->markCompleted(std::move(message)); + jitFuture->markCompleted( + IValue(c10::make_intrusive(std::move(message)))); } } else { // TODO: pass the error back to the caller instead of crashing here. @@ -643,7 +649,7 @@ void ProcessGroupAgent::markFutureWithError(Message& message) { } void ProcessGroupAgent::markFutureWithError(int64_t id, std::string errorMsg) { - std::shared_ptr fm = nullptr; + std::shared_ptr jitFuture = nullptr; { std::lock_guard lock{futureMutex_}; const auto& futureInfo = futures_.find(id); @@ -653,7 +659,7 @@ void ProcessGroupAgent::markFutureWithError(int64_t id, std::string errorMsg) { // out and been processed accordingly. return; } - fm = futureInfo->second.future_; + jitFuture = futureInfo->second.future_; auto rpcEndTime = futureInfo->second.endTime_; futures_.erase(id); // look up the corresponding future by its time out and request ID, @@ -671,7 +677,7 @@ void ProcessGroupAgent::markFutureWithError(int64_t id, std::string errorMsg) { } --clientActiveCalls_; - fm->setError(std::move(errorMsg)); + jitFuture->setError(std::make_exception_ptr(std::runtime_error(errorMsg))); futureCV_.notify_all(); } @@ -803,7 +809,8 @@ void ProcessGroupAgent::pollTimedOutRPCs() { if (!timedOutFuture.future_->hasError()) { --clientActiveCalls_; - timedOutFuture.future_->setError(std::move(err)); + timedOutFuture.future_->setError( + std::make_exception_ptr(std::runtime_error(err))); // The future timed out and will not be processed by handleRecv(), even // if we eventually get a response. In order to keep track of all // send/recv pairs, we increment the count here. diff --git a/torch/csrc/distributed/rpc/process_group_agent.h b/torch/csrc/distributed/rpc/process_group_agent.h index 61d17f03e623..8d2471a7d113 100644 --- a/torch/csrc/distributed/rpc/process_group_agent.h +++ b/torch/csrc/distributed/rpc/process_group_agent.h @@ -88,7 +88,7 @@ class TORCH_API ProcessGroupAgent : public RpcAgent { // This method wraps the destination information and the message into a // SendWork object, and put the SendWork into a queue. Another thread will // consume SendWork from the queue and send it out. - std::shared_ptr send( + std::shared_ptr send( const WorkerInfo& to, Message&& message, const float rpcTimeoutSeconds = kUnsetRpcTimeout) override; @@ -130,16 +130,16 @@ class TORCH_API ProcessGroupAgent : public RpcAgent { // additional information to manage timeouts and destination information, // which is needed for termination detection. struct FutureInfo { - std::shared_ptr future_; + std::shared_ptr future_; steady_clock_time_point endTime_; int dstRank_; std::chrono::milliseconds timeout_; FutureInfo( - const std::shared_ptr& future, + std::shared_ptr future, const steady_clock_time_point& endTime, int dstRank, const std::chrono::milliseconds timeout) - : future_(future), + : future_(std::move(future)), endTime_(endTime), dstRank_(dstRank), timeout_(timeout) {} diff --git a/torch/csrc/distributed/rpc/py_rref.cpp b/torch/csrc/distributed/rpc/py_rref.cpp index 1c955a6baefb..cface0b88551 100644 --- a/torch/csrc/distributed/rpc/py_rref.cpp +++ b/torch/csrc/distributed/rpc/py_rref.cpp @@ -137,8 +137,7 @@ c10::intrusive_ptr PyRRef::getFuture() const { // Marking hasValue to false, as this Future is only used for signaling // profiler to update profiling result and the profiler does not retrieve // any value from it. - return wrapFutureMessageInJitFuture( - rref_->getOwnerCreationFuture(), false /* hasValue */); + return toPyJitFuture(rref_->getOwnerCreationFuture(), false /* hasValue */); } c10::intrusive_ptr PyRRef::getProfilingFuture() const { @@ -335,7 +334,7 @@ void PyRRef::backward( ->send( rpcAgent->getWorkerInfo(rref->owner()), std::move(rrefBackwardReq).toMessage()) - ->wait(); + ->waitAndThrow(); } } diff --git a/torch/csrc/distributed/rpc/py_rref.h b/torch/csrc/distributed/rpc/py_rref.h index 1b61c8c2fbf0..e3ae22fb8320 100644 --- a/torch/csrc/distributed/rpc/py_rref.h +++ b/torch/csrc/distributed/rpc/py_rref.h @@ -55,7 +55,10 @@ class PYBIND11_EXPORT PyRRef { void backward(int64_t autogradContextId, bool retainGraph); // Helper static function to run backward on a given rref. - static void backward(int64_t autogradContextId, bool retainGraph, const c10::intrusive_ptr& rref); + static void backward( + int64_t autogradContextId, + bool retainGraph, + const c10::intrusive_ptr& rref); private: c10::intrusive_ptr rref_; diff --git a/torch/csrc/distributed/rpc/python_functions.cpp b/torch/csrc/distributed/rpc/python_functions.cpp index 5e2e8304b7bd..383d6df9cee0 100644 --- a/torch/csrc/distributed/rpc/python_functions.cpp +++ b/torch/csrc/distributed/rpc/python_functions.cpp @@ -24,7 +24,7 @@ namespace rpc { namespace { -IValue toIValue(const Message& message) { +IValue toPyIValue(const Message& message) { MessageType msgType = message.type(); auto response = deserializeResponse(message, msgType); switch (msgType) { @@ -109,7 +109,7 @@ std::shared_ptr matchBuiltinOp( return matchedOperator; } -std::shared_ptr sendPythonRemoteCall( +std::shared_ptr sendPythonRemoteCall( const WorkerInfo& dst, SerializedPyObj serializedPyObj, const IValue& rrefId, @@ -134,42 +134,40 @@ std::shared_ptr sendPythonRemoteCall( using namespace torch::distributed::autograd; -c10::intrusive_ptr wrapFutureMessageInJitFuture( - const std::shared_ptr& futureResponseMessage, +c10::intrusive_ptr toPyJitFuture( + const std::shared_ptr& messageJitFuture, bool hasValue) { if (hasValue) { - c10::intrusive_ptr jitFuture = + c10::intrusive_ptr pyJitFuture = c10::make_intrusive(PyObjectType::get()); - std::weak_ptr wp = futureResponseMessage; - futureResponseMessage->addCallback( - at::wrapPropagateTLSState([jitFuture, wp]() { - auto futureResponseMessage = wp.lock(); - if (futureResponseMessage->hasError()) { - jitFuture->setError( - std::make_exception_ptr(*futureResponseMessage->error())); + std::weak_ptr wp = messageJitFuture; + messageJitFuture->addCallback( + at::wrapPropagateTLSState([pyJitFuture, wp]() { + auto future = wp.lock(); + if (future->hasError()) { + pyJitFuture->setError(future->exception_ptr()); } else { - jitFuture->markCompleted( - toIValue(futureResponseMessage->constValue())); + pyJitFuture->markCompleted( + toPyIValue(*future->value().toCustomClass())); } })); - return jitFuture; + return pyJitFuture; } else { - c10::intrusive_ptr jitFuture = + c10::intrusive_ptr pyJitFuture = c10::make_intrusive(NoneType::get()); - std::weak_ptr wp = futureResponseMessage; - futureResponseMessage->addCallback( - at::wrapPropagateTLSState([wp, jitFuture]() { - auto futureResponseMessage = wp.lock(); - if (futureResponseMessage->hasError()) { - jitFuture->setError( - std::make_exception_ptr(*futureResponseMessage->error())); + std::weak_ptr wp = messageJitFuture; + messageJitFuture->addCallback( + at::wrapPropagateTLSState([wp, pyJitFuture]() { + auto future = wp.lock(); + if (future->hasError()) { + pyJitFuture->setError(future->exception_ptr()); } else { - jitFuture->markCompleted(IValue()); + pyJitFuture->markCompleted(IValue()); } })); - return jitFuture; + return pyJitFuture; } } @@ -186,7 +184,7 @@ c10::intrusive_ptr pyRpcBuiltin( py::gil_scoped_release release; auto scriptCall = std::make_unique(op, std::move(stack)); auto agent = RpcAgent::getCurrentRpcAgent(); - return wrapFutureMessageInJitFuture(sendMessageWithAutograd( + return toPyJitFuture(sendMessageWithAutograd( *agent, dst, std::move(*scriptCall).toMessage(), @@ -207,7 +205,7 @@ c10::intrusive_ptr pyRpcPythonUdf( std::move(serializedPyObj), isAsyncExecution); auto agent = RpcAgent::getCurrentRpcAgent(); - return wrapFutureMessageInJitFuture(sendMessageWithAutograd( + return toPyJitFuture(sendMessageWithAutograd( *agent, dst, std::move(*pythonCall).toMessage(), @@ -275,20 +273,19 @@ PyRRef pyRemoteBuiltin( auto scriptRemoteCall = std::make_unique( op, std::move(stack), userRRef->rrefId(), userRRef->forkId()); - auto fm = sendMessageWithAutograd( + auto jitFuture = sendMessageWithAutograd( *agent, dst, std::move(*scriptRemoteCall).toMessage(), /*forceGradRecord */ false, /* timeout */ rpcTimeoutSeconds); - userRRef->registerOwnerCreationFuture(fm); + userRRef->registerOwnerCreationFuture(jitFuture); ctx.addPendingUser(userRRef->forkId(), userRRef); - std::weak_ptr wp = fm; - fm->addCallback( + std::weak_ptr wp = jitFuture; + jitFuture->addCallback( at::wrapPropagateTLSState([wp, forkId{userRRef->forkId()}]() { - auto fm = wp.lock(); - callback::confirmPendingUser(*fm, forkId); + callback::confirmPendingUser(*wp.lock(), forkId); })); return PyRRef(userRRef); } else { @@ -298,22 +295,20 @@ PyRRef pyRemoteBuiltin( auto scriptRemoteCall = std::make_unique( op, std::move(stack), ownerRRef->rrefId(), ownerRRef->rrefId()); - auto fm = sendMessageWithAutograd( + auto jitFuture = sendMessageWithAutograd( *agent, dst, std::move(*scriptRemoteCall).toMessage(), /* forceGradRecord */ false, /* timeout */ rpcTimeoutSeconds); - ownerRRef->registerOwnerCreationFuture(fm); - + ownerRRef->registerOwnerCreationFuture(jitFuture); // Builtin operators does not return py::object, and hence does not require // GIL for destructing the potentially deleted OwerRRef. - std::weak_ptr wp = fm; - fm->addCallback(at::wrapPropagateTLSState( + std::weak_ptr wp = jitFuture; + jitFuture->addCallback(at::wrapPropagateTLSState( [wp, ownerRRefId = ownerRRef->rrefId()]() { - auto fm = wp.lock(); - callback::finishCreatingOwnerRRef(*fm, ownerRRefId); + callback::finishCreatingOwnerRRef(*wp.lock(), ownerRRefId); })); return PyRRef(ownerRRef); } @@ -332,7 +327,7 @@ PyRRef pyRemotePythonUdf( if (ctx.getWorkerId() != dst.id_) { auto userRRef = ctx.createUserRRef(dst.id_, PyObjectType::get()); - auto fm = sendPythonRemoteCall( + auto jitFuture = sendPythonRemoteCall( dst, std::move(serializedPyObj), userRRef->rrefId().toIValue(), @@ -340,14 +335,12 @@ PyRRef pyRemotePythonUdf( rpcTimeoutSeconds, isAsyncExecution); - userRRef->registerOwnerCreationFuture(fm); - + userRRef->registerOwnerCreationFuture(jitFuture); ctx.addPendingUser(userRRef->forkId(), userRRef); - std::weak_ptr wp = fm; - fm->addCallback( + std::weak_ptr wp = jitFuture; + jitFuture->addCallback( at::wrapPropagateTLSState([wp, forkId{userRRef->forkId()}]() { - auto fm = wp.lock(); - callback::confirmPendingUser(*fm, forkId); + callback::confirmPendingUser(*wp.lock(), forkId); })); return PyRRef(userRRef); } else { @@ -355,7 +348,7 @@ PyRRef pyRemotePythonUdf( auto ownerRRef = ctx.createOwnerRRef(PyObjectType::get()); // prevent this owner RRef being deleted due to other forks ctx.addSelfAsFork(ownerRRef); - auto fm = sendPythonRemoteCall( + auto jitFuture = sendPythonRemoteCall( dst, std::move(serializedPyObj), ownerRRef->rrefId().toIValue(), @@ -363,13 +356,12 @@ PyRRef pyRemotePythonUdf( rpcTimeoutSeconds, isAsyncExecution); - ownerRRef->registerOwnerCreationFuture(fm); - std::weak_ptr wp = fm; - fm->addCallback(at::wrapPropagateTLSState( + ownerRRef->registerOwnerCreationFuture(jitFuture); + std::weak_ptr wp = jitFuture; + jitFuture->addCallback(at::wrapPropagateTLSState( [wp, ownerRRefId = ownerRRef->rrefId()]() { - auto fm = wp.lock(); auto deletedRRef = - callback::finishCreatingOwnerRRef(*fm, ownerRRefId); + callback::finishCreatingOwnerRRef(*wp.lock(), ownerRRefId); if (deletedRRef && deletedRRef->isPyObj()) { py::gil_scoped_acquire ag; deletedRRef.reset(); diff --git a/torch/csrc/distributed/rpc/python_functions.h b/torch/csrc/distributed/rpc/python_functions.h index 56c091096828..15bc0b2af8a0 100644 --- a/torch/csrc/distributed/rpc/python_functions.h +++ b/torch/csrc/distributed/rpc/python_functions.h @@ -9,16 +9,16 @@ namespace torch { namespace distributed { namespace rpc { -// Converts an internal FutureMessage type into a user-facing FutureIValue type -// by creating a new FutureIValue and call its markCompleted as a callback in -// the given FutureMessage. +// Converts an internal ivalue::Future of Message into a user-facing +// ivalue::Future of py::object type by creating a new ivalue::Future and call +// its markCompleted as a callback in the given ivalue::Future. // If hasValue is true, the Message will be converted into a py::object and then -// wrap it with an IValue. If hasValue is false, this FutureIValue is only used -// for signaling and launching callbacks. In this case, the message will be -// discarded and then set the FutureIValue using an empty IValue or the given +// wrap it with an IValue. If hasValue is false, this ivalue::Future is only +// used for signaling and launching callbacks. In this case, the message will be +// discarded and then set the ivalue::Future using an empty IValue or the given // FutureError if there is an error. -c10::intrusive_ptr wrapFutureMessageInJitFuture( - const std::shared_ptr& futureResponseMessage, +c10::intrusive_ptr toPyJitFuture( + const std::shared_ptr& messageJitFuture, bool hasValue = true); c10::intrusive_ptr pyRpcBuiltin( diff --git a/torch/csrc/distributed/rpc/request_callback.cpp b/torch/csrc/distributed/rpc/request_callback.cpp index 44b7cb6eb2e5..33703c923523 100644 --- a/torch/csrc/distributed/rpc/request_callback.cpp +++ b/torch/csrc/distributed/rpc/request_callback.cpp @@ -9,8 +9,7 @@ namespace rpc { using namespace torch::distributed::autograd; -std::shared_ptr RequestCallback::operator()( - Message& request) const { +std::shared_ptr RequestCallback::operator()(Message& request) const { // NB: cannot clear autograd context id here because the processMessage method // might pause waiting for all RRefs in the arguments to be confirmed by their // owners and resumne processing in a different thread. Hence, the diff --git a/torch/csrc/distributed/rpc/request_callback.h b/torch/csrc/distributed/rpc/request_callback.h index 95847eb6153a..128cf9590034 100644 --- a/torch/csrc/distributed/rpc/request_callback.h +++ b/torch/csrc/distributed/rpc/request_callback.h @@ -12,7 +12,7 @@ namespace rpc { class TORCH_API RequestCallback { public: // Invoke the callback. - std::shared_ptr operator()(Message& request) const; + std::shared_ptr operator()(Message& request) const; virtual ~RequestCallback() {} @@ -24,8 +24,7 @@ class TORCH_API RequestCallback { // message containing an exception. Different rpc agent implementations are // expected to ensure delivery of the response/exception based on their // implementation specific mechanisms. - virtual std::shared_ptr processMessage( - Message& request) const = 0; + virtual std::shared_ptr processMessage(Message& request) const = 0; }; } // namespace rpc diff --git a/torch/csrc/distributed/rpc/request_callback_impl.cpp b/torch/csrc/distributed/rpc/request_callback_impl.cpp index 2004178565ea..684ca5576a56 100644 --- a/torch/csrc/distributed/rpc/request_callback_impl.cpp +++ b/torch/csrc/distributed/rpc/request_callback_impl.cpp @@ -88,12 +88,12 @@ std::unique_ptr deserializePythonRpcCommandReference( void processAsyncExecution( const py::object& pyFn, const int64_t messageId, - const std::shared_ptr& responseFuture, + const std::shared_ptr& responseFuture, std::function&)> postProcessing) { + const std::shared_ptr&)> postProcessing) { std::shared_ptr pyFuture; auto& pythonRpcHandler = PythonRpcHandler::getInstance(); { @@ -151,7 +151,7 @@ void RequestCallbackImpl::processScriptCall( RpcCommandBase& rpc, const std::function& markComplete, const int64_t messageId, - const std::shared_ptr& responseFuture) const { + const std::shared_ptr& responseFuture) const { auto& scriptCall = static_cast(rpc); auto& stack = scriptCall.stackRef(); if (processScriptCallOp(scriptCall, markComplete, stack)) { @@ -176,13 +176,14 @@ void RequestCallbackImpl::processScriptCall( try { Message m = ScriptResp(valueJitFuture->value()).toMessage(); m.setId(messageId); - responseFuture->markCompleted(std::move(m)); - } catch (const std::exception& e) { - responseFuture->setError(e.what()); + responseFuture->markCompleted( + IValue(c10::make_intrusive(std::move(m)))); + } catch (const std::exception& /* unused */) { + responseFuture->setError(std::current_exception()); } }); - } catch (const std::exception& e) { - responseFuture->setError(e.what()); + } catch (const std::exception& /* unused */) { + responseFuture->setError(std::current_exception()); } }); } else { @@ -195,9 +196,10 @@ void RequestCallbackImpl::processScriptCall( try { Message m = ScriptResp(jitFuture->value()).toMessage(); m.setId(messageId); - responseFuture->markCompleted(std::move(m)); - } catch (const std::exception& e) { - responseFuture->setError(e.what()); + responseFuture->markCompleted( + IValue(c10::make_intrusive(std::move(m)))); + } catch (const std::exception& /* unused */) { + responseFuture->setError(std::current_exception()); } }); } @@ -207,7 +209,7 @@ void RequestCallbackImpl::processPythonCall( RpcCommandBase& rpc, const std::function& markComplete, const int64_t messageId, - const std::shared_ptr& responseFuture) const { + const std::shared_ptr& responseFuture) const { auto& upc = static_cast(rpc); if (upc.isAsyncExecution()) { try { @@ -218,17 +220,18 @@ void RequestCallbackImpl::processPythonCall( [](const py::object& result, const int64_t messageId, PythonRpcHandler& pythonRpcHandler, - const std::shared_ptr& responseFuture) { + const std::shared_ptr& responseFuture) { auto serializedPyObj = pythonRpcHandler.serialize(result); py::gil_scoped_release release; auto m = std::move(PythonResp(std::move(serializedPyObj))).toMessage(); m.setId(messageId); - responseFuture->markCompleted(std::move(m)); + responseFuture->markCompleted( + IValue(c10::make_intrusive(std::move(m)))); }); } catch (std::exception& e) { - responseFuture->markCompleted( - createExceptionResponse(e.what(), messageId)); + responseFuture->markCompleted(IValue(c10::make_intrusive( + createExceptionResponse(e.what(), messageId)))); } } else { auto& pythonRpcHandler = PythonRpcHandler::getInstance(); @@ -335,7 +338,7 @@ void RequestCallbackImpl::processPythonRemoteCall( RpcCommandBase& rpc, const std::function& markComplete, const int64_t messageId, - const std::shared_ptr& responseFuture) const { + const std::shared_ptr& responseFuture) const { auto& uprc = static_cast(rpc); const auto& rrefId = uprc.rrefId(); @@ -373,20 +376,22 @@ void RequestCallbackImpl::processPythonRemoteCall( const py::object& result, const int64_t messageId, PythonRpcHandler& /* unused */, - const std::shared_ptr& responseFuture) { + const std::shared_ptr& responseFuture) { IValue py_ivalue = jit::toIValue(result, PyObjectType::get()); py::gil_scoped_release release; ownerRRef->setValue(std::move(py_ivalue)); auto m = RemoteRet(rrefId, forkId).toMessage(); m.setId(messageId); - responseFuture->markCompleted(std::move(m)); + responseFuture->markCompleted( + IValue(c10::make_intrusive(std::move(m)))); }); } catch (std::exception& e) { ownerRRef->setError(std::current_exception()); auto m = RemoteRet(rrefId, forkId).toMessage(); m.setId(messageId); - responseFuture->markCompleted(std::move(m)); + responseFuture->markCompleted( + IValue(c10::make_intrusive(std::move(m)))); } } else { IValue py_ivalue; @@ -414,14 +419,14 @@ void RequestCallbackImpl::processPythonRemoteCall( void RequestCallbackImpl::processPythonRRefFetchCall( RpcCommandBase& rpc, const int64_t messageId, - const std::shared_ptr& responseFuture) const { + const std::shared_ptr& responseFuture) const { // Making this lambda mutable to allow move-capture it in callbacks auto postProcessing = [responseFuture]( const c10::intrusive_ptr& rref, int64_t messageId) mutable { auto whenValueSet = rref->getFuture(); if (whenValueSet->hasError()) { - responseFuture->setError(whenValueSet->tryRetrieveErrorMessage()); + responseFuture->setError(whenValueSet->exception_ptr()); return; } try { @@ -437,15 +442,17 @@ void RequestCallbackImpl::processPythonRRefFetchCall( Message m = PythonRRefFetchRet(std::move(*result).toIValues()).toMessage(); m.setId(messageId); - responseFuture->markCompleted(std::move(m)); + responseFuture->markCompleted( + IValue(c10::make_intrusive(std::move(m)))); } catch (py::error_already_set& e) { // py::error_already_set requires GIL to destruct, take special care. - responseFuture->setError(e.what()); + responseFuture->setError( + std::make_exception_ptr(std::runtime_error(e.what()))); py::gil_scoped_acquire acquire; e.restore(); PyErr_Clear(); - } catch (const std::exception& e) { - responseFuture->setError(e.what()); + } catch (const std::exception& /* unused */) { + responseFuture->setError(std::current_exception()); } }; @@ -487,7 +494,7 @@ void RequestCallbackImpl::processRpcWithErrors( RpcCommandBase& rpc, const MessageType& messageType, const int64_t messageId, - const std::shared_ptr& responseFuture) const { + const std::shared_ptr& responseFuture) const { try { processRpc(rpc, messageType, messageId, responseFuture); } catch (py::error_already_set& e) { @@ -516,7 +523,7 @@ bool RequestCallbackImpl::cudaAvailable() const { void RequestCallbackImpl::processRRefBackward( RpcCommandBase& rpc, const int64_t messageId, - const std::shared_ptr& responseFuture) const { + const std::shared_ptr& responseFuture) const { auto& rrefBackwardReq = static_cast(rpc); // Get all fields @@ -540,7 +547,7 @@ void RequestCallbackImpl::processRRefBackward( autogradContextId, retainGraph]() { if (whenValueSet->hasError()) { - responseFuture->setError(whenValueSet->tryRetrieveErrorMessage()); + responseFuture->setError(whenValueSet->exception_ptr()); return; } @@ -551,9 +558,10 @@ void RequestCallbackImpl::processRRefBackward( // Return the response. Message m = RRefBackwardResp().toMessage(); m.setId(messageId); - responseFuture->markCompleted(std::move(m)); - } catch (const std::exception& e) { - responseFuture->setError(e.what()); + responseFuture->markCompleted( + IValue(c10::make_intrusive(std::move(m)))); + } catch (const std::exception& /* unused */) { + responseFuture->setError(std::current_exception()); } }); }); diff --git a/torch/csrc/distributed/rpc/request_callback_impl.h b/torch/csrc/distributed/rpc/request_callback_impl.h index bf43ef867cff..2883359af303 100644 --- a/torch/csrc/distributed/rpc/request_callback_impl.h +++ b/torch/csrc/distributed/rpc/request_callback_impl.h @@ -18,13 +18,13 @@ class TORCH_API RequestCallbackImpl : public RequestCallbackNoPython { RpcCommandBase& rpc, const std::function& markComplete, const int64_t messageId, - const std::shared_ptr& responseFuture) const override; + const std::shared_ptr& responseFuture) const override; void processScriptCall( RpcCommandBase& rpc, const std::function& markComplete, const int64_t messageId, - const std::shared_ptr& responseFuture) const override; + const std::shared_ptr& responseFuture) const override; TypePtr getScriptRemoteCallType( ScriptRemoteCall& scriptRemoteCall) const override; @@ -39,12 +39,12 @@ class TORCH_API RequestCallbackImpl : public RequestCallbackNoPython { RpcCommandBase& rpc, const std::function& markComplete, const int64_t messageId, - const std::shared_ptr& responseFuture) const override; + const std::shared_ptr& responseFuture) const override; void processPythonRRefFetchCall( RpcCommandBase& rpc, const int64_t messageId, - const std::shared_ptr& responseFuture) const override; + const std::shared_ptr& responseFuture) const override; void handleRRefDelete(c10::intrusive_ptr& rref) const override; @@ -52,14 +52,14 @@ class TORCH_API RequestCallbackImpl : public RequestCallbackNoPython { RpcCommandBase& rpc, const MessageType& messageType, const int64_t messageId, - const std::shared_ptr& responseFuture) const override; + const std::shared_ptr& responseFuture) const override; bool cudaAvailable() const override; void processRRefBackward( RpcCommandBase& rpc, const int64_t messageId, - const std::shared_ptr& responseFuture) const override; + const std::shared_ptr& responseFuture) const override; }; } // namespace rpc diff --git a/torch/csrc/distributed/rpc/request_callback_no_python.cpp b/torch/csrc/distributed/rpc/request_callback_no_python.cpp index 7f8db89f55bf..597edb68862d 100644 --- a/torch/csrc/distributed/rpc/request_callback_no_python.cpp +++ b/torch/csrc/distributed/rpc/request_callback_no_python.cpp @@ -1,3 +1,4 @@ +#include #include #include #include @@ -7,7 +8,7 @@ #include #include #include -#include +#include #include #include #include @@ -47,13 +48,13 @@ std::unique_ptr RequestCallbackNoPython:: return rpc; } -std::shared_ptr RequestCallbackNoPython::processMessage( +std::shared_ptr RequestCallbackNoPython::processMessage( Message& request) const { // We need two futures here because it could pause twice when processing a // RPC message: // 1) waiting for all RRefs in the arguments to become confirmed; // 2) waiting for processRpc to finish. - auto retFuture = std::make_shared(); + auto retFuture = std::make_shared(at::AnyClassType::get()); auto& rrefContext = RRefContext::getInstance(); try { rrefContext.recordThreadLocalPendingRRefs(); @@ -62,44 +63,43 @@ std::shared_ptr RequestCallbackNoPython::processMessage( deserializeRequest(request), request.type()); auto rrefsReadyFuture = rrefContext.waitForThreadLocalPendingRRefs(); - rrefsReadyFuture->addCallback([ - this, - retFuture, - // std::function must be copyable, hence hae to cast the unique_ptr to - // a shared_ptr here. - rpc = (std::shared_ptr)std::move(rpc), - messageType = request.type(), - id = request.id() - ]() { - // The cost of pre-request check is minimal thanks to - // std::shared_lock. The cost is in magnitude - // of 10us. - auto serverProcessGlobalProfilerStateStackEntryPtr = - profiler::processglobal::StateStackEntry::current(); - // If server global profiler is enabled, we futher pay the - // cost of thread local profiler state initialization. - if (serverProcessGlobalProfilerStateStackEntryPtr) { - // Initialize thread-local profiler state from process-global - // profiler state. - ::torch::autograd::profiler::enableProfilerLegacy( - serverProcessGlobalProfilerStateStackEntryPtr->statePtr() - ->config()); - } - - processRpcWithErrors(*rpc, messageType, id, retFuture); - - // Response message has been sent at this moment, this post-response - // work doesn't affect RPC trip time. - if (serverProcessGlobalProfilerStateStackEntryPtr) { - // Restore thread-local profiler state. - ::torch::autograd::profiler::thread_event_lists event_lists = - ::torch::autograd::profiler::disableProfilerLegacy(); - // Put thread_local event_lists into the process-global profiler - // state. - profiler::processglobal::pushResultRecursive( - serverProcessGlobalProfilerStateStackEntryPtr, event_lists); - } - }); + rrefsReadyFuture->addCallback( + [this, + retFuture, + // std::function must be copyable, hence hae to cast the unique_ptr to + // a shared_ptr here. + rpc = (std::shared_ptr)std::move(rpc), + messageType = request.type(), + id = request.id()]() { + // The cost of pre-request check is minimal thanks to + // std::shared_lock. The cost is in magnitude + // of 10us. + auto serverProcessGlobalProfilerStateStackEntryPtr = + profiler::processglobal::StateStackEntry::current(); + // If server global profiler is enabled, we futher pay the + // cost of thread local profiler state initialization. + if (serverProcessGlobalProfilerStateStackEntryPtr) { + // Initialize thread-local profiler state from process-global + // profiler state. + ::torch::autograd::profiler::enableProfilerLegacy( + serverProcessGlobalProfilerStateStackEntryPtr->statePtr() + ->config()); + } + + processRpcWithErrors(*rpc, messageType, id, retFuture); + + // Response message has been sent at this moment, this post-response + // work doesn't affect RPC trip time. + if (serverProcessGlobalProfilerStateStackEntryPtr) { + // Restore thread-local profiler state. + ::torch::autograd::profiler::thread_event_lists event_lists = + ::torch::autograd::profiler::disableProfilerLegacy(); + // Put thread_local event_lists into the process-global profiler + // state. + profiler::processglobal::pushResultRecursive( + serverProcessGlobalProfilerStateStackEntryPtr, event_lists); + } + }); } catch (std::exception& e) { retFuture->markCompleted(handleError(e, request.type(), request.id())); rrefContext.clearRecordedPendingRRefsOnError(); @@ -111,7 +111,7 @@ void RequestCallbackNoPython::processRpcWithErrors( RpcCommandBase& rpc, const MessageType& messageType, const int64_t messageId, - const std::shared_ptr& responseFuture) const { + const std::shared_ptr& responseFuture) const { try { processRpc(rpc, messageType, messageId, responseFuture); } catch (std::exception& e) { @@ -123,7 +123,7 @@ void RequestCallbackNoPython::processScriptCall( RpcCommandBase& rpc, const std::function& markComplete, const int64_t messageId, - const std::shared_ptr& responseFuture) const { + const std::shared_ptr& /* unused */) const { auto& scriptCall = static_cast(rpc); auto& stack = scriptCall.stackRef(); TORCH_CHECK( @@ -161,7 +161,7 @@ void RequestCallbackNoPython::processPythonCall( RpcCommandBase& rpc, const std::function& markComplete, const int64_t messageId, - const std::shared_ptr& responseFuture) const { + const std::shared_ptr& /* unused */) const { C10_THROW_ERROR(Error, "Python call not supported!"); } @@ -169,7 +169,7 @@ void RequestCallbackNoPython::processPythonRemoteCall( RpcCommandBase& rpc, const std::function& markComplete, const int64_t messageId, - const std::shared_ptr& responseFuture) const { + const std::shared_ptr& /* unused */) const { C10_THROW_ERROR(Error, "Python call not supported!"); } @@ -187,7 +187,7 @@ void RequestCallbackNoPython::processBaseScriptRemoteCall( RpcCommandBase& rpc, const std::function& markComplete, const int64_t messageId, - const std::shared_ptr& responseFuture) const { + const std::shared_ptr& responseFuture) const { auto& scriptRemoteCall = static_cast(rpc); auto rrefId = scriptRemoteCall.retRRefId(); auto forkId = scriptRemoteCall.retForkId(); @@ -208,7 +208,8 @@ void RequestCallbackNoPython::processBaseScriptRemoteCall( } Message m = RemoteRet(rrefId, forkId).toMessage(); m.setId(messageId); - responseFuture->markCompleted(std::move(m)); + responseFuture->markCompleted( + IValue(c10::make_intrusive(std::move(m)))); }; // scriptRemoteCall is only alive within this block, use reference to @@ -259,7 +260,7 @@ void RequestCallbackNoPython::processScriptRRefFetchCall( RpcCommandBase& rpc, const std::function& markComplete, const int64_t messageId, - const std::shared_ptr& responseFuture) const { + const std::shared_ptr& responseFuture) const { auto& srf = static_cast(rpc); auto& ctx = RRefContext::getInstance(); @@ -283,15 +284,16 @@ void RequestCallbackNoPython::processScriptRRefFetchCall( whenValueSet->addCallback( [responseFuture, messageId, rref, whenValueSet]() { if (whenValueSet->hasError()) { - responseFuture->setError(whenValueSet->tryRetrieveErrorMessage()); + responseFuture->setError(whenValueSet->exception_ptr()); return; } try { Message m = ScriptRRefFetchRet({rref->getValue()}).toMessage(); m.setId(messageId); - responseFuture->markCompleted(std::move(m)); - } catch (const std::exception& e) { - responseFuture->setError(e.what()); + responseFuture->markCompleted( + IValue(c10::make_intrusive(std::move(m)))); + } catch (const std::exception& /* unused */) { + responseFuture->setError(std::current_exception()); } }); }); @@ -300,7 +302,7 @@ void RequestCallbackNoPython::processScriptRRefFetchCall( void RequestCallbackNoPython::processPythonRRefFetchCall( RpcCommandBase& rpc, const int64_t messageId, - const std::shared_ptr& responseFuture) const { + const std::shared_ptr& /* unused */) const { C10_THROW_ERROR(Error, "Python call not supported!"); } @@ -340,7 +342,7 @@ void RequestCallbackNoPython::processRRefForkRequest( void RequestCallbackNoPython::processForwardAutogradReq( RpcCommandBase& rpc, const int64_t messageId, - const std::shared_ptr& responseFuture) const { + const std::shared_ptr& responseFuture) const { auto& rpcWithAutograd = static_cast(rpc); // Attach 'recv' autograd function. @@ -362,7 +364,8 @@ void RequestCallbackNoPython::processForwardAutogradReq( // Process the original RPC. auto wrappedMessageType = rpcWithAutograd.wrappedMessageType(); // Make an overall future for the wrapped response. - auto wrappedRpcResponseFuture = std::make_shared(); + auto wrappedRpcResponseFuture = + std::make_shared(at::AnyClassType::get()); // Kick off processing for the nested RPC command. // wrappedRpcResponseFuture will be a Future to the result. processRpc( @@ -375,45 +378,46 @@ void RequestCallbackNoPython::processForwardAutogradReq( // The original future needs to be marked as completed when the wrapped // one completes, with the autograd context information wrapped. // Uses weak_ptr so we can std::move the value. - wrappedRpcResponseFuture->addCallback([ - responseFuture, - messageId, - fromWorkerId, - weak = std::weak_ptr(wrappedRpcResponseFuture), - ctxId = autogradContext->contextId() - ]() { - // As this callback can be invoked by a different thread, we have to - // make sure that the thread_local states in the previous thread is - // correctly propagated. - // NB: The execution of TorchScript functions can also run on a - // different thread, which is addressed by - // https://github.com/pytorch/pytorch/pull/36395 - // NB: when adding async UDF support, we should also propagate - // thread_local states there. - // TODO: Land on a general solution for RPC ThreadLocalState. See - // https://github.com/pytorch/pytorch/issues/38510 - DistAutogradContextGuard cbCtxGuard(ctxId); - - auto wrappedRpcResponseFuture = weak.lock(); - TORCH_INTERNAL_ASSERT(wrappedRpcResponseFuture); - if (wrappedRpcResponseFuture->hasError()) { - // Propagate error to responseFuture if we had one. - responseFuture->setError(wrappedRpcResponseFuture->error()->what()); - } else { - auto msg = getMessageWithAutograd( - fromWorkerId, - std::move(*wrappedRpcResponseFuture).moveValue(), - MessageType::FORWARD_AUTOGRAD_RESP); - msg.setId(messageId); - responseFuture->markCompleted(std::move(msg)); - } - }); + wrappedRpcResponseFuture->addCallback( + [responseFuture, + messageId, + fromWorkerId, + weak = std::weak_ptr(wrappedRpcResponseFuture), + ctxId = autogradContext->contextId()]() { + // As this callback can be invoked by a different thread, we have to + // make sure that the thread_local states in the previous thread is + // correctly propagated. + // NB: The execution of TorchScript functions can also run on a + // different thread, which is addressed by + // https://github.com/pytorch/pytorch/pull/36395 + // NB: when adding async UDF support, we should also propagate + // thread_local states there. + // TODO: Land on a general solution for RPC ThreadLocalState. See + // https://github.com/pytorch/pytorch/issues/38510 + DistAutogradContextGuard cbCtxGuard(ctxId); + + auto wrappedRpcResponseFuture = weak.lock(); + TORCH_INTERNAL_ASSERT(wrappedRpcResponseFuture); + if (wrappedRpcResponseFuture->hasError()) { + // Propagate error to responseFuture if we had one. + responseFuture->setError(wrappedRpcResponseFuture->exception_ptr()); + } else { + auto msg = getMessageWithAutograd( + fromWorkerId, + std::move( + *wrappedRpcResponseFuture->value().toCustomClass()), + MessageType::FORWARD_AUTOGRAD_RESP); + msg.setId(messageId); + responseFuture->markCompleted( + IValue(c10::make_intrusive(std::move(msg)))); + } + }); } void RequestCallbackNoPython::processBackwardAutogradReq( RpcCommandBase& rpc, const int64_t messageId, - const std::shared_ptr& responseFuture) const { + const std::shared_ptr& responseFuture) const { auto& gradientsCall = static_cast(rpc); const auto& autogradMetadata = gradientsCall.getAutogradMetadata(); @@ -437,9 +441,10 @@ void RequestCallbackNoPython::processBackwardAutogradReq( if (!execFuture->hasError()) { Message m = std::move(PropagateGradientsResp()).toMessage(); m.setId(messageId); - responseFuture->markCompleted(std::move(m)); + responseFuture->markCompleted( + IValue(c10::make_intrusive(std::move(m)))); } else { - responseFuture->setError(execFuture->tryRetrieveErrorMessage()); + responseFuture->setError(execFuture->exception_ptr()); } }); } @@ -461,7 +466,7 @@ void RequestCallbackNoPython::processCleanupAutogradContextReq( void RequestCallbackNoPython::processRunWithProfilingReq( RpcCommandBase& rpc, const int64_t messageId, - const std::shared_ptr& responseFuture) const { + const std::shared_ptr& responseFuture) const { auto& rpcWithProfilingReq = static_cast(rpc); auto wrappedMsgType = rpcWithProfilingReq.wrappedMessageType(); auto profilingConfig = rpcWithProfilingReq.getProfilingConfig(); @@ -483,7 +488,8 @@ void RequestCallbackNoPython::processRunWithProfilingReq( this->cudaAvailable(), "Profiler state set to CUDA but CUDA not available."); const auto profilingKeyId = rpcWithProfilingReq.getProfilingId(); - auto wrappedRpcResponseFuture = std::make_shared(); + auto wrappedRpcResponseFuture = + std::make_shared(at::AnyClassType::get()); // Enable the profiler with the config from the sender. // When enabling on the main thread, ensure profiler states are cleaned // up, but defer consolidation of all profiled events to the continuation @@ -504,11 +510,11 @@ void RequestCallbackNoPython::processRunWithProfilingReq( messageId, wrappedRpcResponseFuture); - wrappedRpcResponseFuture->addCallback(at::wrapPropagateTLSState( - [wrappedRpcResponseFuture, - responseFuture, - profilingKeyId, - profilingConfig] { + wrappedRpcResponseFuture->addCallback( + at::wrapPropagateTLSState([wrappedRpcResponseFuture, + responseFuture, + profilingKeyId, + profilingConfig] { std::vector profiledEvents; // Defer consolidation of profiler events until async work has // completed (such as async UDF) @@ -521,21 +527,23 @@ void RequestCallbackNoPython::processRunWithProfilingReq( // they will be cleaned up by main thread, and consolidate all // events so we obtain asynchronously run events. torch::autograd::profiler::ProfilerDisableOptions opts(false, true); - auto event_lists = torch::autograd::profiler::disableProfilerLegacy(opts); + auto event_lists = + torch::autograd::profiler::disableProfilerLegacy(opts); if (wrappedRpcResponseFuture->hasError()) { // Propagate error // No need to propagate remote events in the case of an error. - responseFuture->setError(wrappedRpcResponseFuture->error()->what()); + responseFuture->setError(wrappedRpcResponseFuture->exception_ptr()); } else { populateRemoteProfiledEvents( profiledEvents, profilingConfig, event_lists); auto rpcWithProfilingResp = std::make_unique( MessageType::RUN_WITH_PROFILING_RESP, - std::move(*wrappedRpcResponseFuture).moveValue(), + std::move(*wrappedRpcResponseFuture->value() + .toCustomClass()), profiledEvents, profilingKeyId); - responseFuture->markCompleted( - std::move(*rpcWithProfilingResp).toMessage()); + responseFuture->markCompleted(IValue(c10::make_intrusive( + std::move(*rpcWithProfilingResp).toMessage()))); } })); // Exiting the scope will disable the profiler on this thread with the @@ -546,7 +554,7 @@ void RequestCallbackNoPython::processRunWithProfilingReq( void RequestCallbackNoPython::processRRefBackward( RpcCommandBase& rpc, const int64_t messageId, - const std::shared_ptr& responseFuture) const { + const std::shared_ptr& /* unused */) const { C10_THROW_ERROR(Error, "Python call not supported!"); } @@ -554,10 +562,11 @@ void RequestCallbackNoPython::processRpc( RpcCommandBase& rpc, const MessageType& messageType, const int64_t messageId, - const std::shared_ptr& responseFuture) const { + const std::shared_ptr& responseFuture) const { auto markComplete = [messageId, &responseFuture](Message m) { m.setId(messageId); - responseFuture->markCompleted(std::move(m)); + responseFuture->markCompleted( + IValue(c10::make_intrusive(std::move(m)))); }; // TODO: RpcCommandBase should have an abstract execute() method that we can // call here instead of having another switch statement here. Even better we @@ -629,7 +638,7 @@ void RequestCallbackNoPython::processRpc( } } -Message RequestCallbackNoPython::handleError( +IValue RequestCallbackNoPython::handleError( const std::exception& e, const MessageType messageType, int64_t messageId) const { @@ -642,7 +651,8 @@ Message RequestCallbackNoPython::handleError( DistAutogradContainer::getInstance().getWorkerId(), ": ", e.what()); - return createExceptionResponse(errorMsg, messageId); + return IValue(c10::make_intrusive( + createExceptionResponse(errorMsg, messageId))); } bool RequestCallbackNoPython::cudaAvailable() const { diff --git a/torch/csrc/distributed/rpc/request_callback_no_python.h b/torch/csrc/distributed/rpc/request_callback_no_python.h index 7d9dac0a6635..9932c4744900 100644 --- a/torch/csrc/distributed/rpc/request_callback_no_python.h +++ b/torch/csrc/distributed/rpc/request_callback_no_python.h @@ -14,8 +14,7 @@ namespace rpc { // RequestCallback implementation with no Python dependencies. class TORCH_API RequestCallbackNoPython : public RequestCallback { public: - std::shared_ptr processMessage( - Message& request) const override; + std::shared_ptr processMessage(Message& request) const override; protected: virtual std::unique_ptr deserializePythonRpcCommand( @@ -26,7 +25,7 @@ class TORCH_API RequestCallbackNoPython : public RequestCallback { RpcCommandBase& rpc, const std::function& markComplete, const int64_t messageId, - const std::shared_ptr& responseFuture) const; + const std::shared_ptr& responseFuture) const; bool processScriptCallOp( ScriptCall& scriptCall, @@ -37,7 +36,7 @@ class TORCH_API RequestCallbackNoPython : public RequestCallback { RpcCommandBase& rpc, const std::function& markComplete, const int64_t messageId, - const std::shared_ptr& responseFuture) const; + const std::shared_ptr& responseFuture) const; virtual TypePtr getScriptRemoteCallType( ScriptRemoteCall& scriptRemoteCall) const; @@ -52,7 +51,7 @@ class TORCH_API RequestCallbackNoPython : public RequestCallback { RpcCommandBase& rpc, const std::function& markComplete, const int64_t messageId, - const std::shared_ptr& responseFuture) const; + const std::shared_ptr& responseFuture) const; bool processScriptRemoteCallOp( ScriptRemoteCall& scriptRemoteCall, @@ -64,18 +63,18 @@ class TORCH_API RequestCallbackNoPython : public RequestCallback { RpcCommandBase& rpc, const std::function& markComplete, const int64_t messageId, - const std::shared_ptr& responseFuture) const; + const std::shared_ptr& responseFuture) const; void processScriptRRefFetchCall( RpcCommandBase& rpc, const std::function& markComplete, const int64_t messageId, - const std::shared_ptr& responseFuture) const; + const std::shared_ptr& responseFuture) const; virtual void processPythonRRefFetchCall( RpcCommandBase& rpc, const int64_t messageId, - const std::shared_ptr& responseFuture) const; + const std::shared_ptr& responseFuture) const; void processRRefUserDelete( RpcCommandBase& rpc, @@ -92,12 +91,12 @@ class TORCH_API RequestCallbackNoPython : public RequestCallback { void processForwardAutogradReq( RpcCommandBase& rpc, const int64_t messageId, - const std::shared_ptr& responseFuture) const; + const std::shared_ptr& responseFuture) const; void processBackwardAutogradReq( RpcCommandBase& rpc, const int64_t messageId, - const std::shared_ptr& responseFuture) const; + const std::shared_ptr& responseFuture) const; void processCleanupAutogradContextReq( RpcCommandBase& rpc, @@ -106,7 +105,7 @@ class TORCH_API RequestCallbackNoPython : public RequestCallback { void processRunWithProfilingReq( RpcCommandBase& rpc, const int64_t messageId, - const std::shared_ptr& responseFuture) const; + const std::shared_ptr& responseFuture) const; virtual void handleRRefDelete(c10::intrusive_ptr& rref) const; @@ -114,15 +113,15 @@ class TORCH_API RequestCallbackNoPython : public RequestCallback { RpcCommandBase& rpc, const MessageType& messageType, const int64_t messageId, - const std::shared_ptr& responseFuture) const; + const std::shared_ptr& responseFuture) const; virtual void processRpcWithErrors( RpcCommandBase& rpc, const MessageType& messageType, const int64_t messageId, - const std::shared_ptr& responseFuture) const; + const std::shared_ptr& responseFuture) const; - Message handleError( + IValue handleError( const std::exception& e, const MessageType messageType, int64_t messageId) const; @@ -132,7 +131,7 @@ class TORCH_API RequestCallbackNoPython : public RequestCallback { virtual void processRRefBackward( RpcCommandBase& rpc, const int64_t messageId, - const std::shared_ptr& responseFuture) const; + const std::shared_ptr& responseFuture) const; }; } // namespace rpc diff --git a/torch/csrc/distributed/rpc/rpc_agent.cpp b/torch/csrc/distributed/rpc/rpc_agent.cpp index 4d9f6db39220..2033b2b771e2 100644 --- a/torch/csrc/distributed/rpc/rpc_agent.cpp +++ b/torch/csrc/distributed/rpc/rpc_agent.cpp @@ -45,7 +45,7 @@ void RpcAgent::shutdown() { shutdownImpl(); } -std::shared_ptr RpcAgent::sendWithRetries( +std::shared_ptr RpcAgent::sendWithRetries( const WorkerInfo& to, Message&& message, RpcRetryOptions retryOptions) { @@ -57,12 +57,12 @@ std::shared_ptr RpcAgent::sendWithRetries( retryOptions.rpcRetryDuration.count() >= 0, "rpcRetryDuration cannot be negative."); - auto originalFuture = std::make_shared(); + auto originalFuture = std::make_shared(at::AnyClassType::get()); steady_clock_time_point newTime = computeNewRpcRetryTime(retryOptions, /* retryCount */ 0); // Making a copy of the message so it can be retried after the first send. Message msgCopy = message; - auto fm = send(to, std::move(message)); + auto jitFuture = send(to, std::move(message)); auto firstRetryRpc = std::make_shared( to, std::move(msgCopy), @@ -70,13 +70,13 @@ std::shared_ptr RpcAgent::sendWithRetries( /* retryCount */ 0, retryOptions); // Use weak_ptr so that the value can be std::moved in rpcRetryCallback. - fm->addCallback([this, - newTime, - firstRetryRpc, - weak = std::weak_ptr(fm)]() { - auto fm = weak.lock(); - TORCH_INTERNAL_ASSERT(fm); - rpcRetryCallback(fm, newTime, firstRetryRpc); + jitFuture->addCallback([this, + newTime, + firstRetryRpc, + wp = std::weak_ptr(jitFuture)]() { + auto future = wp.lock(); + TORCH_INTERNAL_ASSERT(future); + rpcRetryCallback(future, newTime, firstRetryRpc); }); return originalFuture; @@ -85,11 +85,10 @@ std::shared_ptr RpcAgent::sendWithRetries( void RpcAgent::retryExpiredRpcs() { // Stores the retried futures so callbacks can be added outside the lock. std::vector< - std::pair, std::shared_ptr>> + std::pair, std::shared_ptr>> futures; // Stores futures and exception messages for non-retriable error-ed futures. - std::vector, std::string>> - errorFutures; + std::vector, std::string>> errorFutures; while (rpcAgentRunning_.load()) { std::unique_lock lock(rpcRetryMutex_); @@ -126,15 +125,15 @@ void RpcAgent::retryExpiredRpcs() { auto& earliestRpc = *it; // Making a copy of the message so it can be retried in the future. Message msgCopy = earliestRpc->message_; - std::shared_ptr fm; + std::shared_ptr jitFuture; // send() will throw an exception if an RPC is retried while the agent is // shutdown. We must catch this exception and mark the original future // with an error, since this RPC never succeeded and can no longer be // retried. try { - fm = send(earliestRpc->to_, std::move(msgCopy)); - futures.emplace_back(fm, earliestRpc); + jitFuture = send(earliestRpc->to_, std::move(msgCopy)); + futures.emplace_back(jitFuture, earliestRpc); } catch (std::exception& e) { // We must store the futures and exception messages here and only mark // the futures with an error after releasing the lock. @@ -158,20 +157,20 @@ void RpcAgent::retryExpiredRpcs() { // We attach callbacks to the futures outside of the lock to prevent // potential deadlocks. for (const auto& it : futures) { - auto fm = it.first; + auto jitFuture = it.first; auto earliestRpc = it.second; steady_clock_time_point newTime = computeNewRpcRetryTime( earliestRpc->options_, earliestRpc->retryCount_); earliestRpc->retryCount_++; // Use weak_ptr so that the value can be std::moved in rpcRetryCallback. - fm->addCallback([this, - newTime, - earliestRpc, - weak = std::weak_ptr(fm)]() { - auto fm = weak.lock(); - TORCH_INTERNAL_ASSERT(fm); - rpcRetryCallback(fm, newTime, earliestRpc); + jitFuture->addCallback([this, + newTime, + earliestRpc, + wp = std::weak_ptr(jitFuture)]() { + auto future = wp.lock(); + TORCH_INTERNAL_ASSERT(future); + rpcRetryCallback(future, newTime, earliestRpc); }); } futures.clear(); @@ -181,17 +180,18 @@ void RpcAgent::retryExpiredRpcs() { for (const auto& it : errorFutures) { auto errorFuture = it.first; auto errorMsg = it.second; - errorFuture->setError(errorMsg); + errorFuture->setError( + std::make_exception_ptr(std::runtime_error(errorMsg))); } errorFutures.clear(); } } void RpcAgent::rpcRetryCallback( - const std::shared_ptr& futureMessage, + const std::shared_ptr& jitFuture, steady_clock_time_point newTime, std::shared_ptr earliestRpc) { - if (futureMessage->hasError()) { + if (jitFuture->hasError()) { // Adding one since we want to include the original send as well and not // just the retry count. LOG(INFO) << "Send try " << (earliestRpc->retryCount_ + 1) << " failed"; @@ -203,7 +203,7 @@ void RpcAgent::rpcRetryCallback( "RPC Agent is no longer running on Node ", RpcAgent::getWorkerInfo().id_, ". Cannot retry message."); - earliestRpc->originalFuture_->setError(*futureMessage->error()); + earliestRpc->originalFuture_->setError(jitFuture->exception_ptr()); } else if (earliestRpc->retryCount_ < earliestRpc->options_.maxRetries) { // If the previous future completed with an error and we haven't // completed maxRetries send attempts, we move the earliestRpc @@ -223,12 +223,12 @@ void RpcAgent::rpcRetryCallback( "The RPC has not succeeded after the specified number of max retries (", earliestRpc->options_.maxRetries, ")."); - earliestRpc->originalFuture_->setError(errorMessage); + earliestRpc->originalFuture_->setError( + std::make_exception_ptr(std::runtime_error(errorMessage))); } } else { // This try succeeded, so we can make the original future as complete. - earliestRpc->originalFuture_->markCompleted( - std::move(*futureMessage).moveValue()); + earliestRpc->originalFuture_->markCompleted(jitFuture->value()); } } diff --git a/torch/csrc/distributed/rpc/rpc_agent.h b/torch/csrc/distributed/rpc/rpc_agent.h index 34b77a085510..bfc6c38c07a1 100644 --- a/torch/csrc/distributed/rpc/rpc_agent.h +++ b/torch/csrc/distributed/rpc/rpc_agent.h @@ -105,7 +105,7 @@ struct TORCH_API RpcRetryInfo { RpcRetryInfo( const WorkerInfo& to, Message&& message, - std::shared_ptr originalFuture, + std::shared_ptr originalFuture, int retryCount, RpcRetryOptions options) : to_(to), @@ -117,7 +117,7 @@ struct TORCH_API RpcRetryInfo { const WorkerInfo& to_; Message message_; // Future that is returned to the caller of sendWithRetries(). - std::shared_ptr originalFuture_; + std::shared_ptr originalFuture_; // Number of send attempts completed so far. int retryCount_; RpcRetryOptions options_; @@ -151,13 +151,13 @@ class TORCH_API RpcAgent { virtual ~RpcAgent(); // Send a message to the ``RpcAgent`` of id ``to`` and returns a - // ``FutureMessage`` ptr. The implementation must be asynchronous, i.e., it + // ``JitFuture`` ptr. The implementation must be asynchronous, i.e., it // cannot block until it receives the response. // - // If ``message.isRequest()`` is true, the ``FutureMessage`` will be + // If ``message.isRequest()`` is true, the ``JitFuture`` will be // completed when the response arrives. For other message types, the Future // should be ignored by the caller. - virtual std::shared_ptr send( + virtual std::shared_ptr send( const WorkerInfo& to, Message&& message, const float rpcTimeoutSeconds = kUnsetRpcTimeout) = 0; @@ -167,14 +167,14 @@ class TORCH_API RpcAgent { // time using an exponential backoff algorithm. // // Sends ``message`` to the ``RpcAgent`` of id ``to`` and returns a - // ``FutureMessage`` ptr, just like send(). Caller can specify the maximum + // ``JitFuture`` ptr, just like send(). Caller can specify the maximum // number of retries for this RPC (default is 5), initial duration between // sends (default is 1000ms), and backoff constant (default is 1.5) by // passing in the RpcRetryOptions struct. This API might end up // executing a method twice on the remote end (it does not guarantee // exactly-once semantics). Therefore, the user must ensure their requests // are idempotent. - std::shared_ptr sendWithRetries( + std::shared_ptr sendWithRetries( const WorkerInfo& to, Message&& message, RpcRetryOptions retryOptions = RpcRetryOptions()); @@ -299,7 +299,7 @@ class TORCH_API RpcAgent { // error and do not retry again. In case 3, we move the RpcRetryInfo struct // to another time point in the map to schedule the RPC for a future send. void rpcRetryCallback( - const std::shared_ptr& message, + const std::shared_ptr& message, steady_clock_time_point newTime, std::shared_ptr earliestRpc); diff --git a/torch/csrc/distributed/rpc/rref_context.cpp b/torch/csrc/distributed/rpc/rref_context.cpp index dd64ee5c9445..ce257c50a7a4 100644 --- a/torch/csrc/distributed/rpc/rref_context.cpp +++ b/torch/csrc/distributed/rpc/rref_context.cpp @@ -14,11 +14,12 @@ thread_local bool RRefContext::recording_ = false; namespace callback { void confirmPendingUser( - const FutureMessage& futureMessage, + const JitFuture& jitFuture, const ForkId& expectedForkId) { - if (!futureMessage.hasError()) { - auto msgType = futureMessage.constValue().type(); - auto rpc = deserializeResponse(futureMessage.constValue(), msgType); + if (!jitFuture.hasError()) { + auto msgPtr = jitFuture.constValue().toCustomClass(); + auto msgType = msgPtr->type(); + auto rpc = deserializeResponse(*msgPtr, msgType); auto rr = dynamic_cast(rpc.get()); TORCH_INTERNAL_ASSERT(rr->forkId() == expectedForkId); } else { @@ -34,30 +35,31 @@ void confirmPendingUser( // the user application will use the RRef before the errors are handled. In // this case, errors may not be raised as they have not yet been handled. auto rref_ptr = RRefContext::getInstance().getPendingUser(expectedForkId); - auto errorType = getRPCErrorType(futureMessage); - rref_ptr->handleError(errorType, futureMessage); + auto errorType = getRPCErrorType(jitFuture); + rref_ptr->handleError(errorType, jitFuture); } RRefContext::getInstance().delPendingUser(expectedForkId); } c10::intrusive_ptr finishCreatingOwnerRRef( - const FutureMessage& futureMessage, + const JitFuture& jitFuture, const RRefId& rrefId) { - if (futureMessage.hasError()) { + if (jitFuture.hasError()) { auto& ctx = RRefContext::getInstance(); // We expect to run this callback only after the OwnerRRef has been created, // since this is only invoked when sending to self. auto rref_ptr = ctx.getOwnerRRef(rrefId, /* ensure created */ true)->constValue(); - auto errorType = getRPCErrorType(futureMessage); - rref_ptr->handleError(errorType, futureMessage); + auto errorType = getRPCErrorType(jitFuture); + rref_ptr->handleError(errorType, jitFuture); // OwnerRRefs do not have a forkId, so don't need to assert here. auto deletedRRef = ctx.delForkOfOwner(rref_ptr->rrefId(), rref_ptr->rrefId()); return deletedRRef; } else { - auto msgType = futureMessage.constValue().type(); - auto rpc = deserializeResponse(futureMessage.constValue(), msgType); + auto msgPtr = jitFuture.constValue().toCustomClass(); + auto msgType = msgPtr->type(); + auto rpc = deserializeResponse(*msgPtr, msgType); auto rr = dynamic_cast(rpc.get()); TORCH_INTERNAL_ASSERT( rr->rrefId() == rr->forkId(), @@ -102,10 +104,11 @@ std::vector> RRefContext::destroyInstance( return deletedRRefs; } -void RRefContext::handleException(const FutureMessage& fm) { - if (fm.hasError()) { - VLOG(1) << "Got exception: " << fm.error()->what(); - throw std::runtime_error(fm.error()->what()); +void RRefContext::handleException(const JitFuture& jitFuture) { + if (jitFuture.hasError()) { + auto errMsg = jitFuture.tryRetrieveErrorMessage(); + VLOG(1) << "Got exception: " << errMsg; + throw std::runtime_error(errMsg); } } @@ -209,12 +212,13 @@ void RRefContext::delUser( // which is now idempotent. See the comment at RRefContext::delForkOfOwner // for more details. ++numPendingFutures_; - auto fm = agent_->sendWithRetries( + auto jitFuture = agent_->sendWithRetries( agent_->getWorkerInfo(owner), RRefUserDelete(rrefId, forkId).toMessage()); - fm->addCallback([this](const FutureMessage& fm) { - handleException(fm); + std::weak_ptr wp = jitFuture; + jitFuture->addCallback([this, wp]() { + handleException(*wp.lock()); --numPendingFutures_; }); } @@ -483,21 +487,24 @@ void RRefContext::notifyOwnerAndParentOfFork( // into forks_. Because, there will be no real `UserRRef` associated // with this fork ID. ++numPendingFutures_; - auto fm = agent_->sendWithRetries( + auto jitFuture = agent_->sendWithRetries( agent_->getWorkerInfo(parent), RRefChildAccept(forkId).toMessage()); - fm->addCallback([this](const FutureMessage& fm) { - handleException(fm); + std::weak_ptr wp = jitFuture; + jitFuture->addCallback([this, wp]() { + handleException(*wp.lock()); --numPendingFutures_; }); } else { ++numPendingFutures_; - auto fm = agent_->sendWithRetries( + auto jitFuture = agent_->sendWithRetries( agent_->getWorkerInfo(rref->owner()), RRefForkRequest(rref->rrefId(), forkId).toMessage()); addPendingUser(forkId, rref); - fm->addCallback([this, forkId, parent](const FutureMessage& fm) { - handleException(fm); + + std::weak_ptr wp = jitFuture; + jitFuture->addCallback([this, forkId, parent, wp]() { + handleException(*wp.lock()); this->finishForkRequest(forkId, parent); // Decrease after calling finishForkRequest because, as that creates a new // future, it might otherwise cause the count to briefly go to zero. @@ -676,11 +683,12 @@ void RRefContext::clearRecordedPendingRRefsOnError() { void RRefContext::finishForkRequest(const ForkId& forkId, worker_id_t parent) { delPendingUser(forkId); ++numPendingFutures_; - auto fm = agent_->sendWithRetries( + auto jitFuture = agent_->sendWithRetries( agent_->getWorkerInfo(parent), RRefChildAccept(forkId).toMessage()); - fm->addCallback([this](const FutureMessage& fm) { - handleException(fm); + std::weak_ptr wp = jitFuture; + jitFuture->addCallback([this, wp]() { + handleException(*wp.lock()); --numPendingFutures_; }); } diff --git a/torch/csrc/distributed/rpc/rref_context.h b/torch/csrc/distributed/rpc/rref_context.h index cf89980e7f71..1e3537a6dfd3 100644 --- a/torch/csrc/distributed/rpc/rref_context.h +++ b/torch/csrc/distributed/rpc/rref_context.h @@ -16,16 +16,14 @@ namespace rpc { namespace callback { // It's the callback for RemoteCall. -void TORCH_API confirmPendingUser( - const FutureMessage& futureMessage, - const ForkId& expectedForkId); +void TORCH_API +confirmPendingUser(const JitFuture& jitFuture, const ForkId& expectedForkId); // It's the callback for finishing creating owner rref, it returned deletedRRef, // so that the deletedRRef can be handled under GIL in python_functions.cpp if // deletedRRef contains python object. -c10::intrusive_ptr TORCH_API finishCreatingOwnerRRef( - const FutureMessage& futureMessage, - const RRefId& rrefId); +c10::intrusive_ptr TORCH_API +finishCreatingOwnerRRef(const JitFuture& jitFuture, const RRefId& rrefId); } // namespace callback using torch::utils::Future; @@ -42,7 +40,7 @@ class TORCH_API RRefContext { static std::vector> destroyInstance( bool ignoreRRefLeak = true); - static void handleException(const FutureMessage& fm); + static void handleException(const JitFuture& jitFuture); RRefContext(const RRefContext&) = delete; RRefContext(RRefContext&& other) = delete; diff --git a/torch/csrc/distributed/rpc/rref_impl.cpp b/torch/csrc/distributed/rpc/rref_impl.cpp index 6c6a377a4652..085b65bfc0fb 100644 --- a/torch/csrc/distributed/rpc/rref_impl.cpp +++ b/torch/csrc/distributed/rpc/rref_impl.cpp @@ -65,23 +65,21 @@ RRefForkData RRef::fork() const { getTypeStr(type_)); } -void RRef::handleError( - RPCErrorType errorType, - const FutureMessage& futMessage) { +void RRef::handleError(RPCErrorType errorType, const JitFuture& jitFuture) { static std::unordered_map< RPCErrorType, - std::function, + std::function, std::hash> errorHandlers = { {RPCErrorType::TIMEOUT, - [this](const FutureMessage& /* unused */) { setTimedOut(); }}, + [this](const JitFuture& /* unused */) { setTimedOut(); }}, {RPCErrorType::INTENTIONAL_FAILURE, - [this](const FutureMessage& /* unused */) { setTimedOut(); }}, - {RPCErrorType::UNKNOWN_ERROR, [](const FutureMessage& fm) { + [this](const JitFuture& /* unused */) { setTimedOut(); }}, + {RPCErrorType::UNKNOWN_ERROR, [](const JitFuture& jitFuture) { // Default error handler - RRefContext::handleException(fm); + RRefContext::handleException(jitFuture); }}}; - errorHandlers.find(errorType)->second(futMessage); + errorHandlers.find(errorType)->second(jitFuture); } ////////////////////////// UserRRef ///////////////////////////////////// @@ -170,7 +168,7 @@ IValue UserRRef::toHere(const float timeoutSeconds) const { // toHere is profiled as a blocking call, and does not execute operations on // the remote node. Hence, don't wrap it with a profiling message since we // don't need the profiler to be enabled remotely. - auto futureResponse = autograd::sendMessageWithAutograd( + auto jitFuture = autograd::sendMessageWithAutograd( *agent, agent->getWorkerInfo(ownerId_), std::move(msgToSend), @@ -181,9 +179,10 @@ IValue UserRRef::toHere(const float timeoutSeconds) const { // TODO: we should ideally be able to interrupt this blocking wait if we check // getTimedOut() and it is true // (https://github.com/pytorch/pytorch/issues/39411). - const Message& message = futureResponse->wait(); - MessageType msgType = message.type(); - auto response = deserializeResponse(message, msgType); + jitFuture->waitAndThrow(); + auto messagePtr = jitFuture->constValue().toCustomClass(); + MessageType msgType = messagePtr->type(); + auto response = deserializeResponse(*messagePtr, msgType); TORCH_INTERNAL_ASSERT( msgType == MessageType::SCRIPT_RREF_FETCH_RET || msgType == MessageType::PYTHON_RREF_FETCH_RET, diff --git a/torch/csrc/distributed/rpc/rref_impl.h b/torch/csrc/distributed/rpc/rref_impl.h index 29aa355908fa..c7f812271468 100644 --- a/torch/csrc/distributed/rpc/rref_impl.h +++ b/torch/csrc/distributed/rpc/rref_impl.h @@ -228,12 +228,12 @@ class TORCH_API RRef : public RRefInterface { // node. Note that this is only set when processing requests invoked with // rpc.remote. This is only used to get the future corresponding to the rref // for profiling use cases. - inline void registerOwnerCreationFuture(std::shared_ptr fut) { + inline void registerOwnerCreationFuture(std::shared_ptr fut) { ownerCreationFuture_ = std::move(fut); } // Get the future corresponding to the creation of this rref. - inline std::shared_ptr getOwnerCreationFuture() const { + inline std::shared_ptr getOwnerCreationFuture() const { return ownerCreationFuture_; } @@ -243,7 +243,7 @@ class TORCH_API RRef : public RRefInterface { } // Dispatches an error to the correct handler based on its RPCErrorType. - void handleError(RPCErrorType errorType, const FutureMessage& futMessage); + void handleError(RPCErrorType errorType, const JitFuture& JitFuture); // Send delete UserRRef request to Owner, // if the request hasn't been sent yet. @@ -272,7 +272,7 @@ class TORCH_API RRef : public RRefInterface { // it could be any TypePtr that JIT support, including PyObjectType const TypePtr type_; // Future corresponding to request to create RRef on remote node. - std::shared_ptr ownerCreationFuture_; + std::shared_ptr ownerCreationFuture_; }; // ``UserRRef`` represents a user of an RRef. Besides the ``RRefId``, each user diff --git a/torch/csrc/distributed/rpc/tensorpipe_agent.cpp b/torch/csrc/distributed/rpc/tensorpipe_agent.cpp index 4a39979e6e1b..00c7567c6d43 100644 --- a/torch/csrc/distributed/rpc/tensorpipe_agent.cpp +++ b/torch/csrc/distributed/rpc/tensorpipe_agent.cpp @@ -166,8 +166,8 @@ std::unique_ptr makeMultiplexedUvChannel() { } auto context = std::make_shared( std::move(contexts), std::move(listeners)); - return std::make_unique( - CpuChannelRegistration{std::move(context), kMultiplexedUvChannelPriority}); + return std::make_unique(CpuChannelRegistration{ + std::move(context), kMultiplexedUvChannelPriority}); } // The multiplexed UV channel encapsulates multiple UV transports (each with its @@ -192,7 +192,10 @@ std::unique_ptr makeCudaIpcChannel() { // The cuda_ipc channels use cudaMemcpy to transmit CUDA tensor across processes // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) -C10_REGISTER_CREATOR(TensorPipeCudaChannelRegistry, cuda_ipc, makeCudaIpcChannel); +C10_REGISTER_CREATOR( + TensorPipeCudaChannelRegistry, + cuda_ipc, + makeCudaIpcChannel); #endif @@ -324,9 +327,10 @@ void TensorPipeAgent::startImpl() { if (iter == opts_.channels->end()) { continue; } - // Assign priorities in reverse order of occurrence in the vector, so that - // a channel that comes before another receives a higher priority. - priority = opts_.channels->size() - 1 - (iter - opts_.channels->begin()); + // Assign priorities in reverse order of occurrence in the vector, so + // that a channel that comes before another receives a higher priority. + priority = + opts_.channels->size() - 1 - (iter - opts_.channels->begin()); } // The reg var is either a std::unique_ptr or a @@ -337,7 +341,7 @@ void TensorPipeAgent::startImpl() { priority = reg->priority; } context_->registerChannel( - priority, std::move(key), std::move(reg->channel)); + priority, std::move(key), std::move(reg->channel)); } }; @@ -464,7 +468,7 @@ void TensorPipeAgent::pipeWrite( void TensorPipeAgent::sendCompletedResponseMessage( std::shared_ptr& pipe, - std::shared_ptr& futureResponseMessage, + std::shared_ptr& futureResponseMessage, uint64_t messageId) { if (!rpcAgentRunning_.load()) { LOG(WARNING) << "RPC agent for " << workerInfo_.name_ @@ -477,17 +481,15 @@ void TensorPipeAgent::sendCompletedResponseMessage( << " is sending response to request #" << messageId << " to " << pipe->getRemoteName(); - const c10::optional error = - futureResponseMessage->error(); - Message&& responseMessage = std::move(*futureResponseMessage).moveValue(); - responseMessage.setId(messageId); - if (!error) { + if (!futureResponseMessage->hasError()) { + Message&& responseMessage = + std::move(*futureResponseMessage->value().toCustomClass()); + responseMessage.setId(messageId); std::vector devices; - try { devices = getDevicesForTensors(pipe->getRemoteName(), responseMessage); } catch (const std::exception& e) { - responseMessage = createExceptionResponse(e.what(), responseMessage.id()); + responseMessage = createExceptionResponse(e.what(), messageId); } pipeWrite( @@ -511,7 +513,8 @@ void TensorPipeAgent::sendCompletedResponseMessage( } else { pipeWrite( pipe, - createExceptionResponse(error->what(), responseMessage.id()), + createExceptionResponse( + futureResponseMessage->tryRetrieveErrorMessage(), messageId), {}, [this, pipe, messageId](const tensorpipe::Error& error) { if (error) { @@ -572,12 +575,13 @@ void TensorPipeAgent::respond(std::shared_ptr& pipe) { << " is running request #" << messageId << " from " << pipe->getRemoteName() << " in thread pool"; - std::shared_ptr futureResponseMessage; + std::shared_ptr futureResponseMessage; try { futureResponseMessage = cb_->operator()(requestMessage); - } catch (const std::exception& e) { - futureResponseMessage = std::make_shared(); - futureResponseMessage->setError(e.what()); + } catch (const std::exception& /* unused */) { + futureResponseMessage = + std::make_shared(at::AnyClassType::get()); + futureResponseMessage->setError(std::current_exception()); } // Shortcut if immediately done @@ -604,7 +608,7 @@ void TensorPipeAgent::respond(std::shared_ptr& pipe) { }); } -std::shared_ptr TensorPipeAgent::send( +std::shared_ptr TensorPipeAgent::send( const WorkerInfo& toWorkerInfo, Message&& requestMessage, const float rpcTimeoutSeconds) { @@ -637,7 +641,7 @@ std::shared_ptr TensorPipeAgent::send( ClientPipe& clientPipe = it->second; auto& pendingResponseMessage = clientPipe.pendingResponseMessage_; - auto futureResponseMessage = std::make_shared(); + auto futureResponseMessage = std::make_shared(); uint64_t messageId = nextMessageID_++; requestMessage.setId(messageId); pendingResponseMessage[messageId] = futureResponseMessage; @@ -649,7 +653,7 @@ std::shared_ptr TensorPipeAgent::send( auto devices = getDevicesForTensors(clientPipe.pipe_->getRemoteName(), requestMessage); - futureResponseMessage->futMsg.addCallback([this]() { + futureResponseMessage->jitFuture->addCallback([this]() { TORCH_INTERNAL_ASSERT( this->threadPool_.inThreadPool(), "Future marked complete from outside the thread pool"); @@ -747,7 +751,7 @@ std::shared_ptr TensorPipeAgent::send( << " received response #" << messageId << " from " << clientPipe.pipe_->getRemoteName(); - std::shared_ptr futureResponseMessage; + std::shared_ptr futureResponseMessage; { std::lock_guard lock(mutex_); // A read error will lead all following callbacks to be @@ -778,8 +782,7 @@ std::shared_ptr TensorPipeAgent::send( }); }); - return std::shared_ptr( - futureResponseMessage, &futureResponseMessage->futMsg); + return futureResponseMessage->jitFuture; } void TensorPipeAgent::pollTimeoutRpcs() { @@ -807,9 +810,8 @@ void TensorPipeAgent::pollTimeoutRpcs() { // Move all these futures to a separate vector so we can process them // outside the lock. - std::vector, - std::chrono::milliseconds>> + std::vector< + std::pair, std::chrono::milliseconds>> timedOutFutures = std::move(timeoutMap_.begin()->second); // We can safely remove this key from the timeoutMap_ since all these // futures will be processed. @@ -1026,16 +1028,17 @@ void TensorPipeAgent::decreaseCallCount(int32_t& count) { } void TensorPipeAgent::markFutureAsComplete( - std::shared_ptr futureMessage, + std::shared_ptr atomicFuture, Message message) { - if (!futureMessage->isComplete.test_and_set()) { + if (!atomicFuture->isComplete.test_and_set()) { // Completing the future will run its callbacks, which could execute // arbitrary user code. To prevent blocking or stalling the TensorPipe event // loops, we defer this to a worker thread. threadPool_.run([this, - futureMessage{std::move(futureMessage)}, + atomicFuture{std::move(atomicFuture)}, message{std::move(message)}]() mutable { - futureMessage->futMsg.markCompleted(std::move(message)); + atomicFuture->jitFuture->markCompleted( + IValue(c10::make_intrusive(std::move(message)))); // The future's callbacks may schedule further RPCs, increasing the count. // Thus we must decrease it after completing the future, otherwise it may // briefly dip to zero and trick join into thinking all work is done. @@ -1045,16 +1048,17 @@ void TensorPipeAgent::markFutureAsComplete( } void TensorPipeAgent::markFutureWithError( - std::shared_ptr futureMessage, + std::shared_ptr atomicFuture, std::string errorMsg) { - if (!futureMessage->isComplete.test_and_set()) { + if (!atomicFuture->isComplete.test_and_set()) { // Completing the future will run its callbacks, which could execute // arbitrary user code. To prevent blocking or stalling the TensorPipe event // loops, we defer this to a worker thread. threadPool_.run([this, - futureMessage{std::move(futureMessage)}, + atomicFuture{std::move(atomicFuture)}, errorMsg{std::move(errorMsg)}]() mutable { - futureMessage->futMsg.setError(std::move(errorMsg)); + atomicFuture->jitFuture->setError( + std::make_exception_ptr(std::runtime_error(errorMsg))); // The future's callbacks may schedule further RPCs, increasing the count. // Thus we must decrease it after completing the future, otherwise it may // briefly dip to zero and trick join into thinking all work is done. diff --git a/torch/csrc/distributed/rpc/tensorpipe_agent.h b/torch/csrc/distributed/rpc/tensorpipe_agent.h index 3eb7cdc6ec7e..9f75eaf4e0af 100644 --- a/torch/csrc/distributed/rpc/tensorpipe_agent.h +++ b/torch/csrc/distributed/rpc/tensorpipe_agent.h @@ -11,7 +11,6 @@ #include #include - // Forward-declare the TensorPipe classes we need, to avoid including its // headers in PyTorch's ones and thus have it become a public dependency. @@ -89,7 +88,6 @@ struct CudaChannelRegistration { // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) C10_DECLARE_REGISTRY(TensorPipeCudaChannelRegistry, CudaChannelRegistration); - constexpr auto kDefaultNumWorkerThreads = 16; struct TensorPipeRpcBackendOptions : public RpcBackendOptions { @@ -181,7 +179,7 @@ class TensorPipeAgent : public RpcAgent { TensorPipeAgent(const TensorPipeAgent&) = delete; TensorPipeAgent& operator=(const TensorPipeAgent&) = delete; - std::shared_ptr send( + std::shared_ptr send( const WorkerInfo& to, Message&& message, const float rpcTimeoutSeconds = kUnsetRpcTimeout) override; @@ -249,7 +247,7 @@ class TensorPipeAgent : public RpcAgent { void sendCompletedResponseMessage( std::shared_ptr& pipe, - std::shared_ptr& futureResponseMessage, + std::shared_ptr& futureResponseMessage, uint64_t messageId); // Collects metrics from successful RPC calls @@ -273,8 +271,9 @@ class TensorPipeAgent : public RpcAgent { // only if it isn't yet. It does exist for errors (setErrorIfNeeded) but, even // then, it ends up printing a log message, which may worry the user. To solve // both issues we use a separate atomic flag to know the status of the future. - struct AtomicFutureMessage { - FutureMessage futMsg; + struct AtomicJitFuture { + std::shared_ptr jitFuture = + std::make_shared(at::AnyClassType::get()); std::atomic_flag isComplete = ATOMIC_FLAG_INIT; }; @@ -288,7 +287,7 @@ class TensorPipeAgent : public RpcAgent { std::shared_ptr pipe_; bool readError_{false}; // Map from Message Request ID's to corresponding futures. - std::unordered_map> + std::unordered_map> pendingResponseMessage_; }; @@ -321,7 +320,7 @@ class TensorPipeAgent : public RpcAgent { std::map< steady_clock_time_point, std::vector, + std::shared_ptr, std::chrono::milliseconds>>> timeoutMap_; @@ -394,10 +393,10 @@ class TensorPipeAgent : public RpcAgent { // Helpers to set the state of the requests. void markFutureAsComplete( - std::shared_ptr futureMessage, + std::shared_ptr atomicFuture, Message message); void markFutureWithError( - std::shared_ptr futureMessage, + std::shared_ptr atomicFuture, std::string errorMsg); }; diff --git a/torch/csrc/distributed/rpc/testing/faulty_process_group_agent.cpp b/torch/csrc/distributed/rpc/testing/faulty_process_group_agent.cpp index dccf1abb6d3e..7f6c3015f544 100644 --- a/torch/csrc/distributed/rpc/testing/faulty_process_group_agent.cpp +++ b/torch/csrc/distributed/rpc/testing/faulty_process_group_agent.cpp @@ -56,7 +56,7 @@ std::unordered_map> FaultyProcessGroupAgent:: return delayMessages; } -std::shared_ptr FaultyProcessGroupAgent::send( +std::shared_ptr FaultyProcessGroupAgent::send( const WorkerInfo& to, Message&& message, const float rpcTimeoutSeconds) { @@ -78,11 +78,12 @@ std::shared_ptr FaultyProcessGroupAgent::send( if (failMessageCountMap_[key] < failNumSends_) { failMessageCountMap_[key]++; lock.unlock(); - auto fm = std::make_shared(); - fm->setError(makeRPCError( - c10::str("Send attempt failed intentionally for ", key), - RPCErrorType::INTENTIONAL_FAILURE)); - return fm; + auto jitFuture = std::make_shared(at::AnyClassType::get()); + jitFuture->setError( + std::make_exception_ptr(std::runtime_error(makeRPCError( + c10::str("Send attempt failed intentionally for ", key), + RPCErrorType::INTENTIONAL_FAILURE)))); + return jitFuture; } else { lock.unlock(); return ProcessGroupAgent::send(to, std::move(message), rpcTimeoutSeconds); diff --git a/torch/csrc/distributed/rpc/testing/faulty_process_group_agent.h b/torch/csrc/distributed/rpc/testing/faulty_process_group_agent.h index 25a162bbd559..8cbe4c9a137d 100644 --- a/torch/csrc/distributed/rpc/testing/faulty_process_group_agent.h +++ b/torch/csrc/distributed/rpc/testing/faulty_process_group_agent.h @@ -43,7 +43,7 @@ class FaultyProcessGroupAgent : public ProcessGroupAgent { int failNumSends = 0); // Faulty send function for this class. - std::shared_ptr send( + std::shared_ptr send( const WorkerInfo& to, Message&& message, const float rpcTimeoutSeconds = diff --git a/torch/csrc/distributed/rpc/torchscript_functions.cpp b/torch/csrc/distributed/rpc/torchscript_functions.cpp index a9cd006439e8..eceeb3401376 100644 --- a/torch/csrc/distributed/rpc/torchscript_functions.cpp +++ b/torch/csrc/distributed/rpc/torchscript_functions.cpp @@ -14,7 +14,7 @@ namespace torch { namespace distributed { namespace rpc { -c10::intrusive_ptr rpcTorchscript( +c10::intrusive_ptr rpcTorchscript( const std::string& dstWorkerName, const c10::QualifiedName& qualifiedName, const c10::FunctionSchema& functionSchema, @@ -43,14 +43,14 @@ c10::intrusive_ptr rpcTorchscript( auto scriptCall = std::make_unique( qualifiedName, std::move(stack), isAsyncExecution); auto rpcAgentPtr = RpcAgent::getCurrentRpcAgent(); - auto futMessage = autograd::sendMessageWithAutograd( + auto jitFuture = autograd::sendMessageWithAutograd( *rpcAgentPtr, rpcAgentPtr->getWorkerInfo(dstWorkerName), std::move(*scriptCall).toMessage(), true /*forceGradRecording*/, rpcTimeoutSeconds); - // Get function return type to construct c10::ivalue::Future. + // Get function return type to construct JitFuture. auto returns = functionSchema.returns(); // Script call only allows single IValue returned. TORCH_INTERNAL_ASSERT( @@ -62,15 +62,15 @@ c10::intrusive_ptr rpcTorchscript( // Create a JIT future and pass it to futMessage's callback to set state // of the JIT future. - auto futPtr = c10::make_intrusive(returnType); - std::weak_ptr wp = futMessage; - futMessage->addCallback(at::wrapPropagateTLSState([futPtr, wp]() { - auto futMessage = wp.lock(); - if (futMessage->hasError()) { - c10::ivalue::Future::FutureError jitFutErr(futMessage->error()->what()); - futPtr->setError(std::make_exception_ptr(jitFutErr)); + auto futPtr = c10::make_intrusive(returnType); + std::weak_ptr wp = jitFuture; + jitFuture->addCallback(at::wrapPropagateTLSState([futPtr, wp]() { + auto future = wp.lock(); + if (future->hasError()) { + futPtr->setError(future->exception_ptr()); } else { - futPtr->markCompleted(deserializeRespToIValue(futMessage->constValue())); + futPtr->markCompleted(deserializeRespToIValue( + *future->constValue().toCustomClass())); } })); if (shouldProfile) { @@ -112,21 +112,19 @@ c10::intrusive_ptr remoteTorchscript( userRRefPtr->forkId(), isAsyncExecution); - auto fm = torch::distributed::autograd::sendMessageWithAutograd( + auto jitFuture = torch::distributed::autograd::sendMessageWithAutograd( *rpcAgentPtr, dstWorkerInfo, std::move(*scriptRemoteCall).toMessage(), true /*forceGradRecording*/, rpcTimeoutSeconds /* timeout */); - userRRefPtr->registerOwnerCreationFuture(fm); - + userRRefPtr->registerOwnerCreationFuture(jitFuture); ctx.addPendingUser(userRRefPtr->forkId(), userRRefPtr); - std::weak_ptr wp = fm; - fm->addCallback( + std::weak_ptr wp = jitFuture; + jitFuture->addCallback( at::wrapPropagateTLSState([wp, forkId{userRRefPtr->forkId()}]() { - auto fm = wp.lock(); - callback::confirmPendingUser(*fm, forkId); + callback::confirmPendingUser(*wp.lock(), forkId); })); return userRRefPtr; @@ -142,19 +140,18 @@ c10::intrusive_ptr remoteTorchscript( ownerRRefPtr->rrefId(), isAsyncExecution); - auto fm = torch::distributed::autograd::sendMessageWithAutograd( + auto jitFuture = torch::distributed::autograd::sendMessageWithAutograd( *rpcAgentPtr, dstWorkerInfo, std::move(*scriptRemoteCall).toMessage(), true /*forceGradRecording*/, rpcTimeoutSeconds /* timeout */); - ownerRRefPtr->registerOwnerCreationFuture(fm); - std::weak_ptr wp = fm; - fm->addCallback(at::wrapPropagateTLSState( + ownerRRefPtr->registerOwnerCreationFuture(jitFuture); + std::weak_ptr wp = jitFuture; + jitFuture->addCallback(at::wrapPropagateTLSState( [wp, ownerRRefId = ownerRRefPtr->rrefId()]() { - auto fm = wp.lock(); - callback::finishCreatingOwnerRRef(*fm, ownerRRefId); + callback::finishCreatingOwnerRRef(*wp.lock(), ownerRRefId); })); return ownerRRefPtr; } diff --git a/torch/csrc/distributed/rpc/utils.cpp b/torch/csrc/distributed/rpc/utils.cpp index 79c197505bbb..8c39cc3e3e97 100644 --- a/torch/csrc/distributed/rpc/utils.cpp +++ b/torch/csrc/distributed/rpc/utils.cpp @@ -58,15 +58,15 @@ void processRemoteProfiledEvents( const std::string kRPCErrorPrefix = std::string("RPCErr"); -RPCErrorType getRPCErrorType(const FutureMessage& fm) { +RPCErrorType getRPCErrorType(const JitFuture& jitFuture) { TORCH_INTERNAL_ASSERT( - fm.hasError(), - "FutureMessage passed to getRPCErrorType does not have an error."); + jitFuture.hasError(), + "JitFuture of Message passed to getRPCErrorType does not have an error."); // Attempt to parse for error string given by makeRPCError, otherwise return // unknown error. // Note that this function expects errors formatted with makeRPCError(). - auto err = std::string(fm.error()->what()); + auto err = jitFuture.tryRetrieveErrorMessage(); size_t pos = err.find(kRPCErrorPrefix); if (pos != std::string::npos) { // Parse the RPCErrorType. diff --git a/torch/csrc/distributed/rpc/utils.h b/torch/csrc/distributed/rpc/utils.h index aa920d06cae8..4d27daff6ffe 100644 --- a/torch/csrc/distributed/rpc/utils.h +++ b/torch/csrc/distributed/rpc/utils.h @@ -15,7 +15,7 @@ namespace distributed { namespace rpc { // Parse error message and return RPCErrorType based on the message. -TORCH_API RPCErrorType getRPCErrorType(const FutureMessage& fm); +TORCH_API RPCErrorType getRPCErrorType(const JitFuture& jitFuture); // Create an error string given the error description and error type TORCH_API std::string makeRPCError( const std::string& rpcErrorStr, diff --git a/torch/csrc/jit/frontend/schema_matching.cpp b/torch/csrc/jit/frontend/schema_matching.cpp index d544df0d5535..f431dd3d4a0f 100644 --- a/torch/csrc/jit/frontend/schema_matching.cpp +++ b/torch/csrc/jit/frontend/schema_matching.cpp @@ -284,6 +284,19 @@ static bool varargsCanBeUsedAsList( !typevar_list; } +// Note (@zasdfgbnm): +// This is a workaround for https://github.com/pytorch/pytorch/issues/47964 +// Currently JIT does not distinguish ScalarType vs int, so there is really +// no way to distinguish x.view(1) vs x.view(torch.int8). So we have to hardcode +// the aten::view.dtype here to block this overload. This blocklist should be +// removed when JIT fully suports ScalarType as its own type. +bool isBlockListedSchema(const FunctionSchema& schema) { + if (schema.name() == "aten::view" && schema.overload_name() == "dtype") { + return true; + } + return false; +} + static c10::optional tryMatchSchema( const FunctionSchema& schema, const SourceRange& loc, @@ -293,6 +306,10 @@ static c10::optional tryMatchSchema( c10::optional self, std::ostream* failure_messages, bool allow_conversions) { + if (isBlockListedSchema(schema)) { + return c10::nullopt; + } + auto err = [&]() -> std::ostream& { *failure_messages << "\n" << schema << ":\n"; return *failure_messages; diff --git a/torch/csrc/jit/ir/ir.cpp b/torch/csrc/jit/ir/ir.cpp index eb75928e5952..7f58aca25271 100644 --- a/torch/csrc/jit/ir/ir.cpp +++ b/torch/csrc/jit/ir/ir.cpp @@ -104,7 +104,7 @@ std::ostream& operator<<( static void printAttribute(std::ostream& out, const at::Tensor& tensor) { // 1-elem tensors are usually boxed scalars, so print them like it if (tensor.numel() == 1) { - auto scalar_tensor = tensor.view({}).item(); + auto scalar_tensor = tensor.view(std::vector{}).item(); out << "{"; if (scalar_tensor.isFloatingPoint()) { out << scalar_tensor.toDouble(); diff --git a/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp b/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp index 1d38e1e1f4cd..1a63edf0fc15 100644 --- a/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp +++ b/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp @@ -71,6 +71,7 @@ static bool IsComparisonOp(const NodeKind& nkind) { static TensorTypePtr CreateProfiledTensorTypeWithScalarType( const TensorTypePtr& typePtr, const c10::ScalarType& scalar_type) { + AT_ASSERT(typePtr != nullptr); return typePtr->withScalarType({scalar_type}); } @@ -132,6 +133,15 @@ static c10::optional PromoteScalarTypesWithCategory( static c10::optional InferExpectedScalarType(const Node* n) { std::vector typesFromTensors; std::vector typesFromScalars; + + auto get_scalar_type = + [](const Value* input) -> c10::optional { + if (auto tensor_type = input->type()->cast()) { + return tensor_type->scalarType(); + } + return c10::nullopt; + }; + std::for_each( n->inputs().begin(), n->inputs().end(), [&](const Value* input) { auto nkind = input->node()->kind(); @@ -180,16 +190,13 @@ static c10::optional InferExpectedScalarType(const Node* n) { } else { typesFromTensors.emplace_back(scalar_type); } - } else if ( - auto scalar_type = - input->type()->cast()->scalarType()) { + } else if (auto scalar_type = get_scalar_type(input)) { typesFromTensors.emplace_back(*scalar_type); } }); c10::optional st = c10::nullopt; - const c10::optional output_st = - n->output()->type()->cast()->scalarType(); + const auto output_st = get_scalar_type(n->output()); if (IsComparisonOp(n->kind())) { // For comparison ops, always promote scalar type to highest among inputs, @@ -236,7 +243,8 @@ static void UpdateScalarTypeForInputs( for (auto input : n->inputs()) { auto input_tensor_type = input->type()->cast(); - auto input_scalar_type = input_tensor_type->scalarType(); + auto input_scalar_type = + input_tensor_type ? input_tensor_type->scalarType() : c10::nullopt; if ((input->node()->kind() == onnx::Constant) || (input_scalar_type && (*input_scalar_type != scalar_type))) { diff --git a/torch/csrc/jit/python/pybind_utils.h b/torch/csrc/jit/python/pybind_utils.h index bfe160b2647f..b14f4ddc37fd 100644 --- a/torch/csrc/jit/python/pybind_utils.h +++ b/torch/csrc/jit/python/pybind_utils.h @@ -122,7 +122,7 @@ struct VISIBILITY_HIDDEN PythonFutureWrapper // Capture a copy of the ivalue::Future instead of the `this` pointer // because the PythonFutureWrapper object could have been deleted // when the callbacks are fired. For example, RPC only captures the - // ivalue::Future instead of PythonFutureWrapper in FutureMessage's + // ivalue::Future instead of PythonFutureWrapper in JitFuture's // callback functions. Hence, if user code does not hold a reference to // this PythonFutureWrapper object, there is no guarantee that the // PythonFutureWrapper is still valid when running the callback. diff --git a/torch/csrc/jit/python/python_ir.cpp b/torch/csrc/jit/python/python_ir.cpp index 6e68fe9ebec3..b6f6169ede6c 100644 --- a/torch/csrc/jit/python/python_ir.cpp +++ b/torch/csrc/jit/python/python_ir.cpp @@ -625,7 +625,8 @@ void initPythonIRBindings(PyObject* module_) { [](Node& n, const char* name, const at::Tensor& v) { return n.t_( Symbol::attr(name), - autograd::Variable(v.view({})).set_requires_grad(false)); + autograd::Variable(v.view(std::vector{})) + .set_requires_grad(false)); }) .def( "z", @@ -634,7 +635,8 @@ void initPythonIRBindings(PyObject* module_) { "zs_", [](Node& n, const char* name, TensorsAttr::ValueType v) { for (auto& i : v) { - i = autograd::Variable(i.view({})).set_requires_grad(false); + i = autograd::Variable(i.view(std::vector{})) + .set_requires_grad(false); } return n.ts_(Symbol::attr(name), std::move(v)); }) diff --git a/torch/csrc/jit/runtime/vararg_functions.cpp b/torch/csrc/jit/runtime/vararg_functions.cpp index 220a5e67f723..930ea5eb8f98 100644 --- a/torch/csrc/jit/runtime/vararg_functions.cpp +++ b/torch/csrc/jit/runtime/vararg_functions.cpp @@ -213,7 +213,7 @@ void listConstruct(Stack& stack, const at::ListType& type, size_t num_inputs) { c10::List vals(type.getElementType()); vals.reserve(num_inputs); for (size_t i = stack.size() - num_inputs; i < stack.size(); ++i) { - vals.emplace_back(std::move(stack[i])); + vals.push_back(std::move(stack[i])); } drop(stack, num_inputs); return vals; diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp index e6e31ba4d96c..92cec735b3bb 100644 --- a/torch/csrc/jit/tensorexpr/kernel.cpp +++ b/torch/csrc/jit/tensorexpr/kernel.cpp @@ -1081,8 +1081,9 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) { } break; case aten::exp: { - return computeOneOperand( - "aten_exp", v, [](const ExprHandle& a) { return exp(a); }); + return computeOneOperand("aten_exp", v, [](const ExprHandle& a) { + return exp(promoteIntegerToDefaultType(a)); + }); } break; case aten::expm1: { diff --git a/torch/csrc/utils/tensor_types.cpp b/torch/csrc/utils/tensor_types.cpp index e6b851a3a74c..5d60f8b07c64 100644 --- a/torch/csrc/utils/tensor_types.cpp +++ b/torch/csrc/utils/tensor_types.cpp @@ -21,6 +21,7 @@ static const char* backend_to_string(const at::Backend& backend) { case at::Backend::CUDA: return "torch.cuda"; case at::Backend::SparseCPU: return "torch.sparse"; case at::Backend::SparseCUDA: return "torch.cuda.sparse"; + case at::Backend::QuantizedCPU: return "torch.quantized"; default: AT_ERROR("Unimplemented backend ", backend); } } diff --git a/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py b/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py index b1be0b52c9e6..7183aa1a82a3 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py @@ -82,10 +82,7 @@ def __init__( self.q_memory_dict = {} -def powerSGD_hook( - state: PowerSGDState, - bucket, -) -> torch.futures.Future: +def powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future: """ This DDP communication hook implements the original PowerSGD gradient compression algorithm described in https://arxiv.org/abs/1905.13727. @@ -322,10 +319,7 @@ def decompress(fut): ) -def batched_powerSGD_hook( - state: PowerSGDState, - bucket, -) -> torch.futures.Future: +def batched_powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future: """ This DDP communication hook implements a simplified PowerSGD gradient compression algorithm described in https://arxiv.org/abs/1905.13727. diff --git a/torch/distributed/launch.py b/torch/distributed/launch.py index 58fc42b33dbf..a6a5b26e6d40 100644 --- a/torch/distributed/launch.py +++ b/torch/distributed/launch.py @@ -295,7 +295,7 @@ def sigkill_handler(signum, frame): print(f"Killing subprocess {process.pid}") try: process.kill() - except Exception as e: + except Exception: pass if last_return_code is not None: raise subprocess.CalledProcessError(returncode=last_return_code, cmd=cmd) diff --git a/torch/distributed/rpc/api.py b/torch/distributed/rpc/api.py index 0a99df67269e..ce200dba392f 100644 --- a/torch/distributed/rpc/api.py +++ b/torch/distributed/rpc/api.py @@ -381,7 +381,7 @@ def _rref_typeof_on_user(rref): # Combine the implementation class and the type class. class RRef(PyRRef, Generic[T]): pass -except TypeError as exc: +except TypeError: # TypeError: metaclass conflict: the metaclass of a derived class # must be a (non-strict) subclass of the metaclasses of all its bases # Mypy doesn't understand __class__ (mypy bug #4177) diff --git a/torch/jit/frontend.py b/torch/jit/frontend.py index da3b3c2301a6..c6e5bd9a7870 100644 --- a/torch/jit/frontend.py +++ b/torch/jit/frontend.py @@ -166,9 +166,14 @@ def get_jit_class_def(cls, self_name): and not is_static_fn(cls, m.__name__) and m.__name__ in cls.__dict__ ) + + def is_classmethod(fn): + return inspect.ismethod(fn) and getattr(fn, "__self__", None) == cls + methods = [get_jit_def(method[1], method[0], - self_name=self_name) for method in methods] + self_name=self_name, + is_classmethod=is_classmethod(method[1])) for method in methods] properties = get_class_properties(cls, self_name) @@ -217,7 +222,7 @@ def remove_prefix(text, prefix): return aligned_prefix + aligned_suffix -def get_jit_def(fn, def_name, self_name=None): +def get_jit_def(fn, def_name, self_name=None, is_classmethod=False): """ Build a JIT AST (TreeView) from the given function. @@ -244,6 +249,12 @@ def _forward(self): ctx = SourceContext(source, filename, file_lineno, leading_whitespace_len, True) fn_def = py_ast.body[0] + if is_classmethod: + arg_name = fn_def.args.args[0].arg + # Insert a statement that assigns the first argument to the class + assign_stmt = ast.parse(f"{arg_name} = {self_name}").body[0] + fn_def.body.insert(0, assign_stmt) + # Swap out the function signature and body if it is unused if should_drop(fn): unused_fn_def = ast.parse("def unused_fn(self: Any):\n\traise RuntimeError(\"Cannot call @unused methods\")") diff --git a/torch/lib/c10d/ProcessGroupNCCL.cpp b/torch/lib/c10d/ProcessGroupNCCL.cpp index b9ac5aa77150..f7c6658d715e 100644 --- a/torch/lib/c10d/ProcessGroupNCCL.cpp +++ b/torch/lib/c10d/ProcessGroupNCCL.cpp @@ -1413,7 +1413,7 @@ c10::intrusive_ptr ProcessGroupNCCL::barrier( // Use user defined GPU device ids if provided if (!opts.device_ids.empty()) { for (auto device : opts.device_ids) { - devices.push_back(at::Device(at::DeviceType::CUDA, device)); + devices.emplace_back(at::DeviceType::CUDA, device); } } else if (usedDeviceIdxs_.empty()) { // This means there is not yet a NCCL collective being called @@ -1423,10 +1423,10 @@ c10::intrusive_ptr ProcessGroupNCCL::barrier( // ensure that each process is on a different GPU auto numGPUs = at::cuda::getNumGPUs(); int16_t deviceIdx = static_cast(rank_ % numGPUs); - devices.push_back(at::Device(at::DeviceType::CUDA, deviceIdx)); + devices.emplace_back(at::DeviceType::CUDA, deviceIdx); } else { for (auto usedDeviceIdx : usedDeviceIdxs_) { - devices.push_back(at::Device(at::DeviceType::CUDA, usedDeviceIdx)); + devices.emplace_back(at::DeviceType::CUDA, usedDeviceIdx); } } diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index 0f99def6c7fe..de5fcb54ddb7 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -403,6 +403,93 @@ (tensor(3.7417), tensor(11.2250)) """) +svd = _add_docstr(_linalg.linalg_svd, r""" +linalg.svd(input, full_matrices=True, compute_uv=True, *, out=None) -> (Tensor, Tensor, Tensor) + +Computes the singular value decomposition of either a matrix or batch of +matrices :attr:`input`." The singular value decomposition is represented as a +namedtuple ``(U, S, Vh)``, such that :math:`input = U \mathbin{@} diag(S) \times +Vh`. If :attr:`input` is a batch of tensors, then ``U``, ``S``, and ``Vh`` are +also batched with the same batch dimensions as :attr:`input`. + +If :attr:`full_matrices` is ``False`` (default), the method returns the reduced singular +value decomposition i.e., if the last two dimensions of :attr:`input` are +``m`` and ``n``, then the returned `U` and `V` matrices will contain only +:math:`min(n, m)` orthonormal columns. + +If :attr:`compute_uv` is ``False``, the returned `U` and `Vh` will be empy +tensors with no elements and the same device as :attr:`input`. The +:attr:`full_matrices` argument has no effect when :attr:`compute_uv` is False. + +The dtypes of ``U`` and ``V`` are the same as :attr:`input`'s. ``S`` will +always be real-valued, even if :attr:`input` is complex. + +.. note:: Unlike NumPy's ``linalg.svd``, this always returns a namedtuple of + three tensors, even when :attr:`compute_uv=False`. + +.. note:: The singular values are returned in descending order. If :attr:`input` is a batch of matrices, + then the singular values of each matrix in the batch is returned in descending order. + +.. note:: The implementation of SVD on CPU uses the LAPACK routine `?gesdd` (a divide-and-conquer + algorithm) instead of `?gesvd` for speed. Analogously, the SVD on GPU uses the MAGMA routine + `gesdd` as well. + +.. note:: The returned matrix `U` will be transposed, i.e. with strides + :code:`U.contiguous().transpose(-2, -1).stride()`. + +.. note:: Gradients computed using `U` and `Vh` may be unstable if + :attr:`input` is not full rank or has non-unique singular values. + +.. note:: When :attr:`full_matrices` = ``True``, the gradients on :code:`U[..., :, min(m, n):]` + and :code:`V[..., :, min(m, n):]` will be ignored in backward as those vectors + can be arbitrary bases of the subspaces. + +.. note:: The `S` tensor can only be used to compute gradients if :attr:`compute_uv` is True. + + +Args: + input (Tensor): the input tensor of size :math:`(*, m, n)` where `*` is zero or more + batch dimensions consisting of :math:`m \times n` matrices. + full_matrices (bool, optional): controls whether to compute the full or reduced decomposition, and + consequently the shape of returned ``U`` and ``V``. Defaults to True. + compute_uv (bool, optional): whether to compute `U` and `V` or not. Defaults to True. + out (tuple, optional): a tuple of three tensors to use for the outputs. If compute_uv=False, + the 1st and 3rd arguments must be tensors, but they are ignored. E.g. you can + pass `(torch.Tensor(), out_S, torch.Tensor())` + +Example:: + + >>> import torch + >>> a = torch.randn(5, 3) + >>> a + tensor([[-0.3357, -0.2987, -1.1096], + [ 1.4894, 1.0016, -0.4572], + [-1.9401, 0.7437, 2.0968], + [ 0.1515, 1.3812, 1.5491], + [-1.8489, -0.5907, -2.5673]]) + >>> + >>> # reconstruction in the full_matrices=False case + >>> u, s, vh = torch.linalg.svd(a, full_matrices=False) + >>> u.shape, s.shape, vh.shape + (torch.Size([5, 3]), torch.Size([3]), torch.Size([3, 3])) + >>> torch.dist(a, u @ torch.diag(s) @ vh) + tensor(1.0486e-06) + >>> + >>> # reconstruction in the full_matrices=True case + >>> u, s, vh = torch.linalg.svd(a) + >>> u.shape, s.shape, vh.shape + (torch.Size([5, 5]), torch.Size([3]), torch.Size([3, 3])) + >>> torch.dist(a, u[:, :3] @ torch.diag(s) @ vh) + >>> torch.dist(a, u[:, :3] @ torch.diag(s) @ vh) + tensor(1.0486e-06) + >>> + >>> # extra dimensions + >>> a_big = torch.randn(7, 5, 3) + >>> u, s, vh = torch.linalg.svd(a_big, full_matrices=False) + >>> torch.dist(a_big, u @ torch.diag_embed(s) @ vh) + tensor(3.0957e-06) +""") + cond = _add_docstr(_linalg.linalg_cond, r""" linalg.cond(input, p=None, *, out=None) -> Tensor @@ -644,15 +731,15 @@ .. note:: Backpropagation is not supported for ``mode='r'``. Use ``mode='reduced'`` instead. - If you plan to backpropagate through QR, note that the current backward implementation - is only well-defined when the first :math:`\min(input.size(-1), input.size(-2))` - columns of :attr:`input` are linearly independent. - This behavior may change in the future. + Backpropagation is also not supported if the first + :math:`\min(input.size(-1), input.size(-2))` columns of any matrix + in :attr:`input` are not linearly independent. While no error will + be thrown when this occurs the values of the "gradient" produced may + be anything. This behavior may change in the future. .. note:: This function uses LAPACK for CPU inputs and MAGMA for CUDA inputs, and may produce different (valid) decompositions on different device types - and different platforms, depending on the precise version of the - underlying library. + or different platforms. Args: input (Tensor): the input tensor of size :math:`(*, m, n)` where `*` is zero or more @@ -666,11 +753,8 @@ * ``'r'``: computes only `R`; returns `(Q, R)` where `Q` is empty and `R` has dimensions (k, n) Keyword args: - out (tuple, optional): tuple of `Q` and `R` tensors - satisfying :code:`input = torch.matmul(Q, R)`. - The dimensions of `Q` and `R` are :math:`(*, m, k)` and :math:`(*, k, n)` - respectively, where :math:`k = \min(m, n)` if :attr:`mode` is `'reduced'` and - :math:`k = m` if :attr:`mode` is `'complete'`. + out (tuple, optional): tuple of `Q` and `R` tensors. + The dimensions of `Q` and `R` are detailed in the description of :attr:`mode` above. Example:: @@ -692,6 +776,11 @@ tensor([[ 1., 0., 0.], [ 0., 1., -0.], [ 0., -0., 1.]]) + >>> q2, r2 = torch.linalg.qr(a, mode='r') + >>> q2 + tensor([]) + >>> torch.equal(r, r2) + True >>> a = torch.randn(3, 4, 5) >>> q, r = torch.linalg.qr(a, mode='complete') >>> torch.allclose(torch.matmul(q, r), a) diff --git a/torch/nn/_reduction.py b/torch/nn/_reduction.py index 025ef157b958..2b5a91276bd9 100644 --- a/torch/nn/_reduction.py +++ b/torch/nn/_reduction.py @@ -4,8 +4,7 @@ # NB: Keep this file in sync with enums in aten/src/ATen/core/Reduction.h -def get_enum(reduction): - # type: (str) -> int +def get_enum(reduction: str) -> int: if reduction == 'none': ret = 0 elif reduction == 'mean': @@ -25,8 +24,7 @@ def get_enum(reduction): # We use these functions in torch/legacy as well, in which case we'll silence the warning -def legacy_get_string(size_average, reduce, emit_warning=True): - # type: (Optional[bool], Optional[bool], bool) -> str +def legacy_get_string(size_average: Optional[bool], reduce: Optional[bool], emit_warning: bool = True) -> str: warning = "size_average and reduce args will be deprecated, please use reduction='{}' instead." if size_average is None: @@ -45,6 +43,5 @@ def legacy_get_string(size_average, reduce, emit_warning=True): return ret -def legacy_get_enum(size_average, reduce, emit_warning=True): - # type: (Optional[bool], Optional[bool], bool) -> int +def legacy_get_enum(size_average: Optional[bool], reduce: Optional[bool], emit_warning: bool = True) -> int: return get_enum(legacy_get_string(size_average, reduce, emit_warning)) diff --git a/torch/nn/functional.py b/torch/nn/functional.py index fc71e4a0c449..162c22ea236a 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -1,22 +1,25 @@ r"""Functional interface""" -import warnings +from typing import Callable, List, Optional, Tuple import math +import warnings import torch +from torch import _VF from torch._C import _infer_size, _add_docstr +from torch._torch_docs import reproducibility_notes, tf32_notes + +from .._jit_internal import boolean_dispatch, _overload +from ..overrides import has_torch_function, handle_torch_function from . import _reduction as _Reduction +from . import grad # noqa: F401 from .modules import utils from .modules.utils import _single, _pair, _triple, _list_with_default -from . import grad # noqa: F401 -from torch import _VF -from .._jit_internal import boolean_dispatch, List, Optional, _overload, Tuple -from ..overrides import has_torch_function, handle_torch_function -from torch._torch_docs import reproducibility_notes, tf32_notes - Tensor = torch.Tensor -conv1d = _add_docstr(torch.conv1d, r""" +conv1d = _add_docstr( + torch.conv1d, + r""" conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor Applies a 1D convolution over an input signal composed of several input @@ -28,7 +31,10 @@ Note: {cudnn_reproducibility_note} -""".format(**reproducibility_notes, **tf32_notes) + r""" +""".format( + **reproducibility_notes, **tf32_notes + ) + + r""" Args: input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)` @@ -48,9 +54,12 @@ >>> filters = torch.randn(33, 16, 3) >>> inputs = torch.randn(20, 16, 50) >>> F.conv1d(inputs, filters) -""") +""", +) -conv2d = _add_docstr(torch.conv2d, r""" +conv2d = _add_docstr( + torch.conv2d, + r""" conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor Applies a 2D convolution over an input image composed of several input @@ -62,7 +71,10 @@ Note: {cudnn_reproducibility_note} -""".format(**reproducibility_notes, **tf32_notes) + r""" +""".format( + **reproducibility_notes, **tf32_notes + ) + + r""" Args: input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)` weight: filters of shape :math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , kH , kW)` @@ -82,9 +94,12 @@ >>> filters = torch.randn(8,4,3,3) >>> inputs = torch.randn(1,4,5,5) >>> F.conv2d(inputs, filters, padding=1) -""") # noqa: E501 +""", +) # noqa: E501 -conv3d = _add_docstr(torch.conv3d, r""" +conv3d = _add_docstr( + torch.conv3d, + r""" conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor Applies a 3D convolution over an input image composed of several input @@ -96,7 +111,10 @@ Note: {cudnn_reproducibility_note} -""".format(**reproducibility_notes, **tf32_notes) + r""" +""".format( + **reproducibility_notes, **tf32_notes + ) + + r""" Args: input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iT , iH , iW)` @@ -116,9 +134,12 @@ >>> filters = torch.randn(33, 16, 3, 3, 3) >>> inputs = torch.randn(20, 16, 50, 10, 20) >>> F.conv3d(inputs, filters) -""") # noqa: E501 +""", +) # noqa: E501 -conv_transpose1d = _add_docstr(torch.conv_transpose1d, r""" +conv_transpose1d = _add_docstr( + torch.conv_transpose1d, + r""" conv_transpose1d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1) -> Tensor Applies a 1D transposed convolution operator over an input signal @@ -130,7 +151,10 @@ Note: {cudnn_reproducibility_note} -""".format(**reproducibility_notes, **tf32_notes) + r""" +""".format( + **reproducibility_notes, **tf32_notes + ) + + r""" Args: input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)` @@ -153,9 +177,12 @@ >>> inputs = torch.randn(20, 16, 50) >>> weights = torch.randn(16, 33, 5) >>> F.conv_transpose1d(inputs, weights) -""") +""", +) -conv_transpose2d = _add_docstr(torch.conv_transpose2d, r""" +conv_transpose2d = _add_docstr( + torch.conv_transpose2d, + r""" conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1) -> Tensor Applies a 2D transposed convolution operator over an input image @@ -167,7 +194,10 @@ Note: {cudnn_reproducibility_note} -""".format(**reproducibility_notes, **tf32_notes) + r""" +""".format( + **reproducibility_notes, **tf32_notes + ) + + r""" Args: input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)` @@ -192,9 +222,12 @@ >>> inputs = torch.randn(1, 4, 5, 5) >>> weights = torch.randn(4, 8, 3, 3) >>> F.conv_transpose2d(inputs, weights, padding=1) -""") # noqa: E501 +""", +) # noqa: E501 -conv_transpose3d = _add_docstr(torch.conv_transpose3d, r""" +conv_transpose3d = _add_docstr( + torch.conv_transpose3d, + r""" conv_transpose3d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1) -> Tensor Applies a 3D transposed convolution operator over an input image @@ -206,7 +239,10 @@ Note: {cudnn_reproducibility_note} -""".format(**reproducibility_notes, **tf32_notes) + r""" +""".format( + **reproducibility_notes, **tf32_notes + ) + + r""" Args: input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iT , iH , iW)` @@ -230,9 +266,12 @@ >>> inputs = torch.randn(20, 16, 50, 10, 20) >>> weights = torch.randn(16, 33, 3, 3, 3) >>> F.conv_transpose3d(inputs, weights) -""") # noqa: E501 +""", +) # noqa: E501 -conv_tbc = _add_docstr(torch.conv_tbc, r""" +conv_tbc = _add_docstr( + torch.conv_tbc, + r""" Applies a 1-dimensional sequence convolution over an input sequence. Input and output dimensions are (Time, Batch, Channels) - hence TBC. @@ -241,11 +280,14 @@ weight: filter of shape (:math:`\text{kernel width} \times \text{in\_channels} \times \text{out\_channels}`) bias: bias of shape (:math:`\text{out\_channels}`) pad: number of timesteps to pad. Default: 0 -""") +""", +) # Pooling -avg_pool1d = _add_docstr(torch.avg_pool1d, r""" +avg_pool1d = _add_docstr( + torch.avg_pool1d, + r""" avg_pool1d(input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True) -> Tensor Applies a 1D average pooling over an input signal composed of several @@ -273,10 +315,13 @@ >>> F.avg_pool1d(input, kernel_size=3, stride=2) tensor([[[ 2., 4., 6.]]]) -""") +""", +) -avg_pool2d = _add_docstr(torch._C._nn.avg_pool2d, r""" +avg_pool2d = _add_docstr( + torch._C._nn.avg_pool2d, + r""" avg_pool2d(input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None) -> Tensor Applies 2D average-pooling operation in :math:`kH \times kW` regions by step size @@ -299,9 +344,12 @@ averaging calculation. Default: ``True`` divisor_override: if specified, it will be used as divisor, otherwise size of the pooling region will be used. Default: None -""") +""", +) -avg_pool3d = _add_docstr(torch._C._nn.avg_pool3d, r""" +avg_pool3d = _add_docstr( + torch._C._nn.avg_pool3d, + r""" avg_pool3d(input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None) -> Tensor Applies 3D average-pooling operation in :math:`kT \times kH \times kW` regions by step @@ -324,12 +372,13 @@ averaging calculation divisor_override: if specified, it will be used as divisor, otherwise size of the pooling region will be used. Default: None -""") +""", +) -def fractional_max_pool2d_with_indices(input, kernel_size, output_size=None, - output_ratio=None, return_indices=False, - _random_samples=None): +def fractional_max_pool2d_with_indices( + input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None +): # type: (Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], Optional[BroadcastingList2[float]], bool, Optional[Tensor]) -> Tuple[Tensor, Tensor] # noqa r"""Applies 2D fractional max pooling over an input signal composed of several input planes. @@ -363,50 +412,62 @@ def fractional_max_pool2d_with_indices(input, kernel_size, output_size=None, if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): return handle_torch_function( - fractional_max_pool2d_with_indices, (input,), input, kernel_size, - output_size=output_size, output_ratio=output_ratio, - return_indices=return_indices, _random_samples=_random_samples) + fractional_max_pool2d_with_indices, + (input,), + input, + kernel_size, + output_size=output_size, + output_ratio=output_ratio, + return_indices=return_indices, + _random_samples=_random_samples, + ) if output_size is None and output_ratio is None: - raise ValueError("fractional_max_pool2d requires specifying either " - "an output_size or an output_ratio") + raise ValueError("fractional_max_pool2d requires specifying either " "an output_size or an output_ratio") if output_size is None: assert output_ratio is not None _output_ratio = _pair(output_ratio) - output_size = [int(input.size(2) * _output_ratio[0]), - int(input.size(3) * _output_ratio[1])] + output_size = [int(input.size(2) * _output_ratio[0]), int(input.size(3) * _output_ratio[1])] if _random_samples is None: _random_samples = torch.rand(input.size(0), input.size(1), 2, dtype=input.dtype, device=input.device) return torch._C._nn.fractional_max_pool2d(input, kernel_size, output_size, _random_samples) -def _fractional_max_pool2d(input, kernel_size, output_size=None, - output_ratio=None, return_indices=False, - _random_samples=None): +def _fractional_max_pool2d( + input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None +): # type: (Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], Optional[BroadcastingList2[float]], bool, Optional[Tensor]) -> Tensor # noqa if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): return handle_torch_function( - fractional_max_pool2d, (input,), input, kernel_size, - output_size=output_size, output_ratio=output_ratio, - return_indices=return_indices, _random_samples=_random_samples) - return fractional_max_pool2d_with_indices(input, kernel_size, output_size, - output_ratio, return_indices, - _random_samples)[0] + fractional_max_pool2d, + (input,), + input, + kernel_size, + output_size=output_size, + output_ratio=output_ratio, + return_indices=return_indices, + _random_samples=_random_samples, + ) + return fractional_max_pool2d_with_indices( + input, kernel_size, output_size, output_ratio, return_indices, _random_samples + )[0] + fractional_max_pool2d = boolean_dispatch( - arg_name='return_indices', + arg_name="return_indices", arg_index=4, default=False, if_true=fractional_max_pool2d_with_indices, if_false=_fractional_max_pool2d, module_name=__name__, - func_name='fractional_max_pool2d') + func_name="fractional_max_pool2d", +) -def fractional_max_pool3d_with_indices(input, kernel_size, output_size=None, - output_ratio=None, return_indices=False, - _random_samples=None): +def fractional_max_pool3d_with_indices( + input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None +): # type: (Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], Optional[BroadcastingList3[float]], bool, Optional[Tensor]) -> Tuple[Tensor, Tensor] # noqa r"""Applies 3D fractional max pooling over an input signal composed of several input planes. @@ -441,50 +502,66 @@ def fractional_max_pool3d_with_indices(input, kernel_size, output_size=None, if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): return handle_torch_function( - fractional_max_pool3d_with_indices, (input,), input, kernel_size, - output_size=output_size, output_ratio=output_ratio, - return_indices=return_indices, _random_samples=_random_samples) + fractional_max_pool3d_with_indices, + (input,), + input, + kernel_size, + output_size=output_size, + output_ratio=output_ratio, + return_indices=return_indices, + _random_samples=_random_samples, + ) if output_size is None and output_ratio is None: - raise ValueError("fractional_max_pool3d requires specifying either " - "an output_size or an output_ratio") + raise ValueError("fractional_max_pool3d requires specifying either " "an output_size or an output_ratio") if output_size is None: assert output_ratio is not None _output_ratio = _triple(output_ratio) - output_size = [int(input.size(2) * _output_ratio[0]), - int(input.size(3) * _output_ratio[1]), - int(input.size(4) * _output_ratio[2])] + output_size = [ + int(input.size(2) * _output_ratio[0]), + int(input.size(3) * _output_ratio[1]), + int(input.size(4) * _output_ratio[2]), + ] if _random_samples is None: _random_samples = torch.rand(input.size(0), input.size(1), 3, dtype=input.dtype, device=input.device) return torch._C._nn.fractional_max_pool3d(input, kernel_size, output_size, _random_samples) -def _fractional_max_pool3d(input, kernel_size, output_size=None, - output_ratio=None, return_indices=False, - _random_samples=None): +def _fractional_max_pool3d( + input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None +): # type: (Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], Optional[BroadcastingList3[float]], bool, Optional[Tensor]) -> Tensor # noqa if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): return handle_torch_function( - fractional_max_pool3d, (input,), input, kernel_size, - output_size=output_size, output_ratio=output_ratio, - return_indices=return_indices, _random_samples=_random_samples) - return fractional_max_pool3d_with_indices(input, kernel_size, output_size, - output_ratio, return_indices, - _random_samples)[0] + fractional_max_pool3d, + (input,), + input, + kernel_size, + output_size=output_size, + output_ratio=output_ratio, + return_indices=return_indices, + _random_samples=_random_samples, + ) + return fractional_max_pool3d_with_indices( + input, kernel_size, output_size, output_ratio, return_indices, _random_samples + )[0] + fractional_max_pool3d = boolean_dispatch( - arg_name='return_indices', + arg_name="return_indices", arg_index=4, default=False, if_true=fractional_max_pool3d_with_indices, if_false=_fractional_max_pool3d, module_name=__name__, - func_name='fractional_max_pool3d') + func_name="fractional_max_pool3d", +) -def max_pool1d_with_indices(input, kernel_size, stride=None, padding=0, - dilation=1, ceil_mode=False, return_indices=False): +def max_pool1d_with_indices( + input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False +): # type: (Tensor, BroadcastingList1[int], Optional[BroadcastingList1[int]], BroadcastingList1[int], BroadcastingList1[int], bool, bool) -> Tuple[Tensor, Tensor] # noqa r"""Applies a 1D max pooling over an input signal composed of several input planes. @@ -494,41 +571,55 @@ def max_pool1d_with_indices(input, kernel_size, stride=None, padding=0, if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): return handle_torch_function( - max_pool1d_with_indices, (input,), input, kernel_size, - stride=stride, padding=padding, dilation=dilation, ceil_mode=ceil_mode, - return_indices=return_indices) + max_pool1d_with_indices, + (input,), + input, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + return_indices=return_indices, + ) if stride is None: stride = torch.jit.annotate(List[int], []) - return torch.max_pool1d_with_indices( - input, kernel_size, stride, padding, dilation, ceil_mode) + return torch.max_pool1d_with_indices(input, kernel_size, stride, padding, dilation, ceil_mode) -def _max_pool1d(input, kernel_size, stride=None, padding=0, dilation=1, - ceil_mode=False, return_indices=False): +def _max_pool1d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False): # type: (Tensor, BroadcastingList1[int], Optional[BroadcastingList1[int]], BroadcastingList1[int], BroadcastingList1[int], bool, bool) -> Tensor # noqa if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): return handle_torch_function( - max_pool1d, (input,), input, kernel_size, - stride=stride, padding=padding, dilation=dilation, ceil_mode=ceil_mode, - return_indices=return_indices) + max_pool1d, + (input,), + input, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + return_indices=return_indices, + ) if stride is None: stride = torch.jit.annotate(List[int], []) - return torch.max_pool1d( - input, kernel_size, stride, padding, dilation, ceil_mode) + return torch.max_pool1d(input, kernel_size, stride, padding, dilation, ceil_mode) + max_pool1d = boolean_dispatch( - arg_name='return_indices', + arg_name="return_indices", arg_index=6, default=False, if_true=max_pool1d_with_indices, if_false=_max_pool1d, module_name=__name__, - func_name='max_pool1d') + func_name="max_pool1d", +) -def max_pool2d_with_indices(input, kernel_size, stride=None, padding=0, dilation=1, - ceil_mode=False, return_indices=False): +def max_pool2d_with_indices( + input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False +): # type: (Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], BroadcastingList2[int], BroadcastingList2[int], bool, bool) -> Tuple[Tensor, Tensor] # noqa r"""Applies a 2D max pooling over an input signal composed of several input planes. @@ -538,40 +629,55 @@ def max_pool2d_with_indices(input, kernel_size, stride=None, padding=0, dilation if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): return handle_torch_function( - max_pool2d_with_indices, (input,), input, kernel_size, - stride=stride, padding=padding, dilation=dilation, ceil_mode=ceil_mode, - return_indices=return_indices) + max_pool2d_with_indices, + (input,), + input, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + return_indices=return_indices, + ) if stride is None: stride = torch.jit.annotate(List[int], []) return torch._C._nn.max_pool2d_with_indices(input, kernel_size, stride, padding, dilation, ceil_mode) -def _max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1, - ceil_mode=False, return_indices=False): +def _max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False): # type: (Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], BroadcastingList2[int], BroadcastingList2[int], bool, bool) -> Tensor # noqa if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): return handle_torch_function( - max_pool2d, (input,), input, kernel_size, - stride=stride, padding=padding, dilation=dilation, ceil_mode=ceil_mode, - return_indices=return_indices) + max_pool2d, + (input,), + input, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + return_indices=return_indices, + ) if stride is None: stride = torch.jit.annotate(List[int], []) - return torch.max_pool2d( - input, kernel_size, stride, padding, dilation, ceil_mode) + return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode) + max_pool2d = boolean_dispatch( - arg_name='return_indices', + arg_name="return_indices", arg_index=6, default=False, if_true=max_pool2d_with_indices, if_false=_max_pool2d, module_name=__name__, - func_name='max_pool2d') + func_name="max_pool2d", +) -def max_pool3d_with_indices(input, kernel_size, stride=None, padding=0, - dilation=1, ceil_mode=False, return_indices=False): +def max_pool3d_with_indices( + input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False +): # type: (Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], BroadcastingList3[int], BroadcastingList3[int], bool, bool) -> Tuple[Tensor, Tensor] # noqa r"""Applies a 3D max pooling over an input signal composed of several input planes. @@ -581,69 +687,86 @@ def max_pool3d_with_indices(input, kernel_size, stride=None, padding=0, if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): return handle_torch_function( - max_pool3d_with_indices, (input,), input, kernel_size, - stride=stride, padding=padding, dilation=dilation, ceil_mode=ceil_mode, - return_indices=return_indices) + max_pool3d_with_indices, + (input,), + input, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + return_indices=return_indices, + ) if stride is None: stride = torch.jit.annotate(List[int], []) - return torch._C._nn.max_pool3d_with_indices( - input, kernel_size, stride, padding, dilation, ceil_mode) + return torch._C._nn.max_pool3d_with_indices(input, kernel_size, stride, padding, dilation, ceil_mode) -def _max_pool3d(input, kernel_size, stride=None, padding=0, dilation=1, - ceil_mode=False, return_indices=False): +def _max_pool3d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False): # type: (Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], BroadcastingList3[int], BroadcastingList3[int], bool, bool) -> Tensor # noqa if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): return handle_torch_function( - max_pool3d, (input,), input, kernel_size, stride=stride, padding=padding, - dilation=dilation, ceil_mode=ceil_mode, return_indices=return_indices) + max_pool3d, + (input,), + input, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + return_indices=return_indices, + ) if stride is None: stride = torch.jit.annotate(List[int], []) - return torch.max_pool3d( - input, kernel_size, stride, padding, dilation, ceil_mode) + return torch.max_pool3d(input, kernel_size, stride, padding, dilation, ceil_mode) + max_pool3d = boolean_dispatch( - arg_name='return_indices', + arg_name="return_indices", arg_index=6, default=False, if_true=max_pool3d_with_indices, if_false=_max_pool3d, module_name=__name__, - func_name='max_pool3d') + func_name="max_pool3d", +) -def _unpool_output_size(input, kernel_size, stride, padding, output_size): - # type: (Tensor, List[int], List[int], List[int], Optional[List[int]]) -> List[int] +def _unpool_output_size( + input: Tensor, kernel_size: List[int], stride: List[int], padding: List[int], output_size: Optional[List[int]] +) -> List[int]: input_size = input.size() default_size = torch.jit.annotate(List[int], []) for d in range(len(kernel_size)): - default_size.append((input_size[d + 2] - 1) * stride[d] + - kernel_size[d] - 2 * padding[d]) + default_size.append((input_size[d + 2] - 1) * stride[d] + kernel_size[d] - 2 * padding[d]) if output_size is None: ret = default_size else: if len(output_size) == len(kernel_size) + 2: output_size = output_size[2:] if len(output_size) != len(kernel_size): - raise ValueError("output_size should be a sequence containing " - "{} or {} elements, but it has a length of '{}'" - .format(len(kernel_size), len(kernel_size) + 2, - len(output_size))) + raise ValueError( + "output_size should be a sequence containing " + "{} or {} elements, but it has a length of '{}'".format( + len(kernel_size), len(kernel_size) + 2, len(output_size) + ) + ) for d in range(len(kernel_size)): min_size = default_size[d] - stride[d] max_size = default_size[d] + stride[d] if not (min_size < output_size[d] < max_size): raise ValueError( - 'invalid output_size "{}" (dim {} must be between {} and {})' - .format(output_size, d, min_size, max_size)) + 'invalid output_size "{}" (dim {} must be between {} and {})'.format( + output_size, d, min_size, max_size + ) + ) ret = output_size return ret -def max_unpool1d(input, indices, kernel_size, stride=None, padding=0, - output_size=None): +def max_unpool1d(input, indices, kernel_size, stride=None, padding=0, output_size=None): # type: (Tensor, Tensor, BroadcastingList1[int], Optional[BroadcastingList1[int]], BroadcastingList1[int], Optional[BroadcastingList1[int]]) -> Tensor # noqa r"""Computes a partial inverse of :class:`MaxPool1d`. @@ -652,26 +775,30 @@ def max_unpool1d(input, indices, kernel_size, stride=None, padding=0, if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): return handle_torch_function( - max_unpool1d, (input,), input, indices, kernel_size, - stride=stride, padding=padding, output_size=output_size) + max_unpool1d, + (input,), + input, + indices, + kernel_size, + stride=stride, + padding=padding, + output_size=output_size, + ) kernel_size = _single(kernel_size) if stride is not None: _stride = _single(stride) else: _stride = kernel_size padding = _single(padding) - output_size = _unpool_output_size(input, kernel_size, _stride, padding, - output_size) + output_size = _unpool_output_size(input, kernel_size, _stride, padding, output_size) if isinstance(output_size, list): output_size = output_size + [1] else: output_size = output_size + (1,) - return torch._C._nn.max_unpool2d(input.unsqueeze(3), indices.unsqueeze(3), - output_size).squeeze(3) + return torch._C._nn.max_unpool2d(input.unsqueeze(3), indices.unsqueeze(3), output_size).squeeze(3) -def max_unpool2d(input, indices, kernel_size, stride=None, padding=0, - output_size=None): +def max_unpool2d(input, indices, kernel_size, stride=None, padding=0, output_size=None): # type: (Tensor, Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], BroadcastingList2[int], Optional[BroadcastingList2[int]]) -> Tensor # noqa r"""Computes a partial inverse of :class:`MaxPool2d`. @@ -680,21 +807,26 @@ def max_unpool2d(input, indices, kernel_size, stride=None, padding=0, if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): return handle_torch_function( - max_unpool2d, (input,), input, indices, kernel_size, - stride=stride, padding=padding, output_size=output_size) + max_unpool2d, + (input,), + input, + indices, + kernel_size, + stride=stride, + padding=padding, + output_size=output_size, + ) kernel_size = _pair(kernel_size) if stride is not None: _stride = _pair(stride) else: _stride = kernel_size padding = _pair(padding) - output_size = _unpool_output_size(input, kernel_size, _stride, padding, - output_size) + output_size = _unpool_output_size(input, kernel_size, _stride, padding, output_size) return torch._C._nn.max_unpool2d(input, indices, output_size) -def max_unpool3d(input, indices, kernel_size, stride=None, padding=0, - output_size=None): +def max_unpool3d(input, indices, kernel_size, stride=None, padding=0, output_size=None): # type: (Tensor, Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], BroadcastingList3[int], Optional[BroadcastingList3[int]]) -> Tensor # noqa r"""Computes a partial inverse of :class:`MaxPool3d`. @@ -703,18 +835,23 @@ def max_unpool3d(input, indices, kernel_size, stride=None, padding=0, if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): return handle_torch_function( - max_unpool3d, (input,), input, indices, kernel_size, - stride=stride, padding=padding, output_size=output_size) + max_unpool3d, + (input,), + input, + indices, + kernel_size, + stride=stride, + padding=padding, + output_size=output_size, + ) kernel_size = _triple(kernel_size) if stride is not None: _stride = _triple(stride) else: _stride = kernel_size padding = _triple(padding) - output_size = _unpool_output_size(input, kernel_size, _stride, padding, - output_size) - return torch._C._nn.max_unpool3d( - input, indices, output_size, _stride, padding) + output_size = _unpool_output_size(input, kernel_size, _stride, padding, output_size) + return torch._C._nn.max_unpool3d(input, indices, output_size, _stride, padding) def lp_pool2d(input, norm_type, kernel_size, stride=None, ceil_mode=False): @@ -728,15 +865,15 @@ def lp_pool2d(input, norm_type, kernel_size, stride=None, ceil_mode=False): if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): return handle_torch_function( - lp_pool2d, (input,), input, norm_type, kernel_size, stride=stride, - ceil_mode=ceil_mode) + lp_pool2d, (input,), input, norm_type, kernel_size, stride=stride, ceil_mode=ceil_mode + ) kw, kh = utils._pair(kernel_size) if stride is not None: out = avg_pool2d(input.pow(norm_type), kernel_size, stride, 0, ceil_mode) else: out = avg_pool2d(input.pow(norm_type), kernel_size, padding=0, ceil_mode=ceil_mode) - return (torch.sign(out) * relu(torch.abs(out))).mul(kw * kh).pow(1. / norm_type) + return (torch.sign(out) * relu(torch.abs(out))).mul(kw * kh).pow(1.0 / norm_type) def lp_pool1d(input, norm_type, kernel_size, stride=None, ceil_mode=False): @@ -750,14 +887,14 @@ def lp_pool1d(input, norm_type, kernel_size, stride=None, ceil_mode=False): if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): return handle_torch_function( - lp_pool1d, (input,), input, norm_type, kernel_size, stride=stride, - ceil_mode=ceil_mode) + lp_pool1d, (input,), input, norm_type, kernel_size, stride=stride, ceil_mode=ceil_mode + ) if stride is not None: out = avg_pool1d(input.pow(norm_type), kernel_size, stride, 0, ceil_mode) else: out = avg_pool1d(input.pow(norm_type), kernel_size, padding=0, ceil_mode=ceil_mode) - return (torch.sign(out) * relu(torch.abs(out))).mul(kernel_size).pow(1. / norm_type) + return (torch.sign(out) * relu(torch.abs(out))).mul(kernel_size).pow(1.0 / norm_type) def adaptive_max_pool1d_with_indices(input, output_size, return_indices=False): @@ -774,8 +911,8 @@ def adaptive_max_pool1d_with_indices(input, output_size, return_indices=False): if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): return handle_torch_function( - adaptive_max_pool1d_with_indices, (input,), input, output_size, - return_indices=return_indices) + adaptive_max_pool1d_with_indices, (input,), input, output_size, return_indices=return_indices + ) return torch.adaptive_max_pool1d(input, output_size) @@ -784,18 +921,20 @@ def _adaptive_max_pool1d(input, output_size, return_indices=False): if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): return handle_torch_function( - adaptive_max_pool1d, (input,), input, output_size, - return_indices=return_indices) + adaptive_max_pool1d, (input,), input, output_size, return_indices=return_indices + ) return adaptive_max_pool1d_with_indices(input, output_size)[0] + adaptive_max_pool1d = boolean_dispatch( - arg_name='return_indices', + arg_name="return_indices", arg_index=2, default=False, if_true=adaptive_max_pool1d_with_indices, if_false=_adaptive_max_pool1d, module_name=__name__, - func_name='adaptive_max_pool1d') + func_name="adaptive_max_pool1d", +) def adaptive_max_pool2d_with_indices(input, output_size, return_indices=False): @@ -813,8 +952,8 @@ def adaptive_max_pool2d_with_indices(input, output_size, return_indices=False): if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): return handle_torch_function( - adaptive_max_pool2d_with_indices, (input,), input, output_size, - return_indices=return_indices) + adaptive_max_pool2d_with_indices, (input,), input, output_size, return_indices=return_indices + ) output_size = _list_with_default(output_size, input.size()) return torch._C._nn.adaptive_max_pool2d(input, output_size) @@ -824,18 +963,20 @@ def _adaptive_max_pool2d(input, output_size, return_indices=False): if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): return handle_torch_function( - adaptive_max_pool2d, (input,), input, output_size, - return_indices=return_indices) + adaptive_max_pool2d, (input,), input, output_size, return_indices=return_indices + ) return adaptive_max_pool2d_with_indices(input, output_size)[0] + adaptive_max_pool2d = boolean_dispatch( - arg_name='return_indices', + arg_name="return_indices", arg_index=2, default=False, if_true=adaptive_max_pool2d_with_indices, if_false=_adaptive_max_pool2d, module_name=__name__, - func_name='adaptive_max_pool2d') + func_name="adaptive_max_pool2d", +) def adaptive_max_pool3d_with_indices(input, output_size, return_indices=False): @@ -853,8 +994,8 @@ def adaptive_max_pool3d_with_indices(input, output_size, return_indices=False): if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): return handle_torch_function( - adaptive_max_pool3d_with_indices, (input,), input, output_size, - return_indices=return_indices) + adaptive_max_pool3d_with_indices, (input,), input, output_size, return_indices=return_indices + ) output_size = _list_with_default(output_size, input.size()) return torch._C._nn.adaptive_max_pool3d(input, output_size) @@ -864,21 +1005,25 @@ def _adaptive_max_pool3d(input, output_size, return_indices=False): if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): return handle_torch_function( - adaptive_max_pool3d, (input,), input, output_size, - return_indices=return_indices) + adaptive_max_pool3d, (input,), input, output_size, return_indices=return_indices + ) return adaptive_max_pool3d_with_indices(input, output_size)[0] + adaptive_max_pool3d = boolean_dispatch( - arg_name='return_indices', + arg_name="return_indices", arg_index=2, default=False, if_true=adaptive_max_pool3d_with_indices, if_false=_adaptive_max_pool3d, module_name=__name__, - func_name='adaptive_max_pool3d') + func_name="adaptive_max_pool3d", +) -adaptive_avg_pool1d = _add_docstr(torch.adaptive_avg_pool1d, r""" +adaptive_avg_pool1d = _add_docstr( + torch.adaptive_avg_pool1d, + r""" adaptive_avg_pool1d(input, output_size) -> Tensor Applies a 1D adaptive average pooling over an input signal composed of @@ -888,7 +1033,8 @@ def _adaptive_max_pool3d(input, output_size, return_indices=False): Args: output_size: the target output size (single integer) -""") +""", +) def adaptive_avg_pool2d(input, output_size): @@ -905,8 +1051,7 @@ def adaptive_avg_pool2d(input, output_size): """ if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - adaptive_avg_pool2d, (input,), input, output_size) + return handle_torch_function(adaptive_avg_pool2d, (input,), input, output_size) _output_size = _list_with_default(output_size, input.size()) return torch._C._nn.adaptive_avg_pool2d(input, _output_size) @@ -925,15 +1070,13 @@ def adaptive_avg_pool3d(input, output_size): """ if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - adaptive_avg_pool3d, (input,), input, output_size) + return handle_torch_function(adaptive_avg_pool3d, (input,), input, output_size) _output_size = _list_with_default(output_size, input.size()) return torch._C._nn.adaptive_avg_pool3d(input, _output_size) # Activation functions -def dropout(input, p=0.5, training=True, inplace=False): - # type: (Tensor, float, bool, bool) -> Tensor +def dropout(input: Tensor, p: float = 0.5, training: bool = True, inplace: bool = False) -> Tensor: r""" During training, randomly zeroes some of the elements of the input tensor with probability :attr:`p` using samples from a Bernoulli @@ -948,36 +1091,26 @@ def dropout(input, p=0.5, training=True, inplace=False): """ if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - dropout, (input,), input, p=p, training=training, inplace=inplace) - if p < 0. or p > 1.: - raise ValueError("dropout probability has to be between 0 and 1, " - "but got {}".format(p)) - return (_VF.dropout_(input, p, training) - if inplace - else _VF.dropout(input, p, training)) + return handle_torch_function(dropout, (input,), input, p=p, training=training, inplace=inplace) + if p < 0.0 or p > 1.0: + raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p)) + return _VF.dropout_(input, p, training) if inplace else _VF.dropout(input, p, training) -def alpha_dropout(input, p=0.5, training=False, inplace=False): - # type: (Tensor, float, bool, bool) -> Tensor +def alpha_dropout(input: Tensor, p: float = 0.5, training: bool = False, inplace: bool = False) -> Tensor: r"""Applies alpha dropout to the input. See :class:`~torch.nn.AlphaDropout` for details. """ if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - alpha_dropout, (input,), input, p=p, training=training, inplace=inplace) - if p < 0. or p > 1.: - raise ValueError("dropout probability has to be between 0 and 1, " - "but got {}".format(p)) - return (_VF.alpha_dropout_(input, p, training) - if inplace - else _VF.alpha_dropout(input, p, training)) + return handle_torch_function(alpha_dropout, (input,), input, p=p, training=training, inplace=inplace) + if p < 0.0 or p > 1.0: + raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p)) + return _VF.alpha_dropout_(input, p, training) if inplace else _VF.alpha_dropout(input, p, training) -def dropout2d(input, p=0.5, training=True, inplace=False): - # type: (Tensor, float, bool, bool) -> Tensor +def dropout2d(input: Tensor, p: float = 0.5, training: bool = True, inplace: bool = False) -> Tensor: r""" Randomly zero out entire channels (a channel is a 2D feature map, e.g., the :math:`j`-th channel of the :math:`i`-th sample in the @@ -994,18 +1127,13 @@ def dropout2d(input, p=0.5, training=True, inplace=False): """ if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - dropout2d, (input,), input, p=p, training=training, inplace=inplace) - if p < 0. or p > 1.: - raise ValueError("dropout probability has to be between 0 and 1, " - "but got {}".format(p)) - return (_VF.feature_dropout_(input, p, training) - if inplace - else _VF.feature_dropout(input, p, training)) + return handle_torch_function(dropout2d, (input,), input, p=p, training=training, inplace=inplace) + if p < 0.0 or p > 1.0: + raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p)) + return _VF.feature_dropout_(input, p, training) if inplace else _VF.feature_dropout(input, p, training) -def dropout3d(input, p=0.5, training=True, inplace=False): - # type: (Tensor, float, bool, bool) -> Tensor +def dropout3d(input: Tensor, p: float = 0.5, training: bool = True, inplace: bool = False) -> Tensor: r""" Randomly zero out entire channels (a channel is a 3D feature map, e.g., the :math:`j`-th channel of the :math:`i`-th sample in the @@ -1024,18 +1152,13 @@ def dropout3d(input, p=0.5, training=True, inplace=False): # stack traces are not confusing. if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - dropout3d, (input,), input, p=p, training=training, inplace=inplace) - if p < 0. or p > 1.: - raise ValueError("dropout probability has to be between 0 and 1, " - "but got {}".format(p)) - return (_VF.feature_dropout_(input, p, training) - if inplace - else _VF.feature_dropout(input, p, training)) + return handle_torch_function(dropout3d, (input,), input, p=p, training=training, inplace=inplace) + if p < 0.0 or p > 1.0: + raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p)) + return _VF.feature_dropout_(input, p, training) if inplace else _VF.feature_dropout(input, p, training) -def feature_alpha_dropout(input, p=0.5, training=False, inplace=False): - # type: (Tensor, float, bool, bool) -> Tensor +def feature_alpha_dropout(input: Tensor, p: float = 0.5, training: bool = False, inplace: bool = False) -> Tensor: r""" Randomly masks out entire channels (a channel is a feature map, e.g. the :math:`j`-th channel of the :math:`i`-th sample in the batch input @@ -1058,42 +1181,41 @@ def feature_alpha_dropout(input, p=0.5, training=False, inplace=False): if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): return handle_torch_function( - feature_alpha_dropout, (input,), input, p=p, training=training, - inplace=inplace) - if p < 0. or p > 1.: - raise ValueError("dropout probability has to be between 0 and 1, " - "but got {}".format(p)) - return (_VF.feature_alpha_dropout_(input, p, training) - if inplace - else _VF.feature_alpha_dropout(input, p, training)) - - -def _threshold(input, threshold, value, inplace=False): - # type: (Tensor, float, float, bool) -> Tensor + feature_alpha_dropout, (input,), input, p=p, training=training, inplace=inplace + ) + if p < 0.0 or p > 1.0: + raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p)) + return _VF.feature_alpha_dropout_(input, p, training) if inplace else _VF.feature_alpha_dropout(input, p, training) + + +def _threshold(input: Tensor, threshold: float, value: float, inplace: bool = False) -> Tensor: r"""Thresholds each element of the input Tensor. See :class:`~torch.nn.Threshold` for more details. """ if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - _threshold, (input,), input, threshold, value, inplace=inplace) + return handle_torch_function(_threshold, (input,), input, threshold, value, inplace=inplace) if inplace: result = _VF.threshold_(input, threshold, value) else: result = _VF.threshold(input, threshold, value) return result + # We define this function as _threshold because it takes an argument # named threshold, which clobbers the recursive reference to the # function needed for __torch_function__ support threshold = _threshold -threshold_ = _add_docstr(_VF.threshold_, r""" +threshold_ = _add_docstr( + _VF.threshold_, + r""" threshold_(input, threshold, value) -> Tensor In-place version of :func:`~threshold`. -""") +""", +) def relu(input: Tensor, inplace: bool = False) -> Tensor: @@ -1112,11 +1234,14 @@ def relu(input: Tensor, inplace: bool = False) -> Tensor: return result -relu_ = _add_docstr(torch.relu_, r""" +relu_ = _add_docstr( + torch.relu_, + r""" relu_(input) -> Tensor In-place version of :func:`~relu`. -""") +""", +) def glu(input: Tensor, dim: int = -1) -> Tensor: @@ -1145,7 +1270,7 @@ def glu(input: Tensor, dim: int = -1) -> Tensor: return torch._C._nn.glu(input, dim) -def hardtanh(input: Tensor, min_val: float = -1., max_val: float = 1., inplace: bool = False) -> Tensor: +def hardtanh(input: Tensor, min_val: float = -1.0, max_val: float = 1.0, inplace: bool = False) -> Tensor: r""" hardtanh(input, min_val=-1., max_val=1., inplace=False) -> Tensor @@ -1154,9 +1279,7 @@ def hardtanh(input: Tensor, min_val: float = -1., max_val: float = 1., inplace: """ if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - hardtanh, (input,), input, min_val=min_val, max_val=max_val, - inplace=inplace) + return handle_torch_function(hardtanh, (input,), input, min_val=min_val, max_val=max_val, inplace=inplace) if inplace: result = torch._C._nn.hardtanh_(input, min_val, max_val) else: @@ -1164,15 +1287,17 @@ def hardtanh(input: Tensor, min_val: float = -1., max_val: float = 1., inplace: return result -hardtanh_ = _add_docstr(torch._C._nn.hardtanh_, r""" +hardtanh_ = _add_docstr( + torch._C._nn.hardtanh_, + r""" hardtanh_(input, min_val=-1., max_val=1.) -> Tensor In-place version of :func:`~hardtanh`. -""") +""", +) -def relu6(input, inplace=False): - # type: (Tensor, bool) -> Tensor +def relu6(input: Tensor, inplace: bool = False) -> Tensor: r"""relu6(input, inplace=False) -> Tensor Applies the element-wise function :math:`\text{ReLU6}(x) = \min(\max(0,x), 6)`. @@ -1182,11 +1307,10 @@ def relu6(input, inplace=False): if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): return handle_torch_function(relu6, (input,), input, inplace=inplace) - return hardtanh(input, 0., 6., inplace) + return hardtanh(input, 0.0, 6.0, inplace) -def elu(input, alpha=1., inplace=False): - # type: (Tensor, float, bool) -> Tensor +def elu(input: Tensor, alpha: float = 1.0, inplace: bool = False) -> Tensor: r"""Applies element-wise, :math:`\text{ELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x) - 1))`. @@ -1194,8 +1318,7 @@ def elu(input, alpha=1., inplace=False): """ if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function(elu, (input,), input, alpha=alpha, - inplace=inplace) + return handle_torch_function(elu, (input,), input, alpha=alpha, inplace=inplace) if inplace: result = torch._C._nn.elu_(input, alpha) else: @@ -1203,15 +1326,17 @@ def elu(input, alpha=1., inplace=False): return result -elu_ = _add_docstr(torch._C._nn.elu_, r""" +elu_ = _add_docstr( + torch._C._nn.elu_, + r""" elu_(input, alpha=1.) -> Tensor In-place version of :func:`~elu`. -""") +""", +) -def selu(input, inplace=False): - # type: (Tensor, bool) -> Tensor +def selu(input: Tensor, inplace: bool = False) -> Tensor: r"""selu(input, inplace=False) -> Tensor Applies element-wise, @@ -1231,15 +1356,17 @@ def selu(input, inplace=False): return result -selu_ = _add_docstr(torch.selu_, r""" +selu_ = _add_docstr( + torch.selu_, + r""" selu_(input) -> Tensor In-place version of :func:`~selu`. -""") +""", +) -def celu(input, alpha=1., inplace=False): - # type: (Tensor, float, bool) -> Tensor +def celu(input: Tensor, alpha: float = 1.0, inplace: bool = False) -> Tensor: r"""celu(input, alpha=1., inplace=False) -> Tensor Applies element-wise, @@ -1249,19 +1376,22 @@ def celu(input, alpha=1., inplace=False): """ if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function(celu, (input,), input, alpha=alpha, - inplace=inplace) + return handle_torch_function(celu, (input,), input, alpha=alpha, inplace=inplace) if inplace: result = torch.celu_(input, alpha) else: result = torch.celu(input, alpha) return result -celu_ = _add_docstr(torch.celu_, r""" + +celu_ = _add_docstr( + torch.celu_, + r""" celu_(input, alpha=1.) -> Tensor In-place version of :func:`~celu`. -""") +""", +) def leaky_relu(input: Tensor, negative_slope: float = 0.01, inplace: bool = False) -> Tensor: @@ -1275,9 +1405,7 @@ def leaky_relu(input: Tensor, negative_slope: float = 0.01, inplace: bool = Fals """ if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - leaky_relu, (input,), input, negative_slope=negative_slope, - inplace=inplace) + return handle_torch_function(leaky_relu, (input,), input, negative_slope=negative_slope, inplace=inplace) if inplace: result = torch._C._nn.leaky_relu_(input, negative_slope) else: @@ -1285,15 +1413,17 @@ def leaky_relu(input: Tensor, negative_slope: float = 0.01, inplace: bool = Fals return result -leaky_relu_ = _add_docstr(torch._C._nn.leaky_relu_, r""" +leaky_relu_ = _add_docstr( + torch._C._nn.leaky_relu_, + r""" leaky_relu_(input, negative_slope=0.01) -> Tensor In-place version of :func:`~leaky_relu`. -""") +""", +) -def prelu(input, weight): - # type: (Tensor, Tensor) -> Tensor +def prelu(input: Tensor, weight: Tensor) -> Tensor: r"""prelu(input, weight) -> Tensor Applies element-wise the function @@ -1308,8 +1438,9 @@ def prelu(input, weight): return torch.prelu(input, weight) -def rrelu(input, lower=1. / 8, upper=1. / 3, training=False, inplace=False): - # type: (Tensor, float, float, bool, bool) -> Tensor +def rrelu( + input: Tensor, lower: float = 1.0 / 8, upper: float = 1.0 / 3, training: bool = False, inplace: bool = False +) -> Tensor: r"""rrelu(input, lower=1./8, upper=1./3, training=False, inplace=False) -> Tensor Randomized leaky ReLU. @@ -1319,8 +1450,8 @@ def rrelu(input, lower=1. / 8, upper=1. / 3, training=False, inplace=False): if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): return handle_torch_function( - rrelu, (input,), input, lower=lower, upper=upper, - training=training, inplace=inplace) + rrelu, (input,), input, lower=lower, upper=upper, training=training, inplace=inplace + ) if inplace: result = torch.rrelu_(input, lower, upper, training) else: @@ -1328,19 +1459,26 @@ def rrelu(input, lower=1. / 8, upper=1. / 3, training=False, inplace=False): return result -rrelu_ = _add_docstr(torch.rrelu_, r""" +rrelu_ = _add_docstr( + torch.rrelu_, + r""" rrelu_(input, lower=1./8, upper=1./3, training=False) -> Tensor In-place version of :func:`~rrelu`. -""") +""", +) -logsigmoid = _add_docstr(torch._C._nn.log_sigmoid, r""" +logsigmoid = _add_docstr( + torch._C._nn.log_sigmoid, + r""" logsigmoid(input) -> Tensor Applies element-wise :math:`\text{LogSigmoid}(x_i) = \log \left(\frac{1}{1 + \exp(-x_i)}\right)` See :class:`~torch.nn.LogSigmoid` for more details. -""") +""", +) + def gelu(input): r"""gelu(input) -> Tensor @@ -1358,8 +1496,7 @@ def gelu(input): return torch._C._nn.gelu(input) -def hardshrink(input, lambd=0.5): - # type: (Tensor, float) -> Tensor +def hardshrink(input: Tensor, lambd: float = 0.5) -> Tensor: r""" hardshrink(input, lambd=0.5) -> Tensor @@ -1399,7 +1536,9 @@ def softsign(input): return input / (input.abs() + 1) -softplus = _add_docstr(torch._C._nn.softplus, r""" +softplus = _add_docstr( + torch._C._nn.softplus, + r""" softplus(input, beta=1, threshold=20) -> Tensor Applies element-wise, the function :math:`\text{Softplus}(x) = \frac{1}{\beta} * \log(1 + \exp(\beta * x))`. @@ -1408,13 +1547,16 @@ def softsign(input): when :math:`input \times \beta > threshold`. See :class:`~torch.nn.Softplus` for more details. -""") +""", +) -def _get_softmax_dim(name, ndim, stacklevel): - # type: (str, int, int) -> int - warnings.warn("Implicit dimension choice for {} has been deprecated. " - "Change the call to include dim=X as an argument.".format(name), stacklevel=stacklevel) +def _get_softmax_dim(name: str, ndim: int, stacklevel: int) -> int: + warnings.warn( + "Implicit dimension choice for {} has been deprecated. " + "Change the call to include dim=X as an argument.".format(name), + stacklevel=stacklevel, + ) if ndim == 0 or ndim == 1 or ndim == 3: ret = 0 else: @@ -1422,8 +1564,7 @@ def _get_softmax_dim(name, ndim, stacklevel): return ret -def softmin(input, dim=None, _stacklevel=3, dtype=None): - # type: (Tensor, Optional[int], int, Optional[int]) -> Tensor +def softmin(input: Tensor, dim: Optional[int] = None, _stacklevel: int = 3, dtype: Optional[int] = None) -> Tensor: r"""Applies a softmin function. Note that :math:`\text{Softmin}(x) = \text{Softmax}(-x)`. See softmax definition for mathematical formula. @@ -1440,10 +1581,9 @@ def softmin(input, dim=None, _stacklevel=3, dtype=None): """ if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - softmin, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype) + return handle_torch_function(softmin, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype) if dim is None: - dim = _get_softmax_dim('softmin', input.dim(), _stacklevel) + dim = _get_softmax_dim("softmin", input.dim(), _stacklevel) if dtype is None: ret = (-input).softmax(dim) else: @@ -1451,8 +1591,7 @@ def softmin(input, dim=None, _stacklevel=3, dtype=None): return ret -def softmax(input, dim=None, _stacklevel=3, dtype=None): - # type: (Tensor, Optional[int], int, Optional[int]) -> Tensor +def softmax(input: Tensor, dim: Optional[int] = None, _stacklevel: int = 3, dtype: Optional[int] = None) -> Tensor: r"""Applies a softmax function. Softmax is defined as: @@ -1479,10 +1618,9 @@ def softmax(input, dim=None, _stacklevel=3, dtype=None): """ if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - softmax, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype) + return handle_torch_function(softmax, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype) if dim is None: - dim = _get_softmax_dim('softmax', input.dim(), _stacklevel) + dim = _get_softmax_dim("softmax", input.dim(), _stacklevel) if dtype is None: ret = input.softmax(dim) else: @@ -1490,8 +1628,7 @@ def softmax(input, dim=None, _stacklevel=3, dtype=None): return ret -def gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1): - # type: (Tensor, float, bool, float, int) -> Tensor +def gumbel_softmax(logits: Tensor, tau: float = 1, hard: bool = False, eps: float = 1e-10, dim: int = -1) -> Tensor: r""" Samples from the Gumbel-Softmax distribution (`Link 1`_ `Link 2`_) and optionally discretizes. @@ -1533,12 +1670,13 @@ def gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1): """ if not torch.jit.is_scripting(): if type(logits) is not Tensor and has_torch_function((logits,)): - return handle_torch_function( - gumbel_softmax, (logits,), logits, tau=tau, hard=hard, eps=eps, dim=dim) + return handle_torch_function(gumbel_softmax, (logits,), logits, tau=tau, hard=hard, eps=eps, dim=dim) if eps != 1e-10: warnings.warn("`eps` parameter is deprecated and has no effect.") - gumbels = -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log() # ~Gumbel(0,1) + gumbels = ( + -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log() + ) # ~Gumbel(0,1) gumbels = (logits + gumbels) / tau # ~Gumbel(logits,tau) y_soft = gumbels.softmax(dim) @@ -1553,8 +1691,7 @@ def gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1): return ret -def log_softmax(input, dim=None, _stacklevel=3, dtype=None): - # type: (Tensor, Optional[int], int, Optional[int]) -> Tensor +def log_softmax(input: Tensor, dim: Optional[int] = None, _stacklevel: int = 3, dtype: Optional[int] = None) -> Tensor: r"""Applies a softmax followed by a logarithm. While mathematically equivalent to log(softmax(x)), doing these two @@ -1572,10 +1709,9 @@ def log_softmax(input, dim=None, _stacklevel=3, dtype=None): """ if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - log_softmax, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype) + return handle_torch_function(log_softmax, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype) if dim is None: - dim = _get_softmax_dim('log_softmax', input.dim(), _stacklevel) + dim = _get_softmax_dim("log_softmax", input.dim(), _stacklevel) if dtype is None: ret = input.log_softmax(dim) else: @@ -1583,13 +1719,16 @@ def log_softmax(input, dim=None, _stacklevel=3, dtype=None): return ret -softshrink = _add_docstr(torch._C._nn.softshrink, r""" +softshrink = _add_docstr( + torch._C._nn.softshrink, + r""" softshrink(input, lambd=0.5) -> Tensor Applies the soft shrinkage function elementwise See :class:`~torch.nn.Softshrink` for more details. -""") +""", +) def tanh(input): @@ -1615,8 +1754,7 @@ def sigmoid(input): return input.sigmoid() -def hardsigmoid(input, inplace=False): - # type: (Tensor, bool) -> Tensor +def hardsigmoid(input: Tensor, inplace: bool = False) -> Tensor: r"""hardsigmoid(input) -> Tensor Applies the element-wise function @@ -1641,8 +1779,7 @@ def hardsigmoid(input, inplace=False): return torch._C._nn.hardsigmoid(input) -def linear(input, weight, bias=None): - # type: (Tensor, Tensor, Optional[Tensor]) -> Tensor +def linear(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor: r""" Applies a linear transformation to the incoming data: :math:`y = xA^T + b`. @@ -1671,8 +1808,7 @@ def linear(input, weight, bias=None): return ret -def bilinear(input1, input2, weight, bias=None): - # type: (Tensor, Tensor, Tensor, Optional[Tensor]) -> Tensor +def bilinear(input1: Tensor, input2: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor: r""" Applies a bilinear transformation to the incoming data: :math:`y = x_1^T A x_2 + b` @@ -1691,8 +1827,8 @@ def bilinear(input1, input2, weight, bias=None): """ return torch.bilinear(input1, input2, weight, bias) -def silu(input, inplace=False): - # type: (Tensor, bool) -> Tensor + +def silu(input: Tensor, inplace: bool = False) -> Tensor: r"""Applies the silu function, element-wise. .. math:: @@ -1715,6 +1851,7 @@ def silu(input, inplace=False): return torch._C._nn.silu_(input) return torch._C._nn.silu(input) + def hardswish(input: Tensor, inplace: bool = False) -> Tensor: r"""Applies the hardswish function, element-wise, as described in the paper: @@ -1740,15 +1877,20 @@ def hardswish(input: Tensor, inplace: bool = False) -> Tensor: return torch._C._nn.hardswish(input) -def _no_grad_embedding_renorm_(weight, input, max_norm, norm_type): - # type: (Tensor, Tensor, float, float) -> Tensor +def _no_grad_embedding_renorm_(weight: Tensor, input: Tensor, max_norm: float, norm_type: float) -> Tensor: with torch.no_grad(): torch.embedding_renorm_(weight, input, max_norm, norm_type) -def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2., - scale_grad_by_freq=False, sparse=False): - # type: (Tensor, Tensor, Optional[int], Optional[float], float, bool, bool) -> Tensor +def embedding( + input: Tensor, + weight: Tensor, + padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + sparse: bool = False, +) -> Tensor: r"""A simple lookup table that looks up embeddings in a fixed dictionary and size. This module is often used to retrieve word embeddings using indices. @@ -1809,9 +1951,9 @@ def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2., if padding_idx is not None: if padding_idx > 0: - assert padding_idx < weight.size(0), 'Padding_idx must be within num_embeddings' + assert padding_idx < weight.size(0), "Padding_idx must be within num_embeddings" elif padding_idx < 0: - assert padding_idx >= -weight.size(0), 'Padding_idx must be within num_embeddings' + assert padding_idx >= -weight.size(0), "Padding_idx must be within num_embeddings" padding_idx = weight.size(0) + padding_idx else: padding_idx = -1 @@ -1828,10 +1970,18 @@ def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2., return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse) -def embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2, - scale_grad_by_freq=False, mode='mean', sparse=False, - per_sample_weights=None, include_last_offset=False): - # type: (Tensor, Tensor, Optional[Tensor], Optional[float], float, bool, str, bool, Optional[Tensor], bool) -> Tensor +def embedding_bag( + input: Tensor, + weight: Tensor, + offsets: Optional[Tensor] = None, + max_norm: Optional[float] = None, + norm_type: float = 2, + scale_grad_by_freq: bool = False, + mode: str = "mean", + sparse: bool = False, + per_sample_weights: Optional[Tensor] = None, + include_last_offset: bool = False, +) -> Tensor: r"""Computes sums, means or maxes of `bags` of embeddings, without instantiating the intermediate embeddings. @@ -1911,23 +2061,35 @@ def embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2, tens_ops = (input, weight) if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): return handle_torch_function( - embedding_bag, tens_ops, input, weight, offsets=offsets, max_norm=max_norm, - norm_type=norm_type, scale_grad_by_freq=scale_grad_by_freq, mode=mode, - sparse=sparse, per_sample_weights=per_sample_weights, - include_last_offset=include_last_offset) + embedding_bag, + tens_ops, + input, + weight, + offsets=offsets, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + mode=mode, + sparse=sparse, + per_sample_weights=per_sample_weights, + include_last_offset=include_last_offset, + ) # Check for backward compatibility. # Used to be embedding_bag(weight, input, ...) # Now is embedding_bag(input, weight, ...) if weight.dtype == torch.long and input.is_floating_point(): - warnings.warn("Argument order of nn.functional.embedding_bag was changed. " - "Usage `embedding_bag(weight, input, ...)` is deprecated, " - "and should now be `embedding_bag(input, weight, ...)`.") + warnings.warn( + "Argument order of nn.functional.embedding_bag was changed. " + "Usage `embedding_bag(weight, input, ...)` is deprecated, " + "and should now be `embedding_bag(input, weight, ...)`." + ) weight, input = input, weight if per_sample_weights is not None and input.size() != per_sample_weights.size(): - raise ValueError("embedding_bag: If per_sample_weights ({}) is not None, " - "then it must have the same shape as the input ({})" - .format(per_sample_weights.shape, input.shape)) + raise ValueError( + "embedding_bag: If per_sample_weights ({}) is not None, " + "then it must have the same shape as the input ({})".format(per_sample_weights.shape, input.shape) + ) if input.dim() == 2: if offsets is not None: @@ -1935,12 +2097,13 @@ def embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2, # TODO: Remove this once script supports type() calls if not torch.jit.is_scripting(): type_str = str(type(offsets)) - raise ValueError("if input is 2D, then offsets has to be None" - ", as input is treated is a mini-batch of" - " fixed length sequences. However, found " - "offsets of type {}".format(type_str)) - offsets = torch.arange(0, input.numel(), input.size(1), - dtype=input.dtype, device=input.device) + raise ValueError( + "if input is 2D, then offsets has to be None" + ", as input is treated is a mini-batch of" + " fixed length sequences. However, found " + "offsets of type {}".format(type_str) + ) + offsets = torch.arange(0, input.numel(), input.size(1), dtype=input.dtype, device=input.device) input = input.reshape(-1) if per_sample_weights is not None: @@ -1951,13 +2114,12 @@ def embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2, if offsets.dim() != 1: raise ValueError("offsets has to be a 1D Tensor") else: - raise ValueError("input has to be 1D or 2D Tensor," - " but got Tensor of dimension {}".format(input.dim())) - if mode == 'sum': + raise ValueError("input has to be 1D or 2D Tensor," " but got Tensor of dimension {}".format(input.dim())) + if mode == "sum": mode_enum = 0 - elif mode == 'mean': + elif mode == "mean": mode_enum = 1 - elif mode == 'max': + elif mode == "max": mode_enum = 2 if scale_grad_by_freq: @@ -1976,28 +2138,23 @@ def embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2, # remove once script supports set_grad_enabled _no_grad_embedding_renorm_(weight, input, max_norm, norm_type) - if per_sample_weights is not None and mode != 'sum': - raise NotImplementedError("embedding_bag: per_sample_weights was not None. " - "per_sample_weights is only supported for mode='sum' " - "(got mode='{}'). Please open a feature request on GitHub." - .format(mode)) + if per_sample_weights is not None and mode != "sum": + raise NotImplementedError( + "embedding_bag: per_sample_weights was not None. " + "per_sample_weights is only supported for mode='sum' " + "(got mode='{}'). Please open a feature request on GitHub.".format(mode) + ) ret, _, _, _ = torch.embedding_bag( - weight, - input, - offsets, - scale_grad_by_freq, - mode_enum, - sparse, - per_sample_weights, - include_last_offset) + weight, input, offsets, scale_grad_by_freq, mode_enum, sparse, per_sample_weights, include_last_offset + ) return ret + embedding_bag.__doc__ = embedding_bag.__doc__.format(**reproducibility_notes) -def _verify_batch_size(size): - # type: (List[int]) -> None +def _verify_batch_size(size: List[int]) -> None: # XXX: JIT script does not support the reduce from functools, and mul op is a # builtin, which cannot be used as a value to a func yet, so rewrite this size # check to a simple equivalent for loop @@ -2011,12 +2168,20 @@ def _verify_batch_size(size): for i in range(len(size) - 2): size_prods *= size[i + 2] if size_prods == 1: - raise ValueError('Expected more than 1 value per channel when training, got input size {}'.format(size)) - - -def batch_norm(input, running_mean, running_var, weight=None, bias=None, - training=False, momentum=0.1, eps=1e-5): - # type: (Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor], bool, float, float) -> Tensor # noqa + raise ValueError("Expected more than 1 value per channel when training, got input size {}".format(size)) + + +def batch_norm( + input: Tensor, + running_mean: Optional[Tensor], + running_var: Optional[Tensor], + weight: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + training: bool = False, + momentum: float = 0.1, + eps: float = 1e-5, +) -> Tensor: + # noqa r"""Applies Batch Normalization for each channel across a batch of data. See :class:`~torch.nn.BatchNorm1d`, :class:`~torch.nn.BatchNorm2d`, @@ -2025,20 +2190,36 @@ def batch_norm(input, running_mean, running_var, weight=None, bias=None, if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): return handle_torch_function( - batch_norm, (input,), input, running_mean, running_var, weight=weight, - bias=bias, training=training, momentum=momentum, eps=eps) + batch_norm, + (input,), + input, + running_mean, + running_var, + weight=weight, + bias=bias, + training=training, + momentum=momentum, + eps=eps, + ) if training: _verify_batch_size(input.size()) return torch.batch_norm( - input, weight, bias, running_mean, running_var, - training, momentum, eps, torch.backends.cudnn.enabled + input, weight, bias, running_mean, running_var, training, momentum, eps, torch.backends.cudnn.enabled ) -def instance_norm(input, running_mean=None, running_var=None, weight=None, - bias=None, use_input_stats=True, momentum=0.1, eps=1e-5): - # type: (Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor], bool, float, float) -> Tensor # noqa +def instance_norm( + input: Tensor, + running_mean: Optional[Tensor] = None, + running_var: Optional[Tensor] = None, + weight: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + use_input_stats: bool = True, + momentum: float = 0.1, + eps: float = 1e-5, +) -> Tensor: + # noqa r"""Applies Instance Normalization for each channel in each data sample in a batch. @@ -2048,18 +2229,30 @@ def instance_norm(input, running_mean=None, running_var=None, weight=None, if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): return handle_torch_function( - instance_norm, (input,), input, running_mean=running_mean, - running_var=running_var, weight=weight, bias=bias, - use_input_stats=use_input_stats, momentum=momentum, eps=eps) + instance_norm, + (input,), + input, + running_mean=running_mean, + running_var=running_var, + weight=weight, + bias=bias, + use_input_stats=use_input_stats, + momentum=momentum, + eps=eps, + ) _verify_batch_size(input.size()) return torch.instance_norm( - input, weight, bias, running_mean, running_var, - use_input_stats, momentum, eps, torch.backends.cudnn.enabled + input, weight, bias, running_mean, running_var, use_input_stats, momentum, eps, torch.backends.cudnn.enabled ) -def layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-5): - # type: (Tensor, List[int], Optional[Tensor], Optional[Tensor], float) -> Tensor +def layer_norm( + input: Tensor, + normalized_shape: List[int], + weight: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + eps: float = 1e-5, +) -> Tensor: r"""Applies Layer Normalization for last certain number of dimensions. See :class:`~torch.nn.LayerNorm` for details. @@ -2067,30 +2260,26 @@ def layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-5): if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): return handle_torch_function( - layer_norm, (input,), input, normalized_shape, weight=weight, bias=bias, eps=eps) - return torch.layer_norm(input, normalized_shape, weight, bias, eps, - torch.backends.cudnn.enabled) + layer_norm, (input,), input, normalized_shape, weight=weight, bias=bias, eps=eps + ) + return torch.layer_norm(input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled) -def group_norm(input, num_groups, weight=None, bias=None, eps=1e-5): - # type: (Tensor, int, Optional[Tensor], Optional[Tensor], float) -> Tensor +def group_norm( + input: Tensor, num_groups: int, weight: Optional[Tensor] = None, bias: Optional[Tensor] = None, eps: float = 1e-5 +) -> Tensor: r"""Applies Group Normalization for last certain number of dimensions. See :class:`~torch.nn.GroupNorm` for details. """ if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - group_norm, (input,), input, num_groups, weight=weight, bias=bias, eps=eps) - _verify_batch_size([ - input.size(0) * input.size(1) // num_groups, num_groups] - + list(input.size()[2:])) - return torch.group_norm(input, num_groups, weight, bias, eps, - torch.backends.cudnn.enabled) + return handle_torch_function(group_norm, (input,), input, num_groups, weight=weight, bias=bias, eps=eps) + _verify_batch_size([input.size(0) * input.size(1) // num_groups, num_groups] + list(input.size()[2:])) + return torch.group_norm(input, num_groups, weight, bias, eps, torch.backends.cudnn.enabled) -def local_response_norm(input, size, alpha=1e-4, beta=0.75, k=1.): - # type: (Tensor, int, float, float, float) -> Tensor +def local_response_norm(input: Tensor, size: int, alpha: float = 1e-4, beta: float = 0.75, k: float = 1.0) -> Tensor: r"""Applies local response normalization over an input signal composed of several input planes, where channels occupy the second dimension. Applies normalization across channels. @@ -2099,12 +2288,15 @@ def local_response_norm(input, size, alpha=1e-4, beta=0.75, k=1.): """ if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - local_response_norm, (input,), input, size, alpha=alpha, beta=beta, k=k) + return handle_torch_function(local_response_norm, (input,), input, size, alpha=alpha, beta=beta, k=k) dim = input.dim() if dim < 3: - raise ValueError('Expected 3D or higher dimensionality \ - input (got {} dimensions)'.format(dim)) + raise ValueError( + "Expected 3D or higher dimensionality \ + input (got {} dimensions)".format( + dim + ) + ) div = input.mul(input).unsqueeze(1) if dim == 3: div = pad(div, (0, 0, size // 2, (size - 1) // 2)) @@ -2121,9 +2313,16 @@ def local_response_norm(input, size, alpha=1e-4, beta=0.75, k=1.): # loss -def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0, - reduction='mean', zero_infinity=False): - # type: (Tensor, Tensor, Tensor, Tensor, int, str, bool) -> Tensor + +def ctc_loss( + log_probs: Tensor, + targets: Tensor, + input_lengths: Tensor, + target_lengths: Tensor, + blank: int = 0, + reduction: str = "mean", + zero_infinity: bool = False, +) -> Tensor: r"""The Connectionist Temporal Classification loss. See :class:`~torch.nn.CTCLoss` for details. @@ -2167,14 +2366,23 @@ def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0, >>> loss = F.ctc_loss(log_probs, targets, input_lengths, target_lengths) >>> loss.backward() """ - return torch.ctc_loss(log_probs, targets, input_lengths, target_lengths, blank, _Reduction.get_enum(reduction), - zero_infinity) + return torch.ctc_loss( + log_probs, targets, input_lengths, target_lengths, blank, _Reduction.get_enum(reduction), zero_infinity + ) + + ctc_loss.__doc__ = ctc_loss.__doc__.format(**reproducibility_notes) -def nll_loss(input, target, weight=None, size_average=None, ignore_index=-100, - reduce=None, reduction='mean'): - # type: (Tensor, Tensor, Optional[Tensor], Optional[bool], int, Optional[bool], str) -> Tensor +def nll_loss( + input: Tensor, + target: Tensor, + weight: Optional[Tensor] = None, + size_average: Optional[bool] = None, + ignore_index: int = -100, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: r"""The negative log likelihood loss. See :class:`~torch.nn.NLLLoss` for details. @@ -2220,17 +2428,26 @@ def nll_loss(input, target, weight=None, size_average=None, ignore_index=-100, tens_ops = (input, target) if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): return handle_torch_function( - nll_loss, tens_ops, input, target, weight=weight, size_average=size_average, - ignore_index=ignore_index, reduce=reduce, reduction=reduction) + nll_loss, + tens_ops, + input, + target, + weight=weight, + size_average=size_average, + ignore_index=ignore_index, + reduce=reduce, + reduction=reduction, + ) if size_average is not None or reduce is not None: reduction = _Reduction.legacy_get_string(size_average, reduce) dim = input.dim() if dim < 2: - raise ValueError('Expected 2 or more dimensions (got {})'.format(dim)) + raise ValueError("Expected 2 or more dimensions (got {})".format(dim)) if input.size(0) != target.size(0): - raise ValueError('Expected input batch_size ({}) to match target batch_size ({}).' - .format(input.size(0), target.size(0))) + raise ValueError( + "Expected input batch_size ({}) to match target batch_size ({}).".format(input.size(0), target.size(0)) + ) if dim == 2: ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index) elif dim == 4: @@ -2241,8 +2458,7 @@ def nll_loss(input, target, weight=None, size_average=None, ignore_index=-100, c = input.size(1) out_size = (n,) + input.size()[2:] if target.size()[1:] != input.size()[2:]: - raise ValueError('Expected target size {}, got {}'.format( - out_size, target.size())) + raise ValueError("Expected target size {}, got {}".format(out_size, target.size())) input = input.contiguous() target = target.contiguous() # support empty batches, see #15870 @@ -2255,19 +2471,24 @@ def nll_loss(input, target, weight=None, size_average=None, ignore_index=-100, else: target = target.view(n, 0, 0) reduction_enum = _Reduction.get_enum(reduction) - if reduction != 'none': - ret = torch._C._nn.nll_loss2d( - input, target, weight, reduction_enum, ignore_index) + if reduction != "none": + ret = torch._C._nn.nll_loss2d(input, target, weight, reduction_enum, ignore_index) else: - out = torch._C._nn.nll_loss2d( - input, target, weight, reduction_enum, ignore_index) + out = torch._C._nn.nll_loss2d(input, target, weight, reduction_enum, ignore_index) ret = out.view(out_size) return ret -def poisson_nll_loss(input, target, log_input=True, full=False, size_average=None, eps=1e-8, - reduce=None, reduction='mean'): - # type: (Tensor, Tensor, bool, bool, Optional[bool], float, Optional[bool], str) -> Tensor +def poisson_nll_loss( + input: Tensor, + target: Tensor, + log_input: bool = True, + full: bool = False, + size_average: Optional[bool] = None, + eps: float = 1e-8, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: r"""Poisson negative log likelihood loss. See :class:`~torch.nn.PoissonNLLLoss` for details. @@ -2304,11 +2525,20 @@ def poisson_nll_loss(input, target, log_input=True, full=False, size_average=Non tens_ops = (input, target) if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): return handle_torch_function( - poisson_nll_loss, tens_ops, input, target, log_input=log_input, full=full, - size_average=size_average, eps=eps, reduce=reduce, reduction=reduction) + poisson_nll_loss, + tens_ops, + input, + target, + log_input=log_input, + full=full, + size_average=size_average, + eps=eps, + reduce=reduce, + reduction=reduction, + ) if size_average is not None or reduce is not None: reduction = _Reduction.legacy_get_string(size_average, reduce) - if reduction != 'none' and reduction != 'mean' and reduction != 'sum': + if reduction != "none" and reduction != "mean" and reduction != "sum": ret = input raise ValueError(reduction + " is not valid") @@ -2316,8 +2546,14 @@ def poisson_nll_loss(input, target, log_input=True, full=False, size_average=Non return ret -def kl_div(input, target, size_average=None, reduce=None, reduction='mean', log_target=False): - # type: (Tensor, Tensor, Optional[bool], Optional[bool], str, bool) -> Tensor +def kl_div( + input: Tensor, + target: Tensor, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", + log_target: bool = False, +) -> Tensor: r"""The `Kullback-Leibler divergence Loss `__ @@ -2360,33 +2596,48 @@ def kl_div(input, target, size_average=None, reduce=None, reduction='mean', log_ tens_ops = (input, target) if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): return handle_torch_function( - kl_div, tens_ops, input, target, size_average=size_average, - reduce=reduce, reduction=reduction, log_target=log_target) + kl_div, + tens_ops, + input, + target, + size_average=size_average, + reduce=reduce, + reduction=reduction, + log_target=log_target, + ) if size_average is not None or reduce is not None: reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) else: - if reduction == 'mean': - warnings.warn("reduction: 'mean' divides the total loss by both the batch size and the support size." - "'batchmean' divides only by the batch size, and aligns with the KL div math definition." - "'mean' will be changed to behave the same as 'batchmean' in the next major release.") + if reduction == "mean": + warnings.warn( + "reduction: 'mean' divides the total loss by both the batch size and the support size." + "'batchmean' divides only by the batch size, and aligns with the KL div math definition." + "'mean' will be changed to behave the same as 'batchmean' in the next major release." + ) # special case for batchmean - if reduction == 'batchmean': - reduction_enum = _Reduction.get_enum('sum') + if reduction == "batchmean": + reduction_enum = _Reduction.get_enum("sum") else: reduction_enum = _Reduction.get_enum(reduction) reduced = torch.kl_div(input, target, reduction_enum, log_target=log_target) - if reduction == 'batchmean' and input.dim() != 0: + if reduction == "batchmean" and input.dim() != 0: reduced = reduced / input.size()[0] return reduced -def cross_entropy(input, target, weight=None, size_average=None, ignore_index=-100, - reduce=None, reduction='mean'): - # type: (Tensor, Tensor, Optional[Tensor], Optional[bool], int, Optional[bool], str) -> Tensor +def cross_entropy( + input: Tensor, + target: Tensor, + weight: Optional[Tensor] = None, + size_average: Optional[bool] = None, + ignore_index: int = -100, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: r"""This criterion combines `log_softmax` and `nll_loss` in a single function. @@ -2431,17 +2682,29 @@ def cross_entropy(input, target, weight=None, size_average=None, ignore_index=-1 tens_ops = (input, target) if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): return handle_torch_function( - cross_entropy, tens_ops, input, target, weight=weight, - size_average=size_average, ignore_index=ignore_index, reduce=reduce, - reduction=reduction) + cross_entropy, + tens_ops, + input, + target, + weight=weight, + size_average=size_average, + ignore_index=ignore_index, + reduce=reduce, + reduction=reduction, + ) if size_average is not None or reduce is not None: reduction = _Reduction.legacy_get_string(size_average, reduce) return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction) -def binary_cross_entropy(input, target, weight=None, size_average=None, - reduce=None, reduction='mean'): - # type: (Tensor, Tensor, Optional[Tensor], Optional[bool], Optional[bool], str) -> Tensor +def binary_cross_entropy( + input: Tensor, + target: Tensor, + weight: Optional[Tensor] = None, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: r"""Function that measures the Binary Cross Entropy between the target and the output. @@ -2479,27 +2742,41 @@ def binary_cross_entropy(input, target, weight=None, size_average=None, tens_ops = (input, target) if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): return handle_torch_function( - binary_cross_entropy, tens_ops, input, target, weight=weight, - size_average=size_average, reduce=reduce, reduction=reduction) + binary_cross_entropy, + tens_ops, + input, + target, + weight=weight, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) if size_average is not None or reduce is not None: reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) else: reduction_enum = _Reduction.get_enum(reduction) if target.size() != input.size(): - raise ValueError("Using a target size ({}) that is different to the input size ({}) is deprecated. " - "Please ensure they have the same size.".format(target.size(), input.size())) + raise ValueError( + "Using a target size ({}) that is different to the input size ({}) is deprecated. " + "Please ensure they have the same size.".format(target.size(), input.size()) + ) if weight is not None: new_size = _infer_size(target.size(), weight.size()) weight = weight.expand(new_size) - return torch._C._nn.binary_cross_entropy( - input, target, weight, reduction_enum) + return torch._C._nn.binary_cross_entropy(input, target, weight, reduction_enum) -def binary_cross_entropy_with_logits(input, target, weight=None, size_average=None, - reduce=None, reduction='mean', pos_weight=None): - # type: (Tensor, Tensor, Optional[Tensor], Optional[bool], Optional[bool], str, Optional[Tensor]) -> Tensor +def binary_cross_entropy_with_logits( + input: Tensor, + target: Tensor, + weight: Optional[Tensor] = None, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", + pos_weight: Optional[Tensor] = None, +) -> Tensor: r"""Function that measures Binary Cross Entropy between target and output logits. @@ -2539,9 +2816,16 @@ def binary_cross_entropy_with_logits(input, target, weight=None, size_average=No tens_ops = (input, target) if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): return handle_torch_function( - binary_cross_entropy_with_logits, tens_ops, input, target, weight=weight, - size_average=size_average, reduce=reduce, reduction=reduction, - pos_weight=pos_weight) + binary_cross_entropy_with_logits, + tens_ops, + input, + target, + weight=weight, + size_average=size_average, + reduce=reduce, + reduction=reduction, + pos_weight=pos_weight, + ) if size_average is not None or reduce is not None: reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) else: @@ -2553,8 +2837,14 @@ def binary_cross_entropy_with_logits(input, target, weight=None, size_average=No return torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction_enum) -def smooth_l1_loss(input, target, size_average=None, reduce=None, reduction='mean', beta=1.0): - # type: (Tensor, Tensor, Optional[bool], Optional[bool], str, float) -> Tensor +def smooth_l1_loss( + input: Tensor, + target: Tensor, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", + beta: float = 1.0, +) -> Tensor: r"""Function that uses a squared term if the absolute element-wise error falls below beta and an L1 term otherwise. @@ -2564,13 +2854,22 @@ def smooth_l1_loss(input, target, size_average=None, reduce=None, reduction='mea tens_ops = (input, target) if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): return handle_torch_function( - smooth_l1_loss, tens_ops, input, target, size_average=size_average, - reduce=reduce, reduction=reduction, beta=beta) + smooth_l1_loss, + tens_ops, + input, + target, + size_average=size_average, + reduce=reduce, + reduction=reduction, + beta=beta, + ) if not (target.size() == input.size()): - warnings.warn("Using a target size ({}) that is different to the input size ({}). " - "This will likely lead to incorrect results due to broadcasting. " - "Please ensure they have the same size.".format(target.size(), input.size()), - stacklevel=2) + warnings.warn( + "Using a target size ({}) that is different to the input size ({}). " + "This will likely lead to incorrect results due to broadcasting. " + "Please ensure they have the same size.".format(target.size(), input.size()), + stacklevel=2, + ) if size_average is not None or reduce is not None: reduction = _Reduction.legacy_get_string(size_average, reduce) @@ -2578,8 +2877,13 @@ def smooth_l1_loss(input, target, size_average=None, reduce=None, reduction='mea return torch._C._nn.smooth_l1_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction), beta) -def l1_loss(input, target, size_average=None, reduce=None, reduction='mean'): - # type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor +def l1_loss( + input: Tensor, + target: Tensor, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: r"""l1_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor Function that takes the mean element-wise absolute value difference. @@ -2590,23 +2894,29 @@ def l1_loss(input, target, size_average=None, reduce=None, reduction='mean'): tens_ops = (input, target) if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): return handle_torch_function( - l1_loss, tens_ops, input, target, size_average=size_average, reduce=reduce, - reduction=reduction) + l1_loss, tens_ops, input, target, size_average=size_average, reduce=reduce, reduction=reduction + ) if not (target.size() == input.size()): - warnings.warn("Using a target size ({}) that is different to the input size ({}). " - "This will likely lead to incorrect results due to broadcasting. " - "Please ensure they have the same size.".format(target.size(), input.size()), - stacklevel=2) + warnings.warn( + "Using a target size ({}) that is different to the input size ({}). " + "This will likely lead to incorrect results due to broadcasting. " + "Please ensure they have the same size.".format(target.size(), input.size()), + stacklevel=2, + ) if size_average is not None or reduce is not None: reduction = _Reduction.legacy_get_string(size_average, reduce) - expanded_input, expanded_target = torch.broadcast_tensors(input, target) return torch._C._nn.l1_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction)) -def mse_loss(input, target, size_average=None, reduce=None, reduction='mean'): - # type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor +def mse_loss( + input: Tensor, + target: Tensor, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: r"""mse_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor Measures the element-wise mean squared error. @@ -2617,13 +2927,15 @@ def mse_loss(input, target, size_average=None, reduce=None, reduction='mean'): tens_ops = (input, target) if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): return handle_torch_function( - mse_loss, tens_ops, input, target, size_average=size_average, reduce=reduce, - reduction=reduction) + mse_loss, tens_ops, input, target, size_average=size_average, reduce=reduce, reduction=reduction + ) if not (target.size() == input.size()): - warnings.warn("Using a target size ({}) that is different to the input size ({}). " - "This will likely lead to incorrect results due to broadcasting. " - "Please ensure they have the same size.".format(target.size(), input.size()), - stacklevel=2) + warnings.warn( + "Using a target size ({}) that is different to the input size ({}). " + "This will likely lead to incorrect results due to broadcasting. " + "Please ensure they have the same size.".format(target.size(), input.size()), + stacklevel=2, + ) if size_average is not None or reduce is not None: reduction = _Reduction.legacy_get_string(size_average, reduce) @@ -2631,9 +2943,15 @@ def mse_loss(input, target, size_average=None, reduce=None, reduction='mean'): return torch._C._nn.mse_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction)) -def margin_ranking_loss(input1, input2, target, margin=0, size_average=None, - reduce=None, reduction='mean'): - # type: (Tensor, Tensor, Tensor, float, Optional[bool], Optional[bool], str) -> Tensor +def margin_ranking_loss( + input1: Tensor, + input2: Tensor, + target: Tensor, + margin: float = 0, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: r"""margin_ranking_loss(input1, input2, target, margin=0, size_average=None, reduce=None, reduction='mean') -> Tensor See :class:`~torch.nn.MarginRankingLoss` for details. @@ -2642,21 +2960,38 @@ def margin_ranking_loss(input1, input2, target, margin=0, size_average=None, tens_ops = (input1, input2, target) if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): return handle_torch_function( - margin_ranking_loss, tens_ops, input1, input2, target, margin=margin, - size_average=size_average, reduce=reduce, reduction=reduction) + margin_ranking_loss, + tens_ops, + input1, + input2, + target, + margin=margin, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) if size_average is not None or reduce is not None: reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) else: reduction_enum = _Reduction.get_enum(reduction) if input1.dim() == 0 or input2.dim() == 0 or target.dim() == 0: - raise RuntimeError(("margin_ranking_loss does not support scalars, got sizes: " - "input1: {}, input2: {}, target: {} ".format(input1.size(), input2.size(), target.size()))) + raise RuntimeError( + ( + "margin_ranking_loss does not support scalars, got sizes: " + "input1: {}, input2: {}, target: {} ".format(input1.size(), input2.size(), target.size()) + ) + ) return torch.margin_ranking_loss(input1, input2, target, margin, reduction_enum) -def hinge_embedding_loss(input, target, margin=1.0, size_average=None, - reduce=None, reduction='mean'): - # type: (Tensor, Tensor, float, Optional[bool], Optional[bool], str) -> Tensor +def hinge_embedding_loss( + input: Tensor, + target: Tensor, + margin: float = 1.0, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: r"""hinge_embedding_loss(input, target, margin=1.0, size_average=None, reduce=None, reduction='mean') -> Tensor See :class:`~torch.nn.HingeEmbeddingLoss` for details. @@ -2665,8 +3000,15 @@ def hinge_embedding_loss(input, target, margin=1.0, size_average=None, tens_ops = (input, target) if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): return handle_torch_function( - hinge_embedding_loss, tens_ops, input, target, margin=margin, - size_average=size_average, reduce=reduce, reduction=reduction) + hinge_embedding_loss, + tens_ops, + input, + target, + margin=margin, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) if size_average is not None or reduce is not None: reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) else: @@ -2674,8 +3016,13 @@ def hinge_embedding_loss(input, target, margin=1.0, size_average=None, return torch.hinge_embedding_loss(input, target, margin, reduction_enum) -def multilabel_margin_loss(input, target, size_average=None, reduce=None, reduction='mean'): - # type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor +def multilabel_margin_loss( + input: Tensor, + target: Tensor, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: r"""multilabel_margin_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor See :class:`~torch.nn.MultiLabelMarginLoss` for details. @@ -2684,8 +3031,14 @@ def multilabel_margin_loss(input, target, size_average=None, reduce=None, reduct tens_ops = (input, target) if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): return handle_torch_function( - multilabel_margin_loss, tens_ops, input, target, size_average=size_average, - reduce=reduce, reduction=reduction) + multilabel_margin_loss, + tens_ops, + input, + target, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) if size_average is not None or reduce is not None: reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) else: @@ -2693,8 +3046,13 @@ def multilabel_margin_loss(input, target, size_average=None, reduce=None, reduct return torch._C._nn.multilabel_margin_loss(input, target, reduction_enum) -def soft_margin_loss(input, target, size_average=None, reduce=None, reduction='mean'): - # type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor +def soft_margin_loss( + input: Tensor, + target: Tensor, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: r"""soft_margin_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor See :class:`~torch.nn.SoftMarginLoss` for details. @@ -2703,8 +3061,8 @@ def soft_margin_loss(input, target, size_average=None, reduce=None, reduction='m tens_ops = (input, target) if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): return handle_torch_function( - soft_margin_loss, tens_ops, input, target, size_average=size_average, - reduce=reduce, reduction=reduction) + soft_margin_loss, tens_ops, input, target, size_average=size_average, reduce=reduce, reduction=reduction + ) if size_average is not None or reduce is not None: reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) else: @@ -2712,9 +3070,14 @@ def soft_margin_loss(input, target, size_average=None, reduce=None, reduction='m return torch._C._nn.soft_margin_loss(input, target, reduction_enum) -def multilabel_soft_margin_loss(input, target, weight=None, size_average=None, - reduce=None, reduction='mean'): - # type: (Tensor, Tensor, Optional[Tensor], Optional[bool], Optional[bool], str) -> Tensor +def multilabel_soft_margin_loss( + input: Tensor, + target: Tensor, + weight: Optional[Tensor] = None, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: r"""multilabel_soft_margin_loss(input, target, weight=None, size_average=None) -> Tensor See :class:`~torch.nn.MultiLabelSoftMarginLoss` for details. @@ -2723,8 +3086,15 @@ def multilabel_soft_margin_loss(input, target, weight=None, size_average=None, tens_ops = (input, target) if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): return handle_torch_function( - multilabel_soft_margin_loss, tens_ops, input, target, weight=weight, - size_average=size_average, reduce=reduce, reduction=reduction) + multilabel_soft_margin_loss, + tens_ops, + input, + target, + weight=weight, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) if size_average is not None or reduce is not None: reduction = _Reduction.legacy_get_string(size_average, reduce) @@ -2735,11 +3105,11 @@ def multilabel_soft_margin_loss(input, target, weight=None, size_average=None, loss = loss.sum(dim=1) / input.size(1) # only return N loss values - if reduction == 'none': + if reduction == "none": ret = loss - elif reduction == 'mean': + elif reduction == "mean": ret = loss.mean() - elif reduction == 'sum': + elif reduction == "sum": ret = loss.sum() else: ret = input @@ -2747,9 +3117,15 @@ def multilabel_soft_margin_loss(input, target, weight=None, size_average=None, return ret -def cosine_embedding_loss(input1, input2, target, margin=0, size_average=None, - reduce=None, reduction='mean'): - # type: (Tensor, Tensor, Tensor, float, Optional[bool], Optional[bool], str) -> Tensor +def cosine_embedding_loss( + input1: Tensor, + input2: Tensor, + target: Tensor, + margin: float = 0, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: r"""cosine_embedding_loss(input1, input2, target, margin=0, size_average=None, reduce=None, reduction='mean') -> Tensor See :class:`~torch.nn.CosineEmbeddingLoss` for details. @@ -2758,8 +3134,16 @@ def cosine_embedding_loss(input1, input2, target, margin=0, size_average=None, tens_ops = (input1, input2, target) if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): return handle_torch_function( - cosine_embedding_loss, tens_ops, input1, input2, target, margin=margin, - size_average=size_average, reduce=reduce, reduction=reduction) + cosine_embedding_loss, + tens_ops, + input1, + input2, + target, + margin=margin, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) if size_average is not None or reduce is not None: reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) else: @@ -2767,9 +3151,16 @@ def cosine_embedding_loss(input1, input2, target, margin=0, size_average=None, return torch.cosine_embedding_loss(input1, input2, target, margin, reduction_enum) -def multi_margin_loss(input, target, p=1, margin=1., weight=None, size_average=None, - reduce=None, reduction='mean'): - # type: (Tensor, Tensor, int, float, Optional[Tensor], Optional[bool], Optional[bool], str) -> Tensor +def multi_margin_loss( + input: Tensor, + target: Tensor, + p: int = 1, + margin: float = 1.0, + weight: Optional[Tensor] = None, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: r"""multi_margin_loss(input, target, p=1, margin=1, weight=None, size_average=None, reduce=None, reduction='mean') -> Tensor @@ -2779,23 +3170,33 @@ def multi_margin_loss(input, target, p=1, margin=1., weight=None, size_average=N tens_ops = (input, target) if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): return handle_torch_function( - multi_margin_loss, tens_ops, input, target, p=p, margin=margin, - weight=weight, size_average=size_average, reduce=reduce, - reduction=reduction) + multi_margin_loss, + tens_ops, + input, + target, + p=p, + margin=margin, + weight=weight, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) if size_average is not None or reduce is not None: reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) else: reduction_enum = _Reduction.get_enum(reduction) if p != 1 and p != 2: - raise ValueError('only p == 1 and p == 2 supported') + raise ValueError("only p == 1 and p == 2 supported") if weight is not None: if weight.dim() != 1: - raise ValueError('weight must be one-dimensional') + raise ValueError("weight must be one-dimensional") return torch._C._nn.multi_margin_loss(input, target, p, margin, weight, reduction_enum) -pixel_shuffle = _add_docstr(torch.pixel_shuffle, r""" +pixel_shuffle = _add_docstr( + torch.pixel_shuffle, + r""" pixel_shuffle(input, upscale_factor) -> Tensor Rearranges elements in a tensor of shape :math:`(*, C \times r^2, H, W)` to a @@ -2813,9 +3214,12 @@ def multi_margin_loss(input, target, p=1, margin=1., weight=None, size_average=N >>> output = torch.nn.functional.pixel_shuffle(input, 3) >>> print(output.size()) torch.Size([1, 1, 12, 12]) -""") +""", +) -pixel_unshuffle = _add_docstr(torch.pixel_unshuffle, r""" +pixel_unshuffle = _add_docstr( + torch.pixel_unshuffle, + r""" pixel_unshuffle(input, downscale_factor) -> Tensor Reverses the :class:`~torch.nn.PixelShuffle` operation by rearranging elements in a @@ -2834,9 +3238,12 @@ def multi_margin_loss(input, target, p=1, margin=1., weight=None, size_average=N >>> output = torch.nn.functional.pixel_unshuffle(input, 3) >>> print(output.size()) torch.Size([1, 9, 4, 4]) -""") +""", +) -channel_shuffle = _add_docstr(torch.channel_shuffle, r""" +channel_shuffle = _add_docstr( + torch.channel_shuffle, + r""" channel_shuffle(input, groups) -> Tensor Divide the channels in a tensor of shape :math:`(*, C , H, W)` @@ -2873,20 +3280,23 @@ def multi_margin_loss(input, target, p=1, margin=1., weight=None, size_average=N [[13, 14], [15, 16]], ]] -""") +""", +) + @_overload # noqa: F811 -def upsample(input, size=None, scale_factor=None, mode='nearest', align_corners=None): # noqa: F811 +def upsample(input, size=None, scale_factor=None, mode="nearest", align_corners=None): # noqa: F811 # type: (Tensor, Optional[int], Optional[float], str, Optional[bool]) -> Tensor pass + @_overload # noqa: F811 -def upsample(input, size=None, scale_factor=None, mode='nearest', align_corners=None): # noqa: F811 +def upsample(input, size=None, scale_factor=None, mode="nearest", align_corners=None): # noqa: F811 # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor pass -def upsample(input, size=None, scale_factor=None, mode='nearest', align_corners=None): # noqa: F811 +def upsample(input, size=None, scale_factor=None, mode="nearest", align_corners=None): # noqa: F811 r"""Upsamples the input to either the given :attr:`size` or the given :attr:`scale_factor` @@ -2945,26 +3355,38 @@ def upsample(input, size=None, scale_factor=None, mode='nearest', align_corners= """ warnings.warn("nn.functional.upsample is deprecated. Use nn.functional.interpolate instead.") return interpolate(input, size, scale_factor, mode, align_corners) + + upsample.__doc__ = upsample.__doc__.format(**reproducibility_notes) + @_overload # noqa: F811 def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None): # noqa: F811 # type: (Tensor, Optional[int], Optional[List[float]], str, Optional[bool], Optional[bool]) -> Tensor pass + @_overload # noqa: F811 def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None): # noqa: F811 # type: (Tensor, Optional[List[int]], Optional[List[float]], str, Optional[bool], Optional[bool]) -> Tensor pass + @_overload # noqa: F811 def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None): # noqa: F811 # type: (Tensor, Optional[int], Optional[float], str, Optional[bool], Optional[bool]) -> Tensor pass + @_overload # noqa: F811 -def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None): # noqa: F811 - # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool], Optional[bool]) -> Tensor +def interpolate( # noqa: F811 + input: Tensor, + size: Optional[List[int]] = None, + scale_factor: Optional[float] = None, + mode: str = "nearest", + align_corners: Optional[bool] = None, + recompute_scale_factor: Optional[bool] = None, +) -> Tensor: # noqa: F811 pass def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None): # noqa: F811 @@ -3040,20 +3462,30 @@ def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corne if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): return handle_torch_function( - interpolate, (input,), input, size=size, scale_factor=scale_factor, - mode=mode, align_corners=align_corners, - recompute_scale_factor=recompute_scale_factor) - - if mode in ('nearest', 'area'): + interpolate, + (input,), + input, + size=size, + scale_factor=scale_factor, + mode=mode, + align_corners=align_corners, + recompute_scale_factor=recompute_scale_factor, + ) + + if mode in ("nearest", "area"): if align_corners is not None: - raise ValueError("align_corners option can only be set with the " - "interpolating modes: linear | bilinear | bicubic | trilinear") + raise ValueError( + "align_corners option can only be set with the " + "interpolating modes: linear | bilinear | bicubic | trilinear" + ) else: if align_corners is None: - warnings.warn("Default upsampling behavior when mode={} is changed " - "to align_corners=False since 0.4.0. Please specify " - "align_corners=True if the old behavior is desired. " - "See the documentation of nn.Upsample for details.".format(mode)) + warnings.warn( + "Default upsampling behavior when mode={} is changed " + "to align_corners=False since 0.4.0. Please specify " + "align_corners=True if the old behavior is desired. " + "See the documentation of nn.Upsample for details.".format(mode) + ) align_corners = False dim = input.dim() - 2 # Number of spatial dimensions. @@ -3063,14 +3495,15 @@ def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corne # After this block, exactly one of output_size and scale_factors will # be non-None, and it will be a list (or tuple). if size is not None and scale_factor is not None: - raise ValueError('only one of size or scale_factor should be defined') + raise ValueError("only one of size or scale_factor should be defined") elif size is not None: assert scale_factor is None scale_factors = None if isinstance(size, (list, tuple)): if len(size) != dim: - raise ValueError('size shape must match input shape. ' - 'Input is {}D, size is {}'.format(dim, len(size))) + raise ValueError( + "size shape must match input shape. " "Input is {}D, size is {}".format(dim, len(size)) + ) output_size = size else: output_size = [size for _ in range(dim)] @@ -3079,13 +3512,15 @@ def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corne output_size = None if isinstance(scale_factor, (list, tuple)): if len(scale_factor) != dim: - raise ValueError('scale_factor shape must match input shape. ' - 'Input is {}D, scale_factor is {}'.format(dim, len(scale_factor))) + raise ValueError( + "scale_factor shape must match input shape. " + "Input is {}D, scale_factor is {}".format(dim, len(scale_factor)) + ) scale_factors = scale_factor else: scale_factors = [scale_factor for _ in range(dim)] else: - raise ValueError('either size or scale_factor should be defined') + raise ValueError("either size or scale_factor should be defined") if recompute_scale_factor is None: # only warn when the scales have floating values since @@ -3093,11 +3528,13 @@ def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corne if scale_factors is not None: for scale in scale_factors: if math.floor(scale) != scale: - warnings.warn("The default behavior for interpolate/upsample with float scale_factor changed " - "in 1.6.0 to align with other frameworks/libraries, and now uses scale_factor directly, " - "instead of relying on the computed output size. " - "If you wish to restore the old behavior, please set recompute_scale_factor=True. " - "See the documentation of nn.Upsample for details. ") + warnings.warn( + "The default behavior for interpolate/upsample with float scale_factor changed " + "in 1.6.0 to align with other frameworks/libraries, and now uses scale_factor directly, " + "instead of relying on the computed output size. " + "If you wish to restore the old behavior, please set recompute_scale_factor=True. " + "See the documentation of nn.Upsample for details. " + ) break elif recompute_scale_factor and size is not None: raise ValueError("recompute_scale_factor is not meaningful with an explicit size.") @@ -3112,71 +3549,80 @@ def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corne # The C++ code will recompute it based on the (integer) output size. if not torch.jit.is_scripting() and torch._C._get_tracing_state(): # make scale_factor a tensor in tracing so constant doesn't get baked in - output_size = [(torch.floor((input.size(i + 2).float() * torch.tensor(scale_factors[i], - dtype=torch.float32)).float())) for i in range(dim)] + output_size = [ + (torch.floor((input.size(i + 2).float() * torch.tensor(scale_factors[i], dtype=torch.float32)).float())) + for i in range(dim) + ] else: assert scale_factors is not None output_size = [int(math.floor(float(input.size(i + 2)) * scale_factors[i])) for i in range(dim)] scale_factors = None - if input.dim() == 3 and mode == 'nearest': + if input.dim() == 3 and mode == "nearest": return torch._C._nn.upsample_nearest1d(input, output_size, scale_factors) - if input.dim() == 4 and mode == 'nearest': + if input.dim() == 4 and mode == "nearest": return torch._C._nn.upsample_nearest2d(input, output_size, scale_factors) - if input.dim() == 5 and mode == 'nearest': + if input.dim() == 5 and mode == "nearest": return torch._C._nn.upsample_nearest3d(input, output_size, scale_factors) - if input.dim() == 3 and mode == 'area': + if input.dim() == 3 and mode == "area": assert output_size is not None return adaptive_avg_pool1d(input, output_size) - if input.dim() == 4 and mode == 'area': + if input.dim() == 4 and mode == "area": assert output_size is not None return adaptive_avg_pool2d(input, output_size) - if input.dim() == 5 and mode == 'area': + if input.dim() == 5 and mode == "area": assert output_size is not None return adaptive_avg_pool3d(input, output_size) - if input.dim() == 3 and mode == 'linear': + if input.dim() == 3 and mode == "linear": assert align_corners is not None return torch._C._nn.upsample_linear1d(input, output_size, align_corners, scale_factors) - if input.dim() == 4 and mode == 'bilinear': + if input.dim() == 4 and mode == "bilinear": assert align_corners is not None return torch._C._nn.upsample_bilinear2d(input, output_size, align_corners, scale_factors) - if input.dim() == 5 and mode == 'trilinear': + if input.dim() == 5 and mode == "trilinear": assert align_corners is not None return torch._C._nn.upsample_trilinear3d(input, output_size, align_corners, scale_factors) - if input.dim() == 4 and mode == 'bicubic': + if input.dim() == 4 and mode == "bicubic": assert align_corners is not None return torch._C._nn.upsample_bicubic2d(input, output_size, align_corners, scale_factors) - if input.dim() == 3 and mode == 'bilinear': + if input.dim() == 3 and mode == "bilinear": raise NotImplementedError("Got 3D input, but bilinear mode needs 4D input") - if input.dim() == 3 and mode == 'trilinear': + if input.dim() == 3 and mode == "trilinear": raise NotImplementedError("Got 3D input, but trilinear mode needs 5D input") - if input.dim() == 4 and mode == 'linear': + if input.dim() == 4 and mode == "linear": raise NotImplementedError("Got 4D input, but linear mode needs 3D input") - if input.dim() == 4 and mode == 'trilinear': + if input.dim() == 4 and mode == "trilinear": raise NotImplementedError("Got 4D input, but trilinear mode needs 5D input") - if input.dim() == 5 and mode == 'linear': + if input.dim() == 5 and mode == "linear": raise NotImplementedError("Got 5D input, but linear mode needs 3D input") - if input.dim() == 5 and mode == 'bilinear': + if input.dim() == 5 and mode == "bilinear": raise NotImplementedError("Got 5D input, but bilinear mode needs 4D input") - raise NotImplementedError("Input Error: Only 3D, 4D and 5D input Tensors supported" - " (got {}D) for the modes: nearest | linear | bilinear | bicubic | trilinear" - " (got {})".format(input.dim(), mode)) + raise NotImplementedError( + "Input Error: Only 3D, 4D and 5D input Tensors supported" + " (got {}D) for the modes: nearest | linear | bilinear | bicubic | trilinear" + " (got {})".format(input.dim(), mode) + ) + + interpolate.__doc__ = interpolate.__doc__.format(**reproducibility_notes) + @_overload # noqa: F811 def upsample_nearest(input, size=None, scale_factor=None): # noqa: F811 # type: (Tensor, Optional[int], Optional[float]) -> Tensor pass + @_overload # noqa: F811 def upsample_nearest(input, size=None, scale_factor=None): # noqa: F811 # type: (Tensor, Optional[List[int]], Optional[float]) -> Tensor pass + def upsample_nearest(input, size=None, scale_factor=None): # noqa: F811 r"""Upsamples the input, using nearest neighbours' pixel values. @@ -3198,29 +3644,40 @@ def upsample_nearest(input, size=None, scale_factor=None): # noqa: F811 """ # DeprecationWarning is ignored by default warnings.warn("nn.functional.upsample_nearest is deprecated. Use nn.functional.interpolate instead.") - return interpolate(input, size, scale_factor, mode='nearest') + return interpolate(input, size, scale_factor, mode="nearest") + + upsample_nearest.__doc__ = upsample_nearest.__doc__.format(**reproducibility_notes) + @_overload # noqa: F811 -def upsample_bilinear(input, size=None, scale_factor=None): # noqa: F811 - # type: (Tensor, Optional[int], Optional[float]) -> Tensor +def upsample_bilinear( + input: Tensor, size: Optional[int] = None, scale_factor: Optional[float] = None +) -> Tensor: # noqa: F811 pass + @_overload # noqa: F811 -def upsample_bilinear(input, size=None, scale_factor=None): # noqa: F811 - # type: (Tensor, Optional[List[int]], Optional[float]) -> Tensor +def upsample_bilinear( # noqa: F811 + input: Tensor, size: Optional[List[int]] = None, scale_factor: Optional[float] = None +) -> Tensor: # noqa: F811 pass + @_overload # noqa: F811 -def upsample_bilinear(input, size=None, scale_factor=None): # noqa: F811 - # type: (Tensor, Optional[int], Optional[List[float]]) -> Tensor +def upsample_bilinear( # noqa: F811 + input: Tensor, size: Optional[int] = None, scale_factor: Optional[List[float]] = None +) -> Tensor: # noqa: F811 pass + @_overload # noqa: F811 -def upsample_bilinear(input, size=None, scale_factor=None): # noqa: F811 - # type: (Tensor, Optional[List[int]], Optional[List[float]]) -> Tensor +def upsample_bilinear( # noqa: F811 + input: Tensor, size: Optional[List[int]] = None, scale_factor: Optional[List[float]] = None +) -> Tensor: # noqa: F811 pass + def upsample_bilinear(input, size=None, scale_factor=None): # noqa: F811 r"""Upsamples the input, using bilinear upsampling. @@ -3242,24 +3699,31 @@ def upsample_bilinear(input, size=None, scale_factor=None): # noqa: F811 """ # DeprecationWarning is ignored by default warnings.warn("nn.functional.upsample_bilinear is deprecated. Use nn.functional.interpolate instead.") - return interpolate(input, size, scale_factor, mode='bilinear', align_corners=True) + return interpolate(input, size, scale_factor, mode="bilinear", align_corners=True) + + upsample_bilinear.__doc__ = upsample_bilinear.__doc__.format(**reproducibility_notes) GRID_SAMPLE_INTERPOLATION_MODES = { - 'bilinear': 0, - 'nearest': 1, - 'bicubic': 2, + "bilinear": 0, + "nearest": 1, + "bicubic": 2, } GRID_SAMPLE_PADDING_MODES = { - 'zeros': 0, - 'border': 1, - 'reflection': 2, + "zeros": 0, + "border": 1, + "reflection": 2, } -def grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corners=None): - # type: (Tensor, Tensor, str, str, Optional[bool]) -> Tensor +def grid_sample( + input: Tensor, + grid: Tensor, + mode: str = "bilinear", + padding_mode: str = "zeros", + align_corners: Optional[bool] = None, +) -> Tensor: r"""Given an :attr:`input` and a flow-field :attr:`grid`, computes the ``output`` using :attr:`input` values and pixel locations from :attr:`grid`. @@ -3316,7 +3780,7 @@ def grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corner or :math:`(N, D_\text{out}, H_\text{out}, W_\text{out}, 3)` (5-D case) mode (str): interpolation mode to calculate output values ``'bilinear'`` | ``'nearest'`` | ``'bicubic'``. Default: ``'bilinear'`` - Note: ``mode='bicubic'`` supports only 4-D input. + Note: ``mode='bicubic'`` supports only 4-D input. When ``mode='bilinear'`` and the input is 5-D, the interpolation mode used internally will actually be trilinear. However, when the input is 4-D, the interpolation mode will legitimately be bilinear. @@ -3349,11 +3813,11 @@ def grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corner in order to bring it in line with the default for :func:`interpolate`. .. note:: - ``mode='bicubic'`` is implemented using the `cubic convolution algorithm`_ with :math:`\alpha=-0.75`. - The constant :math:`\alpha` might be different from packages to packages. - For example, `PIL`_ and `OpenCV`_ use -0.5 and -0.75 respectively. - This algorithm may "overshoot" the range of values it's interpolating. - For example, it may produce negative values or values greater than 255 when interpolating input in [0, 255]. + ``mode='bicubic'`` is implemented using the `cubic convolution algorithm`_ with :math:`\alpha=-0.75`. + The constant :math:`\alpha` might be different from packages to packages. + For example, `PIL`_ and `OpenCV`_ use -0.5 and -0.75 respectively. + This algorithm may "overshoot" the range of values it's interpolating. + For example, it may produce negative values or values greater than 255 when interpolating input in [0, 255]. Clamp the results with :func: `torch.clamp` to ensure they are within the valid range. .. _`cubic convolution algorithm`: https://en.wikipedia.org/wiki/Bicubic_interpolation .. _`PIL`: https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/src/libImaging/Resample.c#L51 @@ -3363,42 +3827,47 @@ def grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corner tens_ops = (input, grid) if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): return handle_torch_function( - grid_sample, tens_ops, input, grid, mode=mode, padding_mode=padding_mode, - align_corners=align_corners) - if mode != 'bilinear' and mode != 'nearest' and mode != 'bicubic': - raise ValueError("nn.functional.grid_sample(): expected mode to be " - "'bilinear', 'nearest' or 'bicubic', but got: '{}'".format(mode)) - if padding_mode != 'zeros' and padding_mode != 'border' and padding_mode != 'reflection': - raise ValueError("nn.functional.grid_sample(): expected padding_mode " - "to be 'zeros', 'border', or 'reflection', " - "but got: '{}'".format(padding_mode)) - - if mode == 'bilinear': + grid_sample, tens_ops, input, grid, mode=mode, padding_mode=padding_mode, align_corners=align_corners + ) + if mode != "bilinear" and mode != "nearest" and mode != "bicubic": + raise ValueError( + "nn.functional.grid_sample(): expected mode to be " + "'bilinear', 'nearest' or 'bicubic', but got: '{}'".format(mode) + ) + if padding_mode != "zeros" and padding_mode != "border" and padding_mode != "reflection": + raise ValueError( + "nn.functional.grid_sample(): expected padding_mode " + "to be 'zeros', 'border', or 'reflection', " + "but got: '{}'".format(padding_mode) + ) + + if mode == "bilinear": mode_enum = 0 - elif mode == 'nearest': + elif mode == "nearest": mode_enum = 1 else: # mode == 'bicubic' mode_enum = 2 - if padding_mode == 'zeros': + if padding_mode == "zeros": padding_mode_enum = 0 - elif padding_mode == 'border': + elif padding_mode == "border": padding_mode_enum = 1 else: # padding_mode == 'reflection' padding_mode_enum = 2 if align_corners is None: - warnings.warn("Default grid_sample and affine_grid behavior has changed " - "to align_corners=False since 1.3.0. Please specify " - "align_corners=True if the old behavior is desired. " - "See the documentation of grid_sample for details.") + warnings.warn( + "Default grid_sample and affine_grid behavior has changed " + "to align_corners=False since 1.3.0. Please specify " + "align_corners=True if the old behavior is desired. " + "See the documentation of grid_sample for details." + ) align_corners = False return torch.grid_sampler(input, grid, mode_enum, padding_mode_enum, align_corners) -def affine_grid(theta, size, align_corners=None): - # type: (Tensor, List[int], Optional[bool]) -> Tensor +def affine_grid(theta: Tensor, size: List[int], align_corners: Optional[bool] = None) -> Tensor: r"""Generates a 2D or 3D flow field (sampling grid), given a batch of affine matrices :attr:`theta`. @@ -3448,49 +3917,55 @@ def affine_grid(theta, size, align_corners=None): """ if not torch.jit.is_scripting(): if type(theta) is not Tensor and has_torch_function((theta,)): - return handle_torch_function( - affine_grid, (theta,), theta, size, align_corners=align_corners) + return handle_torch_function(affine_grid, (theta,), theta, size, align_corners=align_corners) if align_corners is None: - warnings.warn("Default grid_sample and affine_grid behavior has changed " - "to align_corners=False since 1.3.0. Please specify " - "align_corners=True if the old behavior is desired. " - "See the documentation of grid_sample for details.") + warnings.warn( + "Default grid_sample and affine_grid behavior has changed " + "to align_corners=False since 1.3.0. Please specify " + "align_corners=True if the old behavior is desired. " + "See the documentation of grid_sample for details." + ) align_corners = False # enforce floating point dtype on theta if not theta.is_floating_point(): - raise ValueError("Expected theta to have floating point type, but got {}" - .format(theta.dtype)) + raise ValueError("Expected theta to have floating point type, but got {}".format(theta.dtype)) # check that shapes and sizes match if len(size) == 4: if theta.dim() != 3 or theta.shape[-2] != 2 or theta.shape[-1] != 3: - raise ValueError("Expected a batch of 2D affine matrices of shape Nx2x3 " - "for size {}. Got {}.".format(size, theta.shape)) + raise ValueError( + "Expected a batch of 2D affine matrices of shape Nx2x3 " + "for size {}. Got {}.".format(size, theta.shape) + ) spatial_size = size[-2:] # spatial dimension sizes elif len(size) == 5: if theta.dim() != 3 or theta.shape[-2] != 3 or theta.shape[-1] != 4: - raise ValueError("Expected a batch of 3D affine matrices of shape Nx3x4 " - "for size {}. Got {}.".format(size, theta.shape)) + raise ValueError( + "Expected a batch of 3D affine matrices of shape Nx3x4 " + "for size {}. Got {}.".format(size, theta.shape) + ) spatial_size = size[-3:] # spatial dimension sizes else: - raise NotImplementedError("affine_grid only supports 4D and 5D sizes, " - "for 2D and 3D affine transforms, respectively. " - "Got size {}.".format(size)) + raise NotImplementedError( + "affine_grid only supports 4D and 5D sizes, " + "for 2D and 3D affine transforms, respectively. " + "Got size {}.".format(size) + ) # check for empty span if align_corners and min(spatial_size) == 1: - warnings.warn("Since version 1.3.0, affine_grid behavior has changed " - "for unit-size grids when align_corners=True. " - "This is not an intended use case of affine_grid. " - "See the documentation of affine_grid for details.") + warnings.warn( + "Since version 1.3.0, affine_grid behavior has changed " + "for unit-size grids when align_corners=True. " + "This is not an intended use case of affine_grid. " + "See the documentation of affine_grid for details." + ) elif min(size) <= 0: - raise ValueError("Expected non-zero, positive output size. Got {}" - .format(size)) + raise ValueError("Expected non-zero, positive output size. Got {}".format(size)) return torch.affine_grid_generator(theta, size, align_corners) -def _pad(input, pad, mode='constant', value=0): - # type: (Tensor, List[int], str, float) -> Tensor +def _pad(input: Tensor, pad: List[int], mode: str = "constant", value: float = 0) -> Tensor: r"""Pads tensor. Padding size: @@ -3551,49 +4026,49 @@ def _pad(input, pad, mode='constant', value=0): """ if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - _pad, (input,), input, pad, mode=mode, value=value) - assert len(pad) % 2 == 0, 'Padding length must be divisible by 2' - assert len(pad) // 2 <= input.dim(), 'Padding length too large' - if mode == 'constant': + return handle_torch_function(_pad, (input,), input, pad, mode=mode, value=value) + assert len(pad) % 2 == 0, "Padding length must be divisible by 2" + assert len(pad) // 2 <= input.dim(), "Padding length too large" + if mode == "constant": return _VF.constant_pad_nd(input, pad, value) else: assert value == 0, 'Padding mode "{}"" doesn\'t take in value argument'.format(mode) if input.dim() == 3: - assert len(pad) == 2, '3D tensors expect 2 values for padding' - if mode == 'reflect': + assert len(pad) == 2, "3D tensors expect 2 values for padding" + if mode == "reflect": return torch._C._nn.reflection_pad1d(input, pad) - elif mode == 'replicate': + elif mode == "replicate": return torch._C._nn.replication_pad1d(input, pad) - elif mode == 'circular': + elif mode == "circular": return _pad_circular(input, pad) else: raise NotImplementedError elif input.dim() == 4: - assert len(pad) == 4, '4D tensors expect 4 values for padding' - if mode == 'reflect': + assert len(pad) == 4, "4D tensors expect 4 values for padding" + if mode == "reflect": return torch._C._nn.reflection_pad2d(input, pad) - elif mode == 'replicate': + elif mode == "replicate": return torch._C._nn.replication_pad2d(input, pad) - elif mode == 'circular': + elif mode == "circular": return _pad_circular(input, pad) else: raise NotImplementedError elif input.dim() == 5: - assert len(pad) == 6, '5D tensors expect 6 values for padding' - if mode == 'reflect': + assert len(pad) == 6, "5D tensors expect 6 values for padding" + if mode == "reflect": raise NotImplementedError - elif mode == 'replicate': + elif mode == "replicate": return torch._C._nn.replication_pad3d(input, pad) - elif mode == 'circular': + elif mode == "circular": return _pad_circular(input, pad) else: raise NotImplementedError else: raise NotImplementedError("Only 3D, 4D, 5D padding with non-constant padding are supported for now") + # We define this function as _pad because it takes an argument # named pad, which clobbers the recursive reference to the pad # function needed for __torch_function__ support @@ -3602,15 +4077,16 @@ def _pad(input, pad, mode='constant', value=0): # distance -def pairwise_distance(x1, x2, p=2., eps=1e-6, keepdim=False): - # type: (Tensor, Tensor, float, float, bool) -> Tensor +def pairwise_distance(x1: Tensor, x2: Tensor, p: float = 2.0, eps: float = 1e-6, keepdim: bool = False) -> Tensor: r""" See :class:`torch.nn.PairwiseDistance` for details """ return torch.pairwise_distance(x1, x2, p, eps, keepdim) -pdist = _add_docstr(torch.pdist, r""" +pdist = _add_docstr( + torch.pdist, + r""" pdist(input, p=2) -> Tensor Computes the p-norm distance between every pair of row vectors in the input. @@ -3631,10 +4107,13 @@ def pairwise_distance(x1, x2, p=2., eps=1e-6, keepdim=False): input: input tensor of shape :math:`N \times M`. p: p value for the p-norm distance to calculate between each vector pair :math:`\in [0, \infty]`. -""") +""", +) -cosine_similarity = _add_docstr(torch.cosine_similarity, r""" +cosine_similarity = _add_docstr( + torch.cosine_similarity, + r""" cosine_similarity(x1, x2, dim=1, eps=1e-8) -> Tensor Returns cosine similarity between x1 and x2, computed along dim. @@ -3659,10 +4138,13 @@ def pairwise_distance(x1, x2, p=2., eps=1e-6, keepdim=False): >>> input2 = torch.randn(100, 128) >>> output = F.cosine_similarity(input1, input2) >>> print(output) -""") +""", +) -one_hot = _add_docstr(torch._C._nn.one_hot, r""" +one_hot = _add_docstr( + torch._C._nn.one_hot, + r""" one_hot(tensor, num_classes=-1) -> LongTensor Takes LongTensor with index values of shape ``(*)`` and returns a tensor @@ -3706,12 +4188,22 @@ def pairwise_distance(x1, x2, p=2., eps=1e-6, keepdim=False): [1, 0, 0]], [[0, 1, 0], [0, 0, 1]]]) -""") - - -def triplet_margin_loss(anchor, positive, negative, margin=1.0, p=2, eps=1e-6, swap=False, size_average=None, - reduce=None, reduction="mean"): - # type: (Tensor, Tensor, Tensor, float, float, float, bool, Optional[bool], Optional[bool], str) -> Tensor +""", +) + + +def triplet_margin_loss( + anchor: Tensor, + positive: Tensor, + negative: Tensor, + margin: float = 1.0, + p: float = 2, + eps: float = 1e-6, + swap: bool = False, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: r""" See :class:`~torch.nn.TripletMarginLoss` for details """ @@ -3719,32 +4211,58 @@ def triplet_margin_loss(anchor, positive, negative, margin=1.0, p=2, eps=1e-6, s tens_ops = (anchor, positive, negative) if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): return handle_torch_function( - triplet_margin_loss, tens_ops, anchor, positive, negative, margin=margin, - p=p, eps=eps, swap=swap, size_average=size_average, reduce=reduce, - reduction=reduction) + triplet_margin_loss, + tens_ops, + anchor, + positive, + negative, + margin=margin, + p=p, + eps=eps, + swap=swap, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) if size_average is not None or reduce is not None: reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) else: reduction_enum = _Reduction.get_enum(reduction) - return torch.triplet_margin_loss(anchor, positive, negative, margin, p, eps, - swap, reduction_enum) - - -def triplet_margin_with_distance_loss(anchor, positive, negative, *, distance_function=None, - margin=1.0, swap=False, reduction="mean"): - # type: (Tensor, Tensor, Tensor, Optional[Callable[[Tensor, Tensor], Tensor]], float, bool, str) -> Tensor + return torch.triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap, reduction_enum) + + +def triplet_margin_with_distance_loss( + anchor: Tensor, + positive: Tensor, + negative: Tensor, + *, + distance_function: Optional[Callable[[Tensor, Tensor], Tensor]] = None, + margin: float = 1.0, + swap: bool = False, + reduction: str = "mean" +) -> Tensor: r""" See :class:`~torch.nn.TripletMarginWithDistanceLoss` for details. """ if torch.jit.is_scripting(): - raise NotImplementedError("F.triplet_margin_with_distance_loss does not support JIT scripting: " - "functions requiring Callables cannot be scripted.") + raise NotImplementedError( + "F.triplet_margin_with_distance_loss does not support JIT scripting: " + "functions requiring Callables cannot be scripted." + ) tens_ops = (anchor, positive, negative) if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): return handle_torch_function( - triplet_margin_with_distance_loss, tens_ops, anchor, positive, negative, - distance_function=distance_function, margin=margin, swap=swap, reduction=reduction) + triplet_margin_with_distance_loss, + tens_ops, + anchor, + positive, + negative, + distance_function=distance_function, + margin=margin, + swap=swap, + reduction=reduction, + ) distance_function = distance_function if distance_function is not None else pairwise_distance @@ -3766,8 +4284,7 @@ def triplet_margin_with_distance_loss(anchor, positive, negative, *, distance_fu return output -def normalize(input, p=2, dim=1, eps=1e-12, out=None): - # type: (Tensor, float, int, float, Optional[Tensor]) -> Tensor +def normalize(input: Tensor, p: float = 2, dim: int = 1, eps: float = 1e-12, out: Optional[Tensor] = None) -> Tensor: r"""Performs :math:`L_p` normalization of inputs over specified dimension. For a tensor :attr:`input` of sizes :math:`(n_0, ..., n_{dim}, ..., n_k)`, each @@ -3788,8 +4305,7 @@ def normalize(input, p=2, dim=1, eps=1e-12, out=None): """ if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - normalize, (input,), input, p=p, dim=dim, eps=eps, out=out) + return handle_torch_function(normalize, (input,), input, p=p, dim=dim, eps=eps, out=out) if out is None: denom = input.norm(p, dim, keepdim=True).clamp_min(eps).expand_as(input) return input / denom @@ -3798,8 +4314,7 @@ def normalize(input, p=2, dim=1, eps=1e-12, out=None): return torch.div(input, denom, out=out) -def assert_int_or_pair(arg, arg_name, message): - # type: (List[int], str, str) -> None +def assert_int_or_pair(arg: List[int], arg_name: str, message: str) -> None: assert isinstance(arg, int) or len(arg) == 2, message.format(arg_name) @@ -3824,17 +4339,16 @@ def unfold(input, kernel_size, dilation=1, padding=0, stride=1): if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): return handle_torch_function( - unfold, (input,), input, kernel_size, dilation=dilation, - padding=padding, stride=stride) + unfold, (input,), input, kernel_size, dilation=dilation, padding=padding, stride=stride + ) if input.dim() == 4: - msg = '{} must be int or 2-tuple for 4D input' - assert_int_or_pair(kernel_size, 'kernel_size', msg) - assert_int_or_pair(dilation, 'dilation', msg) - assert_int_or_pair(padding, 'padding', msg) - assert_int_or_pair(stride, 'stride', msg) - - return torch._C._nn.im2col(input, _pair(kernel_size), - _pair(dilation), _pair(padding), _pair(stride)) + msg = "{} must be int or 2-tuple for 4D input" + assert_int_or_pair(kernel_size, "kernel_size", msg) + assert_int_or_pair(dilation, "dilation", msg) + assert_int_or_pair(padding, "padding", msg) + assert_int_or_pair(stride, "stride", msg) + + return torch._C._nn.im2col(input, _pair(kernel_size), _pair(dilation), _pair(padding), _pair(stride)) else: raise NotImplementedError("Input Error: Only 4D input Tensors are supported (got {}D)".format(input.dim())) @@ -3853,24 +4367,24 @@ def fold(input, output_size, kernel_size, dilation=1, padding=0, stride=1): if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): return handle_torch_function( - fold, (input,), input, output_size, kernel_size, dilation=dilation, - padding=padding, stride=stride) + fold, (input,), input, output_size, kernel_size, dilation=dilation, padding=padding, stride=stride + ) if input.dim() == 3: - msg = '{} must be int or 2-tuple for 3D input' - assert_int_or_pair(output_size, 'output_size', msg) - assert_int_or_pair(kernel_size, 'kernel_size', msg) - assert_int_or_pair(dilation, 'dilation', msg) - assert_int_or_pair(padding, 'padding', msg) - assert_int_or_pair(stride, 'stride', msg) - - return torch._C._nn.col2im(input, _pair(output_size), _pair(kernel_size), - _pair(dilation), _pair(padding), _pair(stride)) + msg = "{} must be int or 2-tuple for 3D input" + assert_int_or_pair(output_size, "output_size", msg) + assert_int_or_pair(kernel_size, "kernel_size", msg) + assert_int_or_pair(dilation, "dilation", msg) + assert_int_or_pair(padding, "padding", msg) + assert_int_or_pair(stride, "stride", msg) + + return torch._C._nn.col2im( + input, _pair(output_size), _pair(kernel_size), _pair(dilation), _pair(padding), _pair(stride) + ) else: raise NotImplementedError("Input Error: Only 3D input Tensors are supported (got {}D)".format(input.dim())) -def _pad_circular(input, padding): - # type: (Tensor, List[int]) -> Tensor +def _pad_circular(input: Tensor, padding: List[int]) -> Tensor: """Circularly pads tensor. Tensor values at the beginning are used to pad the end, and values at the @@ -3924,21 +4438,19 @@ def _pad_circular(input, padding): for idx, size in enumerate(paddable_shape): # Only supports wrapping around once - assert padding[-(idx * 2 + 1)] <= size, \ - "Padding value causes wrapping around more than once." - assert padding[-(idx * 2 + 2)] <= size, \ - "Padding value causes wrapping around more than once." + assert padding[-(idx * 2 + 1)] <= size, "Padding value causes wrapping around more than once." + assert padding[-(idx * 2 + 2)] <= size, "Padding value causes wrapping around more than once." # Negative padding should not result in negative sizes - assert padding[-(idx * 2 + 1)] + padding[-(idx * 2 + 2)] + size >= 0, \ - "Negative padding value is resulting in an empty dimension." + assert ( + padding[-(idx * 2 + 1)] + padding[-(idx * 2 + 2)] + size >= 0 + ), "Negative padding value is resulting in an empty dimension." # Get shape of padded tensor out_shape = in_shape[:2] for idx, size in enumerate(paddable_shape): out_shape += (size + padding[-(idx * 2 + 1)] + padding[-(idx * 2 + 2)],) - out = torch.empty(out_shape, dtype=input.dtype, layout=input.layout, - device=input.device) + out = torch.empty(out_shape, dtype=input.dtype, layout=input.layout, device=input.device) # Put original array in padded array if ndim == 1: @@ -3962,8 +4474,7 @@ def _pad_circular(input, padding): in_h0 = max(-padding[-4], 0) in_h1 = in_shape[3] - max(-padding[-3], 0) - out[..., out_d0:out_d1, out_h0:out_h1] = \ - input[..., in_d0:in_d1, in_h0:in_h1] + out[..., out_d0:out_d1, out_h0:out_h1] = input[..., in_d0:in_d1, in_h0:in_h1] elif ndim == 3: out_d0 = max(padding[-2], 0) out_d1 = out_shape[2] - max(padding[-1], 0) @@ -3983,8 +4494,7 @@ def _pad_circular(input, padding): in_w0 = max(-padding[-6], 0) in_w1 = in_shape[4] - max(-padding[-5], 0) - out[..., out_d0:out_d1, out_h0:out_h1, out_w0:out_w1] = \ - input[..., in_d0:in_d1, in_h0:in_h1, in_w0:in_w1] + out[..., out_d0:out_d1, out_h0:out_h1, out_w0:out_w1] = input[..., in_d0:in_d1, in_h0:in_h1, in_w0:in_w1] # The following steps first pad the beginning of the tensor (left side), # and then pad the end of the tensor (right side). @@ -4014,15 +4524,13 @@ def _pad_circular(input, padding): i1 = out_shape[3] - max(padding[-3], 0) o0 = 0 o1 = padding[-4] - out[:, :, :, o0:o1] = \ - out[:, :, :, i0:i1] + out[:, :, :, o0:o1] = out[:, :, :, i0:i1] if padding[-3] > 0: i0 = max(padding[-4], 0) i1 = max(padding[-4], 0) + padding[-3] o0 = out_shape[3] - padding[-3] o1 = out_shape[3] - out[:, :, :, o0:o1] = \ - out[:, :, :, i0:i1] + out[:, :, :, o0:o1] = out[:, :, :, i0:i1] # Pad third dimension (width) if len(padding) > 4: @@ -4031,43 +4539,42 @@ def _pad_circular(input, padding): i1 = out_shape[4] - max(padding[-5], 0) o0 = 0 o1 = padding[-6] - out[:, :, :, :, o0:o1] = \ - out[:, :, :, :, i0:i1] + out[:, :, :, :, o0:o1] = out[:, :, :, :, i0:i1] if padding[-5] > 0: i0 = max(padding[-6], 0) i1 = max(padding[-6], 0) + padding[-5] o0 = out_shape[4] - padding[-5] o1 = out_shape[4] - out[:, :, :, :, o0:o1] = \ - out[:, :, :, :, i0:i1] + out[:, :, :, :, o0:o1] = out[:, :, :, :, i0:i1] return out -def multi_head_attention_forward(query: Tensor, - key: Tensor, - value: Tensor, - embed_dim_to_check: int, - num_heads: int, - in_proj_weight: Tensor, - in_proj_bias: Tensor, - bias_k: Optional[Tensor], - bias_v: Optional[Tensor], - add_zero_attn: bool, - dropout_p: float, - out_proj_weight: Tensor, - out_proj_bias: Tensor, - training: bool = True, - key_padding_mask: Optional[Tensor] = None, - need_weights: bool = True, - attn_mask: Optional[Tensor] = None, - use_separate_proj_weight: bool = False, - q_proj_weight: Optional[Tensor] = None, - k_proj_weight: Optional[Tensor] = None, - v_proj_weight: Optional[Tensor] = None, - static_k: Optional[Tensor] = None, - static_v: Optional[Tensor] = None - ) -> Tuple[Tensor, Optional[Tensor]]: +def multi_head_attention_forward( + query: Tensor, + key: Tensor, + value: Tensor, + embed_dim_to_check: int, + num_heads: int, + in_proj_weight: Tensor, + in_proj_bias: Tensor, + bias_k: Optional[Tensor], + bias_v: Optional[Tensor], + add_zero_attn: bool, + dropout_p: float, + out_proj_weight: Tensor, + out_proj_bias: Tensor, + training: bool = True, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + use_separate_proj_weight: bool = False, + q_proj_weight: Optional[Tensor] = None, + k_proj_weight: Optional[Tensor] = None, + v_proj_weight: Optional[Tensor] = None, + static_k: Optional[Tensor] = None, + static_v: Optional[Tensor] = None, +) -> Tuple[Tensor, Optional[Tensor]]: r""" Args: query, key, value: map a query and a set of key-value pairs to an output. @@ -4125,18 +4632,35 @@ def multi_head_attention_forward(query: Tensor, L is the target sequence length, S is the source sequence length. """ if not torch.jit.is_scripting(): - tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, - out_proj_weight, out_proj_bias) + tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias) if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): return handle_torch_function( - multi_head_attention_forward, tens_ops, query, key, value, - embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, - bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight, - out_proj_bias, training=training, key_padding_mask=key_padding_mask, - need_weights=need_weights, attn_mask=attn_mask, + multi_head_attention_forward, + tens_ops, + query, + key, + value, + embed_dim_to_check, + num_heads, + in_proj_weight, + in_proj_bias, + bias_k, + bias_v, + add_zero_attn, + dropout_p, + out_proj_weight, + out_proj_bias, + training=training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, use_separate_proj_weight=use_separate_proj_weight, - q_proj_weight=q_proj_weight, k_proj_weight=k_proj_weight, - v_proj_weight=v_proj_weight, static_k=static_k, static_v=static_v) + q_proj_weight=q_proj_weight, + k_proj_weight=k_proj_weight, + v_proj_weight=v_proj_weight, + static_k=static_k, + static_v=static_v, + ) tgt_len, bsz, embed_dim = query.size() assert embed_dim == embed_dim_to_check # allow MHA to have different sizes for the feature dimension @@ -4151,7 +4675,7 @@ def multi_head_attention_forward(query: Tensor, # self-attention q, k, v = linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1) - elif (key is value or torch.equal(key, value)): + elif key is value or torch.equal(key, value): # encoder-decoder attention # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias @@ -4219,8 +4743,8 @@ def multi_head_attention_forward(query: Tensor, if in_proj_bias is not None: q = linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim]) - k = linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim:(embed_dim * 2)]) - v = linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2):]) + k = linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim : (embed_dim * 2)]) + v = linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2) :]) else: q = linear(query, q_proj_weight_non_opt, in_proj_bias) k = linear(key, k_proj_weight_non_opt, in_proj_bias) @@ -4228,9 +4752,13 @@ def multi_head_attention_forward(query: Tensor, q = q * scaling if attn_mask is not None: - assert attn_mask.dtype == torch.float32 or attn_mask.dtype == torch.float64 or \ - attn_mask.dtype == torch.float16 or attn_mask.dtype == torch.uint8 or attn_mask.dtype == torch.bool, \ - 'Only float, byte, and bool types are supported for attn_mask, not {}'.format(attn_mask.dtype) + assert ( + attn_mask.dtype == torch.float32 + or attn_mask.dtype == torch.float64 + or attn_mask.dtype == torch.float16 + or attn_mask.dtype == torch.uint8 + or attn_mask.dtype == torch.bool + ), "Only float, byte, and bool types are supported for attn_mask, not {}".format(attn_mask.dtype) if attn_mask.dtype == torch.uint8: warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.") attn_mask = attn_mask.to(torch.bool) @@ -4238,17 +4766,19 @@ def multi_head_attention_forward(query: Tensor, if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError('The size of the 2D attn_mask is not correct.') + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [bsz * num_heads, query.size(0), key.size(0)]: - raise RuntimeError('The size of the 3D attn_mask is not correct.') + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim())) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: - warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.") + warnings.warn( + "Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead." + ) key_padding_mask = key_padding_mask.to(torch.bool) if bias_k is not None and bias_v is not None: @@ -4302,21 +4832,19 @@ def multi_head_attention_forward(query: Tensor, if attn_mask is not None: if attn_mask.dtype == torch.bool: - attn_output_weights.masked_fill_(attn_mask, float('-inf')) + attn_output_weights.masked_fill_(attn_mask, float("-inf")) else: attn_output_weights += attn_mask - if key_padding_mask is not None: attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) attn_output_weights = attn_output_weights.masked_fill( key_padding_mask.unsqueeze(1).unsqueeze(2), - float('-inf'), + float("-inf"), ) attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len) - attn_output_weights = softmax( - attn_output_weights, dim=-1) + attn_output_weights = softmax(attn_output_weights, dim=-1) attn_output_weights = dropout(attn_output_weights, p=dropout_p, training=training) attn_output = torch.bmm(attn_output_weights, v) diff --git a/torch/nn/functional.pyi.in b/torch/nn/functional.pyi.in index 208dc7c2df40..11eb4c404dc6 100644 --- a/torch/nn/functional.pyi.in +++ b/torch/nn/functional.pyi.in @@ -1,7 +1,7 @@ from torch import Tensor from torch.types import _size from typing import Any, Optional, Tuple, Dict, List, Callable, Sequence, Union -from .common_types import _ratio_any_t, _size_1_t, _size_2_t, _size_3_t, _size_2_opt_t, _size_3_opt_t +from .common_types import _ratio_any_t, _size_any_t, _size_1_t, _size_2_t, _size_3_t, _size_2_opt_t, _size_3_opt_t # 'TypedDict' is a new accepted type that represents a dictionary with a fixed set of allowed keys. # It is standards-track but not in `typing` yet. We leave this hear to be uncommented once the feature @@ -335,12 +335,12 @@ def normalize(input: Tensor, p: float = ..., dim: int = ..., eps: float = ..., def assert_int_or_pair(arg: Any, arg_name: Any, message: Any) -> None: ... -def unfold(input: Tensor, kernel_size: _size, dilation: _size = ..., padding: _size = ..., - stride: _size = ...) -> Tensor: ... +def unfold(input: Tensor, kernel_size: _size_any_t, dilation: _size_any_t = ..., padding: _size_any_t = ..., + stride: _size_any_t = ...) -> Tensor: ... -def fold(input: Tensor, output_size: _size, kernel_size: _size, dilation: _size = ..., padding: _size = ..., - stride: _size = ...) -> Tensor: ... +def fold(input: Tensor, output_size: _size_any_t, kernel_size: _size_any_t, dilation: _size_any_t = ..., padding: _size_any_t = ..., + stride: _size_any_t = ...) -> Tensor: ... def multi_head_attention_forward(query: Tensor, diff --git a/torch/nn/init.py b/torch/nn/init.py index c11dba648c5a..3c4149ff8e81 100644 --- a/torch/nn/init.py +++ b/torch/nn/init.py @@ -106,8 +106,7 @@ def calculate_gain(nonlinearity, param=None): raise ValueError("Unsupported nonlinearity {}".format(nonlinearity)) -def uniform_(tensor, a=0., b=1.): - # type: (Tensor, float, float) -> Tensor +def uniform_(tensor: Tensor, a: float = 0., b: float = 1.) -> Tensor: r"""Fills the input Tensor with values drawn from the uniform distribution :math:`\mathcal{U}(a, b)`. @@ -123,8 +122,7 @@ def uniform_(tensor, a=0., b=1.): return _no_grad_uniform_(tensor, a, b) -def normal_(tensor, mean=0., std=1.): - # type: (Tensor, float, float) -> Tensor +def normal_(tensor: Tensor, mean: float = 0., std: float = 1.) -> Tensor: r"""Fills the input Tensor with values drawn from the normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`. @@ -139,8 +137,7 @@ def normal_(tensor, mean=0., std=1.): """ return _no_grad_normal_(tensor, mean, std) -def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): - # type: (Tensor, float, float, float, float) -> Tensor +def trunc_normal_(tensor: Tensor, mean: float = 0., std: float = 1., a: float = -2., b: float = 2.) -> Tensor: r"""Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn from the normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` @@ -162,8 +159,7 @@ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): return _no_grad_trunc_normal_(tensor, mean, std, a, b) -def constant_(tensor, val): - # type: (Tensor, float) -> Tensor +def constant_(tensor: Tensor, val: float) -> Tensor: r"""Fills the input Tensor with the value :math:`\text{val}`. Args: @@ -177,8 +173,7 @@ def constant_(tensor, val): return _no_grad_fill_(tensor, val) -def ones_(tensor): - # type: (Tensor) -> Tensor +def ones_(tensor: Tensor) -> Tensor: r"""Fills the input Tensor with the scalar value `1`. Args: @@ -191,8 +186,7 @@ def ones_(tensor): return _no_grad_fill_(tensor, 1.) -def zeros_(tensor): - # type: (Tensor) -> Tensor +def zeros_(tensor: Tensor) -> Tensor: r"""Fills the input Tensor with the scalar value `0`. Args: @@ -284,8 +278,7 @@ def _calculate_fan_in_and_fan_out(tensor): return fan_in, fan_out -def xavier_uniform_(tensor, gain=1.): - # type: (Tensor, float) -> Tensor +def xavier_uniform_(tensor: Tensor, gain: float = 1.) -> Tensor: r"""Fills the input `Tensor` with values according to the method described in `Understanding the difficulty of training deep feedforward neural networks` - Glorot, X. & Bengio, Y. (2010), using a uniform @@ -312,8 +305,7 @@ def xavier_uniform_(tensor, gain=1.): return _no_grad_uniform_(tensor, -a, a) -def xavier_normal_(tensor, gain=1.): - # type: (Tensor, float) -> Tensor +def xavier_normal_(tensor: Tensor, gain: float = 1.) -> Tensor: r"""Fills the input `Tensor` with values according to the method described in `Understanding the difficulty of training deep feedforward neural networks` - Glorot, X. & Bengio, Y. (2010), using a normal diff --git a/torch/nn/modules/_functions.py b/torch/nn/modules/_functions.py index b4edae77e0a5..e9424673dda1 100644 --- a/torch/nn/modules/_functions.py +++ b/torch/nn/modules/_functions.py @@ -9,13 +9,14 @@ class SyncBatchNorm(Function): def forward(self, input, weight, bias, running_mean, running_var, eps, momentum, process_group, world_size): input = input.contiguous() - count = torch.empty(1, - dtype=running_mean.dtype, - device=input.device).fill_(input.numel() // input.size(1)) - # calculate mean/invstd for input. mean, invstd = torch.batch_norm_stats(input, eps) + count = torch.full((1,), input.numel() // input.size(1), + dtype=mean.dtype, + device=mean.device) + + num_channels = input.shape[1] # C, C, 1 -> (2C + 1) combined = torch.cat([mean, invstd, count], dim=0) diff --git a/torch/nn/modules/batchnorm.py b/torch/nn/modules/batchnorm.py index 48e58d637ea6..64417069e2b7 100644 --- a/torch/nn/modules/batchnorm.py +++ b/torch/nn/modules/batchnorm.py @@ -54,7 +54,7 @@ def __init__( def reset_running_stats(self) -> None: if self.track_running_stats: - # running_mean/running_var/num_batches... are registerd at runtime depending + # running_mean/running_var/num_batches... are registered at runtime depending # if self.track_running_stats is on self.running_mean.zero_() # type: ignore[operator] self.running_var.fill_(1) # type: ignore[operator] diff --git a/torch/nn/quantized/dynamic/modules/linear.py b/torch/nn/quantized/dynamic/modules/linear.py index f220aed075c1..527ee76fdc76 100644 --- a/torch/nn/quantized/dynamic/modules/linear.py +++ b/torch/nn/quantized/dynamic/modules/linear.py @@ -1,5 +1,4 @@ import torch -from ....modules.linear import Linear as NNLinear import torch.nn.quantized as nnq from torch.nn.quantized.modules.utils import _quantize_weight @@ -80,7 +79,10 @@ def from_float(cls, mod): mod (Module): a float module, either produced by torch.quantization utilities or provided by the user """ - assert type(mod) == NNLinear, 'nn.quantized.dynamic.Linear.from_float only works for nn.Linear' + float_modules = [torch.nn.Linear, torch.nn.modules.linear._LinearWithBias] + assert type(mod) in float_modules, \ + 'nn.quantized.dynamic.Linear.from_float only works for one of' + \ + str([float_mod.__name__ for float_mod in float_modules]) assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined' if mod.qconfig is not None and mod.qconfig.weight is not None: weight_observer = mod.qconfig.weight() diff --git a/torch/nn/quantized/modules/conv.py b/torch/nn/quantized/modules/conv.py index 00ceba7ab367..b3bc78ff6941 100644 --- a/torch/nn/quantized/modules/conv.py +++ b/torch/nn/quantized/modules/conv.py @@ -1,7 +1,7 @@ # coding=utf-8 r"""Quantized convolution modules.""" -from typing import Optional, List +from typing import Optional, List, TypeVar import torch import torch.nn as nn @@ -16,11 +16,17 @@ class _ConvNd(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, stride, - padding, dilation, - transposed, output_padding, - groups, bias, + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'): + # All subclasses have this signature - See PR #49702s + raise NotImplementedError + + def _init(self, in_channels, out_channels, kernel_size, stride, + padding, dilation, + transposed, output_padding, + groups, bias, + padding_mode='zeros'): super(_ConvNd, self).__init__() if padding_mode != 'zeros': raise NotImplementedError( @@ -54,6 +60,15 @@ def __init__(self, in_channels, out_channels, kernel_size, stride, self.scale = 1.0 self.zero_point = 0 + def set_weight_bias(self, qweight, bias_float): + raise NotImplementedError + + def bias(self): + raise NotImplementedError + + def _weight_bias(self): + raise NotImplementedError + def extra_repr(self): s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}' ', stride={stride}, scale={scale}, zero_point={zero_point}') @@ -155,7 +170,8 @@ def get_qconv(cls, mod, activation_post_process, weight_post_process=None): assert weight_post_process.dtype == torch.qint8, \ 'Weight observer must have a dtype of qint8' qweight = _quantize_weight(mod.weight.float(), weight_post_process) - qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size, + # the __init__ call used is the one from derived classes and not the one from _ConvNd + qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size, # type: ignore[call-arg] mod.stride, mod.padding, mod.dilation, mod.groups, mod.bias is not None, mod.padding_mode) qconv.set_weight_bias(qweight, mod.bias) @@ -233,7 +249,9 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding = _pair_from_first(padding) dilation = _pair_from_first(dilation) - super(Conv1d, self).__init__( + # Subclasses of _ConvNd needs to call _init rather than __init__. See + # discussion on PR #49702 + super(Conv1d, self)._init( in_channels, out_channels, kernel_size, stride, padding, dilation, False, _single(0), groups, bias, padding_mode) @@ -319,7 +337,9 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, stride = _pair(stride) padding = _pair(padding) dilation = _pair(dilation) - super(Conv2d, self).__init__( + # Subclasses of _ConvNd need to call _init rather than __init__. See + # discussion on PR #49702 + super(Conv2d, self)._init( in_channels, out_channels, kernel_size, stride, padding, dilation, False, _pair(0), groups, bias, padding_mode) @@ -403,7 +423,9 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, stride = _triple(stride) padding = _triple(padding) dilation = _triple(dilation) - super(Conv3d, self).__init__( + # Subclasses of _ConvNd need to call _init rather than __init__. See + # discussion on PR #49702 + super(Conv3d, self)._init( in_channels, out_channels, kernel_size, stride, padding, dilation, False, _triple(0), groups, bias, padding_mode) @@ -450,15 +472,20 @@ def from_float(cls, mod): return cls.get_qconv(mod, activation_post_process) # === Transposed Convolutions === +MOD = TypeVar('MOD', bound=nn.modules.conv._ConvNd) class _ConvTransposeNd(_ConvNd): + + _FLOAT_MODULE = MOD + def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, groups, bias, padding_mode): if padding_mode != 'zeros': raise ValueError('Only "zeros" padding mode is supported for {}'.format(self.__class__.__name__)) - - super(_ConvTransposeNd, self).__init__( + # Subclasses of _ConvNd need to call _init rather than __init__. See + # discussion on PR #49702 + super(_ConvTransposeNd, self)._init( in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, groups, bias, padding_mode) @@ -477,9 +504,10 @@ def from_float(cls, mod): mod (Module): a float module, either produced by torch.quantization utilities or provided by the user """ - assert type(mod) == cls._FLOAT_MODULE, \ - ' nnq.' + cls.__name__ + '.from_float only works for ' + \ - cls._FLOAT_MODULE.__name__ + # derived classes override cls._FLOAT_MODULE attribute + msg = ' nnq.' + cls.__name__ + '.from_float only works for ' + \ + cls._FLOAT_MODULE.__name__ + assert type(mod) == cls._FLOAT_MODULE, msg assert hasattr(mod, 'qconfig'), \ 'Input float module must have qconfig defined.' weight_post_process = mod.qconfig.weight() @@ -488,7 +516,8 @@ def from_float(cls, mod): assert weight_post_process.dtype == torch.qint8, \ 'Weight observer must have a dtype of qint8' qweight = _quantize_weight(mod.weight.float(), weight_post_process) - qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size, + # the __init__ call used is the one from derived classes and not the one from _ConvTransposeNd + qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size, # type: ignore[call-arg] mod.stride, mod.padding, mod.output_padding, mod.groups, mod.bias is not None, mod.dilation, mod.padding_mode) qconv.set_weight_bias(qweight, mod.bias) diff --git a/torch/nn/quantized/modules/linear.py b/torch/nn/quantized/modules/linear.py index c5a26c777776..d7f86ccb6216 100644 --- a/torch/nn/quantized/modules/linear.py +++ b/torch/nn/quantized/modules/linear.py @@ -1,3 +1,4 @@ +from collections.abc import Iterable import torch import torch.nn as nn @@ -124,7 +125,7 @@ class Linear(torch.nn.Module): torch.Size([128, 30]) """ _version = 3 - _FLOAT_MODULE = nn.Linear + _FLOAT_MODULE = (nn.Linear, nn.modules.linear._LinearWithBias) def __init__(self, in_features, out_features, bias_=True, dtype=torch.qint8): @@ -252,8 +253,14 @@ def from_float(cls, mod): weight_post_process = mod.weight_fake_quant activation_post_process = mod.activation_post_process else: - assert type(mod) == cls._FLOAT_MODULE, ' nnq.' + cls.__name__ + '.from_float only works for ' + \ - cls._FLOAT_MODULE.__name__ + # This function does not participate in JIT, so it is OK to ignore + # the type mismatch in assignment. Also, mypy has an issue with + # iterables not being implemented, so we are ignoring those too. + if not isinstance(cls._FLOAT_MODULE, Iterable): + cls._FLOAT_MODULE = [cls._FLOAT_MODULE] # type: ignore + supported_modules = ', '.join([float_mod.__name__ for float_mod in cls._FLOAT_MODULE]) # type: ignore + error_msg = 'nnq.{}.from_float only works for {}'.format(cls.__name__, supported_modules) + assert type(mod) in cls._FLOAT_MODULE, error_msg.format() # type: ignore assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined' activation_post_process = mod.activation_post_process if type(mod) == nni.LinearReLU: diff --git a/torch/nn/utils/prune.py b/torch/nn/utils/prune.py index 851a551da0d8..c21940689a77 100644 --- a/torch/nn/utils/prune.py +++ b/torch/nn/utils/prune.py @@ -1,13 +1,14 @@ r""" Pruning methods """ -from abc import abstractmethod import numbers -import torch -from abc import ABC +from abc import ABC, abstractmethod from collections.abc import Iterable from typing import Tuple +import torch + + class BasePruningMethod(ABC): r"""Abstract base class for creation of new pruning techniques. @@ -40,7 +41,8 @@ def compute_mask(self, t, default_mask): method recipe. Args: - t (torch.Tensor): tensor representing the parameter to prune + t (torch.Tensor): tensor representing the importance scores of the + parameter to prune. default_mask (torch.Tensor): Base mask from previous pruning iterations, that need to be respected after the new mask is applied. Same dims as ``t``. @@ -64,9 +66,7 @@ def apply_mask(self, module): """ # to carry out the multiplication, the mask needs to have been computed, # so the pruning method must know what tensor it's operating on - assert ( - self._tensor_name is not None - ), "Module {} has to be pruned".format( + assert self._tensor_name is not None, "Module {} has to be pruned".format( module ) # this gets set in apply() mask = getattr(module, self._tensor_name + "_mask") @@ -75,7 +75,7 @@ def apply_mask(self, module): return pruned_tensor @classmethod - def apply(cls, module, name, *args, **kwargs): + def apply(cls, module, name, *args, importance_scores=None, **kwargs): r"""Adds the forward pre-hook that enables pruning on the fly and the reparametrization of a tensor in terms of the original tensor and the pruning mask. @@ -86,6 +86,11 @@ def apply(cls, module, name, *args, **kwargs): will act. args: arguments passed on to a subclass of :class:`BasePruningMethod` + importance_scores (torch.Tensor): tensor of importance scores (of + same shape as module parameter) used to compute mask for pruning. + The values in this tensor indicate the importance of the + corresponding elements in the parameter being pruned. + If unspecified or None, the parameter will be used in its place. kwargs: keyword arguments passed on to a subclass of a :class:`BasePruningMethod` """ @@ -101,10 +106,7 @@ def _get_composite_method(cls, module, name, *args, **kwargs): for k, hook in module._forward_pre_hooks.items(): # if it exists, take existing thing, remove hook, then # go through normal thing - if ( - isinstance(hook, BasePruningMethod) - and hook._tensor_name == name - ): + if isinstance(hook, BasePruningMethod) and hook._tensor_name == name: old_method = hook hooks_to_remove.append(k) found += 1 @@ -150,8 +152,18 @@ def _get_composite_method(cls, module, name, *args, **kwargs): # Pruning is to be applied to the module's tensor named `name`, # starting from the state it is found in prior to this iteration of - # pruning + # pruning. The pruning mask is calculated based on importances scores. + orig = getattr(module, name) + if importance_scores is not None: + assert ( + importance_scores.shape == orig.shape + ), "importance_scores should have the same shape as parameter \ + {} of {}".format( + name, module + ) + else: + importance_scores = orig # If this is the first time pruning is applied, take care of moving # the original tensor to a new parameter called name + '_orig' and @@ -166,13 +178,17 @@ def _get_composite_method(cls, module, name, *args, **kwargs): # has been done before in a previous pruning iteration, so we're good # to go else: - default_mask = getattr(module, name + "_mask").detach().clone(memory_format=torch.contiguous_format) + default_mask = ( + getattr(module, name + "_mask") + .detach() + .clone(memory_format=torch.contiguous_format) + ) # Use try/except because if anything goes wrong with the mask # computation etc., you'd want to roll back. try: # get the final mask, computed according to the specific method - mask = method.compute_mask(orig, default_mask=default_mask) + mask = method.compute_mask(importance_scores, default_mask=default_mask) # reparametrize by saving mask to `module[name + '_mask']`... module.register_buffer(name + "_mask", mask) # ... and the new pruned tensor to `module[name]` @@ -190,13 +206,18 @@ def _get_composite_method(cls, module, name, *args, **kwargs): return method - def prune(self, t, default_mask=None): + def prune(self, t, default_mask=None, importance_scores=None): r"""Computes and returns a pruned version of input tensor ``t`` according to the pruning rule specified in :meth:`compute_mask`. Args: t (torch.Tensor): tensor to prune (of same dimensions as ``default_mask``). + importance_scores (torch.Tensor): tensor of importance scores (of + same shape as ``t``) used to compute mask for pruning ``t``. + The values in this tensor indicate the importance of the + corresponding elements in the ``t`` that is being pruned. + If unspecified or None, the tensor ``t`` will be used in its place. default_mask (torch.Tensor, optional): mask from previous pruning iteration, if any. To be considered when determining what portion of the tensor that pruning should act on. If None, @@ -205,9 +226,14 @@ def prune(self, t, default_mask=None): Returns: pruned version of tensor ``t``. """ - if default_mask is None: - default_mask = torch.ones_like(t) - return t * self.compute_mask(t, default_mask=default_mask) + if importance_scores is not None: + assert ( + importance_scores.shape == t.shape + ), "importance_scores should have the same shape as tensor t" + else: + importance_scores = t + default_mask = default_mask if default_mask is not None else torch.ones_like(t) + return t * self.compute_mask(importance_scores, default_mask=default_mask) def remove(self, module): r"""Removes the pruning reparameterization from a module. The pruned @@ -249,7 +275,7 @@ class PruningContainer(BasePruningMethod): """ def __init__(self, *args): - self._pruning_methods: Tuple['BasePruningMethod', ...] = tuple() + self._pruning_methods: Tuple["BasePruningMethod", ...] = tuple() if not isinstance(args, Iterable): # only 1 item self._tensor_name = args._tensor_name self.add_pruning_method(args) @@ -319,6 +345,7 @@ def compute_mask(self, t, default_mask): pruning ``method`` (of same dimensions as ``default_mask`` and ``t``). """ + def _combine_masks(method, t, mask): r""" Args: @@ -360,13 +387,12 @@ def _combine_masks(method, t, mask): # if dim is still negative after subtracting it from n_dims if dim < 0: raise IndexError( - 'Index is out of bounds for tensor with dimensions {}' - .format(n_dims) + "Index is out of bounds for tensor with dimensions {}".format( + n_dims + ) ) # find channels along dim = dim that aren't already tots 0ed out - keep_channel = ( - mask.sum(dim=[d for d in range(n_dims) if d != dim]) != 0 - ) + keep_channel = mask.sum(dim=[d for d in range(n_dims) if d != dim]) != 0 # create slice to identify what to prune slc = [slice(None)] * n_dims slc[dim] = keep_channel @@ -470,9 +496,7 @@ def apply(cls, module, name, amount): fraction of parameters to prune. If ``int``, it represents the absolute number of parameters to prune. """ - return super(RandomUnstructured, cls).apply( - module, name, amount=amount - ) + return super(RandomUnstructured, cls).apply(module, name, amount=amount) class L1Unstructured(BasePruningMethod): @@ -509,16 +533,14 @@ def compute_mask(self, t, default_mask): if nparams_toprune != 0: # k=0 not supported by torch.kthvalue # largest=True --> top k; largest=False --> bottom k # Prune the smallest k - topk = torch.topk( - torch.abs(t).view(-1), k=nparams_toprune, largest=False - ) + topk = torch.topk(torch.abs(t).view(-1), k=nparams_toprune, largest=False) # topk will have .indices and .values mask.view(-1)[topk.indices] = 0 return mask @classmethod - def apply(cls, module, name, amount): + def apply(cls, module, name, amount, importance_scores=None): r"""Adds the forward pre-hook that enables pruning on the fly and the reparametrization of a tensor in terms of the original tensor and the pruning mask. @@ -531,8 +553,15 @@ def apply(cls, module, name, amount): If ``float``, should be between 0.0 and 1.0 and represent the fraction of parameters to prune. If ``int``, it represents the absolute number of parameters to prune. + importance_scores (torch.Tensor): tensor of importance scores (of same + shape as module parameter) used to compute mask for pruning. + The values in this tensor indicate the importance of the corresponding + elements in the parameter being pruned. + If unspecified or None, the module parameter will be used in its place. """ - return super(L1Unstructured, cls).apply(module, name, amount=amount) + return super(L1Unstructured, cls).apply( + module, name, amount=amount, importance_scores=importance_scores + ) class RandomStructured(BasePruningMethod): @@ -634,9 +663,7 @@ def apply(cls, module, name, amount, dim=-1): dim (int, optional): index of the dim along which we define channels to prune. Default: -1. """ - return super(RandomStructured, cls).apply( - module, name, amount=amount, dim=dim - ) + return super(RandomStructured, cls).apply(module, name, amount=amount, dim=dim) class LnStructured(BasePruningMethod): @@ -705,11 +732,7 @@ def compute_mask(self, t, default_mask): norm = _compute_norm(t, self.n, self.dim) # largest=True --> top k; largest=False --> bottom k # Keep the largest k channels along dim=self.dim - topk = torch.topk( - norm, - k=nparams_tokeep, - largest=True, - ) + topk = torch.topk(norm, k=nparams_tokeep, largest=True) # topk will have .indices and .values # Compute binary mask by initializing it to all 0s and then filling in @@ -737,7 +760,7 @@ def make_mask(t, dim, indices): return mask @classmethod - def apply(cls, module, name, amount, n, dim): + def apply(cls, module, name, amount, n, dim, importance_scores=None): r"""Adds the forward pre-hook that enables pruning on the fly and the reparametrization of a tensor in terms of the original tensor and the pruning mask. @@ -754,9 +777,19 @@ def apply(cls, module, name, amount, n, dim): entries for argument ``p`` in :func:`torch.norm`. dim (int): index of the dim along which we define channels to prune. + importance_scores (torch.Tensor): tensor of importance scores (of same + shape as module parameter) used to compute mask for pruning. + The values in this tensor indicate the importance of the corresponding + elements in the parameter being pruned. + If unspecified or None, the module parameter will be used in its place. """ return super(LnStructured, cls).apply( - module, name, amount=amount, n=n, dim=dim + module, + name, + amount=amount, + n=n, + dim=dim, + importance_scores=importance_scores, ) @@ -783,9 +816,7 @@ def apply(cls, module, name, mask): name (str): parameter name within ``module`` on which pruning will act. """ - return super(CustomFromMask, cls).apply( - module, name, mask - ) + return super(CustomFromMask, cls).apply(module, name, mask=mask) def identity(module, name): @@ -852,7 +883,7 @@ def random_unstructured(module, name, amount): return module -def l1_unstructured(module, name, amount): +def l1_unstructured(module, name, amount, importance_scores=None): r"""Prunes tensor corresponding to parameter called ``name`` in ``module`` by removing the specified `amount` of (currently unpruned) units with the lowest L1-norm. @@ -872,6 +903,11 @@ def l1_unstructured(module, name, amount): If ``float``, should be between 0.0 and 1.0 and represent the fraction of parameters to prune. If ``int``, it represents the absolute number of parameters to prune. + importance_scores (torch.Tensor): tensor of importance scores (of same + shape as module parameter) used to compute mask for pruning. + The values in this tensor indicate the importance of the corresponding + elements in the parameter being pruned. + If unspecified or None, the module parameter will be used in its place. Returns: module (nn.Module): modified (i.e. pruned) version of the input module @@ -881,7 +917,9 @@ def l1_unstructured(module, name, amount): >>> m.state_dict().keys() odict_keys(['bias', 'weight_orig', 'weight_mask']) """ - L1Unstructured.apply(module, name, amount) + L1Unstructured.apply( + module, name, amount=amount, importance_scores=importance_scores + ) return module @@ -922,7 +960,7 @@ def random_structured(module, name, amount, dim): return module -def ln_structured(module, name, amount, n, dim): +def ln_structured(module, name, amount, n, dim, importance_scores=None): r"""Prunes tensor corresponding to parameter called ``name`` in ``module`` by removing the specified ``amount`` of (currently unpruned) channels along the specified ``dim`` with the lowest L``n``-norm. @@ -945,6 +983,11 @@ def ln_structured(module, name, amount, n, dim): n (int, float, inf, -inf, 'fro', 'nuc'): See documentation of valid entries for argument ``p`` in :func:`torch.norm`. dim (int): index of the dim along which we define channels to prune. + importance_scores (torch.Tensor): tensor of importance scores (of same + shape as module parameter) used to compute mask for pruning. + The values in this tensor indicate the importance of the corresponding + elements in the parameter being pruned. + If unspecified or None, the module parameter will be used in its place. Returns: module (nn.Module): modified (i.e. pruned) version of the input module @@ -954,11 +997,13 @@ def ln_structured(module, name, amount, n, dim): nn.Conv2d(5, 3, 2), 'weight', amount=0.3, dim=1, n=float('-inf') ) """ - LnStructured.apply(module, name, amount, n, dim) + LnStructured.apply( + module, name, amount, n, dim, importance_scores=importance_scores + ) return module -def global_unstructured(parameters, pruning_method, **kwargs): +def global_unstructured(parameters, pruning_method, importance_scores=None, **kwargs): r""" Globally prunes tensors corresponding to all parameters in ``parameters`` by applying the specified ``pruning_method``. @@ -977,6 +1022,12 @@ def global_unstructured(parameters, pruning_method, **kwargs): pruning_method (function): a valid pruning function from this module, or a custom one implemented by the user that satisfies the implementation guidelines and has ``PRUNING_TYPE='unstructured'``. + importance_scores (dict): a dictionary mapping (module, name) tuples to + the corresponding parameter's importance scores tensor. The tensor + should be the same shape as the parameter, and is used for computing + mask for pruning. + If unspecified or None, the parameter will be used in place of its + importance scores. kwargs: other keyword arguments such as: amount (int or float): quantity of parameters to prune across the specified parameters. @@ -1011,17 +1062,25 @@ def global_unstructured(parameters, pruning_method, **kwargs): """ # ensure parameters is a list or generator of tuples - assert isinstance(parameters, Iterable) + if not isinstance(parameters, Iterable): + raise TypeError("global_unstructured(): parameters is not an Iterable") - # flatten parameter values to consider them all at once in global pruning - t = torch.nn.utils.parameters_to_vector([getattr(*p) for p in parameters]) + importance_scores = importance_scores if importance_scores is not None else {} + if not isinstance(importance_scores, dict): + raise TypeError("global_unstructured(): importance_scores must be of type dict") + + # flatten importance scores to consider them all at once in global pruning + relevant_importance_scores = torch.nn.utils.parameters_to_vector( + [ + importance_scores.get((module, name), getattr(module, name)) + for (module, name) in parameters + ] + ) # similarly, flatten the masks (if they exist), or use a flattened vector # of 1s of the same dimensions as t default_mask = torch.nn.utils.parameters_to_vector( [ - getattr( - module, name + "_mask", torch.ones_like(getattr(module, name)) - ) + getattr(module, name + "_mask", torch.ones_like(getattr(module, name))) for (module, name) in parameters ] ) @@ -1044,7 +1103,7 @@ def global_unstructured(parameters, pruning_method, **kwargs): # use the `compute_mask` method from `PruningContainer` to combine the # mask computed by the new method with the pre-existing mask - final_mask = container.compute_mask(t, default_mask) + final_mask = container.compute_mask(relevant_importance_scores, default_mask) # Pointer for slicing the mask to match the shape of each parameter pointer = 0 @@ -1057,7 +1116,7 @@ def global_unstructured(parameters, pruning_method, **kwargs): param_mask = final_mask[pointer : pointer + num_param].view_as(param) # Assign the correct pre-computed mask to each parameter and add it # to the forward_pre_hooks like any other pruning method - custom_from_mask(module, name, param_mask) + custom_from_mask(module, name, mask=param_mask) # Increment the pointer to continue slicing the final_mask pointer += num_param @@ -1172,8 +1231,7 @@ def _validate_pruning_amount_init(amount): """ if not isinstance(amount, numbers.Real): raise TypeError( - "Invalid type for amount: {}. Must be int or float." - "".format(amount) + "Invalid type for amount: {}. Must be int or float." "".format(amount) ) if (isinstance(amount, numbers.Integral) and amount < 0) or ( @@ -1260,9 +1318,7 @@ def _validate_pruning_dim(t, dim): dim (int): index of the dim along which we define channels to prune """ if dim >= t.dim(): - raise IndexError( - "Invalid index {} for tensor of size {}".format(dim, t.shape) - ) + raise IndexError("Invalid index {} for tensor of size {}".format(dim, t.shape)) def _compute_norm(t, n, dim): diff --git a/torch/overrides.py b/torch/overrides.py index f8d9f2e152f6..a6ec0de5ffb5 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -803,6 +803,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.nansum: lambda input, dim=None: -1, torch.svd: lambda input, some=True, compute_uv=True, out=None: -1, torch.svd_lowrank: lambda input, q=6, niter=2, M=None: -1, + torch.linalg.svd: lambda input, full_matrices=True, compute_uv=True, out=None: -1, torch.symeig: lambda input, eigenvectors=False, upper=True, out=None: -1, torch.swapaxes: lambda input, dim0, dim1: -1, torch.swapdims: lambda input, axis0, axis1: -1, diff --git a/torch/quantization/quantization_mappings.py b/torch/quantization/quantization_mappings.py index 6c8fc9defc00..d9800552d8a7 100644 --- a/torch/quantization/quantization_mappings.py +++ b/torch/quantization/quantization_mappings.py @@ -39,6 +39,7 @@ nn.InstanceNorm3d: nnq.InstanceNorm3d, nn.LayerNorm: nnq.LayerNorm, nn.LeakyReLU: nnq.LeakyReLU, + nn.modules.linear._LinearWithBias: nnq.Linear, nn.Linear: nnq.Linear, nn.ReLU6: nnq.ReLU6, # Wrapper Modules: @@ -65,6 +66,7 @@ DEFAULT_QAT_MODULE_MAPPINGS : Dict[Callable, Any] = { nn.Conv2d: nnqat.Conv2d, nn.Linear: nnqat.Linear, + nn.modules.linear._LinearWithBias: nnqat.Linear, # Intrinsic modules: nni.ConvBn1d: nniqat.ConvBn1d, nni.ConvBn2d: nniqat.ConvBn2d, @@ -78,6 +80,7 @@ DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS : Dict[Callable, Any] = { nn.GRUCell: nnqd.GRUCell, nn.Linear: nnqd.Linear, + nn.modules.linear._LinearWithBias: nnqd.Linear, nn.LSTM: nnqd.LSTM, nn.GRU: nnqd.GRU, nn.LSTMCell: nnqd.LSTMCell, diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 119750396f1e..b01deacc8ad1 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -533,7 +533,7 @@ def sample_inputs_linalg_solve(op_info, device, dtype, requires_grad=False): return out -def sample_inputs_svd(op_info, device, dtype, requires_grad=False): +def _sample_inputs_svd(op_info, device, dtype, requires_grad=False, is_linalg_svd=False): """ This function generates input for torch.svd with distinct singular values so that autograd is always stable. Matrices of different size: @@ -546,6 +546,16 @@ def sample_inputs_svd(op_info, device, dtype, requires_grad=False): """ from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value + # svd and linalg.svd returns V and V.T, respectively. So we need to slice + # along different dimensions when needed (this is used by + # test_cases2:wide_all and wide_all_batched below) + if is_linalg_svd: + def slice_V(v): + return v[..., :(S - 2), :] + else: + def slice_V(v): + return v[..., :, :(S - 2)] + test_cases1 = ( # some=True (default) # loss functions for complex-valued svd have to be "gauge invariant", # i.e. loss functions shouldn't change when sigh of the singular vectors change. @@ -575,11 +585,11 @@ def sample_inputs_svd(op_info, device, dtype, requires_grad=False): ) test_cases2 = ( # some=False (random_fullrank_matrix_distinct_singular_value(S, dtype=dtype).to(device)[:(S - 2)], - lambda usv: (abs(usv[0]), usv[1], abs(usv[2][:, :(S - 2)]))), # 'wide_all' + lambda usv: (abs(usv[0]), usv[1], abs(slice_V(usv[2])))), # 'wide_all' (random_fullrank_matrix_distinct_singular_value(S, dtype=dtype).to(device)[:, :(S - 2)], lambda usv: (abs(usv[0][:, :(S - 2)]), usv[1], abs(usv[2]))), # 'tall_all' (random_fullrank_matrix_distinct_singular_value(S, 2, dtype=dtype).to(device)[..., :(S - 2), :], - lambda usv: (abs(usv[0]), usv[1], abs(usv[2][..., :, :(S - 2)]))), # 'wide_all_batched' + lambda usv: (abs(usv[0]), usv[1], abs(slice_V(usv[2])))), # 'wide_all_batched' (random_fullrank_matrix_distinct_singular_value(S, 2, dtype=dtype).to(device)[..., :, :(S - 2)], lambda usv: (abs(usv[0][..., :, :(S - 2)]), usv[1], abs(usv[2]))), # 'tall_all_batched' ) @@ -587,15 +597,27 @@ def sample_inputs_svd(op_info, device, dtype, requires_grad=False): out = [] for a, out_fn in test_cases1: a.requires_grad = requires_grad - out.append(SampleInput(a, output_process_fn_grad=out_fn)) + if is_linalg_svd: + kwargs = {'full_matrices': False} + else: + kwargs = {'some': True} + out.append(SampleInput(a, kwargs=kwargs, output_process_fn_grad=out_fn)) for a, out_fn in test_cases2: a.requires_grad = requires_grad - kwargs = {'some': False} + if is_linalg_svd: + kwargs = {'full_matrices': True} + else: + kwargs = {'some': False} out.append(SampleInput(a, kwargs=kwargs, output_process_fn_grad=out_fn)) return out +def sample_inputs_svd(op_info, device, dtype, requires_grad=False): + return _sample_inputs_svd(op_info, device, dtype, requires_grad, is_linalg_svd=False) + +def sample_inputs_linalg_svd(op_info, device, dtype, requires_grad=False): + return _sample_inputs_svd(op_info, device, dtype, requires_grad, is_linalg_svd=True) def sample_inputs_pinverse(op_info, device, dtype, requires_grad=False): """ @@ -806,6 +828,20 @@ def sample_inputs_fliplr_flipud(op_info, device, dtype, requires_grad): SkipInfo('TestCommon', 'test_variant_consistency_jit', device_type='cuda', dtypes=[torch.float16]), )), + UnaryUfuncInfo('exp', + ref=np_unary_ufunc_integer_promotion_wrapper(np.exp), + dtypes=all_types_and_complex_and(torch.bool, torch.half), + dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + skips=( + # Reference: https://github.com/pytorch/pytorch/pull/50093#pullrequestreview-561791547 + SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', dtypes=[torch.bfloat16]), + # Reference: https://github.com/pytorch/pytorch/issues/48010 + SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + ), + assert_autodiffed=True, + promotes_integers_to_float=True), SpectralFuncInfo('fft.fft', aten_name='fft_fft', ref=np.fft.fft, @@ -1175,6 +1211,20 @@ def sample_inputs_fliplr_flipud(op_info, device, dtype, requires_grad): # cuda gradchecks are very slow # see discussion https://github.com/pytorch/pytorch/pull/47761#issuecomment-747316775 SkipInfo('TestGradients', 'test_fn_gradgrad', device_type='cuda'))), + OpInfo('linalg.svd', + op=torch.linalg.svd, + aten_name='linalg_svd', + dtypes=floating_and_complex_types(), + test_inplace_grad=False, + supports_tensor_out=False, + sample_inputs_func=sample_inputs_linalg_svd, + decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack], + skips=( + # gradgrad checks are slow + SkipInfo('TestGradients', 'test_fn_gradgrad', active_if=(not TEST_WITH_SLOW)), + # cuda gradchecks are very slow + # see discussion https://github.com/pytorch/pytorch/pull/47761#issuecomment-747316775 + SkipInfo('TestGradients', 'test_fn_gradgrad', device_type='cuda'))), OpInfo('pinverse', op=torch.pinverse, dtypes=floating_and_complex_types(), @@ -1602,8 +1652,6 @@ def method_tests(): ('expand', (), (dont_convert(()),), 'scalar_to_scalar'), ('expand', (), (1, 3, 2), 'scalar_to_dims', (False,)), ('expand_as', (S, 1, 1), (torch.rand(S, S, S),), '', (False,)), - ('exp', (S, S, S), NO_ARGS, '', (True,)), - ('exp', (), NO_ARGS, 'scalar', (True,)), ('logit', torch.randn(S, S, S).clamp(0.1, 0.9).requires_grad_(True), NO_ARGS, ''), ('logit', torch.randn(S, S, S).clamp(0.1, 0.9).requires_grad_(True), (0.2,), 'eps'), ('logit', uniform_scalar().clamp(0.1, 0.9).requires_grad_(True), NO_ARGS, 'scalar'), diff --git a/torch/utils/data/_utils/worker.py b/torch/utils/data/_utils/worker.py index 7a53d61feae5..a8ca66057b6b 100644 --- a/torch/utils/data/_utils/worker.py +++ b/torch/utils/data/_utils/worker.py @@ -177,7 +177,7 @@ def _worker_loop(dataset_kind, dataset, index_queue, data_queue, done_event, continue if isinstance(r, _ResumeIteration): # Acknowledge the main process - data_queue.put(r) + data_queue.put((r, None)) iteration_end = False # Recreate the fetcher for worker-reuse policy fetcher = _DatasetKind.create_fetcher( diff --git a/torch/utils/data/dataloader.py b/torch/utils/data/dataloader.py index d4ef1a99a2df..a5eeeec671e3 100644 --- a/torch/utils/data/dataloader.py +++ b/torch/utils/data/dataloader.py @@ -962,8 +962,9 @@ def _reset(self, loader, first_iter=False): self._index_queues[idx].put(_utils.worker._ResumeIteration()) resume_iteration_cnt = self._num_workers while resume_iteration_cnt > 0: - data = self._get_data() - if isinstance(data, _utils.worker._ResumeIteration): + return_idx, return_data = self._get_data() + if isinstance(return_idx, _utils.worker._ResumeIteration): + assert return_data is None resume_iteration_cnt -= 1 # prime the prefetch loop for _ in range(self._prefetch_factor * self._num_workers):