Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUILD] Support building out of tree plugins #3007

Merged
merged 20 commits into from
Feb 12, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,25 @@ if(TRITON_BUILD_PYTHON_MODULE)
add_link_options(${Python3_LINK_OPTIONS})
endif()

if (DEFINED TRITON_PLUGIN_DIRS)
foreach(PLUGIN_DIR ${TRITON_PLUGIN_DIRS})
# Generate a list of plugin names by getting the last component of the provided paths.
# Note that each path must not end with a slash, otherwise cmake_path will return empty
# string as the name.
cmake_path(GET PLUGIN_DIR FILENAME PLUGIN_NAME)
if("${PLUGIN_NAME}" STREQUAL "")
message(FATAL_ERROR "Plugin dir ${PLUGIN_DIR} must not end with a slash")
endif()
list(APPEND TRITON_PLUGIN_NAMES ${PLUGIN_NAME})

# Include the plugin as part of the build, placing the build output under
# ${TRITON_BINARY_DIR}/${PLUGIN_NAME}
cmake_path(APPEND TRITON_BINARY_DIR ${PLUGIN_NAME} OUTPUT_VARIABLE PLUGIN_DIR_BUILD_OUTPUT)
message(STATUS "Building plugin '${PLUGIN_NAME}' from ${PLUGIN_DIR} with output ${PLUGIN_DIR_BUILD_OUTPUT}")
add_subdirectory(${PLUGIN_DIR} ${PLUGIN_DIR_BUILD_OUTPUT})
endforeach()
endif()

foreach(CODEGEN_BACKEND ${TRITON_CODEGEN_BACKENDS})
add_subdirectory(third_party/${CODEGEN_BACKEND})
endforeach()
Expand Down Expand Up @@ -191,6 +210,13 @@ if(TRITON_BUILD_PYTHON_MODULE)

# Define triton library
string(JOIN "," TRITON_BACKENDS_TUPLE ${TRITON_CODEGEN_BACKENDS})

if (DEFINED TRITON_PLUGIN_NAMES)
string(JOIN "," TRITON_BACKENDS_TUPLE ${TRITON_BACKENDS_TUPLE} ${TRITON_PLUGIN_NAMES})
endif()

message(STATUS "Triton backends tuple: ${TRITON_BACKENDS_TUPLE}")

set(TRITON_BACKENDS_TUPLE "(${TRITON_BACKENDS_TUPLE})")
add_compile_definitions(TRITON_BACKENDS_TUPLE=${TRITON_BACKENDS_TUPLE})
add_library(triton SHARED ${PYTHON_SRC_PATH}/main.cc
Expand Down
106 changes: 62 additions & 44 deletions python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,34 +29,67 @@ class Backend:
src_dir: str
backend_dir: str
install_dir: str
is_external: bool


def _copy_backends(active):
ret = []
root_dir = os.path.join(os.pardir, "third_party")
for backend in active:
curr_path = os.path.join(root_dir, backend)
backend_path = os.path.abspath(os.path.join(curr_path, "backend"))
install_dir = os.path.join(os.path.dirname(__file__), "triton", "backends", backend)
# initialize submodule if there is one
try:
subprocess.run(["git", "submodule", "update", "--init", f"{backend}"], check=True,
stdout=subprocess.DEVNULL, cwd=root_dir)
except subprocess.CalledProcessError:
pass
except FileNotFoundError:
pass
# check conditions
assert backend in os.listdir(root_dir), f"{backend} is requested for install but not present in {root_dir}"
assert os.path.exists(backend_path), f"{backend_path} does not exist!"
class BackendInstaller:

@staticmethod
def prepare(backend_name: str, backend_src_dir: str = None, is_external: bool = False):
# Initialize submodule if there is one for in-tree backends.
if not is_external:
root_dir = os.path.join(os.pardir, "third_party")
if backend_name not in os.listdir(root_dir):
raise Exception(f"{backend_name} is requested for install but not present in {root_dir}")

try:
subprocess.run(["git", "submodule", "update", "--init", f"{backend_name}"], check=True,
stdout=subprocess.DEVNULL, cwd=root_dir)
except subprocess.CalledProcessError:
pass
except FileNotFoundError:
pass

backend_src_dir = os.path.join(root_dir, backend_name)

backend_dir = os.path.abspath(os.path.join(backend_src_dir, "backend"))
if not os.path.exists(backend_dir):
raise Exception(f"{backend_dir} does not exist!")
for file in ["compiler.py", "driver.py"]:
assert os.path.exists(os.path.join(backend_path, file))
# update
package_data = [f"{os.path.relpath(p, backend_path)}/*" for p, _, _, in os.walk(backend_path)]
ret.append(
Backend(name=backend, package_data=package_data, src_dir=curr_path, backend_dir=backend_path,
install_dir=install_dir))
return ret
if not os.path.exists(os.path.join(backend_dir, file)):
raise Exception(f"${file} does not exist in ${backend_dir}")

install_dir = os.path.join(os.path.dirname(__file__), "triton", "backends", backend_name)
package_data = [f"{os.path.relpath(p, backend_dir)}/*" for p, _, _, in os.walk(backend_dir)]
return Backend(name=backend_name, package_data=package_data, src_dir=backend_src_dir, backend_dir=backend_dir,
install_dir=install_dir, is_external=is_external)

# Copy all in-tree backends under triton/third_party.
@staticmethod
def copy(active):
return [BackendInstaller.prepare(backend) for backend in active]

# Copy all external plugins provided by the `TRITON_PLUGIN_DIRS` env var.
# TRITON_PLUGIN_DIRS is a semicolon-separated list of paths to the plugins.
# The last components of the paths must be valid python identifiders.
@staticmethod
def copy_externals():

def get_backend_name(dir: str):
name = Path(dir).name
if not name.isidentifier():
raise Exception(f"{name} must be a valid python identifier")
return name

backend_dirs = os.getenv("TRITON_PLUGIN_DIRS")
if backend_dirs is None:
return []
backend_dirs = backend_dirs.strip().split(";")
backend_names = [get_backend_name(dir) for dir in backend_dirs]
return [
BackendInstaller.prepare(backend_name, backend_src_dir=backend_src_dir, is_external=True)
for backend_name, backend_src_dir in zip(backend_names, backend_dirs)
]


# Taken from https://github.com/pytorch/pytorch/blob/master/tools/setup_helpers/env.py
Expand All @@ -76,16 +109,6 @@ def get_build_type():
return "TritonRelBuildWithAsserts"


def get_codegen_backends():
backends = []
env_prefix = "TRITON_CODEGEN_"
for name, _ in os.environ.items():
if name.startswith(env_prefix) and check_env_flag(name):
assert name.count(env_prefix) <= 1
backends.append(name.replace(env_prefix, '').lower())
return backends


# --- third party packages -----


Expand Down Expand Up @@ -298,7 +321,8 @@ def build_extension(self, ext):
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir, "-DTRITON_BUILD_TUTORIALS=OFF",
"-DTRITON_BUILD_PYTHON_MODULE=ON", "-DPython3_EXECUTABLE:FILEPATH=" + sys.executable,
"-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON", "-DPYTHON_INCLUDE_DIRS=" + python_include_dir,
"-DTRITON_CODEGEN_BACKENDS=" + ';'.join([b.name for b in backends])
"-DTRITON_CODEGEN_BACKENDS=" + ';'.join([b.name for b in backends if not b.is_external]),
"-DTRITON_PLUGIN_DIRS=" + ';'.join([b.src_dir for b in backends if b.is_external])
]
if lit_dir is not None:
cmake_args.append("-DLLVM_EXTERNAL_LIT=" + lit_dir)
Expand All @@ -308,13 +332,6 @@ def build_extension(self, ext):
cfg = get_build_type()
build_args = ["--config", cfg]

# third-party backend

# codegen_backends = get_codegen_backends()
# if len(codegen_backends) > 0:
# all_codegen_backends = ';'.join(codegen_backends)
# cmake_args += ["-DTRITON_CODEGEN_BACKENDS=" + all_codegen_backends]

if platform.system() == "Windows":
cmake_args += [f"-DCMAKE_RUNTIME_OUTPUT_DIRECTORY_{cfg.upper()}={extdir}"]
if sys.maxsize > 2**32:
Expand Down Expand Up @@ -381,7 +398,8 @@ def build_extension(self, ext):
url_func=lambda arch, version:
f"https://anaconda.org/nvidia/cuda-nvdisasm/12.3.52/download/linux-{arch}/cuda-nvdisasm-{version}-0.tar.bz2",
)
backends = _copy_backends(["nvidia", "amd"])

backends = [*BackendInstaller.copy(["nvidia", "amd"]), *BackendInstaller.copy_externals()]


def add_link_to_backends():
Expand Down
Loading