Skip to content

[ROCm] Bazel build and continuous integration infrastructure#20277

Merged
tensorflow-copybara merged 1 commit intotensorflow:masterfrom
ROCm:upstream-staging
Sep 27, 2018
Merged

[ROCm] Bazel build and continuous integration infrastructure#20277
tensorflow-copybara merged 1 commit intotensorflow:masterfrom
ROCm:upstream-staging

Conversation

@whchung
Copy link
Copy Markdown
Contributor

@whchung whchung commented Jun 25, 2018

This pull request is to start introduce support for ROCm platform to TensorFlow. In this initial pull request, 2 components are addressed:

  • bazel build system
  • continuous integration logic

Authors:

@whchung
Copy link
Copy Markdown
Contributor Author

whchung commented Jul 6, 2018

ping?

yifeif
yifeif previously approved these changes Jul 10, 2018
Copy link
Copy Markdown
Contributor

@yifeif yifeif left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

configure and build script change lgtm

caisq
caisq previously approved these changes Jul 10, 2018
apt-get clean && \
rm -rf /var/lib/apt/lists/*

# Workaround: use HIP PR#457 and then build from source
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably better to provide the full link here for posterity.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HIP PR#457 ( ROCm/hip#457 ) had been merged into HIP mainline after this PR was created. I'll amend this PR to address this fact.

/cc @paralleo for awareness

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@caisq pushed new commit to address this

Comment thread configure.py Outdated
set_trisycl_include_dir(environ_cp)

set_action_env_var(environ_cp, 'TF_NEED_ROCM', 'ROCm', False)
if environ_cp.get('TF_NEED_ROCM') == '1':
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: These two if statements can be consolidated as one, because neither of them has an else branch.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to keep them as is for now. There are additional upcoming ROCm-specific checks / env vars in future PRs. For example, like in CUDA path where it's possible to build TensorFlow with either nvcc or CUDA clang, we are also working on similar route where it's possible to switch between incumbent HIP/HCC toolchain or upcoming HIP clang toolchain.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought the eventual goal was to just have clang?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gunan for the time being HIP/HCC would be the incumbent toolchain on ROCm and we'll switch to HIP clang toolchain once its performance beats the incumbent solution.

I'm working on revising this PR to address comments from all reviewers now and will ping you once it's ready.

Comment thread configure.py
else:
set_trisycl_include_dir(environ_cp)

set_action_env_var(environ_cp, 'TF_NEED_ROCM', 'ROCm', False)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TF_NEED_CUDA, TF_NEED_SYCL and TF_NEED_ROCM are all mutually exclusive, right? If so, we need a sanity check here that at most one of them is true.

Copy link
Copy Markdown
Contributor Author

@whchung whchung Jul 10, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From TensorFlow perspective they are not mutually exclusive. It's possible to enable all three and still have TensorFlow built. Although I'd be surprised to see such configuration in real life. J

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at some of the options below, enabling both at the same time would break other things. For example, the _gpu cc_test targets. Which GPU would they run on?
Let's add the check @caisq requested here.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok will do.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revised the PR to raise error in case more than 2 GPU platforms are specified (CUDA/SYCL/ROCm).

@whchung whchung dismissed stale reviews from caisq and yifeif via 76be23a July 10, 2018 18:58
Comment thread tensorflow/tensorflow.bzl Outdated
hdrs=[],
**kwargs):
copts = copts + _cuda_copts() + if_cuda(cuda_copts) + tf_copts()
copts=copts + tf_copts() + _cuda_copts() + _rocm_copts() + if_cuda_is_configured(cuda_copts) + if_rocm_is_configured(cuda_copts)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's hide the additional complexities here by adding an arg to "_cuda_copts" and "_rocm_copts"

So this line becomes:
copts=copts + tf_copts() + _cuda_copts(opts=cuda_copts) + _rocm_copts(opts=cuda_copts)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gunan Thanks. I'll address this.

@@ -0,0 +1,97 @@
# This Dockerfile provides a starting point for a ROCm installation of
# MIOpen and tensorflow.
FROM ubuntu:xenial-20170619
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not a more generic image? "ubuntu:xenial"

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@parallelo Please weigh in. I believe the specific tag was specified to ensure the proper Linux kernel version is used?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. We typically use a 4.13-45 linux kernel.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have an upper limit on the kernel version?
If we depend on one very specific kernel version this will be very very brittle.
Is there a way to lift this restriction?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we had an issue with 4.15+ kernels in the older version of ROCm stack. Now the issue has been fixed in the upcoming ROCm release so I'll remove this.

if [[ "${TF_NEED_ROCM}" -eq 1 ]]; then
# ROCm requires the video group in order to use the GPU for compute. If it
# exists on the host, add it to the container.
getent group video || addgroup video && adduser "${CI_BUILD_USER}" video
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@caisq I am quite unfamiliar with this script. Are the changes here OK?

Copy link
Copy Markdown
Contributor Author

@whchung whchung Jul 10, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@parallelo could weign in.

This line is to fulfill permission requirements from ROCm stack specified at: https://github.com/RadeonOpenCompute/ROCm#next-set-your-permissions

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the case of ROCm, we add the container's user to the video group (but only if the host user was a member of the video group). This group membership is currently a requirement for ROCm.

bazel test --config=rocm --test_tag_filters=-no_oss,-oss_serial,-no_gpu,-benchmark-test -k \
--test_lang_filters=cc --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 \
--build_tests_only --test_output=errors --local_test_jobs=1 --config=opt \
--run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute -- \
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recommend omitting this flag, and setting local_test_jobs=1.
This is highly specialized for VMs with 8 k80 GPUs attached.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks. will do.

bazel test --config=rocm --test_tag_filters=-no_oss,-oss_serial,-no_gpu,-benchmark-test -k \
--test_lang_filters=py --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 \
--build_tests_only --test_output=errors --local_test_jobs=1 --config=opt \
--run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute -- \
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto above. remove this option.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks. will do.

bazel test --config=rocm --test_tag_filters=-no_gpu,-benchmark-test,-no_oss -k \
--jobs=${N_JOBS} --test_timeout 300,450,1200,3600 \
--build_tests_only --test_output=errors --local_test_jobs=1 \
--run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute \
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this option.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks. will do.

@@ -0,0 +1,277 @@
major_version: "local"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@meteorcloudy @mhlopko could you help review this file?

@@ -0,0 +1,239 @@
#!/usr/bin/env python
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@meteorcloudy @mhlopko could you review?

Comment thread third_party/gpus/rocm/BUILD.tpl Outdated
},
)

config_setting(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not think we need these duplicated here. Is there a reason to not reuse them from //tensorflow/BUILD?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gunan configs here are implemented following tensorflow/third_party/gpus/cuda/BUILD.tpl where CUDA-specific configs are specified. These CUDA-specific configs aren't in //tensorflow/BUILD either.

Since the purpose of this PR is to introduce ROCm-specific build scripts without major refactoring to existing TensorFlow build infrastructure I'd like to propose keep this file as is. Thoughts?

@@ -0,0 +1,32 @@
# Macros for building ROCm code.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@meteorcloudy @mhlopko could you help review?

as is as a string to --compiler-options of hipcc. When "-x rocm" is not
present, this wrapper invokes gcc with the input arguments as is.

NOTES:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can remove this note, we don't have crosstool_wrapper_driver_rocm internally.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok will do

Comment thread third_party/gpus/rocm_configure.bzl Outdated

def _find_rocm_lib(lib, repository_ctx, cpu_value, basedir, version="",
static=False):
"""Finds the given CUDA or cuDNN library on the system.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix comment, please

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks. will do.

Comment thread third_party/gpus/rocm_configure.bzl Outdated
lib: The name of the library, such as "rocmrt"
repository_ctx: The repository context.
cpu_value: The name of the host operating system.
basedir: The install directory of CUDA or cuDNN.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks. will do.

Comment thread third_party/gpus/rocm_configure.bzl Outdated
return struct(file_name=file_name, path=str(path.realpath))

elif cpu_value == "Windows":
path = repository_ctx.path("%s/lib/x64/%s" % (basedir, file_name))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is ROCm support actually available on Windows?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no it's not yet available on Windows. some infrastructure work has been undergoing but it'll be some time before we can actually enable it. I'll remove these checks.

Comment thread third_party/gpus/rocm_configure.bzl Outdated
auto_configure_fail("Cannot find rocm library %s" % file_name)

def _find_libs(repository_ctx, rocm_config):
"""Returns the CUDA and cuDNN libraries on the system.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks. will fix it.

Comment thread third_party/gpus/rocm_configure.bzl Outdated

Args:
repository_ctx: The repository context.
rocm_config: The CUDA config as returned by _get_rocm_config
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks. will fix it.

Comment thread third_party/gpus/rocm_configure.bzl Outdated
}

def _rocmrt_static_linkopt(cpu_value):
"""Returns additional platform-specific linkopts for rocmrt."""
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a copy of _cudart_static_linkopt, right? -lrt is needed during linking cudart static library on Linux, no sure the same option is needed for ROCm. Does the rocmrt static library even exist?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks. rocmrt_static is actually a static library in HIP. but actually it's not really needed for TensorFlow. logic here is indeed a copy from CUDA counterpart. I'll remove it.

Comment thread third_party/gpus/rocm_configure.bzl Outdated
_tpl(repository_ctx, "rocm:BUILD",
{
"%{rocmrt_static_lib}": rocm_libs["hip"].file_name,
"%{rocmrt_static_linkopt}": '',
Copy link
Copy Markdown
Member

@meteorcloudy meteorcloudy Jul 11, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If rocmrt_static_linkopt is not needed, we can remove them from both BUILD.tpl and rocm_configure.bzl
Does %{rocmrt_static_lib} library exist? Why is it the same as %{rocmrt_lib}?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll remove it.

cc = find_cc(repository_ctx)
host_compiler_includes = _host_compiler_includes(repository_ctx, cc)
rocm_defines = {
"%{rocm_include_path}": _rocm_include_path(repository_ctx,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like you have hard-coded rocm include paths in CROSSTOOL, should we remove this field or not hard-code?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll migrate hard-coded paths from CROSSTOOL to rocm_configure.bzl so it's easier to maintain. In CROSSTOOL we'll honor %{rocm_include_path}.

# linker_flag: "-Wl,--detect-odr-violations"

# Include directory for ROCm headers.
cxx_builtin_include_directory: "/opt/rocm/hsa/include"
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we replace them with %{rocm_include_path} so that they can be configured?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks. let me see how to remove these cxx_builtin_include_directory and have them populated from rocm_configure.bzl.

Comment thread tensorflow/workspace.bzl Outdated
nccl_configure(name="local_config_nccl")
git_configure(name="local_config_git")
sycl_configure(name="local_config_sycl")
rocm_configure(name="local_config_rocm")
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You probably need to add an exclude to build_pip_package.sh for this too.
here:

for f in `find . ! -type d ! -name '*.py' ! -path '*local_config_cuda*' ! -path '*local_config_tensorrt*' ! -path '*org_tensorflow*'`; do

Comment thread configure.py Outdated
set_trisycl_include_dir(environ_cp)

set_action_env_var(environ_cp, 'TF_NEED_ROCM', 'ROCm', False)
if environ_cp.get('TF_NEED_ROCM') == '1':
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought the eventual goal was to just have clang?

Comment thread configure.py
if environ_cp.get('TF_NEED_ROCM') == '1':
if 'LD_LIBRARY_PATH' in environ_cp and environ_cp.get(
'LD_LIBRARY_PATH') != '1':
write_action_env_to_bazelrc('LD_LIBRARY_PATH',
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, if both CUDA and ROCM is enabled, we write the environment variable twice, possibly one overriding the other.

We should merge the logic to write LD_LIBRARY_PATH to check both TF_NEED_CUDA and TF_NEED_ROCM and write it only once.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TF_NEED_ROCM and TF_NEED_CUDA would be changed so they are mutually exclusive.

)
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda", "if_cuda_is_configured")
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm", "if_rocm_is_configured")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jlebar Internally we do not have equivalents of these.
Are we planning to create the internal equivalents of these macros, just like what we have for CUDA?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we planning to create the internal equivalents of these macros, just like what we have for CUDA?

It's @Artem-B and my opinion that --config=cuda (and by corollary if_cuda) is a harmful hack. TF uses it as a switch to say "build for GPU or not", but it shouldn't: This should be decided by BUILD rules. (E.g. you depend on :tensorflow_cpu or :tensorflow_gpu.)

The reason we introduced --config=cuda is because at the time, Skylark was missing some features we needed in order to make the toolchain work properly. We think these features are there now.

Unfortunately the direction of these patches -- including the Eigen patch -- sort of doubles down on the notion of --config=cuda. If we can't build a TF which includes both cuda and rocm bits, then we'll effectively never be able to get rid of --config=cuda.

That's why in XLA we've insisted that we maintain the ability to build for both cuda and rocm, determined by BUILD dependencies.

That said, it's unclear to me what is the alternative to these macros for TF, given that the structure of the rocm Eigen patches does not (?) allow us to build for AMDGPU and NVGPU in the same binary. So TF may need its own versions of them internally; I don't see how else to do it.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jlebar Unfortunately given how GPU common runtime is designed I think it's hard to let TF be configured and built with both CUDA and ROCm at the same time. In gpu_device.cc, EigenCudaStreamDevice has a compile-time dependency to CUDA constructs. In my current implementation for ROCm, I renamed EigenCudaStreamDevice to EigenGpuStreamDevice and use TENSORFLOW_USE_ROCM macro to switch to ROCm-functional equivalents.

For XLA compiler, it's relatively easy to specify a new set of compiler backend and let it target AMDGPU. But for GPU common runtime, a bigger overhaul might be required if the ultimate goal is to get rid of --config=cuda.

That said, I believe such effort to modularize TF runtime, should be deferred to future PRs, after we have better consensus on how that be achieved?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me try to clarify the way I'd like to see things working.
The goal is to be able to build any (and all) variants without specifying any extra bazel flags.
I.e. one should be able to say bazel //my_project:app_cuda //my_project:app_rocm //my_project:app_cpu and get all three executables.

Let's suppose the app consists of main.cpp, kernel.cu (for CUDA and ROCm) and kernel_cpu.cpp (for CPU-only). All three app variants would use the same main.o, app_cuda will use kernel-cuda.o built from kernel.cu using CUDA-specific options/defines, kernel-rocm.o from kernel.cu using ROCm-specific options/defines and kernel-cpu would be compiled form kernel_cpu.cpp. User should be able to build any combination of them simultaneously. When you change anything in the build system, there's only one build configuration to test. If you've built an app there's no confusion about what it supports (or does not). I can't count number of times when someone attemted to run CPU-only TF and complained that it does not sees the GPUS or ran CUDA-enabled GPU on machine without GPUs and complained that it failed. Single build configuration also saves on overall build time as the objects that don't care about CUDA/ROCm will be built only once, instead of once per build config. Considering that CUDA files constitute relatively small subset of tensorflow, the difference is substantial.

--config=cuda was inherited from internal Google build and exists for number of reasons that are not relevant to open-source tensorflow. I understand that it is convenient to continue adding extra dimensions to config parameters, but now that we're adding support for another accelerator is the good time to make sure we do it right. Maintaining and debugging multiple build configurations is a royal pain. Having single set of build rules for everything makes things somewhat easier to deal with.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Artem-B / @jlebar let me see if i understand this correctly to assess the efforts. It seems you expect one bazel build to get executables for all specified targets, and ditch --config=XXX.

As a corollary, would the following command be good for you? 1 bazel build get 3 PIP packages:

# 1 bazel build get 3 PIP packages
bazel build //tensorflow/tools/pip_package:build_cpu_pip_package //tensorflow/tools/pip_package:build_cuda_pip_package //tensorflow/tools/pip_package:build_rocm_pip_package

Also I'm wondering how shall we deal with test targets in bazel test?

Copy link
Copy Markdown
Contributor Author

@whchung whchung Jul 27, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gregestren sorry for inviting you to an new discussion thread without prior notification.

I, @whchung, am discussing with @jlebar / @Artem-B / @gunan about how to improve the build system of TensorFlow to make it to support multiple CPU / GPU platforms. And I've discovered your work of dynamic bazel configurations [1], skylark build configurations [2], and 2018 Bazel Configurability / Multiplatform Roadmap [3] which eventually led me to bazel roadmaps pages [4] [5].

I'm very new to the implementation of bazel, let alone how to test / adapt all these upcoming features of it in TensorFlow. I'm wondering could you help share some working example projects with targets tied to different toolchains, so I can learn better from those examples? Thank you very much.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I've completely forgot that OSS TF still uses custom crosstool to compile with nvcc. :-(
That's going to be a problem. Crosstool is not something that can be switched per build target,
and NVCC does not support clang as the host compiler. That forces us to have multiple build configurations. We still would be able to do something like bazel --config=nvcc foo_cuda foo_cpu, or bazel foo_rocm foo_cpu (probably with something like --config=rocm) but we will not be able to do all three at once. This complicates things.

I'll need to think about it a bit.

@gregestren -- It sounds like bazel has grown a lot of features lately that I've been missing for CUDA compilation. I'm glad to see that my handwavy proposal for TF compilation seems to be roughly in line with the general direction bazel configurability roadmap seems to head towards. I'll need to take a closer look at the recent changes to get better idea what we can do these days, but it appears that I may have more tools at my disposal than I used to in the past.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok to merge then?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Considering that at the moment we do not have a working alternative, that's probably the least bad option.
It does entangle the build with multiple crosstools, but it's a marginal bump in amount of work we'll need to do in addition to what's needed to deal with --config=nvcc. Making sure rocm build remains working will be a bit of a pain, but I expect bulk of the issues to be figured out on nvcc build, so overall it should not be a major issue.

I'm OK with the patch, but it's ultimately TF team's call.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Jack, Artem,

Apologies for my delayed response - I was doing some personal traveling last week.

I'd love to discuss more both your goals and how Bazel's multiplatform changes could help them. We're at a weird state now where lots of new possibilities are opening up (like being able to really support per-target crosstools) but the public APIs are still all coming together. So it's not as simple as "just follow this pattern in the Bazel documentation" but that doesn't mean there aren't options for you.

Since I'm just coming in late into this conversation, I suggest we all get on the same basic page of what's desired. Then we can clarify what features can address your goals, and how well.

Would that work?

Comment thread tensorflow/core/kernels/BUILD Outdated
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
] + if_cuda([
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like this whole library should only be added if cuda is enabled. What is the need for a double check?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately the header file in cuda_solvers target is not CUDA-specific. There are some utility structures used by operators such as SegmentSumGPUOp which is supported on ROCm.

I'll submit additional PRs to refactor SegmentSumGPUOp. And I'll remove this check here in this PR.

deps = MATH_DEPS + if_cuda_is_configured([
":cuda_solvers",
]) + if_rocm_is_configured([
":cuda_solvers",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is quite surprising to me.
cuda_solvers is an empty library if cuda is not enabled. why even link it here?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately the header file in cuda_solvers target is not CUDA-specific. There are some utility structures used by operators such as SegmentSumGPUOp which is supported on ROCm.

I'll submit additional PRs to refactor SegmentSumGPUOp. And I'll remove cuda_solvers dependency here.

Comment thread tensorflow/tensorflow.bzl
"cuda_default_copts",
)
load(
"@local_config_rocm//rocm:build_defs.bzl",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto. @jlebar we need to decide on the internal versions of these before we can merge this.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So do we approve, and start wrestling with this change internally?

Comment thread tensorflow/tensorflow.bzl Outdated
extra_copts=extra_copts,
linkopts=linkopts,
args=args)
if if_cuda_is_configured(True) or if_rocm_is_configured(True):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Internally, we do not have a configure script. So we need this enabled unconditionally. Please revert. and just expand the switch statements.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will do.

Comment thread tensorflow/tensorflow.bzl Outdated
hdrs=[],
**kwargs):
copts = copts + _cuda_copts() + if_cuda(cuda_copts) + tf_copts()
copts=copts + tf_copts() + _cuda_copts(opts=cuda_copts) + _rocm_copts(opts=cuda_copts)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit. please revert the spacing change around =
It needs to be copts = copts + .... for our code linter checks to pass.
Not the difference between local variable assignment here, vs a function kwarg, which does not need the space around the =.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks. will fix.

Comment thread configure.py
else:
set_trisycl_include_dir(environ_cp)

set_action_env_var(environ_cp, 'TF_NEED_ROCM', 'ROCm', False)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at some of the options below, enabling both at the same time would break other things. For example, the _gpu cc_test targets. Which GPU would they run on?
Let's add the check @caisq requested here.

@@ -0,0 +1,97 @@
# This Dockerfile provides a starting point for a ROCm installation of
# MIOpen and tensorflow.
FROM ubuntu:xenial-20170619
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have an upper limit on the kernel version?
If we depend on one very specific kernel version this will be very very brittle.
Is there a way to lift this restriction?

@whchung whchung force-pushed the upstream-staging branch from 4885f5e to e2b5544 Compare July 17, 2018 19:16
@rmlarsen
Copy link
Copy Markdown
Contributor

@gunan could you please take another look?

@whchung
Copy link
Copy Markdown
Contributor Author

whchung commented Jul 17, 2018

@rmlarsen / @gunan please hold it for now . I haven't finished addressing all comments from reviewers yet. will ping you guys once i feel more comfortable with the PR. And I think I'll probably need to squash the commits to make commit history looks nicer.

@whchung whchung force-pushed the upstream-staging branch 2 times, most recently from acb4dc5 to 6d5ca1c Compare July 17, 2018 23:14
@whchung
Copy link
Copy Markdown
Contributor Author

whchung commented Jul 17, 2018

@rmlarsen / @gunan I believe I've already address all the comments and updated the PR. Please help review it again. Thanks!

@whchung
Copy link
Copy Markdown
Contributor Author

whchung commented Jul 20, 2018

a gentle ping?

@gunan
Copy link
Copy Markdown
Contributor

gunan commented Aug 30, 2018

No, we do not have the hardware, toolchains, license reviews or anything in place for us to be able to build with ROCm. So they definitely wont become blocking presubmits for now.
This, of course will be subject to review later.

In the meantime, we can work with you to setup community supported builds, as outlined here:
https://github.com/tensorflow/community/blob/master/sigs/build/community-builds.md

@whchung
Copy link
Copy Markdown
Contributor Author

whchung commented Aug 31, 2018

@gunan Thanks for sharing the community build page with me. @parallelo would look into it and adapt our CI infrastructure to accommodate that.

Also we do like to revive our other outstanding PRs, specifically those in StreamExecutor and GPU common runtime which are blocked by this particular PR. Also we'll start submitting PRs to enable operators on ROCm which condition upon TENSORFLOW_USE_ROCM introduced in this PR. Would it be possible to help expedite allowing this PR be merged? Thanks a lot.

With more developers working on TensorFlow ROCm port now we expect subsequent PRs be revised / maintained in timely fashion.

@yifeif
Copy link
Copy Markdown
Contributor

yifeif commented Sep 5, 2018

This will need a manual pull for sure @gunan. @whchung do you mind resolving the latest conflicts? Reviewers, let me know if this is ready and I can give pulling a shot.

The commit contains following components to support TensorFlow on ROCm platform

- bazel build system
- continuous integration logic

Authors:

- Jack Chung: jack.chung@amd.com
- Jeffrey Poznanovic: Jeffrey.Poznanovic@amd.com
- Peng Sun: Peng.Sun@amd.com
@whchung
Copy link
Copy Markdown
Contributor Author

whchung commented Sep 6, 2018

@yifeif @gunan, my colleague @deven-amd has debased and modified the PR

@gunan gunan added the kokoro:force-run Tests on submitted change label Sep 11, 2018
@kokoro-team kokoro-team removed the kokoro:force-run Tests on submitted change label Sep 11, 2018
@aaroey aaroey self-requested a review September 12, 2018 20:19
@yifeif yifeif added ready to pull PR ready for merge process kokoro:force-run Tests on submitted change labels Sep 16, 2018
@kokoro-team kokoro-team removed the kokoro:force-run Tests on submitted change label Sep 17, 2018
@dagamayank
Copy link
Copy Markdown

@yifeif any update to pull this PR in?

@drpngx
Copy link
Copy Markdown
Contributor

drpngx commented Sep 21, 2018

@aaroey @meteorcloudy any additional comments?

@drpngx
Copy link
Copy Markdown
Contributor

drpngx commented Sep 24, 2018

Oh, it looks like we have to import this manually.

@gunan
Copy link
Copy Markdown
Contributor

gunan commented Sep 24, 2018

@yifeif is working on the manual import of this, she has been working on this for 2 weeks now.

@tensorflow-copybara tensorflow-copybara merged commit 69d3b8f into tensorflow:master Sep 27, 2018
tensorflow-copybara pushed a commit that referenced this pull request Sep 27, 2018
@gunan
Copy link
Copy Markdown
Contributor

gunan commented Sep 27, 2018

We had to revert some "if_cuda_is_configured" uses to if_cuda to make all internal tests to pass, but other than that all of this change has been merged.

@yifeif
Copy link
Copy Markdown
Contributor

yifeif commented Sep 27, 2018

We finally got this change merged! Thanks for the patience. We needed to change if_cuda_is_configured in tf_cuda_library back to if_cuda to get some internal targets to pass. Feel free to send another PR if this causes any issue and we can work out a patch.

@drpngx
Copy link
Copy Markdown
Contributor

drpngx commented Sep 27, 2018

Woohoo! Thank you @yifeif !

@whchung
Copy link
Copy Markdown
Contributor Author

whchung commented Sep 27, 2018

Thank you @yifeif and @gunan . We’ll revise other pending PRs for ROCm , as well as submitting new ones :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla: yes ready to pull PR ready for merge process

Projects

None yet

Development

Successfully merging this pull request may close these issues.