Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ jobs:
- name: Run tests
if: matrix.connection != 'plain' || matrix.redis-stack-version != 'latest'
run: |
SKIP_VECTORIZERS=True poetry run test-cov
SKIP_VECTORIZERS=True SKIP_RERANKERS=True poetry run test-cov

- name: Run notebooks
if: matrix.connection == 'plain' && matrix.redis-stack-version == 'latest'
Expand Down
5 changes: 5 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ Tests w/out vectorizers:
SKIP_VECTORIZERS=true poetry run test-cov
```

Tests w/out rerankers:
```bash
SKIP_RERANKERS=true poetry run test-cov
```

### Getting Redis

In order for your applications to use RedisVL, you must have [Redis](https://redis.io) accessible with Search & Query features enabled on [Redis Cloud](https://redis.com/try-free) or locally in docker with [Redis Stack](https://redis.io/docs/getting-started/install-stack/docker/):
Expand Down
12 changes: 10 additions & 2 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@
from redisvl.redis.connection import RedisConnectionFactory
from testcontainers.compose import DockerCompose


# @pytest.fixture(scope="session")
# def event_loop():
# loop = asyncio.get_event_loop_policy().new_event_loop()
# yield loop
# loop.close()


@pytest.fixture(scope="session", autouse=True)
def redis_container():
# Set the default Redis version if not already set
Expand All @@ -25,7 +33,7 @@ def redis_container():
def redis_url():
return os.getenv("REDIS_URL", "redis://localhost:6379")

@pytest.fixture(scope="session")
@pytest.fixture
async def async_client(redis_url):
client = await RedisConnectionFactory.get_async_redis_connection(redis_url)
yield client
Expand All @@ -35,7 +43,7 @@ async def async_client(redis_url):
if "Event loop is closed" not in str(e):
raise

@pytest.fixture(scope="session")
@pytest.fixture
def client():
conn = RedisConnectionFactory.get_redis_connection(os.environ["REDIS_URL"])
yield conn
Expand Down
1 change: 1 addition & 0 deletions docs/api/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ searchindex
query
filter
vectorizer
reranker
cache
```

14 changes: 14 additions & 0 deletions docs/api/reranker.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
***********
Rerankers
***********

CohereReranker
================

.. _coherereranker_api:

.. currentmodule:: redisvl.utils.rerank.cohere

.. autoclass:: CohereReranker
:show-inheritance:
:members:
1 change: 0 additions & 1 deletion docs/api/vectorizer.rst
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

***********
Vectorizers
***********
Expand Down
7 changes: 7 additions & 0 deletions redisvl/utils/rerank/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from redisvl.utils.rerank.base import BaseReranker
from redisvl.utils.rerank.cohere import CohereReranker

__all__ = [
"BaseReranker",
"CohereReranker",
]
48 changes: 48 additions & 0 deletions redisvl/utils/rerank/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple, Union

from pydantic.v1 import BaseModel, validator


class BaseReranker(BaseModel, ABC):
model: str
rank_by: Optional[List[str]] = None
limit: int
return_score: bool

@validator("limit")
@classmethod
def check_limit(cls, value):
"""Ensures the limit is a positive integer."""
if value <= 0:
raise ValueError("Limit must be a positive integer.")
return value

@validator("rank_by")
@classmethod
def check_rank_by(cls, value):
"""Ensures that rank_by is a list of strings if provided."""
if value is not None and (
not isinstance(value, list)
or any(not isinstance(item, str) for item in value)
):
raise ValueError("rank_by must be a list of strings.")
return value

@abstractmethod
def rank(
self, query: str, docs: Union[List[Dict[str, Any]], List[str]], **kwargs
) -> Union[Tuple[List[Dict[str, Any]], List[float]], List[Dict[str, Any]]]:
"""
Synchronously rerank the docs based on the provided query.
"""
pass

@abstractmethod
async def arank(
self, query: str, docs: Union[List[Dict[str, Any]], List[str]], **kwargs
) -> Union[Tuple[List[Dict[str, Any]], List[float]], List[Dict[str, Any]]]:
"""
Asynchronously rerank the docs based on the provided query.
"""
pass
185 changes: 185 additions & 0 deletions redisvl/utils/rerank/cohere.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
import os
from typing import Any, Dict, List, Optional, Tuple, Union

from pydantic.v1 import PrivateAttr

from redisvl.utils.rerank.base import BaseReranker


class CohereReranker(BaseReranker):
"""
The CohereReranker class uses Cohere's API to rerank documents based on an
input query.

This reranker is designed to interact with Cohere's /rerank API,
requiring an API key for authentication. The key can be provided
directly in the `api_config` dictionary or through the `COHERE_API_KEY`
environment variable. User must obtain an API key from Cohere's website
(https://dashboard.cohere.com/). Additionally, the `cohere` python
client must be installed with `pip install cohere`.

.. code-block:: python


"""

_client: Any = PrivateAttr()
_aclient: Any = PrivateAttr()

def __init__(
self,
model: str = "rerank-english-v3.0",
rank_by: Optional[List[str]] = None,
limit: int = 5,
return_score: bool = True,
api_config: Optional[Dict] = None,
) -> None:
"""
Initialize the CohereReranker with specified model, ranking criteria,
and API configuration.

Parameters:
model (str): The identifier for the Cohere model used for reranking.
Defaults to 'rerank-english-v3.0'.
rank_by (Optional[List[str]]): Optional list of keys specifying the
attributes in the documents that should be considered for
ranking. None means ranking will rely on the model's default
behavior.
limit (int): The maximum number of results to return after
reranking. Must be a positive integer.
return_score (bool): Whether to return scores alongside the
reranked results.
api_config (Optional[Dict], optional): Dictionary containing the API key.
Defaults to None.

Raises:
ImportError: If the cohere library is not installed.
ValueError: If the API key is not provided.
"""
super().__init__(
model=model, rank_by=rank_by, limit=limit, return_score=return_score
)
self._initialize_clients(api_config)

def _initialize_clients(self, api_config: Optional[Dict]):
"""
Setup the Cohere clients using the provided API key or an
environment variable.
"""
# Dynamic import of the cohere module
try:
from cohere import AsyncClient, Client
except ImportError:
raise ImportError(
"Cohere vectorizer requires the cohere library. \
Please install with `pip install cohere`"
)

# Fetch the API key from api_config or environment variable
api_key = (
api_config.get("api_key") if api_config else os.getenv("COHERE_API_KEY")
)
if not api_key:
raise ValueError(
"Cohere API key is required. "
"Provide it in api_config or set the COHERE_API_KEY environment variable."
)
self._client = Client(api_key=api_key, client_name="redisvl")
self._aclient = AsyncClient(api_key=api_key, client_name="redisvl")

def _preprocess(
self, query: str, docs: Union[List[Dict[str, Any]], List[str]], **kwargs
):
"""
Prepare and validate reranking config based on provided input and
optional overrides.
"""
limit = kwargs.get("limit", self.limit)
return_score = kwargs.get("return_score", self.return_score)
max_chunks_per_doc = kwargs.get("max_chunks_per_doc")
rank_by = kwargs.get("rank_by", self.rank_by) or []
rank_by = [rank_by] if isinstance(rank_by, str) else rank_by

reranker_kwargs = {
"model": self.model,
"query": query,
"top_n": limit,
"documents": docs,
"max_chunks_per_doc": max_chunks_per_doc,
}
# if we are working with list of dicts
if all(isinstance(doc, dict) for doc in docs):
if rank_by:
reranker_kwargs["rank_fields"] = rank_by
else:
raise ValueError(
"If reranking dictionary-like docs, "
"you must provide a list of rank_by fields"
)

return reranker_kwargs, return_score

@staticmethod
def _postprocess(
docs: Union[List[Dict[str, Any]], List[str]],
rankings: List[Any],
) -> Tuple[List[Any], List[float]]:
"""
Post-process the initial list of documents to include ranking scores,
if specified.
"""
reranked_docs, scores = [], []
for item in rankings.results: # type: ignore
scores.append(item.relevance_score)
reranked_docs.append(docs[item.index])
return reranked_docs, scores

def rank(
self, query: str, docs: Union[List[Dict[str, Any]], List[str]], **kwargs
) -> Union[Tuple[List[Dict[str, Any]], List[float]], List[Dict[str, Any]]]:
"""
Rerank documents based on the provided query using the Cohere rerank API.

This method processes the user's query and the provided documents to
rerank them in a manner that is potentially more relevant to the
query's context.

Parameters:
query (str): The user's search query.
docs (Union[List[Dict[str, Any]], List[str]]): The list of documents
to be ranked, either as dictionaries or strings.

Returns:
Union[Tuple[Union[List[Dict[str, Any]], List[str]], float], List[Dict[str, Any]]]: The reranked list of documents and optionally associated scores.
"""
reranker_kwargs, return_score = self._preprocess(query, docs, **kwargs)
rankings = self._client.rerank(**reranker_kwargs)
reranked_docs, scores = self._postprocess(docs, rankings)
if return_score:
return reranked_docs, scores
return reranked_docs

async def arank(
self, query: str, docs: Union[List[Dict[str, Any]], List[str]], **kwargs
) -> Union[Tuple[List[Dict[str, Any]], List[float]], List[Dict[str, Any]]]:
"""
Rerank documents based on the provided query using the Cohere rerank API.

This method processes the user's query and the provided documents to
rerank them in a manner that is potentially more relevant to the
query's context.

Parameters:
query (str): The user's search query.
docs (Union[List[Dict[str, Any]], List[str]]): The list of documents
to be ranked, either as dictionaries or strings.

Returns:
Union[Tuple[Union[List[Dict[str, Any]], List[str]], float], List[Dict[str, Any]]]: The reranked list of documents and optionally associated scores.
"""
reranker_kwargs, return_score = self._preprocess(query, docs, **kwargs)
rankings = await self._aclient.rerank(**reranker_kwargs)
reranked_docs, scores = self._postprocess(docs, rankings)
if return_score:
return reranked_docs, scores
return reranked_docs
21 changes: 13 additions & 8 deletions redisvl/utils/vectorize/base.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
from typing import Any, Callable, List, Optional
from abc import ABC, abstractmethod
from typing import Callable, List, Optional

from pydantic.v1 import BaseModel, validator

from redisvl.redis.utils import array_to_buffer


class BaseVectorizer(BaseModel):
class BaseVectorizer(BaseModel, ABC):
model: str
dims: int
client: Any

@validator("dims", pre=True)
@validator("dims")
@classmethod
def check_dims(cls, v):
if v <= 0:
raise ValueError("Dimension must be a positive integer")
return v
def check_dims(cls, value):
"""Ensures the dims are a positive integer."""
if value <= 0:
raise ValueError("Dims must be a positive integer.")
return value

@abstractmethod
def embed_many(
self,
texts: List[str],
Expand All @@ -27,6 +29,7 @@ def embed_many(
) -> List[List[float]]:
raise NotImplementedError

@abstractmethod
def embed(
self,
text: str,
Expand All @@ -36,6 +39,7 @@ def embed(
) -> List[float]:
raise NotImplementedError

@abstractmethod
async def aembed_many(
self,
texts: List[str],
Expand All @@ -46,6 +50,7 @@ async def aembed_many(
) -> List[List[float]]:
raise NotImplementedError

@abstractmethod
async def aembed(
self,
text: str,
Expand Down
Loading