@@ -43,7 +43,9 @@ def cli_export(command, model_dir):
43
43
44
44
45
45
def check_causal_lm_output_quality (
46
- model_id : str , generated_tokens : List [int ], max_perplexity_threshold : float = 100.0
46
+ model_id : str ,
47
+ generated_tokens : List [int ],
48
+ max_perplexity_threshold : float = 100.0 ,
47
49
):
48
50
"""
49
51
Evaluates the quality of text generated by a causal language model by calculating its perplexity.
@@ -58,12 +60,24 @@ def check_causal_lm_output_quality(
58
60
"""
59
61
logging .info (f"Starting perplexity check with model '{ model_id } ' ..." )
60
62
# Load model
61
- model = AutoModelForCausalLM .from_pretrained (
62
- model_id ,
63
- low_cpu_mem_usage = True ,
64
- use_cache = False ,
65
- torch_dtype = torch .bfloat16 ,
66
- )
63
+ cls_name = AutoModelForCausalLM
64
+ if "llava" in model_id :
65
+ from transformers import LlavaForConditionalGeneration
66
+
67
+ cls_name = LlavaForConditionalGeneration
68
+ try :
69
+ model = cls_name .from_pretrained (
70
+ model_id ,
71
+ low_cpu_mem_usage = True ,
72
+ use_cache = False ,
73
+ torch_dtype = torch .bfloat16 ,
74
+ )
75
+ except TypeError :
76
+ model = cls_name .from_pretrained (
77
+ model_id ,
78
+ low_cpu_mem_usage = True ,
79
+ torch_dtype = torch .bfloat16 ,
80
+ )
67
81
68
82
with torch .no_grad ():
69
83
outputs = model (input_ids = generated_tokens , labels = generated_tokens )
@@ -156,6 +170,86 @@ def test_text_generation(model_id, model_dir, recipe, *, quantize=True, run_only
156
170
assert check_causal_lm_output_quality (model_id , generated_tokens ) is True
157
171
158
172
173
+ def test_llm_with_image_modality (
174
+ model_id , model_dir , recipe , * , quantize = True , run_only = False
175
+ ):
176
+ command = [
177
+ "optimum-cli" ,
178
+ "export" ,
179
+ "executorch" ,
180
+ "--model" ,
181
+ model_id ,
182
+ "--task" ,
183
+ "multimodal-text-to-text" ,
184
+ "--recipe" ,
185
+ recipe ,
186
+ "--output_dir" ,
187
+ model_dir ,
188
+ "--use_custom_sdpa" ,
189
+ "--use_custom_kv_cache" ,
190
+ "--qlinear" ,
191
+ "8da4w" ,
192
+ "--qembedding" ,
193
+ "8w" ,
194
+ ]
195
+ if not run_only :
196
+ cli_export (command , model_dir )
197
+
198
+ tokenizer = AutoTokenizer .from_pretrained (model_id )
199
+ tokenizer .save_pretrained (model_dir )
200
+
201
+ # input
202
+ processor = AutoProcessor .from_pretrained (model_id )
203
+ image_url = "https://llava-vl.github.io/static/images/view.jpg"
204
+ conversation = [
205
+ {
206
+ "role" : "system" ,
207
+ "content" : [
208
+ {
209
+ "type" : "text" ,
210
+ "text" : "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions." ,
211
+ }
212
+ ],
213
+ },
214
+ {
215
+ "role" : "user" ,
216
+ "content" : [
217
+ {"type" : "image" , "url" : image_url },
218
+ {
219
+ "type" : "text" ,
220
+ "text" : "What are the things I should be cautious about when I visit here?" ,
221
+ },
222
+ ],
223
+ },
224
+ ]
225
+ inputs = processor .apply_chat_template (
226
+ conversation ,
227
+ add_generation_prompt = True ,
228
+ tokenize = True ,
229
+ return_dict = True ,
230
+ return_tensors = "pt" ,
231
+ )
232
+
233
+ from executorch .extension .llm .runner import GenerationConfig , MultimodalRunner
234
+
235
+ runner = MultimodalRunner (f"{ model_dir } /model.pte" , f"{ model_dir } /tokenizer.model" )
236
+ generated_text = runner .generate_text_hf (
237
+ inputs ,
238
+ GenerationConfig (max_new_tokens = 128 , temperature = 0 , echo = False ),
239
+ processor .image_token_id ,
240
+ )
241
+ print (f"\n Generated text:\n \t { generated_text } " )
242
+ # Free memory before loading eager for quality check
243
+ del runner
244
+ gc .collect ()
245
+ assert (
246
+ check_causal_lm_output_quality (
247
+ model_id , tokenizer .encode (generated_text , return_tensors = "pt" )
248
+ )
249
+ is True
250
+ )
251
+
252
+
159
253
def test_fill_mask (model_id , model_dir , recipe , * , quantize = True , run_only = False ):
160
254
command = [
161
255
"optimum-cli" ,
@@ -353,6 +447,9 @@ def test_vit(model_id, model_dir, recipe, *, quantize=False, run_only=False):
353
447
required = False ,
354
448
help = "When provided, write the pte file to this directory. Otherwise, a temporary directory is created for the test." ,
355
449
)
450
+ parser .add_argument (
451
+ "--run_only" , action = "store_true" , help = "Skip export and only run the test"
452
+ )
356
453
args = parser .parse_args ()
357
454
358
455
_text_generation_mapping = {
@@ -384,8 +481,16 @@ def test_vit(model_id, model_dir, recipe, *, quantize=False, run_only=False):
384
481
"vit" : ("google/vit-base-patch16-224" , test_vit ),
385
482
}
386
483
484
+ _multimodal_model_mapping = {
485
+ "gemma3-4b" : ("google/gemma-3-4b-it" , test_llm_with_image_modality ),
486
+ "llava" : ("llava-hf/llava-1.5-7b-hf" , test_llm_with_image_modality ),
487
+ }
488
+
387
489
model_to_model_id_and_test_function = (
388
- _text_generation_mapping | _mask_fill_mapping | _misc_model_mapping
490
+ _text_generation_mapping
491
+ | _mask_fill_mapping
492
+ | _misc_model_mapping
493
+ | _multimodal_model_mapping
389
494
)
390
495
391
496
if args .model not in model_to_model_id_and_test_function :
@@ -400,4 +505,5 @@ def test_vit(model_id, model_dir, recipe, *, quantize=False, run_only=False):
400
505
model_dir = tmp_dir if args .model_dir is None else args .model_dir ,
401
506
recipe = args .recipe ,
402
507
quantize = args .quantize ,
508
+ run_only = args .run_only ,
403
509
)
0 commit comments