Skip to content

Commit

Permalink
Merge pull request #21616 from karllessard:java-ops-default-type-attrs
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 210615110
  • Loading branch information
tensorflower-gardener committed Aug 28, 2018
2 parents af94082 + a053d7b commit 85631dc
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 43 deletions.
30 changes: 30 additions & 0 deletions tensorflow/java/src/gen/cc/java_defs.h
Expand Up @@ -21,6 +21,8 @@ limitations under the License.
#include <string>
#include <utility>

#include "tensorflow/core/framework/types.h"

namespace tensorflow {
namespace java {

Expand Down Expand Up @@ -95,6 +97,34 @@ class Type {
static Type IterableOf(const Type& type) {
return Interface("Iterable").add_parameter(type);
}
static Type ForDataType(DataType data_type) {
switch (data_type) {
case DataType::DT_BOOL:
return Class("Boolean");
case DataType::DT_STRING:
return Class("String");
case DataType::DT_FLOAT:
return Class("Float");
case DataType::DT_DOUBLE:
return Class("Double");
case DataType::DT_UINT8:
return Class("UInt8", "org.tensorflow.types");
case DataType::DT_INT32:
return Class("Integer");
case DataType::DT_INT64:
return Class("Long");
case DataType::DT_RESOURCE:
// TODO(karllessard) create a Resource utility class that could be
// used to store a resource and its type (passed in a second argument).
// For now, we need to force a wildcard and we will unfortunately lose
// track of the resource type.
// Falling through...
default:
// Any other datatypes does not have a equivalent in Java and must
// remain a wildcard (e.g. DT_COMPLEX64, DT_QINT8, ...)
return Wildcard();
}
}
const Kind& kind() const { return kind_; }
const string& name() const { return name_; }
const string& package() const { return package_; }
Expand Down
74 changes: 74 additions & 0 deletions tensorflow/java/src/gen/cc/op_generator.cc
Expand Up @@ -18,6 +18,7 @@ limitations under the License.
#include <memory>
#include <set>
#include <string>
#include <utility>
#include <vector>

#include "tensorflow/core/framework/op_gen_lib.h"
Expand Down Expand Up @@ -100,6 +101,10 @@ void CollectOpDependencies(const OpSpec& op, RenderMode mode,
for (const AttributeSpec& attribute : op.attributes()) {
out->push_back(attribute.var().type());
out->push_back(attribute.jni_type());
if (attribute.has_default_value() &&
attribute.type().kind() == Type::GENERIC) {
out->push_back(Type::ForDataType(attribute.default_value()->type()));
}
}
for (const AttributeSpec& optional_attribute : op.optional_attributes()) {
out->push_back(optional_attribute.var().type());
Expand Down Expand Up @@ -139,6 +144,60 @@ void WriteSetAttrDirective(const AttributeSpec& attr, bool optional,
}
}

void RenderSecondaryFactoryMethod(const OpSpec& op, const Type& op_class,
std::map<string, Type> default_types,
SourceWriter* writer) {
// Build the return type for the secondary factory, replacing generic
// parameters with their default value if any
Type return_type = Type::Class(op_class.name(), op_class.package());
for (const Type& parameter : op_class.parameters()) {
if (parameter.kind() == Type::GENERIC &&
default_types.find(parameter.name()) != default_types.end()) {
return_type.add_parameter(default_types.at(parameter.name()));
} else {
return_type.add_parameter(parameter);
}
}
Method factory = Method::Create("create", return_type);
Javadoc factory_doc = Javadoc::Create(
"Factory method to create a class to wrap a new " + op_class.name() +
" operation to the graph, using "
"default output types.");
Variable scope =
Variable::Create("scope", Type::Class("Scope", "org.tensorflow.op"));
AddArgument(scope, "current graph scope", &factory, &factory_doc);
std::stringstream factory_statement;
factory_statement << "return create(scope";
for (const ArgumentSpec& input : op.inputs()) {
AddArgument(input.var(), input.description(), &factory, &factory_doc);
factory_statement << ", " << input.var().name();
}
for (const AttributeSpec& attr : op.attributes()) {
// Only add attributes that are not types or have no default value to the
// signature of the secondary factory
factory_statement << ", ";
if (attr.type().kind() == Type::GENERIC &&
default_types.find(attr.type().name()) != default_types.end()) {
factory_statement << default_types.at(attr.type().name()).name()
<< ".class";
} else {
AddArgument(attr.var(), attr.description(), &factory, &factory_doc);
factory_statement << attr.var().name();
}
}
if (!op.optional_attributes().empty()) {
Variable options_var = Variable::Varargs("options", Type::Class("Options"));
AddArgument(options_var, "carries optional attributes values", &factory,
&factory_doc);
factory_statement << ", " << options_var.name();
}
factory_doc.add_tag("return", "a new instance of " + op_class.name());

writer->BeginMethod(factory, PUBLIC | STATIC, &factory_doc);
writer->Append(factory_statement.str().c_str()).Append(");").EndLine();
writer->EndMethod();
}

void RenderFactoryMethods(const OpSpec& op, const Type& op_class,
SourceWriter* writer) {
Method factory = Method::Create("create", op_class);
Expand All @@ -151,8 +210,17 @@ void RenderFactoryMethods(const OpSpec& op, const Type& op_class,
for (const ArgumentSpec& input : op.inputs()) {
AddArgument(input.var(), input.description(), &factory, &factory_doc);
}
std::map<string, Type> default_types;
for (const AttributeSpec& attr : op.attributes()) {
AddArgument(attr.var(), attr.description(), &factory, &factory_doc);
// If this attribute is a type with a default value, save its value
// for passing it implicitly in a secondary factory method
if (attr.has_default_value() && attr.type().kind() == Type::GENERIC) {
Type default_type = Type::ForDataType(attr.default_value()->type());
if (!default_type.wildcard()) {
default_types.insert(std::make_pair(attr.type().name(), default_type));
}
}
}
if (!op.optional_attributes().empty()) {
AddArgument(Variable::Varargs("options", Type::Class("Options")),
Expand Down Expand Up @@ -194,6 +262,12 @@ void RenderFactoryMethods(const OpSpec& op, const Type& op_class,
.Append("(opBuilder.build());")
.EndLine();
writer->EndMethod();

// If this operation has type attributes with a default value, create a
// second factory method that infers those values implicitly
if (!default_types.empty()) {
RenderSecondaryFactoryMethod(op, op_class, default_types, writer);
}
}

void RenderConstructor(const OpSpec& op, const Type& op_class,
Expand Down
42 changes: 5 additions & 37 deletions tensorflow/java/src/gen/cc/op_specs.cc
Expand Up @@ -96,43 +96,10 @@ Type TypeResolver::TypeOf(const OpDef_ArgDef& arg_def, bool* iterable_out) {
*iterable_out = true;
visited_attrs_.insert(std::make_pair(arg_def.number_attr(), Type::Int()));
}

Type type = Type::Wildcard();
if (arg_def.type() != DataType::DT_INVALID) {
// resolve type from DataType
switch (arg_def.type()) {
case DataType::DT_BOOL:
type = Type::Class("Boolean");
break;
case DataType::DT_STRING:
type = Type::Class("String");
break;
case DataType::DT_FLOAT:
type = Type::Class("Float");
break;
case DataType::DT_DOUBLE:
type = Type::Class("Double");
break;
case DataType::DT_UINT8:
type = Type::Class("UInt8", "org.tensorflow.types");
break;
case DataType::DT_INT32:
type = Type::Class("Integer");
break;
case DataType::DT_INT64:
type = Type::Class("Long");
break;
case DataType::DT_RESOURCE:
// TODO(karllessard) create a Resource utility class that could be
// used to store a resource and its type (passed in a second argument).
// For now, we need to force a wildcard and we will unfortunately lose
// track of the resource type.
break;
default:
// Any other datatypes does not have a equivalent in Java and must
// remain a wildcard (e.g. DT_COMPLEX64, DT_QINT8, ...)
break;
}
type = Type::ForDataType(arg_def.type());

} else if (!arg_def.type_attr().empty()) {
// resolve type from attribute (if already visited, retrieve its type)
if (IsAttributeVisited(arg_def.type_attr())) {
Expand Down Expand Up @@ -337,7 +304,7 @@ AttributeSpec CreateAttribute(const OpDef_AttrDef& attr_def,
bool iterable = false;
std::pair<Type, Type> types = type_resolver->TypesOf(attr_def, &iterable);
Type var_type = types.first.kind() == Type::GENERIC
? Type::Class("Class").add_parameter(types.first)
? Type::ClassOf(types.first)
: types.first;
if (iterable) {
var_type = Type::ListOf(var_type);
Expand All @@ -346,7 +313,8 @@ AttributeSpec CreateAttribute(const OpDef_AttrDef& attr_def,
attr_api_def.name(),
Variable::Create(SnakeToCamelCase(attr_api_def.rename_to()), var_type),
types.first, types.second, ParseDocumentation(attr_api_def.description()),
iterable, attr_api_def.has_default_value());
iterable,
attr_def.has_default_value() ? &attr_def.default_value() : nullptr);
}

ArgumentSpec CreateOutput(const OpDef_ArgDef& output_def,
Expand Down
14 changes: 9 additions & 5 deletions tensorflow/java/src/gen/cc/op_specs.h
Expand Up @@ -94,26 +94,30 @@ class AttributeSpec {
// jni_type: the type of this attribute in JNI layer (see OperationBuilder)
// description: a description of this attribute, in javadoc
// iterable: true if this attribute is a list
// has_default_value: true if this attribute has a default value if not set
// default_value: default value for this attribute or nullptr if none. Any
// value referenced by this pointer must outlive the lifetime
// of the AttributeSpec. This is guaranteed if the value is
// issued by an OpDef of the global OpRegistry.
AttributeSpec(const string& op_def_name, const Variable& var,
const Type& type, const Type& jni_type,
const string& description, bool iterable,
bool has_default_value)
const AttrValue* default_value)
: op_def_name_(op_def_name),
var_(var),
type_(type),
description_(description),
iterable_(iterable),
jni_type_(jni_type),
has_default_value_(has_default_value) {}
default_value_(default_value) {}

const string& op_def_name() const { return op_def_name_; }
const Variable& var() const { return var_; }
const Type& type() const { return type_; }
const string& description() const { return description_; }
bool iterable() const { return iterable_; }
const Type& jni_type() const { return jni_type_; }
bool has_default_value() const { return has_default_value_; }
bool has_default_value() const { return default_value_ != nullptr; }
const AttrValue* default_value() const { return default_value_; }

private:
const string op_def_name_;
Expand All @@ -122,7 +126,7 @@ class AttributeSpec {
const string description_;
const bool iterable_;
const Type jni_type_;
const bool has_default_value_;
const AttrValue* default_value_;
};

class OpSpec {
Expand Down
1 change: 0 additions & 1 deletion tensorflow/java/src/gen/cc/source_writer.cc
Expand Up @@ -16,7 +16,6 @@ limitations under the License.
#include <string>
#include <algorithm>
#include <list>
#include <string>

#include "tensorflow/java/src/gen/cc/source_writer.h"

Expand Down

0 comments on commit 85631dc

Please sign in to comment.