Skip to content

Commit

Permalink
Skip parallel mapfn optimization if the tensorlist cannot be found in…
Browse files Browse the repository at this point in the history
… the current function block.

PiperOrigin-RevId: 627270701
  • Loading branch information
cky9301 authored and tensorflower-gardener committed Apr 23, 2024
1 parent 363294b commit 396acce
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 0 deletions.
46 changes: 46 additions & 0 deletions tensorflow/compiler/mlir/tfrt/tests/mlrt/while_to_map_fn.mlir
Expand Up @@ -638,3 +638,49 @@ func.func private @tf.NestedWhileRegion_cond(%arg0: tensor<i32>, %arg1: tensor<i
return %2 : tensor<i1>
}

// -----

// Test a while to map_fn conversion is skipped if the tensor list cannot be found in the current function body.

// CHECK-LABEL: map/while_cond
func.func private @"map/while_cond"(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<!tf_type.variant<tensor<*xf32>>>, %arg3: tensor<?xf32>) -> tensor<i1> {
%cst = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<3> : tensor<i32>} : () -> tensor<i32>
%0 = "tf.Less"(%arg0, %cst) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
%1 = "tf.Less"(%arg1, %cst) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
%2 = "tf.LogicalAnd"(%0, %1) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<i1>, tensor<i1>) -> tensor<i1>
return %2 : tensor<i1>
}

// CHECK-LABEL: map/while_body
func.func private @"map/while_body"(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<!tf_type.variant<tensor<*xf32>>>, %arg3: tensor<?xf32>) -> (tensor<i32>, tensor<i32>, tensor<!tf_type.variant<tensor<*xf32>>>, tensor<?xf32>) {
%cst = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00, 7.000000e+00, 8.000000e+00, 9.000000e+00]> : tensor<9xf32>} : () -> tensor<9xf32>
%cst_0 = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<0> : tensor<i32>} : () -> tensor<i32>
%cst_1 = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32>
%cst_2 = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<3> : tensor<2xi32>} : () -> tensor<2xi32>
%cst_3 = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00, 7.000000e+00, 8.000000e+00]> : tensor<9xf32>} : () -> tensor<9xf32>
%cst_4 = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<1> : tensor<i32>} : () -> tensor<i32>
%0 = "tf.AddV2"(%arg0, %cst_4) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<i32>, tensor<i32>) -> tensor<i32>
%1 = "tf.Mul"(%arg3, %cst_3) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<?xf32>, tensor<9xf32>) -> tensor<9xf32>
%2 = "tf.Reshape"(%1, %cst_2) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<9xf32>, tensor<2xi32>) -> tensor<3x3xf32>
%3 = "tf.AddV2"(%arg1, %cst_4) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<i32>, tensor<i32>) -> tensor<i32>
%4 = "tf.GatherV2"(%cst_1, %arg1, %cst_0) {batch_dims = 0 : i64, device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<3xi32>, tensor<i32>, tensor<i32>) -> tensor<i32>
%5 = "tf.Cast"(%4) {Truncate = false, device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<i32>) -> tensor<f32>
%6 = "tf.Mul"(%5, %cst) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<f32>, tensor<9xf32>) -> tensor<9xf32>
%7 = "tf.Reshape"(%6, %cst_2) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<9xf32>, tensor<2xi32>) -> tensor<3x3xf32>
%8 = "tf.MatMul"(%2, %7) {device = "/job:localhost/replica:0/task:0/device:CPU:0", transpose_a = false, transpose_b = false} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
%9 = "tf.MatrixDeterminant"(%8) {T = f32, device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<3x3xf32>) -> tensor<f32>
%10 = "tf.TensorListSetItem"(%arg2, %arg1, %9) {device = "/job:localhost/replica:0/task:0/device:CPU:0", resize_if_index_out_of_bounds = false} : (tensor<!tf_type.variant<tensor<*xf32>>>, tensor<i32>, tensor<f32>) -> tensor<!tf_type.variant<tensor<*xf32>>>
return %0, %3, %10, %arg3 : tensor<i32>, tensor<i32>, tensor<!tf_type.variant<tensor<*xf32>>>, tensor<?xf32>
}

//CHECK-LABEL: @func
func.func @func(%arg0: tensor<?xf32>, %arg1: tensor<!tf_type.variant<tensor<*xf32>>>) -> tensor<3xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input:0", outputs = "PartitionedCall:0"}} {
%cst = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<0> : tensor<i32>} : () -> tensor<i32>
%cst_0 = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<> : tensor<0xi32>} : () -> tensor<0xi32>
%cst_1 = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<-1> : tensor<i32>} : () -> tensor<i32>
%cst_2 = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<3> : tensor<i32>} : () -> tensor<i32>
// CHECK-NOT: tf_map_fn
%1:4 = "tf.While"(%cst, %cst, %arg1, %arg0) {_lower_using_switch_merge = true, _num_original_outputs = 6 : i64, _read_only_resource_inputs = [], _xla_propagate_compile_time_consts = true, body = @"map/while_body", cond = @"map/while_cond", device = "/job:localhost/replica:0/task:0/device:CPU:0", is_stateless = true, parallel_iterations = 4 : i64, shape_invariant} : (tensor<i32>, tensor<i32>, tensor<!tf_type.variant<tensor<*xf32>>>, tensor<?xf32>) -> (tensor<i32>, tensor<i32>, tensor<!tf_type.variant<tensor<*xf32>>>, tensor<?xf32>)
%2 = "tf.TensorListStack"(%1#2, %cst_0) {device = "/job:localhost/replica:0/task:0/device:CPU:0", num_elements = 3 : i64} : (tensor<!tf_type.variant<tensor<*xf32>>>, tensor<0xi32>) -> tensor<3xf32>
return %2 : tensor<3xf32>
}
Expand Up @@ -286,6 +286,10 @@ class WhileToMapFnPass
for (auto tensor_list_index : loop_info.tensor_list_or_flow_in) {
mlir::Operation *tensor_list_or_flow_in_defining_op =
while_op.getOperand(tensor_list_index).getDefiningOp();
if (tensor_list_or_flow_in_defining_op == nullptr) {
return mlir::failure();
}

mlir::Operation *max_iterations = nullptr;
if (loop_info.max_iterations_arg_idx.has_value()) {
max_iterations =
Expand Down

0 comments on commit 396acce

Please sign in to comment.