Skip to content

Commit 29e989b

Browse files
committed
python api for e5v
1 parent 0717956 commit 29e989b

File tree

5 files changed

+237
-0
lines changed

5 files changed

+237
-0
lines changed

python/sparknlp/annotator/embeddings/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,4 @@
4141
from sparknlp.annotator.embeddings.snowflake_embeddings import *
4242
from sparknlp.annotator.embeddings.nomic_embeddings import *
4343
from sparknlp.annotator.embeddings.auto_gguf_embeddings import *
44+
from sparknlp.annotator.embeddings.e5v_embeddings import *
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
# Copyright 2017-2024 John Snow Labs
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from sparknlp.common import *
16+
17+
class E5VEmbeddings(AnnotatorModel,
18+
HasBatchedAnnotateImage,
19+
HasImageFeatureProperties,
20+
HasEngine,
21+
HasRescaleFactor):
22+
"""Universal multimodal embeddings using the E5-V model (see https://huggingface.co/royokong/e5-v).
23+
24+
E5-V bridges the modality gap between different input types (text, image) and demonstrates strong performance in multimodal embeddings, even without fine-tuning. It also supports a single-modality training approach, where the model is trained exclusively on text pairs, often yielding better performance than multimodal training.
25+
26+
Pretrained models can be loaded with :meth:`.pretrained` of the companion object:
27+
28+
>>> e5vEmbeddings = E5VEmbeddings.pretrained() \
29+
... .setInputCols(["image_assembler"]) \
30+
... .setOutputCol("e5v")
31+
32+
The default model is ``"e5v_int4"``, if no name is provided.
33+
34+
For available pretrained models please see the `Models Hub <https://sparknlp.org/models?task=Question+Answering>`__.
35+
36+
====================== ======================
37+
Input Annotation types Output Annotation type
38+
====================== ======================
39+
``IMAGE`` ``SENTENCE_EMBEDDINGS``
40+
====================== ======================
41+
42+
Examples
43+
--------
44+
Image + Text Embedding:
45+
>>> import sparknlp
46+
>>> from sparknlp.base import *
47+
>>> from sparknlp.annotator import *
48+
>>> from pyspark.ml import Pipeline
49+
>>> image_df = spark.read.format("image").option("dropInvalid", value = True).load(imageFolder)
50+
>>> imagePrompt = "<|start_header_id|>user<|end_header_id|>\n\n<image>\\nSummary above image in one word: <|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n"
51+
>>> test_df = image_df.withColumn("text", lit(imagePrompt))
52+
>>> imageAssembler = ImageAssembler() \
53+
... .setInputCol("image") \
54+
... .setOutputCol("image_assembler")
55+
>>> e5vEmbeddings = E5VEmbeddings.pretrained() \
56+
... .setInputCols(["image_assembler"]) \
57+
... .setOutputCol("e5v")
58+
>>> pipeline = Pipeline().setStages([
59+
... imageAssembler,
60+
... e5vEmbeddings
61+
... ])
62+
>>> result = pipeline.fit(test_df).transform(test_df)
63+
>>> result.select("e5v.embeddings").show(truncate = False)
64+
65+
Text-Only Embedding:
66+
>>> from sparknlp.util import EmbeddingsDataFrameUtils
67+
>>> textPrompt = "<|start_header_id|>user<|end_header_id|>\n\n<sent>\\nSummary above sentence in one word: <|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n"
68+
>>> textDesc = "A cat sitting in a box."
69+
>>> nullImageDF = spark.createDataFrame(spark.sparkContext.parallelize([EmbeddingsDataFrameUtils.emptyImageRow]), EmbeddingsDataFrameUtils.imageSchema)
70+
>>> textDF = nullImageDF.withColumn("text", lit(textPrompt.replace("<sent>", textDesc)))
71+
>>> e5vEmbeddings = E5VEmbeddings.pretrained() \
72+
... .setInputCols(["image"]) \
73+
... .setOutputCol("e5v")
74+
>>> result = e5vEmbeddings.transform(textDF)
75+
>>> result.select("e5v.embeddings").show(truncate = False)
76+
"""
77+
78+
name = "E5VEmbeddings"
79+
80+
inputAnnotatorTypes = [AnnotatorType.IMAGE]
81+
outputAnnotatorType = AnnotatorType.SENTENCE_EMBEDDINGS
82+
83+
@keyword_only
84+
def __init__(self, classname="com.johnsnowlabs.nlp.annotators.embeddings.E5VEmbeddings", java_model=None):
85+
"""Initializes the E5VEmbeddings annotator.
86+
87+
Parameters
88+
----------
89+
classname : str, optional
90+
The Java class name of the annotator, by default "com.johnsnowlabs.nlp.annotators.embeddings.E5VEmbeddings"
91+
java_model : Optional[java.lang.Object], optional
92+
A pre-initialized Java model, by default None
93+
"""
94+
super(E5VEmbeddings, self).__init__(classname=classname, java_model=java_model)
95+
self._setDefault()
96+
97+
@staticmethod
98+
def loadSavedModel(folder, spark_session, use_openvino=False):
99+
"""Loads a locally saved model.
100+
101+
Parameters
102+
----------
103+
folder : str
104+
Folder of the saved model
105+
spark_session : pyspark.sql.SparkSession
106+
The current SparkSession
107+
use_openvino : bool, optional
108+
Whether to use OpenVINO engine, by default False
109+
110+
Returns
111+
-------
112+
E5VEmbeddings
113+
The restored model
114+
"""
115+
from sparknlp.internal import _E5VEmbeddingsLoader
116+
jModel = _E5VEmbeddingsLoader(folder, spark_session._jsparkSession, use_openvino)._java_obj
117+
return E5VEmbeddings(java_model=jModel)
118+
119+
@staticmethod
120+
def pretrained(name="e5v_int4", lang="en", remote_loc=None):
121+
"""Downloads and loads a pretrained model.
122+
123+
Parameters
124+
----------
125+
name : str, optional
126+
Name of the pretrained model, by default "e5v_int4"
127+
lang : str, optional
128+
Language of the pretrained model, by default "en"
129+
remote_loc : str, optional
130+
Optional remote address of the resource, by default None. Will use Spark NLPs repositories otherwise.
131+
132+
Returns
133+
-------
134+
E5VEmbeddings
135+
The restored model
136+
"""
137+
from sparknlp.pretrained import ResourceDownloader
138+
return ResourceDownloader.downloadModel(E5VEmbeddings, name, lang, remote_loc)

python/sparknlp/internal/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1165,3 +1165,11 @@ def __init__(self, path, jspark, use_openvino=False):
11651165
jspark,
11661166
use_openvino,
11671167
)
1168+
class _E5VEmbeddingsLoader(ExtendedJavaWrapper):
1169+
def __init__(self, path, jspark, use_openvino=False):
1170+
super(_E5VEmbeddingsLoader, self).__init__(
1171+
"com.johnsnowlabs.nlp.embeddings.E5VEmbeddings.loadSavedModel",
1172+
path,
1173+
jspark,
1174+
use_openvino
1175+
)

python/sparknlp/util.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515

1616

1717
import sparknlp.internal as _internal
18+
import numpy as np
19+
from pyspark.sql import Row
20+
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, BinaryType
1821

1922

2023
def get_config_path():
@@ -33,3 +36,26 @@ def exportConllFiles(*args):
3336
_internal._CoNLLGeneratorExportFromTargetAndPipeline(*args).apply()
3437
else:
3538
raise NotImplementedError(f"No exportConllFiles alternative takes {num_args} parameters")
39+
40+
41+
class EmbeddingsDataFrameUtils:
42+
"""
43+
Utility for creating DataFrames compatible with multimodal embedding models (e.g., E5VEmbeddings) for text-only scenarios.
44+
Provides:
45+
- imageSchema: the expected schema for Spark image DataFrames
46+
- emptyImageRow: a dummy image row for text-only embedding
47+
"""
48+
imageSchema = StructType([
49+
StructField(
50+
"image",
51+
StructType([
52+
StructField("origin", StringType(), True),
53+
StructField("height", IntegerType(), True),
54+
StructField("width", IntegerType(), True),
55+
StructField("nChannels", IntegerType(), True),
56+
StructField("mode", IntegerType(), True),
57+
StructField("data", BinaryType(), True),
58+
]),
59+
)
60+
])
61+
emptyImageRow = Row(Row("", 0, 0, 0, 0, bytes()))
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Copyright 2017-2024 John Snow Labs
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import os
15+
import unittest
16+
import pytest
17+
18+
from sparknlp.annotator import *
19+
from sparknlp.base import *
20+
from pyspark.ml import Pipeline
21+
from pyspark.sql.functions import lit
22+
from test.util import SparkContextForTest
23+
24+
@pytest.mark.slow
25+
class E5VEmbeddingsTestSpec(unittest.TestCase):
26+
def setUp(self):
27+
self.spark = SparkContextForTest.spark
28+
self.images_path = "file://"+os.getcwd() + "/../src/test/resources/image/"
29+
30+
def test_image_and_text_embedding(self):
31+
# Simulate image+text embedding (requires actual image files for full test)
32+
image_folder = os.environ.get("E5V_IMAGE_TEST_FOLDER", self.images_path)
33+
imagePrompt = "<|start_header_id|>user<|end_header_id|>\n\n<image>\\nSummary above image in one word: <|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n"
34+
image_df = self.spark.read.format("image").option("dropInvalid", True).load(image_folder)
35+
test_df = image_df.withColumn("text", lit(imagePrompt))
36+
37+
imageAssembler = ImageAssembler() \
38+
.setInputCol("image") \
39+
.setOutputCol("image_assembler")
40+
e5v = E5VEmbeddings.pretrained() \
41+
.setInputCols(["image_assembler"]) \
42+
.setOutputCol("e5v")
43+
pipeline = Pipeline().setStages([imageAssembler, e5v])
44+
results = pipeline.fit(test_df).transform(test_df)
45+
results.select("e5v.embeddings").show(truncate=True)
46+
47+
def test_text_only_embedding(self):
48+
# Simulate text-only embedding using emptyImageRow and imageSchema
49+
from sparknlp.util import EmbeddingsDataFrameUtils
50+
textPrompt = "<|start_header_id|>user<|end_header_id|>\n\n<sent>\\nSummary above sentence in one word: <|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n"
51+
textDesc = "A cat sitting in a box."
52+
nullImageDF = self.spark.createDataFrame(
53+
self.spark.sparkContext.parallelize([EmbeddingsDataFrameUtils.emptyImageRow]),
54+
EmbeddingsDataFrameUtils.imageSchema)
55+
textDF = nullImageDF.withColumn("text", lit(textPrompt.replace("<sent>", textDesc)))
56+
imageAssembler = ImageAssembler() \
57+
.setInputCol("image") \
58+
.setOutputCol("image_assembler")
59+
e5v = E5VEmbeddings.pretrained() \
60+
.setInputCols(["image_assembler"]) \
61+
.setOutputCol("e5v")
62+
pipeline = Pipeline().setStages([imageAssembler, e5v])
63+
results = pipeline.fit(textDF).transform(textDF)
64+
results.select("e5v.embeddings").show(truncate=True)

0 commit comments

Comments
 (0)