From 0c66882ef41e7f3ba599f46fcecf8b6f844294a5 Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Thu, 2 Nov 2023 19:20:10 +0530 Subject: [PATCH] Exponential standoff retry support for handling rate limited embedding functions (#614) Users ingesting data using rate limited apis don't need to manually make the process sleep for counter rate limits resolves #579 --- python/lancedb/conftest.py | 25 +++++++++++++++ python/lancedb/embeddings/base.py | 24 +++++++++++++- python/lancedb/embeddings/utils.py | 51 ++++++++++++++++++++++++++++++ python/lancedb/query.py | 4 +-- python/lancedb/table.py | 4 ++- python/tests/test_embeddings.py | 31 +++++++++++++++++- 6 files changed, 134 insertions(+), 5 deletions(-) diff --git a/python/lancedb/conftest.py b/python/lancedb/conftest.py index a88e967f8..df4907a79 100644 --- a/python/lancedb/conftest.py +++ b/python/lancedb/conftest.py @@ -1,4 +1,6 @@ import os +import time +from typing import Any import numpy as np import pytest @@ -38,3 +40,26 @@ def _compute_one_embedding(self, row): def ndims(self): return 10 + + +class RateLimitedAPI: + rate_limit = 0.1 # 1 request per 0.1 second + last_request_time = 0 + + @staticmethod + def make_request(): + current_time = time.time() + + if current_time - RateLimitedAPI.last_request_time < RateLimitedAPI.rate_limit: + raise Exception("Rate limit exceeded. Please try again later.") + + # Simulate a successful request + RateLimitedAPI.last_request_time = current_time + return "Request successful" + + +@registry.register("test-rate-limited") +class MockRateLimitedEmbeddingFunction(MockTextEmbeddingFunction): + def generate_embeddings(self, texts): + RateLimitedAPI.make_request() + return [self._compute_one_embedding(row) for row in texts] diff --git a/python/lancedb/embeddings/base.py b/python/lancedb/embeddings/base.py index a1d1aa056..e3e34608a 100644 --- a/python/lancedb/embeddings/base.py +++ b/python/lancedb/embeddings/base.py @@ -6,7 +6,7 @@ import pyarrow as pa from pydantic import BaseModel, Field, PrivateAttr -from .utils import TEXT +from .utils import TEXT, retry_with_exponential_backoff class EmbeddingFunction(BaseModel, ABC): @@ -21,6 +21,9 @@ class EmbeddingFunction(BaseModel, ABC): 3. ndims method which returns the number of dimensions of the vector column """ + max_retries: int = ( + 7 # Setitng 0 disables retires. Maybe this should not be enabled by default, + ) _ndims: int = PrivateAttr() @classmethod @@ -44,6 +47,25 @@ def compute_source_embeddings(self, *args, **kwargs) -> List[np.array]: """ pass + def compute_query_embeddings_with_retry(self, *args, **kwargs) -> List[np.array]: + """ + Compute the embeddings for a given user query with retries + """ + return retry_with_exponential_backoff( + self.compute_query_embeddings, max_retries=self.max_retries + )( + *args, + **kwargs, + ) + + def compute_source_embeddings_with_retry(self, *args, **kwargs) -> List[np.array]: + """ + Compute the embeddings for the source column in the database with retries + """ + return retry_with_exponential_backoff( + self.compute_source_embeddings, max_retries=self.max_retries + )(*args, **kwargs) + def sanitize_input(self, texts: TEXT) -> Union[List[str], np.ndarray]: """ Sanitize the input to the embedding function. diff --git a/python/lancedb/embeddings/utils.py b/python/lancedb/embeddings/utils.py index e33bf4d30..1308b3580 100644 --- a/python/lancedb/embeddings/utils.py +++ b/python/lancedb/embeddings/utils.py @@ -12,8 +12,10 @@ # limitations under the License. import math +import random import socket import sys +import time import urllib.error from typing import Callable, List, Union @@ -162,6 +164,55 @@ def _chunker(arr): yield from _chunker(arr) +def retry_with_exponential_backoff( + func, + initial_delay: float = 1, + exponential_base: float = 2, + jitter: bool = True, + max_retries: int = 7, + # errors: tuple = (), +): + """Retry a function with exponential backoff. + + Args: + func (function): The function to be retried. + initial_delay (float): Initial delay in seconds (default is 1). + exponential_base (float): The base for exponential backoff (default is 2). + jitter (bool): Whether to add jitter to the delay (default is True). + max_retries (int): Maximum number of retries (default is 10). + errors (tuple): Tuple of specific exceptions to retry on (default is (openai.error.RateLimitError,)). + + Returns: + function: The decorated function. + """ + + def wrapper(*args, **kwargs): + num_retries = 0 + delay = initial_delay + + # Loop until a successful response or max_retries is hit or an exception is raised + while True: + try: + return func(*args, **kwargs) + + # Currently retrying on all exceptions as there is no way to know the format of the error msgs used by different APIs + # We'll log the error and say that it is assumed that if this portion errors out, it's due to rate limit but the user + # should check the error message to be sure + except Exception as e: + num_retries += 1 + + if num_retries > max_retries: + raise Exception( + f"Maximum number of retries ({max_retries}) exceeded." + ) + + delay *= exponential_base * (1 + jitter * random.random()) + LOGGER.info(f"Retrying in {delay:.2f} seconds due to {e}") + time.sleep(delay) + + return wrapper + + def url_retrieve(url: str): """ Parameters diff --git a/python/lancedb/query.py b/python/lancedb/query.py index 7fb31af12..76e437282 100644 --- a/python/lancedb/query.py +++ b/python/lancedb/query.py @@ -140,7 +140,7 @@ def _resolve_query(cls, table, query, query_type, vector_column_name): if not isinstance(query, (list, np.ndarray)): conf = table.embedding_functions.get(vector_column_name) if conf is not None: - query = conf.function.compute_query_embeddings(query)[0] + query = conf.function.compute_query_embeddings_with_retry(query)[0] else: msg = f"No embedding function for {vector_column_name}" raise ValueError(msg) @@ -151,7 +151,7 @@ def _resolve_query(cls, table, query, query_type, vector_column_name): else: conf = table.embedding_functions.get(vector_column_name) if conf is not None: - query = conf.function.compute_query_embeddings(query)[0] + query = conf.function.compute_query_embeddings_with_retry(query)[0] return query, "vector" else: return query, "fts" diff --git a/python/lancedb/table.py b/python/lancedb/table.py index 7d9717d90..be92cd0b0 100644 --- a/python/lancedb/table.py +++ b/python/lancedb/table.py @@ -86,7 +86,9 @@ def _append_vector_col(data: pa.Table, metadata: dict, schema: Optional[pa.Schem for vector_column, conf in functions.items(): func = conf.function if vector_column not in data.column_names: - col_data = func.compute_source_embeddings(data[conf.source_column]) + col_data = func.compute_source_embeddings_with_retry( + data[conf.source_column] + ) if schema is not None: dtype = schema.field(vector_column).type else: diff --git a/python/tests/test_embeddings.py b/python/tests/test_embeddings.py index 3eb4839bb..03af14eb7 100644 --- a/python/tests/test_embeddings.py +++ b/python/tests/test_embeddings.py @@ -15,13 +15,16 @@ import lance import numpy as np import pyarrow as pa +import pytest -from lancedb.conftest import MockTextEmbeddingFunction +import lancedb +from lancedb.conftest import MockRateLimitedEmbeddingFunction, MockTextEmbeddingFunction from lancedb.embeddings import ( EmbeddingFunctionConfig, EmbeddingFunctionRegistry, with_embeddings, ) +from lancedb.pydantic import LanceModel, Vector def mock_embed_func(input_data): @@ -83,3 +86,29 @@ def test_embedding_function(tmp_path): expected = func.compute_query_embeddings("hello world") assert np.allclose(actual, expected) + + +def test_embedding_function_rate_limit(tmp_path): + def _get_schema_from_model(model): + class Schema(LanceModel): + text: str = model.SourceField() + vector: Vector(model.ndims()) = model.VectorField() + + return Schema + + db = lancedb.connect(tmp_path) + registry = EmbeddingFunctionRegistry.get_instance() + model = registry.get("test-rate-limited").create(max_retries=0) + schema = _get_schema_from_model(model) + table = db.create_table("test", schema=schema, mode="overwrite") + table.add([{"text": "hello world"}]) + with pytest.raises(Exception): + table.add([{"text": "hello world"}]) + assert len(table) == 1 + + model = registry.get("test-rate-limited").create() + schema = _get_schema_from_model(model) + table = db.create_table("test", schema=schema, mode="overwrite") + table.add([{"text": "hello world"}]) + table.add([{"text": "hello world"}]) + assert len(table) == 2