From 6508416a1474d0a38df8217f5a6f84aecbe87d6b Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Fri, 17 Feb 2023 07:16:39 -0800 Subject: [PATCH 1/3] [MPS] Add optional `minor` argument to `is_macos13_or_newer` Will be needed if one wants to make accurate XFAIL validation --- aten/src/ATen/detail/MPSHooksInterface.h | 2 +- aten/src/ATen/mps/MPSHooks.cpp | 14 ++++++++++++-- aten/src/ATen/mps/MPSHooks.h | 2 +- torch/backends/mps/__init__.py | 4 ++-- torch/csrc/mps/Module.cpp | 9 ++++++--- 5 files changed, 22 insertions(+), 9 deletions(-) diff --git a/aten/src/ATen/detail/MPSHooksInterface.h b/aten/src/ATen/detail/MPSHooksInterface.h index 27f4f193c63a..827d441645f1 100644 --- a/aten/src/ATen/detail/MPSHooksInterface.h +++ b/aten/src/ATen/detail/MPSHooksInterface.h @@ -28,7 +28,7 @@ struct TORCH_API MPSHooksInterface { return false; } - virtual bool isOnMacOS13orNewer() const { + virtual bool isOnMacOS13orNewer(unsigned minor = 0) const { AT_ERROR("MPS backend is not available."); } diff --git a/aten/src/ATen/mps/MPSHooks.cpp b/aten/src/ATen/mps/MPSHooks.cpp index e71bfcc73922..89adac6c34b1 100644 --- a/aten/src/ATen/mps/MPSHooks.cpp +++ b/aten/src/ATen/mps/MPSHooks.cpp @@ -17,8 +17,18 @@ bool MPSHooks::hasMPS() const { return at::mps::is_available(); } -bool MPSHooks::isOnMacOS13orNewer() const { - return at::mps::is_macos_13_or_newer(); +bool MPSHooks::isOnMacOS13orNewer(unsigned minor) const { + switch (minor) { + case 0: + return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_0_PLUS); + case 1: + return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_1_PLUS); + case 2: + return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_2_PLUS); + default: + TORCH_WARN("Can't check whether running on 13.",minor,"+ returning one for 13.2+"); + return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_2_PLUS); + } } Allocator* MPSHooks::getMPSDeviceAllocator() const { diff --git a/aten/src/ATen/mps/MPSHooks.h b/aten/src/ATen/mps/MPSHooks.h index 260113891d51..9e913b38a2e1 100644 --- a/aten/src/ATen/mps/MPSHooks.h +++ b/aten/src/ATen/mps/MPSHooks.h @@ -13,7 +13,7 @@ struct MPSHooks : public at::MPSHooksInterface { MPSHooks(at::MPSHooksArgs) {} void initMPS() const override; bool hasMPS() const override; - bool isOnMacOS13orNewer() const override; + bool isOnMacOS13orNewer(unsigned minor) const override; Allocator* getMPSDeviceAllocator() const override; const Generator& getDefaultMPSGenerator() const override; void deviceSynchronize() const override; diff --git a/torch/backends/mps/__init__.py b/torch/backends/mps/__init__.py index 32f284f1d500..2c6ef64665bc 100644 --- a/torch/backends/mps/__init__.py +++ b/torch/backends/mps/__init__.py @@ -19,9 +19,9 @@ def is_available() -> bool: @_lru_cache() -def is_macos13_or_newer() -> bool: +def is_macos13_or_newer(minor: int = 0) -> bool: r"""Returns a bool indicating whether MPS is running on MacOS 13 or newer.""" - return torch._C._mps_is_on_macos_13_or_newer() + return torch._C._mps_is_on_macos_13_or_newer(minor) # Register prims as implementation of var_mean and group_norm diff --git a/torch/csrc/mps/Module.cpp b/torch/csrc/mps/Module.cpp index ffbc3b9eceaa..e1ac1c151388 100644 --- a/torch/csrc/mps/Module.cpp +++ b/torch/csrc/mps/Module.cpp @@ -61,9 +61,12 @@ static PyObject* MPSModule_isAvailable(PyObject* _unused, PyObject* noargs) { static PyObject* MPSModule_isMacOS13orNewer( PyObject* _unused, - PyObject* noargs) { + PyObject* args) { HANDLE_TH_ERRORS - if (at::detail::getMPSHooks().isOnMacOS13orNewer()) { + THPUtils_assert( + THPUtils_checkLong(args), "invalid argument to isOnMacOS13orNewer()"); + auto minor = THPUtils_unpackUInt32(args); + if (at::detail::getMPSHooks().isOnMacOS13orNewer(minor)) { Py_RETURN_TRUE; } else { Py_RETURN_FALSE; @@ -124,7 +127,7 @@ static struct PyMethodDef _MPSModule_methods[] = { {"_mps_is_available", MPSModule_isAvailable, METH_NOARGS, nullptr}, {"_mps_is_on_macos_13_or_newer", MPSModule_isMacOS13orNewer, - METH_NOARGS, + METH_O, nullptr}, {"_mps_get_default_generator", MPSModule_getDefaultMPSGenerator, From a5e9081876b96bc5d297115646b7e860e21933a5 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Fri, 17 Feb 2023 08:23:48 -0800 Subject: [PATCH 2/3] Fix lint --- torch/csrc/mps/Module.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/torch/csrc/mps/Module.cpp b/torch/csrc/mps/Module.cpp index e1ac1c151388..0a1c45c0838d 100644 --- a/torch/csrc/mps/Module.cpp +++ b/torch/csrc/mps/Module.cpp @@ -59,9 +59,7 @@ static PyObject* MPSModule_isAvailable(PyObject* _unused, PyObject* noargs) { END_HANDLE_TH_ERRORS } -static PyObject* MPSModule_isMacOS13orNewer( - PyObject* _unused, - PyObject* args) { +static PyObject* MPSModule_isMacOS13orNewer(PyObject* _unused, PyObject* args) { HANDLE_TH_ERRORS THPUtils_assert( THPUtils_checkLong(args), "invalid argument to isOnMacOS13orNewer()"); From 751e9a5d001c9cec7eb6398ece5caa9586035ed6 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Fri, 17 Feb 2023 09:08:19 -0800 Subject: [PATCH 3/3] And here as well --- torch/_C/__init__.pyi.in | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 1bd547cc3c6b..b4f8510f6fc6 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -1207,7 +1207,7 @@ def _mps_setMemoryFraction(fraction: _float) -> None: ... def _mps_currentAllocatedMemory() -> _int: ... def _mps_driverAllocatedMemory() -> _int: ... def _mps_is_available() -> _bool: ... -def _mps_is_on_macos_13_or_newer() -> _bool: ... +def _mps_is_on_macos_13_or_newer(minor: _int) -> _bool: ... # Defined in torch/csrc/cuda/Module.cpp def _cuda_getCurrentStream(device: _int) -> Tuple: ...