Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import java.nio.channels.FileChannel

class FileEmbeddingStore(
private val file: File,
private val embeddingLength: Int,
private val embeddingDimension: Int,
val useCache: Boolean = true,
):
IEmbeddingStore {
Expand Down Expand Up @@ -53,12 +53,12 @@ class FileEmbeddingStore(
val batch = embeddingsList.subList(index, end)

// Allocate a smaller buffer for this batch
val batchBuffer = ByteBuffer.allocate(batch.size * (8 + 8 + embeddingLength * 4))
val batchBuffer = ByteBuffer.allocate(batch.size * (8 + 8 + embeddingDimension * 4))
.order(ByteOrder.LITTLE_ENDIAN)

for (embedding in batch) {
if (embedding.embeddings.size != embeddingLength) {
throw IllegalArgumentException("Embedding length must be $embeddingLength")
if (embedding.embeddings.size != embeddingDimension) {
throw IllegalArgumentException("Embedding dimension must be $embeddingDimension")
}
batchBuffer.putLong(embedding.id)
batchBuffer.putLong(embedding.date)
Expand Down Expand Up @@ -88,10 +88,10 @@ class FileEmbeddingStore(
repeat(count) {
val id = buffer.long
val date = buffer.long
val floats = FloatArray(embeddingLength)
val floats = FloatArray(embeddingDimension)
val fb = buffer.asFloatBuffer()
fb.get(floats)
buffer.position(buffer.position() + embeddingLength * 4)
buffer.position(buffer.position() + embeddingDimension * 4)
map[id] = Embedding(id, date, floats)
}
if (useCache) cache = map
Expand Down Expand Up @@ -121,7 +121,7 @@ class FileEmbeddingStore(
val existingCount = headerBuf.int

// Basic validation: each existing entry is at least id(8)+date(8)+EMBEDDING_LEN*4
val minEntryBytes = 8 + 8 + embeddingLength * 4
val minEntryBytes = 8 + 8 + embeddingDimension * 4
val maxCountFromSize = (channel.size() / minEntryBytes).toInt()
if (existingCount < 0 || existingCount > maxCountFromSize + 10_000) {
throw IOException("Corrupt embeddings header: count=$existingCount, fileSize=${channel.size()}")
Expand All @@ -139,10 +139,10 @@ class FileEmbeddingStore(
channel.position(channel.size())

for (embedding in newEmbeddings) {
if (embedding.embeddings.size != embeddingLength) {
throw IllegalArgumentException("Embedding length must be $embeddingLength")
if (embedding.embeddings.size != embeddingDimension) {
throw IllegalArgumentException("Embedding dimension must be $embeddingDimension")
}
val entryBytes = (8 + 8) + embeddingLength * 4
val entryBytes = (8 + 8) + embeddingDimension * 4
val buf = ByteBuffer.allocate(entryBytes).order(ByteOrder.LITTLE_ENDIAN)
buf.putLong(embedding.id)
buf.putLong(embedding.date)
Expand Down