From 69eb83dd7a4671f8dafd63b6c51abbd011d2f7b4 Mon Sep 17 00:00:00 2001 From: Haiting Pu Date: Tue, 27 May 2025 12:41:22 -0700 Subject: [PATCH] Use lateinit var to remove !! and created a common TestFileUtils to share the same code for getTestFilePath --- .../LlmModuleInstrumentationTest.kt | 32 +++++++++---------- .../org/pytorch/executorch/ModuleE2ETest.kt | 7 +--- .../executorch/ModuleInstrumentationTest.kt | 8 +---- .../org/pytorch/executorch/TestFileUtils.kt | 16 ++++++++++ 4 files changed, 33 insertions(+), 30 deletions(-) create mode 100644 extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/TestFileUtils.kt diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmModuleInstrumentationTest.kt b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmModuleInstrumentationTest.kt index 43ce302a7a6..2df45f14985 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmModuleInstrumentationTest.kt +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmModuleInstrumentationTest.kt @@ -18,10 +18,14 @@ import org.apache.commons.io.FileUtils import org.json.JSONException import org.json.JSONObject import org.junit.Assert +import org.junit.Assert.assertEquals +import org.junit.Assert.assertThat +import org.junit.Assert.assertTrue import org.junit.Before import org.junit.Rule import org.junit.Test import org.junit.runner.RunWith +import org.pytorch.executorch.TestFileUtils.getTestFilePath import org.pytorch.executorch.extension.llm.LlmCallback import org.pytorch.executorch.extension.llm.LlmModule @@ -30,7 +34,7 @@ import org.pytorch.executorch.extension.llm.LlmModule class LlmModuleInstrumentationTest : LlmCallback { private val results: MutableList = ArrayList() private val tokensPerSecond: MutableList = ArrayList() - private var llmModule: LlmModule? = null + private lateinit var llmModule: LlmModule @Before @Throws(IOException::class) @@ -57,25 +61,25 @@ class LlmModuleInstrumentationTest : LlmCallback { @Test @Throws(IOException::class, URISyntaxException::class) fun testGenerate() { - val loadResult = llmModule!!.load() + val loadResult = llmModule.load() // Check that the model can be load successfully - Assert.assertEquals(OK.toLong(), loadResult.toLong()) + assertEquals(OK.toLong(), loadResult.toLong()) - llmModule!!.generate(TEST_PROMPT, SEQ_LEN, this@LlmModuleInstrumentationTest) - Assert.assertEquals(results.size.toLong(), SEQ_LEN.toLong()) - Assert.assertTrue(tokensPerSecond[tokensPerSecond.size - 1] > 0) + llmModule.generate(TEST_PROMPT, SEQ_LEN, this@LlmModuleInstrumentationTest) + assertEquals(results.size.toLong(), SEQ_LEN.toLong()) + assertTrue(tokensPerSecond[tokensPerSecond.size - 1] > 0) } @Test @Throws(IOException::class, URISyntaxException::class) fun testGenerateAndStop() { - llmModule!!.generate( + llmModule.generate( TEST_PROMPT, SEQ_LEN, object : LlmCallback { override fun onResult(result: String) { this@LlmModuleInstrumentationTest.onResult(result) - llmModule!!.stop() + llmModule.stop() } override fun onStats(stats: String) { @@ -85,7 +89,7 @@ class LlmModuleInstrumentationTest : LlmCallback { ) val stoppedResultSize = results.size - Assert.assertTrue(stoppedResultSize < SEQ_LEN) + assertTrue(stoppedResultSize < SEQ_LEN) } override fun onResult(result: String) { @@ -101,7 +105,8 @@ class LlmModuleInstrumentationTest : LlmCallback { val promptEvalEndMs = jsonObject.getInt("prompt_eval_end_ms") tps = numGeneratedTokens.toFloat() / (inferenceEndMs - promptEvalEndMs) * 1000 tokensPerSecond.add(tps) - } catch (_: JSONException) {} + } catch (_: JSONException) { + } } companion object { @@ -110,12 +115,5 @@ class LlmModuleInstrumentationTest : LlmCallback { private const val TEST_PROMPT = "Hello" private const val OK = 0x00 private const val SEQ_LEN = 32 - - private fun getTestFilePath(fileName: String): String { - return InstrumentationRegistry.getInstrumentation() - .targetContext - .externalCacheDir - .toString() + fileName - } } } diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.kt b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.kt index 2a1e9d4c8ff..e269f4aa38f 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.kt +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.kt @@ -22,6 +22,7 @@ import org.junit.Rule import org.junit.Test import org.junit.runner.RunWith import org.pytorch.executorch.TensorImageUtils.bitmapToFloat32Tensor +import org.pytorch.executorch.TestFileUtils.getTestFilePath /** Unit tests for [Module]. */ @RunWith(AndroidJUnit4::class) @@ -90,12 +91,6 @@ class ModuleE2ETest { } companion object { - private fun getTestFilePath(fileName: String): String { - return InstrumentationRegistry.getInstrumentation() - .targetContext - .externalCacheDir - .toString() + fileName - } fun argmax(array: FloatArray): Int { require(array.isNotEmpty()) { "Array cannot be empty" } diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt index 1885660d0a1..58e9cc8bfef 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt @@ -23,6 +23,7 @@ import org.junit.Before import org.junit.Rule import org.junit.Test import org.junit.runner.RunWith +import org.pytorch.executorch.TestFileUtils.getTestFilePath /** Unit tests for [Module]. */ @RunWith(AndroidJUnit4::class) @@ -173,12 +174,5 @@ class ModuleInstrumentationTest { private const val INVALID_STATE = 0x2 private const val INVALID_ARGUMENT = 0x12 private const val ACCESS_FAILED = 0x22 - - private fun getTestFilePath(fileName: String): String { - return InstrumentationRegistry.getInstrumentation() - .targetContext - .externalCacheDir - .toString() + fileName - } } } diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/TestFileUtils.kt b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/TestFileUtils.kt new file mode 100644 index 00000000000..efa364f8e94 --- /dev/null +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/TestFileUtils.kt @@ -0,0 +1,16 @@ +package org.pytorch.executorch + +import androidx.test.InstrumentationRegistry + +/** + * Test File Utils + */ +object TestFileUtils { + + fun getTestFilePath(fileName: String): String { + return InstrumentationRegistry.getInstrumentation() + .targetContext + .externalCacheDir + .toString() + fileName + } +}