diff --git a/ml/src/androidTest/kotlin/com/fpf/smartscansdk/ml/models/providers/embeddings/clip/ClipTextEmbedderTest.kt b/ml/src/androidTest/kotlin/com/fpf/smartscansdk/ml/models/providers/embeddings/clip/ClipTextEmbedderTest.kt index cc11437..f5f1816 100644 --- a/ml/src/androidTest/kotlin/com/fpf/smartscansdk/ml/models/providers/embeddings/clip/ClipTextEmbedderTest.kt +++ b/ml/src/androidTest/kotlin/com/fpf/smartscansdk/ml/models/providers/embeddings/clip/ClipTextEmbedderTest.kt @@ -115,4 +115,28 @@ class ClipTextEmbedderInstrumentedTest { verify(exactly = 1) { (mockModel as AutoCloseable).close() } } + + @Test + fun `embed handles strings longer than 77 tokens`() = runBlocking { + val embedder = ClipTextEmbedder(context, ResourceId(0)) + val mockModel = mockk(relaxed = true) + every { mockModel.isLoaded() } returns true + every { mockModel.getInputNames() } returns listOf("input") + every { mockModel.getEnv() } returns mockk() + + val raw = Array(1) { FloatArray(embedder.embeddingDim) { 1.0f } } + every { mockModel.run(any>()) } returns mapOf("out" to raw) + + val field = embedder::class.java.getDeclaredField("model") + field.isAccessible = true + field.set(embedder, mockModel) + + val longText = "a".repeat(2000) + val embedding = embedder.embed(longText) + + assertEquals(embedder.embeddingDim, embedding.size) + val l2 = sqrt(embedding.map { it * it }.sum()) + assertTrue(abs(l2 - 1.0f) < 1e-3) + } + } diff --git a/ml/src/main/java/com/fpf/smartscansdk/ml/models/providers/embeddings/clip/ClipTextEmbedder.kt b/ml/src/main/java/com/fpf/smartscansdk/ml/models/providers/embeddings/clip/ClipTextEmbedder.kt index b1ac8de..3a2b154 100644 --- a/ml/src/main/java/com/fpf/smartscansdk/ml/models/providers/embeddings/clip/ClipTextEmbedder.kt +++ b/ml/src/main/java/com/fpf/smartscansdk/ml/models/providers/embeddings/clip/ClipTextEmbedder.kt @@ -50,8 +50,9 @@ class ClipTextEmbedder( if (!isInitialized()) throw IllegalStateException("Model not initialized") val clean = Regex("[^A-Za-z0-9 ]").replace(data, "").lowercase() - var tokens = mutableListOf(tokenBOS) + tokenizer.encode(clean) + tokenEOS - tokens = tokens.take(77) + List(77 - tokens.size) { 0 } + var tokens = (mutableListOf(tokenBOS) + tokenizer.encode(clean) + tokenEOS).take(77).toMutableList() + if (tokens.size < 77) tokens += List(77 - tokens.size) { 0 } + val inputIds = LongBuffer.allocate(1 * 77).apply { tokens.forEach { put(it.toLong()) }