Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,14 @@ import org.apache.commons.io.FileUtils
import org.json.JSONException
import org.json.JSONObject
import org.junit.Assert
import org.junit.Assert.assertEquals
import org.junit.Assert.assertThat
import org.junit.Assert.assertTrue
import org.junit.Before
import org.junit.Rule
import org.junit.Test
import org.junit.runner.RunWith
import org.pytorch.executorch.TestFileUtils.getTestFilePath
import org.pytorch.executorch.extension.llm.LlmCallback
import org.pytorch.executorch.extension.llm.LlmModule

Expand All @@ -30,7 +34,7 @@ import org.pytorch.executorch.extension.llm.LlmModule
class LlmModuleInstrumentationTest : LlmCallback {
private val results: MutableList<String> = ArrayList()
private val tokensPerSecond: MutableList<Float> = ArrayList()
private var llmModule: LlmModule? = null
private lateinit var llmModule: LlmModule

@Before
@Throws(IOException::class)
Expand All @@ -57,25 +61,25 @@ class LlmModuleInstrumentationTest : LlmCallback {
@Test
@Throws(IOException::class, URISyntaxException::class)
fun testGenerate() {
val loadResult = llmModule!!.load()
val loadResult = llmModule.load()
// Check that the model can be load successfully
Assert.assertEquals(OK.toLong(), loadResult.toLong())
assertEquals(OK.toLong(), loadResult.toLong())

llmModule!!.generate(TEST_PROMPT, SEQ_LEN, this@LlmModuleInstrumentationTest)
Assert.assertEquals(results.size.toLong(), SEQ_LEN.toLong())
Assert.assertTrue(tokensPerSecond[tokensPerSecond.size - 1] > 0)
llmModule.generate(TEST_PROMPT, SEQ_LEN, this@LlmModuleInstrumentationTest)
assertEquals(results.size.toLong(), SEQ_LEN.toLong())
assertTrue(tokensPerSecond[tokensPerSecond.size - 1] > 0)
}

@Test
@Throws(IOException::class, URISyntaxException::class)
fun testGenerateAndStop() {
llmModule!!.generate(
llmModule.generate(
TEST_PROMPT,
SEQ_LEN,
object : LlmCallback {
override fun onResult(result: String) {
this@LlmModuleInstrumentationTest.onResult(result)
llmModule!!.stop()
llmModule.stop()
}

override fun onStats(stats: String) {
Expand All @@ -85,7 +89,7 @@ class LlmModuleInstrumentationTest : LlmCallback {
)

val stoppedResultSize = results.size
Assert.assertTrue(stoppedResultSize < SEQ_LEN)
assertTrue(stoppedResultSize < SEQ_LEN)
}

override fun onResult(result: String) {
Expand All @@ -101,7 +105,8 @@ class LlmModuleInstrumentationTest : LlmCallback {
val promptEvalEndMs = jsonObject.getInt("prompt_eval_end_ms")
tps = numGeneratedTokens.toFloat() / (inferenceEndMs - promptEvalEndMs) * 1000
tokensPerSecond.add(tps)
} catch (_: JSONException) {}
} catch (_: JSONException) {
}
}

companion object {
Expand All @@ -110,12 +115,5 @@ class LlmModuleInstrumentationTest : LlmCallback {
private const val TEST_PROMPT = "Hello"
private const val OK = 0x00
private const val SEQ_LEN = 32

private fun getTestFilePath(fileName: String): String {
return InstrumentationRegistry.getInstrumentation()
.targetContext
.externalCacheDir
.toString() + fileName
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.junit.Rule
import org.junit.Test
import org.junit.runner.RunWith
import org.pytorch.executorch.TensorImageUtils.bitmapToFloat32Tensor
import org.pytorch.executorch.TestFileUtils.getTestFilePath

/** Unit tests for [Module]. */
@RunWith(AndroidJUnit4::class)
Expand Down Expand Up @@ -90,12 +91,6 @@ class ModuleE2ETest {
}

companion object {
private fun getTestFilePath(fileName: String): String {
return InstrumentationRegistry.getInstrumentation()
.targetContext
.externalCacheDir
.toString() + fileName
}

fun argmax(array: FloatArray): Int {
require(array.isNotEmpty()) { "Array cannot be empty" }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.junit.Before
import org.junit.Rule
import org.junit.Test
import org.junit.runner.RunWith
import org.pytorch.executorch.TestFileUtils.getTestFilePath

/** Unit tests for [Module]. */
@RunWith(AndroidJUnit4::class)
Expand Down Expand Up @@ -173,12 +174,5 @@ class ModuleInstrumentationTest {
private const val INVALID_STATE = 0x2
private const val INVALID_ARGUMENT = 0x12
private const val ACCESS_FAILED = 0x22

private fun getTestFilePath(fileName: String): String {
return InstrumentationRegistry.getInstrumentation()
.targetContext
.externalCacheDir
.toString() + fileName
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package org.pytorch.executorch

import androidx.test.InstrumentationRegistry

/**
* Test File Utils
*/
object TestFileUtils {

fun getTestFilePath(fileName: String): String {
return InstrumentationRegistry.getInstrumentation()
.targetContext
.externalCacheDir
.toString() + fileName
}
}
Loading