Skip to content

Commit

Permalink
Add support for cuda 10.1 (TF 2.1+). (#42)
Browse files Browse the repository at this point in the history
  • Loading branch information
yifeif committed Dec 17, 2019
1 parent 35e2e42 commit 62e530e
Show file tree
Hide file tree
Showing 10 changed files with 4,539 additions and 562 deletions.
40 changes: 31 additions & 9 deletions configure.sh
Expand Up @@ -62,17 +62,32 @@ if is_windows; then
echo "On windows, skipping toolchain flags.."
else
while [[ "$PIP_MANYLINUX2010" == "" ]]; do
read -p "Does the pip package have tag manylinux2010 (usually the case for nightly release after Aug 1, 2019, or official releases past 1.14.0)?"\
" Y or enter for manylinux2010, N for manylinux1. [Y/n] " INPUT
read -p "Does the pip package have tag manylinux2010 (usually the case for nightly release after Aug 1, 2019, or official releases past 1.14.0)?. Y or enter for manylinux2010, N for manylinux1. [Y/n] " INPUT
case $INPUT in
[Yy]* ) echo "Build against pip package with manylinux2010 tag. --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.0:toolchain will be added to bazel command."; PIP_MANYLINUX2010=1;;
[Nn]* ) echo "Build against pip package with manylinux1."; PIP_MANYLINUX2010=0;;
"" ) echo "Build against pip package with manylinux2010 tag. --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.0:toolchain will be added to bazel command."; PIP_MANYLINUX2010=1;;
[Yy]* ) PIP_MANYLINUX2010=1;;
[Nn]* ) PIP_MANYLINUX2010=0;;
"" ) PIP_MANYLINUX2010=1;;
* ) echo "Invalid selection: " $INPUT;;
esac
done

if [[ "$PIP_MANYLINUX2010" == "1" ]]; then
while [[ "$TF_CUDA_VERSION" == "" ]]; do
read -p "Are you building against TensorFlow 2.1(including RCs) or newer?[Y/n] " INPUT
case $INPUT in
[Yy]* ) echo "Build with the latest manylinux2010 compatible toolchains."; TF_CUDA_VERSION=10.1;;
[Nn]* ) echo "Build with prvious manylinux2010 compatible toolchains."; TF_CUDA_VERSION=10.0;;
"" ) echo "Build with the latest manylinux2010 compatible toolchains."; TF_CUDA_VERSION=10.1;;
* ) echo "Invalid selection: " $INPUT;;
esac
done
fi
fi





# CPU
if [[ "$TF_NEED_CUDA" == "0" ]]; then

Expand Down Expand Up @@ -123,7 +138,9 @@ if is_linux; then
if [[ "$PIP_MANYLINUX2010" == "0" ]]; then
write_to_bazelrc "build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain"
fi
write_to_bazelrc "build:manylinux2010 --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.0:toolchain"
write_to_bazelrc "build:manylinux2010cuda100 --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.0:toolchain"
write_to_bazelrc "build:manylinux2010cuda101 --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain"

fi
write_to_bazelrc "build --spawn_strategy=standalone"
write_to_bazelrc "build --strategy=Genrule=standalone"
Expand Down Expand Up @@ -162,7 +179,7 @@ write_action_env_to_bazelrc "TF_NEED_CUDA" ${TF_NEED_CUDA}
# TODO(yifeif): do not hardcode path
if [[ "$TF_NEED_CUDA" == "1" ]]; then
write_action_env_to_bazelrc "CUDNN_INSTALL_PATH" "/usr/lib/x86_64-linux-gnu"
write_action_env_to_bazelrc "TF_CUDA_VERSION" "10.0"
write_action_env_to_bazelrc "TF_CUDA_VERSION" ${TF_CUDA_VERSION}
write_action_env_to_bazelrc "TF_CUDNN_VERSION" "7"
write_action_env_to_bazelrc "CUDA_TOOLKIT_PATH" "/usr/local/cuda"
write_to_bazelrc "build --config=cuda"
Expand All @@ -171,6 +188,11 @@ fi


if [[ "$PIP_MANYLINUX2010" == "1" ]]; then
write_to_bazelrc "build --config=manylinux2010"
write_to_bazelrc "test --config=manylinux2010"
if [[ "$TF_CUDA_VERSION" == "10.0" ]]; then
write_to_bazelrc "build --config=manylinux2010cuda100"
write_to_bazelrc "test --config=manylinux2010cuda100"
else
write_to_bazelrc "build --config=manylinux2010cuda101"
write_to_bazelrc "test --config=manylinux2010cuda101"
fi
fi
84 changes: 65 additions & 19 deletions gpu/crosstool/BUILD.tpl
@@ -1,3 +1,8 @@
# This file is expanded from a template by cuda_configure.bzl
# Update cuda_configure.bzl#verify_build_defines when adding new variables.

load(":cc_toolchain_config.bzl", "cc_toolchain_config")

licenses(["restricted"])

package(default_visibility = ["//visibility:public"])
Expand All @@ -24,6 +29,7 @@ cc_toolchain_suite(
"x64_windows|msvc-cl": ":cc-compiler-windows",
"x64_windows": ":cc-compiler-windows",
"arm": ":cc-compiler-local",
"aarch64": ":cc-compiler-local",
"k8": ":cc-compiler-local",
"piii": ":cc-compiler-local",
"ppc": ":cc-compiler-local",
Expand All @@ -33,51 +39,91 @@ cc_toolchain_suite(

cc_toolchain(
name = "cc-compiler-local",
all_files = "%{linker_files}",
compiler_files = ":empty",
cpu = "local",
all_files = "%{compiler_deps}",
compiler_files = "%{compiler_deps}",
ar_files = "%{compiler_deps}",
as_files = "%{compiler_deps}",
dwp_files = ":empty",
dynamic_runtime_libs = [":empty"],
linker_files = "%{linker_files}",
linker_files = "%{compiler_deps}",
objcopy_files = ":empty",
static_runtime_libs = [":empty"],
strip_files = ":empty",
# To support linker flags that need to go to the start of command line
# we need the toolchain to support parameter files. Parameter files are
# last on the command line and contain all shared libraries to link, so all
# regular options will be left of them.
supports_param_files = 1,
toolchain_identifier = "local_linux",
toolchain_config = ":cc-compiler-local-config",
)

cc_toolchain_config(
name = "cc-compiler-local-config",
cpu = "local",
builtin_include_directories = [%{cxx_builtin_include_directories}],
extra_no_canonical_prefixes_flags = [%{extra_no_canonical_prefixes_flags}],
host_compiler_path = "%{host_compiler_path}",
host_compiler_prefix = "%{host_compiler_prefix}",
host_compiler_warnings = [%{host_compiler_warnings}],
host_unfiltered_compile_flags = [%{unfiltered_compile_flags}],
linker_bin_path = "%{linker_bin_path}",
builtin_sysroot = "%{builtin_sysroot}",
cuda_path = "%{cuda_toolkit_path}",
)

cc_toolchain(
name = "cc-compiler-darwin",
all_files = "%{linker_files}",
compiler_files = ":empty",
cpu = "darwin",
all_files = "%{compiler_deps}",
compiler_files = "%{compiler_deps}",
ar_files = "%{compiler_deps}",
as_files = "%{compiler_deps}",
dwp_files = ":empty",
dynamic_runtime_libs = [":empty"],
linker_files = "%{linker_files}",
linker_files = "%{compiler_deps}",
objcopy_files = ":empty",
static_runtime_libs = [":empty"],
strip_files = ":empty",
supports_param_files = 0,
toolchain_identifier = "local_darwin",
toolchain_config = ":cc-compiler-local-darwin",
)

cc_toolchain_config(
name = "cc-compiler-local-darwin",
cpu = "darwin",
builtin_include_directories = [%{cxx_builtin_include_directories}],
extra_no_canonical_prefixes_flags = [%{extra_no_canonical_prefixes_flags}],
host_compiler_path = "%{host_compiler_path}",
host_compiler_prefix = "%{host_compiler_prefix}",
host_compiler_warnings = [%{host_compiler_warnings}],
host_unfiltered_compile_flags = [%{unfiltered_compile_flags}],
linker_bin_path = "%{linker_bin_path}",
)

cc_toolchain(
name = "cc-compiler-windows",
all_files = "%{win_linker_files}",
compiler_files = ":empty",
cpu = "x64_windows",
all_files = "%{win_compiler_deps}",
compiler_files = "%{win_compiler_deps}",
ar_files = "%{win_compiler_deps}",
as_files = "%{win_compiler_deps}",
dwp_files = ":empty",
dynamic_runtime_libs = [":empty"],
linker_files = "%{win_linker_files}",
linker_files = "%{win_compiler_deps}",
objcopy_files = ":empty",
static_runtime_libs = [":empty"],
strip_files = ":empty",
supports_param_files = 1,
toolchain_identifier = "local_windows",
toolchain_config = ":cc-compiler-windows-config",
)

cc_toolchain_config(
name = "cc-compiler-windows-config",
cpu = "x64_windows",
builtin_include_directories = [%{cxx_builtin_include_directories}],
msvc_cl_path = "%{msvc_cl_path}",
msvc_env_include = "%{msvc_env_include}",
msvc_env_lib = "%{msvc_env_lib}",
msvc_env_path = "%{msvc_env_path}",
msvc_env_tmp = "%{msvc_env_tmp}",
msvc_lib_path = "%{msvc_lib_path}",
msvc_link_path = "%{msvc_link_path}",
msvc_ml_path = "%{msvc_ml_path}",
)

filegroup(
Expand All @@ -93,4 +139,4 @@ filegroup(
filegroup(
name = "windows_msvc_wrapper_files",
srcs = glob(["windows/msvc_*"]),
)
)

0 comments on commit 62e530e

Please sign in to comment.