Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Java] Render secondary factory for default output types #21616

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
30 changes: 30 additions & 0 deletions tensorflow/java/src/gen/cc/java_defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ limitations under the License.
#include <string>
#include <utility>

#include "tensorflow/core/framework/types.h"

namespace tensorflow {
namespace java {

Expand Down Expand Up @@ -95,6 +97,34 @@ class Type {
static Type IterableOf(const Type& type) {
return Interface("Iterable").add_parameter(type);
}
static Type ForDataType(DataType data_type) {
switch (data_type) {
case DataType::DT_BOOL:
return Class("Boolean");
case DataType::DT_STRING:
return Class("String");
case DataType::DT_FLOAT:
return Class("Float");
case DataType::DT_DOUBLE:
return Class("Double");
case DataType::DT_UINT8:
return Class("UInt8", "org.tensorflow.types");
case DataType::DT_INT32:
return Class("Integer");
case DataType::DT_INT64:
return Class("Long");
case DataType::DT_RESOURCE:
// TODO(karllessard) create a Resource utility class that could be
// used to store a resource and its type (passed in a second argument).
// For now, we need to force a wildcard and we will unfortunately lose
// track of the resource type.
// Falling through...
default:
// Any other datatypes does not have a equivalent in Java and must
// remain a wildcard (e.g. DT_COMPLEX64, DT_QINT8, ...)
return Wildcard();
}
}
const Kind& kind() const { return kind_; }
const string& name() const { return name_; }
const string& package() const { return package_; }
Expand Down
74 changes: 74 additions & 0 deletions tensorflow/java/src/gen/cc/op_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
#include <set>
#include <string>
#include <vector>
#include <utility>

#include "tensorflow/core/framework/op_gen_lib.h"
#include "tensorflow/core/lib/core/errors.h"
Expand Down Expand Up @@ -100,6 +101,10 @@ void CollectOpDependencies(const OpSpec& op, RenderMode mode,
for (const AttributeSpec& attribute : op.attributes()) {
out->push_back(attribute.var().type());
out->push_back(attribute.jni_type());
if (attribute.has_default_value()
&& attribute.type().kind() == Type::GENERIC) {
out->push_back(Type::ForDataType(attribute.default_value()->type()));
}
}
for (const AttributeSpec& optional_attribute : op.optional_attributes()) {
out->push_back(optional_attribute.var().type());
Expand Down Expand Up @@ -139,6 +144,60 @@ void WriteSetAttrDirective(const AttributeSpec& attr, bool optional,
}
}

void RenderSecondaryFactoryMethod(const OpSpec& op, const Type& op_class,
std::map<string, Type> default_types,
SourceWriter* writer) {
// Build the return type for the secondary factory, replacing generic
// parameters with their default value if any
Type return_type = Type::Class(op_class.name(), op_class.package());
for (const Type& parameter : op_class.parameters()) {
if (parameter.kind() == Type::GENERIC
&& default_types.find(parameter.name()) != default_types.end()) {
return_type.add_parameter(default_types.at(parameter.name()));
} else {
return_type.add_parameter(parameter);
}
}
Method factory = Method::Create("create", return_type);
Javadoc factory_doc =
Javadoc::Create("Factory method to create a class to wrap a new " +
op_class.name() + " operation to the graph, using "
"default output types.");
Variable scope =
Variable::Create("scope", Type::Class("Scope", "org.tensorflow.op"));
AddArgument(scope, "current graph scope", &factory, &factory_doc);
std::stringstream factory_statement;
factory_statement << "return create(scope";
for (const ArgumentSpec& input : op.inputs()) {
AddArgument(input.var(), input.description(), &factory, &factory_doc);
factory_statement << ", " << input.var().name();
}
for (const AttributeSpec& attr : op.attributes()) {
// Only add attributes that are not types or have no default value to the
// signature of the secondary factory
factory_statement << ", ";
if (attr.type().kind() == Type::GENERIC
&& default_types.find(attr.type().name()) != default_types.end()) {
factory_statement << default_types.at(attr.type().name()).name()
<< ".class";
} else {
AddArgument(attr.var(), attr.description(), &factory, &factory_doc);
factory_statement << attr.var().name();
}
}
if (!op.optional_attributes().empty()) {
Variable options_var = Variable::Varargs("options", Type::Class("Options"));
AddArgument(options_var, "carries optional attributes values", &factory,
&factory_doc);
factory_statement << ", " << options_var.name();
}
factory_doc.add_tag("return", "a new instance of " + op_class.name());

writer->BeginMethod(factory, PUBLIC | STATIC, &factory_doc);
writer->Append(factory_statement.str()).Append(");").EndLine();
writer->EndMethod();
}

void RenderFactoryMethods(const OpSpec& op, const Type& op_class,
SourceWriter* writer) {
Method factory = Method::Create("create", op_class);
Expand All @@ -151,8 +210,17 @@ void RenderFactoryMethods(const OpSpec& op, const Type& op_class,
for (const ArgumentSpec& input : op.inputs()) {
AddArgument(input.var(), input.description(), &factory, &factory_doc);
}
std::map<string, Type> default_types;
for (const AttributeSpec& attr : op.attributes()) {
AddArgument(attr.var(), attr.description(), &factory, &factory_doc);
// If this attribute is a type with a default value, save its value
// for passing it implicitly in a secondary factory method
if (attr.has_default_value() && attr.type().kind() == Type::GENERIC) {
Type default_type = Type::ForDataType(attr.default_value()->type());
if (!default_type.wildcard()) {
default_types.insert(std::make_pair(attr.type().name(), default_type));
}
}
}
if (!op.optional_attributes().empty()) {
AddArgument(Variable::Varargs("options", Type::Class("Options")),
Expand Down Expand Up @@ -194,6 +262,12 @@ void RenderFactoryMethods(const OpSpec& op, const Type& op_class,
.Append("(opBuilder.build());")
.EndLine();
writer->EndMethod();

// If this operation has type attributes with a default value, create a
// second factory method that infers those values implicitly
if (!default_types.empty()) {
RenderSecondaryFactoryMethod(op, op_class, default_types, writer);
}
}

void RenderConstructor(const OpSpec& op, const Type& op_class,
Expand Down
45 changes: 6 additions & 39 deletions tensorflow/java/src/gen/cc/op_specs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,43 +96,10 @@ Type TypeResolver::TypeOf(const OpDef_ArgDef& arg_def, bool* iterable_out) {
*iterable_out = true;
visited_attrs_.insert(std::make_pair(arg_def.number_attr(), Type::Int()));
}

Type type = Type::Wildcard();
if (arg_def.type() != DataType::DT_INVALID) {
// resolve type from DataType
switch (arg_def.type()) {
case DataType::DT_BOOL:
type = Type::Class("Boolean");
break;
case DataType::DT_STRING:
type = Type::Class("String");
break;
case DataType::DT_FLOAT:
type = Type::Class("Float");
break;
case DataType::DT_DOUBLE:
type = Type::Class("Double");
break;
case DataType::DT_UINT8:
type = Type::Class("UInt8", "org.tensorflow.types");
break;
case DataType::DT_INT32:
type = Type::Class("Integer");
break;
case DataType::DT_INT64:
type = Type::Class("Long");
break;
case DataType::DT_RESOURCE:
// TODO(karllessard) create a Resource utility class that could be
// used to store a resource and its type (passed in a second argument).
// For now, we need to force a wildcard and we will unfortunately lose
// track of the resource type.
break;
default:
// Any other datatypes does not have a equivalent in Java and must
// remain a wildcard (e.g. DT_COMPLEX64, DT_QINT8, ...)
break;
}
type = Type::ForDataType(arg_def.type());

} else if (!arg_def.type_attr().empty()) {
// resolve type from attribute (if already visited, retrieve its type)
if (IsAttributeVisited(arg_def.type_attr())) {
Expand Down Expand Up @@ -337,16 +304,16 @@ AttributeSpec CreateAttribute(const OpDef_AttrDef& attr_def,
bool iterable = false;
std::pair<Type, Type> types = type_resolver->TypesOf(attr_def, &iterable);
Type var_type = types.first.kind() == Type::GENERIC
? Type::Class("Class").add_parameter(types.first)
: types.first;
? Type::ClassOf(types.first) : types.first;
if (iterable) {
var_type = Type::ListOf(var_type);
}
return AttributeSpec(
attr_api_def.name(),
Variable::Create(SnakeToCamelCase(attr_api_def.rename_to()), var_type),
types.first, types.second, ParseDocumentation(attr_api_def.description()),
iterable, attr_api_def.has_default_value());
types.first, types.second,
ParseDocumentation(attr_api_def.description()), iterable,
attr_def.has_default_value() ? &attr_def.default_value() : nullptr);
}

ArgumentSpec CreateOutput(const OpDef_ArgDef& output_def,
Expand Down
14 changes: 9 additions & 5 deletions tensorflow/java/src/gen/cc/op_specs.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,26 +94,30 @@ class AttributeSpec {
// jni_type: the type of this attribute in JNI layer (see OperationBuilder)
// description: a description of this attribute, in javadoc
// iterable: true if this attribute is a list
// has_default_value: true if this attribute has a default value if not set
// default_value: default value for this attribute or nullptr if none. Any
// value referenced by this pointer must outlive the lifetime
// of the AttributeSpec. This is guaranteed if the value is
// issued by an OpDef of the global OpRegistry.
AttributeSpec(const string& op_def_name, const Variable& var,
const Type& type, const Type& jni_type,
const string& description, bool iterable,
bool has_default_value)
const AttrValue* default_value)
: op_def_name_(op_def_name),
var_(var),
type_(type),
description_(description),
iterable_(iterable),
jni_type_(jni_type),
has_default_value_(has_default_value) {}
default_value_(default_value) {}

const string& op_def_name() const { return op_def_name_; }
const Variable& var() const { return var_; }
const Type& type() const { return type_; }
const string& description() const { return description_; }
bool iterable() const { return iterable_; }
const Type& jni_type() const { return jni_type_; }
bool has_default_value() const { return has_default_value_; }
bool has_default_value() const { return default_value_ != nullptr; }
const AttrValue* default_value() const { return default_value_; }

private:
const string op_def_name_;
Expand All @@ -122,7 +126,7 @@ class AttributeSpec {
const string description_;
const bool iterable_;
const Type jni_type_;
const bool has_default_value_;
const AttrValue* default_value_;
};

class OpSpec {
Expand Down
1 change: 0 additions & 1 deletion tensorflow/java/src/gen/cc/source_writer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ limitations under the License.
#include <string>
#include <algorithm>
#include <list>
#include <string>

#include "tensorflow/java/src/gen/cc/source_writer.h"

Expand Down