Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions _doc/api/export/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ onnx_diagnostic.export
:caption: modules

dynamic_shapes
validate

CoupleInputsDynamicShapes
+++++++++++++++++++++++++
Expand All @@ -19,6 +20,11 @@ ModelInputs
.. autoclass:: onnx_diagnostic.export.ModelInputs
:members:

validate_ep
+++++++++++

.. autofunction:: onnx_diagnostic.export.validate_ep

Other functions
+++++++++++++++

Expand Down
8 changes: 8 additions & 0 deletions _doc/api/export/validate.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@

onnx_diagnostic.export.validate
===============================

.. automodule:: onnx_diagnostic.export.validate
:members:
:no-undoc-members:
:exclude-members: validate_ep
84 changes: 84 additions & 0 deletions _unittests/ut_export/test_validate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import unittest
import torch
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout
from onnx_diagnostic.export import CoupleInputsDynamicShapes, validate_ep


class TestValidate(ExtTestCase):
@hide_stdout()
def test_validate_args(self):
class Model(torch.nn.Module):
def forward(self, x, y):
return x + y

model = Model()
x = torch.randn((5, 6))
y = torch.randn((1, 6))
model(x, y)
ds = ({0: "a", 1: "b"}, {1: "b"})
cpl = CoupleInputsDynamicShapes((x, y), {}, ds)
ep = torch.export.export(model, (x, y), dynamic_shapes=cpl.replace_string_by())
validate_ep(
ep,
model,
args=(x, y),
verbose=2,
copy=True,
dynamic_shapes=ds,
values_to_try={"a": [5, 10], "b": [10, 20]},
)

@hide_stdout()
def test_validate_kwargs(self):
class Model(torch.nn.Module):
def forward(self, x, y):
return x + y

model = Model()
x = torch.randn((5, 6))
y = torch.randn((1, 6))
model(x=x, y=y)
ds = dict(x={0: "a", 1: "b"}, y={1: "b"})
cpl = CoupleInputsDynamicShapes((), dict(x=x, y=y), ds)
ep = torch.export.export(
model, (), kwargs=dict(x=x, y=y), dynamic_shapes=cpl.replace_string_by()
)
validate_ep(
ep,
model,
kwargs=dict(x=x, y=y),
verbose=2,
copy=True,
dynamic_shapes=ds,
values_to_try={"a": [5, 10], "b": [10, 20]},
)

@hide_stdout()
def test_validate_args_kwargs(self):
class Model(torch.nn.Module):
def forward(self, x, y):
return x + y

model = Model()
x = torch.randn((5, 6))
y = torch.randn((1, 6))
model(x, y=y)
ds = dict(x={0: "a", 1: "b"}, y={1: "b"})
cpl = CoupleInputsDynamicShapes((x,), dict(y=y), ds, args_names=["x"])
ep = torch.export.export(
model, (x,), kwargs=dict(y=y), dynamic_shapes=cpl.replace_string_by()
)
validate_ep(
ep,
model,
args=(x,),
kwargs=dict(y=y),
verbose=2,
copy=True,
dynamic_shapes=ds,
values_to_try={"a": [5, 10], "b": [10, 20]},
)


if __name__ == "__main__":
unittest.main(verbosity=2)
1 change: 1 addition & 0 deletions onnx_diagnostic/export/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .dynamic_shapes import CoupleInputsDynamicShapes, ModelInputs
from .validate import validate_ep
35 changes: 26 additions & 9 deletions onnx_diagnostic/export/dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def _valid_shapes_tensor(cls, inputs, ds):
issues[i] = f"d=[{d}]"
return issues if issues else None

def _generic_walker(self, processor: Callable):
def _generic_walker(self, processor: Callable, args_kwargs: bool = False):
"""
Generic deserializator walking through inputs and dynamic_shapes all along.
The function returns a result with the same structure as the dynamic shapes.
Expand All @@ -157,14 +157,16 @@ def _generic_walker(self, processor: Callable):
f"Type mismatch, args={string_type(self.args)} and "
f"dynamic_shapes={self.dynamic_shapes} should have the same type."
)
return self._generic_walker_step(processor, self.kwargs, self.dynamic_shapes)
res = self._generic_walker_step(processor, self.kwargs, self.dynamic_shapes)
return (tuple(), res) if args_kwargs else res

if not self.kwargs:
assert isinstance(self.args, tuple) and isinstance(self.dynamic_shapes, tuple), (
f"Type mismatch, args={string_type(self.args)} and "
f"dynamic_shapes={self.dynamic_shapes} should have the same type."
)
return self._generic_walker_step(processor, self.args, self.dynamic_shapes)
res = self._generic_walker_step(processor, self.args, self.dynamic_shapes)
return (res, {}) if args_kwargs else res

assert isinstance(self.dynamic_shapes, dict), (
f"Both positional and named arguments (args and kwargs) are filled. "
Expand Down Expand Up @@ -192,7 +194,17 @@ def _generic_walker(self, processor: Callable):
)
kwargs = dict(zip(self.args_names, self.args))
kwargs.update(self.kwargs)
return self._generic_walker_step(processor, kwargs, self.dynamic_shapes)
res = self._generic_walker_step(processor, kwargs, self.dynamic_shapes)
if args_kwargs:
pgs = [None for _ in range(len(self.args))]
kws = {}
for k, v in res.items():
if k not in self.kwargs:
pgs[self.args_names.index(k)] = v
else:
kws[k] = v
return pgs, kws
return res

raise NotImplementedError(
f"Not yet implemented when args is filled, "
Expand Down Expand Up @@ -285,14 +297,14 @@ def _build_new_tensor(self, tensor: torch.Tensor, new_shape: Tuple[int, ...]):
tuple(alt_shape), dtype=tensor.dtype, device=tensor.device
)
mind = min(d0, d1)
indices = [slice(None) for _ in range(rank)]
indices: List[Union[slice, int]] = [slice(None) for _ in range(rank)]
indices[i] = slice(0, mind)
ind = tuple(indices)
new_tensor[ind] = tensor[ind]
if d1 > mind:
for k in range(d1 - mind):
indices0 = [slice(None) for _ in range(rank)]
indices1 = [slice(None) for _ in range(rank)]
indices0: List[Union[slice, int]] = [slice(None) for _ in range(rank)]
indices1: List[Union[slice, int]] = [slice(None) for _ in range(rank)]
indices1[i] = mind + k
indices0[i] = k % mind
new_tensor[tuple(indices1)] = tensor[tuple(indices0)]
Expand All @@ -310,7 +322,9 @@ def __call__(self, inputs, ds):
new_shape = self._build_new_shape(inputs.shape, ds)
return self._build_new_tensor(inputs, new_shape)

def change_dynamic_dimensions(self, desired_values: Optional[Dict[str, int]] = None):
def change_dynamic_dimensions(
self, desired_values: Optional[Dict[str, int]] = None, args_kwargs: bool = False
):
"""
A model exported with dynamic shapes is not necessarily dynamic
just because the user specified dynamic shapes. The algorithm
Expand All @@ -321,6 +335,7 @@ def change_dynamic_dimensions(self, desired_values: Optional[Dict[str, int]] = N
the model.

:param desired_values: to fixed named dimension to have the desired value
:param args_kwargs: return both args, kwargs even if empty
:return: new inputs

Example:
Expand All @@ -343,7 +358,9 @@ def change_dynamic_dimensions(self, desired_values: Optional[Dict[str, int]] = N
print("before:", string_type(kwargs, with_shape=True))
print("-after:", string_type(new_kwargs, with_shape=True))
"""
return self._generic_walker(self.ChangeDimensionProcessor(desired_values))
return self._generic_walker(
self.ChangeDimensionProcessor(desired_values), args_kwargs=args_kwargs
)


class ModelInputs:
Expand Down
170 changes: 170 additions & 0 deletions onnx_diagnostic/export/validate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
import inspect
import itertools
import time
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from ..helpers import string_type, max_diff, string_diff
from ..helpers.torch_test_helper import torch_deepcopy
from .dynamic_shapes import CoupleInputsDynamicShapes


def compare_modules(
modep: torch.nn.Module,
mod: Optional[torch.nn.Module] = None,
args: Optional[Tuple[Any, ...]] = None,
kwargs: Optional[Dict[str, Any]] = None,
copy: bool = False,
exc: bool = True,
verbose: int = 0,
atol: float = 1e-2,
rtol: float = 1e-1,
) -> Dict[str, Any]:
"""
Compares two torch modules, usually one coming from an exported program,
the other being the origin model.

:param model: first module
:param mod: second module (it produces the expected values)
:param args: positional arguments
:param kwargs: named arguments
:param copy: copy the inputs before executing the model (they may modify them inplace)
:param exc: raise exception if discrepancies are too high
:param verbose: verbosity level
:param atol: absolute tolerance
:param rtol: relative tolerance
:return: dictionary with inputs, outputs and tolerance

Example:

.. runpython::
:showcode:

import torch
from onnx_diagnostic.export import validate_ep, CoupleInputsDynamicShapes

class Model(torch.nn.Module):
def forward(self, x, y):
return x + y

model = Model()
x = torch.randn((5, 6))
y = torch.randn((1, 6))
model(x, y) # to make it is running

ds = ({0: "a", 1: "b"}, {1: "b"})
cpl = CoupleInputsDynamicShapes((x, y), {}, ds)
ep = torch.export.export(model, (x, y), dynamic_shapes=cpl.replace_string_by())
validate_ep(
ep,
model,
args=(x, y),
verbose=2,
copy=True,
dynamic_shapes=ds,
values_to_try={"a": [5, 10], "b": [10, 20]},
)

"""
args = args or ()
kwargs = kwargs or {}

def _get(a):
return torch_deepcopy(a) if copy else a

if verbose:
begin = time.perf_counter()
print(
f"[compare_modules] check ep with "
f"args={string_type(args, with_shape=True)}, "
f"kwargs={string_type(kwargs, with_shape=True)}..."
)
got = modep(*_get(args), **_get(kwargs))
if verbose:
d = time.perf_counter() - begin
print(f"[compare_modules] done in {d} with output={string_type(got, with_shape=True)}")
if mod:
if verbose:
begin = time.perf_counter()
print("[compare_modules] run torch module...")
expected = mod(*_get(args), **_get(kwargs))
diff = max_diff(expected, got)
if verbose:
d = time.perf_counter() - begin
print(
f"[compare_modules] done in {d} with "
f"output={string_type(expected, with_shape=True)}"
)
print(f"[compare_modules] discrepancies={string_diff(diff)}")
assert not exc or (
diff["abs"] <= atol and diff["rel"] <= rtol
), f"Discrepancies={string_diff(diff)} higher than expected."
return dict(args=args, kwargs=kwargs, expected=expected, got=got, diff=diff)
return dict(args=args, kwargs=kwargs, got=got)


def validate_ep(
ep: Union[torch.nn.Module, torch.export.ExportedProgram],
mod: Optional[torch.nn.Module] = None,
args: Optional[Tuple[Any, ...]] = None,
kwargs: Optional[Dict[str, Any]] = None,
copy: bool = False,
dynamic_shapes: Optional[Any] = None,
values_to_try: Optional[Dict[str, List[int]]] = None,
exc: bool = True,
verbose: int = 0,
atol: float = 1e-2,
rtol: float = 1e-1,
) -> List[Dict[str, Any]]:
"""
Validates an exported program.

:param model: first module
:param mod: second module (it produces the expected values)
:param args: positional arguments
:param kwargs: named arguments
:param copy: copy the inputs before executing the model (they may modify them inplace)
:param dynamic_shapes: dynamic shapes, string should be used not ``torch.export.Dim``
:param values_to_try: dictionary with the values to try for every dynamic dimension
:param exc: raise exception if discrepancies are too high
:param verbose: verbosity level
:param atol: absolute tolerance
:param rtol: relative tolerance
:return: dictionary with inputs, outputs and tolerance
"""
modep = ep.module() if isinstance(ep, torch.export.ExportedProgram) else ep

results = [
compare_modules(
modep, mod, args, kwargs, copy=copy, verbose=verbose, atol=atol, rtol=rtol
)
]

assert (dynamic_shapes and values_to_try) or (
not dynamic_shapes and not values_to_try
), "Either both dynamic_shapes and values_to_try are specified, either none."
if not dynamic_shapes or not values_to_try:
return results

items = list(values_to_try.items())
keys = [_[0] for _ in items]
values = [_[1] for _ in items]
all_vals = list(itertools.product(*values))
cpl = CoupleInputsDynamicShapes(
args or (),
kwargs or {},
dynamic_shapes,
args_names=(
list(inspect.signature(modep.forward).parameters) if args and kwargs else None
),
)
for i, vals in enumerate(all_vals):
change_dims = dict(zip(keys, vals))
if verbose:
print(f"[validate_ep] try {i}/{len(all_vals)}: {change_dims}")
new_params = cpl.change_dynamic_dimensions(change_dims, args_kwargs=True)
na, nkw = new_params
c = compare_modules(
modep, mod, na, nkw, copy=copy, verbose=max(verbose - 1, 0), atol=atol, rtol=rtol
)
results.append(c)
return results
Loading
Loading