diff --git a/docs/en/transformer_entries/E5VEmbeddings.md b/docs/en/transformer_entries/E5VEmbeddings.md
new file mode 100644
index 00000000000000..68ff482cd2f900
--- /dev/null
+++ b/docs/en/transformer_entries/E5VEmbeddings.md
@@ -0,0 +1,133 @@
+{%- capture title -%}
+E5VEmbeddings
+{%- endcapture -%}
+
+{%- capture description -%}
+Universal multimodal embeddings using E5-V.
+
+E5-V is a multimodal embedding model that bridges the modality gap between text and images, enabling strong performance in cross-modal retrieval, classification, clustering, and more. It supports both image+text and text-only embedding scenarios, and is fine-tuned from lmms-lab/llama3-llava-next-8b. The default model is `"e5v_int4"`.
+
+Note that this annotator is only supported for Spark Versions 3.4 and up.
+
+Pretrained models can be loaded with `pretrained` of the companion object:
+
+```scala
+val embeddings = E5VEmbeddings.pretrained()
+  .setInputCols("image_assembler")
+  .setOutputCol("e5v")
+```
+
+For available pretrained models please see the
+[Models Hub](https://sparknlp.org/models?q=E5V).
+
+For extended examples of usage, see
+[E5VEmbeddingsTestSpec](https://github.com/JohnSnowLabs/spark-nlp/blob/master/src/test/scala/com/johnsnowlabs/nlp/embeddings/E5VEmbeddingsTestSpec.scala).
+
+**Sources** :
+
+- [E5-V: Universal Embeddings with Multimodal Large Language Models (arXiv)](https://arxiv.org/abs/2407.12580)
+- [Hugging Face Model Card](https://huggingface.co/royokong/e5-v)
+- [E5-V Github Repository](https://github.com/kongds/E5-V)
+{%- endcapture -%}
+
+{%- capture input_anno -%}
+IMAGE
+{%- endcapture -%}
+
+{%- capture output_anno -%}
+SENTENCE_EMBEDDINGS
+{%- endcapture -%}
+
+{%- capture python_example -%}
+# Image + Text Embedding
+import sparknlp
+from sparknlp.base import *
+from sparknlp.annotator import *
+from pyspark.ml import Pipeline
+from pyspark.sql.functions import lit
+
+image_df = spark.read.format("image").option("dropInvalid", True).load(imageFolder)
+imagePrompt = "<|start_header_id|>user<|end_header_id|>\n\n<image>\\nSummary above image in one word: <|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n"
+test_df = image_df.withColumn("text", lit(imagePrompt))
+imageAssembler = ImageAssembler() \
+    .setInputCol("image") \
+    .setOutputCol("image_assembler")
+e5vEmbeddings = E5VEmbeddings.pretrained() \
+    .setInputCols(["image_assembler"]) \
+    .setOutputCol("e5v")
+pipeline = Pipeline().setStages([
+    imageAssembler,
+    e5vEmbeddings
+])
+result = pipeline.fit(test_df).transform(test_df)
+result.select("e5v.embeddings").show(truncate=False)
+
+# Text-Only Embedding
+from sparknlp.util import EmbeddingsDataFrameUtils
+textPrompt = "<|start_header_id|>user<|end_header_id|>\n\n<sent>\\nSummary above sentence in one word: <|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n"
+textDesc = "A cat sitting in a box."
+nullImageDF = spark.createDataFrame(
+    spark.sparkContext.parallelize([EmbeddingsDataFrameUtils.emptyImageRow]),
+    EmbeddingsDataFrameUtils.imageSchema)
+textDF = nullImageDF.withColumn("text", lit(textPrompt.replace("<sent>", textDesc)))
+e5vEmbeddings = E5VEmbeddings.pretrained() \
+    .setInputCols(["image"]) \
+    .setOutputCol("e5v")
+result = e5vEmbeddings.transform(textDF)
+result.select("e5v.embeddings").show(truncate=False)
+{%- endcapture -%}
+
+{%- capture scala_example -%}
+// Image + Text Embedding
+import org.apache.spark.sql.functions.lit
+import com.johnsnowlabs.nlp.base.ImageAssembler
+import com.johnsnowlabs.nlp.embeddings.E5VEmbeddings
+import org.apache.spark.ml.Pipeline
+
+val imageDF = spark.read.format("image").option("dropInvalid", value = true).load(imageFolder)
+val imagePrompt = "<|start_header_id|>user<|end_header_id|>\n\n<image>\\nSummary above image in one word: <|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n"
+val testDF = imageDF.withColumn("text", lit(imagePrompt))
+val imageAssembler = new ImageAssembler().setInputCol("image").setOutputCol("image_assembler")
+val e5vEmbeddings = E5VEmbeddings.pretrained()
+  .setInputCols("image_assembler")
+  .setOutputCol("e5v")
+val pipeline = new Pipeline().setStages(Array(imageAssembler, e5vEmbeddings))
+val result = pipeline.fit(testDF).transform(testDF)
+result.select("e5v.embeddings").show(truncate = false)
+
+// Text-Only Embedding
+import com.johnsnowlabs.nlp.util.EmbeddingsDataFrameUtils.{emptyImageRow, imageSchema}
+val textPrompt = "<|start_header_id|>user<|end_header_id|>\n\n<sent>\\nSummary above sentence in one word: <|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n"
+val textDesc = "A cat sitting in a box."
+val nullImageDF = spark.createDataFrame(spark.sparkContext.parallelize(Seq(emptyImageRow)), imageSchema)
+val textDF = nullImageDF.withColumn("text", lit(textPrompt.replace("<sent>", textDesc)))
+val e5vEmbeddings = E5VEmbeddings.pretrained()
+  .setInputCols("image")
+  .setOutputCol("e5v")
+val result2 = e5vEmbeddings.transform(textDF)
+result2.select("e5v.embeddings").show(truncate = false)
+{%- endcapture -%}
+
+{%- capture api_link -%}
+[E5VEmbeddings](/api/com/johnsnowlabs/nlp/embeddings/E5VEmbeddings)
+{%- endcapture -%}
+
+{%- capture python_api_link -%}
+[E5VEmbeddings](/api/python/reference/autosummary/sparknlp/annotator/cv/e5v_embeddings/index.html#sparknlp.annotator.cv.e5v_embeddings.E5VEmbeddings)
+{%- endcapture -%}
+
+{%- capture source_link -%}
+[E5VEmbeddings](https://github.com/JohnSnowLabs/spark-nlp/tree/master/src/main/scala/com/johnsnowlabs/nlp/embeddings/E5VEmbeddings.scala)
+{%- endcapture -%}
+
+{% include templates/anno_template.md
+  title=title
+  description=description
+  input_anno=input_anno
+  output_anno=output_anno
+  python_example=python_example
+  scala_example=scala_example
+  api_link=api_link
+  python_api_link=python_api_link
+  source_link=source_link
+%} 
\ No newline at end of file
diff --git a/examples/python/transformers/openvino/HuggingFace_OpenVINO_in_Spark_NLP_E5VEmbeddings.ipynb b/examples/python/transformers/openvino/HuggingFace_OpenVINO_in_Spark_NLP_E5VEmbeddings.ipynb
new file mode 100644
index 00000000000000..a0757f51a05f0e
--- /dev/null
+++ b/examples/python/transformers/openvino/HuggingFace_OpenVINO_in_Spark_NLP_E5VEmbeddings.ipynb
@@ -0,0 +1,1530 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "id": "c9c512f5",
+   "metadata": {},
+   "source": [
+    "![JohnSnowLabs](https://sparknlp.org/assets/images/logo.png)\n",
+    "\n",
+    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JohnSnowLabs/spark-nlp/blob/master/examples/python/transformers/openvino/HuggingFace_OpenVINO_in_Spark_NLP_E5V.ipynb)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "860a240a",
+   "metadata": {},
+   "source": [
+    "# Import OpenVINO E5V models from HuggingFace 🤗 into Spark NLP 🚀\n",
+    "\n",
+    "This notebook provides a detailed walkthrough on optimizing and importing E5V models from HuggingFace  for use in Spark NLP, with [Intel OpenVINO toolkit](https://www.intel.com/content/www/us/en/developer/tools/openvino-toolkit/overview.html). The focus is on converting the model to the OpenVINO format and applying precision optimizations (INT8 and INT4), to enhance the performance and efficiency on CPU platforms using [Optimum Intel](https://huggingface.co/docs/optimum/main/en/intel/inference).\n",
+    "\n",
+    "Let's keep in mind a few things before we start 😊\n",
+    "\n",
+    "- OpenVINO support was introduced in  `Spark NLP 5.4.0`, enabling high performance CPU inference for models. So please make sure you have upgraded to the latest Spark NLP release.\n",
+    "- Model quantization is a computationally expensive process, so it is recommended to use a runtime with more than 32GB memory for exporting the quantized model from HuggingFace.\n",
+    "- You can import E5V models via `E5V`. These models are usually under `Text Generation` category and have `E5V` in their labels.\n",
+    "- Reference: [E5V](https://huggingface.co/docs/transformers/model_doc/llama#transformers.E5V)\n",
+    "- Some [example models](https://huggingface.co/models?search=E5V)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "100a6911",
+   "metadata": {},
+   "source": [
+    "## 1. Export and Save the HuggingFace model\n",
+    "\n",
+    "- Let's install `transformers` and `openvino` packages with other dependencies. You don't need `openvino` to be installed for Spark NLP, however, we need it to load and save models from HuggingFace.\n",
+    "- We lock `transformers` on version `4.41.2`. This doesn't mean it won't work with the future release, but we wanted you to know which versions have been tested successfully."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "id": "902635c5",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import os\n",
+    "\n",
+    "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\""
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "529ad224",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# # Install OpenVINO and NNCF for model optimization\n",
+    "import platform\n",
+    "\n",
+    "%pip install -q \"einops\" \"torch>2.1\" \"torchvision\" \"matplotlib>=3.4\" \"timm>=0.9.8\" \"transformers==4.41.2\" \"pillow\" \"gradio>=4.19\" --extra-index-url https://download.pytorch.org/whl/cpu\n",
+    "%pip install -q -U --pre \"openvino>=2025.0\" \"openvino-tokenizers>=2025.0\" \"openvino-genai>=2025.0\" --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly\n",
+    "%pip install -q \"accelerate\" \"nncf>=2.14.0\" \"git+https://github.com/huggingface/optimum-intel.git\" --extra-index-url https://download.pytorch.org/whl/cpu\n",
+    "\n",
+    "if platform.system() == \"Darwin\":\n",
+    "    %pip install -q \"numpy<2.0.0\""
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "3997e780",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "!wget https://upload.wikimedia.org/wikipedia/commons/thumb/4/47/American_Eskimo_Dog.jpg/360px-American_Eskimo_Dog.jpg -O dog.jpg"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "id": "c1623528",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "model_id = \"royokong/e5-v\"\n",
+    "output_dir = f\"./models/int4/{model_id}\""
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "id": "46678a0b",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/home/prabod/anaconda3/envs/e5v/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
+      "  from .autonotebook import tqdm as notebook_tqdm\n",
+      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
+      "Loading checkpoint shards: 100%|██████████| 4/4 [01:20<00:00, 20.18s/it]\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "111"
+      ]
+     },
+     "execution_count": 3,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration\n",
+    "import torch\n",
+    "import gc\n",
+    "\n",
+    "processor = LlavaNextProcessor.from_pretrained(model_id)\n",
+    "image_encoder_model, input_embedding_model, language_model = None, None, None\n",
+    "\n",
+    "\n",
+    "class ImageEncoder(torch.nn.Module):\n",
+    "    def __init__(self, config, vision_tower, multi_modal_projector):\n",
+    "        super().__init__()\n",
+    "        self.config = config\n",
+    "        self.vision_tower = vision_tower\n",
+    "        self.multi_modal_projector = multi_modal_projector\n",
+    "\n",
+    "    def forward(self, pixel_values):\n",
+    "        batch_size, num_patches, num_channels, height, width = pixel_values.shape\n",
+    "        reshaped_pixel_values = pixel_values.view(\n",
+    "            batch_size * num_patches, num_channels, height, width\n",
+    "        )\n",
+    "        image_features = self.vision_tower(\n",
+    "            reshaped_pixel_values, output_hidden_states=True\n",
+    "        )\n",
+    "        selected_image_feature = image_features.hidden_states[\n",
+    "            self.config.vision_feature_layer\n",
+    "        ]\n",
+    "        if self.config.vision_feature_select_strategy == \"default\":\n",
+    "            selected_image_feature = selected_image_feature[:, 1:]\n",
+    "        elif self.config.vision_feature_select_strategy == \"full\":\n",
+    "            selected_image_feature = selected_image_feature\n",
+    "        image_features = self.multi_modal_projector(selected_image_feature)\n",
+    "        return image_features\n",
+    "\n",
+    "\n",
+    "model = LlavaNextForConditionalGeneration.from_pretrained(\n",
+    "    model_id, low_cpu_mem_usage=True\n",
+    ")\n",
+    "model.config.save_pretrained(output_dir)\n",
+    "image_encoder_model = ImageEncoder(\n",
+    "    model.config, model.vision_tower, model.multi_modal_projector\n",
+    ")\n",
+    "input_embedding_model = input_embedding_model = model.get_input_embeddings()\n",
+    "language_model = model.language_model\n",
+    "del model\n",
+    "gc.collect()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "id": "1908bc09",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import openvino as ov\n",
+    "from pathlib import Path\n",
+    "\n",
+    "core = ov.Core()\n",
+    "device = \"CPU\"\n",
+    "# Load the model and convert it to OpenVINO format\n",
+    "output_dir = f\"./models/int4/{model_id}\"\n",
+    "output_dir = Path(output_dir)\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "id": "2341d4b8",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "IMAGE_ENCODER_PATH = output_dir / \"openvino_vision_embeddings_model.xml\"\n",
+    "LANGUAGE_MODEL_PATH = output_dir / \"openvino_language_model.xml\"\n",
+    "INPUT_EMBEDDING_PATH = output_dir / \"openvino_text_embeddings_model.xml\"\n",
+    "\n",
+    "IMAGE_PACKER_PATH = output_dir / \"openvino_image_packer.xml\"\n",
+    "MULTIMODAL_MERGER_PATH = output_dir / \"openvino_multimodal_merger.xml\""
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "id": "6a0e77cd",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/home/prabod/anaconda3/envs/e5v/lib/python3.11/site-packages/transformers/modeling_utils.py:4481: FutureWarning: `_is_quantized_training_enabled` is going to be deprecated in transformers 4.39.0. Please use `model.hf_quantizer.is_trainable` instead\n",
+      "  warnings.warn(\n",
+      "/home/prabod/anaconda3/envs/e5v/lib/python3.11/site-packages/transformers/models/clip/modeling_clip.py:276: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n",
+      "  if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):\n",
+      "/home/prabod/anaconda3/envs/e5v/lib/python3.11/site-packages/transformers/models/clip/modeling_clip.py:316: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n",
+      "  if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "7397"
+      ]
+     },
+     "execution_count": 6,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "import torch\n",
+    "import openvino as ov\n",
+    "import gc\n",
+    "\n",
+    "\n",
+    "def cleanup_torchscript_cache():\n",
+    "    \"\"\"\n",
+    "    Helper for removing cached model representation\n",
+    "    \"\"\"\n",
+    "    torch._C._jit_clear_class_registry()\n",
+    "    torch.jit._recursive.concrete_type_store = torch.jit._recursive.ConcreteTypeStore()\n",
+    "    torch.jit._state._clear_class_state()\n",
+    "\n",
+    "\n",
+    "if not IMAGE_ENCODER_PATH.exists():\n",
+    "    ov_image_encoder = ov.convert_model(\n",
+    "        image_encoder_model, example_input=torch.zeros((1, 5, 3, 336, 336))\n",
+    "    )\n",
+    "    ov.save_model(ov_image_encoder, IMAGE_ENCODER_PATH)\n",
+    "    del ov_image_encoder\n",
+    "    cleanup_torchscript_cache()\n",
+    "\n",
+    "del image_encoder_model\n",
+    "gc.collect()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "id": "0147d547",
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "117"
+      ]
+     },
+     "execution_count": 7,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "llm_input = None\n",
+    "\n",
+    "llm_input = input_embedding_model(torch.ones((2, 2), dtype=torch.int64))\n",
+    "\n",
+    "if not INPUT_EMBEDDING_PATH.exists():\n",
+    "    ov_input_embeddings_model = ov.convert_model(\n",
+    "        input_embedding_model, example_input=torch.ones((2, 2), dtype=torch.int64)\n",
+    "    )\n",
+    "    ov.save_model(ov_input_embeddings_model, INPUT_EMBEDDING_PATH)\n",
+    "    del ov_input_embeddings_model\n",
+    "    cleanup_torchscript_cache()\n",
+    "\n",
+    "del input_embedding_model\n",
+    "gc.collect()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "id": "18b0be05",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/home/prabod/anaconda3/envs/e5v/lib/python3.11/site-packages/openvino/runtime/__init__.py:10: DeprecationWarning: The `openvino.runtime` module is deprecated and will be removed in the 2026.0 release. Please replace `openvino.runtime` with `openvino`.\n",
+      "  warnings.warn(\n"
+     ]
+    }
+   ],
+   "source": [
+    "from typing import Optional, Tuple, List\n",
+    "from openvino.runtime import opset13\n",
+    "import numpy as np\n",
+    "\n",
+    "\n",
+    "def model_has_state(ov_model: ov.Model):\n",
+    "    return len(ov_model.get_sinks()) > 0\n",
+    "\n",
+    "\n",
+    "def model_has_input_output_name(ov_model: ov.Model, name: str):\n",
+    "    \"\"\"\n",
+    "    Helper function for checking that model has specified input or output name\n",
+    "\n",
+    "    Parameters:\n",
+    "      ov_model (ov.Model):\n",
+    "      name (str):\n",
+    "          name of input or output\n",
+    "\n",
+    "    Returns:\n",
+    "      True if input or output with requested name exists else False\n",
+    "    \"\"\"\n",
+    "    return name in sum(\n",
+    "        [list(t.get_names()) for t in ov_model.inputs + ov_model.outputs], []\n",
+    "    )\n",
+    "\n",
+    "\n",
+    "def fuse_cache_reorder(\n",
+    "    ov_model: ov.Model,\n",
+    "    not_kv_inputs: List[str],\n",
+    "    key_value_input_names: List[str],\n",
+    "    gather_dim: int,\n",
+    "):\n",
+    "    \"\"\"\n",
+    "    Fuses reored_cache during generate cycle into ov.Model. Used with stateful models, because we can not modify model state directly.\n",
+    "\n",
+    "    Adds a new beam_idx parameter and Gather op per each kv-cache input in a given model.\n",
+    "    Should be run before make_stateful. Implements optimumum's _reorder_cache\n",
+    "    inside the model in the beginning of each iteration.\n",
+    "    Gather works along given gather_dim dimension that may vary from model to model.\n",
+    "    KV-cache inputs are identified based on names in key_value_input_names.\n",
+    "    Append the new beam_idx parameter to not_kv_inputs.\n",
+    "\n",
+    "    Parameters:\n",
+    "      ov_model (`ov.Model`):\n",
+    "          openvino model for processing\n",
+    "      not_kv_inputs (`List[str]`):\n",
+    "          list of input nodes in model that not related to past key values\n",
+    "      key_value_input_names (`List[str]`):\n",
+    "          list of names for key value input layers\n",
+    "      gather_dim (int):\n",
+    "          dimension for gathering cache during reorder pass\n",
+    "    \"\"\"\n",
+    "\n",
+    "    if model_has_input_output_name(ov_model, \"beam_idx\"):\n",
+    "        raise ValueError(\"Model already has fused cache\")\n",
+    "    input_batch = ov_model.input(\"inputs_embeds\").get_partial_shape()[0]\n",
+    "    beam_idx = opset13.parameter(\n",
+    "        name=\"beam_idx\", dtype=ov.Type.i32, shape=ov.PartialShape([input_batch])\n",
+    "    )\n",
+    "    beam_idx.output(0).get_tensor().add_names({\"beam_idx\"})  # why list is not accepted?\n",
+    "    ov_model.add_parameters([beam_idx])\n",
+    "    not_kv_inputs.append(ov_model.inputs[-1])\n",
+    "    # Go over all cache parameters and fuse _reorder_cache with indices provided by the new parameter beam_idx\n",
+    "    for input_name in key_value_input_names:\n",
+    "        parameter_output_port = ov_model.input(input_name)\n",
+    "        consumers = parameter_output_port.get_target_inputs()\n",
+    "        gather = opset13.gather(\n",
+    "            parameter_output_port, beam_idx, opset13.constant(gather_dim)\n",
+    "        )\n",
+    "        for consumer in consumers:\n",
+    "            consumer.replace_source_output(gather.output(0))\n",
+    "    ov_model.validate_nodes_and_infer_types()\n",
+    "\n",
+    "\n",
+    "def build_state_initializer(ov_model: ov.Model, batch_dim: int):\n",
+    "    \"\"\"\n",
+    "    Build initialization ShapeOf Expression for all ReadValue ops\n",
+    "\n",
+    "    Parameters:\n",
+    "      ov_model (ov.Model):\n",
+    "          openvino model\n",
+    "      batch_dim (int):\n",
+    "          index of dimension corresponding to batch size\n",
+    "    \"\"\"\n",
+    "    input_ids = ov_model.input(\"inputs_embeds\")\n",
+    "    batch = opset13.gather(\n",
+    "        opset13.shape_of(input_ids, output_type=\"i64\"),\n",
+    "        opset13.constant([0]),\n",
+    "        opset13.constant(0),\n",
+    "    )\n",
+    "    for op in ov_model.get_ops():\n",
+    "        if op.get_type_name() == \"ReadValue\":\n",
+    "            dims = [dim.min_length for dim in list(op.get_output_partial_shape(0))]\n",
+    "            dims[batch_dim] = batch\n",
+    "            dims = [\n",
+    "                (\n",
+    "                    opset13.constant(np.array([dim], dtype=np.int64))\n",
+    "                    if isinstance(dim, int)\n",
+    "                    else dim\n",
+    "                )\n",
+    "                for dim in dims\n",
+    "            ]\n",
+    "            shape = opset13.concat(dims, axis=0)\n",
+    "            broadcast = opset13.broadcast(\n",
+    "                opset13.constant(0.0, dtype=op.get_output_element_type(0)), shape\n",
+    "            )\n",
+    "            op.set_arguments([broadcast])\n",
+    "    ov_model.validate_nodes_and_infer_types()\n",
+    "\n",
+    "\n",
+    "def make_stateful(\n",
+    "    ov_model: ov.Model,\n",
+    "    not_kv_inputs: List[str],\n",
+    "    key_value_input_names: List[str],\n",
+    "    key_value_output_names: List[str],\n",
+    "    batch_dim: int,\n",
+    "    num_attention_heads: int,\n",
+    "    num_beams_and_batch: int = None,\n",
+    "):\n",
+    "    \"\"\"\n",
+    "    Hides kv-cache inputs and outputs inside the model as variables.\n",
+    "\n",
+    "    Parameters:\n",
+    "        ov_model (ov.Model):\n",
+    "            openvino model\n",
+    "        not_kv_inputs (`List[str]`):\n",
+    "            list of input nodes in model that not related to past key values\n",
+    "        key_value_input_names (`List[str]`):\n",
+    "            list of names for key value input layers\n",
+    "        key_value_output_names (`List[str]`):\n",
+    "            list of names for key value input layers\n",
+    "        batch_dim (int):\n",
+    "            index of batch dimension in key value layers\n",
+    "        num_attention_heads (int):\n",
+    "            number of attention heads for batch dimension initialization\n",
+    "        num_beams_an_batch (int):\n",
+    "            precalculated number of beams and batch for shapes initialization\n",
+    "    \"\"\"\n",
+    "    from openvino._offline_transformations import apply_make_stateful_transformation\n",
+    "\n",
+    "    input_output_map = {}\n",
+    "\n",
+    "    if num_beams_and_batch is not None:\n",
+    "        # Set batch size for input_ids and attention mask to avoid dynamic dimension got propagated from the end of the model back to ReadValue\n",
+    "        for input in not_kv_inputs:\n",
+    "            shape = input.get_partial_shape()\n",
+    "            if shape.rank.get_length() <= 2:  # == 1 for beam_index\n",
+    "                shape[0] = num_beams_and_batch\n",
+    "                input.get_node().set_partial_shape(shape)\n",
+    "    for kv_name_pair in zip(key_value_input_names, key_value_output_names):\n",
+    "        input_output_map[kv_name_pair[0]] = kv_name_pair[1]\n",
+    "        if num_beams_and_batch is not None:\n",
+    "            input = ov_model.input(kv_name_pair[0])\n",
+    "            shape = input.get_partial_shape()\n",
+    "            shape[batch_dim] = num_beams_and_batch * num_attention_heads\n",
+    "            input.get_node().set_partial_shape(shape)\n",
+    "\n",
+    "    if num_beams_and_batch is not None:\n",
+    "        # Re-validation model if shapes are altered above\n",
+    "        ov_model.validate_nodes_and_infer_types()\n",
+    "\n",
+    "    apply_make_stateful_transformation(ov_model, input_output_map)\n",
+    "    if num_beams_and_batch is None:\n",
+    "        build_state_initializer(ov_model, batch_dim)\n",
+    "\n",
+    "\n",
+    "def patch_stateful(ov_model):\n",
+    "    key_value_input_names = [key.get_any_name() for key in ov_model.inputs[2:-1]]\n",
+    "    key_value_output_names = [key.get_any_name() for key in ov_model.outputs[1:]]\n",
+    "    not_kv_inputs = [\n",
+    "        input\n",
+    "        for input in ov_model.inputs\n",
+    "        if not any(name in key_value_input_names for name in input.get_names())\n",
+    "    ]\n",
+    "    if not key_value_input_names or not key_value_output_names:\n",
+    "        return\n",
+    "    batch_dim = 0\n",
+    "    num_attention_heads = 1\n",
+    "\n",
+    "    fuse_cache_reorder(ov_model, not_kv_inputs, key_value_input_names, batch_dim)\n",
+    "    make_stateful(\n",
+    "        ov_model,\n",
+    "        not_kv_inputs,\n",
+    "        key_value_input_names,\n",
+    "        key_value_output_names,\n",
+    "        batch_dim,\n",
+    "        num_attention_heads,\n",
+    "        None,\n",
+    "    )"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "id": "5cd69acd",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "Loading checkpoint shards: 100%|██████████| 4/4 [00:04<00:00,  1.00s/it]\n",
+      "/home/prabod/anaconda3/envs/e5v/lib/python3.11/site-packages/transformers/modeling_utils.py:4481: FutureWarning: `_is_quantized_training_enabled` is going to be deprecated in transformers 4.39.0. Please use `model.hf_quantizer.is_trainable` instead\n",
+      "  warnings.warn(\n",
+      "/home/prabod/anaconda3/envs/e5v/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py:1060: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n",
+      "  if sequence_length != 1:\n",
+      "/home/prabod/anaconda3/envs/e5v/lib/python3.11/site-packages/torch/jit/_trace.py:165: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations. (Triggered internally at /pytorch/build/aten/src/ATen/core/TensorBody.h:489.)\n",
+      "  if a.grad is not None:\n"
+     ]
+    }
+   ],
+   "source": [
+    "import types\n",
+    "\n",
+    "make_stateful_model = False\n",
+    "core = ov.Core()\n",
+    "model = LlavaNextForConditionalGeneration.from_pretrained(\n",
+    "    model_id, low_cpu_mem_usage=True\n",
+    ")\n",
+    "language_model = model.language_model\n",
+    "if not LANGUAGE_MODEL_PATH.exists() or True:\n",
+    "\n",
+    "    def forward_wrap(\n",
+    "        self,\n",
+    "        attention_mask,\n",
+    "        position_ids=None,\n",
+    "        past_key_values=None,\n",
+    "        inputs_embeds=None,\n",
+    "    ):\n",
+    "        result = self._orig_forward(\n",
+    "            input_ids=None,\n",
+    "            attention_mask=attention_mask,\n",
+    "            position_ids=position_ids,\n",
+    "            past_key_values=past_key_values,\n",
+    "            inputs_embeds=inputs_embeds,\n",
+    "            output_hidden_states=True,\n",
+    "            return_dict=True,\n",
+    "        )\n",
+    "        return result[\"hidden_states\"][-1][:, -1, :]\n",
+    "\n",
+    "    model_inputs = [\"attention_mask\", \"position_ids\"]\n",
+    "    model_outputs = [\"last_hidden_state\"]\n",
+    "    model_inputs.append(\"inputs_embeds\")\n",
+    "    language_model.config.torchscript = True\n",
+    "    position_ids = torch.tensor([[2, 3], [2, 3]])\n",
+    "    language_model._orig_forward = language_model.forward\n",
+    "    language_model.forward = types.MethodType(forward_wrap, language_model)\n",
+    "    ov_model = ov.convert_model(\n",
+    "        language_model,\n",
+    "        example_input={\n",
+    "            \"inputs_embeds\": llm_input,\n",
+    "            \"attention_mask\": torch.ones((2, 4)),\n",
+    "            \"position_ids\": position_ids,\n",
+    "        },\n",
+    "    )\n",
+    "\n",
+    "    for input, input_name in zip(ov_model.inputs, model_inputs):\n",
+    "        input.get_tensor().set_names({input_name})\n",
+    "\n",
+    "    for output, output_name in zip(ov_model.outputs, model_outputs):\n",
+    "        output.get_tensor().set_names({output_name})\n",
+    "    if make_stateful_model:\n",
+    "        patch_stateful(ov_model)\n",
+    "    ov.save_model(ov_model, LANGUAGE_MODEL_PATH)\n",
+    "    del ov_model\n",
+    "    cleanup_torchscript_cache()\n",
+    "    gc.collect()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 10,
+   "id": "49838499",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "WARNING:nncf:NNCF provides best results with torch==2.6.*, while current torch version is 2.7.0+cpu. If you encounter issues, consider switching to torch==2.6.*\n",
+      "INFO:nncf:Statistics of the bitwidth distribution:\n",
+      "┍━━━━━━━━━━━━━━━━━━━━━━━━━━━┯━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┯━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┑\n",
+      "│ Weight compression mode   │ % all parameters (layers)   │ % ratio-defining parameters (layers)   │\n",
+      "┝━━━━━━━━━━━━━━━━━━━━━━━━━━━┿━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┿━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┥\n",
+      "│ int8_asym                 │ 1% (1 / 224)                │ 0% (0 / 223)                           │\n",
+      "├───────────────────────────┼─────────────────────────────┼────────────────────────────────────────┤\n",
+      "│ int4_asym                 │ 99% (223 / 224)             │ 100% (223 / 223)                       │\n",
+      "┕━━━━━━━━━━━━━━━━━━━━━━━━━━━┷━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┷━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┙\n"
+     ]
+    },
+    {
+     "data": {
+      "text/html": [
+       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">/home/prabod/anaconda3/envs/e5v/lib/python3.11/site-packages/rich/live.py:231: UserWarning: install \"ipywidgets\" \n",
+       "for Jupyter support\n",
+       "  warnings.warn('install \"ipywidgets\" for Jupyter support')\n",
+       "</pre>\n"
+      ],
+      "text/plain": [
+       "/home/prabod/anaconda3/envs/e5v/lib/python3.11/site-packages/rich/live.py:231: UserWarning: install \"ipywidgets\" \n",
+       "for Jupyter support\n",
+       "  warnings.warn('install \"ipywidgets\" for Jupyter support')\n"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "import nncf\n",
+    "\n",
+    "compression_configuration = {\n",
+    "    \"mode\": nncf.CompressWeightsMode.INT4_ASYM,\n",
+    "    \"group_size\": 64,\n",
+    "    \"ratio\": 1.0,\n",
+    "}\n",
+    "LANGUAGE_MODEL_PATH_INT4 = (\n",
+    "    LANGUAGE_MODEL_PATH.parent / LANGUAGE_MODEL_PATH.name.replace(\".xml\", \"-int4.xml\")\n",
+    ")\n",
+    "ov_model = core.read_model(LANGUAGE_MODEL_PATH)\n",
+    "ov_model_compressed = nncf.compress_weights(ov_model, **compression_configuration)\n",
+    "ov.save_model(ov_model_compressed, LANGUAGE_MODEL_PATH_INT4)\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 11,
+   "id": "695c2fbf",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import torch\n",
+    "import torch.nn as nn\n",
+    "\n",
+    "\n",
+    "class UnpadImage(nn.Module):\n",
+    "    def __init__(self):\n",
+    "        super(UnpadImage, self).__init__()\n",
+    "\n",
+    "    def forward(self, tensor, original_size, current_size):\n",
+    "        \"\"\"\n",
+    "        Unpads an image tensor to its original size based on the current size.\n",
+    "        Args:\n",
+    "            tensor (torch.Tensor): The input image tensor of shape (C, H, W).\n",
+    "            original_size (torch.Tensor): The original size of the image tensor as (H, W).\n",
+    "            current_size (torch.Tensor): The current size of the image tensor as (H, W).\n",
+    "        \"\"\"\n",
+    "        # tensor: (C, H, W)\n",
+    "        original_size = original_size.to(torch.float32)\n",
+    "        original_height, original_width = original_size[0], original_size[1]\n",
+    "        current_height, current_width = current_size[0], current_size[1]\n",
+    "\n",
+    "        original_aspect_ratio = original_width / original_height\n",
+    "        current_aspect_ratio = current_width / current_height\n",
+    "\n",
+    "        # Comparison\n",
+    "        condition = original_aspect_ratio > current_aspect_ratio\n",
+    "\n",
+    "        # Branch 1: vertical padding\n",
+    "        scale_factor_1 = current_width.float() / original_width.float()\n",
+    "        new_height = (original_height.float() * scale_factor_1).int()\n",
+    "        pad_top = ((current_height.float() - new_height) / 2).floor().long()\n",
+    "\n",
+    "        # Branch 2: horizontal padding\n",
+    "        scale_factor_2 = current_height.float() / original_height.float()\n",
+    "        new_width = (original_width.float() * scale_factor_2).int()\n",
+    "        pad_left = ((current_width.float() - new_width) / 2).floor().long()\n",
+    "\n",
+    "        zero = torch.zeros(1, dtype=pad_top.dtype, device=tensor.device).squeeze(0)\n",
+    "\n",
+    "        # Use torch.where to conditionally compute slicing\n",
+    "        y_start = torch.where(condition, pad_top, zero)\n",
+    "        y_end = torch.where(condition, current_height - pad_top, current_height)\n",
+    "\n",
+    "        x_start = torch.where(condition, zero, pad_left)\n",
+    "        x_end = torch.where(condition, current_width - pad_left, current_width)\n",
+    "        out = tensor[:, y_start.int() : y_end.int(), x_start.int() : x_end.int()]\n",
+    "        return out  # Remove batch dimension if needed\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 12,
+   "id": "ba325001",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import math\n",
+    "\n",
+    "\n",
+    "class PackImageFeatures(nn.Module):\n",
+    "    def __init__(self, config):\n",
+    "        super().__init__()\n",
+    "        self.config = config\n",
+    "        self.unpad_image = UnpadImage()\n",
+    "        self.height = config.vision_config.image_size // config.vision_config.patch_size\n",
+    "        self.width = config.vision_config.image_size // config.vision_config.patch_size\n",
+    "\n",
+    "    def forward(self, image_feature, image_sizes, num_patch_height, num_patch_width):\n",
+    "        # we image features is a single image features, so we can remove the loop\n",
+    "        base_image_features = image_feature[0]\n",
+    "        features = image_feature[1:]  # Skip the first token\n",
+    "        features = (\n",
+    "            features.view(\n",
+    "                num_patch_height, num_patch_width, self.height, self.width, -1\n",
+    "            )\n",
+    "            .permute(4, 0, 2, 1, 3)\n",
+    "            .contiguous()\n",
+    "            .flatten(1, 2)\n",
+    "            .flatten(2, 3)\n",
+    "        )\n",
+    "        features = self.unpad_image(\n",
+    "            features, image_sizes[0], torch._shape_as_tensor(features)[1:3]\n",
+    "        )\n",
+    "        features = features.flatten(1, 2).transpose(0, 1)\n",
+    "        features = torch.cat([base_image_features, features], dim=0)\n",
+    "        return features.unsqueeze(0)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 13,
+   "id": "f911f0a0",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import torch\n",
+    "import torch.nn as nn\n",
+    "\n",
+    "\n",
+    "class MergeInputWithImageFeatures(nn.Module):\n",
+    "    def __init__(self, pad_token_id=0, image_token_index=0):\n",
+    "        super().__init__()\n",
+    "        self.pad_token_id = pad_token_id\n",
+    "        self.image_token_index = image_token_index\n",
+    "\n",
+    "    def forward(self, image_features, inputs_embeds, input_ids, attention_mask):\n",
+    "        num_images, num_image_patches, embed_dim = image_features.shape\n",
+    "        batch_size, sequence_length = input_ids.shape\n",
+    "\n",
+    "        # left_padding = torch.sum(input_ids[:, -1] == self.pad_token_id) == 0  # Removed, not needed now\n",
+    "\n",
+    "        special_image_token_mask = input_ids == self.image_token_index  # [B, S]\n",
+    "        num_special_image_tokens = special_image_token_mask.sum(dim=-1)  # [B]\n",
+    "\n",
+    "        max_embed_dim = (\n",
+    "            num_special_image_tokens.max() * (num_image_patches - 1)\n",
+    "        ) + sequence_length  # scalar\n",
+    "\n",
+    "        batch_indices, non_image_indices = torch.where(\n",
+    "            input_ids != self.image_token_index\n",
+    "        )  # [N], [N]\n",
+    "\n",
+    "        # Step 2: Compute new token positions\n",
+    "        new_token_positions = (\n",
+    "            torch.cumsum(special_image_token_mask * (num_image_patches - 1) + 1, dim=-1)\n",
+    "            - 1\n",
+    "        )  # [B, S]\n",
+    "\n",
+    "        nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]  # [B]\n",
+    "\n",
+    "        # left_padding_flag = (input_ids[:, -1] != self.pad_token_id).to(nb_image_pad.dtype)  # original\n",
+    "        left_padding_flag = (\n",
+    "            input_ids[:, -1] != self.pad_token_id\n",
+    "        ).long()  # more idiomatic torch\n",
+    "        # new_token_positions = new_token_positions + (left_padding_flag[:, None] * nb_image_pad[:, None])  # original\n",
+    "        new_token_positions += (\n",
+    "            left_padding_flag[:, None] * nb_image_pad[:, None]\n",
+    "        )  # updated\n",
+    "\n",
+    "        text_to_overwrite = new_token_positions[batch_indices, non_image_indices]  # [N]\n",
+    "\n",
+    "        # Step 3: Init final tensors\n",
+    "        final_embedding = torch.zeros(\n",
+    "            batch_size,\n",
+    "            max_embed_dim,\n",
+    "            embed_dim,\n",
+    "            dtype=inputs_embeds.dtype,\n",
+    "            device=inputs_embeds.device,\n",
+    "        )\n",
+    "        final_attention_mask = torch.zeros(\n",
+    "            batch_size,\n",
+    "            max_embed_dim,\n",
+    "            dtype=attention_mask.dtype,\n",
+    "            device=inputs_embeds.device,\n",
+    "        )\n",
+    "\n",
+    "        # final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]  # original\n",
+    "        final_embedding.index_put_(\n",
+    "            (batch_indices, text_to_overwrite),\n",
+    "            inputs_embeds[batch_indices, non_image_indices],\n",
+    "        )  # torch native\n",
+    "\n",
+    "        # final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]  # original\n",
+    "        final_attention_mask.index_put_(\n",
+    "            (batch_indices, text_to_overwrite),\n",
+    "            attention_mask[batch_indices, non_image_indices],\n",
+    "        )  # torch native\n",
+    "\n",
+    "        # Step 5: fill in image features\n",
+    "        image_to_overwrite = (final_embedding == 0).all(dim=-1)  # [B, L]\n",
+    "        image_to_overwrite &= (image_to_overwrite.cumsum(-1) - 1) >= nb_image_pad[\n",
+    "            :, None\n",
+    "        ]  # apply pad cutoff\n",
+    "\n",
+    "        flat_image_features = image_features.reshape(-1, embed_dim).to(\n",
+    "            inputs_embeds.device\n",
+    "        )  # [N_img, D]\n",
+    "\n",
+    "        # final_embedding[image_to_overwrite] = flat_image_features  # original\n",
+    "        final_embedding[image_to_overwrite] = flat_image_features[\n",
+    "            : image_to_overwrite.sum()\n",
+    "        ]  # safe assignment\n",
+    "\n",
+    "        final_attention_mask |= image_to_overwrite  # logical or with existing mask\n",
+    "\n",
+    "        position_ids = final_attention_mask.cumsum(-1) - 1\n",
+    "        position_ids = position_ids.masked_fill(final_attention_mask == 0, 1)\n",
+    "\n",
+    "        # Step 6: remove pad token embeddings\n",
+    "        batch_pad_indices, pad_token_positions = torch.where(\n",
+    "            input_ids == self.pad_token_id\n",
+    "        )  # [N_pad]\n",
+    "        indices_to_mask = new_token_positions[\n",
+    "            batch_pad_indices, pad_token_positions\n",
+    "        ]  # [N_pad]\n",
+    "\n",
+    "        # final_embedding[batch_pad_indices, indices_to_mask] = 0  # original\n",
+    "        final_embedding.index_put_(\n",
+    "            (batch_pad_indices, indices_to_mask),\n",
+    "            torch.zeros_like(final_embedding[batch_pad_indices, indices_to_mask]),\n",
+    "        )  # updated\n",
+    "\n",
+    "        return {\n",
+    "            \"final_embedding\": final_embedding,\n",
+    "            \"final_attention_mask\": final_attention_mask,\n",
+    "            \"position_ids\": position_ids,\n",
+    "        }\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "bfb25757",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# compile the models\n",
+    "language_model = core.read_model(LANGUAGE_MODEL_PATH)\n",
+    "compiled_language_model = core.compile_model(language_model, \"AUTO\")\n",
+    "\n",
+    "image_embed_model = core.compile_model(IMAGE_ENCODER_PATH, device)\n",
+    "text_embeddings_model = core.compile_model(INPUT_EMBEDDING_PATH, device)\n",
+    "\n",
+    "if IMAGE_PACKER_PATH.exists():\n",
+    "    image_packer_model = core.compile_model(IMAGE_PACKER_PATH, device)\n",
+    "else:\n",
+    "    image_packer_model = None\n",
+    "if MULTIMODAL_MERGER_PATH.exists()\n",
+    "    multimodal_merger_model = core.compile_model(MULTIMODAL_MERGER_PATH, device)\n",
+    "else:\n",
+    "    multimodal_merger_model = None\n",
+    "\n",
+    "# multimodal_merger_model = core.compile_model(MODEL_MERGER_PATH, device)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 15,
+   "id": "0d5643ee",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
+      "/home/prabod/anaconda3/envs/e5v/lib/python3.11/site-packages/huggingface_hub/file_download.py:943: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
+      "  warnings.warn(\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Image size: (360, 282), Mode: RGB\n"
+     ]
+    }
+   ],
+   "source": [
+    "import torch\n",
+    "import torch.nn.functional as F\n",
+    "import requests\n",
+    "from PIL import Image\n",
+    "from transformers import AutoTokenizer, AutoConfig\n",
+    "from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration\n",
+    "\n",
+    "llama3_template = \"<|start_header_id|>user<|end_header_id|>\\n\\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\n \\n\"\n",
+    "\n",
+    "processor = LlavaNextProcessor.from_pretrained(\"royokong/e5-v\")\n",
+    "\n",
+    "config = AutoConfig.from_pretrained(\"royokong/e5-v\")\n",
+    "img_prompt = llama3_template.format(\"<image>\\nSummary above image in one word: \")\n",
+    "text_prompt = llama3_template.format(\"<sent>\\nSummary above sentence in one word: \")\n",
+    "\n",
+    "images = [Image.open(\"dog.jpg\").convert(\"RGB\")]\n",
+    "\n",
+    "for image in images:\n",
+    "    print(f\"Image size: {image.size}, Mode: {image.mode}\")\n",
+    "\n",
+    "texts = [\"A dog sitting in the grass.\"]\n",
+    "\n",
+    "text_inputs = processor(\n",
+    "    [text_prompt.replace(\"<sent>\", text) for text in texts],\n",
+    "    return_tensors=\"pt\",\n",
+    "    padding=True,\n",
+    ")\n",
+    "img_inputs = processor(\n",
+    "    [img_prompt] * len(images), images, return_tensors=\"pt\", padding=True\n",
+    ")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 16,
+   "id": "ad1b402c",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "img_input_ids = img_inputs[\"input_ids\"]\n",
+    "img_attention_mask = img_inputs[\"attention_mask\"]\n",
+    "image_sizes = img_inputs[\"image_sizes\"]\n",
+    "pixel_values = img_inputs[\"pixel_values\"]\n",
+    "\n",
+    "text_input_ids = text_inputs[\"input_ids\"]\n",
+    "text_attention_mask = text_inputs[\"attention_mask\"]\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 17,
+   "id": "2649d101",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "image_features = torch.from_numpy(image_embed_model(pixel_values)[0])\n",
+    "image_inputs_embeds = torch.from_numpy(text_embeddings_model(img_input_ids)[0])\n",
+    "text_inputs_embeds = torch.from_numpy(text_embeddings_model(text_input_ids)[0])\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 18,
+   "id": "844968c5",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "image_packer = PackImageFeatures(config)\n",
+    "input_merger = MergeInputWithImageFeatures(\n",
+    "    pad_token_id=processor.tokenizer.pad_token_id,\n",
+    "    image_token_index=config.image_token_index,\n",
+    ")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 19,
+   "id": "190da649",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import numpy as np\n",
+    "from typing import Union, List, Tuple\n",
+    "import torch\n",
+    "\n",
+    "\n",
+    "def select_best_resolution(original_size: tuple, possible_resolutions: list) -> tuple:\n",
+    "    \"\"\"\n",
+    "    Selects the best resolution from a list of possible resolutions based on the original size.\n",
+    "\n",
+    "    This is done by calculating the effective and wasted resolution for each possible resolution.\n",
+    "\n",
+    "    The best fit resolution is the one that maximizes the effective resolution and minimizes the wasted resolution.\n",
+    "\n",
+    "    Args:\n",
+    "        original_size (tuple):\n",
+    "            The original size of the image in the format (height, width).\n",
+    "        possible_resolutions (list):\n",
+    "            A list of possible resolutions in the format [(height1, width1), (height2, width2), ...].\n",
+    "\n",
+    "    Returns:\n",
+    "        tuple: The best fit resolution in the format (height, width).\n",
+    "    \"\"\"\n",
+    "    original_height, original_width = original_size\n",
+    "    best_fit = None\n",
+    "    max_effective_resolution = 0\n",
+    "    min_wasted_resolution = float(\"inf\")\n",
+    "\n",
+    "    for height, width in possible_resolutions:\n",
+    "        scale = min(width / original_width, height / original_height)\n",
+    "        downscaled_width, downscaled_height = (\n",
+    "            int(original_width * scale),\n",
+    "            int(original_height * scale),\n",
+    "        )\n",
+    "        effective_resolution = min(\n",
+    "            downscaled_width * downscaled_height, original_width * original_height\n",
+    "        )\n",
+    "        wasted_resolution = (width * height) - effective_resolution\n",
+    "\n",
+    "        if effective_resolution > max_effective_resolution or (\n",
+    "            effective_resolution == max_effective_resolution\n",
+    "            and wasted_resolution < min_wasted_resolution\n",
+    "        ):\n",
+    "            max_effective_resolution = effective_resolution\n",
+    "            min_wasted_resolution = wasted_resolution\n",
+    "            best_fit = (height, width)\n",
+    "\n",
+    "    return best_fit\n",
+    "\n",
+    "\n",
+    "def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int):\n",
+    "    \"\"\"\n",
+    "    Calculate the number of patches after the preprocessing for images of any resolution.\n",
+    "\n",
+    "    Args:\n",
+    "        image_size (`Union[torch.LongTensor, np.ndarray, Tuple[int, int]):\n",
+    "            The size of the input image in the format (height, width). ?\n",
+    "        grid_pinpoints (`List`):\n",
+    "            A list containing possible resolutions. Each item in the list should be a tuple or list\n",
+    "            of the form `(height, width)`.\n",
+    "        patch_size (`int`):\n",
+    "            The size of each image patch.\n",
+    "\n",
+    "    Returns:\n",
+    "        int: the number of patches\n",
+    "    \"\"\"\n",
+    "    if not isinstance(grid_pinpoints, list):\n",
+    "        raise ValueError(\"grid_pinpoints should be a list of tuples or lists\")\n",
+    "\n",
+    "    # ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate\n",
+    "    if not isinstance(image_size, (list, tuple)):\n",
+    "        if not isinstance(image_size, (torch.Tensor, np.ndarray)):\n",
+    "            raise ValueError(\n",
+    "                f\"image_size invalid type {type(image_size)} with value {image_size}\"\n",
+    "            )\n",
+    "        image_size = image_size.tolist()\n",
+    "\n",
+    "    best_resolution = select_best_resolution(image_size, grid_pinpoints)\n",
+    "    height, width = best_resolution\n",
+    "    num_patches = 0\n",
+    "    # consider change to ceil(height/patch_size)*ceil(width/patch_size) + 1\n",
+    "    for i in range(0, height, patch_size):\n",
+    "        for j in range(0, width, patch_size):\n",
+    "            num_patches += 1\n",
+    "    # add the base patch\n",
+    "    num_patches += 1\n",
+    "    return num_patches\n",
+    "\n",
+    "\n",
+    "def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):\n",
+    "    \"\"\"\n",
+    "    Calculate the shape of the image patch grid after the preprocessing for images of any resolution.\n",
+    "\n",
+    "    Args:\n",
+    "        image_size (`tuple`):\n",
+    "            The size of the input image in the format (width, height).\n",
+    "        grid_pinpoints (`List`):\n",
+    "            A list containing possible resolutions. Each item in the list should be a tuple or list\n",
+    "            of the form `(height, width)`.\n",
+    "        patch_size (`int`):\n",
+    "            The size of each image patch.\n",
+    "\n",
+    "    Returns:\n",
+    "        tuple: The shape of the image patch grid in the format (width, height).\n",
+    "    \"\"\"\n",
+    "    if not isinstance(grid_pinpoints, list):\n",
+    "        raise ValueError(\"grid_pinpoints should be a list of tuples or lists\")\n",
+    "\n",
+    "    # ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate\n",
+    "    if not isinstance(image_size, (list, tuple)):\n",
+    "        if not isinstance(image_size, (torch.Tensor, np.ndarray)):\n",
+    "            raise ValueError(\n",
+    "                f\"image_size invalid type: {type(image_size)} not valid, should be either list, tuple, np.ndarray or tensor\"\n",
+    "            )\n",
+    "        image_size = image_size.tolist()\n",
+    "\n",
+    "    height, width = select_best_resolution(image_size, grid_pinpoints)\n",
+    "    return height // patch_size, width // patch_size"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 20,
+   "id": "bcbec245",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "num_patch_width, num_patch_height = get_anyres_image_grid_shape(\n",
+    "    image_sizes[0],\n",
+    "    config.image_grid_pinpoints,\n",
+    "    config.vision_config.image_size,\n",
+    ")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 21,
+   "id": "40620525",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "packed_image_features = image_packer(\n",
+    "    image_features,\n",
+    "    image_sizes,\n",
+    "    num_patch_height=num_patch_height,\n",
+    "    num_patch_width=num_patch_width\n",
+    ")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 22,
+   "id": "0eb947f8",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "\n",
+    "if IMAGE_PACKER_PATH.exists():\n",
+    "    IMAGE_PACKER_PATH.unlink()\n",
+    "\n",
+    "ov_image_packer = ov.convert_model(\n",
+    "    image_packer,\n",
+    "    example_input={\n",
+    "        \"image_feature\": image_features,\n",
+    "        \"image_sizes\": image_sizes,\n",
+    "        \"num_patch_height\": torch.tensor(num_patch_height, dtype=torch.int64),\n",
+    "        \"num_patch_width\": torch.tensor(num_patch_width, dtype=torch.int64)\n",
+    "    }\n",
+    ")\n",
+    "ov.save_model(ov_image_packer, IMAGE_PACKER_PATH)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 23,
+   "id": "f2fae423",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "if MULTIMODAL_MERGER_PATH.exists():\n",
+    "    MULTIMODAL_MERGER_PATH.unlink()\n",
+    "ov_multimodal_merger = ov.convert_model(\n",
+    "    input_merger,\n",
+    "    example_input={\n",
+    "        \"image_features\": packed_image_features,\n",
+    "        \"inputs_embeds\": image_inputs_embeds,\n",
+    "        \"input_ids\": img_input_ids,\n",
+    "        \"attention_mask\": img_attention_mask\n",
+    "    }\n",
+    ")\n",
+    "ov.save_model(ov_multimodal_merger, MULTIMODAL_MERGER_PATH)\n",
+    "cleanup_torchscript_cache()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 24,
+   "id": "0599dd94",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import shutil\n",
+    "import os\n",
+    "if not os.path.exists(f\"{output_dir}/assets\"):\n",
+    "    output_dir = Path(output_dir)\n",
+    "    assets_dir = output_dir/\"assets\"\n",
+    "    assets_dir.mkdir(exist_ok=True)\n",
+    "    processor.save_pretrained(output_dir)\n",
+    "    # copy all the assets to the assets directory (json files, vocab files, etc.)\n",
+    "    for file in output_dir.glob(\"*.json\"):\n",
+    "        shutil.copy(file, assets_dir)\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 27,
+   "id": "27e894ab",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# delete the f32 language model\n",
+    "if LANGUAGE_MODEL_PATH.exists():\n",
+    "    LANGUAGE_MODEL_PATH.unlink()\n",
+    "\n",
+    "# delete the f32 language model bin file if exists\n",
+    "if LANGUAGE_MODEL_PATH.with_suffix(\".bin\").exists():\n",
+    "    LANGUAGE_MODEL_PATH.with_suffix(\".bin\").unlink()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "ff9ecebb",
+   "metadata": {},
+   "source": [
+    "## 2. Test the Exported model"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 29,
+   "id": "17eaa581",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "IMAGE_ENCODER_PATH = output_dir / \"openvino_vision_embeddings_model.xml\"\n",
+    "LANGUAGE_MODEL_PATH = output_dir / \"openvino_language_model-int4.xml\"\n",
+    "INPUT_EMBEDDING_PATH = output_dir / \"openvino_text_embeddings_model.xml\"\n",
+    "\n",
+    "IMAGE_PACKER_PATH = output_dir / \"openvino_image_packer.xml\"\n",
+    "MULTIMODAL_MERGER_PATH = output_dir / \"openvino_multimodal_merger.xml\""
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 30,
+   "id": "0782a4a9",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# compile the models\n",
+    "language_model = core.read_model(LANGUAGE_MODEL_PATH)\n",
+    "compiled_language_model = core.compile_model(language_model, \"AUTO\")\n",
+    "\n",
+    "image_embed_model = core.compile_model(IMAGE_ENCODER_PATH, device)\n",
+    "text_embeddings_model = core.compile_model(INPUT_EMBEDDING_PATH, device)\n",
+    "\n",
+    "if IMAGE_PACKER_PATH.exists():\n",
+    "    image_packer_model = core.compile_model(IMAGE_PACKER_PATH, device)\n",
+    "else:\n",
+    "    image_packer_model = None\n",
+    "if MULTIMODAL_MERGER_PATH.exists():\n",
+    "    multimodal_merger_model = core.compile_model(MULTIMODAL_MERGER_PATH, device)\n",
+    "else:\n",
+    "    multimodal_merger_model = None\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 31,
+   "id": "4b88a40c",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# use openvino model to pack the image features\n",
+    "packed_image_features = image_packer_model({\n",
+    "    'image_feature': image_features,\n",
+    "    'image_sizes': image_sizes,\n",
+    "    'num_patch_height': torch.tensor(num_patch_height, dtype=torch.int64),\n",
+    "    'num_patch_width': torch.tensor(num_patch_width, dtype=torch.int64)\n",
+    "})[0]\n",
+    "packed_image_features = torch.from_numpy(packed_image_features)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 32,
+   "id": "1a69b30b",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# use openvino model to merge the image features with text features\n",
+    "merger_out = multimodal_merger_model({\n",
+    "        \"image_features\": packed_image_features,\n",
+    "        \"inputs_embeds\": image_inputs_embeds,\n",
+    "        \"input_ids\": img_input_ids,\n",
+    "        \"attention_mask\": img_attention_mask\n",
+    "    }\n",
+    ")\n",
+    "image_final_embeds = torch.from_numpy(merger_out['final_embedding'])\n",
+    "image_final_attention_mask = torch.from_numpy(merger_out['final_attention_mask'])\n",
+    "image_position_ids = torch.from_numpy(merger_out['position_ids'])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 33,
+   "id": "131763dc",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "request = compiled_language_model.create_infer_request()\n",
+    "img_input_lm = {\n",
+    "    \"inputs_embeds\": image_final_embeds.detach().numpy(),\n",
+    "    \"attention_mask\": image_final_attention_mask.detach().numpy(),\n",
+    "    \"position_ids\": image_position_ids.detach().numpy(),\n",
+    "}\n",
+    "request.start_async(img_input_lm, share_inputs=True)\n",
+    "request.wait()\n",
+    "img_lm_output = torch.from_numpy(request.get_tensor(\"last_hidden_state\").data)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 34,
+   "id": "68787196",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "text_request = compiled_language_model.create_infer_request()\n",
+    "text_position_ids = text_attention_mask.long().cumsum(-1) - 1\n",
+    "text_position_ids.masked_fill_(text_attention_mask == 0, 1)\n",
+    "text_input_lm = {\n",
+    "    \"inputs_embeds\": text_inputs_embeds.detach().numpy(),\n",
+    "    \"attention_mask\": text_attention_mask.detach().numpy(),\n",
+    "    \"position_ids\": text_position_ids.detach().numpy(),\n",
+    "}\n",
+    "text_request.start_async(text_input_lm, share_inputs=True)\n",
+    "text_request.wait()\n",
+    "text_lm_output = torch.from_numpy(text_request.get_tensor(\"last_hidden_state\").data)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 35,
+   "id": "df6a5ae1",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "tensor([[0.7158]])\n"
+     ]
+    }
+   ],
+   "source": [
+    "import torch.nn.functional as F\n",
+    "\n",
+    "txt_embed = F.normalize(text_lm_output, dim=-1)\n",
+    "img_embed = F.normalize(img_lm_output, dim=-1)\n",
+    "\n",
+    "print(txt_embed @ img_embed.T)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "3764af1b",
+   "metadata": {},
+   "source": [
+    "## 3 Import and Save E5V in Spark NLP"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "265ecf82",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "! wget -q http://setup.johnsnowlabs.com/colab.sh -O - | bash"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "285bb60c",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import sparknlp\n",
+    "\n",
+    "# let's start Spark with Spark NLP\n",
+    "spark = sparknlp.start()\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "id": "18611787",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "model_id = \"royokong/e5-v\""
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "id": "8ca2060a",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "25/06/10 03:45:32 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.\n",
+      "25/06/10 03:45:41 WARN NativeLibrary: Failed to load library null: java.lang.UnsatisfiedLinkError: Can't load library: /tmp/openvino-native4021672575912693842/libtbb.so.2\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "WARNING: An illegal reflective access operation has occurred\n",
+      "WARNING: Illegal reflective access by org.apache.spark.util.SizeEstimator$ (file:/home/prabod/spark/jars/spark-core_2.12-3.3.2.jar) to field java.util.regex.Pattern.pattern\n",
+      "WARNING: Please consider reporting this to the maintainers of org.apache.spark.util.SizeEstimator$\n",
+      "WARNING: Use --illegal-access=warn to enable warnings of further illegal reflective access operations\n",
+      "WARNING: All illegal access operations will be denied in a future release\n"
+     ]
+    }
+   ],
+   "source": [
+    "e5v_embeddings_sn = E5VEmbeddings \\\n",
+    "            .loadSavedModel(str(output_dir),spark) \\\n",
+    "            .setInputCols(\"image_assembler\") \\\n",
+    "            .setOutputCol(\"answer\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 40,
+   "id": "d5b60572",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "                                                                                \r"
+     ]
+    }
+   ],
+   "source": [
+    "e5v_embeddings_sn.write().overwrite().save(f\"file:///tmp/{model_id}_spark_nlp\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "dd9cf656",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import sparknlp\n",
+    "from sparknlp.base import *\n",
+    "from sparknlp.annotator import *\n",
+    "from pyspark.sql.functions import lit\n",
+    "from pyspark.ml import Pipeline\n",
+    "from sparknlp.util import EmbeddingsDataFrameUtils\n",
+    "\n",
+    "from pathlib import Path\n",
+    "import os\n",
+    "\n",
+    "# download two images to test into ./images folder\n",
+    "\n",
+    "url1 = \"https://github.com/openvinotoolkit/openvino_notebooks/assets/29454499/d5fbbd1a-d484-415c-88cb-9986625b7b11\"\n",
+    "\n",
+    "Path(\"images\").mkdir(exist_ok=True)\n",
+    "\n",
+    "!wget -q -O images/image1.jpg {url1}\n",
+    "\n",
+    "\n",
+    "\n",
+    "images_path = \"file://\" + os.getcwd() + \"/images/\"\n",
+    "image_df = spark.read.format(\"image\").load(\n",
+    "    path=images_path\n",
+    ")\n",
+    "\n",
+    "imagePrompt = \"<|start_header_id|>user<|end_header_id|>\\n\\n<image>\\\\nSummary above image in one word: <|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\n \\n\"\n",
+    "image_df = spark.read.format(\"image\").option(\"dropInvalid\", True).load(images_path)\n",
+    "test_df = image_df.withColumn(\"text\", lit(imagePrompt))\n",
+    "\n",
+    "textPrompt = \"<|start_header_id|>user<|end_header_id|>\\n\\n<sent>\\\\nSummary above sentence in one word: <|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\n \\n\"\n",
+    "textDesc = \"A cat sitting in a box.\"\n",
+    "nullImageDF = spark.createDataFrame(\n",
+    "    [EmbeddingsDataFrameUtils.emptyImageRow], schema=\n",
+    "    EmbeddingsDataFrameUtils.imageSchema)\n",
+    "textDF = nullImageDF.withColumn(\"text\", lit(textPrompt.replace(\"<sent>\", textDesc)))\n",
+    "\n",
+    "test_df = test_df.union(textDF)\n",
+    "\n",
+    "imageAssembler = ImageAssembler() \\\n",
+    "            .setInputCol(\"image\") \\\n",
+    "            .setOutputCol(\"image_assembler\")\n",
+    "e5v = E5VEmbeddings.load(f\"file:///tmp/{model_id}_spark_nlp\") \\\n",
+    "    .setInputCols([\"image_assembler\"]) \\\n",
+    "    .setOutputCol(\"e5v\")\n",
+    "pipeline = Pipeline().setStages([imageAssembler, e5v])\n",
+    "results = pipeline.fit(test_df).transform(test_df)\n",
+    "results.select(\"e5v.embeddings\").show(truncate=True)"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "e5v",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.11.11"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/python/sparknlp/annotator/embeddings/__init__.py b/python/sparknlp/annotator/embeddings/__init__.py
index da453d2c555037..f93ac2e3b11ec4 100644
--- a/python/sparknlp/annotator/embeddings/__init__.py
+++ b/python/sparknlp/annotator/embeddings/__init__.py
@@ -41,3 +41,4 @@
 from sparknlp.annotator.embeddings.snowflake_embeddings import *
 from sparknlp.annotator.embeddings.nomic_embeddings import *
 from sparknlp.annotator.embeddings.auto_gguf_embeddings import *
+from sparknlp.annotator.embeddings.e5v_embeddings import *
\ No newline at end of file
diff --git a/python/sparknlp/annotator/embeddings/e5v_embeddings.py b/python/sparknlp/annotator/embeddings/e5v_embeddings.py
new file mode 100644
index 00000000000000..e8ee518a40333e
--- /dev/null
+++ b/python/sparknlp/annotator/embeddings/e5v_embeddings.py
@@ -0,0 +1,138 @@
+#  Copyright 2017-2024 John Snow Labs
+#
+#  Licensed under the Apache License, Version 2.0 (the "License");
+#  you may not use this file except in compliance with the License.
+#  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+#  limitations under the License.
+
+from sparknlp.common import *
+
+class E5VEmbeddings(AnnotatorModel,
+                   HasBatchedAnnotateImage,
+                   HasImageFeatureProperties,
+                   HasEngine,
+                    HasRescaleFactor):
+    """Universal multimodal embeddings using the E5-V model (see https://huggingface.co/royokong/e5-v).
+
+    E5-V bridges the modality gap between different input types (text, image) and demonstrates strong performance in multimodal embeddings, even without fine-tuning. It also supports a single-modality training approach, where the model is trained exclusively on text pairs, often yielding better performance than multimodal training.
+
+    Pretrained models can be loaded with :meth:`.pretrained` of the companion object:
+
+    >>> e5vEmbeddings = E5VEmbeddings.pretrained() \
+    ...     .setInputCols(["image_assembler"]) \
+    ...     .setOutputCol("e5v")
+
+    The default model is ``"e5v_int4"``, if no name is provided.
+
+    For available pretrained models please see the `Models Hub <https://sparknlp.org/models?task=Question+Answering>`__.
+
+    ====================== ======================
+    Input Annotation types Output Annotation type
+    ====================== ======================
+    ``IMAGE``              ``SENTENCE_EMBEDDINGS``
+    ====================== ======================
+
+    Examples
+    --------
+    Image + Text Embedding:
+    >>> import sparknlp
+    >>> from sparknlp.base import *
+    >>> from sparknlp.annotator import *
+    >>> from pyspark.ml import Pipeline
+    >>> image_df = spark.read.format("image").option("dropInvalid", value = True).load(imageFolder)
+    >>> imagePrompt = "<|start_header_id|>user<|end_header_id|>\n\n<image>\\nSummary above image in one word: <|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n"
+    >>> test_df = image_df.withColumn("text", lit(imagePrompt))
+    >>> imageAssembler = ImageAssembler() \
+    ...     .setInputCol("image") \
+    ...     .setOutputCol("image_assembler")
+    >>> e5vEmbeddings = E5VEmbeddings.pretrained() \
+    ...     .setInputCols(["image_assembler"]) \
+    ...     .setOutputCol("e5v")
+    >>> pipeline = Pipeline().setStages([
+    ...     imageAssembler,
+    ...     e5vEmbeddings
+    ... ])
+    >>> result = pipeline.fit(test_df).transform(test_df)
+    >>> result.select("e5v.embeddings").show(truncate = False)
+
+    Text-Only Embedding:
+    >>> from sparknlp.util import EmbeddingsDataFrameUtils
+    >>> textPrompt = "<|start_header_id|>user<|end_header_id|>\n\n<sent>\\nSummary above sentence in one word: <|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n"
+    >>> textDesc = "A cat sitting in a box."
+    >>> nullImageDF = spark.createDataFrame(spark.sparkContext.parallelize([EmbeddingsDataFrameUtils.emptyImageRow]), EmbeddingsDataFrameUtils.imageSchema)
+    >>> textDF = nullImageDF.withColumn("text", lit(textPrompt.replace("<sent>", textDesc)))
+    >>> e5vEmbeddings = E5VEmbeddings.pretrained() \
+    ...     .setInputCols(["image"]) \
+    ...     .setOutputCol("e5v")
+    >>> result = e5vEmbeddings.transform(textDF)
+    >>> result.select("e5v.embeddings").show(truncate = False)
+    """
+
+    name = "E5VEmbeddings"
+
+    inputAnnotatorTypes = [AnnotatorType.IMAGE]
+    outputAnnotatorType = AnnotatorType.SENTENCE_EMBEDDINGS
+
+    @keyword_only
+    def __init__(self, classname="com.johnsnowlabs.nlp.embeddings.E5VEmbeddings", java_model=None):
+        """Initializes the E5VEmbeddings annotator.
+
+        Parameters
+        ----------
+        classname : str, optional
+            The Java class name of the annotator, by default "com.johnsnowlabs.nlp.annotators.embeddings.E5VEmbeddings"
+        java_model : Optional[java.lang.Object], optional
+            A pre-initialized Java model, by default None
+        """
+        super(E5VEmbeddings, self).__init__(classname=classname, java_model=java_model)
+        self._setDefault()
+
+    @staticmethod
+    def loadSavedModel(folder, spark_session, use_openvino=False):
+        """Loads a locally saved model.
+
+        Parameters
+        ----------
+        folder : str
+            Folder of the saved model
+        spark_session : pyspark.sql.SparkSession
+            The current SparkSession
+        use_openvino : bool, optional
+            Whether to use OpenVINO engine, by default False
+
+        Returns
+        -------
+        E5VEmbeddings
+            The restored model
+        """
+        from sparknlp.internal import _E5VEmbeddingsLoader
+        jModel = _E5VEmbeddingsLoader(folder, spark_session._jsparkSession, use_openvino)._java_obj
+        return E5VEmbeddings(java_model=jModel)
+
+    @staticmethod
+    def pretrained(name="e5v_int4", lang="en", remote_loc=None):
+        """Downloads and loads a pretrained model.
+
+        Parameters
+        ----------
+        name : str, optional
+            Name of the pretrained model, by default "e5v_int4"
+        lang : str, optional
+            Language of the pretrained model, by default "en"
+        remote_loc : str, optional
+            Optional remote address of the resource, by default None. Will use Spark NLPs repositories otherwise.
+
+        Returns
+        -------
+        E5VEmbeddings
+            The restored model
+        """
+        from sparknlp.pretrained import ResourceDownloader
+        return ResourceDownloader.downloadModel(E5VEmbeddings, name, lang, remote_loc) 
\ No newline at end of file
diff --git a/python/sparknlp/internal/__init__.py b/python/sparknlp/internal/__init__.py
index 25f34ce4eba599..e7300ab8586b5c 100644
--- a/python/sparknlp/internal/__init__.py
+++ b/python/sparknlp/internal/__init__.py
@@ -1165,3 +1165,11 @@ def __init__(self, path, jspark, use_openvino=False):
             jspark,
             use_openvino,
         )
+class _E5VEmbeddingsLoader(ExtendedJavaWrapper):
+    def __init__(self, path, jspark, use_openvino=False):
+        super(_E5VEmbeddingsLoader, self).__init__(
+            "com.johnsnowlabs.nlp.embeddings.E5VEmbeddings.loadSavedModel",
+            path,
+            jspark,
+            use_openvino
+        )
\ No newline at end of file
diff --git a/python/sparknlp/util.py b/python/sparknlp/util.py
index 0bbacd410a9e8f..8381337b5873a5 100644
--- a/python/sparknlp/util.py
+++ b/python/sparknlp/util.py
@@ -15,6 +15,9 @@
 
 
 import sparknlp.internal as _internal
+import numpy as np
+from pyspark.sql import Row
+from pyspark.sql.types import StructType, StructField, StringType, IntegerType, BinaryType
 
 
 def get_config_path():
@@ -33,3 +36,26 @@ def exportConllFiles(*args):
             _internal._CoNLLGeneratorExportFromTargetAndPipeline(*args).apply()
         else:
             raise NotImplementedError(f"No exportConllFiles alternative takes {num_args} parameters")
+
+
+class EmbeddingsDataFrameUtils:
+    """
+    Utility for creating DataFrames compatible with multimodal embedding models (e.g., E5VEmbeddings) for text-only scenarios.
+    Provides:
+      - imageSchema: the expected schema for Spark image DataFrames
+      - emptyImageRow: a dummy image row for text-only embedding
+    """
+    imageSchema = StructType([
+        StructField(
+            "image",
+            StructType([
+                StructField("origin", StringType(), True),
+                StructField("height", IntegerType(), True),
+                StructField("width", IntegerType(), True),
+                StructField("nChannels", IntegerType(), True),
+                StructField("mode", IntegerType(), True),
+                StructField("data", BinaryType(), True),
+            ]),
+        )
+    ])
+    emptyImageRow = Row(Row("", 0, 0, 0, 0, bytes()))
diff --git a/python/test/annotator/embeddings/e5v_embeddings_test.py b/python/test/annotator/embeddings/e5v_embeddings_test.py
new file mode 100644
index 00000000000000..249484232284d3
--- /dev/null
+++ b/python/test/annotator/embeddings/e5v_embeddings_test.py
@@ -0,0 +1,64 @@
+#  Copyright 2017-2024 John Snow Labs
+#
+#  Licensed under the Apache License, Version 2.0 (the "License");
+#  you may not use this file except in compliance with the License.
+#  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+#  limitations under the License.
+import os
+import unittest
+import pytest
+
+from sparknlp.annotator import *
+from sparknlp.base import *
+from pyspark.ml import Pipeline
+from pyspark.sql.functions import lit
+from test.util import SparkContextForTest
+
+@pytest.mark.slow
+class E5VEmbeddingsTestSpec(unittest.TestCase):
+    def setUp(self):
+        self.spark = SparkContextForTest.spark
+        self.images_path = "file://"+os.getcwd() + "/../src/test/resources/image/"
+
+    def test_image_and_text_embedding(self):
+        # Simulate image+text embedding (requires actual image files for full test)
+        image_folder = os.environ.get("E5V_IMAGE_TEST_FOLDER", self.images_path)
+        imagePrompt = "<|start_header_id|>user<|end_header_id|>\n\n<image>\\nSummary above image in one word: <|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n"
+        image_df = self.spark.read.format("image").option("dropInvalid", True).load(image_folder)
+        test_df = image_df.withColumn("text", lit(imagePrompt))
+
+        imageAssembler = ImageAssembler() \
+            .setInputCol("image") \
+            .setOutputCol("image_assembler")
+        e5v = E5VEmbeddings.pretrained() \
+            .setInputCols(["image_assembler"]) \
+            .setOutputCol("e5v")
+        pipeline = Pipeline().setStages([imageAssembler, e5v])
+        results = pipeline.fit(test_df).transform(test_df)
+        results.select("e5v.embeddings").show(truncate=True)
+
+    def test_text_only_embedding(self):
+        # Simulate text-only embedding using emptyImageRow and imageSchema
+        from sparknlp.util import EmbeddingsDataFrameUtils
+        textPrompt = "<|start_header_id|>user<|end_header_id|>\n\n<sent>\\nSummary above sentence in one word: <|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n"
+        textDesc = "A cat sitting in a box."
+        nullImageDF = self.spark.createDataFrame(
+            self.spark.sparkContext.parallelize([EmbeddingsDataFrameUtils.emptyImageRow]),
+            EmbeddingsDataFrameUtils.imageSchema)
+        textDF = nullImageDF.withColumn("text", lit(textPrompt.replace("<sent>", textDesc)))
+        imageAssembler = ImageAssembler() \
+            .setInputCol("image") \
+            .setOutputCol("image_assembler")
+        e5v = E5VEmbeddings.pretrained() \
+            .setInputCols(["image_assembler"]) \
+            .setOutputCol("e5v")
+        pipeline = Pipeline().setStages([imageAssembler, e5v])
+        results = pipeline.fit(textDF).transform(textDF)
+        results.select("e5v.embeddings").show(truncate=True)
\ No newline at end of file
diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/E5V.scala b/src/main/scala/com/johnsnowlabs/ml/ai/E5V.scala
new file mode 100644
index 00000000000000..6653fe97cdebcb
--- /dev/null
+++ b/src/main/scala/com/johnsnowlabs/ml/ai/E5V.scala
@@ -0,0 +1,412 @@
+/*
+ * Copyright 2017-2022 John Snow Labs
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.johnsnowlabs.ml.ai
+
+import com.johnsnowlabs.ml.ai.util.Generation.GenerationConfig
+import com.johnsnowlabs.ml.onnx.OnnxWrapper.DecoderWrappers
+import com.johnsnowlabs.ml.openvino.OpenvinoWrapper.E5VWrappers
+import com.johnsnowlabs.ml.util.{ONNX, Openvino}
+import com.johnsnowlabs.nlp.AnnotatorType.DOCUMENT
+import com.johnsnowlabs.nlp._
+import com.johnsnowlabs.nlp.annotators.common.SentenceSplit
+import com.johnsnowlabs.ml.ai.util.transform.E5VUtils
+import com.johnsnowlabs.nlp.annotators.cv.feature_extractor.Preprocessor
+import com.johnsnowlabs.nlp.annotators.cv.util.io.ImageIOUtils
+import com.johnsnowlabs.nlp.annotators.cv.util.transform.ImageResizeUtils
+import com.johnsnowlabs.nlp.annotators.tokenizer.bpe.{BpeTokenizer, LLAVATokenizer, SpecialTokens}
+import org.intel.openvino.InferRequest
+
+private[johnsnowlabs] class E5V(
+    val onnxWrappers: Option[DecoderWrappers],
+    val openvinoWrapper: Option[E5VWrappers],
+    merges: Map[(String, String), Int],
+    vocabulary: Map[String, Int],
+    addedTokens: Map[String, Int],
+    preprocessor: Preprocessor,
+    generationConfig: GenerationConfig,
+    imageToken: Int,
+    imageGridPinpoints: Map[Int, Array[Int]],
+    patchSize: Int)
+    extends Serializable {
+
+  val detectedEngine: String =
+    if (onnxWrappers.isDefined) ONNX.name
+    else if (openvinoWrapper.isDefined) Openvino.name
+    else Openvino.name
+
+  private val GenerationConfig(
+    bosTokenId: Int,
+    paddingTokenId: Int,
+    eosTokenId: Int,
+    vocabSize: Int,
+    beginSuppressTokens,
+    suppressTokenIds,
+    forcedDecoderIds) =
+    generationConfig
+  val reversedVocabulary: Map[Int, String] = vocabulary.map(_.swap)
+  val specialTokens: SpecialTokens = SpecialTokens(
+    vocabulary,
+    startTokenString = reversedVocabulary(bosTokenId),
+    endTokenString = reversedVocabulary(eosTokenId),
+    unkTokenString = reversedVocabulary(eosTokenId),
+    maskTokenString = reversedVocabulary(eosTokenId),
+    padTokenString = reversedVocabulary(paddingTokenId),
+    additionalStrings = addedTokens.keys.toArray)
+
+  val bpeTokenizer: LLAVATokenizer = BpeTokenizer
+    .forModel(
+      "llava",
+      merges = merges,
+      vocab = vocabulary,
+      specialTokens = Some(specialTokens),
+      addPrefixSpaceToSentence = false,
+      alwaysAddPrefix = false,
+      prependString = "")
+    .asInstanceOf[LLAVATokenizer]
+
+  /** Decode a sequence of sentences
+    * @param sentences
+    *   Sequence of sentences
+    * @return
+    *   Sequence of decoded sentences
+    */
+  def decode(sentences: Array[Array[Int]]): Seq[String] = {
+    sentences.map(s => bpeTokenizer.decodeTokens(s.map(_.toInt)))
+  }
+
+  /** Encode a sequence of sentences
+    * @param sentences
+    *   Sequence of sentences
+    * @return
+    *   Sequence of encoded sentences
+    */
+  def encodeText(sentences: Seq[Annotation]): Seq[Array[Int]] = {
+
+    val tokens = SentenceSplit
+      .unpack(sentences)
+      .map(s => {
+        val sentWithTask = s
+        bpeTokenizer
+          .tokenize(sentWithTask)
+          .map(bpeTokenizer.encode)
+          .flatMap(_.map(_.pieceId))
+      })
+    tokens
+  }
+
+  def encode(
+      imageAnnotations: Seq[AnnotationImage],
+      sentences: Seq[Annotation],
+      preprocessor: Preprocessor): (
+      Seq[Array[Int]],
+      Option[Array[Array[Array[Array[Array[Float]]]]]],
+      Option[Array[(Int, Int)]]) = {
+    val encodedText = encodeText(sentences).toArray
+
+    // check if image annotations are present an height and width are > 0
+    val imageAnnotationsFiltered =
+      imageAnnotations.filter(annot => annot.width > 0 && annot.height > 0)
+
+    val preprocessedImages = if (imageAnnotationsFiltered.nonEmpty) {
+      Some(encodeImage(imageAnnotations.toArray, preprocessor))
+    } else {
+      None
+    }
+    val imageSizes = if (imageAnnotationsFiltered.nonEmpty) {
+      Some(imageAnnotations.map(annot => (annot.width, annot.height)).toArray)
+    } else {
+      None
+    }
+
+    (encodedText, preprocessedImages, imageSizes)
+  }
+
+  def tag(
+      batch: Seq[Array[Int]],
+      images: Option[Array[Array[Array[Array[Array[Float]]]]]],
+      imageSizes: Option[Array[(Int, Int)]]): Array[Array[Float]] = {
+
+    val pixelValues = images
+    val expandedDecoderInputsVals = batch
+    val sequencesLength = expandedDecoderInputsVals.map(x => x.length).toArray
+    val numReturn_sequences = 1
+    // from config
+
+    var effectiveBatch_size = 1
+    var effectiveBatch_mult = 1
+
+    effectiveBatch_size = expandedDecoderInputsVals.length
+    effectiveBatch_mult = 1
+
+    val inferRequestLanguageModel =
+      openvinoWrapper.get.languageModel.getCompiledModel().create_infer_request()
+    val inferRequestVisionEmbeddingsModel =
+      openvinoWrapper.get.visionEmbeddingsModel.getCompiledModel().create_infer_request()
+    val inferRequestTextEmbeddingsModel =
+      openvinoWrapper.get.textEmbeddingsModel.getCompiledModel().create_infer_request()
+    val inferRequestImagePackerModel =
+      openvinoWrapper.get.imagePackerModel.getCompiledModel().create_infer_request()
+    val inferRequestMergeModel =
+      openvinoWrapper.get.mergeModel.getCompiledModel().create_infer_request()
+
+    val generatedEmbeddings = getModelOutputs(
+      decoderInputIds = expandedDecoderInputsVals.toArray,
+      pixelValues = pixelValues,
+      imageSizes = imageSizes,
+      inferRequestLanguageModel = inferRequestLanguageModel,
+      inferRequestVisionEmbeddingsModel = inferRequestVisionEmbeddingsModel,
+      inferRequestTextEmbeddingsModel = inferRequestTextEmbeddingsModel,
+      inferRequestImagePackerModel = inferRequestImagePackerModel,
+      inferRequestMergeModel = inferRequestMergeModel)
+    generatedEmbeddings
+  }
+
+  def predict(
+      sentences: Seq[Annotation],
+      imageAnnotations: Seq[AnnotationImage]): Seq[Annotation] = {
+
+    val (encodedText, preprocessedImages, imageSizes) =
+      encode(imageAnnotations, sentences, preprocessor)
+    val sentenceEmbeddings = tag(encodedText, preprocessedImages, imageSizes)
+
+    val annotations = sentences.zip(sentenceEmbeddings).map { case (sentence, vectors) =>
+      Annotation(
+        annotatorType = AnnotatorType.SENTENCE_EMBEDDINGS,
+        begin = sentence.begin,
+        end = sentence.end,
+        result = sentence.result,
+        metadata = sentence.metadata,
+        embeddings = vectors)
+    }
+    annotations
+  }
+
+  def getModelOutputs(
+      decoderInputIds: Array[Array[Int]],
+      pixelValues: Option[Array[Array[Array[Array[Array[Float]]]]]],
+      imageSizes: Option[Array[(Int, Int)]],
+      inferRequestLanguageModel: InferRequest,
+      inferRequestVisionEmbeddingsModel: InferRequest,
+      inferRequestTextEmbeddingsModel: InferRequest,
+      inferRequestImagePackerModel: InferRequest,
+      inferRequestMergeModel: InferRequest): Array[Array[Float]] = {
+
+    val (inputIdsLong, inputPositionIDsLong): (Array[Long], Array[Long]) = {
+      // First pass
+      val inpIdsLong = decoderInputIds.flatMap { tokenIds => tokenIds.map(_.toLong) }
+      val posIdsLong = decoderInputIds.flatMap { tokenIds =>
+        tokenIds.zipWithIndex.map { case (_, i) =>
+          i.toLong
+        }
+      }
+      (inpIdsLong, posIdsLong)
+    }
+
+    val attentionMask: Array[Long] = decoderInputIds.flatMap { tokenIds => tokenIds.map(_ => 1L) }
+    val batchSize: Int = decoderInputIds.length
+    val shape: Array[Int] = Array(batchSize, inputIdsLong.length / batchSize)
+
+    val decoderAttentionMask: org.intel.openvino.Tensor =
+      new org.intel.openvino.Tensor(Array(batchSize, decoderInputIds.head.length), attentionMask)
+    val decoderPositionIDs: org.intel.openvino.Tensor =
+      new org.intel.openvino.Tensor(shape, inputPositionIDsLong)
+
+    val (finalEmbeds, finalAttentionMask, finalPositionIds) = getMultimodalEmbeddings(
+      decoderInputIds,
+      pixelValues,
+      imageSizes,
+      decoderAttentionMask,
+      inferRequestVisionEmbeddingsModel,
+      inferRequestTextEmbeddingsModel,
+      inferRequestImagePackerModel,
+      inferRequestMergeModel)
+
+    inferRequestLanguageModel.set_tensor("inputs_embeds", finalEmbeds)
+    if (finalAttentionMask.isDefined) {
+      val finalAttentionMaskFloatTensor = new org.intel.openvino.Tensor(
+        finalAttentionMask.get.get_shape(),
+        // flat array of floats of values 1.0
+        Array.fill(finalAttentionMask.get.get_shape().product)(1.0f))
+      inferRequestLanguageModel.set_tensor("attention_mask", finalAttentionMaskFloatTensor)
+    } else {
+      val attentionMaskFloat: Array[Float] =
+        decoderInputIds.flatMap { tokenIds => tokenIds.map(_ => 1f) }
+      val attentionMaskFloatTensor =
+        new org.intel.openvino.Tensor(
+          Array(batchSize, decoderInputIds.head.length),
+          attentionMaskFloat)
+      inferRequestLanguageModel.set_tensor("attention_mask", attentionMaskFloatTensor)
+    }
+    if (finalPositionIds.isDefined) {
+      inferRequestLanguageModel.set_tensor("position_ids", finalPositionIds.get)
+    } else {
+      inferRequestLanguageModel.set_tensor("position_ids", decoderPositionIDs)
+    }
+    inferRequestLanguageModel.infer()
+
+    val result = inferRequestLanguageModel.get_tensor("last_hidden_state")
+    val hiddenStateData = result.data()
+    val hiddenStateShape = result.get_shape()
+    val batchSizeResult = hiddenStateShape(0)
+    val hiddenSize = hiddenStateShape(1)
+    // Reshape to (batch, hidden_size) and return as Array[Array[Float]]
+    Array.tabulate(batchSizeResult) { b =>
+      val start = b * hiddenSize
+      val end = start + hiddenSize
+      hiddenStateData.slice(start, end)
+    }
+
+  }
+
+  private def encodeImage(
+      annotations: Array[AnnotationImage],
+      preprocessor: Preprocessor): Array[Array[Array[Array[Array[Float]]]]] = {
+
+    val batchProcessedImages = annotations.map { annot =>
+      val bufferedImage = ImageIOUtils.byteToBufferedImage(
+        bytes = annot.result,
+        w = annot.width,
+        h = annot.height,
+        nChannels = annot.nChannels)
+      val bestResolution = E5VUtils.selectBestResolution(
+        (bufferedImage.getHeight, bufferedImage.getWidth),
+        imageGridPinpoints.map { case (_, pinpoints) =>
+          (pinpoints(0), pinpoints(1))
+        }.toList)
+
+      val (newHeight, newWidth) = E5VUtils.getPatchOutputSize(bufferedImage, bestResolution)
+      val resizedForPatches = ImageResizeUtils.resizeBufferedImage(
+        width = newWidth,
+        height = newHeight,
+        resample = preprocessor.resample)(bufferedImage)
+
+      val paddedForPatches = E5VUtils.padImage(resizedForPatches, bestResolution)
+
+      var patches = E5VUtils.divideToPatches(paddedForPatches, patchSize)
+
+      // add the reshaped original image as the first patch
+      val resizedOriginalImage = ImageResizeUtils.resizeBufferedImage(
+        width = preprocessor.size,
+        height = preprocessor.size,
+        resample = preprocessor.resample)(bufferedImage)
+
+      patches = List(resizedOriginalImage) ++ patches
+      patches.map { patch =>
+        ImageResizeUtils.normalizeAndConvertBufferedImage(
+          img = patch,
+          mean = preprocessor.image_mean,
+          std = preprocessor.image_std,
+          doNormalize = preprocessor.do_normalize,
+          doRescale = preprocessor.do_rescale,
+          rescaleFactor = preprocessor.rescale_factor)
+      }.toArray
+    }
+
+    batchProcessedImages
+
+  }
+
+  def getMultimodalEmbeddings(
+      inputIds: Array[Array[Int]],
+      pixelValues: Option[Array[Array[Array[Array[Array[Float]]]]]],
+      imageSizes: Option[Array[(Int, Int)]],
+      attentionMask: org.intel.openvino.Tensor,
+      inferRequestVisionEmbeddingsModel: InferRequest,
+      inferRequestTextEmbeddingsModel: InferRequest,
+      inferRequestImagePackerModel: InferRequest,
+      inferRequestMergeModel: InferRequest): (
+      org.intel.openvino.Tensor,
+      Option[org.intel.openvino.Tensor],
+      Option[org.intel.openvino.Tensor]) = {
+
+    val inputIdsLong: Array[Long] = inputIds.flatMap(_.map(_.toLong))
+    val batchSize: Int = inputIds.length
+    val shape: Array[Int] = Array(batchSize, inputIdsLong.length / batchSize)
+    val inputIdsLongTensor = new org.intel.openvino.Tensor(shape, inputIdsLong)
+
+    // If pixelValues and imageSizes are present, do multimodal
+    (pixelValues, imageSizes, attentionMask) match {
+      case (Some(pixels), Some(sizes), attnMask) if pixels.nonEmpty && sizes.nonEmpty =>
+        // 1. Get image features
+        val pixelShape = Array(
+          pixels.length,
+          pixels.head.length,
+          pixels.head.head.length,
+          pixels.head.head.head.length,
+          pixels.head.head.head.head.length)
+        // Flatten the pixel values to match the expected input shape
+        val flattenedPixels = pixels.flatten.flatten.flatten.flatten
+        val pixelTensor =
+          new org.intel.openvino.Tensor(pixelShape, flattenedPixels)
+
+        inferRequestVisionEmbeddingsModel.set_tensor("pixel_values", pixelTensor)
+        inferRequestVisionEmbeddingsModel.infer()
+        val imageFeatures = inferRequestVisionEmbeddingsModel.get_output_tensor()
+
+        // 2. Compute patch grid shape (dummy for now, should use config)
+        val (numPatchHeight, numPatchWidth) =
+          E5VUtils.getAnyResImageGridShape(
+            imageSizes.get.head,
+            imageGridPinpoints.map { case (_, pinpoints) =>
+              (pinpoints(0), pinpoints(1))
+            }.toList,
+            preprocessor.size)
+
+        // 3. Pack image features
+        val imageSizesTensor = new org.intel.openvino.Tensor(
+          Array(sizes.length, 2),
+          sizes.flatMap(t => Array(t._1.toLong, t._2.toLong)))
+
+        val numPatchHeightTensor =
+          new org.intel.openvino.Tensor(Array[Int](), Array(numPatchHeight.toLong))
+
+        val numPatchWidthTensor =
+          new org.intel.openvino.Tensor(Array[Int](), Array(numPatchWidth.toLong))
+
+        inferRequestImagePackerModel.set_tensor("image_feature", imageFeatures)
+        inferRequestImagePackerModel.set_tensor("image_sizes", imageSizesTensor)
+        inferRequestImagePackerModel.set_tensor("num_patch_height", numPatchHeightTensor)
+        inferRequestImagePackerModel.set_tensor("num_patch_width", numPatchWidthTensor)
+        inferRequestImagePackerModel.infer()
+
+        val packedImageFeatures = inferRequestImagePackerModel.get_output_tensor()
+
+        // 4. Get text embeddings
+        inferRequestTextEmbeddingsModel.set_input_tensor(inputIdsLongTensor)
+        inferRequestTextEmbeddingsModel.infer()
+        val textEmbeddings = inferRequestTextEmbeddingsModel.get_output_tensor()
+
+        // 5. Merge image and text embeddings
+        inferRequestMergeModel.set_tensor("image_features", packedImageFeatures)
+        inferRequestMergeModel.set_tensor("inputs_embeds", textEmbeddings)
+        inferRequestMergeModel.set_tensor("input_ids", inputIdsLongTensor)
+
+        inferRequestMergeModel.set_tensor("attention_mask", attnMask)
+        inferRequestMergeModel.infer()
+        (
+          inferRequestMergeModel.get_tensor("final_embedding"),
+          Some(inferRequestMergeModel.get_tensor("final_attention_mask")),
+          Some(inferRequestMergeModel.get_tensor("position_ids")))
+      case _ =>
+        // Text-only
+        inferRequestTextEmbeddingsModel.set_input_tensor(inputIdsLongTensor)
+        inferRequestTextEmbeddingsModel.infer()
+        (inferRequestTextEmbeddingsModel.get_output_tensor(), None, None)
+    }
+  }
+
+}
diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/util/transform/E5VUtils.scala b/src/main/scala/com/johnsnowlabs/ml/ai/util/transform/E5VUtils.scala
new file mode 100644
index 00000000000000..a2561d2bd28613
--- /dev/null
+++ b/src/main/scala/com/johnsnowlabs/ml/ai/util/transform/E5VUtils.scala
@@ -0,0 +1,134 @@
+package com.johnsnowlabs.ml.ai.util.transform
+
+import java.awt.image.BufferedImage
+import java.awt.{Color, Graphics2D}
+
+object ChannelDimension extends Enumeration {
+  type ChannelDimension = Value
+  val FIRST, LAST = Value
+}
+
+object E5VUtils {
+  import ChannelDimension._
+
+  def selectBestResolution(
+      originalSize: (Int, Int),
+      possibleResolutions: List[(Int, Int)]): (Int, Int) = {
+    val (originalHeight, originalWidth) = originalSize
+    var bestFit: (Int, Int) = possibleResolutions.head
+    var maxEffectiveResolution = 0
+    var minWastedResolution = Double.PositiveInfinity
+
+    for ((height, width) <- possibleResolutions) {
+      val scale = math.min(width.toDouble / originalWidth, height.toDouble / originalHeight)
+      val downscaledWidth = (originalWidth * scale).toInt
+      val downscaledHeight = (originalHeight * scale).toInt
+      val effectiveResolution =
+        math.min(downscaledWidth * downscaledHeight, originalWidth * originalHeight)
+      val wastedResolution = (width * height) - effectiveResolution
+
+      if (effectiveResolution > maxEffectiveResolution ||
+        (effectiveResolution == maxEffectiveResolution && wastedResolution < minWastedResolution)) {
+        maxEffectiveResolution = effectiveResolution
+        minWastedResolution = wastedResolution
+        bestFit = (height, width)
+      }
+    }
+    bestFit
+  }
+
+  def imageSizeToNumPatches(
+      imageSize: (Int, Int),
+      gridPinpoints: List[(Int, Int)],
+      patchSize: Int): Int = {
+    val (height, width) = selectBestResolution(imageSize, gridPinpoints)
+    val numPatches = (0 until height by patchSize).size * (0 until width by patchSize).size
+    // add the base patch
+    numPatches + 1
+  }
+
+  def getAnyResImageGridShape(
+      imageSize: (Int, Int),
+      gridPinpoints: List[(Int, Int)],
+      patchSize: Int): (Int, Int) = {
+    val (height, width) = selectBestResolution(imageSize, gridPinpoints)
+    (height / patchSize, width / patchSize)
+  }
+
+  def getImageSize(image: BufferedImage): (Int, Int) = {
+    (image.getHeight, image.getWidth)
+  }
+
+  def expandToSquare(image: BufferedImage, backgroundColor: Color): BufferedImage = {
+    val width = image.getWidth
+    val height = image.getHeight
+    if (width == height) {
+      image
+    } else if (width > height) {
+      val result = new BufferedImage(width, width, image.getType)
+      val g = result.createGraphics()
+      g.setColor(backgroundColor)
+      g.fillRect(0, 0, width, width)
+      g.drawImage(image, 0, (width - height) / 2, null)
+      g.dispose()
+      result
+    } else {
+      val result = new BufferedImage(height, height, image.getType)
+      val g = result.createGraphics()
+      g.setColor(backgroundColor)
+      g.fillRect(0, 0, height, height)
+      g.drawImage(image, (height - width) / 2, 0, null)
+      g.dispose()
+      result
+    }
+  }
+
+  def divideToPatches(image: BufferedImage, patchSize: Int): List[BufferedImage] = {
+    val width = image.getWidth
+    val height = image.getHeight
+    val patches = for {
+      i <- 0 until height by patchSize
+      j <- 0 until width by patchSize
+    } yield {
+      val w = math.min(patchSize, width - j)
+      val h = math.min(patchSize, height - i)
+      image.getSubimage(j, i, w, h)
+    }
+    patches.toList
+  }
+
+  def getPatchOutputSize(image: BufferedImage, targetResolution: (Int, Int)): (Int, Int) = {
+    val (originalHeight, originalWidth) = getImageSize(image)
+    val (targetHeight, targetWidth) = targetResolution
+
+    val scaleW = targetWidth.toDouble / originalWidth
+    val scaleH = targetHeight.toDouble / originalHeight
+
+    if (scaleW < scaleH) {
+      val newWidth = targetWidth
+      val newHeight = math.min(math.ceil(originalHeight * scaleW).toInt, targetHeight)
+      (newHeight, newWidth)
+    } else {
+      val newHeight = targetHeight
+      val newWidth = math.min(math.ceil(originalWidth * scaleH).toInt, targetWidth)
+      (newHeight, newWidth)
+    }
+  }
+
+  def padImage(image: BufferedImage, targetResolution: (Int, Int)): BufferedImage = {
+    val (targetHeight, targetWidth) = targetResolution
+    val (originalHeight, originalWidth) = getImageSize(image)
+    val (newHeight, newWidth) = getPatchOutputSize(image, targetResolution)
+    val result = new BufferedImage(targetWidth, targetHeight, image.getType)
+    val g = result.createGraphics()
+    g.setColor(Color.BLACK)
+    g.fillRect(0, 0, newWidth, newHeight)
+    g.drawImage(
+      image,
+      (targetWidth - originalWidth) / 2,
+      (targetHeight - originalHeight) / 2,
+      null)
+    g.dispose()
+    result
+  }
+}
diff --git a/src/main/scala/com/johnsnowlabs/ml/openvino/OpenvinoWrapper.scala b/src/main/scala/com/johnsnowlabs/ml/openvino/OpenvinoWrapper.scala
index 274e085325aaf3..961995ffdd1511 100644
--- a/src/main/scala/com/johnsnowlabs/ml/openvino/OpenvinoWrapper.scala
+++ b/src/main/scala/com/johnsnowlabs/ml/openvino/OpenvinoWrapper.scala
@@ -285,4 +285,10 @@ object OpenvinoWrapper {
       textEmbeddingsModel: OpenvinoWrapper,
       imageEmbedModel: OpenvinoWrapper,
       modelMergerModel: OpenvinoWrapper)
+  case class E5VWrappers(
+      languageModel: OpenvinoWrapper,
+      visionEmbeddingsModel: OpenvinoWrapper,
+      textEmbeddingsModel: OpenvinoWrapper,
+      imagePackerModel: OpenvinoWrapper,
+      mergeModel: OpenvinoWrapper)
 }
diff --git a/src/main/scala/com/johnsnowlabs/nlp/ImageAssembler.scala b/src/main/scala/com/johnsnowlabs/nlp/ImageAssembler.scala
index ae620dc78cbaa5..839346b79453b5 100644
--- a/src/main/scala/com/johnsnowlabs/nlp/ImageAssembler.scala
+++ b/src/main/scala/com/johnsnowlabs/nlp/ImageAssembler.scala
@@ -152,7 +152,21 @@ class ImageAssembler(override val uid: String)
           result = image.get.data,
           metadata = metadata,
           text = text.getOrElse("")))
-    } else Seq.empty
+    } else if (text.isDefined) {
+      Seq(
+        AnnotationImage(
+          annotatorType = outputAnnotatorType,
+          origin = "",
+          height = 0,
+          width = 0,
+          nChannels = 0,
+          mode = 0,
+          result = Array.emptyByteArray,
+          metadata = metadata,
+          text = text.getOrElse("")))
+    } else {
+      Seq.empty[AnnotationImage]
+    }
 
   }
 
diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/E5VEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/E5VEmbeddings.scala
new file mode 100644
index 00000000000000..657d012c04734e
--- /dev/null
+++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/E5VEmbeddings.scala
@@ -0,0 +1,641 @@
+/*
+ * Copyright 2017-2024 John Snow Labs
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.johnsnowlabs.nlp.embeddings
+
+import com.johnsnowlabs.ml.ai.util.Generation.GenerationConfig
+import com.johnsnowlabs.ml.ai.E5V
+import com.johnsnowlabs.ml.onnx.OnnxWrapper.DecoderWrappers
+import com.johnsnowlabs.ml.util.LoadExternalModel.{
+  loadJsonStringAsset,
+  loadTextAsset,
+  modelSanityCheck,
+  notSupportedEngineError
+}
+import com.johnsnowlabs.nlp.annotators.cv.feature_extractor.Preprocessor
+import com.johnsnowlabs.ml.util.Openvino
+import com.johnsnowlabs.nlp.AnnotatorType.{DOCUMENT, IMAGE, SENTENCE_EMBEDDINGS}
+import com.johnsnowlabs.nlp._
+import org.json4s.{DefaultFormats, JValue}
+import org.json4s.jackson.JsonMethods.parse
+import com.johnsnowlabs.ml.openvino.{OpenvinoWrapper, ReadOpenvinoModel, WriteOpenvinoModel}
+import com.johnsnowlabs.ml.openvino.OpenvinoWrapper.E5VWrappers
+import com.johnsnowlabs.nlp.serialization.{MapFeature, StructFeature}
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.ml.param.{IntArrayParam, IntParam}
+import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.sql.types.{BinaryType, IntegerType, StringType, StructField, StructType}
+import org.apache.spark.sql.{Row, SparkSession}
+
+/** E5VEmbeddings provides universal multimodal embeddings using the E5-V model, which is
+  * fine-tuned from lmms-lab/llama3-llava-next-8b.
+  *
+  * E5-V bridges the modality gap between different input types (text, image) and demonstrates
+  * strong performance in multimodal embeddings, even without fine-tuning. It also supports a
+  * single-modality training approach, where the model is trained exclusively on text pairs, often
+  * yielding better performance than multimodal training.
+  *
+  * For more details, see the Hugging Face model card: https://huggingface.co/royokong/e5-v
+  *
+  * ==Overview==
+  *
+  * E5-V can embed both text and images into a shared space, enabling cross-modal retrieval and
+  * similarity tasks. The model is designed for universal embeddings and is suitable for scenarios
+  * where you want to compare or retrieve across modalities.
+  *
+  * ==Example==
+  *
+  * ===Image + Text Embedding===
+  * {{ { import org.apache.spark.sql.functions.lit import com.johnsnowlabs.nlp.base.ImageAssembler
+  * import com.johnsnowlabs.nlp.embeddings.E5VEmbeddings import org.apache.spark.ml.Pipeline
+  *
+  * val imageDF = spark.read.format("image").option("dropInvalid", value = true).load(imageFolder)
+  * val imagePrompt = "<|start_header_id|>user<|end_header_id|>\n\n<image>\\nSummary above image
+  * in one word: <|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n" val testDF =
+  * imageDF.withColumn("text", lit(imagePrompt))
+  *
+  * val imageAssembler = new ImageAssembler().setInputCol("image").setOutputCol("image_assembler")
+  * val e5vEmbeddings = E5VEmbeddings.pretrained() .setInputCols("image_assembler")
+  * .setOutputCol("e5v")
+  *
+  * val pipeline = new Pipeline().setStages(Array(imageAssembler, e5vEmbeddings)) val result =
+  * pipeline.fit(testDF).transform(testDF) result.select("e5v.embeddings").show(truncate = false)
+  * }}
+  *
+  * ===Text-Only Embedding===
+  * {{ { import org.apache.spark.sql.SparkSession import org.apache.spark.sql.functions.lit import
+  * com.johnsnowlabs.nlp.util.EmbeddingsDataFrameUtils.{emptyImageRow, imageSchema} import
+  * com.johnsnowlabs.nlp.embeddings.E5VEmbeddings
+  *
+  * val spark: SparkSession = ... val textPrompt =
+  * "<|start_header_id|>user<|end_header_id|>\n\n<sent>\\nSummary above sentence in one word:
+  * <|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n" val textDesc = "A cat sitting
+  * in a box." val nullImageDF =
+  * spark.createDataFrame(spark.sparkContext.parallelize(Seq(emptyImageRow)), imageSchema) val
+  * textDF = nullImageDF.withColumn("text", lit(textPrompt.replace("<sent>", textDesc)))
+  *
+  * val e5vEmbeddings = E5VEmbeddings.pretrained() .setInputCols("image") .setOutputCol("e5v") val
+  * result = e5vEmbeddings.transform(textDF) result.select("e5v.embeddings").show(truncate =
+  * false) }}
+  *
+  * ==References==
+  *   - Hugging Face model card: https://huggingface.co/royokong/e5-v
+  *   - Paper: https://arxiv.org/abs/2407.12580
+  *   - Code: https://github.com/kongds/E5-V
+  *
+  * @see
+  *   [[CLIPForZeroShotClassification]] for Zero Shot Image Classifier
+  * @see
+  *   [[https://sparknlp.org/docs/en/annotators Annotators Main Page]] for a list of transformer
+  *   based classifiers
+  * @groupname anno Annotator types
+  * @groupdesc anno
+  *   Required input and expected output annotator types
+  * @groupname Ungrouped Members
+  * @groupname param Parameters
+  * @groupname setParam Parameter setters
+  * @groupname getParam Parameter getters
+  * @groupname Ungrouped Members
+  * @groupprio param  1
+  * @groupprio anno  2
+  * @groupprio Ungrouped 3
+  * @groupprio setParam  4
+  * @groupprio getParam  5
+  * @groupdesc param
+  *   A list of (hyper-)parameter keys this annotator can take. Users can set and get the
+  *   parameter values through setters and getters, respectively.
+  */
+
+class E5VEmbeddings(override val uid: String)
+    extends AnnotatorModel[E5VEmbeddings]
+    with HasBatchedAnnotateImage[E5VEmbeddings]
+    with HasImageFeatureProperties
+    with WriteOpenvinoModel
+    with HasGeneratorProperties
+    with HasEngine {
+
+  /** Annotator reference id. Used to identify elements in metadata or to refer to this annotator
+    * type
+    */
+  def this() = this(Identifiable.randomUID("E5VEmbeddings"))
+
+  /** Annotator reference id. Used to identify elements in metadata or to refer to this annotator
+    * type
+    */
+  override val inputAnnotatorTypes: Array[AnnotatorType] = Array(IMAGE)
+  override val outputAnnotatorType: AnnotatorType = SENTENCE_EMBEDDINGS
+
+  /** @group setParam */
+  def setRandomSeed(value: Int): E5VEmbeddings.this.type = {
+    if (randomSeed.isEmpty) {
+      this.randomSeed = Some(value)
+    }
+    this
+  }
+
+  /** A list of token ids which are ignored in the decoder's output (Default: `Array()`)
+    *
+    * @group param
+    */
+  var ignoreTokenIds = new IntArrayParam(
+    this,
+    "ignoreTokenIds",
+    "A list of token ids which are ignored in the decoder's output")
+
+  /** @group setParam */
+  def setIgnoreTokenIds(tokenIds: Array[Int]): E5VEmbeddings.this.type = {
+    set(ignoreTokenIds, tokenIds)
+  }
+
+  /** @group getParam */
+  def getIgnoreTokenIds: Array[Int] = $(ignoreTokenIds)
+
+  /** Vocabulary used to encode the words to ids with bpeTokenizer.encode
+    *
+    * @group param
+    */
+  val vocabulary: MapFeature[String, Int] = new MapFeature(this, "vocabulary").setProtected()
+
+  /** @group setParam */
+  def setVocabulary(value: Map[String, Int]): this.type = set(vocabulary, value)
+
+  /** Holding merges.txt coming from RoBERTa model
+    *
+    * @group param
+    */
+  val merges: MapFeature[(String, String), Int] = new MapFeature(this, "merges").setProtected()
+
+  /** @group setParam */
+  def setMerges(value: Map[(String, String), Int]): this.type = set(merges, value)
+
+  /** Additional tokens to be added to the vocabulary
+    *
+    * @group param
+    */
+  val addedTokens: MapFeature[String, Int] = new MapFeature(this, "addedTokens").setProtected()
+
+  /** @group setParam */
+  def setAddedTokens(value: Map[String, Int]): this.type = set(addedTokens, value)
+
+  /** Stop tokens to terminate the generation
+    *
+    * @group param
+    */
+  override val stopTokenIds =
+    new IntArrayParam(this, "stopTokenIds", "Stop tokens to terminate the generation")
+
+  /** @group setParam */
+  override def setStopTokenIds(value: Array[Int]): this.type = {
+    set(stopTokenIds, value)
+  }
+
+  /** @group getParam */
+  override def getStopTokenIds: Array[Int] = $(stopTokenIds)
+
+  private var _model: Option[Broadcast[E5V]] = None
+  val generationConfig: StructFeature[GenerationConfig] =
+    new StructFeature(this, "generationConfig").setProtected()
+
+  def setGenerationConfig(value: GenerationConfig): this.type =
+    set(generationConfig, value)
+
+  def getGenerationConfig: GenerationConfig = $$(generationConfig)
+
+  val imageToken =
+    new IntParam(this, "imageToken", "Token id for image embeddings")
+
+  /** @group setParam */
+  def setImageToken(value: Int): this.type = set(imageToken, value)
+
+  /** @group getParam */
+  def getImageToken: Int = $(imageToken)
+
+  /** Pinpoints for image grid, used to extract image features from the grid
+    *
+    * @group param
+    */
+  val imageGridPinpoints: MapFeature[Int, Array[Int]] = new MapFeature(this, "imageGridPinpoints")
+
+  /** @group setParam */
+  def setImageGridPinpoints(value: Map[Int, Array[Int]]): this.type =
+    set(imageGridPinpoints, value)
+
+  /** @group getParam */
+  def getImageGridPinpoints: Map[Int, Array[Int]] = $$(imageGridPinpoints)
+
+  /** Patch size for image embeddings
+    *
+    * @group param
+    */
+  val patchSize: IntParam =
+    new IntParam(this, "patchSize", "Patch size for image embeddings, default is 336")
+
+  /** @group setParam */
+  def setPatchSize(value: Int): this.type = set(patchSize, value)
+
+  /** @group getParam */
+  def getPatchSize: Int = $(patchSize)
+
+  /** @group setParam */
+  def setModelIfNotSet(
+      spark: SparkSession,
+      preprocessor: Preprocessor,
+      onnxWrappers: Option[DecoderWrappers],
+      openvinoWrapper: Option[E5VWrappers]): this.type = {
+    if (_model.isEmpty) {
+      _model = Some(
+        spark.sparkContext.broadcast(
+          new E5V(
+            onnxWrappers,
+            openvinoWrapper,
+            $$(merges),
+            $$(vocabulary),
+            $$(addedTokens),
+            preprocessor,
+            generationConfig = getGenerationConfig,
+            imageToken = getImageToken,
+            imageGridPinpoints = getImageGridPinpoints,
+            patchSize = getPatchSize)))
+    }
+    this
+  }
+
+  /** @group getParam */
+  def getModelIfNotSet: E5V = _model.get.value
+
+  setDefault(
+    minOutputLength -> 0,
+    maxOutputLength -> 20,
+    doSample -> false,
+    temperature -> 0.6,
+    topK -> -1,
+    topP -> 0.9,
+    repetitionPenalty -> 1.0,
+    noRepeatNgramSize -> 3,
+    ignoreTokenIds -> Array(),
+    batchSize -> 1,
+    beamSize -> 1,
+    maxInputLength -> 4096,
+    stopTokenIds -> Array(2),
+    imageToken -> 128256,
+    patchSize -> 336)
+
+  /** takes a document and annotations and produces new annotations of this annotator's annotation
+    * type
+    *
+    * @param batchedAnnotations
+    *   Annotations in batches that correspond to inputAnnotationCols generated by previous
+    *   annotators if any
+    * @return
+    *   any number of annotations processed for every batch of input annotations. Not necessary
+    *   one to one relationship
+    */
+  override def batchAnnotate(
+      batchedAnnotations: Seq[Array[AnnotationImage]]): Seq[Seq[Annotation]] = {
+
+    batchedAnnotations
+      //      .filter { annotationImages =>
+      //        annotationImages.exists(_.text.nonEmpty)
+      //      }
+      .map { cleanAnnotationImages =>
+        val validImages = cleanAnnotationImages
+        val questionAnnotations = extractInputAnnotation(validImages)
+
+        getModelIfNotSet.predict(questionAnnotations, validImages.toSeq)
+      }
+  }
+
+  private def extractInputAnnotation(
+      annotationImages: Array[AnnotationImage]): Seq[Annotation] = {
+    val questions = annotationImages.map(annotationImage => {
+      val imageText =
+        if (annotationImage.text.nonEmpty) annotationImage.text
+        else
+          "<|user|> \n <|image|> This is an image\n <|end|>\n <|assistant|>\n" // default question
+      Annotation(imageText)
+    })
+
+    questions
+  }
+
+  override def onWrite(path: String, spark: SparkSession): Unit = {
+    super.onWrite(path, spark)
+    getEngine match {
+      case Openvino.name =>
+        val wrappers = getModelIfNotSet.openvinoWrapper
+        writeOpenvinoModels(
+          path,
+          spark,
+          Seq((wrappers.get.languageModel, "openvino_language_model-int4.xml")),
+          E5VEmbeddings.suffix)
+
+        writeOpenvinoModels(
+          path,
+          spark,
+          Seq((wrappers.get.visionEmbeddingsModel, "openvino_vision_embeddings_model.xml")),
+          E5VEmbeddings.suffix)
+
+        writeOpenvinoModels(
+          path,
+          spark,
+          Seq((wrappers.get.textEmbeddingsModel, "openvino_text_embeddings_model.xml")),
+          E5VEmbeddings.suffix)
+
+        writeOpenvinoModels(
+          path,
+          spark,
+          Seq((wrappers.get.imagePackerModel, "openvino_image_packer.xml")),
+          E5VEmbeddings.suffix)
+
+        writeOpenvinoModels(
+          path,
+          spark,
+          Seq((wrappers.get.mergeModel, "openvino_multimodal_merger.xml")),
+          E5VEmbeddings.suffix)
+      case _ =>
+        throw new Exception(notSupportedEngineError)
+    }
+  }
+
+}
+
+trait ReadablePretrainedE5VEmbeddings
+    extends ParamsAndFeaturesReadable[E5VEmbeddings]
+    with HasPretrained[E5VEmbeddings] {
+
+  override val defaultModelName: Some[String] = Some("e5v_1_5_7b_int4")
+
+  /** Java compliant-overrides */
+  override def pretrained(): E5VEmbeddings = super.pretrained()
+
+  override def pretrained(name: String): E5VEmbeddings =
+    super.pretrained(name)
+
+  override def pretrained(name: String, lang: String): E5VEmbeddings =
+    super.pretrained(name, lang)
+
+  override def pretrained(name: String, lang: String, remoteLoc: String): E5VEmbeddings =
+    super.pretrained(name, lang, remoteLoc)
+
+}
+
+trait ReadE5VEmbeddingsDLModel extends ReadOpenvinoModel {
+  this: ParamsAndFeaturesReadable[E5VEmbeddings] =>
+  val suffix: String = "_e5v"
+  override val openvinoFile: String = "e5v_openvino"
+  def readModel(instance: E5VEmbeddings, path: String, spark: SparkSession): Unit = {
+    instance.getEngine match {
+      case Openvino.name =>
+        val languageModelWrappers =
+          readOpenvinoModels(path, spark, Seq("openvino_language_model-int4.xml"), suffix)
+
+        val visionEmbeddingsModelWrappers =
+          readOpenvinoModels(path, spark, Seq("openvino_vision_embeddings_model.xml"), suffix)
+
+        val textEmbeddingsModelWrappers =
+          readOpenvinoModels(path, spark, Seq("openvino_text_embeddings_model.xml"), suffix)
+
+        val imagePackerModelWrappers =
+          readOpenvinoModels(path, spark, Seq("openvino_image_packer.xml"), suffix)
+
+        val mergeModelWrappers =
+          readOpenvinoModels(path, spark, Seq("openvino_multimodal_merger.xml"), suffix)
+
+        val ovWrapper = E5VWrappers(
+          languageModel = languageModelWrappers("openvino_language_model-int4.xml"),
+          visionEmbeddingsModel =
+            visionEmbeddingsModelWrappers("openvino_vision_embeddings_model.xml"),
+          textEmbeddingsModel = textEmbeddingsModelWrappers("openvino_text_embeddings_model.xml"),
+          mergeModel = mergeModelWrappers("openvino_multimodal_merger.xml"),
+          imagePackerModel = imagePackerModelWrappers("openvino_image_packer.xml"))
+        val preprocessor = Preprocessor(
+          do_normalize = true,
+          do_resize = true,
+          "E5VFeatureExtractor",
+          instance.getImageMean,
+          instance.getImageStd,
+          instance.getResample,
+          instance.getSize)
+        instance.setModelIfNotSet(spark, preprocessor, None, Some(ovWrapper))
+      case _ => {
+        throw new Exception(notSupportedEngineError)
+      }
+    }
+  }
+
+  addReader(readModel)
+
+  def loadSavedModel(
+      modelPath: String,
+      spark: SparkSession,
+      useOpenvino: Boolean = false): E5VEmbeddings = {
+    implicit val formats: DefaultFormats.type = DefaultFormats // for json4
+    val (localModelPath, detectedEngine) =
+      modelSanityCheck(
+        modelPath,
+        isDecoder = false,
+        custom = Some(
+          List(
+            "openvino_language_model-int4",
+            "openvino_vision_embeddings_model",
+            "openvino_text_embeddings_model",
+            "openvino_image_packer",
+            "openvino_multimodal_merger")))
+    val modelConfig: JValue =
+      parse(loadJsonStringAsset(localModelPath, "config.json"))
+
+    val preprocessorConfigJsonContent =
+      loadJsonStringAsset(localModelPath, "preprocessor_config.json")
+    val preprocessorConfig = Preprocessor.loadPreprocessorConfig(preprocessorConfigJsonContent)
+    val beginSuppressTokens: Array[Int] =
+      (modelConfig \ "begin_suppress_tokens").extract[Array[Int]]
+
+    val suppressTokenIds: Array[Int] =
+      (modelConfig \ "suppress_tokens").extract[Array[Int]]
+
+    val forcedDecoderIds: Array[(Int, Int)] =
+      (modelConfig \ "forced_decoder_ids").extract[Array[Array[Int]]].map {
+        case idxWithTokenId: Array[Int] if idxWithTokenId.length == 2 =>
+          (idxWithTokenId(0), idxWithTokenId(1))
+        case _ =>
+          throw new Exception(
+            "Could not extract forced_decoder_ids. Should be a list of tuples with 2 entries.")
+      }
+
+    def arrayOrNone[T](array: Array[T]): Option[Array[T]] =
+      if (array.nonEmpty) Some(array) else None
+
+    val bosTokenId = (modelConfig \ "text_config" \ "bos_token_id").extract[Int]
+    val eosTokenId = (modelConfig \ "text_config" \ "eos_token_id").extract[Int]
+    val padTokenId = (modelConfig \ "text_config" \ "eos_token_id").extract[Int]
+    val vocabSize = (modelConfig \ "text_config" \ "vocab_size").extract[Int]
+
+    val imageToken = (modelConfig \ "image_token_index").extract[Int]
+    val imageGridPinpoints: Array[Array[Int]] =
+      (modelConfig \ "image_grid_pinpoints").extract[Array[Array[Int]]]
+    val imageGridPinpointsMap: Map[Int, Array[Int]] =
+      imageGridPinpoints.zipWithIndex.map { case (pinpoints, index) =>
+        (index, pinpoints)
+      }.toMap
+    // Check if tokenizer.json exists
+    val tokenizerPath = s"$localModelPath/assets/tokenizer.json"
+    val tokenizerExists = new java.io.File(tokenizerPath).exists()
+    val (vocabs, addedTokens, bytePairs) = if (tokenizerExists) {
+      val tokenizerConfig: JValue = parse(loadJsonStringAsset(localModelPath, "tokenizer.json"))
+      // extract vocab from tokenizer.json ( model -> vocab)
+      var vocabs: Map[String, Int] =
+        (tokenizerConfig \ "model" \ "vocab").extract[Map[String, Int]]
+
+      // extract merges from tokenizer.json ( model -> merges)
+//      val bytePairs = (tokenizerConfig \ "model" \ "merges")
+//        .extract[List[Array[String]]]
+//        .filter(w => w.length == 2)
+//        .map { case Array(c1, c2) => (c1, c2) }
+//        .zipWithIndex
+//        .toMap
+      val bytePairs = (tokenizerConfig \ "model" \ "merges")
+        .extract[List[String]]
+        .map(_.split(" "))
+        .filter(w => w.length == 2)
+        .map { case Array(c1, c2) => (c1, c2) }
+        .zipWithIndex
+        .toMap
+      // extract added_tokens from tokenizer.json (added_tokens)
+      // "added_tokens": [
+      //    {
+      //      "id": 128000,
+      //      "content": "<|begin_of_text|>",
+      //      "single_word": false,
+      //      "lstrip": false,
+      //      "rstrip": false,
+      //      "normalized": false,
+      //      "special": true
+      //    }, ...
+      //  ]
+      val addedTokens = (tokenizerConfig \ "added_tokens")
+        .extract[List[Map[String, Any]]]
+        .map { token =>
+          val id = token("id").asInstanceOf[BigInt].intValue()
+          val content = token("content").asInstanceOf[String]
+          (content, id)
+        }
+        .toMap
+
+      // update vocab with added tokens
+      addedTokens.foreach { case (content, id) =>
+        vocabs += (content -> id)
+      }
+      (vocabs, addedTokens, bytePairs)
+    } else {
+      val vocabs = loadTextAsset(localModelPath, "vocab.txt").zipWithIndex.toMap
+      val addedTokens = loadTextAsset(localModelPath, "added_tokens.txt").zipWithIndex.toMap
+      val bytePairs = loadTextAsset(localModelPath, "merges.txt")
+        .map(_.split(" "))
+        .filter(w => w.length == 2)
+        .map { case Array(c1, c2) => (c1, c2) }
+        .zipWithIndex
+        .toMap
+      (vocabs, addedTokens, bytePairs)
+    }
+
+    val annotatorModel = new E5VEmbeddings()
+      .setGenerationConfig(
+        GenerationConfig(
+          bosTokenId,
+          padTokenId,
+          eosTokenId,
+          vocabSize,
+          arrayOrNone(beginSuppressTokens),
+          arrayOrNone(suppressTokenIds),
+          arrayOrNone(forcedDecoderIds)))
+      .setVocabulary(vocabs)
+      .setMerges(bytePairs)
+      .setAddedTokens(addedTokens)
+      .setImageToken(imageToken)
+      .setSize(preprocessorConfig.size)
+      .setImageMean(preprocessorConfig.image_mean)
+      .setImageStd(preprocessorConfig.image_std)
+      .setResample(preprocessorConfig.resample)
+      .setImageGridPinpoints(imageGridPinpointsMap)
+
+    val modelEngine =
+      if (useOpenvino)
+        Openvino.name
+      else
+        detectedEngine
+    annotatorModel.set(annotatorModel.engine, modelEngine)
+
+    detectedEngine match {
+      case Openvino.name =>
+        val visionWrapper =
+          OpenvinoWrapper.read(
+            spark,
+            localModelPath,
+            zipped = false,
+            useBundle = true,
+            detectedEngine = detectedEngine,
+            modelName = "openvino_vision_embeddings_model")
+        val textWrapper =
+          OpenvinoWrapper.read(
+            spark,
+            localModelPath,
+            zipped = false,
+            useBundle = true,
+            detectedEngine = detectedEngine,
+            modelName = "openvino_text_embeddings_model")
+
+        val imagePackerModelWrapper =
+          OpenvinoWrapper.read(
+            spark,
+            localModelPath,
+            zipped = false,
+            useBundle = true,
+            detectedEngine = detectedEngine,
+            modelName = "openvino_image_packer")
+
+        val mergeWrapper =
+          OpenvinoWrapper.read(
+            spark,
+            localModelPath,
+            zipped = false,
+            useBundle = true,
+            detectedEngine = detectedEngine,
+            modelName = "openvino_multimodal_merger")
+        val languageModelWrapper =
+          OpenvinoWrapper.read(
+            spark,
+            localModelPath,
+            zipped = false,
+            useBundle = true,
+            detectedEngine = detectedEngine,
+            modelName = "openvino_language_model-int4")
+
+        val openvinoWrapper = E5VWrappers(
+          languageModel = languageModelWrapper,
+          visionEmbeddingsModel = visionWrapper,
+          textEmbeddingsModel = textWrapper,
+          imagePackerModel = imagePackerModelWrapper,
+          mergeModel = mergeWrapper)
+        annotatorModel.setModelIfNotSet(spark, preprocessorConfig, None, Some(openvinoWrapper))
+      case _ =>
+        throw new Exception(notSupportedEngineError)
+    }
+
+    annotatorModel
+  }
+}
+
+object E5VEmbeddings extends ReadablePretrainedE5VEmbeddings with ReadE5VEmbeddingsDLModel
diff --git a/src/main/scala/com/johnsnowlabs/nlp/pretrained/ResourceDownloader.scala b/src/main/scala/com/johnsnowlabs/nlp/pretrained/ResourceDownloader.scala
index 6cb1ab2aa565f0..fcd7c11be57029 100644
--- a/src/main/scala/com/johnsnowlabs/nlp/pretrained/ResourceDownloader.scala
+++ b/src/main/scala/com/johnsnowlabs/nlp/pretrained/ResourceDownloader.scala
@@ -710,7 +710,8 @@ object PythonResourceDownloader {
     "PaliGemmaForMultiModal" -> PaliGemmaForMultiModal,
     "Gemma3ForMultiModal" -> Gemma3ForMultiModal,
     "InternVLForMultiModal" -> InternVLForMultiModal,
-    "Florence2Transformer" -> Florence2Transformer)
+    "Florence2Transformer" -> Florence2Transformer,
+    "E5VEmbeddings" -> E5VEmbeddings)
 
   // List pairs of types such as the one with key type can load a pretrained model from the value type
   val typeMapper: Map[String, String] = Map("ZeroShotNerModel" -> "RoBertaForQuestionAnswering")
diff --git a/src/main/scala/com/johnsnowlabs/nlp/util/EmbeddingsDataFrameUtils.scala b/src/main/scala/com/johnsnowlabs/nlp/util/EmbeddingsDataFrameUtils.scala
new file mode 100644
index 00000000000000..5c701306da0c22
--- /dev/null
+++ b/src/main/scala/com/johnsnowlabs/nlp/util/EmbeddingsDataFrameUtils.scala
@@ -0,0 +1,22 @@
+package com.johnsnowlabs.nlp.util
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.types._
+
+object EmbeddingsDataFrameUtils {
+  // Schema Spark expects for `format("image")`
+  val imageSchema: StructType = StructType(
+    Seq(
+      StructField(
+        "image",
+        StructType(Seq(
+          StructField("origin", StringType, true),
+          StructField("height", IntegerType, true),
+          StructField("width", IntegerType, true),
+          StructField("nChannels", IntegerType, true),
+          StructField("mode", IntegerType, true),
+          StructField("data", BinaryType, true))))))
+
+  // A reusable null image row for text-only embedding scenarios
+  val emptyImageRow: Row = Row(Row("", 0, 0, 0, 0, Array[Byte]()))
+}
diff --git a/src/test/scala/com/johnsnowlabs/nlp/embeddings/E5VEmbeddingsTestSpec.scala b/src/test/scala/com/johnsnowlabs/nlp/embeddings/E5VEmbeddingsTestSpec.scala
new file mode 100644
index 00000000000000..9bc4d00f9b7eb2
--- /dev/null
+++ b/src/test/scala/com/johnsnowlabs/nlp/embeddings/E5VEmbeddingsTestSpec.scala
@@ -0,0 +1,87 @@
+/*
+ * Copyright 2017-2022 John Snow Labs
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.johnsnowlabs.nlp.embeddings
+
+import com.johnsnowlabs.nlp.{AssertAnnotations, ImageAssembler}
+import com.johnsnowlabs.nlp.util.io.ResourceHelper
+import com.johnsnowlabs.tags.{FastTest, SlowTest}
+import org.apache.spark.ml.Pipeline
+import org.apache.spark.sql.{DataFrame, Encoder, Encoders}
+import org.apache.spark.sql.functions.{col, lit, size}
+import org.scalatest.flatspec.AnyFlatSpec
+import com.johnsnowlabs.nlp.util.EmbeddingsDataFrameUtils.{emptyImageRow, imageSchema}
+
+class E5VEmbeddingsTestSpec extends AnyFlatSpec {
+  lazy val model = getE5VEmbeddingsPipelineModel
+
+  val textPrompt =
+    "<|start_header_id|>user<|end_header_id|>\n\n<sent>\\nSummary above sentence in one word: <|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n"
+  val imagePrompt =
+    "<|start_header_id|>user<|end_header_id|>\n\n<image>\\nSummary above image in one word: <|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n"
+
+  "E5V Embeddings" should "correctly embed sentences" taggedAs SlowTest in {
+    val testDF = getTestDF
+    val result = model.transform(testDF)
+
+    result.select("e5v.embeddings").show(true)
+
+  }
+
+  private def getTestDF: DataFrame = {
+    val imageFolder = "src/test/resources/image1/"
+    val imageDF: DataFrame = ResourceHelper.spark.read
+      .format("image")
+      .option("dropInvalid", value = true)
+      .load(imageFolder)
+
+    val testDF: DataFrame = imageDF.withColumn("text", lit(imagePrompt))
+    val textDesc = "A cat sitting in a box."
+
+    // Create DataFrame with a single null image row
+    val spark = ResourceHelper.spark
+    val nullImageDF =
+      spark.createDataFrame(spark.sparkContext.parallelize(Seq(emptyImageRow)), imageSchema)
+
+    val textDF = nullImageDF.withColumn("text", lit(textPrompt.replace("<sent>", textDesc)))
+
+    testDF.union(textDF)
+//    textDF
+  }
+  private def getE5VEmbeddingsPipelineModel = {
+    val testDF = getTestDF
+
+    val imageAssembler: ImageAssembler = new ImageAssembler()
+      .setInputCol("image")
+      .setOutputCol("image_assembler")
+
+    val loadModel = E5VEmbeddings
+      .pretrained()
+      .setInputCols("image_assembler")
+      .setOutputCol("e5v")
+
+    val newPipeline: Pipeline =
+      new Pipeline().setStages(Array(imageAssembler, loadModel))
+
+    val pipelineModel = newPipeline.fit(testDF)
+
+    pipelineModel
+      .transform(testDF)
+      .show(truncate = true)
+
+    pipelineModel
+  }
+}