diff --git a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/DataType.java b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/DataType.java index 407e01c6e17f02..527346c3c9b1d4 100644 --- a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/DataType.java +++ b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/DataType.java @@ -30,7 +30,10 @@ public enum DataType { INT64(4), /** Strings. */ - STRING(5); + STRING(5), + + /** 8-bit signed integer. */ + INT8(9); private final int value; @@ -45,6 +48,7 @@ public int byteSize() { return 4; case INT32: return 4; + case INT8: case UINT8: return 1; case INT64: @@ -83,6 +87,7 @@ String toStringName() { return "float"; case INT32: return "int"; + case INT8: case UINT8: return "byte"; case INT64: diff --git a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Tensor.java b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Tensor.java index 68952ff6e4999a..8ed019dc3f1652 100644 --- a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Tensor.java +++ b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Tensor.java @@ -311,7 +311,13 @@ private void throwIfTypeIsIncompatible(Object o) { return; } DataType oType = dataTypeOf(o); + if (oType != dtype) { + // INT8 and UINT8 have the same string name, "byte" + if (oType.toStringName().equals(dtype.toStringName())) { + return; + } + throw new IllegalArgumentException( String.format( "Cannot convert between a TensorFlowLite tensor with type %s and a Java " diff --git a/tensorflow/lite/java/src/main/native/tensor_jni.cc b/tensorflow/lite/java/src/main/native/tensor_jni.cc index f2cb1f81ab8083..8beafa0c48e17f 100644 --- a/tensorflow/lite/java/src/main/native/tensor_jni.cc +++ b/tensorflow/lite/java/src/main/native/tensor_jni.cc @@ -126,6 +126,7 @@ size_t WriteOneDimensionalArray(JNIEnv* env, jobject object, TfLiteType type, env->GetLongArrayRegion(long_array, 0, num_elements, long_dst); return to_copy; } + case kTfLiteInt8: case kTfLiteUInt8: { jbyteArray byte_array = static_cast(array); jbyte* byte_dst = static_cast(dst); @@ -174,6 +175,7 @@ size_t ReadOneDimensionalArray(JNIEnv* env, TfLiteType data_type, static_cast(src)); return size; } + case kTfLiteInt8: case kTfLiteUInt8: { jbyteArray byte_array = static_cast(dst); env->SetByteArrayRegion(byte_array, 0, len, diff --git a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/DataTypeTest.java b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/DataTypeTest.java index 8412ec0e9dacd5..d1e9c03ddd681a 100644 --- a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/DataTypeTest.java +++ b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/DataTypeTest.java @@ -39,4 +39,11 @@ public void testConversion() { assertThat(DataType.fromC(dataType.c())).isEqualTo(dataType); } } + + @Test + public void testINT8AndUINT8() { + assertThat(DataType.INT8.toStringName()).isEqualTo("byte"); + assertThat(DataType.UINT8.toStringName()).isEqualTo("byte"); + assertThat(DataType.INT8.toStringName()).isEqualTo(DataType.UINT8.toStringName()); + } }