Description
Current Status
The current implementation of ChromaDBVectorMemory
in the AutoGen extension package doesn't expose parameters for setting custom embedding functions. It relies entirely on ChromaDB's default embedding function (Sentence Transformers all-MiniLM-L6-v2).
Goal
Allow users to customize the embedding function used by ChromaDBVectorMemory
through a flexible, declarative configuration system that supports:
- Default embedding function (current behavior)
- Alternative Sentence Transformer models
- OpenAI embeddings
- Custom user-defined embedding functions
Rough Sketch of an Implementation Plan
1. Create Base Configuration Classes
Create a hierarchy of embedding function configurations:
class BaseEmbeddingFunctionConfig(BaseModel):
"""Base configuration for embedding functions."""
function_type: Literal["default", "sentence_transformer", "openai", "custom"]
class DefaultEmbeddingFunctionConfig(BaseEmbeddingFunctionConfig):
"""Configuration for the default embedding function."""
function_type: Literal["default", "sentence_transformer", "openai", "custom"] = "default"
class SentenceTransformerEmbeddingFunctionConfig(BaseEmbeddingFunctionConfig):
"""Configuration for SentenceTransformer embedding functions."""
function_type: Literal["default", "sentence_transformer", "openai", "custom"] = "sentence_transformer"
model_name: str = Field(default="all-MiniLM-L6-v2", description="Model name to use")
class OpenAIEmbeddingFunctionConfig(BaseEmbeddingFunctionConfig):
"""Configuration for OpenAI embedding functions."""
function_type: Literal["default", "sentence_transformer", "openai", "custom"] = "openai"
api_key: str = Field(default="", description="OpenAI API key")
model_name: str = Field(default="text-embedding-ada-002", description="Model name")
2. Support Custom Embedding Functions
Add a configuration for custom embedding functions using the direct function approach:
class CustomEmbeddingFunctionConfig(BaseEmbeddingFunctionConfig):
"""Configuration for custom embedding functions."""
function_type: Literal["default", "sentence_transformer", "openai", "custom"] = "custom"
function: Callable[..., Any] = Field(description="Function that returns an embedding function")
params: Dict[str, Any] = Field(default_factory=dict, description="Parameters")
Note: Using a direct function in the configuration will make it non-serializable. The implementation should include appropriate warnings when users attempt to serialize configurations that contain function references.
3. Update ChromaDBVectorMemory Configuration
Extend the existing ChromaDBVectorMemoryConfig
class to include the embedding function configuration:
class ChromaDBVectorMemoryConfig(BaseModel):
# Existing fields...
embedding_function_config: BaseEmbeddingFunctionConfig = Field(
default_factory=DefaultEmbeddingFunctionConfig,
description="Configuration for the embedding function"
)
4. Implement Embedding Function Creation
Add a method to ChromaDBVectorMemory
that creates embedding functions based on configuration:
def _create_embedding_function(self):
"""Create an embedding function based on the configuration."""
from chromadb.utils import embedding_functions
config = self._config.embedding_function_config
if config.function_type == "default":
return embedding_functions.DefaultEmbeddingFunction()
elif config.function_type == "sentence_transformer":
cfg = cast(SentenceTransformerEmbeddingFunctionConfig, config)
return embedding_functions.SentenceTransformerEmbeddingFunction(
model_name=cfg.model_name
)
elif config.function_type == "openai":
cfg = cast(OpenAIEmbeddingFunctionConfig, config)
return embedding_functions.OpenAIEmbeddingFunction(
api_key=cfg.api_key,
model_name=cfg.model_name
)
elif config.function_type == "custom":
cfg = cast(CustomEmbeddingFunctionConfig, config)
return cfg.function(**cfg.params)
else:
raise ValueError(f"Unsupported embedding function type: {config.function_type}")
5. Update Collection Initialization
Modify the _ensure_initialized
method to use the embedding function:
def _ensure_initialized(self) -> None:
# ... existing client initialization code ...
if self._collection is None:
try:
# Create embedding function
embedding_function = self._create_embedding_function()
# Create or get collection with embedding function
self._collection = self._client.get_or_create_collection(
name=self._config.collection_name,
metadata={"distance_metric": self._config.distance_metric},
embedding_function=embedding_function
)
except Exception as e:
logger.error(f"Failed to get/create collection: {e}")
raise
Example Usage
# Using default embedding function
memory = ChromaDBVectorMemory(
config=PersistentChromaDBVectorMemoryConfig()
)
# Using a specific Sentence Transformer model
memory = ChromaDBVectorMemory(
config=PersistentChromaDBVectorMemoryConfig(
embedding_function_config=SentenceTransformerEmbeddingFunctionConfig(
model_name="paraphrase-multilingual-mpnet-base-v2"
)
)
)
# Using OpenAI embeddings
memory = ChromaDBVectorMemory(
config=PersistentChromaDBVectorMemoryConfig(
embedding_function_config=OpenAIEmbeddingFunctionConfig(
api_key="sk-...",
model_name="text-embedding-3-small"
)
)
)
# Using a custom embedding function (direct function approach)
def create_my_embedder(param1="default"):
# Return a ChromaDB-compatible embedding function
class MyCustomEmbeddingFunction(EmbeddingFunction):
def __call__(self, input: Documents) -> Embeddings:
# Custom embedding logic here
return embeddings
return MyCustomEmbeddingFunction(param1)
memory = ChromaDBVectorMemory(
config=PersistentChromaDBVectorMemoryConfig(
embedding_function_config=CustomEmbeddingFunctionConfig(
function=create_my_embedder,
params={"param1": "custom_value"}
)
)
)