Skip to content
Permalink
Browse files

Remove the Variant shape function registry and all references.

Variant shape information is now stored in the HandleData.shape_and_type
during shape inference.

While we're at it, propagate DT_VARIANT ShapeAndType handles through
a number of ops ("AddN", "ZerosLike", and several resource variable ops).

For DT_VARIANT types stored inside a ResourceVariable, the ResourceVariable's
ShapeAndType array is extended.  Before it had one entry storing the dtype and
shape of the tensor stored inside the Variable.  Now, entries indexed 1,... will
contain the ShapeAndType data read from the DT_VARIANT passed as the initializer
at Variable creation time.  ResourceVariable read ops will propagate this
information back out by reading ShapeAndType[1:] from the resource tensor's
handle data.

PiperOrigin-RevId: 231057278
  • Loading branch information...
ebrevdo authored and tensorflower-gardener committed Jan 26, 2019
1 parent 441ea6d commit 52c8bdba0081e4ed428add13e3c0da7ccbfc8f06
@@ -203,6 +203,10 @@ Status GetWindowedOutputSizeFromDims(

Status UnchangedShape(shape_inference::InferenceContext* c) {
c->set_output(0, c->input(0));
auto* handle_data = c->input_handle_shapes_and_types(0);
if (handle_data != nullptr) {
c->set_output_handle_shapes_and_types(0, *handle_data);
}
return Status::OK();
}

@@ -1259,7 +1259,6 @@ bool InferenceContext::RelaxHandleShapesAndMergeTypes(
return false;
}
std::vector<ShapeAndType> new_values(shapes_and_types.size());
bool refined = false;
for (int i = 0; i < shapes_and_types.size(); ++i) {
const ShapeAndType& existing = (*to_update)[i];
if (shapes_and_types[i].dtype == existing.dtype) {
@@ -1269,16 +1268,9 @@ bool InferenceContext::RelaxHandleShapesAndMergeTypes(
return false;
} else {
new_values[i].dtype = shapes_and_types[i].dtype;
refined = true;
}
}
Relax(existing.shape, shapes_and_types[i].shape, &new_values[i].shape);
if (!existing.shape.SameHandle(new_values[i].shape)) {
refined = true;
}
}
if (!refined) {
return false;
}
to_update->swap(new_values);
return true;
@@ -588,9 +588,9 @@ class InferenceContext {
// position idx with the specified shapes and types. This requires idx to be
// in the [0, num_inputs) range.
//
// If the relax is successful and any of the new shapes differs from the old
// one, or any of the old dtypes was DT_INVALID, store the new shapes and
// return true. Return false otherwise.
// If the relax is successful (sizes are the same, old dtypes match new ones
// or are DT_INVALID), then store the relaxed shapes and return true.
// Return false otherwise.
//
// See 'RelaxInput' function for full details and examples.
bool RelaxInputHandleShapesAndMergeTypes(
@@ -37,57 +37,6 @@ UnaryVariantOpRegistry* UnaryVariantOpRegistry::Global() {
return global_unary_variant_op_registry;
}

UnaryVariantOpRegistry::VariantShapeFn* UnaryVariantOpRegistry::GetShapeFn(
const TypeIndex& type_index) {
auto found = shape_fns.find(type_index);
if (found == shape_fns.end()) return nullptr;
return &found->second;
}

void UnaryVariantOpRegistry::RegisterShapeFn(const TypeIndex& type_index,
const VariantShapeFn& shape_fn) {
VariantShapeFn* existing = GetShapeFn(type_index);
CHECK_EQ(existing, nullptr)
<< "Unary VariantShapeFn for type_index: "
<< port::MaybeAbiDemangle(type_index.name()) << " already registered";
shape_fns.insert(std::pair<TypeIndex, VariantShapeFn>(type_index, shape_fn));
}

Status GetUnaryVariantShape(const Tensor& variant_tensor, TensorShape* shape) {
CHECK_EQ(variant_tensor.dtype(), DT_VARIANT);
CHECK_EQ(variant_tensor.dims(), 0);
const Variant& v = variant_tensor.scalar<Variant>()();
UnaryVariantOpRegistry::VariantShapeFn* shape_fn =
UnaryVariantOpRegistry::Global()->GetShapeFn(v.TypeId());
if (shape_fn == nullptr) {
return errors::Internal(
"No unary variant shape function found for Variant type_index: ",
port::MaybeAbiDemangle(v.TypeId().name()));
}
return (*shape_fn)(v, shape);
}

// Add some basic registrations for use by others, e.g., for testing.
namespace {
template <typename T>
Status ScalarShape(const T&, TensorShape* shape) {
*shape = TensorShape({});
return Status::OK();
}
} // namespace

#define REGISTER_VARIANT_SHAPE_TYPE(T) \
REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(T, ScalarShape<T>);

// No encode/shape registered for std::complex<> and Eigen::half
// objects yet.
REGISTER_VARIANT_SHAPE_TYPE(int);
REGISTER_VARIANT_SHAPE_TYPE(float);
REGISTER_VARIANT_SHAPE_TYPE(bool);
REGISTER_VARIANT_SHAPE_TYPE(double);

#undef REGISTER_VARIANT_SHAPE_TYPE

UnaryVariantOpRegistry::VariantDecodeFn* UnaryVariantOpRegistry::GetDecodeFn(
StringPiece type_name) {
auto found = decode_fns.find(type_name);
@@ -177,6 +126,37 @@ Status VariantDeviceCopy(
return (*device_copy_fn)(from, to, copy_fn);
}

namespace {
template <typename T>
Status DeviceCopyPrimitiveType(
const T& in, T* out,
const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copier) {
// Dummy copy, we don't actually bother copying to the device and back for
// testing.
*out = in;
return Status::OK();
}
} // namespace

#define REGISTER_VARIANT_DEVICE_COPY_TYPE(T) \
INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
T, VariantDeviceCopyDirection::HOST_TO_DEVICE, \
DeviceCopyPrimitiveType<T>); \
INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
T, VariantDeviceCopyDirection::DEVICE_TO_HOST, \
DeviceCopyPrimitiveType<T>); \
INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
T, VariantDeviceCopyDirection::DEVICE_TO_DEVICE, \
DeviceCopyPrimitiveType<T>);

// No zeros_like registered for std::complex<> or Eigen::half objects yet.
REGISTER_VARIANT_DEVICE_COPY_TYPE(int);
REGISTER_VARIANT_DEVICE_COPY_TYPE(float);
REGISTER_VARIANT_DEVICE_COPY_TYPE(double);
REGISTER_VARIANT_DEVICE_COPY_TYPE(bool);

#undef REGISTER_VARIANT_DEVICE_COPY_TYPE

// Special casing UnaryOpFn per op and per device.
UnaryVariantOpRegistry::VariantUnaryOpFn* UnaryVariantOpRegistry::GetUnaryOpFn(
VariantUnaryOp op, StringPiece device, const TypeIndex& type_index) {
@@ -58,7 +58,6 @@ enum VariantDeviceCopyDirection {

class UnaryVariantOpRegistry {
public:
typedef std::function<Status(const Variant& v, TensorShape*)> VariantShapeFn;
typedef std::function<bool(Variant*)> VariantDecodeFn;
typedef std::function<Status(OpKernelContext*, const Variant&, Variant*)>
VariantUnaryOpFn;
@@ -93,13 +92,6 @@ class UnaryVariantOpRegistry {
AsyncTensorDeviceCopyFn copy_fn)>
AsyncVariantDeviceCopyFn;

// Add a shape lookup function to the registry.
void RegisterShapeFn(const TypeIndex& type_index,
const VariantShapeFn& shape_fn);

// Returns nullptr if no shape function was found for the given TypeIndex.
VariantShapeFn* GetShapeFn(const TypeIndex& type_index);

// Add a decode function to the registry.
void RegisterDecodeFn(const string& type_name,
const VariantDecodeFn& decode_fn);
@@ -154,7 +146,6 @@ class UnaryVariantOpRegistry {
std::size_t operator()(const TypeIndex& x) const { return x.hash_code(); }
};

gtl::FlatMap<TypeIndex, VariantShapeFn, TypeIndexHash> shape_fns;
gtl::FlatMap<StringPiece, VariantDecodeFn, StringPieceHasher> decode_fns;

// Map std::pair<Direction, type_name> to function.
@@ -235,15 +226,6 @@ inline bool operator==(const UnaryVariantOpRegistry::FuncTuple<Op>& lhs,
return (lhs.op_type_ == rhs.op_type_) && (lhs.device_ == rhs.device_) &&
(lhs.type_index_ == rhs.type_index_);
}
// Gets a TensorShape from a Tensor containing a scalar Variant.
// Returns an Internal error if the Variant does not have a registered shape
// function, or if it's a serialized Variant that cannot be decoded.
//
// REQUIRES:
// variant_tensor.dtype() == DT_VARIANT
// variant_tensor.dims() == 0
//
Status GetUnaryVariantShape(const Tensor& variant_tensor, TensorShape* shape);

// Decodes the Variant whose data_type has a registered decode
// function. Returns an Internal error if the Variant does not have a
@@ -326,29 +308,6 @@ Status BinaryOpVariants(OpKernelContext* ctx, VariantBinaryOp op,

namespace variant_op_registry_fn_registration {

template <typename T>
class UnaryVariantShapeRegistration {
public:
typedef std::function<Status(const T& t, TensorShape*)> LocalVariantShapeFn;

UnaryVariantShapeRegistration(const TypeIndex& type_index,
const LocalVariantShapeFn& shape_fn) {
const string type_index_name = port::MaybeAbiDemangle(type_index.name());
UnaryVariantOpRegistry::Global()->RegisterShapeFn(
type_index,
[type_index_name, shape_fn](const Variant& v,
TensorShape* s) -> Status {
const T* t = v.get<T>();
if (t == nullptr) {
return errors::Internal(
"VariantShapeFn: Could not access object, type_index: ",
type_index_name);
}
return shape_fn(*t, s);
});
}
};

template <typename T>
class UnaryVariantDecodeRegistration {
public:
@@ -471,23 +430,6 @@ class UnaryVariantBinaryOpRegistration {

}; // namespace variant_op_registry_fn_registration

// Register a unary shape variant function with the signature:
// Status ShapeFn(const T& t, TensorShape* s);
// to Variants having TypeIndex type_index.
#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(T, shape_function) \
REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ_HELPER( \
__COUNTER__, T, MakeTypeIndex<T>(), shape_function)

#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ_HELPER(ctr, T, type_index, \
shape_function) \
REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ(ctr, T, type_index, shape_function)

#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ(ctr, T, type_index, \
shape_function) \
static variant_op_registry_fn_registration::UnaryVariantShapeRegistration<T> \
register_unary_variant_op_shape_registration_fn_##ctr(type_index, \
shape_function)

// Register a unary decode variant function for the given type.
#define REGISTER_UNARY_VARIANT_DECODE_FUNCTION(T, type_name) \
REGISTER_UNARY_VARIANT_DECODE_FUNCTION_UNIQ_HELPER(__COUNTER__, T, type_name)
@@ -39,13 +39,6 @@ namespace {

struct VariantValue {
string TypeName() const { return "TEST VariantValue"; }
static Status ShapeFn(const VariantValue& v, TensorShape* s) {
if (v.early_exit) {
return errors::InvalidArgument("early exit!");
}
*s = TensorShape({-0xdeadbeef});
return Status::OK();
}
static Status CPUZerosLikeFn(OpKernelContext* ctx, const VariantValue& v,
VariantValue* v_out) {
if (v.early_exit) {
@@ -89,8 +82,6 @@ struct VariantValue {
int value;
};

REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(VariantValue, VariantValue::ShapeFn);

REGISTER_UNARY_VARIANT_DECODE_FUNCTION(VariantValue, "TEST VariantValue");

INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(
@@ -113,38 +104,6 @@ REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_GPU,

} // namespace

TEST(VariantOpShapeRegistryTest, TestBasic) {
class Blah {};
EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetShapeFn(MakeTypeIndex<Blah>()),
nullptr);

auto* shape_fn = UnaryVariantOpRegistry::Global()->GetShapeFn(
MakeTypeIndex<VariantValue>());
EXPECT_NE(shape_fn, nullptr);
TensorShape shape;

VariantValue vv_early_exit{true /* early_exit */};
Variant v = vv_early_exit;
Status s0 = (*shape_fn)(v, &shape);
EXPECT_FALSE(s0.ok());
EXPECT_TRUE(str_util::StrContains(s0.error_message(), "early exit!"));

VariantValue vv_ok{false /* early_exit */};
v = vv_ok;
TF_EXPECT_OK((*shape_fn)(v, &shape));
EXPECT_EQ(shape, TensorShape({-0xdeadbeef}));
}

TEST(VariantOpShapeRegistryTest, TestDuplicate) {
UnaryVariantOpRegistry registry;
UnaryVariantOpRegistry::VariantShapeFn f;
class FjFjFj {};
const auto kTypeIndex = MakeTypeIndex<FjFjFj>();
registry.RegisterShapeFn(kTypeIndex, f);
EXPECT_DEATH(registry.RegisterShapeFn(kTypeIndex, f),
"FjFjFj already registered");
}

TEST(VariantOpDecodeRegistryTest, TestBasic) {
EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetDecodeFn("YOU SHALL NOT PASS"),
nullptr);
@@ -3669,27 +3669,6 @@ tf_cc_test(
],
)

tf_cc_test(
name = "shape_op_test",
srcs = ["shape_op_test.cc"],
deps = [
":array",
":ops_testutil",
":ops_util",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:client_session",
"//tensorflow/core:core_cpu",
"//tensorflow/core:direct_session",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:ops",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)

tf_cuda_cc_test(
name = "sparse_matmul_op_test",
size = "small",
@@ -179,20 +179,7 @@ class AddNOp<Device, Variant> : public OpKernel {
i, " has shape: ", ctx->input(i).shape().DebugString(), "."));
}

TensorShape common_shape;
OP_REQUIRES_OK(ctx, GetUnaryVariantShape(ctx->input(0), &common_shape));
// Step 2: access all variants and ensure shapes match.
for (int i = 1; i < num; ++i) {
TensorShape check_shape;
OP_REQUIRES_OK(ctx, GetUnaryVariantShape(ctx->input(i), &check_shape));
OP_REQUIRES(ctx, common_shape == check_shape,
errors::InvalidArgument(
"AddN of Variants of differing shapes; inputs[0] shape: ",
common_shape.DebugString(), ", inputs[", i,
"] shape: ", check_shape.DebugString()));
}

// Step 3: attempt to add using
// Step 2: attempt to add using
// BinaryOpVariants(ADD_VARIANT_BINARY_OP, ...)
// For the output create a default-constructed variant object.
// TODO(ebrevdo): Perform summation in a tree-structure.
@@ -221,12 +221,5 @@ REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU,
OptionalVariant,
OptionalBinaryAdd<CPUDevice>);

Status OptionalShape(const OptionalVariant& x, TensorShape* s) {
*s = TensorShape({});
return Status::OK();
}

REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(OptionalVariant, OptionalShape);

} // namespace data
} // namespace tensorflow
@@ -99,13 +99,6 @@ REGISTER_LIST_COPY(VariantDeviceCopyDirection::DEVICE_TO_DEVICE);

REGISTER_UNARY_VARIANT_DECODE_FUNCTION(TensorList, TensorList::kTypeName);

Status TensorListShape(const TensorList& t, TensorShape* s) {
*s = TensorShape({});
return Status::OK();
}

REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(TensorList, TensorListShape);

bool TensorList::Decode(const VariantTensorData& data) {
// TODO(srbs): Change the signature to Decode(VariantTensorData data) so
// that we do not have to copy each tensor individually below. This would

0 comments on commit 52c8bdb

Please sign in to comment.
You can’t perform that action at this time.