Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 82 additions & 1 deletion pinecone/_internal/data_plane_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from collections.abc import Mapping
from collections.abc import Mapping, Sequence
from typing import Any

from pinecone._internal.config import normalize_host
Expand Down Expand Up @@ -60,6 +60,87 @@ def _normalize_search_vector_dict(vector: Mapping[str, Any]) -> dict[str, Any]:
return result


def _legacy_search_query_to_dict(query: Any) -> dict[str, Any]:
if hasattr(query, "to_dict") and callable(query.to_dict):
raw = query.to_dict()
elif hasattr(query, "as_dict") and callable(query.as_dict):
raw = query.as_dict()
else:
raw = dict(query)
return dict(raw)


def _build_search_records_body(
*,
top_k: int | None,
inputs: Mapping[str, Any] | None,
vector: Sequence[float] | Mapping[str, Any] | None,
id: str | None,
filter: Mapping[str, Any] | None,
fields: Sequence[str] | None,
rerank: Mapping[str, Any] | None,
match_terms: Mapping[str, Any] | None,
query: Any | None,
wrap_dense_vector: bool = True,
) -> dict[str, Any]:
if rerank is not None:
if "model" not in rerank:
raise ValidationError("rerank requires 'model' to be specified")
if "rank_fields" not in rerank:
raise ValidationError("rerank requires 'rank_fields' to be specified")

if query is not None:
if any(value is not None for value in (top_k, inputs, vector, id, filter, match_terms)):
raise ValidationError(
"query cannot be combined with top_k, inputs, vector, id, filter, or match_terms"
)
query_body = _legacy_search_query_to_dict(query)
if "vector" in query_body and query_body["vector"] is not None:
query_vector = query_body["vector"]
if isinstance(query_vector, Mapping):
query_body["vector"] = _normalize_search_vector_dict(query_vector)
else:
values = list(query_vector)
query_body["vector"] = {"values": values} if wrap_dense_vector else values
else:
if top_k is None:
raise ValidationError("top_k is required unless query is provided")
query_body = {"top_k": top_k}
if inputs is not None:
query_body["inputs"] = inputs
if vector is not None:
if isinstance(vector, Mapping):
query_body["vector"] = _normalize_search_vector_dict(vector)
else:
values = list(vector)
query_body["vector"] = {"values": values} if wrap_dense_vector else values
if id is not None:
query_body["id"] = id
if filter is not None:
query_body["filter"] = filter
if match_terms is not None:
query_body["match_terms"] = match_terms

top_k_value = query_body.get("top_k")
if not isinstance(top_k_value, int) or top_k_value < 1:
raise ValidationError(f"top_k must be a positive integer, got {top_k_value}")
if (
query_body.get("inputs") is None
and query_body.get("vector") is None
and query_body.get("id") is None
):
raise ValidationError(
"At least one of inputs, vector, or id must be provided as a query source"
)

body: dict[str, Any] = {"query": query_body}
if fields is not None:
body["fields"] = fields
if rerank is not None:
body["rerank"] = rerank
return body


def _vector_to_dict(v: Vector) -> dict[str, Any]:
"""Serialize a Vector to a dict matching the API wire format."""
id_ = v.id
Expand Down
64 changes: 27 additions & 37 deletions pinecone/async_client/async_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from pinecone._internal.config import PineconeConfig
from pinecone._internal.constants import DATA_PLANE_API_VERSION
from pinecone._internal.data_plane_helpers import (
_normalize_search_vector_dict,
_build_search_records_body,
_validate_host,
_vector_to_dict,
)
Expand All @@ -40,7 +40,12 @@
UpsertRecordsResponse,
UpsertResponse,
)
from pinecone.models.vectors.search import RerankConfig, SearchInputs, SearchRecordsResponse
from pinecone.models.vectors.search import (
RerankConfig,
SearchInputs,
SearchQuery,
SearchRecordsResponse,
)
from pinecone.models.vectors.sparse import SparseValues
from pinecone.models.vectors.vector import Vector

Expand Down Expand Up @@ -926,14 +931,15 @@ async def search(
self,
*,
namespace: str,
top_k: int,
top_k: int | None = None,
inputs: SearchInputs | Mapping[str, Any] | None = None,
vector: Sequence[float] | Mapping[str, Any] | None = None,
id: str | None = None,
filter: Mapping[str, Any] | None = None,
fields: Sequence[str] | None = None,
rerank: RerankConfig | Mapping[str, Any] | None = None,
match_terms: Mapping[str, Any] | None = None,
query: SearchQuery | Mapping[str, Any] | None = None,
timeout: float | None = None,
) -> SearchRecordsResponse:
"""Search records by text, vector, or ID with optional reranking.
Expand Down Expand Up @@ -967,6 +973,9 @@ async def search(
``"all"``) and ``"terms"`` (list of strings). Only supported
for sparse indexes using ``pinecone-sparse-english-v0``.
``None`` disables term matching.
query (dict[str, Any] | None): Legacy query body containing
``top_k`` plus one of ``inputs``, ``vector``, or ``id``. Prefer
passing these fields directly.

Returns:
:class:`SearchRecordsResponse` with hits and usage statistics.
Expand Down Expand Up @@ -995,40 +1004,19 @@ async def search(
raise ValidationError("namespace must be a string")
if not namespace or not namespace.strip():
raise ValidationError("namespace must be a non-empty string")
if top_k < 1:
raise ValidationError(f"top_k must be a positive integer, got {top_k}")
if rerank is not None:
if "model" not in rerank:
raise ValidationError("rerank requires 'model' to be specified")
if "rank_fields" not in rerank:
raise ValidationError("rerank requires 'rank_fields' to be specified")
if inputs is None and vector is None and id is None:
raise ValidationError(
"At least one of inputs, vector, or id must be provided as a query source"
)

query_body: dict[str, Any] = {"top_k": top_k}
if inputs is not None:
query_body["inputs"] = inputs
if vector is not None:
if isinstance(vector, Mapping):
query_body["vector"] = _normalize_search_vector_dict(vector)
else:
query_body["vector"] = {"values": list(vector)}
if id is not None:
query_body["id"] = id
if filter is not None:
query_body["filter"] = filter
if match_terms is not None:
query_body["match_terms"] = match_terms

body: dict[str, Any] = {"query": query_body}
if fields is not None:
body["fields"] = fields
if rerank is not None:
body["rerank"] = rerank
body = _build_search_records_body(
top_k=top_k,
inputs=inputs,
vector=vector,
id=id,
filter=filter,
fields=fields,
rerank=rerank,
match_terms=match_terms,
query=query,
)

logger.info("Searching namespace %r with top_k=%d", namespace, top_k)
logger.info("Searching namespace %r with top_k=%d", namespace, body["query"]["top_k"])
response = await self._http.post(
f"/records/namespaces/{namespace}/search", timeout=timeout, json=body
)
Expand All @@ -1040,14 +1028,15 @@ async def search_records(
self,
*,
namespace: str,
top_k: int,
top_k: int | None = None,
inputs: SearchInputs | Mapping[str, Any] | None = None,
vector: Sequence[float] | Mapping[str, Any] | None = None,
id: str | None = None,
filter: Mapping[str, Any] | None = None,
fields: Sequence[str] | None = None,
rerank: RerankConfig | Mapping[str, Any] | None = None,
match_terms: Mapping[str, Any] | None = None,
query: SearchQuery | Mapping[str, Any] | None = None,
timeout: float | None = None,
) -> SearchRecordsResponse:
"""Alias for :meth:`search`.
Expand All @@ -1063,6 +1052,7 @@ async def search_records(
fields=fields,
rerank=rerank,
match_terms=match_terms,
query=query,
timeout=timeout,
)

Expand Down
70 changes: 34 additions & 36 deletions pinecone/grpc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from pinecone._internal.batching import chunked, validate_batch_size, with_progress
from pinecone._internal.config import PineconeConfig
from pinecone._internal.constants import DATA_PLANE_API_VERSION
from pinecone._internal.data_plane_helpers import _validate_host
from pinecone._internal.data_plane_helpers import _build_search_records_body, _validate_host
from pinecone._internal.validation import require_in_range
from pinecone._internal.vector_factory import VectorFactory
from pinecone.errors.exceptions import (
Expand Down Expand Up @@ -45,7 +45,12 @@
UpsertRecordsResponse,
UpsertResponse,
)
from pinecone.models.vectors.search import RerankConfig, SearchInputs, SearchRecordsResponse
from pinecone.models.vectors.search import (
RerankConfig,
SearchInputs,
SearchQuery,
SearchRecordsResponse,
)
from pinecone.models.vectors.sparse import SparseValues
from pinecone.models.vectors.usage import Usage
from pinecone.models.vectors.vector import ScoredVector, Vector
Expand Down Expand Up @@ -1196,14 +1201,15 @@ def search(
self,
*,
namespace: str,
top_k: int,
top_k: int | None = None,
inputs: SearchInputs | Mapping[str, Any] | None = None,
vector: Sequence[float] | None = None,
vector: Sequence[float] | Mapping[str, Any] | None = None,
id: str | None = None,
filter: Mapping[str, Any] | None = None,
fields: Sequence[str] | None = None,
rerank: RerankConfig | Mapping[str, Any] | None = None,
match_terms: Mapping[str, Any] | None = None,
query: SearchQuery | Mapping[str, Any] | None = None,
timeout: float | None = None,
) -> SearchRecordsResponse:
"""Search records by text, vector, or ID with optional reranking.
Expand Down Expand Up @@ -1234,6 +1240,9 @@ def search(
``"all"``) and ``"terms"`` (list of strings). Only supported
for sparse indexes using ``pinecone-sparse-english-v0``.
``None`` disables term matching.
query (dict[str, Any] | None): Legacy query body containing
``top_k`` plus one of ``inputs``, ``vector``, or ``id``. Prefer
passing these fields directly.

Returns:
:class:`SearchRecordsResponse` with hits and usage statistics.
Expand Down Expand Up @@ -1284,37 +1293,24 @@ def search(
raise ValidationError("namespace must be a string")
if not namespace or not namespace.strip():
raise ValidationError("namespace must be a non-empty string")
if top_k < 1:
raise ValidationError(f"top_k must be a positive integer, got {top_k}")
if rerank is not None:
if "model" not in rerank:
raise ValidationError("rerank requires 'model' to be specified")
if "rank_fields" not in rerank:
raise ValidationError("rerank requires 'rank_fields' to be specified")
if inputs is None and vector is None and id is None:
raise ValidationError(
"At least one of inputs, vector, or id must be provided as a query source"
)
body = _build_search_records_body(
top_k=top_k,
inputs=inputs,
vector=vector,
id=id,
filter=filter,
fields=fields,
rerank=rerank,
match_terms=match_terms,
query=query,
wrap_dense_vector=False,
)

query_body: dict[str, Any] = {"top_k": top_k}
if inputs is not None:
query_body["inputs"] = inputs
if vector is not None:
query_body["vector"] = vector
if id is not None:
query_body["id"] = id
if filter is not None:
query_body["filter"] = filter
if match_terms is not None:
query_body["match_terms"] = match_terms

body: dict[str, Any] = {"query": query_body}
if fields is not None:
body["fields"] = fields
if rerank is not None:
body["rerank"] = rerank

logger.info("Searching namespace %r with top_k=%d (via REST)", namespace, top_k)
logger.info(
"Searching namespace %r with top_k=%d (via REST)",
namespace,
body["query"]["top_k"],
)
response = self._http.post(
f"/records/namespaces/{namespace}/search", timeout=timeout, json=body
)
Expand All @@ -1326,14 +1322,15 @@ def search_records(
self,
*,
namespace: str,
top_k: int,
top_k: int | None = None,
inputs: SearchInputs | Mapping[str, Any] | None = None,
vector: Sequence[float] | None = None,
vector: Sequence[float] | Mapping[str, Any] | None = None,
id: str | None = None,
filter: Mapping[str, Any] | None = None,
fields: Sequence[str] | None = None,
rerank: RerankConfig | Mapping[str, Any] | None = None,
match_terms: Mapping[str, Any] | None = None,
query: SearchQuery | Mapping[str, Any] | None = None,
timeout: float | None = None,
) -> SearchRecordsResponse:
"""Alias for :meth:`search`.
Expand All @@ -1350,6 +1347,7 @@ def search_records(
fields=fields,
rerank=rerank,
match_terms=match_terms,
query=query,
timeout=timeout,
)

Expand Down
Loading
Loading