Skip to content

[MPS] Add optional minor argument to is_macos13_or_newer #95065

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aten/src/ATen/detail/MPSHooksInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ struct TORCH_API MPSHooksInterface {
return false;
}

virtual bool isOnMacOS13orNewer() const {
virtual bool isOnMacOS13orNewer(unsigned minor = 0) const {
AT_ERROR("MPS backend is not available.");
}

Expand Down
14 changes: 12 additions & 2 deletions aten/src/ATen/mps/MPSHooks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,18 @@ bool MPSHooks::hasMPS() const {
return at::mps::is_available();
}

bool MPSHooks::isOnMacOS13orNewer() const {
return at::mps::is_macos_13_or_newer();
bool MPSHooks::isOnMacOS13orNewer(unsigned minor) const {
switch (minor) {
case 0:
return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_0_PLUS);
case 1:
return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_1_PLUS);
case 2:
return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_2_PLUS);
default:
TORCH_WARN("Can't check whether running on 13.",minor,"+ returning one for 13.2+");
return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_2_PLUS);
}
}

Allocator* MPSHooks::getMPSDeviceAllocator() const {
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/mps/MPSHooks.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ struct MPSHooks : public at::MPSHooksInterface {
MPSHooks(at::MPSHooksArgs) {}
void initMPS() const override;
bool hasMPS() const override;
bool isOnMacOS13orNewer() const override;
bool isOnMacOS13orNewer(unsigned minor) const override;
Allocator* getMPSDeviceAllocator() const override;
const Generator& getDefaultMPSGenerator() const override;
void deviceSynchronize() const override;
Expand Down
2 changes: 1 addition & 1 deletion torch/_C/__init__.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -1207,7 +1207,7 @@ def _mps_setMemoryFraction(fraction: _float) -> None: ...
def _mps_currentAllocatedMemory() -> _int: ...
def _mps_driverAllocatedMemory() -> _int: ...
def _mps_is_available() -> _bool: ...
def _mps_is_on_macos_13_or_newer() -> _bool: ...
def _mps_is_on_macos_13_or_newer(minor: _int) -> _bool: ...

# Defined in torch/csrc/cuda/Module.cpp
def _cuda_getCurrentStream(device: _int) -> Tuple: ...
Expand Down
4 changes: 2 additions & 2 deletions torch/backends/mps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ def is_available() -> bool:


@_lru_cache()
def is_macos13_or_newer() -> bool:
def is_macos13_or_newer(minor: int = 0) -> bool:
r"""Returns a bool indicating whether MPS is running on MacOS 13 or newer."""
return torch._C._mps_is_on_macos_13_or_newer()
return torch._C._mps_is_on_macos_13_or_newer(minor)


# Register prims as implementation of var_mean and group_norm
Expand Down
11 changes: 6 additions & 5 deletions torch/csrc/mps/Module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,12 @@ static PyObject* MPSModule_isAvailable(PyObject* _unused, PyObject* noargs) {
END_HANDLE_TH_ERRORS
}

static PyObject* MPSModule_isMacOS13orNewer(
PyObject* _unused,
PyObject* noargs) {
static PyObject* MPSModule_isMacOS13orNewer(PyObject* _unused, PyObject* args) {
HANDLE_TH_ERRORS
if (at::detail::getMPSHooks().isOnMacOS13orNewer()) {
THPUtils_assert(
THPUtils_checkLong(args), "invalid argument to isOnMacOS13orNewer()");
auto minor = THPUtils_unpackUInt32(args);
if (at::detail::getMPSHooks().isOnMacOS13orNewer(minor)) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
Expand Down Expand Up @@ -124,7 +125,7 @@ static struct PyMethodDef _MPSModule_methods[] = {
{"_mps_is_available", MPSModule_isAvailable, METH_NOARGS, nullptr},
{"_mps_is_on_macos_13_or_newer",
MPSModule_isMacOS13orNewer,
METH_NOARGS,
METH_O,
nullptr},
{"_mps_get_default_generator",
MPSModule_getDefaultMPSGenerator,
Expand Down