Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Intel MKL] Fixing threadpool bug #46562

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion tensorflow/core/kernels/mkl/mkl_aggregate_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,8 @@ class MklAddNOp : public OpKernel {
}

std::shared_ptr<stream> fwd_cpu_stream;
fwd_cpu_stream.reset(CreateStream(ctx, cpu_engine));
MklDnnThreadPool eigen_tp(ctx);
fwd_cpu_stream.reset(CreateStream(&eigen_tp, cpu_engine));

// Create memory descriptor for MKL-DNN.
// If all input in Tensorflow format, create block memory descriptor,
Expand Down
6 changes: 4 additions & 2 deletions tensorflow/core/kernels/mkl/mkl_avgpooling_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ class MklAvgPoolingOp : public MklPoolingForwardOpBase<T> {

T* dst_data = output_tensor->flat<T>().data();
std::shared_ptr<stream> fwd_cpu_stream;
fwd_cpu_stream.reset(CreateStream(context, pooling_fwd->GetEngine()));
MklDnnThreadPool eigen_tp(context);
fwd_cpu_stream.reset(CreateStream(&eigen_tp, pooling_fwd->GetEngine()));
// Execute pooling op.
pooling_fwd->Execute(src_data, dst_data, nullptr, fwd_cpu_stream);

Expand Down Expand Up @@ -250,7 +251,8 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase<T> {
MklPoolingBwdPrimitiveFactory<T>::Get(bwdParams);

std::shared_ptr<stream> bwd_cpu_stream;
bwd_cpu_stream.reset(CreateStream(context, pooling_bwd->GetEngine()));
MklDnnThreadPool eigen_tp(context);
bwd_cpu_stream.reset(CreateStream(&eigen_tp, pooling_bwd->GetEngine()));
Tensor* output_tensor = nullptr;
this->AllocateOutputTensor(context, *(pooling_bwd->GetPoolingBwdPd()),
orig_input_dims_mkl_order,
Expand Down
5 changes: 3 additions & 2 deletions tensorflow/core/kernels/mkl/mkl_batch_matmul_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ limitations under the License.

#if defined(INTEL_MKL)

#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
Expand All @@ -40,6 +39,7 @@ limitations under the License.
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/matmul_bcast.h"
#include "tensorflow/core/util/mkl_util.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"

namespace tensorflow {

Expand Down Expand Up @@ -144,7 +144,8 @@ class BatchMatMulMkl : public OpKernel {
*params, false /* value for do_not_cache */);
// Execute matmul primitive.
std::shared_ptr<stream> cpu_stream;
cpu_stream.reset(CreateStream(ctx, matmul_prim->GetEngine()));
MklDnnThreadPool eigen_tp(ctx);
cpu_stream.reset(CreateStream(&eigen_tp, matmul_prim->GetEngine()));
matmul_prim->Execute(lhs.flat<Scalar>().data(), rhs.flat<Scalar>().data(),
out->flat<Scalar>().data(), cpu_stream);
}
Expand Down
9 changes: 6 additions & 3 deletions tensorflow/core/kernels/mkl/mkl_concat_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ limitations under the License.
#include <vector>

#include "mkldnn.hpp"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
Expand All @@ -31,6 +30,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/mkl_util.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"

using mkldnn::concat;
using mkldnn::stream;
Expand Down Expand Up @@ -732,7 +732,8 @@ class MklConcatOp : public OpKernel {
DCHECK(dst_tensor != nullptr) << "Output tensor pointer is NULL";

std::shared_ptr<stream> fwd_cpu_stream;
fwd_cpu_stream.reset(CreateStream(context, cpu_engine));
MklDnnThreadPool eigen_tp(context);
fwd_cpu_stream.reset(CreateStream(&eigen_tp, cpu_engine));

if (dnn_shape_dst.IsMklTensor())
dst_md = dnn_shape_dst.GetMklLayout();
Expand Down Expand Up @@ -769,7 +770,9 @@ class MklConcatOp : public OpKernel {
dst_md = dnn_shape_dst.IsMklTensor() ? dnn_shape_dst.GetMklLayout()
: dst_md;
std::shared_ptr<stream> fwd_cpu_stream;
fwd_cpu_stream.reset(CreateStream(context, concat_fwd->GetEngine()));
MklDnnThreadPool eigen_tp(context);
fwd_cpu_stream.reset(
CreateStream(&eigen_tp, concat_fwd->GetEngine()));
dst.SetUsrMem(dst_md, dst_tensor);
dst.SetUsrMemDataHandle(dst_tensor, fwd_cpu_stream);
// Execute concat
Expand Down
4 changes: 3 additions & 1 deletion tensorflow/core/kernels/mkl/mkl_conv_grad_filter_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,9 @@ class MklConvCustomBackpropFilterOp

// Execute convolution backward filter.
std::shared_ptr<stream> bwd_cpu_stream;
bwd_cpu_stream.reset(CreateStream(context, conv_bwd_filter->GetEngine()));
MklDnnThreadPool eigen_tp(context);
bwd_cpu_stream.reset(
CreateStream(&eigen_tp, conv_bwd_filter->GetEngine()));
if (bias_enabled) {
T* diff_bias_data =
static_cast<T*>(const_cast<T*>(diff_bias_tensor->flat<T>().data()));
Expand Down
4 changes: 3 additions & 1 deletion tensorflow/core/kernels/mkl/mkl_conv_grad_input_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,9 @@ class MklConvCustomBackpropInputOp
}

std::shared_ptr<stream> bwd_cpu_stream;
bwd_cpu_stream.reset(CreateStream(context, conv_bwd_input->GetEngine()));
MklDnnThreadPool eigen_tp(context);
bwd_cpu_stream.reset(
CreateStream(&eigen_tp, conv_bwd_input->GetEngine()));
// Execute conv bwd input primitive.
conv_bwd_input->Execute(diff_src_data, filter_data, diff_dst_data,
bwd_cpu_stream);
Expand Down
5 changes: 3 additions & 2 deletions tensorflow/core/kernels/mkl/mkl_conv_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ limitations under the License.
#include <unordered_map>
#include <vector>

#include "mkldnn.hpp"
#include "absl/strings/str_join.h"
#include "mkldnn.hpp"
#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
Expand Down Expand Up @@ -724,7 +724,8 @@ class MklConvOp : public OpKernel {

// Execute convolution
std::shared_ptr<stream> fwd_cpu_stream;
fwd_cpu_stream.reset(CreateStream(context, conv_fwd->GetEngine()));
MklDnnThreadPool eigen_tp(context);
fwd_cpu_stream.reset(CreateStream(&eigen_tp, conv_fwd->GetEngine()));
if (fuse_biasadd_) {
const Tensor& bias_tensor = MklGetInput(context, kInputIndex_Bias);
Tbias* bias_data =
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/core/kernels/mkl/mkl_dequantize_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ class MklDequantizeOp : public OpKernel {
MklDnnData<float> dst(&cpu_engine);

std::shared_ptr<stream> reorder_stream;
reorder_stream.reset(CreateStream(ctx, cpu_engine));
MklDnnThreadPool eigen_tp(ctx);
reorder_stream.reset(CreateStream(&eigen_tp, cpu_engine));

// If input is in MKL layout, then simply grab input layout; otherwise,
// construct input TF layout. For TF layout, although input shape
Expand Down
8 changes: 5 additions & 3 deletions tensorflow/core/kernels/mkl/mkl_fused_batch_norm_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ limitations under the License.
==============================================================================*/
#ifdef INTEL_MKL
#include "mkldnn.hpp"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
Expand All @@ -23,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/kernels/no_op.h"
#include "tensorflow/core/util/mkl_util.h"
#include "tensorflow/core/util/tensor_format.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"

#define GET_FLAG(bn_flag) static_cast<int>(mkldnn::normalization_flags::bn_flag)
#define IS_SET(cflag) (context_.flags & GET_FLAG(cflag))
Expand Down Expand Up @@ -866,7 +866,8 @@ class MklFusedBatchNormOp : public OpKernel {

// Execute
std::shared_ptr<stream> fwd_cpu_stream;
fwd_cpu_stream.reset(CreateStream(context, bn_fwd->GetEngine()));
MklDnnThreadPool eigen_tp(context);
fwd_cpu_stream.reset(CreateStream(&eigen_tp, bn_fwd->GetEngine()));
bn_fwd->Execute(src_data, weights_op_data, dst_data, mean_op_data,
variance_op_data, fwd_cpu_stream, ws_data);
float adjust_factor = 1.0;
Expand Down Expand Up @@ -1272,7 +1273,8 @@ class MklFusedBatchNormGradOp : public OpKernel {

// Execute
std::shared_ptr<stream> bwd_cpu_stream;
bwd_cpu_stream.reset(CreateStream(context, bn_bwd->GetEngine()));
MklDnnThreadPool eigen_tp(context);
bwd_cpu_stream.reset(CreateStream(&eigen_tp, bn_bwd->GetEngine()));
bn_bwd->Execute(src_data, mean_data, variance_data, diff_dst_data,
weights_data, diff_src_data, diff_weights_data,
res_space_data, bwd_cpu_stream);
Expand Down
5 changes: 3 additions & 2 deletions tensorflow/core/kernels/mkl/mkl_lrn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ limitations under the License.
#include <vector>

#include "mkldnn.hpp"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
Expand All @@ -35,6 +34,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/util/mkl_util.h"
#include "tensorflow/core/util/tensor_format.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"

#if !defined(IS_MOBILE_PLATFORM)
#include "tensorflow/core/util/work_sharder.h"
Expand Down Expand Up @@ -167,7 +167,8 @@ class MklLRNOp : public OpKernel {
src_dnn_data.CheckReorderToOpMem(lrn_prim_desc.src_desc(), cpu_engine_);

std::vector<primitive> net;
fwd_stream_.reset(CreateStream(context, cpu_engine_));
MklDnnThreadPool eigen_tp(context);
fwd_stream_.reset(CreateStream(&eigen_tp, cpu_engine_));
net.push_back(lrn_forward(lrn_prim_desc));
std::vector<std::unordered_map<int, memory>> net_args;
net_args.push_back({{MKLDNN_ARG_SRC, src_dnn_data.GetOpMem()},
Expand Down
6 changes: 2 additions & 4 deletions tensorflow/core/kernels/mkl/mkl_matmul_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,11 +156,9 @@ class MklMatMulOp : public OpKernel {
char char_transb = transb ? 'T' : 'N';
VLOG(2) << "MKL DNN SGEMM called";
#ifdef ENABLE_MKLDNN_THREADPOOL
auto eigen_tp =
MklDnnThreadPoolWrapper::GetInstance().CreateThreadPoolPtr(ctx);

MklDnnThreadPool eigen_tp(ctx);
dnnl_sgemm_tp(char_transa, char_transb, m, n, k, alpha, a, lda, b, ldb,
beta, c, ldc, eigen_tp);
beta, c, ldc, &eigen_tp);
#else
dnnl_sgemm(char_transa, char_transb, m, n, k, alpha, a, lda, b, ldb, beta,
c, ldc);
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/core/kernels/mkl/mkl_matmul_op_fused.cc
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,8 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
}
}
std::shared_ptr<stream> cpu_stream;
cpu_stream.reset(CreateStream(ctx, matmul_prim->GetEngine()));
MklDnnThreadPool eigen_tp(ctx);
cpu_stream.reset(CreateStream(&eigen_tp, matmul_prim->GetEngine()));
// Execute fused matmul op.
matmul_prim->Execute(src_data, weight_data, bias_data, dst_data,
cpu_stream);
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -718,7 +718,8 @@ void dnnl_gemm(char transa, char transb, int64_t m, int64_t n, int64_t k,

// Execute matmul primitive.
std::shared_ptr<stream> cpu_stream;
cpu_stream.reset(CreateStream(ctx, matmul_prim->GetEngine()));
MklDnnThreadPool eigen_tp(ctx);
cpu_stream.reset(CreateStream(&eigen_tp, matmul_prim->GetEngine()));
matmul_prim->Execute(a, b, c, cpu_stream);
}

Expand Down
6 changes: 4 additions & 2 deletions tensorflow/core/kernels/mkl/mkl_maxpooling_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,8 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase<T> {

T* dst_data = output_tensor->flat<T>().data();
std::shared_ptr<stream> fwd_cpu_stream;
fwd_cpu_stream.reset(CreateStream(context, pooling_fwd->GetEngine()));
MklDnnThreadPool eigen_tp(context);
fwd_cpu_stream.reset(CreateStream(&eigen_tp, pooling_fwd->GetEngine()));

if (int8_forward_inference) {
// Execute pooling op
Expand Down Expand Up @@ -304,7 +305,8 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase<T> {
MklPoolingBwdPrimitiveFactory<T>::Get(bwdParams);

std::shared_ptr<stream> bwd_cpu_stream;
bwd_cpu_stream.reset(CreateStream(context, pooling_bwd->GetEngine()));
MklDnnThreadPool eigen_tp(context);
bwd_cpu_stream.reset(CreateStream(&eigen_tp, pooling_bwd->GetEngine()));
// Allocate output tensor and memory primitive.
Tensor* output_tensor = nullptr;
this->AllocateOutputTensor(context, *(pooling_bwd->GetPoolingBwdPd()),
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/core/kernels/mkl/mkl_qmatmul_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,8 @@ class MklDnnQuantizedMatMulOp : public MklDnnMatMulOpBase<Tweight, Toutput> {
}

std::shared_ptr<stream> cpu_stream;
cpu_stream.reset(CreateStream(context, matmul_fwd->GetEngine()));
MklDnnThreadPool eigen_tp(context);
cpu_stream.reset(CreateStream(&eigen_tp, matmul_fwd->GetEngine()));
// Execute inner-product
Tbias* bias_data = this->GetBiasHandle(
context, matmul_fwd_pd, bias_tensor, weight_tensor, cpu_stream);
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/core/kernels/mkl/mkl_quantize_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,8 @@ class MklQuantizeV2Op : public OpKernel {
MklReorderWithScalePrimitiveFactory<T>::Get(src.GetUsrMem(),
dst.GetUsrMem(), fwdParams);
std::shared_ptr<stream> cpu_stream;
cpu_stream.reset(CreateStream(ctx, reorder_prim->GetEngine()));
MklDnnThreadPool eigen_tp(ctx);
cpu_stream.reset(CreateStream(&eigen_tp, reorder_prim->GetEngine()));
reorder_prim->Execute(src.GetUsrMemDataHandle(), dst.GetUsrMemDataHandle(),
cpu_stream);

Expand Down
8 changes: 5 additions & 3 deletions tensorflow/core/kernels/mkl/mkl_relu_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ limitations under the License.
#include <unordered_map>

#include "mkldnn.hpp"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/util/mkl_util.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"

using mkldnn::algorithm;
using mkldnn::eltwise_forward;
Expand Down Expand Up @@ -479,7 +479,8 @@ class MklReluOpBase : public OpKernel {
MklEltwiseFwdPrimitiveFactory<T>::Get(fwdParams);
auto eltwise_fwd_pd = eltwise_fwd->GetEltwiseFwdPd();
std::shared_ptr<stream> fwd_cpu_stream;
fwd_cpu_stream.reset(CreateStream(context, eltwise_fwd->GetEngine()));
MklDnnThreadPool eigen_tp(context);
fwd_cpu_stream.reset(CreateStream(&eigen_tp, eltwise_fwd->GetEngine()));
// Check if src needs to be reordered
bool is_src_reordered = false;
const T* src_data = src_tensor.flat<T>().data();
Expand Down Expand Up @@ -685,7 +686,8 @@ class MklReluGradOpBase : public OpKernel {

auto eltwise_bwd_pd = eltwise_bwd->GetEltwiseBwdPd();
std::shared_ptr<stream> bwd_cpu_stream;
bwd_cpu_stream.reset(CreateStream(context, eltwise_bwd->GetEngine()));
MklDnnThreadPool eigen_tp(context);
bwd_cpu_stream.reset(CreateStream(&eigen_tp, eltwise_bwd->GetEngine()));
// check whether need reorder for src / diff_dst
const T* src_data = src_tensor.flat<T>().data();
if (src_md != eltwise_bwd_pd->src_desc()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ limitations under the License.
#include <math.h>

#include "mkldnn.hpp"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/type_traits.h"
Expand All @@ -30,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/kernels/no_op.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/util/mkl_util.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"

namespace tensorflow {

Expand Down Expand Up @@ -129,7 +129,8 @@ class MklRequantizePerChannelOp : public OpKernel {
ReorderPd(cpu_engine_, input_mem_prim->get_desc(), cpu_engine_,
output_mem_prim->get_desc(), reorder_attr);
std::shared_ptr<stream> reorder_stream;
reorder_stream.reset(CreateStream(ctx, cpu_engine_));
MklDnnThreadPool eigen_tp(ctx);
reorder_stream.reset(CreateStream(&eigen_tp, cpu_engine_));
std::unordered_map<int, mkldnn::memory> reorder_args = {
{MKLDNN_ARG_FROM, *input_mem_prim},
{MKLDNN_ARG_TO, *output_mem_prim}};
Expand Down
5 changes: 3 additions & 2 deletions tensorflow/core/kernels/mkl/mkl_slice_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ limitations under the License.
#ifdef INTEL_MKL

#include "mkldnn.hpp"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
Expand All @@ -27,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/prefetch.h"
#include "tensorflow/core/util/mkl_util.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"

using mkldnn::stream;

Expand Down Expand Up @@ -453,7 +453,8 @@ class MklSliceOp : public OpKernel {
MklSlicePrimitiveFactory<T>::Get(sliceParams);
// Execute slice reorder.
std::shared_ptr<stream> slice_stream;
slice_stream.reset(CreateStream(context, reorder_prim->GetEngine()));
MklDnnThreadPool eigen_tp(context);
slice_stream.reset(CreateStream(&eigen_tp, reorder_prim->GetEngine()));
reorder_prim->Execute(sliceParams, slice_stream);
} catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +
Expand Down
5 changes: 3 additions & 2 deletions tensorflow/core/kernels/mkl/mkl_softmax_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ limitations under the License.
#ifdef INTEL_MKL

#include "mkldnn.hpp"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/util/mkl_util.h"
#include "tensorflow/core/util/tensor_format.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"

using mkldnn::prop_kind;
using mkldnn::softmax_forward;
Expand Down Expand Up @@ -298,7 +298,8 @@ class MklSoftmaxOp : public OpKernel {
const T* src_data = src_tensor.flat<T>().data();
T* dst_data = reinterpret_cast<T*>(output_tensor->flat<T>().data());
std::shared_ptr<stream> fwd_cpu_stream;
fwd_cpu_stream.reset(CreateStream(context, softmax_fwd->GetEngine()));
MklDnnThreadPool eigen_tp(context);
fwd_cpu_stream.reset(CreateStream(&eigen_tp, softmax_fwd->GetEngine()));
softmax_fwd->Execute(src_data, dst_data, fwd_cpu_stream);
} catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +
Expand Down