Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ranking_score_threshold parameter to search #955

Merged
merged 9 commits into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docker-compose.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
services:
meilisearch:
# image: getmeili/meilisearch:latest
image: getmeili/meilisearch:v1.9.0-rc.1
image: getmeili/meilisearch:v1.9.0-rc.2
ports:
- "7700:7700"
environment:
Expand Down
23 changes: 23 additions & 0 deletions meilisearch_python_sdk/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,7 @@ async def search(
attributes_to_search_on: list[str] | None = None,
show_ranking_score: bool = False,
show_ranking_score_details: bool = False,
ranking_score_threshold: float | None = None,
vector: list[float] | None = None,
hybrid: Hybrid | None = None,
) -> SearchResults:
Expand Down Expand Up @@ -777,6 +778,9 @@ async def search(
Because this feature is experimental it may be removed or updated causing breaking
changes in this library without a major version bump so use with caution. This
feature became stable in Meiliseach v1.7.0.
ranking_score_threshold: If set, no document whose _rankingScore is under the
rankingScoreThreshold is returned. The value must be between 0.0 and 1.0. Defaults
to None.
vector: List of vectors for vector search. Defaults to None. Note: This parameter can
only be used with Meilisearch >= v1.3.0, and is experimental in Meilisearch v1.3.0.
In order to use this feature in Meilisearch v1.3.0 you first need to enable the
Expand Down Expand Up @@ -808,6 +812,9 @@ async def search(
>>> index = client.index("movies")
>>> search_results = await index.search("Tron")
"""
if ranking_score_threshold:
_validate_ranking_score_threshold(ranking_score_threshold)

body = _process_search_parameters(
q=query,
offset=offset,
Expand All @@ -831,6 +838,7 @@ async def search(
show_ranking_score_details=show_ranking_score_details,
vector=vector,
hybrid=hybrid,
ranking_score_threshold=ranking_score_threshold,
)
search_url = f"{self._base_url_with_uid}/search"

Expand Down Expand Up @@ -4909,6 +4917,7 @@ def search(
attributes_to_search_on: list[str] | None = None,
show_ranking_score: bool = False,
show_ranking_score_details: bool = False,
ranking_score_threshold: float | None = None,
vector: list[float] | None = None,
hybrid: Hybrid | None = None,
) -> SearchResults:
Expand Down Expand Up @@ -4949,6 +4958,9 @@ def search(
Because this feature is experimental it may be removed or updated causing breaking
changes in this library without a major version bump so use with caution. This
feature became stable in Meiliseach v1.7.0.
ranking_score_threshold: If set, no document whose _rankingScore is under the
rankingScoreThreshold is returned. The value must be between 0.0 and 1.0. Defaults
to None.
vector: List of vectors for vector search. Defaults to None. Note: This parameter can
only be used with Meilisearch >= v1.3.0, and is experimental in Meilisearch v1.3.0.
In order to use this feature in Meilisearch v1.3.0 you first need to enable the
Expand Down Expand Up @@ -4980,6 +4992,9 @@ def search(
>>> index = client.index("movies")
>>> search_results = index.search("Tron")
"""
if ranking_score_threshold:
_validate_ranking_score_threshold(ranking_score_threshold)

body = _process_search_parameters(
q=query,
offset=offset,
Expand All @@ -5003,6 +5018,7 @@ def search(
show_ranking_score_details=show_ranking_score_details,
vector=vector,
hybrid=hybrid,
ranking_score_threshold=ranking_score_threshold,
)

if self._pre_search_plugins:
Expand Down Expand Up @@ -7999,6 +8015,7 @@ def _process_search_parameters(
attributes_to_search_on: list[str] | None = None,
show_ranking_score: bool = False,
show_ranking_score_details: bool = False,
ranking_score_threshold: float | None = None,
vector: list[float] | None = None,
hybrid: Hybrid | None = None,
) -> JsonDict:
Expand All @@ -8025,6 +8042,7 @@ def _process_search_parameters(
"page": page,
"attributesToSearchOn": attributes_to_search_on,
"showRankingScore": show_ranking_score,
"rankingScoreThreshold": ranking_score_threshold,
}

if facet_name:
Expand Down Expand Up @@ -8118,3 +8136,8 @@ def _embedder_json_to_settings_model( # pragma: no cover
def _validate_file_type(file_path: Path) -> None:
if file_path.suffix not in (".json", ".csv", ".ndjson"):
raise MeilisearchError("File must be a json, ndjson, or csv file")


def _validate_ranking_score_threshold(ranking_score_threshold: float) -> None:
if not 0.0 <= ranking_score_threshold <= 1.0:
raise MeilisearchError("ranking_score_threshold must be between 0.0 and 1.0")
33 changes: 31 additions & 2 deletions meilisearch_python_sdk/models/search.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from typing import List, Optional
from warnings import warn

import pydantic
from camel_converter.pydantic_base import CamelBase
from pydantic import Field

from meilisearch_python_sdk._utils import is_pydantic_2
from meilisearch_python_sdk.errors import MeilisearchError
from meilisearch_python_sdk.types import Filter, JsonDict


Expand All @@ -24,7 +27,7 @@ class Hybrid(CamelBase):

class SearchParams(CamelBase):
index_uid: str
query: Optional[str] = Field(None, alias="q")
query: Optional[str] = pydantic.Field(None, alias="q")
offset: int = 0
limit: int = 20
filter: Optional[Filter] = None
Expand All @@ -44,9 +47,35 @@ class SearchParams(CamelBase):
attributes_to_search_on: Optional[List[str]] = None
show_ranking_score: bool = False
show_ranking_score_details: bool = False
ranking_score_threshold: Optional[float] = None
vector: Optional[List[float]] = None
hybrid: Optional[Hybrid] = None

if is_pydantic_2():

@pydantic.field_validator("ranking_score_threshold", mode="before") # type: ignore[attr-defined]
@classmethod
def validate_ranking_score_threshold(cls, v: Optional[float]) -> Optional[float]:
if v and not 0.0 <= v <= 1.0:
raise MeilisearchError("ranking_score_threshold must be between 0.0 and 1.0")

return v

else: # pragma: no cover
warn(
"The use of Pydantic less than version 2 is depreciated and will be removed in a future release",
DeprecationWarning,
stacklevel=2,
)

@pydantic.validator("ranking_score_threshold", pre=True)
@classmethod
def validate_expires_at(cls, v: Optional[float]) -> Optional[float]:
if v and not 0.0 <= v <= 1.0:
raise MeilisearchError("ranking_score_threshold must be between 0.0 and 1.0")

return v


class SearchResults(CamelBase):
hits: List[JsonDict]
Expand Down
48 changes: 47 additions & 1 deletion tests/test_async_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from meilisearch_python_sdk import AsyncClient
from meilisearch_python_sdk._task import async_wait_for_task
from meilisearch_python_sdk.errors import MeilisearchApiError
from meilisearch_python_sdk.errors import MeilisearchApiError, MeilisearchError
from meilisearch_python_sdk.models.search import Hybrid, SearchParams


Expand Down Expand Up @@ -405,3 +405,49 @@ async def test_custom_facet_search(async_index_with_documents):
)
assert response.facet_hits[0].value == "cartoon"
assert response.facet_hits[0].count == 1


@pytest.mark.parametrize("ranking_score_threshold", (-0.1, 1.1))
@pytest.mark.usefixtures("enable_vector_search")
async def test_search_invalid_ranking_score_threshold(
ranking_score_threshold, async_index_with_documents
):
index = await async_index_with_documents()
with pytest.raises(MeilisearchError) as e:
await index.search("", ranking_score_threshold=ranking_score_threshold)
assert "ranking_score_threshold must be between 0.0 and 1.0" in str(e.value)


@pytest.mark.usefixtures("enable_vector_search")
async def test_search_ranking_score_threshold(async_index_with_documents_and_vectors):
index = await async_index_with_documents_and_vectors()
result = await index.search("", ranking_score_threshold=0.5)
assert len(result.hits) > 0


@pytest.mark.parametrize("ranking_score_threshold", (-0.1, 1.1))
@pytest.mark.usefixtures("enable_vector_search")
async def test_multi_search_invalid_ranking_score_threshold(
ranking_score_threshold, async_client, async_index_with_documents
):
index1 = await async_index_with_documents()
with pytest.raises(MeilisearchError) as e:
await async_client.multi_search(
[
SearchParams(
index_uid=index1.uid, query="", ranking_score_threshold=ranking_score_threshold
),
]
)
assert "ranking_score_threshold must be between 0.0 and 1.0" in str(e.value)


@pytest.mark.usefixtures("enable_vector_search")
async def test_multi_search_ranking_score_threshold(async_client, async_index_with_documents):
index1 = await async_index_with_documents()
result = await async_client.multi_search(
[
SearchParams(index_uid=index1.uid, query="", ranking_score_threshold=0.5),
]
)
assert len(result[0].hits) > 0
48 changes: 47 additions & 1 deletion tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from meilisearch_python_sdk import Client
from meilisearch_python_sdk._task import wait_for_task
from meilisearch_python_sdk.errors import MeilisearchApiError
from meilisearch_python_sdk.errors import MeilisearchApiError, MeilisearchError
from meilisearch_python_sdk.models.search import Hybrid, SearchParams


Expand Down Expand Up @@ -406,3 +406,49 @@ def test_custom_facet_search(index_with_documents):
)
assert response.facet_hits[0].value == "cartoon"
assert response.facet_hits[0].count == 1


@pytest.mark.parametrize("ranking_score_threshold", (-0.1, 1.1))
@pytest.mark.usefixtures("enable_vector_search")
def test_search_invalid_ranking_score_threshold(
ranking_score_threshold, index_with_documents_and_vectors
):
index = index_with_documents_and_vectors()
with pytest.raises(MeilisearchError) as e:
index.search("", ranking_score_threshold=ranking_score_threshold)
assert "ranking_score_threshold must be between 0.0 and 1.0" in str(e.value)


@pytest.mark.usefixtures("enable_vector_search")
def test_search_ranking_score_threshold(index_with_documents_and_vectors):
index = index_with_documents_and_vectors()
result = index.search("", ranking_score_threshold=0.5)
assert len(result.hits) > 0


@pytest.mark.parametrize("ranking_score_threshold", (-0.1, 1.1))
@pytest.mark.usefixtures("enable_vector_search")
def test_multi_search_invalid_ranking_score_threshold(
ranking_score_threshold, client, index_with_documents
):
index1 = index_with_documents()
with pytest.raises(MeilisearchError) as e:
client.multi_search(
[
SearchParams(
index_uid=index1.uid, query="", ranking_score_threshold=ranking_score_threshold
),
]
)
assert "ranking_score_threshold must be between 0.0 and 1.0" in str(e.value)


@pytest.mark.usefixtures("enable_vector_search")
def test_multi_search_ranking_score_threshold(client, index_with_documents):
index1 = index_with_documents()
result = client.multi_search(
[
SearchParams(index_uid=index1.uid, query="", ranking_score_threshold=0.5),
]
)
assert len(result[0].hits) > 0