diff --git a/tensorflow/python/debug/cli/analyzer_cli_test.py b/tensorflow/python/debug/cli/analyzer_cli_test.py index d5aa00c38f0ee0..3292ae330abacb 100644 --- a/tensorflow/python/debug/cli/analyzer_cli_test.py +++ b/tensorflow/python/debug/cli/analyzer_cli_test.py @@ -49,11 +49,23 @@ # Helper function to accommodate MKL-enabled TensorFlow: -# MatMul op is supported by MKL and its name is prefixed with "_Mkl" during the -# MKL graph rewrite pass. +# MatMul op is supported by MKL for some data types and its name is prefixed +# with "_Mkl" during the MKL graph rewrite pass. def _matmul_op_name(): - return "_MklMatMul" if test_util.IsMklEnabled() else "MatMul" + if (test_util.IsMklEnabled() and + _get_graph_matmul_dtype() in _mkl_matmul_supported_types()): + return "_MklMatMul" + else: + return "MatMul" + +# Helper function to get MklMatMul supported types +def _mkl_matmul_supported_types(): + return {"float32", "bfloat16"} +# Helper function to get dtype used in the graph of SetUpClass() +def _get_graph_matmul_dtype(): + # default dtype of matmul op created is float64 + return "float64" def _cli_config_from_temp_file(): return cli_config.CLIConfig(