From 0c6f207050c0c00a6a9b694d435e558bf832c5e1 Mon Sep 17 00:00:00 2001 From: Xiao Wang <24860335+xwang233@users.noreply.github.com> Date: Sun, 7 Nov 2021 18:22:35 -0800 Subject: [PATCH 01/30] linalg.cuda_prefer_cusolver flag py/c++ bindings --- aten/src/ATen/Context.cpp | 9 +++++++ aten/src/ATen/Context.h | 4 ++++ test/test_linalg.py | 5 ++++ torch/_C/__init__.pyi.in | 2 ++ torch/__init__.py | 2 ++ torch/backends/linalg/__init__.py | 40 +++++++++++++++++++++++++++++++ torch/csrc/Module.cpp | 18 ++++++++++++++ 7 files changed, 80 insertions(+) create mode 100644 torch/backends/linalg/__init__.py diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index f5875085503e6..46597597d9a98 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -147,6 +147,15 @@ void Context::setAllowTF32CuBLAS(bool b) { allow_tf32_cublas = b; } +bool Context::linalgCudaPreferCusolver() const { + return linalg_cuda_prefer_cusolver; + +} + +void Context::setLinalgCudaPreferCusolver(bool b) { + linalg_cuda_prefer_cusolver = b; +} + bool Context::hasMKL() { #if AT_MKL_ENABLED() return true; diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h index 28d7ea35094a3..86c16e5adf976 100644 --- a/aten/src/ATen/Context.h +++ b/aten/src/ATen/Context.h @@ -122,6 +122,9 @@ class TORCH_API Context { bool deterministicCuDNN() const; void setDeterministicCuDNN(bool); + bool linalgCudaPreferCusolver() const; + void setLinalgCudaPreferCusolver(bool); + // Note [Enabling Deterministic Operations] // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // Operations in PyTorch that normally act nondeterministically, but have an alternate @@ -240,6 +243,7 @@ class TORCH_API Context { bool allow_tf32_cudnn = true; bool allow_tf32_cublas = true; bool enabled_mkldnn = true; + bool linalg_cuda_prefer_cusolver = false; // if this is set to false, use existing heuristics #ifdef C10_MOBILE bool release_original_weights = true; #else diff --git a/test/test_linalg.py b/test/test_linalg.py index 45efb5a422725..fbf95ab9d0266 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -8082,6 +8082,11 @@ def test_tensordot(self, device): an = torch.from_numpy(np.tensordot(np.zeros((), dtype=np.float32), np.zeros((), dtype=np.float32), 0)) self.assertEqual(a, an) + def test_linalg_cuda_prefer_cusolver_get_set(self): + with torch.backends.linalg.flags(cuda_prefer_cusolver=False): + self.assertFalse(torch.backends.linalg.cuda_prefer_cusolver) + with torch.backends.linalg.flags(cuda_prefer_cusolver=True): + self.assertTrue(torch.backends.linalg.cuda_prefer_cusolver) instantiate_device_type_tests(TestLinalg, globals()) diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index ebbb912c16e99..f179b09eaa0b8 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -594,6 +594,8 @@ def _get_warnAlways() -> _bool: ... # THPModule_warnAlways def _set_warnAlways(arg: _bool) -> None: ... # THPModule_setWarnAlways def _get_cudnn_allow_tf32() -> _bool: ... # THPModule_allowTF32CuDNN def _set_cudnn_allow_tf32(arg: _bool) -> None: ... # THPModule_setAllowTF32CuDNN +def _get_linalg_cuda_prefer_cusolver() -> _bool: ... # THPModule_linalgCudaPreferCusolver +def _set_linalg_cuda_prefer_cusolver(arg: _bool) -> None: ... # THPModule_setLinalgCudaPreferCusolver def _get_cublas_allow_tf32() -> _bool: ... # THPModule_allowTF32CuBLAS def _set_cublas_allow_tf32(arg: _bool) -> None: ... # THPModule_setAllowTF32CuBLAS # NB: There is no Capsule type in typing, see diff --git a/torch/__init__.py b/torch/__init__.py index b403bb0f749e6..32ebc1f00290a 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -747,9 +747,11 @@ def _assert(condition, message): from torch import distributions as distributions from torch import testing as testing import torch.backends.cuda +import torch.backends.cudnn import torch.backends.mkl import torch.backends.mkldnn import torch.backends.openmp +import torch.backends.linalg import torch.backends.quantized import torch.utils.data from torch import __config__ as __config__ diff --git a/torch/backends/linalg/__init__.py b/torch/backends/linalg/__init__.py new file mode 100644 index 0000000000000..acb132c4e4a49 --- /dev/null +++ b/torch/backends/linalg/__init__.py @@ -0,0 +1,40 @@ +import sys +import torch +from contextlib import contextmanager +from torch.backends import ContextProp, PropModule, __allow_nonbracketed_mutation + +def set_flags(_cuda_prefer_cusolver=None): + orig_flags = (torch._C._get_linalg_cuda_prefer_cusolver(),) + if _cuda_prefer_cusolver is not None: + torch._C._set_linalg_cuda_prefer_cusolver(_cuda_prefer_cusolver) + return orig_flags + + +@contextmanager +def flags(cuda_prefer_cusolver=True): + with __allow_nonbracketed_mutation(): + orig_flags = set_flags(cuda_prefer_cusolver) + try: + yield + finally: + # recover the previous values + with __allow_nonbracketed_mutation(): + set_flags(*orig_flags) + + +# The magic here is to allow us to intercept code like this: +# +# torch.backends..enabled = True + +class LinalgModule(PropModule): + def __init__(self, m, name): + super(LinalgModule, self).__init__(m, name) + + cuda_prefer_cusolver = ContextProp(torch._C._get_linalg_cuda_prefer_cusolver, torch._C._set_linalg_cuda_prefer_cusolver) + +# This is the sys.modules replacement trick, see +# https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273 +sys.modules[__name__] = LinalgModule(sys.modules[__name__], __name__) + +# Add type annotation for the replaced module +cuda_prefer_cusolver: bool diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index b04c8f9ceb6c3..3f764952c1a68 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -522,6 +522,22 @@ PyObject *THPModule_allowTF32CuBLAS(PyObject *_unused, PyObject *noargs) Py_RETURN_FALSE; } +PyObject *THPModule_setLinalgCudaPreferCusolver(PyObject *_unused, PyObject *arg) +{ + THPUtils_assert(PyBool_Check(arg), "set_linalg_cuda_prefer_cusolver expects a bool, " + "but got %s", THPUtils_typename(arg)); + at::globalContext().setLinalgCudaPreferCusolver(arg == Py_True); + Py_RETURN_NONE; +} + +PyObject *THPModule_linalgCudaPreferCusolver(PyObject *_unused, PyObject *noargs) +{ + if (at::globalContext().linalgCudaPreferCusolver()) { + Py_RETURN_TRUE; + } + Py_RETURN_FALSE; +} + PyObject *THPModule_setFlushDenormal(PyObject *_unused, PyObject *arg) { THPUtils_assert(PyBool_Check(arg), "flush_denormal expects a bool, " "but got %s", THPUtils_typename(arg)); @@ -676,6 +692,8 @@ static PyMethodDef TorchMethods[] = { {"_set_warnAlways", THPModule_setWarnAlways, METH_O, nullptr}, {"_get_cublas_allow_tf32", THPModule_allowTF32CuBLAS, METH_NOARGS, nullptr}, {"_set_cublas_allow_tf32", THPModule_setAllowTF32CuBLAS, METH_O, nullptr}, + {"_get_linalg_cuda_prefer_cusolver", THPModule_linalgCudaPreferCusolver, METH_NOARGS, nullptr}, + {"_set_linalg_cuda_prefer_cusolver", THPModule_setLinalgCudaPreferCusolver, METH_O, nullptr}, {"_vmapmode_increment_nesting", THPModule_vmapmode_increment_nesting, METH_NOARGS, nullptr}, {"_vmapmode_decrement_nesting", THPModule_vmapmode_decrement_nesting, METH_NOARGS, nullptr}, {"_debug_only_display_vmap_fallback_warnings", THPModule_set_display_vmap_fallback_warnings_mode, METH_O, nullptr}, From e135ef16558db4127cae6de4d82e4e413ccad097 Mon Sep 17 00:00:00 2001 From: Xiao Wang <24860335+xwang233@users.noreply.github.com> Date: Sun, 7 Nov 2021 19:04:56 -0800 Subject: [PATCH 02/30] global flag override heuristics --- .../ATen/native/cuda/BatchLinearAlgebra.cpp | 29 +++++++++++++++---- 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cpp b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cpp index 4a7cba6321030..32137909b6fdc 100644 --- a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cpp @@ -1471,7 +1471,10 @@ Tensor _inverse_helper_cuda_legacy(const Tensor& self) { Tensor _inverse_helper_cuda(const Tensor& self) { #ifdef USE_CUSOLVER - if ((self.dim() == 2) || (/* self.dim() > 2 && */ batchCount(self) <= 2) || !use_magma_) { + if ((self.dim() == 2) || + (/* self.dim() > 2 && */ batchCount(self) <= 2) || + !use_magma_ || + at::globalContext().linalgCudaPreferCusolver()) { return _inverse_helper_cuda_lib(self); // cusolver or cublas } else { return _inverse_helper_cuda_legacy(self); // magma-cuda @@ -1503,7 +1506,10 @@ Tensor& _linalg_inv_out_helper_cuda(Tensor &result, Tensor& infos_lu, Tensor& in // This function calculates the inverse matrix in-place // result should be in column major order and contain matrices to invert #ifdef USE_CUSOLVER - if ((result.dim() == 2) || (/* result.dim() > 2 && */ batchCount(result) <= 2) || !use_magma_) { + if ((result.dim() == 2) || + (/* result.dim() > 2 && */ batchCount(result) <= 2) || + !use_magma_ || + at::globalContext().linalgCudaPreferCusolver()) { return _linalg_inv_out_helper_cuda_lib(result, infos_lu, infos_getri); // cusolver or cublas } else { return _linalg_inv_out_helper_cuda_legacy(result, infos_lu, infos_getri); // magma-cuda @@ -1600,7 +1606,9 @@ Tensor _cholesky_solve_helper_cuda_magma(const Tensor& self, const Tensor& A, bo // Batched cholesky_solve is dispatched to magma. Tensor _cholesky_solve_helper_cuda(const Tensor& self, const Tensor& A, bool upper) { #ifdef USE_CUSOLVER - if (batchCount(self) == 1 || !use_magma_) { + if (batchCount(self) == 1 || + !use_magma_ || + at::globalContext().linalgCudaPreferCusolver()) { return _cholesky_solve_helper_cuda_cusolver(self, A, upper); } else { return _cholesky_solve_helper_cuda_magma(self, A, upper); @@ -1706,7 +1714,10 @@ void cholesky_helper_magma(const Tensor& input, bool upper, const Tensor& info) static void cholesky_kernel(const Tensor& input, const Tensor& info, bool upper) { #ifdef USE_CUSOLVER - if (batchCount(input) == 1 || !use_magma_ || use_cusolver_potrf_batched_) { + if (batchCount(input) == 1 || + !use_magma_ || + use_cusolver_potrf_batched_ || + at::globalContext().linalgCudaPreferCusolver()) { cholesky_helper_cusolver(input, upper, info); } else { cholesky_helper_magma(input, upper, info); @@ -1777,7 +1788,9 @@ Tensor& cholesky_inverse_kernel_impl(Tensor &result, Tensor& infos, bool upper) // result should be in column major order and contain matrices to invert // the content of result is overwritten by 'apply_cholesky_inverse' #ifdef USE_CUSOLVER - if (batchCount(result) == 1 || !use_magma_) { + if (batchCount(result) == 1 || + !use_magma_ || + at::globalContext().linalgCudaPreferCusolver()) { return cholesky_inverse_kernel_impl_cusolver(result, infos, upper); } else { return cholesky_inverse_kernel_impl_magma(result, infos, upper); @@ -1947,7 +1960,11 @@ static void apply_lu(const Tensor& input, const Tensor& pivots, const Tensor& in // Use a heuristic to determine that cusolver is faster than MAGMA for the following sizes. auto m = input.size(-2); // exclude complex128 since nan_to_num_ does not work with it. - if ((batch_size == 1 || (batch_size <= 8 && m <= 16) || !use_magma_ ) && !input.is_complex()) { + if ((batch_size == 1 || + (batch_size <= 8 && m <= 16) || + !use_magma_ || + at::globalContext().linalgCudaPreferCusolver()) + && !input.is_complex()) { lu_looped_cusolver(input, pivots, infos, compute_pivots); } #else From 400771a27c009e2445f3ecf31a1197448bf5e7f8 Mon Sep 17 00:00:00 2001 From: Xiao Wang <24860335+xwang233@users.noreply.github.com> Date: Sun, 7 Nov 2021 19:43:21 -0800 Subject: [PATCH 03/30] doc and warning --- aten/src/ATen/Context.cpp | 8 +++++++- docs/source/backends.rst | 10 ++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index 46597597d9a98..6a54978c045f5 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -149,11 +149,17 @@ void Context::setAllowTF32CuBLAS(bool b) { bool Context::linalgCudaPreferCusolver() const { return linalg_cuda_prefer_cusolver; - } void Context::setLinalgCudaPreferCusolver(bool b) { linalg_cuda_prefer_cusolver = b; + if (b) { + TORCH_WARN_ONCE( + "torch.backends.linalg.cuda_prefer_cusolver is an experimental feature. " + "If you see any error or regression when this flag is enabled, " + "you're encourged to file an issue on GitHub." + ); + } } bool Context::hasMKL() { diff --git a/docs/source/backends.rst b/docs/source/backends.rst index 3136c4ee7820b..93ec140606bd9 100644 --- a/docs/source/backends.rst +++ b/docs/source/backends.rst @@ -10,6 +10,7 @@ These backends include: - ``torch.backends.cuda`` - ``torch.backends.cudnn`` +- ``torch.backends.linalg`` - ``torch.backends.mkl`` - ``torch.backends.mkldnn`` - ``torch.backends.openmp`` @@ -69,6 +70,15 @@ torch.backends.cudnn A :class:`bool` that, if True, causes cuDNN to benchmark multiple convolution algorithms and select the fastest. +torch.backends.linalg +^^^^^^^^^^^^^^^^^^^^^ + +.. attribute:: torch.backends.linalg.cuda_prefer_cusolver + + .. warning:: This flag is experimental and subject to change. + + A :class:`bool` that lets pytorch to prefer cuSOLVER implementations when calling linear algebra functions on GPU. + torch.backends.mkl ^^^^^^^^^^^^^^^^^^ From abbf65d9b35b9d7f8e4c4590e4f3f82a2b4e5631 Mon Sep 17 00:00:00 2001 From: Xiao Wang <24860335+xwang233@users.noreply.github.com> Date: Sun, 7 Nov 2021 19:57:56 -0800 Subject: [PATCH 04/30] format --- docs/source/backends.rst | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/docs/source/backends.rst b/docs/source/backends.rst index 93ec140606bd9..fd0361138ebd5 100644 --- a/docs/source/backends.rst +++ b/docs/source/backends.rst @@ -70,6 +70,7 @@ torch.backends.cudnn A :class:`bool` that, if True, causes cuDNN to benchmark multiple convolution algorithms and select the fastest. + torch.backends.linalg ^^^^^^^^^^^^^^^^^^^^^ @@ -77,7 +78,11 @@ torch.backends.linalg .. warning:: This flag is experimental and subject to change. - A :class:`bool` that lets pytorch to prefer cuSOLVER implementations when calling linear algebra functions on GPU. + The default value for this flag is *False*. + + A :class:`bool` that, if True, lets pytorch prefer cuSOLVER implementations when calling linear algebra functions on GPU. + + Note: cuSOLVER implementations may still be used in some functions even if this flag is set to False. torch.backends.mkl From b9465a568c1fe2e31b7b1b8f008348ad31e621ac Mon Sep 17 00:00:00 2001 From: Xiao Wang <24860335+xwang233@users.noreply.github.com> Date: Wed, 10 Nov 2021 20:16:23 -0800 Subject: [PATCH 05/30] LinalgBackend py/cpp bindings --- aten/src/ATen/Context.cpp | 15 ++++ aten/src/ATen/Context.h | 4 + c10/core/LinalgBackend.h | 39 +++++++++ tools/build_variables.bzl | 2 + torch/_C/__init__.pyi.in | 10 +++ torch/csrc/LinalgBackend.cpp | 81 +++++++++++++++++++ torch/csrc/LinalgBackend.h | 26 ++++++ torch/csrc/Module.cpp | 23 ++++++ .../csrc/jit/python/python_sugared_value.cpp | 5 ++ torch/csrc/utils/linalg_backends.cpp | 37 +++++++++ torch/csrc/utils/linalg_backends.h | 7 ++ torch/csrc/utils/python_arg_parser.cpp | 3 + torch/csrc/utils/python_arg_parser.h | 18 ++++- 13 files changed, 269 insertions(+), 1 deletion(-) create mode 100644 c10/core/LinalgBackend.h create mode 100644 torch/csrc/LinalgBackend.cpp create mode 100644 torch/csrc/LinalgBackend.h create mode 100644 torch/csrc/utils/linalg_backends.cpp create mode 100644 torch/csrc/utils/linalg_backends.h diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index 894ba8eb5d7ae..e386e16a1a1ad 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -162,6 +162,21 @@ void Context::setLinalgCudaPreferCusolver(bool b) { } } +at::LinalgBackend Context::linalgPreferredBackend() const { + return linalg_preferred_backend; +} + +void Context::setLinalgPreferredBackend(at::LinalgBackend b) { + linalg_preferred_backend = b; + if (b != at::LinalgBackend::Default) { + TORCH_WARN_ONCE( + "torch.backends.linalg.preferred is an experimental feature. " + "If you see any error or regression when this flag is set, " + "you're encourged to file an issue on GitHub." + ); + } +} + bool Context::allowFP16ReductionCuBLAS() const { return allow_fp16_reduction_cublas; } diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h index e6cbd61dee7f4..61493b63e318f 100644 --- a/aten/src/ATen/Context.h +++ b/aten/src/ATen/Context.h @@ -10,6 +10,7 @@ #include #include #include +#include #include #include @@ -124,6 +125,8 @@ class TORCH_API Context { bool linalgCudaPreferCusolver() const; void setLinalgCudaPreferCusolver(bool); + at::LinalgBackend linalgPreferredBackend() const; + void setLinalgPreferredBackend(at::LinalgBackend); // Note [Enabling Deterministic Operations] // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -247,6 +250,7 @@ class TORCH_API Context { bool allow_fp16_reduction_cublas = true; bool enabled_mkldnn = true; bool linalg_cuda_prefer_cusolver = false; // if this is set to false, use existing heuristics + at::LinalgBackend linalg_preferred_backend = at::LinalgBackend::Default; #ifdef C10_MOBILE bool release_original_weights = true; #else diff --git a/c10/core/LinalgBackend.h b/c10/core/LinalgBackend.h new file mode 100644 index 0000000000000..59e77e76db44e --- /dev/null +++ b/c10/core/LinalgBackend.h @@ -0,0 +1,39 @@ +#pragma once + +#include + +#include +#include + +namespace c10 { + +enum class LinalgBackend : int8_t { + Default, + Cusolver, + Magma +}; + +inline std::string LinalgBackendToString(at::LinalgBackend backend) { + switch (backend) { + case LinalgBackend::Default: + return "linalg_default"; + case LinalgBackend::Cusolver: + return "linalg_cusolver"; + case LinalgBackend::Magma: + return "linalg_magma"; + default: + TORCH_CHECK(false, "Unknown memory format"); + } +} + +inline std::string LinalgBackendToRepr(at::LinalgBackend backend) { + return std::string("torch.") + at::LinalgBackendToString(backend); +} + +inline std::ostream& operator<<( + std::ostream& stream, + at::LinalgBackend backend) { + return stream << LinalgBackendToString(backend); +} + +} // namespace c10 diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl index e5339bf7e0b2a..3846d44b64f5a 100644 --- a/tools/build_variables.bzl +++ b/tools/build_variables.bzl @@ -737,6 +737,7 @@ libtorch_python_core_sources = [ "torch/csrc/Generator.cpp", "torch/csrc/Layout.cpp", "torch/csrc/MemoryFormat.cpp", + "torch/csrc/LinalgBackend.cpp", "torch/csrc/QScheme.cpp", "torch/csrc/Module.cpp", "torch/csrc/python_dimname.cpp", @@ -806,6 +807,7 @@ libtorch_python_core_sources = [ "torch/csrc/utils.cpp", "torch/csrc/utils/cuda_lazy_init.cpp", "torch/csrc/utils/invalid_arguments.cpp", + "torch/csrc/utils/linalg_backends.cpp", "torch/csrc/utils/object_ptr.cpp", "torch/csrc/utils/python_arg_parser.cpp", "torch/csrc/utils/python_dispatch.cpp", diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 85bffeced427d..abb13daec54ae 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -117,6 +117,14 @@ channels_last: memory_format = ... channels_last_3d: memory_format = ... preserve_format: memory_format = ... +# Defined in torch/csrc/LinalgBackend.cpp +class linalg_backend: ... + +# Defined in torch/csrc/utils/linalg_backends.cpp +linalg_default: linalg_backend = ... +linalg_cusolver: linalg_backend = ... +linalg_magma: linalg_backend = ... + # Defined in torch/csrc/QScheme.cpp class qscheme: ... @@ -597,6 +605,8 @@ def _get_cudnn_allow_tf32() -> _bool: ... # THPModule_allowTF32CuDNN def _set_cudnn_allow_tf32(arg: _bool) -> None: ... # THPModule_setAllowTF32CuDNN def _get_linalg_cuda_prefer_cusolver() -> _bool: ... # THPModule_linalgCudaPreferCusolver def _set_linalg_cuda_prefer_cusolver(arg: _bool) -> None: ... # THPModule_setLinalgCudaPreferCusolver +def _get_linalg_preferred_backend() -> _bool: ... # THPModule_linalgPreferredBackend +def _set_linalg_preferred_backend(arg: _bool) -> None: ... # THPModule_setLinalgPreferredBackend def _get_cublas_allow_tf32() -> _bool: ... # THPModule_allowTF32CuBLAS def _set_cublas_allow_tf32(arg: _bool) -> None: ... # THPModule_setAllowTF32CuBLAS def _get_cublas_allow_fp16_reduced_precision_reduction() -> _bool: ... #THPModule_allowFP16ReductionCuBLAS diff --git a/torch/csrc/LinalgBackend.cpp b/torch/csrc/LinalgBackend.cpp new file mode 100644 index 0000000000000..d3c9492cd3784 --- /dev/null +++ b/torch/csrc/LinalgBackend.cpp @@ -0,0 +1,81 @@ +#include + +#include +#include +#include + +#include + +#include +#include +#include + +PyObject *THPLinalgBackend_New(at::LinalgBackend linalg_backend) +{ + const std::string py_repr = at::LinalgBackendToRepr(linalg_backend); + auto type = (PyTypeObject*)&THPLinalgBackendType; + auto self = THPObjectPtr{type->tp_alloc(type, 0)}; + if (!self) throw python_error(); + auto self_ = reinterpret_cast(self.get()); + self_->linalg_backend = linalg_backend; + std::strncpy (self_->name, py_repr.c_str(), LINALG_BACKEND_NAME_LEN); + self_->name[LINALG_BACKEND_NAME_LEN] = '\0'; + return self.release(); +} + +PyObject *THPLinalgBackend_repr(THPLinalgBackend *self) +{ + return THPUtils_packString(self->name); +} + +PyTypeObject THPLinalgBackendType = { + PyVarObject_HEAD_INIT(nullptr, 0) + "torch.linalg_backend", /* tp_name */ + sizeof(THPLinalgBackend), /* tp_basicsize */ + 0, /* tp_itemsize */ + nullptr, /* tp_dealloc */ + 0, /* tp_vectorcall_offset */ + nullptr, /* tp_getattr */ + nullptr, /* tp_setattr */ + nullptr, /* tp_reserved */ + (reprfunc)THPLinalgBackend_repr, /* tp_repr */ + nullptr, /* tp_as_number */ + nullptr, /* tp_as_sequence */ + nullptr, /* tp_as_mapping */ + nullptr, /* tp_hash */ + nullptr, /* tp_call */ + nullptr, /* tp_str */ + nullptr, /* tp_getattro */ + nullptr, /* tp_setattro */ + nullptr, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT, /* tp_flags */ + nullptr, /* tp_doc */ + nullptr, /* tp_traverse */ + nullptr, /* tp_clear */ + nullptr, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + nullptr, /* tp_iter */ + nullptr, /* tp_iternext */ + nullptr, /* tp_methods */ + nullptr, /* tp_members */ + nullptr, /* tp_getset */ + nullptr, /* tp_base */ + nullptr, /* tp_dict */ + nullptr, /* tp_descr_get */ + nullptr, /* tp_descr_set */ + 0, /* tp_dictoffset */ + nullptr, /* tp_init */ + nullptr, /* tp_alloc */ + nullptr, /* tp_new */ +}; + +void THPLinalgBackend_init(PyObject *module) +{ + if (PyType_Ready(&THPLinalgBackendType) < 0) { + throw python_error(); + } + Py_INCREF(&THPLinalgBackendType); + if (PyModule_AddObject(module, "linalg_backend", (PyObject *)&THPLinalgBackendType) != 0) { + throw python_error(); + } +} diff --git a/torch/csrc/LinalgBackend.h b/torch/csrc/LinalgBackend.h new file mode 100644 index 0000000000000..ecbcb7b01b680 --- /dev/null +++ b/torch/csrc/LinalgBackend.h @@ -0,0 +1,26 @@ +#pragma once + +#include + +#include + +#include + +const int LINALG_BACKEND_NAME_LEN = 64; + +struct THPLinalgBackend { + PyObject_HEAD + at::LinalgBackend linalg_backend; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) + char name[LINALG_BACKEND_NAME_LEN + 1]; +}; + +extern PyTypeObject THPLinalgBackendType; + +inline bool THPLinalgBackend_Check(PyObject *obj) { + return Py_TYPE(obj) == &THPLinalgBackendType; +} + +PyObject * THPLinalgBackend_New(at::LinalgBackend linalg_backend); + +void THPLinalgBackend_init(PyObject *module); diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 2df2d180338df..11ec4b7ac2be8 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -32,6 +32,7 @@ #include #include #include +#include #include #include #include @@ -43,6 +44,7 @@ #include #include #include +#include #include #include #include @@ -125,6 +127,7 @@ static PyObject * THPModule_initExtension(PyObject *_unused, PyObject *shm_manag return nullptr; } torch::utils::initializeLayouts(); + torch::utils::initializeLinalgBackends(); torch::utils::initializeMemoryFormats(); torch::utils::initializeQSchemes(); torch::utils::initializeDtypes(); @@ -539,6 +542,23 @@ PyObject *THPModule_linalgCudaPreferCusolver(PyObject *_unused, PyObject *noargs Py_RETURN_FALSE; } +PyObject *THPModule_setLinalgPreferredBackend(PyObject *_unused, PyObject *arg) +{ + THPUtils_assert(THPLinalgBackend_Check(arg), "set_linalg_preferred_backend expects a linalg_backend, " + "but got %s", THPUtils_typename(arg)); + at::globalContext().setLinalgPreferredBackend(reinterpret_cast(arg)->linalg_backend); + Py_RETURN_NONE; +} + +PyObject *THPModule_linalgPreferredBackend(PyObject *_unused, PyObject *noargs) +{ + HANDLE_TH_ERRORS + auto res = (PyObject*)THPLinalgBackend_New(at::globalContext().linalgPreferredBackend()); + Py_INCREF(res); + return res; + END_HANDLE_TH_ERRORS +} + PyObject *THPModule_setAllowFP16ReductionCuBLAS(PyObject *_unused, PyObject *arg) { THPUtils_assert(PyBool_Check(arg), "set_allow_fp16_reduction_cublas expects a bool, " @@ -711,6 +731,8 @@ static PyMethodDef TorchMethods[] = { {"_set_cublas_allow_tf32", THPModule_setAllowTF32CuBLAS, METH_O, nullptr}, {"_get_linalg_cuda_prefer_cusolver", THPModule_linalgCudaPreferCusolver, METH_NOARGS, nullptr}, {"_set_linalg_cuda_prefer_cusolver", THPModule_setLinalgCudaPreferCusolver, METH_O, nullptr}, + {"_get_linalg_preferred_backend", THPModule_linalgPreferredBackend, METH_NOARGS, nullptr}, + {"_set_linalg_preferred_backend", THPModule_setLinalgPreferredBackend, METH_O, nullptr}, {"_get_cublas_allow_fp16_reduced_precision_reduction", THPModule_allowFP16ReductionCuBLAS, METH_NOARGS, nullptr}, {"_set_cublas_allow_fp16_reduced_precision_reduction", THPModule_setAllowFP16ReductionCuBLAS, METH_O, nullptr}, {"_vmapmode_increment_nesting", THPModule_vmapmode_increment_nesting, METH_NOARGS, nullptr}, @@ -837,6 +859,7 @@ PyObject* initModule() { THPDTypeInfo_init(module); THPLayout_init(module); THPMemoryFormat_init(module); + THPLinalgBackend_init(module); THPQScheme_init(module); THPDevice_init(module); THPStream_init(module); diff --git a/torch/csrc/jit/python/python_sugared_value.cpp b/torch/csrc/jit/python/python_sugared_value.cpp index 1a4ac0370ac65..3498d49ff0f91 100644 --- a/torch/csrc/jit/python/python_sugared_value.cpp +++ b/torch/csrc/jit/python/python_sugared_value.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -1083,6 +1084,10 @@ std::shared_ptr toSugaredValue( auto memory_format = reinterpret_cast(obj.ptr()); const auto v = static_cast(memory_format->memory_format); return toSimple(g.insertConstant(v, loc)); + } else if (THPLinalgBackend_Check(obj.ptr())) { + auto linalg_backend = reinterpret_cast(obj.ptr()); + const auto v = static_cast(linalg_backend->linalg_backend); + return toSimple(g.insertConstant(v, loc)); } else if (THPDtype_Check(obj.ptr())) { auto dtype = reinterpret_cast(obj.ptr()); const auto v = static_cast(dtype->scalar_type); diff --git a/torch/csrc/utils/linalg_backends.cpp b/torch/csrc/utils/linalg_backends.cpp new file mode 100644 index 0000000000000..f27314e4e7f3b --- /dev/null +++ b/torch/csrc/utils/linalg_backends.cpp @@ -0,0 +1,37 @@ +#include + +#include +#include +#include +#include + +#include +#include + + +namespace torch { +namespace utils { + +#define _ADD_LINALG_BACKEND(format) \ + { \ + std::string name = at::LinalgBackendToString(format); \ + PyObject* linalg_backend = THPLinalgBackend_New(format); \ + Py_INCREF(linalg_backend); \ + if (PyModule_AddObject(torch_module, name.c_str(), linalg_backend) != 0) { \ + throw python_error(); \ + } \ + } + +void initializeLinalgBackends() { + auto torch_module = THPObjectPtr(PyImport_ImportModule("torch")); + if (!torch_module) { + throw python_error(); + } + + _ADD_LINALG_BACKEND(at::LinalgBackend::Default); + _ADD_LINALG_BACKEND(at::LinalgBackend::Cusolver); + _ADD_LINALG_BACKEND(at::LinalgBackend::Magma); +} + +} // namespace utils +} // namespace torch diff --git a/torch/csrc/utils/linalg_backends.h b/torch/csrc/utils/linalg_backends.h new file mode 100644 index 0000000000000..9ea45bf6a7bf9 --- /dev/null +++ b/torch/csrc/utils/linalg_backends.h @@ -0,0 +1,7 @@ +#pragma once + +namespace torch { namespace utils { + +void initializeLinalgBackends(); + +}} // namespace torch::utils diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp index b01681c79d707..9cab42ce65a55 100644 --- a/torch/csrc/utils/python_arg_parser.cpp +++ b/torch/csrc/utils/python_arg_parser.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include @@ -35,6 +36,7 @@ static std::unordered_map type_map = { {"ScalarType", ParameterType::SCALARTYPE}, {"Layout", ParameterType::LAYOUT}, {"MemoryFormat", ParameterType::MEMORY_FORMAT}, + {"LinalgBackend", ParameterType::LINALG_BACKEND}, {"QScheme", ParameterType::QSCHEME}, {"Device", ParameterType::DEVICE}, {"Stream", ParameterType::STREAM}, @@ -529,6 +531,7 @@ auto FunctionParameter::check(PyObject* obj, std::vector &overloaded case ParameterType::SCALARTYPE: return THPDtype_Check(obj) || THPPythonScalarType_Check(obj); case ParameterType::LAYOUT: return THPLayout_Check(obj); case ParameterType::MEMORY_FORMAT: return THPMemoryFormat_Check(obj); + case ParameterType::LINALG_BACKEND: return THPLinalgBackend_Check(obj); case ParameterType::QSCHEME: return THPQScheme_Check(obj); case ParameterType::DEVICE: return THPUtils_checkLong(obj) || THPUtils_checkString(obj) || THPDevice_Check(obj); diff --git a/torch/csrc/utils/python_arg_parser.h b/torch/csrc/utils/python_arg_parser.h index 7e0c9b34026ad..8d7e3354c90ff 100644 --- a/torch/csrc/utils/python_arg_parser.h +++ b/torch/csrc/utils/python_arg_parser.h @@ -49,6 +49,7 @@ #include #include #include +#include #include #include #include @@ -80,7 +81,7 @@ namespace torch { enum class ParameterType { TENSOR, SCALAR, INT64, DOUBLE, COMPLEX, TENSOR_LIST, INT_LIST, GENERATOR, BOOL, STORAGE, PYOBJECT, SCALARTYPE, LAYOUT, MEMORY_FORMAT, DEVICE, STREAM, STRING, - DIMNAME, DIMNAME_LIST, QSCHEME, FLOAT_LIST, SCALAR_LIST + DIMNAME, DIMNAME_LIST, QSCHEME, FLOAT_LIST, SCALAR_LIST, LINALG_BACKEND }; struct FunctionParameter; @@ -194,7 +195,9 @@ struct PythonArgs { inline std::vector dimnamelist(int i); inline c10::optional> toDimnameListOptional(int i); inline at::MemoryFormat memoryformat(int i); + inline at::LinalgBackend linalgbackend(int i); inline c10::optional memoryformatOptional(int i); + inline c10::optional linalgbackendOptional(int i); inline at::QScheme toQScheme(int i); inline std::string string(int i); inline std::string stringWithDefault(int i, const std::string& default_str); @@ -583,6 +586,19 @@ inline c10::optional PythonArgs::memoryformatOptional(int i) { return memoryformat(i); } +inline at::LinalgBackend PythonArgs::linalgbackend(int i) { + if (!args[i]) return at::LinalgBackend::Default; + TORCH_CHECK(THPLinalgBackend_Check(args[i]), "linalg_backend arg must be an instance of the torch.linalg_backend"); + const auto linalg_backend = reinterpret_cast(args[i]); + return linalg_backend->linalg_backend; +} + +inline c10::optional PythonArgs::linalgbackendOptional(int i) { + if (!args[i]) + return c10::nullopt; + return linalgbackend(i); +} + inline at::QScheme PythonArgs::toQScheme(int i) { if (!args[i]) return at::kPerTensorAffine; TORCH_CHECK(THPQScheme_Check(args[i]), "qscheme arg must be an instance of the torch.qscheme"); From 185d534f5b9a1fe2a3a4f234e9eeac4a3cdf3cc8 Mon Sep 17 00:00:00 2001 From: Xiao Wang <24860335+xwang233@users.noreply.github.com> Date: Wed, 10 Nov 2021 22:27:57 -0800 Subject: [PATCH 06/30] remove cuda_prefer_cusolver --- aten/src/ATen/Context.cpp | 15 --------------- aten/src/ATen/Context.h | 3 --- .../ATen/native/cuda/BatchLinearAlgebra.cpp | 18 ++++++------------ docs/source/backends.rst | 17 +++++++++++++---- test/test_linalg.py | 6 ------ torch/_C/__init__.pyi.in | 2 -- torch/backends/linalg/__init__.py | 18 ++++++++++-------- torch/csrc/Module.cpp | 18 ------------------ 8 files changed, 29 insertions(+), 68 deletions(-) diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index e386e16a1a1ad..703d16f180588 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -147,21 +147,6 @@ void Context::setAllowTF32CuBLAS(bool b) { allow_tf32_cublas = b; } -bool Context::linalgCudaPreferCusolver() const { - return linalg_cuda_prefer_cusolver; -} - -void Context::setLinalgCudaPreferCusolver(bool b) { - linalg_cuda_prefer_cusolver = b; - if (b) { - TORCH_WARN_ONCE( - "torch.backends.linalg.cuda_prefer_cusolver is an experimental feature. " - "If you see any error or regression when this flag is enabled, " - "you're encourged to file an issue on GitHub." - ); - } -} - at::LinalgBackend Context::linalgPreferredBackend() const { return linalg_preferred_backend; } diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h index 61493b63e318f..8703ba3caa931 100644 --- a/aten/src/ATen/Context.h +++ b/aten/src/ATen/Context.h @@ -123,8 +123,6 @@ class TORCH_API Context { bool deterministicCuDNN() const; void setDeterministicCuDNN(bool); - bool linalgCudaPreferCusolver() const; - void setLinalgCudaPreferCusolver(bool); at::LinalgBackend linalgPreferredBackend() const; void setLinalgPreferredBackend(at::LinalgBackend); @@ -249,7 +247,6 @@ class TORCH_API Context { bool allow_tf32_cublas = true; bool allow_fp16_reduction_cublas = true; bool enabled_mkldnn = true; - bool linalg_cuda_prefer_cusolver = false; // if this is set to false, use existing heuristics at::LinalgBackend linalg_preferred_backend = at::LinalgBackend::Default; #ifdef C10_MOBILE bool release_original_weights = true; diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cpp b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cpp index 32137909b6fdc..f996f3732cdbf 100644 --- a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cpp @@ -1473,8 +1473,7 @@ Tensor _inverse_helper_cuda(const Tensor& self) { #ifdef USE_CUSOLVER if ((self.dim() == 2) || (/* self.dim() > 2 && */ batchCount(self) <= 2) || - !use_magma_ || - at::globalContext().linalgCudaPreferCusolver()) { + !use_magma_) { return _inverse_helper_cuda_lib(self); // cusolver or cublas } else { return _inverse_helper_cuda_legacy(self); // magma-cuda @@ -1508,8 +1507,7 @@ Tensor& _linalg_inv_out_helper_cuda(Tensor &result, Tensor& infos_lu, Tensor& in #ifdef USE_CUSOLVER if ((result.dim() == 2) || (/* result.dim() > 2 && */ batchCount(result) <= 2) || - !use_magma_ || - at::globalContext().linalgCudaPreferCusolver()) { + !use_magma_) { return _linalg_inv_out_helper_cuda_lib(result, infos_lu, infos_getri); // cusolver or cublas } else { return _linalg_inv_out_helper_cuda_legacy(result, infos_lu, infos_getri); // magma-cuda @@ -1607,8 +1605,7 @@ Tensor _cholesky_solve_helper_cuda_magma(const Tensor& self, const Tensor& A, bo Tensor _cholesky_solve_helper_cuda(const Tensor& self, const Tensor& A, bool upper) { #ifdef USE_CUSOLVER if (batchCount(self) == 1 || - !use_magma_ || - at::globalContext().linalgCudaPreferCusolver()) { + !use_magma_) { return _cholesky_solve_helper_cuda_cusolver(self, A, upper); } else { return _cholesky_solve_helper_cuda_magma(self, A, upper); @@ -1716,8 +1713,7 @@ static void cholesky_kernel(const Tensor& input, const Tensor& info, bool upper) #ifdef USE_CUSOLVER if (batchCount(input) == 1 || !use_magma_ || - use_cusolver_potrf_batched_ || - at::globalContext().linalgCudaPreferCusolver()) { + use_cusolver_potrf_batched_) { cholesky_helper_cusolver(input, upper, info); } else { cholesky_helper_magma(input, upper, info); @@ -1789,8 +1785,7 @@ Tensor& cholesky_inverse_kernel_impl(Tensor &result, Tensor& infos, bool upper) // the content of result is overwritten by 'apply_cholesky_inverse' #ifdef USE_CUSOLVER if (batchCount(result) == 1 || - !use_magma_ || - at::globalContext().linalgCudaPreferCusolver()) { + !use_magma_) { return cholesky_inverse_kernel_impl_cusolver(result, infos, upper); } else { return cholesky_inverse_kernel_impl_magma(result, infos, upper); @@ -1962,8 +1957,7 @@ static void apply_lu(const Tensor& input, const Tensor& pivots, const Tensor& in // exclude complex128 since nan_to_num_ does not work with it. if ((batch_size == 1 || (batch_size <= 8 && m <= 16) || - !use_magma_ || - at::globalContext().linalgCudaPreferCusolver()) + !use_magma_) && !input.is_complex()) { lu_looped_cusolver(input, pivots, infos, compute_pivots); } diff --git a/docs/source/backends.rst b/docs/source/backends.rst index 961d3f57dec23..9fa25139cc675 100644 --- a/docs/source/backends.rst +++ b/docs/source/backends.rst @@ -78,15 +78,24 @@ torch.backends.cudnn torch.backends.linalg ^^^^^^^^^^^^^^^^^^^^^ -.. attribute:: torch.backends.linalg.cuda_prefer_cusolver +.. attribute:: torch.backends.linalg.preferred .. warning:: This flag is experimental and subject to change. - The default value for this flag is *False*. + A flag that lets pytorch prefer one of the many backend implementations when calling linear algebra functions on GPU. + Currently the available options are: - A :class:`bool` that, if True, lets pytorch prefer cuSOLVER implementations when calling linear algebra functions on GPU. + * `torch.linalg_default` + * `torch.linalg_cusolver` + * `torch.linalg_magma` - Note: cuSOLVER implementations may still be used in some functions even if this flag is set to False. + Usage: + + * Use as a global flag, e.g. `torch.backends.linalg.preferred = torch.linalg_cusolver` + * Use the context manager, e.g. `with torch.backends.linalg.flags(preferred=torch.linalg_cusolver):` + + Note: The usage of a backend is not guaranteed for all linear algebra operators even if that backend is set as preferred. + Explicitly setting a preferred backend may override existing pytorch linear algebra heuristics and achieve better performance. torch.backends.mkl diff --git a/test/test_linalg.py b/test/test_linalg.py index fbf95ab9d0266..86088a51cc4da 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -8082,12 +8082,6 @@ def test_tensordot(self, device): an = torch.from_numpy(np.tensordot(np.zeros((), dtype=np.float32), np.zeros((), dtype=np.float32), 0)) self.assertEqual(a, an) - def test_linalg_cuda_prefer_cusolver_get_set(self): - with torch.backends.linalg.flags(cuda_prefer_cusolver=False): - self.assertFalse(torch.backends.linalg.cuda_prefer_cusolver) - with torch.backends.linalg.flags(cuda_prefer_cusolver=True): - self.assertTrue(torch.backends.linalg.cuda_prefer_cusolver) - instantiate_device_type_tests(TestLinalg, globals()) if __name__ == '__main__': diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index abb13daec54ae..7e492291c404b 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -603,8 +603,6 @@ def _get_warnAlways() -> _bool: ... # THPModule_warnAlways def _set_warnAlways(arg: _bool) -> None: ... # THPModule_setWarnAlways def _get_cudnn_allow_tf32() -> _bool: ... # THPModule_allowTF32CuDNN def _set_cudnn_allow_tf32(arg: _bool) -> None: ... # THPModule_setAllowTF32CuDNN -def _get_linalg_cuda_prefer_cusolver() -> _bool: ... # THPModule_linalgCudaPreferCusolver -def _set_linalg_cuda_prefer_cusolver(arg: _bool) -> None: ... # THPModule_setLinalgCudaPreferCusolver def _get_linalg_preferred_backend() -> _bool: ... # THPModule_linalgPreferredBackend def _set_linalg_preferred_backend(arg: _bool) -> None: ... # THPModule_setLinalgPreferredBackend def _get_cublas_allow_tf32() -> _bool: ... # THPModule_allowTF32CuBLAS diff --git a/torch/backends/linalg/__init__.py b/torch/backends/linalg/__init__.py index acb132c4e4a49..c21ae2299a527 100644 --- a/torch/backends/linalg/__init__.py +++ b/torch/backends/linalg/__init__.py @@ -3,17 +3,19 @@ from contextlib import contextmanager from torch.backends import ContextProp, PropModule, __allow_nonbracketed_mutation -def set_flags(_cuda_prefer_cusolver=None): - orig_flags = (torch._C._get_linalg_cuda_prefer_cusolver(),) - if _cuda_prefer_cusolver is not None: - torch._C._set_linalg_cuda_prefer_cusolver(_cuda_prefer_cusolver) +def set_flags(_preferred=None): + orig_flags = (torch._C._get_linalg_preferred_backend(),) + if _preferred is not None: + if not isinstance(_preferred, torch.linalg_backend): + raise RuntimeError("must choose a linalg backend from: torch.linalg_default, torch.linalg_cusolver, torch.linalg_magma.") + torch._C._set_linalg_preferred_backend(_preferred) return orig_flags @contextmanager -def flags(cuda_prefer_cusolver=True): +def flags(preferred=torch.linalg_default): with __allow_nonbracketed_mutation(): - orig_flags = set_flags(cuda_prefer_cusolver) + orig_flags = set_flags(preferred) try: yield finally: @@ -30,11 +32,11 @@ class LinalgModule(PropModule): def __init__(self, m, name): super(LinalgModule, self).__init__(m, name) - cuda_prefer_cusolver = ContextProp(torch._C._get_linalg_cuda_prefer_cusolver, torch._C._set_linalg_cuda_prefer_cusolver) + preferred = ContextProp(torch._C._get_linalg_preferred_backend, torch._C._set_linalg_preferred_backend) # This is the sys.modules replacement trick, see # https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273 sys.modules[__name__] = LinalgModule(sys.modules[__name__], __name__) # Add type annotation for the replaced module -cuda_prefer_cusolver: bool +preferred: torch.linalg_backend diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 11ec4b7ac2be8..24e922c62f974 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -526,22 +526,6 @@ PyObject *THPModule_allowTF32CuBLAS(PyObject *_unused, PyObject *noargs) Py_RETURN_FALSE; } -PyObject *THPModule_setLinalgCudaPreferCusolver(PyObject *_unused, PyObject *arg) -{ - THPUtils_assert(PyBool_Check(arg), "set_linalg_cuda_prefer_cusolver expects a bool, " - "but got %s", THPUtils_typename(arg)); - at::globalContext().setLinalgCudaPreferCusolver(arg == Py_True); - Py_RETURN_NONE; -} - -PyObject *THPModule_linalgCudaPreferCusolver(PyObject *_unused, PyObject *noargs) -{ - if (at::globalContext().linalgCudaPreferCusolver()) { - Py_RETURN_TRUE; - } - Py_RETURN_FALSE; -} - PyObject *THPModule_setLinalgPreferredBackend(PyObject *_unused, PyObject *arg) { THPUtils_assert(THPLinalgBackend_Check(arg), "set_linalg_preferred_backend expects a linalg_backend, " @@ -729,8 +713,6 @@ static PyMethodDef TorchMethods[] = { {"_set_warnAlways", THPModule_setWarnAlways, METH_O, nullptr}, {"_get_cublas_allow_tf32", THPModule_allowTF32CuBLAS, METH_NOARGS, nullptr}, {"_set_cublas_allow_tf32", THPModule_setAllowTF32CuBLAS, METH_O, nullptr}, - {"_get_linalg_cuda_prefer_cusolver", THPModule_linalgCudaPreferCusolver, METH_NOARGS, nullptr}, - {"_set_linalg_cuda_prefer_cusolver", THPModule_setLinalgCudaPreferCusolver, METH_O, nullptr}, {"_get_linalg_preferred_backend", THPModule_linalgPreferredBackend, METH_NOARGS, nullptr}, {"_set_linalg_preferred_backend", THPModule_setLinalgPreferredBackend, METH_O, nullptr}, {"_get_cublas_allow_fp16_reduced_precision_reduction", THPModule_allowFP16ReductionCuBLAS, METH_NOARGS, nullptr}, From fecbf34ac4ee748840f96e68b00990361da073c8 Mon Sep 17 00:00:00 2001 From: Xiao Wang <24860335+xwang233@users.noreply.github.com> Date: Wed, 10 Nov 2021 23:24:31 -0800 Subject: [PATCH 07/30] heuristic and preferred_backend --- .../ATen/native/cuda/BatchLinearAlgebra.cpp | 193 +++++++++++++----- 1 file changed, 140 insertions(+), 53 deletions(-) diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cpp b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cpp index f996f3732cdbf..d785c56fd1fa2 100644 --- a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cpp @@ -1471,12 +1471,18 @@ Tensor _inverse_helper_cuda_legacy(const Tensor& self) { Tensor _inverse_helper_cuda(const Tensor& self) { #ifdef USE_CUSOLVER - if ((self.dim() == 2) || - (/* self.dim() > 2 && */ batchCount(self) <= 2) || - !use_magma_) { - return _inverse_helper_cuda_lib(self); // cusolver or cublas - } else { - return _inverse_helper_cuda_legacy(self); // magma-cuda + auto preferred_backend = at::globalContext().linalgPreferredBackend(); + switch (preferred_backend) { + case at::LinalgBackend::Cusolver: + return _inverse_helper_cuda_lib(self); // cusolver or cublas + case at::LinalgBackend::Magma: + return _inverse_helper_cuda_legacy(self); // magma-cuda + default: + if (batchCount(self) <= 2 || !use_magma_) { + return _inverse_helper_cuda_lib(self); // cusolver or cublas + } else { + return _inverse_helper_cuda_legacy(self); // magma-cuda + } } #else return _inverse_helper_cuda_legacy(self); // magma-cuda @@ -1505,12 +1511,18 @@ Tensor& _linalg_inv_out_helper_cuda(Tensor &result, Tensor& infos_lu, Tensor& in // This function calculates the inverse matrix in-place // result should be in column major order and contain matrices to invert #ifdef USE_CUSOLVER - if ((result.dim() == 2) || - (/* result.dim() > 2 && */ batchCount(result) <= 2) || - !use_magma_) { - return _linalg_inv_out_helper_cuda_lib(result, infos_lu, infos_getri); // cusolver or cublas - } else { - return _linalg_inv_out_helper_cuda_legacy(result, infos_lu, infos_getri); // magma-cuda + auto preferred_backend = at::globalContext().linalgPreferredBackend(); + switch (preferred_backend) { + case at::LinalgBackend::Cusolver: + return _linalg_inv_out_helper_cuda_lib(result, infos_lu, infos_getri); // cusolver or cublas + case at::LinalgBackend::Magma: + return _linalg_inv_out_helper_cuda_legacy(result, infos_lu, infos_getri); // magma-cuda + default: + if (batchCount(result) <= 2 || !use_magma_) { + return _linalg_inv_out_helper_cuda_lib(result, infos_lu, infos_getri); // cusolver or cublas + } else { + return _linalg_inv_out_helper_cuda_legacy(result, infos_lu, infos_getri); // magma-cuda + } } #else return _linalg_inv_out_helper_cuda_legacy(result, infos_lu, infos_getri); // magma-cuda @@ -1604,11 +1616,18 @@ Tensor _cholesky_solve_helper_cuda_magma(const Tensor& self, const Tensor& A, bo // Batched cholesky_solve is dispatched to magma. Tensor _cholesky_solve_helper_cuda(const Tensor& self, const Tensor& A, bool upper) { #ifdef USE_CUSOLVER - if (batchCount(self) == 1 || - !use_magma_) { - return _cholesky_solve_helper_cuda_cusolver(self, A, upper); - } else { - return _cholesky_solve_helper_cuda_magma(self, A, upper); + auto preferred_backend = at::globalContext().linalgPreferredBackend(); + switch (preferred_backend) { + case at::LinalgBackend::Cusolver: + return _cholesky_solve_helper_cuda_cusolver(self, A, upper); + case at::LinalgBackend::Magma: + return _cholesky_solve_helper_cuda_magma(self, A, upper); + default: + if (batchCount(self) == 1 || !use_magma_) { + return _cholesky_solve_helper_cuda_cusolver(self, A, upper); + } else { + return _cholesky_solve_helper_cuda_magma(self, A, upper); + } } #else return _cholesky_solve_helper_cuda_magma(self, A, upper); @@ -1711,12 +1730,20 @@ void cholesky_helper_magma(const Tensor& input, bool upper, const Tensor& info) static void cholesky_kernel(const Tensor& input, const Tensor& info, bool upper) { #ifdef USE_CUSOLVER - if (batchCount(input) == 1 || - !use_magma_ || - use_cusolver_potrf_batched_) { - cholesky_helper_cusolver(input, upper, info); - } else { - cholesky_helper_magma(input, upper, info); + auto preferred_backend = at::globalContext().linalgPreferredBackend(); + switch (preferred_backend) { + case at::LinalgBackend::Cusolver: + cholesky_helper_cusolver(input, upper, info); + break; + case at::LinalgBackend::Magma: + cholesky_helper_magma(input, upper, info); + break; + default: + if (batchCount(input) == 1 || !use_magma_ || use_cusolver_potrf_batched_) { + cholesky_helper_cusolver(input, upper, info); + } else { + cholesky_helper_magma(input, upper, info); + } } #else cholesky_helper_magma(input, upper, info); @@ -1784,11 +1811,19 @@ Tensor& cholesky_inverse_kernel_impl(Tensor &result, Tensor& infos, bool upper) // result should be in column major order and contain matrices to invert // the content of result is overwritten by 'apply_cholesky_inverse' #ifdef USE_CUSOLVER - if (batchCount(result) == 1 || - !use_magma_) { - return cholesky_inverse_kernel_impl_cusolver(result, infos, upper); - } else { - return cholesky_inverse_kernel_impl_magma(result, infos, upper); + auto preferred_backend = at::globalContext().linalgPreferredBackend(); + switch (preferred_backend) { + case at::LinalgBackend::Cusolver: + return cholesky_inverse_kernel_impl_cusolver(result, infos, upper); + case at::LinalgBackend::Magma: + return cholesky_inverse_kernel_impl_magma(result, infos, upper); + default: + if (batchCount(result) == 1 || + !use_magma_) { + return cholesky_inverse_kernel_impl_cusolver(result, infos, upper); + } else { + return cholesky_inverse_kernel_impl_magma(result, infos, upper); + } } #else return cholesky_inverse_kernel_impl_magma(result, infos, upper); @@ -1952,23 +1987,39 @@ static void lu_batched_magma(const Tensor& input, const Tensor& pivots, const Te static void apply_lu(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) { int64_t batch_size = batchCount(input); #ifdef USE_CUSOLVER - // Use a heuristic to determine that cusolver is faster than MAGMA for the following sizes. - auto m = input.size(-2); - // exclude complex128 since nan_to_num_ does not work with it. - if ((batch_size == 1 || - (batch_size <= 8 && m <= 16) || - !use_magma_) - && !input.is_complex()) { - lu_looped_cusolver(input, pivots, infos, compute_pivots); + auto preferred_backend = at::globalContext().linalgPreferredBackend(); + switch (preferred_backend) { + case at::LinalgBackend::Cusolver: + lu_looped_cusolver(input, pivots, infos, compute_pivots); + break; + case at::LinalgBackend::Magma: + if (batch_size == 1) { + lu_looped_magma(input, pivots, infos, compute_pivots); + } else { + lu_batched_magma(input, pivots, infos, compute_pivots); + } + break; + default: + // Use a heuristic to determine that cusolver is faster than MAGMA for the following sizes. + auto m = input.size(-2); + // exclude complex128 since nan_to_num_ does not work with it. + if ((batch_size == 1 || + (batch_size <= 8 && m <= 16) || + !use_magma_) + && !input.is_complex()) { + lu_looped_cusolver(input, pivots, infos, compute_pivots); + } else { + lu_batched_magma(input, pivots, infos, compute_pivots); + } } #else if (batch_size == 1) { lu_looped_magma(input, pivots, infos, compute_pivots); } -#endif // USE_CUSOLVER else { lu_batched_magma(input, pivots, infos, compute_pivots); } +#endif // USE_CUSOLVER } REGISTER_CUDA_DISPATCH(lu_stub, &apply_lu); @@ -2075,12 +2126,12 @@ Tensor& orgqr_kernel_impl(Tensor& result, const Tensor& tau) { // See discussions in https://github.com/pytorch/pytorch/pull/51348 for comparison of cuSOLVER-MAGMA // and Windows failure. // For reference here is the MAGMA-based implementation: https://gist.github.com/IvanYashchuk/2db50002c9d3c1462ff769e6410ad983 - #if defined(USE_CUSOLVER) - return orgqr_helper_cusolver(result, tau); // cusolver - #else - TORCH_CHECK(false, "Calling torch.orgqr on a CUDA tensor requires compiling ", - "PyTorch with cuSOLVER. Please use PyTorch built with cuSOLVER support."); - #endif +#if defined(USE_CUSOLVER) + return orgqr_helper_cusolver(result, tau); // cusolver +#else + TORCH_CHECK(false, "Calling torch.orgqr on a CUDA tensor requires compiling ", + "PyTorch with cuSOLVER. Please use PyTorch built with cuSOLVER support."); +#endif } REGISTER_CUDA_DISPATCH(orgqr_stub, &orgqr_kernel_impl); @@ -2147,7 +2198,14 @@ void geqrf_magma(const Tensor& input, const Tensor& tau) { // This is a backend library dispatching helper function for calling looped batch implementation void geqrf_looped(const Tensor& input, const Tensor& tau) { #if defined(USE_CUSOLVER) - return geqrf_cusolver(input, tau); + auto preferred_backend = at::globalContext().linalgPreferredBackend(); + switch (preferred_backend) { + case at::LinalgBackend::Magma: + return geqrf_magma(input, tau); + case at::LinalgBackend::Cusolver: + default: + return geqrf_cusolver(input, tau); + } #else return geqrf_magma(input, tau); #endif @@ -2284,9 +2342,16 @@ std::tuple linalg_qr_helper_magma(const Tensor& self, c10::strin std::tuple _linalg_qr_helper_cuda(const Tensor& input, c10::string_view mode) { #if defined(USE_CUSOLVER) - // _linalg_qr_helper_default is a generic function that is implemented using - // geqrf_stub and orgqr_stub. It dispatches to cuSOLVER for CUDA inputs if USE_CUSOLVER is defined - return _linalg_qr_helper_default(input, mode); + auto preferred_backend = at::globalContext().linalgPreferredBackend(); + switch (preferred_backend) { + case at::LinalgBackend::Magma: + return linalg_qr_helper_magma(input, mode); + case at::LinalgBackend::Cusolver: + default: + // _linalg_qr_helper_default is a generic function that is implemented using + // geqrf_stub and orgqr_stub. It dispatches to cuSOLVER for CUDA inputs if USE_CUSOLVER is defined + return _linalg_qr_helper_default(input, mode); + } #else return linalg_qr_helper_magma(input, mode); #endif @@ -2443,7 +2508,15 @@ void linalg_eigh_magma(const Tensor& eigenvalues, const Tensor& eigenvectors, co void linalg_eigh_kernel(const Tensor& eigenvalues, const Tensor& eigenvectors, const Tensor& infos, bool upper, bool compute_eigenvectors) { #if defined(USE_CUSOLVER) - linalg_eigh_cusolver(eigenvalues, eigenvectors, infos, upper, compute_eigenvectors); + auto preferred_backend = at::globalContext().linalgPreferredBackend(); + switch (preferred_backend) { + case at::LinalgBackend::Magma: + linalg_eigh_magma(eigenvalues, eigenvectors, infos, upper, compute_eigenvectors); + break; + case at::LinalgBackend::Cusolver: + default: + linalg_eigh_cusolver(eigenvalues, eigenvectors, infos, upper, compute_eigenvectors); + } #else linalg_eigh_magma(eigenvalues, eigenvectors, infos, upper, compute_eigenvectors); #endif @@ -2742,7 +2815,14 @@ std::tuple _svd_helper_cuda_legacy(const Tensor& self, b std::tuple _svd_helper_cuda(const Tensor& self, bool some, bool compute_uv) { #ifdef USE_CUSOLVER - return _svd_helper_cuda_lib(self, some, compute_uv); + auto preferred_backend = at::globalContext().linalgPreferredBackend(); + switch (preferred_backend) { + case at::LinalgBackend::Magma: + return _svd_helper_cuda_legacy(self, some, compute_uv); + case at::LinalgBackend::Cusolver: + default: + return _svd_helper_cuda_lib(self, some, compute_uv); + } #else return _svd_helper_cuda_legacy(self, some, compute_uv); #endif @@ -3057,10 +3137,17 @@ void linalg_lstsq_gels(const Tensor& A, const Tensor& B, const Tensor& /*infos*/ void gels_looped(const Tensor& a, Tensor& b, Tensor& infos) { #if defined(USE_CUSOLVER) - // linalg_lstsq_gels is a generic function that is implemented using - // geqrf_stub, ormqr_stub, and triangular_solve_stub - // It dispatches to cuSOLVER for CUDA inputs if USE_CUSOLVER is defined - return linalg_lstsq_gels(a, b, infos); + auto preferred_backend = at::globalContext().linalgPreferredBackend(); + switch (preferred_backend) { + case at::LinalgBackend::Magma: + return gels_magma(a, b, infos); + case at::LinalgBackend::Cusolver: + default: + // linalg_lstsq_gels is a generic function that is implemented using + // geqrf_stub, ormqr_stub, and triangular_solve_stub + // It dispatches to cuSOLVER for CUDA inputs if USE_CUSOLVER is defined + return linalg_lstsq_gels(a, b, infos); + } #else return gels_magma(a, b, infos); #endif From c8af8a8ac4dc5fa2c12fea31eb34d5753674a0ee Mon Sep 17 00:00:00 2001 From: Xiao Wang <24860335+xwang233@users.noreply.github.com> Date: Thu, 11 Nov 2021 00:04:47 -0800 Subject: [PATCH 08/30] doc --- docs/source/backends.rst | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/docs/source/backends.rst b/docs/source/backends.rst index 9fa25139cc675..8385ae5847fd3 100644 --- a/docs/source/backends.rst +++ b/docs/source/backends.rst @@ -91,12 +91,25 @@ torch.backends.linalg Usage: - * Use as a global flag, e.g. `torch.backends.linalg.preferred = torch.linalg_cusolver` - * Use the context manager, e.g. `with torch.backends.linalg.flags(preferred=torch.linalg_cusolver):` + * Use as a global flag, e.g. `torch.backends.linalg.preferred = torch.linalg_cusolver` + * Use the context manager, e.g. `with torch.backends.linalg.flags(preferred=torch.linalg_cusolver):` Note: The usage of a backend is not guaranteed for all linear algebra operators even if that backend is set as preferred. Explicitly setting a preferred backend may override existing pytorch linear algebra heuristics and achieve better performance. + Currently supported linalg operators: + + * :func:`torch.linalg.inv` + * :func:`torch.linalg.inv_ex` + * :func:`torch.linalg.cholesky` + * :func:`torch.linalg.cholesky_ex` + * :func:`torch.cholesky_solve` + * :func:`torch.cholesky_inverse` + * :func:`torch.lu` + * :func:`torch.linalg.qr` + * :func:`torch.linalg.eigh` + * :func:`torch.linalg.svd` + torch.backends.mkl ^^^^^^^^^^^^^^^^^^ From 4937c1841c101cac98b894e2c97812396fe82809 Mon Sep 17 00:00:00 2001 From: Xiao Wang <24860335+xwang233@users.noreply.github.com> Date: Thu, 11 Nov 2021 00:13:51 -0800 Subject: [PATCH 09/30] lint --- c10/core/LinalgBackend.h | 8 ++------ torch/csrc/jit/python/python_sugared_value.cpp | 2 +- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/c10/core/LinalgBackend.h b/c10/core/LinalgBackend.h index 59e77e76db44e..0ea57ecbd5ca1 100644 --- a/c10/core/LinalgBackend.h +++ b/c10/core/LinalgBackend.h @@ -7,11 +7,7 @@ namespace c10 { -enum class LinalgBackend : int8_t { - Default, - Cusolver, - Magma -}; +enum class LinalgBackend : int8_t { Default, Cusolver, Magma }; inline std::string LinalgBackendToString(at::LinalgBackend backend) { switch (backend) { @@ -27,7 +23,7 @@ inline std::string LinalgBackendToString(at::LinalgBackend backend) { } inline std::string LinalgBackendToRepr(at::LinalgBackend backend) { - return std::string("torch.") + at::LinalgBackendToString(backend); + return std::string("torch.") + at::LinalgBackendToString(backend); } inline std::ostream& operator<<( diff --git a/torch/csrc/jit/python/python_sugared_value.cpp b/torch/csrc/jit/python/python_sugared_value.cpp index 3498d49ff0f91..17959d88c4375 100644 --- a/torch/csrc/jit/python/python_sugared_value.cpp +++ b/torch/csrc/jit/python/python_sugared_value.cpp @@ -3,8 +3,8 @@ #include #include #include -#include #include +#include #include #include #include From 1e48f00252b234f99d714978a82affd300cb9702 Mon Sep 17 00:00:00 2001 From: Xiao Wang <24860335+xwang233@users.noreply.github.com> Date: Thu, 11 Nov 2021 00:16:34 -0800 Subject: [PATCH 10/30] mypy --- torch/_C/__init__.pyi.in | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 7e492291c404b..a7bcde1d52d5f 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -603,8 +603,8 @@ def _get_warnAlways() -> _bool: ... # THPModule_warnAlways def _set_warnAlways(arg: _bool) -> None: ... # THPModule_setWarnAlways def _get_cudnn_allow_tf32() -> _bool: ... # THPModule_allowTF32CuDNN def _set_cudnn_allow_tf32(arg: _bool) -> None: ... # THPModule_setAllowTF32CuDNN -def _get_linalg_preferred_backend() -> _bool: ... # THPModule_linalgPreferredBackend -def _set_linalg_preferred_backend(arg: _bool) -> None: ... # THPModule_setLinalgPreferredBackend +def _get_linalg_preferred_backend() -> torch.linalg_backend: ... # THPModule_linalgPreferredBackend +def _set_linalg_preferred_backend(arg: torch.linalg_backend) -> None: ... # THPModule_setLinalgPreferredBackend def _get_cublas_allow_tf32() -> _bool: ... # THPModule_allowTF32CuBLAS def _set_cublas_allow_tf32(arg: _bool) -> None: ... # THPModule_setAllowTF32CuBLAS def _get_cublas_allow_fp16_reduced_precision_reduction() -> _bool: ... #THPModule_allowFP16ReductionCuBLAS From 03aef5ee84f839fb0edbca3ee777a484b10b0761 Mon Sep 17 00:00:00 2001 From: Xiao Wang <24860335+xwang233@users.noreply.github.com> Date: Thu, 11 Nov 2021 00:18:07 -0800 Subject: [PATCH 11/30] flake8 --- torch/backends/linalg/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch/backends/linalg/__init__.py b/torch/backends/linalg/__init__.py index c21ae2299a527..8dfd9706b5cb8 100644 --- a/torch/backends/linalg/__init__.py +++ b/torch/backends/linalg/__init__.py @@ -7,7 +7,8 @@ def set_flags(_preferred=None): orig_flags = (torch._C._get_linalg_preferred_backend(),) if _preferred is not None: if not isinstance(_preferred, torch.linalg_backend): - raise RuntimeError("must choose a linalg backend from: torch.linalg_default, torch.linalg_cusolver, torch.linalg_magma.") + raise RuntimeError("must choose a linalg backend from: " + "torch.linalg_default, torch.linalg_cusolver, torch.linalg_magma.") torch._C._set_linalg_preferred_backend(_preferred) return orig_flags From eac8a5a78dd156d98e47121bb3b0f944bd96df27 Mon Sep 17 00:00:00 2001 From: Xiao Wang <24860335+xwang233@users.noreply.github.com> Date: Thu, 11 Nov 2021 00:22:04 -0800 Subject: [PATCH 12/30] warning --- c10/core/LinalgBackend.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/c10/core/LinalgBackend.h b/c10/core/LinalgBackend.h index 0ea57ecbd5ca1..442274ca4a6ef 100644 --- a/c10/core/LinalgBackend.h +++ b/c10/core/LinalgBackend.h @@ -9,6 +9,8 @@ namespace c10 { enum class LinalgBackend : int8_t { Default, Cusolver, Magma }; +// WARNING: These exact strings, e.g. "torch.linalg_default", are also used in python bindings. +// Modifying output strings is **very** likely to cause BC-breaking in python side. inline std::string LinalgBackendToString(at::LinalgBackend backend) { switch (backend) { case LinalgBackend::Default: From 2a5a233831354d24e2183e603b6a8a1b81e2ecc9 Mon Sep 17 00:00:00 2001 From: Xiao Wang <24860335+xwang233@users.noreply.github.com> Date: Thu, 11 Nov 2021 00:26:51 -0800 Subject: [PATCH 13/30] clang format --- c10/core/LinalgBackend.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/c10/core/LinalgBackend.h b/c10/core/LinalgBackend.h index 442274ca4a6ef..c41658780de27 100644 --- a/c10/core/LinalgBackend.h +++ b/c10/core/LinalgBackend.h @@ -9,8 +9,9 @@ namespace c10 { enum class LinalgBackend : int8_t { Default, Cusolver, Magma }; -// WARNING: These exact strings, e.g. "torch.linalg_default", are also used in python bindings. -// Modifying output strings is **very** likely to cause BC-breaking in python side. +// WARNING: These exact strings, e.g. "torch.linalg_default", are also used in +// python bindings. Modifying output strings is **very** likely to cause +// BC-breaking in python side. inline std::string LinalgBackendToString(at::LinalgBackend backend) { switch (backend) { case LinalgBackend::Default: From badba2df6a30f187d6ae6ed0d8544c07767b85c8 Mon Sep 17 00:00:00 2001 From: Xiao Wang <24860335+xwang233@users.noreply.github.com> Date: Thu, 11 Nov 2021 00:33:02 -0800 Subject: [PATCH 14/30] typo --- c10/core/LinalgBackend.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/c10/core/LinalgBackend.h b/c10/core/LinalgBackend.h index c41658780de27..55abb03cea693 100644 --- a/c10/core/LinalgBackend.h +++ b/c10/core/LinalgBackend.h @@ -21,7 +21,7 @@ inline std::string LinalgBackendToString(at::LinalgBackend backend) { case LinalgBackend::Magma: return "linalg_magma"; default: - TORCH_CHECK(false, "Unknown memory format"); + TORCH_CHECK(false, "Unknown linalg backend"); } } From c9344b87bb7644ea10d1c5ff43337554d4ad46e0 Mon Sep 17 00:00:00 2001 From: Xiao Wang <24860335+xwang233@users.noreply.github.com> Date: Thu, 11 Nov 2021 01:56:54 -0800 Subject: [PATCH 15/30] add a test --- test/test_linalg.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/test/test_linalg.py b/test/test_linalg.py index 86088a51cc4da..0f398d7a47558 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -8082,6 +8082,26 @@ def test_tensordot(self, device): an = torch.from_numpy(np.tensordot(np.zeros((), dtype=np.float32), np.zeros((), dtype=np.float32), 0)) self.assertEqual(a, an) + @onlyCUDA + @skipCUDAIfNoMagma + @skipCUDAIfNoCusolver + def test_preferred_backends(self): + # The main purpose of this test is to make sure these "backend" calls work normally without raising exceptions. + x = torch.randint(2, 5, (2, 4, 4), device='cuda', dtype=torch.double) + with torch.backends.linalg.flags(preferred=torch.linalg_cusolver): + out1 = torch.linalg.inv(x) + with torch.backends.linalg.flags(preferred=torch.linalg_magma): + out2 = torch.linalg.inv(x) + with torch.backends.linalg.flags(preferred=torch.linalg_default): + # Although linalg preferred flags doesn't affect CPU currently, + # we set this to make sure the flag can switch back to default normally. + # We may also have different CPU backends in the future. + out_ref = torch.linalg.inv(x.cpu()) + + self.assertEqual(out_ref, out1.cpu()) + self.assertEqual(out1, out2) + + instantiate_device_type_tests(TestLinalg, globals()) if __name__ == '__main__': From 96b88961003158599c7f3de7f8358650178a3aa5 Mon Sep 17 00:00:00 2001 From: Xiao Wang <24860335+xwang233@users.noreply.github.com> Date: Thu, 11 Nov 2021 02:02:30 -0800 Subject: [PATCH 16/30] skipCUDAIfRocm --- test/test_linalg.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_linalg.py b/test/test_linalg.py index 0f398d7a47558..27dfe58b06dec 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -8083,6 +8083,7 @@ def test_tensordot(self, device): self.assertEqual(a, an) @onlyCUDA + @skipCUDAIfRocm @skipCUDAIfNoMagma @skipCUDAIfNoCusolver def test_preferred_backends(self): From 88ff857307e766c84405c5b5704cc126a5fa6976 Mon Sep 17 00:00:00 2001 From: Xiao Wang <24860335+xwang233@users.noreply.github.com> Date: Thu, 11 Nov 2021 14:35:16 -0800 Subject: [PATCH 17/30] override and binding --- test/test_public_bindings.py | 1 + torch/overrides.py | 1 + 2 files changed, 2 insertions(+) diff --git a/test/test_public_bindings.py b/test/test_public_bindings.py index 769e231597473..92338547fbab3 100644 --- a/test/test_public_bindings.py +++ b/test/test_public_bindings.py @@ -147,6 +147,7 @@ def test_no_new_bindings(self): "is_inference_mode_enabled", "JITException", "layout", + "linalg_backend", "ListType", "LiteScriptModule", "LockingLogger", diff --git a/torch/overrides.py b/torch/overrides.py index 48258c1c2d9b6..4a0a0e51fdcdd 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -111,6 +111,7 @@ def get_ignored_functions() -> Set[Callable]: torch.has_openmp, torch.iinfo, torch.memory_format, + torch.linalg_backend, torch.qscheme, torch.set_grad_enabled, torch.no_grad, From 7334ace5576d7813d4267dfc464ebcc5ee3cdb62 Mon Sep 17 00:00:00 2001 From: Xiao Wang <24860335+xwang233@users.noreply.github.com> Date: Thu, 11 Nov 2021 14:47:16 -0800 Subject: [PATCH 18/30] doc coverage fix You added the following module(s) to the PyTorch namespace 'torch.backends.linalg' but they have no corresponding entry in a doc .rst file. You should either make sure that the .rst file that contains the module's documentation properly contains either '.. automodule:: mod_name' (if you do not want the paragraph added by the automodule, you can simply use py:module) or make the module private (by appending an '_' at the beginning of its name. --- docs/source/backends.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/backends.rst b/docs/source/backends.rst index 8385ae5847fd3..4ea36144622dd 100644 --- a/docs/source/backends.rst +++ b/docs/source/backends.rst @@ -77,6 +77,7 @@ torch.backends.cudnn torch.backends.linalg ^^^^^^^^^^^^^^^^^^^^^ +.. py:module:: torch.backends.linalg .. attribute:: torch.backends.linalg.preferred From 1e8e0f825e982c1939828c71e521aed43f714865 Mon Sep 17 00:00:00 2001 From: Xiao Wang <24860335+xwang233@users.noreply.github.com> Date: Thu, 11 Nov 2021 14:49:18 -0800 Subject: [PATCH 19/30] remove rocm skip --- test/test_linalg.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_linalg.py b/test/test_linalg.py index b675f9928f529..03094db3edafe 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -8182,7 +8182,6 @@ def test_tensordot(self, device): self.assertEqual(a, an) @onlyCUDA - @skipCUDAIfRocm @skipCUDAIfNoMagma @skipCUDAIfNoCusolver def test_preferred_backends(self): From 569a6b139d54e3e924171ae03673e1d5ad878b81 Mon Sep 17 00:00:00 2001 From: Xiao Wang <24860335+xwang233@users.noreply.github.com> Date: Mon, 15 Nov 2021 15:09:58 -0800 Subject: [PATCH 20/30] doc and warning message updates --- aten/src/ATen/Context.cpp | 4 ++-- docs/source/backends.rst | 15 ++++++++------- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index 703d16f180588..9b72db7a37097 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -156,8 +156,8 @@ void Context::setLinalgPreferredBackend(at::LinalgBackend b) { if (b != at::LinalgBackend::Default) { TORCH_WARN_ONCE( "torch.backends.linalg.preferred is an experimental feature. " - "If you see any error or regression when this flag is set, " - "you're encourged to file an issue on GitHub." + "If you see any error or regression when this flag is set " + "please file an issue on GitHub." ); } } diff --git a/docs/source/backends.rst b/docs/source/backends.rst index 4ea36144622dd..b236209c6cad8 100644 --- a/docs/source/backends.rst +++ b/docs/source/backends.rst @@ -83,20 +83,21 @@ torch.backends.linalg .. warning:: This flag is experimental and subject to change. - A flag that lets pytorch prefer one of the many backend implementations when calling linear algebra functions on GPU. - Currently the available options are: + When PyTorch runs a CUDA linear algebra operation it often uses the cuSOLVER or MAGMA libraries, + and if both are available it decides which to use with a heuristic. + This flag allows overriding those heuristics. - * `torch.linalg_default` - * `torch.linalg_cusolver` - * `torch.linalg_magma` + * If `torch.linalg_cusolver` is set then cuSOLVER will be used wherever possible. + * If `torch.linalg_magma` is set then MAGMA will be used wherever possible. + * If `torch.linalg_default` (the default) is set then heuristics will be used to pick between cuSOLVER and MAGMA if both are available. Usage: * Use as a global flag, e.g. `torch.backends.linalg.preferred = torch.linalg_cusolver` * Use the context manager, e.g. `with torch.backends.linalg.flags(preferred=torch.linalg_cusolver):` - Note: The usage of a backend is not guaranteed for all linear algebra operators even if that backend is set as preferred. - Explicitly setting a preferred backend may override existing pytorch linear algebra heuristics and achieve better performance. + Note: When a library is preferred other libraries may still be used if the preferred library doesn't implement the operation(s) called. + This flag may achieve better performance if PyTorch's heuristic library selection is incorrect for your application's inputs. Currently supported linalg operators: From 9d061971bd046a1947261eaec54abb0c0f6a2e5f Mon Sep 17 00:00:00 2001 From: Xiao Wang <24860335+xwang233@users.noreply.github.com> Date: Wed, 17 Nov 2021 15:13:28 -0800 Subject: [PATCH 21/30] move c10/core/LinalgBackend.h to ATen --- aten/src/ATen/Context.h | 2 +- {c10/core => aten/src/ATen}/LinalgBackend.h | 2 +- torch/csrc/LinalgBackend.cpp | 4 ++-- torch/csrc/LinalgBackend.h | 4 ++-- torch/csrc/utils/linalg_backends.cpp | 3 ++- 5 files changed, 8 insertions(+), 7 deletions(-) rename {c10/core => aten/src/ATen}/LinalgBackend.h (98%) diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h index 8703ba3caa931..3dc7464300c6b 100644 --- a/aten/src/ATen/Context.h +++ b/aten/src/ATen/Context.h @@ -3,6 +3,7 @@ #include #include #include +#include #include #include #include @@ -10,7 +11,6 @@ #include #include #include -#include #include #include diff --git a/c10/core/LinalgBackend.h b/aten/src/ATen/LinalgBackend.h similarity index 98% rename from c10/core/LinalgBackend.h rename to aten/src/ATen/LinalgBackend.h index 55abb03cea693..36b5fc89d475a 100644 --- a/c10/core/LinalgBackend.h +++ b/aten/src/ATen/LinalgBackend.h @@ -5,7 +5,7 @@ #include #include -namespace c10 { +namespace at { enum class LinalgBackend : int8_t { Default, Cusolver, Magma }; diff --git a/torch/csrc/LinalgBackend.cpp b/torch/csrc/LinalgBackend.cpp index d3c9492cd3784..870a1e0169b69 100644 --- a/torch/csrc/LinalgBackend.cpp +++ b/torch/csrc/LinalgBackend.cpp @@ -1,11 +1,11 @@ #include +#include + #include #include #include -#include - #include #include #include diff --git a/torch/csrc/LinalgBackend.h b/torch/csrc/LinalgBackend.h index ecbcb7b01b680..51310bace3185 100644 --- a/torch/csrc/LinalgBackend.h +++ b/torch/csrc/LinalgBackend.h @@ -1,8 +1,8 @@ #pragma once -#include +#include -#include +#include #include diff --git a/torch/csrc/utils/linalg_backends.cpp b/torch/csrc/utils/linalg_backends.cpp index f27314e4e7f3b..b2e21d25e0fa6 100644 --- a/torch/csrc/utils/linalg_backends.cpp +++ b/torch/csrc/utils/linalg_backends.cpp @@ -1,9 +1,10 @@ #include +#include + #include #include #include -#include #include #include From 6b432f72314e4b882c34c2a2e0eaea5de337c77a Mon Sep 17 00:00:00 2001 From: Xiao Wang <24860335+xwang233@users.noreply.github.com> Date: Wed, 17 Nov 2021 15:19:34 -0800 Subject: [PATCH 22/30] remove jit changes --- torch/csrc/jit/python/python_sugared_value.cpp | 5 ----- 1 file changed, 5 deletions(-) diff --git a/torch/csrc/jit/python/python_sugared_value.cpp b/torch/csrc/jit/python/python_sugared_value.cpp index 17959d88c4375..1a4ac0370ac65 100644 --- a/torch/csrc/jit/python/python_sugared_value.cpp +++ b/torch/csrc/jit/python/python_sugared_value.cpp @@ -3,7 +3,6 @@ #include #include #include -#include #include #include #include @@ -1084,10 +1083,6 @@ std::shared_ptr toSugaredValue( auto memory_format = reinterpret_cast(obj.ptr()); const auto v = static_cast(memory_format->memory_format); return toSimple(g.insertConstant(v, loc)); - } else if (THPLinalgBackend_Check(obj.ptr())) { - auto linalg_backend = reinterpret_cast(obj.ptr()); - const auto v = static_cast(linalg_backend->linalg_backend); - return toSimple(g.insertConstant(v, loc)); } else if (THPDtype_Check(obj.ptr())) { auto dtype = reinterpret_cast(obj.ptr()); const auto v = static_cast(dtype->scalar_type); From 4eaf20380479a489ea86c62308ce570c40d54152 Mon Sep 17 00:00:00 2001 From: Xiao Wang <24860335+xwang233@users.noreply.github.com> Date: Wed, 17 Nov 2021 15:19:56 -0800 Subject: [PATCH 23/30] undef linalg macro --- torch/csrc/utils/linalg_backends.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch/csrc/utils/linalg_backends.cpp b/torch/csrc/utils/linalg_backends.cpp index b2e21d25e0fa6..d4fc064018de9 100644 --- a/torch/csrc/utils/linalg_backends.cpp +++ b/torch/csrc/utils/linalg_backends.cpp @@ -34,5 +34,7 @@ void initializeLinalgBackends() { _ADD_LINALG_BACKEND(at::LinalgBackend::Magma); } +#undef _ADD_LINALG_BACKEND + } // namespace utils } // namespace torch From 4faf9f06204ce6c5d5cf485ee27caff67dbd935f Mon Sep 17 00:00:00 2001 From: Xiao Wang <24860335+xwang233@users.noreply.github.com> Date: Wed, 17 Nov 2021 15:27:18 -0800 Subject: [PATCH 24/30] revert python_arg_parser changes --- torch/csrc/utils/python_arg_parser.cpp | 3 --- torch/csrc/utils/python_arg_parser.h | 18 +----------------- 2 files changed, 1 insertion(+), 20 deletions(-) diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp index 9cab42ce65a55..b01681c79d707 100644 --- a/torch/csrc/utils/python_arg_parser.cpp +++ b/torch/csrc/utils/python_arg_parser.cpp @@ -3,7 +3,6 @@ #include #include #include -#include #include #include @@ -36,7 +35,6 @@ static std::unordered_map type_map = { {"ScalarType", ParameterType::SCALARTYPE}, {"Layout", ParameterType::LAYOUT}, {"MemoryFormat", ParameterType::MEMORY_FORMAT}, - {"LinalgBackend", ParameterType::LINALG_BACKEND}, {"QScheme", ParameterType::QSCHEME}, {"Device", ParameterType::DEVICE}, {"Stream", ParameterType::STREAM}, @@ -531,7 +529,6 @@ auto FunctionParameter::check(PyObject* obj, std::vector &overloaded case ParameterType::SCALARTYPE: return THPDtype_Check(obj) || THPPythonScalarType_Check(obj); case ParameterType::LAYOUT: return THPLayout_Check(obj); case ParameterType::MEMORY_FORMAT: return THPMemoryFormat_Check(obj); - case ParameterType::LINALG_BACKEND: return THPLinalgBackend_Check(obj); case ParameterType::QSCHEME: return THPQScheme_Check(obj); case ParameterType::DEVICE: return THPUtils_checkLong(obj) || THPUtils_checkString(obj) || THPDevice_Check(obj); diff --git a/torch/csrc/utils/python_arg_parser.h b/torch/csrc/utils/python_arg_parser.h index 8d7e3354c90ff..7e0c9b34026ad 100644 --- a/torch/csrc/utils/python_arg_parser.h +++ b/torch/csrc/utils/python_arg_parser.h @@ -49,7 +49,6 @@ #include #include #include -#include #include #include #include @@ -81,7 +80,7 @@ namespace torch { enum class ParameterType { TENSOR, SCALAR, INT64, DOUBLE, COMPLEX, TENSOR_LIST, INT_LIST, GENERATOR, BOOL, STORAGE, PYOBJECT, SCALARTYPE, LAYOUT, MEMORY_FORMAT, DEVICE, STREAM, STRING, - DIMNAME, DIMNAME_LIST, QSCHEME, FLOAT_LIST, SCALAR_LIST, LINALG_BACKEND + DIMNAME, DIMNAME_LIST, QSCHEME, FLOAT_LIST, SCALAR_LIST }; struct FunctionParameter; @@ -195,9 +194,7 @@ struct PythonArgs { inline std::vector dimnamelist(int i); inline c10::optional> toDimnameListOptional(int i); inline at::MemoryFormat memoryformat(int i); - inline at::LinalgBackend linalgbackend(int i); inline c10::optional memoryformatOptional(int i); - inline c10::optional linalgbackendOptional(int i); inline at::QScheme toQScheme(int i); inline std::string string(int i); inline std::string stringWithDefault(int i, const std::string& default_str); @@ -586,19 +583,6 @@ inline c10::optional PythonArgs::memoryformatOptional(int i) { return memoryformat(i); } -inline at::LinalgBackend PythonArgs::linalgbackend(int i) { - if (!args[i]) return at::LinalgBackend::Default; - TORCH_CHECK(THPLinalgBackend_Check(args[i]), "linalg_backend arg must be an instance of the torch.linalg_backend"); - const auto linalg_backend = reinterpret_cast(args[i]); - return linalg_backend->linalg_backend; -} - -inline c10::optional PythonArgs::linalgbackendOptional(int i) { - if (!args[i]) - return c10::nullopt; - return linalgbackend(i); -} - inline at::QScheme PythonArgs::toQScheme(int i) { if (!args[i]) return at::kPerTensorAffine; TORCH_CHECK(THPQScheme_Check(args[i]), "qscheme arg must be an instance of the torch.qscheme"); From 4114dd2ee1ff7900dcc6ffe53d42c584c762c8fd Mon Sep 17 00:00:00 2001 From: Xiao Wang <24860335+xwang233@users.noreply.github.com> Date: Wed, 17 Nov 2021 16:55:50 -0800 Subject: [PATCH 25/30] change py bindings to torch.backends.cuda.linalg.preferred_library --- aten/src/ATen/Context.cpp | 2 +- docs/source/backends.rst | 67 +++++++++++++------------------ test/test_linalg.py | 20 ++++----- torch/__init__.py | 1 - torch/backends/cuda/__init__.py | 11 +++++ torch/backends/linalg/__init__.py | 43 -------------------- 6 files changed, 51 insertions(+), 93 deletions(-) delete mode 100644 torch/backends/linalg/__init__.py diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index 9b72db7a37097..b354384fcb5ed 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -155,7 +155,7 @@ void Context::setLinalgPreferredBackend(at::LinalgBackend b) { linalg_preferred_backend = b; if (b != at::LinalgBackend::Default) { TORCH_WARN_ONCE( - "torch.backends.linalg.preferred is an experimental feature. " + "torch.backends.cuda.linalg.preferred_library is an experimental feature. " "If you see any error or regression when this flag is set " "please file an issue on GitHub." ); diff --git a/docs/source/backends.rst b/docs/source/backends.rst index b236209c6cad8..e37416e291252 100644 --- a/docs/source/backends.rst +++ b/docs/source/backends.rst @@ -10,7 +10,6 @@ These backends include: - ``torch.backends.cuda`` - ``torch.backends.cudnn`` -- ``torch.backends.linalg`` - ``torch.backends.mkl`` - ``torch.backends.mkldnn`` - ``torch.backends.openmp`` @@ -46,6 +45,34 @@ torch.backends.cuda Clears the cuFFT plan cache. +.. attribute:: torch.backends.cuda.linalg.preferred_library + + .. warning:: This flag is experimental and subject to change. + + When PyTorch runs a CUDA linear algebra operation it often uses the cuSOLVER or MAGMA libraries, + and if both are available it decides which to use with a heuristic. + This flag allows overriding those heuristics. + + * If `torch.linalg_cusolver` is set then cuSOLVER will be used wherever possible. + * If `torch.linalg_magma` is set then MAGMA will be used wherever possible. + * If `torch.linalg_default` (the default) is set then heuristics will be used to pick between cuSOLVER and MAGMA if both are available. + + Note: When a library is preferred other libraries may still be used if the preferred library doesn't implement the operation(s) called. + This flag may achieve better performance if PyTorch's heuristic library selection is incorrect for your application's inputs. + + Currently supported linalg operators: + + * :func:`torch.linalg.inv` + * :func:`torch.linalg.inv_ex` + * :func:`torch.linalg.cholesky` + * :func:`torch.linalg.cholesky_ex` + * :func:`torch.cholesky_solve` + * :func:`torch.cholesky_inverse` + * :func:`torch.lu` + * :func:`torch.linalg.qr` + * :func:`torch.linalg.eigh` + * :func:`torch.linalg.svd` + torch.backends.cudnn ^^^^^^^^^^^^^^^^^^^^ @@ -75,44 +102,6 @@ torch.backends.cudnn and select the fastest. -torch.backends.linalg -^^^^^^^^^^^^^^^^^^^^^ -.. py:module:: torch.backends.linalg - -.. attribute:: torch.backends.linalg.preferred - - .. warning:: This flag is experimental and subject to change. - - When PyTorch runs a CUDA linear algebra operation it often uses the cuSOLVER or MAGMA libraries, - and if both are available it decides which to use with a heuristic. - This flag allows overriding those heuristics. - - * If `torch.linalg_cusolver` is set then cuSOLVER will be used wherever possible. - * If `torch.linalg_magma` is set then MAGMA will be used wherever possible. - * If `torch.linalg_default` (the default) is set then heuristics will be used to pick between cuSOLVER and MAGMA if both are available. - - Usage: - - * Use as a global flag, e.g. `torch.backends.linalg.preferred = torch.linalg_cusolver` - * Use the context manager, e.g. `with torch.backends.linalg.flags(preferred=torch.linalg_cusolver):` - - Note: When a library is preferred other libraries may still be used if the preferred library doesn't implement the operation(s) called. - This flag may achieve better performance if PyTorch's heuristic library selection is incorrect for your application's inputs. - - Currently supported linalg operators: - - * :func:`torch.linalg.inv` - * :func:`torch.linalg.inv_ex` - * :func:`torch.linalg.cholesky` - * :func:`torch.linalg.cholesky_ex` - * :func:`torch.cholesky_solve` - * :func:`torch.cholesky_inverse` - * :func:`torch.lu` - * :func:`torch.linalg.qr` - * :func:`torch.linalg.eigh` - * :func:`torch.linalg.svd` - - torch.backends.mkl ^^^^^^^^^^^^^^^^^^ diff --git a/test/test_linalg.py b/test/test_linalg.py index dbb126e71ed88..1de9b97c39967 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -8194,15 +8194,17 @@ def test_tensordot(self, device): def test_preferred_backends(self): # The main purpose of this test is to make sure these "backend" calls work normally without raising exceptions. x = torch.randint(2, 5, (2, 4, 4), device='cuda', dtype=torch.double) - with torch.backends.linalg.flags(preferred=torch.linalg_cusolver): - out1 = torch.linalg.inv(x) - with torch.backends.linalg.flags(preferred=torch.linalg_magma): - out2 = torch.linalg.inv(x) - with torch.backends.linalg.flags(preferred=torch.linalg_default): - # Although linalg preferred flags doesn't affect CPU currently, - # we set this to make sure the flag can switch back to default normally. - # We may also have different CPU backends in the future. - out_ref = torch.linalg.inv(x.cpu()) + + torch.backends.cuda.linalg.preferred_library = torch.linalg_cusolver + out1 = torch.linalg.inv(x) + + torch.backends.cuda.linalg.preferred_library = torch.linalg_magma + out2 = torch.linalg.inv(x) + + torch.backends.cuda.linalg.preferred_library = torch.linalg_default + # Although linalg preferred flags doesn't affect CPU currently, + # we set this to make sure the flag can switch back to default normally. + out_ref = torch.linalg.inv(x.cpu()) self.assertEqual(out_ref, out1.cpu()) self.assertEqual(out1, out2) diff --git a/torch/__init__.py b/torch/__init__.py index a5c6bb255cc9c..46f6c9d278a91 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -813,7 +813,6 @@ def _assert(condition, message): import torch.backends.mkl import torch.backends.mkldnn import torch.backends.openmp -import torch.backends.linalg import torch.backends.quantized import torch.utils.data from torch import __config__ as __config__ diff --git a/torch/backends/cuda/__init__.py b/torch/backends/cuda/__init__.py index 80861780a1976..c1f97ec7cae27 100644 --- a/torch/backends/cuda/__init__.py +++ b/torch/backends/cuda/__init__.py @@ -98,6 +98,17 @@ def __setattr__(self, name, value): return torch._C._set_cublas_allow_fp16_reduced_precision_reduction(value) raise AssertionError("Unknown attribute " + name) +class _LinalgModule: + def __getattr__(self, name): + if name == "preferred_library": + return torch._C._get_linalg_preferred_backend() + raise AssertionError("Unknown attribute " + name) + + def __setattr__(self, name, value): + if name == "preferred_library": + return torch._C._set_linalg_preferred_backend(value) + raise AssertionError("Unknown attribute " + name) cufft_plan_cache = cuFFTPlanCacheManager() matmul = cuBLASModule() +linalg = _LinalgModule() diff --git a/torch/backends/linalg/__init__.py b/torch/backends/linalg/__init__.py deleted file mode 100644 index 8dfd9706b5cb8..0000000000000 --- a/torch/backends/linalg/__init__.py +++ /dev/null @@ -1,43 +0,0 @@ -import sys -import torch -from contextlib import contextmanager -from torch.backends import ContextProp, PropModule, __allow_nonbracketed_mutation - -def set_flags(_preferred=None): - orig_flags = (torch._C._get_linalg_preferred_backend(),) - if _preferred is not None: - if not isinstance(_preferred, torch.linalg_backend): - raise RuntimeError("must choose a linalg backend from: " - "torch.linalg_default, torch.linalg_cusolver, torch.linalg_magma.") - torch._C._set_linalg_preferred_backend(_preferred) - return orig_flags - - -@contextmanager -def flags(preferred=torch.linalg_default): - with __allow_nonbracketed_mutation(): - orig_flags = set_flags(preferred) - try: - yield - finally: - # recover the previous values - with __allow_nonbracketed_mutation(): - set_flags(*orig_flags) - - -# The magic here is to allow us to intercept code like this: -# -# torch.backends..enabled = True - -class LinalgModule(PropModule): - def __init__(self, m, name): - super(LinalgModule, self).__init__(m, name) - - preferred = ContextProp(torch._C._get_linalg_preferred_backend, torch._C._set_linalg_preferred_backend) - -# This is the sys.modules replacement trick, see -# https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273 -sys.modules[__name__] = LinalgModule(sys.modules[__name__], __name__) - -# Add type annotation for the replaced module -preferred: torch.linalg_backend From c5571336be26284fec89035e54e78c3b8a449163 Mon Sep 17 00:00:00 2001 From: Xiao Wang <24860335+xwang233@users.noreply.github.com> Date: Thu, 18 Nov 2021 20:08:40 -0800 Subject: [PATCH 26/30] use pybind11 for py/c++ bindings --- aten/src/ATen/Context.cpp | 2 +- aten/src/ATen/LinalgBackend.h | 13 ++--- docs/source/backends.rst | 28 +--------- test/test_linalg.py | 8 +-- test/test_public_bindings.py | 1 - tools/build_variables.bzl | 2 - torch/_C/__init__.pyi.in | 10 ---- torch/backends/cuda/__init__.py | 63 ++++++++++++++++++---- torch/csrc/LinalgBackend.cpp | 81 ---------------------------- torch/csrc/LinalgBackend.h | 26 --------- torch/csrc/Module.cpp | 36 +++++-------- torch/csrc/utils/linalg_backends.cpp | 40 -------------- torch/csrc/utils/linalg_backends.h | 7 --- torch/overrides.py | 1 - 14 files changed, 74 insertions(+), 244 deletions(-) delete mode 100644 torch/csrc/LinalgBackend.cpp delete mode 100644 torch/csrc/LinalgBackend.h delete mode 100644 torch/csrc/utils/linalg_backends.cpp delete mode 100644 torch/csrc/utils/linalg_backends.h diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index b354384fcb5ed..78b0552548152 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -155,7 +155,7 @@ void Context::setLinalgPreferredBackend(at::LinalgBackend b) { linalg_preferred_backend = b; if (b != at::LinalgBackend::Default) { TORCH_WARN_ONCE( - "torch.backends.cuda.linalg.preferred_library is an experimental feature. " + "torch.backends.cuda.preferred_linalg_library is an experimental feature. " "If you see any error or regression when this flag is set " "please file an issue on GitHub." ); diff --git a/aten/src/ATen/LinalgBackend.h b/aten/src/ATen/LinalgBackend.h index 36b5fc89d475a..df9e51b453b28 100644 --- a/aten/src/ATen/LinalgBackend.h +++ b/aten/src/ATen/LinalgBackend.h @@ -9,26 +9,19 @@ namespace at { enum class LinalgBackend : int8_t { Default, Cusolver, Magma }; -// WARNING: These exact strings, e.g. "torch.linalg_default", are also used in -// python bindings. Modifying output strings is **very** likely to cause -// BC-breaking in python side. inline std::string LinalgBackendToString(at::LinalgBackend backend) { switch (backend) { case LinalgBackend::Default: - return "linalg_default"; + return "at::LinalgBackend::Default"; case LinalgBackend::Cusolver: - return "linalg_cusolver"; + return "at::LinalgBackend::Cusolver"; case LinalgBackend::Magma: - return "linalg_magma"; + return "at::LinalgBackend::Magma"; default: TORCH_CHECK(false, "Unknown linalg backend"); } } -inline std::string LinalgBackendToRepr(at::LinalgBackend backend) { - return std::string("torch.") + at::LinalgBackendToString(backend); -} - inline std::ostream& operator<<( std::ostream& stream, at::LinalgBackend backend) { diff --git a/docs/source/backends.rst b/docs/source/backends.rst index e37416e291252..45d6fdf2add2a 100644 --- a/docs/source/backends.rst +++ b/docs/source/backends.rst @@ -45,33 +45,7 @@ torch.backends.cuda Clears the cuFFT plan cache. -.. attribute:: torch.backends.cuda.linalg.preferred_library - - .. warning:: This flag is experimental and subject to change. - - When PyTorch runs a CUDA linear algebra operation it often uses the cuSOLVER or MAGMA libraries, - and if both are available it decides which to use with a heuristic. - This flag allows overriding those heuristics. - - * If `torch.linalg_cusolver` is set then cuSOLVER will be used wherever possible. - * If `torch.linalg_magma` is set then MAGMA will be used wherever possible. - * If `torch.linalg_default` (the default) is set then heuristics will be used to pick between cuSOLVER and MAGMA if both are available. - - Note: When a library is preferred other libraries may still be used if the preferred library doesn't implement the operation(s) called. - This flag may achieve better performance if PyTorch's heuristic library selection is incorrect for your application's inputs. - - Currently supported linalg operators: - - * :func:`torch.linalg.inv` - * :func:`torch.linalg.inv_ex` - * :func:`torch.linalg.cholesky` - * :func:`torch.linalg.cholesky_ex` - * :func:`torch.cholesky_solve` - * :func:`torch.cholesky_inverse` - * :func:`torch.lu` - * :func:`torch.linalg.qr` - * :func:`torch.linalg.eigh` - * :func:`torch.linalg.svd` +.. autofunction:: torch.backends.cuda.preferred_linalg_library torch.backends.cudnn diff --git a/test/test_linalg.py b/test/test_linalg.py index 1de9b97c39967..37365964b5ffc 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -8191,17 +8191,17 @@ def test_tensordot(self, device): @onlyCUDA @skipCUDAIfNoMagma @skipCUDAIfNoCusolver - def test_preferred_backends(self): + def test_preferred_linalg_library(self): # The main purpose of this test is to make sure these "backend" calls work normally without raising exceptions. x = torch.randint(2, 5, (2, 4, 4), device='cuda', dtype=torch.double) - torch.backends.cuda.linalg.preferred_library = torch.linalg_cusolver + torch.backends.cuda.preferred_linalg_library('cusolver') out1 = torch.linalg.inv(x) - torch.backends.cuda.linalg.preferred_library = torch.linalg_magma + torch.backends.cuda.preferred_linalg_library('magma') out2 = torch.linalg.inv(x) - torch.backends.cuda.linalg.preferred_library = torch.linalg_default + torch.backends.cuda.preferred_linalg_library('default') # Although linalg preferred flags doesn't affect CPU currently, # we set this to make sure the flag can switch back to default normally. out_ref = torch.linalg.inv(x.cpu()) diff --git a/test/test_public_bindings.py b/test/test_public_bindings.py index 92338547fbab3..769e231597473 100644 --- a/test/test_public_bindings.py +++ b/test/test_public_bindings.py @@ -147,7 +147,6 @@ def test_no_new_bindings(self): "is_inference_mode_enabled", "JITException", "layout", - "linalg_backend", "ListType", "LiteScriptModule", "LockingLogger", diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl index 0cd8b635dbd75..2f9041e432e98 100644 --- a/tools/build_variables.bzl +++ b/tools/build_variables.bzl @@ -748,7 +748,6 @@ libtorch_python_core_sources = [ "torch/csrc/Generator.cpp", "torch/csrc/Layout.cpp", "torch/csrc/MemoryFormat.cpp", - "torch/csrc/LinalgBackend.cpp", "torch/csrc/QScheme.cpp", "torch/csrc/Module.cpp", "torch/csrc/python_dimname.cpp", @@ -819,7 +818,6 @@ libtorch_python_core_sources = [ "torch/csrc/utils.cpp", "torch/csrc/utils/cuda_lazy_init.cpp", "torch/csrc/utils/invalid_arguments.cpp", - "torch/csrc/utils/linalg_backends.cpp", "torch/csrc/utils/object_ptr.cpp", "torch/csrc/utils/python_arg_parser.cpp", "torch/csrc/utils/python_dispatch.cpp", diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 29f05ac77f212..dafeaef2bb244 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -117,14 +117,6 @@ channels_last: memory_format = ... channels_last_3d: memory_format = ... preserve_format: memory_format = ... -# Defined in torch/csrc/LinalgBackend.cpp -class linalg_backend: ... - -# Defined in torch/csrc/utils/linalg_backends.cpp -linalg_default: linalg_backend = ... -linalg_cusolver: linalg_backend = ... -linalg_magma: linalg_backend = ... - # Defined in torch/csrc/QScheme.cpp class qscheme: ... @@ -603,8 +595,6 @@ def _get_warnAlways() -> _bool: ... # THPModule_warnAlways def _set_warnAlways(arg: _bool) -> None: ... # THPModule_setWarnAlways def _get_cudnn_allow_tf32() -> _bool: ... # THPModule_allowTF32CuDNN def _set_cudnn_allow_tf32(arg: _bool) -> None: ... # THPModule_setAllowTF32CuDNN -def _get_linalg_preferred_backend() -> torch.linalg_backend: ... # THPModule_linalgPreferredBackend -def _set_linalg_preferred_backend(arg: torch.linalg_backend) -> None: ... # THPModule_setLinalgPreferredBackend def _get_cublas_allow_tf32() -> _bool: ... # THPModule_allowTF32CuBLAS def _set_cublas_allow_tf32(arg: _bool) -> None: ... # THPModule_setAllowTF32CuBLAS def _get_cublas_allow_fp16_reduced_precision_reduction() -> _bool: ... #THPModule_allowFP16ReductionCuBLAS diff --git a/torch/backends/cuda/__init__.py b/torch/backends/cuda/__init__.py index c1f97ec7cae27..f75f160bfc68a 100644 --- a/torch/backends/cuda/__init__.py +++ b/torch/backends/cuda/__init__.py @@ -1,6 +1,8 @@ import sys import torch +from typing import Union + def is_built(): r"""Returns whether PyTorch is built with CUDA support. Note that this @@ -98,17 +100,56 @@ def __setattr__(self, name, value): return torch._C._set_cublas_allow_fp16_reduced_precision_reduction(value) raise AssertionError("Unknown attribute " + name) -class _LinalgModule: - def __getattr__(self, name): - if name == "preferred_library": - return torch._C._get_linalg_preferred_backend() - raise AssertionError("Unknown attribute " + name) - - def __setattr__(self, name, value): - if name == "preferred_library": - return torch._C._set_linalg_preferred_backend(value) - raise AssertionError("Unknown attribute " + name) +_LinalgBackends = { + 'default': torch._C._LinalgBackend.Default, + 'cusolver': torch._C._LinalgBackend.Cusolver, + 'magma': torch._C._LinalgBackend.Magma, +} +_LinalgBackends_str = ', '.join(_LinalgBackends.keys()) + +def preferred_linalg_library(backend: Union[None, str, torch._C._LinalgBackend] = None) -> torch._C._LinalgBackend: + r''' + .. warning:: This flag is experimental and subject to change. + + When PyTorch runs a CUDA linear algebra operation it often uses the cuSOLVER or MAGMA libraries, + and if both are available it decides which to use with a heuristic. + This flag (a :class:`str`) allows overriding those heuristics. + + * If `"cusolver"` is set then cuSOLVER will be used wherever possible. + * If `"magma"` is set then MAGMA will be used wherever possible. + * If `"default"` (the default) is set then heuristics will be used to pick between cuSOLVER and MAGMA if both are available. + * When no input is given, this function returns the currently preferred library. + + Note: When a library is preferred other libraries may still be used if the preferred library doesn't implement the operation(s) called. + This flag may achieve better performance if PyTorch's heuristic library selection is incorrect for your application's inputs. + + Currently supported linalg operators: + + * :func:`torch.linalg.inv` + * :func:`torch.linalg.inv_ex` + * :func:`torch.linalg.cholesky` + * :func:`torch.linalg.cholesky_ex` + * :func:`torch.cholesky_solve` + * :func:`torch.cholesky_inverse` + * :func:`torch.lu` + * :func:`torch.linalg.qr` + * :func:`torch.linalg.eigh` + * :func:`torch.linalg.svd` + ''' + + if backend is None: + pass + elif isinstance(backend, str): + if backend not in _LinalgBackends: + raise RuntimeError("Unknown input value. " + f"Choose from: {_LinalgBackends_str}.") + torch._C._set_linalg_preferred_backend(_LinalgBackends[backend]) + elif isinstance(backend, torch._C._LinalgBackend): + torch._C._set_linalg_preferred_backend(backend) + else: + raise RuntimeError("Unknown input value type.") + + return torch._C._get_linalg_preferred_backend() cufft_plan_cache = cuFFTPlanCacheManager() matmul = cuBLASModule() -linalg = _LinalgModule() diff --git a/torch/csrc/LinalgBackend.cpp b/torch/csrc/LinalgBackend.cpp deleted file mode 100644 index 870a1e0169b69..0000000000000 --- a/torch/csrc/LinalgBackend.cpp +++ /dev/null @@ -1,81 +0,0 @@ -#include - -#include - -#include -#include -#include - -#include -#include -#include - -PyObject *THPLinalgBackend_New(at::LinalgBackend linalg_backend) -{ - const std::string py_repr = at::LinalgBackendToRepr(linalg_backend); - auto type = (PyTypeObject*)&THPLinalgBackendType; - auto self = THPObjectPtr{type->tp_alloc(type, 0)}; - if (!self) throw python_error(); - auto self_ = reinterpret_cast(self.get()); - self_->linalg_backend = linalg_backend; - std::strncpy (self_->name, py_repr.c_str(), LINALG_BACKEND_NAME_LEN); - self_->name[LINALG_BACKEND_NAME_LEN] = '\0'; - return self.release(); -} - -PyObject *THPLinalgBackend_repr(THPLinalgBackend *self) -{ - return THPUtils_packString(self->name); -} - -PyTypeObject THPLinalgBackendType = { - PyVarObject_HEAD_INIT(nullptr, 0) - "torch.linalg_backend", /* tp_name */ - sizeof(THPLinalgBackend), /* tp_basicsize */ - 0, /* tp_itemsize */ - nullptr, /* tp_dealloc */ - 0, /* tp_vectorcall_offset */ - nullptr, /* tp_getattr */ - nullptr, /* tp_setattr */ - nullptr, /* tp_reserved */ - (reprfunc)THPLinalgBackend_repr, /* tp_repr */ - nullptr, /* tp_as_number */ - nullptr, /* tp_as_sequence */ - nullptr, /* tp_as_mapping */ - nullptr, /* tp_hash */ - nullptr, /* tp_call */ - nullptr, /* tp_str */ - nullptr, /* tp_getattro */ - nullptr, /* tp_setattro */ - nullptr, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT, /* tp_flags */ - nullptr, /* tp_doc */ - nullptr, /* tp_traverse */ - nullptr, /* tp_clear */ - nullptr, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - nullptr, /* tp_iter */ - nullptr, /* tp_iternext */ - nullptr, /* tp_methods */ - nullptr, /* tp_members */ - nullptr, /* tp_getset */ - nullptr, /* tp_base */ - nullptr, /* tp_dict */ - nullptr, /* tp_descr_get */ - nullptr, /* tp_descr_set */ - 0, /* tp_dictoffset */ - nullptr, /* tp_init */ - nullptr, /* tp_alloc */ - nullptr, /* tp_new */ -}; - -void THPLinalgBackend_init(PyObject *module) -{ - if (PyType_Ready(&THPLinalgBackendType) < 0) { - throw python_error(); - } - Py_INCREF(&THPLinalgBackendType); - if (PyModule_AddObject(module, "linalg_backend", (PyObject *)&THPLinalgBackendType) != 0) { - throw python_error(); - } -} diff --git a/torch/csrc/LinalgBackend.h b/torch/csrc/LinalgBackend.h deleted file mode 100644 index 51310bace3185..0000000000000 --- a/torch/csrc/LinalgBackend.h +++ /dev/null @@ -1,26 +0,0 @@ -#pragma once - -#include - -#include - -#include - -const int LINALG_BACKEND_NAME_LEN = 64; - -struct THPLinalgBackend { - PyObject_HEAD - at::LinalgBackend linalg_backend; - // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) - char name[LINALG_BACKEND_NAME_LEN + 1]; -}; - -extern PyTypeObject THPLinalgBackendType; - -inline bool THPLinalgBackend_Check(PyObject *obj) { - return Py_TYPE(obj) == &THPLinalgBackendType; -} - -PyObject * THPLinalgBackend_New(at::LinalgBackend linalg_backend); - -void THPLinalgBackend_init(PyObject *module); diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 24e922c62f974..be637f2b916b4 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -32,7 +33,6 @@ #include #include #include -#include #include #include #include @@ -44,7 +44,6 @@ #include #include #include -#include #include #include #include @@ -127,7 +126,6 @@ static PyObject * THPModule_initExtension(PyObject *_unused, PyObject *shm_manag return nullptr; } torch::utils::initializeLayouts(); - torch::utils::initializeLinalgBackends(); torch::utils::initializeMemoryFormats(); torch::utils::initializeQSchemes(); torch::utils::initializeDtypes(); @@ -526,23 +524,6 @@ PyObject *THPModule_allowTF32CuBLAS(PyObject *_unused, PyObject *noargs) Py_RETURN_FALSE; } -PyObject *THPModule_setLinalgPreferredBackend(PyObject *_unused, PyObject *arg) -{ - THPUtils_assert(THPLinalgBackend_Check(arg), "set_linalg_preferred_backend expects a linalg_backend, " - "but got %s", THPUtils_typename(arg)); - at::globalContext().setLinalgPreferredBackend(reinterpret_cast(arg)->linalg_backend); - Py_RETURN_NONE; -} - -PyObject *THPModule_linalgPreferredBackend(PyObject *_unused, PyObject *noargs) -{ - HANDLE_TH_ERRORS - auto res = (PyObject*)THPLinalgBackend_New(at::globalContext().linalgPreferredBackend()); - Py_INCREF(res); - return res; - END_HANDLE_TH_ERRORS -} - PyObject *THPModule_setAllowFP16ReductionCuBLAS(PyObject *_unused, PyObject *arg) { THPUtils_assert(PyBool_Check(arg), "set_allow_fp16_reduction_cublas expects a bool, " @@ -713,8 +694,6 @@ static PyMethodDef TorchMethods[] = { {"_set_warnAlways", THPModule_setWarnAlways, METH_O, nullptr}, {"_get_cublas_allow_tf32", THPModule_allowTF32CuBLAS, METH_NOARGS, nullptr}, {"_set_cublas_allow_tf32", THPModule_setAllowTF32CuBLAS, METH_O, nullptr}, - {"_get_linalg_preferred_backend", THPModule_linalgPreferredBackend, METH_NOARGS, nullptr}, - {"_set_linalg_preferred_backend", THPModule_setLinalgPreferredBackend, METH_O, nullptr}, {"_get_cublas_allow_fp16_reduced_precision_reduction", THPModule_allowFP16ReductionCuBLAS, METH_NOARGS, nullptr}, {"_set_cublas_allow_fp16_reduced_precision_reduction", THPModule_setAllowFP16ReductionCuBLAS, METH_O, nullptr}, {"_vmapmode_increment_nesting", THPModule_vmapmode_increment_nesting, METH_NOARGS, nullptr}, @@ -841,7 +820,6 @@ PyObject* initModule() { THPDTypeInfo_init(module); THPLayout_init(module); THPMemoryFormat_init(module); - THPLinalgBackend_init(module); THPQScheme_init(module); THPDevice_init(module); THPStream_init(module); @@ -1020,6 +998,18 @@ Call this whenever a new thread is created in order to propagate values from input, weight, bias_opt, stride_, padding_, dilation_, transposed_, output_padding_, groups_); }); + py::enum_(py_module, "_LinalgBackend") + .value("Default", at::LinalgBackend::Default) + .value("Cusolver", at::LinalgBackend::Cusolver) + .value("Magma", at::LinalgBackend::Magma); + + py_module.def("_set_linalg_preferred_backend", [](at::LinalgBackend b) { + at::globalContext().setLinalgPreferredBackend(b); + }); + py_module.def("_get_linalg_preferred_backend", []() { + return at::globalContext().linalgPreferredBackend(); + }); + #ifdef USE_CUDA PyObject *has_cuda = Py_True; #else diff --git a/torch/csrc/utils/linalg_backends.cpp b/torch/csrc/utils/linalg_backends.cpp deleted file mode 100644 index d4fc064018de9..0000000000000 --- a/torch/csrc/utils/linalg_backends.cpp +++ /dev/null @@ -1,40 +0,0 @@ -#include - -#include - -#include -#include -#include - -#include -#include - - -namespace torch { -namespace utils { - -#define _ADD_LINALG_BACKEND(format) \ - { \ - std::string name = at::LinalgBackendToString(format); \ - PyObject* linalg_backend = THPLinalgBackend_New(format); \ - Py_INCREF(linalg_backend); \ - if (PyModule_AddObject(torch_module, name.c_str(), linalg_backend) != 0) { \ - throw python_error(); \ - } \ - } - -void initializeLinalgBackends() { - auto torch_module = THPObjectPtr(PyImport_ImportModule("torch")); - if (!torch_module) { - throw python_error(); - } - - _ADD_LINALG_BACKEND(at::LinalgBackend::Default); - _ADD_LINALG_BACKEND(at::LinalgBackend::Cusolver); - _ADD_LINALG_BACKEND(at::LinalgBackend::Magma); -} - -#undef _ADD_LINALG_BACKEND - -} // namespace utils -} // namespace torch diff --git a/torch/csrc/utils/linalg_backends.h b/torch/csrc/utils/linalg_backends.h deleted file mode 100644 index 9ea45bf6a7bf9..0000000000000 --- a/torch/csrc/utils/linalg_backends.h +++ /dev/null @@ -1,7 +0,0 @@ -#pragma once - -namespace torch { namespace utils { - -void initializeLinalgBackends(); - -}} // namespace torch::utils diff --git a/torch/overrides.py b/torch/overrides.py index 4e731988b9cbd..f373d65fa6a39 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -111,7 +111,6 @@ def get_ignored_functions() -> Set[Callable]: torch.has_openmp, torch.iinfo, torch.memory_format, - torch.linalg_backend, torch.qscheme, torch.set_grad_enabled, torch.no_grad, From a372948ba186c96ebf7d805d0133f728d6c9f230 Mon Sep 17 00:00:00 2001 From: Xiao Wang <24860335+xwang233@users.noreply.github.com> Date: Thu, 18 Nov 2021 20:20:37 -0800 Subject: [PATCH 27/30] lint --- torch/backends/cuda/__init__.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/torch/backends/cuda/__init__.py b/torch/backends/cuda/__init__.py index f75f160bfc68a..54fce27db5cf2 100644 --- a/torch/backends/cuda/__init__.py +++ b/torch/backends/cuda/__init__.py @@ -107,7 +107,7 @@ def __setattr__(self, name, value): } _LinalgBackends_str = ', '.join(_LinalgBackends.keys()) -def preferred_linalg_library(backend: Union[None, str, torch._C._LinalgBackend] = None) -> torch._C._LinalgBackend: +def preferred_linalg_library(backend: Union[None, str] = None): r''' .. warning:: This flag is experimental and subject to change. @@ -117,11 +117,14 @@ def preferred_linalg_library(backend: Union[None, str, torch._C._LinalgBackend] * If `"cusolver"` is set then cuSOLVER will be used wherever possible. * If `"magma"` is set then MAGMA will be used wherever possible. - * If `"default"` (the default) is set then heuristics will be used to pick between cuSOLVER and MAGMA if both are available. + * If `"default"` (the default) is set then heuristics will be used to pick between + cuSOLVER and MAGMA if both are available. * When no input is given, this function returns the currently preferred library. - Note: When a library is preferred other libraries may still be used if the preferred library doesn't implement the operation(s) called. - This flag may achieve better performance if PyTorch's heuristic library selection is incorrect for your application's inputs. + Note: When a library is preferred other libraries may still be used if the preferred library + doesn't implement the operation(s) called. + This flag may achieve better performance if PyTorch's heuristic library selection is incorrect + for your application's inputs. Currently supported linalg operators: From 63106988b387fadf42b0aa3e1415a6eaba438e10 Mon Sep 17 00:00:00 2001 From: Xiao Wang <24860335+xwang233@users.noreply.github.com> Date: Thu, 18 Nov 2021 21:55:30 -0800 Subject: [PATCH 28/30] mypy --- torch/_C/__init__.pyi.in | 6 ++++++ torch/backends/cuda/__init__.py | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index dafeaef2bb244..c7d8753d8fac1 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -621,6 +621,12 @@ def _vmapmode_decrement_nesting() -> _int: ... # THPModule_vmapmode_decrement_n def _log_api_usage_once(str) -> None: ... # LogAPIUsageOnceFromPython def _demangle(str) -> str: ... # c10::demangle def _disabled_torch_function_impl(func: Callable, types: Iterable[Type], args: Tuple, kwargs: Dict) -> Any: ... # THPModule_disable_torch_function +def _get_linalg_preferred_backend() -> torch._C._LinalgBackend: ... +def _set_linalg_preferred_backend(arg: torch._C._LinalgBackend): ... +class _LinalgBackend: + Default: _LinalgBackend + Cusolver: _LinalgBackend + Magma: _LinalgBackend # Defined in `valgrind.h` and `callgrind.h` respecitively. def _valgrind_supported_platform() -> _bool: ... # NVALGRIND diff --git a/torch/backends/cuda/__init__.py b/torch/backends/cuda/__init__.py index 54fce27db5cf2..6b823389c8e41 100644 --- a/torch/backends/cuda/__init__.py +++ b/torch/backends/cuda/__init__.py @@ -107,7 +107,7 @@ def __setattr__(self, name, value): } _LinalgBackends_str = ', '.join(_LinalgBackends.keys()) -def preferred_linalg_library(backend: Union[None, str] = None): +def preferred_linalg_library(backend: Union[None, str, torch._C._LinalgBackend] = None) -> torch._C._LinalgBackend: r''' .. warning:: This flag is experimental and subject to change. From 7446e49d8a7818dc5d52e116fa0e680680be4b32 Mon Sep 17 00:00:00 2001 From: Xiao Wang <24860335+xwang233@users.noreply.github.com> Date: Mon, 22 Nov 2021 15:01:42 -0800 Subject: [PATCH 29/30] wrap test with a try-finally --- test/test_linalg.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/test/test_linalg.py b/test/test_linalg.py index 37365964b5ffc..4cebcda2b7762 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -12,7 +12,7 @@ import random from random import randrange from itertools import product -from functools import reduce +from functools import reduce, wraps from torch.testing._internal.common_utils import \ (TestCase, run_tests, TEST_SCIPY, IS_MACOS, IS_WINDOWS, slowTest, @@ -39,6 +39,18 @@ if TEST_SCIPY: import scipy +def setLinalgBackendsToDefaultFinally(fn): + @wraps(fn) + def _fn(*args, **kwargs): + try: + fn(*args, **kwargs) + finally: + # Set linalg backend back to default to make sure potential failures in one test + # doesn't affect other linalg tests + torch.backends.cuda.preferred_linalg_library('default') + return _fn + + class TestLinalg(TestCase): def setUp(self): super(self.__class__, self).setUp() @@ -8191,6 +8203,7 @@ def test_tensordot(self, device): @onlyCUDA @skipCUDAIfNoMagma @skipCUDAIfNoCusolver + @setLinalgBackendsToDefaultFinally def test_preferred_linalg_library(self): # The main purpose of this test is to make sure these "backend" calls work normally without raising exceptions. x = torch.randint(2, 5, (2, 4, 4), device='cuda', dtype=torch.double) From 3a81c30d0d306fbd71229d72ff2602b30a7bf8e3 Mon Sep 17 00:00:00 2001 From: Xiao Wang <24860335+xwang233@users.noreply.github.com> Date: Thu, 2 Dec 2021 13:00:37 -0800 Subject: [PATCH 30/30] change warning message --- aten/src/ATen/Context.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index 78b0552548152..3d19cb7bdc137 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -156,7 +156,7 @@ void Context::setLinalgPreferredBackend(at::LinalgBackend b) { if (b != at::LinalgBackend::Default) { TORCH_WARN_ONCE( "torch.backends.cuda.preferred_linalg_library is an experimental feature. " - "If you see any error or regression when this flag is set " + "If you see any error or unexpected behavior when this flag is set " "please file an issue on GitHub." ); }