Skip to content

Commit

Permalink
Merge pull request #64520 from Intel-tensorflow:amin/xla-disable-rema…
Browse files Browse the repository at this point in the history
…pper

PiperOrigin-RevId: 627598905
  • Loading branch information
tensorflower-gardener committed Apr 24, 2024
2 parents bdd3d08 + 7b93181 commit 4a80864
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 4 deletions.
54 changes: 50 additions & 4 deletions tensorflow/core/grappler/optimizers/remapper.cc
Expand Up @@ -109,20 +109,23 @@ constexpr int kMissingIndex = -1;
struct RemapperContext {
explicit RemapperContext(GrapplerItem* item, Status* status,
RewriterConfig::CpuLayout cpu_layout_conversion,
bool xla_auto_clustering_on)
bool xla_auto_clustering_on,
bool xla_cpu_jit_disable_fusion)
: nodes_to_preserve(item->NodesToPreserve()),
graph_view(&item->graph, status),
graph_properties(*item),
inferred_graph_properties(false),
cpu_layout_conversion(cpu_layout_conversion),
xla_auto_clustering_on(xla_auto_clustering_on) {}
xla_auto_clustering_on(xla_auto_clustering_on),
xla_cpu_jit_disable_fusion(xla_cpu_jit_disable_fusion) {}

std::unordered_set<string> nodes_to_preserve;
utils::MutableGraphView graph_view;
GraphProperties graph_properties;
bool inferred_graph_properties;
RewriterConfig::CpuLayout cpu_layout_conversion;
bool xla_auto_clustering_on;
bool xla_cpu_jit_disable_fusion;
};

// FusedBatchNorm that can be replaced with a cheaper set of primitives.
Expand Down Expand Up @@ -446,6 +449,9 @@ bool IsCpuCompatibleDepthwiseConv2dNative(const NodeDef* dw_conv2d) {
// Checks if we can rewrite a pattern to the `_Fused{Conv2D,MatMul}` on CPU.
template <typename Pattern>
bool IsCpuCompatible(const RemapperContext& ctx, const Pattern& matched) {
// Disable fusions on CPU when XLA JIT compilation enabled.
if (ctx.xla_cpu_jit_disable_fusion) return false;

const NodeDef& node = ctx.graph_view.graph()->node(matched.contraction);
if (IsConv2D(node)) {
return IsCpuCompatibleConv2D(ctx, &node);
Expand Down Expand Up @@ -998,6 +1004,11 @@ bool FindConv2DWithBatchNorm(const RemapperContext& ctx, int node_index,
const auto* conv2d_node_view = regular_fanin_0.node_view();
const auto* conv2d_node_def = conv2d_node_view->node();

// Disable fusions on CPU when XLA JIT compilation enabled.
if (NodeIsOnCpu(conv2d_node_def) && ctx.xla_cpu_jit_disable_fusion) {
return false;
}

if (!IsConv2D(*conv2d_node_def) || !NodeIsOnCpu(conv2d_node_def) ||
!HaveSameDataType(node_def, conv2d_node_def) ||
!IsCpuCompatibleDataType(conv2d_node_def) ||
Expand Down Expand Up @@ -1695,6 +1706,12 @@ bool FindMatMulBiasAddAndGelu(RemapperContext* ctx, int node_index,
// Check if the MatMul to be fused is device compatible.
NodeDef* matmul_node =
ctx->graph_view.GetNode(matched_nodes_map->at("matmul"))->node();

// Disable fusions on CPU when XLA JIT compilation enabled.
if (NodeIsOnCpu(matmul_node) && ctx->xla_cpu_jit_disable_fusion) {
return false;
}

DataType matmul_dtype = GetDataTypeFromAttr(*matmul_node, "T");

bool cpu_ok = IsMKLEnabled() && IsCpuCompatibleMatMul(*ctx, matmul_node);
Expand Down Expand Up @@ -1732,6 +1749,11 @@ bool FindMatMulBiasAddAndGelu(RemapperContext* ctx, int node_index,
NodeDef* matmul_node =
ctx->graph_view.GetNode(matched_nodes_map->at("matmul"))->node();

// Disable fusions on CPU when XLA JIT compilation enabled.
if (NodeIsOnCpu(matmul_node) && ctx->xla_cpu_jit_disable_fusion) {
return false;
}

// matmul_node is already the _FusedMatMul and we don't need to check its
// data type again.
if (!IsMKLEnabled() && !NodeIsOnGpu(matmul_node)) return false;
Expand Down Expand Up @@ -2298,6 +2320,10 @@ bool FindFusedBatchNorm(const RemapperContext& ctx, int node_index,
FusedBatchNorm* matched) {
const auto* node_view = ctx.graph_view.GetNode(node_index);
const auto* node_def = node_view->node();

// Disable fusions on CPU when XLA JIT compilation enabled.
if (ctx.xla_cpu_jit_disable_fusion && NodeIsOnCpu(node_def)) return false;

if (!IsFusedBatchNorm(*node_def)) return false;
if (GetDataTypeFromAttr(*node_def, "T") != DT_FLOAT) return false;

Expand Down Expand Up @@ -2380,6 +2406,8 @@ bool FindFusedBatchNormEx(const RemapperContext& ctx, int node_index,
// should be processed when it's on GPU and oneDNN CPU is enabled.
if (t_dtype != DT_FLOAT && t_dtype != DT_HALF) return false;
} else {
// Disable fusions on CPU when XLA JIT compilation enabled.
if (ctx.xla_cpu_jit_disable_fusion) return false;
if (IsMKLEnabled() && !IsDataTypeSupportedByOneDNNOnThisCPU(t_dtype))
return false;
}
Expand Down Expand Up @@ -4793,15 +4821,29 @@ bool RequiresInferredShapes(const RemapperContext& ctx, int node_index,
is_matmul_gelu_exact_fusion_candidate() ||
is_act_biasadd_matmul_candidate();
}

inline bool IsXlaCpuGlobalJitOn() {
std::vector<string> tf_xla_flags;
const std::string tf_xla_cpu_global_jit = "--tf_xla_cpu_global_jit";
TF_CHECK_OK(ReadStringsFromEnvVar("TF_XLA_FLAGS", "", &tf_xla_flags));
return std::find(tf_xla_flags.begin(), tf_xla_flags.end(),
tf_xla_cpu_global_jit) != tf_xla_flags.end();
}
} // namespace

Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* optimized_graph) {
GrapplerItem mutable_item = item;
Status status;
bool xla_cpu_jit_disable_fusion =
xla_auto_clustering_on_ && IsXlaCpuGlobalJitOn();
#ifdef DNNL_AARCH64_USE_ACL
xla_cpu_jit_disable_fusion = false;
#endif // DNNL_AARCH64_USE_ACL
RemapperContext ctx(&mutable_item, &status, cpu_layout_conversion_,
xla_auto_clustering_on_);
xla_auto_clustering_on_, xla_cpu_jit_disable_fusion);
TF_RETURN_IF_ERROR(status);

// Processing graph in reverse-topological sorted order allows to remap
// longer chains of dependent ops in one pass.
TF_RETURN_IF_ERROR(
Expand Down Expand Up @@ -4852,7 +4894,7 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
AddInputShapesAttr(ctx, i);
}

if (IsMKLEnabled()) {
if (IsMKLEnabled() && !ctx.xla_cpu_jit_disable_fusion) {
const auto* node_view = ctx.graph_view.GetNode(i);
const auto* node_def = node_view->node();
const string& type_attr = "T";
Expand Down Expand Up @@ -5049,6 +5091,9 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
continue;
}

// Fusions are disabled on XLA CPU in IsCpuCompatible(...) invoked by the
// following fusions.
//
// Remap {Conv2D,DepthwiseConv2D,MatMul}+BiasAdd into the
// _Fused{Conv2D,DepthwiseConv2dNative,MatMul}
ContractionWithBiasAdd contract_with_bias;
Expand Down Expand Up @@ -5119,6 +5164,7 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
continue;
}

// This fusion is enabled on GPU only.
FusedBatchNormGradEx fused_batch_norm_grad_ex;
if (allow_non_differentiable_rewrites &&
FindFusedBatchNormGradEx(ctx, i, &fused_batch_norm_grad_ex)) {
Expand Down
64 changes: 64 additions & 0 deletions tensorflow/core/grappler/optimizers/remapper_test.cc
Expand Up @@ -3044,5 +3044,69 @@ TEST_F(RemapperControlDependencyPatternMatcher, BF16) {
RunTest<DT_BFLOAT16>();
}

class XlaCpuJitDisableFusionTest : public RemapperTest {
protected:
void SetUp() override {
setenv("TF_XLA_FLAGS", "--tf_xla_cpu_global_jit", /*overwrite=*/1);
}

template <DataType DTYPE>
void RunTest() {
using ::tensorflow::ops::Placeholder;

tensorflow::Scope s = tensorflow::Scope::NewRootScope();

auto lhs_shape = ops::Placeholder::Shape({8, 32});
auto rhs_shape = ops::Placeholder::Shape({32, 64});
auto bias_shape = ops::Placeholder::Shape({64});

auto lhs = Placeholder(s.WithOpName("lhs"), DTYPE, lhs_shape);
auto rhs = Placeholder(s.WithOpName("rhs"), DTYPE, rhs_shape);
auto bias = Placeholder(s.WithOpName("bias"), DTYPE, bias_shape);

auto matmul = ops::MatMul(s.WithOpName("matmul"), lhs, rhs);
auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), matmul, bias);
auto fetch = ops::Identity(s.WithOpName("fetch"), bias_add);

auto lhs_t = GenerateTensorWithSetRandom<DTYPE>({8, 32});
auto rhs_t = GenerateTensorWithSetRandom<DTYPE>({32, 64});
auto bias_t = GenerateTensorWithSetRandom<DTYPE>({64});

GrapplerItem item;
item.fetch = {"fetch"};
item.feed = {{"lhs", lhs_t}, {"rhs", rhs_t}, {"bias", bias_t}};
TF_ASSERT_OK(s.ToGraphDef(&item.graph));

const string device = "/device:CPU:0";

// Place all nodes on CPU.
for (int i = 0; i < item.graph.node_size(); ++i) {
item.graph.mutable_node(i)->set_device(device);
}

Remapper optimizer(RewriterConfig::ON, RewriterConfig::NO_CONVERSION_ON_CPU,
/*xla_clustering_on=*/true);
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));

// Fusion should not take place on CPU when tf_xla_cpu_global_jit in ON.
int found = 0;
for (const NodeDef& node : output.node()) {
if (node.name() == "bias_add") {
EXPECT_EQ(node.op(), "BiasAdd");
found++;
} else if (node.name() == "matmul") {
EXPECT_EQ(node.op(), "MatMul");
found++;
}
}
EXPECT_EQ(2, found);
}
};

#if !(DNNL_AARCH64_USE_ACL || GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
TEST_F(XlaCpuJitDisableFusionTest, MatMulWithBias) { RunTest<DT_FLOAT>(); }
#endif // !(DNNL_AARCH64_USE_ACL || GOOGLE_CUDA || TENSORFLOW_USE_ROCM)

} // namespace grappler
} // namespace tensorflow

0 comments on commit 4a80864

Please sign in to comment.