Skip to content

Commit

Permalink
Merge pull request #50932 from Intel-tensorflow:mabuzain/improving-mk…
Browse files Browse the repository at this point in the history
…l-layout-pass-perf

PiperOrigin-RevId: 397805346
Change-Id: I084d43a61dbbb6efffc80e64f4c7f316e1f41564
  • Loading branch information
tensorflower-gardener committed Sep 20, 2021
2 parents 58c7bfd + ac69afd commit 2e17f5d
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 131 deletions.
45 changes: 8 additions & 37 deletions tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,25 +68,18 @@ class MklEagerOpRewrite : public EagerOpRewrite {
static bool AlwaysRewrite(EagerOperation* op) { return true; }

// Check if kernel is registered for a particular op.
bool FastCheckIfKernelRegistered(std::string op_name, DataType dt);

// This is called by FastCheckIfKernelRegistered once per unique op name
// and data type.
bool SlowCheckIfKernelRegistered(std::string op_name, DataType dt);
bool IsKernelRegistered(string op_name, DataType dt);

// Helper function to insert mkl_eager_ops to Map
void InsertMKLEagerOps(MklEagerOp op);

// Map used by FastCheckIfKernelRegistered.
std::unordered_map<std::string, bool> registered_kernels_map_;
};

REGISTER_REWRITE(EagerOpRewriteRegistry::POST_PLACEMENT, 10000,
MklEagerOpRewrite);

// Constructor
MklEagerOpRewrite::MklEagerOpRewrite(string name, string file, string line)
: EagerOpRewrite(name, file, line), registered_kernels_map_() {
: EagerOpRewrite(name, file, line) {
InsertMKLEagerOps({"AvgPool", AlwaysRewrite, CreateGenericMklOp});
InsertMKLEagerOps({"AvgPoolGrad", AlwaysRewrite, CreateGenericMklOp});
InsertMKLEagerOps({"AvgPool3D", AlwaysRewrite, CreateGenericMklOp});
Expand Down Expand Up @@ -188,7 +181,7 @@ bool MklEagerOpRewrite::ShouldRewriteOp(EagerOperation* op) {
return false;
}
// Check if we have registered MKL kernel for this op.
bool kernel_found = FastCheckIfKernelRegistered(op->Name(), data_type);
bool kernel_found = IsKernelRegistered(op->Name(), data_type);
if (!kernel_found) {
return false;
}
Expand All @@ -205,37 +198,15 @@ bool MklEagerOpRewrite::ShouldRewriteOp(EagerOperation* op) {
return false;
}

bool MklEagerOpRewrite::FastCheckIfKernelRegistered(std::string op_name,
DataType dt) {
// Check for kernel registration only once per op name and data type
// for performance reasons.
string registered_kernels_key = op_name + std::to_string(dt);
auto kernel_element = registered_kernels_map_.find(registered_kernels_key);
bool kernel_registered = false;
if (kernel_element == registered_kernels_map_.end()) {
// Kernel registration is not verified even once yet.
// So verify and store registration.
kernel_registered = SlowCheckIfKernelRegistered(op_name, dt);
registered_kernels_map_.insert(
std::make_pair(registered_kernels_key, kernel_registered));
} else {
// Kernel is visited at least once. Return stored registration result.
kernel_registered = kernel_element->second;
}

return kernel_registered;
}

bool MklEagerOpRewrite::SlowCheckIfKernelRegistered(string op_name,
DataType dt) {
bool MklEagerOpRewrite::IsKernelRegistered(string op_name, DataType dt) {
// Find if the eager op_name exists in mkl_eager_ops_ list.
auto element = mkl_eager_ops_.find(op_name);
if (element != mkl_eager_ops_.end()) {
// Eager Op exists. So verify registry and return registered or not.
return (mkl_op_registry::IsMklNameChangeOp(
mkl_op_registry::GetMklNativeOpName(op_name), dt) ||
mkl_op_registry::IsMklNameChangeOp(
mkl_op_registry::GetMklOpName(op_name), dt));
return (mkl_op_registry::IsMklOp(
mkl_op_registry::GetMklNativeOpName(op_name), dt, true) ||
mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(op_name), dt,
true));
} else {
return false;
}
Expand Down
38 changes: 20 additions & 18 deletions tensorflow/core/common_runtime/mkl_layout_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2200,7 +2200,7 @@ void MklLayoutRewritePass::GetNodeProducingMklTensor(
// If this is an MKL op, then it will create extra output for MKL layout.
DataType T;
if (TryGetNodeAttr(n->def(), "T", &T) &&
mkl_op_registry::IsMklLayoutDependentOp(n->type_string(), T)) {
mkl_op_registry::IsMklOp(n->type_string(), T, false)) {
// If this is an MKL op, then it will generate an edge that will receive
// Mkl tensor from a node.
// output slot number for Mkl tensor would be N+slot number of TensorFlow
Expand Down Expand Up @@ -3955,7 +3955,7 @@ bool MklLayoutRewritePass::FixMklMetaDataEdges(std::unique_ptr<Graph>* g,
// If graph node is not Mkl node, then return.
DataType T = DT_INVALID;
if (!TryGetNodeAttr(n->def(), "T", &T) ||
!mkl_op_registry::IsMklLayoutDependentOp(n->type_string(), T)) {
!mkl_op_registry::IsMklOp(n->type_string(), T, false)) {
return result;
}

Expand All @@ -3980,7 +3980,7 @@ bool MklLayoutRewritePass::FixMklMetaDataEdges(std::unique_ptr<Graph>* g,
// node, then we don't need to do anything.
Node* e_src = e->src();
if (TryGetNodeAttr(e_src->def(), "T", &T) &&
mkl_op_registry::IsMklLayoutDependentOp(e_src->type_string(), T)) {
mkl_op_registry::IsMklOp(e_src->type_string(), T, false)) {
// Source node for edge 'e' is Mkl node.
// Destination node and destination input slot of e is node 'n' and 'idx'
// resp.
Expand Down Expand Up @@ -4093,24 +4093,26 @@ bool MklLayoutRewritePass::RunPass(std::unique_ptr<Graph>* g) {

DumpGraph("After running MklLayoutRewritePass(NodeMerge+Rewrite)", &**g);

order.clear();
GetReversePostOrder(**g, &order); // This will give us topological sort.
for (Node* n : order) {
// If node is not an op or it cannot run on CPU device, then skip.
if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) {
continue;
}
if (FixMklMetaDataEdges(g, n)) {
string node_name = n->name();
string op_name = n->type_string();
if (!NativeFormatEnabled()) {
order.clear();
GetReversePostOrder(**g, &order); // This will give us topological sort.
for (Node* n : order) {
// If node is not an op or it cannot run on CPU device, then skip.
if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) {
continue;
}
if (FixMklMetaDataEdges(g, n)) {
string node_name = n->name();
string op_name = n->type_string();

VLOG(1) << "MklLayoutRewritePass: fixed metadata edges for node "
<< node_name << " with op " << op_name;
result = true;
VLOG(1) << "MklLayoutRewritePass: fixed metadata edges for node "
<< node_name << " with op " << op_name;
result = true;
}
}
DumpGraph("After running MklLayoutRewritePass(NodeMerge+Rewrite+Fixup)",
&**g);
}
DumpGraph("After running MklLayoutRewritePass(NodeMerge+Rewrite+Fixup)",
&**g);

return result;
}
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/common_runtime/mkl_tfconversion_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class MklToTfConversionPass : public GraphOptimizationPass {
// @input T Datatype to use for checking input op
// @return true if op is Mkl supported; false, otherwise.
inline bool IsMklSupportedOp(const string& op_name, DataType T) const {
return mkl_op_registry::IsMklLayoutDependentOp(op_name, T);
return mkl_op_registry::IsMklOp(op_name, T, false);
}

// Is the input Op supported by Mkl-specific layout AND
Expand Down
1 change: 1 addition & 0 deletions tensorflow/core/graph/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package(
cc_library(
name = "mkl_graph_util",
hdrs = ["mkl_graph_util.h"],
deps = ["@com_google_absl//absl/container:flat_hash_map"],
)

tf_cc_test(
Expand Down
131 changes: 56 additions & 75 deletions tensorflow/core/graph/mkl_graph_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,105 +196,86 @@ static inline void BF16UnsupportedWarning() {
}

// Check whether opname with type T is registered as MKL operator
// that can accept input tensors in MKL layout.
// that will go through name change or layout change pass.
//
// @input: name of the op
// @input: T datatype to be used for checking op
// @return: true if opname is registered as Mkl-layout dependent op;
// false otherwise
static inline bool IsMklLayoutDependentOp(const string& op_name, DataType T) {
string kernel = KernelsRegisteredForOp(op_name);
// @return: true if opname is registered as MKL op that will go through name
// change or layout change pass; false otherwise
static inline bool IsMklOp(const string& op_name, DataType T,
bool is_native_op) {
string label = is_native_op ? kMklNameChangeOpLabelPattern
: kMklLayoutDependentOpLabelPattern;
string registered_kernels_key = op_name + label + std::to_string(T);
thread_local static auto* registered_kernels_map =
new absl::flat_hash_map<string, bool>();
auto kernel_element = registered_kernels_map->find(registered_kernels_key);
bool kernel_registered = false;

if (kernel_element == registered_kernels_map->end()) {
string registered_kernels = KernelsRegisteredForOp(op_name);
// String returned by KernelsRegisteredForOp looks like below:
//
// Op = _MklMatMul, kernels =
// device='CPU'; label='MklNameChangeOp'; T in [DT_COMPLEX128]
// device='CPU'; label='MklNameChangeOp'; T in [DT_COMPLEX64]
// device='CPU'; label='MklNameChangeOp'; T in [DT_DOUBLE]
// device='CPU'; label='MklNameChangeOp'; T in [DT_FLOAT]

if (is_native_op &&
registered_kernels.find(kMklQuantizedOpLabelPattern) != string::npos) {
// Restrict quantized ops to QUINT8, QINT8 and DT_QINT32
kernel_registered = (T == DT_QUINT8 || T == DT_QINT8 || T == DT_QINT32);
}

// Restrict regular ops to FLOAT and BFLOAT16
if (kernel.find(kMklLayoutDependentOpLabelPattern) != string::npos) {
if (T == DT_FLOAT) return true;
if (T == DT_BFLOAT16) {
if (IsBF16SupportedByOneDNNOnThisCPU()) {
return true;
} else {
// Restrict bfloat16 ops to platforms with at least AVX512 support, fall
// back to Eigen implementation otherwise.
BF16UnsupportedWarning();
return false;
// Now we just construct a search string to match what we are looking for.
string search_string =
label + string("; T in [") + DataType_Name(T) + string("]");

if (registered_kernels.find(search_string) != string::npos) {
kernel_registered = is_native_op
? (T == DT_COMPLEX128 || T == DT_COMPLEX64 ||
T == DT_DOUBLE || T == DT_FLOAT)
: T == DT_FLOAT;
if (!kernel_registered) {
if (T == DT_BFLOAT16) {
if (IsBF16SupportedByOneDNNOnThisCPU()) {
kernel_registered = true;
} else {
// Restrict bfloat16 ops to platforms with at least AVX512 support,
// fall back to Eigen implementation otherwise.
BF16UnsupportedWarning();
kernel_registered = false;
}
}
}
}
return false;
registered_kernels_map->insert(
std::make_pair(registered_kernels_key, kernel_registered));
} else {
// Kernel is visited at least once. Return stored registration result.
kernel_registered = kernel_element->second;
}
return false;
return kernel_registered;
}

// TODO(mdfaijul): QuantizedConv2D is registered with input: QUINT8
// filter:QINT8 for mkldnn integration. First a dummy kernel is created
// and then it is replaced by an actual kernel.
static inline bool IsMklQuantizedOp(const string& op_name, DataType Tinput,
DataType Tfilter) {
string kernel = KernelsRegisteredForOp(op_name);

// Restrict quantized ops to QUINT8 and QINT8 for now
if (kernel.find(kMklQuantizedOpLabelPattern) != string::npos) {
if (IsMklOp(op_name, Tinput, kMklQuantizedOpLabelPattern)) {
return (Tfilter == DT_QINT8);
}
return false;
}

// Check whether opname with type T is registered as an MKL operator that
// will go through name change.
//
// @input: name of the op
// @input: T datatype to be used for checking op
// @return: true if opname is registered as MKL op that will go through name
// change; false otherwise
static inline bool IsMklNameChangeOp(const string& op_name, DataType T) {
string kernel = KernelsRegisteredForOp(op_name);
// String returned by KernelsRegisteredForOp looks like below:
//
// Op = _MklMatMul, kernels =
// device='CPU'; label='MklNameChangeOp'; T in [DT_COMPLEX128]
// device='CPU'; label='MklNameChangeOp'; T in [DT_COMPLEX64]
// device='CPU'; label='MklNameChangeOp'; T in [DT_DOUBLE]
// device='CPU'; label='MklNameChangeOp'; T in [DT_FLOAT]

if (kernel.find(kMklQuantizedOpLabelPattern) != string::npos) {
// Restrict quantized ops to QUINT8, QINT8 and DT_QINT32
return (T == DT_QUINT8 || T == DT_QINT8 || T == DT_QINT32);
}

// Now we just construct a search string to match what we are looking for.
string search_string = kMklNameChangeOpLabelPattern;
search_string += string(";") + string(" T in [");
search_string += DataType_Name(T) + string("]");

// Temporarily replacing earlier check by adding a type-specific check so
// that we can selectively decide which type is supported by MKL operators.
// That way kernel registration does not decide which operators we support.
// We are using this change to temporarily disable BFLOAT16 support. Once
// we want to enable it, we will go back to earlier check.
bool isTypeAllowed = false;
if (kernel.find(search_string) != string::npos) {
isTypeAllowed = (T == DT_COMPLEX128 || T == DT_COMPLEX64 ||
T == DT_DOUBLE || T == DT_FLOAT);
if (!isTypeAllowed) {
if (T == DT_BFLOAT16) {
if (IsBF16SupportedByOneDNNOnThisCPU()) {
isTypeAllowed = true;
} else {
// Restrict bfloat16 ops to platforms with at least AVX512 support,
// fall back to Eigen implementation otherwise.
BF16UnsupportedWarning();
isTypeAllowed = false;
}
}
}
return isTypeAllowed;
}
return false;
}

// Check if the operator with 'op_name' and type 'T' is an MKL operator that
// will either understand input tensors in MKL layout or will go through name
// rewrite that some operators go through.
static inline bool IsMklOp(const string& op_name, DataType T) {
return IsMklLayoutDependentOp(op_name, T) || IsMklNameChangeOp(op_name, T);
return IsMklOp(op_name, T, true) || IsMklOp(op_name, T, false);
}

static inline bool IsMklOp(const Node* n) {
Expand Down

0 comments on commit 2e17f5d

Please sign in to comment.