diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index 94b842ec4177..c425ac767270 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -1272,7 +1272,8 @@ Tensor cholesky(const Tensor &self, bool upper) { "and\n" "U = torch.cholesky(A, upper=True)\n", "should be replaced with\n", - "U = torch.linalg.cholesky(A.transpose(-2, -1).conj()).transpose(-2, -1).conj()" + "U = torch.linalg.cholesky(A).transpose(-2, -1).conj().\n" + "This transform will produce equivalent results for all valid (symmetric positive definite) inputs." ); if (self.numel() == 0) { return at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT); @@ -1310,7 +1311,8 @@ Tensor& cholesky_out(const Tensor &self, bool upper, Tensor &result) { "and\n" "U = torch.cholesky(A, upper=True)\n", "should be replaced with\n", - "U = torch.linalg.cholesky(A.transpose(-2, -1).conj()).transpose(-2, -1).conj()" + "U = torch.linalg.cholesky(A).transpose(-2, -1).conj().\n" + "This transform will produce equivalent results for all valid (symmetric positive definite) inputs." ); checkSameDevice("cholesky", result, self); checkLinalgCompatibleDtype("cholesky", result, self); diff --git a/aten/src/ATen/native/BinaryOps.cpp b/aten/src/ATen/native/BinaryOps.cpp index 958ef3b625ee..69adbbd071d5 100644 --- a/aten/src/ATen/native/BinaryOps.cpp +++ b/aten/src/ATen/native/BinaryOps.cpp @@ -127,6 +127,18 @@ TORCH_META_FUNC2(bitwise_right_shift, Tensor) ( build_borrowing_binary_op(maybe_get_output(), self, other); } +TORCH_META_FUNC2(bitwise_and, Tensor) (const Tensor& self, const Tensor& other) { + build_borrowing_binary_op(maybe_get_output(), self, other); +} + +TORCH_META_FUNC2(bitwise_or, Tensor) (const Tensor& self, const Tensor& other) { + build_borrowing_binary_op(maybe_get_output(), self, other); +} + +TORCH_META_FUNC2(bitwise_xor, Tensor) (const Tensor& self, const Tensor& other) { + build_borrowing_binary_op(maybe_get_output(), self, other); +} + TORCH_META_FUNC2(fmod, Tensor) (const Tensor& self, const Tensor& other) { build_borrowing_binary_op(maybe_get_output(), self, other); } @@ -366,6 +378,9 @@ TORCH_IMPL_FUNC(func_out) (const Tensor& self, const Tensor& other, const Tensor func_stub(device_type(), *this); \ } +CREATE_BINARY_TORCH_IMPL_FUNC(bitwise_and_out, bitwise_and_stub); +CREATE_BINARY_TORCH_IMPL_FUNC(bitwise_or_out, bitwise_or_stub); +CREATE_BINARY_TORCH_IMPL_FUNC(bitwise_xor_out, bitwise_xor_stub); CREATE_BINARY_TORCH_IMPL_FUNC(maximum_out, maximum_stub); CREATE_BINARY_TORCH_IMPL_FUNC(minimum_out, minimum_stub); CREATE_BINARY_TORCH_IMPL_FUNC(fmax_out, fmax_stub); @@ -711,33 +726,16 @@ Tensor rsub(const Tensor& self, const Scalar& other, const Scalar& alpha) { return native::rsub(self, wrapped_scalar_tensor(other), alpha); } -Tensor& bitwise_and_out(const Tensor& self, const Tensor& other, Tensor& result) { - auto iter = TensorIterator::binary_op(result, self, other); - bitwise_and_stub(iter.device_type(), iter); - return result; -} - -Tensor bitwise_and(const Tensor& self, const Tensor& other) { - Tensor result = at::empty({0}, self.options()); - at::bitwise_and_out(result, self, other); - return result; -} - -Tensor& bitwise_and_(Tensor& self, const Tensor& other) { - return at::bitwise_and_out(self, self, other); -} - Tensor& bitwise_and_out(const Tensor& self, const Scalar& other, Tensor& result) { return at::bitwise_and_out(result, self, wrapped_scalar_tensor(other)); } Tensor bitwise_and(const Tensor& self, const Scalar& other) { - Tensor result = at::empty({0}, self.options()); - return at::bitwise_and_out(result, self, other); + return at::bitwise_and(self, wrapped_scalar_tensor(other)); } Tensor& bitwise_and_(Tensor& self, const Scalar& other) { - return at::bitwise_and_out(self, self, other); + return self.bitwise_and_(wrapped_scalar_tensor(other)); } // Legacy and interfaces. They are aliased to bitwise_and* functions @@ -757,33 +755,16 @@ Tensor& __iand__(Tensor& self, const Scalar& other) { return self.bitwise_and_(other); } -Tensor& bitwise_or_out(const Tensor& self, const Tensor& other, Tensor& result) { - auto iter = TensorIterator::binary_op(result, self, other); - bitwise_or_stub(iter.device_type(), iter); - return result; -} - -Tensor bitwise_or(const Tensor& self, const Tensor& other) { - Tensor result = at::empty({0}, self.options()); - at::bitwise_or_out(result, self, other); - return result; -} - -Tensor& bitwise_or_(Tensor& self, const Tensor& other) { - return at::bitwise_or_out(self, self, other); -} - Tensor& bitwise_or_out(const Tensor& self, const Scalar& other, Tensor& result) { return at::bitwise_or_out(result, self, wrapped_scalar_tensor(other)); } Tensor bitwise_or(const Tensor& self, const Scalar& other) { - Tensor result = at::empty({0}, self.options()); - return at::bitwise_or_out(result, self, other); + return at::bitwise_or(self, wrapped_scalar_tensor(other)); } Tensor& bitwise_or_(Tensor& self, const Scalar& other) { - return at::bitwise_or_out(self, self, other); + return self.bitwise_or_(wrapped_scalar_tensor(other)); } // Legacy or interfaces. They are aliased to bitwise_or* functions @@ -803,33 +784,16 @@ Tensor& __ior__(Tensor& self, const Scalar& other) { return self.bitwise_or_(other); } -Tensor& bitwise_xor_out(const Tensor& self, const Tensor& other, Tensor& result) { - auto iter = TensorIterator::binary_op(result, self, other); - bitwise_xor_stub(iter.device_type(), iter); - return result; -} - -Tensor bitwise_xor(const Tensor& self, const Tensor& other) { - Tensor result = at::empty({0}, self.options()); - at::bitwise_xor_out(result, self, other); - return result; -} - -Tensor& bitwise_xor_(Tensor& self, const Tensor& other) { - return at::bitwise_xor_out(self, self, other); -} - Tensor& bitwise_xor_out(const Tensor& self, const Scalar& other, Tensor& result) { return at::bitwise_xor_out(result, self, wrapped_scalar_tensor(other)); } Tensor bitwise_xor(const Tensor& self, const Scalar& other) { - Tensor result = at::empty({0}, self.options()); - return at::bitwise_xor_out(result, self, other); + return at::bitwise_xor(self, wrapped_scalar_tensor(other)); } Tensor& bitwise_xor_(Tensor& self, const Scalar& other) { - return at::bitwise_xor_out(self, self, other); + return self.bitwise_xor_(wrapped_scalar_tensor(other)); } // Legacy xor interfaces. They are aliased to bitwise_xor* functions diff --git a/aten/src/ATen/native/BinaryOps.h b/aten/src/ATen/native/BinaryOps.h index bc83e9941a5a..dc8f4b5dfe8e 100644 --- a/aten/src/ATen/native/BinaryOps.h +++ b/aten/src/ATen/native/BinaryOps.h @@ -54,9 +54,9 @@ DECLARE_DISPATCH(structured_binary_fn, div_floor_stub); DECLARE_DISPATCH(structured_binary_fn, div_trunc_stub); DECLARE_DISPATCH(structured_binary_fn, atan2_stub); DECLARE_DISPATCH(structured_binary_fn, remainder_stub); -DECLARE_DISPATCH(binary_fn, bitwise_and_stub); -DECLARE_DISPATCH(binary_fn, bitwise_or_stub); -DECLARE_DISPATCH(binary_fn, bitwise_xor_stub); +DECLARE_DISPATCH(structured_binary_fn, bitwise_and_stub); +DECLARE_DISPATCH(structured_binary_fn, bitwise_or_stub); +DECLARE_DISPATCH(structured_binary_fn, bitwise_xor_stub); DECLARE_DISPATCH(structured_binary_fn, lshift_stub); DECLARE_DISPATCH(structured_binary_fn, rshift_stub); DECLARE_DISPATCH(binary_fn, logical_xor_stub); diff --git a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp index d0781c624768..8de66bc0c515 100644 --- a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp @@ -240,7 +240,7 @@ void remainder_kernel(TensorIteratorBase& iter) { } } -void bitwise_and_kernel(TensorIterator& iter) { +void bitwise_and_kernel(TensorIteratorBase& iter) { if (iter.dtype() == ScalarType::Bool) { cpu_kernel( iter, @@ -261,7 +261,7 @@ void bitwise_and_kernel(TensorIterator& iter) { } } -void bitwise_or_kernel(TensorIterator& iter) { +void bitwise_or_kernel(TensorIteratorBase& iter) { if (iter.dtype() == ScalarType::Bool) { cpu_kernel( iter, @@ -282,7 +282,7 @@ void bitwise_or_kernel(TensorIterator& iter) { } } -void bitwise_xor_kernel(TensorIterator& iter) { +void bitwise_xor_kernel(TensorIteratorBase& iter) { if (iter.dtype() == ScalarType::Bool) { // Boolean type does not work with ^ (bitwise XOR) in C++. bitwise_xor wraps this operation for both Boolean and // integral types. diff --git a/aten/src/ATen/native/cuda/BinaryBitwiseOpsKernels.cu b/aten/src/ATen/native/cuda/BinaryBitwiseOpsKernels.cu index 30894b568762..0462c473a01f 100644 --- a/aten/src/ATen/native/cuda/BinaryBitwiseOpsKernels.cu +++ b/aten/src/ATen/native/cuda/BinaryBitwiseOpsKernels.cu @@ -23,7 +23,7 @@ struct BitwiseAndFunctor { } }; -void bitwise_and_kernel_cuda(TensorIterator& iter) { +void bitwise_and_kernel_cuda(TensorIteratorBase& iter) { AT_DISPATCH_INTEGRAL_TYPES_AND(kBool, iter.dtype(), "bitwise_and_cuda", [&]() { BitwiseAndFunctor f; gpu_kernel_with_scalars(iter, f); @@ -44,7 +44,7 @@ struct BitwiseOrFunctor { } }; -void bitwise_or_kernel_cuda(TensorIterator& iter) { +void bitwise_or_kernel_cuda(TensorIteratorBase& iter) { AT_DISPATCH_INTEGRAL_TYPES_AND(kBool, iter.dtype(), "bitwise_or_cuda", [&]() { BitwiseOrFunctor f; gpu_kernel_with_scalars(iter, f); @@ -65,7 +65,7 @@ struct BitwiseXorFunctor { } }; -void bitwise_xor_kernel_cuda(TensorIterator& iter) { +void bitwise_xor_kernel_cuda(TensorIteratorBase& iter) { AT_DISPATCH_INTEGRAL_TYPES_AND(kBool, iter.dtype(), "bitwise_xor_cuda", [&]() { BitwiseXorFunctor f; gpu_kernel_with_scalars(iter, f); diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index a1f0f6145fa3..d0cd5663f631 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -5692,6 +5692,8 @@ - func: bitwise_and.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase variants: function dispatch: CPU, CUDA: bitwise_and_out @@ -5700,19 +5702,18 @@ device_check: NoCheck # TensorIterator variants: function dispatch: - CPU, CUDA: bitwise_and_out + CompositeExplicitAutograd: bitwise_and_out - func: bitwise_and.Scalar(Tensor self, Scalar other) -> Tensor device_check: NoCheck # TensorIterator variants: method, function dispatch: - CPU, CUDA: bitwise_and + CompositeExplicitAutograd: bitwise_and - func: bitwise_and.Tensor(Tensor self, Tensor other) -> Tensor device_check: NoCheck # TensorIterator variants: method, function - dispatch: - CPU, CUDA: bitwise_and + structured_delegate: bitwise_and.Tensor_out - func: bitwise_and_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -5721,6 +5722,7 @@ - func: bitwise_and_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) device_check: NoCheck # TensorIterator variants: method + structured_delegate: bitwise_and.Tensor_out - func: __and__.Scalar(Tensor self, Scalar other) -> Tensor device_check: NoCheck # TensorIterator @@ -5740,6 +5742,8 @@ - func: bitwise_or.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase variants: function dispatch: CPU, CUDA: bitwise_or_out @@ -5748,7 +5752,7 @@ device_check: NoCheck # TensorIterator variants: function dispatch: - CPU, CUDA: bitwise_or_out + CompositeExplicitAutograd: bitwise_or_out - func: bitwise_or.Scalar(Tensor self, Scalar other) -> Tensor device_check: NoCheck # TensorIterator @@ -5757,6 +5761,7 @@ - func: bitwise_or.Tensor(Tensor self, Tensor other) -> Tensor device_check: NoCheck # TensorIterator variants: method, function + structured_delegate: bitwise_or.Tensor_out - func: bitwise_or_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -5765,6 +5770,7 @@ - func: bitwise_or_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) device_check: NoCheck # TensorIterator variants: method + structured_delegate: bitwise_or.Tensor_out - func: __or__.Scalar(Tensor self, Scalar other) -> Tensor device_check: NoCheck # TensorIterator @@ -5784,6 +5790,8 @@ - func: bitwise_xor.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase variants: function dispatch: CPU, CUDA: bitwise_xor_out @@ -5792,7 +5800,7 @@ device_check: NoCheck # TensorIterator variants: function dispatch: - CPU, CUDA: bitwise_xor_out + CompositeExplicitAutograd: bitwise_xor_out - func: bitwise_xor.Scalar(Tensor self, Scalar other) -> Tensor device_check: NoCheck # TensorIterator @@ -5801,6 +5809,7 @@ - func: bitwise_xor.Tensor(Tensor self, Tensor other) -> Tensor device_check: NoCheck # TensorIterator variants: method, function + structured_delegate: bitwise_xor.Tensor_out - func: bitwise_xor_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -5809,6 +5818,7 @@ - func: bitwise_xor_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) device_check: NoCheck # TensorIterator variants: method + structured_delegate: bitwise_xor.Tensor_out - func: __xor__.Scalar(Tensor self, Scalar other) -> Tensor device_check: NoCheck # TensorIterator diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 15b5973486c1..f201947f84d6 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -2082,7 +2082,9 @@ def merge_dicts(*dicts): .. code:: python - U = torch.linalg.cholesky(A.transpose(-2, -1).conj()).transpose(-2, -1).conj() + U = torch.linalg.cholesky(A).transpose(-2, -1).conj() + + This transform will produce equivalent results for all valid (symmetric positive definite) inputs. Args: input (Tensor): the input tensor :math:`A` of size :math:`(*, n, n)` where `*` is zero or more diff --git a/torch/csrc/api/src/nn/modules/rnn.cpp b/torch/csrc/api/src/nn/modules/rnn.cpp index d325d6dce46d..d92be5e6b513 100644 --- a/torch/csrc/api/src/nn/modules/rnn.cpp +++ b/torch/csrc/api/src/nn/modules/rnn.cpp @@ -127,10 +127,9 @@ void RNNImplBase::reset() { layer_params.emplace_back(w_hr); param_names.emplace_back("weight_hr_l{layer}{suffix}"); } - for(const auto i : c10::irange(param_names.size())) { // NOLINT(modernize-loop-convert) - std::string x = std::regex_replace(param_names[i], std::regex("\\{layer\\}"), c10::str(layer)); - x = std::regex_replace(x, std::regex("\\{suffix\\}"), c10::str(suffix)); - param_names[i] = x; + for(auto& param_name : param_names) { + std::string x = std::regex_replace(param_name, std::regex("\\{layer\\}"), c10::str(layer)); + param_name = std::regex_replace(x, std::regex("\\{suffix\\}"), c10::str(suffix)); } for(const auto i : c10::irange(param_names.size())) { diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp index f881d3f6c691..e63480f96921 100644 --- a/torch/csrc/autograd/engine.cpp +++ b/torch/csrc/autograd/engine.cpp @@ -577,6 +577,12 @@ void GraphTask::exec_post_processing() { // 2. The callback's results can safely be used on (user-facing) caller_current_streams // after backward(). c10::MultiStreamGuard g(caller_current_streams_filtered); + + // Set the ThreadLocalState before calling the function. + // NB: The ThreadLocalStateGuard doesn't set the grad_mode because GraphTask + // always saves ThreadLocalState without grad_mode. + at::ThreadLocalStateGuard tls_guard(this->thread_locals_); + // WARNING: Don't use a range-for loop here because more callbacks may be // added in between callback calls, so iterators may become invalidated. // NOLINTNEXTLINE(modernize-loop-convert) diff --git a/torch/csrc/cuda/utils.cpp b/torch/csrc/cuda/utils.cpp index 6f544d09dc68..d18efbf66735 100644 --- a/torch/csrc/cuda/utils.cpp +++ b/torch/csrc/cuda/utils.cpp @@ -1,6 +1,5 @@ #include -// NOLINTNEXTLINE(modernize-deprecated-headers) -#include +#include #include #include diff --git a/torch/csrc/deploy/deploy.h b/torch/csrc/deploy/deploy.h index 906bada1e22c..faead8af1f8d 100644 --- a/torch/csrc/deploy/deploy.h +++ b/torch/csrc/deploy/deploy.h @@ -1,10 +1,9 @@ #pragma once -// NOLINTNEXTLINE(modernize-deprecated-headers) -#include #include #include #include #include +#include #include #include #include diff --git a/torch/csrc/deploy/example/benchmark.cpp b/torch/csrc/deploy/example/benchmark.cpp index c5e9781c706a..4a747ec6d67a 100644 --- a/torch/csrc/deploy/example/benchmark.cpp +++ b/torch/csrc/deploy/example/benchmark.cpp @@ -1,22 +1,21 @@ +#include + +#include +#include +#include + +#include + #include #include #include +#include #include #include #include #include #include -// NOLINTNEXTLINE(modernize-deprecated-headers) -#include -#include - -#include -#include -#include - -#include - typedef void (*function_type)(const char*); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) diff --git a/torch/csrc/deploy/interpreter/interpreter_impl.cpp b/torch/csrc/deploy/interpreter/interpreter_impl.cpp index 7cbffe4f7bc3..92da5ae1b75f 100644 --- a/torch/csrc/deploy/interpreter/interpreter_impl.cpp +++ b/torch/csrc/deploy/interpreter/interpreter_impl.cpp @@ -3,16 +3,14 @@ #define PY_SSIZE_T_CLEAN #include #include -#include -// NOLINTNEXTLINE(modernize-deprecated-headers) -#include #include #include -// NOLINTNEXTLINE(modernize-deprecated-headers) -#include #include #include + +#include +#include #include #include #include diff --git a/torch/csrc/jit/codegen/cuda/instrumentation.h b/torch/csrc/jit/codegen/cuda/instrumentation.h index e67cced4ec0b..06dde895c028 100644 --- a/torch/csrc/jit/codegen/cuda/instrumentation.h +++ b/torch/csrc/jit/codegen/cuda/instrumentation.h @@ -2,9 +2,8 @@ #include -// NOLINTNEXTLINE(modernize-deprecated-headers) -#include #include +#include namespace torch { namespace jit { diff --git a/torch/csrc/jit/frontend/strtod.cpp b/torch/csrc/jit/frontend/strtod.cpp index 8c7bd277a1df..34023bcd52d3 100644 --- a/torch/csrc/jit/frontend/strtod.cpp +++ b/torch/csrc/jit/frontend/strtod.cpp @@ -2,10 +2,8 @@ // https://github.com/JuliaLang/julia/blob/v1.1.0/src/support/strtod.c #include -// NOLINTNEXTLINE(modernize-deprecated-headers) -#include -// NOLINTNEXTLINE(modernize-deprecated-headers) -#include +#include +#include #if defined(__APPLE__) || defined(__FreeBSD__) #include @@ -23,14 +21,10 @@ // respective // C stdlib functions -// NOLINTNEXTLINE(modernize-deprecated-headers) -#include -// NOLINTNEXTLINE(modernize-deprecated-headers) -#include -// NOLINTNEXTLINE(modernize-deprecated-headers) -#include -// NOLINTNEXTLINE(modernize-deprecated-headers) -#include +#include +#include +#include +#include #include #define D_PNAN ((double)+NAN) diff --git a/torch/csrc/jit/passes/frozen_ops_to_mkldnn.cpp b/torch/csrc/jit/passes/frozen_ops_to_mkldnn.cpp index b4acab100fdc..5f040ecf1ad1 100644 --- a/torch/csrc/jit/passes/frozen_ops_to_mkldnn.cpp +++ b/torch/csrc/jit/passes/frozen_ops_to_mkldnn.cpp @@ -147,9 +147,7 @@ void InplaceMKLDNNSubgraph(std::shared_ptr graph) { continue; } Node* last = nullptr; - // NOLINTNEXTLINE(modernize-loop-convert) - for (auto it = set.second->begin(); it != set.second->end(); it++) { - Value* v = *it; + for (const auto& v : *set.second) { auto k = v->node()->kind(); if (k == prim::Constant || k == prim::ConstantMKLDNNTensor || k == prim::Param) { diff --git a/torch/csrc/jit/passes/onnx/shape_type_inference.cpp b/torch/csrc/jit/passes/onnx/shape_type_inference.cpp index 3b27d8074941..d8fd4c0f6b8d 100644 --- a/torch/csrc/jit/passes/onnx/shape_type_inference.cpp +++ b/torch/csrc/jit/passes/onnx/shape_type_inference.cpp @@ -487,9 +487,8 @@ std::vector ComputeShapeFromReshape( c10::optional<::c10::SymbolicShape> ComputeShapeFromExpand( const std::vector<::c10::ShapeSymbol>& input_shape, const std::vector& reshape) { - // NOLINTNEXTLINE(modernize-loop-convert) - for (auto it = reshape.begin(); it != reshape.end(); ++it) { - if (*it < 0) { + for (const auto& it : reshape) { + if (it < 0) { return c10::nullopt; } } @@ -530,9 +529,8 @@ c10::optional<::c10::SymbolicShape> ComputeShapeFromTile( TORCH_INTERNAL_ASSERT( input_shape.size() == reshape.size(), "ONNX Tile input shapes do not match."); - // NOLINTNEXTLINE(modernize-loop-convert) - for (auto it = reshape.begin(); it != reshape.end(); ++it) { - if (*it < 0) { + for (const auto& it : reshape) { + if (it < 0) { return c10::nullopt; } } diff --git a/torch/csrc/jit/passes/symbolic_shape_analysis.cpp b/torch/csrc/jit/passes/symbolic_shape_analysis.cpp index 1660a9c17553..3bdc8ca98566 100644 --- a/torch/csrc/jit/passes/symbolic_shape_analysis.cpp +++ b/torch/csrc/jit/passes/symbolic_shape_analysis.cpp @@ -235,10 +235,9 @@ struct SymbolicShapeAnalyzer { for (size_t i = 0; i < symbolic_set.size(); ++i) { Value* v = symbolic_set[i]; Value* dominating_value = v; - // NOLINTNEXTLINE(modernize-loop-convert) - for (size_t j = 0; j < symbolic_set.size(); ++j) { - if (dominating_value->node()->isDominatedBy(symbolic_set[j]->node())) { - dominating_value = symbolic_set[j]; + for (const auto& sym_set : symbolic_set) { + if (dominating_value->node()->isDominatedBy(sym_set->node())) { + dominating_value = sym_set; } } if (dominating_value != v) { diff --git a/torch/csrc/jit/runtime/instruction.h b/torch/csrc/jit/runtime/instruction.h index 3fde40e3f9d8..889bbf8ff8df 100644 --- a/torch/csrc/jit/runtime/instruction.h +++ b/torch/csrc/jit/runtime/instruction.h @@ -1,6 +1,6 @@ #pragma once -// NOLINTNEXTLINE(modernize-deprecated-headers) -#include + +#include #include #include diff --git a/torch/csrc/python_headers.h b/torch/csrc/python_headers.h index 2778d53435be..1e5b16eebbff 100644 --- a/torch/csrc/python_headers.h +++ b/torch/csrc/python_headers.h @@ -1,6 +1,5 @@ #pragma once -// NOLINTNEXTLINE(modernize-deprecated-headers) -#include +#include // workaround for Python 2 issue: https://bugs.python.org/issue17120 // NOTE: It looks like this affects Python 3 as well. #pragma push_macro("_XOPEN_SOURCE") diff --git a/torch/lib/libshm/manager.cpp b/torch/lib/libshm/manager.cpp index 5df897cabd28..a87465dd4fa9 100644 --- a/torch/lib/libshm/manager.cpp +++ b/torch/lib/libshm/manager.cpp @@ -1,14 +1,13 @@ -#include -#include -// NOLINTNEXTLINE(modernize-deprecated-headers) -#include -#include -#include -#include -#include #include +#include +#include #include +#include +#include +#include +#include #include +#include #include diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index f897694ed125..0d76e9000c42 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -64,42 +64,23 @@ Examples:: - >>> a = torch.randn(2, 2, dtype=torch.complex128) - >>> a = a @ a.t().conj() + torch.eye(2) # creates a Hermitian positive-definite matrix - >>> l = torch.linalg.cholesky(a) - >>> a + >>> A = torch.randn(2, 2, dtype=torch.complex128) + >>> A = A @ A.T.conj() + torch.eye(2) # creates a Hermitian positive-definite matrix + >>> A tensor([[2.5266+0.0000j, 1.9586-2.0626j], [1.9586+2.0626j, 9.4160+0.0000j]], dtype=torch.complex128) - >>> l + >>> L = torch.linalg.cholesky(A) + >>> L tensor([[1.5895+0.0000j, 0.0000+0.0000j], [1.2322+1.2976j, 2.4928+0.0000j]], dtype=torch.complex128) - >>> l @ l.t().conj() - tensor([[2.5266+0.0000j, 1.9586-2.0626j], - [1.9586+2.0626j, 9.4160+0.0000j]], dtype=torch.complex128) - - >>> a = torch.randn(3, 2, 2, dtype=torch.float64) - >>> a = a @ a.transpose(-2, -1) + torch.eye(2).squeeze(0) # symmetric positive definite matrices - >>> l = torch.linalg.cholesky(a) - >>> a - tensor([[[ 1.1629, 2.0237], - [ 2.0237, 6.6593]], - - [[ 0.4187, 0.1830], - [ 0.1830, 0.1018]], - - [[ 1.9348, -2.5744], - [-2.5744, 4.6386]]], dtype=torch.float64) - >>> l - tensor([[[ 1.0784, 0.0000], - [ 1.8766, 1.7713]], - - [[ 0.6471, 0.0000], - [ 0.2829, 0.1477]], - - [[ 1.3910, 0.0000], - [-1.8509, 1.1014]]], dtype=torch.float64) - >>> torch.allclose(l @ l.transpose(-2, -1), a) - True + >>> torch.dist(L @ L.T.conj(), A) + tensor(4.4692e-16, dtype=torch.float64) + + >>> A = torch.randn(3, 2, 2, dtype=torch.float64) + >>> A = A @ A.transpose(-2, -1) + torch.eye(2) # batch of symmetric positive-definite matrices + >>> L = torch.linalg.cholesky(A) + >>> torch.dist(L @ L.transpose(-2, -1), A) + tensor(5.8747e-16, dtype=torch.float64) """) cholesky_ex = _add_docstr(_linalg.linalg_cholesky_ex, r""" @@ -144,13 +125,13 @@ Examples:: - >>> a = torch.randn(2, 2, dtype=torch.complex128) - >>> a = a @ a.t().conj() # creates a Hermitian positive-definite matrix - >>> l, info = torch.linalg.cholesky_ex(a) - >>> a + >>> A = torch.randn(2, 2, dtype=torch.complex128) + >>> A = A @ A.t().conj() # creates a Hermitian positive-definite matrix + >>> L, info = torch.linalg.cholesky_ex(A) + >>> A tensor([[ 2.3792+0.0000j, -0.9023+0.9831j], [-0.9023-0.9831j, 0.8757+0.0000j]], dtype=torch.complex128) - >>> l + >>> L tensor([[ 1.5425+0.0000j, 0.0000+0.0000j], [-0.5850-0.6374j, 0.3567+0.0000j]], dtype=torch.complex128) >>> info @@ -214,28 +195,19 @@ Examples:: - >>> x = torch.rand(4, 4) - >>> y = torch.linalg.inv(x) - >>> z = x @ y - >>> z - tensor([[ 1.0000, -0.0000, -0.0000, 0.0000], - [ 0.0000, 1.0000, 0.0000, 0.0000], - [ 0.0000, 0.0000, 1.0000, 0.0000], - [ 0.0000, -0.0000, -0.0000, 1.0000]]) - >>> torch.dist(z, torch.eye(4)) + >>> A = torch.randn(4, 4) + >>> Ainv = torch.linalg.inv(A) + >>> torch.dist(A @ Ainv, torch.eye(4)) tensor(1.1921e-07) - >>> # Batched inverse example - >>> x = torch.randn(2, 3, 4, 4) - >>> y = torch.linalg.inv(x) - >>> z = x @ y - >>> torch.dist(z, torch.eye(4).expand_as(x)) + >>> A = torch.randn(2, 3, 4, 4) # Batch of matrices + >>> Ainv = torch.linalg.inv(A) + >>> torch.dist(A @ Ainv, torch.eye(4))) tensor(1.9073e-06) - >>> x = torch.rand(4, 4, dtype=torch.cdouble) - >>> y = torch.linalg.inv(x) - >>> z = x @ y - >>> torch.dist(z, torch.eye(4, dtype=torch.cdouble)) + >>> A = torch.randn(4, 4, dtype=torch.complex128) # Complex matrix + >>> Ainv = torch.linalg.inv(A) + >>> torch.dist(A @ Ainv, torch.eye(4)) tensor(7.5107e-16, dtype=torch.float64) .. _invertible: @@ -282,16 +254,10 @@ Examples:: - >>> a = torch.randn(3, 3) - >>> inverse, info = torch.linalg.inv_ex(a) - >>> a - tensor([[-0.0464, 0.2302, -1.3568], - [-0.5437, -1.2301, -0.6918], - [ 0.2328, -1.4910, -0.3003]]) - >>> l - tensor([[ 0.4320, -1.3653, 1.1931], - [ 0.2117, -0.2152, -0.4605], - [-0.7159, 0.0102, -0.1190]]) + >>> A = torch.randn(3, 3) + >>> Ainv, info = torch.linalg.inv_ex(A) + >>> torch.dist(torch.linalg.inv(A), Ainv) + tensor(0.) >>> info tensor(0, dtype=torch.int32) @@ -325,31 +291,12 @@ Examples:: - >>> a = torch.randn(3, 3) - >>> a - tensor([[ 0.9478, 0.9158, -1.1295], - [ 0.9701, 0.7346, -1.8044], - [-0.2337, 0.0557, 0.6929]]) - >>> torch.linalg.det(a) - tensor(0.0934) - - >>> out = torch.empty(0) - >>> torch.linalg.det(a, out=out) - tensor(0.0934) - >>> out + >>> A = torch.randn(3, 3) + >>> torch.linalg.det(A) tensor(0.0934) - >>> a = torch.randn(3, 2, 2) - >>> a - tensor([[[ 0.9254, -0.6213], - [-0.5787, 1.6843]], - - [[ 0.3242, -0.9665], - [ 0.4539, -0.0887]], - - [[ 1.1336, -0.4025], - [-0.7089, 0.9032]]]) - >>> torch.linalg.det(a) + >>> A = torch.randn(3, 2, 2) + >>> torch.linalg.det(A) tensor([1.1990, 0.4099, 0.7386]) """) @@ -481,23 +428,23 @@ Examples:: - >>> a = torch.randn(2, 2, dtype=torch.complex128) - >>> a + >>> A = torch.randn(2, 2, dtype=torch.complex128) + >>> A tensor([[ 0.9828+0.3889j, -0.4617+0.3010j], [ 0.1662-0.7435j, -0.6139+0.0562j]], dtype=torch.complex128) - >>> w, v = torch.linalg.eig(a) - >>> w + >>> L, V = torch.linalg.eig(A) + >>> L tensor([ 1.1226+0.5738j, -0.7537-0.1286j], dtype=torch.complex128) - >>> v + >>> V tensor([[ 0.9218+0.0000j, 0.1882-0.2220j], [-0.0270-0.3867j, 0.9567+0.0000j]], dtype=torch.complex128) - >>> torch.allclose(torch.matmul(v, torch.matmul(w.diag_embed(), v.inverse())), a) - True + >>> torch.dist(V @ torch.diag(L) @ torch.linalg.inv(V), A) + tensor(7.7119e-16, dtype=torch.float64) - >>> a = torch.randn(3, 2, 2, dtype=torch.float64) - >>> w, v = torch.linalg.eig(a) - >>> torch.allclose(torch.matmul(v, torch.matmul(w.diag_embed(), v.inverse())).real, a) - True + >>> A = torch.randn(3, 2, 2, dtype=torch.float64) + >>> L, V = torch.linalg.eig(A) + >>> torch.dist(V @ torch.diag_embed(L) @ torch.linalg.inv(V), A) + tensor(3.2841e-16, dtype=torch.float64) .. _diagonalizable: https://en.wikipedia.org/wiki/Diagonalizable_matrix#Definition @@ -545,13 +492,13 @@ Examples:: - >>> a = torch.randn(2, 2, dtype=torch.complex128) - >>> a - tensor([[ 0.9828+0.3889j, -0.4617+0.3010j], - [ 0.1662-0.7435j, -0.6139+0.0562j]], dtype=torch.complex128) - >>> w = torch.linalg.eigvals(a) - >>> w + >>> A = torch.randn(2, 2, dtype=torch.complex128) + >>> L = torch.linalg.eigvals(A) + >>> L tensor([ 1.1226+0.5738j, -0.7537-0.1286j], dtype=torch.complex128) + + >>> torch.dist(L, torch.linalg.eig(A).eigenvalues) + tensor(2.4576e-07) """) eigh = _add_docstr(_linalg.linalg_eigh, r""" @@ -601,14 +548,14 @@ .. warning:: Gradients computed using the `eigenvectors` tensor will only be finite when :attr:`A` has unique eigenvalues. - Furthermore, if the distance between any two eigvalues is close to zero, + Furthermore, if the distance between any two eigenvalues is close to zero, the gradient will be numerically unstable, as it depends on the eigenvalues :math:`\lambda_i` through the computation of :math:`\frac{1}{\min_{i \neq j} \lambda_i - \lambda_j}`. .. seealso:: - :func:`torch.linalg.eigvalsh` computes only the eigenvalues values of a Hermitian matrix. + :func:`torch.linalg.eigvalsh` computes only the eigenvalues of a Hermitian matrix. Unlike :func:`torch.linalg.eigh`, the gradients of :func:`~eigvalsh` are always numerically stable. @@ -643,26 +590,25 @@ `eigenvectors` will have the same dtype as :attr:`A` and will contain the eigenvectors as its columns. Examples:: - - >>> a = torch.randn(2, 2, dtype=torch.complex128) - >>> a = a + a.t().conj() # creates a Hermitian matrix - >>> a + >>> A = torch.randn(2, 2, dtype=torch.complex128) + >>> A = A + A.T.conj() # creates a Hermitian matrix + >>> A tensor([[2.9228+0.0000j, 0.2029-0.0862j], [0.2029+0.0862j, 0.3464+0.0000j]], dtype=torch.complex128) - >>> w, v = torch.linalg.eigh(a) - >>> w + >>> L, Q = torch.linalg.eigh(A) + >>> L tensor([0.3277, 2.9415], dtype=torch.float64) - >>> v + >>> Q tensor([[-0.0846+-0.0000j, -0.9964+0.0000j], [ 0.9170+0.3898j, -0.0779-0.0331j]], dtype=torch.complex128) - >>> torch.allclose(torch.matmul(v, torch.matmul(w.to(v.dtype).diag_embed(), v.t().conj())), a) - True - - >>> a = torch.randn(3, 2, 2, dtype=torch.float64) - >>> a = a + a.transpose(-2, -1) # creates a symmetric matrix - >>> w, v = torch.linalg.eigh(a) - >>> torch.allclose(torch.matmul(v, torch.matmul(w.diag_embed(), v.transpose(-2, -1))), a) - True + >>> torch.dist(Q @ torch.diag(L.cdouble()) @ Q.T.conj(), A) + tensor(6.1062e-16, dtype=torch.float64) + + >>> A = torch.randn(3, 2, 2, dtype=torch.float64) + >>> A = A + A.transpose(-2, -1) # creates a batch of symmetric matrices + >>> L, Q = torch.linalg.eigh(A) + >>> torch.dist(Q @ torch.diag_embed(L) @ Q.transpose(-2, -1).conj(), A) + tensor(1.5423e-15, dtype=torch.float64) """) eigvalsh = _add_docstr(_linalg.linalg_eigvalsh, r""" @@ -715,28 +661,17 @@ Examples:: - >>> a = torch.randn(2, 2, dtype=torch.complex128) - >>> a = a + a.t().conj() # creates a Hermitian matrix - >>> a + >>> A = torch.randn(2, 2, dtype=torch.complex128) + >>> A = A + A.T.conj() # creates a Hermitian matrix + >>> A tensor([[2.9228+0.0000j, 0.2029-0.0862j], [0.2029+0.0862j, 0.3464+0.0000j]], dtype=torch.complex128) - >>> w = torch.linalg.eigvalsh(a) - >>> w + >>> torch.linalg.eigvalsh(A) tensor([0.3277, 2.9415], dtype=torch.float64) - >>> a = torch.randn(3, 2, 2, dtype=torch.float64) - >>> a = a + a.transpose(-2, -1) # creates a symmetric matrix - >>> a - tensor([[[ 2.8050, -0.3850], - [-0.3850, 3.2376]], - - [[-1.0307, -2.7457], - [-2.7457, -1.7517]], - - [[ 1.7166, 2.2207], - [ 2.2207, -2.0898]]], dtype=torch.float64) - >>> w = torch.linalg.eigvalsh(a) - >>> w + >>> A = torch.randn(3, 2, 2, dtype=torch.float64) + >>> A = A + A.transpose(-2, -1) # creates a batch of symmetric matrices + >>> torch.linalg.eigvalsh(A) tensor([[ 2.5797, 3.4629], [-4.1605, 1.3780], [-3.1113, 2.7381]], dtype=torch.float64) @@ -790,16 +725,16 @@ Examples:: - >>> a = torch.randn(2, 2) - >>> h, tau = torch.geqrf(a) - >>> q = torch.linalg.householder_product(h, tau) - >>> torch.allclose(q, torch.linalg.qr(a)[0]) - True + >>> A = torch.randn(2, 2) + >>> h, tau = torch.geqrf(A) + >>> Q = torch.linalg.householder_product(h, tau) + >>> torch.dist(Q, torch.linalg.qr(A).Q) + tensor(0.) >>> h = torch.randn(3, 2, 2, dtype=torch.complex128) >>> tau = torch.randn(3, 1, dtype=torch.complex128) - >>> q = torch.linalg.householder_product(h, tau) - >>> q + >>> Q = torch.linalg.householder_product(h, tau) + >>> Q tensor([[[ 1.8034+0.4184j, 0.2588-1.0174j], [-0.6853+0.7953j, 2.0790+0.5620j]], @@ -849,11 +784,11 @@ See also the `full description of these drivers`_ -:attr:`cond` is used to determine the effective rank of the matrices in :attr:`A` +:attr:`rcond` is used to determine the effective rank of the matrices in :attr:`A` when :attr:`driver` is one of (`'gelsy'`, `'gelsd'`, `'gelss'`). In this case, if :math:`\sigma_i` are the singular values of `A` in decreasing order, -:math:`\sigma_i` will be rounded down to zero if :math:`\sigma_i \leq \text{cond} \cdot \sigma_1`. -If :attr:`cond`\ `= None` (default), :attr:`cond` is set to the machine precision of the dtype of :attr:`A`. +:math:`\sigma_i` will be rounded down to zero if :math:`\sigma_i \leq \text{rcond} \cdot \sigma_1`. +If :attr:`rcond`\ `= None` (default), :attr:`rcond` is set to the machine precision of the dtype of :attr:`A`. This function returns the solution to the problem and some extra information in a named tuple of four tensors `(solution, residuals, rank, singular_values)`. For inputs :attr:`A`, :attr:`B` @@ -876,9 +811,8 @@ otherwise it is an empty tensor. .. note:: - While `X = \ `:attr:`A`\ `.pinv() @ \ `:attr:`B`, this function computes the - solution in a faster and more numerically stable way than performing the - computations separately. + This function computes `X = \ `:attr:`A`\ `.pinverse() @ \ `:attr:`B` in a faster and + more numerically stable way than performing the computations separately. .. warning:: The default value of :attr:`rcond` may change in a future PyTorch release. @@ -902,20 +836,19 @@ Examples:: - >>> a = torch.tensor([[10, 2, 3], [3, 10, 5], [5, 6, 12]], dtype=torch.float) - >>> a.unsqueeze_(0) - >>> b = torch.tensor([[[2, 5, 1], [3, 2, 1], [5, 1, 9]], - [[4, 2, 9], [2, 0, 3], [2, 5, 3]]], dtype=torch.float) - >>> x = torch.linalg.lstsq(a, b).solution - >>> torch.dist(x, a.pinverse() @ b) + >>> A = torch.tensor([[[10, 2, 3], [3, 10, 5], [5, 6, 12]]], dtype=torch.float) # shape (1, 3, 3) + >>> B = torch.tensor([[[2, 5, 1], [3, 2, 1], [5, 1, 9]], + [[4, 2, 9], [2, 0, 3], [2, 5, 3]]], dtype=torch.float) # shape (2, 3, 3) + >>> X = torch.linalg.lstsq(A, B).solution # A is broadcasted to shape (2, 3, 3) + >>> torch.dist(X, torch.linalg.pinv(A) @ B) tensor(2.0862e-07) - >>> sv = torch.linalg.lstsq(a, driver='gelsd').singular_values - >>> torch.dist(sv, a.svd().S) + >>> S = torch.linalg.lstsq(A, B, driver='gelsd').singular_values + >>> torch.dist(S, torch.linalg.svdvals(A)) tensor(5.7220e-06) - >>> a[:, 0].zero_() - >>> xx, rank, _ = torch.linalg.lstsq(a, b) + >>> A[:, 0].zero_() # Decrease the rank of A + >>> rank = torch.linalg.lstsq(A, B).rank >>> rank tensor([2]) @@ -949,7 +882,7 @@ .. seealso:: - :func:`torch.linalg.solve` computes :attr:`A`\ `.inv() @ \ `:attr:`B` with a + :func:`torch.linalg.solve` computes :attr:`A`\ `.inverse() @ \ `:attr:`B` with a numerically stable algorithm. Args: @@ -965,20 +898,16 @@ Examples:: - >>> a = torch.randn(3, 3) - >>> a - tensor([[-0.2270, 0.6663, -1.3515], - [-0.9838, -0.4002, -1.9313], - [-0.7886, -0.0450, 0.0528]]) - >>> torch.linalg.matrix_power(a, 0) + >>> A = torch.randn(3, 3) + >>> torch.linalg.matrix_power(A, 0) tensor([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]) - >>> torch.linalg.matrix_power(a, 3) + >>> torch.linalg.matrix_power(A, 3) tensor([[ 1.0756, 0.4980, 0.0100], [-1.6617, 1.4994, -1.9980], [-0.4509, 0.2731, 0.8001]]) - >>> torch.linalg.matrix_power(a.expand(2, -1, -1), -2) + >>> torch.linalg.matrix_power(A.expand(2, -1, -1), -2) tensor([[[ 0.2640, 0.4571, -0.5511], [-1.0163, 0.3491, -1.5292], [-0.4899, 0.0822, 0.2773]], @@ -1036,34 +965,34 @@ Examples:: - >>> a = torch.eye(10) - >>> torch.linalg.matrix_rank(a) + >>> A = torch.eye(10) + >>> torch.linalg.matrix_rank(A) tensor(10) - >>> b = torch.eye(10) - >>> b[0, 0] = 0 - >>> torch.linalg.matrix_rank(b) + >>> B = torch.eye(10) + >>> B[0, 0] = 0 + >>> torch.linalg.matrix_rank(B) tensor(9) - >>> a = torch.randn(4, 3, 2) - >>> torch.linalg.matrix_rank(a) + >>> A = torch.randn(4, 3, 2) + >>> torch.linalg.matrix_rank(A) tensor([2, 2, 2, 2]) - >>> a = torch.randn(2, 4, 2, 3) - >>> torch.linalg.matrix_rank(a) + >>> A = torch.randn(2, 4, 2, 3) + >>> torch.linalg.matrix_rank(A) tensor([[2, 2, 2, 2], [2, 2, 2, 2]]) - >>> a = torch.randn(2, 4, 3, 3, dtype=torch.complex64) - >>> torch.linalg.matrix_rank(a) + >>> A = torch.randn(2, 4, 3, 3, dtype=torch.complex64) + >>> torch.linalg.matrix_rank(A) tensor([[3, 3, 3, 3], [3, 3, 3, 3]]) - >>> torch.linalg.matrix_rank(a, hermitian=True) + >>> torch.linalg.matrix_rank(A, hermitian=True) tensor([[3, 3, 3, 3], [3, 3, 3, 3]]) - >>> torch.linalg.matrix_rank(a, tol=1.0) + >>> torch.linalg.matrix_rank(A, tol=1.0) tensor([[3, 2, 2, 2], [1, 2, 1, 2]]) - >>> torch.linalg.matrix_rank(a, tol=1.0, hermitian=True) + >>> torch.linalg.matrix_rank(A, tol=1.0, hermitian=True) tensor([[2, 2, 2, 1], [1, 2, 2, 2]]) """) @@ -1140,43 +1069,43 @@ >>> a = torch.arange(9, dtype=torch.float) - 4 >>> a tensor([-4., -3., -2., -1., 0., 1., 2., 3., 4.]) - >>> b = a.reshape((3, 3)) - >>> b + >>> B = a.reshape((3, 3)) + >>> B tensor([[-4., -3., -2.], [-1., 0., 1.], [ 2., 3., 4.]]) >>> LA.norm(a) tensor(7.7460) - >>> LA.norm(b) + >>> LA.norm(B) tensor(7.7460) - >>> LA.norm(b, 'fro') + >>> LA.norm(B, 'fro') tensor(7.7460) >>> LA.norm(a, float('inf')) tensor(4.) - >>> LA.norm(b, float('inf')) + >>> LA.norm(B, float('inf')) tensor(9.) >>> LA.norm(a, -float('inf')) tensor(0.) - >>> LA.norm(b, -float('inf')) + >>> LA.norm(B, -float('inf')) tensor(2.) >>> LA.norm(a, 1) tensor(20.) - >>> LA.norm(b, 1) + >>> LA.norm(B, 1) tensor(7.) >>> LA.norm(a, -1) tensor(0.) - >>> LA.norm(b, -1) + >>> LA.norm(B, -1) tensor(6.) >>> LA.norm(a, 2) tensor(7.7460) - >>> LA.norm(b, 2) + >>> LA.norm(B, 2) tensor(7.3485) >>> LA.norm(a, -2) tensor(0.) - >>> LA.norm(b.double(), -2) + >>> LA.norm(B.double(), -2) tensor(1.8570e-16, dtype=torch.float64) >>> LA.norm(a, 3) tensor(5.8480) @@ -1196,10 +1125,10 @@ Using the :attr:`dim` argument to compute matrix norms:: - >>> m = torch.arange(8, dtype=torch.float).reshape(2, 2, 2) - >>> LA.norm(m, dim=(1,2)) + >>> A = torch.arange(8, dtype=torch.float).reshape(2, 2, 2) + >>> LA.norm(A, dim=(1,2)) tensor([ 3.7417, 11.2250]) - >>> LA.norm(m[0, :, :]), LA.norm(m[1, :, :]) + >>> LA.norm(A[0, :, :]), LA.norm(A[1, :, :]) (tensor(3.7417), tensor(11.2250)) """) @@ -1264,14 +1193,14 @@ >>> a = torch.arange(9, dtype=torch.float) - 4 >>> a tensor([-4., -3., -2., -1., 0., 1., 2., 3., 4.]) - >>> b = a.reshape((3, 3)) - >>> b + >>> B = a.reshape((3, 3)) + >>> B tensor([[-4., -3., -2.], [-1., 0., 1.], [ 2., 3., 4.]]) >>> LA.vector_norm(a, ord=3.5) tensor(5.4345) - >>> LA.vector_norm(b, ord=3.5) + >>> LA.vector_norm(B, ord=3.5) tensor(5.4345) """) @@ -1408,16 +1337,12 @@ >>> multi_dot([torch.tensor([[1, 2]]), torch.tensor([[2], [3]])]) tensor([[8]]) - >>> a = torch.arange(2 * 3).view(2, 3) - >>> b = torch.arange(3 * 2).view(3, 2) - >>> c = torch.arange(2 * 2).view(2, 2) - >>> multi_dot((a, b, c)) + >>> A = torch.arange(2 * 3).view(2, 3) + >>> B = torch.arange(3 * 2).view(3, 2) + >>> C = torch.arange(2 * 2).view(2, 2) + >>> multi_dot((A, B, C)) tensor([[ 26, 49], [ 80, 148]]) - - >>> multi_dot((a.to(torch.float), torch.empty(3, 0), torch.empty(0, 2))) - tensor([[0., 0.], - [0., 0.]]) """) svd = _add_docstr(_linalg.linalg_svd, r""" @@ -1499,7 +1424,7 @@ numerically stable. :func:`torch.linalg.eig` for a function that computes another type of spectral - decomposition of a matrix. The eigendecomposition works just on on square matrices. + decomposition of a matrix. The eigendecomposition works just on square matrices. :func:`torch.linalg.eigh` for a (faster) function that computes the eigenvalue decomposition for Hermitian and symmetric matrices. @@ -1528,33 +1453,22 @@ Examples:: - >>> a = torch.randn(5, 3) - >>> a - tensor([[-0.3357, -0.2987, -1.1096], - [ 1.4894, 1.0016, -0.4572], - [-1.9401, 0.7437, 2.0968], - [ 0.1515, 1.3812, 1.5491], - [-1.8489, -0.5907, -2.5673]]) - >>> - >>> # reconstruction in the full_matrices=False case - >>> u, s, vh = torch.linalg.svd(a, full_matrices=False) - >>> u.shape, s.shape, vh.shape + >>> A = torch.randn(5, 3) + >>> U, S, Vh = torch.linalg.svd(A, full_matrices=False) + >>> U.shape, S.shape, Vh.shape (torch.Size([5, 3]), torch.Size([3]), torch.Size([3, 3])) - >>> torch.dist(a, u @ torch.diag(s) @ vh) + >>> torch.dist(A, U @ torch.diag(S) @ Vh) tensor(1.0486e-06) - >>> - >>> # reconstruction in the full_matrices=True case - >>> u, s, vh = torch.linalg.svd(a) - >>> u.shape, s.shape, vh.shape + + >>> U, S, Vh = torch.linalg.svd(A) + >>> U.shape, S.shape, Vh.shape (torch.Size([5, 5]), torch.Size([3]), torch.Size([3, 3])) - >>> torch.dist(a, u[:, :3] @ torch.diag(s) @ vh) - >>> torch.dist(a, u[:, :3] @ torch.diag(s) @ vh) + >>> torch.dist(A, U[:, :3] @ torch.diag(S) @ Vh) tensor(1.0486e-06) - >>> - >>> # extra dimensions - >>> a_big = torch.randn(7, 5, 3) - >>> u, s, vh = torch.linalg.svd(a_big, full_matrices=False) - >>> torch.dist(a_big, u @ torch.diag_embed(s) @ vh) + + >>> A = torch.randn(7, 5, 3) + >>> U, S, Vh = torch.linalg.svd(A, full_matrices=False) + >>> torch.dist(A, U @ torch.diag_embed(S) @ Vh) tensor(3.0957e-06) .. _the resulting vectors will span the same subspace: @@ -1593,17 +1507,13 @@ Examples:: - >>> import torch - >>> a = torch.randn(5, 3) - >>> a - tensor([[-1.3490, -0.1723, 0.7730], - [-1.6118, -0.3385, -0.6490], - [ 0.0908, 2.0704, 0.5647], - [-0.6451, 0.1911, 0.7353], - [ 0.5247, 0.5160, 0.5110]]) - >>> s = torch.linalg.svdvals(a) - >>> s + >>> A = torch.randn(5, 3) + >>> S = torch.linalg.svdvals(A) + >>> S tensor([2.5139, 2.1087, 1.1066]) + + >>> torch.dist(S, torch.linalg.svd(A, full_matrices=False).S) + tensor(2.4576e-07) """) cond = _add_docstr(_linalg.linalg_cond, r""" @@ -1659,7 +1569,7 @@ In these cases, it is computed using :func:`torch.linalg.svd`. For these norms, the matrix (or every matrix in the batch) :attr:`A` may have any shape. -.. note :: When inputs are on a CUDA device, this function synchronizes that device with the CPU if +.. note :: When inputs are on a CUDA device, this function synchronizes that device with the CPU if :attr:`p` is one of `('fro', 'nuc', inf, -inf, 1, -1)`. .. seealso:: @@ -1689,55 +1599,36 @@ Examples:: - >>> a = torch.randn(3, 4, 4, dtype=torch.complex64) - >>> torch.linalg.cond(a) - >>> a = torch.tensor([[1., 0, -1], [0, 1, 0], [1, 0, 1]]) - >>> torch.linalg.cond(a) + >>> A = torch.randn(3, 4, 4, dtype=torch.complex64) + >>> torch.linalg.cond(A) + >>> A = torch.tensor([[1., 0, -1], [0, 1, 0], [1, 0, 1]]) + >>> torch.linalg.cond(A) tensor([1.4142]) - >>> torch.linalg.cond(a, 'fro') + >>> torch.linalg.cond(A, 'fro') tensor(3.1623) - >>> torch.linalg.cond(a, 'nuc') + >>> torch.linalg.cond(A, 'nuc') tensor(9.2426) - >>> torch.linalg.cond(a, float('inf')) + >>> torch.linalg.cond(A, float('inf')) tensor(2.) - >>> torch.linalg.cond(a, float('-inf')) + >>> torch.linalg.cond(A, float('-inf')) tensor(1.) - >>> torch.linalg.cond(a, 1) + >>> torch.linalg.cond(A, 1) tensor(2.) - >>> torch.linalg.cond(a, -1) + >>> torch.linalg.cond(A, -1) tensor(1.) - >>> torch.linalg.cond(a, 2) + >>> torch.linalg.cond(A, 2) tensor([1.4142]) - >>> torch.linalg.cond(a, -2) + >>> torch.linalg.cond(A, -2) tensor([0.7071]) - >>> a = torch.randn(2, 3, 3) - >>> a - tensor([[[-0.9204, 1.1140, 1.2055], - [ 0.3988, -0.2395, -0.7441], - [-0.5160, 0.3115, 0.2619]], - - [[-2.2128, 0.9241, 2.1492], - [-1.1277, 2.7604, -0.8760], - [ 1.2159, 0.5960, 0.0498]]]) - >>> torch.linalg.cond(a) + >>> A = torch.randn(2, 3, 3) + >>> torch.linalg.cond(A) tensor([[9.5917], [3.2538]]) - - >>> a = torch.randn(2, 3, 3, dtype=torch.complex64) - >>> a - tensor([[[-0.4671-0.2137j, -0.1334-0.9508j, 0.6252+0.1759j], - [-0.3486-0.2991j, -0.1317+0.1252j, 0.3025-0.1604j], - [-0.5634+0.8582j, 0.1118-0.4677j, -0.1121+0.7574j]], - - [[ 0.3964+0.2533j, 0.9385-0.6417j, -0.0283-0.8673j], - [ 0.2635+0.2323j, -0.8929-1.1269j, 0.3332+0.0733j], - [ 0.1151+0.1644j, -1.1163+0.3471j, -0.5870+0.1629j]]]) - >>> torch.linalg.cond(a) + >>> A = torch.randn(2, 3, 3, dtype=torch.complex64) + >>> torch.linalg.cond(A) tensor([[4.6245], [4.5671]]) - >>> torch.linalg.cond(a, 1) - tensor([9.2589, 9.3486]) """) pinv = _add_docstr(_linalg.linalg_pinv, r""" @@ -1765,7 +1656,7 @@ .. note:: Consider using :func:`torch.linalg.lstsq` if possible for multiplying a matrix on the left by - the the pseudoinverse, as:: + the pseudoinverse, as:: torch.linalg.lstsq(A, B).solution == A.pinv() @ B @@ -1812,48 +1703,16 @@ [ 0.1356, 0.3933, -0.5023], [-0.0308, -0.1725, -0.5216]]) - Batched linalg.pinv example >>> A = torch.randn(2, 6, 3) - >>> B = torch.linalg.pinv(A) - >>> torch.matmul(B, A).round() - tensor([[[1., -0., 0.], - [0., 1., -0.], - [0., 0., 1.]], - - [[1., -0., 0.], - [-0., 1., 0.], - [-0., -0., 1.]]]) + >>> Apinv = torch.linalg.pinv(A) + >>> torch.dist(Apinv @ A, torch.eye(3)) + tensor(8.5633e-07) - Hermitian input example >>> A = torch.randn(3, 3, dtype=torch.complex64) - >>> A = A + A.t().conj() # creates a Hermitian matrix - >>> B = torch.linalg.pinv(A, hermitian=True) - >>> torch.matmul(B, A) - tensor([[ 1.0000e+00+0.0000e+00j, -1.1921e-07-2.3842e-07j, - 5.9605e-08-2.3842e-07j], - [ 5.9605e-08+2.3842e-07j, 1.0000e+00+2.3842e-07j, - -4.7684e-07+1.1921e-07j], - [-1.1921e-07+0.0000e+00j, -2.3842e-07-2.9802e-07j, - 1.0000e+00-1.7897e-07j]]) - - Non-default rcond example - >>> rcond = 0.5 - >>> A = torch.randn(3, 3) - >>> torch.linalg.pinv(A) - tensor([[ 0.2971, -0.4280, -2.0111], - [-0.0090, 0.6426, -0.1116], - [-0.7832, -0.2465, 1.0994]]) - >>> torch.linalg.pinv(A, rcond) - tensor([[-0.2672, -0.2351, -0.0539], - [-0.0211, 0.6467, -0.0698], - [-0.4400, -0.3638, -0.0910]]) - - Matrix-wise rcond example - >>> A = torch.randn(5, 6, 2, 3, 3) - >>> rcond = torch.rand(2) # different rcond values for each matrix in a[:, :, 0] and a[:, :, 1] - >>> torch.linalg.pinv(A, rcond) - >>> rcond = torch.randn(5, 6, 2) # different rcond value for each matrix in 'a' - >>> torch.linalg.pinv(A, rcond) + >>> A = A + A.T.conj() # creates a Hermitian matrix + >>> Apinv = torch.linalg.pinv(A, hermitian=True) + >>> torch.dist(Apinv @ A, torch.eye(3)) + tensor(1.0830e-06) .. _defined algebraically: https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse#Existence_and_uniqueness @@ -1889,9 +1748,8 @@ This function then returns the solution of the resulting batch of systems of linear equations. .. note:: - While `X = \ `:attr:`A`\ `.inv() @ \ `:attr:`B`, this function computes the - solution in a faster and more numerically stable way than performing the - computations separately. + This function computes `X = \ `:attr:`A`\ `.inverse() @ \ `:attr:`B` in a faster and + more numerically stable way than performing the computations separately. """ + fr""" .. note:: {common_notes["sync_note"]} @@ -1911,7 +1769,7 @@ Examples:: - >>> A = torch.rand(3, 3) + >>> A = torch.randn(3, 3) >>> b = torch.randn(3) >>> x = torch.linalg.solve(A, b) >>> torch.allclose(A @ x, b) @@ -1924,8 +1782,6 @@ >>> torch.allclose(A @ X, B) True -Broadcasting:: - >>> A = torch.randn(2, 3, 3) >>> b = torch.randn(3, 1) >>> x = torch.linalg.solve(A, b) # b is broadcasted to size (2, 3, 1) @@ -1933,7 +1789,7 @@ torch.Size([2, 3, 1]) >>> torch.allclose(A @ x, b) True - >>> b = torch.rand(3) + >>> b = torch.randn(3) >>> x = torch.linalg.solve(A, b) # b is broadcasted to size (2, 3) >>> x.shape torch.Size([2, 3]) @@ -1995,18 +1851,18 @@ Examples:: - >>> a = torch.eye(4 * 6).reshape((4, 6, 8, 3)) - >>> ainv = torch.linalg.tensorinv(a, ind=2) - >>> ainv.shape + >>> A = torch.eye(4 * 6).reshape((4, 6, 8, 3)) + >>> Ainv = torch.linalg.tensorinv(A, ind=2) + >>> Ainv.shape torch.Size([8, 3, 4, 6]) - >>> b = torch.randn(4, 6) - >>> torch.allclose(torch.tensordot(ainv, b), torch.linalg.tensorsolve(a, b)) + >>> B = torch.randn(4, 6) + >>> torch.allclose(torch.tensordot(Ainv, B), torch.linalg.tensorsolve(A, B)) True - >>> a = torch.randn(4, 4) - >>> a_tensorinv = torch.linalg.tensorinv(a, ind=1) - >>> a_inv = torch.inverse(a) - >>> torch.allclose(a_tensorinv, a_inv) + >>> A = torch.randn(4, 4) + >>> Atensorinv = torch.linalg.tensorinv(A, ind=1) + >>> Ainv = torch.linalg.inverse(A) + >>> torch.allclose(Atensorinv, Ainv) True """) @@ -2053,23 +1909,23 @@ Examples:: - >>> a = torch.eye(2 * 3 * 4).reshape((2 * 3, 4, 2, 3, 4)) - >>> b = torch.randn(2 * 3, 4) - >>> x = torch.linalg.tensorsolve(a, b) - >>> x.shape + >>> A = torch.eye(2 * 3 * 4).reshape((2 * 3, 4, 2, 3, 4)) + >>> B = torch.randn(2 * 3, 4) + >>> X = torch.linalg.tensorsolve(A, B) + >>> X.shape torch.Size([2, 3, 4]) - >>> torch.allclose(torch.tensordot(a, x, dims=x.ndim), b) + >>> torch.allclose(torch.tensordot(A, X, dims=X.ndim), B) True - >>> a = torch.randn(6, 4, 4, 3, 2) - >>> b = torch.randn(4, 3, 2) - >>> x = torch.linalg.tensorsolve(a, b, dims=(0, 2)) - >>> x.shape + >>> A = torch.randn(6, 4, 4, 3, 2) + >>> B = torch.randn(4, 3, 2) + >>> X = torch.linalg.tensorsolve(A, B, dims=(0, 2)) + >>> X.shape torch.Size([6, 4]) - >>> a = a.permute(1, 3, 4, 0, 2) - >>> a.shape[b.ndim:] + >>> A = A.permute(1, 3, 4, 0, 2) + >>> A.shape[B.ndim:] torch.Size([6, 4]) - >>> torch.allclose(torch.tensordot(a, x, dims=x.ndim), b, atol=1e-6) + >>> torch.allclose(torch.tensordot(A, X, dims=X.ndim), B, atol=1e-6) True """) @@ -2144,33 +2000,33 @@ Examples:: - >>> a = torch.tensor([[12., -51, 4], [6, 167, -68], [-4, 24, -41]]) - >>> q, r = torch.linalg.qr(a) - >>> q + >>> A = torch.tensor([[12., -51, 4], [6, 167, -68], [-4, 24, -41]]) + >>> Q, R = torch.linalg.qr(A) + >>> Q tensor([[-0.8571, 0.3943, 0.3314], [-0.4286, -0.9029, -0.0343], [ 0.2857, -0.1714, 0.9429]]) - >>> r + >>> R tensor([[ -14.0000, -21.0000, 14.0000], [ 0.0000, -175.0000, 70.0000], [ 0.0000, 0.0000, -35.0000]]) - >>> torch.mm(q, r).round() + >>> (Q @ R).round() tensor([[ 12., -51., 4.], [ 6., 167., -68.], [ -4., 24., -41.]]) - >>> torch.mm(q.t(), q).round() + >>> (Q.T @ Q).round() tensor([[ 1., 0., 0.], [ 0., 1., -0.], [ 0., -0., 1.]]) - >>> q2, r2 = torch.linalg.qr(a, mode='r') - >>> q2 + >>> Q2, R2 = torch.linalg.qr(A, mode='r') + >>> Q2 tensor([]) - >>> torch.equal(r, r2) - True - >>> a = torch.randn(3, 4, 5) - >>> q, r = torch.linalg.qr(a, mode='complete') - >>> torch.allclose(torch.matmul(q, r), a, atol=1e-5) - True - >>> torch.allclose(torch.matmul(q.transpose(-2, -1), q), torch.eye(4), atol=1e-5) + >>> torch.equal(R, R2) True + >>> A = torch.randn(3, 4, 5) + >>> Q, R = torch.linalg.qr(A, mode='complete') + >>> torch.dist(Q @ R, A) + tensor(1.6099e-06) + >>> torch.dist(Q.transpose(-2, -1) @ Q, torch.eye(4)) + tensor(6.2158e-07) """) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 5540c4f09a29..e13fc9473c0b 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -6599,7 +6599,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): autodiff_nonfusible_nodes=['aten::mul'],), OpInfo('__rand__', op=torch.Tensor.__rand__, - dtypes=integral_types_and(), + dtypes=integral_types_and(torch.bool), sample_inputs_func=sample_inputs_rbinops, supports_out=False, skips=(SkipInfo('TestCommon', 'test_variant_consistency_jit',),), @@ -6607,7 +6607,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): supports_forward_ad=True,), OpInfo('__ror__', op=torch.Tensor.__ror__, - dtypes=integral_types_and(), + dtypes=integral_types_and(torch.bool), sample_inputs_func=sample_inputs_rbinops, supports_out=False, skips=(SkipInfo('TestCommon', 'test_variant_consistency_jit',),), @@ -6615,7 +6615,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): supports_forward_ad=True,), OpInfo('__rxor__', op=torch.Tensor.__rxor__, - dtypes=integral_types_and(), + dtypes=integral_types_and(torch.bool), sample_inputs_func=sample_inputs_rbinops, supports_out=False, skips=(SkipInfo('TestCommon', 'test_variant_consistency_jit',),),