Skip to content

Commit

Permalink
fix: fix encoding type access in JinaEmbeddings (#13315)
Browse files Browse the repository at this point in the history
  • Loading branch information
JoanFM committed May 7, 2024
1 parent f94dfdf commit 2c23764
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

API_URL = "https://api.jina.ai/v1/embeddings"


VALID_ENCODING = ["float", "ubinary", "binary"]


Expand Down Expand Up @@ -78,34 +77,39 @@ def class_name(cls) -> str:

def _get_query_embedding(self, query: str) -> List[float]:
"""Get query embedding."""
return self._get_text_embeddings([query], encoding_type=self._encoding_queries)[
0
]
return self._get_embeddings([query], encoding_type=self._encoding_queries)[0]

async def _aget_query_embedding(self, query: str) -> List[float]:
"""The asynchronous version of _get_query_embedding."""
result = await self._aget_text_embeddings(
result = await self._aget_embeddings(
[query], encoding_type=self._encoding_queries
)
return result[0]

def _get_text_embedding(self, text: str) -> List[float]:
"""Get text embedding."""
return self._get_text_embeddings(
[text], encoding_type=self._encoding_documents
)[0]
return self._get_text_embeddings([text])[0]

async def _aget_text_embedding(self, text: str) -> List[float]:
"""Asynchronously get text embedding."""
result = await self._aget_text_embeddings(
[text], encoding_type=self._encoding_documents
)
result = await self._aget_text_embeddings([text])
return result[0]

def _get_text_embeddings(
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
return self._get_embeddings(texts=texts, encoding_type=self._encoding_documents)

async def _aget_text_embeddings(
self,
texts: List[str],
) -> List[List[float]]:
return await self._aget_embeddings(
texts=texts, encoding_type=self._encoding_documents
)

def _get_embeddings(
self, texts: List[str], encoding_type: str = "float"
) -> List[List[float]]:
"""Get text embeddings."""
"""Get embeddings."""
# Call Jina AI Embedding API
resp = self._session.post( # type: ignore
API_URL,
Expand Down Expand Up @@ -134,7 +138,7 @@ def _get_text_embeddings(
]
return [result["embedding"] for result in sorted_embeddings]

async def _aget_text_embeddings(
async def _aget_embeddings(
self, texts: List[str], encoding_type: str = "float"
) -> List[List[float]]:
"""Asynchronously get text embeddings."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ exclude = ["**/BUILD"]
license = "MIT"
name = "llama-index-embeddings-jinaai"
readme = "README.md"
version = "0.1.4"
version = "0.1.5"

[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
Expand Down

0 comments on commit 2c23764

Please sign in to comment.