diff --git a/.ci/scripts/test_model.sh b/.ci/scripts/test_model.sh index 74eb75c6ddd..de28597b1d5 100755 --- a/.ci/scripts/test_model.sh +++ b/.ci/scripts/test_model.sh @@ -131,13 +131,13 @@ test_model_with_xnnpack() { return 0 fi - # Delegation + # Delegation and test with pybindings if [[ ${WITH_QUANTIZATION} == true ]]; then SUFFIX="q8" - "${PYTHON_EXECUTABLE}" -m examples.xnnpack.aot_compiler --model_name="${MODEL_NAME}" --delegate --quantize + "${PYTHON_EXECUTABLE}" -m examples.xnnpack.aot_compiler --model_name="${MODEL_NAME}" --delegate --quantize --test_after_export else SUFFIX="fp32" - "${PYTHON_EXECUTABLE}" -m examples.xnnpack.aot_compiler --model_name="${MODEL_NAME}" --delegate + "${PYTHON_EXECUTABLE}" -m examples.xnnpack.aot_compiler --model_name="${MODEL_NAME}" --delegate --test_after_export fi OUTPUT_MODEL_PATH="${MODEL_NAME}_xnnpack_${SUFFIX}.pte" diff --git a/examples/xnnpack/aot_compiler.py b/examples/xnnpack/aot_compiler.py index 81eeb75c72c..9a78138adf3 100644 --- a/examples/xnnpack/aot_compiler.py +++ b/examples/xnnpack/aot_compiler.py @@ -61,6 +61,14 @@ default="", help="Generate and save an ETRecord to the given file location", ) + parser.add_argument( + "-t", + "--test_after_export", + action="store_true", + required=False, + default=False, + help="Test the pte with pybindings", + ) parser.add_argument("-o", "--output_dir", default=".", help="output directory") args = parser.parse_args() @@ -117,3 +125,24 @@ quant_tag = "q8" if args.quantize else "fp32" model_name = f"{args.model_name}_xnnpack_{quant_tag}" save_pte_program(exec_prog, model_name, args.output_dir) + + if args.test_after_export: + logging.info("Testing the pte with pybind") + from executorch.extension.pybindings.portable_lib import ( + _load_for_executorch_from_buffer, + ) + + # Import custom ops. This requires portable_lib to be loaded first. + from executorch.extension.llm.custom_ops import ( # noqa: F401, F403 + custom_ops, + ) # usort: skip + + # Import quantized ops. This requires portable_lib to be loaded first. + from executorch.kernels import quantized # usort: skip # noqa: F401, F403 + from torch.utils._pytree import tree_flatten + + m = _load_for_executorch_from_buffer(exec_prog.buffer) + logging.info("Successfully loaded the model") + flattened = tree_flatten(example_inputs)[0] + res = m.run_method("forward", flattened) + logging.info("Successfully ran the model")