Skip to content

Commit

Permalink
Decouple graph operations from their implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
karllessard committed Jan 8, 2019
1 parent 9b884d8 commit 0c3ab56
Show file tree
Hide file tree
Showing 22 changed files with 1,084 additions and 727 deletions.
12 changes: 6 additions & 6 deletions tensorflow/java/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,11 @@ tf_java_test(
)

tf_java_test(
name = "OperationBuilderTest",
name = "GraphNodeBuilderTest",
size = "small",
srcs = ["src/test/java/org/tensorflow/OperationBuilderTest.java"],
srcs = ["src/test/java/org/tensorflow/GraphNodeBuilderTest.java"],
javacopts = JAVACOPTS,
test_class = "org.tensorflow.OperationBuilderTest",
test_class = "org.tensorflow.GraphNodeBuilderTest",
deps = [
":tensorflow",
":testutil",
Expand All @@ -160,11 +160,11 @@ tf_java_test(
)

tf_java_test(
name = "OperationTest",
name = "GraphNodeTest",
size = "small",
srcs = ["src/test/java/org/tensorflow/OperationTest.java"],
srcs = ["src/test/java/org/tensorflow/GraphNodeTest.java"],
javacopts = JAVACOPTS,
test_class = "org.tensorflow.OperationTest",
test_class = "org.tensorflow.GraphNodeTest",
deps = [
":tensorflow",
":testutil",
Expand Down
15 changes: 7 additions & 8 deletions tensorflow/java/src/gen/cc/op_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,12 +160,11 @@ void RenderSecondaryFactoryMethod(const OpSpec& op, const Type& op_class,
}
Method factory = Method::Create("create", return_type);
Javadoc factory_doc = Javadoc::Create(
"Factory method to create a class to wrap a new " + op_class.name() +
" operation to the graph, using "
"default output types.");
"Factory method to create a class wrapping a new " + op_class.name() +
" operation using default output types.");
Variable scope =
Variable::Create("scope", Type::Class("Scope", "org.tensorflow.op"));
AddArgument(scope, "current graph scope", &factory, &factory_doc);
AddArgument(scope, "current scope", &factory, &factory_doc);
std::stringstream factory_statement;
factory_statement << "return create(scope";
for (const ArgumentSpec& input : op.inputs()) {
Expand Down Expand Up @@ -202,11 +201,11 @@ void RenderFactoryMethods(const OpSpec& op, const Type& op_class,
SourceWriter* writer) {
Method factory = Method::Create("create", op_class);
Javadoc factory_doc =
Javadoc::Create("Factory method to create a class to wrap a new " +
op_class.name() + " operation to the graph.");
Javadoc::Create("Factory method to create a class wrapping a new " +
op_class.name() + " operation.");
Variable scope =
Variable::Create("scope", Type::Class("Scope", "org.tensorflow.op"));
AddArgument(scope, "current graph scope", &factory, &factory_doc);
AddArgument(scope, "current scope", &factory, &factory_doc);
for (const ArgumentSpec& input : op.inputs()) {
AddArgument(input.var(), input.description(), &factory, &factory_doc);
}
Expand All @@ -229,7 +228,7 @@ void RenderFactoryMethods(const OpSpec& op, const Type& op_class,
factory_doc.add_tag("return", "a new instance of " + op_class.name());

writer->BeginMethod(factory, PUBLIC | STATIC, &factory_doc);
writer->Append("OperationBuilder opBuilder = scope.graph().opBuilder(\"" +
writer->Append("OperationBuilder opBuilder = scope.env().opBuilder(\"" +
op.graph_op_name() + "\", scope.makeOpName(\"" +
op_class.name() + "\"));");
writer->EndLine();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,6 @@

package org.tensorflow.processor;

import com.google.common.base.CaseFormat;
import com.google.common.base.Strings;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.Multimap;
import com.squareup.javapoet.ClassName;
import com.squareup.javapoet.FieldSpec;
import com.squareup.javapoet.JavaFile;
import com.squareup.javapoet.MethodSpec;
import com.squareup.javapoet.ParameterSpec;
import com.squareup.javapoet.TypeName;
import com.squareup.javapoet.TypeSpec;
import com.squareup.javapoet.TypeVariableName;
import java.io.IOException;
import java.util.Collection;
import java.util.Collections;
Expand All @@ -35,6 +23,7 @@
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import javax.annotation.processing.AbstractProcessor;
import javax.annotation.processing.Filer;
import javax.annotation.processing.Messager;
Expand All @@ -55,6 +44,19 @@
import javax.lang.model.util.Elements;
import javax.tools.Diagnostic.Kind;

import com.google.common.base.CaseFormat;
import com.google.common.base.Strings;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.Multimap;
import com.squareup.javapoet.ClassName;
import com.squareup.javapoet.FieldSpec;
import com.squareup.javapoet.JavaFile;
import com.squareup.javapoet.MethodSpec;
import com.squareup.javapoet.ParameterSpec;
import com.squareup.javapoet.TypeName;
import com.squareup.javapoet.TypeSpec;
import com.squareup.javapoet.TypeVariableName;

/**
* A compile-time Processor that aggregates classes annotated with {@link
* org.tensorflow.op.annotation.Operator} and generates the {@code Ops} convenience API. Please
Expand Down Expand Up @@ -148,11 +150,12 @@ public Set<String> getSupportedAnnotationTypes() {

private static final Pattern JAVADOC_TAG_PATTERN =
Pattern.compile("@(?:param|return|throws|exception|see)\\s+.*");
private static final TypeName T_OP = ClassName.get("org.tensorflow.op", "Op");
private static final TypeName T_OPS = ClassName.get("org.tensorflow.op", "Ops");
private static final TypeName T_OPERATOR =
ClassName.get("org.tensorflow.op.annotation", "Operator");
private static final TypeName T_SCOPE = ClassName.get("org.tensorflow.op", "Scope");
private static final TypeName T_GRAPH = ClassName.get("org.tensorflow", "Graph");
private static final TypeName T_EXEC_ENV = ClassName.get("org.tensorflow", "ExecutionEnvironment");
private static final TypeName T_STRING = ClassName.get(String.class);

private Filer filer;
Expand Down Expand Up @@ -272,9 +275,9 @@ private MethodSpec buildOpMethod(
private String buildOpMethodJavadoc(ClassName opClassName, ExecutableElement factoryMethod) {
StringBuilder javadoc = new StringBuilder();
javadoc
.append("Adds an {@link ")
.append("Builds an {@link ")
.append(opClassName.simpleName())
.append("} operation to the graph\n\n");
.append("} operation\n\n");

// Add all javadoc tags found in the operator factory method but the first one, which should be
// in all cases the
Expand Down Expand Up @@ -305,10 +308,10 @@ private static TypeSpec buildGroupClass(String group, Collection<MethodSpec> met
TypeSpec.classBuilder(CaseFormat.LOWER_CAMEL.to(CaseFormat.UPPER_CAMEL, group) + "Ops")
.addModifiers(Modifier.PUBLIC, Modifier.FINAL)
.addJavadoc(
"An API for adding {@code $L} operations to a {@link $T Graph}\n\n"
"An API for building {@code $L} operations as {@link $T Op}s\n\n"
+ "@see {@link $T}\n",
group,
T_GRAPH,
T_OP,
T_OPS)
.addMethods(methods)
.addMethod(ctorBuilder.build());
Expand All @@ -335,7 +338,7 @@ private static TypeSpec buildTopClass(
TypeSpec.classBuilder("Ops")
.addModifiers(Modifier.PUBLIC, Modifier.FINAL)
.addJavadoc(
"An API for building a {@link $T} with operation wrappers\n<p>\n"
"An API for building operations as {@link $T Op}s\n<p>\n"
+ "Any operation wrapper found in the classpath properly annotated as an"
+ "{@link $T @Operator} is exposed\n"
+ "by this API or one of its subgroup.\n<p>Example usage:\n<pre>{@code\n"
Expand Down Expand Up @@ -363,7 +366,7 @@ private static TypeSpec buildTopClass(
+ " sub.withName(“bar”).constant(4); // “sub/bar”\n"
+ "}\n"
+ "}</pre>\n",
T_GRAPH,
T_OP,
T_OPERATOR)
.addMethods(methods)
.addMethod(ctorBuilder.build());
Expand All @@ -375,7 +378,7 @@ private static TypeSpec buildTopClass(
.returns(T_OPS)
.addStatement("return new $T(scope.withSubScope(childScopeName))", T_OPS)
.addJavadoc(
"Returns an API that adds operations to the graph with the provided name prefix.\n"
"Returns an API that builds operations with the provided name prefix.\n"
+ "\n@see {@link $T#withSubScope(String)}\n",
T_SCOPE)
.build());
Expand Down Expand Up @@ -415,17 +418,17 @@ private static TypeSpec buildTopClass(
.returns(entry.getValue())
.addStatement("return $L", entry.getKey())
.addJavadoc(
"Returns an API for adding {@code $L} operations to the graph\n", entry.getKey())
"Returns an API for building {@code $L} operations\n", entry.getKey())
.build());
}

opsBuilder.addMethod(
MethodSpec.methodBuilder("create")
.addModifiers(Modifier.PUBLIC, Modifier.STATIC)
.addParameter(T_GRAPH, "graph")
.addParameter(T_EXEC_ENV, "env")
.returns(T_OPS)
.addStatement("return new Ops(new $T(graph))", T_SCOPE)
.addJavadoc("Creates an API for adding operations to the provided {@code graph}\n")
.addStatement("return new Ops(new $T(env))", T_SCOPE)
.addJavadoc("Creates an API for building operations in the provided environment\n")
.build());

return opsBuilder.build();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

package org.tensorflow;

/**
* Base class for {@link Operation} implementations.
*
* <p>As opposed to {@link Operation} itself, this class is package private and
* therefore its usage is limited to internal purposes only.
*/
abstract class AbstractOperation implements Operation {

AbstractOperation(long unsafeNativeHandle) {
this.unsafeNativeHandle = unsafeNativeHandle;
}

@Override
public String toString() {
return String.format("<%s '%s'>", type(), name());
}

/**
* Returns the native handle of the {@code outputIdx}th output of this operation.
*
* <p>The nature of the returned value varies depending on current the execution environment.
* <ul>
* <li>In eager mode, the value is a handle to the tensor returned at this output.</li>
* <li>In graph mode, the value is a handle to the operation itself, which should be paired with
* the index of the output when calling the native layer.</li>
* </ul>
*
* @param outputIdx index of the output in this operation
* @return a native handle, see method description for more details
*/
abstract long getUnsafeNativeHandle(int outputIdx);

/**
* Returns the shape of the tensor of the {code outputIdx}th output of this operation.
*
* @param outputIdx index of the output of this operation
* @return output tensor shape
*/
abstract long[] shape(int outputIdx);

/**
* Returns the datatype of the tensor of the {code outputIdx}th output of this operation.
*
* @param outputIdx index of the output of this operation
* @return output tensor datatype
*/
abstract DataType dtype(int outputIdx);

long getUnsafeNativeHandle() {
return unsafeNativeHandle;
}

private final long unsafeNativeHandle;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

package org.tensorflow;

/**
* Defines an environment for creating and executing TensorFlow {@link Operation}s.
*/
public interface ExecutionEnvironment {

/**
* Returns a builder to create a new {@link Operation}.
*
* @param type of the Operation (i.e., identifies the computation to be performed)
* @param name to refer to the created Operation in this environment scope.
* @return an {@link OperationBuilder} to create an Operation when {@link
* OperationBuilder#build()} is invoked. If {@link OperationBuilder#build()} is not invoked,
* then some resources may leak.
*/
OperationBuilder opBuilder(String type, String name);
}
21 changes: 11 additions & 10 deletions tensorflow/java/src/main/java/org/tensorflow/Graph.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
* <p><b>WARNING:</b> Resources consumed by the Graph object must be explicitly freed by invoking
* the {@link #close()} method then the Graph object is no longer needed.
*/
public final class Graph implements AutoCloseable {
public final class Graph implements ExecutionEnvironment, AutoCloseable {

/** Create an empty Graph. */
public Graph() {
Expand Down Expand Up @@ -68,13 +68,13 @@ public void close() {
*
* <p>Or {@code null} if no such operation exists in the Graph.
*/
public Operation operation(String name) {
public GraphNode operation(String name) {
synchronized (nativeHandleLock) {
long oph = operation(nativeHandle, name);
if (oph == 0) {
return null;
}
return new Operation(this, oph);
return new GraphNode(this, oph);
}
}

Expand All @@ -97,8 +97,9 @@ public Iterator<Operation> operations() {
* OperationBuilder#build()} is invoked. If {@link OperationBuilder#build()} is not invoked,
* then some resources may leak.
*/
public OperationBuilder opBuilder(String type, String name) {
return new OperationBuilder(this, type, name);
@Override
public GraphNodeBuilder opBuilder(String type, String name) {
return new GraphNodeBuilder(this, type, name);
}

/**
Expand Down Expand Up @@ -177,19 +178,19 @@ public Output<?>[] addGradients(String prefix, Output<?>[] y, Output<?>[] x, Out

try (Reference ref = ref()) {
for (int i = 0; i < y.length; ++i) {
yHandles[i] = y[i].op().getUnsafeNativeHandle();
yHandles[i] = y[i].getUnsafeNativeHandle();
yIndices[i] = y[i].index();
}
for (int i = 0; i < x.length; ++i) {
xHandles[i] = x[i].op().getUnsafeNativeHandle();
xHandles[i] = x[i].getUnsafeNativeHandle();
xIndices[i] = x[i].index();
}
if (dx != null && dx.length > 0) {
dxHandles = new long[dx.length];
dxIndices = new int[dx.length];

for (int i = 0; i < dx.length; ++i) {
dxHandles[i] = dx[i].op().getUnsafeNativeHandle();
dxHandles[i] = dx[i].getUnsafeNativeHandle();
dxIndices[i] = dx[i].index();
}
}
Expand All @@ -214,7 +215,7 @@ public Output<?>[] addGradients(String prefix, Output<?>[] y, Output<?>[] x, Out
+ " were expected");
}
for (int i = 0, j = ndy; i < ndy; ++i, ++j) {
Operation op = new Operation(this, dyHandlesAndIndices[i]);
GraphNode op = new GraphNode(this, dyHandlesAndIndices[i]);
dy[i] = new Output<>(op, (int) dyHandlesAndIndices[j]);
}
}
Expand Down Expand Up @@ -302,7 +303,7 @@ private final void advance() {
long[] nativeReturn = nextOperation(reference.nativeHandle(), this.position);

if ((nativeReturn != null) && (nativeReturn[0] != 0)) {
this.operation = new Operation(this.graph, nativeReturn[0]);
this.operation = new GraphNode(this.graph, nativeReturn[0]);
this.position = (int) nativeReturn[1];
}
} finally {
Expand Down

0 comments on commit 0c3ab56

Please sign in to comment.