diff --git a/.github/workflows/metal_test.yml b/.github/workflows/metal_test.yml new file mode 100644 index 0000000000..c82cff7442 --- /dev/null +++ b/.github/workflows/metal_test.yml @@ -0,0 +1,43 @@ +name: Run TorchAO Experimental MPS Tests +on: + push: + branches: + - main + - 'gh/**' + pull_request: + branches: + - main + - 'gh/**' + +jobs: + test-mps-ops: + name: test-mps-ops + uses: pytorch/test-infra/.github/workflows/macos_job.yml@main + with: + runner: macos-m1-stable + python-version: '3.11' + submodules: 'recursive' + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + timeout: 90 + script: | + set -eux + + echo "::group::Install Torch" + ${CONDA_RUN} pip install torch --index-url "https://download.pytorch.org/whl/nightly/cpu" + echo "::endgroup::" + + echo "::group::Install requirements" + ${CONDA_RUN} pip install -r dev-requirements.txt + echo "::endgroup::" + + echo "::group::Install experimental MPS ops" + ${CONDA_RUN} USE_CPP=1 TORCHAO_BUILD_EXPERIMENTAL_MPS=1 pip install . --no-build-isolation + echo "::endgroup::" + + echo "::group::Run lowbit tests" + ${CONDA_RUN} python -m pytest torchao/experimental/ops/mps/test/test_lowbit.py + echo "::endgroup::" + + echo "::group::Run quantizer tests" + ${CONDA_RUN} python -m pytest torchao/experimental/ops/mps/test/test_quantizer.py + echo "::endgroup::" diff --git a/dev-requirements.txt b/dev-requirements.txt index ef00257bb7..c55d9bd661 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -7,6 +7,7 @@ transformers hypothesis # Avoid test derandomization warning sentencepiece # for gpt-fast tokenizer expecttest +pyyaml # For prototype features and benchmarks bitsandbytes # needed for testing triton quant / dequant ops for 8-bit optimizers diff --git a/setup.py b/setup.py index 9d2d7bce1c..11b869facb 100644 --- a/setup.py +++ b/setup.py @@ -329,6 +329,10 @@ def get_cutlass_build_flags(): ) +def bool_to_on_off(value): + return "ON" if value else "OFF" + + # BuildExtension is a subclass of from setuptools.command.build_ext.build_ext class TorchAOBuildExt(BuildExtension): def __init__(self, *args, **kwargs) -> None: @@ -353,8 +357,11 @@ def build_extensions(self): def build_cmake(self, ext): extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name))) - if not os.path.exists(self.build_temp): - os.makedirs(self.build_temp) + # Use a unique build directory per CMake extension to avoid cache conflicts + # when multiple extensions use different CMakeLists.txt source directories + ext_build_temp = os.path.join(self.build_temp, ext.name.replace(".", "_")) + if not os.path.exists(ext_build_temp): + os.makedirs(ext_build_temp) # Get the expected extension file name that Python will look for # We force CMake to use this library name @@ -362,7 +369,7 @@ def build_cmake(self, ext): ext_basename = os.path.splitext(ext_filename)[0] print( - "CMAKE COMMANG", + "CMAKE COMMAND", [ "cmake", ext.cmake_lists_dir, @@ -384,9 +391,9 @@ def build_cmake(self, ext): "-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir, "-DTORCHAO_CMAKE_EXT_SO_NAME=" + ext_basename, ], - cwd=self.build_temp, + cwd=ext_build_temp, ) - subprocess.check_call(["cmake", "--build", "."], cwd=self.build_temp) + subprocess.check_call(["cmake", "--build", "."], cwd=ext_build_temp) class CMakeExtension(Extension): @@ -772,9 +779,6 @@ def get_extensions(): if build_macos_arm_auto or os.getenv("BUILD_TORCHAO_EXPERIMENTAL") == "1": build_options = BuildOptions() - def bool_to_on_off(value): - return "ON" if value else "OFF" - from distutils.sysconfig import get_python_lib torch_dir = get_python_lib() + "/torch/share/cmake/Torch" @@ -799,6 +803,21 @@ def bool_to_on_off(value): ) ) + if build_options.build_experimental_mps: + ext_modules.append( + CMakeExtension( + "torchao._C_mps", + cmake_lists_dir="torchao/experimental/ops/mps", + cmake_args=( + [ + f"-DCMAKE_BUILD_TYPE={'Debug' if use_debug_mode() else 'Release'}", + f"-DTORCHAO_BUILD_MPS_OPS={bool_to_on_off(build_options.build_experimental_mps)}", + "-DTorch_DIR=" + torch_dir, + ] + ), + ) + ) + return ext_modules diff --git a/torchao/experimental/ops/__init__.py b/torchao/experimental/ops/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/torchao/experimental/ops/mps/CMakeLists.txt b/torchao/experimental/ops/mps/CMakeLists.txt index 8dcdec523e..34d4b178e0 100644 --- a/torchao/experimental/ops/mps/CMakeLists.txt +++ b/torchao/experimental/ops/mps/CMakeLists.txt @@ -30,7 +30,9 @@ set(METAL_SHADERS_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../kernels/mps/metal) file(GLOB METAL_FILES ${METAL_SHADERS_DIR}/*.metal) set(METAL_SHADERS_YAML ${CMAKE_CURRENT_SOURCE_DIR}/../../kernels/mps/metal.yaml) set(GEN_SCRIPT ${CMAKE_CURRENT_SOURCE_DIR}/../../kernels/mps/codegen/gen_metal_shader_lib.py) -set(GENERATED_METAL_SHADER_LIB ${CMAKE_INSTALL_PREFIX}/include/torchao/experimental/kernels/mps/src/metal_shader_lib.h) +# Use the build directory for generated files to avoid permission issues during pip install +set(GENERATED_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated_include) +set(GENERATED_METAL_SHADER_LIB ${GENERATED_INCLUDE_DIR}/torchao/experimental/kernels/mps/src/metal_shader_lib.h) add_custom_command( OUTPUT ${GENERATED_METAL_SHADER_LIB} COMMAND python ${GEN_SCRIPT} ${GENERATED_METAL_SHADER_LIB} @@ -45,7 +47,7 @@ endif() message(STATUS "TORCHAO_INCLUDE_DIRS: ${TORCHAO_INCLUDE_DIRS}") include_directories(${TORCHAO_INCLUDE_DIRS}) -include_directories(${CMAKE_INSTALL_PREFIX}/include) +include_directories(${GENERATED_INCLUDE_DIR}) add_library(torchao_ops_mps_linear_fp_act_xbit_weight_aten OBJECT linear_fp_act_xbit_weight_aten.mm) add_dependencies(torchao_ops_mps_linear_fp_act_xbit_weight_aten generated_metal_shader_lib) diff --git a/torchao/experimental/ops/mps/__init__.py b/torchao/experimental/ops/mps/__init__.py new file mode 100644 index 0000000000..1f29e43c27 --- /dev/null +++ b/torchao/experimental/ops/mps/__init__.py @@ -0,0 +1,3 @@ +from torchao.experimental.ops.mps.utils import _load_torchao_mps_lib + +_load_torchao_mps_lib() diff --git a/torchao/experimental/ops/mps/test/test_lowbit.py b/torchao/experimental/ops/mps/test/test_lowbit.py index dc2460110e..ddaae39c15 100644 --- a/torchao/experimental/ops/mps/test/test_lowbit.py +++ b/torchao/experimental/ops/mps/test/test_lowbit.py @@ -4,35 +4,13 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -import os import unittest import torch from parameterized import parameterized # Need to import to load the ops -from torchao.experimental.quant_api import UIntxWeightOnlyLinearQuantizer # noqa: F401 - -try: - for nbit in range(1, 8): - getattr(torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight") - getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit") -except AttributeError: - try: - libname = "libtorchao_ops_mps_aten.dylib" - libpath = os.path.abspath( - os.path.join(os.path.dirname(__file__), "../cmake-out/lib/", libname) - ) - torch.ops.load_library(libpath) - except: - raise RuntimeError(f"Failed to load library {libpath}") - else: - try: - for nbit in range(1, 8): - getattr(torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight") - getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit") - except AttributeError as e: - raise e +import torchao.experimental.ops.mps # noqa: F401 class TestLowBitQuantWeightsLinear(unittest.TestCase): diff --git a/torchao/experimental/ops/mps/test/test_quantizer.py b/torchao/experimental/ops/mps/test/test_quantizer.py index e7d035fb61..23e7f16727 100644 --- a/torchao/experimental/ops/mps/test/test_quantizer.py +++ b/torchao/experimental/ops/mps/test/test_quantizer.py @@ -6,36 +6,15 @@ import copy import itertools -import os import unittest import torch from parameterized import parameterized -import torchao # noqa: F401 +# Need to import to load the ops +import torchao.experimental.ops.mps # noqa: F401 from torchao.experimental.quant_api import UIntxWeightOnlyLinearQuantizer, _quantize -try: - for nbit in range(1, 8): - getattr(torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight") - getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit") -except AttributeError: - try: - libname = "libtorchao_ops_mps_aten.dylib" - libpath = os.path.abspath( - os.path.join(os.path.dirname(__file__), "../cmake-out/lib/", libname) - ) - torch.ops.load_library(libpath) - except: - raise RuntimeError(f"Failed to load library {libpath}") - else: - try: - for nbit in range(1, 8): - getattr(torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight") - getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit") - except AttributeError as e: - raise e - class TestUIntxWeightOnlyLinearQuantizer(unittest.TestCase): BITWIDTHS = range(1, 8) diff --git a/torchao/experimental/ops/mps/utils.py b/torchao/experimental/ops/mps/utils.py new file mode 100644 index 0000000000..6fdfa07b60 --- /dev/null +++ b/torchao/experimental/ops/mps/utils.py @@ -0,0 +1,69 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import glob +import os + +import torch + + +def _get_torchao_mps_lib_path(): + """Get the path to the MPS ops library. + + Searches in the following locations: + 1. The torchao package directory (for pip-installed packages) + 2. The build directory (for development installs from source) + 3. The cmake-out directory relative to this file (for standalone CMake builds) + """ + import torchao + + libname = "libtorchao_ops_mps_aten.dylib" + + # Try the torchao package directory first (pip install location) + torchao_dir = os.path.dirname(torchao.__file__) + pip_libpath = os.path.join(torchao_dir, libname) + if os.path.exists(pip_libpath): + return pip_libpath + + # Try the build directory (for editable/development installs) + # The build directory is typically at the repo root level + repo_root = os.path.dirname(torchao_dir) + build_pattern = os.path.join(repo_root, "build", "lib.*", "torchao", libname) + build_matches = glob.glob(build_pattern) + if build_matches: + return build_matches[0] + + # Fall back to cmake-out directory (standalone CMake build) + cmake_libpath = os.path.abspath( + os.path.join(os.path.dirname(__file__), "cmake-out/lib/", libname) + ) + if os.path.exists(cmake_libpath): + return cmake_libpath + + return None + + +def _load_torchao_mps_lib(): + """Load the MPS ops library.""" + try: + for nbit in range(1, 8): + getattr(torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight") + getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit") + except AttributeError: + libpath = _get_torchao_mps_lib_path() + if libpath is None: + raise RuntimeError( + "Could not find libtorchao_ops_mps_aten.dylib. " + "Please build with TORCHAO_BUILD_EXPERIMENTAL_MPS=1" + ) + try: + torch.ops.load_library(libpath) + except Exception as e: + raise RuntimeError(f"Failed to load library {libpath}: {e}") + + for nbit in range(1, 8): + getattr(torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight") + getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit")