Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Building Tensorflow from source with XLA enabled and without GPU #11122

Closed
ankitachandak opened this issue Jun 28, 2017 · 18 comments · Fixed by #64942
Closed

Building Tensorflow from source with XLA enabled and without GPU #11122

ankitachandak opened this issue Jun 28, 2017 · 18 comments · Fixed by #64942
Assignees
Labels
type:build/install Build and install issues

Comments

@ankitachandak
Copy link

Hi,
I built and installed tensorflow from source with XLA enabled and GPU disabled (basically I opted N for everything while configuring via ./config except XLA enabling as Y). There were lot of warnings regrding deprecated syntax while building. but the build was successful.
I am able to import tensorflow and run basic print command in session. But while I try to do some computation (for eg. simple addition) it gives me following error:
2017-06-28 15:09:22.366052: F tensorflow/compiler/xla/statusor.cc:41] Attempting to fetch value instead of handling error Not found: could not find registered computation placer for platform Executor -- check target linkage
Aborted

I did a bit of debugging and this error comes just after the call from client/sessions.py:1262 to pywrap_tensorflow:
tf_session.TF_Run(session, options,
feed_dict, fetch_list, target_list,
status, run_metadata)

so I believe it's because it is unable to link to _pywrap_tensorflow_internal.so.
Can you please provide any fix to this or is there something am doing wrong here?
This is blocking my further task so any kind of help is appreciated!

Thanks & Regards

@ankitachandak ankitachandak changed the title Building Tensorflow from source with XLA and without GPU Building Tensorflow from source with XLA enabled and without GPU Jun 28, 2017
@mikowals
Copy link
Contributor

I encounter this also having built on MacOS. I noticed that the error goes away if running in a test block if that helps identify what is happening.

This code does not error:

import tensorflow as tf

class BugTest(tf.test.TestCase):
  def bug_test(self):

    with tf.Session() as sess:  #this can also use self.test_session()
	x1 = tf.random_normal(shape=[64, 64, 32, 32], seed=1.)
	res = sess.run(x1)

if __name__ == '__main__':
  tf.test.main()

This code gives the error:

import tensorflow as tf
with tf.Session() as sess:
   x1 = tf.random_normal(shape=[64, 64, 32, 32], seed=1.)
   res = sess.run(x1)

@suiyuan2009
Copy link
Contributor

suiyuan2009 commented Jun 29, 2017

I meet this error when running cnn benchmark. I build tensorflow from source, enabled both gpu and xla, latest version is this commit. I'm in a cloud environment using nfs. @mikowals both programs are ok in my environment.

@ankitachandak
Copy link
Author

@mikowals are you running the test with bazel?

@ankitachandak
Copy link
Author

@suiyuan2009 I think there's issue when we do it without GPU support. Have you tried that?

@ramanishka
Copy link

Faced the same thing running cifar10 example from https://github.com/tensorflow/models.
Here is the snippet from @yaroslavvb, link

import os
os.environ['CUDA_VISIBLE_DEVICES']=''
import tensorflow as tf
from tensorflow.contrib.compiler import jit
tf.reset_default_graph()
jit_scope = jit.experimental_jit_scope
with jit_scope(compile_ops=True):
    N = 500*1000*1000
    x = tf.Variable(tf.random_uniform(shape=(N,)))
    y = 0.1*x*x*x*x*x-0.5*x*x*x*x+.25*x*x*x+.75*x*x-1.5*x-2
    y0 = y[0]
import time
sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run(y.op)
start_time = time.time()
print(sess.run(y0))
end_time = time.time()
print("%.2f sec"%(end_time-start_time))

gives:

2017-06-29 14:37:05.431926: F tensorflow/compiler/xla/statusor.cc:41] Attempting to fetch value instead of handling error Not found: could not find registered computation placer for platform Executor -- check target linkage

After adding with tf.device('/cpu') it works:

import os
os.environ['CUDA_VISIBLE_DEVICES']=''
import tensorflow as tf
from tensorflow.contrib.compiler import jit
tf.reset_default_graph()
jit_scope = jit.experimental_jit_scope
with jit_scope(compile_ops=True):
    N = 500*1000*1000
    with tf.device("/cpu"):
        x = tf.Variable(tf.random_uniform(shape=(N,)))
        y = 0.1*x*x*x*x*x-0.5*x*x*x*x+.25*x*x*x+.75*x*x-1.5*x-2
        y0 = y[0]
import time
sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run(y.op)
start_time = time.time()
print(sess.run(y0))
end_time = time.time()
print("%.2f sec"%(end_time-start_time))
2017-06-29 14:52:30.979254: I tensorflow/compiler/xla/service/service.cc:193] XLA service 0x7ff440159ca0 executing computations on platform Host. Devices:
2017-06-29 14:52:30.979277: I tensorflow/compiler/xla/service/service.cc:201]   StreamExecutor device (0): <undefined>, <undefined>
-2.7442
0.68 sec

Built from sources with XLA and CUDA, commit: 270a3e8

@mikowals
Copy link
Contributor

Forcing ops on the CPU seems to be a workaround for my reproduction also on a CPU only build.

@ankitachandak I am not testing with bazel, just running with python.

This code works:

import tensorflow as tf
with tf.device('/cpu'):
    x1 = tf.random_normal(shape=[64, 64, 32, 32], seed=1.)

with tf.Session() as sess: 
    res = sess.run(x)

@suiyuan2009
Copy link
Contributor

@ankitachandak , I didn't try cpu version, our training data is very large.

@ali01
Copy link

ali01 commented Jun 30, 2017

Please provide details about what platform you are using (operating system, architecture). Also include your TensorFlow version. Also, did you compile from source or install a binary? Make sure you also include the exact command if possible to produce the output included in your test case. If you are unclear what to include see the issue template displayed in the Github new issue template.

We ask for this in the issue submission template, because it is really difficult to help without that information. Thanks!

@ali01 ali01 added stat:awaiting response Status - Awaiting response from author type:build/install Build and install issues labels Jun 30, 2017
@ramanishka
Copy link

ramanishka commented Jun 30, 2017

@suiyuan2009 No need to "try cpu version". The snippet @mikowals provided works fine without the explicit specification of the device if gpus are available. But if they are not available for some reason (e.g. CUDA_VISIBLE_DEVICES='') - the code requires such explicit specification. Interesting, that this breaks cifar10 code.

@ali01

$ cat tf_env.txt
== cat /etc/issue ===============================================
Linux sv 4.8.0-56-generic #61~16.04.1-Ubuntu SMP Wed Jun 14 11:58:22 UTC 2017 x86_64 x86_64 x86_64 GNU/Linux
VERSION="16.04.2 LTS (Xenial Xerus)"
VERSION_ID="16.04"
VERSION_CODENAME=xenial

== are we in docker =============================================
No

== compiler =====================================================
c++ (Ubuntu 5.4.0-6ubuntu1~16.04.4) 5.4.0 20160609

== uname -a =====================================================
Linux sv800478lx 4.8.0-56-generic #61~16.04.1-Ubuntu SMP Wed Jun 14 11:58:22 UTC 2017 x86_64 x86_64 x86_64 GNU/Linux

== check pips ===================================================
numpy (1.13.0)
protobuf (3.3.0)
tensorflow (1.2.0)
tensorflow-tensorboard (0.1.2)

== check for virtualenv =========================================
False

== tensorflow import ============================================
tf.VERSION = 1.2.0
tf.GIT_VERSION = b'v1.2.0-1367-g270a3e8'
tf.COMPILER_VERSION = b'v1.2.0-1367-g270a3e8'
Sanity check: array([1], dtype=int32)

== env ==========================================================
LD_LIBRARY_PATH is unset
DYLD_LIBRARY_PATH is unset

== nvidia-smi ===================================================
Fri Jun 30 12:01:38 2017       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 375.51                 Driver Version: 375.51                    |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|===============================+======================+======================|
|   0  TITAN X (Pascal)    On   | 0000:03:00.0      On |                  N/A |
| 27%   49C    P0    59W / 250W |   1017MiB / 12189MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   1  TITAN X (Pascal)    On   | 0000:04:00.0     Off |                  N/A |
| 23%   39C    P0    56W / 250W |      1MiB / 12189MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+

== cuda libs  ===================================================
/usr/local/cuda-8.0/doc/man/man7/libcudart.7
/usr/local/cuda-8.0/doc/man/man7/libcudart.so.7
/usr/local/cuda-8.0/targets/x86_64-linux/lib/libcudart.so.8.0.61
/usr/local/cuda-8.0/targets/x86_64-linux/lib/libcudart_static.a
$ cat .tf_configure.bazelrc
................
build --define with_jemalloc=true
build --define with_xla_support=true
build:opt --cxxopt=-march=native --copt=-march=native
build --action_env TF_NEED_CUDA="1"
build --action_env TF_NEED_OPENCL="0"
build --action_env TF_CUDA_CLANG="0"
build --action_env CUDA_TOOLKIT_PATH="/usr/local/cuda"
build --action_env TF_CUDA_VERSION="8.0"
build --action_env GCC_HOST_COMPILER_PATH="/usr/bin/gcc"
build --action_env TF_CUDNN_VERSION=""
build --action_env CUDNN_INSTALL_PATH="/usr/local/cuda-8.0"
build --action_env TF_CUDNN_VERSION="6"
build --action_env TF_CUDA_COMPUTE_CAPABILITIES="6.1"
build --config=cuda
test --config=cuda

Without explicit device specification, example.py:

import tensorflow as tf
x = tf.random_normal(shape=[64, 64, 32, 32], seed=1.)
with tf.Session() as sess: 
    res = sess.run(x)

CUDA_VISIBLE_DEVICES='' python example.py crashes with Attempting to fetch value instead of handling error Not found: could not find registered computation placer for platform Executor -- check target linkage Aborted (core dumped)

At the same time the version from g8bbec0b works.

I would assume that something happened in between.

@ankitachandak
Copy link
Author

@ali01 I compiled from the source
I followed the instructions from https://www.tensorflow.org/install/install_sources
So cloned the latest Tensorflow version
OS Platform and Distribution: Linux Ubuntu 16.04
Python version: 2.7
Bazel: 0.5.2
CUDA/cuDNN version: None (As I said I am installing CPU version)
GPU model and memory: N/A

Exact command I ran which gave the error were as simple:

a = tf.placeholder(tf.float32)
b = tf.placeholder(tf.float32)
adder_node = a + b # + provides a shortcut for tf.add(a, b)
print(sess.run(adder_node, {a: 3, b:4.5}))

I placed it in a file and ran python example.py
I debugged and found out that it goes till pywrap_tensorflow run method:
tf_session.TF_Run(session, options,
feed_dict, fetch_list, target_list,
status, run_metadata)

@ankitachandak
Copy link
Author

@ramanishka @mikowals Yes force setting CPU works! Thanks for the hack.

Although I hope tensorflow resolves this issue as the workaround won't be of much help later as I am trying to play around LLVM and I would prefer to run it without any hacks.

@oxydron
Copy link

oxydron commented Jul 4, 2017

Used the following command, trying to guarantee that it uses the SSE4.[1,2] and AVX:

bazel build -c opt --copt=-mavx --copt=-mavx2 --copt=-msse4.2 --copt=-msse4.1 --copt=-msse3 --copt=-mfma -k //tensorflow/tools/pip_package:build_pip_package

After the pip installation, tryed to run a Keras project and got this error:

"Using TensorFlow backend.
2017-07-04 10:30:20.681180: F tensorflow/compiler/xla/statusor.cc:41] Attempting to fetch value instead of handling error Not found: could not find registered computation placer for platform Executor -- check target linkage

Tried this way and the standard way from TensorFlow website.

My computer info:
OS: Ubuntu 16.04.2 LTS (Linux 4.8.0-51-generic)
Processor: i3-2100 CPU
Memory: 16.4 GB

@aselle aselle removed the stat:awaiting response Status - Awaiting response from author label Jul 5, 2017
@darrengarvey
Copy link
Contributor

Looks like this came in with the ComputationPlacer in 7d3497a. Although it might have just been exposed by that commit.

Specifically, this check fails for the Executor plugin. The TransferManager registers itself in the executor plugin (like this) so bodging in a similar thing for ComputationPlacer fixes the network I see the reported error on. eg:

diff --git a/tensorflow/compiler/plugin/executor/transfer_manager.cc b/tensorflow/compiler/plugin/executor/transfer_manager.cc
index 51c5dee..39d443e 100644
--- a/tensorflow/compiler/plugin/executor/transfer_manager.cc
+++ b/tensorflow/compiler/plugin/executor/transfer_manager.cc
@@ -23,6 +23,7 @@ limitations under the License.
 #include "tensorflow/compiler/xla/types.h"
 #include "tensorflow/compiler/xla/util.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/compiler/xla/service/computation_placer.h"
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
@@ -179,9 +180,15 @@ static std::unique_ptr<xla::TransferManager> CreateExecutorTransferManager() {
   return xla::MakeUnique<xla::executorplugin::ExecutorTransferManager>();
 }

+static std::unique_ptr<xla::ComputationPlacer> CreateExecutorComputationPlacer() {
+  return xla::MakeUnique<xla::ComputationPlacer>();
+}
+
 static bool InitModule() {
   xla::TransferManager::RegisterTransferManager(sep::kExecutorPlatformId,
                                                 &CreateExecutorTransferManager);
+  xla::ComputationPlacer::RegisterComputationPlacer(sep::kExecutorPlatformId,
+                                                &CreateExecutorComputationPlacer);
   return true;
 }
 static bool module_initialized = InitModule();

However unsurprisingly this bodge doesn't fix the other test cases in this ticket. For instance the simple test case:

import tensorflow as tf
x = tf.random_normal(shape=[64, 64, 32, 32], seed=1.)
with tf.Session() as sess: 
    res = sess.run(x)

Now provokes:

tensorflow.python.framework.errors_impl.UnimplementedError: unhandled HLO ops for HloEvaluator: rng.
	 [[Node: cluster_0/_0/_1 = _XlaLaunch[Nresources=0, Targs=[], Tconstants=[], Tresults=[DT_FLOAT], function=cluster_0[_XlaCompiledKernel=true, _XlaNumConstantArgs=0, _XlaNumResourceArgs=0], _device="/job:localhost/replica:0/task:0/device:XLA_EXEC:0"]()]]
	 [[Node: cluster_0/_0/_1/_1 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/cpu:0", send_device="/job:localhost/replica:0/task:0/device:XLA_EXEC:0", send_device_incarnation=1, tensor_name="edge_6_cluster_0/_0/_1", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/cpu:0"]()]]```

@JeremyCCHsu
Copy link

JeremyCCHsu commented Jul 6, 2017

I have the same problem (built from source, XLA and GPU enabled).
(Note that this error doesn't occur in the binary version from the official website.)

Code to reproduce the issue:

import tensorflow as tf

x = tf.ones([2, 1])
y = tf.layers.dense(x, 3)

sv = tf.train.Supervisor()
sess = sv.prepare_or_wait_for_session()  # this triggers error

The error message:

2017-07-06 14:41:11.341941: F tensorflow/compiler/xla/statusor.cc:41] Attempting to fetch value instead of handling error Not found: could not find registered computation placer for platform Executor -- check target linkage
Aborted (core dumped)

However, replacing the last 2 lines with the following won't trigger the error.

sess = tf.Session()
sess.run(tf.initialize_all_variables())
sess.run(y)

This yielded normal output and did not trigger core dump.

Unlike what @ramanishka mentioned, using with tf.device('/cpu') cannot prevent the error in my case.

System Info

OS: Ubuntu 16.04
Tensorflow: 1.2.1, built from source
commit ee4259a
(and commit 43a819e)
Python: 3.5.3
Bazel: release 0.5.2
CUDA/cuDNN: 8.0/6.0
GPU: GeForce GTX TITAN X

@michaelisard
Copy link

@hyouklee it looks as if you added the ComputationPlacer code: would you take a look?

@kayzhu
Copy link
Contributor

kayzhu commented Jul 10, 2017

I will take a look at this, since @hyouklee is away this week.

@kayzhu kayzhu self-assigned this Jul 10, 2017
@kayzhu
Copy link
Contributor

kayzhu commented Jul 12, 2017

Have a fix: the priority of XLA_EXEC is set too high in https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/plugin/executor/device.cc#L50 changing it to lower (40 for eg.) fixes the problem.

A fix should be merged upstream soon.

@kayzhu
Copy link
Contributor

kayzhu commented Jul 14, 2017

Thanks @frankchn for the merge. This should now be fixed, but please feel free to reopen if the same symptom remains.

@kayzhu kayzhu closed this as completed Jul 14, 2017
zhuangh pushed a commit to zhuangh/tensorflow that referenced this issue Jul 14, 2017
…t when TF

is built from source with XLA support. See Github issue tensorflow#11122.

The priority of the executor backend is set to be higher than the default (50)
and CPUs (<100), and is therefore selected as the default when tf.device is not
explicitly specified.

PiperOrigin-RevId: 161717173
copybara-service bot pushed a commit that referenced this issue Apr 2, 2024
Imported from GitHub PR openxla/xla#11122

Copybara import of the project:

--
94f7dc85f6ab07343ddb565b0aac0ed908fd66b4 by Dragan Mladjenovic <Dragan.Mladjenovic@amd.com>:

[ROCm] Allow ncclCommInitRankConfig on rocm5.7

--
2d196b645ccac046309d44f747d42777f00f1c9b by Dragan Mladjenovic <Dragan.Mladjenovic@amd.com>:

[ROCm] Fix sporadic hangs in CommInitRanks

Pointer to comm_handle must stay valid untill GroupEnd call
on pre 2.18 nccl. Only after that comm_handle contains a valid value.
Enforce this by using std::vector instead of temp local.

Merging this change closes #11122

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#11122 from ROCm:5.7-rccl-hang 2d196b645ccac046309d44f747d42777f00f1c9b
PiperOrigin-RevId: 621235912
copybara-service bot pushed a commit that referenced this issue Apr 2, 2024
Imported from GitHub PR openxla/xla#11122

Copybara import of the project:

--
94f7dc85f6ab07343ddb565b0aac0ed908fd66b4 by Dragan Mladjenovic <Dragan.Mladjenovic@amd.com>:

[ROCm] Allow ncclCommInitRankConfig on rocm5.7

--
2d196b645ccac046309d44f747d42777f00f1c9b by Dragan Mladjenovic <Dragan.Mladjenovic@amd.com>:

[ROCm] Fix sporadic hangs in CommInitRanks

Pointer to comm_handle must stay valid untill GroupEnd call
on pre 2.18 nccl. Only after that comm_handle contains a valid value.
Enforce this by using std::vector instead of temp local.

Merging this change closes #11122

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#11122 from ROCm:5.7-rccl-hang 2d196b645ccac046309d44f747d42777f00f1c9b
PiperOrigin-RevId: 621235912
copybara-service bot pushed a commit that referenced this issue Apr 2, 2024
Imported from GitHub PR openxla/xla#11122

Copybara import of the project:

--
94f7dc85f6ab07343ddb565b0aac0ed908fd66b4 by Dragan Mladjenovic <Dragan.Mladjenovic@amd.com>:

[ROCm] Allow ncclCommInitRankConfig on rocm5.7

--
2d196b645ccac046309d44f747d42777f00f1c9b by Dragan Mladjenovic <Dragan.Mladjenovic@amd.com>:

[ROCm] Fix sporadic hangs in CommInitRanks

Pointer to comm_handle must stay valid untill GroupEnd call
on pre 2.18 nccl. Only after that comm_handle contains a valid value.
Enforce this by using std::vector instead of temp local.

Merging this change closes #11122

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#11122 from ROCm:5.7-rccl-hang 2d196b645ccac046309d44f747d42777f00f1c9b
PiperOrigin-RevId: 621235912
copybara-service bot pushed a commit that referenced this issue Apr 2, 2024
Imported from GitHub PR openxla/xla#11122

Copybara import of the project:

--
94f7dc85f6ab07343ddb565b0aac0ed908fd66b4 by Dragan Mladjenovic <Dragan.Mladjenovic@amd.com>:

[ROCm] Allow ncclCommInitRankConfig on rocm5.7

--
2d196b645ccac046309d44f747d42777f00f1c9b by Dragan Mladjenovic <Dragan.Mladjenovic@amd.com>:

[ROCm] Fix sporadic hangs in CommInitRanks

Pointer to comm_handle must stay valid untill GroupEnd call
on pre 2.18 nccl. Only after that comm_handle contains a valid value.
Enforce this by using std::vector instead of temp local.

Merging this change closes #11122

PiperOrigin-RevId: 621262510
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
type:build/install Build and install issues
Projects
None yet
Development

Successfully merging a pull request may close this issue.