diff --git a/_scripts/test_backend_onnxruntime.py b/_scripts/test_backend_onnxruntime.py index fb4b8c68..222df32f 100644 --- a/_scripts/test_backend_onnxruntime.py +++ b/_scripts/test_backend_onnxruntime.py @@ -5,6 +5,7 @@ import unittest import warnings from typing import Any +import packaging.version as pv import numpy import onnx.backend.base import onnx.backend.test @@ -140,6 +141,9 @@ def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs): backend_test.exclude("(test_adagrad|test_adam|test_add_uint8)") +if pv.Version(onnxruntime.__version__) <= pv.Version("1.24"): + backend_test.exclude("(test_attention_4d_with|test_attention_4d_gqa)") + # import all test cases at global scope to make them visible to python.unittest globals().update(backend_test.test_cases) diff --git a/_unittests/ut_helpers/test_log_helper.py b/_unittests/ut_helpers/test_log_helper.py index a8d1981f..065e3b4d 100644 --- a/_unittests/ut_helpers/test_log_helper.py +++ b/_unittests/ut_helpers/test_log_helper.py @@ -514,13 +514,13 @@ def test_cube_sbs_no_time(self): cube = CubeLogs( df, keys=["^m_*", "exporter", "opt"], values=["time_p", "perf"], time="date" ).load() - sbs, sbs_agg = cube.sbs( + sbs, sbs_agg, _ = cube.sbs( dict(CFA=dict(exporter="E1", opt="O"), CFB=dict(exporter="E2", opt="O")) ) - self.assertEqual(sbs.shape, (4, 9)) + self.assertEqual(sbs.shape, (4, 11)) self.assertEqual(sbs.index.names, ["METRICS", "m_name", "date"]) self.assertEqual(sorted(sbs.columns.names), ["CONF", "exporter"]) - self.assertEqual(sbs_agg.shape, (2, 9)) + self.assertEqual(sbs_agg.shape, (2, 11)) self.assertEqual(sbs_agg.index.names, ["date", "METRICS"]) self.assertEqual(sorted(sbs_agg.columns.names), ["CONF", "exporter"]) @@ -604,13 +604,13 @@ def test_cube_sbs_with_time(self): cube = CubeLogs( df, keys=["^m_*", "exporter", "opt"], values=["time_p", "perf"], time="date" ).load() - sbs, sbs_agg = cube.sbs( + sbs, sbs_agg, _ = cube.sbs( dict(CFA=dict(exporter="E1", opt="O"), CFB=dict(exporter="E2", opt="O")) ) - self.assertEqual(sbs.shape, (8, 9)) + self.assertEqual(sbs.shape, (8, 11)) self.assertEqual(sbs.index.names, ["METRICS", "m_name", "date"]) self.assertEqual(sorted(sbs.columns.names), ["CONF", "exporter"]) - self.assertEqual(sbs_agg.shape, (4, 9)) + self.assertEqual(sbs_agg.shape, (4, 11)) self.assertEqual(sbs_agg.index.names, ["date", "METRICS"]) self.assertEqual(sorted(sbs_agg.columns.names), ["CONF", "exporter"]) diff --git a/_unittests/ut_reference/test_backend_onnxruntime_evaluator.py b/_unittests/ut_reference/test_backend_onnxruntime_evaluator.py index de727031..13d3ce38 100644 --- a/_unittests/ut_reference/test_backend_onnxruntime_evaluator.py +++ b/_unittests/ut_reference/test_backend_onnxruntime_evaluator.py @@ -1,6 +1,7 @@ import unittest import warnings from typing import Any +import packaging.version as pv import numpy import onnx.backend.base import onnx.backend.test @@ -9,6 +10,7 @@ from onnx import ModelProto from onnx.backend.base import Device, DeviceType from onnx.defs import onnx_opset_version +import onnxruntime from onnx_diagnostic.reference import OnnxruntimeEvaluator ORT_OPSET = max(21, onnx_opset_version() - 2) @@ -95,10 +97,12 @@ def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs): dft_atol = 1e-3 stft_atol = 1e-4 ql_atol = 1e-5 +fp16_atol = 1e-3 backend_test = onnx.backend.test.BackendTest( OnnxruntimeEvaluatorBackend, __name__, test_kwargs={ + "test_attention_4d_fp16": {"atol": fp16_atol}, "test_dft": {"atol": dft_atol, "rtol": numpy.inf}, "test_dft_axis": {"atol": dft_atol, "rtol": numpy.inf}, "test_dft_axis_opset19": {"atol": dft_atol, "rtol": numpy.inf}, @@ -287,6 +291,9 @@ def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs): ) backend_test.exclude(f"({exc})") +if pv.Version(onnxruntime.__version__) <= pv.Version("1.24"): + backend_test.exclude("(test_attention_4d_with|test_attention_4d_gqa)") + # import all test cases at global scope to make them visible to python.unittest globals().update(backend_test.test_cases) diff --git a/onnx_diagnostic/helpers/log_helper.py b/onnx_diagnostic/helpers/log_helper.py index 826f94eb..15e57bce 100644 --- a/onnx_diagnostic/helpers/log_helper.py +++ b/onnx_diagnostic/helpers/log_helper.py @@ -1210,7 +1210,7 @@ def to_excel( for k, v in sbs.items(): print(f"[CubeLogs.to_excel] sbs {k}: {v}") name = "∧".join(sbs) - sbs_raw, sbs_agg = self.sbs(sbs) + sbs_raw, sbs_agg, sbs_col = self.sbs(sbs) if verbose: print(f"[CubeLogs.to_excel] add sheet {name!r} with shape {sbs_raw.shape}") print( @@ -1234,6 +1234,14 @@ def to_excel( sbs_agg.index.nlevels, ), ) + sbs_col.to_excel( + writer, + sheet_name=f"{name}-COL", + freeze_panes=( + sbs_col.columns.nlevels + 1, + sbs_col.index.nlevels, + ), + ) if plots: from openpyxl.drawing.image import Image @@ -1314,7 +1322,7 @@ def cube_time(self, fill_other_dates: bool = False, threshold: float = 1.2) -> " def sbs( self, configs: Dict[str, Dict[str, Any]], column_name: str = "CONF" - ) -> Tuple[pandas.DataFrame, pandas.DataFrame]: + ) -> Tuple[pandas.DataFrame, pandas.DataFrame, pandas.DataFrame]: """ Creates a side-by-side for two configurations. Every configuration a dictionary column:value which filters in @@ -1325,7 +1333,7 @@ def sbs( :param configs: example ``dict(CFA=dict(exporter="E1", opt="O"), CFB=dict(exporter="E2", opt="O"))`` :param column_name: column to add with the name of the configuration - :return: data and aggregated date + :return: data, aggregated date, data with a row per model """ assert ( len(configs) >= 2 @@ -1433,6 +1441,8 @@ def _mkc(m, s): _mkc(m, f"{n1}<{n2}"): (si < sj).astype(int), _mkc(m, f"{n1}=={n2}"): (si == sj).astype(int), _mkc(m, f"{n1}>{n2}"): (si > sj).astype(int), + _mkc(m, f"{n1}*({n1}∧{n2})"): si * (~sinan & ~sjnan).astype(float), + _mkc(m, f"{n2}*({n1}∧{n2})"): sj * (~sinan & ~sjnan).astype(float), } ) nas.columns.names = view_res.columns.names @@ -1452,7 +1462,7 @@ def _mkc(m, s): } flat = view_res.groupby(self.time).agg(aggs) flat = flat.stack("METRICS", future_stack=True) - return res, flat + return res, flat, view_res.T.sort_index().T class CubeLogsPerformance(CubeLogs):