Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions .github/workflows/metal_test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
name: Run TorchAO Experimental MPS Tests
on:
push:
branches:
- main
- 'gh/**'
pull_request:
branches:
- main
- 'gh/**'

jobs:
test-mps-ops:
name: test-mps-ops
uses: pytorch/test-infra/.github/workflows/macos_job.yml@main
with:
runner: macos-m1-stable
python-version: '3.11'
submodules: 'recursive'
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
timeout: 90
script: |
set -eux

echo "::group::Install Torch"
${CONDA_RUN} pip install torch --index-url "https://download.pytorch.org/whl/nightly/cpu"
echo "::endgroup::"

echo "::group::Install requirements"
${CONDA_RUN} pip install -r dev-requirements.txt
echo "::endgroup::"

echo "::group::Install experimental MPS ops"
${CONDA_RUN} USE_CPP=1 TORCHAO_BUILD_EXPERIMENTAL_MPS=1 pip install . --no-build-isolation
echo "::endgroup::"

echo "::group::Run lowbit tests"
${CONDA_RUN} python -m pytest torchao/experimental/ops/mps/test/test_lowbit.py
echo "::endgroup::"

echo "::group::Run quantizer tests"
${CONDA_RUN} python -m pytest torchao/experimental/ops/mps/test/test_quantizer.py
echo "::endgroup::"
1 change: 1 addition & 0 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ transformers
hypothesis # Avoid test derandomization warning
sentencepiece # for gpt-fast tokenizer
expecttest
pyyaml

# For prototype features and benchmarks
bitsandbytes # needed for testing triton quant / dequant ops for 8-bit optimizers
Expand Down
35 changes: 27 additions & 8 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,10 @@ def get_cutlass_build_flags():
)


def bool_to_on_off(value):
return "ON" if value else "OFF"


# BuildExtension is a subclass of from setuptools.command.build_ext.build_ext
class TorchAOBuildExt(BuildExtension):
def __init__(self, *args, **kwargs) -> None:
Expand All @@ -353,16 +357,19 @@ def build_extensions(self):
def build_cmake(self, ext):
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name)))

if not os.path.exists(self.build_temp):
os.makedirs(self.build_temp)
# Use a unique build directory per CMake extension to avoid cache conflicts
# when multiple extensions use different CMakeLists.txt source directories
ext_build_temp = os.path.join(self.build_temp, ext.name.replace(".", "_"))
if not os.path.exists(ext_build_temp):
os.makedirs(ext_build_temp)

# Get the expected extension file name that Python will look for
# We force CMake to use this library name
ext_filename = os.path.basename(self.get_ext_filename(ext.name))
ext_basename = os.path.splitext(ext_filename)[0]

print(
"CMAKE COMMANG",
"CMAKE COMMAND",
[
"cmake",
ext.cmake_lists_dir,
Expand All @@ -384,9 +391,9 @@ def build_cmake(self, ext):
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir,
"-DTORCHAO_CMAKE_EXT_SO_NAME=" + ext_basename,
],
cwd=self.build_temp,
cwd=ext_build_temp,
)
subprocess.check_call(["cmake", "--build", "."], cwd=self.build_temp)
subprocess.check_call(["cmake", "--build", "."], cwd=ext_build_temp)


class CMakeExtension(Extension):
Expand Down Expand Up @@ -772,9 +779,6 @@ def get_extensions():
if build_macos_arm_auto or os.getenv("BUILD_TORCHAO_EXPERIMENTAL") == "1":
build_options = BuildOptions()

def bool_to_on_off(value):
return "ON" if value else "OFF"

from distutils.sysconfig import get_python_lib

torch_dir = get_python_lib() + "/torch/share/cmake/Torch"
Expand All @@ -799,6 +803,21 @@ def bool_to_on_off(value):
)
)

if build_options.build_experimental_mps:
ext_modules.append(
CMakeExtension(
"torchao._C_mps",
cmake_lists_dir="torchao/experimental/ops/mps",
cmake_args=(
[
f"-DCMAKE_BUILD_TYPE={'Debug' if use_debug_mode() else 'Release'}",
f"-DTORCHAO_BUILD_MPS_OPS={bool_to_on_off(build_options.build_experimental_mps)}",
"-DTorch_DIR=" + torch_dir,
]
),
)
)

return ext_modules


Expand Down
Empty file.
6 changes: 4 additions & 2 deletions torchao/experimental/ops/mps/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ set(METAL_SHADERS_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../kernels/mps/metal)
file(GLOB METAL_FILES ${METAL_SHADERS_DIR}/*.metal)
set(METAL_SHADERS_YAML ${CMAKE_CURRENT_SOURCE_DIR}/../../kernels/mps/metal.yaml)
set(GEN_SCRIPT ${CMAKE_CURRENT_SOURCE_DIR}/../../kernels/mps/codegen/gen_metal_shader_lib.py)
set(GENERATED_METAL_SHADER_LIB ${CMAKE_INSTALL_PREFIX}/include/torchao/experimental/kernels/mps/src/metal_shader_lib.h)
# Use the build directory for generated files to avoid permission issues during pip install
set(GENERATED_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated_include)
set(GENERATED_METAL_SHADER_LIB ${GENERATED_INCLUDE_DIR}/torchao/experimental/kernels/mps/src/metal_shader_lib.h)
add_custom_command(
OUTPUT ${GENERATED_METAL_SHADER_LIB}
COMMAND python ${GEN_SCRIPT} ${GENERATED_METAL_SHADER_LIB}
Expand All @@ -45,7 +47,7 @@ endif()
message(STATUS "TORCHAO_INCLUDE_DIRS: ${TORCHAO_INCLUDE_DIRS}")

include_directories(${TORCHAO_INCLUDE_DIRS})
include_directories(${CMAKE_INSTALL_PREFIX}/include)
include_directories(${GENERATED_INCLUDE_DIR})
add_library(torchao_ops_mps_linear_fp_act_xbit_weight_aten OBJECT linear_fp_act_xbit_weight_aten.mm)
add_dependencies(torchao_ops_mps_linear_fp_act_xbit_weight_aten generated_metal_shader_lib)

Expand Down
3 changes: 3 additions & 0 deletions torchao/experimental/ops/mps/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from torchao.experimental.ops.mps.utils import _load_torchao_mps_lib

_load_torchao_mps_lib()
24 changes: 1 addition & 23 deletions torchao/experimental/ops/mps/test/test_lowbit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,35 +4,13 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import os
import unittest

import torch
from parameterized import parameterized

# Need to import to load the ops
from torchao.experimental.quant_api import UIntxWeightOnlyLinearQuantizer # noqa: F401

try:
for nbit in range(1, 8):
getattr(torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight")
getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit")
except AttributeError:
try:
libname = "libtorchao_ops_mps_aten.dylib"
libpath = os.path.abspath(
os.path.join(os.path.dirname(__file__), "../cmake-out/lib/", libname)
)
torch.ops.load_library(libpath)
except:
raise RuntimeError(f"Failed to load library {libpath}")
else:
try:
for nbit in range(1, 8):
getattr(torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight")
getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit")
except AttributeError as e:
raise e
import torchao.experimental.ops.mps # noqa: F401


class TestLowBitQuantWeightsLinear(unittest.TestCase):
Expand Down
25 changes: 2 additions & 23 deletions torchao/experimental/ops/mps/test/test_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,36 +6,15 @@

import copy
import itertools
import os
import unittest

import torch
from parameterized import parameterized

import torchao # noqa: F401
# Need to import to load the ops
import torchao.experimental.ops.mps # noqa: F401
from torchao.experimental.quant_api import UIntxWeightOnlyLinearQuantizer, _quantize

try:
for nbit in range(1, 8):
getattr(torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight")
getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit")
except AttributeError:
try:
libname = "libtorchao_ops_mps_aten.dylib"
libpath = os.path.abspath(
os.path.join(os.path.dirname(__file__), "../cmake-out/lib/", libname)
)
torch.ops.load_library(libpath)
except:
raise RuntimeError(f"Failed to load library {libpath}")
else:
try:
for nbit in range(1, 8):
getattr(torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight")
getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit")
except AttributeError as e:
raise e


class TestUIntxWeightOnlyLinearQuantizer(unittest.TestCase):
BITWIDTHS = range(1, 8)
Expand Down
69 changes: 69 additions & 0 deletions torchao/experimental/ops/mps/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import glob
import os

import torch


def _get_torchao_mps_lib_path():
"""Get the path to the MPS ops library.

Searches in the following locations:
1. The torchao package directory (for pip-installed packages)
2. The build directory (for development installs from source)
3. The cmake-out directory relative to this file (for standalone CMake builds)
"""
import torchao

libname = "libtorchao_ops_mps_aten.dylib"

# Try the torchao package directory first (pip install location)
torchao_dir = os.path.dirname(torchao.__file__)
pip_libpath = os.path.join(torchao_dir, libname)
if os.path.exists(pip_libpath):
return pip_libpath

# Try the build directory (for editable/development installs)
# The build directory is typically at the repo root level
repo_root = os.path.dirname(torchao_dir)
build_pattern = os.path.join(repo_root, "build", "lib.*", "torchao", libname)
build_matches = glob.glob(build_pattern)
if build_matches:
return build_matches[0]

# Fall back to cmake-out directory (standalone CMake build)
cmake_libpath = os.path.abspath(
os.path.join(os.path.dirname(__file__), "cmake-out/lib/", libname)
)
if os.path.exists(cmake_libpath):
return cmake_libpath

return None


def _load_torchao_mps_lib():
"""Load the MPS ops library."""
try:
for nbit in range(1, 8):
getattr(torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight")
getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit")
except AttributeError:
libpath = _get_torchao_mps_lib_path()
if libpath is None:
raise RuntimeError(
"Could not find libtorchao_ops_mps_aten.dylib. "
"Please build with TORCHAO_BUILD_EXPERIMENTAL_MPS=1"
)
try:
torch.ops.load_library(libpath)
except Exception as e:
raise RuntimeError(f"Failed to load library {libpath}: {e}")

for nbit in range(1, 8):
getattr(torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight")
getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit")
Loading