Skip to content

Commit

Permalink
chore(python): update embedding API to use openai 1.6.1 (lancedb#751)
Browse files Browse the repository at this point in the history
API has changed significantly, namely `openai.Embedding.create` no
longer exists.
openai/openai-python#742

Update the OpenAI embedding function and put a minimum on the openai sdk
version.
  • Loading branch information
changhiskhan committed Dec 28, 2023
1 parent a0e2524 commit eb91a64
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 5 deletions.
9 changes: 7 additions & 2 deletions python/lancedb/embeddings/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import cached_property
from typing import List, Union

import numpy as np
Expand Down Expand Up @@ -44,6 +45,10 @@ def generate_embeddings(
The texts to embed
"""
# TODO retry, rate limit, token limit
rs = self._openai_client.embeddings.create(input=texts, model=self.name)
return [v.embedding for v in rs.data]

@cached_property
def _openai_client(self):
openai = self.safe_import("openai")
rs = openai.Embedding.create(input=texts, model=self.name)["data"]
return [v["embedding"] for v in rs]
return openai.OpenAI()
2 changes: 1 addition & 1 deletion python/lancedb/embeddings/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def wrapper(*args, **kwargs):

if num_retries > max_retries:
raise Exception(
f"Maximum number of retries ({max_retries}) exceeded."
f"Maximum number of retries ({max_retries}) exceeded.", e
)

delay *= exponential_base * (1 + jitter * random.random())
Expand Down
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ tests = ["pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "requests", "
dev = ["ruff", "pre-commit", "black"]
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
clip = ["torch", "pillow", "open-clip"]
embeddings = ["openai", "sentence-transformers", "torch", "pillow", "open-clip-torch", "cohere", "InstructorEmbedding"]
embeddings = ["openai>=1.6.1", "sentence-transformers", "torch", "pillow", "open-clip-torch", "cohere", "InstructorEmbedding"]

[project.scripts]
lancedb = "lancedb.cli.cli:cli"
Expand Down
2 changes: 1 addition & 1 deletion python/tests/test_embeddings_slow.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

@pytest.mark.slow
@pytest.mark.parametrize("alias", ["sentence-transformers", "openai"])
def test_sentence_transformer(alias, tmp_path):
def test_basic_text_embeddings(alias, tmp_path):
db = lancedb.connect(tmp_path)
registry = get_registry()
func = registry.get(alias).create(max_retries=0)
Expand Down

0 comments on commit eb91a64

Please sign in to comment.