diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt index ba91f444287..49f5981314d 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt @@ -17,7 +17,6 @@ import java.util.concurrent.atomic.AtomicInteger import org.apache.commons.io.FileUtils import org.junit.Assert import org.junit.Before -import org.junit.Ignore import org.junit.Test import org.junit.runner.RunWith import org.pytorch.executorch.TestFileUtils.getTestFilePath @@ -40,15 +39,12 @@ class ModuleInstrumentationTest { inputStream.close() } - @Ignore( - "The forward has failure that needs to be fixed before enabling this test: [Executorch Error 0x12] Invalid argument: Execution failed for method: forward " - ) @Test @Throws(IOException::class, URISyntaxException::class) fun testModuleLoadAndForward() { val module = Module.load(getTestFilePath(TEST_FILE_NAME)) - val results = module.forward() + val results = module.forward(EValue.from(dummyInput())) Assert.assertTrue(results[0].isTensor) } @@ -58,9 +54,6 @@ class ModuleInstrumentationTest { val module = Module.load(getTestFilePath(TEST_FILE_NAME)) } - @Ignore( - "The forward has failure that needs to be fixed before enabling this test: [Executorch Error 0x12] Invalid argument: Execution failed for method: forward " - ) @Test @Throws(IOException::class) fun testModuleLoadMethodAndForward() { @@ -68,19 +61,16 @@ class ModuleInstrumentationTest { module.loadMethod(FORWARD_METHOD) - val results = module.forward() + val results = module.forward(EValue.from(dummyInput())) Assert.assertTrue(results[0].isTensor) } - @Ignore( - "The forward has failure that needs to be fixed before enabling this test: [Executorch Error 0x12] Invalid argument: Execution failed for method: forward " - ) @Test @Throws(IOException::class) fun testModuleLoadForwardExplicit() { val module = Module.load(getTestFilePath(TEST_FILE_NAME)) - val results = module.execute(FORWARD_METHOD) + val results = module.execute(FORWARD_METHOD, EValue.from(dummyInput())) Assert.assertTrue(results[0].isTensor) } @@ -135,9 +125,6 @@ class ModuleInstrumentationTest { Assert.assertThrows(IllegalStateException::class.java) { module.forward() } } - @Ignore( - "The forward has failure that needs to be fixed before enabling this test: [Executorch Error 0x12] Invalid argument: Execution failed for method: forward " - ) @Test @Throws(InterruptedException::class, IOException::class) fun testForwardFromMultipleThreads() { @@ -151,7 +138,7 @@ class ModuleInstrumentationTest { try { latch.countDown() latch.await(5000, TimeUnit.MILLISECONDS) - val results = module.forward() + val results = module.forward(EValue.from(dummyInput())) Assert.assertTrue(results[0].isTensor) completed.incrementAndGet() } catch (_: InterruptedException) {} @@ -176,5 +163,8 @@ class ModuleInstrumentationTest { private const val NON_PTE_FILE_NAME = "/test.txt" private const val FORWARD_METHOD = "forward" private const val NONE_METHOD = "none" + private val inputShape = longArrayOf(1, 3, 224, 224) + + private fun dummyInput(): Tensor = Tensor.ones(inputShape, DType.FLOAT) } }