Skip to content

Support embedding func in ChromaDB memory #6267

Closed
@victordibia

Description

@victordibia

Current Status

The current implementation of ChromaDBVectorMemory in the AutoGen extension package doesn't expose parameters for setting custom embedding functions. It relies entirely on ChromaDB's default embedding function (Sentence Transformers all-MiniLM-L6-v2).

Goal

Allow users to customize the embedding function used by ChromaDBVectorMemory through a flexible, declarative configuration system that supports:

  1. Default embedding function (current behavior)
  2. Alternative Sentence Transformer models
  3. OpenAI embeddings
  4. Custom user-defined embedding functions

Rough Sketch of an Implementation Plan

1. Create Base Configuration Classes

Create a hierarchy of embedding function configurations:

class BaseEmbeddingFunctionConfig(BaseModel):
    """Base configuration for embedding functions."""
    function_type: Literal["default", "sentence_transformer", "openai", "custom"]
    

class DefaultEmbeddingFunctionConfig(BaseEmbeddingFunctionConfig):
    """Configuration for the default embedding function."""
    function_type: Literal["default", "sentence_transformer", "openai", "custom"] = "default"


class SentenceTransformerEmbeddingFunctionConfig(BaseEmbeddingFunctionConfig):
    """Configuration for SentenceTransformer embedding functions."""
    function_type: Literal["default", "sentence_transformer", "openai", "custom"] = "sentence_transformer"
    model_name: str = Field(default="all-MiniLM-L6-v2", description="Model name to use")
    

class OpenAIEmbeddingFunctionConfig(BaseEmbeddingFunctionConfig):
    """Configuration for OpenAI embedding functions."""
    function_type: Literal["default", "sentence_transformer", "openai", "custom"] = "openai"
    api_key: str = Field(default="", description="OpenAI API key")
    model_name: str = Field(default="text-embedding-ada-002", description="Model name")

2. Support Custom Embedding Functions

Add a configuration for custom embedding functions using the direct function approach:

class CustomEmbeddingFunctionConfig(BaseEmbeddingFunctionConfig):
    """Configuration for custom embedding functions."""
    function_type: Literal["default", "sentence_transformer", "openai", "custom"] = "custom"
    function: Callable[..., Any] = Field(description="Function that returns an embedding function")
    params: Dict[str, Any] = Field(default_factory=dict, description="Parameters")

Note: Using a direct function in the configuration will make it non-serializable. The implementation should include appropriate warnings when users attempt to serialize configurations that contain function references.

3. Update ChromaDBVectorMemory Configuration

Extend the existing ChromaDBVectorMemoryConfig class to include the embedding function configuration:

class ChromaDBVectorMemoryConfig(BaseModel):
    # Existing fields...
    embedding_function_config: BaseEmbeddingFunctionConfig = Field(
        default_factory=DefaultEmbeddingFunctionConfig,
        description="Configuration for the embedding function"
    )

4. Implement Embedding Function Creation

Add a method to ChromaDBVectorMemory that creates embedding functions based on configuration:

def _create_embedding_function(self):
    """Create an embedding function based on the configuration."""
    from chromadb.utils import embedding_functions
    
    config = self._config.embedding_function_config
    
    if config.function_type == "default":
        return embedding_functions.DefaultEmbeddingFunction()
    
    elif config.function_type == "sentence_transformer":
        cfg = cast(SentenceTransformerEmbeddingFunctionConfig, config)
        return embedding_functions.SentenceTransformerEmbeddingFunction(
            model_name=cfg.model_name
        )
    
    elif config.function_type == "openai":
        cfg = cast(OpenAIEmbeddingFunctionConfig, config)
        return embedding_functions.OpenAIEmbeddingFunction(
            api_key=cfg.api_key,
            model_name=cfg.model_name
        )
    
    elif config.function_type == "custom":
        cfg = cast(CustomEmbeddingFunctionConfig, config)
        return cfg.function(**cfg.params)
    
    else:
        raise ValueError(f"Unsupported embedding function type: {config.function_type}")

5. Update Collection Initialization

Modify the _ensure_initialized method to use the embedding function:

def _ensure_initialized(self) -> None:
    # ... existing client initialization code ...
    
    if self._collection is None:
        try:
            # Create embedding function
            embedding_function = self._create_embedding_function()
            
            # Create or get collection with embedding function
            self._collection = self._client.get_or_create_collection(
                name=self._config.collection_name,
                metadata={"distance_metric": self._config.distance_metric},
                embedding_function=embedding_function
            )
        except Exception as e:
            logger.error(f"Failed to get/create collection: {e}")
            raise

Example Usage

# Using default embedding function
memory = ChromaDBVectorMemory(
    config=PersistentChromaDBVectorMemoryConfig()
)

# Using a specific Sentence Transformer model
memory = ChromaDBVectorMemory(
    config=PersistentChromaDBVectorMemoryConfig(
        embedding_function_config=SentenceTransformerEmbeddingFunctionConfig(
            model_name="paraphrase-multilingual-mpnet-base-v2"
        )
    )
)

# Using OpenAI embeddings
memory = ChromaDBVectorMemory(
    config=PersistentChromaDBVectorMemoryConfig(
        embedding_function_config=OpenAIEmbeddingFunctionConfig(
            api_key="sk-...",
            model_name="text-embedding-3-small"
        )
    )
)

# Using a custom embedding function (direct function approach)
def create_my_embedder(param1="default"):
    # Return a ChromaDB-compatible embedding function
    class MyCustomEmbeddingFunction(EmbeddingFunction):
        def __call__(self, input: Documents) -> Embeddings:
            # Custom embedding logic here
            return embeddings
    
    return MyCustomEmbeddingFunction(param1)

memory = ChromaDBVectorMemory(
    config=PersistentChromaDBVectorMemoryConfig(
        embedding_function_config=CustomEmbeddingFunctionConfig(
            function=create_my_embedder,
            params={"param1": "custom_value"}
        )
    )
)
 

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions