Skip to content

Commit e8cd6bc

Browse files
Update Phi-3 vision example and add Phi-3.5 vision example (#1049)
### Description This PR updates the Phi-3 vision example and adds a similar example for Phi-3.5 vision. ### Motivation and Context Now that ONNX Runtime v0.5.0 is released, the Phi-3 vision example needs to be updated and a similar example for Phi-3.5 vision can be created.
1 parent 83ddc3d commit e8cd6bc

File tree

3 files changed

+137
-39
lines changed

3 files changed

+137
-39
lines changed

examples/python/phi-3-vision.md

Lines changed: 9 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,14 @@ $ cd phi3-vision-128k-instruct/pytorch
4949
$ huggingface-cli download microsoft/Phi-3-vision-128k-instruct --local-dir .
5050
```
5151
52-
Now, let's download the modified PyTorch modeling files that have been uploaded to the Phi-3 vision ONNX repositories on Hugging Face. Here, let's use `microsoft/Phi-3-vision-128k-instruct-onnx-cpu` as the example ONNX repo.
53-
5452
### Download the modified PyTorch modeling files
53+
54+
Now, let's download the modified PyTorch modeling files that have been uploaded to the Phi-3 vision ONNX repository on Hugging Face.
55+
5556
```bash
5657
# Download modified files
5758
$ cd ..
58-
$ huggingface-cli download microsoft/Phi-3-vision-128k-instruct-onnx-cpu --include onnx/* --local-dir .
59+
$ huggingface-cli download microsoft/Phi-3-vision-128k-instruct-onnx --include onnx/* --local-dir .
5960
```
6061

6162
### Replace original PyTorch repo files with modified files
@@ -65,8 +66,7 @@ $ huggingface-cli download microsoft/Phi-3-vision-128k-instruct-onnx-cpu --inclu
6566
$ rm pytorch/config.json
6667
$ mv onnx/config.json pytorch/
6768
68-
# In our `modeling_phi3_v.py`, we replaced `from .image_embedding_phi3_v import Phi3ImageEmbedding`
69-
# with `from .image_embedding_phi3_v_for_onnx import Phi3ImageEmbedding`
69+
# In our `modeling_phi3_v.py`, we modified some classes for exporting to ONNX
7070
$ rm pytorch/modeling_phi3_v.py
7171
$ mv onnx/modeling_phi3_v.py pytorch/
7272
@@ -103,47 +103,19 @@ $ python3 builder.py --input ./pytorch --output ./dml --precision fp16 --executi
103103

104104
## 3. Build `genai_config.json` and `processor_config.json`
105105

106-
Currently, both JSON files needed to run with ONNX Runtime GenAI are created by hand. Because the fields have been hand-crafted, it is recommended that you copy the already-uploaded JSON files and modify the fields as needed for your fine-tuned Phi-3 vision model. [Here](https://huggingface.co/microsoft/Phi-3-vision-128k-instruct-onnx-cpu/blob/main/cpu-int4-rtn-block-32-acc-level-4/genai_config.json) is an example for `genai_config.json` and [here](https://huggingface.co/microsoft/Phi-3-vision-128k-instruct-onnx-cpu/blob/main/cpu-int4-rtn-block-32-acc-level-4/processor_config.json) is an example for `processor_config.json`.
107-
108-
### For DirectML
109-
Replace
110-
```json
111-
"provider_options": []
112-
```
113-
in `genai_config.json` With
114-
```json
115-
"provider_options": [
116-
{
117-
"dml" : {}
118-
}
119-
]
120-
```
121-
122-
### For CUDA
123-
Replace
124-
```json
125-
"provider_options": []
126-
```
127-
in `genai_config.json` With
128-
```json
129-
"provider_options": [
130-
{
131-
"cuda" : {}
132-
}
133-
]
134-
```
106+
Currently, both JSON files needed to run with ONNX Runtime GenAI are created by hand. Because the fields have been hand-crafted, it is recommended that you copy the already-uploaded JSON files and modify the fields as needed for your fine-tuned Phi-3 vision model. [Here](https://huggingface.co/microsoft/Phi-3-vision-128k-instruct-onnx/blob/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4/genai_config.json) is an example for `genai_config.json` and [here](https://huggingface.co/microsoft/Phi-3-vision-128k-instruct-onnx/blob/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4/processor_config.json) is an example for `processor_config.json`.
135107

136108
## 4. Run Phi-3 vision ONNX models
137109

138110
[Here](https://github.com/microsoft/onnxruntime-genai/blob/main/examples/python/phi3v.py) is an example of how you can run your Phi-3 vision model with the ONNX Runtime generate() API.
139111

140112
### CUDA
141113
```bash
142-
$ python .\phi3v.py -m .\phi3-vision-128k-instruct\cuda
114+
$ python .\phi3v.py -m .\phi3-vision-128k-instruct\cuda -p cuda
143115
```
144116

145117
### DirectML
146118

147119
```bash
148-
$ python .\phi3v.py -m .\phi3-vision-128k-instruct\dml
149-
```
120+
$ python .\phi3v.py -m .\phi3-vision-128k-instruct\dml -p dml
121+
```

examples/python/phi-3.5-vision.md

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# Build your Phi-3.5 vision ONNX models for ONNX Runtime GenAI
2+
3+
## Steps
4+
0. [Pre-requisites](#pre-requisites)
5+
1. [Prepare Local Workspace](#prepare-local-workspace)
6+
2. [Build ONNX Components](#build-onnx-components)
7+
3. [Build ORT GenAI Configs](#build-genai_configjson-and-processor_configjson)
8+
4. [Run Phi-3.5 vision ONNX models](#run-phi-3.5-vision-onnx-models)
9+
10+
## 0. Pre-requisites
11+
12+
Please ensure you have the following Python packages installed to create the ONNX models.
13+
14+
- `huggingface_hub[cli]`
15+
- `numpy`
16+
- `onnx`
17+
- `onnxruntime-genai`
18+
- For CPU:
19+
```bash
20+
pip install onnxruntime-genai
21+
```
22+
- For CUDA:
23+
```bash
24+
pip install onnxruntime-genai-cuda
25+
```
26+
- For DirectML:
27+
```bash
28+
pip install onnxruntime-genai-directml
29+
```
30+
- `pillow`
31+
- `requests`
32+
- `torch`
33+
- Please install torch by following the [instructions](https://pytorch.org/get-started/locally/). For getting ONNX models that can run on CUDA or DirectML, please install torch with CUDA and ensure the CUDA version you choose in the instructions is the one you have installed.
34+
- `torchvision`
35+
- `transformers`
36+
37+
## 1. Prepare Local Workspace
38+
39+
Phi-3.5 vision is a multimodal model consisting of several models internally. In order to run Phi-3.5 vision with ONNX Runtime GenAI, each internal model needs to be created as a separate ONNX model. To get these ONNX models, some of the original PyTorch modeling files have to be modified.
40+
41+
### Download the original PyTorch modeling files
42+
43+
First, let's download the original PyTorch modeling files.
44+
45+
```bash
46+
# Download PyTorch model and files
47+
$ mkdir -p phi3.5-vision-instruct/pytorch
48+
$ cd phi3.5-vision-instruct/pytorch
49+
$ huggingface-cli download microsoft/Phi-3.5-vision-instruct --local-dir .
50+
```
51+
52+
### Download the modified PyTorch modeling files
53+
54+
Now, let's download the modified PyTorch modeling files that have been uploaded to the Phi-3.5 vision ONNX repository on Hugging Face.
55+
56+
```bash
57+
# Download modified files
58+
$ cd ..
59+
$ huggingface-cli download microsoft/Phi-3.5-vision-instruct-onnx --include onnx/* --local-dir .
60+
```
61+
62+
### Replace original PyTorch repo files with modified files
63+
64+
```bash
65+
# In our `config.json`, we replaced `flash_attention_2` with `eager` in `_attn_implementation`
66+
$ rm pytorch/config.json
67+
$ mv onnx/config.json pytorch/
68+
69+
# In our `modeling_phi3_v.py`, we modified some classes for exporting to ONNX
70+
$ rm pytorch/modeling_phi3_v.py
71+
$ mv onnx/modeling_phi3_v.py pytorch/
72+
73+
# Move the builder script to the root directory
74+
$ mv onnx/builder.py .
75+
76+
# Delete empty `onnx` directory
77+
$ rm -rf onnx/
78+
```
79+
80+
If you have your own fine-tuned version of Phi-3.5 vision, you can now replace the `*.safetensors` files in the `pytorch` folder with your `*.safetensors` files.
81+
82+
## 2. Build ONNX Components
83+
84+
Here are some examples of how you can build the components as INT4 ONNX models.
85+
86+
```bash
87+
# Build INT4 components with FP32 inputs/outputs for CPU
88+
$ python3 builder.py --input ./pytorch --output ./cpu --precision fp32 --execution_provider cpu
89+
```
90+
91+
```bash
92+
# Build INT4 components with FP16 inputs/outputs for CUDA
93+
$ python3 builder.py --input ./pytorch --output ./cuda --precision fp16 --execution_provider cuda
94+
```
95+
96+
```bash
97+
# Build INT4 components with FP16 inputs/outputs for DirectML
98+
$ python3 builder.py --input ./pytorch --output ./dml --precision fp16 --execution_provider dml
99+
```
100+
101+
## 3. Build `genai_config.json` and `processor_config.json`
102+
103+
Currently, both JSON files needed to run with ONNX Runtime GenAI are created by hand. Because the fields have been hand-crafted, it is recommended that you copy the already-uploaded JSON files and modify the fields as needed for your fine-tuned Phi-3.5 vision model. [Here](https://huggingface.co/microsoft/Phi-3.5-vision-instruct-onnx/blob/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4/genai_config.json) is an example for `genai_config.json` and [here](https://huggingface.co/microsoft/Phi-3.5-vision-instruct-onnx/blob/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4/processor_config.json) is an example for `processor_config.json`.
104+
105+
## 4. Run Phi-3.5 vision ONNX models
106+
107+
[Here](https://github.com/microsoft/onnxruntime-genai/blob/main/examples/python/phi3v.py) is an example of how you can run your Phi-3.5 vision model with the ONNX Runtime generate() API.
108+
109+
### CUDA
110+
```bash
111+
$ python .\phi3v.py -m .\phi3.5-vision-instruct\cuda -p cuda
112+
```
113+
114+
### DirectML
115+
116+
```bash
117+
$ python .\phi3v.py -m .\phi3.5-vision-instruct\dml -p dml
118+
```

examples/python/phi3v.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,12 @@ def _complete(text, state):
1414

1515
def run(args: argparse.Namespace):
1616
print("Loading model...")
17-
model = og.Model(args.model_path)
17+
config = og.Config(args.model_path)
18+
config.clear_providers()
19+
if args.provider != "cpu":
20+
print(f"Setting model to {args.provider}...")
21+
config.append_provider(args.provider)
22+
model = og.Model(config)
1823
processor = model.create_multimodal_processor()
1924
tokenizer_stream = processor.create_stream()
2025

@@ -73,7 +78,10 @@ def run(args: argparse.Namespace):
7378
if __name__ == "__main__":
7479
parser = argparse.ArgumentParser()
7580
parser.add_argument(
76-
"-m", "--model_path", type=str, required=True, help="Path to the model"
81+
"-m", "--model_path", type=str, required=True, help="Path to the folder containing the model"
82+
)
83+
parser.add_argument(
84+
"-p", "--provider", type=str, required=True, help="Provider to run model"
7785
)
7886
args = parser.parse_args()
7987
run(args)

0 commit comments

Comments
 (0)