diff --git a/.github/workflows/contrib-openai.yml b/.github/workflows/contrib-openai.yml index 4eda8d93071..5e4ba170370 100644 --- a/.github/workflows/contrib-openai.yml +++ b/.github/workflows/contrib-openai.yml @@ -53,7 +53,7 @@ jobs: AZURE_OPENAI_API_BASE: ${{ secrets.AZURE_OPENAI_API_BASE }} OAI_CONFIG_LIST: ${{ secrets.OAI_CONFIG_LIST }} run: | - coverage run -a -m pytest test/agentchat/contrib/test_retrievechat.py test/agentchat/contrib/test_qdrant_retrievechat.py + coverage run -a -m pytest test/agentchat/contrib/test_retrievechat.py::test_retrievechat test/agentchat/contrib/test_qdrant_retrievechat.py::test_retrievechat coverage xml - name: Upload coverage to Codecov uses: codecov/codecov-action@v3 diff --git a/.github/workflows/contrib-tests.yml b/.github/workflows/contrib-tests.yml index 8b5d0aaab26..719ff086183 100644 --- a/.github/workflows/contrib-tests.yml +++ b/.github/workflows/contrib-tests.yml @@ -58,13 +58,10 @@ jobs: if [[ ${{ matrix.os }} != ubuntu-latest ]]; then echo "AUTOGEN_USE_DOCKER=False" >> $GITHUB_ENV fi - - name: Test RetrieveChat - run: | - pytest test/test_retrieve_utils.py test/agentchat/contrib/test_retrievechat.py test/agentchat/contrib/test_qdrant_retrievechat.py --skip-openai - name: Coverage run: | pip install coverage>=5.3 - coverage run -a -m pytest test/test_retrieve_utils.py test/agentchat/contrib/test_retrievechat.py test/agentchat/contrib/test_qdrant_retrievechat.py --skip-openai + coverage run -a -m pytest test/test_retrieve_utils.py test/agentchat/contrib/test_retrievechat.py test/agentchat/contrib/test_qdrant_retrievechat.py test/agentchat/contrib/vectordb --skip-openai coverage xml - name: Upload coverage to Codecov uses: codecov/codecov-action@v3 diff --git a/autogen/agentchat/contrib/vectordb/__init__.py b/autogen/agentchat/contrib/vectordb/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/autogen/agentchat/contrib/vectordb/base.py b/autogen/agentchat/contrib/vectordb/base.py new file mode 100644 index 00000000000..187d0d6acbb --- /dev/null +++ b/autogen/agentchat/contrib/vectordb/base.py @@ -0,0 +1,209 @@ +from typing import Any, List, Mapping, Optional, Protocol, Sequence, Tuple, TypedDict, Union, runtime_checkable + +Metadata = Union[Mapping[str, Any], None] +Vector = Union[Sequence[float], Sequence[int]] +ItemID = Union[str, int] # chromadb doesn't support int ids, VikingDB does + + +class Document(TypedDict): + """A Document is a record in the vector database. + + id: ItemID | the unique identifier of the document. + content: str | the text content of the chunk. + metadata: Metadata, Optional | contains additional information about the document such as source, date, etc. + embedding: Vector, Optional | the vector representation of the content. + """ + + id: ItemID + content: str + metadata: Optional[Metadata] + embedding: Optional[Vector] + + +"""QueryResults is the response from the vector database for a query/queries. +A query is a list containing one string while queries is a list containing multiple strings. +The response is a list of query results, each query result is a list of tuples containing the document and the distance. +""" +QueryResults = List[List[Tuple[Document, float]]] + + +@runtime_checkable +class VectorDB(Protocol): + """ + Abstract class for vector database. A vector database is responsible for storing and retrieving documents. + + Attributes: + active_collection: Any | The active collection in the vector database. Make get_collection faster. Default is None. + type: str | The type of the vector database, chroma, pgvector, etc. Default is "". + + Methods: + create_collection: Callable[[str, bool, bool], Any] | Create a collection in the vector database. + get_collection: Callable[[str], Any] | Get the collection from the vector database. + delete_collection: Callable[[str], Any] | Delete the collection from the vector database. + insert_docs: Callable[[List[Document], str, bool], None] | Insert documents into the collection of the vector database. + update_docs: Callable[[List[Document], str], None] | Update documents in the collection of the vector database. + delete_docs: Callable[[List[ItemID], str], None] | Delete documents from the collection of the vector database. + retrieve_docs: Callable[[List[str], str, int, float], QueryResults] | Retrieve documents from the collection of the vector database based on the queries. + get_docs_by_ids: Callable[[List[ItemID], str], List[Document]] | Retrieve documents from the collection of the vector database based on the ids. + """ + + active_collection: Any = None + type: str = "" + + def create_collection(self, collection_name: str, overwrite: bool = False, get_or_create: bool = True) -> Any: + """ + Create a collection in the vector database. + Case 1. if the collection does not exist, create the collection. + Case 2. the collection exists, if overwrite is True, it will overwrite the collection. + Case 3. the collection exists and overwrite is False, if get_or_create is True, it will get the collection, + otherwise it raise a ValueError. + + Args: + collection_name: str | The name of the collection. + overwrite: bool | Whether to overwrite the collection if it exists. Default is False. + get_or_create: bool | Whether to get the collection if it exists. Default is True. + + Returns: + Any | The collection object. + """ + ... + + def get_collection(self, collection_name: str = None) -> Any: + """ + Get the collection from the vector database. + + Args: + collection_name: str | The name of the collection. Default is None. If None, return the + current active collection. + + Returns: + Any | The collection object. + """ + ... + + def delete_collection(self, collection_name: str) -> Any: + """ + Delete the collection from the vector database. + + Args: + collection_name: str | The name of the collection. + + Returns: + Any + """ + ... + + def insert_docs(self, docs: List[Document], collection_name: str = None, upsert: bool = False, **kwargs) -> None: + """ + Insert documents into the collection of the vector database. + + Args: + docs: List[Document] | A list of documents. Each document is a TypedDict `Document`. + collection_name: str | The name of the collection. Default is None. + upsert: bool | Whether to update the document if it exists. Default is False. + kwargs: Dict | Additional keyword arguments. + + Returns: + None + """ + ... + + def update_docs(self, docs: List[Document], collection_name: str = None, **kwargs) -> None: + """ + Update documents in the collection of the vector database. + + Args: + docs: List[Document] | A list of documents. + collection_name: str | The name of the collection. Default is None. + kwargs: Dict | Additional keyword arguments. + + Returns: + None + """ + ... + + def delete_docs(self, ids: List[ItemID], collection_name: str = None, **kwargs) -> None: + """ + Delete documents from the collection of the vector database. + + Args: + ids: List[ItemID] | A list of document ids. Each id is a typed `ItemID`. + collection_name: str | The name of the collection. Default is None. + kwargs: Dict | Additional keyword arguments. + + Returns: + None + """ + ... + + def retrieve_docs( + self, + queries: List[str], + collection_name: str = None, + n_results: int = 10, + distance_threshold: float = -1, + **kwargs, + ) -> QueryResults: + """ + Retrieve documents from the collection of the vector database based on the queries. + + Args: + queries: List[str] | A list of queries. Each query is a string. + collection_name: str | The name of the collection. Default is None. + n_results: int | The number of relevant documents to return. Default is 10. + distance_threshold: float | The threshold for the distance score, only distance smaller than it will be + returned. Don't filter with it if < 0. Default is -1. + kwargs: Dict | Additional keyword arguments. + + Returns: + QueryResults | The query results. Each query result is a list of list of tuples containing the document and + the distance. + """ + ... + + def get_docs_by_ids( + self, ids: List[ItemID] = None, collection_name: str = None, include=None, **kwargs + ) -> List[Document]: + """ + Retrieve documents from the collection of the vector database based on the ids. + + Args: + ids: List[ItemID] | A list of document ids. If None, will return all the documents. Default is None. + collection_name: str | The name of the collection. Default is None. + include: List[str] | The fields to include. Default is None. + If None, will include ["metadatas", "documents"], ids will always be included. + kwargs: dict | Additional keyword arguments. + + Returns: + List[Document] | The results. + """ + ... + + +class VectorDBFactory: + """ + Factory class for creating vector databases. + """ + + PREDEFINED_VECTOR_DB = ["chroma"] + + @staticmethod + def create_vector_db(db_type: str, **kwargs) -> VectorDB: + """ + Create a vector database. + + Args: + db_type: str | The type of the vector database. + kwargs: Dict | The keyword arguments for initializing the vector database. + + Returns: + VectorDB | The vector database. + """ + if db_type.lower() in ["chroma", "chromadb"]: + from .chromadb import ChromaVectorDB + + return ChromaVectorDB(**kwargs) + else: + raise ValueError( + f"Unsupported vector database type: {db_type}. Valid types are {VectorDBFactory.PREDEFINED_VECTOR_DB}." + ) diff --git a/autogen/agentchat/contrib/vectordb/chromadb.py b/autogen/agentchat/contrib/vectordb/chromadb.py new file mode 100644 index 00000000000..6e571d58abc --- /dev/null +++ b/autogen/agentchat/contrib/vectordb/chromadb.py @@ -0,0 +1,318 @@ +import os +from typing import Callable, List + +from .base import Document, ItemID, QueryResults, VectorDB +from .utils import chroma_results_to_query_results, filter_results_by_distance, get_logger + +try: + import chromadb + + if chromadb.__version__ < "0.4.15": + raise ImportError("Please upgrade chromadb to version 0.4.15 or later.") + import chromadb.utils.embedding_functions as ef + from chromadb.api.models.Collection import Collection +except ImportError: + raise ImportError("Please install chromadb: `pip install chromadb`") + +CHROMADB_MAX_BATCH_SIZE = os.environ.get("CHROMADB_MAX_BATCH_SIZE", 40000) +logger = get_logger(__name__) + + +class ChromaVectorDB(VectorDB): + """ + A vector database that uses ChromaDB as the backend. + """ + + def __init__( + self, *, client=None, path: str = None, embedding_function: Callable = None, metadata: dict = None, **kwargs + ) -> None: + """ + Initialize the vector database. + + Args: + client: chromadb.Client | The client object of the vector database. Default is None. + If provided, it will use the client object directly and ignore other arguments. + path: str | The path to the vector database. Default is None. + embedding_function: Callable | The embedding function used to generate the vector representation + of the documents. Default is None, SentenceTransformerEmbeddingFunction("all-MiniLM-L6-v2") will be used. + metadata: dict | The metadata of the vector database. Default is None. If None, it will use this + setting: {"hnsw:space": "ip", "hnsw:construction_ef": 30, "hnsw:M": 32}. For more details of + the metadata, please refer to [distances](https://github.com/nmslib/hnswlib#supported-distances), + [hnsw](https://github.com/chroma-core/chroma/blob/566bc80f6c8ee29f7d99b6322654f32183c368c4/chromadb/segment/impl/vector/local_hnsw.py#L184), + and [ALGO_PARAMS](https://github.com/nmslib/hnswlib/blob/master/ALGO_PARAMS.md). + kwargs: dict | Additional keyword arguments. + + Returns: + None + """ + self.client = client + self.path = path + self.embedding_function = ( + ef.SentenceTransformerEmbeddingFunction("all-MiniLM-L6-v2") + if embedding_function is None + else embedding_function + ) + self.metadata = metadata if metadata else {"hnsw:space": "ip", "hnsw:construction_ef": 30, "hnsw:M": 32} + if not self.client: + if self.path is not None: + self.client = chromadb.PersistentClient(path=self.path, **kwargs) + else: + self.client = chromadb.Client(**kwargs) + self.active_collection = None + self.type = "chroma" + + def create_collection( + self, collection_name: str, overwrite: bool = False, get_or_create: bool = True + ) -> Collection: + """ + Create a collection in the vector database. + Case 1. if the collection does not exist, create the collection. + Case 2. the collection exists, if overwrite is True, it will overwrite the collection. + Case 3. the collection exists and overwrite is False, if get_or_create is True, it will get the collection, + otherwise it raise a ValueError. + + Args: + collection_name: str | The name of the collection. + overwrite: bool | Whether to overwrite the collection if it exists. Default is False. + get_or_create: bool | Whether to get the collection if it exists. Default is True. + + Returns: + Collection | The collection object. + """ + try: + if self.active_collection and self.active_collection.name == collection_name: + collection = self.active_collection + else: + collection = self.client.get_collection(collection_name) + except ValueError: + collection = None + if collection is None: + return self.client.create_collection( + collection_name, + embedding_function=self.embedding_function, + get_or_create=get_or_create, + metadata=self.metadata, + ) + elif overwrite: + self.client.delete_collection(collection_name) + return self.client.create_collection( + collection_name, + embedding_function=self.embedding_function, + get_or_create=get_or_create, + metadata=self.metadata, + ) + elif get_or_create: + return collection + else: + raise ValueError(f"Collection {collection_name} already exists.") + + def get_collection(self, collection_name: str = None) -> Collection: + """ + Get the collection from the vector database. + + Args: + collection_name: str | The name of the collection. Default is None. If None, return the + current active collection. + + Returns: + Collection | The collection object. + """ + if collection_name is None: + if self.active_collection is None: + raise ValueError("No collection is specified.") + else: + logger.info( + f"No collection is specified. Using current active collection {self.active_collection.name}." + ) + else: + if not (self.active_collection and self.active_collection.name == collection_name): + self.active_collection = self.client.get_collection(collection_name) + return self.active_collection + + def delete_collection(self, collection_name: str) -> None: + """ + Delete the collection from the vector database. + + Args: + collection_name: str | The name of the collection. + + Returns: + None + """ + self.client.delete_collection(collection_name) + if self.active_collection and self.active_collection.name == collection_name: + self.active_collection = None + + def _batch_insert( + self, collection: Collection, embeddings=None, ids=None, metadatas=None, documents=None, upsert=False + ) -> None: + batch_size = int(CHROMADB_MAX_BATCH_SIZE) + for i in range(0, len(documents), min(batch_size, len(documents))): + end_idx = i + min(batch_size, len(documents) - i) + collection_kwargs = { + "documents": documents[i:end_idx], + "ids": ids[i:end_idx], + "metadatas": metadatas[i:end_idx] if metadatas else None, + "embeddings": embeddings[i:end_idx] if embeddings else None, + } + if upsert: + collection.upsert(**collection_kwargs) + else: + collection.add(**collection_kwargs) + + def insert_docs(self, docs: List[Document], collection_name: str = None, upsert: bool = False) -> None: + """ + Insert documents into the collection of the vector database. + + Args: + docs: List[Document] | A list of documents. Each document is a TypedDict `Document`. + collection_name: str | The name of the collection. Default is None. + upsert: bool | Whether to update the document if it exists. Default is False. + kwargs: Dict | Additional keyword arguments. + + Returns: + None + """ + if not docs: + return + if docs[0].get("content") is None: + raise ValueError("The document content is required.") + if docs[0].get("id") is None: + raise ValueError("The document id is required.") + documents = [doc.get("content") for doc in docs] + ids = [doc.get("id") for doc in docs] + collection = self.get_collection(collection_name) + if docs[0].get("embedding") is None: + logger.info( + "No content embedding is provided. Will use the VectorDB's embedding function to generate the content embedding." + ) + embeddings = None + else: + embeddings = [doc.get("embedding") for doc in docs] + if docs[0].get("metadata") is None: + metadatas = None + else: + metadatas = [doc.get("metadata") for doc in docs] + self._batch_insert(collection, embeddings, ids, metadatas, documents, upsert) + + def update_docs(self, docs: List[Document], collection_name: str = None) -> None: + """ + Update documents in the collection of the vector database. + + Args: + docs: List[Document] | A list of documents. + collection_name: str | The name of the collection. Default is None. + + Returns: + None + """ + self.insert_docs(docs, collection_name, upsert=True) + + def delete_docs(self, ids: List[ItemID], collection_name: str = None, **kwargs) -> None: + """ + Delete documents from the collection of the vector database. + + Args: + ids: List[ItemID] | A list of document ids. Each id is a typed `ItemID`. + collection_name: str | The name of the collection. Default is None. + kwargs: Dict | Additional keyword arguments. + + Returns: + None + """ + collection = self.get_collection(collection_name) + collection.delete(ids, **kwargs) + + def retrieve_docs( + self, + queries: List[str], + collection_name: str = None, + n_results: int = 10, + distance_threshold: float = -1, + **kwargs, + ) -> QueryResults: + """ + Retrieve documents from the collection of the vector database based on the queries. + + Args: + queries: List[str] | A list of queries. Each query is a string. + collection_name: str | The name of the collection. Default is None. + n_results: int | The number of relevant documents to return. Default is 10. + distance_threshold: float | The threshold for the distance score, only distance smaller than it will be + returned. Don't filter with it if < 0. Default is -1. + kwargs: Dict | Additional keyword arguments. + + Returns: + QueryResults | The query results. Each query result is a list of list of tuples containing the document and + the distance. + """ + collection = self.get_collection(collection_name) + if isinstance(queries, str): + queries = [queries] + results = collection.query( + query_texts=queries, + n_results=n_results, + **kwargs, + ) + results["contents"] = results.pop("documents") + results = chroma_results_to_query_results(results) + results = filter_results_by_distance(results, distance_threshold) + return results + + @staticmethod + def _chroma_get_results_to_list_documents(data_dict) -> List[Document]: + """Converts a dictionary with list values to a list of Document. + + Args: + data_dict: A dictionary where keys map to lists or None. + + Returns: + List[Document] | The list of Document. + + Example: + data_dict = { + "key1s": [1, 2, 3], + "key2s": ["a", "b", "c"], + "key3s": None, + "key4s": ["x", "y", "z"], + } + + results = [ + {"key1": 1, "key2": "a", "key4": "x"}, + {"key1": 2, "key2": "b", "key4": "y"}, + {"key1": 3, "key2": "c", "key4": "z"}, + ] + """ + + results = [] + keys = [key for key in data_dict if data_dict[key] is not None] + + for i in range(len(data_dict[keys[0]])): + sub_dict = {} + for key in data_dict.keys(): + if data_dict[key] is not None and len(data_dict[key]) > i: + sub_dict[key[:-1]] = data_dict[key][i] + results.append(sub_dict) + return results + + def get_docs_by_ids( + self, ids: List[ItemID] = None, collection_name: str = None, include=None, **kwargs + ) -> List[Document]: + """ + Retrieve documents from the collection of the vector database based on the ids. + + Args: + ids: List[ItemID] | A list of document ids. If None, will return all the documents. Default is None. + collection_name: str | The name of the collection. Default is None. + include: List[str] | The fields to include. Default is None. + If None, will include ["metadatas", "documents"], ids will always be included. + kwargs: dict | Additional keyword arguments. + + Returns: + List[Document] | The results. + """ + collection = self.get_collection(collection_name) + include = include if include else ["metadatas", "documents"] + results = collection.get(ids, include=include, **kwargs) + results = self._chroma_get_results_to_list_documents(results) + return results diff --git a/autogen/agentchat/contrib/vectordb/utils.py b/autogen/agentchat/contrib/vectordb/utils.py new file mode 100644 index 00000000000..ae1ef125251 --- /dev/null +++ b/autogen/agentchat/contrib/vectordb/utils.py @@ -0,0 +1,112 @@ +import logging +from typing import Any, Dict, List + +from termcolor import colored + +from .base import QueryResults + + +class ColoredLogger(logging.Logger): + def __init__(self, name, level=logging.NOTSET): + super().__init__(name, level) + + def debug(self, msg, *args, color=None, **kwargs): + super().debug(colored(msg, color), *args, **kwargs) + + def info(self, msg, *args, color=None, **kwargs): + super().info(colored(msg, color), *args, **kwargs) + + def warning(self, msg, *args, color="yellow", **kwargs): + super().warning(colored(msg, color), *args, **kwargs) + + def error(self, msg, *args, color="light_red", **kwargs): + super().error(colored(msg, color), *args, **kwargs) + + def critical(self, msg, *args, color="red", **kwargs): + super().critical(colored(msg, color), *args, **kwargs) + + +def get_logger(name: str, level: int = logging.INFO) -> ColoredLogger: + logger = ColoredLogger(name, level) + console_handler = logging.StreamHandler() + logger.addHandler(console_handler) + formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") + logger.handlers[0].setFormatter(formatter) + return logger + + +logger = get_logger(__name__) + + +def filter_results_by_distance(results: QueryResults, distance_threshold: float = -1) -> QueryResults: + """Filters results based on a distance threshold. + + Args: + results: QueryResults | The query results. List[List[Tuple[Document, float]]] + distance_threshold: The maximum distance allowed for results. + + Returns: + QueryResults | A filtered results containing only distances smaller than the threshold. + """ + + if distance_threshold > 0: + results = [[(key, value) for key, value in data if value < distance_threshold] for data in results] + + return results + + +def chroma_results_to_query_results(data_dict: Dict[str, List[List[Any]]], special_key="distances") -> QueryResults: + """Converts a dictionary with list-of-list values to a list of tuples. + + Args: + data_dict: A dictionary where keys map to lists of lists or None. + special_key: The key in the dictionary containing the special values + for each tuple. + + Returns: + A list of tuples, where each tuple contains a sub-dictionary with + some keys from the original dictionary and the value from the + special_key. + + Example: + data_dict = { + "key1s": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], + "key2s": [["a", "b", "c"], ["c", "d", "e"], ["e", "f", "g"]], + "key3s": None, + "key4s": [["x", "y", "z"], ["1", "2", "3"], ["4", "5", "6"]], + "distances": [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]], + } + + results = [ + [ + ({"key1": 1, "key2": "a", "key4": "x"}, 0.1), + ({"key1": 2, "key2": "b", "key4": "y"}, 0.2), + ({"key1": 3, "key2": "c", "key4": "z"}, 0.3), + ], + [ + ({"key1": 4, "key2": "c", "key4": "1"}, 0.4), + ({"key1": 5, "key2": "d", "key4": "2"}, 0.5), + ({"key1": 6, "key2": "e", "key4": "3"}, 0.6), + ], + [ + ({"key1": 7, "key2": "e", "key4": "4"}, 0.7), + ({"key1": 8, "key2": "f", "key4": "5"}, 0.8), + ({"key1": 9, "key2": "g", "key4": "6"}, 0.9), + ], + ] + """ + + keys = [key for key in data_dict if key != special_key] + result = [] + + for i in range(len(data_dict[special_key])): + sub_result = [] + for j, distance in enumerate(data_dict[special_key][i]): + sub_dict = {} + for key in keys: + if data_dict[key] is not None and len(data_dict[key]) > i: + sub_dict[key[:-1]] = data_dict[key][i][j] # remove 's' in the end from key + sub_result.append((sub_dict, distance)) + result.append(sub_result) + + return result diff --git a/test/agentchat/contrib/vectordb/test_chromadb.py b/test/agentchat/contrib/vectordb/test_chromadb.py new file mode 100644 index 00000000000..ee4886f5154 --- /dev/null +++ b/test/agentchat/contrib/vectordb/test_chromadb.py @@ -0,0 +1,95 @@ +import os +import sys + +import pytest + +sys.path.append(os.path.join(os.path.dirname(__file__), "..")) + +try: + import chromadb + import sentence_transformers + + from autogen.agentchat.contrib.vectordb.chromadb import ChromaVectorDB +except ImportError: + skip = True +else: + skip = False + + +@pytest.mark.skipif(skip, reason="dependency is not installed") +def test_chromadb(): + # test create collection + db = ChromaVectorDB(path=".db") + collection_name = "test_collection" + collection = db.create_collection(collection_name, overwrite=True, get_or_create=True) + assert collection.name == collection_name + + # test_delete_collection + db.delete_collection(collection_name) + pytest.raises(ValueError, db.get_collection, collection_name) + + # test more create collection + collection = db.create_collection(collection_name, overwrite=False, get_or_create=False) + assert collection.name == collection_name + pytest.raises(ValueError, db.create_collection, collection_name, overwrite=False, get_or_create=False) + collection = db.create_collection(collection_name, overwrite=True, get_or_create=False) + assert collection.name == collection_name + collection = db.create_collection(collection_name, overwrite=False, get_or_create=True) + assert collection.name == collection_name + + # test_get_collection + collection = db.get_collection(collection_name) + assert collection.name == collection_name + + # test_insert_docs + docs = [{"content": "doc1", "id": "1"}, {"content": "doc2", "id": "2"}, {"content": "doc3", "id": "3"}] + db.insert_docs(docs, collection_name, upsert=False) + res = db.get_collection(collection_name).get(["1", "2"]) + assert res["documents"] == ["doc1", "doc2"] + + # test_update_docs + docs = [{"content": "doc11", "id": "1"}, {"content": "doc2", "id": "2"}, {"content": "doc3", "id": "3"}] + db.update_docs(docs, collection_name) + res = db.get_collection(collection_name).get(["1", "2"]) + assert res["documents"] == ["doc11", "doc2"] + + # test_delete_docs + ids = ["1"] + collection_name = "test_collection" + db.delete_docs(ids, collection_name) + res = db.get_collection(collection_name).get(ids) + assert res["documents"] == [] + + # test_retrieve_docs + queries = ["doc2", "doc3"] + collection_name = "test_collection" + res = db.retrieve_docs(queries, collection_name) + assert [[r[0]["id"] for r in rr] for rr in res] == [["2", "3"], ["3", "2"]] + res = db.retrieve_docs(queries, collection_name, distance_threshold=0.1) + print(res) + assert [[r[0]["id"] for r in rr] for rr in res] == [["2"], ["3"]] + + # test_get_docs_by_ids + res = db.get_docs_by_ids(["1", "2"], collection_name) + assert [r["id"] for r in res] == ["2"] # "1" has been deleted + res = db.get_docs_by_ids(collection_name=collection_name) + assert [r["id"] for r in res] == ["2", "3"] + + # test _chroma_get_results_to_list_documents + data_dict = { + "key1s": [1, 2, 3], + "key2s": ["a", "b", "c"], + "key3s": None, + "key4s": ["x", "y", "z"], + } + + results = [ + {"key1": 1, "key2": "a", "key4": "x"}, + {"key1": 2, "key2": "b", "key4": "y"}, + {"key1": 3, "key2": "c", "key4": "z"}, + ] + assert db._chroma_get_results_to_list_documents(data_dict) == results + + +if __name__ == "__main__": + test_chromadb() diff --git a/test/agentchat/contrib/vectordb/test_vectordb_utils.py b/test/agentchat/contrib/vectordb/test_vectordb_utils.py new file mode 100644 index 00000000000..8c26ac9c3cd --- /dev/null +++ b/test/agentchat/contrib/vectordb/test_vectordb_utils.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 -m pytest + +import os +import sys + +import pytest + +from autogen.agentchat.contrib.vectordb.utils import chroma_results_to_query_results, filter_results_by_distance + + +def test_retrieve_config(): + results = [ + [("id1", 1), ("id2", 2)], + [("id3", 2), ("id4", 3)], + ] + print(filter_results_by_distance(results, 2.1)) + filter_results = [ + [("id1", 1), ("id2", 2)], + [("id3", 2)], + ] + assert filter_results == filter_results_by_distance(results, 2.1) + + +def test_chroma_results_to_query_results(): + data_dict = { + "key1s": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], + "key2s": [["a", "b", "c"], ["c", "d", "e"], ["e", "f", "g"]], + "key3s": None, + "key4s": [["x", "y", "z"], ["1", "2", "3"], ["4", "5", "6"]], + "distances": [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]], + } + results = [ + [ + ({"key1": 1, "key2": "a", "key4": "x"}, 0.1), + ({"key1": 2, "key2": "b", "key4": "y"}, 0.2), + ({"key1": 3, "key2": "c", "key4": "z"}, 0.3), + ], + [ + ({"key1": 4, "key2": "c", "key4": "1"}, 0.4), + ({"key1": 5, "key2": "d", "key4": "2"}, 0.5), + ({"key1": 6, "key2": "e", "key4": "3"}, 0.6), + ], + [ + ({"key1": 7, "key2": "e", "key4": "4"}, 0.7), + ({"key1": 8, "key2": "f", "key4": "5"}, 0.8), + ({"key1": 9, "key2": "g", "key4": "6"}, 0.9), + ], + ] + assert chroma_results_to_query_results(data_dict) == results