diff --git a/examples/demo-apps/android/ExecuTorchDemo/app/src/main/AndroidManifest.xml b/examples/demo-apps/android/ExecuTorchDemo/app/src/main/AndroidManifest.xml index ff9e3b880ed..9958047f4a8 100644 --- a/examples/demo-apps/android/ExecuTorchDemo/app/src/main/AndroidManifest.xml +++ b/examples/demo-apps/android/ExecuTorchDemo/app/src/main/AndroidManifest.xml @@ -1,9 +1,13 @@ + xmlns:tools="http://schemas.android.com/tools" + package="com.example.executorchdemo"> + + - + tools:targetApi="34"> diff --git a/examples/demo-apps/android/ExecuTorchDemo/app/src/main/java/com/example/executorchdemo/ClassificationActivity.java b/examples/demo-apps/android/ExecuTorchDemo/app/src/main/java/com/example/executorchdemo/ClassificationActivity.java index bd5dfb3b443..8c4dd8f8def 100644 --- a/examples/demo-apps/android/ExecuTorchDemo/app/src/main/java/com/example/executorchdemo/ClassificationActivity.java +++ b/examples/demo-apps/android/ExecuTorchDemo/app/src/main/java/com/example/executorchdemo/ClassificationActivity.java @@ -18,11 +18,11 @@ import android.widget.Button; import android.widget.ImageView; import android.widget.TextView; -import com.example.executorchdemo.executor.EValue; -import com.example.executorchdemo.executor.Module; -import com.example.executorchdemo.executor.Tensor; -import com.example.executorchdemo.executor.TensorImageUtils; import java.io.IOException; +import org.pytorch.executorch.EValue; +import org.pytorch.executorch.Module; +import org.pytorch.executorch.Tensor; +import org.pytorch.executorch.TensorImageUtils; public class ClassificationActivity extends Activity implements Runnable { diff --git a/examples/demo-apps/android/ExecuTorchDemo/app/src/main/java/com/example/executorchdemo/MainActivity.java b/examples/demo-apps/android/ExecuTorchDemo/app/src/main/java/com/example/executorchdemo/MainActivity.java index 9c10ad9ff0b..cdb1ac1983b 100644 --- a/examples/demo-apps/android/ExecuTorchDemo/app/src/main/java/com/example/executorchdemo/MainActivity.java +++ b/examples/demo-apps/android/ExecuTorchDemo/app/src/main/java/com/example/executorchdemo/MainActivity.java @@ -21,16 +21,16 @@ import android.widget.Button; import android.widget.ImageView; import android.widget.ProgressBar; -import com.example.executorchdemo.executor.EValue; -import com.example.executorchdemo.executor.Module; -import com.example.executorchdemo.executor.Tensor; -import com.example.executorchdemo.executor.TensorImageUtils; import java.io.File; import java.io.FileOutputStream; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.util.Objects; +import org.pytorch.executorch.EValue; +import org.pytorch.executorch.Module; +import org.pytorch.executorch.Tensor; +import org.pytorch.executorch.TensorImageUtils; public class MainActivity extends Activity implements Runnable { private ImageView mImageView; diff --git a/examples/demo-apps/android/ExecuTorchDemo/app/src/main/res/values/styles.xml b/examples/demo-apps/android/ExecuTorchDemo/app/src/main/res/values/styles.xml index fac92916801..391ec9ae3b7 100644 --- a/examples/demo-apps/android/ExecuTorchDemo/app/src/main/res/values/styles.xml +++ b/examples/demo-apps/android/ExecuTorchDemo/app/src/main/res/values/styles.xml @@ -7,4 +7,4 @@ @color/colorAccent - \ No newline at end of file + diff --git a/extension/android/src/main/java/org/pytorch/executorch/DType.java b/extension/android/src/main/java/org/pytorch/executorch/DType.java new file mode 100644 index 00000000000..8b3fb42a6ad --- /dev/null +++ b/extension/android/src/main/java/org/pytorch/executorch/DType.java @@ -0,0 +1,69 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.executorch; + +/** Codes representing tensor data types. */ +public enum DType { + // NOTE: "jniCode" must be kept in sync with scalar_type.h. + // NOTE: Never serialize "jniCode", because it can change between releases. + + /** Code for dtype torch::executor::Byte */ + UINT8(0), + /** Code for dtype torch::executor::Char */ + INT8(1), + /** Code for dtype torch::executor::Short */ + INT16(2), + /** Code for dtype torch::executor::Int */ + INT32(3), + /** Code for dtype torch::executor::Long */ + INT64(4), + /** Code for dtype torch::executor::Half */ + HALF(5), + /** Code for dtype torch::executor::Float */ + FLOAT(6), + /** Code for dtype torch::executor::Double */ + DOUBLE(7), + /** Code for dtype torch::executor::ComplexHalf */ + COMPLEX_HALF(8), + /** Code for dtype torch::executor::ComplexFloat */ + COMPLEX_FLOAT(9), + /** Code for dtype torch::executor::ComplexDouble */ + COMPLEX_DOUBLE(10), + /** Code for dtype torch::executor::Bool */ + BOOL(11), + /** Code for dtype torch::executor::QInt8 */ + QINT8(12), + /** Code for dtype torch::executor::QUInt8 */ + QUINT8(13), + /** Code for dtype torch::executor::QInt32 */ + QINT32(14), + /** Code for dtype torch::executor::BFloat16 */ + BFLOAT16(15), + /** Code for dtype torch::executor::QUInt4x2 */ + QINT4X2(16), + /** Code for dtype torch::executor::QUInt2x4 */ + QINT2X4(17), + /** Code for dtype torch::executor::Bits1x8 */ + BITS1X8(18), + /** Code for dtype torch::executor::Bits2x4 */ + BITS2X4(19), + /** Code for dtype torch::executor::Bits4x2 */ + BITS4X2(20), + /** Code for dtype torch::executor::Bits8 */ + BITS8(21), + /** Code for dtype torch::executor::Bits16 */ + BITS16(22), + ; + + final int jniCode; + + DType(int jniCode) { + this.jniCode = jniCode; + } +} diff --git a/extension/android/src/main/java/org/pytorch/executorch/EValue.java b/extension/android/src/main/java/org/pytorch/executorch/EValue.java new file mode 100644 index 00000000000..7926eadafff --- /dev/null +++ b/extension/android/src/main/java/org/pytorch/executorch/EValue.java @@ -0,0 +1,279 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.executorch; + +import com.facebook.jni.annotations.DoNotStrip; +import java.util.Locale; +import java.util.Optional; + +/** + * Java representation of an ExecuTorch value, which is implemented as tagged union that can be one + * of the supported types: https://pytorch.org/docs/stable/jit.html#types . + * + *

Calling {@code toX} methods for inappropriate types will throw {@link IllegalStateException}. + * + *

{@code EValue} objects are constructed with {@code EValue.from(value)}, {@code + * EValue.tupleFrom(value1, value2, ...)}, {@code EValue.listFrom(value1, value2, ...)}, or one of + * the {@code dict} methods, depending on the key type. + * + *

Data is retrieved from {@code EValue} objects with the {@code toX()} methods. Note that {@code + * str}-type EValues must be extracted with {@link #toStr()}, rather than {@link #toString()}. + * + *

{@code EValue} objects may retain references to objects passed into their constructors, and + * may return references to their internal state from {@code toX()}. + */ +@DoNotStrip +public class EValue { + private static final int TYPE_CODE_NONE = 0; + + private static final int TYPE_CODE_TENSOR = 1; + private static final int TYPE_CODE_STRING = 2; + private static final int TYPE_CODE_DOUBLE = 3; + private static final int TYPE_CODE_INT = 4; + private static final int TYPE_CODE_BOOL = 5; + + private static final int TYPE_CODE_LIST_BOOL = 6; + private static final int TYPE_CODE_LIST_DOUBLE = 7; + private static final int TYPE_CODE_LIST_INT = 8; + private static final int TYPE_CODE_LIST_TENSOR = 9; + private static final int TYPE_CODE_LIST_SCALAR = 10; + private static final int TYPE_CODE_LIST_OPTIONAL_TENSOR = 11; + + private String[] TYPE_NAMES = { + "None", + "Tensor", + "String", + "Double", + "Int", + "Bool", + "ListBool", + "ListDouble", + "ListInt", + "ListTensor", + "ListScalar", + "ListOptionalScalar", + }; + + @DoNotStrip private final int mTypeCode; + @DoNotStrip private Object mData; + + @DoNotStrip + private EValue(int typeCode) { + this.mTypeCode = typeCode; + } + + @DoNotStrip + public boolean isNone() { + return TYPE_CODE_NONE == this.mTypeCode; + } + + @DoNotStrip + public boolean isTensor() { + return TYPE_CODE_TENSOR == this.mTypeCode; + } + + @DoNotStrip + public boolean isBool() { + return TYPE_CODE_BOOL == this.mTypeCode; + } + + @DoNotStrip + public boolean isInt() { + return TYPE_CODE_INT == this.mTypeCode; + } + + @DoNotStrip + public boolean isDouble() { + return TYPE_CODE_DOUBLE == this.mTypeCode; + } + + @DoNotStrip + public boolean isString() { + return TYPE_CODE_STRING == this.mTypeCode; + } + + @DoNotStrip + public boolean isBoolList() { + return TYPE_CODE_LIST_BOOL == this.mTypeCode; + } + + @DoNotStrip + public boolean isIntList() { + return TYPE_CODE_LIST_INT == this.mTypeCode; + } + + @DoNotStrip + public boolean isDoubleList() { + return TYPE_CODE_LIST_DOUBLE == this.mTypeCode; + } + + @DoNotStrip + public boolean isTensorList() { + return TYPE_CODE_LIST_TENSOR == this.mTypeCode; + } + + @DoNotStrip + public boolean isOptionalTensorList() { + return TYPE_CODE_LIST_OPTIONAL_TENSOR == this.mTypeCode; + } + + /** Creates a new {@code EValue} of type {@code Optional} that contains no value. */ + @DoNotStrip + public static EValue optionalNone() { + return new EValue(TYPE_CODE_NONE); + } + + /** Creates a new {@code EValue} of type {@code Tensor}. */ + @DoNotStrip + public static EValue from(Tensor tensor) { + final EValue iv = new EValue(TYPE_CODE_TENSOR); + iv.mData = tensor; + return iv; + } + /** Creates a new {@code EValue} of type {@code bool}. */ + @DoNotStrip + public static EValue from(boolean value) { + final EValue iv = new EValue(TYPE_CODE_BOOL); + iv.mData = value; + return iv; + } + + /** Creates a new {@code EValue} of type {@code int}. */ + @DoNotStrip + public static EValue from(long value) { + final EValue iv = new EValue(TYPE_CODE_INT); + iv.mData = value; + return iv; + } + + /** Creates a new {@code EValue} of type {@code double}. */ + @DoNotStrip + public static EValue from(double value) { + final EValue iv = new EValue(TYPE_CODE_DOUBLE); + iv.mData = value; + return iv; + } + + /** Creates a new {@code EValue} of type {@code str}. */ + @DoNotStrip + public static EValue from(String value) { + final EValue iv = new EValue(TYPE_CODE_STRING); + iv.mData = value; + return iv; + } + + /** Creates a new {@code EValue} of type {@code List[bool]}. */ + @DoNotStrip + public static EValue listFrom(boolean... list) { + final EValue iv = new EValue(TYPE_CODE_LIST_BOOL); + iv.mData = list; + return iv; + } + + /** Creates a new {@code EValue} of type {@code List[int]}. */ + @DoNotStrip + public static EValue listFrom(long... list) { + final EValue iv = new EValue(TYPE_CODE_LIST_INT); + iv.mData = list; + return iv; + } + + /** Creates a new {@code EValue} of type {@code List[double]}. */ + @DoNotStrip + public static EValue listFrom(double... list) { + final EValue iv = new EValue(TYPE_CODE_LIST_DOUBLE); + iv.mData = list; + return iv; + } + + /** Creates a new {@code EValue} of type {@code List[Tensor]}. */ + @DoNotStrip + public static EValue listFrom(Tensor... list) { + final EValue iv = new EValue(TYPE_CODE_LIST_TENSOR); + iv.mData = list; + return iv; + } + + /** Creates a new {@code EValue} of type {@code List[Optional[Tensor]]}. */ + @DoNotStrip + public static EValue listFrom(Optional... list) { + final EValue iv = new EValue(TYPE_CODE_LIST_OPTIONAL_TENSOR); + iv.mData = list; + return iv; + } + + @DoNotStrip + public Tensor toTensor() { + preconditionType(TYPE_CODE_TENSOR, mTypeCode); + return (Tensor) mData; + } + + @DoNotStrip + public boolean toBool() { + preconditionType(TYPE_CODE_BOOL, mTypeCode); + return (boolean) mData; + } + + @DoNotStrip + public long toInt() { + preconditionType(TYPE_CODE_INT, mTypeCode); + return (long) mData; + } + + @DoNotStrip + public double toDouble() { + preconditionType(TYPE_CODE_DOUBLE, mTypeCode); + return (double) mData; + } + + @DoNotStrip + public String toStr() { + preconditionType(TYPE_CODE_STRING, mTypeCode); + return (String) mData; + } + + @DoNotStrip + public boolean[] toBoolList() { + preconditionType(TYPE_CODE_LIST_BOOL, mTypeCode); + return (boolean[]) mData; + } + + @DoNotStrip + public long[] toIntList() { + preconditionType(TYPE_CODE_LIST_INT, mTypeCode); + return (long[]) mData; + } + + @DoNotStrip + public double[] toDoubleList() { + preconditionType(TYPE_CODE_LIST_DOUBLE, mTypeCode); + return (double[]) mData; + } + + @DoNotStrip + public Tensor[] toTensorList() { + preconditionType(TYPE_CODE_LIST_TENSOR, mTypeCode); + return (Tensor[]) mData; + } + + private void preconditionType(int typeCodeExpected, int typeCode) { + if (typeCode != typeCodeExpected) { + throw new IllegalStateException( + String.format( + Locale.US, + "Expected EValue type %s, actual type %s", + getTypeName(typeCodeExpected), + getTypeName(typeCode))); + } + } + + private String getTypeName(int typeCode) { + return typeCode >= 0 && typeCode < TYPE_NAMES.length ? TYPE_NAMES[typeCode] : "Unknown"; + } +} diff --git a/extension/android/src/main/java/org/pytorch/executorch/INativePeer.java b/extension/android/src/main/java/org/pytorch/executorch/INativePeer.java new file mode 100644 index 00000000000..976e732d3f5 --- /dev/null +++ b/extension/android/src/main/java/org/pytorch/executorch/INativePeer.java @@ -0,0 +1,21 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.executorch; + +/** Interface for the native peer object for entry points to the Module */ +interface INativePeer { + /** Clean up the native resources associated with this instance */ + void resetNative(); + + /** Run a "forward" call with the given inputs */ + EValue forward(EValue... inputs); + + /** Run an arbitrary method on the module */ + EValue execute(String methodName, EValue... inputs); +} diff --git a/extension/android/src/main/java/org/pytorch/executorch/Module.java b/extension/android/src/main/java/org/pytorch/executorch/Module.java new file mode 100644 index 00000000000..3f1546d009a --- /dev/null +++ b/extension/android/src/main/java/org/pytorch/executorch/Module.java @@ -0,0 +1,79 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.executorch; + +import com.facebook.soloader.nativeloader.NativeLoader; +import com.facebook.soloader.nativeloader.SystemDelegate; +import java.util.Map; + +/** Java wrapper for ExecuTorch Module. */ +public class Module { + + /** Reference to the INativePeer object of this module. */ + private INativePeer mNativePeer; + + /** + * Loads a serialized ExecuTorch module from the specified path on the disk. + * + * @param modelPath path to file that contains the serialized ExecuTorch module. + * @param extraFiles map with extra files names as keys, content of them will be loaded to values. + * @return new {@link org.pytorch.executorch.Module} object which owns torch::jit::Module. + */ + public static Module load(final String modelPath, final Map extraFiles) { + if (!NativeLoader.isInitialized()) { + NativeLoader.init(new SystemDelegate()); + } + return new Module(new NativePeer(modelPath, extraFiles)); + } + + /** + * Loads a serialized ExecuTorch module from the specified path on the disk to run on CPU. + * + * @param modelPath path to file that contains the serialized ExecuTorch module. + * @return new {@link org.pytorch.executorch.Module} object which owns torch::jit::Module. + */ + public static Module load(final String modelPath) { + return load(modelPath, null); + } + + Module(INativePeer nativePeer) { + this.mNativePeer = nativePeer; + } + + /** + * Runs the 'forward' method of this module with the specified arguments. + * + * @param inputs arguments for the ExecuTorch module's 'forward' method. + * @return return value from the 'forward' method. + */ + public EValue forward(EValue... inputs) { + return mNativePeer.forward(inputs); + } + + /** + * Runs the specified method of this module with the specified arguments. + * + * @param methodName name of the ExecuTorch method to run. + * @param inputs arguments that will be passed to ExecuTorch method. + * @return return value from the method. + */ + public EValue execute(String methodName, EValue... inputs) { + return mNativePeer.execute(methodName, inputs); + } + + /** + * Explicitly destroys the native torch::jit::Module. Calling this method is not required, as the + * native object will be destroyed when this object is garbage-collected. However, the timing of + * garbage collection is not guaranteed, so proactively calling {@code destroy} can free memory + * more quickly. See {@link com.facebook.jni.HybridData#resetNative}. + */ + public void destroy() { + mNativePeer.resetNative(); + } +} diff --git a/extension/android/src/main/java/org/pytorch/executorch/NativePeer.java b/extension/android/src/main/java/org/pytorch/executorch/NativePeer.java new file mode 100644 index 00000000000..029c173ee90 --- /dev/null +++ b/extension/android/src/main/java/org/pytorch/executorch/NativePeer.java @@ -0,0 +1,41 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.executorch; + +import com.facebook.jni.HybridData; +import com.facebook.jni.annotations.DoNotStrip; +import com.facebook.soloader.nativeloader.NativeLoader; +import java.util.Map; + +class NativePeer implements INativePeer { + static { + // Loads libexecutorch.so from jniLibs + NativeLoader.loadLibrary("executorch"); + } + + private final HybridData mHybridData; + + @DoNotStrip + private static native HybridData initHybrid( + String moduleAbsolutePath, Map extraFiles); + + NativePeer(String moduleAbsolutePath, Map extraFiles) { + mHybridData = initHybrid(moduleAbsolutePath, extraFiles); + } + + public void resetNative() { + mHybridData.resetNative(); + } + + @DoNotStrip + public native EValue forward(EValue... inputs); + + @DoNotStrip + public native EValue execute(String methodName, EValue... inputs); +} diff --git a/extension/android/src/main/java/org/pytorch/executorch/Tensor.java b/extension/android/src/main/java/org/pytorch/executorch/Tensor.java new file mode 100644 index 00000000000..bc18dd43330 --- /dev/null +++ b/extension/android/src/main/java/org/pytorch/executorch/Tensor.java @@ -0,0 +1,675 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.executorch; + +import com.facebook.jni.HybridData; +import com.facebook.jni.annotations.DoNotStrip; +import java.nio.Buffer; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.DoubleBuffer; +import java.nio.FloatBuffer; +import java.nio.IntBuffer; +import java.nio.LongBuffer; +import java.util.Arrays; +import java.util.Locale; + +/** + * Representation of an ExecuTorch Tensor. Behavior is similar to PyTorch's tensor objects. + * + *

Most tensors will be constructed as {@code Tensor.fromBlob(data, shape)}, where {@code data} + * can be an array or a direct {@link Buffer} (of the proper subclass). Helper methods are provided + * to allocate buffers properly. + * + *

To access Tensor data, see {@link #dtype()}, {@link #shape()}, and various {@code getDataAs*} + * methods. + * + *

When constructing {@code Tensor} objects with {@code data} as an array, it is not specified + * whether this data is copied or retained as a reference so it is recommended not to modify it + * after constructing. {@code data} passed as a {@link Buffer} is not copied, so it can be modified + * between {@link Module} calls to avoid reallocation. Data retrieved from {@code Tensor} objects + * may be copied or may be a reference to the {@code Tensor}'s internal data buffer. {@code shape} + * is always copied. + */ +public abstract class Tensor { + private static final String ERROR_MSG_DATA_BUFFER_NOT_NULL = "Data buffer must be not null"; + private static final String ERROR_MSG_DATA_ARRAY_NOT_NULL = "Data array must be not null"; + private static final String ERROR_MSG_SHAPE_NOT_NULL = "Shape must be not null"; + private static final String ERROR_MSG_SHAPE_NON_NEGATIVE = "Shape elements must be non negative"; + private static final String ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER = + "Data buffer must have native byte order (java.nio.ByteOrder#nativeOrder)"; + private static final String ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT = + "Data buffer must be direct (java.nio.ByteBuffer#allocateDirect)"; + + @DoNotStrip final long[] shape; + + private static final int INT_SIZE_BYTES = 4; + private static final int FLOAT_SIZE_BYTES = 4; + private static final int LONG_SIZE_BYTES = 8; + private static final int DOUBLE_SIZE_BYTES = 8; + + /** + * Allocates a new direct {@link ByteBuffer} with native byte order with specified capacity that + * can be used in {@link Tensor#fromBlob(ByteBuffer, long[])}, {@link + * Tensor#fromBlobUnsigned(ByteBuffer, long[])}. + * + * @param numElements capacity (number of elements) of result buffer. + */ + public static ByteBuffer allocateByteBuffer(int numElements) { + return ByteBuffer.allocateDirect(numElements).order(ByteOrder.nativeOrder()); + } + + /** + * Allocates a new direct {@link IntBuffer} with native byte order with specified capacity that + * can be used in {@link Tensor#fromBlob(IntBuffer, long[])}. + * + * @param numElements capacity (number of elements) of result buffer. + */ + public static IntBuffer allocateIntBuffer(int numElements) { + return ByteBuffer.allocateDirect(numElements * INT_SIZE_BYTES) + .order(ByteOrder.nativeOrder()) + .asIntBuffer(); + } + + /** + * Allocates a new direct {@link FloatBuffer} with native byte order with specified capacity that + * can be used in {@link Tensor#fromBlob(FloatBuffer, long[])}. + * + * @param numElements capacity (number of elements) of result buffer. + */ + public static FloatBuffer allocateFloatBuffer(int numElements) { + return ByteBuffer.allocateDirect(numElements * FLOAT_SIZE_BYTES) + .order(ByteOrder.nativeOrder()) + .asFloatBuffer(); + } + + /** + * Allocates a new direct {@link LongBuffer} with native byte order with specified capacity that + * can be used in {@link Tensor#fromBlob(LongBuffer, long[])}. + * + * @param numElements capacity (number of elements) of result buffer. + */ + public static LongBuffer allocateLongBuffer(int numElements) { + return ByteBuffer.allocateDirect(numElements * LONG_SIZE_BYTES) + .order(ByteOrder.nativeOrder()) + .asLongBuffer(); + } + + /** + * Allocates a new direct {@link DoubleBuffer} with native byte order with specified capacity that + * can be used in {@link Tensor#fromBlob(DoubleBuffer, long[])}. + * + * @param numElements capacity (number of elements) of result buffer. + */ + public static DoubleBuffer allocateDoubleBuffer(int numElements) { + return ByteBuffer.allocateDirect(numElements * DOUBLE_SIZE_BYTES) + .order(ByteOrder.nativeOrder()) + .asDoubleBuffer(); + } + + /** + * Creates a new Tensor instance with dtype torch.uint8 with specified shape and data as array of + * bytes. + * + * @param data Tensor elements + * @param shape Tensor shape + */ + public static Tensor fromBlobUnsigned(byte[] data, long[] shape) { + checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL); + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + checkShape(shape); + checkShapeAndDataCapacityConsistency(data.length, shape); + final ByteBuffer byteBuffer = allocateByteBuffer((int) numel(shape)); + byteBuffer.put(data); + return new Tensor_uint8(byteBuffer, shape); + } + + /** + * Creates a new Tensor instance with dtype torch.int8 with specified shape and data as array of + * bytes. + * + * @param data Tensor elements + * @param shape Tensor shape + */ + public static Tensor fromBlob(byte[] data, long[] shape) { + checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL); + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + checkShape(shape); + checkShapeAndDataCapacityConsistency(data.length, shape); + final ByteBuffer byteBuffer = allocateByteBuffer((int) numel(shape)); + byteBuffer.put(data); + return new Tensor_int8(byteBuffer, shape); + } + + /** + * Creates a new Tensor instance with dtype torch.int32 with specified shape and data as array of + * ints. + * + * @param data Tensor elements + * @param shape Tensor shape + */ + public static Tensor fromBlob(int[] data, long[] shape) { + checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL); + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + checkShape(shape); + checkShapeAndDataCapacityConsistency(data.length, shape); + final IntBuffer intBuffer = allocateIntBuffer((int) numel(shape)); + intBuffer.put(data); + return new Tensor_int32(intBuffer, shape); + } + + /** + * Creates a new Tensor instance with dtype torch.float32 with specified shape and data as array + * of floats. + * + * @param data Tensor elements + * @param shape Tensor shape + */ + public static Tensor fromBlob(float[] data, long[] shape) { + checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL); + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + checkShape(shape); + checkShapeAndDataCapacityConsistency(data.length, shape); + final FloatBuffer floatBuffer = allocateFloatBuffer((int) numel(shape)); + floatBuffer.put(data); + return new Tensor_float32(floatBuffer, shape); + } + + /** + * Creates a new Tensor instance with dtype torch.int64 with specified shape and data as array of + * longs. + * + * @param data Tensor elements + * @param shape Tensor shape + */ + public static Tensor fromBlob(long[] data, long[] shape) { + checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL); + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + checkShape(shape); + checkShapeAndDataCapacityConsistency(data.length, shape); + final LongBuffer longBuffer = allocateLongBuffer((int) numel(shape)); + longBuffer.put(data); + return new Tensor_int64(longBuffer, shape); + } + + /** + * Creates a new Tensor instance with dtype torch.float64 with specified shape and data as array + * of doubles. + * + * @param shape Tensor shape + * @param data Tensor elements + */ + public static Tensor fromBlob(double[] data, long[] shape) { + checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL); + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + checkShape(shape); + checkShapeAndDataCapacityConsistency(data.length, shape); + final DoubleBuffer doubleBuffer = allocateDoubleBuffer((int) numel(shape)); + doubleBuffer.put(data); + return new Tensor_float64(doubleBuffer, shape); + } + + /** + * Creates a new Tensor instance with dtype torch.uint8 with specified shape and data. + * + * @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)} + * elements. The buffer is used directly without copying, and changes to its content will + * change the tensor. + * @param shape Tensor shape + */ + public static Tensor fromBlobUnsigned(ByteBuffer data, long[] shape) { + checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL); + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + checkShape(shape); + checkShapeAndDataCapacityConsistency(data.capacity(), shape); + checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT); + checkArgument( + (data.order() == ByteOrder.nativeOrder()), + ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER); + return new Tensor_uint8(data, shape); + } + + /** + * Creates a new Tensor instance with dtype torch.int8 with specified shape and data. + * + * @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)} + * elements. The buffer is used directly without copying, and changes to its content will + * change the tensor. + * @param shape Tensor shape + */ + public static Tensor fromBlob(ByteBuffer data, long[] shape) { + checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL); + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + checkShape(shape); + checkShapeAndDataCapacityConsistency(data.capacity(), shape); + checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT); + checkArgument( + (data.order() == ByteOrder.nativeOrder()), + ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER); + return new Tensor_int8(data, shape); + } + + /** + * Creates a new Tensor instance with dtype torch.int32 with specified shape and data. + * + * @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)} + * elements. The buffer is used directly without copying, and changes to its content will + * change the tensor. + * @param shape Tensor shape + */ + public static Tensor fromBlob(IntBuffer data, long[] shape) { + checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL); + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + checkShape(shape); + checkShapeAndDataCapacityConsistency(data.capacity(), shape); + checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT); + checkArgument( + (data.order() == ByteOrder.nativeOrder()), + ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER); + return new Tensor_int32(data, shape); + } + + /** + * Creates a new Tensor instance with dtype torch.float32 with specified shape and data. + * + * @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)} + * elements. The buffer is used directly without copying, and changes to its content will + * change the tensor. + * @param shape Tensor shape + */ + public static Tensor fromBlob(FloatBuffer data, long[] shape) { + checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL); + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + checkShape(shape); + checkShapeAndDataCapacityConsistency(data.capacity(), shape); + checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT); + checkArgument( + (data.order() == ByteOrder.nativeOrder()), + ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER); + return new Tensor_float32(data, shape); + } + + /** + * Creates a new Tensor instance with dtype torch.int64 with specified shape and data. + * + * @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)} + * elements. The buffer is used directly without copying, and changes to its content will + * change the tensor. + * @param shape Tensor shape + */ + public static Tensor fromBlob(LongBuffer data, long[] shape) { + checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL); + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + checkShape(shape); + checkShapeAndDataCapacityConsistency(data.capacity(), shape); + checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT); + checkArgument( + (data.order() == ByteOrder.nativeOrder()), + ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER); + return new Tensor_int64(data, shape); + } + + /** + * Creates a new Tensor instance with dtype torch.float64 with specified shape and data. + * + * @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)} + * elements. The buffer is used directly without copying, and changes to its content will + * change the tensor. + * @param shape Tensor shape + */ + public static Tensor fromBlob(DoubleBuffer data, long[] shape) { + checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL); + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + checkShape(shape); + checkShapeAndDataCapacityConsistency(data.capacity(), shape); + checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT); + checkArgument( + (data.order() == ByteOrder.nativeOrder()), + ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER); + return new Tensor_float64(data, shape); + } + + @DoNotStrip private HybridData mHybridData; + + private Tensor(long[] shape) { + checkShape(shape); + this.shape = Arrays.copyOf(shape, shape.length); + } + + /** Returns the number of elements in this tensor. */ + public long numel() { + return numel(this.shape); + } + + /** Calculates the number of elements in a tensor with the specified shape. */ + public static long numel(long[] shape) { + checkShape(shape); + int result = 1; + for (long s : shape) { + result *= s; + } + return result; + } + + /** Returns the shape of this tensor. (The array is a fresh copy.) */ + public long[] shape() { + return Arrays.copyOf(shape, shape.length); + } + + /** @return data type of this tensor. */ + public abstract DType dtype(); + + // Called from native + @DoNotStrip + int dtypeJniCode() { + return dtype().jniCode; + } + + /** + * @return a Java byte array that contains the tensor data. This may be a copy or reference. + * @throws IllegalStateException if it is called for a non-int8 tensor. + */ + public byte[] getDataAsByteArray() { + throw new IllegalStateException( + "Tensor of type " + getClass().getSimpleName() + " cannot return data as byte array."); + } + + /** + * @return a Java byte array that contains the tensor data. This may be a copy or reference. + * @throws IllegalStateException if it is called for a non-uint8 tensor. + */ + public byte[] getDataAsUnsignedByteArray() { + throw new IllegalStateException( + "Tensor of type " + getClass().getSimpleName() + " cannot return data as byte array."); + } + + /** + * @return a Java int array that contains the tensor data. This may be a copy or reference. + * @throws IllegalStateException if it is called for a non-int32 tensor. + */ + public int[] getDataAsIntArray() { + throw new IllegalStateException( + "Tensor of type " + getClass().getSimpleName() + " cannot return data as int array."); + } + + /** + * @return a Java float array that contains the tensor data. This may be a copy or reference. + * @throws IllegalStateException if it is called for a non-float32 tensor. + */ + public float[] getDataAsFloatArray() { + throw new IllegalStateException( + "Tensor of type " + getClass().getSimpleName() + " cannot return data as float array."); + } + + /** + * @return a Java long array that contains the tensor data. This may be a copy or reference. + * @throws IllegalStateException if it is called for a non-int64 tensor. + */ + public long[] getDataAsLongArray() { + throw new IllegalStateException( + "Tensor of type " + getClass().getSimpleName() + " cannot return data as long array."); + } + + /** + * @return a Java double array that contains the tensor data. This may be a copy or reference. + * @throws IllegalStateException if it is called for a non-float64 tensor. + */ + public double[] getDataAsDoubleArray() { + throw new IllegalStateException( + "Tensor of type " + getClass().getSimpleName() + " cannot return data as double array."); + } + + @DoNotStrip + Buffer getRawDataBuffer() { + throw new IllegalStateException( + "Tensor of type " + getClass().getSimpleName() + " cannot " + "return raw data buffer."); + } + + static class Tensor_uint8 extends Tensor { + private final ByteBuffer data; + + private Tensor_uint8(ByteBuffer data, long[] shape) { + super(shape); + this.data = data; + } + + @Override + public DType dtype() { + return DType.UINT8; + } + + @Override + Buffer getRawDataBuffer() { + return data; + } + + @Override + public byte[] getDataAsUnsignedByteArray() { + data.rewind(); + byte[] arr = new byte[data.remaining()]; + data.get(arr); + return arr; + } + + @Override + public String toString() { + return String.format("Tensor(%s, dtype=torch.uint8)", Arrays.toString(shape)); + } + } + + static class Tensor_int8 extends Tensor { + private final ByteBuffer data; + + private Tensor_int8(ByteBuffer data, long[] shape) { + super(shape); + this.data = data; + } + + @Override + public DType dtype() { + return DType.INT8; + } + + @Override + Buffer getRawDataBuffer() { + return data; + } + + @Override + public byte[] getDataAsByteArray() { + data.rewind(); + byte[] arr = new byte[data.remaining()]; + data.get(arr); + return arr; + } + + @Override + public String toString() { + return String.format("Tensor(%s, dtype=torch.int8)", Arrays.toString(shape)); + } + } + + static class Tensor_int32 extends Tensor { + private final IntBuffer data; + + private Tensor_int32(IntBuffer data, long[] shape) { + super(shape); + this.data = data; + } + + @Override + public DType dtype() { + return DType.INT32; + } + + @Override + Buffer getRawDataBuffer() { + return data; + } + + @Override + public int[] getDataAsIntArray() { + data.rewind(); + int[] arr = new int[data.remaining()]; + data.get(arr); + return arr; + } + + @Override + public String toString() { + return String.format("Tensor(%s, dtype=torch.int32)", Arrays.toString(shape)); + } + } + + static class Tensor_float32 extends Tensor { + private final FloatBuffer data; + + Tensor_float32(FloatBuffer data, long[] shape) { + super(shape); + this.data = data; + } + + @Override + public float[] getDataAsFloatArray() { + data.rewind(); + float[] arr = new float[data.remaining()]; + data.get(arr); + return arr; + } + + @Override + public DType dtype() { + return DType.FLOAT; + } + + @Override + Buffer getRawDataBuffer() { + return data; + } + + @Override + public String toString() { + return String.format("Tensor(%s, dtype=torch.float32)", Arrays.toString(shape)); + } + } + + static class Tensor_int64 extends Tensor { + private final LongBuffer data; + + private Tensor_int64(LongBuffer data, long[] shape) { + super(shape); + this.data = data; + } + + @Override + public DType dtype() { + return DType.INT64; + } + + @Override + Buffer getRawDataBuffer() { + return data; + } + + @Override + public long[] getDataAsLongArray() { + data.rewind(); + long[] arr = new long[data.remaining()]; + data.get(arr); + return arr; + } + + @Override + public String toString() { + return String.format("Tensor(%s, dtype=torch.int64)", Arrays.toString(shape)); + } + } + + static class Tensor_float64 extends Tensor { + private final DoubleBuffer data; + + private Tensor_float64(DoubleBuffer data, long[] shape) { + super(shape); + this.data = data; + } + + @Override + public DType dtype() { + return DType.DOUBLE; + } + + @Override + Buffer getRawDataBuffer() { + return data; + } + + @Override + public double[] getDataAsDoubleArray() { + data.rewind(); + double[] arr = new double[data.remaining()]; + data.get(arr); + return arr; + } + + @Override + public String toString() { + return String.format("Tensor(%s, dtype=torch.float64)", Arrays.toString(shape)); + } + } + + // region checks + private static void checkArgument(boolean expression, String errorMessage, Object... args) { + if (!expression) { + throw new IllegalArgumentException(String.format(Locale.US, errorMessage, args)); + } + } + + private static void checkShape(long[] shape) { + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + for (int i = 0; i < shape.length; i++) { + checkArgument(shape[i] >= 0, ERROR_MSG_SHAPE_NON_NEGATIVE); + } + } + + private static void checkShapeAndDataCapacityConsistency(int dataCapacity, long[] shape) { + final long numel = numel(shape); + checkArgument( + numel == dataCapacity, + "Inconsistent data capacity:%d and shape number elements:%d shape:%s", + dataCapacity, + numel, + Arrays.toString(shape)); + } + // endregion checks + + // Called from native + @DoNotStrip + private static Tensor nativeNewTensor( + ByteBuffer data, long[] shape, int dtype, HybridData hybridData) { + Tensor tensor = null; + + if (DType.FLOAT.jniCode == dtype) { + tensor = new Tensor_float32(data.asFloatBuffer(), shape); + } else if (DType.INT32.jniCode == dtype) { + tensor = new Tensor_int32(data.asIntBuffer(), shape); + } else if (DType.INT64.jniCode == dtype) { + tensor = new Tensor_int64(data.asLongBuffer(), shape); + } else if (DType.DOUBLE.jniCode == dtype) { + tensor = new Tensor_float64(data.asDoubleBuffer(), shape); + } else if (DType.UINT8.jniCode == dtype) { + tensor = new Tensor_uint8(data, shape); + } else if (DType.INT8.jniCode == dtype) { + tensor = new Tensor_int8(data, shape); + } else { + throw new IllegalArgumentException("Unknown Tensor dtype"); + } + tensor.mHybridData = hybridData; + return tensor; + } +} diff --git a/extension/android/src/main/java/org/pytorch/executorch/TensorImageUtils.java b/extension/android/src/main/java/org/pytorch/executorch/TensorImageUtils.java new file mode 100644 index 00000000000..d86f3ff7b14 --- /dev/null +++ b/extension/android/src/main/java/org/pytorch/executorch/TensorImageUtils.java @@ -0,0 +1,149 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.executorch; + +import android.graphics.Bitmap; +import android.util.Log; +import java.nio.FloatBuffer; + +/** + * Contains utility functions for {@link Tensor} creation from {@link android.graphics.Bitmap} or + * {@link android.media.Image} source. + */ +public final class TensorImageUtils { + + public static float[] TORCHVISION_NORM_MEAN_RGB = new float[] {0.485f, 0.456f, 0.406f}; + public static float[] TORCHVISION_NORM_STD_RGB = new float[] {0.229f, 0.224f, 0.225f}; + + /** + * Creates new {@link Tensor} from full {@link android.graphics.Bitmap}, normalized with specified + * in parameters mean and std. + * + * @param normMeanRGB means for RGB channels normalization, length must equal 3, RGB order + * @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB + * order + */ + public static Tensor bitmapToFloat32Tensor( + final Bitmap bitmap, final float[] normMeanRGB, final float normStdRGB[]) { + checkNormMeanArg(normMeanRGB); + checkNormStdArg(normStdRGB); + + return bitmapToFloat32Tensor( + bitmap, 0, 0, bitmap.getWidth(), bitmap.getHeight(), normMeanRGB, normStdRGB); + } + + /** + * Writes tensor content from specified {@link android.graphics.Bitmap}, normalized with specified + * in parameters mean and std to specified {@link java.nio.FloatBuffer} with specified offset. + * + * @param bitmap {@link android.graphics.Bitmap} as a source for Tensor data + * @param x - x coordinate of top left corner of bitmap's area + * @param y - y coordinate of top left corner of bitmap's area + * @param width - width of bitmap's area + * @param height - height of bitmap's area + * @param normMeanRGB means for RGB channels normalization, length must equal 3, RGB order + * @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB + * order + */ + public static void bitmapToFloatBuffer( + final Bitmap bitmap, + final int x, + final int y, + final int width, + final int height, + final float[] normMeanRGB, + final float[] normStdRGB, + final FloatBuffer outBuffer, + final int outBufferOffset) { + checkOutBufferCapacity(outBuffer, outBufferOffset, width, height); + checkNormMeanArg(normMeanRGB); + checkNormStdArg(normStdRGB); + final int pixelsCount = height * width; + final int[] pixels = new int[pixelsCount]; + bitmap.getPixels(pixels, 0, width, x, y, width, height); + final int offset_g = pixelsCount; + final int offset_b = 2 * pixelsCount; + for (int i = 0; i < 100; i++) { + final int c = pixels[i]; + Log.i("Image", ": " + i + " " + ((c >> 16) & 0xff)); + } + for (int i = 0; i < pixelsCount; i++) { + final int c = pixels[i]; + float r = ((c >> 16) & 0xff) / 255.0f; + float g = ((c >> 8) & 0xff) / 255.0f; + float b = ((c) & 0xff) / 255.0f; + outBuffer.put(outBufferOffset + i, (r - normMeanRGB[0]) / normStdRGB[0]); + outBuffer.put(outBufferOffset + offset_g + i, (g - normMeanRGB[1]) / normStdRGB[1]); + outBuffer.put(outBufferOffset + offset_b + i, (b - normMeanRGB[2]) / normStdRGB[2]); + } + } + + /** + * Creates new {@link Tensor} from specified area of {@link android.graphics.Bitmap}, normalized + * with specified in parameters mean and std. + * + * @param bitmap {@link android.graphics.Bitmap} as a source for Tensor data + * @param x - x coordinate of top left corner of bitmap's area + * @param y - y coordinate of top left corner of bitmap's area + * @param width - width of bitmap's area + * @param height - height of bitmap's area + * @param normMeanRGB means for RGB channels normalization, length must equal 3, RGB order + * @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB + * order + */ + public static Tensor bitmapToFloat32Tensor( + final Bitmap bitmap, + int x, + int y, + int width, + int height, + float[] normMeanRGB, + float[] normStdRGB) { + checkNormMeanArg(normMeanRGB); + checkNormStdArg(normStdRGB); + + final FloatBuffer floatBuffer = Tensor.allocateFloatBuffer(3 * width * height); + bitmapToFloatBuffer(bitmap, x, y, width, height, normMeanRGB, normStdRGB, floatBuffer, 0); + return Tensor.fromBlob(floatBuffer, new long[] {1, 3, height, width}); + } + + private static void checkOutBufferCapacity( + FloatBuffer outBuffer, int outBufferOffset, int tensorWidth, int tensorHeight) { + if (outBufferOffset + 3 * tensorWidth * tensorHeight > outBuffer.capacity()) { + throw new IllegalStateException("Buffer underflow"); + } + } + + private static void checkTensorSize(int tensorWidth, int tensorHeight) { + if (tensorHeight <= 0 || tensorWidth <= 0) { + throw new IllegalArgumentException("tensorHeight and tensorWidth must be positive"); + } + } + + private static void checkRotateCWDegrees(int rotateCWDegrees) { + if (rotateCWDegrees != 0 + && rotateCWDegrees != 90 + && rotateCWDegrees != 180 + && rotateCWDegrees != 270) { + throw new IllegalArgumentException("rotateCWDegrees must be one of 0, 90, 180, 270"); + } + } + + private static void checkNormStdArg(float[] normStdRGB) { + if (normStdRGB.length != 3) { + throw new IllegalArgumentException("normStdRGB length must be 3"); + } + } + + private static void checkNormMeanArg(float[] normMeanRGB) { + if (normMeanRGB.length != 3) { + throw new IllegalArgumentException("normMeanRGB length must be 3"); + } + } +}