Skip to content

Commit

Permalink
Address review comments and fix unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
rgomathi committed Oct 15, 2019
1 parent cdd8dbf commit 0883191
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 26 deletions.
11 changes: 11 additions & 0 deletions tensorflow/core/graph/mkl_layout_pass.cc
Expand Up @@ -1574,10 +1574,12 @@ rinfo_.push_back({csinfo_.tanh_grad,
int axis = -1;
string mode_string;
string round_mode_string;
DataType type;
TryGetNodeAttr(n->def(), "narrow_range", &narrow_range);
TryGetNodeAttr(n->def(), "axis", &axis);
TF_CHECK_OK(GetNodeAttr(n->def(), "mode", &mode_string));
TF_CHECK_OK(GetNodeAttr(n->def(), "round_mode", &round_mode_string));
TF_CHECK_OK(GetNodeAttr(n->def(), "T", &type));

if (narrow_range) {
VLOG(1) << "QuantizeOpRewrite: narrow range is enabled for quantization."
Expand Down Expand Up @@ -1608,6 +1610,15 @@ rinfo_.push_back({csinfo_.tanh_grad,

return false;
}
if (mode_string == "MIN_FIRST") {
if (type != DT_QUINT8) {
VLOG(1) << "QuantizeOpRewrite: For MIN_FIRST mode the data type is "
<< "not DT_UINT8."
<< "This case is not optimized by Intel MKL, "
<< "thus using Eigen op for Quantize op ";
return false;
}
}
return true;
}
static bool MaxpoolGradRewrite(const Node* n) {
Expand Down
9 changes: 6 additions & 3 deletions tensorflow/core/graph/mkl_layout_pass_test.cc
Expand Up @@ -1470,7 +1470,7 @@ TEST_F(MklLayoutPassTest, NodeRewrite_QuantizeV2Op_Negative_ConstInp) {
"A->D;B->D:1;C->D:2;D->E");
}

TEST_F(MklLayoutPassTest, NodeRewrite_QuantizeV2Op_Negative_MinFirst) {
TEST_F(MklLayoutPassTest, NodeRewrite_QuantizeV2Op_MinFirst) {
InitGraph(
"node { name: 'A' op: 'Input' } "
"node { name: 'B' op: 'Const' "
Expand All @@ -1491,8 +1491,11 @@ TEST_F(MklLayoutPassTest, NodeRewrite_QuantizeV2Op_Negative_MinFirst) {
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_QUINT8 } }"
" input: ['D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Const);C(Const);D(QuantizeV2);E(Zeta)|"
"A->D;B->D:1;C->D:2;D->E");
"A(Input);B(Const);C(Const);D(_MklQuantizeV2);DMT/_0(Const);DMT/"
"_1(Const);DMT/_2(Const);E(Zeta)|"
"A->D;A:control->DMT/_0:control;A:control->DMT/"
"_1:control;A:control->DMT/_2:control;B->D:1;C->D:2;D->E;DMT/"
"_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
}

TEST_F(MklLayoutPassTest, NodeRewrite_QuantizeV2Op_Negative_NarrowRange_True) {
Expand Down
109 changes: 86 additions & 23 deletions tensorflow/core/kernels/mkl_quantize_op.cc
Expand Up @@ -70,7 +70,7 @@ class MklReorderWithScalePrimitive : public MklPrimitive {
explicit MklReorderWithScalePrimitive(
const memory* from, const memory* to,
const MklReorderWithScaleFwdParams& fwdParams) {
// create reorder primitive
// Create reorder primitive
Setup(from, to, fwdParams);
}

Expand All @@ -87,7 +87,7 @@ class MklReorderWithScalePrimitive : public MklPrimitive {
private:
// Primitive reuse context for reorder
struct ReorderContext {
// MKLDNN memory
// MKL-DNN memory
std::shared_ptr<mkldnn::memory> src_mem;
std::shared_ptr<mkldnn::memory> dst_mem;

Expand Down Expand Up @@ -131,14 +131,11 @@ class MklReorderWithScalePrimitive : public MklPrimitive {
auto const& post_op_params = fwdParams.post_op_params;
mkldnn::primitive_attr post_ops_attr;

if (post_op_params.name == "scale") {
DCHECK_EQ(post_op_params.param.size(), 1);
std::vector<float> scales;
scales.push_back(post_op_params.param[0]);
post_ops_attr.set_output_scales(0, scales);
} else {
DCHECK(post_op_params.name == "scale");
}
DCHECK(post_op_params.name == "scale");
DCHECK_EQ(post_op_params.param.size(), 1);
std::vector<float> scales;
scales.push_back(post_op_params.param[0]);
post_ops_attr.set_output_scales(0, scales);

// Create a reorder
context_.reorder_pd =
Expand Down Expand Up @@ -224,16 +221,16 @@ class MklReorderWithScalePrimitiveFactory : public MklPrimitiveFactory<T> {
}
};

/// Fuction to find(or create) a reorder from memory pointed by
/// from to memory pointed by to, it will create primitive or
/// get primitive from pool if it is cached.
/// Returns the primitive.
// Fuction to find (or create) a reorder from memory pointed by
// 'from' to memory pointed by 'to', it will create primitive or
// get primitive from pool if it is cached.
// Returns the primitive.
template <typename T>
inline primitive FindOrCreateReorder(
const memory* from, const memory* to,
const MklReorderWithScaleFwdParams& fwdParams) {
CHECK_NOTNULL(from);
CHECK_NOTNULL(to);
DCHECK(from);
DCHECK(to);
MklReorderWithScalePrimitive* reorder_prim =
MklReorderWithScalePrimitiveFactory<T>::Get(from, to, fwdParams);
return *reorder_prim->GetPrimitive();
Expand Down Expand Up @@ -268,26 +265,79 @@ class MklQuantizeV2Op : public OpKernel {
}

~MklQuantizeV2Op() {
if (this->minfirst_input_ != nullptr) {
delete this->minfirst_input_;
if (minfirst_input_ != nullptr) {
delete minfirst_input_;
minfirst_input_ = nullptr;
}
}

float* GetMinfirstInputBuf(int size) {
if (!minfirst_input_) {
minfirst_input_ = new float[size];
minfirst_input_size_ = size;
} else if (size != minfirst_input_size_) {
delete minfirst_input_;
minfirst_input_ = new float[size];
minfirst_input_size_ = size;
}

return minfirst_input_;
}

void Compute_Scalar(OpKernelContext* ctx, float min_range, float max_range) {
// TO-DO - Scalar support has to be added for SCALE mode
OP_REQUIRES(ctx, (mode_ == QUANTIZE_MODE_MIN_FIRST),
errors::InvalidArgument(
"Scalar calculation in MKL is supported only for"
"MIN_FIRST mode for now."));

auto cpu_engine = engine(engine::cpu, 0);
const Tensor& input = ctx->input(0);
const unsigned int src_idx = 0;
const Tensor& src_tensor = MklGetInput(ctx, src_idx);

MklDnnShape output_mkl_shape;
output_mkl_shape.SetMklTensor(false);

Tensor* output_tensor = nullptr;
AllocateOutputSetMklShape(ctx, 0, &output_tensor, src_tensor.shape(),
output_mkl_shape);
TensorShape min_tf_shape = {};
MklDnnShape min_mkl_shape;
min_mkl_shape.SetMklTensor(false);
Tensor* output_min_tensor = nullptr;
AllocateOutputSetMklShape(ctx, 1, &output_min_tensor, min_tf_shape,
min_mkl_shape);
TensorShape max_tf_shape = {};
MklDnnShape max_mkl_shape;
max_mkl_shape.SetMklTensor(false);
Tensor* output_max_tensor = nullptr;
AllocateOutputSetMklShape(ctx, 2, &output_max_tensor, max_tf_shape,
max_mkl_shape);

// Estimate scale for qunatization
float scale_factor = 0;
const int number_of_bits = sizeof(T) * 8;
const int64 number_of_steps = static_cast<int64>(1) << number_of_bits;
scale_factor = (number_of_steps - 1.0) / (max_range - min_range);

float* src_data = const_cast<float*>(src_tensor.flat<float>().data());
T* out_data = output_tensor->flat<T>().data();

out_data[0] = (src_data[0] - min_range) * scale_factor;
output_min_tensor->flat<float>()(0) = min_range;
output_max_tensor->flat<float>()(0) = max_range;

return;
}

void Compute(OpKernelContext* ctx) override {
const Tensor& input = ctx->input(0);
const float input_min_range = ctx->input(1).flat<float>()(0);
const float input_max_range = ctx->input(2).flat<float>()(0);
float min_range = std::min(0.0f, input_min_range);
float max_range;
OP_REQUIRES(ctx, (input_max_range > input_min_range),
OP_REQUIRES(ctx, !(input_max_range < input_min_range),
errors::InvalidArgument(
"input_max_range must be larger than input_min_range."));

Expand Down Expand Up @@ -320,6 +370,9 @@ class MklQuantizeV2Op : public OpKernel {
// Set the dst layout to be the best mkl layout based on dims and type.
memory::format dst_layout_type;
switch (src_tf_shape.dims()) {
case 0:
Compute_Scalar(ctx, min_range, max_range);
return;
case 1:
dst_layout_type = memory::format::x;
break;
Expand Down Expand Up @@ -354,10 +407,19 @@ class MklQuantizeV2Op : public OpKernel {
auto flat_input = input.flat<float>().data();
if (mode_ == QUANTIZE_MODE_MIN_FIRST) {
float* minfirst_input = GetMinfirstInputBuf(input.NumElements());
#pragma omp parallel for schedule(static)
for (int i = 0; i < input.NumElements(); i++) {
minfirst_input[i] = flat_input[i] - min_range;
}
const Eigen::TensorOpCost cost(
sizeof(float), /*load bytes*/
sizeof(float), /*saved bytes*/
/*sub cost*/ Eigen::TensorOpCost::AddCost<float>());

const CPUDevice& d = ctx->eigen_device<CPUDevice>();
auto ParallelSub = [&](int64 start, int64 end) {
for (int i = start; i < end; i++) {
minfirst_input[i] = flat_input[i] - min_range;
}
};
d.parallelFor(input.NumElements(), cost, ParallelSub);

src.SetUsrMem(src_md, minfirst_input);
} else {
src.SetUsrMem(src_md, &src_tensor);
Expand Down Expand Up @@ -464,6 +526,7 @@ class MklQuantizeV2Op : public OpKernel {
int mode_;
int round_mode_;
float* minfirst_input_ = nullptr;
int minfirst_input_size_;
};

REGISTER_KERNEL_BUILDER(Name("_MklQuantizeV2")
Expand Down

0 comments on commit 0883191

Please sign in to comment.