From 9e5ba544e12f02fc1f0fa698bf574a18263185be Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Fri, 9 Feb 2024 17:22:28 -0500 Subject: [PATCH 01/15] WIP --- redisvl/utils/rerank/__init__.py | 0 redisvl/utils/rerank/base.py | 70 ++++++++++++++++++++++++++++++++ redisvl/utils/rerank/cohere.py | 0 3 files changed, 70 insertions(+) create mode 100644 redisvl/utils/rerank/__init__.py create mode 100644 redisvl/utils/rerank/base.py create mode 100644 redisvl/utils/rerank/cohere.py diff --git a/redisvl/utils/rerank/__init__.py b/redisvl/utils/rerank/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/redisvl/utils/rerank/base.py b/redisvl/utils/rerank/base.py new file mode 100644 index 00000000..d85cf475 --- /dev/null +++ b/redisvl/utils/rerank/base.py @@ -0,0 +1,70 @@ +from redis.asyncio import Redis +from redisvl.schema import IndexSchema +from redisvl.index import AsyncSearchIndex +from redisvl.utils.vectorize import CohereTextVectorizer +from redisvl.utils.rerank import CohereReranker +from redisvl.query import VectorQuery + + +vectorizer = CohereTextVectorizer() +reranker = CohereReranker( + model="rerank-english-v2.0", limit=5, rank_by="", max_chunks_per_doc +) +# when there's an overflow from context length + +client = Redis.from_url("redis://localhost:6379") +schema = IndexSchema.from_yaml("schema/schema.yaml") + + +async def main(data, query: str): + """To start""" + index = AsyncSearchIndex(schema, client) + + await index.create(overwrite=True, drop=True) + await index.load([data]) + + vector_query = VectorQuery( + vector=vectorizer.embed(query), + vector_field_name="", + return_fields=[], + num_results=20 + ) + + # TODO think about the scoring implementation + # add score to the dict + results = await index.query(vector_query) + ranked_results = await reranker.rank( + query, results, limit=4, rank_by="", return_score=True + ) + # How do we handle multiple fields for overflow? + # If you do provide multiple fields, truncate in order? + # Support single field to start + return ranked_results + + + + +async def main(data, query: str): + """Maybe in the future???""" + index = AsyncSearchIndex( + schema, + client, + vectorizer=vectorizer, + reranker=reranker + ) + + await index.create(overwrite=True, drop=True) + await index.load([data]) + + vector_query = VectorQuery( + vector_field_name="", + return_fields=[], + num_results=20 + ) + + results = await index.pipeline_run( + vectorizer, + vector_query, + reranker + ) + diff --git a/redisvl/utils/rerank/cohere.py b/redisvl/utils/rerank/cohere.py new file mode 100644 index 00000000..e69de29b From 26b5d3d849354417c0fcbb038b013f6ad874cead Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Wed, 14 Feb 2024 16:48:58 -0500 Subject: [PATCH 02/15] add base reranker and pilot cohere implementation --- redisvl/utils/rerank/__init__.py | 7 +++ redisvl/utils/rerank/base.py | 100 ++++++++++--------------------- redisvl/utils/rerank/cohere.py | 83 +++++++++++++++++++++++++ 3 files changed, 120 insertions(+), 70 deletions(-) diff --git a/redisvl/utils/rerank/__init__.py b/redisvl/utils/rerank/__init__.py index e69de29b..ef7fa9e4 100644 --- a/redisvl/utils/rerank/__init__.py +++ b/redisvl/utils/rerank/__init__.py @@ -0,0 +1,7 @@ +from redisvl.utils.rerank.base import BaseReranker +from redisvl.utils.rerank.cohere import CohereReranker + +__all__ = [ + "BaseReranker", + "CohereReranker", +] diff --git a/redisvl/utils/rerank/base.py b/redisvl/utils/rerank/base.py index d85cf475..4cbe0bcf 100644 --- a/redisvl/utils/rerank/base.py +++ b/redisvl/utils/rerank/base.py @@ -1,70 +1,30 @@ -from redis.asyncio import Redis -from redisvl.schema import IndexSchema -from redisvl.index import AsyncSearchIndex -from redisvl.utils.vectorize import CohereTextVectorizer -from redisvl.utils.rerank import CohereReranker -from redisvl.query import VectorQuery - - -vectorizer = CohereTextVectorizer() -reranker = CohereReranker( - model="rerank-english-v2.0", limit=5, rank_by="", max_chunks_per_doc -) -# when there's an overflow from context length - -client = Redis.from_url("redis://localhost:6379") -schema = IndexSchema.from_yaml("schema/schema.yaml") - - -async def main(data, query: str): - """To start""" - index = AsyncSearchIndex(schema, client) - - await index.create(overwrite=True, drop=True) - await index.load([data]) - - vector_query = VectorQuery( - vector=vectorizer.embed(query), - vector_field_name="", - return_fields=[], - num_results=20 - ) - - # TODO think about the scoring implementation - # add score to the dict - results = await index.query(vector_query) - ranked_results = await reranker.rank( - query, results, limit=4, rank_by="", return_score=True - ) - # How do we handle multiple fields for overflow? - # If you do provide multiple fields, truncate in order? - # Support single field to start - return ranked_results - - - - -async def main(data, query: str): - """Maybe in the future???""" - index = AsyncSearchIndex( - schema, - client, - vectorizer=vectorizer, - reranker=reranker - ) - - await index.create(overwrite=True, drop=True) - await index.load([data]) - - vector_query = VectorQuery( - vector_field_name="", - return_fields=[], - num_results=20 - ) - - results = await index.pipeline_run( - vectorizer, - vector_query, - reranker - ) - +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional + +from pydantic.v1 import BaseModel, validator + + +class BaseReranker(BaseModel, ABC): + model: str + rank_by: Optional[str] = None + limit: int = 5 + return_score: bool = True + + @validator("limit") + @classmethod + def check_limit(cls, value): + if value <= 0: + raise ValueError("limit must be a positive integer") + return value + + @abstractmethod + def rank( + self, query: str, results: List[Dict[str, Any]], **kwargs + ) -> List[Dict[str, Any]]: + pass + + @abstractmethod + async def arank( + self, query: str, results: List[Dict[str, Any]], **kwargs + ) -> List[Dict[str, Any]]: + pass diff --git a/redisvl/utils/rerank/cohere.py b/redisvl/utils/rerank/cohere.py index e69de29b..6e34a7a5 100644 --- a/redisvl/utils/rerank/cohere.py +++ b/redisvl/utils/rerank/cohere.py @@ -0,0 +1,83 @@ +from typing import Any, Dict, List, Optional + +import cohere + +from redisvl.utils.rerank.base import BaseReranker + + +class CohereReranker(BaseReranker): + def __init__(self, model: str = "rerank-english-v2.0", **data): + super().__init__(model=model, **data) + self.client = cohere.Client() + self.aclient = cohere.AsyncClient() + + @staticmethod + def _preprocess(results: List[Dict[str, Any]], rank_by: str) -> List[str]: + try: + docs = [result[rank_by] for result in results] + except (TypeError, KeyError): + raise ValueError( + "Must provide a valid rank_by field option. " + f"{rank_by} field is not present in the search results" + ) + return docs + + @staticmethod + def _postprocess( + results: List[Dict[str, Any]], rankings: List[Any], return_score: bool + ) -> List[Dict[str, Any]]: + reranked_results = [] + for item in rankings: + result = results[item.index] + if return_score: + result["score"] = item.relevance_score + reranked_results.append(result) + return reranked_results + + def rank( + self, + query: str, + results: List[Dict[str, Any]], + max_chunks_per_doc: Optional[int] = None, + **kwargs, + ) -> List[Dict[str, Any]]: + limit = kwargs.get("limit", self.limit) + return_score = kwargs.get("return_score", self.return_score) + rank_by = kwargs.get("rank_by", self.rank_by) + + docs = self._preprocess(results, rank_by) + + rankings = self.client.rerank( + model=self.model, + query=query, + documents=docs, + top_n=limit, + return_documents=False, + max_chunks_per_doc=max_chunks_per_doc, + ) + + return self._postprocess(results, rankings, return_score) + + async def arank( + self, + query: str, + results: List[Dict[str, Any]], + max_chunks_per_doc: Optional[int] = None, + **kwargs, + ) -> List[Dict[str, Any]]: + limit = kwargs.get("limit", self.limit) + return_score = kwargs.get("return_score", self.return_score) + rank_by = kwargs.get("rank_by", self.rank_by) + + docs = self._preprocess(results, rank_by) + + rankings = await self.aclient.rerank( + model=self.model, + query=query, + documents=docs, + top_n=limit, + return_documents=False, + max_chunks_per_doc=max_chunks_per_doc, + ) + + return self._postprocess(results, rankings, return_score) From 8e573867f5f54bf348db6a75882eec2f528daf3c Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Thu, 15 Feb 2024 09:48:07 -0500 Subject: [PATCH 03/15] updates to reranker implementation --- redisvl/utils/rerank/cohere.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/redisvl/utils/rerank/cohere.py b/redisvl/utils/rerank/cohere.py index 6e34a7a5..9e7da6c8 100644 --- a/redisvl/utils/rerank/cohere.py +++ b/redisvl/utils/rerank/cohere.py @@ -1,3 +1,4 @@ +import os from typing import Any, Dict, List, Optional import cohere @@ -6,10 +7,7 @@ class CohereReranker(BaseReranker): - def __init__(self, model: str = "rerank-english-v2.0", **data): - super().__init__(model=model, **data) - self.client = cohere.Client() - self.aclient = cohere.AsyncClient() + model: str = "rerank-english-v2.0" @staticmethod def _preprocess(results: List[Dict[str, Any]], rank_by: str) -> List[str]: @@ -46,13 +44,13 @@ def rank( rank_by = kwargs.get("rank_by", self.rank_by) docs = self._preprocess(results, rank_by) + client = cohere.Client(os.environ["COHERE_API_KEY"]) - rankings = self.client.rerank( + rankings = client.rerank( model=self.model, query=query, documents=docs, top_n=limit, - return_documents=False, max_chunks_per_doc=max_chunks_per_doc, ) @@ -70,13 +68,12 @@ async def arank( rank_by = kwargs.get("rank_by", self.rank_by) docs = self._preprocess(results, rank_by) - - rankings = await self.aclient.rerank( + client = cohere.AsyncClient(os.environ["COHERE_API_KEY"]) + rankings = await client.rerank( model=self.model, query=query, documents=docs, top_n=limit, - return_documents=False, max_chunks_per_doc=max_chunks_per_doc, ) From 58678bb13414f88617bafb740e932b1445b3486b Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Wed, 17 Apr 2024 17:27:46 -0400 Subject: [PATCH 04/15] update to support sync and async, use PrivateAttr, and finish cohere sample --- redisvl/utils/rerank/base.py | 9 ++++-- redisvl/utils/rerank/cohere.py | 51 ++++++++++++++++++++++++++++------ 2 files changed, 49 insertions(+), 11 deletions(-) diff --git a/redisvl/utils/rerank/base.py b/redisvl/utils/rerank/base.py index 4cbe0bcf..33a2b3e1 100644 --- a/redisvl/utils/rerank/base.py +++ b/redisvl/utils/rerank/base.py @@ -1,3 +1,4 @@ + from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional @@ -6,9 +7,9 @@ class BaseReranker(BaseModel, ABC): model: str - rank_by: Optional[str] = None - limit: int = 5 - return_score: bool = True + rank_by: Optional[str] + limit: int + return_score: bool @validator("limit") @classmethod @@ -28,3 +29,5 @@ async def arank( self, query: str, results: List[Dict[str, Any]], **kwargs ) -> List[Dict[str, Any]]: pass + + diff --git a/redisvl/utils/rerank/cohere.py b/redisvl/utils/rerank/cohere.py index 9e7da6c8..d47f750c 100644 --- a/redisvl/utils/rerank/cohere.py +++ b/redisvl/utils/rerank/cohere.py @@ -1,13 +1,50 @@ import os from typing import Any, Dict, List, Optional - -import cohere +from pydantic.v1 import PrivateAttr from redisvl.utils.rerank.base import BaseReranker class CohereReranker(BaseReranker): - model: str = "rerank-english-v2.0" + _client: Any = PrivateAttr() + _aclient: Any = PrivateAttr() + + def __init__( + self, + model: str = "rerank-english-v2.0", + rank_by: Optional[str] = None, + limit: int = 5, + return_score: bool = True, + api_config: Optional[Dict] = None + ) -> None: + # Dynamic import of the cohere module + try: + from cohere import Client, AsyncClient + except ImportError: + raise ImportError( + "Cohere reranker requires the cohere library. \ + Please install with `pip install cohere`" + ) + + # Fetch the API key from api_config or environment variable + api_key = ( + api_config.get("api_key") if api_config else os.getenv("COHERE_API_KEY") + ) + if not api_key: + raise ValueError( + "Cohere API key is required. " + "Provide it in api_config or set the COHERE_API_KEY environment variable." + ) + + self._client = Client(api_key=api_key) + self._aclient = AsyncClient(api_key=api_key) + + super().__init__( + model=model, + rank_by=rank_by, + limit=limit, + return_score=return_score, + ) @staticmethod def _preprocess(results: List[Dict[str, Any]], rank_by: str) -> List[str]: @@ -25,7 +62,7 @@ def _postprocess( results: List[Dict[str, Any]], rankings: List[Any], return_score: bool ) -> List[Dict[str, Any]]: reranked_results = [] - for item in rankings: + for item in rankings.results: result = results[item.index] if return_score: result["score"] = item.relevance_score @@ -44,9 +81,8 @@ def rank( rank_by = kwargs.get("rank_by", self.rank_by) docs = self._preprocess(results, rank_by) - client = cohere.Client(os.environ["COHERE_API_KEY"]) - rankings = client.rerank( + rankings = self._client.rerank( model=self.model, query=query, documents=docs, @@ -68,8 +104,7 @@ async def arank( rank_by = kwargs.get("rank_by", self.rank_by) docs = self._preprocess(results, rank_by) - client = cohere.AsyncClient(os.environ["COHERE_API_KEY"]) - rankings = await client.rerank( + rankings = await self._aclient.rerank( model=self.model, query=query, documents=docs, From 66c5123c4417480efdc0ed5b4c15c69733863881 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Wed, 24 Apr 2024 11:53:01 -0400 Subject: [PATCH 05/15] wip --- redisvl/utils/rerank/base.py | 8 ++-- redisvl/utils/rerank/cohere.py | 85 ++++++++++++++++------------------ 2 files changed, 45 insertions(+), 48 deletions(-) diff --git a/redisvl/utils/rerank/base.py b/redisvl/utils/rerank/base.py index 33a2b3e1..57973823 100644 --- a/redisvl/utils/rerank/base.py +++ b/redisvl/utils/rerank/base.py @@ -1,13 +1,13 @@ from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from pydantic.v1 import BaseModel, validator class BaseReranker(BaseModel, ABC): model: str - rank_by: Optional[str] + rank_by: Optional[List[str]] limit: int return_score: bool @@ -20,13 +20,13 @@ def check_limit(cls, value): @abstractmethod def rank( - self, query: str, results: List[Dict[str, Any]], **kwargs + self, query: str, results: Union[List[Dict[str, Any]], List[str]], **kwargs ) -> List[Dict[str, Any]]: pass @abstractmethod async def arank( - self, query: str, results: List[Dict[str, Any]], **kwargs + self, query: str, results: Union[List[Dict[str, Any]], List[str]], **kwargs ) -> List[Dict[str, Any]]: pass diff --git a/redisvl/utils/rerank/cohere.py b/redisvl/utils/rerank/cohere.py index d47f750c..2c50e555 100644 --- a/redisvl/utils/rerank/cohere.py +++ b/redisvl/utils/rerank/cohere.py @@ -1,5 +1,5 @@ import os -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from pydantic.v1 import PrivateAttr from redisvl.utils.rerank.base import BaseReranker @@ -11,8 +11,8 @@ class CohereReranker(BaseReranker): def __init__( self, - model: str = "rerank-english-v2.0", - rank_by: Optional[str] = None, + model: str = "rerank-english-v3.0", + rank_by: Optional[List[str]] = None, limit: int = 5, return_score: bool = True, api_config: Optional[Dict] = None @@ -46,16 +46,32 @@ def __init__( return_score=return_score, ) - @staticmethod - def _preprocess(results: List[Dict[str, Any]], rank_by: str) -> List[str]: - try: - docs = [result[rank_by] for result in results] - except (TypeError, KeyError): - raise ValueError( - "Must provide a valid rank_by field option. " - f"{rank_by} field is not present in the search results" - ) - return docs + def _preprocess( + self, + query: str, + results: Union[List[Dict[str, Any]], List[str]], + **kwargs, + ): + # parse optional overrides + limit = kwargs.get("limit", self.limit) + return_score = kwargs.get("return_score", self.return_score) + max_chunks_per_doc = kwargs.get("max_chunks_per_doc") + rank_by = kwargs.get("rank_by", self.rank_by) + if isinstance(rank_by, str): + rank_by = [rank_by] + + reranker_kwargs = { + "model": self.model, + "query": query, + "top_n": limit, + "documents": results, + "max_chunks_per_doc": max_chunks_per_doc + } + + if rank_by and all([isinstance(result, dict) for result in results]): + reranker_kwargs["rank_fields"] = rank_by + + return reranker_kwargs, return_score @staticmethod def _postprocess( @@ -72,44 +88,25 @@ def _postprocess( def rank( self, query: str, - results: List[Dict[str, Any]], - max_chunks_per_doc: Optional[int] = None, + results: Union[List[Dict[str, Any]], List[str]], **kwargs, ) -> List[Dict[str, Any]]: - limit = kwargs.get("limit", self.limit) - return_score = kwargs.get("return_score", self.return_score) - rank_by = kwargs.get("rank_by", self.rank_by) - - docs = self._preprocess(results, rank_by) - - rankings = self._client.rerank( - model=self.model, - query=query, - documents=docs, - top_n=limit, - max_chunks_per_doc=max_chunks_per_doc, + # preprocess inputs + reranker_kwargs, return_score = self._preprocess( + query, results, **kwargs ) - - return self._postprocess(results, rankings, return_score) + ranked_results = self._client.rerank(**reranker_kwargs) + return self._postprocess(results, ranked_results, return_score) async def arank( self, query: str, - results: List[Dict[str, Any]], - max_chunks_per_doc: Optional[int] = None, + results: Union[List[Dict[str, Any]], List[str]], **kwargs, ) -> List[Dict[str, Any]]: - limit = kwargs.get("limit", self.limit) - return_score = kwargs.get("return_score", self.return_score) - rank_by = kwargs.get("rank_by", self.rank_by) - - docs = self._preprocess(results, rank_by) - rankings = await self._aclient.rerank( - model=self.model, - query=query, - documents=docs, - top_n=limit, - max_chunks_per_doc=max_chunks_per_doc, + # preprocess inputs + reranker_kwargs, return_score = self._preprocess( + query, results, **kwargs ) - - return self._postprocess(results, rankings, return_score) + ranked_results = await self._client.arerank(**reranker_kwargs) + return self._postprocess(results, ranked_results, return_score) From e7c18a5bb9ea7ae7637ce90c2c63680fe9bceda3 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Thu, 25 Apr 2024 16:49:31 -0400 Subject: [PATCH 06/15] wip --- redisvl/utils/rerank/base.py | 62 +++++++++--- redisvl/utils/rerank/cohere.py | 173 ++++++++++++++++++++++++--------- 2 files changed, 177 insertions(+), 58 deletions(-) diff --git a/redisvl/utils/rerank/base.py b/redisvl/utils/rerank/base.py index 57973823..823932d3 100644 --- a/redisvl/utils/rerank/base.py +++ b/redisvl/utils/rerank/base.py @@ -1,33 +1,73 @@ - from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union, Tuple -from pydantic.v1 import BaseModel, validator +from pydantic import BaseModel, validator class BaseReranker(BaseModel, ABC): + """ + Base class for reranking services that defines the essential + framework for implementations. + + This class serves as a template for creating specialized reranker services + that can interact with different machine learning models to rerank a list of + docs based on a query. It uses abstract methods that must be implemented + by subclasses to provide concrete behavior. + + Attributes: + model (str): Identifier for the model used for reranking. + rank_by (Optional[List[str]], optional): An optional list of keys + specifying the attributes in the docs that should be considered + for ranking. + limit (int): The maximum number of results to return after reranking. + return_score (bool): Flag indicating whether to return scores + alongside the reranked results. + """ model: str - rank_by: Optional[List[str]] + rank_by: Optional[List[str]] = None limit: int return_score: bool @validator("limit") @classmethod def check_limit(cls, value): + """ Ensures the limit is a positive integer. """ if value <= 0: - raise ValueError("limit must be a positive integer") + raise ValueError("Limit must be a positive integer.") + return value + + @validator("rank_by") + @classmethod + def check_rank_by(cls, value): + """ Ensures that rank_by is a list of strings if provided. """ + if value is not None and ( + not isinstance(value, list) or any( + not isinstance(item, str) for item in value + ) + ): + raise ValueError("rank_by must be a list of strings.") return value @abstractmethod def rank( - self, query: str, results: Union[List[Dict[str, Any]], List[str]], **kwargs - ) -> List[Dict[str, Any]]: + self, + query: str, + docs: Union[List[Dict[str, Any]], List[str]], + **kwargs + ) -> Union[Tuple[Union[List[Dict[str, Any]], List[str]], float], List[Dict[str, Any]]]: + """ + Synchronously rerank the docs based on the provided query. + """ pass @abstractmethod async def arank( - self, query: str, results: Union[List[Dict[str, Any]], List[str]], **kwargs - ) -> List[Dict[str, Any]]: + self, + query: str, + docs: Union[List[Dict[str, Any]], List[str]], + **kwargs + ) -> Union[Tuple[Union[List[Dict[str, Any]], List[str]], float], List[Dict[str, Any]]]: + """ + Asynchronously rerank the docs based on the provided query. + """ pass - - diff --git a/redisvl/utils/rerank/cohere.py b/redisvl/utils/rerank/cohere.py index 2c50e555..c05042c3 100644 --- a/redisvl/utils/rerank/cohere.py +++ b/redisvl/utils/rerank/cohere.py @@ -1,11 +1,27 @@ import os -from typing import Any, Dict, List, Optional, Union -from pydantic.v1 import PrivateAttr +from typing import Any, Dict, List, Optional, Union, Tuple +from pydantic import BaseModel, PrivateAttr from redisvl.utils.rerank.base import BaseReranker class CohereReranker(BaseReranker): + """ + The CohereReranker class uses Cohere's API to rerank documents based on an + input query. + + This reranker is designed to interact with Cohere's /rerank API, + requiring an API key for authentication. The key can be provided + directly in the `api_config` dictionary or through the `COHERE_API_KEY` + environment variable. User must obtain an API key from Cohere's website + (https://dashboard.cohere.com/). Additionally, the `cohere` python + client must be installed with `pip install cohere`. + + .. code-block:: python + + + """ + _client: Any = PrivateAttr() _aclient: Any = PrivateAttr() @@ -17,12 +33,42 @@ def __init__( return_score: bool = True, api_config: Optional[Dict] = None ) -> None: + """ + Initialize the CohereReranker with specified model, ranking criteria, + and API configuration. + + Parameters: + model (str): The identifier for the Cohere model used for reranking. + Defaults to 'rerank-english-v3.0'. + rank_by (Optional[List[str]]): Optional list of keys specifying the + attributes in the documents that should be considered for + ranking. None means ranking will rely on the model's default + behavior. + limit (int): The maximum number of results to return after + reranking. Must be a positive integer. + return_score (bool): Whether to return scores alongside the + reranked results. + api_config (Optional[Dict], optional): Dictionary containing the API key. + Defaults to None. + + Raises: + ImportError: If the cohere library is not installed. + ValueError: If the API key is not provided. + """ + super().__init__(model=model, rank_by=rank_by, limit=limit, return_score=return_score) + self._initialize_clients(api_config) + + def _initialize_clients(self, api_config: Optional[Dict]): + """ + Setup the Cohere clients using the provided API key or an + environment variable. + """ # Dynamic import of the cohere module try: from cohere import Client, AsyncClient except ImportError: raise ImportError( - "Cohere reranker requires the cohere library. \ + "Cohere vectorizer requires the cohere library. \ Please install with `pip install cohere`" ) @@ -35,78 +81,111 @@ def __init__( "Cohere API key is required. " "Provide it in api_config or set the COHERE_API_KEY environment variable." ) - self._client = Client(api_key=api_key) self._aclient = AsyncClient(api_key=api_key) - super().__init__( - model=model, - rank_by=rank_by, - limit=limit, - return_score=return_score, - ) - def _preprocess( self, query: str, - results: Union[List[Dict[str, Any]], List[str]], - **kwargs, + docs: Union[List[Dict[str, Any]], List[str]], + **kwargs ): - # parse optional overrides + """ + Prepare and validate reranking config based on provided input and + optional overrides. + """ limit = kwargs.get("limit", self.limit) return_score = kwargs.get("return_score", self.return_score) max_chunks_per_doc = kwargs.get("max_chunks_per_doc") - rank_by = kwargs.get("rank_by", self.rank_by) - if isinstance(rank_by, str): - rank_by = [rank_by] + rank_by = kwargs.get("rank_by", self.rank_by) or [] + rank_by = [rank_by] if isinstance(rank_by, str) else rank_by reranker_kwargs = { "model": self.model, "query": query, "top_n": limit, - "documents": results, + "documents": docs, "max_chunks_per_doc": max_chunks_per_doc } - - if rank_by and all([isinstance(result, dict) for result in results]): - reranker_kwargs["rank_fields"] = rank_by + # if we are working with list of dicts + if all(isinstance(doc, dict) for doc in docs): + if rank_by: + reranker_kwargs["rank_fields"] = rank_by + else: + raise ValueError( + "If reranking dictionary-like docs, " + "you must provide a list of rank_by fields" + ) return reranker_kwargs, return_score @staticmethod def _postprocess( - results: List[Dict[str, Any]], rankings: List[Any], return_score: bool - ) -> List[Dict[str, Any]]: - reranked_results = [] + docs: Union[List[Dict[str, Any]], List[str]], + rankings: List[Any], + ) -> Tuple[Union[List[Dict[str, Any]], List[str]], float]: + """ + Post-process the initial list of documents to include ranking scores, + if specified. + """ + reranked_docs, scores = [], [] for item in rankings.results: - result = results[item.index] - if return_score: - result["score"] = item.relevance_score - reranked_results.append(result) - return reranked_results + scores.append(item.relevance_score) + reranked_docs.append(docs[item.index]) + return reranked_docs, scores def rank( self, query: str, - results: Union[List[Dict[str, Any]], List[str]], - **kwargs, - ) -> List[Dict[str, Any]]: - # preprocess inputs - reranker_kwargs, return_score = self._preprocess( - query, results, **kwargs - ) - ranked_results = self._client.rerank(**reranker_kwargs) - return self._postprocess(results, ranked_results, return_score) + docs: Union[List[Dict[str, Any]], List[str]], + **kwargs + ) -> Union[Tuple[Union[List[Dict[str, Any]], List[str]], float], List[Dict[str, Any]]]: + """ + Rerank documents based on the provided query using the Cohere rerank API. + + This method processes the user's query and the provided documents to + rerank them in a manner that is potentially more relevant to the + query's context. + + Parameters: + query (str): The user's search query. + docs (Union[List[Dict[str, Any]], List[str]]): The list of documents + to be ranked, either as dictionaries or strings. + + Returns: + Union[Tuple[Union[List[Dict[str, Any]], List[str]], float], List[Dict[str, Any]]]: The reranked list of documents and optionally associated scores. + """ + reranker_kwargs, return_score = self._preprocess(query, docs, **kwargs) + rankings = self._client.rerank(**reranker_kwargs) + reranked_docs, scores = self._postprocess(docs, rankings) + if return_score: + return reranked_docs, scores + return reranked_docs async def arank( self, query: str, - results: Union[List[Dict[str, Any]], List[str]], - **kwargs, - ) -> List[Dict[str, Any]]: - # preprocess inputs - reranker_kwargs, return_score = self._preprocess( - query, results, **kwargs - ) - ranked_results = await self._client.arerank(**reranker_kwargs) - return self._postprocess(results, ranked_results, return_score) + docs: Union[List[Dict[str, Any]], List[str]], + **kwargs + ) -> Union[Tuple[Union[List[Dict[str, Any]], List[str]], float], List[Dict[str, Any]]]: + """ + Rerank documents based on the provided query using the Cohere rerank API. + + This method processes the user's query and the provided documents to + rerank them in a manner that is potentially more relevant to the + query's context. + + Parameters: + query (str): The user's search query. + docs (Union[List[Dict[str, Any]], List[str]]): The list of documents + to be ranked, either as dictionaries or strings. + + Returns: + Union[Tuple[Union[List[Dict[str, Any]], List[str]], float], List[Dict[str, Any]]]: The reranked list of documents and optionally associated scores. + """ + reranker_kwargs, return_score = self._preprocess(query, docs, **kwargs) + rankings = await self._aclient.rerank(**reranker_kwargs) + reranked_docs, scores = self._postprocess(docs, rankings) + if return_score: + return reranked_docs, scores + return reranked_docs \ No newline at end of file From 511ba283f0989f177e3624f14bb49efcd09f06bb Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Thu, 25 Apr 2024 17:05:04 -0400 Subject: [PATCH 07/15] update api docs --- docs/api/index.md | 1 + docs/api/reranker.rst | 14 ++++++++++++++ docs/api/vectorizer.rst | 1 - 3 files changed, 15 insertions(+), 1 deletion(-) create mode 100644 docs/api/reranker.rst diff --git a/docs/api/index.md b/docs/api/index.md index 613e5542..b2c45ab6 100644 --- a/docs/api/index.md +++ b/docs/api/index.md @@ -16,6 +16,7 @@ searchindex query filter vectorizer +reranker cache ``` diff --git a/docs/api/reranker.rst b/docs/api/reranker.rst new file mode 100644 index 00000000..dd9fd425 --- /dev/null +++ b/docs/api/reranker.rst @@ -0,0 +1,14 @@ +*********** +Rerankers +*********** + +CohereReranker +================ + +.. _coherereranker_api: + +.. currentmodule:: redisvl.utils.rerank.cohere + +.. autoclass:: CohereReranker + :show-inheritance: + :members: diff --git a/docs/api/vectorizer.rst b/docs/api/vectorizer.rst index 61dd432c..8f2d8a97 100644 --- a/docs/api/vectorizer.rst +++ b/docs/api/vectorizer.rst @@ -1,4 +1,3 @@ - *********** Vectorizers *********** From f6e9b54382ebe8827096af0c8fe86c4c8c6aca95 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Fri, 26 Apr 2024 11:23:14 -0400 Subject: [PATCH 08/15] wip update vectorizers --- redisvl/utils/rerank/base.py | 49 ++++++--------------- redisvl/utils/rerank/cohere.py | 44 +++++++++--------- redisvl/utils/vectorize/base.py | 8 +++- redisvl/utils/vectorize/text/cohere.py | 26 +++++++---- redisvl/utils/vectorize/text/huggingface.py | 22 +++++---- redisvl/utils/vectorize/text/openai.py | 31 ++++++++----- redisvl/utils/vectorize/text/vertexai.py | 26 +++++++---- 7 files changed, 108 insertions(+), 98 deletions(-) diff --git a/redisvl/utils/rerank/base.py b/redisvl/utils/rerank/base.py index 823932d3..4c2f2424 100644 --- a/redisvl/utils/rerank/base.py +++ b/redisvl/utils/rerank/base.py @@ -1,28 +1,10 @@ from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Union, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union -from pydantic import BaseModel, validator +from pydantic.v1 import BaseModel, validator class BaseReranker(BaseModel, ABC): - """ - Base class for reranking services that defines the essential - framework for implementations. - - This class serves as a template for creating specialized reranker services - that can interact with different machine learning models to rerank a list of - docs based on a query. It uses abstract methods that must be implemented - by subclasses to provide concrete behavior. - - Attributes: - model (str): Identifier for the model used for reranking. - rank_by (Optional[List[str]], optional): An optional list of keys - specifying the attributes in the docs that should be considered - for ranking. - limit (int): The maximum number of results to return after reranking. - return_score (bool): Flag indicating whether to return scores - alongside the reranked results. - """ model: str rank_by: Optional[List[str]] = None limit: int @@ -31,7 +13,7 @@ class BaseReranker(BaseModel, ABC): @validator("limit") @classmethod def check_limit(cls, value): - """ Ensures the limit is a positive integer. """ + """Ensures the limit is a positive integer.""" if value <= 0: raise ValueError("Limit must be a positive integer.") return value @@ -39,22 +21,20 @@ def check_limit(cls, value): @validator("rank_by") @classmethod def check_rank_by(cls, value): - """ Ensures that rank_by is a list of strings if provided. """ + """Ensures that rank_by is a list of strings if provided.""" if value is not None and ( - not isinstance(value, list) or any( - not isinstance(item, str) for item in value - ) + not isinstance(value, list) + or any(not isinstance(item, str) for item in value) ): raise ValueError("rank_by must be a list of strings.") return value @abstractmethod def rank( - self, - query: str, - docs: Union[List[Dict[str, Any]], List[str]], - **kwargs - ) -> Union[Tuple[Union[List[Dict[str, Any]], List[str]], float], List[Dict[str, Any]]]: + self, query: str, docs: Union[List[Dict[str, Any]], List[str]], **kwargs + ) -> Union[ + Tuple[Union[List[Dict[str, Any]], List[str]], float], List[Dict[str, Any]] + ]: """ Synchronously rerank the docs based on the provided query. """ @@ -62,11 +42,10 @@ def rank( @abstractmethod async def arank( - self, - query: str, - docs: Union[List[Dict[str, Any]], List[str]], - **kwargs - ) -> Union[Tuple[Union[List[Dict[str, Any]], List[str]], float], List[Dict[str, Any]]]: + self, query: str, docs: Union[List[Dict[str, Any]], List[str]], **kwargs + ) -> Union[ + Tuple[Union[List[Dict[str, Any]], List[str]], float], List[Dict[str, Any]] + ]: """ Asynchronously rerank the docs based on the provided query. """ diff --git a/redisvl/utils/rerank/cohere.py b/redisvl/utils/rerank/cohere.py index c05042c3..ca32147e 100644 --- a/redisvl/utils/rerank/cohere.py +++ b/redisvl/utils/rerank/cohere.py @@ -1,6 +1,7 @@ import os -from typing import Any, Dict, List, Optional, Union, Tuple -from pydantic import BaseModel, PrivateAttr +from typing import Any, Dict, List, Optional, Tuple, Union + +from pydantic import PrivateAttr from redisvl.utils.rerank.base import BaseReranker @@ -31,7 +32,7 @@ def __init__( rank_by: Optional[List[str]] = None, limit: int = 5, return_score: bool = True, - api_config: Optional[Dict] = None + api_config: Optional[Dict] = None, ) -> None: """ Initialize the CohereReranker with specified model, ranking criteria, @@ -55,7 +56,9 @@ def __init__( ImportError: If the cohere library is not installed. ValueError: If the API key is not provided. """ - super().__init__(model=model, rank_by=rank_by, limit=limit, return_score=return_score) + super().__init__( + model=model, rank_by=rank_by, limit=limit, return_score=return_score + ) self._initialize_clients(api_config) def _initialize_clients(self, api_config: Optional[Dict]): @@ -65,7 +68,7 @@ def _initialize_clients(self, api_config: Optional[Dict]): """ # Dynamic import of the cohere module try: - from cohere import Client, AsyncClient + from cohere import AsyncClient, Client except ImportError: raise ImportError( "Cohere vectorizer requires the cohere library. \ @@ -81,14 +84,11 @@ def _initialize_clients(self, api_config: Optional[Dict]): "Cohere API key is required. " "Provide it in api_config or set the COHERE_API_KEY environment variable." ) - self._client = Client(api_key=api_key) - self._aclient = AsyncClient(api_key=api_key) + self._client = Client(api_key=api_key, client_name="redisvl") + self._aclient = AsyncClient(api_key=api_key, client_name="redisvl") def _preprocess( - self, - query: str, - docs: Union[List[Dict[str, Any]], List[str]], - **kwargs + self, query: str, docs: Union[List[Dict[str, Any]], List[str]], **kwargs ): """ Prepare and validate reranking config based on provided input and @@ -105,7 +105,7 @@ def _preprocess( "query": query, "top_n": limit, "documents": docs, - "max_chunks_per_doc": max_chunks_per_doc + "max_chunks_per_doc": max_chunks_per_doc, } # if we are working with list of dicts if all(isinstance(doc, dict) for doc in docs): @@ -135,11 +135,10 @@ def _postprocess( return reranked_docs, scores def rank( - self, - query: str, - docs: Union[List[Dict[str, Any]], List[str]], - **kwargs - ) -> Union[Tuple[Union[List[Dict[str, Any]], List[str]], float], List[Dict[str, Any]]]: + self, query: str, docs: Union[List[Dict[str, Any]], List[str]], **kwargs + ) -> Union[ + Tuple[Union[List[Dict[str, Any]], List[str]], float], List[Dict[str, Any]] + ]: """ Rerank documents based on the provided query using the Cohere rerank API. @@ -163,11 +162,10 @@ def rank( return reranked_docs async def arank( - self, - query: str, - docs: Union[List[Dict[str, Any]], List[str]], - **kwargs - ) -> Union[Tuple[Union[List[Dict[str, Any]], List[str]], float], List[Dict[str, Any]]]: + self, query: str, docs: Union[List[Dict[str, Any]], List[str]], **kwargs + ) -> Union[ + Tuple[Union[List[Dict[str, Any]], List[str]], float], List[Dict[str, Any]] + ]: """ Rerank documents based on the provided query using the Cohere rerank API. @@ -188,4 +186,4 @@ async def arank( reranked_docs, scores = self._postprocess(docs, rankings) if return_score: return reranked_docs, scores - return reranked_docs \ No newline at end of file + return reranked_docs diff --git a/redisvl/utils/vectorize/base.py b/redisvl/utils/vectorize/base.py index 46ba955d..b01452cc 100644 --- a/redisvl/utils/vectorize/base.py +++ b/redisvl/utils/vectorize/base.py @@ -1,3 +1,4 @@ +from abc import ABC, abstractmethod from typing import Any, Callable, List, Optional from pydantic.v1 import BaseModel, validator @@ -5,10 +6,9 @@ from redisvl.redis.utils import array_to_buffer -class BaseVectorizer(BaseModel): +class BaseVectorizer(BaseModel, ABC): model: str dims: int - client: Any @validator("dims", pre=True) @classmethod @@ -17,6 +17,7 @@ def check_dims(cls, v): raise ValueError("Dimension must be a positive integer") return v + @abstractmethod def embed_many( self, texts: List[str], @@ -27,6 +28,7 @@ def embed_many( ) -> List[List[float]]: raise NotImplementedError + @abstractmethod def embed( self, text: str, @@ -36,6 +38,7 @@ def embed( ) -> List[float]: raise NotImplementedError + @abstractmethod async def aembed_many( self, texts: List[str], @@ -46,6 +49,7 @@ async def aembed_many( ) -> List[List[float]]: raise NotImplementedError + @abstractmethod async def aembed( self, text: str, diff --git a/redisvl/utils/vectorize/text/cohere.py b/redisvl/utils/vectorize/text/cohere.py index 7eadd658..ae76f7de 100644 --- a/redisvl/utils/vectorize/text/cohere.py +++ b/redisvl/utils/vectorize/text/cohere.py @@ -1,6 +1,7 @@ import os -from typing import Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional +from pydantic import PrivateAttr from tenacity import retry, stop_after_attempt, wait_random_exponential from tenacity.retry import retry_if_not_exception_type @@ -43,6 +44,8 @@ class CohereTextVectorizer(BaseVectorizer): """ + _client: Any = PrivateAttr() + def __init__( self, model: str = "embed-english-v3.0", api_config: Optional[Dict] = None ): @@ -59,10 +62,18 @@ def __init__( ImportError: If the cohere library is not installed. ValueError: If the API key is not provided. + """ + self._initialize_client(api_config) + super().__init__(model=model, dims=self._set_model_dims(model)) + + def _initialize_client(self, api_config: Optional[Dict]): + """ + Setup the Cohere clients using the provided API key or an + environment variable. """ # Dynamic import of the cohere module try: - import cohere + from cohere import AsyncClient, Client except ImportError: raise ImportError( "Cohere vectorizer requires the cohere library. \ @@ -78,15 +89,12 @@ def __init__( "Cohere API key is required. " "Provide it in api_config or set the COHERE_API_KEY environment variable." ) + self._client = Client(api_key=api_key, client_name="redisvl") + self._aclient = AsyncClient(api_key=api_key, client_name="redisvl") - client = cohere.Client(api_key, client_name="redisvl") - dims = self._set_model_dims(client, model) - super().__init__(model=model, dims=dims, client=client) - - @staticmethod - def _set_model_dims(client, model) -> int: + def _set_model_dims(self, model) -> int: try: - embedding = client.embed( + embedding = self._client.embed( texts=["dimension test"], model=model, input_type="search_document", diff --git a/redisvl/utils/vectorize/text/huggingface.py b/redisvl/utils/vectorize/text/huggingface.py index 5d02ed97..cd404fc2 100644 --- a/redisvl/utils/vectorize/text/huggingface.py +++ b/redisvl/utils/vectorize/text/huggingface.py @@ -1,4 +1,6 @@ -from typing import Callable, Dict, List, Optional +from typing import Any, Callable, List, Optional + +from pydantic import PrivateAttr from redisvl.utils.vectorize.base import BaseVectorizer @@ -28,6 +30,8 @@ class HFTextVectorizer(BaseVectorizer): """ + _client: Any = PrivateAttr() + def __init__( self, model: str = "sentence-transformers/all-mpnet-base-v2", **kwargs ): @@ -42,7 +46,12 @@ def __init__( ImportError: If the sentence-transformers library is not installed. ValueError: If there is an error setting the embedding model dimensions. """ - # Load the SentenceTransformer model + self._initialize_client(model) + super().__init__(model=model, dims=self._set_model_dims()) + + def _initialize_client(self, model: str): + """Setup the HuggingFace client""" + # Dynamic import of the cohere module\ try: from sentence_transformers import SentenceTransformer except ImportError: @@ -51,14 +60,11 @@ def __init__( "Please install with `pip install sentence-transformers`" ) - client = SentenceTransformer(model) - dims = self._set_model_dims(client) - super().__init__(model=model, dims=dims, client=client) + self._client = SentenceTransformer(model) - @staticmethod - def _set_model_dims(client): + def _set_model_dims(self): try: - embedding = client.encode(["dimension check"])[0] + embedding = self._client.encode(["dimension check"])[0] except (KeyError, IndexError) as ke: raise ValueError(f"Empty response from the embedding model: {str(ke)}") except Exception as e: # pylint: disable=broad-except diff --git a/redisvl/utils/vectorize/text/openai.py b/redisvl/utils/vectorize/text/openai.py index 22afb2f5..3f6c3e65 100644 --- a/redisvl/utils/vectorize/text/openai.py +++ b/redisvl/utils/vectorize/text/openai.py @@ -1,6 +1,7 @@ import os from typing import Any, Callable, Dict, List, Optional +from pydantic import PrivateAttr from tenacity import retry, stop_after_attempt, wait_random_exponential from tenacity.retry import retry_if_not_exception_type @@ -42,7 +43,8 @@ class OpenAITextVectorizer(BaseVectorizer): """ - aclient: Any # Since the OpenAI module is loaded dynamically + _client: Any = PrivateAttr() + _aclient: Any = PrivateAttr() def __init__( self, model: str = "text-embedding-ada-002", api_config: Optional[Dict] = None @@ -59,6 +61,14 @@ def __init__( ImportError: If the openai library is not installed. ValueError: If the OpenAI API key is not provided. """ + self._initialize_clients(api_config) + super().__init__(model=model, dims=self._set_model_dims(model)) + + def _initialize_clients(self, api_config: Optional[Dict]): + """ + Setup the OpenAI clients using the provided API key or an + environment variable. + """ # Dynamic import of the openai module try: from openai import AsyncOpenAI, OpenAI @@ -79,16 +89,13 @@ def __init__( environment variable." ) - client = OpenAI(api_key=api_key) - dims = self._set_model_dims(client, model) - super().__init__(model=model, dims=dims, client=client) - self.aclient = AsyncOpenAI(api_key=api_key) + self._client = OpenAI(api_key=api_key) + self._aclient = AsyncOpenAI(api_key=api_key) - @staticmethod - def _set_model_dims(client, model) -> int: + def _set_model_dims(self, model) -> int: try: embedding = ( - client.embeddings.create(input=["dimension test"], model=model) + self._client.embeddings.create(input=["dimension test"], model=model) .data[0] .embedding ) @@ -136,7 +143,7 @@ def embed_many( embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): - response = self.client.embeddings.create(input=batch, model=self.model) + response = self._client.embeddings.create(input=batch, model=self.model) embeddings += [ self._process_embedding(r.embedding, as_buffer) for r in response.data ] @@ -174,7 +181,7 @@ def embed( if preprocess: text = preprocess(text) - result = self.client.embeddings.create(input=[text], model=self.model) + result = self._client.embeddings.create(input=[text], model=self.model) return self._process_embedding(result.data[0].embedding, as_buffer) @retry( @@ -214,7 +221,7 @@ async def aembed_many( embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): - response = await self.aclient.embeddings.create( + response = await self._aclient.embeddings.create( input=batch, model=self.model ) embeddings += [ @@ -254,5 +261,5 @@ async def aembed( if preprocess: text = preprocess(text) - result = await self.aclient.embeddings.create(input=[text], model=self.model) + result = await self._aclient.embeddings.create(input=[text], model=self.model) return self._process_embedding(result.data[0].embedding, as_buffer) diff --git a/redisvl/utils/vectorize/text/vertexai.py b/redisvl/utils/vectorize/text/vertexai.py index 0aaf314a..d29365f8 100644 --- a/redisvl/utils/vectorize/text/vertexai.py +++ b/redisvl/utils/vectorize/text/vertexai.py @@ -1,6 +1,7 @@ import os -from typing import Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional +from pydantic import PrivateAttr from tenacity import retry, stop_after_attempt, wait_random_exponential from tenacity.retry import retry_if_not_exception_type @@ -40,6 +41,8 @@ class VertexAITextVectorizer(BaseVectorizer): """ + _client: Any = PrivateAttr() + def __init__( self, model: str = "textembedding-gecko", api_config: Optional[Dict] = None ): @@ -55,6 +58,14 @@ def __init__( ImportError: If the google-cloud-aiplatform library is not installed. ValueError: If the API key is not provided. """ + self._initialize_client(model, api_config) + super().__init__(model=model, dims=self._set_model_dims()) + + def _initialize_client(self, model: str, api_config: Optional[Dict]): + """ + Setup the VertexAI clients using the provided API key or an + environment variable. + """ # Fetch the project_id and location from api_config or environment variables project_id = ( api_config.get("project_id") if api_config else os.getenv("GCP_PROJECT_ID") @@ -93,14 +104,11 @@ def __init__( "Please install with `pip install google-cloud-aiplatform>=1.26`" ) - client = TextEmbeddingModel.from_pretrained(model) - dims = self._set_model_dims(client) - super().__init__(model=model, dims=dims, client=client) + self._client = TextEmbeddingModel.from_pretrained(model) - @staticmethod - def _set_model_dims(client) -> int: + def _set_model_dims(self) -> int: try: - embedding = client.get_embeddings(["dimension test"])[0].values + embedding = self._client.get_embeddings(["dimension test"])[0].values except (KeyError, IndexError) as ke: raise ValueError(f"Unexpected response from the VertexAI API: {str(ke)}") except Exception as e: # pylint: disable=broad-except @@ -145,7 +153,7 @@ def embed_many( embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): - response = self.client.get_embeddings(batch) + response = self._client.get_embeddings(batch) embeddings += [ self._process_embedding(r.values, as_buffer) for r in response ] @@ -183,5 +191,5 @@ def embed( if preprocess: text = preprocess(text) - result = self.client.get_embeddings([text]) + result = self._client.get_embeddings([text]) return self._process_embedding(result[0].values, as_buffer) From 969ebe9844b32dc2a91ac9af2eb9bdd583739ffd Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Fri, 26 Apr 2024 20:39:28 -0400 Subject: [PATCH 09/15] use appropriate pydantic shim --- redisvl/utils/rerank/cohere.py | 2 +- redisvl/utils/vectorize/base.py | 13 ++------- redisvl/utils/vectorize/text/azureopenai.py | 32 ++++++++++++--------- redisvl/utils/vectorize/text/cohere.py | 26 ++++++++++++++--- redisvl/utils/vectorize/text/huggingface.py | 25 ++++++++++++++-- redisvl/utils/vectorize/text/openai.py | 2 +- redisvl/utils/vectorize/text/vertexai.py | 21 +++++++++++++- 7 files changed, 88 insertions(+), 33 deletions(-) diff --git a/redisvl/utils/rerank/cohere.py b/redisvl/utils/rerank/cohere.py index ca32147e..23209545 100644 --- a/redisvl/utils/rerank/cohere.py +++ b/redisvl/utils/rerank/cohere.py @@ -1,7 +1,7 @@ import os from typing import Any, Dict, List, Optional, Tuple, Union -from pydantic import PrivateAttr +from pydantic.v1 import PrivateAttr from redisvl.utils.rerank.base import BaseReranker diff --git a/redisvl/utils/vectorize/base.py b/redisvl/utils/vectorize/base.py index b01452cc..9d5fab78 100644 --- a/redisvl/utils/vectorize/base.py +++ b/redisvl/utils/vectorize/base.py @@ -1,21 +1,14 @@ from abc import ABC, abstractmethod -from typing import Any, Callable, List, Optional +from typing import Callable, List, Optional -from pydantic.v1 import BaseModel, validator +from pydantic.v1 import BaseModel from redisvl.redis.utils import array_to_buffer class BaseVectorizer(BaseModel, ABC): model: str - dims: int - - @validator("dims", pre=True) - @classmethod - def check_dims(cls, v): - if v <= 0: - raise ValueError("Dimension must be a positive integer") - return v + dims: Optional[int] @abstractmethod def embed_many( diff --git a/redisvl/utils/vectorize/text/azureopenai.py b/redisvl/utils/vectorize/text/azureopenai.py index 5ac527fa..39628b67 100644 --- a/redisvl/utils/vectorize/text/azureopenai.py +++ b/redisvl/utils/vectorize/text/azureopenai.py @@ -3,7 +3,7 @@ from tenacity import retry, stop_after_attempt, wait_random_exponential from tenacity.retry import retry_if_not_exception_type - +from pydantic.v1 import PrivateAttr from redisvl.utils.vectorize.base import BaseVectorizer # ignore that openai isn't imported @@ -47,7 +47,8 @@ class AzureOpenAITextVectorizer(BaseVectorizer): """ - aclient: Any # Since the OpenAI module is loaded dynamically + _client: Any = PrivateAttr() + _aclient: Any = PrivateAttr() def __init__( self, model: str = "text-embedding-ada-002", api_config: Optional[Dict] = None @@ -65,6 +66,14 @@ def __init__( ImportError: If the openai library is not installed. ValueError: If the AzureOpenAI API key, version, or endpoint are not provided. """ + self._initialize_clients(api_config) + super().__init__(model=model, dims=self._set_model_dims(model)) + + def _initialize_clients(self, api_config: Optional[Dict]): + """ + Setup the OpenAI clients using the provided API key or an + environment variable. + """ # Dynamic import of the openai module try: from openai import AsyncAzureOpenAI, AzureOpenAI @@ -114,20 +123,17 @@ def __init__( environment variable." ) - client = AzureOpenAI( + self._client = AzureOpenAI( api_key=api_key, api_version=api_version, azure_endpoint=azure_endpoint ) - dims = self._set_model_dims(client, model) - super().__init__(model=model, dims=dims, client=client) - self.aclient = AsyncAzureOpenAI( + self._aclient = AsyncAzureOpenAI( api_key=api_key, api_version=api_version, azure_endpoint=azure_endpoint ) - @staticmethod - def _set_model_dims(client, model) -> int: + def _set_model_dims(self, model) -> int: try: embedding = ( - client.embeddings.create(input=["dimension test"], model=model) + self._client.embeddings.create(input=["dimension test"], model=model) .data[0] .embedding ) @@ -175,7 +181,7 @@ def embed_many( embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): - response = self.client.embeddings.create(input=batch, model=self.model) + response = self._client.embeddings.create(input=batch, model=self.model) embeddings += [ self._process_embedding(r.embedding, as_buffer) for r in response.data ] @@ -213,7 +219,7 @@ def embed( if preprocess: text = preprocess(text) - result = self.client.embeddings.create(input=[text], model=self.model) + result = self._client.embeddings.create(input=[text], model=self.model) return self._process_embedding(result.data[0].embedding, as_buffer) @retry( @@ -253,7 +259,7 @@ async def aembed_many( embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): - response = await self.aclient.embeddings.create( + response = await self._aclient.embeddings.create( input=batch, model=self.model ) embeddings += [ @@ -293,5 +299,5 @@ async def aembed( if preprocess: text = preprocess(text) - result = await self.aclient.embeddings.create(input=[text], model=self.model) + result = await self._aclient.embeddings.create(input=[text], model=self.model) return self._process_embedding(result.data[0].embedding, as_buffer) diff --git a/redisvl/utils/vectorize/text/cohere.py b/redisvl/utils/vectorize/text/cohere.py index ae76f7de..cec856dd 100644 --- a/redisvl/utils/vectorize/text/cohere.py +++ b/redisvl/utils/vectorize/text/cohere.py @@ -1,7 +1,7 @@ import os from typing import Any, Callable, Dict, List, Optional -from pydantic import PrivateAttr +from pydantic.v1 import PrivateAttr from tenacity import retry, stop_after_attempt, wait_random_exponential from tenacity.retry import retry_if_not_exception_type @@ -90,7 +90,6 @@ def _initialize_client(self, api_config: Optional[Dict]): "Provide it in api_config or set the COHERE_API_KEY environment variable." ) self._client = Client(api_key=api_key, client_name="redisvl") - self._aclient = AsyncClient(api_key=api_key, client_name="redisvl") def _set_model_dims(self, model) -> int: try: @@ -158,7 +157,7 @@ def embed( ) if preprocess: text = preprocess(text) - embedding = self.client.embed( + embedding = self._client.embed( texts=[text], model=self.model, input_type=input_type ).embeddings[0] return self._process_embedding(embedding, as_buffer) @@ -227,7 +226,7 @@ def embed_many( embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): - response = self.client.embed( + response = self._client.embed( texts=batch, model=self.model, input_type=input_type ) embeddings += [ @@ -235,3 +234,22 @@ def embed_many( for embedding in response.embeddings ] return embeddings + + async def aembed_many( + self, + texts: List[str], + preprocess: Optional[Callable] = None, + batch_size: int = 1000, + as_buffer: bool = False, + **kwargs, + ) -> List[List[float]]: + raise NotImplementedError + + async def aembed( + self, + text: str, + preprocess: Optional[Callable] = None, + as_buffer: bool = False, + **kwargs, + ) -> List[float]: + raise NotImplementedError diff --git a/redisvl/utils/vectorize/text/huggingface.py b/redisvl/utils/vectorize/text/huggingface.py index cd404fc2..f49efd93 100644 --- a/redisvl/utils/vectorize/text/huggingface.py +++ b/redisvl/utils/vectorize/text/huggingface.py @@ -1,6 +1,6 @@ from typing import Any, Callable, List, Optional -from pydantic import PrivateAttr +from pydantic.v1 import PrivateAttr from redisvl.utils.vectorize.base import BaseVectorizer @@ -99,7 +99,7 @@ def embed( if preprocess: text = preprocess(text) - embedding = self.client.encode([text])[0] + embedding = self._client.encode([text])[0] return self._process_embedding(embedding.tolist(), as_buffer) def embed_many( @@ -135,7 +135,7 @@ def embed_many( embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): - batch_embeddings = self.client.encode(batch) + batch_embeddings = self._client.encode(batch) embeddings.extend( [ self._process_embedding(embedding.tolist(), as_buffer) @@ -143,3 +143,22 @@ def embed_many( ] ) return embeddings + + async def aembed_many( + self, + texts: List[str], + preprocess: Optional[Callable] = None, + batch_size: int = 1000, + as_buffer: bool = False, + **kwargs, + ) -> List[List[float]]: + raise NotImplementedError + + async def aembed( + self, + text: str, + preprocess: Optional[Callable] = None, + as_buffer: bool = False, + **kwargs, + ) -> List[float]: + raise NotImplementedError \ No newline at end of file diff --git a/redisvl/utils/vectorize/text/openai.py b/redisvl/utils/vectorize/text/openai.py index 3f6c3e65..b5d2070c 100644 --- a/redisvl/utils/vectorize/text/openai.py +++ b/redisvl/utils/vectorize/text/openai.py @@ -1,7 +1,7 @@ import os from typing import Any, Callable, Dict, List, Optional -from pydantic import PrivateAttr +from pydantic.v1 import PrivateAttr from tenacity import retry, stop_after_attempt, wait_random_exponential from tenacity.retry import retry_if_not_exception_type diff --git a/redisvl/utils/vectorize/text/vertexai.py b/redisvl/utils/vectorize/text/vertexai.py index d29365f8..2bc80733 100644 --- a/redisvl/utils/vectorize/text/vertexai.py +++ b/redisvl/utils/vectorize/text/vertexai.py @@ -1,7 +1,7 @@ import os from typing import Any, Callable, Dict, List, Optional -from pydantic import PrivateAttr +from pydantic.v1 import PrivateAttr from tenacity import retry, stop_after_attempt, wait_random_exponential from tenacity.retry import retry_if_not_exception_type @@ -193,3 +193,22 @@ def embed( text = preprocess(text) result = self._client.get_embeddings([text]) return self._process_embedding(result[0].values, as_buffer) + + async def aembed_many( + self, + texts: List[str], + preprocess: Optional[Callable] = None, + batch_size: int = 1000, + as_buffer: bool = False, + **kwargs, + ) -> List[List[float]]: + raise NotImplementedError + + async def aembed( + self, + text: str, + preprocess: Optional[Callable] = None, + as_buffer: bool = False, + **kwargs, + ) -> List[float]: + raise NotImplementedError \ No newline at end of file From 5bae8b042b1391cfb10109d6a23c77a3c3a22cd4 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Fri, 26 Apr 2024 20:54:48 -0400 Subject: [PATCH 10/15] fix formatting and mypy --- redisvl/utils/rerank/base.py | 8 ++------ redisvl/utils/rerank/cohere.py | 12 ++++-------- redisvl/utils/vectorize/base.py | 2 +- redisvl/utils/vectorize/text/azureopenai.py | 3 ++- redisvl/utils/vectorize/text/huggingface.py | 2 +- redisvl/utils/vectorize/text/vertexai.py | 2 +- 6 files changed, 11 insertions(+), 18 deletions(-) diff --git a/redisvl/utils/rerank/base.py b/redisvl/utils/rerank/base.py index 4c2f2424..f4602662 100644 --- a/redisvl/utils/rerank/base.py +++ b/redisvl/utils/rerank/base.py @@ -32,9 +32,7 @@ def check_rank_by(cls, value): @abstractmethod def rank( self, query: str, docs: Union[List[Dict[str, Any]], List[str]], **kwargs - ) -> Union[ - Tuple[Union[List[Dict[str, Any]], List[str]], float], List[Dict[str, Any]] - ]: + ) -> Union[Tuple[List[Dict[str, Any]], List[float]], List[Dict[str, Any]]]: """ Synchronously rerank the docs based on the provided query. """ @@ -43,9 +41,7 @@ def rank( @abstractmethod async def arank( self, query: str, docs: Union[List[Dict[str, Any]], List[str]], **kwargs - ) -> Union[ - Tuple[Union[List[Dict[str, Any]], List[str]], float], List[Dict[str, Any]] - ]: + ) -> Union[Tuple[List[Dict[str, Any]], List[float]], List[Dict[str, Any]]]: """ Asynchronously rerank the docs based on the provided query. """ diff --git a/redisvl/utils/rerank/cohere.py b/redisvl/utils/rerank/cohere.py index 23209545..29a69788 100644 --- a/redisvl/utils/rerank/cohere.py +++ b/redisvl/utils/rerank/cohere.py @@ -123,22 +123,20 @@ def _preprocess( def _postprocess( docs: Union[List[Dict[str, Any]], List[str]], rankings: List[Any], - ) -> Tuple[Union[List[Dict[str, Any]], List[str]], float]: + ) -> Tuple[List[Any], List[float]]: """ Post-process the initial list of documents to include ranking scores, if specified. """ reranked_docs, scores = [], [] - for item in rankings.results: + for item in rankings.results: # type: ignore scores.append(item.relevance_score) reranked_docs.append(docs[item.index]) return reranked_docs, scores def rank( self, query: str, docs: Union[List[Dict[str, Any]], List[str]], **kwargs - ) -> Union[ - Tuple[Union[List[Dict[str, Any]], List[str]], float], List[Dict[str, Any]] - ]: + ) -> Union[Tuple[List[Dict[str, Any]], List[float]], List[Dict[str, Any]]]: """ Rerank documents based on the provided query using the Cohere rerank API. @@ -163,9 +161,7 @@ def rank( async def arank( self, query: str, docs: Union[List[Dict[str, Any]], List[str]], **kwargs - ) -> Union[ - Tuple[Union[List[Dict[str, Any]], List[str]], float], List[Dict[str, Any]] - ]: + ) -> Union[Tuple[List[Dict[str, Any]], List[float]], List[Dict[str, Any]]]: """ Rerank documents based on the provided query using the Cohere rerank API. diff --git a/redisvl/utils/vectorize/base.py b/redisvl/utils/vectorize/base.py index 9d5fab78..835b955b 100644 --- a/redisvl/utils/vectorize/base.py +++ b/redisvl/utils/vectorize/base.py @@ -8,7 +8,7 @@ class BaseVectorizer(BaseModel, ABC): model: str - dims: Optional[int] + dims: int @abstractmethod def embed_many( diff --git a/redisvl/utils/vectorize/text/azureopenai.py b/redisvl/utils/vectorize/text/azureopenai.py index 39628b67..fc13eb75 100644 --- a/redisvl/utils/vectorize/text/azureopenai.py +++ b/redisvl/utils/vectorize/text/azureopenai.py @@ -1,9 +1,10 @@ import os from typing import Any, Callable, Dict, List, Optional +from pydantic.v1 import PrivateAttr from tenacity import retry, stop_after_attempt, wait_random_exponential from tenacity.retry import retry_if_not_exception_type -from pydantic.v1 import PrivateAttr + from redisvl.utils.vectorize.base import BaseVectorizer # ignore that openai isn't imported diff --git a/redisvl/utils/vectorize/text/huggingface.py b/redisvl/utils/vectorize/text/huggingface.py index f49efd93..cb72652e 100644 --- a/redisvl/utils/vectorize/text/huggingface.py +++ b/redisvl/utils/vectorize/text/huggingface.py @@ -161,4 +161,4 @@ async def aembed( as_buffer: bool = False, **kwargs, ) -> List[float]: - raise NotImplementedError \ No newline at end of file + raise NotImplementedError diff --git a/redisvl/utils/vectorize/text/vertexai.py b/redisvl/utils/vectorize/text/vertexai.py index 2bc80733..1d67c672 100644 --- a/redisvl/utils/vectorize/text/vertexai.py +++ b/redisvl/utils/vectorize/text/vertexai.py @@ -211,4 +211,4 @@ async def aembed( as_buffer: bool = False, **kwargs, ) -> List[float]: - raise NotImplementedError \ No newline at end of file + raise NotImplementedError From 46347ef6f3d9321736bdced19ee1d285abfce1b1 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Fri, 26 Apr 2024 20:56:05 -0400 Subject: [PATCH 11/15] Add validator --- redisvl/utils/vectorize/base.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/redisvl/utils/vectorize/base.py b/redisvl/utils/vectorize/base.py index 835b955b..f5ef8198 100644 --- a/redisvl/utils/vectorize/base.py +++ b/redisvl/utils/vectorize/base.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import Callable, List, Optional -from pydantic.v1 import BaseModel +from pydantic.v1 import BaseModel, validator from redisvl.redis.utils import array_to_buffer @@ -10,6 +10,14 @@ class BaseVectorizer(BaseModel, ABC): model: str dims: int + @validator("dims") + @classmethod + def check_dims(cls, value): + """Ensures the dims are a positive integer.""" + if value <= 0: + raise ValueError("Dims must be a positive integer.") + return value + @abstractmethod def embed_many( self, From 6a438655afc52cbd7cf82b6f5c3008a6ca71c0aa Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Fri, 26 Apr 2024 21:04:33 -0400 Subject: [PATCH 12/15] add reranker tests --- tests/integration/test_rerankers.py | 66 +++++++++++++++++++++++++++ tests/integration/test_vectorizers.py | 1 - 2 files changed, 66 insertions(+), 1 deletion(-) create mode 100644 tests/integration/test_rerankers.py diff --git a/tests/integration/test_rerankers.py b/tests/integration/test_rerankers.py new file mode 100644 index 00000000..19774d7f --- /dev/null +++ b/tests/integration/test_rerankers.py @@ -0,0 +1,66 @@ +import os + +import pytest + +from redisvl.utils.rerank import CohereReranker + + +@pytest.fixture +def skip_reranker() -> bool: + # os.getenv returns a string + v = os.getenv("SKIP_RERANKERS", "False").lower() == "true" + return v + + +# Fixture for the reranker instance +@pytest.fixture +def reranker(): + return CohereReranker() + + +# Test for basic ranking functionality +def test_rank_documents(reranker, skip_reranker): + if skip_reranker: + pytest.skip("Skipping reranker instantiation...") + docs = ["document one", "document two", "document three"] + query = "search query" + + reranked_docs, scores = reranker.rank(query, docs) + + assert isinstance(reranked_docs, list) + assert len(reranked_docs) == len(docs) # Ensure we get back as many docs as we sent + assert all(isinstance(score, float) for score in scores) # Scores should be floats + + +# Test for asynchronous ranking functionality +@pytest.mark.asyncio +async def test_async_rank_documents(reranker, skip_reranker): + if skip_reranker: + pytest.skip("Skipping reranker instantiation...") + docs = ["document one", "document two", "document three"] + query = "search query" + + reranked_docs, scores = await reranker.arank(query, docs) + + assert isinstance(reranked_docs, list) + assert len(reranked_docs) == len(docs) # Ensure we get back as many docs as we sent + assert all(isinstance(score, float) for score in scores) # Scores should be floats + + +# Test handling of bad input +def test_bad_input(reranker, skip_reranker): + if skip_reranker: + pytest.skip("Skipping reranker instantiation...") + with pytest.raises(Exception): + reranker.rank("", []) # Empty query or documents + + with pytest.raises(Exception): + reranker.rank(123, ["valid document"]) # Invalid type for query + + with pytest.raises(Exception): + reranker.rank("valid query", "not a list") # Invalid type for documents + + with pytest.raises(Exception): + reranker.rank( + "valid query", [{"field": "valid document"}], rank_by=["invalid_field"] + ) # Invalid rank_by field diff --git a/tests/integration/test_vectorizers.py b/tests/integration/test_vectorizers.py index 352aaa2e..614b8d0b 100644 --- a/tests/integration/test_vectorizers.py +++ b/tests/integration/test_vectorizers.py @@ -15,7 +15,6 @@ def skip_vectorizer() -> bool: # os.getenv returns a string v = os.getenv("SKIP_VECTORIZERS", "False").lower() == "true" - print(v, flush=True) return v From e5b9deb9cc145dc5413a15f6a1fa47ce5fadfcc3 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Fri, 26 Apr 2024 21:19:32 -0400 Subject: [PATCH 13/15] add SKIP flag for rerankers --- .github/workflows/run_tests.yml | 2 +- CONTRIBUTING.md | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml index 41dca251..c5c389a9 100644 --- a/.github/workflows/run_tests.yml +++ b/.github/workflows/run_tests.yml @@ -71,7 +71,7 @@ jobs: - name: Run tests if: matrix.connection != 'plain' || matrix.redis-stack-version != 'latest' run: | - SKIP_VECTORIZERS=True poetry run test-cov + SKIP_VECTORIZERS=True SKIP_RERANKERS=True poetry run test-cov - name: Run notebooks if: matrix.connection == 'plain' && matrix.redis-stack-version == 'latest' diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 86e59e8b..45634012 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -68,6 +68,11 @@ Tests w/out vectorizers: SKIP_VECTORIZERS=true poetry run test-cov ``` +Tests w/out rerankers: +```bash +SKIP_RERANKERS=true poetry run test-cov +``` + ### Getting Redis In order for your applications to use RedisVL, you must have [Redis](https://redis.io) accessible with Search & Query features enabled on [Redis Cloud](https://redis.com/try-free) or locally in docker with [Redis Stack](https://redis.io/docs/getting-started/install-stack/docker/): From 2676cedaff4bcc686869c148f58997c8546350c2 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Tue, 30 Apr 2024 21:36:39 -0400 Subject: [PATCH 14/15] updates to flags and subprocess call --- scripts.py | 2 +- tests/integration/test_rerankers.py | 22 ++++++-------------- tests/integration/test_vectorizers.py | 30 ++++++--------------------- 3 files changed, 13 insertions(+), 41 deletions(-) diff --git a/scripts.py b/scripts.py index cfc254da..61d36636 100644 --- a/scripts.py +++ b/scripts.py @@ -26,7 +26,7 @@ def test_verbose(): subprocess.run(["python", "-m", "pytest", "-vv", "-s", "--log-level=CRITICAL"]) def test_cov(): - subprocess.run(["python", "-m", "pytest", "-vv", "--cov=./redisvl", "--cov-report=xml", "--log-level=CRITICAL"]) + subprocess.run(["python", "-m", "pytest", "-vv", "--cov=./redisvl", "--cov-report=xml", "--log-level=CRITICAL"], check=True) def cov(): subprocess.run(["coverage", "html"]) diff --git a/tests/integration/test_rerankers.py b/tests/integration/test_rerankers.py index 19774d7f..4866aa58 100644 --- a/tests/integration/test_rerankers.py +++ b/tests/integration/test_rerankers.py @@ -5,23 +5,17 @@ from redisvl.utils.rerank import CohereReranker -@pytest.fixture -def skip_reranker() -> bool: - # os.getenv returns a string - v = os.getenv("SKIP_RERANKERS", "False").lower() == "true" - return v - - # Fixture for the reranker instance @pytest.fixture def reranker(): + skip_reranker = os.getenv("SKIP_RERANKERS", "False").lower() == "true" + if skip_reranker: + pytest.skip("Skipping reranker instantiation...") return CohereReranker() # Test for basic ranking functionality -def test_rank_documents(reranker, skip_reranker): - if skip_reranker: - pytest.skip("Skipping reranker instantiation...") +def test_rank_documents(reranker): docs = ["document one", "document two", "document three"] query = "search query" @@ -34,9 +28,7 @@ def test_rank_documents(reranker, skip_reranker): # Test for asynchronous ranking functionality @pytest.mark.asyncio -async def test_async_rank_documents(reranker, skip_reranker): - if skip_reranker: - pytest.skip("Skipping reranker instantiation...") +async def test_async_rank_documents(reranker): docs = ["document one", "document two", "document three"] query = "search query" @@ -48,9 +40,7 @@ async def test_async_rank_documents(reranker, skip_reranker): # Test handling of bad input -def test_bad_input(reranker, skip_reranker): - if skip_reranker: - pytest.skip("Skipping reranker instantiation...") +def test_bad_input(reranker): with pytest.raises(Exception): reranker.rank("", []) # Empty query or documents diff --git a/tests/integration/test_vectorizers.py b/tests/integration/test_vectorizers.py index 614b8d0b..23952c65 100644 --- a/tests/integration/test_vectorizers.py +++ b/tests/integration/test_vectorizers.py @@ -45,10 +45,7 @@ def vectorizer(request, skip_vectorizer): ) -def test_vectorizer_embed(vectorizer, skip_vectorizer): - if skip_vectorizer: - pytest.skip("Skipping vectorizer tests") - +def test_vectorizer_embed(vectorizer): text = "This is a test sentence." if isinstance(vectorizer, CohereTextVectorizer): embedding = vectorizer.embed(text, input_type="search_document") @@ -59,10 +56,7 @@ def test_vectorizer_embed(vectorizer, skip_vectorizer): assert len(embedding) == vectorizer.dims -def test_vectorizer_embed_many(vectorizer, skip_vectorizer): - if skip_vectorizer: - pytest.skip("Skipping vectorizer tests") - +def test_vectorizer_embed_many(vectorizer): texts = ["This is the first test sentence.", "This is the second test sentence."] if isinstance(vectorizer, CohereTextVectorizer): embeddings = vectorizer.embed_many(texts, input_type="search_document") @@ -76,10 +70,7 @@ def test_vectorizer_embed_many(vectorizer, skip_vectorizer): ) -def test_vectorizer_bad_input(vectorizer, skip_vectorizer): - if skip_vectorizer: - pytest.skip("Skipping vectorizer tests") - +def test_vectorizer_bad_input(vectorizer): with pytest.raises(TypeError): vectorizer.embed(1) @@ -101,10 +92,7 @@ def avectorizer(request, skip_vectorizer): @pytest.mark.asyncio -async def test_vectorizer_aembed(avectorizer, skip_vectorizer): - if skip_vectorizer: - pytest.skip("Skipping vectorizer tests") - +async def test_vectorizer_aembed(avectorizer): text = "This is a test sentence." embedding = await avectorizer.aembed(text) @@ -113,10 +101,7 @@ async def test_vectorizer_aembed(avectorizer, skip_vectorizer): @pytest.mark.asyncio -async def test_vectorizer_aembed_many(avectorizer, skip_vectorizer): - if skip_vectorizer: - pytest.skip("Skipping vectorizer tests") - +async def test_vectorizer_aembed_many(avectorizer): texts = ["This is the first test sentence.", "This is the second test sentence."] embeddings = await avectorizer.aembed_many(texts) @@ -128,10 +113,7 @@ async def test_vectorizer_aembed_many(avectorizer, skip_vectorizer): @pytest.mark.asyncio -async def test_avectorizer_bad_input(avectorizer, skip_vectorizer): - if skip_vectorizer: - pytest.skip("Skipping vectorizer tests") - +async def test_avectorizer_bad_input(avectorizer): with pytest.raises(TypeError): avectorizer.embed(1) From c6a9a961ee7d658f1168e20ec1238a396584609c Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Wed, 1 May 2024 08:55:51 -0400 Subject: [PATCH 15/15] use proper event loop policy from asyncio and pytest --- conftest.py | 12 ++++++++++-- tests/integration/test_flow_async.py | 2 +- tests/unit/test_async_search_index.py | 16 ++++++++-------- 3 files changed, 19 insertions(+), 11 deletions(-) diff --git a/conftest.py b/conftest.py index 3c2abdc6..e7975057 100644 --- a/conftest.py +++ b/conftest.py @@ -5,6 +5,14 @@ from redisvl.redis.connection import RedisConnectionFactory from testcontainers.compose import DockerCompose + +# @pytest.fixture(scope="session") +# def event_loop(): +# loop = asyncio.get_event_loop_policy().new_event_loop() +# yield loop +# loop.close() + + @pytest.fixture(scope="session", autouse=True) def redis_container(): # Set the default Redis version if not already set @@ -25,7 +33,7 @@ def redis_container(): def redis_url(): return os.getenv("REDIS_URL", "redis://localhost:6379") -@pytest.fixture(scope="session") +@pytest.fixture async def async_client(redis_url): client = await RedisConnectionFactory.get_async_redis_connection(redis_url) yield client @@ -35,7 +43,7 @@ async def async_client(redis_url): if "Event loop is closed" not in str(e): raise -@pytest.fixture(scope="session") +@pytest.fixture def client(): conn = RedisConnectionFactory.get_redis_connection(os.environ["REDIS_URL"]) yield conn diff --git a/tests/integration/test_flow_async.py b/tests/integration/test_flow_async.py index afb8e874..11762068 100644 --- a/tests/integration/test_flow_async.py +++ b/tests/integration/test_flow_async.py @@ -44,7 +44,7 @@ } -@pytest.mark.asyncio(scope="session") +@pytest.mark.asyncio @pytest.mark.parametrize("schema", [hash_schema, json_schema]) async def test_simple(async_client, schema, sample_data): index = AsyncSearchIndex.from_dict(schema) diff --git a/tests/unit/test_async_search_index.py b/tests/unit/test_async_search_index.py index bfb615b1..ff1c7d9f 100644 --- a/tests/unit/test_async_search_index.py +++ b/tests/unit/test_async_search_index.py @@ -64,7 +64,7 @@ def test_search_index_set_client(async_client, client, async_index): assert async_index.client == None -@pytest.mark.asyncio(scope="session") +@pytest.mark.asyncio async def test_search_index_create(async_client, async_index): async_index.set_client(async_client) await async_index.create(overwrite=True, drop=True) @@ -74,7 +74,7 @@ async def test_search_index_create(async_client, async_index): ) -@pytest.mark.asyncio(scope="session") +@pytest.mark.asyncio async def test_search_index_delete(async_client, async_index): async_index.set_client(async_client) await async_index.create(overwrite=True, drop=True) @@ -85,7 +85,7 @@ async def test_search_index_delete(async_client, async_index): ) -@pytest.mark.asyncio(scope="session") +@pytest.mark.asyncio async def test_search_index_load_and_fetch(async_client, async_index): async_index.set_client(async_client) await async_index.create(overwrite=True, drop=True) @@ -104,7 +104,7 @@ async def test_search_index_load_and_fetch(async_client, async_index): assert not await async_index.fetch("1") -@pytest.mark.asyncio(scope="session") +@pytest.mark.asyncio async def test_search_index_load_preprocess(async_client, async_index): async_index.set_client(async_client) await async_index.create(overwrite=True, drop=True) @@ -129,7 +129,7 @@ async def bad_preprocess(record): await async_index.load(data, id_field="id", preprocess=bad_preprocess) -@pytest.mark.asyncio(scope="session") +@pytest.mark.asyncio async def test_no_id_field(async_client, async_index): async_index.set_client(async_client) await async_index.create(overwrite=True, drop=True) @@ -140,7 +140,7 @@ async def test_no_id_field(async_client, async_index): await async_index.load(bad_data, id_field="key") -@pytest.mark.asyncio(scope="session") +@pytest.mark.asyncio async def test_check_index_exists_before_delete(async_client, async_index): async_index.set_client(async_client) await async_index.create(overwrite=True, drop=True) @@ -149,7 +149,7 @@ async def test_check_index_exists_before_delete(async_client, async_index): await async_index.delete() -@pytest.mark.asyncio(scope="session") +@pytest.mark.asyncio async def test_check_index_exists_before_search(async_client, async_index): async_index.set_client(async_client) await async_index.create(overwrite=True, drop=True) @@ -165,7 +165,7 @@ async def test_check_index_exists_before_search(async_client, async_index): await async_index.search(query.query, query_params=query.params) -@pytest.mark.asyncio(scope="session") +@pytest.mark.asyncio async def test_check_index_exists_before_info(async_client, async_index): async_index.set_client(async_client) await async_index.create(overwrite=True, drop=True)