Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[oneDNN] Improving Graph Rewrite Performance #50932

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
45 changes: 8 additions & 37 deletions tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,25 +68,18 @@ class MklEagerOpRewrite : public EagerOpRewrite {
static bool AlwaysRewrite(EagerOperation* op) { return true; }

// Check if kernel is registered for a particular op.
bool FastCheckIfKernelRegistered(std::string op_name, DataType dt);

// This is called by FastCheckIfKernelRegistered once per unique op name
// and data type.
bool SlowCheckIfKernelRegistered(std::string op_name, DataType dt);
bool IsKernelRegistered(string op_name, DataType dt);

// Helper function to insert mkl_eager_ops to Map
void InsertMKLEagerOps(MklEagerOp op);

// Map used by FastCheckIfKernelRegistered.
std::unordered_map<std::string, bool> registered_kernels_map_;
};

REGISTER_REWRITE(EagerOpRewriteRegistry::POST_PLACEMENT, 10000,
MklEagerOpRewrite);

// Constructor
MklEagerOpRewrite::MklEagerOpRewrite(string name, string file, string line)
: EagerOpRewrite(name, file, line), registered_kernels_map_() {
: EagerOpRewrite(name, file, line) {
InsertMKLEagerOps({"AvgPool", AlwaysRewrite, CreateGenericMklOp});
InsertMKLEagerOps({"AvgPoolGrad", AlwaysRewrite, CreateGenericMklOp});
InsertMKLEagerOps({"AvgPool3D", AlwaysRewrite, CreateGenericMklOp});
Expand Down Expand Up @@ -188,7 +181,7 @@ bool MklEagerOpRewrite::ShouldRewriteOp(EagerOperation* op) {
return false;
}
// Check if we have registered MKL kernel for this op.
bool kernel_found = FastCheckIfKernelRegistered(op->Name(), data_type);
bool kernel_found = IsKernelRegistered(op->Name(), data_type);
if (!kernel_found) {
return false;
}
Expand All @@ -205,37 +198,15 @@ bool MklEagerOpRewrite::ShouldRewriteOp(EagerOperation* op) {
return false;
}

bool MklEagerOpRewrite::FastCheckIfKernelRegistered(std::string op_name,
DataType dt) {
// Check for kernel registration only once per op name and data type
// for performance reasons.
string registered_kernels_key = op_name + std::to_string(dt);
auto kernel_element = registered_kernels_map_.find(registered_kernels_key);
bool kernel_registered = false;
if (kernel_element == registered_kernels_map_.end()) {
// Kernel registration is not verified even once yet.
// So verify and store registration.
kernel_registered = SlowCheckIfKernelRegistered(op_name, dt);
registered_kernels_map_.insert(
std::make_pair(registered_kernels_key, kernel_registered));
} else {
// Kernel is visited at least once. Return stored registration result.
kernel_registered = kernel_element->second;
}

return kernel_registered;
}

bool MklEagerOpRewrite::SlowCheckIfKernelRegistered(string op_name,
DataType dt) {
bool MklEagerOpRewrite::IsKernelRegistered(string op_name, DataType dt) {
// Find if the eager op_name exists in mkl_eager_ops_ list.
auto element = mkl_eager_ops_.find(op_name);
if (element != mkl_eager_ops_.end()) {
// Eager Op exists. So verify registry and return registered or not.
return (mkl_op_registry::IsMklNameChangeOp(
mkl_op_registry::GetMklNativeOpName(op_name), dt) ||
mkl_op_registry::IsMklNameChangeOp(
mkl_op_registry::GetMklOpName(op_name), dt));
return (mkl_op_registry::IsMklOp(
mkl_op_registry::GetMklNativeOpName(op_name), dt, true) ||
mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(op_name), dt,
true));
} else {
return false;
}
Expand Down
38 changes: 20 additions & 18 deletions tensorflow/core/common_runtime/mkl_layout_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2200,7 +2200,7 @@ void MklLayoutRewritePass::GetNodeProducingMklTensor(
// If this is an MKL op, then it will create extra output for MKL layout.
DataType T;
if (TryGetNodeAttr(n->def(), "T", &T) &&
mkl_op_registry::IsMklLayoutDependentOp(n->type_string(), T)) {
mkl_op_registry::IsMklOp(n->type_string(), T, false)) {
// If this is an MKL op, then it will generate an edge that will receive
// Mkl tensor from a node.
// output slot number for Mkl tensor would be N+slot number of TensorFlow
Expand Down Expand Up @@ -3955,7 +3955,7 @@ bool MklLayoutRewritePass::FixMklMetaDataEdges(std::unique_ptr<Graph>* g,
// If graph node is not Mkl node, then return.
DataType T = DT_INVALID;
if (!TryGetNodeAttr(n->def(), "T", &T) ||
!mkl_op_registry::IsMklLayoutDependentOp(n->type_string(), T)) {
!mkl_op_registry::IsMklOp(n->type_string(), T, false)) {
return result;
}

Expand All @@ -3980,7 +3980,7 @@ bool MklLayoutRewritePass::FixMklMetaDataEdges(std::unique_ptr<Graph>* g,
// node, then we don't need to do anything.
Node* e_src = e->src();
if (TryGetNodeAttr(e_src->def(), "T", &T) &&
mkl_op_registry::IsMklLayoutDependentOp(e_src->type_string(), T)) {
mkl_op_registry::IsMklOp(e_src->type_string(), T, false)) {
// Source node for edge 'e' is Mkl node.
// Destination node and destination input slot of e is node 'n' and 'idx'
// resp.
Expand Down Expand Up @@ -4093,24 +4093,26 @@ bool MklLayoutRewritePass::RunPass(std::unique_ptr<Graph>* g) {

DumpGraph("After running MklLayoutRewritePass(NodeMerge+Rewrite)", &**g);

order.clear();
GetReversePostOrder(**g, &order); // This will give us topological sort.
for (Node* n : order) {
// If node is not an op or it cannot run on CPU device, then skip.
if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) {
continue;
}
if (FixMklMetaDataEdges(g, n)) {
string node_name = n->name();
string op_name = n->type_string();
if (!NativeFormatEnabled()) {
order.clear();
GetReversePostOrder(**g, &order); // This will give us topological sort.
for (Node* n : order) {
// If node is not an op or it cannot run on CPU device, then skip.
if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) {
continue;
}
if (FixMklMetaDataEdges(g, n)) {
string node_name = n->name();
string op_name = n->type_string();

VLOG(1) << "MklLayoutRewritePass: fixed metadata edges for node "
<< node_name << " with op " << op_name;
result = true;
VLOG(1) << "MklLayoutRewritePass: fixed metadata edges for node "
<< node_name << " with op " << op_name;
result = true;
}
}
DumpGraph("After running MklLayoutRewritePass(NodeMerge+Rewrite+Fixup)",
&**g);
}
DumpGraph("After running MklLayoutRewritePass(NodeMerge+Rewrite+Fixup)",
&**g);

return result;
}
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/common_runtime/mkl_tfconversion_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class MklToTfConversionPass : public GraphOptimizationPass {
// @input T Datatype to use for checking input op
// @return true if op is Mkl supported; false, otherwise.
inline bool IsMklSupportedOp(const string& op_name, DataType T) const {
return mkl_op_registry::IsMklLayoutDependentOp(op_name, T);
return mkl_op_registry::IsMklOp(op_name, T, false);
}

// Is the input Op supported by Mkl-specific layout AND
Expand Down
1 change: 1 addition & 0 deletions tensorflow/core/graph/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package(
cc_library(
name = "mkl_graph_util",
hdrs = ["mkl_graph_util.h"],
deps = ["@com_google_absl//absl/container:flat_hash_map"],
)

tf_cc_test(
Expand Down
131 changes: 56 additions & 75 deletions tensorflow/core/graph/mkl_graph_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,105 +196,86 @@ static inline void BF16UnsupportedWarning() {
}

// Check whether opname with type T is registered as MKL operator
// that can accept input tensors in MKL layout.
// that will go through name change or layout change pass.
//
// @input: name of the op
// @input: T datatype to be used for checking op
// @return: true if opname is registered as Mkl-layout dependent op;
// false otherwise
static inline bool IsMklLayoutDependentOp(const string& op_name, DataType T) {
string kernel = KernelsRegisteredForOp(op_name);
// @return: true if opname is registered as MKL op that will go through name
// change or layout change pass; false otherwise
static inline bool IsMklOp(const string& op_name, DataType T,
bool is_native_op) {
string label = is_native_op ? kMklNameChangeOpLabelPattern
: kMklLayoutDependentOpLabelPattern;
string registered_kernels_key = op_name + label + std::to_string(T);
thread_local static auto* registered_kernels_map =
new absl::flat_hash_map<string, bool>();
auto kernel_element = registered_kernels_map->find(registered_kernels_key);
bool kernel_registered = false;

if (kernel_element == registered_kernels_map->end()) {
string registered_kernels = KernelsRegisteredForOp(op_name);
// String returned by KernelsRegisteredForOp looks like below:
//
// Op = _MklMatMul, kernels =
// device='CPU'; label='MklNameChangeOp'; T in [DT_COMPLEX128]
// device='CPU'; label='MklNameChangeOp'; T in [DT_COMPLEX64]
// device='CPU'; label='MklNameChangeOp'; T in [DT_DOUBLE]
// device='CPU'; label='MklNameChangeOp'; T in [DT_FLOAT]

if (is_native_op &&
registered_kernels.find(kMklQuantizedOpLabelPattern) != string::npos) {
// Restrict quantized ops to QUINT8, QINT8 and DT_QINT32
kernel_registered = (T == DT_QUINT8 || T == DT_QINT8 || T == DT_QINT32);
}

// Restrict regular ops to FLOAT and BFLOAT16
if (kernel.find(kMklLayoutDependentOpLabelPattern) != string::npos) {
if (T == DT_FLOAT) return true;
if (T == DT_BFLOAT16) {
if (IsBF16SupportedByOneDNNOnThisCPU()) {
return true;
} else {
// Restrict bfloat16 ops to platforms with at least AVX512 support, fall
// back to Eigen implementation otherwise.
BF16UnsupportedWarning();
return false;
// Now we just construct a search string to match what we are looking for.
string search_string =
label + string("; T in [") + DataType_Name(T) + string("]");

if (registered_kernels.find(search_string) != string::npos) {
kernel_registered = is_native_op
? (T == DT_COMPLEX128 || T == DT_COMPLEX64 ||
T == DT_DOUBLE || T == DT_FLOAT)
: T == DT_FLOAT;
if (!kernel_registered) {
if (T == DT_BFLOAT16) {
if (IsBF16SupportedByOneDNNOnThisCPU()) {
kernel_registered = true;
} else {
// Restrict bfloat16 ops to platforms with at least AVX512 support,
// fall back to Eigen implementation otherwise.
BF16UnsupportedWarning();
kernel_registered = false;
}
}
}
}
return false;
registered_kernels_map->insert(
std::make_pair(registered_kernels_key, kernel_registered));
} else {
// Kernel is visited at least once. Return stored registration result.
kernel_registered = kernel_element->second;
}
return false;
return kernel_registered;
}

// TODO(mdfaijul): QuantizedConv2D is registered with input: QUINT8
// filter:QINT8 for mkldnn integration. First a dummy kernel is created
// and then it is replaced by an actual kernel.
static inline bool IsMklQuantizedOp(const string& op_name, DataType Tinput,
DataType Tfilter) {
string kernel = KernelsRegisteredForOp(op_name);

// Restrict quantized ops to QUINT8 and QINT8 for now
if (kernel.find(kMklQuantizedOpLabelPattern) != string::npos) {
if (IsMklOp(op_name, Tinput, kMklQuantizedOpLabelPattern)) {
return (Tfilter == DT_QINT8);
}
return false;
}

// Check whether opname with type T is registered as an MKL operator that
// will go through name change.
//
// @input: name of the op
// @input: T datatype to be used for checking op
// @return: true if opname is registered as MKL op that will go through name
// change; false otherwise
static inline bool IsMklNameChangeOp(const string& op_name, DataType T) {
string kernel = KernelsRegisteredForOp(op_name);
// String returned by KernelsRegisteredForOp looks like below:
//
// Op = _MklMatMul, kernels =
// device='CPU'; label='MklNameChangeOp'; T in [DT_COMPLEX128]
// device='CPU'; label='MklNameChangeOp'; T in [DT_COMPLEX64]
// device='CPU'; label='MklNameChangeOp'; T in [DT_DOUBLE]
// device='CPU'; label='MklNameChangeOp'; T in [DT_FLOAT]

if (kernel.find(kMklQuantizedOpLabelPattern) != string::npos) {
// Restrict quantized ops to QUINT8, QINT8 and DT_QINT32
return (T == DT_QUINT8 || T == DT_QINT8 || T == DT_QINT32);
}

// Now we just construct a search string to match what we are looking for.
string search_string = kMklNameChangeOpLabelPattern;
search_string += string(";") + string(" T in [");
search_string += DataType_Name(T) + string("]");

// Temporarily replacing earlier check by adding a type-specific check so
// that we can selectively decide which type is supported by MKL operators.
// That way kernel registration does not decide which operators we support.
// We are using this change to temporarily disable BFLOAT16 support. Once
// we want to enable it, we will go back to earlier check.
bool isTypeAllowed = false;
if (kernel.find(search_string) != string::npos) {
isTypeAllowed = (T == DT_COMPLEX128 || T == DT_COMPLEX64 ||
T == DT_DOUBLE || T == DT_FLOAT);
if (!isTypeAllowed) {
if (T == DT_BFLOAT16) {
if (IsBF16SupportedByOneDNNOnThisCPU()) {
isTypeAllowed = true;
} else {
// Restrict bfloat16 ops to platforms with at least AVX512 support,
// fall back to Eigen implementation otherwise.
BF16UnsupportedWarning();
isTypeAllowed = false;
}
}
}
return isTypeAllowed;
}
return false;
}

// Check if the operator with 'op_name' and type 'T' is an MKL operator that
// will either understand input tensors in MKL layout or will go through name
// rewrite that some operators go through.
static inline bool IsMklOp(const string& op_name, DataType T) {
return IsMklLayoutDependentOp(op_name, T) || IsMklNameChangeOp(op_name, T);
return IsMklOp(op_name, T, true) || IsMklOp(op_name, T, false);
}

static inline bool IsMklOp(const Node* n) {
Expand Down