Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support EIGEN_USE_MKL_ALL macro for building tensorflow #34924

Open
refraction-ray opened this issue Dec 7, 2019 · 11 comments
Open

Support EIGEN_USE_MKL_ALL macro for building tensorflow #34924

refraction-ray opened this issue Dec 7, 2019 · 11 comments
Assignees
Labels
stat:awaiting tensorflower Status - Awaiting response from tensorflower subtype:bazel Bazel related Build_Installation issues TF 1.15 for issues seen on TF 1.15 TF 2.0 Issues relating to TensorFlow 2.0 type:build/install Build and install issues type:feature Feature requests

Comments

@refraction-ray
Copy link
Contributor

Please make sure that this is a feature request. As per our GitHub Policy, we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:feature_template

System information

  • TensorFlow version (you are using): tf1.14, tf2.0
  • Are you willing to contribute it (Yes/No): No, I am not familiar with bazel

Describe the feature and the current behavior/state.
If my understanding is correct, the compiling flag --config=mkl for bazel only enables mkl-dnn supports which could replace several basic ops like matrix multiplication with mkl implementation using jit. However, it seems to me that this flag doesn't enable mkl linkage with Eigen. Therefore, all linear ops beyond several covered in mkl-dnn are still executed in plain eigen single-threaded implementation which is too slow to use (there can be O(10) or even O(100) speed difference for large matrix and eigh, svd, qr ops, as previously noted in #7128, #13222, etc.).

Currently, in tensorflow.bzl, DEIGEN_USE_VML flag is set when compiled with --config=mkl. As explained in #30592, this indicates eigen is not enabled with mkl at all. But a simple replacement of this flag with DEIGEN_USE_MKL_ALL leads to failure of the compiling with the error complaining <mkl.h> not found. Also as noted #12219, MKL optimized Tensorflow does not support EIGEN_USE_MKL_ALL. I know little about bazel setup, so I don't know whether turn on such support is involved or as simple as some small tweaks.

In sum, tuning tf building system to "really" enable mkl behind tf is of great importance and it is vital for the speed of a large range of matrix ops. And this should be the expected behavior for --config=mkl flag after all. Currently, so called "intel optimized" or "mkl enabled" tensorflow is somehow confusing.

Will this change the current api? How?
Not for the user level API.

Who will benefit with this feature?
Anyone using tf in his/her workflow including matrix operations like EIG, SVD, QR etc on CPU. (One can argue that there is no problem for GPU implementation, but cusolver implementations for SVD and QR can still be much slower than mkl cpu implementations. So fast CPU implementations are critical for these matrix decomposition types op).

Any Other info.

@amahendrakar amahendrakar self-assigned this Dec 9, 2019
@amahendrakar amahendrakar added type:build/install Build and install issues subtype:bazel Bazel related Build_Installation issues TF 1.15 for issues seen on TF 1.15 TF 2.0 Issues relating to TensorFlow 2.0 labels Dec 9, 2019
@jvishnuvardhan jvishnuvardhan added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Dec 10, 2019
@Leslie-Fang
Copy link
Contributor

@refraction-ray Could give the full build commands which you use?

@refraction-ray
Copy link
Contributor Author

@Leslie-Fang , I believe the building workflow is similar to official doc using devel docker. The only tweak I have done is changing the flag in tensorflow.bzl directly into DEIGEN_USE_MKL_ALL.
This fails building for sure, since no path (header or library) is specified for mkl library which eigen is required now. Since I have no time to learn about bazel building system, I didn't try adding some cxxopts or hacking more bazel files. And #12219 gives me the impression DEIGEN_USE_MKL_ALL support is not trivial in tf building.
And this is the reason for this issue, either some experts tell me how to build tf with EIGEN_USE_MKL_ALL flag by some small tweaks or building related codebase is better to be improved accordingly for mkl support.

@Leslie-Fang
Copy link
Contributor

@refraction-ray if that's the case in #12219 , although I don't know why but MKLDNN support and eigen with mkl support seems cann't be enabled simultaneously.
Which Eigen ops is the bottleneck of your application?
I have done some Eigen op optimization previously.

@refraction-ray
Copy link
Contributor Author

@Leslie-Fang , as I have mentioned, QR, SVD decompositions are slow with eigen implementation compared to MKL multithreaded version.

@Leslie-Fang
Copy link
Contributor

@refraction-ray Do you know how to measure the performance of QR, SVD ops with MKL multithreaded version?

@Leslie-Fang
Copy link
Contributor

I see some comments in the #7128 about how to measure the performance.
I will try to enable the native Eigen multithread version.

@refraction-ray
Copy link
Contributor Author

@Leslie-Fang , great! Be sure to use mkl linked numpy for the benchmark

@Leslie-Fang
Copy link
Contributor

@refraction-ray Have checked the implementation of QR and SVD in tensorflow.

void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs,
MatrixMaps* outputs) final {
Eigen::HouseholderQR<Matrix> qr(inputs[0]);
const int m = inputs[0].rows();
const int n = inputs[0].cols();
const int min_size = std::min(m, n);
if (full_matrices_) {
outputs->at(0) = qr.householderQ();
outputs->at(1) = qr.matrixQR().template triangularView<Eigen::Upper>();
} else {
// TODO(jpoulson): Exploit the fact that Householder transformations can
// be expanded faster than they can be applied to an arbitrary matrix
// (Cf. LAPACK's DORGQR).
Matrix tmp = Matrix::Identity(m, min_size);
outputs->at(0) = qr.householderQ() * tmp;
auto qr_top = qr.matrixQR().block(0, 0, min_size, n);
outputs->at(1) = qr_top.template triangularView<Eigen::Upper>();
}
}

Both of them are invoking eigen in-place decomposition.
https://eigen.tuxfamily.org/dox/classEigen_1_1HouseholderQR.html
I suspect there is no parallel optimization we can do in the tensorflow op.

@refraction-ray
Copy link
Contributor Author

@Leslie-Fang , this is why I believe the better solution here is to enable mkl linkage with eigen when building tensorflow. If this works, all ops can enjoy multithreaded mkl implementation and no need to hack these ops one by one.

@Exferro
Copy link

Exferro commented Mar 29, 2021

Hi everyone, could you please tell if there has been any progress on the subject? I am not an expert, but enabling MKL support for Eigen operations sounds like a reasonable solution, and would be highly appreciated. Otherwise, it's a bit frustrating to see how TensorFlow humbly uses only one core, while NumPy enjoys running SVD with >1000% CPU load :-)

@NeoZhangJianyu
Copy link

@refraction-ray
Though Eigen supports mkl (https://eigen.tuxfamily.org/dox/TopicUsingIntelMKL.html), it not available through Tensorflow.

@sushreebarsa sushreebarsa self-assigned this Aug 16, 2021
@sushreebarsa sushreebarsa removed their assignment Sep 11, 2021
@SuryanarayanaY SuryanarayanaY added the type:feature Feature requests label Apr 10, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stat:awaiting tensorflower Status - Awaiting response from tensorflower subtype:bazel Bazel related Build_Installation issues TF 1.15 for issues seen on TF 1.15 TF 2.0 Issues relating to TensorFlow 2.0 type:build/install Build and install issues type:feature Feature requests
Projects
None yet
Development

No branches or pull requests

10 participants