Skip to content

Commit

Permalink
feat: add Jina Embeddings MultiModal (#13861)
Browse files Browse the repository at this point in the history
  • Loading branch information
JoanFM authored Jun 3, 2024
1 parent 9c9ca30 commit 1386625
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 87 deletions.
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
"""Jina embeddings file."""

from typing import Any, List, Optional

from urllib.parse import urlparse
from os.path import exists
import base64
import requests
import numpy as np
from llama_index.core.base.embeddings.base import (
DEFAULT_EMBED_BATCH_SIZE,
BaseEmbedding,
)

from llama_index.core.base.embeddings.base import DEFAULT_EMBED_BATCH_SIZE
from llama_index.core.bridge.pydantic import Field, PrivateAttr
from llama_index.core.callbacks.base import CallbackManager
from llama_index.core.base.llms.generic_utils import get_from_param_or_env
from llama_index.core.embeddings import MultiModalEmbedding
from llama_index.core.schema import ImageType

MAX_BATCH_SIZE = 2048

Expand All @@ -19,101 +21,26 @@
VALID_ENCODING = ["float", "ubinary", "binary"]


class JinaEmbedding(BaseEmbedding):
"""JinaAI class for embeddings.

Args:
model (str): Model for embedding.
Defaults to `jina-embeddings-v2-base-en`
"""

api_key: str = Field(default=None, description="The JinaAI API key.")
model: str = Field(
default="jina-embeddings-v2-base-en",
description="The model to use when calling Jina AI API",
)

_session: Any = PrivateAttr()
_encoding_queries: str = PrivateAttr()
_encoding_documents: str = PrivateAttr()

class _JinaAPICaller:
def __init__(
self,
model: str = "jina-embeddings-v2-base-en",
embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
api_key: Optional[str] = None,
callback_manager: Optional[CallbackManager] = None,
encoding_queries: Optional[str] = None,
encoding_documents: Optional[str] = None,
**kwargs: Any,
) -> None:
super().__init__(
embed_batch_size=embed_batch_size,
callback_manager=callback_manager,
model=model,
api_key=api_key,
**kwargs,
)
self._encoding_queries = encoding_queries or "float"
self._encoding_documents = encoding_documents or "float"

assert (
self._encoding_documents in VALID_ENCODING
), f"Encoding Documents parameter {self._encoding_documents} not supported. Please choose one of {VALID_ENCODING}"
assert (
self._encoding_queries in VALID_ENCODING
), f"Encoding Queries parameter {self._encoding_documents} not supported. Please choose one of {VALID_ENCODING}"

self.api_key = get_from_param_or_env("api_key", api_key, "JINAAI_API_KEY", "")
self.model = model
self._session = requests.Session()
self._session.headers.update(
{"Authorization": f"Bearer {api_key}", "Accept-Encoding": "identity"}
)

@classmethod
def class_name(cls) -> str:
return "JinaAIEmbedding"

def _get_query_embedding(self, query: str) -> List[float]:
"""Get query embedding."""
return self._get_embeddings([query], encoding_type=self._encoding_queries)[0]

async def _aget_query_embedding(self, query: str) -> List[float]:
"""The asynchronous version of _get_query_embedding."""
result = await self._aget_embeddings(
[query], encoding_type=self._encoding_queries
)
return result[0]

def _get_text_embedding(self, text: str) -> List[float]:
"""Get text embedding."""
return self._get_text_embeddings([text])[0]

async def _aget_text_embedding(self, text: str) -> List[float]:
"""Asynchronously get text embedding."""
result = await self._aget_text_embeddings([text])
return result[0]

def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
return self._get_embeddings(texts=texts, encoding_type=self._encoding_documents)

async def _aget_text_embeddings(
self,
texts: List[str],
) -> List[List[float]]:
return await self._aget_embeddings(
texts=texts, encoding_type=self._encoding_documents
)

def _get_embeddings(
self, texts: List[str], encoding_type: str = "float"
) -> List[List[float]]:
def get_embeddings(self, input, encoding_type: str = "float") -> List[List[float]]:
"""Get embeddings."""
# Call Jina AI Embedding API
resp = self._session.post( # type: ignore
API_URL,
json={"input": texts, "model": self.model, "encoding_type": encoding_type},
json={"input": input, "model": self.model, "encoding_type": encoding_type},
).json()
if "data" not in resp:
raise RuntimeError(resp["detail"])
Expand All @@ -138,8 +65,8 @@ def _get_embeddings(
]
return [result["embedding"] for result in sorted_embeddings]

async def _aget_embeddings(
self, texts: List[str], encoding_type: str = "float"
async def aget_embeddings(
self, input, encoding_type: str = "float"
) -> List[List[float]]:
"""Asynchronously get text embeddings."""
import aiohttp
Expand All @@ -152,7 +79,7 @@ async def _aget_embeddings(
async with session.post(
f"{API_URL}",
json={
"input": texts,
"input": input,
"model": self.model,
"encoding_type": encoding_type,
},
Expand Down Expand Up @@ -181,3 +108,139 @@ async def _aget_embeddings(
for result in sorted_embeddings
]
return [result["embedding"] for result in sorted_embeddings]


def is_local(url):
url_parsed = urlparse(url)
if url_parsed.scheme in ("file", ""): # Possibly a local file
return exists(url_parsed.path)
return False


def get_bytes_str(file_path):
with open(file_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")


class JinaEmbedding(MultiModalEmbedding):
"""
JinaAI class for embeddings.

Args:
model (str): Model for embedding.
Defaults to `jina-embeddings-v2-base-en`
"""

api_key: str = Field(default=None, description="The JinaAI API key.")
model: str = Field(
default="jina-embeddings-v2-base-en",
description="The model to use when calling Jina AI API",
)

_encoding_queries: str = PrivateAttr()
_encoding_documents: str = PrivateAttr()
_api: Any = PrivateAttr()

def __init__(
self,
model: str = "jina-embeddings-v2-base-en",
embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
api_key: Optional[str] = None,
callback_manager: Optional[CallbackManager] = None,
encoding_queries: Optional[str] = None,
encoding_documents: Optional[str] = None,
**kwargs: Any,
) -> None:
super().__init__(
embed_batch_size=embed_batch_size,
callback_manager=callback_manager,
model=model,
api_key=api_key,
**kwargs,
)
self._encoding_queries = encoding_queries or "float"
self._encoding_documents = encoding_documents or "float"

assert (
self._encoding_documents in VALID_ENCODING
), f"Encoding Documents parameter {self._encoding_documents} not supported. Please choose one of {VALID_ENCODING}"
assert (
self._encoding_queries in VALID_ENCODING
), f"Encoding Queries parameter {self._encoding_documents} not supported. Please choose one of {VALID_ENCODING}"

self._api = _JinaAPICaller(model=model, api_key=api_key)

@classmethod
def class_name(cls) -> str:
return "JinaAIEmbedding"

def _get_query_embedding(self, query: str) -> List[float]:
"""Get query embedding."""
return self._api.get_embeddings(
input=[query], encoding_type=self._encoding_queries
)[0]

async def _aget_query_embedding(self, query: str) -> List[float]:
"""The asynchronous version of _get_query_embedding."""
result = await self._api.aget_embeddings(
input=[query], encoding_type=self._encoding_queries
)
return result[0]

def _get_text_embedding(self, text: str) -> List[float]:
"""Get text embedding."""
return self._get_text_embeddings([text])[0]

async def _aget_text_embedding(self, text: str) -> List[float]:
"""Asynchronously get text embedding."""
result = await self._aget_text_embeddings([text])
return result[0]

def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
return self._api.get_embeddings(
input=texts, encoding_type=self._encoding_documents
)

async def _aget_text_embeddings(
self,
texts: List[str],
) -> List[List[float]]:
return await self._api.aget_embeddings(
input=texts, encoding_type=self._encoding_documents
)

def _get_image_embedding(self, img_file_path: ImageType) -> List[float]:
if is_local(img_file_path):
input = [{"bytes": get_bytes_str(img_file_path)}]
else:
input = [{"url": img_file_path}]
return self._api.get_embeddings(input=input)[0]

async def _aget_image_embedding(self, img_file_path: ImageType) -> List[float]:
if is_local(img_file_path):
input = [{"bytes": get_bytes_str(img_file_path)}]
else:
input = [{"url": img_file_path}]
return await self._api.aget_embeddings(input=input)[0]

def _get_image_embeddings(
self, img_file_paths: List[ImageType]
) -> List[List[float]]:
input = []
for img_file_path in img_file_paths:
if is_local(img_file_path):
input.append({"bytes": get_bytes_str(img_file_path)})
else:
input.append({"url": img_file_path})
return self._api.get_embeddings(input=input)

async def _aget_image_embeddings(
self, img_file_paths: List[ImageType]
) -> List[List[float]]:
input = []
for img_file_path in img_file_paths:
if is_local(img_file_path):
input.append({"bytes": get_bytes_str(img_file_path)})
else:
input.append({"url": img_file_path})
return await self._api.aget_embeddings(input=input)
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ exclude = ["**/BUILD"]
license = "MIT"
name = "llama-index-embeddings-jinaai"
readme = "README.md"
version = "0.1.5"
version = "0.2.0"

[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.core.embeddings import MultiModalEmbedding
from llama_index.embeddings.jinaai import JinaEmbedding


def test_embedding_class():
emb = JinaEmbedding()
assert isinstance(emb, BaseEmbedding)
assert isinstance(emb, MultiModalEmbedding)

0 comments on commit 1386625

Please sign in to comment.