Skip to content

Commit

Permalink
Add ranking_score_threshold parameter to search (#955)
Browse files Browse the repository at this point in the history
* Add ranking_score_threshold parameter to search

* Fix mypy error

* Change to ValidationError

* Fix tests

* Fix tests

* Update tests

* Fix range check

* Fix type

* Add more tests
  • Loading branch information
sanders41 committed Jun 10, 2024
1 parent 33f11c4 commit 9ff0c65
Show file tree
Hide file tree
Showing 5 changed files with 149 additions and 5 deletions.
2 changes: 1 addition & 1 deletion docker-compose.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
services:
meilisearch:
# image: getmeili/meilisearch:latest
image: getmeili/meilisearch:v1.9.0-rc.1
image: getmeili/meilisearch:v1.9.0-rc.2
ports:
- "7700:7700"
environment:
Expand Down
23 changes: 23 additions & 0 deletions meilisearch_python_sdk/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,7 @@ async def search(
attributes_to_search_on: list[str] | None = None,
show_ranking_score: bool = False,
show_ranking_score_details: bool = False,
ranking_score_threshold: float | None = None,
vector: list[float] | None = None,
hybrid: Hybrid | None = None,
) -> SearchResults:
Expand Down Expand Up @@ -777,6 +778,9 @@ async def search(
Because this feature is experimental it may be removed or updated causing breaking
changes in this library without a major version bump so use with caution. This
feature became stable in Meiliseach v1.7.0.
ranking_score_threshold: If set, no document whose _rankingScore is under the
rankingScoreThreshold is returned. The value must be between 0.0 and 1.0. Defaults
to None.
vector: List of vectors for vector search. Defaults to None. Note: This parameter can
only be used with Meilisearch >= v1.3.0, and is experimental in Meilisearch v1.3.0.
In order to use this feature in Meilisearch v1.3.0 you first need to enable the
Expand Down Expand Up @@ -808,6 +812,9 @@ async def search(
>>> index = client.index("movies")
>>> search_results = await index.search("Tron")
"""
if ranking_score_threshold:
_validate_ranking_score_threshold(ranking_score_threshold)

body = _process_search_parameters(
q=query,
offset=offset,
Expand All @@ -831,6 +838,7 @@ async def search(
show_ranking_score_details=show_ranking_score_details,
vector=vector,
hybrid=hybrid,
ranking_score_threshold=ranking_score_threshold,
)
search_url = f"{self._base_url_with_uid}/search"

Expand Down Expand Up @@ -4909,6 +4917,7 @@ def search(
attributes_to_search_on: list[str] | None = None,
show_ranking_score: bool = False,
show_ranking_score_details: bool = False,
ranking_score_threshold: float | None = None,
vector: list[float] | None = None,
hybrid: Hybrid | None = None,
) -> SearchResults:
Expand Down Expand Up @@ -4949,6 +4958,9 @@ def search(
Because this feature is experimental it may be removed or updated causing breaking
changes in this library without a major version bump so use with caution. This
feature became stable in Meiliseach v1.7.0.
ranking_score_threshold: If set, no document whose _rankingScore is under the
rankingScoreThreshold is returned. The value must be between 0.0 and 1.0. Defaults
to None.
vector: List of vectors for vector search. Defaults to None. Note: This parameter can
only be used with Meilisearch >= v1.3.0, and is experimental in Meilisearch v1.3.0.
In order to use this feature in Meilisearch v1.3.0 you first need to enable the
Expand Down Expand Up @@ -4980,6 +4992,9 @@ def search(
>>> index = client.index("movies")
>>> search_results = index.search("Tron")
"""
if ranking_score_threshold:
_validate_ranking_score_threshold(ranking_score_threshold)

body = _process_search_parameters(
q=query,
offset=offset,
Expand All @@ -5003,6 +5018,7 @@ def search(
show_ranking_score_details=show_ranking_score_details,
vector=vector,
hybrid=hybrid,
ranking_score_threshold=ranking_score_threshold,
)

if self._pre_search_plugins:
Expand Down Expand Up @@ -7999,6 +8015,7 @@ def _process_search_parameters(
attributes_to_search_on: list[str] | None = None,
show_ranking_score: bool = False,
show_ranking_score_details: bool = False,
ranking_score_threshold: float | None = None,
vector: list[float] | None = None,
hybrid: Hybrid | None = None,
) -> JsonDict:
Expand All @@ -8025,6 +8042,7 @@ def _process_search_parameters(
"page": page,
"attributesToSearchOn": attributes_to_search_on,
"showRankingScore": show_ranking_score,
"rankingScoreThreshold": ranking_score_threshold,
}

if facet_name:
Expand Down Expand Up @@ -8118,3 +8136,8 @@ def _embedder_json_to_settings_model( # pragma: no cover
def _validate_file_type(file_path: Path) -> None:
if file_path.suffix not in (".json", ".csv", ".ndjson"):
raise MeilisearchError("File must be a json, ndjson, or csv file")


def _validate_ranking_score_threshold(ranking_score_threshold: float) -> None:
if not 0.0 <= ranking_score_threshold <= 1.0:
raise MeilisearchError("ranking_score_threshold must be between 0.0 and 1.0")
33 changes: 31 additions & 2 deletions meilisearch_python_sdk/models/search.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from typing import List, Optional
from warnings import warn

import pydantic
from camel_converter.pydantic_base import CamelBase
from pydantic import Field

from meilisearch_python_sdk._utils import is_pydantic_2
from meilisearch_python_sdk.errors import MeilisearchError
from meilisearch_python_sdk.types import Filter, JsonDict


Expand All @@ -24,7 +27,7 @@ class Hybrid(CamelBase):

class SearchParams(CamelBase):
index_uid: str
query: Optional[str] = Field(None, alias="q")
query: Optional[str] = pydantic.Field(None, alias="q")
offset: int = 0
limit: int = 20
filter: Optional[Filter] = None
Expand All @@ -44,9 +47,35 @@ class SearchParams(CamelBase):
attributes_to_search_on: Optional[List[str]] = None
show_ranking_score: bool = False
show_ranking_score_details: bool = False
ranking_score_threshold: Optional[float] = None
vector: Optional[List[float]] = None
hybrid: Optional[Hybrid] = None

if is_pydantic_2():

@pydantic.field_validator("ranking_score_threshold", mode="before") # type: ignore[attr-defined]
@classmethod
def validate_ranking_score_threshold(cls, v: Optional[float]) -> Optional[float]:
if v and not 0.0 <= v <= 1.0:
raise MeilisearchError("ranking_score_threshold must be between 0.0 and 1.0")

return v

else: # pragma: no cover
warn(
"The use of Pydantic less than version 2 is depreciated and will be removed in a future release",
DeprecationWarning,
stacklevel=2,
)

@pydantic.validator("ranking_score_threshold", pre=True)
@classmethod
def validate_expires_at(cls, v: Optional[float]) -> Optional[float]:
if v and not 0.0 <= v <= 1.0:
raise MeilisearchError("ranking_score_threshold must be between 0.0 and 1.0")

return v


class SearchResults(CamelBase):
hits: List[JsonDict]
Expand Down
48 changes: 47 additions & 1 deletion tests/test_async_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from meilisearch_python_sdk import AsyncClient
from meilisearch_python_sdk._task import async_wait_for_task
from meilisearch_python_sdk.errors import MeilisearchApiError
from meilisearch_python_sdk.errors import MeilisearchApiError, MeilisearchError
from meilisearch_python_sdk.models.search import Hybrid, SearchParams


Expand Down Expand Up @@ -405,3 +405,49 @@ async def test_custom_facet_search(async_index_with_documents):
)
assert response.facet_hits[0].value == "cartoon"
assert response.facet_hits[0].count == 1


@pytest.mark.parametrize("ranking_score_threshold", (-0.1, 1.1))
@pytest.mark.usefixtures("enable_vector_search")
async def test_search_invalid_ranking_score_threshold(
ranking_score_threshold, async_index_with_documents
):
index = await async_index_with_documents()
with pytest.raises(MeilisearchError) as e:
await index.search("", ranking_score_threshold=ranking_score_threshold)
assert "ranking_score_threshold must be between 0.0 and 1.0" in str(e.value)


@pytest.mark.usefixtures("enable_vector_search")
async def test_search_ranking_score_threshold(async_index_with_documents_and_vectors):
index = await async_index_with_documents_and_vectors()
result = await index.search("", ranking_score_threshold=0.5)
assert len(result.hits) > 0


@pytest.mark.parametrize("ranking_score_threshold", (-0.1, 1.1))
@pytest.mark.usefixtures("enable_vector_search")
async def test_multi_search_invalid_ranking_score_threshold(
ranking_score_threshold, async_client, async_index_with_documents
):
index1 = await async_index_with_documents()
with pytest.raises(MeilisearchError) as e:
await async_client.multi_search(
[
SearchParams(
index_uid=index1.uid, query="", ranking_score_threshold=ranking_score_threshold
),
]
)
assert "ranking_score_threshold must be between 0.0 and 1.0" in str(e.value)


@pytest.mark.usefixtures("enable_vector_search")
async def test_multi_search_ranking_score_threshold(async_client, async_index_with_documents):
index1 = await async_index_with_documents()
result = await async_client.multi_search(
[
SearchParams(index_uid=index1.uid, query="", ranking_score_threshold=0.5),
]
)
assert len(result[0].hits) > 0
48 changes: 47 additions & 1 deletion tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from meilisearch_python_sdk import Client
from meilisearch_python_sdk._task import wait_for_task
from meilisearch_python_sdk.errors import MeilisearchApiError
from meilisearch_python_sdk.errors import MeilisearchApiError, MeilisearchError
from meilisearch_python_sdk.models.search import Hybrid, SearchParams


Expand Down Expand Up @@ -406,3 +406,49 @@ def test_custom_facet_search(index_with_documents):
)
assert response.facet_hits[0].value == "cartoon"
assert response.facet_hits[0].count == 1


@pytest.mark.parametrize("ranking_score_threshold", (-0.1, 1.1))
@pytest.mark.usefixtures("enable_vector_search")
def test_search_invalid_ranking_score_threshold(
ranking_score_threshold, index_with_documents_and_vectors
):
index = index_with_documents_and_vectors()
with pytest.raises(MeilisearchError) as e:
index.search("", ranking_score_threshold=ranking_score_threshold)
assert "ranking_score_threshold must be between 0.0 and 1.0" in str(e.value)


@pytest.mark.usefixtures("enable_vector_search")
def test_search_ranking_score_threshold(index_with_documents_and_vectors):
index = index_with_documents_and_vectors()
result = index.search("", ranking_score_threshold=0.5)
assert len(result.hits) > 0


@pytest.mark.parametrize("ranking_score_threshold", (-0.1, 1.1))
@pytest.mark.usefixtures("enable_vector_search")
def test_multi_search_invalid_ranking_score_threshold(
ranking_score_threshold, client, index_with_documents
):
index1 = index_with_documents()
with pytest.raises(MeilisearchError) as e:
client.multi_search(
[
SearchParams(
index_uid=index1.uid, query="", ranking_score_threshold=ranking_score_threshold
),
]
)
assert "ranking_score_threshold must be between 0.0 and 1.0" in str(e.value)


@pytest.mark.usefixtures("enable_vector_search")
def test_multi_search_ranking_score_threshold(client, index_with_documents):
index1 = index_with_documents()
result = client.multi_search(
[
SearchParams(index_uid=index1.uid, query="", ranking_score_threshold=0.5),
]
)
assert len(result[0].hits) > 0

0 comments on commit 9ff0c65

Please sign in to comment.