Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 11 additions & 24 deletions tensorflow/java/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -70,52 +70,38 @@ filegroup(

tf_java_op_gen_srcjar(
name = "java_op_gen_sources",
gen_base_package = "org.tensorflow.op",
gen_tool = "java_op_gen_tool",
ops_libs = [
"array_ops",
"candidate_sampling_ops",
"control_flow_ops",
"data_flow_ops",
"image_ops",
"io_ops",
"linalg_ops",
"logging_ops",
"math_ops",
"nn_ops",
"no_op",
"parsing_ops",
"random_ops",
"sparse_ops",
"state_ops",
"string_ops",
"training_ops",
"user_ops",
api_def_srcs = [
"//tensorflow/core/api_def:base_api_def",
],
base_package = "org.tensorflow.op",
gen_tool = ":java_op_gen_tool",
)

# Build the gen tool as a library, as it will be linked to a core/ops binary
# file before making it an executable. See tf_java_op_gen_srcjar().
cc_library(
tf_cc_binary(
name = "java_op_gen_tool",
srcs = [
"src/gen/cc/op_gen_main.cc",
],
copts = tf_copts(),
linkopts = ["-lm"],
linkstatic = 1,
deps = [
":java_op_gen_lib",
"//tensorflow/core:ops",
],
)

cc_library(
name = "java_op_gen_lib",
srcs = [
"src/gen/cc/op_generator.cc",
"src/gen/cc/op_specs.cc",
"src/gen/cc/source_writer.cc",
],
hdrs = [
"src/gen/cc/java_defs.h",
"src/gen/cc/op_generator.h",
"src/gen/cc/op_specs.h",
"src/gen/cc/source_writer.h",
],
copts = tf_copts(),
Expand All @@ -124,6 +110,7 @@ cc_library(
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:op_gen_lib",
],
)

Expand Down
1 change: 1 addition & 0 deletions tensorflow/java/build_defs.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ JAVA_VERSION_OPTS = [
XLINT_OPTS = [
"-Werror",
"-Xlint:all",
"-Xlint:-processing",
"-Xlint:-serial",
"-Xlint:-try",
"-Xlint:-classfile", # see b/32750402, go/javac-warnings#classfile
Expand Down
76 changes: 42 additions & 34 deletions tensorflow/java/src/gen/cc/java_defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@ limitations under the License.

#include <string>
#include <list>
#include <map>
#include <utility>

namespace tensorflow {
namespace java {

// An enumeration of different modifiers commonly used in Java
enum Modifier {
PACKAGE = 0,
PUBLIC = (1 << 0),
PROTECTED = (1 << 1),
PRIVATE = (1 << 2),
Expand Down Expand Up @@ -72,6 +75,12 @@ class Type {
// Reflection API does
return Type(Type::PRIMITIVE, "void");
}
static Type Generic(const string& name) {
return Type(Type::GENERIC, name);
}
static Type Wildcard() {
return Type(Type::GENERIC, "");
}
static Type Class(const string& name, const string& package = "") {
return Type(Type::CLASS, name, package);
}
Expand All @@ -81,9 +90,6 @@ class Type {
static Type Enum(const string& name, const string& package = "") {
return Type(Type::ENUM, name, package);
}
static Type Generic(const string& name = "") {
return Type(Type::GENERIC, name);
}
static Type ClassOf(const Type& type) {
return Class("Class").add_parameter(type);
}
Expand All @@ -96,11 +102,10 @@ class Type {
const Kind& kind() const { return kind_; }
const string& name() const { return name_; }
const string& package() const { return package_; }
const string& description() const { return description_; }
Type& description(const string& description) {
description_ = description;
return *this;
const string canonical_name() const {
return package_.empty() ? name_ : package_ + "." + name_;
}
bool wildcard() const { return name_.empty(); } // only wildcards has no name
const std::list<Type>& parameters() const { return parameters_; }
Type& add_parameter(const Type& parameter) {
parameters_.push_back(parameter);
Expand All @@ -120,14 +125,6 @@ class Type {
}
return *this;
}
// Returns true if "type" is of a known collection type (only a few for now)
bool IsCollection() const {
return name_ == "List" || name_ == "Iterable";
}
// Returns true if this instance is a wildcard (<?>)
bool IsWildcard() const {
return kind_ == GENERIC && name_.empty();
}

protected:
Type(Kind kind, const string& name, const string& package = "")
Expand All @@ -137,7 +134,6 @@ class Type {
Kind kind_;
string name_;
string package_;
string description_;
std::list<Type> parameters_;
std::list<Annotation> annotations_;
std::list<Type> supertypes_;
Expand Down Expand Up @@ -180,16 +176,11 @@ class Variable {
const string& name() const { return name_; }
const Type& type() const { return type_; }
bool variadic() const { return variadic_; }
const string& description() const { return description_; }
Variable& description(const string& description) {
description_ = description;
return *this;
}

private:
string name_;
Type type_;
bool variadic_;
string description_;

Variable(const string& name, const Type& type, bool variadic)
: name_(name), type_(type), variadic_(variadic) {}
Expand All @@ -210,16 +201,6 @@ class Method {
bool constructor() const { return constructor_; }
const string& name() const { return name_; }
const Type& return_type() const { return return_type_; }
const string& description() const { return description_; }
Method& description(const string& description) {
description_ = description;
return *this;
}
const string& return_description() const { return return_description_; }
Method& return_description(const string& description) {
return_description_ = description;
return *this;
}
const std::list<Variable>& arguments() const { return arguments_; }
Method& add_argument(const Variable& var) {
arguments_.push_back(var);
Expand All @@ -235,15 +216,42 @@ class Method {
string name_;
Type return_type_;
bool constructor_;
string description_;
string return_description_;
std::list<Variable> arguments_;
std::list<Annotation> annotations_;

Method(const string& name, const Type& return_type, bool constructor)
: name_(name), return_type_(return_type), constructor_(constructor) {}
};

// A definition of a documentation bloc for a Java element (JavaDoc)
class Javadoc {
public:
static Javadoc Create(const string& brief = "") {
return Javadoc(brief);
}
const string& brief() const { return brief_; }
const string& details() const { return details_; }
Javadoc& details(const string& details) {
details_ = details;
return *this;
}
const std::list<std::pair<string, string>>& tags() const { return tags_; }
Javadoc& add_tag(const string& tag, const string& text) {
tags_.push_back(std::make_pair(tag, text));
return *this;
}
Javadoc& add_param_tag(const string& name, const string& text) {
return add_tag("param", name + " " + text);
}

private:
string brief_;
string details_;
std::list<std::pair<string, string>> tags_;

explicit Javadoc(const string& brief) : brief_(brief) {}
};

} // namespace java
} // namespace tensorflow

Expand Down
50 changes: 21 additions & 29 deletions tensorflow/java/src/gen/cc/op_gen_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,49 +36,41 @@ const char kUsageHeader[] =
"Operation wrappers are generated under the path specified by the "
"'--output_dir' argument. This path can be absolute or relative to the\n"
"current working directory and will be created if it does not exists.\n\n"
"The '--lib_name' argument is used to classify the set of operations. If "
"the chosen name contains more than one word, it must be provided in \n"
"snake_case. This value is declined into other meaningful names, such as "
"the group and package of the generated operations. For example,\n"
"'--lib_name=my_lib' generates the operations under the "
"'org.tensorflow.op.mylib' package and add them to the 'myLib()' operator\n"
"group.\n\n"
"Note that the operator group assigned to the generated wrappers is just "
"an annotation tag at this stage. Operations will not be available "
"through\n"
"the 'org.tensorflow.op.Ops' API as a group until the generated classes "
"are compiled using an appropriate annotation processor.\n\n"
"Finally, the '--base_package' overrides the default parent package "
"under which the generated subpackage and classes are to be located.\n\n";
"Note that the operations will not be available through the "
"'org.tensorflow.op.Ops' API until the generated classes are compiled\n"
"using an appropriate annotation processor.\n\n"
"The '--base_package' overrides the default parent package under which "
"the generated subpackage and classes are to be located.\n\n"
"Finally, the `--api_dirs` argument takes a list of comma-seperated "
"directories of API definitions can be provided to override default\n"
"values found in the ops definitions. Directories are ordered by priority "
"(the last having precedence over the first).\n\n";

} // namespace java
} // namespace tensorflow

int main(int argc, char* argv[]) {
tensorflow::string lib_name;
tensorflow::string output_dir;
tensorflow::string base_package = "org.tensorflow.op";
tensorflow::string api_dirs_str;
std::vector<tensorflow::Flag> flag_list = {
tensorflow::Flag("output_dir", &output_dir,
"Root directory into which output files are generated"),
tensorflow::Flag(
"lib_name", &lib_name,
"A name, in snake_case, used to classify this set of operations"),
tensorflow::Flag(
"base_package", &base_package,
"Package parent to the generated subpackage and classes")};
"Root directory into which output files are generated"),
tensorflow::Flag("base_package", &base_package,
"Package parent to the generated subpackage and classes"),
tensorflow::Flag("api_dirs", &api_dirs_str,
"List of directories that contains the ops api definitions")};
tensorflow::string usage = tensorflow::java::kUsageHeader;
usage += tensorflow::Flags::Usage(argv[0], flag_list);
bool parsed_flags_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
tensorflow::port::InitMain(usage.c_str(), &argc, &argv);
QCHECK(parsed_flags_ok && !lib_name.empty() && !output_dir.empty()) << usage;

tensorflow::java::OpGenerator generator;
QCHECK(parsed_flags_ok && !output_dir.empty()) << usage;
std::vector<tensorflow::string> api_dirs = tensorflow::str_util::Split(
api_dirs_str, ",", tensorflow::str_util::SkipEmpty());
tensorflow::java::OpGenerator generator(api_dirs);
tensorflow::OpList ops;
tensorflow::OpRegistry::Global()->Export(true, &ops);
tensorflow::Status status =
generator.Run(ops, lib_name, base_package, output_dir);
TF_QCHECK_OK(status);
tensorflow::OpRegistry::Global()->Export(false, &ops);
TF_CHECK_OK(generator.Run(ops, base_package, output_dir));

return 0;
}
Loading