Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proper way to install now on Colab/Linux, also "squeeze" gradient not implemented? #528

Closed
chiehminwei opened this issue Mar 26, 2019 · 2 comments

Comments

@chiehminwei
Copy link

chiehminwei commented Mar 26, 2019

I'm aware this project is still under active development and not all things are ready yet, but I'd like to get some quick feedback on where things are. I've dug through the code and tried a couple of different things.

I first saw the Colab code (https://github.com/pytorch/xla/blob/master/contrib/colab/PyTorch_TPU_XRT_1_13.ipynb) and tried it on Colab. I saw some discrepancy between the code there and code elsewhere. I'm guessing the "train" method is no longer wrapped under XlaModel?

More importantly, the code there (MNIST) works, but once I try using it on my own model (a transformer, BERT), it throws an error:

RuntimeError: differentiation of aten::squeeze is not supported, or it is missing necessary type information

This is weird since I then looked through this repo and found "squeeze.cpp" implemented under xla/torch_xla/csrc/ops/

So I thought maybe the Colab repos (http://storage.googleapis.com/pytorch-tpu-releases/tf-1.13/torch-1.0.0a0+1d94a2b-cp36-cp36m-linux_x86_64.whl) are out-dated and I then looked under http://storage.googleapis.com/pytorch-tpu-releases/ and found a lot of new releases, but they require compilers for python 3.5 and I couldn't install them on Colab.

So, I then tried installing them on a linux machine (Google's Cloud TPU). Interestingly, I also had to pin down MKL to older version like kokoro/ubuntu/common.sh, otherwise importing torch would throw an error about MKL. This time, however, I can't even get the colab code to work. It throws this message:

RuntimeError: tensorflow/compiler/xla/xla_client/xrt_computation_client.cc:109 : Check failed: session_work.first->session()->Run( session_
work.second.feed_inputs, session_work.second.outputs_handles, &outputs) == ::tensorflow::Status::OK() (Not found: Op type not registered 'X
RTAllocateFromTensor' in binary running on n-80018309-w-0. Make sure the Op and Kernel are registered in the binary running in this process
. Note that if you are loading a saved graph which used ops from tf.contrib, accessing (e.g.) tf.contrib.resampler should be done before
importing the graph, as contrib ops are lazily registered when the module is first accessed. vs. OK)

I'm guessing this is saying I should install from source?

So finally, I then tried installing from source following the directions on README.md. I got Pytorch to compile, but got the following message when trying to compile xla:

ERROR: /home/cmw025/pytorch/xla/third_party/tensorflow/tensorflow/core/kernels/BUILD:3371:1: C++ compilation of rule '//tensorflow/
core/kernels:reduction_ops' failed (Exit 1)
In file included from external/eigen_archive/unsupported/Eigen/CXX11/Tensor:124:0,
from ./third_party/eigen3/unsupported/Eigen/CXX11/Tensor:1,
from ./tensorflow/core/kernels/reduction_ops_common.h:27,
from tensorflow/core/kernels/reduction_ops_sum.cc:16:
external/eigen_archive/unsupported/Eigen/CXX11/src/Tensor/TensorReduction.h: In static member function 'static void std::Function
handler<void(_ArgTypes ...), _Functor>::_M_invoke(const std::_Any_data&, _ArgTypes&& ...) [with _Functor = Eigen::internal::TensorE
xecutor<Expression, Eigen::ThreadPoolDevice, Vectorizable, Tileable>::run(const Expression&, const Eigen::ThreadPoolDevice&) [with
Expression = const Eigen::TensorAssignOp<Eigen::TensorMap<Eigen::Tensor<std::complex, 0, 1, long int>, 16, Eigen::MakePointe
r>, const Eigen::TensorReductionOp<Eigen::internal::SumReducer<std::complex >, const Eigen::IndexList<Eigen::type2index<0l>

, const Eigen::TensorMap<Eigen::Tensor<const std::complex, 1, 1, long int>, 16, Eigen::MakePointer>, Eigen::MakePointer> >;
bool Vectorizable = true; bool Tileable = false]::<lambda(Eigen::internal::TensorExecutor<const Eigen::TensorAssignOp<Eigen::Tenso
rMap<Eigen::Tensor<std::complex, 0, 1, long int>, 16, Eigen::MakePointer>, const Eigen::TensorReductionOp<Eigen::internal::S
umReducer<std::complex >, const Eigen::IndexList<Eigen::type2index<0l> >, const Eigen::TensorMap<Eigen::Tensor<const std::co
mplex, 1, 1, long int>, 16, Eigen::MakePointer>, Eigen::MakePointer> >, Eigen::ThreadPoolDevice, true, false>::StorageIndex,
Eigen::internal::TensorExecutor<const Eigen::TensorAssignOp<Eigen::TensorMap<Eigen::Tensor<std::complex, 0, 1, long int>, 1
6, Eigen::MakePointer>, const Eigen::TensorReductionOp<Eigen::internal::SumReducer<std::complex >, const Eigen::IndexList<Ei
gen::type2index<0l> >, const Eigen::TensorMap<Eigen::Tensor<const std::complex, 1, 1, long int>, 16, Eigen::MakePointer>, Ei
gen::MakePointer> >, Eigen::ThreadPoolDevice, true, false>::StorageIndex)>; _ArgTypes = {long int, long int}]':
external/eigen_archive/unsupported/Eigen/CXX11/src/Tensor/TensorReduction.h:801:9: internal compiler error: in emit_move_insn, at e
xpr.c:3547
values[i] = internal::InnerMostDimReducer<Self, Op>::reduce(*this, firstIndex + i * num_values_to_reduce,
^~~~~~
Please submit a full bug report,
with preprocessed source if appropriate.
See file:///usr/share/doc/gcc-6/README.Bugs for instructions.
Target //tensorflow/compiler/xla/xla_client:libxla_computation_client.so failed to build
Use --verbose_failures to see the command lines of failed build steps.
INFO: Elapsed time: 1306.383s, Critical Path: 52.54s
INFO: 736 processes: 736 local.
FAILED: Build did NOT complete successfully

I also get this bazel warning, not sure if that's what's causing the problem:

WARNING: build_bazel_rules_apple depends on bazel_skylib loaded from https://github.com/bazelbuild/bazel-skylib.git (tag 0.6.0)
, but we have detected it already loaded into your workspace from None (tag None). You may run into compatibility issues. To silenc
e this warning, pass ignore_version_differences = True to apple_rules_dependencies().

I'm also a little confused, since the python script (setup.py) doesn't seem to make use of the kokoro/ubuntu/common.sh script. How do you actually build now? Any advice is appreciated. Thank you so much.

@dlibenzi
Copy link
Collaborator

I'm aware this project is still under active development and not all things are ready yet, but I'd like to get some quick feedback on where things are. I've dug through the code and tried a couple of different things.

I first saw the Colab code (https://github.com/pytorch/xla/blob/master/contrib/colab/PyTorch_TPU_XRT_1_13.ipynb) and tried it on Colab. I saw some discrepancy between the code there and code elsewhere. I'm guessing the "train" method is no longer wrapped under XlaModel?

Sorry about that.
The Colab is wrong. Will fix it soon:

#530

Please look at this for reference for the Colab:

device = xm.xla_device()

More importantly, the code there (MNIST) works, but once I try using it on my own model (a transformer, BERT), it throws an error:

RuntimeError: differentiation of aten::squeeze is not supported, or it is missing necessary type information

This is weird since I then looked through this repo and found "squeeze.cpp" implemented under xla/torch_xla/csrc/ops/

We have two ways of operation.
One is using the PT JIT, which is what happens when you use the xla_model():

xla_model = xm.XlaModel(

xla_model = xm.XlaModel(

xla_model = xm.XlaModel(

This uses the PT JIT, and quite a few operations are not supported.
We are likely going to abandon the JIT integration, in favor of using a new XTen architecture (the operators you see in ops/ are the ones driving it), in which we automatically fall back to CPU for operations which are not supported.
Unfortunately this code lives on HEAD, and for PT+PT/XLA compatibility with the TF 1.13 backend in Colab, this is not yet available in Colab.
To use the XTen version, you have the following options:

  1. Wait for the next TF release in Colab so that new wheels can be generated fetching most recent bits
  2. Build from source (see below) on any machine, and use XLA CPU (see README)
  3. Wait for TF 1.14 to be out (early May)

We will have more frequent Cloud TPU (and Colab) TF releases in the future, so catching up with latest PT/XLA developments will be easier.
Sorry about that, but we are just booting the PT/XLA project :)

So I thought maybe the Colab repos (http://storage.googleapis.com/pytorch-tpu-releases/tf-1.13/torch-1.0.0a0+1d94a2b-cp36-cp36m-linux_x86_64.whl) are out-dated and I then looked under http://storage.googleapis.com/pytorch-tpu-releases/ and found a lot of new releases, but they require compilers for python 3.5 and I couldn't install them on Colab.

So, I then tried installing them on a linux machine (Google's Cloud TPU). Interestingly, I also had to pin down MKL to older version like kokoro/ubuntu/common.sh, otherwise importing torch would throw an error about MKL. This time, however, I can't even get the colab code to work. It throws this message:

RuntimeError: tensorflow/compiler/xla/xla_client/xrt_computation_client.cc:109 : Check failed: session_work.first->session()->Run( session_
work.second.feed_inputs, session_work.second.outputs_handles, &outputs) == ::tensorflow::Status::OK() (Not found: Op type not registered 'X
RTAllocateFromTensor' in binary running on n-80018309-w-0. Make sure the Op and Kernel are registered in the binary running in this process
. Note that if you are loading a saved graph which used ops from tf.contrib, accessing (e.g.) tf.contrib.resampler should be done before
importing the graph, as contrib ops are lazily registered when the module is first accessed. vs. OK)

To use with that Colab, you need to use the "old" JIT version, but you may be finding missing operators when trying to add new models.
Things like MNIST and resent works, but new models like BERT or MaskRCNN will not.

Yes, MKL needs to be pinned until the PT side fixes the source code to not use anymore the deprecated APIs:

# Pin MKL to older version as as of March 22 2019 the new MKL shipped with

I'm guessing this is saying I should install from source?

So finally, I then tried installing from source following the directions on README.md. I got Pytorch to compile, but got the following message when trying to compile xla:

ERROR: /home/cmw025/pytorch/xla/third_party/tensorflow/tensorflow/core/kernels/BUILD:3371:1: C++ compilation of rule '//tensorflow/
core/kernels:reduction_ops' failed (Exit 1)
In file included from external/eigen_archive/unsupported/Eigen/CXX11/Tensor:124:0,
from ./third_party/eigen3/unsupported/Eigen/CXX11/Tensor:1,
from ./tensorflow/core/kernels/reduction_ops_common.h:27,
from tensorflow/core/kernels/reduction_ops_sum.cc:16:
external/eigen_archive/unsupported/Eigen/CXX11/src/Tensor/TensorReduction.h: In static member function 'static void std::Function
handler<void(_ArgTypes ...), _Functor>::_M_invoke(const std::_Any_data&, _ArgTypes&& ...) [with _Functor = Eigen::internal::TensorE
xecutor<Expression, Eigen::ThreadPoolDevice, Vectorizable, Tileable>::run(const Expression&, const Eigen::ThreadPoolDevice&) [with
Expression = const Eigen::TensorAssignOp<Eigen::TensorMap<Eigen::Tensor<std::complex, 0, 1, long int>, 16, Eigen::MakePointe
r>, const Eigen::TensorReductionOp<Eigen::internal::SumReducer<std::complex >, const Eigen::IndexList<Eigen::type2index<0l>

, const Eigen::TensorMap<Eigen::Tensor<const std::complex, 1, 1, long int>, 16, Eigen::MakePointer>, Eigen::MakePointer> >;
bool Vectorizable = true; bool Tileable = false]::<lambda(Eigen::internal::TensorExecutor<const Eigen::TensorAssignOp<Eigen::Tenso
rMap<Eigen::Tensor<std::complex, 0, 1, long int>, 16, Eigen::MakePointer>, const Eigen::TensorReductionOp<Eigen::internal::S
umReducer<std::complex >, const Eigen::IndexList<Eigen::type2index<0l> >, const Eigen::TensorMap<Eigen::Tensor<const std::co
mplex, 1, 1, long int>, 16, Eigen::MakePointer>, Eigen::MakePointer> >, Eigen::ThreadPoolDevice, true, false>::StorageIndex,
Eigen::internal::TensorExecutor<const Eigen::TensorAssignOp<Eigen::TensorMap<Eigen::Tensor<std::complex, 0, 1, long int>, 1
6, Eigen::MakePointer>, const Eigen::TensorReductionOp<Eigen::internal::SumReducer<std::complex >, const Eigen::IndexList<Ei
gen::type2index<0l> >, const Eigen::TensorMap<Eigen::Tensor<const std::complex, 1, 1, long int>, 16, Eigen::MakePointer>, Ei
gen::MakePointer> >, Eigen::ThreadPoolDevice, true, false>::StorageIndex)>; _ArgTypes = {long int, long int}]':
external/eigen_archive/unsupported/Eigen/CXX11/src/Tensor/TensorReduction.h:801:9: internal compiler error: in emit_move_insn, at e
xpr.c:3547
values[i] = internal::InnerMostDimReducer<Self, Op>::reduce(*this, firstIndex + i * num_values_to_reduce,
^~~~~~
Please submit a full bug report,
with preprocessed source if appropriate.
See file:///usr/share/doc/gcc-6/README.Bugs for instructions.
Target //tensorflow/compiler/xla/xla_client:libxla_computation_client.so failed to build
Use --verbose_failures to see the command lines of failed build steps.
INFO: Elapsed time: 1306.383s, Critical Path: 52.54s
INFO: 736 processes: 736 local.
FAILED: Build did NOT complete successfully

I also get this bazel warning, not sure if that's what's causing the problem:

WARNING: build_bazel_rules_apple depends on bazel_skylib loaded from https://github.com/bazelbuild/bazel-skylib.git (tag 0.6.0)
, but we have detected it already loaded into your workspace from None (tag None). You may run into compatibility issues. To silenc
e this warning, pass ignore_version_differences = True to apple_rules_dependencies().

I'm also a little confused, since the python script (setup.py) doesn't seem to make use of the kokoro/ubuntu/common.sh script. How do you actually build now? Any advice is appreciated. Thank you so much.

That is a GCC ICE unfortunately.
We use clang-7 to build, as described in:

https://github.com/pytorch/xla/blob/master/README.md

The Bazel thing is just a warning, we get that as well.

@asuhan
Copy link
Contributor

asuhan commented Sep 5, 2019

Closing as it should all work now, feel free to reopen if you have follow-up questions.

@asuhan asuhan closed this as completed Sep 5, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants