From 47957b3c85cb683b075ac260e160f027a9ee318e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Fri, 21 Nov 2025 10:23:21 +0100 Subject: [PATCH] improve max_diff --- _doc/technical/plot_gemm_or_matmul_add.py | 4 +-- onnx_diagnostic/helpers/helper.py | 32 ++++++++++++++--------- 2 files changed, 22 insertions(+), 14 deletions(-) diff --git a/_doc/technical/plot_gemm_or_matmul_add.py b/_doc/technical/plot_gemm_or_matmul_add.py index 0e574552..0363133c 100644 --- a/_doc/technical/plot_gemm_or_matmul_add.py +++ b/_doc/technical/plot_gemm_or_matmul_add.py @@ -10,8 +10,8 @@ What an operator Gemm in :epkg:`onnxruntime`, the most simple way to represent a linear neural layer. -A model with three choices -========================== +A model with many choices +========================= """ import cpuinfo diff --git a/onnx_diagnostic/helpers/helper.py b/onnx_diagnostic/helpers/helper.py index 665954f5..b9f6c927 100644 --- a/onnx_diagnostic/helpers/helper.py +++ b/onnx_diagnostic/helpers/helper.py @@ -1486,19 +1486,27 @@ def max_diff( dev=dev, ) if hist: - if isinstance(hist, bool): - hist = torch.tensor( - [0, 0.0001, 0.001, 0.01, 0.1, 1, 10, 100], dtype=diff.dtype - ) - hist = hist.to(diff.device) - ind = torch.bucketize(diff.reshape((-1,)), hist, right=False) - cou = torch.bincount(ind, minlength=ind.shape[0] + 1) - res["rep"] = dict( - zip( - [f">{x}" for x in hist], - [int(i) for i in (cou.sum() - torch.cumsum(cou, 0))], + if isinstance(hist, list) and len(hist) == 1: + res["rep"] = {f">{hist[0]}": (diff > hist[0]).sum().item()} + elif isinstance(hist, list) and len(hist) == 2: + res["rep"] = { + f">{hist[0]}": (diff > hist[0]).sum().item(), + f">{hist[1]}": (diff > hist[1]).sum().item(), + } + else: + if isinstance(hist, bool): + hist = torch.tensor( + [0, 0.0001, 0.001, 0.01, 0.1, 1, 10, 100], dtype=diff.dtype + ) + hist = torch.tensor(hist).to(diff.device) + ind = torch.bucketize(diff.reshape((-1,)), hist, right=False) + cou = torch.bincount(ind, minlength=ind.shape[0] + 1) + res["rep"] = dict( + zip( + [f">{x}" for x in hist], + [int(i) for i in (cou.sum() - torch.cumsum(cou, 0))], + ) ) - ) return res # type: ignore if isinstance(expected, int) and isinstance(got, torch.Tensor):