Skip to content

Commit 67ebeed

Browse files
Feature/chromadb embedding functions #6267 (#6648)
## Why are these changes needed? This PR adds support for configurable embedding functions in ChromaDBVectorMemory, addressing the need for users to customize how embeddings are generated for vector similarity search. Currently, ChromaDB memory is limited to default embedding functions, which restricts flexibility for different use cases that may require specific embedding models or custom embedding logic. The implementation allows users to: - Use different SentenceTransformer models for domain-specific embeddings - Integrate with OpenAI's embedding API for consistent embedding generation - Define custom embedding functions for specialized requirements - Maintain backward compatibility with existing default behavior ## Related issue number Closes #6267 ## Checks - [x] I've included any doc changes needed for <https://microsoft.github.io/autogen/>. See <https://github.com/microsoft/autogen/blob/main/CONTRIBUTING.md> to build and test documentation locally. - [x] I've added tests corresponding to the changes introduced in this PR. - [x] I've made sure all auto checks have passed. --------- Co-authored-by: Victor Dibia <victordibia@microsoft.com> Co-authored-by: Victor Dibia <victor.dibia@gmail.com>
1 parent 150ea01 commit 67ebeed

File tree

5 files changed

+1029
-475
lines changed

5 files changed

+1029
-475
lines changed

python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/memory.ipynb

Lines changed: 563 additions & 439 deletions
Large diffs are not rendered by default.
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from ._chroma_configs import (
2+
ChromaDBVectorMemoryConfig,
3+
CustomEmbeddingFunctionConfig,
4+
DefaultEmbeddingFunctionConfig,
5+
HttpChromaDBVectorMemoryConfig,
6+
OpenAIEmbeddingFunctionConfig,
7+
PersistentChromaDBVectorMemoryConfig,
8+
SentenceTransformerEmbeddingFunctionConfig,
9+
)
10+
from ._chromadb import ChromaDBVectorMemory
11+
12+
__all__ = [
13+
"ChromaDBVectorMemory",
14+
"ChromaDBVectorMemoryConfig",
15+
"PersistentChromaDBVectorMemoryConfig",
16+
"HttpChromaDBVectorMemoryConfig",
17+
"DefaultEmbeddingFunctionConfig",
18+
"SentenceTransformerEmbeddingFunctionConfig",
19+
"OpenAIEmbeddingFunctionConfig",
20+
"CustomEmbeddingFunctionConfig",
21+
]
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
"""Configuration classes for ChromaDB vector memory."""
2+
3+
from typing import Any, Callable, Dict, Literal, Union
4+
5+
from pydantic import BaseModel, Field
6+
from typing_extensions import Annotated
7+
8+
9+
class DefaultEmbeddingFunctionConfig(BaseModel):
10+
"""Configuration for the default ChromaDB embedding function.
11+
12+
Uses ChromaDB's default embedding function (Sentence Transformers all-MiniLM-L6-v2).
13+
14+
.. versionadded:: v0.4.1
15+
Support for custom embedding functions in ChromaDB memory.
16+
"""
17+
18+
function_type: Literal["default"] = "default"
19+
20+
21+
class SentenceTransformerEmbeddingFunctionConfig(BaseModel):
22+
"""Configuration for SentenceTransformer embedding functions.
23+
24+
Allows specifying a custom SentenceTransformer model for embeddings.
25+
26+
.. versionadded:: v0.4.1
27+
Support for custom embedding functions in ChromaDB memory.
28+
29+
Args:
30+
model_name (str): Name of the SentenceTransformer model to use.
31+
Defaults to "all-MiniLM-L6-v2".
32+
33+
Example:
34+
.. code-block:: python
35+
36+
config = SentenceTransformerEmbeddingFunctionConfig(model_name="paraphrase-multilingual-mpnet-base-v2")
37+
"""
38+
39+
function_type: Literal["sentence_transformer"] = "sentence_transformer"
40+
model_name: str = Field(default="all-MiniLM-L6-v2", description="SentenceTransformer model name to use")
41+
42+
43+
class OpenAIEmbeddingFunctionConfig(BaseModel):
44+
"""Configuration for OpenAI embedding functions.
45+
46+
Uses OpenAI's embedding API for generating embeddings.
47+
48+
.. versionadded:: v0.4.1
49+
Support for custom embedding functions in ChromaDB memory.
50+
51+
Args:
52+
api_key (str): OpenAI API key. If empty, will attempt to use environment variable.
53+
model_name (str): OpenAI embedding model name. Defaults to "text-embedding-ada-002".
54+
55+
Example:
56+
.. code-block:: python
57+
58+
config = OpenAIEmbeddingFunctionConfig(api_key="sk-...", model_name="text-embedding-3-small")
59+
"""
60+
61+
function_type: Literal["openai"] = "openai"
62+
api_key: str = Field(default="", description="OpenAI API key")
63+
model_name: str = Field(default="text-embedding-ada-002", description="OpenAI embedding model name")
64+
65+
66+
class CustomEmbeddingFunctionConfig(BaseModel):
67+
"""Configuration for custom embedding functions.
68+
69+
Allows using a custom function that returns a ChromaDB-compatible embedding function.
70+
71+
.. versionadded:: v0.4.1
72+
Support for custom embedding functions in ChromaDB memory.
73+
74+
.. warning::
75+
Configurations containing custom functions are not serializable.
76+
77+
Args:
78+
function (Callable): Function that returns a ChromaDB-compatible embedding function.
79+
params (Dict[str, Any]): Parameters to pass to the function.
80+
81+
Example:
82+
.. code-block:: python
83+
84+
def create_my_embedder(param1="default"):
85+
# Return a ChromaDB-compatible embedding function
86+
class MyCustomEmbeddingFunction(EmbeddingFunction):
87+
def __call__(self, input: Documents) -> Embeddings:
88+
# Custom embedding logic here
89+
return embeddings
90+
91+
return MyCustomEmbeddingFunction(param1)
92+
93+
94+
config = CustomEmbeddingFunctionConfig(function=create_my_embedder, params={"param1": "custom_value"})
95+
"""
96+
97+
function_type: Literal["custom"] = "custom"
98+
function: Callable[..., Any] = Field(description="Function that returns an embedding function")
99+
params: Dict[str, Any] = Field(default_factory=dict, description="Parameters to pass to the function")
100+
101+
102+
# Tagged union type for embedding function configurations
103+
EmbeddingFunctionConfig = Annotated[
104+
Union[
105+
DefaultEmbeddingFunctionConfig,
106+
SentenceTransformerEmbeddingFunctionConfig,
107+
OpenAIEmbeddingFunctionConfig,
108+
CustomEmbeddingFunctionConfig,
109+
],
110+
Field(discriminator="function_type"),
111+
]
112+
113+
114+
class ChromaDBVectorMemoryConfig(BaseModel):
115+
"""Base configuration for ChromaDB-based memory implementation.
116+
117+
.. versionchanged:: v0.4.1
118+
Added support for custom embedding functions via embedding_function_config.
119+
"""
120+
121+
client_type: Literal["persistent", "http"]
122+
collection_name: str = Field(default="memory_store", description="Name of the ChromaDB collection")
123+
distance_metric: str = Field(default="cosine", description="Distance metric for similarity search")
124+
k: int = Field(default=3, description="Number of results to return in queries")
125+
score_threshold: float | None = Field(default=None, description="Minimum similarity score threshold")
126+
allow_reset: bool = Field(default=False, description="Whether to allow resetting the ChromaDB client")
127+
tenant: str = Field(default="default_tenant", description="Tenant to use")
128+
database: str = Field(default="default_database", description="Database to use")
129+
embedding_function_config: EmbeddingFunctionConfig = Field(
130+
default_factory=DefaultEmbeddingFunctionConfig, description="Configuration for the embedding function"
131+
)
132+
133+
134+
class PersistentChromaDBVectorMemoryConfig(ChromaDBVectorMemoryConfig):
135+
"""Configuration for persistent ChromaDB memory."""
136+
137+
client_type: Literal["persistent", "http"] = "persistent"
138+
persistence_path: str = Field(default="./chroma_db", description="Path for persistent storage")
139+
140+
141+
class HttpChromaDBVectorMemoryConfig(ChromaDBVectorMemoryConfig):
142+
"""Configuration for HTTP ChromaDB memory."""
143+
144+
client_type: Literal["persistent", "http"] = "http"
145+
host: str = Field(default="localhost", description="Host of the remote server")
146+
port: int = Field(default=8000, description="Port of the remote server")
147+
ssl: bool = Field(default=False, description="Whether to use HTTPS")
148+
headers: Dict[str, str] | None = Field(default=None, description="Headers to send to the server")

python/packages/autogen-ext/src/autogen_ext/memory/chromadb.py renamed to python/packages/autogen-ext/src/autogen_ext/memory/chromadb/_chromadb.py

Lines changed: 96 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22
import uuid
3-
from typing import Any, Dict, List, Literal
3+
from typing import Any, List
44

55
from autogen_core import CancellationToken, Component, Image
66
from autogen_core.memory import Memory, MemoryContent, MemoryMimeType, MemoryQueryResult, UpdateContextResult
@@ -9,9 +9,18 @@
99
from chromadb import HttpClient, PersistentClient
1010
from chromadb.api.models.Collection import Collection
1111
from chromadb.api.types import Document, Metadata
12-
from pydantic import BaseModel, Field
1312
from typing_extensions import Self
1413

14+
from ._chroma_configs import (
15+
ChromaDBVectorMemoryConfig,
16+
CustomEmbeddingFunctionConfig,
17+
DefaultEmbeddingFunctionConfig,
18+
HttpChromaDBVectorMemoryConfig,
19+
OpenAIEmbeddingFunctionConfig,
20+
PersistentChromaDBVectorMemoryConfig,
21+
SentenceTransformerEmbeddingFunctionConfig,
22+
)
23+
1524
logger = logging.getLogger(__name__)
1625

1726

@@ -23,36 +32,6 @@
2332
) from e
2433

2534

26-
class ChromaDBVectorMemoryConfig(BaseModel):
27-
"""Base configuration for ChromaDB-based memory implementation."""
28-
29-
client_type: Literal["persistent", "http"]
30-
collection_name: str = Field(default="memory_store", description="Name of the ChromaDB collection")
31-
distance_metric: str = Field(default="cosine", description="Distance metric for similarity search")
32-
k: int = Field(default=3, description="Number of results to return in queries")
33-
score_threshold: float | None = Field(default=None, description="Minimum similarity score threshold")
34-
allow_reset: bool = Field(default=False, description="Whether to allow resetting the ChromaDB client")
35-
tenant: str = Field(default="default_tenant", description="Tenant to use")
36-
database: str = Field(default="default_database", description="Database to use")
37-
38-
39-
class PersistentChromaDBVectorMemoryConfig(ChromaDBVectorMemoryConfig):
40-
"""Configuration for persistent ChromaDB memory."""
41-
42-
client_type: Literal["persistent", "http"] = "persistent"
43-
persistence_path: str = Field(default="./chroma_db", description="Path for persistent storage")
44-
45-
46-
class HttpChromaDBVectorMemoryConfig(ChromaDBVectorMemoryConfig):
47-
"""Configuration for HTTP ChromaDB memory."""
48-
49-
client_type: Literal["persistent", "http"] = "http"
50-
host: str = Field(default="localhost", description="Host of the remote server")
51-
port: int = Field(default=8000, description="Port of the remote server")
52-
ssl: bool = Field(default=False, description="Whether to use HTTPS")
53-
headers: Dict[str, str] | None = Field(default=None, description="Headers to send to the server")
54-
55-
5635
class ChromaDBVectorMemory(Memory, Component[ChromaDBVectorMemoryConfig]):
5736
"""
5837
Store and retrieve memory using vector similarity search powered by ChromaDB.
@@ -86,10 +65,15 @@ class ChromaDBVectorMemory(Memory, Component[ChromaDBVectorMemoryConfig]):
8665
from pathlib import Path
8766
from autogen_agentchat.agents import AssistantAgent
8867
from autogen_core.memory import MemoryContent, MemoryMimeType
89-
from autogen_ext.memory.chromadb import ChromaDBVectorMemory, PersistentChromaDBVectorMemoryConfig
68+
from autogen_ext.memory.chromadb import (
69+
ChromaDBVectorMemory,
70+
PersistentChromaDBVectorMemoryConfig,
71+
SentenceTransformerEmbeddingFunctionConfig,
72+
OpenAIEmbeddingFunctionConfig,
73+
)
9074
from autogen_ext.models.openai import OpenAIChatCompletionClient
9175
92-
# Initialize ChromaDB memory with custom config
76+
# Initialize ChromaDB memory with default embedding function
9377
memory = ChromaDBVectorMemory(
9478
config=PersistentChromaDBVectorMemoryConfig(
9579
collection_name="user_preferences",
@@ -99,6 +83,28 @@ class ChromaDBVectorMemory(Memory, Component[ChromaDBVectorMemoryConfig]):
9983
)
10084
)
10185
86+
# Using a custom SentenceTransformer model
87+
memory_custom_st = ChromaDBVectorMemory(
88+
config=PersistentChromaDBVectorMemoryConfig(
89+
collection_name="multilingual_memory",
90+
persistence_path=os.path.join(str(Path.home()), ".chromadb_autogen"),
91+
embedding_function_config=SentenceTransformerEmbeddingFunctionConfig(
92+
model_name="paraphrase-multilingual-mpnet-base-v2"
93+
),
94+
)
95+
)
96+
97+
# Using OpenAI embeddings
98+
memory_openai = ChromaDBVectorMemory(
99+
config=PersistentChromaDBVectorMemoryConfig(
100+
collection_name="openai_memory",
101+
persistence_path=os.path.join(str(Path.home()), ".chromadb_autogen"),
102+
embedding_function_config=OpenAIEmbeddingFunctionConfig(
103+
api_key="sk-...", model_name="text-embedding-3-small"
104+
),
105+
)
106+
)
107+
102108
# Add user preferences to memory
103109
await memory.add(
104110
MemoryContent(
@@ -138,6 +144,55 @@ def collection_name(self) -> str:
138144
"""Get the name of the ChromaDB collection."""
139145
return self._config.collection_name
140146

147+
def _create_embedding_function(self) -> Any:
148+
"""Create an embedding function based on the configuration.
149+
150+
Returns:
151+
A ChromaDB-compatible embedding function.
152+
153+
Raises:
154+
ValueError: If the embedding function type is unsupported.
155+
ImportError: If required dependencies are not installed.
156+
"""
157+
try:
158+
from chromadb.utils import embedding_functions
159+
except ImportError as e:
160+
raise ImportError(
161+
"ChromaDB embedding functions not available. Ensure chromadb is properly installed."
162+
) from e
163+
164+
config = self._config.embedding_function_config
165+
166+
if isinstance(config, DefaultEmbeddingFunctionConfig):
167+
return embedding_functions.DefaultEmbeddingFunction()
168+
169+
elif isinstance(config, SentenceTransformerEmbeddingFunctionConfig):
170+
try:
171+
return embedding_functions.SentenceTransformerEmbeddingFunction(model_name=config.model_name)
172+
except Exception as e:
173+
raise ImportError(
174+
f"Failed to create SentenceTransformer embedding function with model '{config.model_name}'. "
175+
f"Ensure sentence-transformers is installed and the model is available. Error: {e}"
176+
) from e
177+
178+
elif isinstance(config, OpenAIEmbeddingFunctionConfig):
179+
try:
180+
return embedding_functions.OpenAIEmbeddingFunction(api_key=config.api_key, model_name=config.model_name)
181+
except Exception as e:
182+
raise ImportError(
183+
f"Failed to create OpenAI embedding function with model '{config.model_name}'. "
184+
f"Ensure openai is installed and API key is valid. Error: {e}"
185+
) from e
186+
187+
elif isinstance(config, CustomEmbeddingFunctionConfig):
188+
try:
189+
return config.function(**config.params)
190+
except Exception as e:
191+
raise ValueError(f"Failed to create custom embedding function. Error: {e}") from e
192+
193+
else:
194+
raise ValueError(f"Unsupported embedding function config type: {type(config)}")
195+
141196
def _ensure_initialized(self) -> None:
142197
"""Ensure ChromaDB client and collection are initialized."""
143198
if self._client is None:
@@ -171,8 +226,14 @@ def _ensure_initialized(self) -> None:
171226

172227
if self._collection is None:
173228
try:
229+
# Create embedding function
230+
embedding_function = self._create_embedding_function()
231+
232+
# Create or get collection with embedding function
174233
self._collection = self._client.get_or_create_collection(
175-
name=self._config.collection_name, metadata={"distance_metric": self._config.distance_metric}
234+
name=self._config.collection_name,
235+
metadata={"distance_metric": self._config.distance_metric},
236+
embedding_function=embedding_function,
176237
)
177238
except Exception as e:
178239
logger.error(f"Failed to get/create collection: {e}")

0 commit comments

Comments
 (0)