diff --git a/tensorflow/java/src/main/java/org/tensorflow/AbstractOperation.java b/tensorflow/java/src/main/java/org/tensorflow/AbstractOperation.java index 0d4745fe0b7a63..f586dae73e0ba9 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/AbstractOperation.java +++ b/tensorflow/java/src/main/java/org/tensorflow/AbstractOperation.java @@ -60,7 +60,7 @@ public String toString() { abstract long getUnsafeNativeHandle(int outputIdx); /** - * Returns the shape of the tensor of the {code outputIdx}th output of this operation. + * Returns the shape of the tensor of the {@code outputIdx}th output of this operation. * * @param outputIdx index of the output of this operation * @return output tensor shape @@ -68,10 +68,20 @@ public String toString() { abstract long[] shape(int outputIdx); /** - * Returns the datatype of the tensor of the {code outputIdx}th output of this operation. + * Returns the datatype of the tensor of the {@code outputIdx}th output of this operation. * * @param outputIdx index of the output of this operation * @return output tensor datatype */ abstract DataType dtype(int outputIdx); + + /** + * Returns the tensor of the {@code outputIdx}th output of this operation. + * + *

This is only supported in an eager execution environment. + * + * @param outputIdx index of the output of this operation + * @return output tensor + */ + abstract Tensor tensor(int outputIdx); } diff --git a/tensorflow/java/src/main/java/org/tensorflow/EagerOperation.java b/tensorflow/java/src/main/java/org/tensorflow/EagerOperation.java index e989c00bb0cc0c..9c8c59ec4e7b17 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/EagerOperation.java +++ b/tensorflow/java/src/main/java/org/tensorflow/EagerOperation.java @@ -15,7 +15,7 @@ package org.tensorflow; -import java.util.Arrays; +import java.util.concurrent.atomic.AtomicReferenceArray; /** * Implementation of an {@link Operation} executed eagerly. @@ -38,6 +38,7 @@ class EagerOperation extends AbstractOperation { this.type = type; this.name = name; this.nativeRef = new NativeReference(session, this, opNativeHandle, outputNativeHandles); + this.outputTensors = new AtomicReferenceArray>(outputNativeHandles.length); } @Override @@ -72,6 +73,12 @@ public long getUnsafeNativeHandle(int outputIndex) { @Override public long[] shape(int outputIndex) { + // If the tensor of this output has already been resolved, return its shape. + // Otherwise, retrieve the tensor shape from the native library. + Tensor tensor = outputTensors.get(outputIndex); + if (tensor != null) { + return tensor.shape(); + } long outputNativeHandle = getUnsafeNativeHandle(outputIndex); long[] shape = new long[numDims(outputNativeHandle)]; for (int i = 0; i < shape.length; ++i) { @@ -82,10 +89,43 @@ public long[] shape(int outputIndex) { @Override public DataType dtype(int outputIndex) { + // If the tensor of this output has already been resolved, return its datatype. + // Otherwise, retrieve the tensor datatype from the native library. + Tensor tensor = outputTensors.get(outputIndex); + if (tensor != null) { + return tensor.dataType(); + } long outputNativeHandle = getUnsafeNativeHandle(outputIndex); return DataType.fromC(dataType(outputNativeHandle)); } + @Override + public Tensor tensor(int outputIndex) { + Tensor tensor = outputTensors.get(outputIndex); + if (tensor == null) { + tensor = resolveTensor(outputIndex); + } + return tensor; + } + + private final EagerSession session; + private final NativeReference nativeRef; + private final String type; + private final String name; + private final AtomicReferenceArray> outputTensors; + + private Tensor resolveTensor(int outputIndex) { + // Take an optimistic approach, where we attempt to resolve the output tensor without locking. + // If another thread has resolved it meanwhile, release our copy and reuse the existing one instead. + long tensorNativeHandle = resolveTensorHandle(getUnsafeNativeHandle(outputIndex)); + Tensor tensor = Tensor.fromHandle(tensorNativeHandle, session); + if (!outputTensors.compareAndSet(outputIndex, null, tensor)) { + tensor.close(); + tensor = outputTensors.get(outputIndex); + } + return tensor; + } + private static class NativeReference extends EagerSession.NativeReference { NativeReference( @@ -98,29 +138,26 @@ private static class NativeReference extends EagerSession.NativeReference { @Override void delete() { if (opHandle != 0L) { - for (long tensorHandle : outputHandles) { - if (tensorHandle != 0L) { - EagerOperation.deleteTensorHandle(tensorHandle); + for (int i = 0; i < outputHandles.length; ++i) { + if (outputHandles[i] != 0L) { + EagerOperation.deleteTensorHandle(outputHandles[i]); + outputHandles[i] = 0L; } } EagerOperation.delete(opHandle); opHandle = 0L; - Arrays.fill(outputHandles, 0L); } } private long opHandle; private final long[] outputHandles; } - - private final EagerSession session; - private final NativeReference nativeRef; - private final String type; - private final String name; - + private static native void delete(long handle); private static native void deleteTensorHandle(long handle); + + private static native long resolveTensorHandle(long handle); private static native int outputListLength(long handle, String name); diff --git a/tensorflow/java/src/main/java/org/tensorflow/GraphOperation.java b/tensorflow/java/src/main/java/org/tensorflow/GraphOperation.java index 0e43bc3eb43c18..590eff8a83ef8a 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/GraphOperation.java +++ b/tensorflow/java/src/main/java/org/tensorflow/GraphOperation.java @@ -138,6 +138,11 @@ DataType dtype(int outputIdx) { r.close(); } } + + @Override + Tensor tensor(int outputIdx) { + throw new IllegalStateException("Graph tensors must be fetched by running a session"); + } long getUnsafeNativeHandle() { return unsafeNativeHandle; diff --git a/tensorflow/java/src/main/java/org/tensorflow/Output.java b/tensorflow/java/src/main/java/org/tensorflow/Output.java index 15bb2e89e8d6b9..90668bb7ad3408 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/Output.java +++ b/tensorflow/java/src/main/java/org/tensorflow/Output.java @@ -47,6 +47,22 @@ public Shape shape() { public DataType dataType() { return operation.dtype(index); } + + /** + * Returns the tensor at this output. + * + *

This operation is only supported on the outputs of an operation executed eagerly. + * For graph environments, output tensors must be fetched by running a session, using + * {@link Session.Runner#fetch(Output)}. + * + * @return tensor + * @throws IllegalStateException if this output results from a graph + * @see EagerSession + */ + @SuppressWarnings("unchecked") + public Tensor tensor() { + return (Tensor)operation.tensor(index); + } @Override public Output asOutput() { diff --git a/tensorflow/java/src/main/java/org/tensorflow/Tensor.java b/tensorflow/java/src/main/java/org/tensorflow/Tensor.java index 89872537689815..253ceb65781896 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/Tensor.java +++ b/tensorflow/java/src/main/java/org/tensorflow/Tensor.java @@ -140,15 +140,17 @@ private static Tensor create(Object obj, DataType dtype) { Tensor t = new Tensor(dtype); t.shapeCopy = new long[numDimensions(obj, dtype)]; fillShape(obj, 0, t.shapeCopy); + long nativeHandle; if (t.dtype != DataType.STRING) { int byteSize = elemByteSize(t.dtype) * numElements(t.shapeCopy); - t.nativeHandle = allocate(t.dtype.c(), t.shapeCopy, byteSize); - setValue(t.nativeHandle, obj); + nativeHandle = allocate(t.dtype.c(), t.shapeCopy, byteSize); + setValue(nativeHandle, obj); } else if (t.shapeCopy.length != 0) { - t.nativeHandle = allocateNonScalarBytes(t.shapeCopy, (Object[]) obj); + nativeHandle = allocateNonScalarBytes(t.shapeCopy, (Object[]) obj); } else { - t.nativeHandle = allocateScalarBytes((byte[]) obj); + nativeHandle = allocateScalarBytes((byte[]) obj); } + t.nativeRef = new NativeReference(nativeHandle); return t; } @@ -314,23 +316,22 @@ private static Tensor allocateForBuffer(DataType dataType, long[] shape, } Tensor t = new Tensor(dataType); t.shapeCopy = Arrays.copyOf(shape, shape.length); - t.nativeHandle = allocate(t.dtype.c(), t.shapeCopy, nbytes); + long nativeHandle = allocate(t.dtype.c(), t.shapeCopy, nbytes); + t.nativeRef = new NativeReference(nativeHandle); return t; } /** * Release resources associated with the Tensor. * - *

WARNING:If not invoked, memory will be leaked. + *

WARNING:This must be invoked for all tensors that were not been produced by an eager + * operation or memory will be leaked. * *

The Tensor object is no longer usable after {@code close} returns. */ @Override public void close() { - if (nativeHandle != 0) { - delete(nativeHandle); - nativeHandle = 0; - } + nativeRef.release(); } /** Returns the {@link DataType} of elements stored in the Tensor. */ @@ -374,7 +375,7 @@ public long[] shape() { * @throws IllegalArgumentException if the Tensor does not represent a float scalar. */ public float floatValue() { - return scalarFloat(nativeHandle); + return scalarFloat(getNativeHandle()); } /** @@ -383,7 +384,7 @@ public float floatValue() { * @throws IllegalArgumentException if the Tensor does not represent a double scalar. */ public double doubleValue() { - return scalarDouble(nativeHandle); + return scalarDouble(getNativeHandle()); } /** @@ -392,7 +393,7 @@ public double doubleValue() { * @throws IllegalArgumentException if the Tensor does not represent a int scalar. */ public int intValue() { - return scalarInt(nativeHandle); + return scalarInt(getNativeHandle()); } /** @@ -401,7 +402,7 @@ public int intValue() { * @throws IllegalArgumentException if the Tensor does not represent a long scalar. */ public long longValue() { - return scalarLong(nativeHandle); + return scalarLong(getNativeHandle()); } /** @@ -410,7 +411,7 @@ public long longValue() { * @throws IllegalArgumentException if the Tensor does not represent a boolean scalar. */ public boolean booleanValue() { - return scalarBoolean(nativeHandle); + return scalarBoolean(getNativeHandle()); } /** @@ -419,7 +420,7 @@ public boolean booleanValue() { * @throws IllegalArgumentException if the Tensor does not represent a boolean scalar. */ public byte[] bytesValue() { - return scalarBytes(nativeHandle); + return scalarBytes(getNativeHandle()); } /** @@ -448,7 +449,7 @@ public byte[] bytesValue() { */ public U copyTo(U dst) { throwExceptionIfTypeIsIncompatible(dst); - readNDArray(nativeHandle, dst); + readNDArray(getNativeHandle(), dst); return dst; } @@ -553,16 +554,27 @@ static Tensor fromHandle(long handle) { @SuppressWarnings("rawtypes") Tensor t = new Tensor(DataType.fromC(dtype(handle))); t.shapeCopy = shape(handle); - t.nativeHandle = handle; + t.nativeRef = new NativeReference(handle); + return t; + } + + /** + * Create an eager Tensor object from a handle to the C TF_Tensor object. + * + *

Takes ownership of the handle. + */ + static Tensor fromHandle(long handle, EagerSession session) { + Tensor t = fromHandle(handle); + t.nativeRef.eager(session, t); return t; } long getNativeHandle() { - return nativeHandle; + return nativeRef.tensorHandle; } - private long nativeHandle; - private DataType dtype; + private NativeReference nativeRef = null; + private final DataType dtype; private long[] shapeCopy = null; private Tensor(DataType t) { @@ -570,7 +582,7 @@ private Tensor(DataType t) { } private ByteBuffer buffer() { - return buffer(nativeHandle).order(ByteOrder.nativeOrder()); + return buffer(getNativeHandle()).order(ByteOrder.nativeOrder()); } private static IllegalArgumentException incompatibleBuffer(Buffer buf, DataType dataType) { @@ -609,6 +621,66 @@ private static void throwExceptionIfNotByteOfByteArrays(Object array) { } } + /** + * Reference to the underlying native tensor + * + *

Tensors are commonly allocated in a `try-with-resources` statement, where they get automatically + * released after executing the last line of the `try` block they were declared in. + * + *

They can also be attached to an eager session, where in this case their lifetime ends either when + * this session is closed or when the Tensor instance is no longer referenced and have been garbage-collected. + * + *

This helper class wraps the tensor native handle and support both situations; If an eager reference to + * the tensor exists, it will take care of releasing the tensor at the end of its life. If the tensor is + * being explicetly closed before this happens, it will take cake of clearing its association with any eager + * session before cleaning up the resources. + */ + private static class NativeReference { + + /** + * Attaches this reference to an eager session + */ + private class EagerReference extends EagerSession.NativeReference { + + EagerReference(EagerSession session, Tensor tensor) { + super(session, tensor); + } + + @Override + void delete() { + // Mark this eager reference as cleared since it has been deleted by the session + NativeReference.this.eagerRef = null; + NativeReference.this.release(); + } + } + + NativeReference(long tensorHandle) { + this.tensorHandle = tensorHandle; + } + + void eager(EagerSession session, Tensor tensor) { + if (eagerRef != null) { + throw new IllegalStateException("The tensor is already attached to an eager session"); + } + eagerRef = new EagerReference(session, tensor); + } + + synchronized void release() { + if (tensorHandle != 0L) { + // Clear any remaining eager reference to this tensor + if (eagerRef != null) { + eagerRef.clear(); + eagerRef = null; + } + Tensor.delete(tensorHandle); + tensorHandle = 0L; + } + } + + private long tensorHandle; + private EagerReference eagerRef; + } + private static HashMap, DataType> classDataTypes = new HashMap<>(); static { diff --git a/tensorflow/java/src/main/native/eager_operation_jni.cc b/tensorflow/java/src/main/native/eager_operation_jni.cc index d5545e25718688..2dbe81efd358b5 100644 --- a/tensorflow/java/src/main/native/eager_operation_jni.cc +++ b/tensorflow/java/src/main/native/eager_operation_jni.cc @@ -59,6 +59,22 @@ JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperation_deleteTensorHandle( TFE_DeleteTensorHandle(reinterpret_cast(handle)); } +JNIEXPORT jlong JNICALL Java_org_tensorflow_EagerOperation_resolveTensorHandle( + JNIEnv* env, jclass clazz, jlong handle) { + TFE_TensorHandle* tensor_handle = requireTensorHandle(env, handle); + if (tensor_handle == nullptr) return 0; + TF_Status* status = TF_NewStatus(); + TF_Tensor* tensor = TFE_TensorHandleResolve(tensor_handle, status); + if (!throwExceptionIfNotOK(env, status)) { + TF_DeleteStatus(status); + return 0; + } + TF_DeleteStatus(status); + static_assert(sizeof(jlong) >= sizeof(TF_Tensor*), + "Cannot represent a C TF_Tensor as a Java long"); + return reinterpret_cast(tensor); +} + JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_outputListLength( JNIEnv* env, jclass clazz, jlong handle, jstring name) { TFE_Op* op = requireOp(env, handle); diff --git a/tensorflow/java/src/main/native/eager_operation_jni.h b/tensorflow/java/src/main/native/eager_operation_jni.h index 732883aceeedf8..a6924f5aa3d94c 100644 --- a/tensorflow/java/src/main/native/eager_operation_jni.h +++ b/tensorflow/java/src/main/native/eager_operation_jni.h @@ -38,6 +38,14 @@ JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperation_delete(JNIEnv *, JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperation_deleteTensorHandle(JNIEnv *, jclass, jlong); +/** + * Class: org_tensorflow_EagerOperation + * Method: resolveTensorHandle + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL Java_org_tensorflow_EagerOperation_resolveTensorHandle( + JNIEnv *, jclass, jlong); + /** * Class: org_tensorflow_EagerOperation * Method: outputListLength diff --git a/tensorflow/java/src/test/java/org/tensorflow/EagerOperationTest.java b/tensorflow/java/src/test/java/org/tensorflow/EagerOperationTest.java index 1dabbb765715d3..41b0ed3936d4ac 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/EagerOperationTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/EagerOperationTest.java @@ -53,6 +53,22 @@ public void outputDataTypeAndShape() { } } + @Test + public void outputTensor() { + try (EagerSession session = EagerSession.create()) { + EagerOperation add = opBuilder(session, "Add", "CompareResult") + .addInput(TestUtil.constant(session, "Const1", 2)) + .addInput(TestUtil.constant(session, "Const2", 4)) + .build(); + assertEquals(6, add.tensor(0).intValue()); + + // Validate that we retrieve the right shape and datatype from the tensor + // that has been resolved + assertEquals(0, add.shape(0).length); + assertEquals(DataType.INT32, add.dtype(0)); + } + } + @Test public void inputAndOutputListLengths() { try (EagerSession session = EagerSession.create()) { @@ -107,11 +123,10 @@ public void numOutputs() { @Test public void opNotAccessibleIfSessionIsClosed() { EagerSession session = EagerSession.create(); - EagerOperation add = - opBuilder(session, "Add", "SetDevice") - .addInput(TestUtil.constant(session, "Const1", 2)) - .addInput(TestUtil.constant(session, "Const2", 4)) - .build(); + EagerOperation add = opBuilder(session, "Add", "SessionClosed") + .addInput(TestUtil.constant(session, "Const1", 2)) + .addInput(TestUtil.constant(session, "Const2", 4)) + .build(); assertEquals(1, add.outputListLength("z")); session.close(); try { @@ -121,7 +136,41 @@ public void opNotAccessibleIfSessionIsClosed() { // expected } } - + + @Test + public void outputIndexOutOfBounds() { + try (EagerSession session = EagerSession.create()) { + EagerOperation add = opBuilder(session, "Add", "OutOfRange") + .addInput(TestUtil.constant(session, "Const1", 2)) + .addInput(TestUtil.constant(session, "Const2", 4)) + .build(); + try { + add.getUnsafeNativeHandle(1); + fail(); + } catch (IndexOutOfBoundsException e) { + // expected + } + try { + add.shape(1); + fail(); + } catch (IndexOutOfBoundsException e) { + // expected + } + try { + add.dtype(1); + fail(); + } catch (IndexOutOfBoundsException e) { + // expected + } + try { + add.tensor(1); + fail(); + } catch (IndexOutOfBoundsException e) { + // expected + } + } + } + private static EagerOperationBuilder opBuilder(EagerSession session, String type, String name) { return new EagerOperationBuilder(session, type, name); } diff --git a/tensorflow/java/src/test/java/org/tensorflow/GraphOperationTest.java b/tensorflow/java/src/test/java/org/tensorflow/GraphOperationTest.java index 7331ad50e51542..bfbf5385b48c84 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/GraphOperationTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/GraphOperationTest.java @@ -166,6 +166,17 @@ public void outputList() { } } } + + @Test + public void outputTensorNotSupported() { + try (Graph g = new Graph()) { + Operation split = TestUtil.split(g, "split", new int[] {0, 1, 2}, 3); + try { + split.output(0).tensor(); + fail(); + } catch (IllegalStateException e) {} + } + } private static int split(int[] values, int num_split) { try (Graph g = new Graph()) { diff --git a/tensorflow/java/src/test/java/org/tensorflow/TensorTest.java b/tensorflow/java/src/test/java/org/tensorflow/TensorTest.java index 3229cce2776dd3..21f4e25f5ab18c 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/TensorTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/TensorTest.java @@ -18,6 +18,7 @@ import static java.nio.charset.StandardCharsets.UTF_8; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -28,6 +29,7 @@ import java.nio.FloatBuffer; import java.nio.IntBuffer; import java.nio.LongBuffer; + import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -519,6 +521,25 @@ public void useAfterClose() { // The expected exception. } } + + @Test + public void eagerTensorIsReleasedAfterSessionIsClosed() { + Tensor sum; + try (EagerSession session = EagerSession.create()) { + Output x = TestUtil.constant(session, "Const1", 10); + Output y = TestUtil.constant(session, "Const2", 20); + sum = TestUtil.addN(session, x, y).tensor(); + assertNotEquals(0L, sum.getNativeHandle()); + assertEquals(30, sum.intValue()); + } + assertEquals(0L, sum.getNativeHandle()); + try { + sum.intValue(); + fail(); + } catch (NullPointerException e) { + // expected. + } + } @Test public void fromHandle() { diff --git a/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java b/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java index c97bcaa3386c40..6e24d88a310398 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java +++ b/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java @@ -67,8 +67,8 @@ public static Output placeholder(Graph g, String name, Class type) { .output(0); } - public static Output addN(Graph g, Output... inputs) { - return g.opBuilder("AddN", "AddN").addInputList(inputs).build().output(0); + public static Output addN(ExecutionEnvironment env, Output... inputs) { + return env.opBuilder("AddN", "AddN").addInputList(inputs).build().output(0); } public static Output matmul(