Skip to content

Commit

Permalink
[BUILD] Support building out of tree plugins (#3007)
Browse files Browse the repository at this point in the history
This PR adds support for building out of tree plugins with triton.
External plugins can provide an environment variable before invoking
`python setup.py` to let triton know where to find the backends.

+ `TRITON_PLUGIN_DIRS`: semicolon-separated list of directories
containing the plugins. The last components of the paths will be used as
plugin names.

For instance, if we have `triton_shared` under
`/home/user/github/triton_shared`, to build `triton_shared` with
`triton`, we can do:

```
export TRITON_PLUGIN_DIRS="/home/user/github/triton_shared"
python3 setup.py build
```

This also assumes that the external backends have moved over to the new
external backend architecture; they have to expose a `backend` folder
under their root directory.

With this change, all external backends will still be copied to
`triton/runtime` and be used like:

```python
import triton
from triton.backends.triton_shared.driver import CPUDriver

triton.runtime.driver.set_active(CPUDriver())
```
  • Loading branch information
nhat-nguyen committed Feb 12, 2024
1 parent 91fb6ea commit 7acb4ff
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 42 deletions.
24 changes: 24 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,23 @@ if(TRITON_BUILD_PYTHON_MODULE)
add_link_options(${Python3_LINK_OPTIONS})
endif()

if (DEFINED TRITON_PLUGIN_DIRS)
foreach(PLUGIN_DIR ${TRITON_PLUGIN_DIRS})
# Read the plugin name under dir/backend/name.conf
cmake_path(APPEND PLUGIN_DIR "backend" "name.conf" OUTPUT_VARIABLE PLUGIN_NAME_PATH)
file(READ ${PLUGIN_NAME_PATH} PLUGIN_NAME)
string(STRIP ${PLUGIN_NAME} PLUGIN_NAME)

list(APPEND TRITON_PLUGIN_NAMES ${PLUGIN_NAME})

# Include the plugin as part of the build, placing the build output under
# ${TRITON_BINARY_DIR}/third_party/${PLUGIN_NAME}
cmake_path(APPEND TRITON_BINARY_DIR "third_party" ${PLUGIN_NAME} OUTPUT_VARIABLE PLUGIN_DIR_BUILD_OUTPUT)
message(STATUS "Building plugin '${PLUGIN_NAME}' from ${PLUGIN_DIR} with output ${PLUGIN_DIR_BUILD_OUTPUT}")
add_subdirectory(${PLUGIN_DIR} ${PLUGIN_DIR_BUILD_OUTPUT})
endforeach()
endif()

foreach(CODEGEN_BACKEND ${TRITON_CODEGEN_BACKENDS})
add_subdirectory(third_party/${CODEGEN_BACKEND})
endforeach()
Expand Down Expand Up @@ -200,6 +217,13 @@ if(TRITON_BUILD_PYTHON_MODULE)

# Define triton library
string(JOIN "," TRITON_BACKENDS_TUPLE ${TRITON_CODEGEN_BACKENDS})

if (DEFINED TRITON_PLUGIN_NAMES)
string(JOIN "," TRITON_BACKENDS_TUPLE ${TRITON_BACKENDS_TUPLE} ${TRITON_PLUGIN_NAMES})
endif()

message(STATUS "Triton backends tuple: ${TRITON_BACKENDS_TUPLE}")

set(TRITON_BACKENDS_TUPLE "(${TRITON_BACKENDS_TUPLE})")
add_compile_definitions(TRITON_BACKENDS_TUPLE=${TRITON_BACKENDS_TUPLE})
add_library(triton SHARED ${PYTHON_SRC_PATH}/main.cc
Expand Down
94 changes: 52 additions & 42 deletions python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,34 +29,59 @@ class Backend:
src_dir: str
backend_dir: str
install_dir: str
is_external: bool


def _copy_backends(active):
ret = []
root_dir = os.path.join(os.pardir, "third_party")
for backend in active:
curr_path = os.path.join(root_dir, backend)
backend_path = os.path.abspath(os.path.join(curr_path, "backend"))
install_dir = os.path.join(os.path.dirname(__file__), "triton", "backends", backend)
# initialize submodule if there is one
try:
subprocess.run(["git", "submodule", "update", "--init", f"{backend}"], check=True,
stdout=subprocess.DEVNULL, cwd=root_dir)
except subprocess.CalledProcessError:
pass
except FileNotFoundError:
pass
# check conditions
assert backend in os.listdir(root_dir), f"{backend} is requested for install but not present in {root_dir}"
class BackendInstaller:

@staticmethod
def prepare(backend_name: str, backend_src_dir: str = None, is_external: bool = False):
# Initialize submodule if there is one for in-tree backends.
if not is_external:
root_dir = os.path.join(os.pardir, "third_party")
assert backend_name in os.listdir(
root_dir), f"{backend_name} is requested for install but not present in {root_dir}"

try:
subprocess.run(["git", "submodule", "update", "--init", f"{backend_name}"], check=True,
stdout=subprocess.DEVNULL, cwd=root_dir)
except subprocess.CalledProcessError:
pass
except FileNotFoundError:
pass

backend_src_dir = os.path.join(root_dir, backend_name)

backend_path = os.path.abspath(os.path.join(backend_src_dir, "backend"))
assert os.path.exists(backend_path), f"{backend_path} does not exist!"

for file in ["compiler.py", "driver.py"]:
assert os.path.exists(os.path.join(backend_path, file))
# update
assert os.path.exists(os.path.join(backend_path, file)), f"${file} does not exist in ${backend_path}"

install_dir = os.path.join(os.path.dirname(__file__), "triton", "backends", backend_name)
package_data = [f"{os.path.relpath(p, backend_path)}/*" for p, _, _, in os.walk(backend_path)]
ret.append(
Backend(name=backend, package_data=package_data, src_dir=curr_path, backend_dir=backend_path,
install_dir=install_dir))
return ret
return Backend(name=backend_name, package_data=package_data, src_dir=backend_src_dir, backend_dir=backend_path,
install_dir=install_dir, is_external=is_external)

# Copy all in-tree backends under triton/third_party.
@staticmethod
def copy(active):
return [BackendInstaller.prepare(backend) for backend in active]

# Copy all external plugins provided by the `TRITON_PLUGIN_DIRS` env var.
# TRITON_PLUGIN_DIRS is a semicolon-separated list of paths to the plugins.
# Expect to find the name of the backend under dir/backend/name.conf
@staticmethod
def copy_externals():
backend_dirs = os.getenv("TRITON_PLUGIN_DIRS")
if backend_dirs is None:
return []
backend_dirs = backend_dirs.strip().split(";")
backend_names = [Path(os.path.join(dir, "backend", "name.conf")).read_text().strip() for dir in backend_dirs]
return [
BackendInstaller.prepare(backend_name, backend_src_dir=backend_src_dir, is_external=True)
for backend_name, backend_src_dir in zip(backend_names, backend_dirs)
]


# Taken from https://github.com/pytorch/pytorch/blob/master/tools/setup_helpers/env.py
Expand All @@ -76,16 +101,6 @@ def get_build_type():
return "TritonRelBuildWithAsserts"


def get_codegen_backends():
backends = []
env_prefix = "TRITON_CODEGEN_"
for name, _ in os.environ.items():
if name.startswith(env_prefix) and check_env_flag(name):
assert name.count(env_prefix) <= 1
backends.append(name.replace(env_prefix, '').lower())
return backends


# --- third party packages -----


Expand Down Expand Up @@ -298,7 +313,8 @@ def build_extension(self, ext):
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir, "-DTRITON_BUILD_TUTORIALS=OFF",
"-DTRITON_BUILD_PYTHON_MODULE=ON", "-DPython3_EXECUTABLE:FILEPATH=" + sys.executable,
"-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON", "-DPYTHON_INCLUDE_DIRS=" + python_include_dir,
"-DTRITON_CODEGEN_BACKENDS=" + ';'.join([b.name for b in backends])
"-DTRITON_CODEGEN_BACKENDS=" + ';'.join([b.name for b in backends if not b.is_external]),
"-DTRITON_PLUGIN_DIRS=" + ';'.join([b.src_dir for b in backends if b.is_external])
]
if lit_dir is not None:
cmake_args.append("-DLLVM_EXTERNAL_LIT=" + lit_dir)
Expand All @@ -308,13 +324,6 @@ def build_extension(self, ext):
cfg = get_build_type()
build_args = ["--config", cfg]

# third-party backend

# codegen_backends = get_codegen_backends()
# if len(codegen_backends) > 0:
# all_codegen_backends = ';'.join(codegen_backends)
# cmake_args += ["-DTRITON_CODEGEN_BACKENDS=" + all_codegen_backends]

if platform.system() == "Windows":
cmake_args += [f"-DCMAKE_RUNTIME_OUTPUT_DIRECTORY_{cfg.upper()}={extdir}"]
if sys.maxsize > 2**32:
Expand Down Expand Up @@ -381,7 +390,8 @@ def build_extension(self, ext):
url_func=lambda arch, version:
f"https://anaconda.org/nvidia/cuda-nvdisasm/12.3.52/download/linux-{arch}/cuda-nvdisasm-{version}-0.tar.bz2",
)
backends = _copy_backends(["nvidia", "amd"])

backends = [*BackendInstaller.copy(["nvidia", "amd"]), *BackendInstaller.copy_externals()]


def add_link_to_backends():
Expand Down

0 comments on commit 7acb4ff

Please sign in to comment.