From af5d8ee1632277a0bd593b8330e5f47932272f28 Mon Sep 17 00:00:00 2001 From: Ivan Kobzarev Date: Fri, 13 Sep 2019 12:05:55 -0700 Subject: [PATCH] [android] Tensor renaming to dtype, shape; support long, double --- .../org/pytorch/PytorchInstrumentedTests.java | 57 +-- .../src/main/cpp/pytorch_jni.cpp | 94 ++--- .../src/main/java/org/pytorch/Module.java | 4 +- .../src/main/java/org/pytorch/Tensor.java | 329 +++++++++++------- .../TorchVisionInstrumentedTests.java | 2 +- .../pytorch/torchvision/TensorImageUtils.java | 8 +- 6 files changed, 293 insertions(+), 201 deletions(-) diff --git a/android/pytorch_android/src/androidTest/java/org/pytorch/PytorchInstrumentedTests.java b/android/pytorch_android/src/androidTest/java/org/pytorch/PytorchInstrumentedTests.java index 5cdeb30b63b23..517c9b35d771e 100644 --- a/android/pytorch_android/src/androidTest/java/org/pytorch/PytorchInstrumentedTests.java +++ b/android/pytorch_android/src/androidTest/java/org/pytorch/PytorchInstrumentedTests.java @@ -1,13 +1,11 @@ package org.pytorch; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertTrue; - import android.content.Context; -import androidx.test.ext.junit.runners.AndroidJUnit4; -import androidx.test.platform.app.InstrumentationRegistry; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; + import java.io.File; import java.io.FileOutputStream; import java.io.IOException; @@ -15,9 +13,14 @@ import java.io.OutputStream; import java.util.HashMap; import java.util.Map; -import org.junit.Before; -import org.junit.Test; -import org.junit.runner.RunWith; + +import androidx.test.ext.junit.runners.AndroidJUnit4; +import androidx.test.platform.app.InstrumentationRegistry; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; @RunWith(AndroidJUnit4.class) public class PytorchInstrumentedTests { @@ -33,7 +36,7 @@ public void setUp() { public void testForwardNull() throws IOException { final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME)); final IValue input = - IValue.tensor(Tensor.newByteTensor(new long[] {1}, Tensor.allocateByteBuffer(1))); + IValue.tensor(Tensor.newTensor(new long[] {1}, Tensor.allocateByteBuffer(1))); assertTrue(input.isTensor()); final IValue output = module.forward(input); assertTrue(output.isNull()); @@ -94,13 +97,13 @@ public void testEqFloat() throws IOException { @Test public void testEqTensor() throws IOException { - final long[] inputTensorDims = new long[] {1, 3, 224, 224}; - final long numElements = Tensor.numElements(inputTensorDims); + final long[] inputTensorShape = new long[] {1, 3, 224, 224}; + final long numElements = Tensor.numel(inputTensorShape); final float[] inputTensorData = new float[(int) numElements]; for (int i = 0; i < numElements; ++i) { inputTensorData[i] = i; } - final Tensor inputTensor = Tensor.newFloatTensor(inputTensorDims, inputTensorData); + final Tensor inputTensor = Tensor.newTensor(inputTensorShape, inputTensorData); final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME)); final IValue input = IValue.tensor(inputTensor); @@ -110,7 +113,7 @@ public void testEqTensor() throws IOException { assertTrue(output.isTensor()); final Tensor outputTensor = output.getTensor(); assertNotNull(outputTensor); - assertArrayEquals(inputTensorDims, outputTensor.dims); + assertArrayEquals(inputTensorShape, outputTensor.shape); float[] outputData = outputTensor.getDataAsFloatArray(); for (int i = 0; i < numElements; i++) { assertTrue(inputTensorData[i] == outputData[i]); @@ -216,8 +219,8 @@ public void testRunUndefinedMethod() throws IOException { @Test public void testTensorMethods() { - long[] dims = new long[] {1, 3, 224, 224}; - final int numel = (int) Tensor.numElements(dims); + long[] shape = new long[] {1, 3, 224, 224}; + final int numel = (int) Tensor.numel(shape); int[] ints = new int[numel]; float[] floats = new float[numel]; @@ -228,16 +231,16 @@ public void testTensorMethods() { floats[i] = i / 1000.f; } - Tensor tensorBytes = Tensor.newByteTensor(dims, bytes); - assertTrue(tensorBytes.isByteTensor()); + Tensor tensorBytes = Tensor.newTensor(shape, bytes); + assertTrue(tensorBytes.dtype() == Tensor.DTYPE_BYTE); assertArrayEquals(bytes, tensorBytes.getDataAsByteArray()); - Tensor tensorInts = Tensor.newIntTensor(dims, ints); - assertTrue(tensorInts.isIntTensor()); + Tensor tensorInts = Tensor.newTensor(shape, ints); + assertTrue(tensorInts.dtype() == Tensor.DTYPE_INT32); assertArrayEquals(ints, tensorInts.getDataAsIntArray()); - Tensor tensorFloats = Tensor.newFloatTensor(dims, floats); - assertTrue(tensorFloats.isFloatTensor()); + Tensor tensorFloats = Tensor.newTensor(shape, floats); + assertTrue(tensorFloats.dtype() == Tensor.DTYPE_FLOAT32); float[] floatsOut = tensorFloats.getDataAsFloatArray(); assertTrue(floatsOut.length == numel); for (int i = 0; i < numel; i++) { @@ -247,11 +250,11 @@ public void testTensorMethods() { @Test(expected = IllegalStateException.class) public void testTensorIllegalStateOnWrongType() { - long[] dims = new long[] {1, 3, 224, 224}; - final int numel = (int) Tensor.numElements(dims); + long[] shape = new long[] {1, 3, 224, 224}; + final int numel = (int) Tensor.numel(shape); float[] floats = new float[numel]; - Tensor tensorFloats = Tensor.newFloatTensor(dims, floats); - assertTrue(tensorFloats.isFloatTensor()); + Tensor tensorFloats = Tensor.newTensor(shape, floats); + assertTrue(tensorFloats.dtype() == Tensor.DTYPE_FLOAT32); tensorFloats.getDataAsByteArray(); } diff --git a/android/pytorch_android/src/main/cpp/pytorch_jni.cpp b/android/pytorch_android/src/main/cpp/pytorch_jni.cpp index a86cd501c374a..897311f8b477e 100644 --- a/android/pytorch_android/src/main/cpp/pytorch_jni.cpp +++ b/android/pytorch_android/src/main/cpp/pytorch_jni.cpp @@ -10,9 +10,11 @@ namespace pytorch_jni { -constexpr static int kTensorTypeCodeByte = 1; -constexpr static int kTensorTypeCodeInt32 = 2; -constexpr static int kTensorTypeCodeFloat32 = 3; +constexpr static int kTensorDTypeByte = 1; +constexpr static int kTensorDTypeInt32 = 2; +constexpr static int kTensorDTypeFloat32 = 3; +constexpr static int kTensorDTypeLong64 = 4; +constexpr static int kTensorDTypeDouble64 = 5; template struct JHashMap @@ -42,34 +44,40 @@ struct JHashMap static at::Tensor newAtTensor( facebook::jni::alias_ref jbuffer, - facebook::jni::alias_ref jdims, - jint typeCode) { - const auto rank = jdims->size(); - const auto dimsArr = jdims->getRegion(0, rank); - std::vector dimsVec{}; - dimsVec.reserve(rank); + facebook::jni::alias_ref jshape, + jint jdtype) { + const auto rank = jshape->size(); + const auto shapeArr = jshape->getRegion(0, rank); + std::vector shapeVec{}; + shapeVec.reserve(rank); auto numel = 1; for (auto i = 0; i < rank; ++i) { - dimsVec.push_back(dimsArr[i]); - numel *= dimsArr[i]; + shapeVec.push_back(shapeArr[i]); + numel *= shapeArr[i]; } JNIEnv* jni = facebook::jni::Environment::current(); caffe2::TypeMeta typeMeta{}; int dataElementSizeBytes = 0; - if (kTensorTypeCodeFloat32 == typeCode) { + if (kTensorDTypeFloat32 == jdtype) { dataElementSizeBytes = 4; typeMeta = caffe2::TypeMeta::Make(); - } else if (kTensorTypeCodeInt32 == typeCode) { + } else if (kTensorDTypeInt32 == jdtype) { dataElementSizeBytes = 4; - typeMeta = caffe2::TypeMeta::Make(); - } else if (kTensorTypeCodeByte == typeCode) { + typeMeta = caffe2::TypeMeta::Make(); + } else if (kTensorDTypeByte == jdtype) { dataElementSizeBytes = 1; - typeMeta = caffe2::TypeMeta::Make(); + typeMeta = caffe2::TypeMeta::Make(); + } else if (kTensorDTypeLong64 == jdtype) { + dataElementSizeBytes = 8; + typeMeta = caffe2::TypeMeta::Make(); + } else if (kTensorDTypeDouble64 == jdtype) { + dataElementSizeBytes = 8; + typeMeta = caffe2::TypeMeta::Make(); } else { facebook::jni::throwNewJavaException( facebook::jni::gJavaLangIllegalArgumentException, - "Unknown Tensor typeCode %d", - typeCode); + "Unknown Tensor jdtype %d", + jdtype); } const auto dataCapacity = jni->GetDirectBufferCapacity(jbuffer.get()); if (dataCapacity != numel) { @@ -84,7 +92,7 @@ static at::Tensor newAtTensor( } return torch::from_blob( jni->GetDirectBufferAddress(jbuffer.get()), - torch::IntArrayRef(dimsVec), + torch::IntArrayRef(shapeVec), at::TensorOptions(typeMeta)); } @@ -94,8 +102,8 @@ class JTensor : public facebook::jni::JavaClass { static facebook::jni::local_ref newJTensor( facebook::jni::alias_ref jBuffer, - facebook::jni::alias_ref jDims, - jint typeCode) { + facebook::jni::alias_ref jShape, + jint jdtype) { static auto jMethodNewTensor = JTensor::javaClassStatic() ->getStaticMethod( @@ -103,35 +111,39 @@ class JTensor : public facebook::jni::JavaClass { facebook::jni::alias_ref, jint)>("nativeNewTensor"); return jMethodNewTensor( - JTensor::javaClassStatic(), jBuffer, jDims, typeCode); + JTensor::javaClassStatic(), jBuffer, jShape, jdtype); } static facebook::jni::local_ref newJTensorFromAtTensor( const at::Tensor& tensor) { const auto scalarType = tensor.scalar_type(); - int typeCode = 0; + int jdtype = 0; if (at::kFloat == scalarType) { - typeCode = kTensorTypeCodeFloat32; + jdtype = kTensorDTypeFloat32; } else if (at::kInt == scalarType) { - typeCode = kTensorTypeCodeInt32; + jdtype = kTensorDTypeInt32; } else if (at::kByte == scalarType) { - typeCode = kTensorTypeCodeByte; + jdtype = kTensorDTypeByte; + } else if (at::kLong == scalarType) { + jdtype = kTensorDTypeLong64; + } else if (at::kDouble == scalarType) { + jdtype = kTensorDTypeDouble64; } else { facebook::jni::throwNewJavaException( facebook::jni::gJavaLangIllegalArgumentException, "at::Tensor scalar type is not supported on java side"); } - const auto& tensorDims = tensor.sizes(); - std::vector tensorDimsVec; - for (const auto& dim : tensorDims) { - tensorDimsVec.push_back(dim); + const auto& tensorShape = tensor.sizes(); + std::vector tensorShapeVec; + for (const auto& s : tensorShape) { + tensorShapeVec.push_back(s); } - facebook::jni::local_ref jTensorDims = - facebook::jni::make_long_array(tensorDimsVec.size()); + facebook::jni::local_ref jTensorShape = + facebook::jni::make_long_array(tensorShapeVec.size()); - jTensorDims->setRegion(0, tensorDimsVec.size(), tensorDimsVec.data()); + jTensorShape->setRegion(0, tensorShapeVec.size(), tensorShapeVec.data()); facebook::jni::local_ref jTensorBuffer = facebook::jni::JByteBuffer::allocateDirect(tensor.nbytes()); @@ -140,18 +152,18 @@ class JTensor : public facebook::jni::JavaClass { jTensorBuffer->getDirectBytes(), tensor.storage().data(), tensor.nbytes()); - return JTensor::newJTensor(jTensorBuffer, jTensorDims, typeCode); + return JTensor::newJTensor(jTensorBuffer, jTensorShape, jdtype); } static at::Tensor newAtTensorFromJTensor( facebook::jni::alias_ref jtensor) { - static const auto typeCodeMethod = - JTensor::javaClassStatic()->getMethod("getTypeCode"); - jint typeCode = typeCodeMethod(jtensor); + static const auto dtypeMethod = + JTensor::javaClassStatic()->getMethod("dtype"); + jint jdtype = dtypeMethod(jtensor); - static const auto dimsField = - JTensor::javaClassStatic()->getField("dims"); - auto jdims = jtensor->getFieldValue(dimsField); + static const auto shapeField = + JTensor::javaClassStatic()->getField("shape"); + auto jshape = jtensor->getFieldValue(shapeField); static auto dataBufferMethod = JTensor::javaClassStatic() @@ -160,7 +172,7 @@ class JTensor : public facebook::jni::JavaClass { "getRawDataBuffer"); facebook::jni::local_ref jbuffer = dataBufferMethod(jtensor); - return newAtTensor(jbuffer, jdims, typeCode); + return newAtTensor(jbuffer, jshape, jdtype); } }; diff --git a/android/pytorch_android/src/main/java/org/pytorch/Module.java b/android/pytorch_android/src/main/java/org/pytorch/Module.java index 38dfc8ddf59d3..51c612778e67a 100644 --- a/android/pytorch_android/src/main/java/org/pytorch/Module.java +++ b/android/pytorch_android/src/main/java/org/pytorch/Module.java @@ -12,8 +12,8 @@ public static Module load(final String modelAbsolutePath) { return new Module(modelAbsolutePath); } - private Module(final String modelAbsolutePath) { - this.mNativePeer = new NativePeer(modelAbsolutePath); + private Module(final String moduleAbsolutePath) { + this.mNativePeer = new NativePeer(moduleAbsolutePath); } public IValue forward(IValue... inputs) { diff --git a/android/pytorch_android/src/main/java/org/pytorch/Tensor.java b/android/pytorch_android/src/main/java/org/pytorch/Tensor.java index ee595fe3ab411..81c5f1182748a 100644 --- a/android/pytorch_android/src/main/java/org/pytorch/Tensor.java +++ b/android/pytorch_android/src/main/java/org/pytorch/Tensor.java @@ -3,31 +3,46 @@ import java.nio.Buffer; import java.nio.ByteBuffer; import java.nio.ByteOrder; +import java.nio.DoubleBuffer; import java.nio.FloatBuffer; import java.nio.IntBuffer; +import java.nio.LongBuffer; import java.util.Arrays; import java.util.Locale; public abstract class Tensor { - private static final int TYPE_CODE_BYTE = 1; - private static final int TYPE_CODE_INT32 = 2; - private static final int TYPE_CODE_FLOAT32 = 3; + public static final int DTYPE_BYTE = 1; + public static final int DTYPE_INT32 = 2; + public static final int DTYPE_FLOAT32 = 3; + public static final int DTYPE_LONG64 = 4; + public static final int DTYPE_DOUBLE64 = 5; private static final String ERROR_MSG_DATA_BUFFER_NOT_NULL = "Data buffer must be not null"; private static final String ERROR_MSG_DATA_ARRAY_NOT_NULL = "Data array must be not null"; - private static final String ERROR_MSG_DIMS_NOT_NULL = "Dims must be not null"; - private static final String ERROR_MSG_DIMS_NOT_EMPTY = "Dims must be not empty"; - private static final String ERROR_MSG_INDEX_NOT_NULL = "Index must be not null"; - private static final String ERROR_MSG_DIMS_NON_NEGATIVE = "Dims must be non negative"; + private static final String ERROR_MSG_SHAPE_NOT_NULL = "Dims must be not null"; + private static final String ERROR_MSG_SHAPE_NOT_EMPTY = "Dims must be not empty"; + private static final String ERROR_MSG_SHAPE_NON_NEGATIVE = "Dims must be non negative"; private static final String ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER = "Data buffer must have native byte order (java.nio.ByteOrder#nativeOrder)"; private static final String ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT = "Data buffer must be direct (java.nio.ByteBuffer#allocateDirect)"; - public final long[] dims; + public final long[] shape; - private static final int FLOAT_SIZE_BYTES = 4; private static final int INT_SIZE_BYTES = 4; + private static final int FLOAT_SIZE_BYTES = 4; + private static final int LONG_SIZE_BYTES = 8; + private static final int DOUBLE_SIZE_BYTES = 8; + + public static ByteBuffer allocateByteBuffer(int numElements) { + return ByteBuffer.allocateDirect(numElements).order(ByteOrder.nativeOrder()); + } + + public static IntBuffer allocateIntBuffer(int numElements) { + return ByteBuffer.allocateDirect(numElements * INT_SIZE_BYTES) + .order(ByteOrder.nativeOrder()) + .asIntBuffer(); + } public static FloatBuffer allocateFloatBuffer(int numElements) { return ByteBuffer.allocateDirect(numElements * FLOAT_SIZE_BYTES) @@ -35,99 +50,120 @@ public static FloatBuffer allocateFloatBuffer(int numElements) { .asFloatBuffer(); } - public static IntBuffer allocateIntBuffer(int numElements) { - return ByteBuffer.allocateDirect(numElements * INT_SIZE_BYTES) + public static LongBuffer allocateLongBuffer(int numElements) { + return ByteBuffer.allocateDirect(numElements * LONG_SIZE_BYTES) .order(ByteOrder.nativeOrder()) - .asIntBuffer(); + .asLongBuffer(); } - public static ByteBuffer allocateByteBuffer(int numElements) { - return ByteBuffer.allocateDirect(numElements).order(ByteOrder.nativeOrder()); + public static DoubleBuffer allocateDoubleBuffer(int numElements) { + return ByteBuffer.allocateDirect(numElements * DOUBLE_SIZE_BYTES) + .order(ByteOrder.nativeOrder()) + .asDoubleBuffer(); } - public static Tensor newFloatTensor(long[] dims, float[] data) { + public static Tensor newTensor(long[] shape, byte[] data) { checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL); - checkArgument(dims != null, ERROR_MSG_DIMS_NOT_NULL); - checkDims(dims); - checkDimsAndDataCapacityConsistency(data.length, dims); - final int bufferCapacity = (int) numElements(dims); - final FloatBuffer floatBuffer = allocateFloatBuffer(bufferCapacity); - floatBuffer.put(data); - return new Tensor_float32(floatBuffer, dims); + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + checkShape(shape); + checkShapeAndDataCapacityConsistency(data.length, shape); + final ByteBuffer byteBuffer = allocateByteBuffer((int) numel(shape)); + byteBuffer.put(data); + return new Tensor_byte(byteBuffer, shape); } - public static Tensor newIntTensor(long[] dims, int[] data) { + public static Tensor newTensor(long[] shape, int[] data) { checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL); - checkArgument(dims != null, ERROR_MSG_DIMS_NOT_NULL); - checkDims(dims); - checkDimsAndDataCapacityConsistency(data.length, dims); - final int bufferCapacity = (int) numElements(dims); - final IntBuffer intBuffer = allocateIntBuffer(bufferCapacity); + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + checkShape(shape); + checkShapeAndDataCapacityConsistency(data.length, shape); + final IntBuffer intBuffer = allocateIntBuffer((int) numel(shape)); intBuffer.put(data); - return new Tensor_int32(intBuffer, dims); + return new Tensor_int32(intBuffer, shape); } - public static Tensor newByteTensor(long[] dims, byte[] data) { + public static Tensor newTensor(long[] shape, float[] data) { checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL); - checkArgument(dims != null, ERROR_MSG_DIMS_NOT_NULL); - checkDims(dims); - checkDimsAndDataCapacityConsistency(data.length, dims); - final int bufferCapacity = (int) numElements(dims); - final ByteBuffer byteBuffer = allocateByteBuffer(bufferCapacity); - byteBuffer.put(data); - return new Tensor_byte(byteBuffer, dims); + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + checkShape(shape); + checkShapeAndDataCapacityConsistency(data.length, shape); + final FloatBuffer floatBuffer = allocateFloatBuffer((int) numel(shape)); + floatBuffer.put(data); + return new Tensor_float32(floatBuffer, shape); + } + + public static Tensor newTensor(long[] shape, long[] data) { + checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL); + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + checkShape(shape); + checkShapeAndDataCapacityConsistency(data.length, shape); + final LongBuffer longBuffer = allocateLongBuffer((int) numel(shape)); + longBuffer.put(data); + return new Tensor_long64(longBuffer, shape); + } + + public static Tensor newTensor(long[] shape, double[] data) { + checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL); + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + checkShape(shape); + checkShapeAndDataCapacityConsistency(data.length, shape); + final DoubleBuffer doubleBuffer = allocateDoubleBuffer((int) numel(shape)); + doubleBuffer.put(data); + return new Tensor_double64(doubleBuffer, shape); } - public static Tensor newFloatTensor(long[] dims, FloatBuffer data) { + public static Tensor newTensor(long[] shape, FloatBuffer data) { checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL); - checkArgument(dims != null, ERROR_MSG_DIMS_NOT_NULL); - checkDims(dims); - checkDimsAndDataCapacityConsistency(data.capacity(), dims); + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + checkShape(shape); + checkShapeAndDataCapacityConsistency(data.capacity(), shape); checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT); checkArgument( (data.order() == ByteOrder.nativeOrder()), ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER); - return new Tensor_float32(data, dims); + return new Tensor_float32(data, shape); } - public static Tensor newIntTensor(long[] dims, IntBuffer data) { + public static Tensor newTensor(long[] shape, IntBuffer data) { checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL); - checkArgument(dims != null, ERROR_MSG_DIMS_NOT_NULL); - checkDims(dims); - checkDimsAndDataCapacityConsistency(data.capacity(), dims); + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + checkShape(shape); + checkShapeAndDataCapacityConsistency(data.capacity(), shape); checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT); checkArgument( (data.order() == ByteOrder.nativeOrder()), ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER); - return new Tensor_int32(data, dims); + return new Tensor_int32(data, shape); } - public static Tensor newByteTensor(long[] dims, ByteBuffer data) { + public static Tensor newTensor(long[] shape, ByteBuffer data) { checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL); - checkArgument(dims != null, ERROR_MSG_DIMS_NOT_NULL); - checkDims(dims); - checkDimsAndDataCapacityConsistency(data.capacity(), dims); + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + checkShape(shape); + checkShapeAndDataCapacityConsistency(data.capacity(), shape); checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT); checkArgument( (data.order() == ByteOrder.nativeOrder()), ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER); - return new Tensor_byte(data, dims); + return new Tensor_byte(data, shape); } - private Tensor(long[] dims) { - checkDims(dims); - this.dims = Arrays.copyOf(dims, dims.length); + private Tensor(long[] shape) { + checkShape(shape); + this.shape = Arrays.copyOf(shape, shape.length); } - public static long numElements(long[] dims) { - checkDims(dims); + public static long numel(long[] shape) { + checkShape(shape); int result = 1; - for (long dim : dims) { + for (long dim : shape) { result *= dim; } return result; } + public abstract int dtype(); + public byte[] getDataAsByteArray() { throw new IllegalStateException( "Tensor of type " + getClass().getSimpleName() + " cannot return data as byte array."); @@ -143,31 +179,85 @@ public float[] getDataAsFloatArray() { "Tensor of type " + getClass().getSimpleName() + " cannot return data as float array."); } - public boolean isByteTensor() { - return TYPE_CODE_BYTE == getTypeCode(); - } - - public boolean isIntTensor() { - return TYPE_CODE_INT32 == getTypeCode(); + public long[] getDataAsLongArray() { + throw new IllegalStateException( + "Tensor of type " + getClass().getSimpleName() + " cannot return data as float array."); } - public boolean isFloatTensor() { - return TYPE_CODE_FLOAT32 == getTypeCode(); + public double[] getDataAsDoubleArray() { + throw new IllegalStateException( + "Tensor of type " + getClass().getSimpleName() + " cannot return data as double array."); } - abstract int getTypeCode(); - Buffer getRawDataBuffer() { throw new IllegalStateException( "Tensor of type " + getClass().getSimpleName() + " cannot " + "return raw data buffer."); } - private static String invalidIndexErrorMessage(int[] index, long dims[]) { - return String.format( - Locale.US, - "Invalid index %s for tensor dimensions %s", - Arrays.toString(index), - Arrays.toString(dims)); + static class Tensor_byte extends Tensor { + private final ByteBuffer data; + + private Tensor_byte(ByteBuffer data, long[] dims) { + super(dims); + this.data = data; + } + + @Override + public int dtype() { + return DTYPE_BYTE; + } + + @Override + Buffer getRawDataBuffer() { + return data; + } + + @Override + public byte[] getDataAsByteArray() { + data.rewind(); + byte[] arr = new byte[data.remaining()]; + data.get(arr); + return arr; + } + + @Override + public String toString() { + return String.format( + "Tensor_byte{shape:%s numel:%d}", Arrays.toString(shape), data.capacity()); + } + } + + static class Tensor_int32 extends Tensor { + private final IntBuffer data; + + private Tensor_int32(IntBuffer data, long[] dims) { + super(dims); + this.data = data; + } + + @Override + public int dtype() { + return DTYPE_INT32; + } + + @Override + Buffer getRawDataBuffer() { + return data; + } + + @Override + public int[] getDataAsIntArray() { + data.rewind(); + int[] arr = new int[data.remaining()]; + data.get(arr); + return arr; + } + + @Override + public String toString() { + return String.format( + "Tensor_int32{shape:%s numel:%d}", Arrays.toString(shape), data.capacity()); + } } static class Tensor_float32 extends Tensor { @@ -187,8 +277,8 @@ public float[] getDataAsFloatArray() { } @Override - int getTypeCode() { - return TYPE_CODE_FLOAT32; + public int dtype() { + return DTYPE_FLOAT32; } @Override @@ -199,22 +289,21 @@ Buffer getRawDataBuffer() { @Override public String toString() { return String.format( - "Tensor_float32{dims:%s data:%s}", - Arrays.toString(dims), Arrays.toString(getDataAsFloatArray())); + "Tensor_float32{shape:%s capacity:%d}", Arrays.toString(shape), data.capacity()); } } - static class Tensor_int32 extends Tensor { - private final IntBuffer data; + static class Tensor_long64 extends Tensor { + private final LongBuffer data; - private Tensor_int32(IntBuffer data, long[] dims) { + private Tensor_long64(LongBuffer data, long[] dims) { super(dims); this.data = data; } @Override - int getTypeCode() { - return TYPE_CODE_INT32; + public int dtype() { + return DTYPE_LONG64; } @Override @@ -223,9 +312,9 @@ Buffer getRawDataBuffer() { } @Override - public int[] getDataAsIntArray() { + public long[] getDataAsLongArray() { data.rewind(); - int[] arr = new int[data.remaining()]; + long[] arr = new long[data.remaining()]; data.get(arr); return arr; } @@ -233,22 +322,21 @@ public int[] getDataAsIntArray() { @Override public String toString() { return String.format( - "Tensor_int32{dims:%s data:%s}", - Arrays.toString(dims), Arrays.toString(getDataAsIntArray())); + "Tensor_long64{shape:%s numel:%d}", Arrays.toString(shape), data.capacity()); } } - static class Tensor_byte extends Tensor { - private final ByteBuffer data; + static class Tensor_double64 extends Tensor { + private final DoubleBuffer data; - private Tensor_byte(ByteBuffer data, long[] dims) { - super(dims); + private Tensor_double64(DoubleBuffer data, long[] shape) { + super(shape); this.data = data; } @Override - int getTypeCode() { - return TYPE_CODE_BYTE; + public int dtype() { + return DTYPE_DOUBLE64; } @Override @@ -257,9 +345,9 @@ Buffer getRawDataBuffer() { } @Override - public byte[] getDataAsByteArray() { + public double[] getDataAsDoubleArray() { data.rewind(); - byte[] arr = new byte[data.remaining()]; + double[] arr = new double[data.remaining()]; data.get(arr); return arr; } @@ -267,8 +355,7 @@ public byte[] getDataAsByteArray() { @Override public String toString() { return String.format( - "Tensor_byte{dims:%s data:%s}", - Arrays.toString(dims), Arrays.toString(getDataAsByteArray())); + "Tensor_double64{shape:%s numel:%d}", Arrays.toString(shape), data.capacity()); } } @@ -279,30 +366,16 @@ private static void checkArgument(boolean expression, String errorMessage, Objec } } - private static void checkDims(long[] dims) { - checkArgument(dims != null, ERROR_MSG_DIMS_NOT_NULL); - checkArgument(dims.length > 0, ERROR_MSG_DIMS_NOT_EMPTY); - for (int i = 0; i < dims.length; i++) { - checkArgument(dims[i] >= 0, ERROR_MSG_DIMS_NON_NEGATIVE); - } - } - - private static void checkIndex(int[] index, long dims[]) { - checkArgument(dims != null, ERROR_MSG_INDEX_NOT_NULL); - - if (index.length != dims.length) { - throw new IllegalArgumentException(invalidIndexErrorMessage(index, dims)); - } - - for (int i = 0; i < index.length; i++) { - if (index[i] >= dims[i]) { - throw new IllegalArgumentException(invalidIndexErrorMessage(index, dims)); - } + private static void checkShape(long[] shape) { + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + checkArgument(shape.length > 0, ERROR_MSG_SHAPE_NOT_EMPTY); + for (int i = 0; i < shape.length; i++) { + checkArgument(shape[i] >= 0, ERROR_MSG_SHAPE_NON_NEGATIVE); } } - private static void checkDimsAndDataCapacityConsistency(int dataCapacity, long[] dims) { - final long numElements = numElements(dims); + private static void checkShapeAndDataCapacityConsistency(int dataCapacity, long[] dims) { + final long numElements = numel(dims); checkArgument( numElements == dataCapacity, "Inconsistent data capacity:%d and dims number elements:%d dims:%s", @@ -313,14 +386,18 @@ private static void checkDimsAndDataCapacityConsistency(int dataCapacity, long[] // endregion checks // Called from native - private static Tensor nativeNewTensor(ByteBuffer data, long[] dims, int typeCode) { - if (TYPE_CODE_FLOAT32 == typeCode) { - return new Tensor_float32(data.asFloatBuffer(), dims); - } else if (TYPE_CODE_INT32 == typeCode) { - return new Tensor_int32(data.asIntBuffer(), dims); - } else if (TYPE_CODE_BYTE == typeCode) { - return new Tensor_byte(data, dims); + private static Tensor nativeNewTensor(ByteBuffer data, long[] shape, int dtype) { + if (DTYPE_FLOAT32 == dtype) { + return new Tensor_float32(data.asFloatBuffer(), shape); + } else if (DTYPE_INT32 == dtype) { + return new Tensor_int32(data.asIntBuffer(), shape); + } else if (DTYPE_LONG64 == dtype) { + return new Tensor_long64(data.asLongBuffer(), shape); + } else if (DTYPE_DOUBLE64 == dtype) { + return new Tensor_double64(data.asDoubleBuffer(), shape); + } else if (DTYPE_BYTE == dtype) { + return new Tensor_byte(data, shape); } - throw new IllegalArgumentException("Unknown Tensor typeCode"); + throw new IllegalArgumentException("Unknown Tensor dtype"); } } diff --git a/android/pytorch_android_torchvision/src/androidTest/java/org/pytorch/torchvision/TorchVisionInstrumentedTests.java b/android/pytorch_android_torchvision/src/androidTest/java/org/pytorch/torchvision/TorchVisionInstrumentedTests.java index ecb64840d2481..305bcc48fad63 100644 --- a/android/pytorch_android_torchvision/src/androidTest/java/org/pytorch/torchvision/TorchVisionInstrumentedTests.java +++ b/android/pytorch_android_torchvision/src/androidTest/java/org/pytorch/torchvision/TorchVisionInstrumentedTests.java @@ -23,6 +23,6 @@ public void setUp() { public void smokeTest() { Bitmap bitmap = Bitmap.createBitmap(320, 240, Bitmap.Config.ARGB_8888); Tensor tensor = TensorImageUtils.bitmapToFloatTensorTorchVisionForm(bitmap); - assertArrayEquals(new long[] {1l, 3l, 240l, 320l}, tensor.dims); + assertArrayEquals(new long[] {1l, 3l, 240l, 320l}, tensor.shape); } } diff --git a/android/pytorch_android_torchvision/src/main/java/org/pytorch/torchvision/TensorImageUtils.java b/android/pytorch_android_torchvision/src/main/java/org/pytorch/torchvision/TensorImageUtils.java index 70222d482cdf5..d68542bd3b61b 100644 --- a/android/pytorch_android_torchvision/src/main/java/org/pytorch/torchvision/TensorImageUtils.java +++ b/android/pytorch_android_torchvision/src/main/java/org/pytorch/torchvision/TensorImageUtils.java @@ -39,8 +39,8 @@ public static Tensor bitmapToFloatTensorTorchVisionForm( floatArray[offset_g + i] = (g - NORM_MEAN_G) / NORM_STD_G; floatArray[offset_b + i] = (b - NORM_MEAN_B) / NORM_STD_B; } - final long dims[] = new long[] {1, 3, height, width}; - return Tensor.newFloatTensor(dims, floatArray); + final long shape[] = new long[] {1, 3, height, width}; + return Tensor.newTensor(shape, floatArray); } public static Tensor imageYUV420CenterCropToFloatTensorTorchVisionForm( @@ -130,8 +130,8 @@ public static Tensor imageYUV420CenterCropToFloatTensorTorchVisionForm( floatArray[tensorInputOffsetB + offset] = ((b / 255.f) - NORM_MEAN_B) / NORM_STD_B; } } - final long dims[] = new long[] {1, 3, tensorHeight, tensorHeight}; - return Tensor.newFloatTensor(dims, floatArray); + final long shape[] = new long[] {1, 3, tensorHeight, tensorHeight}; + return Tensor.newTensor(shape, floatArray); } private static final int clamp(int c, int min, int max) {