Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,4 @@ dmypy.json
*~

tests/integration/proxy_config/logs
benchmark_results.json
10 changes: 6 additions & 4 deletions pinecone/db_data/resources/asyncio/vector_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from pinecone.utils.tqdm import tqdm
import logging
import asyncio
import json
from typing import List, Any, Literal, AsyncIterator

import orjson

from pinecone.core.openapi.db_data.api.vector_operations_api import AsyncioVectorOperationsApi
from pinecone.core.openapi.db_data.models import (
QueryResponse as OpenAPIQueryResponse,
Expand Down Expand Up @@ -571,11 +572,12 @@ async def query_namespaces(
from pinecone.openapi_support.rest_utils import RESTResponse

if isinstance(raw_result, RESTResponse):
response = json.loads(raw_result.data.decode("utf-8"))
response = orjson.loads(raw_result.data)
aggregator.add_results(response)
else:
# Fallback: if somehow we got an OpenAPIQueryResponse, parse it
response = json.loads(raw_result.to_dict())
# Fallback: if somehow we got an OpenAPIQueryResponse, use dict directly
# to_dict() returns a dict, not JSON, so no parsing needed
response = raw_result.to_dict()
aggregator.add_results(response)

final_results = aggregator.get_results()
Expand Down
5 changes: 3 additions & 2 deletions pinecone/db_data/resources/sync/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

from pinecone.utils.tqdm import tqdm
import logging
import json
from typing import Any, Literal

import orjson
from multiprocessing.pool import ApplyResult
from concurrent.futures import as_completed

Expand Down Expand Up @@ -649,7 +650,7 @@ def query_namespaces(
futures: list[Future[Any]] = cast(list[Future[Any]], async_futures)
for result in as_completed(futures):
raw_result = result.result()
response = json.loads(raw_result.data.decode("utf-8"))
response = orjson.loads(raw_result.data)
aggregator.add_results(response)

final_results = aggregator.get_results()
Expand Down
9 changes: 5 additions & 4 deletions pinecone/openapi_support/api_client_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import json
import mimetypes
import io
import mimetypes
import os
from urllib3.fields import RequestField
from urllib.parse import quote
from urllib3.fields import RequestField

import orjson
from typing import Any
from .serializer import Serializer
from .exceptions import PineconeApiValueError
Expand Down Expand Up @@ -116,7 +116,8 @@ def parameters_to_multipart(params, collection_types):
if isinstance(
v, collection_types
): # v is instance of collection_type, formatting as application/json
v = json.dumps(v, ensure_ascii=False).encode("utf-8")
# orjson.dumps() returns bytes, no need to encode
v = orjson.dumps(v)
field = RequestField(k, v)
field.make_multipart(content_type="application/json; charset=utf-8")
new_params.append(field)
Expand Down
7 changes: 4 additions & 3 deletions pinecone/openapi_support/asyncio_api_client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import json
import io
from urllib3.fields import RequestField
import logging
from urllib3.fields import RequestField

import orjson
from typing import Any


Expand Down Expand Up @@ -203,7 +203,8 @@ def parameters_to_multipart(self, params, collection_types):
if isinstance(
v, collection_types
): # v is instance of collection_type, formatting as application/json
v = json.dumps(v, ensure_ascii=False).encode("utf-8")
# orjson.dumps() returns bytes, no need to encode
v = orjson.dumps(v)
field = RequestField(k, v)
field.make_multipart(content_type="application/json; charset=utf-8")
new_params.append(field)
Expand Down
5 changes: 3 additions & 2 deletions pinecone/openapi_support/deserializer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import json
import re
from typing import TypeVar, Type, Any

import orjson

from .model_utils import deserialize_file, file_type, validate_and_convert_types

T = TypeVar("T")
Expand Down Expand Up @@ -53,7 +54,7 @@ def deserialize(

# fetch data from response object
try:
received_data = json.loads(response.data)
received_data = orjson.loads(response.data)
except ValueError:
received_data = response.data

Expand Down
10 changes: 7 additions & 3 deletions pinecone/openapi_support/rest_aiohttp.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import ssl
import certifi
import json

import orjson
from .rest_utils import RestClientInterface, RESTResponse, raise_exceptions_or_return
from ..config.openapi_configuration import Configuration

Expand Down Expand Up @@ -61,7 +62,7 @@ async def request(
headers["Content-Type"] = "application/json"

if "application/x-ndjson" in headers.get("Content-Type", "").lower():
ndjson_data = "\n".join(json.dumps(record) for record in body)
ndjson_data = "\n".join(orjson.dumps(record).decode("utf-8") for record in body)

async with self._retry_client.request(
method, url, params=query_params, headers=headers, data=ndjson_data
Expand All @@ -72,8 +73,11 @@ async def request(
)

else:
# Pre-serialize with orjson for better performance than aiohttp's json parameter
# which uses standard library json
body_data = orjson.dumps(body) if body is not None else None
async with self._retry_client.request(
method, url, params=query_params, headers=headers, json=body
method, url, params=query_params, headers=headers, data=body_data
) as resp:
content = await resp.read()
return raise_exceptions_or_return(
Expand Down
13 changes: 8 additions & 5 deletions pinecone/openapi_support/rest_urllib3.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import json
import logging
import ssl
import os
import ssl
from urllib.parse import urlencode, quote

import orjson
from ..config.openapi_configuration import Configuration
from .rest_utils import raise_exceptions_or_return, RESTResponse, RestClientInterface

Expand Down Expand Up @@ -141,7 +142,7 @@ def request(
+ bcolors.ENDC
)
else:
formatted_body = json.dumps(body)
formatted_body = orjson.dumps(body).decode("utf-8")
print(
bcolors.OKBLUE
+ "curl -X {method} '{url}' {formatted_headers} -d '{data}'".format(
Expand Down Expand Up @@ -184,9 +185,11 @@ def request(
if content_type == "application/x-ndjson":
# for x-ndjson requests, we are expecting an array of elements
# that need to be converted to a newline separated string
request_body = "\n".join(json.dumps(element) for element in body)
request_body = "\n".join(
orjson.dumps(element).decode("utf-8") for element in body
)
else: # content_type == "application/json":
request_body = json.dumps(body)
request_body = orjson.dumps(body).decode("utf-8")
r = self.pool_manager.request(
method,
url,
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ classifiers = [
dependencies = [
"typing-extensions>=3.7.4",
"certifi>=2019.11.17",
"orjson>=3.0.0",
"pinecone-plugin-interface>=0.0.7,<0.1.0",
"python-dateutil>=2.5.3",
"pinecone-plugin-assistant==3.0.0",
Expand Down
158 changes: 158 additions & 0 deletions tests/perf/test_orjson_performance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
"""Performance tests comparing orjson vs standard json library.

These tests measure the performance improvements from using orjson
for JSON serialization and deserialization in REST API requests/responses.
"""

import json
import random

import orjson
import pytest


def create_vector_payload(num_vectors: int, dimension: int) -> list[dict]:
"""Create a typical upsert payload with vectors."""
vectors = []
for i in range(num_vectors):
vector = {
"id": f"vec_{i}",
"values": [random.random() for _ in range(dimension)],
"metadata": {
"category": f"cat_{i % 10}",
"score": random.randint(0, 100),
"tags": [f"tag_{j}" for j in range(3)],
},
}
vectors.append(vector)
return vectors


def create_query_response(num_matches: int, dimension: int, include_values: bool = True) -> dict:
"""Create a typical query response payload."""
matches = []
for i in range(num_matches):
match = {
"id": f"vec_{i}",
"score": random.random(),
"metadata": {"category": f"cat_{i % 10}", "score": random.randint(0, 100)},
}
if include_values:
match["values"] = [random.random() for _ in range(dimension)]
matches.append(match)
return {"matches": matches}


class TestOrjsonSerialization:
"""Benchmark orjson.dumps() vs json.dumps()."""

@pytest.mark.parametrize("num_vectors,dimension", [(10, 128), (100, 128), (100, 512)])
def test_json_dumps_vectors(self, benchmark, num_vectors, dimension):
"""Benchmark json.dumps() for vector payloads."""
payload = create_vector_payload(num_vectors, dimension)
result = benchmark(json.dumps, payload)
assert isinstance(result, str)
assert len(result) > 0

@pytest.mark.parametrize("num_vectors,dimension", [(10, 128), (100, 128), (100, 512)])
def test_orjson_dumps_vectors(self, benchmark, num_vectors, dimension):
"""Benchmark orjson.dumps() for vector payloads."""
payload = create_vector_payload(num_vectors, dimension)
result = benchmark(orjson.dumps, payload)
assert isinstance(result, bytes)
assert len(result) > 0

@pytest.mark.parametrize("num_matches,dimension", [(10, 128), (100, 128), (1000, 128)])
def test_json_dumps_query_response(self, benchmark, num_matches, dimension):
"""Benchmark json.dumps() for query responses."""
payload = create_query_response(num_matches, dimension)
result = benchmark(json.dumps, payload)
assert isinstance(result, str)
assert len(result) > 0

@pytest.mark.parametrize("num_matches,dimension", [(10, 128), (100, 128), (1000, 128)])
def test_orjson_dumps_query_response(self, benchmark, num_matches, dimension):
"""Benchmark orjson.dumps() for query responses."""
payload = create_query_response(num_matches, dimension)
result = benchmark(orjson.dumps, payload)
assert isinstance(result, bytes)
assert len(result) > 0


class TestOrjsonDeserialization:
"""Benchmark orjson.loads() vs json.loads()."""

@pytest.mark.parametrize("num_vectors,dimension", [(10, 128), (100, 128), (100, 512)])
def test_json_loads_vectors(self, benchmark, num_vectors, dimension):
"""Benchmark json.loads() for vector payloads."""
payload = create_vector_payload(num_vectors, dimension)
json_str = json.dumps(payload)
result = benchmark(json.loads, json_str)
assert isinstance(result, list)
assert len(result) == num_vectors

@pytest.mark.parametrize("num_vectors,dimension", [(10, 128), (100, 128), (100, 512)])
def test_orjson_loads_vectors(self, benchmark, num_vectors, dimension):
"""Benchmark orjson.loads() for vector payloads."""
payload = create_vector_payload(num_vectors, dimension)
json_bytes = json.dumps(payload).encode("utf-8")
result = benchmark(orjson.loads, json_bytes)
assert isinstance(result, list)
assert len(result) == num_vectors

@pytest.mark.parametrize("num_matches,dimension", [(10, 128), (100, 128), (1000, 128)])
def test_json_loads_query_response(self, benchmark, num_matches, dimension):
"""Benchmark json.loads() for query responses."""
payload = create_query_response(num_matches, dimension)
json_str = json.dumps(payload)
result = benchmark(json.loads, json_str)
assert isinstance(result, dict)
assert len(result["matches"]) == num_matches

@pytest.mark.parametrize("num_matches,dimension", [(10, 128), (100, 128), (1000, 128)])
def test_orjson_loads_query_response(self, benchmark, num_matches, dimension):
"""Benchmark orjson.loads() for query responses."""
payload = create_query_response(num_matches, dimension)
json_bytes = json.dumps(payload).encode("utf-8")
result = benchmark(orjson.loads, json_bytes)
assert isinstance(result, dict)
assert len(result["matches"]) == num_matches

@pytest.mark.parametrize("num_matches,dimension", [(10, 128), (100, 128), (1000, 128)])
def test_orjson_loads_from_string(self, benchmark, num_matches, dimension):
"""Benchmark orjson.loads() with string input (like from decoded response)."""
payload = create_query_response(num_matches, dimension)
json_str = json.dumps(payload)
result = benchmark(orjson.loads, json_str)
assert isinstance(result, dict)
assert len(result["matches"]) == num_matches


class TestRoundTrip:
"""Benchmark complete round-trip serialization/deserialization."""

@pytest.mark.parametrize("num_vectors,dimension", [(10, 128), (100, 128)])
def test_json_round_trip(self, benchmark, num_vectors, dimension):
"""Benchmark json round-trip (dumps + loads)."""

def round_trip(payload):
json_str = json.dumps(payload)
return json.loads(json_str)

payload = create_vector_payload(num_vectors, dimension)
result = benchmark(round_trip, payload)
assert isinstance(result, list)
assert len(result) == num_vectors

@pytest.mark.parametrize("num_vectors,dimension", [(10, 128), (100, 128)])
def test_orjson_round_trip(self, benchmark, num_vectors, dimension):
"""Benchmark orjson round-trip (dumps + loads)."""

def round_trip(payload):
json_bytes = orjson.dumps(payload)
return orjson.loads(json_bytes)

payload = create_vector_payload(num_vectors, dimension)
result = benchmark(round_trip, payload)
assert isinstance(result, list)
assert len(result) == num_vectors
Loading
Loading