diff --git a/backends/arm/test/models/test_llama.py b/backends/arm/test/models/test_llama.py index 89196674c48..378c86cb420 100644 --- a/backends/arm/test/models/test_llama.py +++ b/backends/arm/test/models/test_llama.py @@ -102,7 +102,7 @@ def prepare_model(self): def test_llama_tosa_MI(self): llama_model, llama_inputs, llama_meta = self.prepare_model() - if llama_model is None and llama_inputs is None and llama_meta is None: + if llama_model is None or llama_inputs is None: pytest.skip("Missing model and/or input files") with torch.no_grad(): @@ -123,3 +123,29 @@ def test_llama_tosa_MI(self): rtol=1.1, # TODO: MLETORCH-825 decrease tolerance ) ) + + @pytest.mark.xfail(reason="KeyError: scalar_tensor_1 (MLETORCH-907)") + def test_llama_tosa_BI(self): + llama_model, llama_inputs, llama_meta = self.prepare_model() + + if llama_model is None or llama_inputs is None: + pytest.skip("Missing model and/or input files") + + with torch.no_grad(): + ( + ArmTester( + llama_model, + example_inputs=llama_inputs, + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), + constant_methods=llama_meta, + ) + .quantize() + .export() + .to_edge_transform_and_lower() + .to_executorch() + .run_method_and_compare_outputs( + inputs=llama_inputs, + atol=4.3, + rtol=1.1, # TODO: Tolerance needs to be updated after MLETORCH-907 + ) + )