1
1
import logging
2
2
import uuid
3
- from typing import Any , Dict , List , Literal
3
+ from typing import Any , List
4
4
5
5
from autogen_core import CancellationToken , Component , Image
6
6
from autogen_core .memory import Memory , MemoryContent , MemoryMimeType , MemoryQueryResult , UpdateContextResult
9
9
from chromadb import HttpClient , PersistentClient
10
10
from chromadb .api .models .Collection import Collection
11
11
from chromadb .api .types import Document , Metadata
12
- from pydantic import BaseModel , Field
13
12
from typing_extensions import Self
14
13
14
+ from ._chroma_configs import (
15
+ ChromaDBVectorMemoryConfig ,
16
+ CustomEmbeddingFunctionConfig ,
17
+ DefaultEmbeddingFunctionConfig ,
18
+ HttpChromaDBVectorMemoryConfig ,
19
+ OpenAIEmbeddingFunctionConfig ,
20
+ PersistentChromaDBVectorMemoryConfig ,
21
+ SentenceTransformerEmbeddingFunctionConfig ,
22
+ )
23
+
15
24
logger = logging .getLogger (__name__ )
16
25
17
26
23
32
) from e
24
33
25
34
26
- class ChromaDBVectorMemoryConfig (BaseModel ):
27
- """Base configuration for ChromaDB-based memory implementation."""
28
-
29
- client_type : Literal ["persistent" , "http" ]
30
- collection_name : str = Field (default = "memory_store" , description = "Name of the ChromaDB collection" )
31
- distance_metric : str = Field (default = "cosine" , description = "Distance metric for similarity search" )
32
- k : int = Field (default = 3 , description = "Number of results to return in queries" )
33
- score_threshold : float | None = Field (default = None , description = "Minimum similarity score threshold" )
34
- allow_reset : bool = Field (default = False , description = "Whether to allow resetting the ChromaDB client" )
35
- tenant : str = Field (default = "default_tenant" , description = "Tenant to use" )
36
- database : str = Field (default = "default_database" , description = "Database to use" )
37
-
38
-
39
- class PersistentChromaDBVectorMemoryConfig (ChromaDBVectorMemoryConfig ):
40
- """Configuration for persistent ChromaDB memory."""
41
-
42
- client_type : Literal ["persistent" , "http" ] = "persistent"
43
- persistence_path : str = Field (default = "./chroma_db" , description = "Path for persistent storage" )
44
-
45
-
46
- class HttpChromaDBVectorMemoryConfig (ChromaDBVectorMemoryConfig ):
47
- """Configuration for HTTP ChromaDB memory."""
48
-
49
- client_type : Literal ["persistent" , "http" ] = "http"
50
- host : str = Field (default = "localhost" , description = "Host of the remote server" )
51
- port : int = Field (default = 8000 , description = "Port of the remote server" )
52
- ssl : bool = Field (default = False , description = "Whether to use HTTPS" )
53
- headers : Dict [str , str ] | None = Field (default = None , description = "Headers to send to the server" )
54
-
55
-
56
35
class ChromaDBVectorMemory (Memory , Component [ChromaDBVectorMemoryConfig ]):
57
36
"""
58
37
Store and retrieve memory using vector similarity search powered by ChromaDB.
@@ -86,10 +65,15 @@ class ChromaDBVectorMemory(Memory, Component[ChromaDBVectorMemoryConfig]):
86
65
from pathlib import Path
87
66
from autogen_agentchat.agents import AssistantAgent
88
67
from autogen_core.memory import MemoryContent, MemoryMimeType
89
- from autogen_ext.memory.chromadb import ChromaDBVectorMemory, PersistentChromaDBVectorMemoryConfig
68
+ from autogen_ext.memory.chromadb import (
69
+ ChromaDBVectorMemory,
70
+ PersistentChromaDBVectorMemoryConfig,
71
+ SentenceTransformerEmbeddingFunctionConfig,
72
+ OpenAIEmbeddingFunctionConfig,
73
+ )
90
74
from autogen_ext.models.openai import OpenAIChatCompletionClient
91
75
92
- # Initialize ChromaDB memory with custom config
76
+ # Initialize ChromaDB memory with default embedding function
93
77
memory = ChromaDBVectorMemory(
94
78
config=PersistentChromaDBVectorMemoryConfig(
95
79
collection_name="user_preferences",
@@ -99,6 +83,28 @@ class ChromaDBVectorMemory(Memory, Component[ChromaDBVectorMemoryConfig]):
99
83
)
100
84
)
101
85
86
+ # Using a custom SentenceTransformer model
87
+ memory_custom_st = ChromaDBVectorMemory(
88
+ config=PersistentChromaDBVectorMemoryConfig(
89
+ collection_name="multilingual_memory",
90
+ persistence_path=os.path.join(str(Path.home()), ".chromadb_autogen"),
91
+ embedding_function_config=SentenceTransformerEmbeddingFunctionConfig(
92
+ model_name="paraphrase-multilingual-mpnet-base-v2"
93
+ ),
94
+ )
95
+ )
96
+
97
+ # Using OpenAI embeddings
98
+ memory_openai = ChromaDBVectorMemory(
99
+ config=PersistentChromaDBVectorMemoryConfig(
100
+ collection_name="openai_memory",
101
+ persistence_path=os.path.join(str(Path.home()), ".chromadb_autogen"),
102
+ embedding_function_config=OpenAIEmbeddingFunctionConfig(
103
+ api_key="sk-...", model_name="text-embedding-3-small"
104
+ ),
105
+ )
106
+ )
107
+
102
108
# Add user preferences to memory
103
109
await memory.add(
104
110
MemoryContent(
@@ -138,6 +144,55 @@ def collection_name(self) -> str:
138
144
"""Get the name of the ChromaDB collection."""
139
145
return self ._config .collection_name
140
146
147
+ def _create_embedding_function (self ) -> Any :
148
+ """Create an embedding function based on the configuration.
149
+
150
+ Returns:
151
+ A ChromaDB-compatible embedding function.
152
+
153
+ Raises:
154
+ ValueError: If the embedding function type is unsupported.
155
+ ImportError: If required dependencies are not installed.
156
+ """
157
+ try :
158
+ from chromadb .utils import embedding_functions
159
+ except ImportError as e :
160
+ raise ImportError (
161
+ "ChromaDB embedding functions not available. Ensure chromadb is properly installed."
162
+ ) from e
163
+
164
+ config = self ._config .embedding_function_config
165
+
166
+ if isinstance (config , DefaultEmbeddingFunctionConfig ):
167
+ return embedding_functions .DefaultEmbeddingFunction ()
168
+
169
+ elif isinstance (config , SentenceTransformerEmbeddingFunctionConfig ):
170
+ try :
171
+ return embedding_functions .SentenceTransformerEmbeddingFunction (model_name = config .model_name )
172
+ except Exception as e :
173
+ raise ImportError (
174
+ f"Failed to create SentenceTransformer embedding function with model '{ config .model_name } '. "
175
+ f"Ensure sentence-transformers is installed and the model is available. Error: { e } "
176
+ ) from e
177
+
178
+ elif isinstance (config , OpenAIEmbeddingFunctionConfig ):
179
+ try :
180
+ return embedding_functions .OpenAIEmbeddingFunction (api_key = config .api_key , model_name = config .model_name )
181
+ except Exception as e :
182
+ raise ImportError (
183
+ f"Failed to create OpenAI embedding function with model '{ config .model_name } '. "
184
+ f"Ensure openai is installed and API key is valid. Error: { e } "
185
+ ) from e
186
+
187
+ elif isinstance (config , CustomEmbeddingFunctionConfig ):
188
+ try :
189
+ return config .function (** config .params )
190
+ except Exception as e :
191
+ raise ValueError (f"Failed to create custom embedding function. Error: { e } " ) from e
192
+
193
+ else :
194
+ raise ValueError (f"Unsupported embedding function config type: { type (config )} " )
195
+
141
196
def _ensure_initialized (self ) -> None :
142
197
"""Ensure ChromaDB client and collection are initialized."""
143
198
if self ._client is None :
@@ -171,8 +226,14 @@ def _ensure_initialized(self) -> None:
171
226
172
227
if self ._collection is None :
173
228
try :
229
+ # Create embedding function
230
+ embedding_function = self ._create_embedding_function ()
231
+
232
+ # Create or get collection with embedding function
174
233
self ._collection = self ._client .get_or_create_collection (
175
- name = self ._config .collection_name , metadata = {"distance_metric" : self ._config .distance_metric }
234
+ name = self ._config .collection_name ,
235
+ metadata = {"distance_metric" : self ._config .distance_metric },
236
+ embedding_function = embedding_function ,
176
237
)
177
238
except Exception as e :
178
239
logger .error (f"Failed to get/create collection: { e } " )
0 commit comments