-
Notifications
You must be signed in to change notification settings - Fork 60
Initial reranker integration in RedisVL #139
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
9e5ba54
WIP
tylerhutcherson 26b5d3d
add base reranker and pilot cohere implementation
tylerhutcherson 8e57386
updates to reranker implementation
tylerhutcherson 58678bb
update to support sync and async, use PrivateAttr, and finish cohere …
tylerhutcherson 66c5123
wip
tylerhutcherson e7c18a5
wip
tylerhutcherson 511ba28
update api docs
tylerhutcherson f6e9b54
wip update vectorizers
tylerhutcherson 969ebe9
use appropriate pydantic shim
tylerhutcherson 5bae8b0
fix formatting and mypy
tylerhutcherson 46347ef
Add validator
tylerhutcherson 6a43865
add reranker tests
tylerhutcherson e5b9deb
add SKIP flag for rerankers
tylerhutcherson 2676ced
updates to flags and subprocess call
tylerhutcherson c6a9a96
use proper event loop policy from asyncio and pytest
tylerhutcherson File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,6 +16,7 @@ searchindex | |
| query | ||
| filter | ||
| vectorizer | ||
| reranker | ||
| cache | ||
| ``` | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,14 @@ | ||
| *********** | ||
| Rerankers | ||
| *********** | ||
|
|
||
| CohereReranker | ||
| ================ | ||
|
|
||
| .. _coherereranker_api: | ||
|
|
||
| .. currentmodule:: redisvl.utils.rerank.cohere | ||
|
|
||
| .. autoclass:: CohereReranker | ||
| :show-inheritance: | ||
| :members: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,3 @@ | ||
|
|
||
| *********** | ||
| Vectorizers | ||
| *********** | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| from redisvl.utils.rerank.base import BaseReranker | ||
| from redisvl.utils.rerank.cohere import CohereReranker | ||
|
|
||
| __all__ = [ | ||
| "BaseReranker", | ||
| "CohereReranker", | ||
| ] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,48 @@ | ||
| from abc import ABC, abstractmethod | ||
| from typing import Any, Dict, List, Optional, Tuple, Union | ||
|
|
||
| from pydantic.v1 import BaseModel, validator | ||
|
|
||
|
|
||
| class BaseReranker(BaseModel, ABC): | ||
| model: str | ||
| rank_by: Optional[List[str]] = None | ||
| limit: int | ||
| return_score: bool | ||
|
|
||
| @validator("limit") | ||
| @classmethod | ||
| def check_limit(cls, value): | ||
| """Ensures the limit is a positive integer.""" | ||
| if value <= 0: | ||
| raise ValueError("Limit must be a positive integer.") | ||
| return value | ||
|
|
||
| @validator("rank_by") | ||
| @classmethod | ||
| def check_rank_by(cls, value): | ||
| """Ensures that rank_by is a list of strings if provided.""" | ||
| if value is not None and ( | ||
| not isinstance(value, list) | ||
| or any(not isinstance(item, str) for item in value) | ||
| ): | ||
| raise ValueError("rank_by must be a list of strings.") | ||
| return value | ||
|
|
||
| @abstractmethod | ||
| def rank( | ||
| self, query: str, docs: Union[List[Dict[str, Any]], List[str]], **kwargs | ||
| ) -> Union[Tuple[List[Dict[str, Any]], List[float]], List[Dict[str, Any]]]: | ||
| """ | ||
| Synchronously rerank the docs based on the provided query. | ||
| """ | ||
| pass | ||
|
|
||
| @abstractmethod | ||
| async def arank( | ||
| self, query: str, docs: Union[List[Dict[str, Any]], List[str]], **kwargs | ||
| ) -> Union[Tuple[List[Dict[str, Any]], List[float]], List[Dict[str, Any]]]: | ||
| """ | ||
| Asynchronously rerank the docs based on the provided query. | ||
| """ | ||
| pass |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,185 @@ | ||
| import os | ||
| from typing import Any, Dict, List, Optional, Tuple, Union | ||
|
|
||
| from pydantic.v1 import PrivateAttr | ||
|
|
||
| from redisvl.utils.rerank.base import BaseReranker | ||
|
|
||
|
|
||
| class CohereReranker(BaseReranker): | ||
| """ | ||
| The CohereReranker class uses Cohere's API to rerank documents based on an | ||
| input query. | ||
|
|
||
| This reranker is designed to interact with Cohere's /rerank API, | ||
| requiring an API key for authentication. The key can be provided | ||
| directly in the `api_config` dictionary or through the `COHERE_API_KEY` | ||
| environment variable. User must obtain an API key from Cohere's website | ||
| (https://dashboard.cohere.com/). Additionally, the `cohere` python | ||
| client must be installed with `pip install cohere`. | ||
|
|
||
| .. code-block:: python | ||
|
|
||
|
|
||
| """ | ||
|
|
||
| _client: Any = PrivateAttr() | ||
| _aclient: Any = PrivateAttr() | ||
|
|
||
| def __init__( | ||
| self, | ||
| model: str = "rerank-english-v3.0", | ||
| rank_by: Optional[List[str]] = None, | ||
| limit: int = 5, | ||
| return_score: bool = True, | ||
| api_config: Optional[Dict] = None, | ||
| ) -> None: | ||
| """ | ||
| Initialize the CohereReranker with specified model, ranking criteria, | ||
| and API configuration. | ||
|
|
||
| Parameters: | ||
| model (str): The identifier for the Cohere model used for reranking. | ||
| Defaults to 'rerank-english-v3.0'. | ||
| rank_by (Optional[List[str]]): Optional list of keys specifying the | ||
| attributes in the documents that should be considered for | ||
| ranking. None means ranking will rely on the model's default | ||
| behavior. | ||
| limit (int): The maximum number of results to return after | ||
| reranking. Must be a positive integer. | ||
| return_score (bool): Whether to return scores alongside the | ||
| reranked results. | ||
| api_config (Optional[Dict], optional): Dictionary containing the API key. | ||
| Defaults to None. | ||
|
|
||
| Raises: | ||
| ImportError: If the cohere library is not installed. | ||
| ValueError: If the API key is not provided. | ||
| """ | ||
| super().__init__( | ||
| model=model, rank_by=rank_by, limit=limit, return_score=return_score | ||
| ) | ||
| self._initialize_clients(api_config) | ||
|
|
||
| def _initialize_clients(self, api_config: Optional[Dict]): | ||
| """ | ||
| Setup the Cohere clients using the provided API key or an | ||
| environment variable. | ||
| """ | ||
| # Dynamic import of the cohere module | ||
| try: | ||
| from cohere import AsyncClient, Client | ||
| except ImportError: | ||
| raise ImportError( | ||
| "Cohere vectorizer requires the cohere library. \ | ||
| Please install with `pip install cohere`" | ||
| ) | ||
|
|
||
| # Fetch the API key from api_config or environment variable | ||
| api_key = ( | ||
| api_config.get("api_key") if api_config else os.getenv("COHERE_API_KEY") | ||
| ) | ||
| if not api_key: | ||
| raise ValueError( | ||
| "Cohere API key is required. " | ||
| "Provide it in api_config or set the COHERE_API_KEY environment variable." | ||
| ) | ||
| self._client = Client(api_key=api_key, client_name="redisvl") | ||
| self._aclient = AsyncClient(api_key=api_key, client_name="redisvl") | ||
|
|
||
| def _preprocess( | ||
| self, query: str, docs: Union[List[Dict[str, Any]], List[str]], **kwargs | ||
| ): | ||
| """ | ||
| Prepare and validate reranking config based on provided input and | ||
| optional overrides. | ||
| """ | ||
| limit = kwargs.get("limit", self.limit) | ||
| return_score = kwargs.get("return_score", self.return_score) | ||
| max_chunks_per_doc = kwargs.get("max_chunks_per_doc") | ||
| rank_by = kwargs.get("rank_by", self.rank_by) or [] | ||
| rank_by = [rank_by] if isinstance(rank_by, str) else rank_by | ||
|
|
||
| reranker_kwargs = { | ||
| "model": self.model, | ||
| "query": query, | ||
| "top_n": limit, | ||
| "documents": docs, | ||
| "max_chunks_per_doc": max_chunks_per_doc, | ||
| } | ||
| # if we are working with list of dicts | ||
| if all(isinstance(doc, dict) for doc in docs): | ||
| if rank_by: | ||
| reranker_kwargs["rank_fields"] = rank_by | ||
| else: | ||
| raise ValueError( | ||
| "If reranking dictionary-like docs, " | ||
| "you must provide a list of rank_by fields" | ||
| ) | ||
|
|
||
| return reranker_kwargs, return_score | ||
|
|
||
| @staticmethod | ||
| def _postprocess( | ||
| docs: Union[List[Dict[str, Any]], List[str]], | ||
| rankings: List[Any], | ||
| ) -> Tuple[List[Any], List[float]]: | ||
| """ | ||
| Post-process the initial list of documents to include ranking scores, | ||
| if specified. | ||
| """ | ||
| reranked_docs, scores = [], [] | ||
| for item in rankings.results: # type: ignore | ||
| scores.append(item.relevance_score) | ||
| reranked_docs.append(docs[item.index]) | ||
| return reranked_docs, scores | ||
|
|
||
| def rank( | ||
| self, query: str, docs: Union[List[Dict[str, Any]], List[str]], **kwargs | ||
| ) -> Union[Tuple[List[Dict[str, Any]], List[float]], List[Dict[str, Any]]]: | ||
| """ | ||
| Rerank documents based on the provided query using the Cohere rerank API. | ||
|
|
||
| This method processes the user's query and the provided documents to | ||
| rerank them in a manner that is potentially more relevant to the | ||
| query's context. | ||
|
|
||
| Parameters: | ||
| query (str): The user's search query. | ||
| docs (Union[List[Dict[str, Any]], List[str]]): The list of documents | ||
| to be ranked, either as dictionaries or strings. | ||
|
|
||
| Returns: | ||
| Union[Tuple[Union[List[Dict[str, Any]], List[str]], float], List[Dict[str, Any]]]: The reranked list of documents and optionally associated scores. | ||
| """ | ||
| reranker_kwargs, return_score = self._preprocess(query, docs, **kwargs) | ||
| rankings = self._client.rerank(**reranker_kwargs) | ||
| reranked_docs, scores = self._postprocess(docs, rankings) | ||
| if return_score: | ||
| return reranked_docs, scores | ||
| return reranked_docs | ||
|
|
||
| async def arank( | ||
| self, query: str, docs: Union[List[Dict[str, Any]], List[str]], **kwargs | ||
| ) -> Union[Tuple[List[Dict[str, Any]], List[float]], List[Dict[str, Any]]]: | ||
| """ | ||
| Rerank documents based on the provided query using the Cohere rerank API. | ||
|
|
||
| This method processes the user's query and the provided documents to | ||
| rerank them in a manner that is potentially more relevant to the | ||
| query's context. | ||
|
|
||
| Parameters: | ||
| query (str): The user's search query. | ||
| docs (Union[List[Dict[str, Any]], List[str]]): The list of documents | ||
| to be ranked, either as dictionaries or strings. | ||
|
|
||
| Returns: | ||
| Union[Tuple[Union[List[Dict[str, Any]], List[str]], float], List[Dict[str, Any]]]: The reranked list of documents and optionally associated scores. | ||
| """ | ||
| reranker_kwargs, return_score = self._preprocess(query, docs, **kwargs) | ||
| rankings = await self._aclient.rerank(**reranker_kwargs) | ||
| reranked_docs, scores = self._postprocess(docs, rankings) | ||
| if return_score: | ||
| return reranked_docs, scores | ||
| return reranked_docs | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.