Skip to content

Commit

Permalink
Vector store allowing to change embeddings models and similarity stra…
Browse files Browse the repository at this point in the history
…tegy (#686)

* added parameters

* Apply spotless formatting

* added missing parameter

---------

Co-authored-by: Montagon <Montagon@users.noreply.github.com>
  • Loading branch information
Montagon and Montagon committed Mar 15, 2024
1 parent 7ce94b6 commit 301bd56
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 20 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package com.xebia.functional.xef.llm

import ai.xef.openai.OpenAIModel
import ai.xef.openai.StandardModel
import arrow.fx.coroutines.parMap
import com.xebia.functional.openai.apis.EmbeddingsApi
import com.xebia.functional.openai.models.CreateEmbeddingRequest
Expand All @@ -9,7 +11,9 @@ import com.xebia.functional.openai.models.ext.embedding.create.CreateEmbeddingRe

suspend fun EmbeddingsApi.embedDocuments(
texts: List<String>,
chunkSize: Int = 400
chunkSize: Int = 400,
embeddingRequestModel: OpenAIModel<CreateEmbeddingRequestModel> =
StandardModel(CreateEmbeddingRequestModel.text_embedding_ada_002)
): List<Embedding> =
if (texts.isEmpty()) emptyList()
else
Expand All @@ -18,8 +22,7 @@ suspend fun EmbeddingsApi.embedDocuments(
.parMap {
createEmbedding(
CreateEmbeddingRequest(
model =
ai.xef.openai.StandardModel(CreateEmbeddingRequestModel.text_embedding_ada_002),
model = embeddingRequestModel,
input = CreateEmbeddingRequestInput.StringArrayValue(it)
)
)
Expand All @@ -28,5 +31,10 @@ suspend fun EmbeddingsApi.embedDocuments(
}
.flatten()

suspend fun EmbeddingsApi.embedQuery(text: String): List<Embedding> =
if (text.isNotEmpty()) embedDocuments(listOf(text)) else emptyList()
suspend fun EmbeddingsApi.embedQuery(
text: String,
embeddingRequestModel: OpenAIModel<CreateEmbeddingRequestModel>
): List<Embedding> =
if (text.isNotEmpty())
embedDocuments(texts = listOf(text), embeddingRequestModel = embeddingRequestModel)
else emptyList()
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
package com.xebia.functional.xef.store

import ai.xef.openai.OpenAIModel
import ai.xef.openai.StandardModel
import arrow.atomic.Atomic
import arrow.atomic.AtomicInt
import arrow.atomic.getAndUpdate
import arrow.atomic.update
import com.xebia.functional.openai.apis.EmbeddingsApi
import com.xebia.functional.openai.models.CreateEmbeddingRequestModel
import com.xebia.functional.openai.models.Embedding
import com.xebia.functional.xef.llm.embedDocuments
import com.xebia.functional.xef.llm.embedQuery
Expand All @@ -25,9 +27,16 @@ private data class State(
private typealias AtomicState = Atomic<State>

class LocalVectorStore
private constructor(private val embeddings: EmbeddingsApi, private val state: AtomicState) :
VectorStore {
constructor(embeddings: EmbeddingsApi) : this(embeddings, Atomic(State.empty()))
private constructor(
private val embeddings: EmbeddingsApi,
private val state: AtomicState,
private val embeddingRequestModel: OpenAIModel<CreateEmbeddingRequestModel>
) : VectorStore {
constructor(
embeddings: EmbeddingsApi,
embeddingRequestModel: OpenAIModel<CreateEmbeddingRequestModel> =
StandardModel(CreateEmbeddingRequestModel.text_embedding_ada_002)
) : this(embeddings, Atomic(State.empty()), embeddingRequestModel)

override val indexValue: AtomicInt = AtomicInt(0)

Expand Down Expand Up @@ -68,15 +77,17 @@ private constructor(private val embeddings: EmbeddingsApi, private val state: At
}

override suspend fun addTexts(texts: List<String>) {
val embeddingsList = embeddings.embedDocuments(texts)
val embeddingsList =
embeddings.embedDocuments(texts, embeddingRequestModel = embeddingRequestModel)
state.getAndUpdate { prevState ->
val newEmbeddings = prevState.precomputedEmbeddings + texts.zip(embeddingsList)
State(prevState.orderedMemories, prevState.documents + texts, newEmbeddings)
}
}

override suspend fun similaritySearch(query: String, limit: Int): List<String> {
val queryEmbedding = embeddings.embedQuery(query).firstOrNull()
val queryEmbedding =
embeddings.embedQuery(query, embeddingRequestModel = embeddingRequestModel).firstOrNull()
return queryEmbedding?.let { similaritySearchByVector(it, limit) }.orEmpty()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import ai.xef.openai.OpenAIModel
import arrow.atomic.AtomicInt
import com.xebia.functional.openai.apis.EmbeddingsApi
import com.xebia.functional.openai.models.ChatCompletionRole
import com.xebia.functional.openai.models.CreateEmbeddingRequestModel
import com.xebia.functional.openai.models.Embedding
import com.xebia.functional.xef.llm.embedQuery
import com.xebia.functional.xef.llm.models.modelType
Expand All @@ -24,6 +25,7 @@ open class Lucene(
private val writer: IndexWriter,
private val embeddings: EmbeddingsApi?,
private val similarity: VectorSimilarityFunction = VectorSimilarityFunction.EUCLIDEAN,
private val embeddingAIModel: OpenAIModel<CreateEmbeddingRequestModel>
) : VectorStore, AutoCloseable {

override val indexValue: AtomicInt = AtomicInt(0)
Expand All @@ -47,12 +49,13 @@ open class Lucene(
}

override suspend fun <T> memories(
model: OpenAIModel<T>, conversationId: ConversationId, limitTokens: Int): List<Memory> =
model: OpenAIModel<T>, conversationId: ConversationId, limitTokens: Int
): List<Memory> =
getMemoryByConversationId(conversationId).reduceByLimitToken(model.modelType(), limitTokens).reversed()

override suspend fun addTexts(texts: List<String>) {
texts.forEach {
val embedding = embeddings?.embedQuery(it)
val embedding = embeddings?.embedQuery(text = it, embeddingRequestModel = embeddingAIModel)
val doc =
Document().apply {
add(TextField("contents", it, Field.Store.YES))
Expand Down Expand Up @@ -125,8 +128,9 @@ class DirectoryLucene(
private val directory: Directory,
writerConfig: IndexWriterConfig = IndexWriterConfig(),
embeddings: EmbeddingsApi?,
embeddingAIModel: OpenAIModel<CreateEmbeddingRequestModel>,
similarity: VectorSimilarityFunction = VectorSimilarityFunction.EUCLIDEAN
) : Lucene(IndexWriter(directory, writerConfig), embeddings, similarity) {
) : Lucene(IndexWriter(directory, writerConfig), embeddings, similarity, embeddingAIModel) {
override fun close() {
super.close()
directory.close()
Expand All @@ -138,17 +142,19 @@ fun InMemoryLucene(
path: Path,
writerConfig: IndexWriterConfig = IndexWriterConfig(),
embeddings: EmbeddingsApi?,
embeddingAIModel: OpenAIModel<CreateEmbeddingRequestModel>,
similarity: VectorSimilarityFunction = VectorSimilarityFunction.EUCLIDEAN
): DirectoryLucene = DirectoryLucene(MMapDirectory(path), writerConfig, embeddings, similarity)
): DirectoryLucene = DirectoryLucene(MMapDirectory(path), writerConfig, embeddings, embeddingAIModel, similarity)

@JvmOverloads
fun InMemoryLuceneBuilder(
path: Path,
useAIEmbeddings: Boolean = true,
writerConfig: IndexWriterConfig = IndexWriterConfig(),
embeddingAIModel: OpenAIModel<CreateEmbeddingRequestModel>,
similarity: VectorSimilarityFunction = VectorSimilarityFunction.EUCLIDEAN
): (EmbeddingsApi) -> DirectoryLucene = { embeddings ->
InMemoryLucene(path, writerConfig, embeddings.takeIf { useAIEmbeddings }, similarity)
InMemoryLucene(path, writerConfig, embeddings.takeIf { useAIEmbeddings }, embeddingAIModel, similarity)
}

fun List<Embedding>.toFloatArray(): FloatArray = flatMap { it.embedding.map { it.toFloat() } }.toFloatArray()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import ai.xef.openai.OpenAIModel
import arrow.atomic.AtomicInt
import com.xebia.functional.openai.apis.EmbeddingsApi
import com.xebia.functional.openai.models.ChatCompletionRole
import com.xebia.functional.openai.models.CreateEmbeddingRequestModel
import com.xebia.functional.openai.models.Embedding
import com.xebia.functional.xef.llm.embedDocuments
import com.xebia.functional.xef.llm.embedQuery
Expand All @@ -20,6 +21,7 @@ class PGVectorStore(
private val collectionName: String,
private val distanceStrategy: PGDistanceStrategy,
private val preDeleteCollection: Boolean,
private val embeddingRequestModel: OpenAIModel<CreateEmbeddingRequestModel>,
private val chunkSize: Int = 400
) : VectorStore {

Expand Down Expand Up @@ -83,7 +85,7 @@ class PGVectorStore(

override suspend fun addTexts(texts: List<String>): Unit =
dataSource.connection {
val embeddings = embeddings.embedDocuments(texts, chunkSize)
val embeddings = embeddings.embedDocuments(texts, chunkSize, embeddingRequestModel)
val collection = getCollection(collectionName)
texts.zip(embeddings) { text, embedding ->
val uuid = UUID.generateUUID()
Expand All @@ -105,7 +107,7 @@ class PGVectorStore(
if (!hasEmbeddings) return emptyList()

val embeddings =
embeddings.embedQuery(query).ifEmpty {
embeddings.embedQuery(query, embeddingRequestModel).ifEmpty {
throw IllegalStateException(
"Embedding for text: '$query', has not been properly generated"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,17 @@ class PGVectorStoreSpec :
)
)

val embeddingsRequestModel = StandardModel(CreateEmbeddingRequestModel.text_embedding_ada_002)

fun StringSpecScope.pg() =
PGVectorStore(
vectorSize = 3,
dataSource = dataSource,
embeddings = TestEmbeddings(coroutineContext),
collectionName = "test_collection",
distanceStrategy = PGDistanceStrategy.Euclidean,
preDeleteCollection = false
preDeleteCollection = false,
embeddingRequestModel = embeddingsRequestModel
)

beforeContainer {
Expand All @@ -56,7 +59,8 @@ class PGVectorStoreSpec :
embeddings = TestEmbeddings(coroutineContext),
collectionName = "test_collection",
distanceStrategy = PGDistanceStrategy.Euclidean,
preDeleteCollection = false
preDeleteCollection = false,
embeddingRequestModel = embeddingsRequestModel
)
postgresVector.initialDbSetup()
postgresVector.createCollection()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package com.xebia.functional.xef.server.services

import ai.xef.openai.OpenAIModel
import ai.xef.openai.StandardModel
import com.xebia.functional.openai.apis.EmbeddingsApi
import com.xebia.functional.openai.models.CreateEmbeddingRequestModel
import com.xebia.functional.xef.llm.fromEnvironment
import com.xebia.functional.xef.llm.fromToken
import com.xebia.functional.xef.server.http.routes.Provider
Expand All @@ -21,6 +24,9 @@ class PostgresVectorStoreService(
private val vectorSize: Int,
private val preDeleteCollection: Boolean = false,
private val chunkSize: Int = 400,
private val distanceStrategy: PGDistanceStrategy = PGDistanceStrategy.Euclidean,
private val embeddingRequestModel: OpenAIModel<CreateEmbeddingRequestModel> =
StandardModel(CreateEmbeddingRequestModel.text_embedding_ada_002)
) : VectorStoreService() {

fun addCollection() {
Expand All @@ -45,8 +51,9 @@ class PostgresVectorStoreService(
dataSource = dataSource,
embeddings = embeddingsApi,
collectionName = collectionName,
distanceStrategy = PGDistanceStrategy.Euclidean,
distanceStrategy = distanceStrategy,
preDeleteCollection = preDeleteCollection,
embeddingRequestModel = embeddingRequestModel,
chunkSize = chunkSize
)
}
Expand Down

0 comments on commit 301bd56

Please sign in to comment.