Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUILD] Support building out of tree plugins #3007

Merged
merged 20 commits into from
Feb 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,23 @@ if(TRITON_BUILD_PYTHON_MODULE)
add_link_options(${Python3_LINK_OPTIONS})
endif()

if (DEFINED TRITON_PLUGIN_DIRS)
foreach(PLUGIN_DIR ${TRITON_PLUGIN_DIRS})
# Read the plugin name under dir/backend/name.conf
cmake_path(APPEND PLUGIN_DIR "backend" "name.conf" OUTPUT_VARIABLE PLUGIN_NAME_PATH)
file(READ ${PLUGIN_NAME_PATH} PLUGIN_NAME)
string(STRIP ${PLUGIN_NAME} PLUGIN_NAME)

list(APPEND TRITON_PLUGIN_NAMES ${PLUGIN_NAME})

# Include the plugin as part of the build, placing the build output under
# ${TRITON_BINARY_DIR}/third_party/${PLUGIN_NAME}
cmake_path(APPEND TRITON_BINARY_DIR "third_party" ${PLUGIN_NAME} OUTPUT_VARIABLE PLUGIN_DIR_BUILD_OUTPUT)
message(STATUS "Building plugin '${PLUGIN_NAME}' from ${PLUGIN_DIR} with output ${PLUGIN_DIR_BUILD_OUTPUT}")
add_subdirectory(${PLUGIN_DIR} ${PLUGIN_DIR_BUILD_OUTPUT})
endforeach()
endif()

foreach(CODEGEN_BACKEND ${TRITON_CODEGEN_BACKENDS})
add_subdirectory(third_party/${CODEGEN_BACKEND})
endforeach()
Expand Down Expand Up @@ -191,6 +208,13 @@ if(TRITON_BUILD_PYTHON_MODULE)

# Define triton library
string(JOIN "," TRITON_BACKENDS_TUPLE ${TRITON_CODEGEN_BACKENDS})

if (DEFINED TRITON_PLUGIN_NAMES)
string(JOIN "," TRITON_BACKENDS_TUPLE ${TRITON_BACKENDS_TUPLE} ${TRITON_PLUGIN_NAMES})
endif()

message(STATUS "Triton backends tuple: ${TRITON_BACKENDS_TUPLE}")

set(TRITON_BACKENDS_TUPLE "(${TRITON_BACKENDS_TUPLE})")
add_compile_definitions(TRITON_BACKENDS_TUPLE=${TRITON_BACKENDS_TUPLE})
add_library(triton SHARED ${PYTHON_SRC_PATH}/main.cc
Expand Down
94 changes: 52 additions & 42 deletions python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,34 +29,59 @@ class Backend:
src_dir: str
backend_dir: str
install_dir: str
is_external: bool


def _copy_backends(active):
ret = []
root_dir = os.path.join(os.pardir, "third_party")
for backend in active:
curr_path = os.path.join(root_dir, backend)
backend_path = os.path.abspath(os.path.join(curr_path, "backend"))
install_dir = os.path.join(os.path.dirname(__file__), "triton", "backends", backend)
# initialize submodule if there is one
try:
subprocess.run(["git", "submodule", "update", "--init", f"{backend}"], check=True,
stdout=subprocess.DEVNULL, cwd=root_dir)
except subprocess.CalledProcessError:
pass
except FileNotFoundError:
pass
# check conditions
assert backend in os.listdir(root_dir), f"{backend} is requested for install but not present in {root_dir}"
class BackendInstaller:

@staticmethod
def prepare(backend_name: str, backend_src_dir: str = None, is_external: bool = False):
# Initialize submodule if there is one for in-tree backends.
if not is_external:
root_dir = os.path.join(os.pardir, "third_party")
assert backend_name in os.listdir(
root_dir), f"{backend_name} is requested for install but not present in {root_dir}"

try:
subprocess.run(["git", "submodule", "update", "--init", f"{backend_name}"], check=True,
stdout=subprocess.DEVNULL, cwd=root_dir)
except subprocess.CalledProcessError:
pass
except FileNotFoundError:
pass

backend_src_dir = os.path.join(root_dir, backend_name)

backend_path = os.path.abspath(os.path.join(backend_src_dir, "backend"))
assert os.path.exists(backend_path), f"{backend_path} does not exist!"

for file in ["compiler.py", "driver.py"]:
assert os.path.exists(os.path.join(backend_path, file))
# update
assert os.path.exists(os.path.join(backend_path, file)), f"${file} does not exist in ${backend_path}"

install_dir = os.path.join(os.path.dirname(__file__), "triton", "backends", backend_name)
package_data = [f"{os.path.relpath(p, backend_path)}/*" for p, _, _, in os.walk(backend_path)]
ret.append(
Backend(name=backend, package_data=package_data, src_dir=curr_path, backend_dir=backend_path,
install_dir=install_dir))
return ret
return Backend(name=backend_name, package_data=package_data, src_dir=backend_src_dir, backend_dir=backend_path,
install_dir=install_dir, is_external=is_external)

# Copy all in-tree backends under triton/third_party.
@staticmethod
def copy(active):
return [BackendInstaller.prepare(backend) for backend in active]

# Copy all external plugins provided by the `TRITON_PLUGIN_DIRS` env var.
# TRITON_PLUGIN_DIRS is a semicolon-separated list of paths to the plugins.
# Expect to find the name of the backend under dir/backend/name.conf
@staticmethod
def copy_externals():
backend_dirs = os.getenv("TRITON_PLUGIN_DIRS")
if backend_dirs is None:
return []
backend_dirs = backend_dirs.strip().split(";")
backend_names = [Path(os.path.join(dir, "backend", "name.conf")).read_text().strip() for dir in backend_dirs]
return [
BackendInstaller.prepare(backend_name, backend_src_dir=backend_src_dir, is_external=True)
for backend_name, backend_src_dir in zip(backend_names, backend_dirs)
]


# Taken from https://github.com/pytorch/pytorch/blob/master/tools/setup_helpers/env.py
Expand All @@ -76,16 +101,6 @@ def get_build_type():
return "TritonRelBuildWithAsserts"


def get_codegen_backends():
backends = []
env_prefix = "TRITON_CODEGEN_"
for name, _ in os.environ.items():
if name.startswith(env_prefix) and check_env_flag(name):
assert name.count(env_prefix) <= 1
backends.append(name.replace(env_prefix, '').lower())
return backends


# --- third party packages -----


Expand Down Expand Up @@ -298,7 +313,8 @@ def build_extension(self, ext):
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir, "-DTRITON_BUILD_TUTORIALS=OFF",
"-DTRITON_BUILD_PYTHON_MODULE=ON", "-DPython3_EXECUTABLE:FILEPATH=" + sys.executable,
"-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON", "-DPYTHON_INCLUDE_DIRS=" + python_include_dir,
"-DTRITON_CODEGEN_BACKENDS=" + ';'.join([b.name for b in backends])
"-DTRITON_CODEGEN_BACKENDS=" + ';'.join([b.name for b in backends if not b.is_external]),
"-DTRITON_PLUGIN_DIRS=" + ';'.join([b.src_dir for b in backends if b.is_external])
]
if lit_dir is not None:
cmake_args.append("-DLLVM_EXTERNAL_LIT=" + lit_dir)
Expand All @@ -308,13 +324,6 @@ def build_extension(self, ext):
cfg = get_build_type()
build_args = ["--config", cfg]

# third-party backend

# codegen_backends = get_codegen_backends()
# if len(codegen_backends) > 0:
# all_codegen_backends = ';'.join(codegen_backends)
# cmake_args += ["-DTRITON_CODEGEN_BACKENDS=" + all_codegen_backends]

if platform.system() == "Windows":
cmake_args += [f"-DCMAKE_RUNTIME_OUTPUT_DIRECTORY_{cfg.upper()}={extdir}"]
if sys.maxsize > 2**32:
Expand Down Expand Up @@ -381,7 +390,8 @@ def build_extension(self, ext):
url_func=lambda arch, version:
f"https://anaconda.org/nvidia/cuda-nvdisasm/12.3.52/download/linux-{arch}/cuda-nvdisasm-{version}-0.tar.bz2",
)
backends = _copy_backends(["nvidia", "amd"])

backends = [*BackendInstaller.copy(["nvidia", "amd"]), *BackendInstaller.copy_externals()]


def add_link_to_backends():
Expand Down
Loading