Skip to content

Commit

Permalink
Exponential standoff retry support for handling rate limited embeddin…
Browse files Browse the repository at this point in the history
…g functions (lancedb#614)

Users ingesting data using rate limited apis don't need to manually make
the process sleep for counter rate limits
resolves lancedb#579
  • Loading branch information
AyushExel committed Nov 2, 2023
1 parent b502c55 commit 0c66882
Show file tree
Hide file tree
Showing 6 changed files with 134 additions and 5 deletions.
25 changes: 25 additions & 0 deletions python/lancedb/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import os
import time
from typing import Any

import numpy as np
import pytest
Expand Down Expand Up @@ -38,3 +40,26 @@ def _compute_one_embedding(self, row):

def ndims(self):
return 10


class RateLimitedAPI:
rate_limit = 0.1 # 1 request per 0.1 second
last_request_time = 0

@staticmethod
def make_request():
current_time = time.time()

if current_time - RateLimitedAPI.last_request_time < RateLimitedAPI.rate_limit:
raise Exception("Rate limit exceeded. Please try again later.")

# Simulate a successful request
RateLimitedAPI.last_request_time = current_time
return "Request successful"


@registry.register("test-rate-limited")
class MockRateLimitedEmbeddingFunction(MockTextEmbeddingFunction):
def generate_embeddings(self, texts):
RateLimitedAPI.make_request()
return [self._compute_one_embedding(row) for row in texts]
24 changes: 23 additions & 1 deletion python/lancedb/embeddings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pyarrow as pa
from pydantic import BaseModel, Field, PrivateAttr

from .utils import TEXT
from .utils import TEXT, retry_with_exponential_backoff


class EmbeddingFunction(BaseModel, ABC):
Expand All @@ -21,6 +21,9 @@ class EmbeddingFunction(BaseModel, ABC):
3. ndims method which returns the number of dimensions of the vector column
"""

max_retries: int = (
7 # Setitng 0 disables retires. Maybe this should not be enabled by default,
)
_ndims: int = PrivateAttr()

@classmethod
Expand All @@ -44,6 +47,25 @@ def compute_source_embeddings(self, *args, **kwargs) -> List[np.array]:
"""
pass

def compute_query_embeddings_with_retry(self, *args, **kwargs) -> List[np.array]:
"""
Compute the embeddings for a given user query with retries
"""
return retry_with_exponential_backoff(
self.compute_query_embeddings, max_retries=self.max_retries
)(
*args,
**kwargs,
)

def compute_source_embeddings_with_retry(self, *args, **kwargs) -> List[np.array]:
"""
Compute the embeddings for the source column in the database with retries
"""
return retry_with_exponential_backoff(
self.compute_source_embeddings, max_retries=self.max_retries
)(*args, **kwargs)

def sanitize_input(self, texts: TEXT) -> Union[List[str], np.ndarray]:
"""
Sanitize the input to the embedding function.
Expand Down
51 changes: 51 additions & 0 deletions python/lancedb/embeddings/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
# limitations under the License.

import math
import random
import socket
import sys
import time
import urllib.error
from typing import Callable, List, Union

Expand Down Expand Up @@ -162,6 +164,55 @@ def _chunker(arr):
yield from _chunker(arr)


def retry_with_exponential_backoff(
func,
initial_delay: float = 1,
exponential_base: float = 2,
jitter: bool = True,
max_retries: int = 7,
# errors: tuple = (),
):
"""Retry a function with exponential backoff.
Args:
func (function): The function to be retried.
initial_delay (float): Initial delay in seconds (default is 1).
exponential_base (float): The base for exponential backoff (default is 2).
jitter (bool): Whether to add jitter to the delay (default is True).
max_retries (int): Maximum number of retries (default is 10).
errors (tuple): Tuple of specific exceptions to retry on (default is (openai.error.RateLimitError,)).
Returns:
function: The decorated function.
"""

def wrapper(*args, **kwargs):
num_retries = 0
delay = initial_delay

# Loop until a successful response or max_retries is hit or an exception is raised
while True:
try:
return func(*args, **kwargs)

# Currently retrying on all exceptions as there is no way to know the format of the error msgs used by different APIs
# We'll log the error and say that it is assumed that if this portion errors out, it's due to rate limit but the user
# should check the error message to be sure
except Exception as e:
num_retries += 1

if num_retries > max_retries:
raise Exception(
f"Maximum number of retries ({max_retries}) exceeded."
)

delay *= exponential_base * (1 + jitter * random.random())
LOGGER.info(f"Retrying in {delay:.2f} seconds due to {e}")
time.sleep(delay)

return wrapper


def url_retrieve(url: str):
"""
Parameters
Expand Down
4 changes: 2 additions & 2 deletions python/lancedb/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def _resolve_query(cls, table, query, query_type, vector_column_name):
if not isinstance(query, (list, np.ndarray)):
conf = table.embedding_functions.get(vector_column_name)
if conf is not None:
query = conf.function.compute_query_embeddings(query)[0]
query = conf.function.compute_query_embeddings_with_retry(query)[0]
else:
msg = f"No embedding function for {vector_column_name}"
raise ValueError(msg)
Expand All @@ -151,7 +151,7 @@ def _resolve_query(cls, table, query, query_type, vector_column_name):
else:
conf = table.embedding_functions.get(vector_column_name)
if conf is not None:
query = conf.function.compute_query_embeddings(query)[0]
query = conf.function.compute_query_embeddings_with_retry(query)[0]
return query, "vector"
else:
return query, "fts"
Expand Down
4 changes: 3 additions & 1 deletion python/lancedb/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ def _append_vector_col(data: pa.Table, metadata: dict, schema: Optional[pa.Schem
for vector_column, conf in functions.items():
func = conf.function
if vector_column not in data.column_names:
col_data = func.compute_source_embeddings(data[conf.source_column])
col_data = func.compute_source_embeddings_with_retry(
data[conf.source_column]
)
if schema is not None:
dtype = schema.field(vector_column).type
else:
Expand Down
31 changes: 30 additions & 1 deletion python/tests/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,16 @@
import lance
import numpy as np
import pyarrow as pa
import pytest

from lancedb.conftest import MockTextEmbeddingFunction
import lancedb
from lancedb.conftest import MockRateLimitedEmbeddingFunction, MockTextEmbeddingFunction
from lancedb.embeddings import (
EmbeddingFunctionConfig,
EmbeddingFunctionRegistry,
with_embeddings,
)
from lancedb.pydantic import LanceModel, Vector


def mock_embed_func(input_data):
Expand Down Expand Up @@ -83,3 +86,29 @@ def test_embedding_function(tmp_path):
expected = func.compute_query_embeddings("hello world")

assert np.allclose(actual, expected)


def test_embedding_function_rate_limit(tmp_path):
def _get_schema_from_model(model):
class Schema(LanceModel):
text: str = model.SourceField()
vector: Vector(model.ndims()) = model.VectorField()

return Schema

db = lancedb.connect(tmp_path)
registry = EmbeddingFunctionRegistry.get_instance()
model = registry.get("test-rate-limited").create(max_retries=0)
schema = _get_schema_from_model(model)
table = db.create_table("test", schema=schema, mode="overwrite")
table.add([{"text": "hello world"}])
with pytest.raises(Exception):
table.add([{"text": "hello world"}])
assert len(table) == 1

model = registry.get("test-rate-limited").create()
schema = _get_schema_from_model(model)
table = db.create_table("test", schema=schema, mode="overwrite")
table.add([{"text": "hello world"}])
table.add([{"text": "hello world"}])
assert len(table) == 2

0 comments on commit 0c66882

Please sign in to comment.