New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Building Tensorflow from source with XLA enabled and without GPU #11122
Comments
I encounter this also having built on MacOS. I noticed that the error goes away if running in a test block if that helps identify what is happening. This code does not error:
This code gives the error:
|
I meet this error when running cnn benchmark. I build tensorflow from source, enabled both gpu and xla, latest version is this commit. I'm in a cloud environment using nfs. @mikowals both programs are ok in my environment. |
@mikowals are you running the test with bazel? |
@suiyuan2009 I think there's issue when we do it without GPU support. Have you tried that? |
Faced the same thing running cifar10 example from https://github.com/tensorflow/models.
gives:
After adding
Built from sources with XLA and CUDA, commit: 270a3e8 |
Forcing ops on the CPU seems to be a workaround for my reproduction also on a CPU only build. @ankitachandak I am not testing with bazel, just running with python. This code works:
|
@ankitachandak , I didn't try cpu version, our training data is very large. |
Please provide details about what platform you are using (operating system, architecture). Also include your TensorFlow version. Also, did you compile from source or install a binary? Make sure you also include the exact command if possible to produce the output included in your test case. If you are unclear what to include see the issue template displayed in the Github new issue template. We ask for this in the issue submission template, because it is really difficult to help without that information. Thanks! |
@suiyuan2009 No need to "try cpu version". The snippet @mikowals provided works fine without the explicit specification of the device if gpus are available. But if they are not available for some reason (e.g. CUDA_VISIBLE_DEVICES='') - the code requires such explicit specification. Interesting, that this breaks cifar10 code.
Without explicit device specification, example.py:
At the same time the version from g8bbec0b works. I would assume that something happened in between. |
@ali01 I compiled from the source Exact command I ran which gave the error were as simple: a = tf.placeholder(tf.float32) I placed it in a file and ran python example.py |
@ramanishka @mikowals Yes force setting CPU works! Thanks for the hack. Although I hope tensorflow resolves this issue as the workaround won't be of much help later as I am trying to play around LLVM and I would prefer to run it without any hacks. |
Used the following command, trying to guarantee that it uses the SSE4.[1,2] and AVX: bazel build -c opt --copt=-mavx --copt=-mavx2 --copt=-msse4.2 --copt=-msse4.1 --copt=-msse3 --copt=-mfma -k //tensorflow/tools/pip_package:build_pip_package After the pip installation, tryed to run a Keras project and got this error: "Using TensorFlow backend. Tried this way and the standard way from TensorFlow website. My computer info: |
Looks like this came in with the Specifically, this check fails for the Executor plugin. The
However unsurprisingly this bodge doesn't fix the other test cases in this ticket. For instance the simple test case:
Now provokes:
|
I have the same problem (built from source, XLA and GPU enabled). Code to reproduce the issue:import tensorflow as tf
x = tf.ones([2, 1])
y = tf.layers.dense(x, 3)
sv = tf.train.Supervisor()
sess = sv.prepare_or_wait_for_session() # this triggers error The error message:
However, replacing the last 2 lines with the following won't trigger the error. sess = tf.Session()
sess.run(tf.initialize_all_variables())
sess.run(y) This yielded normal output and did not trigger core dump. Unlike what @ramanishka mentioned, using System InfoOS: Ubuntu 16.04 |
@hyouklee it looks as if you added the ComputationPlacer code: would you take a look? |
I will take a look at this, since @hyouklee is away this week. |
Have a fix: the priority of XLA_EXEC is set too high in https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/plugin/executor/device.cc#L50 changing it to lower (40 for eg.) fixes the problem. A fix should be merged upstream soon. |
Thanks @frankchn for the merge. This should now be fixed, but please feel free to reopen if the same symptom remains. |
…t when TF is built from source with XLA support. See Github issue tensorflow#11122. The priority of the executor backend is set to be higher than the default (50) and CPUs (<100), and is therefore selected as the default when tf.device is not explicitly specified. PiperOrigin-RevId: 161717173
Imported from GitHub PR openxla/xla#11122 Copybara import of the project: -- 94f7dc85f6ab07343ddb565b0aac0ed908fd66b4 by Dragan Mladjenovic <Dragan.Mladjenovic@amd.com>: [ROCm] Allow ncclCommInitRankConfig on rocm5.7 -- 2d196b645ccac046309d44f747d42777f00f1c9b by Dragan Mladjenovic <Dragan.Mladjenovic@amd.com>: [ROCm] Fix sporadic hangs in CommInitRanks Pointer to comm_handle must stay valid untill GroupEnd call on pre 2.18 nccl. Only after that comm_handle contains a valid value. Enforce this by using std::vector instead of temp local. Merging this change closes #11122 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#11122 from ROCm:5.7-rccl-hang 2d196b645ccac046309d44f747d42777f00f1c9b PiperOrigin-RevId: 621235912
Imported from GitHub PR openxla/xla#11122 Copybara import of the project: -- 94f7dc85f6ab07343ddb565b0aac0ed908fd66b4 by Dragan Mladjenovic <Dragan.Mladjenovic@amd.com>: [ROCm] Allow ncclCommInitRankConfig on rocm5.7 -- 2d196b645ccac046309d44f747d42777f00f1c9b by Dragan Mladjenovic <Dragan.Mladjenovic@amd.com>: [ROCm] Fix sporadic hangs in CommInitRanks Pointer to comm_handle must stay valid untill GroupEnd call on pre 2.18 nccl. Only after that comm_handle contains a valid value. Enforce this by using std::vector instead of temp local. Merging this change closes #11122 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#11122 from ROCm:5.7-rccl-hang 2d196b645ccac046309d44f747d42777f00f1c9b PiperOrigin-RevId: 621235912
Imported from GitHub PR openxla/xla#11122 Copybara import of the project: -- 94f7dc85f6ab07343ddb565b0aac0ed908fd66b4 by Dragan Mladjenovic <Dragan.Mladjenovic@amd.com>: [ROCm] Allow ncclCommInitRankConfig on rocm5.7 -- 2d196b645ccac046309d44f747d42777f00f1c9b by Dragan Mladjenovic <Dragan.Mladjenovic@amd.com>: [ROCm] Fix sporadic hangs in CommInitRanks Pointer to comm_handle must stay valid untill GroupEnd call on pre 2.18 nccl. Only after that comm_handle contains a valid value. Enforce this by using std::vector instead of temp local. Merging this change closes #11122 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#11122 from ROCm:5.7-rccl-hang 2d196b645ccac046309d44f747d42777f00f1c9b PiperOrigin-RevId: 621235912
Imported from GitHub PR openxla/xla#11122 Copybara import of the project: -- 94f7dc85f6ab07343ddb565b0aac0ed908fd66b4 by Dragan Mladjenovic <Dragan.Mladjenovic@amd.com>: [ROCm] Allow ncclCommInitRankConfig on rocm5.7 -- 2d196b645ccac046309d44f747d42777f00f1c9b by Dragan Mladjenovic <Dragan.Mladjenovic@amd.com>: [ROCm] Fix sporadic hangs in CommInitRanks Pointer to comm_handle must stay valid untill GroupEnd call on pre 2.18 nccl. Only after that comm_handle contains a valid value. Enforce this by using std::vector instead of temp local. Merging this change closes #11122 PiperOrigin-RevId: 621262510
Hi,
I built and installed tensorflow from source with XLA enabled and GPU disabled (basically I opted N for everything while configuring via ./config except XLA enabling as Y). There were lot of warnings regrding deprecated syntax while building. but the build was successful.
I am able to import tensorflow and run basic print command in session. But while I try to do some computation (for eg. simple addition) it gives me following error:
2017-06-28 15:09:22.366052: F tensorflow/compiler/xla/statusor.cc:41] Attempting to fetch value instead of handling error Not found: could not find registered computation placer for platform Executor -- check target linkage
Aborted
I did a bit of debugging and this error comes just after the call from client/sessions.py:1262 to pywrap_tensorflow:
tf_session.TF_Run(session, options,
feed_dict, fetch_list, target_list,
status, run_metadata)
so I believe it's because it is unable to link to _pywrap_tensorflow_internal.so.
Can you please provide any fix to this or is there something am doing wrong here?
This is blocking my further task so any kind of help is appreciated!
Thanks & Regards
The text was updated successfully, but these errors were encountered: