Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fbgemm_gpu] Fix Triton OSS installation #2618

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/scripts/fbgemm_gpu_build.bash
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ prepare_fbgemm_gpu_build () {

# BUILD_VARIANT is provided by the github workflow file
if [ "$BUILD_VARIANT" == "cuda" ]; then
(install_triton "${env_name}") || return 1
(install_triton_pip "${env_name}") || return 1
fi

# shellcheck disable=SC2086
Expand Down
42 changes: 29 additions & 13 deletions .github/scripts/utils_pip.bash
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,17 @@ __export_package_variant_info () {
}

__export_pip_arguments () {
local include_variant="$1"

# shellcheck disable=SC2155
local postfix=$([ "$include_variant" = true ] && echo "${package_variant}/" || echo "")

# Extract the PIP channel
if [ "$package_channel" == "release" ]; then
export pip_channel="https://download.pytorch.org/whl/${package_variant}/"
export pip_channel="https://download.pytorch.org/whl/${postfix}"
else
echo "[INSTALL] Using a non-RELEASE channel: ${package_channel} ..."
export pip_channel="https://download.pytorch.org/whl/${package_channel}/${package_variant}/"
export pip_channel="https://download.pytorch.org/whl/${package_channel}/${postfix}"
fi
echo "[INSTALL] Extracted the full PIP channel: ${pip_channel}"

Expand All @@ -114,9 +119,13 @@ __export_pip_arguments () {
else
export pip_package="${package_name}"
fi

# shellcheck disable=SC2155
local postfix=$([ "$include_variant" = true ] && echo "+${package_variant}" || echo "")

# If a specific version is specified, then append with `==<version>`
if [ "$package_version" != "" ]; then
export pip_package="${pip_package}==${package_version}+${package_variant}"
export pip_package="${pip_package}==${package_version}${postfix}"
fi
echo "[INSTALL] Extracted the full PIP package: ${pip_package}"
}
Expand All @@ -125,8 +134,8 @@ __prepare_pip_arguments () {
local package_name_raw="$1"
local package_channel_version="$2"
local package_variant_type_version="$3"
if [ "$package_variant_type_version" == "" ]; then
echo "Usage: ${FUNCNAME[0]} PACKAGE_NAME PACKAGE_CHANNEL[/VERSION] PACKAGE_VARIANT_TYPE[/VARIANT_VERSION]"
if [ "$package_channel_version" == "" ]; then
echo "Usage: ${FUNCNAME[0]} PACKAGE_NAME PACKAGE_CHANNEL[/VERSION] [PACKAGE_VARIANT_TYPE[/VARIANT_VERSION]]"
return 1
else
echo "################################################################################"
Expand All @@ -146,27 +155,33 @@ __prepare_pip_arguments () {
# export variables to environment
__export_package_channel_info "$package_channel_version"

# Extract the package variant type and variant version from the tuple-string,
# and export variables to environment
__export_package_variant_info "${package_variant_type_version}"
if [ "$package_variant_type_version" != "" ]; then
# Extract the package variant type and variant version from the tuple-string,
# and export variables to environment
__export_package_variant_info "${package_variant_type_version}"
fi

# With all package_* variables exported, extract the arguments for PIP, and
# export variabels to environment
__export_pip_arguments
__export_pip_arguments "$([ "$package_variant_type_version" != "" ] && echo "true" || echo "false")"
}

install_from_pytorch_pip () {
local env_name="$1"
local package_name_raw="$2"
local package_channel_version="$3"
local package_variant_type_version="$4"
if [ "$package_variant_type_version" == "" ]; then
echo "Usage: ${FUNCNAME[0]} ENV_NAME PACKAGE_NAME PACKAGE_CHANNEL[/VERSION] PACKAGE_VARIANT_TYPE[/VARIANT_VERSION]"
if [ "$package_channel_version" == "" ]; then
echo "Usage: ${FUNCNAME[0]} ENV_NAME PACKAGE_NAME PACKAGE_CHANNEL[/VERSION] [PACKAGE_VARIANT_TYPE[/VARIANT_VERSION]]"
echo "Example(s):"
echo " ${FUNCNAME[0]} build_env torch 1.11.0 cpu # Install the CPU variant, specific version from release channel"
echo " ${FUNCNAME[0]} build_env torch release cpu # Install the CPU variant, latest version from release channel"
echo " ${FUNCNAME[0]} build_env fbgemm_gpu test/0.6.0rc0 cuda/12.1.0 # Install the CUDA 12.1 variant, specific version from test channel"
echo " ${FUNCNAME[0]} build_env fbgemm_gpu nightly rocm/5.3 # Install the ROCM 5.3 variant, latest version from nightly channel"
echo " ${FUNCNAME[0]} build_env pytorch_triton 1.11.0 # Install specific version from release channel"
echo " ${FUNCNAME[0]} build_env pytorch_triton release # Install latest version from release channel"
echo " ${FUNCNAME[0]} build_env pytorch_triton test/0.6.0rc0 # Install specific version from test channel"
echo " ${FUNCNAME[0]} build_env pytorch_triton_rocm nightly # Install latest version from nightly channel"
return 1
else
echo "################################################################################"
Expand All @@ -188,8 +203,8 @@ install_from_pytorch_pip () {
# shellcheck disable=SC2086
(exec_with_retries 3 conda run ${env_prefix} pip install ${pip_package} --index-url ${pip_channel}) || return 1

# Check only applies to non-CPU variants
if [ "$package_variant_type" != "cpu" ]; then
# Check applies to installation of packages with variants, and only to non-CPU variants
if [ "$package_variant_type_version" != "" ] && [ "$package_variant_type" != "cpu" ]; then
# Ensure that the package build is of the correct variant
# This test usually applies to the nightly builds
# shellcheck disable=SC2086
Expand All @@ -203,6 +218,7 @@ install_from_pytorch_pip () {
fi
}


################################################################################
# PyTorch PIP Download Functions
################################################################################
Expand Down
36 changes: 32 additions & 4 deletions .github/scripts/utils_triton.bash
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@

# shellcheck disable=SC1091,SC2128
. "$( dirname -- "$BASH_SOURCE"; )/utils_base.bash"
# shellcheck disable=SC1091,SC2128
. "$( dirname -- "$BASH_SOURCE"; )/utils_pip.bash"

################################################################################
# Triton Setup Functions
################################################################################

install_triton () {
install_triton_gitmodule () {
local env_name="$1"
local triton_version="$2"
if [ "$env_name" == "" ]; then
Expand All @@ -24,7 +26,7 @@ install_triton () {
return 1
else
echo "################################################################################"
echo "# Build + Install Triton"
echo "# Build + Install Triton (gitmodule)"
echo "#"
echo "# [$(date --utc +%FT%T.%3NZ)] + ${FUNCNAME[0]} ${*}"
echo "################################################################################"
Expand All @@ -42,13 +44,39 @@ install_triton () {
(print_exec git checkout "${triton_version}") || return 1
fi

echo "[BUILD] Installing Triton ..."
echo "[BUILD] Installing Triton from gitmodule ..."
# shellcheck disable=SC2086
(exec_with_retries 3 conda run --no-capture-output ${env_prefix} python -m pip install -e .) || return 1

# shellcheck disable=SC2086
(test_python_import_package "${env_name}" triton) || return 1

cd - || return 1
echo "[INSTALL] Successfully installed Triton ${triton_version}"
echo "[INSTALL] Successfully installed Triton ${triton_version} from gitmodule"
}

install_triton_pip () {
local env_name="$1"
if [ "$env_name" == "" ]; then
echo "Usage: ${FUNCNAME[0]} ENV_NAME"
echo "Example(s):"
echo " ${FUNCNAME[0]} build_env"
return 1
else
echo "################################################################################"
echo "# Install PyTorch (PyTorch PIP)"
echo "#"
echo "# [$(date --utc +%FT%T.%3NZ)] + ${FUNCNAME[0]} ${*}"
echo "################################################################################"
echo ""
fi

echo "[BUILD] Installing Triton from PIP ..."
# shellcheck disable=SC2086
install_from_pytorch_pip "${env_name}" pytorch-triton nightly/3.0.0+45fff310c8 || return 1

# shellcheck disable=SC2086
(test_python_import_package "${env_name}" triton) || return 1

echo "[INSTALL] Successfully installed PyTorch through PyTorch PIP"
}
Loading