diff --git a/extension/android/src/main/java/org/pytorch/executorch/EValue.java b/extension/android/src/main/java/org/pytorch/executorch/EValue.java index f133eb4ad60..599818a00d7 100644 --- a/extension/android/src/main/java/org/pytorch/executorch/EValue.java +++ b/extension/android/src/main/java/org/pytorch/executorch/EValue.java @@ -12,7 +12,6 @@ import java.nio.ByteBuffer; import java.util.Arrays; import java.util.Locale; -import java.util.Optional; import org.pytorch.executorch.annotations.Experimental; /** @@ -44,26 +43,8 @@ public class EValue { private static final int TYPE_CODE_INT = 4; private static final int TYPE_CODE_BOOL = 5; - private static final int TYPE_CODE_LIST_BOOL = 6; - private static final int TYPE_CODE_LIST_DOUBLE = 7; - private static final int TYPE_CODE_LIST_INT = 8; - private static final int TYPE_CODE_LIST_TENSOR = 9; - private static final int TYPE_CODE_LIST_SCALAR = 10; - private static final int TYPE_CODE_LIST_OPTIONAL_TENSOR = 11; - private String[] TYPE_NAMES = { - "None", - "Tensor", - "String", - "Double", - "Int", - "Bool", - "ListBool", - "ListDouble", - "ListInt", - "ListTensor", - "ListScalar", - "ListOptionalTensor", + "None", "Tensor", "String", "Double", "Int", "Bool", }; @DoNotStrip private final int mTypeCode; @@ -104,31 +85,6 @@ public boolean isString() { return TYPE_CODE_STRING == this.mTypeCode; } - @DoNotStrip - public boolean isBoolList() { - return TYPE_CODE_LIST_BOOL == this.mTypeCode; - } - - @DoNotStrip - public boolean isIntList() { - return TYPE_CODE_LIST_INT == this.mTypeCode; - } - - @DoNotStrip - public boolean isDoubleList() { - return TYPE_CODE_LIST_DOUBLE == this.mTypeCode; - } - - @DoNotStrip - public boolean isTensorList() { - return TYPE_CODE_LIST_TENSOR == this.mTypeCode; - } - - @DoNotStrip - public boolean isOptionalTensorList() { - return TYPE_CODE_LIST_OPTIONAL_TENSOR == this.mTypeCode; - } - /** Creates a new {@code EValue} of type {@code Optional} that contains no value. */ @DoNotStrip public static EValue optionalNone() { @@ -175,46 +131,6 @@ public static EValue from(String value) { return iv; } - /** Creates a new {@code EValue} of type {@code List[bool]}. */ - @DoNotStrip - public static EValue listFrom(boolean... list) { - final EValue iv = new EValue(TYPE_CODE_LIST_BOOL); - iv.mData = list; - return iv; - } - - /** Creates a new {@code EValue} of type {@code List[int]}. */ - @DoNotStrip - public static EValue listFrom(long... list) { - final EValue iv = new EValue(TYPE_CODE_LIST_INT); - iv.mData = list; - return iv; - } - - /** Creates a new {@code EValue} of type {@code List[double]}. */ - @DoNotStrip - public static EValue listFrom(double... list) { - final EValue iv = new EValue(TYPE_CODE_LIST_DOUBLE); - iv.mData = list; - return iv; - } - - /** Creates a new {@code EValue} of type {@code List[Tensor]}. */ - @DoNotStrip - public static EValue listFrom(Tensor... list) { - final EValue iv = new EValue(TYPE_CODE_LIST_TENSOR); - iv.mData = list; - return iv; - } - - /** Creates a new {@code EValue} of type {@code List[Optional[Tensor]]}. */ - @DoNotStrip - public static EValue listFrom(Optional... list) { - final EValue iv = new EValue(TYPE_CODE_LIST_OPTIONAL_TENSOR); - iv.mData = list; - return iv; - } - @DoNotStrip public Tensor toTensor() { preconditionType(TYPE_CODE_TENSOR, mTypeCode); @@ -245,36 +161,6 @@ public String toStr() { return (String) mData; } - @DoNotStrip - public boolean[] toBoolList() { - preconditionType(TYPE_CODE_LIST_BOOL, mTypeCode); - return (boolean[]) mData; - } - - @DoNotStrip - public long[] toIntList() { - preconditionType(TYPE_CODE_LIST_INT, mTypeCode); - return (long[]) mData; - } - - @DoNotStrip - public double[] toDoubleList() { - preconditionType(TYPE_CODE_LIST_DOUBLE, mTypeCode); - return (double[]) mData; - } - - @DoNotStrip - public Tensor[] toTensorList() { - preconditionType(TYPE_CODE_LIST_TENSOR, mTypeCode); - return (Tensor[]) mData; - } - - @DoNotStrip - public Optional[] toOptionalTensorList() { - preconditionType(TYPE_CODE_LIST_OPTIONAL_TENSOR, mTypeCode); - return (Optional[]) mData; - } - private void preconditionType(int typeCodeExpected, int typeCode) { if (typeCode != typeCodeExpected) { throw new IllegalStateException( @@ -294,8 +180,7 @@ private String getTypeName(int typeCode) { * Serializes an {@code EValue} into a byte array. * * @return The serialized byte array. - * @apiNote This method is experimental and subject to change without notice. This does NOT - * supoprt list type. + * @apiNote This method is experimental and subject to change without notice. */ public byte[] toByteArray() { if (isNone()) { @@ -331,8 +216,7 @@ public byte[] toByteArray() { * * @param bytes The byte array to deserialize from. * @return The deserialized {@code EValue}. - * @apiNote This method is experimental and subject to change without notice. This does NOT list - * type. + * @apiNote This method is experimental and subject to change without notice. */ public static EValue fromByteArray(byte[] bytes) { ByteBuffer buffer = ByteBuffer.wrap(bytes); diff --git a/extension/android_test/src/test/java/org/pytorch/executorch/EValueTest.java b/extension/android_test/src/test/java/org/pytorch/executorch/EValueTest.java index 9856329da78..cbeb3a7b634 100644 --- a/extension/android_test/src/test/java/org/pytorch/executorch/EValueTest.java +++ b/extension/android_test/src/test/java/org/pytorch/executorch/EValueTest.java @@ -14,7 +14,6 @@ import static org.junit.Assert.fail; import java.util.Arrays; -import java.util.Optional; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -66,70 +65,6 @@ public void testStringValue() { assertEquals(evalue.toStr(), "a"); } - @Test - public void testBoolListValue() { - boolean[] value = {true, false, true}; - EValue evalue = EValue.listFrom(value); - assertTrue(evalue.isBoolList()); - assertTrue(Arrays.equals(value, evalue.toBoolList())); - } - - @Test - public void testIntListValue() { - long[] value = {Long.MIN_VALUE, 0, Long.MAX_VALUE}; - EValue evalue = EValue.listFrom(value); - assertTrue(evalue.isIntList()); - assertTrue(Arrays.equals(value, evalue.toIntList())); - } - - @Test - public void testDoubleListValue() { - double[] value = {Double.MIN_VALUE, 0.1d, 0.01d, 0.001d, Double.MAX_VALUE}; - EValue evalue = EValue.listFrom(value); - assertTrue(evalue.isDoubleList()); - assertTrue(Arrays.equals(value, evalue.toDoubleList())); - } - - @Test - public void testTensorListValue() { - long[][] data = {{1, 2, 3}, {1, 2, 3, 4, 5, 6}}; - long[][] shape = {{1, 3}, {2, 3}}; - Tensor[] tensors = {Tensor.fromBlob(data[0], shape[0]), Tensor.fromBlob(data[1], shape[1])}; - - EValue evalue = EValue.listFrom(tensors); - assertTrue(evalue.isTensorList()); - - assertTrue(Arrays.equals(evalue.toTensorList()[0].shape, shape[0])); - assertTrue(Arrays.equals(evalue.toTensorList()[0].getDataAsLongArray(), data[0])); - - assertTrue(Arrays.equals(evalue.toTensorList()[1].shape, shape[1])); - assertTrue(Arrays.equals(evalue.toTensorList()[1].getDataAsLongArray(), data[1])); - } - - @Test - @SuppressWarnings("unchecked") - public void testOptionalTensorListValue() { - long[][] data = {{1, 2, 3}, {1, 2, 3, 4, 5, 6}}; - long[][] shape = {{1, 3}, {2, 3}}; - - EValue evalue = - EValue.listFrom( - Optional.empty(), - Optional.of(Tensor.fromBlob(data[0], shape[0])), - Optional.of(Tensor.fromBlob(data[1], shape[1]))); - assertTrue(evalue.isOptionalTensorList()); - - assertTrue(!evalue.toOptionalTensorList()[0].isPresent()); - - assertTrue(evalue.toOptionalTensorList()[1].isPresent()); - assertTrue(Arrays.equals(evalue.toOptionalTensorList()[1].get().shape, shape[0])); - assertTrue(Arrays.equals(evalue.toOptionalTensorList()[1].get().getDataAsLongArray(), data[0])); - - assertTrue(evalue.toOptionalTensorList()[2].isPresent()); - assertTrue(Arrays.equals(evalue.toOptionalTensorList()[2].get().shape, shape[1])); - assertTrue(Arrays.equals(evalue.toOptionalTensorList()[2].get().getDataAsLongArray(), data[1])); - } - @Test public void testAllIllegalCast() { EValue evalue = EValue.optionalNone(); @@ -174,46 +109,6 @@ public void testAllIllegalCast() { fail("Should have thrown an exception"); } catch (IllegalStateException e) { } - - // try bool list - assertFalse(evalue.isBoolList()); - try { - evalue.toBoolList(); - fail("Should have thrown an exception"); - } catch (IllegalStateException e) { - } - - // try int list - assertFalse(evalue.isIntList()); - try { - evalue.toIntList(); - fail("Should have thrown an exception"); - } catch (IllegalStateException e) { - } - - // try double list - assertFalse(evalue.isDoubleList()); - try { - evalue.toBool(); - fail("Should have thrown an exception"); - } catch (IllegalStateException e) { - } - - // try Tensor list - assertFalse(evalue.isTensorList()); - try { - evalue.toTensorList(); - fail("Should have thrown an exception"); - } catch (IllegalStateException e) { - } - - // try optional Tensor list - assertFalse(evalue.isOptionalTensorList()); - try { - evalue.toOptionalTensorList(); - fail("Should have thrown an exception"); - } catch (IllegalStateException e) { - } } @Test