diff --git a/onnx_diagnostic/helpers/torch_helper.py b/onnx_diagnostic/helpers/torch_helper.py index b0d5cae6..11fc05a1 100644 --- a/onnx_diagnostic/helpers/torch_helper.py +++ b/onnx_diagnostic/helpers/torch_helper.py @@ -1028,6 +1028,16 @@ def study_discrepancies( """ Computes different metrics for the discrepancies. Returns graphs. + + .. plot:: + :include-source: + + import torch + from onnx_diagnostic.helpers.torch_helper import study_discrepancies + + t1 = torch.randn((512, 1024)) * 10 + t2 = t1 + torch.randn((512, 1024)) + study_discrepancies(t1, t2, title="Random noise") """ assert t1.dtype == t2.dtype, f"Type mismatch {t1.dtype} != {t2.dtype}" assert t1.shape == t2.shape, f"Shape mismatch {t1.shape} != {t2.shape}"