diff --git a/docs/en/transformer_entries/E5VEmbeddings.md b/docs/en/transformer_entries/E5VEmbeddings.md new file mode 100644 index 00000000000000..68ff482cd2f900 --- /dev/null +++ b/docs/en/transformer_entries/E5VEmbeddings.md @@ -0,0 +1,133 @@ +{%- capture title -%} +E5VEmbeddings +{%- endcapture -%} + +{%- capture description -%} +Universal multimodal embeddings using E5-V. + +E5-V is a multimodal embedding model that bridges the modality gap between text and images, enabling strong performance in cross-modal retrieval, classification, clustering, and more. It supports both image+text and text-only embedding scenarios, and is fine-tuned from lmms-lab/llama3-llava-next-8b. The default model is `"e5v_int4"`. + +Note that this annotator is only supported for Spark Versions 3.4 and up. + +Pretrained models can be loaded with `pretrained` of the companion object: + +```scala +val embeddings = E5VEmbeddings.pretrained() + .setInputCols("image_assembler") + .setOutputCol("e5v") +``` + +For available pretrained models please see the +[Models Hub](https://sparknlp.org/models?q=E5V). + +For extended examples of usage, see +[E5VEmbeddingsTestSpec](https://github.com/JohnSnowLabs/spark-nlp/blob/master/src/test/scala/com/johnsnowlabs/nlp/embeddings/E5VEmbeddingsTestSpec.scala). + +**Sources** : + +- [E5-V: Universal Embeddings with Multimodal Large Language Models (arXiv)](https://arxiv.org/abs/2407.12580) +- [Hugging Face Model Card](https://huggingface.co/royokong/e5-v) +- [E5-V Github Repository](https://github.com/kongds/E5-V) +{%- endcapture -%} + +{%- capture input_anno -%} +IMAGE +{%- endcapture -%} + +{%- capture output_anno -%} +SENTENCE_EMBEDDINGS +{%- endcapture -%} + +{%- capture python_example -%} +# Image + Text Embedding +import sparknlp +from sparknlp.base import * +from sparknlp.annotator import * +from pyspark.ml import Pipeline +from pyspark.sql.functions import lit + +image_df = spark.read.format("image").option("dropInvalid", True).load(imageFolder) +imagePrompt = "<|start_header_id|>user<|end_header_id|>\n\n<image>\\nSummary above image in one word: <|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n" +test_df = image_df.withColumn("text", lit(imagePrompt)) +imageAssembler = ImageAssembler() \ + .setInputCol("image") \ + .setOutputCol("image_assembler") +e5vEmbeddings = E5VEmbeddings.pretrained() \ + .setInputCols(["image_assembler"]) \ + .setOutputCol("e5v") +pipeline = Pipeline().setStages([ + imageAssembler, + e5vEmbeddings +]) +result = pipeline.fit(test_df).transform(test_df) +result.select("e5v.embeddings").show(truncate=False) + +# Text-Only Embedding +from sparknlp.util import EmbeddingsDataFrameUtils +textPrompt = "<|start_header_id|>user<|end_header_id|>\n\n<sent>\\nSummary above sentence in one word: <|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n" +textDesc = "A cat sitting in a box." +nullImageDF = spark.createDataFrame( + spark.sparkContext.parallelize([EmbeddingsDataFrameUtils.emptyImageRow]), + EmbeddingsDataFrameUtils.imageSchema) +textDF = nullImageDF.withColumn("text", lit(textPrompt.replace("<sent>", textDesc))) +e5vEmbeddings = E5VEmbeddings.pretrained() \ + .setInputCols(["image"]) \ + .setOutputCol("e5v") +result = e5vEmbeddings.transform(textDF) +result.select("e5v.embeddings").show(truncate=False) +{%- endcapture -%} + +{%- capture scala_example -%} +// Image + Text Embedding +import org.apache.spark.sql.functions.lit +import com.johnsnowlabs.nlp.base.ImageAssembler +import com.johnsnowlabs.nlp.embeddings.E5VEmbeddings +import org.apache.spark.ml.Pipeline + +val imageDF = spark.read.format("image").option("dropInvalid", value = true).load(imageFolder) +val imagePrompt = "<|start_header_id|>user<|end_header_id|>\n\n<image>\\nSummary above image in one word: <|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n" +val testDF = imageDF.withColumn("text", lit(imagePrompt)) +val imageAssembler = new ImageAssembler().setInputCol("image").setOutputCol("image_assembler") +val e5vEmbeddings = E5VEmbeddings.pretrained() + .setInputCols("image_assembler") + .setOutputCol("e5v") +val pipeline = new Pipeline().setStages(Array(imageAssembler, e5vEmbeddings)) +val result = pipeline.fit(testDF).transform(testDF) +result.select("e5v.embeddings").show(truncate = false) + +// Text-Only Embedding +import com.johnsnowlabs.nlp.util.EmbeddingsDataFrameUtils.{emptyImageRow, imageSchema} +val textPrompt = "<|start_header_id|>user<|end_header_id|>\n\n<sent>\\nSummary above sentence in one word: <|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n" +val textDesc = "A cat sitting in a box." +val nullImageDF = spark.createDataFrame(spark.sparkContext.parallelize(Seq(emptyImageRow)), imageSchema) +val textDF = nullImageDF.withColumn("text", lit(textPrompt.replace("<sent>", textDesc))) +val e5vEmbeddings = E5VEmbeddings.pretrained() + .setInputCols("image") + .setOutputCol("e5v") +val result2 = e5vEmbeddings.transform(textDF) +result2.select("e5v.embeddings").show(truncate = false) +{%- endcapture -%} + +{%- capture api_link -%} +[E5VEmbeddings](/api/com/johnsnowlabs/nlp/embeddings/E5VEmbeddings) +{%- endcapture -%} + +{%- capture python_api_link -%} +[E5VEmbeddings](/api/python/reference/autosummary/sparknlp/annotator/cv/e5v_embeddings/index.html#sparknlp.annotator.cv.e5v_embeddings.E5VEmbeddings) +{%- endcapture -%} + +{%- capture source_link -%} +[E5VEmbeddings](https://github.com/JohnSnowLabs/spark-nlp/tree/master/src/main/scala/com/johnsnowlabs/nlp/embeddings/E5VEmbeddings.scala) +{%- endcapture -%} + +{% include templates/anno_template.md + title=title + description=description + input_anno=input_anno + output_anno=output_anno + python_example=python_example + scala_example=scala_example + api_link=api_link + python_api_link=python_api_link + source_link=source_link +%} \ No newline at end of file diff --git a/examples/python/transformers/openvino/HuggingFace_OpenVINO_in_Spark_NLP_E5VEmbeddings.ipynb b/examples/python/transformers/openvino/HuggingFace_OpenVINO_in_Spark_NLP_E5VEmbeddings.ipynb new file mode 100644 index 00000000000000..a0757f51a05f0e --- /dev/null +++ b/examples/python/transformers/openvino/HuggingFace_OpenVINO_in_Spark_NLP_E5VEmbeddings.ipynb @@ -0,0 +1,1530 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "c9c512f5", + "metadata": {}, + "source": [ + "\n", + "\n", + "[](https://colab.research.google.com/github/JohnSnowLabs/spark-nlp/blob/master/examples/python/transformers/openvino/HuggingFace_OpenVINO_in_Spark_NLP_E5V.ipynb)" + ] + }, + { + "cell_type": "markdown", + "id": "860a240a", + "metadata": {}, + "source": [ + "# Import OpenVINO E5V models from HuggingFace 🤗 into Spark NLP 🚀\n", + "\n", + "This notebook provides a detailed walkthrough on optimizing and importing E5V models from HuggingFace for use in Spark NLP, with [Intel OpenVINO toolkit](https://www.intel.com/content/www/us/en/developer/tools/openvino-toolkit/overview.html). The focus is on converting the model to the OpenVINO format and applying precision optimizations (INT8 and INT4), to enhance the performance and efficiency on CPU platforms using [Optimum Intel](https://huggingface.co/docs/optimum/main/en/intel/inference).\n", + "\n", + "Let's keep in mind a few things before we start 😊\n", + "\n", + "- OpenVINO support was introduced in `Spark NLP 5.4.0`, enabling high performance CPU inference for models. So please make sure you have upgraded to the latest Spark NLP release.\n", + "- Model quantization is a computationally expensive process, so it is recommended to use a runtime with more than 32GB memory for exporting the quantized model from HuggingFace.\n", + "- You can import E5V models via `E5V`. These models are usually under `Text Generation` category and have `E5V` in their labels.\n", + "- Reference: [E5V](https://huggingface.co/docs/transformers/model_doc/llama#transformers.E5V)\n", + "- Some [example models](https://huggingface.co/models?search=E5V)" + ] + }, + { + "cell_type": "markdown", + "id": "100a6911", + "metadata": {}, + "source": [ + "## 1. Export and Save the HuggingFace model\n", + "\n", + "- Let's install `transformers` and `openvino` packages with other dependencies. You don't need `openvino` to be installed for Spark NLP, however, we need it to load and save models from HuggingFace.\n", + "- We lock `transformers` on version `4.41.2`. This doesn't mean it won't work with the future release, but we wanted you to know which versions have been tested successfully." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "902635c5", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "529ad224", + "metadata": {}, + "outputs": [], + "source": [ + "# # Install OpenVINO and NNCF for model optimization\n", + "import platform\n", + "\n", + "%pip install -q \"einops\" \"torch>2.1\" \"torchvision\" \"matplotlib>=3.4\" \"timm>=0.9.8\" \"transformers==4.41.2\" \"pillow\" \"gradio>=4.19\" --extra-index-url https://download.pytorch.org/whl/cpu\n", + "%pip install -q -U --pre \"openvino>=2025.0\" \"openvino-tokenizers>=2025.0\" \"openvino-genai>=2025.0\" --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly\n", + "%pip install -q \"accelerate\" \"nncf>=2.14.0\" \"git+https://github.com/huggingface/optimum-intel.git\" --extra-index-url https://download.pytorch.org/whl/cpu\n", + "\n", + "if platform.system() == \"Darwin\":\n", + " %pip install -q \"numpy<2.0.0\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3997e780", + "metadata": {}, + "outputs": [], + "source": [ + "!wget https://upload.wikimedia.org/wikipedia/commons/thumb/4/47/American_Eskimo_Dog.jpg/360px-American_Eskimo_Dog.jpg -O dog.jpg" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "c1623528", + "metadata": {}, + "outputs": [], + "source": [ + "model_id = \"royokong/e5-v\"\n", + "output_dir = f\"./models/int4/{model_id}\"" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "46678a0b", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/prabod/anaconda3/envs/e5v/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n", + "Loading checkpoint shards: 100%|██████████| 4/4 [01:20<00:00, 20.18s/it]\n" + ] + }, + { + "data": { + "text/plain": [ + "111" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration\n", + "import torch\n", + "import gc\n", + "\n", + "processor = LlavaNextProcessor.from_pretrained(model_id)\n", + "image_encoder_model, input_embedding_model, language_model = None, None, None\n", + "\n", + "\n", + "class ImageEncoder(torch.nn.Module):\n", + " def __init__(self, config, vision_tower, multi_modal_projector):\n", + " super().__init__()\n", + " self.config = config\n", + " self.vision_tower = vision_tower\n", + " self.multi_modal_projector = multi_modal_projector\n", + "\n", + " def forward(self, pixel_values):\n", + " batch_size, num_patches, num_channels, height, width = pixel_values.shape\n", + " reshaped_pixel_values = pixel_values.view(\n", + " batch_size * num_patches, num_channels, height, width\n", + " )\n", + " image_features = self.vision_tower(\n", + " reshaped_pixel_values, output_hidden_states=True\n", + " )\n", + " selected_image_feature = image_features.hidden_states[\n", + " self.config.vision_feature_layer\n", + " ]\n", + " if self.config.vision_feature_select_strategy == \"default\":\n", + " selected_image_feature = selected_image_feature[:, 1:]\n", + " elif self.config.vision_feature_select_strategy == \"full\":\n", + " selected_image_feature = selected_image_feature\n", + " image_features = self.multi_modal_projector(selected_image_feature)\n", + " return image_features\n", + "\n", + "\n", + "model = LlavaNextForConditionalGeneration.from_pretrained(\n", + " model_id, low_cpu_mem_usage=True\n", + ")\n", + "model.config.save_pretrained(output_dir)\n", + "image_encoder_model = ImageEncoder(\n", + " model.config, model.vision_tower, model.multi_modal_projector\n", + ")\n", + "input_embedding_model = input_embedding_model = model.get_input_embeddings()\n", + "language_model = model.language_model\n", + "del model\n", + "gc.collect()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "1908bc09", + "metadata": {}, + "outputs": [], + "source": [ + "import openvino as ov\n", + "from pathlib import Path\n", + "\n", + "core = ov.Core()\n", + "device = \"CPU\"\n", + "# Load the model and convert it to OpenVINO format\n", + "output_dir = f\"./models/int4/{model_id}\"\n", + "output_dir = Path(output_dir)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "2341d4b8", + "metadata": {}, + "outputs": [], + "source": [ + "IMAGE_ENCODER_PATH = output_dir / \"openvino_vision_embeddings_model.xml\"\n", + "LANGUAGE_MODEL_PATH = output_dir / \"openvino_language_model.xml\"\n", + "INPUT_EMBEDDING_PATH = output_dir / \"openvino_text_embeddings_model.xml\"\n", + "\n", + "IMAGE_PACKER_PATH = output_dir / \"openvino_image_packer.xml\"\n", + "MULTIMODAL_MERGER_PATH = output_dir / \"openvino_multimodal_merger.xml\"" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "6a0e77cd", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/prabod/anaconda3/envs/e5v/lib/python3.11/site-packages/transformers/modeling_utils.py:4481: FutureWarning: `_is_quantized_training_enabled` is going to be deprecated in transformers 4.39.0. Please use `model.hf_quantizer.is_trainable` instead\n", + " warnings.warn(\n", + "/home/prabod/anaconda3/envs/e5v/lib/python3.11/site-packages/transformers/models/clip/modeling_clip.py:276: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n", + " if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):\n", + "/home/prabod/anaconda3/envs/e5v/lib/python3.11/site-packages/transformers/models/clip/modeling_clip.py:316: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n", + " if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):\n" + ] + }, + { + "data": { + "text/plain": [ + "7397" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import torch\n", + "import openvino as ov\n", + "import gc\n", + "\n", + "\n", + "def cleanup_torchscript_cache():\n", + " \"\"\"\n", + " Helper for removing cached model representation\n", + " \"\"\"\n", + " torch._C._jit_clear_class_registry()\n", + " torch.jit._recursive.concrete_type_store = torch.jit._recursive.ConcreteTypeStore()\n", + " torch.jit._state._clear_class_state()\n", + "\n", + "\n", + "if not IMAGE_ENCODER_PATH.exists():\n", + " ov_image_encoder = ov.convert_model(\n", + " image_encoder_model, example_input=torch.zeros((1, 5, 3, 336, 336))\n", + " )\n", + " ov.save_model(ov_image_encoder, IMAGE_ENCODER_PATH)\n", + " del ov_image_encoder\n", + " cleanup_torchscript_cache()\n", + "\n", + "del image_encoder_model\n", + "gc.collect()" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "0147d547", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "117" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "llm_input = None\n", + "\n", + "llm_input = input_embedding_model(torch.ones((2, 2), dtype=torch.int64))\n", + "\n", + "if not INPUT_EMBEDDING_PATH.exists():\n", + " ov_input_embeddings_model = ov.convert_model(\n", + " input_embedding_model, example_input=torch.ones((2, 2), dtype=torch.int64)\n", + " )\n", + " ov.save_model(ov_input_embeddings_model, INPUT_EMBEDDING_PATH)\n", + " del ov_input_embeddings_model\n", + " cleanup_torchscript_cache()\n", + "\n", + "del input_embedding_model\n", + "gc.collect()" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "18b0be05", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/prabod/anaconda3/envs/e5v/lib/python3.11/site-packages/openvino/runtime/__init__.py:10: DeprecationWarning: The `openvino.runtime` module is deprecated and will be removed in the 2026.0 release. Please replace `openvino.runtime` with `openvino`.\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "from typing import Optional, Tuple, List\n", + "from openvino.runtime import opset13\n", + "import numpy as np\n", + "\n", + "\n", + "def model_has_state(ov_model: ov.Model):\n", + " return len(ov_model.get_sinks()) > 0\n", + "\n", + "\n", + "def model_has_input_output_name(ov_model: ov.Model, name: str):\n", + " \"\"\"\n", + " Helper function for checking that model has specified input or output name\n", + "\n", + " Parameters:\n", + " ov_model (ov.Model):\n", + " name (str):\n", + " name of input or output\n", + "\n", + " Returns:\n", + " True if input or output with requested name exists else False\n", + " \"\"\"\n", + " return name in sum(\n", + " [list(t.get_names()) for t in ov_model.inputs + ov_model.outputs], []\n", + " )\n", + "\n", + "\n", + "def fuse_cache_reorder(\n", + " ov_model: ov.Model,\n", + " not_kv_inputs: List[str],\n", + " key_value_input_names: List[str],\n", + " gather_dim: int,\n", + "):\n", + " \"\"\"\n", + " Fuses reored_cache during generate cycle into ov.Model. Used with stateful models, because we can not modify model state directly.\n", + "\n", + " Adds a new beam_idx parameter and Gather op per each kv-cache input in a given model.\n", + " Should be run before make_stateful. Implements optimumum's _reorder_cache\n", + " inside the model in the beginning of each iteration.\n", + " Gather works along given gather_dim dimension that may vary from model to model.\n", + " KV-cache inputs are identified based on names in key_value_input_names.\n", + " Append the new beam_idx parameter to not_kv_inputs.\n", + "\n", + " Parameters:\n", + " ov_model (`ov.Model`):\n", + " openvino model for processing\n", + " not_kv_inputs (`List[str]`):\n", + " list of input nodes in model that not related to past key values\n", + " key_value_input_names (`List[str]`):\n", + " list of names for key value input layers\n", + " gather_dim (int):\n", + " dimension for gathering cache during reorder pass\n", + " \"\"\"\n", + "\n", + " if model_has_input_output_name(ov_model, \"beam_idx\"):\n", + " raise ValueError(\"Model already has fused cache\")\n", + " input_batch = ov_model.input(\"inputs_embeds\").get_partial_shape()[0]\n", + " beam_idx = opset13.parameter(\n", + " name=\"beam_idx\", dtype=ov.Type.i32, shape=ov.PartialShape([input_batch])\n", + " )\n", + " beam_idx.output(0).get_tensor().add_names({\"beam_idx\"}) # why list is not accepted?\n", + " ov_model.add_parameters([beam_idx])\n", + " not_kv_inputs.append(ov_model.inputs[-1])\n", + " # Go over all cache parameters and fuse _reorder_cache with indices provided by the new parameter beam_idx\n", + " for input_name in key_value_input_names:\n", + " parameter_output_port = ov_model.input(input_name)\n", + " consumers = parameter_output_port.get_target_inputs()\n", + " gather = opset13.gather(\n", + " parameter_output_port, beam_idx, opset13.constant(gather_dim)\n", + " )\n", + " for consumer in consumers:\n", + " consumer.replace_source_output(gather.output(0))\n", + " ov_model.validate_nodes_and_infer_types()\n", + "\n", + "\n", + "def build_state_initializer(ov_model: ov.Model, batch_dim: int):\n", + " \"\"\"\n", + " Build initialization ShapeOf Expression for all ReadValue ops\n", + "\n", + " Parameters:\n", + " ov_model (ov.Model):\n", + " openvino model\n", + " batch_dim (int):\n", + " index of dimension corresponding to batch size\n", + " \"\"\"\n", + " input_ids = ov_model.input(\"inputs_embeds\")\n", + " batch = opset13.gather(\n", + " opset13.shape_of(input_ids, output_type=\"i64\"),\n", + " opset13.constant([0]),\n", + " opset13.constant(0),\n", + " )\n", + " for op in ov_model.get_ops():\n", + " if op.get_type_name() == \"ReadValue\":\n", + " dims = [dim.min_length for dim in list(op.get_output_partial_shape(0))]\n", + " dims[batch_dim] = batch\n", + " dims = [\n", + " (\n", + " opset13.constant(np.array([dim], dtype=np.int64))\n", + " if isinstance(dim, int)\n", + " else dim\n", + " )\n", + " for dim in dims\n", + " ]\n", + " shape = opset13.concat(dims, axis=0)\n", + " broadcast = opset13.broadcast(\n", + " opset13.constant(0.0, dtype=op.get_output_element_type(0)), shape\n", + " )\n", + " op.set_arguments([broadcast])\n", + " ov_model.validate_nodes_and_infer_types()\n", + "\n", + "\n", + "def make_stateful(\n", + " ov_model: ov.Model,\n", + " not_kv_inputs: List[str],\n", + " key_value_input_names: List[str],\n", + " key_value_output_names: List[str],\n", + " batch_dim: int,\n", + " num_attention_heads: int,\n", + " num_beams_and_batch: int = None,\n", + "):\n", + " \"\"\"\n", + " Hides kv-cache inputs and outputs inside the model as variables.\n", + "\n", + " Parameters:\n", + " ov_model (ov.Model):\n", + " openvino model\n", + " not_kv_inputs (`List[str]`):\n", + " list of input nodes in model that not related to past key values\n", + " key_value_input_names (`List[str]`):\n", + " list of names for key value input layers\n", + " key_value_output_names (`List[str]`):\n", + " list of names for key value input layers\n", + " batch_dim (int):\n", + " index of batch dimension in key value layers\n", + " num_attention_heads (int):\n", + " number of attention heads for batch dimension initialization\n", + " num_beams_an_batch (int):\n", + " precalculated number of beams and batch for shapes initialization\n", + " \"\"\"\n", + " from openvino._offline_transformations import apply_make_stateful_transformation\n", + "\n", + " input_output_map = {}\n", + "\n", + " if num_beams_and_batch is not None:\n", + " # Set batch size for input_ids and attention mask to avoid dynamic dimension got propagated from the end of the model back to ReadValue\n", + " for input in not_kv_inputs:\n", + " shape = input.get_partial_shape()\n", + " if shape.rank.get_length() <= 2: # == 1 for beam_index\n", + " shape[0] = num_beams_and_batch\n", + " input.get_node().set_partial_shape(shape)\n", + " for kv_name_pair in zip(key_value_input_names, key_value_output_names):\n", + " input_output_map[kv_name_pair[0]] = kv_name_pair[1]\n", + " if num_beams_and_batch is not None:\n", + " input = ov_model.input(kv_name_pair[0])\n", + " shape = input.get_partial_shape()\n", + " shape[batch_dim] = num_beams_and_batch * num_attention_heads\n", + " input.get_node().set_partial_shape(shape)\n", + "\n", + " if num_beams_and_batch is not None:\n", + " # Re-validation model if shapes are altered above\n", + " ov_model.validate_nodes_and_infer_types()\n", + "\n", + " apply_make_stateful_transformation(ov_model, input_output_map)\n", + " if num_beams_and_batch is None:\n", + " build_state_initializer(ov_model, batch_dim)\n", + "\n", + "\n", + "def patch_stateful(ov_model):\n", + " key_value_input_names = [key.get_any_name() for key in ov_model.inputs[2:-1]]\n", + " key_value_output_names = [key.get_any_name() for key in ov_model.outputs[1:]]\n", + " not_kv_inputs = [\n", + " input\n", + " for input in ov_model.inputs\n", + " if not any(name in key_value_input_names for name in input.get_names())\n", + " ]\n", + " if not key_value_input_names or not key_value_output_names:\n", + " return\n", + " batch_dim = 0\n", + " num_attention_heads = 1\n", + "\n", + " fuse_cache_reorder(ov_model, not_kv_inputs, key_value_input_names, batch_dim)\n", + " make_stateful(\n", + " ov_model,\n", + " not_kv_inputs,\n", + " key_value_input_names,\n", + " key_value_output_names,\n", + " batch_dim,\n", + " num_attention_heads,\n", + " None,\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "5cd69acd", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading checkpoint shards: 100%|██████████| 4/4 [00:04<00:00, 1.00s/it]\n", + "/home/prabod/anaconda3/envs/e5v/lib/python3.11/site-packages/transformers/modeling_utils.py:4481: FutureWarning: `_is_quantized_training_enabled` is going to be deprecated in transformers 4.39.0. Please use `model.hf_quantizer.is_trainable` instead\n", + " warnings.warn(\n", + "/home/prabod/anaconda3/envs/e5v/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py:1060: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n", + " if sequence_length != 1:\n", + "/home/prabod/anaconda3/envs/e5v/lib/python3.11/site-packages/torch/jit/_trace.py:165: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations. (Triggered internally at /pytorch/build/aten/src/ATen/core/TensorBody.h:489.)\n", + " if a.grad is not None:\n" + ] + } + ], + "source": [ + "import types\n", + "\n", + "make_stateful_model = False\n", + "core = ov.Core()\n", + "model = LlavaNextForConditionalGeneration.from_pretrained(\n", + " model_id, low_cpu_mem_usage=True\n", + ")\n", + "language_model = model.language_model\n", + "if not LANGUAGE_MODEL_PATH.exists() or True:\n", + "\n", + " def forward_wrap(\n", + " self,\n", + " attention_mask,\n", + " position_ids=None,\n", + " past_key_values=None,\n", + " inputs_embeds=None,\n", + " ):\n", + " result = self._orig_forward(\n", + " input_ids=None,\n", + " attention_mask=attention_mask,\n", + " position_ids=position_ids,\n", + " past_key_values=past_key_values,\n", + " inputs_embeds=inputs_embeds,\n", + " output_hidden_states=True,\n", + " return_dict=True,\n", + " )\n", + " return result[\"hidden_states\"][-1][:, -1, :]\n", + "\n", + " model_inputs = [\"attention_mask\", \"position_ids\"]\n", + " model_outputs = [\"last_hidden_state\"]\n", + " model_inputs.append(\"inputs_embeds\")\n", + " language_model.config.torchscript = True\n", + " position_ids = torch.tensor([[2, 3], [2, 3]])\n", + " language_model._orig_forward = language_model.forward\n", + " language_model.forward = types.MethodType(forward_wrap, language_model)\n", + " ov_model = ov.convert_model(\n", + " language_model,\n", + " example_input={\n", + " \"inputs_embeds\": llm_input,\n", + " \"attention_mask\": torch.ones((2, 4)),\n", + " \"position_ids\": position_ids,\n", + " },\n", + " )\n", + "\n", + " for input, input_name in zip(ov_model.inputs, model_inputs):\n", + " input.get_tensor().set_names({input_name})\n", + "\n", + " for output, output_name in zip(ov_model.outputs, model_outputs):\n", + " output.get_tensor().set_names({output_name})\n", + " if make_stateful_model:\n", + " patch_stateful(ov_model)\n", + " ov.save_model(ov_model, LANGUAGE_MODEL_PATH)\n", + " del ov_model\n", + " cleanup_torchscript_cache()\n", + " gc.collect()" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "49838499", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "WARNING:nncf:NNCF provides best results with torch==2.6.*, while current torch version is 2.7.0+cpu. If you encounter issues, consider switching to torch==2.6.*\n", + "INFO:nncf:Statistics of the bitwidth distribution:\n", + "┍━━━━━━━━━━━━━━━━━━━━━━━━━━━┯━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┯━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┑\n", + "│ Weight compression mode │ % all parameters (layers) │ % ratio-defining parameters (layers) │\n", + "┝━━━━━━━━━━━━━━━━━━━━━━━━━━━┿━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┿━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┥\n", + "│ int8_asym │ 1% (1 / 224) │ 0% (0 / 223) │\n", + "├───────────────────────────┼─────────────────────────────┼────────────────────────────────────────┤\n", + "│ int4_asym │ 99% (223 / 224) │ 100% (223 / 223) │\n", + "┕━━━━━━━━━━━━━━━━━━━━━━━━━━━┷━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┷━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┙\n" + ] + }, + { + "data": { + "text/html": [ + "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">/home/prabod/anaconda3/envs/e5v/lib/python3.11/site-packages/rich/live.py:231: UserWarning: install \"ipywidgets\" \n", + "for Jupyter support\n", + " warnings.warn('install \"ipywidgets\" for Jupyter support')\n", + "</pre>\n" + ], + "text/plain": [ + "/home/prabod/anaconda3/envs/e5v/lib/python3.11/site-packages/rich/live.py:231: UserWarning: install \"ipywidgets\" \n", + "for Jupyter support\n", + " warnings.warn('install \"ipywidgets\" for Jupyter support')\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import nncf\n", + "\n", + "compression_configuration = {\n", + " \"mode\": nncf.CompressWeightsMode.INT4_ASYM,\n", + " \"group_size\": 64,\n", + " \"ratio\": 1.0,\n", + "}\n", + "LANGUAGE_MODEL_PATH_INT4 = (\n", + " LANGUAGE_MODEL_PATH.parent / LANGUAGE_MODEL_PATH.name.replace(\".xml\", \"-int4.xml\")\n", + ")\n", + "ov_model = core.read_model(LANGUAGE_MODEL_PATH)\n", + "ov_model_compressed = nncf.compress_weights(ov_model, **compression_configuration)\n", + "ov.save_model(ov_model_compressed, LANGUAGE_MODEL_PATH_INT4)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "695c2fbf", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "\n", + "\n", + "class UnpadImage(nn.Module):\n", + " def __init__(self):\n", + " super(UnpadImage, self).__init__()\n", + "\n", + " def forward(self, tensor, original_size, current_size):\n", + " \"\"\"\n", + " Unpads an image tensor to its original size based on the current size.\n", + " Args:\n", + " tensor (torch.Tensor): The input image tensor of shape (C, H, W).\n", + " original_size (torch.Tensor): The original size of the image tensor as (H, W).\n", + " current_size (torch.Tensor): The current size of the image tensor as (H, W).\n", + " \"\"\"\n", + " # tensor: (C, H, W)\n", + " original_size = original_size.to(torch.float32)\n", + " original_height, original_width = original_size[0], original_size[1]\n", + " current_height, current_width = current_size[0], current_size[1]\n", + "\n", + " original_aspect_ratio = original_width / original_height\n", + " current_aspect_ratio = current_width / current_height\n", + "\n", + " # Comparison\n", + " condition = original_aspect_ratio > current_aspect_ratio\n", + "\n", + " # Branch 1: vertical padding\n", + " scale_factor_1 = current_width.float() / original_width.float()\n", + " new_height = (original_height.float() * scale_factor_1).int()\n", + " pad_top = ((current_height.float() - new_height) / 2).floor().long()\n", + "\n", + " # Branch 2: horizontal padding\n", + " scale_factor_2 = current_height.float() / original_height.float()\n", + " new_width = (original_width.float() * scale_factor_2).int()\n", + " pad_left = ((current_width.float() - new_width) / 2).floor().long()\n", + "\n", + " zero = torch.zeros(1, dtype=pad_top.dtype, device=tensor.device).squeeze(0)\n", + "\n", + " # Use torch.where to conditionally compute slicing\n", + " y_start = torch.where(condition, pad_top, zero)\n", + " y_end = torch.where(condition, current_height - pad_top, current_height)\n", + "\n", + " x_start = torch.where(condition, zero, pad_left)\n", + " x_end = torch.where(condition, current_width - pad_left, current_width)\n", + " out = tensor[:, y_start.int() : y_end.int(), x_start.int() : x_end.int()]\n", + " return out # Remove batch dimension if needed\n" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "ba325001", + "metadata": {}, + "outputs": [], + "source": [ + "import math\n", + "\n", + "\n", + "class PackImageFeatures(nn.Module):\n", + " def __init__(self, config):\n", + " super().__init__()\n", + " self.config = config\n", + " self.unpad_image = UnpadImage()\n", + " self.height = config.vision_config.image_size // config.vision_config.patch_size\n", + " self.width = config.vision_config.image_size // config.vision_config.patch_size\n", + "\n", + " def forward(self, image_feature, image_sizes, num_patch_height, num_patch_width):\n", + " # we image features is a single image features, so we can remove the loop\n", + " base_image_features = image_feature[0]\n", + " features = image_feature[1:] # Skip the first token\n", + " features = (\n", + " features.view(\n", + " num_patch_height, num_patch_width, self.height, self.width, -1\n", + " )\n", + " .permute(4, 0, 2, 1, 3)\n", + " .contiguous()\n", + " .flatten(1, 2)\n", + " .flatten(2, 3)\n", + " )\n", + " features = self.unpad_image(\n", + " features, image_sizes[0], torch._shape_as_tensor(features)[1:3]\n", + " )\n", + " features = features.flatten(1, 2).transpose(0, 1)\n", + " features = torch.cat([base_image_features, features], dim=0)\n", + " return features.unsqueeze(0)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "f911f0a0", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "\n", + "\n", + "class MergeInputWithImageFeatures(nn.Module):\n", + " def __init__(self, pad_token_id=0, image_token_index=0):\n", + " super().__init__()\n", + " self.pad_token_id = pad_token_id\n", + " self.image_token_index = image_token_index\n", + "\n", + " def forward(self, image_features, inputs_embeds, input_ids, attention_mask):\n", + " num_images, num_image_patches, embed_dim = image_features.shape\n", + " batch_size, sequence_length = input_ids.shape\n", + "\n", + " # left_padding = torch.sum(input_ids[:, -1] == self.pad_token_id) == 0 # Removed, not needed now\n", + "\n", + " special_image_token_mask = input_ids == self.image_token_index # [B, S]\n", + " num_special_image_tokens = special_image_token_mask.sum(dim=-1) # [B]\n", + "\n", + " max_embed_dim = (\n", + " num_special_image_tokens.max() * (num_image_patches - 1)\n", + " ) + sequence_length # scalar\n", + "\n", + " batch_indices, non_image_indices = torch.where(\n", + " input_ids != self.image_token_index\n", + " ) # [N], [N]\n", + "\n", + " # Step 2: Compute new token positions\n", + " new_token_positions = (\n", + " torch.cumsum(special_image_token_mask * (num_image_patches - 1) + 1, dim=-1)\n", + " - 1\n", + " ) # [B, S]\n", + "\n", + " nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1] # [B]\n", + "\n", + " # left_padding_flag = (input_ids[:, -1] != self.pad_token_id).to(nb_image_pad.dtype) # original\n", + " left_padding_flag = (\n", + " input_ids[:, -1] != self.pad_token_id\n", + " ).long() # more idiomatic torch\n", + " # new_token_positions = new_token_positions + (left_padding_flag[:, None] * nb_image_pad[:, None]) # original\n", + " new_token_positions += (\n", + " left_padding_flag[:, None] * nb_image_pad[:, None]\n", + " ) # updated\n", + "\n", + " text_to_overwrite = new_token_positions[batch_indices, non_image_indices] # [N]\n", + "\n", + " # Step 3: Init final tensors\n", + " final_embedding = torch.zeros(\n", + " batch_size,\n", + " max_embed_dim,\n", + " embed_dim,\n", + " dtype=inputs_embeds.dtype,\n", + " device=inputs_embeds.device,\n", + " )\n", + " final_attention_mask = torch.zeros(\n", + " batch_size,\n", + " max_embed_dim,\n", + " dtype=attention_mask.dtype,\n", + " device=inputs_embeds.device,\n", + " )\n", + "\n", + " # final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices] # original\n", + " final_embedding.index_put_(\n", + " (batch_indices, text_to_overwrite),\n", + " inputs_embeds[batch_indices, non_image_indices],\n", + " ) # torch native\n", + "\n", + " # final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices] # original\n", + " final_attention_mask.index_put_(\n", + " (batch_indices, text_to_overwrite),\n", + " attention_mask[batch_indices, non_image_indices],\n", + " ) # torch native\n", + "\n", + " # Step 5: fill in image features\n", + " image_to_overwrite = (final_embedding == 0).all(dim=-1) # [B, L]\n", + " image_to_overwrite &= (image_to_overwrite.cumsum(-1) - 1) >= nb_image_pad[\n", + " :, None\n", + " ] # apply pad cutoff\n", + "\n", + " flat_image_features = image_features.reshape(-1, embed_dim).to(\n", + " inputs_embeds.device\n", + " ) # [N_img, D]\n", + "\n", + " # final_embedding[image_to_overwrite] = flat_image_features # original\n", + " final_embedding[image_to_overwrite] = flat_image_features[\n", + " : image_to_overwrite.sum()\n", + " ] # safe assignment\n", + "\n", + " final_attention_mask |= image_to_overwrite # logical or with existing mask\n", + "\n", + " position_ids = final_attention_mask.cumsum(-1) - 1\n", + " position_ids = position_ids.masked_fill(final_attention_mask == 0, 1)\n", + "\n", + " # Step 6: remove pad token embeddings\n", + " batch_pad_indices, pad_token_positions = torch.where(\n", + " input_ids == self.pad_token_id\n", + " ) # [N_pad]\n", + " indices_to_mask = new_token_positions[\n", + " batch_pad_indices, pad_token_positions\n", + " ] # [N_pad]\n", + "\n", + " # final_embedding[batch_pad_indices, indices_to_mask] = 0 # original\n", + " final_embedding.index_put_(\n", + " (batch_pad_indices, indices_to_mask),\n", + " torch.zeros_like(final_embedding[batch_pad_indices, indices_to_mask]),\n", + " ) # updated\n", + "\n", + " return {\n", + " \"final_embedding\": final_embedding,\n", + " \"final_attention_mask\": final_attention_mask,\n", + " \"position_ids\": position_ids,\n", + " }\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bfb25757", + "metadata": {}, + "outputs": [], + "source": [ + "# compile the models\n", + "language_model = core.read_model(LANGUAGE_MODEL_PATH)\n", + "compiled_language_model = core.compile_model(language_model, \"AUTO\")\n", + "\n", + "image_embed_model = core.compile_model(IMAGE_ENCODER_PATH, device)\n", + "text_embeddings_model = core.compile_model(INPUT_EMBEDDING_PATH, device)\n", + "\n", + "if IMAGE_PACKER_PATH.exists():\n", + " image_packer_model = core.compile_model(IMAGE_PACKER_PATH, device)\n", + "else:\n", + " image_packer_model = None\n", + "if MULTIMODAL_MERGER_PATH.exists()\n", + " multimodal_merger_model = core.compile_model(MULTIMODAL_MERGER_PATH, device)\n", + "else:\n", + " multimodal_merger_model = None\n", + "\n", + "# multimodal_merger_model = core.compile_model(MODEL_MERGER_PATH, device)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "0d5643ee", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n", + "/home/prabod/anaconda3/envs/e5v/lib/python3.11/site-packages/huggingface_hub/file_download.py:943: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Image size: (360, 282), Mode: RGB\n" + ] + } + ], + "source": [ + "import torch\n", + "import torch.nn.functional as F\n", + "import requests\n", + "from PIL import Image\n", + "from transformers import AutoTokenizer, AutoConfig\n", + "from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration\n", + "\n", + "llama3_template = \"<|start_header_id|>user<|end_header_id|>\\n\\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\n \\n\"\n", + "\n", + "processor = LlavaNextProcessor.from_pretrained(\"royokong/e5-v\")\n", + "\n", + "config = AutoConfig.from_pretrained(\"royokong/e5-v\")\n", + "img_prompt = llama3_template.format(\"<image>\\nSummary above image in one word: \")\n", + "text_prompt = llama3_template.format(\"<sent>\\nSummary above sentence in one word: \")\n", + "\n", + "images = [Image.open(\"dog.jpg\").convert(\"RGB\")]\n", + "\n", + "for image in images:\n", + " print(f\"Image size: {image.size}, Mode: {image.mode}\")\n", + "\n", + "texts = [\"A dog sitting in the grass.\"]\n", + "\n", + "text_inputs = processor(\n", + " [text_prompt.replace(\"<sent>\", text) for text in texts],\n", + " return_tensors=\"pt\",\n", + " padding=True,\n", + ")\n", + "img_inputs = processor(\n", + " [img_prompt] * len(images), images, return_tensors=\"pt\", padding=True\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "ad1b402c", + "metadata": {}, + "outputs": [], + "source": [ + "img_input_ids = img_inputs[\"input_ids\"]\n", + "img_attention_mask = img_inputs[\"attention_mask\"]\n", + "image_sizes = img_inputs[\"image_sizes\"]\n", + "pixel_values = img_inputs[\"pixel_values\"]\n", + "\n", + "text_input_ids = text_inputs[\"input_ids\"]\n", + "text_attention_mask = text_inputs[\"attention_mask\"]\n" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "2649d101", + "metadata": {}, + "outputs": [], + "source": [ + "image_features = torch.from_numpy(image_embed_model(pixel_values)[0])\n", + "image_inputs_embeds = torch.from_numpy(text_embeddings_model(img_input_ids)[0])\n", + "text_inputs_embeds = torch.from_numpy(text_embeddings_model(text_input_ids)[0])\n" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "844968c5", + "metadata": {}, + "outputs": [], + "source": [ + "image_packer = PackImageFeatures(config)\n", + "input_merger = MergeInputWithImageFeatures(\n", + " pad_token_id=processor.tokenizer.pad_token_id,\n", + " image_token_index=config.image_token_index,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "190da649", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "from typing import Union, List, Tuple\n", + "import torch\n", + "\n", + "\n", + "def select_best_resolution(original_size: tuple, possible_resolutions: list) -> tuple:\n", + " \"\"\"\n", + " Selects the best resolution from a list of possible resolutions based on the original size.\n", + "\n", + " This is done by calculating the effective and wasted resolution for each possible resolution.\n", + "\n", + " The best fit resolution is the one that maximizes the effective resolution and minimizes the wasted resolution.\n", + "\n", + " Args:\n", + " original_size (tuple):\n", + " The original size of the image in the format (height, width).\n", + " possible_resolutions (list):\n", + " A list of possible resolutions in the format [(height1, width1), (height2, width2), ...].\n", + "\n", + " Returns:\n", + " tuple: The best fit resolution in the format (height, width).\n", + " \"\"\"\n", + " original_height, original_width = original_size\n", + " best_fit = None\n", + " max_effective_resolution = 0\n", + " min_wasted_resolution = float(\"inf\")\n", + "\n", + " for height, width in possible_resolutions:\n", + " scale = min(width / original_width, height / original_height)\n", + " downscaled_width, downscaled_height = (\n", + " int(original_width * scale),\n", + " int(original_height * scale),\n", + " )\n", + " effective_resolution = min(\n", + " downscaled_width * downscaled_height, original_width * original_height\n", + " )\n", + " wasted_resolution = (width * height) - effective_resolution\n", + "\n", + " if effective_resolution > max_effective_resolution or (\n", + " effective_resolution == max_effective_resolution\n", + " and wasted_resolution < min_wasted_resolution\n", + " ):\n", + " max_effective_resolution = effective_resolution\n", + " min_wasted_resolution = wasted_resolution\n", + " best_fit = (height, width)\n", + "\n", + " return best_fit\n", + "\n", + "\n", + "def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int):\n", + " \"\"\"\n", + " Calculate the number of patches after the preprocessing for images of any resolution.\n", + "\n", + " Args:\n", + " image_size (`Union[torch.LongTensor, np.ndarray, Tuple[int, int]):\n", + " The size of the input image in the format (height, width). ?\n", + " grid_pinpoints (`List`):\n", + " A list containing possible resolutions. Each item in the list should be a tuple or list\n", + " of the form `(height, width)`.\n", + " patch_size (`int`):\n", + " The size of each image patch.\n", + "\n", + " Returns:\n", + " int: the number of patches\n", + " \"\"\"\n", + " if not isinstance(grid_pinpoints, list):\n", + " raise ValueError(\"grid_pinpoints should be a list of tuples or lists\")\n", + "\n", + " # ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate\n", + " if not isinstance(image_size, (list, tuple)):\n", + " if not isinstance(image_size, (torch.Tensor, np.ndarray)):\n", + " raise ValueError(\n", + " f\"image_size invalid type {type(image_size)} with value {image_size}\"\n", + " )\n", + " image_size = image_size.tolist()\n", + "\n", + " best_resolution = select_best_resolution(image_size, grid_pinpoints)\n", + " height, width = best_resolution\n", + " num_patches = 0\n", + " # consider change to ceil(height/patch_size)*ceil(width/patch_size) + 1\n", + " for i in range(0, height, patch_size):\n", + " for j in range(0, width, patch_size):\n", + " num_patches += 1\n", + " # add the base patch\n", + " num_patches += 1\n", + " return num_patches\n", + "\n", + "\n", + "def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):\n", + " \"\"\"\n", + " Calculate the shape of the image patch grid after the preprocessing for images of any resolution.\n", + "\n", + " Args:\n", + " image_size (`tuple`):\n", + " The size of the input image in the format (width, height).\n", + " grid_pinpoints (`List`):\n", + " A list containing possible resolutions. Each item in the list should be a tuple or list\n", + " of the form `(height, width)`.\n", + " patch_size (`int`):\n", + " The size of each image patch.\n", + "\n", + " Returns:\n", + " tuple: The shape of the image patch grid in the format (width, height).\n", + " \"\"\"\n", + " if not isinstance(grid_pinpoints, list):\n", + " raise ValueError(\"grid_pinpoints should be a list of tuples or lists\")\n", + "\n", + " # ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate\n", + " if not isinstance(image_size, (list, tuple)):\n", + " if not isinstance(image_size, (torch.Tensor, np.ndarray)):\n", + " raise ValueError(\n", + " f\"image_size invalid type: {type(image_size)} not valid, should be either list, tuple, np.ndarray or tensor\"\n", + " )\n", + " image_size = image_size.tolist()\n", + "\n", + " height, width = select_best_resolution(image_size, grid_pinpoints)\n", + " return height // patch_size, width // patch_size" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "bcbec245", + "metadata": {}, + "outputs": [], + "source": [ + "num_patch_width, num_patch_height = get_anyres_image_grid_shape(\n", + " image_sizes[0],\n", + " config.image_grid_pinpoints,\n", + " config.vision_config.image_size,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "40620525", + "metadata": {}, + "outputs": [], + "source": [ + "packed_image_features = image_packer(\n", + " image_features,\n", + " image_sizes,\n", + " num_patch_height=num_patch_height,\n", + " num_patch_width=num_patch_width\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "0eb947f8", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "if IMAGE_PACKER_PATH.exists():\n", + " IMAGE_PACKER_PATH.unlink()\n", + "\n", + "ov_image_packer = ov.convert_model(\n", + " image_packer,\n", + " example_input={\n", + " \"image_feature\": image_features,\n", + " \"image_sizes\": image_sizes,\n", + " \"num_patch_height\": torch.tensor(num_patch_height, dtype=torch.int64),\n", + " \"num_patch_width\": torch.tensor(num_patch_width, dtype=torch.int64)\n", + " }\n", + ")\n", + "ov.save_model(ov_image_packer, IMAGE_PACKER_PATH)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "f2fae423", + "metadata": {}, + "outputs": [], + "source": [ + "if MULTIMODAL_MERGER_PATH.exists():\n", + " MULTIMODAL_MERGER_PATH.unlink()\n", + "ov_multimodal_merger = ov.convert_model(\n", + " input_merger,\n", + " example_input={\n", + " \"image_features\": packed_image_features,\n", + " \"inputs_embeds\": image_inputs_embeds,\n", + " \"input_ids\": img_input_ids,\n", + " \"attention_mask\": img_attention_mask\n", + " }\n", + ")\n", + "ov.save_model(ov_multimodal_merger, MULTIMODAL_MERGER_PATH)\n", + "cleanup_torchscript_cache()" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "0599dd94", + "metadata": {}, + "outputs": [], + "source": [ + "import shutil\n", + "import os\n", + "if not os.path.exists(f\"{output_dir}/assets\"):\n", + " output_dir = Path(output_dir)\n", + " assets_dir = output_dir/\"assets\"\n", + " assets_dir.mkdir(exist_ok=True)\n", + " processor.save_pretrained(output_dir)\n", + " # copy all the assets to the assets directory (json files, vocab files, etc.)\n", + " for file in output_dir.glob(\"*.json\"):\n", + " shutil.copy(file, assets_dir)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "27e894ab", + "metadata": {}, + "outputs": [], + "source": [ + "# delete the f32 language model\n", + "if LANGUAGE_MODEL_PATH.exists():\n", + " LANGUAGE_MODEL_PATH.unlink()\n", + "\n", + "# delete the f32 language model bin file if exists\n", + "if LANGUAGE_MODEL_PATH.with_suffix(\".bin\").exists():\n", + " LANGUAGE_MODEL_PATH.with_suffix(\".bin\").unlink()" + ] + }, + { + "cell_type": "markdown", + "id": "ff9ecebb", + "metadata": {}, + "source": [ + "## 2. Test the Exported model" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "17eaa581", + "metadata": {}, + "outputs": [], + "source": [ + "IMAGE_ENCODER_PATH = output_dir / \"openvino_vision_embeddings_model.xml\"\n", + "LANGUAGE_MODEL_PATH = output_dir / \"openvino_language_model-int4.xml\"\n", + "INPUT_EMBEDDING_PATH = output_dir / \"openvino_text_embeddings_model.xml\"\n", + "\n", + "IMAGE_PACKER_PATH = output_dir / \"openvino_image_packer.xml\"\n", + "MULTIMODAL_MERGER_PATH = output_dir / \"openvino_multimodal_merger.xml\"" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "0782a4a9", + "metadata": {}, + "outputs": [], + "source": [ + "# compile the models\n", + "language_model = core.read_model(LANGUAGE_MODEL_PATH)\n", + "compiled_language_model = core.compile_model(language_model, \"AUTO\")\n", + "\n", + "image_embed_model = core.compile_model(IMAGE_ENCODER_PATH, device)\n", + "text_embeddings_model = core.compile_model(INPUT_EMBEDDING_PATH, device)\n", + "\n", + "if IMAGE_PACKER_PATH.exists():\n", + " image_packer_model = core.compile_model(IMAGE_PACKER_PATH, device)\n", + "else:\n", + " image_packer_model = None\n", + "if MULTIMODAL_MERGER_PATH.exists():\n", + " multimodal_merger_model = core.compile_model(MULTIMODAL_MERGER_PATH, device)\n", + "else:\n", + " multimodal_merger_model = None\n" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "4b88a40c", + "metadata": {}, + "outputs": [], + "source": [ + "# use openvino model to pack the image features\n", + "packed_image_features = image_packer_model({\n", + " 'image_feature': image_features,\n", + " 'image_sizes': image_sizes,\n", + " 'num_patch_height': torch.tensor(num_patch_height, dtype=torch.int64),\n", + " 'num_patch_width': torch.tensor(num_patch_width, dtype=torch.int64)\n", + "})[0]\n", + "packed_image_features = torch.from_numpy(packed_image_features)" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "1a69b30b", + "metadata": {}, + "outputs": [], + "source": [ + "# use openvino model to merge the image features with text features\n", + "merger_out = multimodal_merger_model({\n", + " \"image_features\": packed_image_features,\n", + " \"inputs_embeds\": image_inputs_embeds,\n", + " \"input_ids\": img_input_ids,\n", + " \"attention_mask\": img_attention_mask\n", + " }\n", + ")\n", + "image_final_embeds = torch.from_numpy(merger_out['final_embedding'])\n", + "image_final_attention_mask = torch.from_numpy(merger_out['final_attention_mask'])\n", + "image_position_ids = torch.from_numpy(merger_out['position_ids'])" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "131763dc", + "metadata": {}, + "outputs": [], + "source": [ + "request = compiled_language_model.create_infer_request()\n", + "img_input_lm = {\n", + " \"inputs_embeds\": image_final_embeds.detach().numpy(),\n", + " \"attention_mask\": image_final_attention_mask.detach().numpy(),\n", + " \"position_ids\": image_position_ids.detach().numpy(),\n", + "}\n", + "request.start_async(img_input_lm, share_inputs=True)\n", + "request.wait()\n", + "img_lm_output = torch.from_numpy(request.get_tensor(\"last_hidden_state\").data)" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "id": "68787196", + "metadata": {}, + "outputs": [], + "source": [ + "text_request = compiled_language_model.create_infer_request()\n", + "text_position_ids = text_attention_mask.long().cumsum(-1) - 1\n", + "text_position_ids.masked_fill_(text_attention_mask == 0, 1)\n", + "text_input_lm = {\n", + " \"inputs_embeds\": text_inputs_embeds.detach().numpy(),\n", + " \"attention_mask\": text_attention_mask.detach().numpy(),\n", + " \"position_ids\": text_position_ids.detach().numpy(),\n", + "}\n", + "text_request.start_async(text_input_lm, share_inputs=True)\n", + "text_request.wait()\n", + "text_lm_output = torch.from_numpy(text_request.get_tensor(\"last_hidden_state\").data)" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "df6a5ae1", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[0.7158]])\n" + ] + } + ], + "source": [ + "import torch.nn.functional as F\n", + "\n", + "txt_embed = F.normalize(text_lm_output, dim=-1)\n", + "img_embed = F.normalize(img_lm_output, dim=-1)\n", + "\n", + "print(txt_embed @ img_embed.T)" + ] + }, + { + "cell_type": "markdown", + "id": "3764af1b", + "metadata": {}, + "source": [ + "## 3 Import and Save E5V in Spark NLP" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "265ecf82", + "metadata": {}, + "outputs": [], + "source": [ + "! wget -q http://setup.johnsnowlabs.com/colab.sh -O - | bash" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "285bb60c", + "metadata": {}, + "outputs": [], + "source": [ + "import sparknlp\n", + "\n", + "# let's start Spark with Spark NLP\n", + "spark = sparknlp.start()\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "18611787", + "metadata": {}, + "outputs": [], + "source": [ + "model_id = \"royokong/e5-v\"" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "8ca2060a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "25/06/10 03:45:32 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.\n", + "25/06/10 03:45:41 WARN NativeLibrary: Failed to load library null: java.lang.UnsatisfiedLinkError: Can't load library: /tmp/openvino-native4021672575912693842/libtbb.so.2\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING: An illegal reflective access operation has occurred\n", + "WARNING: Illegal reflective access by org.apache.spark.util.SizeEstimator$ (file:/home/prabod/spark/jars/spark-core_2.12-3.3.2.jar) to field java.util.regex.Pattern.pattern\n", + "WARNING: Please consider reporting this to the maintainers of org.apache.spark.util.SizeEstimator$\n", + "WARNING: Use --illegal-access=warn to enable warnings of further illegal reflective access operations\n", + "WARNING: All illegal access operations will be denied in a future release\n" + ] + } + ], + "source": [ + "e5v_embeddings_sn = E5VEmbeddings \\\n", + " .loadSavedModel(str(output_dir),spark) \\\n", + " .setInputCols(\"image_assembler\") \\\n", + " .setOutputCol(\"answer\")" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "id": "d5b60572", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " \r" + ] + } + ], + "source": [ + "e5v_embeddings_sn.write().overwrite().save(f\"file:///tmp/{model_id}_spark_nlp\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dd9cf656", + "metadata": {}, + "outputs": [], + "source": [ + "import sparknlp\n", + "from sparknlp.base import *\n", + "from sparknlp.annotator import *\n", + "from pyspark.sql.functions import lit\n", + "from pyspark.ml import Pipeline\n", + "from sparknlp.util import EmbeddingsDataFrameUtils\n", + "\n", + "from pathlib import Path\n", + "import os\n", + "\n", + "# download two images to test into ./images folder\n", + "\n", + "url1 = \"https://github.com/openvinotoolkit/openvino_notebooks/assets/29454499/d5fbbd1a-d484-415c-88cb-9986625b7b11\"\n", + "\n", + "Path(\"images\").mkdir(exist_ok=True)\n", + "\n", + "!wget -q -O images/image1.jpg {url1}\n", + "\n", + "\n", + "\n", + "images_path = \"file://\" + os.getcwd() + \"/images/\"\n", + "image_df = spark.read.format(\"image\").load(\n", + " path=images_path\n", + ")\n", + "\n", + "imagePrompt = \"<|start_header_id|>user<|end_header_id|>\\n\\n<image>\\\\nSummary above image in one word: <|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\n \\n\"\n", + "image_df = spark.read.format(\"image\").option(\"dropInvalid\", True).load(images_path)\n", + "test_df = image_df.withColumn(\"text\", lit(imagePrompt))\n", + "\n", + "textPrompt = \"<|start_header_id|>user<|end_header_id|>\\n\\n<sent>\\\\nSummary above sentence in one word: <|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\n \\n\"\n", + "textDesc = \"A cat sitting in a box.\"\n", + "nullImageDF = spark.createDataFrame(\n", + " [EmbeddingsDataFrameUtils.emptyImageRow], schema=\n", + " EmbeddingsDataFrameUtils.imageSchema)\n", + "textDF = nullImageDF.withColumn(\"text\", lit(textPrompt.replace(\"<sent>\", textDesc)))\n", + "\n", + "test_df = test_df.union(textDF)\n", + "\n", + "imageAssembler = ImageAssembler() \\\n", + " .setInputCol(\"image\") \\\n", + " .setOutputCol(\"image_assembler\")\n", + "e5v = E5VEmbeddings.load(f\"file:///tmp/{model_id}_spark_nlp\") \\\n", + " .setInputCols([\"image_assembler\"]) \\\n", + " .setOutputCol(\"e5v\")\n", + "pipeline = Pipeline().setStages([imageAssembler, e5v])\n", + "results = pipeline.fit(test_df).transform(test_df)\n", + "results.select(\"e5v.embeddings\").show(truncate=True)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "e5v", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/python/sparknlp/annotator/embeddings/__init__.py b/python/sparknlp/annotator/embeddings/__init__.py index da453d2c555037..f93ac2e3b11ec4 100644 --- a/python/sparknlp/annotator/embeddings/__init__.py +++ b/python/sparknlp/annotator/embeddings/__init__.py @@ -41,3 +41,4 @@ from sparknlp.annotator.embeddings.snowflake_embeddings import * from sparknlp.annotator.embeddings.nomic_embeddings import * from sparknlp.annotator.embeddings.auto_gguf_embeddings import * +from sparknlp.annotator.embeddings.e5v_embeddings import * \ No newline at end of file diff --git a/python/sparknlp/annotator/embeddings/e5v_embeddings.py b/python/sparknlp/annotator/embeddings/e5v_embeddings.py new file mode 100644 index 00000000000000..e8ee518a40333e --- /dev/null +++ b/python/sparknlp/annotator/embeddings/e5v_embeddings.py @@ -0,0 +1,138 @@ +# Copyright 2017-2024 John Snow Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from sparknlp.common import * + +class E5VEmbeddings(AnnotatorModel, + HasBatchedAnnotateImage, + HasImageFeatureProperties, + HasEngine, + HasRescaleFactor): + """Universal multimodal embeddings using the E5-V model (see https://huggingface.co/royokong/e5-v). + + E5-V bridges the modality gap between different input types (text, image) and demonstrates strong performance in multimodal embeddings, even without fine-tuning. It also supports a single-modality training approach, where the model is trained exclusively on text pairs, often yielding better performance than multimodal training. + + Pretrained models can be loaded with :meth:`.pretrained` of the companion object: + + >>> e5vEmbeddings = E5VEmbeddings.pretrained() \ + ... .setInputCols(["image_assembler"]) \ + ... .setOutputCol("e5v") + + The default model is ``"e5v_int4"``, if no name is provided. + + For available pretrained models please see the `Models Hub <https://sparknlp.org/models?task=Question+Answering>`__. + + ====================== ====================== + Input Annotation types Output Annotation type + ====================== ====================== + ``IMAGE`` ``SENTENCE_EMBEDDINGS`` + ====================== ====================== + + Examples + -------- + Image + Text Embedding: + >>> import sparknlp + >>> from sparknlp.base import * + >>> from sparknlp.annotator import * + >>> from pyspark.ml import Pipeline + >>> image_df = spark.read.format("image").option("dropInvalid", value = True).load(imageFolder) + >>> imagePrompt = "<|start_header_id|>user<|end_header_id|>\n\n<image>\\nSummary above image in one word: <|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n" + >>> test_df = image_df.withColumn("text", lit(imagePrompt)) + >>> imageAssembler = ImageAssembler() \ + ... .setInputCol("image") \ + ... .setOutputCol("image_assembler") + >>> e5vEmbeddings = E5VEmbeddings.pretrained() \ + ... .setInputCols(["image_assembler"]) \ + ... .setOutputCol("e5v") + >>> pipeline = Pipeline().setStages([ + ... imageAssembler, + ... e5vEmbeddings + ... ]) + >>> result = pipeline.fit(test_df).transform(test_df) + >>> result.select("e5v.embeddings").show(truncate = False) + + Text-Only Embedding: + >>> from sparknlp.util import EmbeddingsDataFrameUtils + >>> textPrompt = "<|start_header_id|>user<|end_header_id|>\n\n<sent>\\nSummary above sentence in one word: <|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n" + >>> textDesc = "A cat sitting in a box." + >>> nullImageDF = spark.createDataFrame(spark.sparkContext.parallelize([EmbeddingsDataFrameUtils.emptyImageRow]), EmbeddingsDataFrameUtils.imageSchema) + >>> textDF = nullImageDF.withColumn("text", lit(textPrompt.replace("<sent>", textDesc))) + >>> e5vEmbeddings = E5VEmbeddings.pretrained() \ + ... .setInputCols(["image"]) \ + ... .setOutputCol("e5v") + >>> result = e5vEmbeddings.transform(textDF) + >>> result.select("e5v.embeddings").show(truncate = False) + """ + + name = "E5VEmbeddings" + + inputAnnotatorTypes = [AnnotatorType.IMAGE] + outputAnnotatorType = AnnotatorType.SENTENCE_EMBEDDINGS + + @keyword_only + def __init__(self, classname="com.johnsnowlabs.nlp.embeddings.E5VEmbeddings", java_model=None): + """Initializes the E5VEmbeddings annotator. + + Parameters + ---------- + classname : str, optional + The Java class name of the annotator, by default "com.johnsnowlabs.nlp.annotators.embeddings.E5VEmbeddings" + java_model : Optional[java.lang.Object], optional + A pre-initialized Java model, by default None + """ + super(E5VEmbeddings, self).__init__(classname=classname, java_model=java_model) + self._setDefault() + + @staticmethod + def loadSavedModel(folder, spark_session, use_openvino=False): + """Loads a locally saved model. + + Parameters + ---------- + folder : str + Folder of the saved model + spark_session : pyspark.sql.SparkSession + The current SparkSession + use_openvino : bool, optional + Whether to use OpenVINO engine, by default False + + Returns + ------- + E5VEmbeddings + The restored model + """ + from sparknlp.internal import _E5VEmbeddingsLoader + jModel = _E5VEmbeddingsLoader(folder, spark_session._jsparkSession, use_openvino)._java_obj + return E5VEmbeddings(java_model=jModel) + + @staticmethod + def pretrained(name="e5v_int4", lang="en", remote_loc=None): + """Downloads and loads a pretrained model. + + Parameters + ---------- + name : str, optional + Name of the pretrained model, by default "e5v_int4" + lang : str, optional + Language of the pretrained model, by default "en" + remote_loc : str, optional + Optional remote address of the resource, by default None. Will use Spark NLPs repositories otherwise. + + Returns + ------- + E5VEmbeddings + The restored model + """ + from sparknlp.pretrained import ResourceDownloader + return ResourceDownloader.downloadModel(E5VEmbeddings, name, lang, remote_loc) \ No newline at end of file diff --git a/python/sparknlp/internal/__init__.py b/python/sparknlp/internal/__init__.py index 25f34ce4eba599..e7300ab8586b5c 100644 --- a/python/sparknlp/internal/__init__.py +++ b/python/sparknlp/internal/__init__.py @@ -1165,3 +1165,11 @@ def __init__(self, path, jspark, use_openvino=False): jspark, use_openvino, ) +class _E5VEmbeddingsLoader(ExtendedJavaWrapper): + def __init__(self, path, jspark, use_openvino=False): + super(_E5VEmbeddingsLoader, self).__init__( + "com.johnsnowlabs.nlp.embeddings.E5VEmbeddings.loadSavedModel", + path, + jspark, + use_openvino + ) \ No newline at end of file diff --git a/python/sparknlp/util.py b/python/sparknlp/util.py index 0bbacd410a9e8f..8381337b5873a5 100644 --- a/python/sparknlp/util.py +++ b/python/sparknlp/util.py @@ -15,6 +15,9 @@ import sparknlp.internal as _internal +import numpy as np +from pyspark.sql import Row +from pyspark.sql.types import StructType, StructField, StringType, IntegerType, BinaryType def get_config_path(): @@ -33,3 +36,26 @@ def exportConllFiles(*args): _internal._CoNLLGeneratorExportFromTargetAndPipeline(*args).apply() else: raise NotImplementedError(f"No exportConllFiles alternative takes {num_args} parameters") + + +class EmbeddingsDataFrameUtils: + """ + Utility for creating DataFrames compatible with multimodal embedding models (e.g., E5VEmbeddings) for text-only scenarios. + Provides: + - imageSchema: the expected schema for Spark image DataFrames + - emptyImageRow: a dummy image row for text-only embedding + """ + imageSchema = StructType([ + StructField( + "image", + StructType([ + StructField("origin", StringType(), True), + StructField("height", IntegerType(), True), + StructField("width", IntegerType(), True), + StructField("nChannels", IntegerType(), True), + StructField("mode", IntegerType(), True), + StructField("data", BinaryType(), True), + ]), + ) + ]) + emptyImageRow = Row(Row("", 0, 0, 0, 0, bytes())) diff --git a/python/test/annotator/embeddings/e5v_embeddings_test.py b/python/test/annotator/embeddings/e5v_embeddings_test.py new file mode 100644 index 00000000000000..249484232284d3 --- /dev/null +++ b/python/test/annotator/embeddings/e5v_embeddings_test.py @@ -0,0 +1,64 @@ +# Copyright 2017-2024 John Snow Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import unittest +import pytest + +from sparknlp.annotator import * +from sparknlp.base import * +from pyspark.ml import Pipeline +from pyspark.sql.functions import lit +from test.util import SparkContextForTest + +@pytest.mark.slow +class E5VEmbeddingsTestSpec(unittest.TestCase): + def setUp(self): + self.spark = SparkContextForTest.spark + self.images_path = "file://"+os.getcwd() + "/../src/test/resources/image/" + + def test_image_and_text_embedding(self): + # Simulate image+text embedding (requires actual image files for full test) + image_folder = os.environ.get("E5V_IMAGE_TEST_FOLDER", self.images_path) + imagePrompt = "<|start_header_id|>user<|end_header_id|>\n\n<image>\\nSummary above image in one word: <|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n" + image_df = self.spark.read.format("image").option("dropInvalid", True).load(image_folder) + test_df = image_df.withColumn("text", lit(imagePrompt)) + + imageAssembler = ImageAssembler() \ + .setInputCol("image") \ + .setOutputCol("image_assembler") + e5v = E5VEmbeddings.pretrained() \ + .setInputCols(["image_assembler"]) \ + .setOutputCol("e5v") + pipeline = Pipeline().setStages([imageAssembler, e5v]) + results = pipeline.fit(test_df).transform(test_df) + results.select("e5v.embeddings").show(truncate=True) + + def test_text_only_embedding(self): + # Simulate text-only embedding using emptyImageRow and imageSchema + from sparknlp.util import EmbeddingsDataFrameUtils + textPrompt = "<|start_header_id|>user<|end_header_id|>\n\n<sent>\\nSummary above sentence in one word: <|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n" + textDesc = "A cat sitting in a box." + nullImageDF = self.spark.createDataFrame( + self.spark.sparkContext.parallelize([EmbeddingsDataFrameUtils.emptyImageRow]), + EmbeddingsDataFrameUtils.imageSchema) + textDF = nullImageDF.withColumn("text", lit(textPrompt.replace("<sent>", textDesc))) + imageAssembler = ImageAssembler() \ + .setInputCol("image") \ + .setOutputCol("image_assembler") + e5v = E5VEmbeddings.pretrained() \ + .setInputCols(["image_assembler"]) \ + .setOutputCol("e5v") + pipeline = Pipeline().setStages([imageAssembler, e5v]) + results = pipeline.fit(textDF).transform(textDF) + results.select("e5v.embeddings").show(truncate=True) \ No newline at end of file diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/E5V.scala b/src/main/scala/com/johnsnowlabs/ml/ai/E5V.scala new file mode 100644 index 00000000000000..6653fe97cdebcb --- /dev/null +++ b/src/main/scala/com/johnsnowlabs/ml/ai/E5V.scala @@ -0,0 +1,412 @@ +/* + * Copyright 2017-2022 John Snow Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.johnsnowlabs.ml.ai + +import com.johnsnowlabs.ml.ai.util.Generation.GenerationConfig +import com.johnsnowlabs.ml.onnx.OnnxWrapper.DecoderWrappers +import com.johnsnowlabs.ml.openvino.OpenvinoWrapper.E5VWrappers +import com.johnsnowlabs.ml.util.{ONNX, Openvino} +import com.johnsnowlabs.nlp.AnnotatorType.DOCUMENT +import com.johnsnowlabs.nlp._ +import com.johnsnowlabs.nlp.annotators.common.SentenceSplit +import com.johnsnowlabs.ml.ai.util.transform.E5VUtils +import com.johnsnowlabs.nlp.annotators.cv.feature_extractor.Preprocessor +import com.johnsnowlabs.nlp.annotators.cv.util.io.ImageIOUtils +import com.johnsnowlabs.nlp.annotators.cv.util.transform.ImageResizeUtils +import com.johnsnowlabs.nlp.annotators.tokenizer.bpe.{BpeTokenizer, LLAVATokenizer, SpecialTokens} +import org.intel.openvino.InferRequest + +private[johnsnowlabs] class E5V( + val onnxWrappers: Option[DecoderWrappers], + val openvinoWrapper: Option[E5VWrappers], + merges: Map[(String, String), Int], + vocabulary: Map[String, Int], + addedTokens: Map[String, Int], + preprocessor: Preprocessor, + generationConfig: GenerationConfig, + imageToken: Int, + imageGridPinpoints: Map[Int, Array[Int]], + patchSize: Int) + extends Serializable { + + val detectedEngine: String = + if (onnxWrappers.isDefined) ONNX.name + else if (openvinoWrapper.isDefined) Openvino.name + else Openvino.name + + private val GenerationConfig( + bosTokenId: Int, + paddingTokenId: Int, + eosTokenId: Int, + vocabSize: Int, + beginSuppressTokens, + suppressTokenIds, + forcedDecoderIds) = + generationConfig + val reversedVocabulary: Map[Int, String] = vocabulary.map(_.swap) + val specialTokens: SpecialTokens = SpecialTokens( + vocabulary, + startTokenString = reversedVocabulary(bosTokenId), + endTokenString = reversedVocabulary(eosTokenId), + unkTokenString = reversedVocabulary(eosTokenId), + maskTokenString = reversedVocabulary(eosTokenId), + padTokenString = reversedVocabulary(paddingTokenId), + additionalStrings = addedTokens.keys.toArray) + + val bpeTokenizer: LLAVATokenizer = BpeTokenizer + .forModel( + "llava", + merges = merges, + vocab = vocabulary, + specialTokens = Some(specialTokens), + addPrefixSpaceToSentence = false, + alwaysAddPrefix = false, + prependString = "") + .asInstanceOf[LLAVATokenizer] + + /** Decode a sequence of sentences + * @param sentences + * Sequence of sentences + * @return + * Sequence of decoded sentences + */ + def decode(sentences: Array[Array[Int]]): Seq[String] = { + sentences.map(s => bpeTokenizer.decodeTokens(s.map(_.toInt))) + } + + /** Encode a sequence of sentences + * @param sentences + * Sequence of sentences + * @return + * Sequence of encoded sentences + */ + def encodeText(sentences: Seq[Annotation]): Seq[Array[Int]] = { + + val tokens = SentenceSplit + .unpack(sentences) + .map(s => { + val sentWithTask = s + bpeTokenizer + .tokenize(sentWithTask) + .map(bpeTokenizer.encode) + .flatMap(_.map(_.pieceId)) + }) + tokens + } + + def encode( + imageAnnotations: Seq[AnnotationImage], + sentences: Seq[Annotation], + preprocessor: Preprocessor): ( + Seq[Array[Int]], + Option[Array[Array[Array[Array[Array[Float]]]]]], + Option[Array[(Int, Int)]]) = { + val encodedText = encodeText(sentences).toArray + + // check if image annotations are present an height and width are > 0 + val imageAnnotationsFiltered = + imageAnnotations.filter(annot => annot.width > 0 && annot.height > 0) + + val preprocessedImages = if (imageAnnotationsFiltered.nonEmpty) { + Some(encodeImage(imageAnnotations.toArray, preprocessor)) + } else { + None + } + val imageSizes = if (imageAnnotationsFiltered.nonEmpty) { + Some(imageAnnotations.map(annot => (annot.width, annot.height)).toArray) + } else { + None + } + + (encodedText, preprocessedImages, imageSizes) + } + + def tag( + batch: Seq[Array[Int]], + images: Option[Array[Array[Array[Array[Array[Float]]]]]], + imageSizes: Option[Array[(Int, Int)]]): Array[Array[Float]] = { + + val pixelValues = images + val expandedDecoderInputsVals = batch + val sequencesLength = expandedDecoderInputsVals.map(x => x.length).toArray + val numReturn_sequences = 1 + // from config + + var effectiveBatch_size = 1 + var effectiveBatch_mult = 1 + + effectiveBatch_size = expandedDecoderInputsVals.length + effectiveBatch_mult = 1 + + val inferRequestLanguageModel = + openvinoWrapper.get.languageModel.getCompiledModel().create_infer_request() + val inferRequestVisionEmbeddingsModel = + openvinoWrapper.get.visionEmbeddingsModel.getCompiledModel().create_infer_request() + val inferRequestTextEmbeddingsModel = + openvinoWrapper.get.textEmbeddingsModel.getCompiledModel().create_infer_request() + val inferRequestImagePackerModel = + openvinoWrapper.get.imagePackerModel.getCompiledModel().create_infer_request() + val inferRequestMergeModel = + openvinoWrapper.get.mergeModel.getCompiledModel().create_infer_request() + + val generatedEmbeddings = getModelOutputs( + decoderInputIds = expandedDecoderInputsVals.toArray, + pixelValues = pixelValues, + imageSizes = imageSizes, + inferRequestLanguageModel = inferRequestLanguageModel, + inferRequestVisionEmbeddingsModel = inferRequestVisionEmbeddingsModel, + inferRequestTextEmbeddingsModel = inferRequestTextEmbeddingsModel, + inferRequestImagePackerModel = inferRequestImagePackerModel, + inferRequestMergeModel = inferRequestMergeModel) + generatedEmbeddings + } + + def predict( + sentences: Seq[Annotation], + imageAnnotations: Seq[AnnotationImage]): Seq[Annotation] = { + + val (encodedText, preprocessedImages, imageSizes) = + encode(imageAnnotations, sentences, preprocessor) + val sentenceEmbeddings = tag(encodedText, preprocessedImages, imageSizes) + + val annotations = sentences.zip(sentenceEmbeddings).map { case (sentence, vectors) => + Annotation( + annotatorType = AnnotatorType.SENTENCE_EMBEDDINGS, + begin = sentence.begin, + end = sentence.end, + result = sentence.result, + metadata = sentence.metadata, + embeddings = vectors) + } + annotations + } + + def getModelOutputs( + decoderInputIds: Array[Array[Int]], + pixelValues: Option[Array[Array[Array[Array[Array[Float]]]]]], + imageSizes: Option[Array[(Int, Int)]], + inferRequestLanguageModel: InferRequest, + inferRequestVisionEmbeddingsModel: InferRequest, + inferRequestTextEmbeddingsModel: InferRequest, + inferRequestImagePackerModel: InferRequest, + inferRequestMergeModel: InferRequest): Array[Array[Float]] = { + + val (inputIdsLong, inputPositionIDsLong): (Array[Long], Array[Long]) = { + // First pass + val inpIdsLong = decoderInputIds.flatMap { tokenIds => tokenIds.map(_.toLong) } + val posIdsLong = decoderInputIds.flatMap { tokenIds => + tokenIds.zipWithIndex.map { case (_, i) => + i.toLong + } + } + (inpIdsLong, posIdsLong) + } + + val attentionMask: Array[Long] = decoderInputIds.flatMap { tokenIds => tokenIds.map(_ => 1L) } + val batchSize: Int = decoderInputIds.length + val shape: Array[Int] = Array(batchSize, inputIdsLong.length / batchSize) + + val decoderAttentionMask: org.intel.openvino.Tensor = + new org.intel.openvino.Tensor(Array(batchSize, decoderInputIds.head.length), attentionMask) + val decoderPositionIDs: org.intel.openvino.Tensor = + new org.intel.openvino.Tensor(shape, inputPositionIDsLong) + + val (finalEmbeds, finalAttentionMask, finalPositionIds) = getMultimodalEmbeddings( + decoderInputIds, + pixelValues, + imageSizes, + decoderAttentionMask, + inferRequestVisionEmbeddingsModel, + inferRequestTextEmbeddingsModel, + inferRequestImagePackerModel, + inferRequestMergeModel) + + inferRequestLanguageModel.set_tensor("inputs_embeds", finalEmbeds) + if (finalAttentionMask.isDefined) { + val finalAttentionMaskFloatTensor = new org.intel.openvino.Tensor( + finalAttentionMask.get.get_shape(), + // flat array of floats of values 1.0 + Array.fill(finalAttentionMask.get.get_shape().product)(1.0f)) + inferRequestLanguageModel.set_tensor("attention_mask", finalAttentionMaskFloatTensor) + } else { + val attentionMaskFloat: Array[Float] = + decoderInputIds.flatMap { tokenIds => tokenIds.map(_ => 1f) } + val attentionMaskFloatTensor = + new org.intel.openvino.Tensor( + Array(batchSize, decoderInputIds.head.length), + attentionMaskFloat) + inferRequestLanguageModel.set_tensor("attention_mask", attentionMaskFloatTensor) + } + if (finalPositionIds.isDefined) { + inferRequestLanguageModel.set_tensor("position_ids", finalPositionIds.get) + } else { + inferRequestLanguageModel.set_tensor("position_ids", decoderPositionIDs) + } + inferRequestLanguageModel.infer() + + val result = inferRequestLanguageModel.get_tensor("last_hidden_state") + val hiddenStateData = result.data() + val hiddenStateShape = result.get_shape() + val batchSizeResult = hiddenStateShape(0) + val hiddenSize = hiddenStateShape(1) + // Reshape to (batch, hidden_size) and return as Array[Array[Float]] + Array.tabulate(batchSizeResult) { b => + val start = b * hiddenSize + val end = start + hiddenSize + hiddenStateData.slice(start, end) + } + + } + + private def encodeImage( + annotations: Array[AnnotationImage], + preprocessor: Preprocessor): Array[Array[Array[Array[Array[Float]]]]] = { + + val batchProcessedImages = annotations.map { annot => + val bufferedImage = ImageIOUtils.byteToBufferedImage( + bytes = annot.result, + w = annot.width, + h = annot.height, + nChannels = annot.nChannels) + val bestResolution = E5VUtils.selectBestResolution( + (bufferedImage.getHeight, bufferedImage.getWidth), + imageGridPinpoints.map { case (_, pinpoints) => + (pinpoints(0), pinpoints(1)) + }.toList) + + val (newHeight, newWidth) = E5VUtils.getPatchOutputSize(bufferedImage, bestResolution) + val resizedForPatches = ImageResizeUtils.resizeBufferedImage( + width = newWidth, + height = newHeight, + resample = preprocessor.resample)(bufferedImage) + + val paddedForPatches = E5VUtils.padImage(resizedForPatches, bestResolution) + + var patches = E5VUtils.divideToPatches(paddedForPatches, patchSize) + + // add the reshaped original image as the first patch + val resizedOriginalImage = ImageResizeUtils.resizeBufferedImage( + width = preprocessor.size, + height = preprocessor.size, + resample = preprocessor.resample)(bufferedImage) + + patches = List(resizedOriginalImage) ++ patches + patches.map { patch => + ImageResizeUtils.normalizeAndConvertBufferedImage( + img = patch, + mean = preprocessor.image_mean, + std = preprocessor.image_std, + doNormalize = preprocessor.do_normalize, + doRescale = preprocessor.do_rescale, + rescaleFactor = preprocessor.rescale_factor) + }.toArray + } + + batchProcessedImages + + } + + def getMultimodalEmbeddings( + inputIds: Array[Array[Int]], + pixelValues: Option[Array[Array[Array[Array[Array[Float]]]]]], + imageSizes: Option[Array[(Int, Int)]], + attentionMask: org.intel.openvino.Tensor, + inferRequestVisionEmbeddingsModel: InferRequest, + inferRequestTextEmbeddingsModel: InferRequest, + inferRequestImagePackerModel: InferRequest, + inferRequestMergeModel: InferRequest): ( + org.intel.openvino.Tensor, + Option[org.intel.openvino.Tensor], + Option[org.intel.openvino.Tensor]) = { + + val inputIdsLong: Array[Long] = inputIds.flatMap(_.map(_.toLong)) + val batchSize: Int = inputIds.length + val shape: Array[Int] = Array(batchSize, inputIdsLong.length / batchSize) + val inputIdsLongTensor = new org.intel.openvino.Tensor(shape, inputIdsLong) + + // If pixelValues and imageSizes are present, do multimodal + (pixelValues, imageSizes, attentionMask) match { + case (Some(pixels), Some(sizes), attnMask) if pixels.nonEmpty && sizes.nonEmpty => + // 1. Get image features + val pixelShape = Array( + pixels.length, + pixels.head.length, + pixels.head.head.length, + pixels.head.head.head.length, + pixels.head.head.head.head.length) + // Flatten the pixel values to match the expected input shape + val flattenedPixels = pixels.flatten.flatten.flatten.flatten + val pixelTensor = + new org.intel.openvino.Tensor(pixelShape, flattenedPixels) + + inferRequestVisionEmbeddingsModel.set_tensor("pixel_values", pixelTensor) + inferRequestVisionEmbeddingsModel.infer() + val imageFeatures = inferRequestVisionEmbeddingsModel.get_output_tensor() + + // 2. Compute patch grid shape (dummy for now, should use config) + val (numPatchHeight, numPatchWidth) = + E5VUtils.getAnyResImageGridShape( + imageSizes.get.head, + imageGridPinpoints.map { case (_, pinpoints) => + (pinpoints(0), pinpoints(1)) + }.toList, + preprocessor.size) + + // 3. Pack image features + val imageSizesTensor = new org.intel.openvino.Tensor( + Array(sizes.length, 2), + sizes.flatMap(t => Array(t._1.toLong, t._2.toLong))) + + val numPatchHeightTensor = + new org.intel.openvino.Tensor(Array[Int](), Array(numPatchHeight.toLong)) + + val numPatchWidthTensor = + new org.intel.openvino.Tensor(Array[Int](), Array(numPatchWidth.toLong)) + + inferRequestImagePackerModel.set_tensor("image_feature", imageFeatures) + inferRequestImagePackerModel.set_tensor("image_sizes", imageSizesTensor) + inferRequestImagePackerModel.set_tensor("num_patch_height", numPatchHeightTensor) + inferRequestImagePackerModel.set_tensor("num_patch_width", numPatchWidthTensor) + inferRequestImagePackerModel.infer() + + val packedImageFeatures = inferRequestImagePackerModel.get_output_tensor() + + // 4. Get text embeddings + inferRequestTextEmbeddingsModel.set_input_tensor(inputIdsLongTensor) + inferRequestTextEmbeddingsModel.infer() + val textEmbeddings = inferRequestTextEmbeddingsModel.get_output_tensor() + + // 5. Merge image and text embeddings + inferRequestMergeModel.set_tensor("image_features", packedImageFeatures) + inferRequestMergeModel.set_tensor("inputs_embeds", textEmbeddings) + inferRequestMergeModel.set_tensor("input_ids", inputIdsLongTensor) + + inferRequestMergeModel.set_tensor("attention_mask", attnMask) + inferRequestMergeModel.infer() + ( + inferRequestMergeModel.get_tensor("final_embedding"), + Some(inferRequestMergeModel.get_tensor("final_attention_mask")), + Some(inferRequestMergeModel.get_tensor("position_ids"))) + case _ => + // Text-only + inferRequestTextEmbeddingsModel.set_input_tensor(inputIdsLongTensor) + inferRequestTextEmbeddingsModel.infer() + (inferRequestTextEmbeddingsModel.get_output_tensor(), None, None) + } + } + +} diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/util/transform/E5VUtils.scala b/src/main/scala/com/johnsnowlabs/ml/ai/util/transform/E5VUtils.scala new file mode 100644 index 00000000000000..a2561d2bd28613 --- /dev/null +++ b/src/main/scala/com/johnsnowlabs/ml/ai/util/transform/E5VUtils.scala @@ -0,0 +1,134 @@ +package com.johnsnowlabs.ml.ai.util.transform + +import java.awt.image.BufferedImage +import java.awt.{Color, Graphics2D} + +object ChannelDimension extends Enumeration { + type ChannelDimension = Value + val FIRST, LAST = Value +} + +object E5VUtils { + import ChannelDimension._ + + def selectBestResolution( + originalSize: (Int, Int), + possibleResolutions: List[(Int, Int)]): (Int, Int) = { + val (originalHeight, originalWidth) = originalSize + var bestFit: (Int, Int) = possibleResolutions.head + var maxEffectiveResolution = 0 + var minWastedResolution = Double.PositiveInfinity + + for ((height, width) <- possibleResolutions) { + val scale = math.min(width.toDouble / originalWidth, height.toDouble / originalHeight) + val downscaledWidth = (originalWidth * scale).toInt + val downscaledHeight = (originalHeight * scale).toInt + val effectiveResolution = + math.min(downscaledWidth * downscaledHeight, originalWidth * originalHeight) + val wastedResolution = (width * height) - effectiveResolution + + if (effectiveResolution > maxEffectiveResolution || + (effectiveResolution == maxEffectiveResolution && wastedResolution < minWastedResolution)) { + maxEffectiveResolution = effectiveResolution + minWastedResolution = wastedResolution + bestFit = (height, width) + } + } + bestFit + } + + def imageSizeToNumPatches( + imageSize: (Int, Int), + gridPinpoints: List[(Int, Int)], + patchSize: Int): Int = { + val (height, width) = selectBestResolution(imageSize, gridPinpoints) + val numPatches = (0 until height by patchSize).size * (0 until width by patchSize).size + // add the base patch + numPatches + 1 + } + + def getAnyResImageGridShape( + imageSize: (Int, Int), + gridPinpoints: List[(Int, Int)], + patchSize: Int): (Int, Int) = { + val (height, width) = selectBestResolution(imageSize, gridPinpoints) + (height / patchSize, width / patchSize) + } + + def getImageSize(image: BufferedImage): (Int, Int) = { + (image.getHeight, image.getWidth) + } + + def expandToSquare(image: BufferedImage, backgroundColor: Color): BufferedImage = { + val width = image.getWidth + val height = image.getHeight + if (width == height) { + image + } else if (width > height) { + val result = new BufferedImage(width, width, image.getType) + val g = result.createGraphics() + g.setColor(backgroundColor) + g.fillRect(0, 0, width, width) + g.drawImage(image, 0, (width - height) / 2, null) + g.dispose() + result + } else { + val result = new BufferedImage(height, height, image.getType) + val g = result.createGraphics() + g.setColor(backgroundColor) + g.fillRect(0, 0, height, height) + g.drawImage(image, (height - width) / 2, 0, null) + g.dispose() + result + } + } + + def divideToPatches(image: BufferedImage, patchSize: Int): List[BufferedImage] = { + val width = image.getWidth + val height = image.getHeight + val patches = for { + i <- 0 until height by patchSize + j <- 0 until width by patchSize + } yield { + val w = math.min(patchSize, width - j) + val h = math.min(patchSize, height - i) + image.getSubimage(j, i, w, h) + } + patches.toList + } + + def getPatchOutputSize(image: BufferedImage, targetResolution: (Int, Int)): (Int, Int) = { + val (originalHeight, originalWidth) = getImageSize(image) + val (targetHeight, targetWidth) = targetResolution + + val scaleW = targetWidth.toDouble / originalWidth + val scaleH = targetHeight.toDouble / originalHeight + + if (scaleW < scaleH) { + val newWidth = targetWidth + val newHeight = math.min(math.ceil(originalHeight * scaleW).toInt, targetHeight) + (newHeight, newWidth) + } else { + val newHeight = targetHeight + val newWidth = math.min(math.ceil(originalWidth * scaleH).toInt, targetWidth) + (newHeight, newWidth) + } + } + + def padImage(image: BufferedImage, targetResolution: (Int, Int)): BufferedImage = { + val (targetHeight, targetWidth) = targetResolution + val (originalHeight, originalWidth) = getImageSize(image) + val (newHeight, newWidth) = getPatchOutputSize(image, targetResolution) + val result = new BufferedImage(targetWidth, targetHeight, image.getType) + val g = result.createGraphics() + g.setColor(Color.BLACK) + g.fillRect(0, 0, newWidth, newHeight) + g.drawImage( + image, + (targetWidth - originalWidth) / 2, + (targetHeight - originalHeight) / 2, + null) + g.dispose() + result + } +} diff --git a/src/main/scala/com/johnsnowlabs/ml/openvino/OpenvinoWrapper.scala b/src/main/scala/com/johnsnowlabs/ml/openvino/OpenvinoWrapper.scala index 274e085325aaf3..961995ffdd1511 100644 --- a/src/main/scala/com/johnsnowlabs/ml/openvino/OpenvinoWrapper.scala +++ b/src/main/scala/com/johnsnowlabs/ml/openvino/OpenvinoWrapper.scala @@ -285,4 +285,10 @@ object OpenvinoWrapper { textEmbeddingsModel: OpenvinoWrapper, imageEmbedModel: OpenvinoWrapper, modelMergerModel: OpenvinoWrapper) + case class E5VWrappers( + languageModel: OpenvinoWrapper, + visionEmbeddingsModel: OpenvinoWrapper, + textEmbeddingsModel: OpenvinoWrapper, + imagePackerModel: OpenvinoWrapper, + mergeModel: OpenvinoWrapper) } diff --git a/src/main/scala/com/johnsnowlabs/nlp/ImageAssembler.scala b/src/main/scala/com/johnsnowlabs/nlp/ImageAssembler.scala index ae620dc78cbaa5..839346b79453b5 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/ImageAssembler.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/ImageAssembler.scala @@ -152,7 +152,21 @@ class ImageAssembler(override val uid: String) result = image.get.data, metadata = metadata, text = text.getOrElse(""))) - } else Seq.empty + } else if (text.isDefined) { + Seq( + AnnotationImage( + annotatorType = outputAnnotatorType, + origin = "", + height = 0, + width = 0, + nChannels = 0, + mode = 0, + result = Array.emptyByteArray, + metadata = metadata, + text = text.getOrElse(""))) + } else { + Seq.empty[AnnotationImage] + } } diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/E5VEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/E5VEmbeddings.scala new file mode 100644 index 00000000000000..657d012c04734e --- /dev/null +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/E5VEmbeddings.scala @@ -0,0 +1,641 @@ +/* + * Copyright 2017-2024 John Snow Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.johnsnowlabs.nlp.embeddings + +import com.johnsnowlabs.ml.ai.util.Generation.GenerationConfig +import com.johnsnowlabs.ml.ai.E5V +import com.johnsnowlabs.ml.onnx.OnnxWrapper.DecoderWrappers +import com.johnsnowlabs.ml.util.LoadExternalModel.{ + loadJsonStringAsset, + loadTextAsset, + modelSanityCheck, + notSupportedEngineError +} +import com.johnsnowlabs.nlp.annotators.cv.feature_extractor.Preprocessor +import com.johnsnowlabs.ml.util.Openvino +import com.johnsnowlabs.nlp.AnnotatorType.{DOCUMENT, IMAGE, SENTENCE_EMBEDDINGS} +import com.johnsnowlabs.nlp._ +import org.json4s.{DefaultFormats, JValue} +import org.json4s.jackson.JsonMethods.parse +import com.johnsnowlabs.ml.openvino.{OpenvinoWrapper, ReadOpenvinoModel, WriteOpenvinoModel} +import com.johnsnowlabs.ml.openvino.OpenvinoWrapper.E5VWrappers +import com.johnsnowlabs.nlp.serialization.{MapFeature, StructFeature} +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.ml.param.{IntArrayParam, IntParam} +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.sql.types.{BinaryType, IntegerType, StringType, StructField, StructType} +import org.apache.spark.sql.{Row, SparkSession} + +/** E5VEmbeddings provides universal multimodal embeddings using the E5-V model, which is + * fine-tuned from lmms-lab/llama3-llava-next-8b. + * + * E5-V bridges the modality gap between different input types (text, image) and demonstrates + * strong performance in multimodal embeddings, even without fine-tuning. It also supports a + * single-modality training approach, where the model is trained exclusively on text pairs, often + * yielding better performance than multimodal training. + * + * For more details, see the Hugging Face model card: https://huggingface.co/royokong/e5-v + * + * ==Overview== + * + * E5-V can embed both text and images into a shared space, enabling cross-modal retrieval and + * similarity tasks. The model is designed for universal embeddings and is suitable for scenarios + * where you want to compare or retrieve across modalities. + * + * ==Example== + * + * ===Image + Text Embedding=== + * {{ { import org.apache.spark.sql.functions.lit import com.johnsnowlabs.nlp.base.ImageAssembler + * import com.johnsnowlabs.nlp.embeddings.E5VEmbeddings import org.apache.spark.ml.Pipeline + * + * val imageDF = spark.read.format("image").option("dropInvalid", value = true).load(imageFolder) + * val imagePrompt = "<|start_header_id|>user<|end_header_id|>\n\n<image>\\nSummary above image + * in one word: <|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n" val testDF = + * imageDF.withColumn("text", lit(imagePrompt)) + * + * val imageAssembler = new ImageAssembler().setInputCol("image").setOutputCol("image_assembler") + * val e5vEmbeddings = E5VEmbeddings.pretrained() .setInputCols("image_assembler") + * .setOutputCol("e5v") + * + * val pipeline = new Pipeline().setStages(Array(imageAssembler, e5vEmbeddings)) val result = + * pipeline.fit(testDF).transform(testDF) result.select("e5v.embeddings").show(truncate = false) + * }} + * + * ===Text-Only Embedding=== + * {{ { import org.apache.spark.sql.SparkSession import org.apache.spark.sql.functions.lit import + * com.johnsnowlabs.nlp.util.EmbeddingsDataFrameUtils.{emptyImageRow, imageSchema} import + * com.johnsnowlabs.nlp.embeddings.E5VEmbeddings + * + * val spark: SparkSession = ... val textPrompt = + * "<|start_header_id|>user<|end_header_id|>\n\n<sent>\\nSummary above sentence in one word: + * <|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n" val textDesc = "A cat sitting + * in a box." val nullImageDF = + * spark.createDataFrame(spark.sparkContext.parallelize(Seq(emptyImageRow)), imageSchema) val + * textDF = nullImageDF.withColumn("text", lit(textPrompt.replace("<sent>", textDesc))) + * + * val e5vEmbeddings = E5VEmbeddings.pretrained() .setInputCols("image") .setOutputCol("e5v") val + * result = e5vEmbeddings.transform(textDF) result.select("e5v.embeddings").show(truncate = + * false) }} + * + * ==References== + * - Hugging Face model card: https://huggingface.co/royokong/e5-v + * - Paper: https://arxiv.org/abs/2407.12580 + * - Code: https://github.com/kongds/E5-V + * + * @see + * [[CLIPForZeroShotClassification]] for Zero Shot Image Classifier + * @see + * [[https://sparknlp.org/docs/en/annotators Annotators Main Page]] for a list of transformer + * based classifiers + * @groupname anno Annotator types + * @groupdesc anno + * Required input and expected output annotator types + * @groupname Ungrouped Members + * @groupname param Parameters + * @groupname setParam Parameter setters + * @groupname getParam Parameter getters + * @groupname Ungrouped Members + * @groupprio param 1 + * @groupprio anno 2 + * @groupprio Ungrouped 3 + * @groupprio setParam 4 + * @groupprio getParam 5 + * @groupdesc param + * A list of (hyper-)parameter keys this annotator can take. Users can set and get the + * parameter values through setters and getters, respectively. + */ + +class E5VEmbeddings(override val uid: String) + extends AnnotatorModel[E5VEmbeddings] + with HasBatchedAnnotateImage[E5VEmbeddings] + with HasImageFeatureProperties + with WriteOpenvinoModel + with HasGeneratorProperties + with HasEngine { + + /** Annotator reference id. Used to identify elements in metadata or to refer to this annotator + * type + */ + def this() = this(Identifiable.randomUID("E5VEmbeddings")) + + /** Annotator reference id. Used to identify elements in metadata or to refer to this annotator + * type + */ + override val inputAnnotatorTypes: Array[AnnotatorType] = Array(IMAGE) + override val outputAnnotatorType: AnnotatorType = SENTENCE_EMBEDDINGS + + /** @group setParam */ + def setRandomSeed(value: Int): E5VEmbeddings.this.type = { + if (randomSeed.isEmpty) { + this.randomSeed = Some(value) + } + this + } + + /** A list of token ids which are ignored in the decoder's output (Default: `Array()`) + * + * @group param + */ + var ignoreTokenIds = new IntArrayParam( + this, + "ignoreTokenIds", + "A list of token ids which are ignored in the decoder's output") + + /** @group setParam */ + def setIgnoreTokenIds(tokenIds: Array[Int]): E5VEmbeddings.this.type = { + set(ignoreTokenIds, tokenIds) + } + + /** @group getParam */ + def getIgnoreTokenIds: Array[Int] = $(ignoreTokenIds) + + /** Vocabulary used to encode the words to ids with bpeTokenizer.encode + * + * @group param + */ + val vocabulary: MapFeature[String, Int] = new MapFeature(this, "vocabulary").setProtected() + + /** @group setParam */ + def setVocabulary(value: Map[String, Int]): this.type = set(vocabulary, value) + + /** Holding merges.txt coming from RoBERTa model + * + * @group param + */ + val merges: MapFeature[(String, String), Int] = new MapFeature(this, "merges").setProtected() + + /** @group setParam */ + def setMerges(value: Map[(String, String), Int]): this.type = set(merges, value) + + /** Additional tokens to be added to the vocabulary + * + * @group param + */ + val addedTokens: MapFeature[String, Int] = new MapFeature(this, "addedTokens").setProtected() + + /** @group setParam */ + def setAddedTokens(value: Map[String, Int]): this.type = set(addedTokens, value) + + /** Stop tokens to terminate the generation + * + * @group param + */ + override val stopTokenIds = + new IntArrayParam(this, "stopTokenIds", "Stop tokens to terminate the generation") + + /** @group setParam */ + override def setStopTokenIds(value: Array[Int]): this.type = { + set(stopTokenIds, value) + } + + /** @group getParam */ + override def getStopTokenIds: Array[Int] = $(stopTokenIds) + + private var _model: Option[Broadcast[E5V]] = None + val generationConfig: StructFeature[GenerationConfig] = + new StructFeature(this, "generationConfig").setProtected() + + def setGenerationConfig(value: GenerationConfig): this.type = + set(generationConfig, value) + + def getGenerationConfig: GenerationConfig = $$(generationConfig) + + val imageToken = + new IntParam(this, "imageToken", "Token id for image embeddings") + + /** @group setParam */ + def setImageToken(value: Int): this.type = set(imageToken, value) + + /** @group getParam */ + def getImageToken: Int = $(imageToken) + + /** Pinpoints for image grid, used to extract image features from the grid + * + * @group param + */ + val imageGridPinpoints: MapFeature[Int, Array[Int]] = new MapFeature(this, "imageGridPinpoints") + + /** @group setParam */ + def setImageGridPinpoints(value: Map[Int, Array[Int]]): this.type = + set(imageGridPinpoints, value) + + /** @group getParam */ + def getImageGridPinpoints: Map[Int, Array[Int]] = $$(imageGridPinpoints) + + /** Patch size for image embeddings + * + * @group param + */ + val patchSize: IntParam = + new IntParam(this, "patchSize", "Patch size for image embeddings, default is 336") + + /** @group setParam */ + def setPatchSize(value: Int): this.type = set(patchSize, value) + + /** @group getParam */ + def getPatchSize: Int = $(patchSize) + + /** @group setParam */ + def setModelIfNotSet( + spark: SparkSession, + preprocessor: Preprocessor, + onnxWrappers: Option[DecoderWrappers], + openvinoWrapper: Option[E5VWrappers]): this.type = { + if (_model.isEmpty) { + _model = Some( + spark.sparkContext.broadcast( + new E5V( + onnxWrappers, + openvinoWrapper, + $$(merges), + $$(vocabulary), + $$(addedTokens), + preprocessor, + generationConfig = getGenerationConfig, + imageToken = getImageToken, + imageGridPinpoints = getImageGridPinpoints, + patchSize = getPatchSize))) + } + this + } + + /** @group getParam */ + def getModelIfNotSet: E5V = _model.get.value + + setDefault( + minOutputLength -> 0, + maxOutputLength -> 20, + doSample -> false, + temperature -> 0.6, + topK -> -1, + topP -> 0.9, + repetitionPenalty -> 1.0, + noRepeatNgramSize -> 3, + ignoreTokenIds -> Array(), + batchSize -> 1, + beamSize -> 1, + maxInputLength -> 4096, + stopTokenIds -> Array(2), + imageToken -> 128256, + patchSize -> 336) + + /** takes a document and annotations and produces new annotations of this annotator's annotation + * type + * + * @param batchedAnnotations + * Annotations in batches that correspond to inputAnnotationCols generated by previous + * annotators if any + * @return + * any number of annotations processed for every batch of input annotations. Not necessary + * one to one relationship + */ + override def batchAnnotate( + batchedAnnotations: Seq[Array[AnnotationImage]]): Seq[Seq[Annotation]] = { + + batchedAnnotations + // .filter { annotationImages => + // annotationImages.exists(_.text.nonEmpty) + // } + .map { cleanAnnotationImages => + val validImages = cleanAnnotationImages + val questionAnnotations = extractInputAnnotation(validImages) + + getModelIfNotSet.predict(questionAnnotations, validImages.toSeq) + } + } + + private def extractInputAnnotation( + annotationImages: Array[AnnotationImage]): Seq[Annotation] = { + val questions = annotationImages.map(annotationImage => { + val imageText = + if (annotationImage.text.nonEmpty) annotationImage.text + else + "<|user|> \n <|image|> This is an image\n <|end|>\n <|assistant|>\n" // default question + Annotation(imageText) + }) + + questions + } + + override def onWrite(path: String, spark: SparkSession): Unit = { + super.onWrite(path, spark) + getEngine match { + case Openvino.name => + val wrappers = getModelIfNotSet.openvinoWrapper + writeOpenvinoModels( + path, + spark, + Seq((wrappers.get.languageModel, "openvino_language_model-int4.xml")), + E5VEmbeddings.suffix) + + writeOpenvinoModels( + path, + spark, + Seq((wrappers.get.visionEmbeddingsModel, "openvino_vision_embeddings_model.xml")), + E5VEmbeddings.suffix) + + writeOpenvinoModels( + path, + spark, + Seq((wrappers.get.textEmbeddingsModel, "openvino_text_embeddings_model.xml")), + E5VEmbeddings.suffix) + + writeOpenvinoModels( + path, + spark, + Seq((wrappers.get.imagePackerModel, "openvino_image_packer.xml")), + E5VEmbeddings.suffix) + + writeOpenvinoModels( + path, + spark, + Seq((wrappers.get.mergeModel, "openvino_multimodal_merger.xml")), + E5VEmbeddings.suffix) + case _ => + throw new Exception(notSupportedEngineError) + } + } + +} + +trait ReadablePretrainedE5VEmbeddings + extends ParamsAndFeaturesReadable[E5VEmbeddings] + with HasPretrained[E5VEmbeddings] { + + override val defaultModelName: Some[String] = Some("e5v_1_5_7b_int4") + + /** Java compliant-overrides */ + override def pretrained(): E5VEmbeddings = super.pretrained() + + override def pretrained(name: String): E5VEmbeddings = + super.pretrained(name) + + override def pretrained(name: String, lang: String): E5VEmbeddings = + super.pretrained(name, lang) + + override def pretrained(name: String, lang: String, remoteLoc: String): E5VEmbeddings = + super.pretrained(name, lang, remoteLoc) + +} + +trait ReadE5VEmbeddingsDLModel extends ReadOpenvinoModel { + this: ParamsAndFeaturesReadable[E5VEmbeddings] => + val suffix: String = "_e5v" + override val openvinoFile: String = "e5v_openvino" + def readModel(instance: E5VEmbeddings, path: String, spark: SparkSession): Unit = { + instance.getEngine match { + case Openvino.name => + val languageModelWrappers = + readOpenvinoModels(path, spark, Seq("openvino_language_model-int4.xml"), suffix) + + val visionEmbeddingsModelWrappers = + readOpenvinoModels(path, spark, Seq("openvino_vision_embeddings_model.xml"), suffix) + + val textEmbeddingsModelWrappers = + readOpenvinoModels(path, spark, Seq("openvino_text_embeddings_model.xml"), suffix) + + val imagePackerModelWrappers = + readOpenvinoModels(path, spark, Seq("openvino_image_packer.xml"), suffix) + + val mergeModelWrappers = + readOpenvinoModels(path, spark, Seq("openvino_multimodal_merger.xml"), suffix) + + val ovWrapper = E5VWrappers( + languageModel = languageModelWrappers("openvino_language_model-int4.xml"), + visionEmbeddingsModel = + visionEmbeddingsModelWrappers("openvino_vision_embeddings_model.xml"), + textEmbeddingsModel = textEmbeddingsModelWrappers("openvino_text_embeddings_model.xml"), + mergeModel = mergeModelWrappers("openvino_multimodal_merger.xml"), + imagePackerModel = imagePackerModelWrappers("openvino_image_packer.xml")) + val preprocessor = Preprocessor( + do_normalize = true, + do_resize = true, + "E5VFeatureExtractor", + instance.getImageMean, + instance.getImageStd, + instance.getResample, + instance.getSize) + instance.setModelIfNotSet(spark, preprocessor, None, Some(ovWrapper)) + case _ => { + throw new Exception(notSupportedEngineError) + } + } + } + + addReader(readModel) + + def loadSavedModel( + modelPath: String, + spark: SparkSession, + useOpenvino: Boolean = false): E5VEmbeddings = { + implicit val formats: DefaultFormats.type = DefaultFormats // for json4 + val (localModelPath, detectedEngine) = + modelSanityCheck( + modelPath, + isDecoder = false, + custom = Some( + List( + "openvino_language_model-int4", + "openvino_vision_embeddings_model", + "openvino_text_embeddings_model", + "openvino_image_packer", + "openvino_multimodal_merger"))) + val modelConfig: JValue = + parse(loadJsonStringAsset(localModelPath, "config.json")) + + val preprocessorConfigJsonContent = + loadJsonStringAsset(localModelPath, "preprocessor_config.json") + val preprocessorConfig = Preprocessor.loadPreprocessorConfig(preprocessorConfigJsonContent) + val beginSuppressTokens: Array[Int] = + (modelConfig \ "begin_suppress_tokens").extract[Array[Int]] + + val suppressTokenIds: Array[Int] = + (modelConfig \ "suppress_tokens").extract[Array[Int]] + + val forcedDecoderIds: Array[(Int, Int)] = + (modelConfig \ "forced_decoder_ids").extract[Array[Array[Int]]].map { + case idxWithTokenId: Array[Int] if idxWithTokenId.length == 2 => + (idxWithTokenId(0), idxWithTokenId(1)) + case _ => + throw new Exception( + "Could not extract forced_decoder_ids. Should be a list of tuples with 2 entries.") + } + + def arrayOrNone[T](array: Array[T]): Option[Array[T]] = + if (array.nonEmpty) Some(array) else None + + val bosTokenId = (modelConfig \ "text_config" \ "bos_token_id").extract[Int] + val eosTokenId = (modelConfig \ "text_config" \ "eos_token_id").extract[Int] + val padTokenId = (modelConfig \ "text_config" \ "eos_token_id").extract[Int] + val vocabSize = (modelConfig \ "text_config" \ "vocab_size").extract[Int] + + val imageToken = (modelConfig \ "image_token_index").extract[Int] + val imageGridPinpoints: Array[Array[Int]] = + (modelConfig \ "image_grid_pinpoints").extract[Array[Array[Int]]] + val imageGridPinpointsMap: Map[Int, Array[Int]] = + imageGridPinpoints.zipWithIndex.map { case (pinpoints, index) => + (index, pinpoints) + }.toMap + // Check if tokenizer.json exists + val tokenizerPath = s"$localModelPath/assets/tokenizer.json" + val tokenizerExists = new java.io.File(tokenizerPath).exists() + val (vocabs, addedTokens, bytePairs) = if (tokenizerExists) { + val tokenizerConfig: JValue = parse(loadJsonStringAsset(localModelPath, "tokenizer.json")) + // extract vocab from tokenizer.json ( model -> vocab) + var vocabs: Map[String, Int] = + (tokenizerConfig \ "model" \ "vocab").extract[Map[String, Int]] + + // extract merges from tokenizer.json ( model -> merges) +// val bytePairs = (tokenizerConfig \ "model" \ "merges") +// .extract[List[Array[String]]] +// .filter(w => w.length == 2) +// .map { case Array(c1, c2) => (c1, c2) } +// .zipWithIndex +// .toMap + val bytePairs = (tokenizerConfig \ "model" \ "merges") + .extract[List[String]] + .map(_.split(" ")) + .filter(w => w.length == 2) + .map { case Array(c1, c2) => (c1, c2) } + .zipWithIndex + .toMap + // extract added_tokens from tokenizer.json (added_tokens) + // "added_tokens": [ + // { + // "id": 128000, + // "content": "<|begin_of_text|>", + // "single_word": false, + // "lstrip": false, + // "rstrip": false, + // "normalized": false, + // "special": true + // }, ... + // ] + val addedTokens = (tokenizerConfig \ "added_tokens") + .extract[List[Map[String, Any]]] + .map { token => + val id = token("id").asInstanceOf[BigInt].intValue() + val content = token("content").asInstanceOf[String] + (content, id) + } + .toMap + + // update vocab with added tokens + addedTokens.foreach { case (content, id) => + vocabs += (content -> id) + } + (vocabs, addedTokens, bytePairs) + } else { + val vocabs = loadTextAsset(localModelPath, "vocab.txt").zipWithIndex.toMap + val addedTokens = loadTextAsset(localModelPath, "added_tokens.txt").zipWithIndex.toMap + val bytePairs = loadTextAsset(localModelPath, "merges.txt") + .map(_.split(" ")) + .filter(w => w.length == 2) + .map { case Array(c1, c2) => (c1, c2) } + .zipWithIndex + .toMap + (vocabs, addedTokens, bytePairs) + } + + val annotatorModel = new E5VEmbeddings() + .setGenerationConfig( + GenerationConfig( + bosTokenId, + padTokenId, + eosTokenId, + vocabSize, + arrayOrNone(beginSuppressTokens), + arrayOrNone(suppressTokenIds), + arrayOrNone(forcedDecoderIds))) + .setVocabulary(vocabs) + .setMerges(bytePairs) + .setAddedTokens(addedTokens) + .setImageToken(imageToken) + .setSize(preprocessorConfig.size) + .setImageMean(preprocessorConfig.image_mean) + .setImageStd(preprocessorConfig.image_std) + .setResample(preprocessorConfig.resample) + .setImageGridPinpoints(imageGridPinpointsMap) + + val modelEngine = + if (useOpenvino) + Openvino.name + else + detectedEngine + annotatorModel.set(annotatorModel.engine, modelEngine) + + detectedEngine match { + case Openvino.name => + val visionWrapper = + OpenvinoWrapper.read( + spark, + localModelPath, + zipped = false, + useBundle = true, + detectedEngine = detectedEngine, + modelName = "openvino_vision_embeddings_model") + val textWrapper = + OpenvinoWrapper.read( + spark, + localModelPath, + zipped = false, + useBundle = true, + detectedEngine = detectedEngine, + modelName = "openvino_text_embeddings_model") + + val imagePackerModelWrapper = + OpenvinoWrapper.read( + spark, + localModelPath, + zipped = false, + useBundle = true, + detectedEngine = detectedEngine, + modelName = "openvino_image_packer") + + val mergeWrapper = + OpenvinoWrapper.read( + spark, + localModelPath, + zipped = false, + useBundle = true, + detectedEngine = detectedEngine, + modelName = "openvino_multimodal_merger") + val languageModelWrapper = + OpenvinoWrapper.read( + spark, + localModelPath, + zipped = false, + useBundle = true, + detectedEngine = detectedEngine, + modelName = "openvino_language_model-int4") + + val openvinoWrapper = E5VWrappers( + languageModel = languageModelWrapper, + visionEmbeddingsModel = visionWrapper, + textEmbeddingsModel = textWrapper, + imagePackerModel = imagePackerModelWrapper, + mergeModel = mergeWrapper) + annotatorModel.setModelIfNotSet(spark, preprocessorConfig, None, Some(openvinoWrapper)) + case _ => + throw new Exception(notSupportedEngineError) + } + + annotatorModel + } +} + +object E5VEmbeddings extends ReadablePretrainedE5VEmbeddings with ReadE5VEmbeddingsDLModel diff --git a/src/main/scala/com/johnsnowlabs/nlp/pretrained/ResourceDownloader.scala b/src/main/scala/com/johnsnowlabs/nlp/pretrained/ResourceDownloader.scala index 6cb1ab2aa565f0..fcd7c11be57029 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/pretrained/ResourceDownloader.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/pretrained/ResourceDownloader.scala @@ -710,7 +710,8 @@ object PythonResourceDownloader { "PaliGemmaForMultiModal" -> PaliGemmaForMultiModal, "Gemma3ForMultiModal" -> Gemma3ForMultiModal, "InternVLForMultiModal" -> InternVLForMultiModal, - "Florence2Transformer" -> Florence2Transformer) + "Florence2Transformer" -> Florence2Transformer, + "E5VEmbeddings" -> E5VEmbeddings) // List pairs of types such as the one with key type can load a pretrained model from the value type val typeMapper: Map[String, String] = Map("ZeroShotNerModel" -> "RoBertaForQuestionAnswering") diff --git a/src/main/scala/com/johnsnowlabs/nlp/util/EmbeddingsDataFrameUtils.scala b/src/main/scala/com/johnsnowlabs/nlp/util/EmbeddingsDataFrameUtils.scala new file mode 100644 index 00000000000000..5c701306da0c22 --- /dev/null +++ b/src/main/scala/com/johnsnowlabs/nlp/util/EmbeddingsDataFrameUtils.scala @@ -0,0 +1,22 @@ +package com.johnsnowlabs.nlp.util + +import org.apache.spark.sql.Row +import org.apache.spark.sql.types._ + +object EmbeddingsDataFrameUtils { + // Schema Spark expects for `format("image")` + val imageSchema: StructType = StructType( + Seq( + StructField( + "image", + StructType(Seq( + StructField("origin", StringType, true), + StructField("height", IntegerType, true), + StructField("width", IntegerType, true), + StructField("nChannels", IntegerType, true), + StructField("mode", IntegerType, true), + StructField("data", BinaryType, true)))))) + + // A reusable null image row for text-only embedding scenarios + val emptyImageRow: Row = Row(Row("", 0, 0, 0, 0, Array[Byte]())) +} diff --git a/src/test/scala/com/johnsnowlabs/nlp/embeddings/E5VEmbeddingsTestSpec.scala b/src/test/scala/com/johnsnowlabs/nlp/embeddings/E5VEmbeddingsTestSpec.scala new file mode 100644 index 00000000000000..9bc4d00f9b7eb2 --- /dev/null +++ b/src/test/scala/com/johnsnowlabs/nlp/embeddings/E5VEmbeddingsTestSpec.scala @@ -0,0 +1,87 @@ +/* + * Copyright 2017-2022 John Snow Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.johnsnowlabs.nlp.embeddings + +import com.johnsnowlabs.nlp.{AssertAnnotations, ImageAssembler} +import com.johnsnowlabs.nlp.util.io.ResourceHelper +import com.johnsnowlabs.tags.{FastTest, SlowTest} +import org.apache.spark.ml.Pipeline +import org.apache.spark.sql.{DataFrame, Encoder, Encoders} +import org.apache.spark.sql.functions.{col, lit, size} +import org.scalatest.flatspec.AnyFlatSpec +import com.johnsnowlabs.nlp.util.EmbeddingsDataFrameUtils.{emptyImageRow, imageSchema} + +class E5VEmbeddingsTestSpec extends AnyFlatSpec { + lazy val model = getE5VEmbeddingsPipelineModel + + val textPrompt = + "<|start_header_id|>user<|end_header_id|>\n\n<sent>\\nSummary above sentence in one word: <|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n" + val imagePrompt = + "<|start_header_id|>user<|end_header_id|>\n\n<image>\\nSummary above image in one word: <|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n" + + "E5V Embeddings" should "correctly embed sentences" taggedAs SlowTest in { + val testDF = getTestDF + val result = model.transform(testDF) + + result.select("e5v.embeddings").show(true) + + } + + private def getTestDF: DataFrame = { + val imageFolder = "src/test/resources/image1/" + val imageDF: DataFrame = ResourceHelper.spark.read + .format("image") + .option("dropInvalid", value = true) + .load(imageFolder) + + val testDF: DataFrame = imageDF.withColumn("text", lit(imagePrompt)) + val textDesc = "A cat sitting in a box." + + // Create DataFrame with a single null image row + val spark = ResourceHelper.spark + val nullImageDF = + spark.createDataFrame(spark.sparkContext.parallelize(Seq(emptyImageRow)), imageSchema) + + val textDF = nullImageDF.withColumn("text", lit(textPrompt.replace("<sent>", textDesc))) + + testDF.union(textDF) +// textDF + } + private def getE5VEmbeddingsPipelineModel = { + val testDF = getTestDF + + val imageAssembler: ImageAssembler = new ImageAssembler() + .setInputCol("image") + .setOutputCol("image_assembler") + + val loadModel = E5VEmbeddings + .pretrained() + .setInputCols("image_assembler") + .setOutputCol("e5v") + + val newPipeline: Pipeline = + new Pipeline().setStages(Array(imageAssembler, loadModel)) + + val pipelineModel = newPipeline.fit(testDF) + + pipelineModel + .transform(testDF) + .show(truncate = true) + + pipelineModel + } +}