Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Java] Graph environment decoupling in preparation of eager execution #24858

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 6 additions & 6 deletions tensorflow/java/BUILD
Expand Up @@ -147,11 +147,11 @@ tf_java_test(
)

tf_java_test(
name = "OperationBuilderTest",
name = "GraphOperationBuilderTest",
size = "small",
srcs = ["src/test/java/org/tensorflow/OperationBuilderTest.java"],
srcs = ["src/test/java/org/tensorflow/GraphOperationBuilderTest.java"],
javacopts = JAVACOPTS,
test_class = "org.tensorflow.OperationBuilderTest",
test_class = "org.tensorflow.GraphOperationBuilderTest",
deps = [
":tensorflow",
":testutil",
Expand All @@ -160,11 +160,11 @@ tf_java_test(
)

tf_java_test(
name = "OperationTest",
name = "GraphOperationTest",
size = "small",
srcs = ["src/test/java/org/tensorflow/OperationTest.java"],
srcs = ["src/test/java/org/tensorflow/GraphOperationTest.java"],
javacopts = JAVACOPTS,
test_class = "org.tensorflow.OperationTest",
test_class = "org.tensorflow.GraphOperationTest",
deps = [
":tensorflow",
":testutil",
Expand Down
15 changes: 7 additions & 8 deletions tensorflow/java/src/gen/cc/op_generator.cc
Expand Up @@ -160,12 +160,11 @@ void RenderSecondaryFactoryMethod(const OpSpec& op, const Type& op_class,
}
Method factory = Method::Create("create", return_type);
Javadoc factory_doc = Javadoc::Create(
"Factory method to create a class to wrap a new " + op_class.name() +
" operation to the graph, using "
"default output types.");
"Factory method to create a class wrapping a new " + op_class.name() +
" operation using default output types.");
Variable scope =
Variable::Create("scope", Type::Class("Scope", "org.tensorflow.op"));
AddArgument(scope, "current graph scope", &factory, &factory_doc);
AddArgument(scope, "current scope", &factory, &factory_doc);
std::stringstream factory_statement;
factory_statement << "return create(scope";
for (const ArgumentSpec& input : op.inputs()) {
Expand Down Expand Up @@ -202,11 +201,11 @@ void RenderFactoryMethods(const OpSpec& op, const Type& op_class,
SourceWriter* writer) {
Method factory = Method::Create("create", op_class);
Javadoc factory_doc =
Javadoc::Create("Factory method to create a class to wrap a new " +
op_class.name() + " operation to the graph.");
Javadoc::Create("Factory method to create a class wrapping a new " +
op_class.name() + " operation.");
Variable scope =
Variable::Create("scope", Type::Class("Scope", "org.tensorflow.op"));
AddArgument(scope, "current graph scope", &factory, &factory_doc);
AddArgument(scope, "current scope", &factory, &factory_doc);
for (const ArgumentSpec& input : op.inputs()) {
AddArgument(input.var(), input.description(), &factory, &factory_doc);
}
Expand All @@ -229,7 +228,7 @@ void RenderFactoryMethods(const OpSpec& op, const Type& op_class,
factory_doc.add_tag("return", "a new instance of " + op_class.name());

writer->BeginMethod(factory, PUBLIC | STATIC, &factory_doc);
writer->Append("OperationBuilder opBuilder = scope.graph().opBuilder(\"" +
writer->Append("OperationBuilder opBuilder = scope.env().opBuilder(\"" +
op.graph_op_name() + "\", scope.makeOpName(\"" +
op_class.name() + "\"));");
writer->EndLine();
Expand Down
Expand Up @@ -15,18 +15,6 @@

package org.tensorflow.processor;

import com.google.common.base.CaseFormat;
import com.google.common.base.Strings;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.Multimap;
import com.squareup.javapoet.ClassName;
import com.squareup.javapoet.FieldSpec;
import com.squareup.javapoet.JavaFile;
import com.squareup.javapoet.MethodSpec;
import com.squareup.javapoet.ParameterSpec;
import com.squareup.javapoet.TypeName;
import com.squareup.javapoet.TypeSpec;
import com.squareup.javapoet.TypeVariableName;
import java.io.IOException;
import java.util.Collection;
import java.util.Collections;
Expand All @@ -35,6 +23,7 @@
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import javax.annotation.processing.AbstractProcessor;
import javax.annotation.processing.Filer;
import javax.annotation.processing.Messager;
Expand All @@ -55,6 +44,19 @@
import javax.lang.model.util.Elements;
import javax.tools.Diagnostic.Kind;

import com.google.common.base.CaseFormat;
import com.google.common.base.Strings;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.Multimap;
import com.squareup.javapoet.ClassName;
import com.squareup.javapoet.FieldSpec;
import com.squareup.javapoet.JavaFile;
import com.squareup.javapoet.MethodSpec;
import com.squareup.javapoet.ParameterSpec;
import com.squareup.javapoet.TypeName;
import com.squareup.javapoet.TypeSpec;
import com.squareup.javapoet.TypeVariableName;

/**
* A compile-time Processor that aggregates classes annotated with {@link
* org.tensorflow.op.annotation.Operator} and generates the {@code Ops} convenience API. Please
Expand Down Expand Up @@ -148,11 +150,12 @@ public Set<String> getSupportedAnnotationTypes() {

private static final Pattern JAVADOC_TAG_PATTERN =
Pattern.compile("@(?:param|return|throws|exception|see)\\s+.*");
private static final TypeName T_OP = ClassName.get("org.tensorflow.op", "Op");
private static final TypeName T_OPS = ClassName.get("org.tensorflow.op", "Ops");
private static final TypeName T_OPERATOR =
ClassName.get("org.tensorflow.op.annotation", "Operator");
private static final TypeName T_SCOPE = ClassName.get("org.tensorflow.op", "Scope");
private static final TypeName T_GRAPH = ClassName.get("org.tensorflow", "Graph");
private static final TypeName T_EXEC_ENV = ClassName.get("org.tensorflow", "ExecutionEnvironment");
private static final TypeName T_STRING = ClassName.get(String.class);

private Filer filer;
Expand Down Expand Up @@ -272,9 +275,9 @@ private MethodSpec buildOpMethod(
private String buildOpMethodJavadoc(ClassName opClassName, ExecutableElement factoryMethod) {
StringBuilder javadoc = new StringBuilder();
javadoc
.append("Adds an {@link ")
.append("Builds an {@link ")
.append(opClassName.simpleName())
.append("} operation to the graph\n\n");
.append("} operation\n\n");

// Add all javadoc tags found in the operator factory method but the first one, which should be
// in all cases the
Expand Down Expand Up @@ -305,10 +308,10 @@ private static TypeSpec buildGroupClass(String group, Collection<MethodSpec> met
TypeSpec.classBuilder(CaseFormat.LOWER_CAMEL.to(CaseFormat.UPPER_CAMEL, group) + "Ops")
.addModifiers(Modifier.PUBLIC, Modifier.FINAL)
.addJavadoc(
"An API for adding {@code $L} operations to a {@link $T Graph}\n\n"
"An API for building {@code $L} operations as {@link $T Op}s\n\n"
+ "@see {@link $T}\n",
group,
T_GRAPH,
T_OP,
T_OPS)
.addMethods(methods)
.addMethod(ctorBuilder.build());
Expand All @@ -335,7 +338,7 @@ private static TypeSpec buildTopClass(
TypeSpec.classBuilder("Ops")
.addModifiers(Modifier.PUBLIC, Modifier.FINAL)
.addJavadoc(
"An API for building a {@link $T} with operation wrappers\n<p>\n"
"An API for building operations as {@link $T Op}s\n<p>\n"
+ "Any operation wrapper found in the classpath properly annotated as an"
+ "{@link $T @Operator} is exposed\n"
+ "by this API or one of its subgroup.\n<p>Example usage:\n<pre>{@code\n"
Expand Down Expand Up @@ -363,7 +366,7 @@ private static TypeSpec buildTopClass(
+ " sub.withName(\"bar\").constant(4); // \"sub/bar\"\n"
+ "}\n"
+ "}</pre>\n",
T_GRAPH,
T_OP,
T_OPERATOR)
.addMethods(methods)
.addMethod(ctorBuilder.build());
Expand All @@ -375,7 +378,7 @@ private static TypeSpec buildTopClass(
.returns(T_OPS)
.addStatement("return new $T(scope.withSubScope(childScopeName))", T_OPS)
.addJavadoc(
"Returns an API that adds operations to the graph with the provided name prefix.\n"
"Returns an API that builds operations with the provided name prefix.\n"
+ "\n@see {@link $T#withSubScope(String)}\n",
T_SCOPE)
.build());
Expand Down Expand Up @@ -415,17 +418,17 @@ private static TypeSpec buildTopClass(
.returns(entry.getValue())
.addStatement("return $L", entry.getKey())
.addJavadoc(
"Returns an API for adding {@code $L} operations to the graph\n", entry.getKey())
"Returns an API for building {@code $L} operations\n", entry.getKey())
.build());
}

opsBuilder.addMethod(
MethodSpec.methodBuilder("create")
.addModifiers(Modifier.PUBLIC, Modifier.STATIC)
.addParameter(T_GRAPH, "graph")
.addParameter(T_EXEC_ENV, "env")
.returns(T_OPS)
.addStatement("return new Ops(new $T(graph))", T_SCOPE)
.addJavadoc("Creates an API for adding operations to the provided {@code graph}\n")
.addStatement("return new Ops(new $T(env))", T_SCOPE)
.addJavadoc("Creates an API for building operations in the provided environment\n")
.build());

return opsBuilder.build();
Expand Down
@@ -0,0 +1,61 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

package org.tensorflow;

/**
* Base class for {@link Operation} implementations.
*
* <p>As opposed to {@link Operation} itself, this class is package private and
* therefore its usage is limited to internal purposes only.
*/
abstract class AbstractOperation implements Operation {

@Override
public String toString() {
return String.format("<%s '%s'>", type(), name());
}

/**
* Returns the native handle of the {@code outputIdx}th output of this operation.
*
* <p>The nature of the returned value varies depending on current the execution environment.
* <ul>
* <li>In eager mode, the value is a handle to the tensor returned at this output.</li>
* <li>In graph mode, the value is a handle to the operation itself, which should be paired with
* the index of the output when calling the native layer.</li>
* </ul>
*
* @param outputIdx index of the output in this operation
* @return a native handle, see method description for more details
*/
abstract long getUnsafeNativeHandle(int outputIdx);

/**
* Returns the shape of the tensor of the {code outputIdx}th output of this operation.
*
* @param outputIdx index of the output of this operation
* @return output tensor shape
*/
abstract long[] shape(int outputIdx);

/**
* Returns the datatype of the tensor of the {code outputIdx}th output of this operation.
*
* @param outputIdx index of the output of this operation
* @return output tensor datatype
*/
abstract DataType dtype(int outputIdx);
}
@@ -0,0 +1,33 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

package org.tensorflow;

/**
* Defines an environment for creating and executing TensorFlow {@link Operation}s.
*/
public interface ExecutionEnvironment {

/**
* Returns a builder to create a new {@link Operation}.
*
* @param type of the Operation (i.e., identifies the computation to be performed)
* @param name to refer to the created Operation in this environment scope.
* @return an {@link OperationBuilder} to create an Operation when {@link
* OperationBuilder#build()} is invoked. If {@link OperationBuilder#build()} is not invoked,
* then some resources may leak.
*/
OperationBuilder opBuilder(String type, String name);
}