From dc5db816dfe361c44dffb57baf9eccb9f5ce7814 Mon Sep 17 00:00:00 2001 From: Martin Yuan Date: Thu, 21 Mar 2024 16:26:12 -0700 Subject: [PATCH 1/3] Add Llava model to examples [ghstack-poisoned] --- .ci/scripts/test.sh | 4 ++ .gitmodules | 3 ++ examples/models/__init__.py | 1 + examples/models/llava_encoder/README.md | 17 ++++++ examples/models/llava_encoder/__init__.py | 11 ++++ .../llava_encoder/install_requirements.sh | 19 +++++++ examples/models/llava_encoder/model.py | 52 +++++++++++++++++++ examples/third-party/LLaVA | 1 + 8 files changed, 108 insertions(+) create mode 100644 examples/models/llava_encoder/README.md create mode 100644 examples/models/llava_encoder/__init__.py create mode 100644 examples/models/llava_encoder/install_requirements.sh create mode 100644 examples/models/llava_encoder/model.py create mode 160000 examples/third-party/LLaVA diff --git a/.ci/scripts/test.sh b/.ci/scripts/test.sh index de241834611..2d915506158 100755 --- a/.ci/scripts/test.sh +++ b/.ci/scripts/test.sh @@ -67,6 +67,10 @@ test_model() { run_portable_executor_runner rm "./${MODEL_NAME}.pte" fi + if [[ "${MODEL_NAME}" == "llava_encoder" ]]; then + # Install requirements for llava + bash examples/models/llava_encoder/install_requirements.sh + fi # python3 -m examples.portable.scripts.export --model_name="llama2" should works too "${PYTHON_EXECUTABLE}" -m examples.portable.scripts.export --model_name="${MODEL_NAME}" run_portable_executor_runner diff --git a/.gitmodules b/.gitmodules index b57b10fedd5..ae75796637e 100644 --- a/.gitmodules +++ b/.gitmodules @@ -59,3 +59,6 @@ path = third-party/lm-evaluation-harness url = https://github.com/EleutherAI/lm-evaluation-harness branch = v0.4.1 +[submodule "examples/third-party/LLaVA"] + path = examples/third-party/LLaVA + url = https://github.com/haotian-liu/LLaVA.git diff --git a/examples/models/__init__.py b/examples/models/__init__.py index a64686b239f..c66feb09629 100644 --- a/examples/models/__init__.py +++ b/examples/models/__init__.py @@ -26,6 +26,7 @@ "ic4": ("inception_v4", "InceptionV4Model"), "resnet18": ("resnet", "ResNet18Model"), "resnet50": ("resnet", "ResNet50Model"), + "llava_encoder": ("llava_encoder", "LlavaModel"), } __all__ = [ diff --git a/examples/models/llava_encoder/README.md b/examples/models/llava_encoder/README.md new file mode 100644 index 00000000000..c0947690107 --- /dev/null +++ b/examples/models/llava_encoder/README.md @@ -0,0 +1,17 @@ +## Summary +In this example, we initiate the process of running multi modality through ExecuTorch. +- Demonstrate how to export the image encoder model in the [LLava](https://github.com/haotian-liu/LLaVA) multimodal model. +- Provide TODO steps on how to use the exported .pte file and the existing [exported Llama2 model](https://github.com/pytorch/executorch/tree/main/examples/models/llama2), to build the multimodal pipeline. + +## Instructions +Note that this folder does not host the pretrained LLava model. +- To have Llava available, follow the [Install instructions](https://github.com/haotian-liu/LLaVA?tab=readme-ov-file#install) in the LLava github. Follow the licence in the specific repo when using L +- Since the pytorch model version may not be updated, `cd executorch`, run `./install_requirements.sh`. +- Run `python3 -m examples.portable.scripts.export --model_name="llava_encoder"`. The llava_encoder.pte file will be generated. + +## TODO +- Write the pipeline in cpp + - Have image and text prompts as inputs. + - Call image processing functions to preprocess the image tensor. + - Load the llava_encoder.pte model, run it using the image tensor. + - The output of the encoder can be combined with the prompt, as inputs to the llama model. Call functions in llama_runner.cpp to run the llama model and get outputs. \ No newline at end of file diff --git a/examples/models/llava_encoder/__init__.py b/examples/models/llava_encoder/__init__.py new file mode 100644 index 00000000000..3029fd184f5 --- /dev/null +++ b/examples/models/llava_encoder/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from .model import LlavaModel + +__all__ = [ + LlavaModel, +] diff --git a/examples/models/llava_encoder/install_requirements.sh b/examples/models/llava_encoder/install_requirements.sh new file mode 100644 index 00000000000..c3d527419d3 --- /dev/null +++ b/examples/models/llava_encoder/install_requirements.sh @@ -0,0 +1,19 @@ +#!/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# install llava from the submodule +pip install --force-reinstall -e examples/third-party/LLaVA + +# The deps of llava can have different versions than deps of ExecuTorch. +# For example, torch version required from llava is older than ExecuTorch. +# To make both work, recover ExecuTorch's original dependencies by rerunning +# the install_requirements.sh. +./install_requirements.sh + +# bitsandbytes depends on numpy 1.x, which is not compatible with numpy 2.x. +# Reinstall bitsandbytes to make it compatible. +pip install bitsandbytes -I diff --git a/examples/models/llava_encoder/model.py b/examples/models/llava_encoder/model.py new file mode 100644 index 00000000000..302613dae8d --- /dev/null +++ b/examples/models/llava_encoder/model.py @@ -0,0 +1,52 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from examples.models.model_base import EagerModelBase +from llava.eval.run_llava import eval_model, load_images, process_images +from llava.mm_utils import get_model_name_from_path + +from llava.model.builder import load_pretrained_model +from torch import nn + + +class EncoderModel(nn.Module): + def __init__(self, llava_model): + super().__init__() + self.model_ = llava_model + + def forward(self, images_tensor): + features = self.model_.get_model().get_vision_tower()(images_tensor) + features = self.model_.get_model().mm_projector(features) + return features + + +class LlavaModel(EagerModelBase): + def __init__(self): + model_path = "liuhaotian/llava-v1.5-7b" + tokenizer, self.model_, self.image_processor_, context_len = ( + load_pretrained_model( + model_path=model_path, + model_base=None, + model_name=get_model_name_from_path(model_path), + ) + ) + self.device = "cpu" + self.dtype = torch.float32 + self.model_.to(device=self.device, dtype=self.dtype) + + def get_eager_model(self): + model = EncoderModel(self.model_) + return model + + def get_example_inputs(self): + image_file = "https://llava-vl.github.io/static/images/view.jpg" + images = load_images([image_file]) + images_tensor = process_images( + images, self.image_processor_, self.model_.config + ).to(self.model_.device) + return (images_tensor,) diff --git a/examples/third-party/LLaVA b/examples/third-party/LLaVA new file mode 160000 index 00000000000..7440ec9ee37 --- /dev/null +++ b/examples/third-party/LLaVA @@ -0,0 +1 @@ +Subproject commit 7440ec9ee37b0374c6b5548818e89878e38f3353 From 00b98b31f6a12469b4d19e1b30a19fdc3fddd04f Mon Sep 17 00:00:00 2001 From: Martin Yuan Date: Fri, 29 Mar 2024 17:43:55 -0700 Subject: [PATCH 2/3] Update base for Update on "Add Llava model to examples" [ghstack-poisoned] From d9c23af8e808793bb978df76b94098246f845a4e Mon Sep 17 00:00:00 2001 From: Martin Yuan Date: Sat, 30 Mar 2024 18:44:29 -0700 Subject: [PATCH 3/3] Update base for Update on "Add Llava model to examples" [ghstack-poisoned]