Skip to content

Commit

Permalink
Support resource types in CastCompatible checks.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 281999124
Change-Id: Ib3a9749114e8e5c5463c25e9f6618e4d811d1449
  • Loading branch information
gmagogsfm authored and tensorflower-gardener committed Nov 22, 2019
1 parent f635dd1 commit 2f889d7
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 3 deletions.
35 changes: 32 additions & 3 deletions tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc
Expand Up @@ -38,6 +38,7 @@ limitations under the License.
#include "mlir/IR/Diagnostics.h" // TF:local_config_mlir
#include "mlir/IR/DialectImplementation.h" // TF:local_config_mlir
#include "mlir/IR/Function.h" // TF:local_config_mlir
#include "mlir/IR/Location.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Matchers.h" // TF:local_config_mlir
#include "mlir/IR/OpImplementation.h" // TF:local_config_mlir
Expand Down Expand Up @@ -109,12 +110,40 @@ static inline bool HasRankAtMost(Value *value, int64_t rank) {
static bool AreCastCompatible(Type a, Type b) {
if (TensorCastOp::areCastCompatible(a, b)) return true;

// Resource types may optionally contain subtypes information that does not
// match. Check subtypes compatibility when possible, otherwise treat them as
// compatible.
auto a_or_element_type = getElementTypeOrSelf(a);
auto b_or_element_type = getElementTypeOrSelf(b);

auto a_kind = a_or_element_type.getKind();
auto b_kind = b_or_element_type.getKind();

if (a_kind == TensorFlowTypes::RESOURCE &&
b_kind == TensorFlowTypes::RESOURCE) {
auto a_resource_type = a_or_element_type.dyn_cast<ResourceType>();
auto b_resource_type = b_or_element_type.dyn_cast<ResourceType>();
bool a_has_subtype = !a_resource_type.getSubtypes().empty();
bool b_has_subtype = !b_resource_type.getSubtypes().empty();

if (!a_has_subtype || !b_has_subtype) return true;

assert(a_resource_type.getSubtypes().size() <= 1 &&
"Resource type must have at most one subtype");
assert(b_resource_type.getSubtypes().size() <= 1 &&
"Resource type must have at most one subtype");

return TensorCastOp::areCastCompatible(
a_resource_type.getSubtypes().front(),
b_resource_type.getSubtypes().front());
}

// Variant types may optionally contain subtypes information that need not
// match. It is also not possible to compare subtypes for compatibility as
// their interpretation depends on the ops operating on them. So, accept all
// their interpretation depends on the ops operating on them. So, accept all
// pairs of variant types.
return getElementTypeOrSelf(a).getKind() == TensorFlowTypes::VARIANT &&
getElementTypeOrSelf(b).getKind() == TensorFlowTypes::VARIANT;
return a_kind == TensorFlowTypes::VARIANT &&
b_kind == TensorFlowTypes::VARIANT;
}

static bool IsUnknownDimOrRank(int64_t dim_or_rank) {
Expand Down
56 changes: 56 additions & 0 deletions tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir
Expand Up @@ -1042,6 +1042,62 @@ func @testWhileResult(tensor<*xf32>) -> (tensor<*xf32>) {

// -----

func @testWhileCond(tensor<*x!tf.resource<tensor<32xf32>>>) -> (tensor<i1>)
func @testWhileBody(tensor<*x!tf.resource<tensor<32xf32>>>) -> (tensor<!tf.resource<tensor<16xf32>>>)

// Test invalid 'While' operation verifier that detects incompatible tf.resource
// subtypes.
func @testWhileResult(tensor<*x!tf.resource<tensor<32xf32>>>) -> (tensor<!tf.resource<tensor<16xf32>>>) {
^bb0(%arg0: tensor<*x!tf.resource<tensor<32xf32>>>):
// expected-error @+1 {{operand type tensor<*x!tf.resource<tensor<32xf32>>> is incompatible with result type}}
%1 = "tf.While"(%arg0) {
cond = @testWhileCond,
body = @testWhileBody,
is_stateless = false
} : (tensor<*x!tf.resource<tensor<32xf32>>>) -> (tensor<!tf.resource<tensor<16xf32>>>)

return %1 : tensor<!tf.resource<tensor<16xf32>>>
}

// -----

func @testWhileCond(tensor<*x!tf.resource<tensor<32xf32>>>) -> (tensor<i1>)
func @testWhileBody(tensor<*x!tf.resource<tensor<32xf32>>>) -> (tensor<!tf.resource<tensor<*xf32>>>)

// Test 'While' operation verifier allows compatible tf.resource subtypes.
// CHECK-LABEL: func @testWhileResult
func @testWhileResult(tensor<*x!tf.resource<tensor<32xf32>>>) -> (tensor<!tf.resource<tensor<*xf32>>>) {
^bb0(%arg0: tensor<*x!tf.resource<tensor<32xf32>>>):
%1 = "tf.While"(%arg0) {
cond = @testWhileCond,
body = @testWhileBody,
is_stateless = false
} : (tensor<*x!tf.resource<tensor<32xf32>>>) -> (tensor<!tf.resource<tensor<*xf32>>>)

return %1 : tensor<!tf.resource<tensor<*xf32>>>
}

// -----

func @testWhileCond(tensor<*x!tf.resource<tensor<32xf32>>>) -> (tensor<i1>)
func @testWhileBody(tensor<*x!tf.resource<tensor<32xf32>>>) -> (tensor<!tf.resource>)

// Test 'While' operation verifier treats tf.resource with subtype and without
// subtype as compatible types.
// CHECK-LABEL: func @testWhileResult
func @testWhileResult(tensor<*x!tf.resource<tensor<32xf32>>>) -> (tensor<!tf.resource>) {
^bb0(%arg0: tensor<*x!tf.resource<tensor<32xf32>>>):
%1 = "tf.While"(%arg0) {
cond = @testWhileCond,
body = @testWhileBody,
is_stateless = false
} : (tensor<*x!tf.resource<tensor<32xf32>>>) -> (tensor<!tf.resource>)

return %1 : tensor<!tf.resource>
}

// -----

// CHECK-LABEL: func @testValidShape
func @testValidShape(tensor<1x32x32x16xf32>, tensor<*xf32>) -> (tensor<4xi32>, tensor<?xi32>) {
^bb0(%arg0: tensor<1x32x32x16xf32>, %arg1: tensor<*xf32>):
Expand Down

0 comments on commit 2f889d7

Please sign in to comment.