Skip to content

Commit

Permalink
Handle length safe embedding only if needed (langchain-ai#3723)
Browse files Browse the repository at this point in the history
Re: langchain-ai#3722

Copy pasting context from the issue:


https://github.com/hwchase17/langchain/blob/1bf1c37c0cccb7c8c73d87ace27cf742f814dbe5/langchain/embeddings/openai.py#L210-L211

Means that the length safe embedding method is "always" used, initial
implementation langchain-ai#991 has the
`embedding_ctx_length` set to -1 (meaning you had to opt-in for the
length safe method), langchain-ai#2330
changed that to max length of OpenAI embeddings v2, meaning the length
safe method is used at all times.

How about changing that if branch to use length safe method only when
needed, meaning when the text is longer than the max context length?
  • Loading branch information
ravwojdyla authored and samching committed May 1, 2023
1 parent 34d1eaf commit 59b874f
Showing 1 changed file with 4 additions and 15 deletions.
19 changes: 4 additions & 15 deletions langchain/embeddings/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def _get_len_safe_embeddings(
def _embedding_func(self, text: str, *, engine: str) -> List[float]:
"""Call out to OpenAI's embedding endpoint."""
# handle large input text
if self.embedding_ctx_length > 0:
if len(text) > self.embedding_ctx_length:
return self._get_len_safe_embeddings([text], engine=engine)[0]
else:
# replace newlines, which can negatively affect performance.
Expand All @@ -229,20 +229,9 @@ def embed_documents(
Returns:
List of embeddings, one for each text.
"""
# handle batches of large input text
if self.embedding_ctx_length > 0:
return self._get_len_safe_embeddings(texts, engine=self.deployment)
else:
results = []
_chunk_size = chunk_size or self.chunk_size
for i in range(0, len(texts), _chunk_size):
response = embed_with_retry(
self,
input=texts[i : i + _chunk_size],
engine=self.deployment,
)
results += [r["embedding"] for r in response["data"]]
return results
# NOTE: to keep things simple, we assume the list may contain texts longer
# than the maximum context and use length-safe embedding function.
return self._get_len_safe_embeddings(texts, engine=self.deployment)

def embed_query(self, text: str) -> List[float]:
"""Call out to OpenAI's embedding endpoint for embedding query text.
Expand Down

0 comments on commit 59b874f

Please sign in to comment.