Skip to content

Commit

Permalink
Allow larger Java output buffers for TFLite outputs
Browse files Browse the repository at this point in the history
Allow the user to provide a larger output buffer than is necessary
when copying from an output tensor in the TFLite Java bindings.
This makes it easier to accommodate outputs that might have variable
output size using a single, pre-allocated output.

See also PR #39266.

PiperOrigin-RevId: 314454310
Change-Id: I83fd82344196831cdd240f106a588996ad87e88b
  • Loading branch information
jdduke authored and tensorflower-gardener committed Jun 3, 2020
1 parent 5febf24 commit 6afec5e
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 26 deletions.
62 changes: 44 additions & 18 deletions tensorflow/lite/java/src/main/java/org/tensorflow/lite/Tensor.java
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,8 @@ void setTo(Object src) {
throw new IllegalArgumentException(
"Null inputs are allowed only if the Tensor is bound to a buffer handle.");
}
throwIfDataIsIncompatible(src);
throwIfTypeIsIncompatible(src);
throwIfSrcShapeIsIncompatible(src);
if (isBuffer(src)) {
setTo((Buffer) src);
} else if (src.getClass().isArray()) {
Expand Down Expand Up @@ -247,7 +248,8 @@ Object copyTo(Object dst) {
throw new IllegalArgumentException(
"Null outputs are allowed only if the Tensor is bound to a buffer handle.");
}
throwIfDataIsIncompatible(dst);
throwIfTypeIsIncompatible(dst);
throwIfDstShapeIsIncompatible(dst);
if (isBuffer(dst)) {
copyTo((Buffer) dst);
} else {
Expand Down Expand Up @@ -387,11 +389,6 @@ static void fillShape(Object o, int dim, int[] shape) {
}
}

private void throwIfDataIsIncompatible(Object o) {
throwIfTypeIsIncompatible(o);
throwIfShapeIsIncompatible(o);
}

private void throwIfTypeIsIncompatible(Object o) {
// ByteBuffer payloads can map to any type, so exempt it from the check.
if (isByteBuffer(o)) {
Expand All @@ -413,29 +410,58 @@ private void throwIfTypeIsIncompatible(Object o) {
}
}

private void throwIfShapeIsIncompatible(Object o) {
if (isBuffer(o)) {
Buffer oBuffer = (Buffer) o;
private void throwIfSrcShapeIsIncompatible(Object src) {
if (isBuffer(src)) {
Buffer srcBuffer = (Buffer) src;
int bytes = numBytes();
// Note that we allow the client to provide a ByteBuffer even for non-byte Tensors.
// In such cases, we only care that the raw byte capacity matches the tensor byte capacity.
int oBytes = isByteBuffer(o) ? oBuffer.capacity() : oBuffer.capacity() * dtype.byteSize();
if (bytes != oBytes) {
int srcBytes =
isByteBuffer(src) ? srcBuffer.capacity() : srcBuffer.capacity() * dtype.byteSize();
if (bytes != srcBytes) {
throw new IllegalArgumentException(
String.format(
"Cannot copy to a TensorFlowLite tensor (%s) with %d bytes from a "
+ "Java Buffer with %d bytes.",
name(), bytes, srcBytes));
}
return;
}
int[] srcShape = computeShapeOf(src);
if (!Arrays.equals(srcShape, shapeCopy)) {
throw new IllegalArgumentException(
String.format(
"Cannot copy to a TensorFlowLite tensor (%s) with shape %s from a Java object "
+ "with shape %s.",
name(), Arrays.toString(shapeCopy), Arrays.toString(srcShape)));
}
}

private void throwIfDstShapeIsIncompatible(Object dst) {
if (isBuffer(dst)) {
Buffer dstBuffer = (Buffer) dst;
int bytes = numBytes();
// Note that we allow the client to provide a ByteBuffer even for non-byte Tensors.
// In such cases, we only care that the raw byte capacity fits the tensor byte capacity.
// This is subtly different than Buffer *inputs*, where the size should be exact.
int dstBytes =
isByteBuffer(dst) ? dstBuffer.capacity() : dstBuffer.capacity() * dtype.byteSize();
if (bytes > dstBytes) {
throw new IllegalArgumentException(
String.format(
"Cannot convert between a TensorFlowLite buffer with %d bytes and a "
"Cannot copy from a TensorFlowLite tensor (%s) with %d bytes to a "
+ "Java Buffer with %d bytes.",
bytes, oBytes));
name(), bytes, dstBytes));
}
return;
}
int[] oShape = computeShapeOf(o);
if (!Arrays.equals(oShape, shapeCopy)) {
int[] dstShape = computeShapeOf(dst);
if (!Arrays.equals(dstShape, shapeCopy)) {
throw new IllegalArgumentException(
String.format(
"Cannot copy between a TensorFlowLite tensor with shape %s and a Java object "
"Cannot copy from a TensorFlowLite tensor (%s) with shape %s to a Java object "
+ "with shape %s.",
Arrays.toString(shapeCopy), Arrays.toString(oShape)));
name(), Arrays.toString(shapeCopy), Arrays.toString(dstShape)));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -301,8 +301,8 @@ public void testRunWithString_wrongShapeError() {
assertThat(e)
.hasMessageThat()
.contains(
"Cannot copy between a TensorFlowLite tensor with shape [2, 4, 4, 12] and "
+ "a Java object with shape [2, 4, 4, 10]");
"Cannot copy from a TensorFlowLite tensor (output_tensor) with shape [2, 4, 4, 12] "
+ "to a Java object with shape [2, 4, 4, 10]");
}
}
}
Expand Down Expand Up @@ -365,7 +365,7 @@ public void testRunWithByteBufferHavingFloats() {
assertThat(e)
.hasMessageThat()
.contains(
"Cannot convert between a TensorFlowLite buffer with 768 bytes and a "
"Cannot copy to a TensorFlowLite tensor (input) with 768 bytes from a "
+ "Java Buffer with 3072 bytes.");
}
int[] inputDims = {4, 8, 8, 3};
Expand Down Expand Up @@ -393,7 +393,7 @@ public void testRunWithByteBufferHavingWrongSize() {
assertThat(e)
.hasMessageThat()
.contains(
"Cannot convert between a TensorFlowLite buffer with 192 bytes and a "
"Cannot copy to a TensorFlowLite tensor (input) with 192 bytes from a "
+ "Java Buffer with 336 bytes.");
}
}
Expand Down Expand Up @@ -494,7 +494,7 @@ public void testRunWithWrongInputNumOfDims() {
assertThat(e)
.hasMessageThat()
.contains(
"Cannot copy between a TensorFlowLite tensor with shape [8, 7, 3] and a "
"Cannot copy from a TensorFlowLite tensor (output) with shape [8, 7, 3] to a "
+ "Java object with shape [2, 8, 8, 3].");
}
}
Expand All @@ -518,7 +518,7 @@ public void testRunWithWrongInputDims() {
assertThat(e)
.hasMessageThat()
.contains(
"Cannot copy between a TensorFlowLite tensor with shape [2, 8, 7, 3] and a "
"Cannot copy from a TensorFlowLite tensor (output) with shape [2, 8, 7, 3] to a "
+ "Java object with shape [2, 8, 8, 3].");
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,20 @@ public void testCopyToByteBuffer() {
assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
}

@Test
public void testCopyToLargerByteBuffer() {
// Allocate a ByteBuffer that is larger than the Tensor, and ensure we can copy to it.
ByteBuffer parsedOutput =
ByteBuffer.allocateDirect(10 * 2 * 8 * 8 * 3 * 4).order(ByteOrder.nativeOrder());
tensor.copyTo(parsedOutput);
assertThat(parsedOutput.position()).isEqualTo(2 * 8 * 8 * 3 * 4);
float[] outputOneD = {
parsedOutput.getFloat(0), parsedOutput.getFloat(4), parsedOutput.getFloat(8)
};
float[] expected = {3.69f, 19.62f, 23.43f};
assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
}

@Test
public void testCopyToByteBufferAsFloatBuffer() {
FloatBuffer parsedOutput =
Expand Down Expand Up @@ -203,8 +217,8 @@ public void testCopyToWrongShape() {
assertThat(e)
.hasMessageThat()
.contains(
"Cannot copy between a TensorFlowLite tensor with shape [2, 8, 8, 3] "
+ "and a Java object with shape [1, 8, 8, 3].");
"Cannot copy from a TensorFlowLite tensor (output) with shape [2, 8, 8, 3] "
+ "to a Java object with shape [1, 8, 8, 3].");
}
}

Expand Down

0 comments on commit 6afec5e

Please sign in to comment.