Skip to content

Commit

Permalink
TF OpDefBuilder: support compound types when using tuple restrictions.
Browse files Browse the repository at this point in the history
Can now use type restrictions of the form
  'attr("T: {numbertype, bool, string}")'.

PiperOrigin-RevId: 168003571
  • Loading branch information
ebrevdo authored and tensorflower-gardener committed Sep 8, 2017
1 parent c0525d3 commit 7565452
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 28 deletions.
75 changes: 48 additions & 27 deletions tensorflow/core/framework/op_def_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,37 @@ bool ConsumeAttrNumber(StringPiece* sp, int64* out) {
} \
} while (false)

bool ConsumeCompoundAttrType(StringPiece* sp, StringPiece* out) {
auto capture_begin = sp->begin();
if (sp->Consume("numbertype") || sp->Consume("numerictype") ||
sp->Consume("quantizedtype") || sp->Consume("realnumbertype") ||
sp->Consume("realnumberictype")) {
*out = StringPiece(capture_begin, sp->begin() - capture_begin);
return true;
}
return false;
}

bool ProcessCompoundType(const StringPiece type_string, AttrValue* allowed) {
if (type_string == "numbertype" || type_string == "numerictype") {
for (DataType dt : NumberTypes()) {
allowed->mutable_list()->add_type(dt);
}
} else if (type_string == "quantizedtype") {
for (DataType dt : QuantizedTypes()) {
allowed->mutable_list()->add_type(dt);
}
} else if (type_string == "realnumbertype" ||
type_string == "realnumerictype") {
for (DataType dt : RealNumberTypes()) {
allowed->mutable_list()->add_type(dt);
}
} else {
return false;
}
return true;
}

void FinalizeAttr(StringPiece spec, OpDef* op_def,
std::vector<string>* errors) {
OpDef::AttrDef* attr = op_def->add_attr();
Expand All @@ -123,6 +154,7 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def,
// Read "<type>" or "list(<type>)".
bool is_list = ConsumeListPrefix(&spec);
string type;
StringPiece type_string; // Used if type == "type"
if (spec.Consume("string")) {
type = "string";
} else if (spec.Consume("int")) {
Expand All @@ -139,29 +171,15 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def,
type = "tensor";
} else if (spec.Consume("func")) {
type = "func";
} else if (spec.Consume("numbertype") || spec.Consume("numerictype")) {
} else if (ConsumeCompoundAttrType(&spec, &type_string)) {
type = "type";
AttrValue* allowed = attr->mutable_allowed_values();
for (DataType dt : NumberTypes()) {
allowed->mutable_list()->add_type(dt);
}
} else if (spec.Consume("quantizedtype")) {
type = "type";
AttrValue* allowed = attr->mutable_allowed_values();
for (DataType dt : QuantizedTypes()) {
allowed->mutable_list()->add_type(dt);
}
} else if (spec.Consume("realnumbertype") ||
spec.Consume("realnumerictype")) {
type = "type";
AttrValue* allowed = attr->mutable_allowed_values();
for (DataType dt : RealNumberTypes()) {
allowed->mutable_list()->add_type(dt);
}
VERIFY(ProcessCompoundType(type_string, allowed),
"Expected to see a compound type, saw: ", type_string);
} else if (spec.Consume("{")) {
// e.g. "{ int32, float, bool }" or "{ \"foo\", \"bar\" }"
str_util::RemoveLeadingWhitespace(&spec);
AttrValue* allowed = attr->mutable_allowed_values();
str_util::RemoveLeadingWhitespace(&spec);
if (spec.starts_with("\"") || spec.starts_with("'")) {
type = "string"; // "{ \"foo\", \"bar\" }" or "{ 'foo', 'bar' }"
while (true) {
Expand All @@ -172,8 +190,8 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def,
string unescaped;
string error;
VERIFY(str_util::CUnescape(escaped_string, &unescaped, &error),
"Trouble unescaping \"", escaped_string, "\", got error: ",
error);
"Trouble unescaping \"", escaped_string,
"\", got error: ", error);
allowed->mutable_list()->add_s(unescaped);
if (spec.Consume(",")) {
str_util::RemoveLeadingWhitespace(&spec);
Expand All @@ -184,16 +202,19 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def,
break;
}
}
} else { // "{ int32, float, bool }"
} else { // "{ bool, numbertype, string }"
type = "type";
while (true) {
StringPiece type_string;
VERIFY(ConsumeAttrType(&spec, &type_string),
"Trouble parsing type string at '", spec, "'");
DataType dt;
VERIFY(DataTypeFromString(type_string, &dt),
"Unrecognized type string '", type_string, "'");
allowed->mutable_list()->add_type(dt);
if (ProcessCompoundType(type_string, allowed)) {
// Processed a compound type.
} else {
DataType dt;
VERIFY(DataTypeFromString(type_string, &dt),
"Unrecognized type string '", type_string, "'");
allowed->mutable_list()->add_type(dt);
}
if (spec.Consume(",")) {
str_util::RemoveLeadingWhitespace(&spec);
if (spec.Consume("}")) break; // Allow ending with ", }".
Expand All @@ -204,7 +225,7 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def,
}
}
}
} else {
} else { // if spec.Consume("{")
VERIFY(false, "Trouble parsing type string at '", spec, "'");
}
str_util::RemoveLeadingWhitespace(&spec);
Expand Down
4 changes: 3 additions & 1 deletion tensorflow/core/framework/op_def_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,10 @@ class OpDefBuilder {
// (by convention only using capital letters for attrs that can be inferred)
// <type> can be:
// "string", "int", "float", "bool", "type", "shape", or "tensor"
// "numbertype", "realnumbertype", "quantizedtype", "{int32,int64}"
// "numbertype", "realnumbertype", "quantizedtype"
// (meaning "type" with a restriction on valid values)
// "{int32,int64}" or {realnumbertype,quantizedtype,string}"
// (meaning "type" with a restriction containing unions of value types)
// "{\"foo\", \"bar\n baz\"}", or "{'foo', 'bar\n baz'}"
// (meaning "string" with a restriction on valid values)
// "list(string)", ..., "list(tensor)", "list(numbertype)", ...
Expand Down
19 changes: 19 additions & 0 deletions tensorflow/core/framework/op_def_builder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,27 @@ TEST_F(OpDefBuilderTest, AttrWithRestrictions) {
"[DT_HALF, DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, DT_INT16, "
"DT_UINT16, DT_INT8, DT_COMPLEX64, DT_COMPLEX128, DT_QINT8, DT_QUINT8, "
"DT_QINT32] } } }");
ExpectSuccess(
b().Attr("a:{numbertype, variant}"),
"attr: { name: 'a' type: 'type' allowed_values { list { type: "
"[DT_HALF, DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, DT_INT16, "
"DT_UINT16, DT_INT8, DT_COMPLEX64, DT_COMPLEX128, DT_QINT8, DT_QUINT8, "
"DT_QINT32, DT_VARIANT] } } }");
ExpectSuccess(b().Attr("a:realnumbertype"),
"attr: { name: 'a' type: 'type' allowed_values { list { type: "
"[DT_HALF, DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, "
"DT_INT16, DT_UINT16, DT_INT8] } } }");
ExpectSuccess(b().Attr("a:{realnumbertype, variant , string, }"),
"attr: { name: 'a' type: 'type' allowed_values { list { type: "
"[DT_HALF, DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, "
"DT_INT16, DT_UINT16, DT_INT8, DT_VARIANT, DT_STRING] } } }");
ExpectSuccess(b().Attr("a:quantizedtype"),
"attr: { name: 'a' type: 'type' allowed_values { list { type: "
"[DT_QINT8, DT_QUINT8, DT_QINT32, DT_QINT16, DT_QUINT16]} } }");
ExpectSuccess(b().Attr("a:{quantizedtype ,string}"),
"attr: { name: 'a' type: 'type' allowed_values { list { type: "
"[DT_QINT8, DT_QUINT8, DT_QINT32, DT_QINT16, DT_QUINT16, "
"DT_STRING]} } }");
ExpectSuccess(b().Attr("a:{string,int32}"),
"attr: { name: 'a' type: 'type' allowed_values { list { type: "
"[DT_STRING, DT_INT32] } } }");
Expand Down Expand Up @@ -202,6 +216,11 @@ TEST_F(OpDefBuilderTest, AttrListOfRestricted) {
"attr: { name: 'a' type: 'list(type)' allowed_values { list { type: "
"[DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, DT_INT16, "
"DT_UINT16, DT_INT8, DT_HALF] } } }");
ExpectSuccess(
b().Attr("a:list({realnumbertype, variant})"),
"attr: { name: 'a' type: 'list(type)' allowed_values { list { type: "
"[DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, DT_INT16, "
"DT_UINT16, DT_INT8, DT_HALF, DT_VARIANT] } } }");
ExpectSuccess(
b().Attr("a:list(quantizedtype)"),
"attr: { name: 'a' type: 'list(type)' allowed_values { list { type: "
Expand Down
16 changes: 16 additions & 0 deletions tensorflow/docs_src/extend/adding_an_op.md
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,22 @@ define an attr with constraints, you can use the following `<attr-type-expr>`s:
tf.number_type(t=tf.bool) # Invalid
```

Lists can be combined with other lists and single types. The following
op allows attr `t` to be any of the numberic types, or the bool type:

```c++
REGISTER_OP("NumberOrBooleanType")
.Attr("t: {numbertype, bool}");
```

For this op:

```python
tf.number_or_boolean_type(t=tf.int32) # Valid
tf.number_or_boolean_type(t=tf.bool) # Valid
tf.number_or_boolean_type(t=tf.string) # Invalid
```

* `int >= <n>`: The value must be an int whose value is greater than or equal to
`<n>`, where `<n>` is a natural number.

Expand Down

0 comments on commit 7565452

Please sign in to comment.