From 3da85a811fcf96bf91a682e7b55cda3a1847d377 Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 15 Apr 2025 22:29:17 +0200 Subject: [PATCH 1/4] jit --- _unittests/ut_export/test_jit.py | 90 ++++++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) create mode 100644 _unittests/ut_export/test_jit.py diff --git a/_unittests/ut_export/test_jit.py b/_unittests/ut_export/test_jit.py new file mode 100644 index 00000000..07ca44af --- /dev/null +++ b/_unittests/ut_export/test_jit.py @@ -0,0 +1,90 @@ +import unittest +from typing import Callable +import torch +from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout +from onnx_diagnostic.reference import ExtendedReferenceEvaluator +from onnx_diagnostic.helpers.torch_test_helper import is_torchdynamo_exporting + + +@torch.jit.script_if_tracing +def dummy_loop(padded: torch.Tensor, pos: torch.Tensor): + copy = torch.zeros(padded.shape) + for i in range(pos.shape[0]): + p = pos[i] + copy[i, :p] = padded[i, :p] + return copy + + +def wrap_for_export(f: Callable) -> Callable: + + class _wrapped(torch.nn.Module): + def __init__(self): + super().__init__() + self.f = f + + def forward(self, *args, **kwargs): + return self.f(*args, **kwargs) + + return _wrapped() + + +def select_when_exporting(mod, f): + if is_torchdynamo_exporting(): + return mod + return f + + +class TestJit(ExtTestCase): + @hide_stdout() + def test_export_loop(self): + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.wrapped_f = wrap_for_export(dummy_loop) + + def forward(self, images, position): + return select_when_exporting(self.wrapped_f, dummy_loop)(images, position) + + model = Model() + x = torch.randn((5, 6)) + y = torch.arange(5, dtype=torch.int64) + 1 + expected = model(x, y) + + name = self.get_dump_file("test_export_loop.onnx") + torch.onnx.export( + model, + (x, y), + name, + dynamic_axes={"images": {0: "batch", 1: "maxdim"}, "position": {0: "batch"}}, + dynamo=False, + ) + ref = ExtendedReferenceEvaluator(name) + feeds = dict(images=x.numpy(), position=y.numpy()) + got = ref.run(None, feeds)[0] + self.assertEqualArray(expected, got) + + DYN = torch.export.Dim.DYNAMIC + ep = torch.export.export( + model, + (x, y), + dynamic_shapes={"images": {0: DYN, 1: DYN}, "position": {0: DYN}}, + ) + print(ep) + + name2 = self.get_dump_file("test_export_loop.dynamo.onnx") + torch.onnx.export( + model, + (x, y), + name2, + dynamic_axes={"images": {0: "batch", 1: "maxdim"}, "position": {0: "batch"}}, + dynamo=True, + fallback=False, + ) + ref = ExtendedReferenceEvaluator(name2) + feeds = dict(images=x.numpy(), position=y.numpy()) + got = ref.run(None, feeds)[0] + self.assertEqualArray(expected, got) + + +if __name__ == "__main__": + unittest.main(verbosity=2) From 5138b69701a116c92ac84dbb682523310baea158 Mon Sep 17 00:00:00 2001 From: xadupre Date: Thu, 17 Apr 2025 15:08:02 +0200 Subject: [PATCH 2/4] jit --- _unittests/ut_export/test_jit.py | 95 +++++++++++++++++++++++--------- 1 file changed, 69 insertions(+), 26 deletions(-) diff --git a/_unittests/ut_export/test_jit.py b/_unittests/ut_export/test_jit.py index 07ca44af..795977ee 100644 --- a/_unittests/ut_export/test_jit.py +++ b/_unittests/ut_export/test_jit.py @@ -1,10 +1,14 @@ import unittest -from typing import Callable import torch -from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout +from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout, ignore_warnings from onnx_diagnostic.reference import ExtendedReferenceEvaluator from onnx_diagnostic.helpers.torch_test_helper import is_torchdynamo_exporting +try: + from experimental_experiment.torch_interpreter import to_onnx +except ImportError: + to_onnx = None + @torch.jit.script_if_tracing def dummy_loop(padded: torch.Tensor, pos: torch.Tensor): @@ -15,42 +19,53 @@ def dummy_loop(padded: torch.Tensor, pos: torch.Tensor): return copy -def wrap_for_export(f: Callable) -> Callable: - - class _wrapped(torch.nn.Module): - def __init__(self): - super().__init__() - self.f = f +def dummy_loop_with_scan(padded: torch.Tensor, pos: torch.Tensor): + def pad_row(padded, p): + row = torch.zeros((padded.shape[0],)) + torch._check(p.item() > 0) + torch._check(p.item() < padded.shape[0]) + # this check is not always true, we add it anyway to make this dimension >= 2 + # and avoid raising an exception about dynamic dimension in {0, 1} + if is_torchdynamo_exporting(): + torch._check(p.item() > 1) + row[: p.item()] = padded[: p.item()] + return (row,) - def forward(self, *args, **kwargs): - return self.f(*args, **kwargs) + return torch.ops.higher_order.scan( + pad_row, + [], + [padded, pos], + [], + ) - return _wrapped() - -def select_when_exporting(mod, f): - if is_torchdynamo_exporting(): - return mod - return f +def select_when_exporting(f, f_scan): + return f_scan if is_torchdynamo_exporting() else f class TestJit(ExtTestCase): + def test_dummy_loop(self): + x = torch.randn((5, 6)) + y = torch.arange(5, dtype=torch.int64) + 1 + res = dummy_loop(x, y) + res_scan = dummy_loop_with_scan(x, y) + self.assertEqualArray(res, res_scan[0]) + @hide_stdout() - def test_export_loop(self): + @ignore_warnings(UserWarning) + def test_export_loop_onnxscript(self): class Model(torch.nn.Module): - def __init__(self): - super().__init__() - self.wrapped_f = wrap_for_export(dummy_loop) - def forward(self, images, position): - return select_when_exporting(self.wrapped_f, dummy_loop)(images, position) + return select_when_exporting(dummy_loop, dummy_loop_with_scan)( + images, position + ) model = Model() x = torch.randn((5, 6)) y = torch.arange(5, dtype=torch.int64) + 1 expected = model(x, y) - name = self.get_dump_file("test_export_loop.onnx") + name = self.get_dump_file("test_export_loop_onnxscript.onnx") torch.onnx.export( model, (x, y), @@ -68,15 +83,16 @@ def forward(self, images, position): model, (x, y), dynamic_shapes={"images": {0: DYN, 1: DYN}, "position": {0: DYN}}, + strict=False, ) - print(ep) + self.assertNotEmpty(ep) - name2 = self.get_dump_file("test_export_loop.dynamo.onnx") + name2 = self.get_dump_file("test_export_loop_onnxscript.dynamo.onnx") torch.onnx.export( model, (x, y), name2, - dynamic_axes={"images": {0: "batch", 1: "maxdim"}, "position": {0: "batch"}}, + dynamic_shapes={"images": {0: "batch", 1: "maxdim"}, "position": {0: "batch"}}, dynamo=True, fallback=False, ) @@ -85,6 +101,33 @@ def forward(self, images, position): got = ref.run(None, feeds)[0] self.assertEqualArray(expected, got) + @hide_stdout() + @ignore_warnings(UserWarning) + @unittest.skipIf(to_onnx is None, "missing to_onnx") + def test_export_loop_custom(self): + class Model(torch.nn.Module): + def forward(self, images, position): + return select_when_exporting(dummy_loop, dummy_loop_with_scan)( + images, position + ) + + model = Model() + x = torch.randn((5, 6)) + y = torch.arange(5, dtype=torch.int64) + 1 + expected = model(x, y) + + name2 = self.get_dump_file("test_export_loop.custom.onnx") + to_onnx( + model, + (x, y), + filename=name2, + dynamic_shapes={"images": {0: "batch", 1: "maxdim"}, "position": {0: "batch"}}, + ) + ref = ExtendedReferenceEvaluator(name2) + feeds = dict(images=x.numpy(), position=y.numpy()) + got = ref.run(None, feeds)[0] + self.assertEqualArray(expected, got) + if __name__ == "__main__": unittest.main(verbosity=2) From 951257dd29723c4e2bef5c52c33652bdceb2cd70 Mon Sep 17 00:00:00 2001 From: xadupre Date: Fri, 18 Apr 2025 00:05:33 +0200 Subject: [PATCH 3/4] fix ut --- _unittests/ut_export/test_jit.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/_unittests/ut_export/test_jit.py b/_unittests/ut_export/test_jit.py index 795977ee..fb4b8082 100644 --- a/_unittests/ut_export/test_jit.py +++ b/_unittests/ut_export/test_jit.py @@ -1,6 +1,11 @@ import unittest import torch -from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout, ignore_warnings +from onnx_diagnostic.ext_test_case import ( + ExtTestCase, + hide_stdout, + ignore_warnings, + requires_onnxscript, +) from onnx_diagnostic.reference import ExtendedReferenceEvaluator from onnx_diagnostic.helpers.torch_test_helper import is_torchdynamo_exporting @@ -53,6 +58,7 @@ def test_dummy_loop(self): @hide_stdout() @ignore_warnings(UserWarning) + @requires_onnxscript("0.4") def test_export_loop_onnxscript(self): class Model(torch.nn.Module): def forward(self, images, position): @@ -96,7 +102,9 @@ def forward(self, images, position): dynamo=True, fallback=False, ) - ref = ExtendedReferenceEvaluator(name2) + import onnxruntime + + ref = onnxruntime.InferenceSession(name2, providers=["CPUExecutionProvider"]) feeds = dict(images=x.numpy(), position=y.numpy()) got = ref.run(None, feeds)[0] self.assertEqualArray(expected, got) @@ -123,7 +131,9 @@ def forward(self, images, position): filename=name2, dynamic_shapes={"images": {0: "batch", 1: "maxdim"}, "position": {0: "batch"}}, ) - ref = ExtendedReferenceEvaluator(name2) + import onnxruntime + + ref = onnxruntime.InferenceSession(name2, providers=["CPUExecutionProvider"]) feeds = dict(images=x.numpy(), position=y.numpy()) got = ref.run(None, feeds)[0] self.assertEqualArray(expected, got) From 78092a1f11feeff04a0280c57fe836f298a4c945 Mon Sep 17 00:00:00 2001 From: xadupre Date: Fri, 18 Apr 2025 00:19:14 +0200 Subject: [PATCH 4/4] ut --- _unittests/ut_export/test_jit.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/_unittests/ut_export/test_jit.py b/_unittests/ut_export/test_jit.py index fb4b8082..525eae01 100644 --- a/_unittests/ut_export/test_jit.py +++ b/_unittests/ut_export/test_jit.py @@ -1,3 +1,4 @@ +import inspect import unittest import torch from onnx_diagnostic.ext_test_case import ( @@ -15,6 +16,9 @@ to_onnx = None +has_scan_reverse = "reverse" in set(inspect.signature(torch.ops.higher_order.scan).parameters) + + @torch.jit.script_if_tracing def dummy_loop(padded: torch.Tensor, pos: torch.Tensor): copy = torch.zeros(padded.shape) @@ -36,12 +40,12 @@ def pad_row(padded, p): row[: p.item()] = padded[: p.item()] return (row,) - return torch.ops.higher_order.scan( - pad_row, - [], - [padded, pos], - [], - ) + if has_scan_reverse: + # torch==2.6 + return torch.ops.higher_order.scan( + pad_row, [], [padded, pos], additional_inputs=[], reverse=False, dim=0 + ) + return torch.ops.higher_order.scan(pad_row, [], [padded, pos], []) def select_when_exporting(f, f_scan):