Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ _cache/*
.coverage
dist/*
build/*
_sbs_*
.eggs/*
.olive-cache/*
.hypothesis/*
Expand Down
2 changes: 1 addition & 1 deletion CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Change Logs
* :pr:`311`: use custom and local function to use PackedMultiHeadAttention from onnxruntime
* :pr:`310`: splits patches into multiple files
* :pr:`308`: add option --save_ep to dump the exported program as well as torch input
* :pr:`304`, :pr:`306`, :pr:`316`, :pr:`317`, :pr:`318`: improves side-by-side comparison, creates command line sbs
* :pr:`304`, :pr:`306`, :pr:`316`, :pr:`317`, :pr:`318`, :pr:`319`: improves side-by-side comparison, creates command line sbs

0.8.2
+++++
Expand Down
8 changes: 7 additions & 1 deletion _unittests/ut_helpers/test_torch_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
)
from onnx_diagnostic.helpers.mini_onnx_builder import create_input_tensors_from_onnx_model
from onnx_diagnostic.helpers.onnx_helper import from_array_extended, to_array_extended
from onnx_diagnostic.helpers.torch_helper import to_tensor
from onnx_diagnostic.helpers.torch_helper import to_tensor, study_discrepancies

TFLOAT = onnx.TensorProto.FLOAT

Expand Down Expand Up @@ -425,6 +425,12 @@ def test_get_weight_type(self):
dt = get_weight_type(model)
self.assertEqual(torch.float32, dt)

def test_study_discrepancies(self):
t1 = torch.rand((3, 4))
t2 = torch.rand((3, 4))
ax = study_discrepancies(t1, t2)
self.assertEqual(ax.shape, ((3, 2)))


if __name__ == "__main__":
unittest.main(verbosity=2)
2 changes: 1 addition & 1 deletion onnx_diagnostic/_command_lines_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1234,7 +1234,7 @@ def get_parser_sbs() -> ArgumentParser:
"--replay-threshold",
type=float,
required=False,
default=1e6,
default=1e9,
help="Triggers the replay if the discrepancies are higher than this value.",
)
parser.add_argument(
Expand Down
74 changes: 74 additions & 0 deletions onnx_diagnostic/helpers/torch_helper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import contextlib
import ctypes
import inspect
import math
import os
import sys
import warnings
Expand Down Expand Up @@ -1003,3 +1004,76 @@ def get_weight_type(model: torch.nn.Module) -> torch.dtype:
counts[dt] += 1
final = max(list(counts.items()))
return final[0]


def closest_factor_pair(n: int):
"""Tries to find ``a, b`` such as ``n == a * b``."""
assert n > 0, f"n={n} must be a positive integer"
start = math.isqrt(n)
for a in range(start, 0, -1):
if n % a == 0:
b = n // a
return a, b
return 1, n


def study_discrepancies(
t1: torch.Tensor,
t2: torch.Tensor,
bins: int = 50,
figsize: Optional[Tuple[int, int]] = (15, 15),
title: Optional[str] = None,
name: Optional[str] = None,
) -> "matplotlib.axes.Axes": # noqa: F821
"""
Computes different metrics for the discrepancies.
Returns graphs.
"""
assert t1.dtype == t2.dtype, f"Type mismatch {t1.dtype} != {t2.dtype}"
assert t1.shape == t2.shape, f"Shape mismatch {t1.shape} != {t2.shape}"
d1, d2 = (
(t1, t2) if t1.dtype == torch.float64 else (t1.to(torch.float32), t2.to(torch.float32))
)

d1 = d1.squeeze()
d2 = d2.squeeze()
if len(d1.shape) == 1:
new_shape = closest_factor_pair(d1.shape[0])
d1, d2 = d1.reshape(new_shape), d2.reshape(new_shape)
elif len(d1.shape) > 2:
new_shape = (-1, max(d1.shape))
d1, d2 = d1.reshape(new_shape), d2.reshape(new_shape)

import matplotlib.pyplot as plt

fig, ax = plt.subplots(3, 2, figsize=figsize)
vmin, vmax = d1.min().item(), d1.max().item()
ax[0, 0].imshow(d1.detach().cpu().numpy(), cmap="Greys", vmin=vmin, vmax=vmax)
ax[0, 0].set_title(
f"Color plot of the first tensor in\n[{vmin}, {vmax}]\n{t1.shape} -> {d1.shape}"
)

diff = d2 - d1
vmin, vmax = diff.min().item(), diff.max().item()
ax[0, 1].imshow(diff.detach().cpu().numpy(), cmap="seismic", vmin=vmin, vmax=vmax)
ax[0, 1].set_title(f"Color plot of the differences in \n[{vmin}, {vmax}]")

ax[1, 0].hist(d1.detach().cpu().numpy().ravel(), bins=bins)
ax[1, 0].set_title("Distribution of the first tensor")

ax[1, 1].hist(diff.detach().cpu().numpy().ravel(), bins=bins)
ax[1, 1].set_title("Distribution of the differences")

tf1 = d1.ravel()
td1 = diff.ravel()
ax[2, 1].plot(tf1.detach().cpu().numpy(), td1.detach().cpu().numpy(), ".")
ax[2, 1].set_title("Graph XY")
ax[2, 1].set_xlabel("First tensor values")
ax[2, 1].set_ylabel("Difference values")

if title:
fig.suptitle(title)
fig.tight_layout()
if name:
fig.savefig(name)
return ax
Loading
Loading