Navigation Menu

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Java] Add eager tensor support #28636

Merged
merged 2 commits into from May 13, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -60,18 +60,28 @@ public String toString() {
abstract long getUnsafeNativeHandle(int outputIdx);

/**
* Returns the shape of the tensor of the {code outputIdx}th output of this operation.
* Returns the shape of the tensor of the {@code outputIdx}th output of this operation.
*
* @param outputIdx index of the output of this operation
* @return output tensor shape
*/
abstract long[] shape(int outputIdx);

/**
* Returns the datatype of the tensor of the {code outputIdx}th output of this operation.
* Returns the datatype of the tensor of the {@code outputIdx}th output of this operation.
*
* @param outputIdx index of the output of this operation
* @return output tensor datatype
*/
abstract DataType dtype(int outputIdx);

/**
* Returns the tensor of the {@code outputIdx}th output of this operation.
*
* <p>This is only supported in an eager execution environment.
*
* @param outputIdx index of the output of this operation
* @return output tensor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does it return if the user happens to call it in a non-eager env?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See GraphOperation class in this PR, it throws an IllegalStateException error.

It's documented in the Output.tensor() method but not in this internal one, to which Output.tensor() delegates to.

*/
abstract Tensor<?> tensor(int outputIdx);
}
59 changes: 48 additions & 11 deletions tensorflow/java/src/main/java/org/tensorflow/EagerOperation.java
Expand Up @@ -15,7 +15,7 @@

package org.tensorflow;

import java.util.Arrays;
import java.util.concurrent.atomic.AtomicReferenceArray;

/**
* Implementation of an {@link Operation} executed eagerly.
Expand All @@ -38,6 +38,7 @@ class EagerOperation extends AbstractOperation {
this.type = type;
this.name = name;
this.nativeRef = new NativeReference(session, this, opNativeHandle, outputNativeHandles);
this.outputTensors = new AtomicReferenceArray<Tensor<?>>(outputNativeHandles.length);
}

@Override
Expand Down Expand Up @@ -72,6 +73,12 @@ public long getUnsafeNativeHandle(int outputIndex) {

@Override
public long[] shape(int outputIndex) {
// If the tensor of this output has already been resolved, return its shape.
// Otherwise, retrieve the tensor shape from the native library.
Tensor<?> tensor = outputTensors.get(outputIndex);
if (tensor != null) {
return tensor.shape();
}
long outputNativeHandle = getUnsafeNativeHandle(outputIndex);
long[] shape = new long[numDims(outputNativeHandle)];
for (int i = 0; i < shape.length; ++i) {
Expand All @@ -82,10 +89,43 @@ public long[] shape(int outputIndex) {

@Override
public DataType dtype(int outputIndex) {
// If the tensor of this output has already been resolved, return its datatype.
// Otherwise, retrieve the tensor datatype from the native library.
Tensor<?> tensor = outputTensors.get(outputIndex);
if (tensor != null) {
return tensor.dataType();
}
long outputNativeHandle = getUnsafeNativeHandle(outputIndex);
return DataType.fromC(dataType(outputNativeHandle));
}

@Override
public Tensor<?> tensor(int outputIndex) {
Tensor<?> tensor = outputTensors.get(outputIndex);
if (tensor == null) {
tensor = resolveTensor(outputIndex);
}
return tensor;
}

private final EagerSession session;
private final NativeReference nativeRef;
private final String type;
private final String name;
private final AtomicReferenceArray<Tensor<?>> outputTensors;

private Tensor<?> resolveTensor(int outputIndex) {
// Take an optimistic approach, where we attempt to resolve the output tensor without locking.
// If another thread has resolved it meanwhile, release our copy and reuse the existing one instead.
long tensorNativeHandle = resolveTensorHandle(getUnsafeNativeHandle(outputIndex));
Tensor<?> tensor = Tensor.fromHandle(tensorNativeHandle, session);
if (!outputTensors.compareAndSet(outputIndex, null, tensor)) {
tensor.close();
tensor = outputTensors.get(outputIndex);
}
return tensor;
}

private static class NativeReference extends EagerSession.NativeReference {

NativeReference(
Expand All @@ -98,29 +138,26 @@ private static class NativeReference extends EagerSession.NativeReference {
@Override
void delete() {
if (opHandle != 0L) {
for (long tensorHandle : outputHandles) {
if (tensorHandle != 0L) {
EagerOperation.deleteTensorHandle(tensorHandle);
for (int i = 0; i < outputHandles.length; ++i) {
if (outputHandles[i] != 0L) {
EagerOperation.deleteTensorHandle(outputHandles[i]);
outputHandles[i] = 0L;
}
}
EagerOperation.delete(opHandle);
opHandle = 0L;
Arrays.fill(outputHandles, 0L);
}
}

private long opHandle;
private final long[] outputHandles;
}

private final EagerSession session;
private final NativeReference nativeRef;
private final String type;
private final String name;


private static native void delete(long handle);

private static native void deleteTensorHandle(long handle);

private static native long resolveTensorHandle(long handle);

private static native int outputListLength(long handle, String name);

Expand Down
Expand Up @@ -138,6 +138,11 @@ DataType dtype(int outputIdx) {
r.close();
}
}

@Override
Tensor<?> tensor(int outputIdx) {
throw new IllegalStateException("Graph tensors must be fetched by running a session");
}

long getUnsafeNativeHandle() {
return unsafeNativeHandle;
Expand Down
16 changes: 16 additions & 0 deletions tensorflow/java/src/main/java/org/tensorflow/Output.java
Expand Up @@ -47,6 +47,22 @@ public Shape shape() {
public DataType dataType() {
return operation.dtype(index);
}

/**
* Returns the tensor at this output.
*
* <p>This operation is only supported on the outputs of an operation executed eagerly.
* For graph environments, output tensors must be fetched by running a session, using
* {@link Session.Runner#fetch(Output)}.
*
* @return tensor
* @throws IllegalStateException if this output results from a graph
* @see EagerSession
*/
@SuppressWarnings("unchecked")
public Tensor<T> tensor() {
return (Tensor<T>)operation.tensor(index);
}

@Override
public Output<T> asOutput() {
Expand Down
116 changes: 94 additions & 22 deletions tensorflow/java/src/main/java/org/tensorflow/Tensor.java
Expand Up @@ -140,15 +140,17 @@ private static Tensor<?> create(Object obj, DataType dtype) {
Tensor<?> t = new Tensor(dtype);
t.shapeCopy = new long[numDimensions(obj, dtype)];
fillShape(obj, 0, t.shapeCopy);
long nativeHandle;
if (t.dtype != DataType.STRING) {
int byteSize = elemByteSize(t.dtype) * numElements(t.shapeCopy);
t.nativeHandle = allocate(t.dtype.c(), t.shapeCopy, byteSize);
setValue(t.nativeHandle, obj);
nativeHandle = allocate(t.dtype.c(), t.shapeCopy, byteSize);
setValue(nativeHandle, obj);
} else if (t.shapeCopy.length != 0) {
t.nativeHandle = allocateNonScalarBytes(t.shapeCopy, (Object[]) obj);
nativeHandle = allocateNonScalarBytes(t.shapeCopy, (Object[]) obj);
} else {
t.nativeHandle = allocateScalarBytes((byte[]) obj);
nativeHandle = allocateScalarBytes((byte[]) obj);
}
t.nativeRef = new NativeReference(nativeHandle);
return t;
}

Expand Down Expand Up @@ -314,23 +316,22 @@ private static <T> Tensor<T> allocateForBuffer(DataType dataType, long[] shape,
}
Tensor<T> t = new Tensor<T>(dataType);
t.shapeCopy = Arrays.copyOf(shape, shape.length);
t.nativeHandle = allocate(t.dtype.c(), t.shapeCopy, nbytes);
long nativeHandle = allocate(t.dtype.c(), t.shapeCopy, nbytes);
t.nativeRef = new NativeReference(nativeHandle);
return t;
}

/**
* Release resources associated with the Tensor.
*
* <p><b>WARNING:</b>If not invoked, memory will be leaked.
* <p><b>WARNING:</b>This must be invoked for all tensors that were not been produced by an eager
* operation or memory will be leaked.
*
* <p>The Tensor object is no longer usable after {@code close} returns.
*/
@Override
public void close() {
if (nativeHandle != 0) {
delete(nativeHandle);
nativeHandle = 0;
}
nativeRef.release();
}

/** Returns the {@link DataType} of elements stored in the Tensor. */
Expand Down Expand Up @@ -374,7 +375,7 @@ public long[] shape() {
* @throws IllegalArgumentException if the Tensor does not represent a float scalar.
*/
public float floatValue() {
return scalarFloat(nativeHandle);
return scalarFloat(getNativeHandle());
}

/**
Expand All @@ -383,7 +384,7 @@ public float floatValue() {
* @throws IllegalArgumentException if the Tensor does not represent a double scalar.
*/
public double doubleValue() {
return scalarDouble(nativeHandle);
return scalarDouble(getNativeHandle());
}

/**
Expand All @@ -392,7 +393,7 @@ public double doubleValue() {
* @throws IllegalArgumentException if the Tensor does not represent a int scalar.
*/
public int intValue() {
return scalarInt(nativeHandle);
return scalarInt(getNativeHandle());
}

/**
Expand All @@ -401,7 +402,7 @@ public int intValue() {
* @throws IllegalArgumentException if the Tensor does not represent a long scalar.
*/
public long longValue() {
return scalarLong(nativeHandle);
return scalarLong(getNativeHandle());
}

/**
Expand All @@ -410,7 +411,7 @@ public long longValue() {
* @throws IllegalArgumentException if the Tensor does not represent a boolean scalar.
*/
public boolean booleanValue() {
return scalarBoolean(nativeHandle);
return scalarBoolean(getNativeHandle());
}

/**
Expand All @@ -419,7 +420,7 @@ public boolean booleanValue() {
* @throws IllegalArgumentException if the Tensor does not represent a boolean scalar.
*/
public byte[] bytesValue() {
return scalarBytes(nativeHandle);
return scalarBytes(getNativeHandle());
}

/**
Expand Down Expand Up @@ -448,7 +449,7 @@ public byte[] bytesValue() {
*/
public <U> U copyTo(U dst) {
throwExceptionIfTypeIsIncompatible(dst);
readNDArray(nativeHandle, dst);
readNDArray(getNativeHandle(), dst);
return dst;
}

Expand Down Expand Up @@ -553,24 +554,35 @@ static Tensor<?> fromHandle(long handle) {
@SuppressWarnings("rawtypes")
Tensor<?> t = new Tensor(DataType.fromC(dtype(handle)));
t.shapeCopy = shape(handle);
t.nativeHandle = handle;
t.nativeRef = new NativeReference(handle);
return t;
}

/**
* Create an eager Tensor object from a handle to the C TF_Tensor object.
*
* <p>Takes ownership of the handle.
*/
static Tensor<?> fromHandle(long handle, EagerSession session) {
Tensor<?> t = fromHandle(handle);
t.nativeRef.eager(session, t);
return t;
}

long getNativeHandle() {
return nativeHandle;
return nativeRef.tensorHandle;
}

private long nativeHandle;
private DataType dtype;
private NativeReference nativeRef = null;
private final DataType dtype;
private long[] shapeCopy = null;

private Tensor(DataType t) {
dtype = t;
}

private ByteBuffer buffer() {
return buffer(nativeHandle).order(ByteOrder.nativeOrder());
return buffer(getNativeHandle()).order(ByteOrder.nativeOrder());
}

private static IllegalArgumentException incompatibleBuffer(Buffer buf, DataType dataType) {
Expand Down Expand Up @@ -609,6 +621,66 @@ private static void throwExceptionIfNotByteOfByteArrays(Object array) {
}
}

/**
* Reference to the underlying native tensor
*
* <p>Tensors are commonly allocated in a `try-with-resources` statement, where they get automatically
* released after executing the last line of the `try` block they were declared in.
*
* <p>They can also be attached to an eager session, where in this case their lifetime ends either when
* this session is closed or when the Tensor instance is no longer referenced and have been garbage-collected.
*
* <p>This helper class wraps the tensor native handle and support both situations; If an eager reference to
* the tensor exists, it will take care of releasing the tensor at the end of its life. If the tensor is
* being explicetly closed before this happens, it will take cake of clearing its association with any eager
* session before cleaning up the resources.
*/
private static class NativeReference {

/**
* Attaches this reference to an eager session
*/
private class EagerReference extends EagerSession.NativeReference {

EagerReference(EagerSession session, Tensor<?> tensor) {
super(session, tensor);
}

@Override
void delete() {
// Mark this eager reference as cleared since it has been deleted by the session
NativeReference.this.eagerRef = null;
NativeReference.this.release();
}
}

NativeReference(long tensorHandle) {
this.tensorHandle = tensorHandle;
}

void eager(EagerSession session, Tensor<?> tensor) {
if (eagerRef != null) {
throw new IllegalStateException("The tensor is already attached to an eager session");
}
eagerRef = new EagerReference(session, tensor);
}

synchronized void release() {
if (tensorHandle != 0L) {
// Clear any remaining eager reference to this tensor
if (eagerRef != null) {
eagerRef.clear();
eagerRef = null;
}
Tensor.delete(tensorHandle);
tensorHandle = 0L;
}
}

private long tensorHandle;
private EagerReference eagerRef;
}

private static HashMap<Class<?>, DataType> classDataTypes = new HashMap<>();

static {
Expand Down