diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml index 41dca251..c5c389a9 100644 --- a/.github/workflows/run_tests.yml +++ b/.github/workflows/run_tests.yml @@ -71,7 +71,7 @@ jobs: - name: Run tests if: matrix.connection != 'plain' || matrix.redis-stack-version != 'latest' run: | - SKIP_VECTORIZERS=True poetry run test-cov + SKIP_VECTORIZERS=True SKIP_RERANKERS=True poetry run test-cov - name: Run notebooks if: matrix.connection == 'plain' && matrix.redis-stack-version == 'latest' diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 86e59e8b..45634012 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -68,6 +68,11 @@ Tests w/out vectorizers: SKIP_VECTORIZERS=true poetry run test-cov ``` +Tests w/out rerankers: +```bash +SKIP_RERANKERS=true poetry run test-cov +``` + ### Getting Redis In order for your applications to use RedisVL, you must have [Redis](https://redis.io) accessible with Search & Query features enabled on [Redis Cloud](https://redis.com/try-free) or locally in docker with [Redis Stack](https://redis.io/docs/getting-started/install-stack/docker/): diff --git a/conftest.py b/conftest.py index 3c2abdc6..e7975057 100644 --- a/conftest.py +++ b/conftest.py @@ -5,6 +5,14 @@ from redisvl.redis.connection import RedisConnectionFactory from testcontainers.compose import DockerCompose + +# @pytest.fixture(scope="session") +# def event_loop(): +# loop = asyncio.get_event_loop_policy().new_event_loop() +# yield loop +# loop.close() + + @pytest.fixture(scope="session", autouse=True) def redis_container(): # Set the default Redis version if not already set @@ -25,7 +33,7 @@ def redis_container(): def redis_url(): return os.getenv("REDIS_URL", "redis://localhost:6379") -@pytest.fixture(scope="session") +@pytest.fixture async def async_client(redis_url): client = await RedisConnectionFactory.get_async_redis_connection(redis_url) yield client @@ -35,7 +43,7 @@ async def async_client(redis_url): if "Event loop is closed" not in str(e): raise -@pytest.fixture(scope="session") +@pytest.fixture def client(): conn = RedisConnectionFactory.get_redis_connection(os.environ["REDIS_URL"]) yield conn diff --git a/docs/api/index.md b/docs/api/index.md index 613e5542..b2c45ab6 100644 --- a/docs/api/index.md +++ b/docs/api/index.md @@ -16,6 +16,7 @@ searchindex query filter vectorizer +reranker cache ``` diff --git a/docs/api/reranker.rst b/docs/api/reranker.rst new file mode 100644 index 00000000..dd9fd425 --- /dev/null +++ b/docs/api/reranker.rst @@ -0,0 +1,14 @@ +*********** +Rerankers +*********** + +CohereReranker +================ + +.. _coherereranker_api: + +.. currentmodule:: redisvl.utils.rerank.cohere + +.. autoclass:: CohereReranker + :show-inheritance: + :members: diff --git a/docs/api/vectorizer.rst b/docs/api/vectorizer.rst index 61dd432c..8f2d8a97 100644 --- a/docs/api/vectorizer.rst +++ b/docs/api/vectorizer.rst @@ -1,4 +1,3 @@ - *********** Vectorizers *********** diff --git a/redisvl/utils/rerank/__init__.py b/redisvl/utils/rerank/__init__.py new file mode 100644 index 00000000..ef7fa9e4 --- /dev/null +++ b/redisvl/utils/rerank/__init__.py @@ -0,0 +1,7 @@ +from redisvl.utils.rerank.base import BaseReranker +from redisvl.utils.rerank.cohere import CohereReranker + +__all__ = [ + "BaseReranker", + "CohereReranker", +] diff --git a/redisvl/utils/rerank/base.py b/redisvl/utils/rerank/base.py new file mode 100644 index 00000000..f4602662 --- /dev/null +++ b/redisvl/utils/rerank/base.py @@ -0,0 +1,48 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional, Tuple, Union + +from pydantic.v1 import BaseModel, validator + + +class BaseReranker(BaseModel, ABC): + model: str + rank_by: Optional[List[str]] = None + limit: int + return_score: bool + + @validator("limit") + @classmethod + def check_limit(cls, value): + """Ensures the limit is a positive integer.""" + if value <= 0: + raise ValueError("Limit must be a positive integer.") + return value + + @validator("rank_by") + @classmethod + def check_rank_by(cls, value): + """Ensures that rank_by is a list of strings if provided.""" + if value is not None and ( + not isinstance(value, list) + or any(not isinstance(item, str) for item in value) + ): + raise ValueError("rank_by must be a list of strings.") + return value + + @abstractmethod + def rank( + self, query: str, docs: Union[List[Dict[str, Any]], List[str]], **kwargs + ) -> Union[Tuple[List[Dict[str, Any]], List[float]], List[Dict[str, Any]]]: + """ + Synchronously rerank the docs based on the provided query. + """ + pass + + @abstractmethod + async def arank( + self, query: str, docs: Union[List[Dict[str, Any]], List[str]], **kwargs + ) -> Union[Tuple[List[Dict[str, Any]], List[float]], List[Dict[str, Any]]]: + """ + Asynchronously rerank the docs based on the provided query. + """ + pass diff --git a/redisvl/utils/rerank/cohere.py b/redisvl/utils/rerank/cohere.py new file mode 100644 index 00000000..29a69788 --- /dev/null +++ b/redisvl/utils/rerank/cohere.py @@ -0,0 +1,185 @@ +import os +from typing import Any, Dict, List, Optional, Tuple, Union + +from pydantic.v1 import PrivateAttr + +from redisvl.utils.rerank.base import BaseReranker + + +class CohereReranker(BaseReranker): + """ + The CohereReranker class uses Cohere's API to rerank documents based on an + input query. + + This reranker is designed to interact with Cohere's /rerank API, + requiring an API key for authentication. The key can be provided + directly in the `api_config` dictionary or through the `COHERE_API_KEY` + environment variable. User must obtain an API key from Cohere's website + (https://dashboard.cohere.com/). Additionally, the `cohere` python + client must be installed with `pip install cohere`. + + .. code-block:: python + + + """ + + _client: Any = PrivateAttr() + _aclient: Any = PrivateAttr() + + def __init__( + self, + model: str = "rerank-english-v3.0", + rank_by: Optional[List[str]] = None, + limit: int = 5, + return_score: bool = True, + api_config: Optional[Dict] = None, + ) -> None: + """ + Initialize the CohereReranker with specified model, ranking criteria, + and API configuration. + + Parameters: + model (str): The identifier for the Cohere model used for reranking. + Defaults to 'rerank-english-v3.0'. + rank_by (Optional[List[str]]): Optional list of keys specifying the + attributes in the documents that should be considered for + ranking. None means ranking will rely on the model's default + behavior. + limit (int): The maximum number of results to return after + reranking. Must be a positive integer. + return_score (bool): Whether to return scores alongside the + reranked results. + api_config (Optional[Dict], optional): Dictionary containing the API key. + Defaults to None. + + Raises: + ImportError: If the cohere library is not installed. + ValueError: If the API key is not provided. + """ + super().__init__( + model=model, rank_by=rank_by, limit=limit, return_score=return_score + ) + self._initialize_clients(api_config) + + def _initialize_clients(self, api_config: Optional[Dict]): + """ + Setup the Cohere clients using the provided API key or an + environment variable. + """ + # Dynamic import of the cohere module + try: + from cohere import AsyncClient, Client + except ImportError: + raise ImportError( + "Cohere vectorizer requires the cohere library. \ + Please install with `pip install cohere`" + ) + + # Fetch the API key from api_config or environment variable + api_key = ( + api_config.get("api_key") if api_config else os.getenv("COHERE_API_KEY") + ) + if not api_key: + raise ValueError( + "Cohere API key is required. " + "Provide it in api_config or set the COHERE_API_KEY environment variable." + ) + self._client = Client(api_key=api_key, client_name="redisvl") + self._aclient = AsyncClient(api_key=api_key, client_name="redisvl") + + def _preprocess( + self, query: str, docs: Union[List[Dict[str, Any]], List[str]], **kwargs + ): + """ + Prepare and validate reranking config based on provided input and + optional overrides. + """ + limit = kwargs.get("limit", self.limit) + return_score = kwargs.get("return_score", self.return_score) + max_chunks_per_doc = kwargs.get("max_chunks_per_doc") + rank_by = kwargs.get("rank_by", self.rank_by) or [] + rank_by = [rank_by] if isinstance(rank_by, str) else rank_by + + reranker_kwargs = { + "model": self.model, + "query": query, + "top_n": limit, + "documents": docs, + "max_chunks_per_doc": max_chunks_per_doc, + } + # if we are working with list of dicts + if all(isinstance(doc, dict) for doc in docs): + if rank_by: + reranker_kwargs["rank_fields"] = rank_by + else: + raise ValueError( + "If reranking dictionary-like docs, " + "you must provide a list of rank_by fields" + ) + + return reranker_kwargs, return_score + + @staticmethod + def _postprocess( + docs: Union[List[Dict[str, Any]], List[str]], + rankings: List[Any], + ) -> Tuple[List[Any], List[float]]: + """ + Post-process the initial list of documents to include ranking scores, + if specified. + """ + reranked_docs, scores = [], [] + for item in rankings.results: # type: ignore + scores.append(item.relevance_score) + reranked_docs.append(docs[item.index]) + return reranked_docs, scores + + def rank( + self, query: str, docs: Union[List[Dict[str, Any]], List[str]], **kwargs + ) -> Union[Tuple[List[Dict[str, Any]], List[float]], List[Dict[str, Any]]]: + """ + Rerank documents based on the provided query using the Cohere rerank API. + + This method processes the user's query and the provided documents to + rerank them in a manner that is potentially more relevant to the + query's context. + + Parameters: + query (str): The user's search query. + docs (Union[List[Dict[str, Any]], List[str]]): The list of documents + to be ranked, either as dictionaries or strings. + + Returns: + Union[Tuple[Union[List[Dict[str, Any]], List[str]], float], List[Dict[str, Any]]]: The reranked list of documents and optionally associated scores. + """ + reranker_kwargs, return_score = self._preprocess(query, docs, **kwargs) + rankings = self._client.rerank(**reranker_kwargs) + reranked_docs, scores = self._postprocess(docs, rankings) + if return_score: + return reranked_docs, scores + return reranked_docs + + async def arank( + self, query: str, docs: Union[List[Dict[str, Any]], List[str]], **kwargs + ) -> Union[Tuple[List[Dict[str, Any]], List[float]], List[Dict[str, Any]]]: + """ + Rerank documents based on the provided query using the Cohere rerank API. + + This method processes the user's query and the provided documents to + rerank them in a manner that is potentially more relevant to the + query's context. + + Parameters: + query (str): The user's search query. + docs (Union[List[Dict[str, Any]], List[str]]): The list of documents + to be ranked, either as dictionaries or strings. + + Returns: + Union[Tuple[Union[List[Dict[str, Any]], List[str]], float], List[Dict[str, Any]]]: The reranked list of documents and optionally associated scores. + """ + reranker_kwargs, return_score = self._preprocess(query, docs, **kwargs) + rankings = await self._aclient.rerank(**reranker_kwargs) + reranked_docs, scores = self._postprocess(docs, rankings) + if return_score: + return reranked_docs, scores + return reranked_docs diff --git a/redisvl/utils/vectorize/base.py b/redisvl/utils/vectorize/base.py index 46ba955d..f5ef8198 100644 --- a/redisvl/utils/vectorize/base.py +++ b/redisvl/utils/vectorize/base.py @@ -1,22 +1,24 @@ -from typing import Any, Callable, List, Optional +from abc import ABC, abstractmethod +from typing import Callable, List, Optional from pydantic.v1 import BaseModel, validator from redisvl.redis.utils import array_to_buffer -class BaseVectorizer(BaseModel): +class BaseVectorizer(BaseModel, ABC): model: str dims: int - client: Any - @validator("dims", pre=True) + @validator("dims") @classmethod - def check_dims(cls, v): - if v <= 0: - raise ValueError("Dimension must be a positive integer") - return v + def check_dims(cls, value): + """Ensures the dims are a positive integer.""" + if value <= 0: + raise ValueError("Dims must be a positive integer.") + return value + @abstractmethod def embed_many( self, texts: List[str], @@ -27,6 +29,7 @@ def embed_many( ) -> List[List[float]]: raise NotImplementedError + @abstractmethod def embed( self, text: str, @@ -36,6 +39,7 @@ def embed( ) -> List[float]: raise NotImplementedError + @abstractmethod async def aembed_many( self, texts: List[str], @@ -46,6 +50,7 @@ async def aembed_many( ) -> List[List[float]]: raise NotImplementedError + @abstractmethod async def aembed( self, text: str, diff --git a/redisvl/utils/vectorize/text/azureopenai.py b/redisvl/utils/vectorize/text/azureopenai.py index 5ac527fa..fc13eb75 100644 --- a/redisvl/utils/vectorize/text/azureopenai.py +++ b/redisvl/utils/vectorize/text/azureopenai.py @@ -1,6 +1,7 @@ import os from typing import Any, Callable, Dict, List, Optional +from pydantic.v1 import PrivateAttr from tenacity import retry, stop_after_attempt, wait_random_exponential from tenacity.retry import retry_if_not_exception_type @@ -47,7 +48,8 @@ class AzureOpenAITextVectorizer(BaseVectorizer): """ - aclient: Any # Since the OpenAI module is loaded dynamically + _client: Any = PrivateAttr() + _aclient: Any = PrivateAttr() def __init__( self, model: str = "text-embedding-ada-002", api_config: Optional[Dict] = None @@ -65,6 +67,14 @@ def __init__( ImportError: If the openai library is not installed. ValueError: If the AzureOpenAI API key, version, or endpoint are not provided. """ + self._initialize_clients(api_config) + super().__init__(model=model, dims=self._set_model_dims(model)) + + def _initialize_clients(self, api_config: Optional[Dict]): + """ + Setup the OpenAI clients using the provided API key or an + environment variable. + """ # Dynamic import of the openai module try: from openai import AsyncAzureOpenAI, AzureOpenAI @@ -114,20 +124,17 @@ def __init__( environment variable." ) - client = AzureOpenAI( + self._client = AzureOpenAI( api_key=api_key, api_version=api_version, azure_endpoint=azure_endpoint ) - dims = self._set_model_dims(client, model) - super().__init__(model=model, dims=dims, client=client) - self.aclient = AsyncAzureOpenAI( + self._aclient = AsyncAzureOpenAI( api_key=api_key, api_version=api_version, azure_endpoint=azure_endpoint ) - @staticmethod - def _set_model_dims(client, model) -> int: + def _set_model_dims(self, model) -> int: try: embedding = ( - client.embeddings.create(input=["dimension test"], model=model) + self._client.embeddings.create(input=["dimension test"], model=model) .data[0] .embedding ) @@ -175,7 +182,7 @@ def embed_many( embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): - response = self.client.embeddings.create(input=batch, model=self.model) + response = self._client.embeddings.create(input=batch, model=self.model) embeddings += [ self._process_embedding(r.embedding, as_buffer) for r in response.data ] @@ -213,7 +220,7 @@ def embed( if preprocess: text = preprocess(text) - result = self.client.embeddings.create(input=[text], model=self.model) + result = self._client.embeddings.create(input=[text], model=self.model) return self._process_embedding(result.data[0].embedding, as_buffer) @retry( @@ -253,7 +260,7 @@ async def aembed_many( embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): - response = await self.aclient.embeddings.create( + response = await self._aclient.embeddings.create( input=batch, model=self.model ) embeddings += [ @@ -293,5 +300,5 @@ async def aembed( if preprocess: text = preprocess(text) - result = await self.aclient.embeddings.create(input=[text], model=self.model) + result = await self._aclient.embeddings.create(input=[text], model=self.model) return self._process_embedding(result.data[0].embedding, as_buffer) diff --git a/redisvl/utils/vectorize/text/cohere.py b/redisvl/utils/vectorize/text/cohere.py index 7eadd658..cec856dd 100644 --- a/redisvl/utils/vectorize/text/cohere.py +++ b/redisvl/utils/vectorize/text/cohere.py @@ -1,6 +1,7 @@ import os -from typing import Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional +from pydantic.v1 import PrivateAttr from tenacity import retry, stop_after_attempt, wait_random_exponential from tenacity.retry import retry_if_not_exception_type @@ -43,6 +44,8 @@ class CohereTextVectorizer(BaseVectorizer): """ + _client: Any = PrivateAttr() + def __init__( self, model: str = "embed-english-v3.0", api_config: Optional[Dict] = None ): @@ -59,10 +62,18 @@ def __init__( ImportError: If the cohere library is not installed. ValueError: If the API key is not provided. + """ + self._initialize_client(api_config) + super().__init__(model=model, dims=self._set_model_dims(model)) + + def _initialize_client(self, api_config: Optional[Dict]): + """ + Setup the Cohere clients using the provided API key or an + environment variable. """ # Dynamic import of the cohere module try: - import cohere + from cohere import AsyncClient, Client except ImportError: raise ImportError( "Cohere vectorizer requires the cohere library. \ @@ -78,15 +89,11 @@ def __init__( "Cohere API key is required. " "Provide it in api_config or set the COHERE_API_KEY environment variable." ) + self._client = Client(api_key=api_key, client_name="redisvl") - client = cohere.Client(api_key, client_name="redisvl") - dims = self._set_model_dims(client, model) - super().__init__(model=model, dims=dims, client=client) - - @staticmethod - def _set_model_dims(client, model) -> int: + def _set_model_dims(self, model) -> int: try: - embedding = client.embed( + embedding = self._client.embed( texts=["dimension test"], model=model, input_type="search_document", @@ -150,7 +157,7 @@ def embed( ) if preprocess: text = preprocess(text) - embedding = self.client.embed( + embedding = self._client.embed( texts=[text], model=self.model, input_type=input_type ).embeddings[0] return self._process_embedding(embedding, as_buffer) @@ -219,7 +226,7 @@ def embed_many( embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): - response = self.client.embed( + response = self._client.embed( texts=batch, model=self.model, input_type=input_type ) embeddings += [ @@ -227,3 +234,22 @@ def embed_many( for embedding in response.embeddings ] return embeddings + + async def aembed_many( + self, + texts: List[str], + preprocess: Optional[Callable] = None, + batch_size: int = 1000, + as_buffer: bool = False, + **kwargs, + ) -> List[List[float]]: + raise NotImplementedError + + async def aembed( + self, + text: str, + preprocess: Optional[Callable] = None, + as_buffer: bool = False, + **kwargs, + ) -> List[float]: + raise NotImplementedError diff --git a/redisvl/utils/vectorize/text/huggingface.py b/redisvl/utils/vectorize/text/huggingface.py index 5d02ed97..cb72652e 100644 --- a/redisvl/utils/vectorize/text/huggingface.py +++ b/redisvl/utils/vectorize/text/huggingface.py @@ -1,4 +1,6 @@ -from typing import Callable, Dict, List, Optional +from typing import Any, Callable, List, Optional + +from pydantic.v1 import PrivateAttr from redisvl.utils.vectorize.base import BaseVectorizer @@ -28,6 +30,8 @@ class HFTextVectorizer(BaseVectorizer): """ + _client: Any = PrivateAttr() + def __init__( self, model: str = "sentence-transformers/all-mpnet-base-v2", **kwargs ): @@ -42,7 +46,12 @@ def __init__( ImportError: If the sentence-transformers library is not installed. ValueError: If there is an error setting the embedding model dimensions. """ - # Load the SentenceTransformer model + self._initialize_client(model) + super().__init__(model=model, dims=self._set_model_dims()) + + def _initialize_client(self, model: str): + """Setup the HuggingFace client""" + # Dynamic import of the cohere module\ try: from sentence_transformers import SentenceTransformer except ImportError: @@ -51,14 +60,11 @@ def __init__( "Please install with `pip install sentence-transformers`" ) - client = SentenceTransformer(model) - dims = self._set_model_dims(client) - super().__init__(model=model, dims=dims, client=client) + self._client = SentenceTransformer(model) - @staticmethod - def _set_model_dims(client): + def _set_model_dims(self): try: - embedding = client.encode(["dimension check"])[0] + embedding = self._client.encode(["dimension check"])[0] except (KeyError, IndexError) as ke: raise ValueError(f"Empty response from the embedding model: {str(ke)}") except Exception as e: # pylint: disable=broad-except @@ -93,7 +99,7 @@ def embed( if preprocess: text = preprocess(text) - embedding = self.client.encode([text])[0] + embedding = self._client.encode([text])[0] return self._process_embedding(embedding.tolist(), as_buffer) def embed_many( @@ -129,7 +135,7 @@ def embed_many( embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): - batch_embeddings = self.client.encode(batch) + batch_embeddings = self._client.encode(batch) embeddings.extend( [ self._process_embedding(embedding.tolist(), as_buffer) @@ -137,3 +143,22 @@ def embed_many( ] ) return embeddings + + async def aembed_many( + self, + texts: List[str], + preprocess: Optional[Callable] = None, + batch_size: int = 1000, + as_buffer: bool = False, + **kwargs, + ) -> List[List[float]]: + raise NotImplementedError + + async def aembed( + self, + text: str, + preprocess: Optional[Callable] = None, + as_buffer: bool = False, + **kwargs, + ) -> List[float]: + raise NotImplementedError diff --git a/redisvl/utils/vectorize/text/openai.py b/redisvl/utils/vectorize/text/openai.py index 22afb2f5..b5d2070c 100644 --- a/redisvl/utils/vectorize/text/openai.py +++ b/redisvl/utils/vectorize/text/openai.py @@ -1,6 +1,7 @@ import os from typing import Any, Callable, Dict, List, Optional +from pydantic.v1 import PrivateAttr from tenacity import retry, stop_after_attempt, wait_random_exponential from tenacity.retry import retry_if_not_exception_type @@ -42,7 +43,8 @@ class OpenAITextVectorizer(BaseVectorizer): """ - aclient: Any # Since the OpenAI module is loaded dynamically + _client: Any = PrivateAttr() + _aclient: Any = PrivateAttr() def __init__( self, model: str = "text-embedding-ada-002", api_config: Optional[Dict] = None @@ -59,6 +61,14 @@ def __init__( ImportError: If the openai library is not installed. ValueError: If the OpenAI API key is not provided. """ + self._initialize_clients(api_config) + super().__init__(model=model, dims=self._set_model_dims(model)) + + def _initialize_clients(self, api_config: Optional[Dict]): + """ + Setup the OpenAI clients using the provided API key or an + environment variable. + """ # Dynamic import of the openai module try: from openai import AsyncOpenAI, OpenAI @@ -79,16 +89,13 @@ def __init__( environment variable." ) - client = OpenAI(api_key=api_key) - dims = self._set_model_dims(client, model) - super().__init__(model=model, dims=dims, client=client) - self.aclient = AsyncOpenAI(api_key=api_key) + self._client = OpenAI(api_key=api_key) + self._aclient = AsyncOpenAI(api_key=api_key) - @staticmethod - def _set_model_dims(client, model) -> int: + def _set_model_dims(self, model) -> int: try: embedding = ( - client.embeddings.create(input=["dimension test"], model=model) + self._client.embeddings.create(input=["dimension test"], model=model) .data[0] .embedding ) @@ -136,7 +143,7 @@ def embed_many( embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): - response = self.client.embeddings.create(input=batch, model=self.model) + response = self._client.embeddings.create(input=batch, model=self.model) embeddings += [ self._process_embedding(r.embedding, as_buffer) for r in response.data ] @@ -174,7 +181,7 @@ def embed( if preprocess: text = preprocess(text) - result = self.client.embeddings.create(input=[text], model=self.model) + result = self._client.embeddings.create(input=[text], model=self.model) return self._process_embedding(result.data[0].embedding, as_buffer) @retry( @@ -214,7 +221,7 @@ async def aembed_many( embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): - response = await self.aclient.embeddings.create( + response = await self._aclient.embeddings.create( input=batch, model=self.model ) embeddings += [ @@ -254,5 +261,5 @@ async def aembed( if preprocess: text = preprocess(text) - result = await self.aclient.embeddings.create(input=[text], model=self.model) + result = await self._aclient.embeddings.create(input=[text], model=self.model) return self._process_embedding(result.data[0].embedding, as_buffer) diff --git a/redisvl/utils/vectorize/text/vertexai.py b/redisvl/utils/vectorize/text/vertexai.py index 0aaf314a..1d67c672 100644 --- a/redisvl/utils/vectorize/text/vertexai.py +++ b/redisvl/utils/vectorize/text/vertexai.py @@ -1,6 +1,7 @@ import os -from typing import Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional +from pydantic.v1 import PrivateAttr from tenacity import retry, stop_after_attempt, wait_random_exponential from tenacity.retry import retry_if_not_exception_type @@ -40,6 +41,8 @@ class VertexAITextVectorizer(BaseVectorizer): """ + _client: Any = PrivateAttr() + def __init__( self, model: str = "textembedding-gecko", api_config: Optional[Dict] = None ): @@ -55,6 +58,14 @@ def __init__( ImportError: If the google-cloud-aiplatform library is not installed. ValueError: If the API key is not provided. """ + self._initialize_client(model, api_config) + super().__init__(model=model, dims=self._set_model_dims()) + + def _initialize_client(self, model: str, api_config: Optional[Dict]): + """ + Setup the VertexAI clients using the provided API key or an + environment variable. + """ # Fetch the project_id and location from api_config or environment variables project_id = ( api_config.get("project_id") if api_config else os.getenv("GCP_PROJECT_ID") @@ -93,14 +104,11 @@ def __init__( "Please install with `pip install google-cloud-aiplatform>=1.26`" ) - client = TextEmbeddingModel.from_pretrained(model) - dims = self._set_model_dims(client) - super().__init__(model=model, dims=dims, client=client) + self._client = TextEmbeddingModel.from_pretrained(model) - @staticmethod - def _set_model_dims(client) -> int: + def _set_model_dims(self) -> int: try: - embedding = client.get_embeddings(["dimension test"])[0].values + embedding = self._client.get_embeddings(["dimension test"])[0].values except (KeyError, IndexError) as ke: raise ValueError(f"Unexpected response from the VertexAI API: {str(ke)}") except Exception as e: # pylint: disable=broad-except @@ -145,7 +153,7 @@ def embed_many( embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): - response = self.client.get_embeddings(batch) + response = self._client.get_embeddings(batch) embeddings += [ self._process_embedding(r.values, as_buffer) for r in response ] @@ -183,5 +191,24 @@ def embed( if preprocess: text = preprocess(text) - result = self.client.get_embeddings([text]) + result = self._client.get_embeddings([text]) return self._process_embedding(result[0].values, as_buffer) + + async def aembed_many( + self, + texts: List[str], + preprocess: Optional[Callable] = None, + batch_size: int = 1000, + as_buffer: bool = False, + **kwargs, + ) -> List[List[float]]: + raise NotImplementedError + + async def aembed( + self, + text: str, + preprocess: Optional[Callable] = None, + as_buffer: bool = False, + **kwargs, + ) -> List[float]: + raise NotImplementedError diff --git a/scripts.py b/scripts.py index cfc254da..61d36636 100644 --- a/scripts.py +++ b/scripts.py @@ -26,7 +26,7 @@ def test_verbose(): subprocess.run(["python", "-m", "pytest", "-vv", "-s", "--log-level=CRITICAL"]) def test_cov(): - subprocess.run(["python", "-m", "pytest", "-vv", "--cov=./redisvl", "--cov-report=xml", "--log-level=CRITICAL"]) + subprocess.run(["python", "-m", "pytest", "-vv", "--cov=./redisvl", "--cov-report=xml", "--log-level=CRITICAL"], check=True) def cov(): subprocess.run(["coverage", "html"]) diff --git a/tests/integration/test_flow_async.py b/tests/integration/test_flow_async.py index afb8e874..11762068 100644 --- a/tests/integration/test_flow_async.py +++ b/tests/integration/test_flow_async.py @@ -44,7 +44,7 @@ } -@pytest.mark.asyncio(scope="session") +@pytest.mark.asyncio @pytest.mark.parametrize("schema", [hash_schema, json_schema]) async def test_simple(async_client, schema, sample_data): index = AsyncSearchIndex.from_dict(schema) diff --git a/tests/integration/test_rerankers.py b/tests/integration/test_rerankers.py new file mode 100644 index 00000000..4866aa58 --- /dev/null +++ b/tests/integration/test_rerankers.py @@ -0,0 +1,56 @@ +import os + +import pytest + +from redisvl.utils.rerank import CohereReranker + + +# Fixture for the reranker instance +@pytest.fixture +def reranker(): + skip_reranker = os.getenv("SKIP_RERANKERS", "False").lower() == "true" + if skip_reranker: + pytest.skip("Skipping reranker instantiation...") + return CohereReranker() + + +# Test for basic ranking functionality +def test_rank_documents(reranker): + docs = ["document one", "document two", "document three"] + query = "search query" + + reranked_docs, scores = reranker.rank(query, docs) + + assert isinstance(reranked_docs, list) + assert len(reranked_docs) == len(docs) # Ensure we get back as many docs as we sent + assert all(isinstance(score, float) for score in scores) # Scores should be floats + + +# Test for asynchronous ranking functionality +@pytest.mark.asyncio +async def test_async_rank_documents(reranker): + docs = ["document one", "document two", "document three"] + query = "search query" + + reranked_docs, scores = await reranker.arank(query, docs) + + assert isinstance(reranked_docs, list) + assert len(reranked_docs) == len(docs) # Ensure we get back as many docs as we sent + assert all(isinstance(score, float) for score in scores) # Scores should be floats + + +# Test handling of bad input +def test_bad_input(reranker): + with pytest.raises(Exception): + reranker.rank("", []) # Empty query or documents + + with pytest.raises(Exception): + reranker.rank(123, ["valid document"]) # Invalid type for query + + with pytest.raises(Exception): + reranker.rank("valid query", "not a list") # Invalid type for documents + + with pytest.raises(Exception): + reranker.rank( + "valid query", [{"field": "valid document"}], rank_by=["invalid_field"] + ) # Invalid rank_by field diff --git a/tests/integration/test_vectorizers.py b/tests/integration/test_vectorizers.py index 352aaa2e..23952c65 100644 --- a/tests/integration/test_vectorizers.py +++ b/tests/integration/test_vectorizers.py @@ -15,7 +15,6 @@ def skip_vectorizer() -> bool: # os.getenv returns a string v = os.getenv("SKIP_VECTORIZERS", "False").lower() == "true" - print(v, flush=True) return v @@ -46,10 +45,7 @@ def vectorizer(request, skip_vectorizer): ) -def test_vectorizer_embed(vectorizer, skip_vectorizer): - if skip_vectorizer: - pytest.skip("Skipping vectorizer tests") - +def test_vectorizer_embed(vectorizer): text = "This is a test sentence." if isinstance(vectorizer, CohereTextVectorizer): embedding = vectorizer.embed(text, input_type="search_document") @@ -60,10 +56,7 @@ def test_vectorizer_embed(vectorizer, skip_vectorizer): assert len(embedding) == vectorizer.dims -def test_vectorizer_embed_many(vectorizer, skip_vectorizer): - if skip_vectorizer: - pytest.skip("Skipping vectorizer tests") - +def test_vectorizer_embed_many(vectorizer): texts = ["This is the first test sentence.", "This is the second test sentence."] if isinstance(vectorizer, CohereTextVectorizer): embeddings = vectorizer.embed_many(texts, input_type="search_document") @@ -77,10 +70,7 @@ def test_vectorizer_embed_many(vectorizer, skip_vectorizer): ) -def test_vectorizer_bad_input(vectorizer, skip_vectorizer): - if skip_vectorizer: - pytest.skip("Skipping vectorizer tests") - +def test_vectorizer_bad_input(vectorizer): with pytest.raises(TypeError): vectorizer.embed(1) @@ -102,10 +92,7 @@ def avectorizer(request, skip_vectorizer): @pytest.mark.asyncio -async def test_vectorizer_aembed(avectorizer, skip_vectorizer): - if skip_vectorizer: - pytest.skip("Skipping vectorizer tests") - +async def test_vectorizer_aembed(avectorizer): text = "This is a test sentence." embedding = await avectorizer.aembed(text) @@ -114,10 +101,7 @@ async def test_vectorizer_aembed(avectorizer, skip_vectorizer): @pytest.mark.asyncio -async def test_vectorizer_aembed_many(avectorizer, skip_vectorizer): - if skip_vectorizer: - pytest.skip("Skipping vectorizer tests") - +async def test_vectorizer_aembed_many(avectorizer): texts = ["This is the first test sentence.", "This is the second test sentence."] embeddings = await avectorizer.aembed_many(texts) @@ -129,10 +113,7 @@ async def test_vectorizer_aembed_many(avectorizer, skip_vectorizer): @pytest.mark.asyncio -async def test_avectorizer_bad_input(avectorizer, skip_vectorizer): - if skip_vectorizer: - pytest.skip("Skipping vectorizer tests") - +async def test_avectorizer_bad_input(avectorizer): with pytest.raises(TypeError): avectorizer.embed(1) diff --git a/tests/unit/test_async_search_index.py b/tests/unit/test_async_search_index.py index bfb615b1..ff1c7d9f 100644 --- a/tests/unit/test_async_search_index.py +++ b/tests/unit/test_async_search_index.py @@ -64,7 +64,7 @@ def test_search_index_set_client(async_client, client, async_index): assert async_index.client == None -@pytest.mark.asyncio(scope="session") +@pytest.mark.asyncio async def test_search_index_create(async_client, async_index): async_index.set_client(async_client) await async_index.create(overwrite=True, drop=True) @@ -74,7 +74,7 @@ async def test_search_index_create(async_client, async_index): ) -@pytest.mark.asyncio(scope="session") +@pytest.mark.asyncio async def test_search_index_delete(async_client, async_index): async_index.set_client(async_client) await async_index.create(overwrite=True, drop=True) @@ -85,7 +85,7 @@ async def test_search_index_delete(async_client, async_index): ) -@pytest.mark.asyncio(scope="session") +@pytest.mark.asyncio async def test_search_index_load_and_fetch(async_client, async_index): async_index.set_client(async_client) await async_index.create(overwrite=True, drop=True) @@ -104,7 +104,7 @@ async def test_search_index_load_and_fetch(async_client, async_index): assert not await async_index.fetch("1") -@pytest.mark.asyncio(scope="session") +@pytest.mark.asyncio async def test_search_index_load_preprocess(async_client, async_index): async_index.set_client(async_client) await async_index.create(overwrite=True, drop=True) @@ -129,7 +129,7 @@ async def bad_preprocess(record): await async_index.load(data, id_field="id", preprocess=bad_preprocess) -@pytest.mark.asyncio(scope="session") +@pytest.mark.asyncio async def test_no_id_field(async_client, async_index): async_index.set_client(async_client) await async_index.create(overwrite=True, drop=True) @@ -140,7 +140,7 @@ async def test_no_id_field(async_client, async_index): await async_index.load(bad_data, id_field="key") -@pytest.mark.asyncio(scope="session") +@pytest.mark.asyncio async def test_check_index_exists_before_delete(async_client, async_index): async_index.set_client(async_client) await async_index.create(overwrite=True, drop=True) @@ -149,7 +149,7 @@ async def test_check_index_exists_before_delete(async_client, async_index): await async_index.delete() -@pytest.mark.asyncio(scope="session") +@pytest.mark.asyncio async def test_check_index_exists_before_search(async_client, async_index): async_index.set_client(async_client) await async_index.create(overwrite=True, drop=True) @@ -165,7 +165,7 @@ async def test_check_index_exists_before_search(async_client, async_index): await async_index.search(query.query, query_params=query.params) -@pytest.mark.asyncio(scope="session") +@pytest.mark.asyncio async def test_check_index_exists_before_info(async_client, async_index): async_index.set_client(async_client) await async_index.create(overwrite=True, drop=True)