From c51a4e64c0183120e3bc7b8f0b0cf36d8ff7f406 Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Fri, 19 Jan 2024 05:51:15 +0000 Subject: [PATCH 01/12] Add support for compiling SDPAParams (#117207) Allows us to `allow_in_graph` this `torch._C` struct for supporting scaled dot product attention. helps unblock https://github.com/pytorch/pytorch/pull/116071 Pull Request resolved: https://github.com/pytorch/pytorch/pull/117207 Approved by: https://github.com/voznesenskym --- test/dynamo/test_dynamic_shapes.py | 3 + test/dynamo/test_sdpa.py | 106 ++++++++++++++++++++++++++++ torch/_dynamo/variables/__init__.py | 2 + torch/_dynamo/variables/builder.py | 11 ++- torch/_dynamo/variables/sdpa.py | 83 ++++++++++++++++++++++ torch/_dynamo/variables/torch.py | 11 +++ torch/backends/cuda/__init__.py | 1 + 7 files changed, 216 insertions(+), 1 deletion(-) create mode 100644 test/dynamo/test_sdpa.py create mode 100644 torch/_dynamo/variables/sdpa.py diff --git a/test/dynamo/test_dynamic_shapes.py b/test/dynamo/test_dynamic_shapes.py index 8231465c568ad..4bf6c78b88cf3 100644 --- a/test/dynamo/test_dynamic_shapes.py +++ b/test/dynamo/test_dynamic_shapes.py @@ -17,6 +17,7 @@ test_misc, test_modules, test_repros, + test_sdpa, test_subgraphs, ) except ImportError: @@ -28,6 +29,7 @@ import test_misc import test_modules import test_repros + import test_sdpa import test_subgraphs @@ -69,6 +71,7 @@ def make_dynamic_cls(cls): test_higher_order_ops.HigherOrderOpTests, test_higher_order_ops.FuncTorchHigherOrderOpTests, test_aot_autograd.AotAutogradFallbackTests, + test_sdpa.TestSDPA, ] for test in tests: make_dynamic_cls(test) diff --git a/test/dynamo/test_sdpa.py b/test/dynamo/test_sdpa.py new file mode 100644 index 0000000000000..ba351410a203d --- /dev/null +++ b/test/dynamo/test_sdpa.py @@ -0,0 +1,106 @@ +# Owner(s): ["module: dynamo"] +import contextlib + +import torch._dynamo.test_case +import torch._dynamo.testing +from torch._dynamo.testing import CompileCounter +from torch.backends.cuda import SDPAParams + + +@contextlib.contextmanager +def allow_in_graph_sdpa_params(): + global SDPAParams + try: + old = SDPAParams + SDPAParams = torch._dynamo.allow_in_graph(SDPAParams) + yield + finally: + SDPAParams = old + + +class TestSDPA(torch._dynamo.test_case.TestCase): + def assert_ref_equals_params(self, actual, expected): + self.assertIs(actual.query, expected.query) + self.assertIs(actual.key, expected.key) + self.assertIs(actual.value, expected.value) + self.assertIs(actual.attn_mask, expected.attn_mask) + + def test_returns_SDPAParams(self): + with allow_in_graph_sdpa_params(): + counter = CompileCounter() + + @torch.compile(fullgraph=True, backend=counter) + def fn(q, k, v, m): + return SDPAParams(q, k, v, m, 0.1, True) + + q = torch.randn(10) + k = torch.randn(10) + v = torch.randn(10) + m = torch.randn(10) + o = fn(q, k, v, m) + self.assertTrue(isinstance(o, SDPAParams)) + self.assert_ref_equals_params(o, SDPAParams(q, k, v, m, 0.1, True)) + self.assertEqual(counter.frame_count, 1) + + def test_graph_break_SDPAParams(self): + with allow_in_graph_sdpa_params(): + counter = CompileCounter() + + @torch.compile(backend=counter) + def fn(q, k, v, m): + z = SDPAParams(q, k, v, m, 0.1, True) + torch._dynamo.graph_break() + return z, q + 1 + + q = torch.randn(10) + k = torch.randn(10) + v = torch.randn(10) + m = torch.randn(10) + o, _ = fn(q, k, v, m) + self.assertTrue(isinstance(o, SDPAParams)) + self.assert_ref_equals_params(o, SDPAParams(q, k, v, m, 0.1, True)) + self.assertEqual(counter.frame_count, 2) + + def test_input_SDPAParams(self): + with allow_in_graph_sdpa_params(): + counter = CompileCounter() + + @torch.compile(backend=counter) + def fn(sdpap, q): + torch._dynamo.graph_break() + return sdpap, sdpap.query + q + + q = torch.randn(10) + k = torch.randn(10) + v = torch.randn(10) + m = torch.randn(10) + s = SDPAParams(q, k, v, m, 0.1, True) + o, _ = fn(s, q) + self.assertIs(o, s) + self.assertEqual(counter.frame_count, 1) + + def test_intermediate_attr_access_SDPAParams(self): + with allow_in_graph_sdpa_params(): + counter = CompileCounter() + + @torch.compile(fullgraph=True, backend=counter) + def fn(q, k, v, m): + q += 1 + z = SDPAParams(q, k, v, m, 0.1, True) + a = z.query + return a + 1, z, q + + q = torch.randn(10) + k = torch.randn(10) + v = torch.randn(10) + m = torch.randn(10) + _, o, _ = fn(q, k, v, m) + expected = SDPAParams(q, k, v, m, 0.1, True) + self.assert_ref_equals_params(o, expected) + self.assertEqual(counter.frame_count, 1) + + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/torch/_dynamo/variables/__init__.py b/torch/_dynamo/variables/__init__.py index 2c1af0c9c2e10..811d0fce8e1b4 100644 --- a/torch/_dynamo/variables/__init__.py +++ b/torch/_dynamo/variables/__init__.py @@ -62,6 +62,7 @@ UnknownVariable, ) from .nn_module import NNModuleVariable, UnspecializedNNModuleVariable +from .sdpa import SDPAParamsVariable from .tensor import ( FakeItemVariable, NumpyNdarrayVariable, @@ -127,4 +128,5 @@ "UserMethodVariable", "VariableTracker", "WithExitFunctionVariable", + "SDPAParamsVariable", ] diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index ce101fa9c76c4..bd192e197dc97 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -138,9 +138,10 @@ SkipFilesVariable, TypingVariable, ) - from .nn_module import FSDPManagedNNModuleVariable, UnspecializedNNModuleVariable from .optimizer import OptimizerVariable + +from .sdpa import SDPAParamsVariable from .tensor import ( NumpyNdarrayVariable, SymNodeVariable, @@ -586,6 +587,9 @@ def build_key_value(k, v): value.device, source=self.source, ) + elif isinstance(value, torch._C._SDPAParams): + self.install_guards(GuardBuilder.TYPE_MATCH) + return SDPAParamsVariable.create(self.tx, value, self.source) elif isinstance(value, _EventBase): self.install_guards(GuardBuilder.ID_MATCH) return EventVariable( @@ -1513,6 +1517,11 @@ def _clone_input(value): ]: proxy.node.meta["example_value"] = example_value return ConstantVariable.create(example_value, **options) + elif isinstance(example_value, torch.backends.cuda.SDPAParams): + from .sdpa import SDPAParamsVariable + + proxy.node.meta["example_value"] = example_value + return SDPAParamsVariable(proxy, **options) else: unimplemented( "torch.* op returned non-Tensor " diff --git a/torch/_dynamo/variables/sdpa.py b/torch/_dynamo/variables/sdpa.py new file mode 100644 index 0000000000000..19d02c2a4905d --- /dev/null +++ b/torch/_dynamo/variables/sdpa.py @@ -0,0 +1,83 @@ +from inspect import getattr_static + +from ..bytecode_transformation import create_call_function +from ..exc import Unsupported +from .base import VariableTracker + + +class SDPAParamsVariable(VariableTracker): + """Represents the c++ params struct for scaled dot product attention. + This is a read-only container.""" + + @staticmethod + def create(tx, value, source): + from torch.backends.cuda import SDPAParams + from ..source import AttrSource + from .builder import VariableBuilder + from .torch import TorchInGraphFunctionVariable + + query_var = VariableBuilder(tx, AttrSource(source, "query"))(value.query) + key_var = VariableBuilder(tx, AttrSource(source, "key"))(value.key) + value_var = VariableBuilder(tx, AttrSource(source, "value"))(value.value) + attn_mask_var = VariableBuilder(tx, AttrSource(source, "attn_mask"))( + value.attn_mask + ) + dropout_var = VariableBuilder(tx, AttrSource(source, "dropout"))(value.dropout) + is_causal_var = VariableBuilder(tx, AttrSource(source, "is_causal"))( + value.is_causal + ) + param_vars = [ + query_var, + key_var, + value_var, + attn_mask_var, + dropout_var, + is_causal_var, + ] + return TorchInGraphFunctionVariable(SDPAParams).call_function( + tx, param_vars, {} + ) + + def __init__(self, proxy, param_vars, **kwargs): + self.proxy = proxy + self.param_vars = param_vars + super().__init__(**kwargs) + + def reconstruct(self, codegen): + assert self.source is None + assert self.param_vars is not None + codegen.load_import_from("torch._C", "_SDPAParams") + for var in self.param_vars: + codegen(var) + return create_call_function(len(self.param_vars), True) + + def as_proxy(self): + return self.proxy + + def var_getattr(self, tx, name: str) -> VariableTracker: + import torch._C + from ..source import AttrSource + from .builder import wrap_fx_proxy + from .misc import GetAttrVariable + + try: + getattr_static(torch._C._SDPAParams, name) + except AttributeError: + # Using raise from is too verbose here + raise Unsupported( # noqa: TRY200 + f"Unsupported torch._C._SDPAParams attribute {name}" + ) + + proxy = GetAttrVariable.create_getattr_proxy(self.as_proxy(), name) + if self.source is not None: + return wrap_fx_proxy( + tx=tx, proxy=proxy, source=AttrSource(self.source, name) + ) + else: + return wrap_fx_proxy(tx=tx, proxy=proxy) + + @staticmethod + def is_sdpa_params(value): + from torch.backends.cuda import SDPAParams + + return value is SDPAParams diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 07c97d0ab0302..7547945e70249 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -209,6 +209,7 @@ def call_function( DeterministicAlgorithmsVariable, DisabledSavedTensorsHooksVariable, GradModeVariable, + SDPAParamsVariable, StreamContextVariable, SymNodeVariable, TensorVariable, @@ -447,6 +448,16 @@ def call_function( ) ): return ConstantVariable(None) + elif SDPAParamsVariable.is_sdpa_params(self.value): + return wrap_fx_proxy( + tx, + proxy=tx.output.create_proxy( + "call_function", + torch._C._SDPAParams, + *proxy_args_kwargs(args, kwargs), + ), + param_vars=args, + ) elif is_constant_pg_functions(self.value): # becuase the input is a "ProcessGroupVariable", we'll be guarding on its # ID_MATCH based on how it was constructed. diff --git a/torch/backends/cuda/__init__.py b/torch/backends/cuda/__init__.py index 77ce755d14a98..7fa4e9e669ec0 100644 --- a/torch/backends/cuda/__init__.py +++ b/torch/backends/cuda/__init__.py @@ -209,6 +209,7 @@ def preferred_linalg_library( # Set the __module__ attribute SDPBackend.__module__ = "torch.backends.cuda" SDPAParams.__module__ = "torch.backends.cuda" +SDPAParams.__name__ = "SDPAParams" def flash_sdp_enabled(): From 4057d005ffff1ee848af113c21c62af262fceacc Mon Sep 17 00:00:00 2001 From: suo Date: Thu, 18 Jan 2024 17:34:43 -0800 Subject: [PATCH 02/12] Initial torchbind support in PT2 (#117697) This PR adds the bare minimum functionality to get torchbind working in an e2e testable way on PT2. It implements: * ProxyTensor support * Simple torch.export support (proxytensor-only path, e.g. non-strict). * add some tests exercising the path. Because all this is not fully baked, I hide the functionality behind a feature flag (`enable_torchbind_tracing()`) so it does not affect regular users for now. Still on the agenda: * Dynamo support * Actual FakeMode support * Mutability support Hoping to get this first bit in as a standalone, as it will unblock some more extensive experimentation/testing going on internally. Differential Revision: [D51825372](https://our.internmc.facebook.com/intern/diff/D51825372/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/117697 Approved by: https://github.com/SherlockNoMad --- docs/source/conf.py | 2 +- test/allowlist_for_publicAPI.json | 2 +- .../jit/test_custom_class_registrations.cpp | 26 ++++- test/export/test_torchbind.py | 108 ++++++++++++++++++ torch/_dynamo/compiled_autograd.py | 4 +- torch/_export/non_strict_utils.py | 8 ++ torch/_export/passes/lift_constants_pass.py | 24 ++++ .../collect_metadata_analysis.py | 2 +- torch/_functorch/_aot_autograd/utils.py | 4 +- torch/_higher_order_ops/torchbind.py | 94 +++++++++++++++ torch/distributed/_spmd/comm_tensor.py | 4 +- torch/export/_trace.py | 12 +- torch/export/graph_signature.py | 9 +- torch/fx/experimental/proxy_tensor.py | 19 ++- 14 files changed, 296 insertions(+), 22 deletions(-) create mode 100644 test/export/test_torchbind.py create mode 100644 torch/_higher_order_ops/torchbind.py diff --git a/docs/source/conf.py b/docs/source/conf.py index f2388219fa6aa..3ece6fcc65a4a 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -896,7 +896,7 @@ "extract_val", "fake_signature", "fetch_sym_proxy", - "fetch_tensor_proxy", + "fetch_object_proxy", "get_innermost_proxy_mode", "get_isolated_graphmodule", "get_proxy_slot", diff --git a/test/allowlist_for_publicAPI.json b/test/allowlist_for_publicAPI.json index 579c4a49956c4..196df6175fdf8 100644 --- a/test/allowlist_for_publicAPI.json +++ b/test/allowlist_for_publicAPI.json @@ -1930,7 +1930,7 @@ "extract_val", "fake_signature", "fetch_sym_proxy", - "fetch_tensor_proxy", + "fetch_object_proxy", "get_isolated_graphmodule", "get_proxy_slot", "get_torch_dispatch_modes", diff --git a/test/cpp/jit/test_custom_class_registrations.cpp b/test/cpp/jit/test_custom_class_registrations.cpp index 3982407f19660..cec3db025e743 100644 --- a/test/cpp/jit/test_custom_class_registrations.cpp +++ b/test/cpp/jit/test_custom_class_registrations.cpp @@ -46,6 +46,9 @@ struct Foo : torch::CustomClassHolder { int64_t add(int64_t z) { return (x + y) * z; } + at::Tensor add_tensor(at::Tensor z) { + return (x + y) * z; + } void increment(int64_t z) { this->x += z; this->y += z; @@ -317,8 +320,18 @@ TORCH_LIBRARY(_TorchScriptTesting, m) { .def("info", &Foo::info) .def("increment", &Foo::increment) .def("add", &Foo::add) + .def("add_tensor", &Foo::add_tensor) .def("__eq__", &Foo::eq) - .def("combine", &Foo::combine); + .def("combine", &Foo::combine) + .def_pickle( + [](c10::intrusive_ptr self) { // __getstate__ + return std::vector{self->x, self->y}; + }, + [](std::vector state) { // __setstate__ + return c10::make_intrusive(state[0], state[1]); + }); + m.def( + "takes_foo(__torch__.torch.classes._TorchScriptTesting._Foo foo, Tensor x) -> Tensor"); m.class_("_FooGetterSetter") .def(torch::init()) @@ -436,4 +449,15 @@ TORCH_LIBRARY(_TorchScriptTesting, m) { }); } +at::Tensor takes_foo(c10::intrusive_ptr foo, at::Tensor x) { + return foo->add_tensor(x); +} + +TORCH_LIBRARY_IMPL(_TorchScriptTesting, CPU, m) { + m.impl("takes_foo", takes_foo); +} +TORCH_LIBRARY_IMPL(_TorchScriptTesting, Meta, m) { + m.impl("takes_foo", &takes_foo); +} + } // namespace diff --git a/test/export/test_torchbind.py b/test/export/test_torchbind.py new file mode 100644 index 0000000000000..9d63c65dbe937 --- /dev/null +++ b/test/export/test_torchbind.py @@ -0,0 +1,108 @@ +# Owner(s): ["module: export"] +import unittest + +import torch +from torch._higher_order_ops.torchbind import enable_torchbind_tracing +from torch.export import export +from torch.testing._internal.common_utils import ( + find_library_location, + IS_FBCODE, + IS_MACOS, + IS_SANDCASTLE, + IS_WINDOWS, + run_tests, + skipIfTorchDynamo, + TestCase, +) + + +@skipIfTorchDynamo("torchbind not supported with dynamo yet") +class TestExportTorchbind(TestCase): + def setUp(self): + if IS_MACOS: + raise unittest.SkipTest("non-portable load_library call used in test") + elif IS_SANDCASTLE or IS_FBCODE: + torch.ops.load_library( + "//caffe2/test/cpp/jit:test_custom_class_registrations" + ) + elif IS_WINDOWS: + lib_file_path = find_library_location("torchbind_test.dll") + torch.ops.load_library(str(lib_file_path)) + else: + lib_file_path = find_library_location("libtorchbind_test.so") + torch.ops.load_library(str(lib_file_path)) + + def _test_export_same_as_eager(self, f, args, kwargs=None, strict=True): + kwargs = kwargs or {} + with enable_torchbind_tracing(): + exported_program = export(f, args, kwargs, strict=strict) + reversed_kwargs = {key: kwargs[key] for key in reversed(kwargs)} + self.assertEqual(exported_program(*args, **kwargs), f(*args, **kwargs)) + self.assertEqual( + exported_program(*args, **reversed_kwargs), f(*args, **reversed_kwargs) + ) + + def test_none(self): + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.attr = torch.classes._TorchScriptTesting._Foo(10, 20) + + def forward(self, x, n): + return x + self.attr.add_tensor(x) + + self._test_export_same_as_eager( + MyModule(), (torch.ones(2, 3), None), strict=False + ) + + def test_attribute(self): + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.attr = torch.classes._TorchScriptTesting._Foo(10, 20) + + def forward(self, x): + return x + self.attr.add_tensor(x) + + self._test_export_same_as_eager(MyModule(), (torch.ones(2, 3),), strict=False) + + def test_attribute_as_custom_op_argument(self): + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.attr = torch.classes._TorchScriptTesting._Foo(10, 20) + + def forward(self, x): + return x + torch.ops._TorchScriptTesting.takes_foo(self.attr, x) + + self._test_export_same_as_eager(MyModule(), (torch.ones(2, 3),), strict=False) + + def test_input(self): + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, cc): + return x + cc.add_tensor(x) + + cc = torch.classes._TorchScriptTesting._Foo(10, 20) + self._test_export_same_as_eager( + MyModule(), (torch.ones(2, 3), cc), strict=False + ) + + def test_input_as_custom_op_argument(self): + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, cc): + return x + torch.ops._TorchScriptTesting.takes_foo(cc, x) + + cc = torch.classes._TorchScriptTesting._Foo(10, 20) + self._test_export_same_as_eager( + MyModule(), (torch.ones(2, 3), cc), strict=False + ) + + +if __name__ == "__main__": + run_tests() diff --git a/torch/_dynamo/compiled_autograd.py b/torch/_dynamo/compiled_autograd.py index 0da712d2e996f..3da03b8beab51 100644 --- a/torch/_dynamo/compiled_autograd.py +++ b/torch/_dynamo/compiled_autograd.py @@ -14,7 +14,7 @@ decompose, disable_autocast_cache, disable_proxy_modes_tracing, - fetch_tensor_proxy, + fetch_object_proxy, ProxyTorchDispatchMode, PythonKeyTracer, track_tensor_tree, @@ -211,7 +211,7 @@ def to_proxy(self, t): if isinstance(t, tuple): return tuple(self.to_proxy(x) for x in t) assert isinstance(t, (torch.Tensor, torch.SymInt)) - return fetch_tensor_proxy(self.fx_tracer)(t).proxy + return fetch_object_proxy(self.fx_tracer)(t).proxy def bind_tensors_to_proxies(self, tensors, proxies): if isinstance(proxies, torch.fx.Proxy): diff --git a/torch/_export/non_strict_utils.py b/torch/_export/non_strict_utils.py index c200f306a7fe7..eefea7fc88aef 100644 --- a/torch/_export/non_strict_utils.py +++ b/torch/_export/non_strict_utils.py @@ -14,6 +14,7 @@ from torch._guards import Source from torch._subclasses.fake_tensor import FakeTensorMode from torch.export import Constraint +from torch.export.custom_obj import ScriptObjectMeta from torch.fx.experimental.symbolic_shapes import ( ConstraintViolationError, DimDynamic, @@ -29,6 +30,11 @@ def fakify(mode, t, t_constraints, source, sources): by tensor ids, the source for the tensor, and an accumulator mapping tensor dimensions to their sources. """ + if t is None or isinstance(t, torch.ScriptObject): + return t + if not isinstance(t, torch.Tensor): + raise ValueError("Only tensors allowed as input") + n_dims = len(t.shape) symbolic_context = StatelessSymbolicContext( dynamic_sizes=[DimDynamic.STATIC] * n_dims, @@ -151,6 +157,8 @@ def make_constraints(fake_mode, src_equalities, original_signature, gm): for node in gm.graph.nodes: if node.op != "placeholder": continue + if node.meta["val"] is None or isinstance(node.meta["val"], ScriptObjectMeta): + continue for i, d in enumerate(node.meta["val"].shape): if isinstance(d, torch.SymInt): range_constraints[d.node.expr] = shape_env.var_to_range[d.node.expr] diff --git a/torch/_export/passes/lift_constants_pass.py b/torch/_export/passes/lift_constants_pass.py index 0b0ea262f383f..468a5a60bb33a 100644 --- a/torch/_export/passes/lift_constants_pass.py +++ b/torch/_export/passes/lift_constants_pass.py @@ -101,3 +101,27 @@ def lift_constants_pass( first_user_input_loc += 1 return all_constants + + +def rewrite_script_object_meta( + gm: torch.fx.GraphModule, +) -> Dict[str, Union[torch.Tensor, torch.ScriptObject]]: + """When tracing, we produce a graph with an actual ScriptObject in the + meta["val"]. Eventually we want to change this behavior, when FakeMode infra + for ScriptObjects lands. + + For now, we rewrie meta["val"] to be a placeholder ScriptObjectMeta. + """ + constants: Dict[str, Union[torch.Tensor, torch._C.ScriptObject]] = {} + for node in gm.graph.nodes: + if "val" not in node.meta or not isinstance( + node.meta["val"], torch.ScriptObject + ): + continue + + old_meta = node.meta["val"] + new_meta = ScriptObjectMeta(node.name) + constants[node.name] = old_meta + node.meta["val"] = new_meta + + return constants diff --git a/torch/_functorch/_aot_autograd/collect_metadata_analysis.py b/torch/_functorch/_aot_autograd/collect_metadata_analysis.py index bda600bf09281..ce6181dcb6d35 100644 --- a/torch/_functorch/_aot_autograd/collect_metadata_analysis.py +++ b/torch/_functorch/_aot_autograd/collect_metadata_analysis.py @@ -92,7 +92,7 @@ def _to_fun(t): @wraps(f) def inner(*flat_args): # This function is meant to be run with the forward, which expects a flat list of tensor/symint/other args. - assert all(isinstance(a, KNOWN_TYPES) for a in flat_args) + assert all(isinstance(a, tuple(KNOWN_TYPES)) for a in flat_args) input_info: List[InputAliasInfo] = [] output_info: List[OutputAliasInfo] = [] diff --git a/torch/_functorch/_aot_autograd/utils.py b/torch/_functorch/_aot_autograd/utils.py index d3dc9dbc09834..bcf79c64b10dc 100644 --- a/torch/_functorch/_aot_autograd/utils.py +++ b/torch/_functorch/_aot_autograd/utils.py @@ -12,9 +12,7 @@ import torch.utils._pytree as pytree from torch.fx.experimental.proxy_tensor import py_sym_types -KNOWN_TYPES = tuple( - [torch.Tensor, int, str, float, bool, type(None)] + list(py_sym_types) -) +KNOWN_TYPES = [torch.Tensor, int, str, float, bool, type(None)] + list(py_sym_types) original_zip = zip diff --git a/torch/_higher_order_ops/torchbind.py b/torch/_higher_order_ops/torchbind.py new file mode 100644 index 0000000000000..bb880dde33d25 --- /dev/null +++ b/torch/_higher_order_ops/torchbind.py @@ -0,0 +1,94 @@ +from contextlib import contextmanager + +import torch +from torch._C import DispatchKey # @manual +from torch._functorch._aot_autograd.utils import KNOWN_TYPES +from torch._higher_order_ops.utils import autograd_not_implemented +from torch._ops import HigherOrderOperator +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree +from torch.fx.node import has_side_effect +from torch.utils import _pytree as pytree + +# The call_torchbind operator represents a method invocation on a torchbind +# object. The calling convention is: +# call_torchbind(self: ScriptObject, method_name: str, *method_args, **method_kwargs) +# We do not expect users to write this operator directly. Instead it will be +# emitted by Dynamo when tracing encounters a torchbind object. +call_torchbind = HigherOrderOperator("call_torchbind") + +# Register this operator as side-effectful with FX. +# TODO: this is not really sufficient. While passes (hopefully) check +# Node.is_impure() and make good decisions, we also assume we can execute the +# graph as many times as we want without changing behavior, which is NOT true of +# ops that mutate torchbind object state. +has_side_effect(call_torchbind) + +_orig_scriptmethod_call = torch.ScriptMethod.__call__ + + +def torchbind_method_redispatch(self, *args, **kwargs): + if isinstance(self.owner, torch.ScriptObject): + return call_torchbind(self.owner, self.name, *args, **kwargs) + return _orig_scriptmethod_call(self, *args, **kwargs) + + +@contextmanager +def enable_torchbind_tracing(): + """Context manager that acts as a feature flag to enable torchbind tracing + behavior. Once torchbind tracing has been stabilized, we can remove this and + turn it always on. + """ + try: + KNOWN_TYPES.append(torch.ScriptObject) + torch.ScriptMethod.__call__ = torchbind_method_redispatch # type: ignore[method-assign] + yield + finally: + assert ( + KNOWN_TYPES.pop() is torch.ScriptObject + ), "Someone else messed with KNOWN_TYPES during tracing, exploding." + torch.ScriptMethod.__call__ = _orig_scriptmethod_call # type: ignore[method-assign] + + +@call_torchbind.py_impl(DispatchKey.CompositeExplicitAutograd) +def call_torchbind_impl(obj, method, *args, **kwargs): + return _orig_scriptmethod_call(getattr(obj, method), *args, **kwargs) + + +@call_torchbind.py_impl(ProxyTorchDispatchMode) +def inner(mode, *args, **kwargs): + if mode.enable_tracing: + proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, args) + proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs) + + out_proxy = mode.tracer.create_proxy( + "call_function", + call_torchbind, + proxy_args, + proxy_kwargs, + ) + out = call_torchbind_impl(*args, **kwargs) + + return track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer) + else: + return call_torchbind(*args, **kwargs) + + +# TODO: currently we just run the C++ implementation with fake tensors. +# But we should make it possible to register a fake torchbind implementation. +@call_torchbind.py_impl(FakeTensorMode) +def call_torchbind_fake(mode, *args, **kwargs): + with mode: + return call_torchbind_impl(*args, **kwargs) + + +call_torchbind.py_impl(DispatchKey.Autograd)( + autograd_not_implemented(call_torchbind, deferred_error=True) +) + + +@call_torchbind.py_functionalize_impl +def call_torchbind_func(ctx, *args, **kwargs): + args = ctx.unwrap_tensors(args) + with ctx.redispatch_to_next(): + return ctx.wrap_tensors(call_torchbind(*args, **kwargs)) diff --git a/torch/distributed/_spmd/comm_tensor.py b/torch/distributed/_spmd/comm_tensor.py index 0f29b9a809e94..292f5b2508612 100644 --- a/torch/distributed/_spmd/comm_tensor.py +++ b/torch/distributed/_spmd/comm_tensor.py @@ -6,7 +6,7 @@ from torch._C import _disabled_torch_function_impl from torch.fx.experimental.proxy_tensor import ( _ProxyTensor, - fetch_tensor_proxy, + fetch_object_proxy, get_innermost_proxy_mode, get_proxy_slot, set_proxy_slot, @@ -193,7 +193,7 @@ def set_work(work: torch.distributed._Work, e: Any): lambda e: e.proxy, tree_map_only( torch.Tensor, - fetch_tensor_proxy(tracer), + fetch_object_proxy(tracer), (unwrapped_args, unwrapped_kwargs), ), ) diff --git a/torch/export/_trace.py b/torch/export/_trace.py index 48e969b0b0b75..cd1d00ee514eb 100644 --- a/torch/export/_trace.py +++ b/torch/export/_trace.py @@ -18,7 +18,10 @@ _AddRuntimeAssertionsForInlineConstraintsPass, ) from torch._export.passes.collect_tracepoints_pass import CollectTracepointsPass -from torch._export.passes.lift_constants_pass import lift_constants_pass +from torch._export.passes.lift_constants_pass import ( + lift_constants_pass, + rewrite_script_object_meta, +) from torch._export.wrappers import _wrap_submodules from torch._functorch.aot_autograd import aot_export_module, GraphSignature from torch._guards import detect_fake_mode @@ -452,9 +455,10 @@ def make_argument_spec(node) -> ArgumentSpec: input_specs=input_specs, output_specs=output_specs ) - constants: Dict[ - str, Union[torch.Tensor, torch._C.ScriptObject] - ] = lift_constants_pass(gm, export_graph_signature) + constants = rewrite_script_object_meta(gm) + more_constants = lift_constants_pass(gm, export_graph_signature) + for k, v in more_constants.items(): + constants[k] = v @dataclasses.dataclass class _ExportedProgramNonStrict: diff --git a/torch/export/graph_signature.py b/torch/export/graph_signature.py index fff7fa4325d72..4deb057d10a6d 100644 --- a/torch/export/graph_signature.py +++ b/torch/export/graph_signature.py @@ -5,6 +5,7 @@ __all__ = [ "ConstantArgument", + "CustomObjArgument", "ExportBackwardSignature", "ExportGraphSignature", "InputKind", @@ -28,13 +29,13 @@ class SymIntArgument: @dataclasses.dataclass -class ConstantArgument: - value: Union[int, float, bool, None] +class CustomObjArgument: + name: str @dataclasses.dataclass -class CustomObjArgument: - name: str +class ConstantArgument: + value: Union[int, float, bool, None] ArgumentSpec = Union[ diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index 92fcc5ff53ff3..18714c77b7517 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -36,7 +36,7 @@ from torch.fx import Proxy import torch.fx.traceback as fx_traceback from torch import SymInt, SymFloat, SymBool -from torch.utils.weak import WeakTensorKeyDictionary +from torch.utils.weak import WeakTensorKeyDictionary, WeakIdKeyDictionary, _WeakHashRef from torch._ops import unset_mode_pre_dispatch, _set_mode_pre_dispatch, _get_dispatch_mode_pre_dispatch __all__ = ["PythonKeyTracer", "dispatch_trace", "make_fx", "DecompositionInterpreter", "py_sym_types", "get_innermost_proxy_mode"] @@ -166,6 +166,9 @@ def set_proxy_slot(obj, tracer, proxy): # We DO want to clobber proxies whenever we run an inplace operation # on a tensor, and it affects the metadata on the proxy. tracer.tensor_tracker[obj] = proxy + elif isinstance(obj, torch.ScriptObject): + # We DO want to clobber proxies, with a similar rationale as for tensors. + tracer.script_object_tracker[obj] = proxy else: # NB: Never clobber pre-existing proxy. Although the proxies # are in principle equivalent, when we do graph partitioning @@ -187,6 +190,8 @@ def has_proxy_slot(obj, tracer): def get_proxy_slot(obj, tracer, default=no_default, transform=lambda x: x): if isinstance(obj, torch.Tensor): tracker = tracer.tensor_tracker + elif isinstance(obj, torch.ScriptObject): + tracker = tracer.script_object_tracker else: assert isinstance(obj, py_sym_types), type(obj) tracker = tracer.symnode_tracker @@ -205,6 +210,8 @@ def extract_val(val): return snapshot_fake(val) elif isinstance(val, py_sym_types): return val + elif isinstance(val, torch.ScriptObject): + return val elif isinstance(val, (list, tuple)): return val.__class__([extract_val(x) for x in val]) elif isinstance(val, torch.Tensor): @@ -302,6 +309,9 @@ def wrap_with_proxy(e, proxy, constant): # NB: eagerly set meta here, so that the numbering is in order set_meta(proxy, e) set_proxy_slot(e, tracer, lambda: proxy) + elif isinstance(e, torch.ScriptObject): + set_proxy_slot(e, tracer, proxy) + set_meta(proxy, e) elif isinstance(e, (tuple, list)): if isinstance(proxy, fx.Proxy): set_meta(proxy, e) @@ -386,7 +396,7 @@ def inner(e): return inner -def fetch_tensor_proxy(tracer): +def fetch_object_proxy(tracer): return lambda t: get_proxy_slot(t, tracer, t) HANDLED_TYPES = (torch.Tensor, torch.nn.Parameter, FakeTensor) @@ -422,7 +432,7 @@ def can_handle_tensor(x): return r tracer = proxy_mode.tracer - f_args, f_kwargs = pytree.tree_map_only(torch.Tensor, fetch_tensor_proxy(tracer), (args, kwargs)) + f_args, f_kwargs = pytree.tree_map_only((torch.Tensor, torch.ScriptObject), fetch_object_proxy(tracer), (args, kwargs)) # If there are SymInts, we also should not consider this constant. # However, fake tensor handling of SymInts is sufficiently broken that @@ -567,6 +577,7 @@ def __init__(self): super().__init__(autowrap_modules=()) self.tensor_tracker = WeakTensorKeyDictionary() self.symnode_tracker = _SymHashingDict() # type: ignore[var-annotated] + self.script_object_tracker = WeakIdKeyDictionary(dict=None, ref_type=_WeakHashRef) # In general, we don't want to make modules leaves. In principle, users of # this tracer might want to override this in order to turn a couple specific @@ -607,6 +618,8 @@ def unwrap_proxy(self, e): return get_proxy_slot(e, self, e, lambda e: e.proxy) elif isinstance(e, (torch.SymInt, torch.SymFloat, torch.SymBool)): return get_proxy_slot(e, self, e, lambda e: e()) + elif isinstance(e, torch.ScriptObject): + return get_proxy_slot(e, self, e) else: return e From f2d6e99f8dd35e297ebaa5bc5d06a37422dc930b Mon Sep 17 00:00:00 2001 From: lezcano Date: Fri, 19 Jan 2024 10:04:41 +0000 Subject: [PATCH 03/12] Workaround a cusolver bug on CUDA < 12.1 in triangular_solve (#117636) Fix https://github.com/pytorch/pytorch/issues/79191 Pull Request resolved: https://github.com/pytorch/pytorch/pull/117636 Approved by: https://github.com/malfet --- .../cuda/linalg/BatchLinearAlgebraLibBlas.cpp | 14 ++++++++++++++ test/test_linalg.py | 9 +++++++++ 2 files changed, 23 insertions(+) diff --git a/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLibBlas.cpp b/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLibBlas.cpp index 38e7b8dd3288b..d882791b58baa 100644 --- a/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLibBlas.cpp +++ b/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLibBlas.cpp @@ -218,6 +218,20 @@ static void apply_triangular_solve_batched(const Tensor& A, const Tensor& B, boo } void triangular_solve_batched_cublas(const Tensor& A, const Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular) { + // Workaround the following a bug on CUDA < 12.1 + // RuntimeError: CUDA error: CUBLAS_STATUS_EXECUTION_FAILED when calling `cublasStrsmBatched + // See https://github.com/pytorch/pytorch/issues/79191#issuecomment-1154222580 +#if defined(CUSOLVER_VERSION) && CUSOLVER_VERSION < 12100 + constexpr auto max_batch_size = 524280; + if (B.size(-1) > max_batch_size) { + auto n_chunks = (B.size(-1) + max_batch_size - 1) / max_batch_size; // ceildiv + auto splits = B.split(n_chunks, /*dim=*/-1); + for (const Tensor& b : splits) { + triangular_solve_batched_cublas(A, b, left, upper, transpose, unitriangular); + } + return; + } +#endif AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(A.scalar_type(), "triangular_solve_cuda", [&]{ apply_triangular_solve_batched(A, B, left, upper, transpose, unitriangular); }); diff --git a/test/test_linalg.py b/test/test_linalg.py index b25ee8379d6bf..6fd7a7e97c5bc 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -4287,6 +4287,15 @@ def run_test(A_dims, b_dims, device, upper, transpose, unitriangular): run_test((4, 4), (2, 1, 3, 4, 2), device, upper, transpose, unitriangular) # broadcasting A run_test((1, 3, 1, 4, 4), (2, 1, 3, 4, 5), device, upper, transpose, unitriangular) # broadcasting A & b + @onlyCUDA + @dtypes(torch.float) + def test_triangular_solve_large(self, device, dtype): + # Repro for https://github.com/pytorch/pytorch/issues/79191 + A = torch.randn(1, 2, 2, device=device, dtype=dtype).tril_() + B = torch.randn(1, 2, 524281, device=device, dtype=dtype) + X = torch.linalg.solve_triangular(A, B, upper=False) + self.assertEqual(A @ X, B) + @skipCUDAIfNoMagma @skipCPUIfNoLapack @dtypes(*floating_and_complex_types()) From 5756b7a08e00292330500f5f237784f178dd4b99 Mon Sep 17 00:00:00 2001 From: cyy Date: Fri, 19 Jan 2024 12:56:17 +0000 Subject: [PATCH 04/12] Remove math_compat.h (#117828) Follows #116167 Pull Request resolved: https://github.com/pytorch/pytorch/pull/117828 Approved by: https://github.com/malfet --- c10/util/math_compat.h | 256 ----------------------------------------- 1 file changed, 256 deletions(-) delete mode 100644 c10/util/math_compat.h diff --git a/c10/util/math_compat.h b/c10/util/math_compat.h deleted file mode 100644 index 39cc5b04c0812..0000000000000 --- a/c10/util/math_compat.h +++ /dev/null @@ -1,256 +0,0 @@ -#pragma once - -#include - -// Android NDK platform < 21 with libstdc++ has spotty C++11 support. -// Various hacks in this header allow the rest of the codebase to use -// standard APIs. -#if (defined(__ANDROID__) && __ANDROID_API__ < 21 && defined(__GLIBCXX__)) || \ - defined(__NEWLIB__) -#include - -namespace std { -// Import double versions of these functions from the global namespace. -using ::acosh; -using ::asinh; -using ::atanh; -using ::erf; -using ::erfc; -using ::expm1; -using ::lgamma; -using ::log1p; -using ::nearbyint; -using ::round; -using ::tgamma; -using ::trunc; -using ::truncf; - -// Define float versions the same way as more recent libstdc++ -inline float acosh(float x) { - return __builtin_acoshf(x); -} -inline float asinh(float x) { - return __builtin_asinhf(x); -} -inline float atanh(float x) { - return __builtin_atanhf(x); -} -inline float copysign(float x, float y) { - return __builtin_copysignf(x, y); -} -inline float erf(float x) { - return __builtin_erff(x); -} -inline float erfc(float x) { - return __builtin_erfcf(x); -} -inline float expm1(float x) { - return __builtin_expm1f(x); -} -inline float fmax(float x, float y) { - return __builtin_fmaxf(x, y); -} -inline float fmin(float x, float y) { - return __builtin_fminf(x, y); -} -inline float lgamma(float x) { - return __builtin_lgammaf(x); -} -inline float log1p(float x) { - return __builtin_log1pf(x); -} -inline float nearbyint(float x) { - return __builtin_nearbyintf(x); -} -inline float remainder(float x, float y) { - return __builtin_remainderf(x, y); -} -inline float round(float x) { - return __builtin_roundf(x); -} -inline float tgamma(float x) { - return __builtin_tgammaf(x); -} -inline float trunc(float x) { - return __builtin_truncf(x); -} - -// __builtin_nexttoward isn't doesn't work. It appears to try to -// link against the global nexttoward function, which is not present -// prior to API 18. Just bail for now. -inline float nexttoward(float x, long double y) { - throw std::runtime_error("std::nexttoward is not present on older Android"); -} -inline double nexttoward(double x, long double y) { - throw std::runtime_error("std::nexttoward is not present on older Android"); -} - -#if !defined(__NEWLIB__) -// TODO: this function needs to be implemented and tested. Currently just throw -// an error. -inline float hypot(float x, float y) { - throw std::runtime_error("std::hypot is not implemented on older Android"); -} -inline double hypot(double x, double y) { - throw std::runtime_error("std::hypot is not implemented on older Android"); -} -#else -inline float hypot(float x, float y) { - return hypot((double)x, (double)y); -} -#endif - -// TODO: this function needs to be implemented and tested. Currently just throw -// an error. -inline float igamma(float x, float y) { - throw std::runtime_error("igamma is not implemented on older Android"); -} -inline double igamma(double x, double y) { - throw std::runtime_error("igamma is not implemented on older Android"); -} -inline float igammac(float x, float y) { - throw std::runtime_error("igammac is not implemented on older Android"); -} -inline double igammac(double x, double y) { - throw std::runtime_error("igammac is not implemented on older Android"); -} - -// Note: std::signbit returns true for negative zero (-0), but this -// implementation returns false. -inline bool signbit(float x) { - return x < 0; -} -inline bool signbit(double x) { - return x < 0; -} -inline bool signbit(long double x) { - return x < 0; -} - -#if !defined(__NEWLIB__) -// TODO: this function needs to be implemented and tested. Currently just throw -// an error. -inline float nextafter(float x, float y) { - throw std::runtime_error( - "std::nextafter is not implemented on older Android"); -} -inline double nextafter(double x, double y) { - throw std::runtime_error( - "std::nextafter is not implemented on older Android"); -} -#else -inline float nextafter(float x, float y) { - return nextafter((double)x, (double)y); -} -#endif - -#if !defined(__NEWLIB__) -// TODO: this function needs to be implemented and tested. Currently just throw -// an error. -inline float exp2(float x) { - throw std::runtime_error("std::exp2 is not implemented on older Android"); -} -inline double exp2(double x) { - throw std::runtime_error("std::exp2 is not implemented on older Android"); -} -#else -inline float exp2(float x) { - return exp2((double)x); -} -#endif - -// Define integral versions the same way as more recent libstdc++ -template -typename std::enable_if::value, double>::type acosh(T x) { - return __builtin_acosh(x); -} -template -typename std::enable_if::value, double>::type asinh(T x) { - return __builtin_asinh(x); -} -template -typename std::enable_if::value, double>::type atanh(T x) { - return __builtin_atanh(x); -} -template -typename std::enable_if::value, double>::type erf(T x) { - return __builtin_erf(x); -} -template -typename std::enable_if::value, double>::type erfc(T x) { - return __builtin_erfc(x); -} -template -typename std::enable_if::value, double>::type expm1(T x) { - return __builtin_expm1(x); -} -template -typename std::enable_if::value, double>::type lgamma(T x) { - return __builtin_lgamma(x); -} -template -typename std::enable_if::value, double>::type log1p(T x) { - return __builtin_log1p(x); -} -template -typename std::enable_if::value, double>::type nearbyint( - T x) { - return __builtin_nearbyint(x); -} -template -typename std::enable_if::value, double>::type round(T x) { - return __builtin_round(x); -} -template -typename std::enable_if::value, double>::type tgamma(T x) { - return __builtin_tgamma(x); -} -template -typename std::enable_if::value, double>::type trunc(T x) { - return __builtin_trunc(x); -} - -// Convoluted definition of these binary functions for overloads other than -// (float,float) and (double,double). Using a template from __gnu_cxx -// is dirty, but this code is only enabled on a dead platform, so there -// shouldn't be any risk of it breaking due to updates. -template -typename __gnu_cxx::__promote_2::__type fmax(T x, U y) { - typedef typename __gnu_cxx::__promote_2::__type type; - return fmax(type(x), type(y)); -} -template -typename __gnu_cxx::__promote_2::__type fmin(T x, U y) { - typedef typename __gnu_cxx::__promote_2::__type type; - return fmin(type(x), type(y)); -} -template -typename __gnu_cxx::__promote_2::__type copysign(T x, U y) { - typedef typename __gnu_cxx::__promote_2::__type type; - return copysign(type(x), type(y)); -} -template -typename __gnu_cxx::__promote_2::__type remainder(T x, U y) { - typedef typename __gnu_cxx::__promote_2::__type type; - return remainder(type(x), type(y)); -} - -// log2 is a macro on Android API < 21, so we need to define it ourselves. -inline float log2(float arg) { - return ::log(arg) / ::log(2.0); -} -#if !defined(__NEWLIB__) -inline double log2(double arg) { - return ::log(arg) / ::log(2.0); -} -#endif -inline long double log2(long double arg) { - return ::log(arg) / ::log(2.0); -} -template -typename std::enable_if::value, double>::type log2(T x) { - return ::log(x) / ::log(2.0); -} -} // namespace std - -#endif From f115f1cde16092a4a79efa40992f961455c255e8 Mon Sep 17 00:00:00 2001 From: le-zheng Date: Mon, 15 Jan 2024 17:14:23 +0800 Subject: [PATCH 05/12] [Quant] Enable QConv2d with hardswish post op (#117487) **Summary** Enable QConv2d implementation with post op `hardswish` **Test Plan** ``` python -m pytest test_quantized_op.py -k test_qconv2d_hardswish_pt2e ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/117487 Approved by: https://github.com/leslie-fang-intel, https://github.com/jgong5 --- aten/src/ATen/native/quantized/cpu/qconv.cpp | 6 ++- test/quantization/core/test_quantized_op.py | 55 ++++++++++++++++++++ 2 files changed, 59 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/native/quantized/cpu/qconv.cpp b/aten/src/ATen/native/quantized/cpu/qconv.cpp index 9f3c790d52c75..404a95b376470 100644 --- a/aten/src/ATen/native/quantized/cpu/qconv.cpp +++ b/aten/src/ATen/native/quantized/cpu/qconv.cpp @@ -1658,6 +1658,8 @@ static at::Tensor _quantized_convolution_onednn( auto upper_bound_value = unary_scalars[1].get().toOptional().value().to(); op_attr = ideep::attr_t::fuse_clamp(lower_bound_value, upper_bound_value); + } else if (has_unary_post_op && unary_attr.value()=="hardswish") { + op_attr = ideep::attr_t::fuse_hardswish(); } else { op_attr = ideep::attr_t(); } @@ -1851,8 +1853,8 @@ class QConvoneDNN final { } else { // Conv2D post op check TORCH_CHECK( - attr == "none" || attr == "relu" || attr == "hardtanh", - "none post_op or post_op relu/hardtanh is supported for quantized pointwise conv2d. Got unary_post_op: ", + attr == "none" || attr == "relu" || attr == "hardtanh" || attr == "hardswish", + "none post_op or post_op relu/hardtanh/hardswish is supported for quantized pointwise conv2d. Got unary_post_op: ", attr, ".") } diff --git a/test/quantization/core/test_quantized_op.py b/test/quantization/core/test_quantized_op.py index 47564b1ccac19..3fe8c06839943 100644 --- a/test/quantization/core/test_quantized_op.py +++ b/test/quantization/core/test_quantized_op.py @@ -6574,6 +6574,10 @@ def _test_qconv_impl_cpu_tensor( assert len(post_op.scalars) == 2, "For post op hardtanh, expect 2 parameters passed in" hardtanh = torch.nn.Hardtanh(min_val=post_op.scalars[0], max_val=post_op.scalars[1]) result_ref = hardtanh(result_ref) + elif post_op.unary_attr == "hardswish": + assert not use_transpose, "Cannot fuse hardswish with ConvTranspose" + hardswish = torch.nn.Hardswish() + result_ref = hardswish(result_ref) # Quantize reference results for comparison result_ref_q = torch.quantize_per_tensor( @@ -6968,6 +6972,57 @@ def test_qconv2d_hardtanh_pt2e(self): qconv_output_dtype=output_dtype, ) + # Test qconv with post op hardswish + @skipIfNoONEDNN + def test_qconv2d_hardswish_pt2e(self): + input_channels_per_group = 2 + output_channels_per_group = 2 + groups_list = [1, 10] + input_feature_map_shape = (10, 10) + kernels = (3, 3) + strides = (2, 2) + pads = (1, 1) + dilations = (1, 1) + W_scale = [1.5] + W_zero_point = [0] + use_bias_list = [False, True] + use_channelwise_list = [False, True] + output_dtype_list = [None, torch.float32, torch.bfloat16] + options = itertools.product(groups_list, use_bias_list, use_channelwise_list, output_dtype_list) + + for groups, use_bias, use_channelwise, output_dtype in options: + qconv = torch.ops.onednn.qconv2d_pointwise + qconv_prepack = torch.ops.onednn.qconv_prepack + conv_op = torch.nn.Conv2d( + input_channels_per_group * groups, + output_channels_per_group * groups, + kernels, + strides, + pads, + dilations, + groups, + ) + pointwise_post_op = PointwisePostOp(unary_attr="hardswish") + self._test_qconv_impl_cpu_tensor( + qconv, + qconv_prepack, + conv_op, + input_channels_per_group=input_channels_per_group, + input_feature_map_shape=input_feature_map_shape, + output_channels_per_group=output_channels_per_group, + groups=groups, + kernels=kernels, + strides=strides, + pads=pads, + dilations=dilations, + W_scale=W_scale, + W_zero_point=W_zero_point, + use_bias=use_bias, + post_op=pointwise_post_op, + use_channelwise=use_channelwise, + qconv_output_dtype=output_dtype, + ) + # Test qconv with post op sum @skipIfNoONEDNN def test_qconv2d_sum_pt2e(self): From 17c5f6985299ae9c5520c6b51fabc3c4c4baf7bb Mon Sep 17 00:00:00 2001 From: rzou Date: Thu, 18 Jan 2024 21:29:01 -0800 Subject: [PATCH 06/12] Run test_jit with PYTORCH_TEST_WITH_DYNAMO=1 in CI (#117765) Gets rid of all the single test excludes. Pull Request resolved: https://github.com/pytorch/pytorch/pull/117765 Approved by: https://github.com/voznesenskym --- .ci/pytorch/test.sh | 2 - test/jit/test_backend_nnapi.py | 3 +- test/jit/test_freezing.py | 6 +- test/jit/test_torchbind.py | 2 + .../testing/_internal/dynamo_test_failures.py | 343 ++++++++++++++++++ 5 files changed, 352 insertions(+), 4 deletions(-) diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index a402a3d94b301..d3ac34280368f 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -274,8 +274,6 @@ test_dynamo_shard() { --exclude-inductor-tests \ --exclude-jit-executor \ --exclude-distributed-tests \ - --exclude \ - test_jit \ --shard "$1" "$NUM_TEST_SHARDS" \ --verbose assert_git_not_dirty diff --git a/test/jit/test_backend_nnapi.py b/test/jit/test_backend_nnapi.py index e0f7f671147c3..8ca4083e5f8bb 100644 --- a/test/jit/test_backend_nnapi.py +++ b/test/jit/test_backend_nnapi.py @@ -7,7 +7,7 @@ import torch import torch._C from pathlib import Path -from torch.testing._internal.common_utils import IS_FBCODE +from torch.testing._internal.common_utils import IS_FBCODE, skipIfTorchDynamo # hacky way to skip these tests in fbcode: # during test execution in fbcode, test_nnapi is available during test discovery, @@ -40,6 +40,7 @@ # First skip is needed for IS_WINDOWS or IS_MACOS to skip the tests. torch_root = Path(__file__).resolve().parent.parent.parent lib_path = torch_root / 'build' / 'lib' / 'libnnapi_backend.so' +@skipIfTorchDynamo("weird py38 failures") @unittest.skipIf(not os.path.exists(lib_path), "Skipping the test as libnnapi_backend.so was not found") @unittest.skipIf(IS_FBCODE, "test_nnapi.py not found") diff --git a/test/jit/test_freezing.py b/test/jit/test_freezing.py index 854a54ce042c0..783d29728905d 100644 --- a/test/jit/test_freezing.py +++ b/test/jit/test_freezing.py @@ -12,7 +12,7 @@ from torch.testing import FileCheck from torch.testing._internal.common_quantization import skipIfNoFBGEMM from torch.testing._internal.common_quantized import override_quantized_engine -from torch.testing._internal.common_utils import set_default_dtype, skipCUDAMemoryLeakCheckIf, TEST_WITH_ROCM +from torch.testing._internal.common_utils import set_default_dtype, skipCUDAMemoryLeakCheckIf, TEST_WITH_ROCM, skipIfTorchDynamo from torch.testing._internal.common_cuda import TEST_CUDNN, TEST_CUDA from torch.testing._internal.jit_utils import JitTestCase from torch.utils import mkldnn as mkldnn_utils @@ -35,6 +35,8 @@ def removeExceptions(graph): for n in graph.findAllNodes('prim::RaiseException'): n.destroy() + +@skipIfTorchDynamo("somehow causing hanging during python shutdown") class TestFreezing(JitTestCase): def test_freeze_module(self): class M(nn.Module): @@ -1988,6 +1990,7 @@ def make_prediction(self, x): mod.forward(x), unscripted_mod.forward(x), atol=1e-5, rtol=1e-5 ) +@skipIfTorchDynamo("somehow causing hanging during python shutdown") class TestFrozenOptimizations(JitTestCase): def setUp(self): super().setUp() @@ -2979,6 +2982,7 @@ def forward(self, x): FileCheck().check("aten::detach").run(frozen_mod.graph) self.assertEqual(frozen_mod(inp), mod(inp)) +@skipIfTorchDynamo("somehow causing hanging during python shutdown") @unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled") class TestMKLDNNReinplacing(JitTestCase): def setUp(self): diff --git a/test/jit/test_torchbind.py b/test/jit/test_torchbind.py index 5e046a083edb5..e7ffb056c19ea 100644 --- a/test/jit/test_torchbind.py +++ b/test/jit/test_torchbind.py @@ -8,6 +8,7 @@ import torch from typing import Optional +from torch.testing._internal.common_utils import skipIfTorchDynamo # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) @@ -29,6 +30,7 @@ "instead." ) +@skipIfTorchDynamo("skipping as a precaution") class TestTorchbind(JitTestCase): def setUp(self): if IS_SANDCASTLE or IS_MACOS or IS_FBCODE: diff --git a/torch/testing/_internal/dynamo_test_failures.py b/torch/testing/_internal/dynamo_test_failures.py index 3a07af3f88c03..49547587b1511 100644 --- a/torch/testing/_internal/dynamo_test_failures.py +++ b/torch/testing/_internal/dynamo_test_failures.py @@ -8228,6 +8228,349 @@ "LoggingTests.test_trace_call", # known py311 fail "LoggingTests.test_trace_call_graph_break", # known py311 fail "LoggingTests.test_trace_call_inline_call", # known py311 fail + "TestPythonBuiltinOP.test_stepped_tuple_slicing", # known py38 fail + "TestPythonBuiltinOP.test_advancedindex", # known py38 fail + "TestCustomOperators.test_dynamic_op_registry", # known py38 fail + "TestComplex.test_complex_constants_and_ops", # known py38 fail + "TestPythonBuiltinOP.test_index", # known py38 fail + "TestHash.test_hash_tuple_nested_unhashable_type", # test_jit + "TestPeephole.test_peephole_dict_getitem_no_optimization_missing_key", # test_jit + "TestMisc.test_joined_str", # test_jit + "TestNnapiBackend.test_to", # test_jit + "TestIsinstance.test_dict_tensor", # test_jit + "TestPythonBuiltinOP.test_adv_indexing_list", # test_jit + "TestUnion.test_union_redundant_arguments_are_skipped", # test_jit + "TestPythonBuiltinOP.test_inf", # test_jit + "TestSymbolicShapeAnalysis.test_register_function_error_checking", # test_jit + "TestPythonBuiltinOP.test_pow", # test_jit + "TestTyping.test_tuple_io", # test_jit + "TestPeephole.test_peephole_dict_len_no_optimization_unsupported_type", # test_jit + "TestRemoveMutation.test_lists_append", # test_jit + "TestSlice.test_slice_tensor", # test_jit + "TestMisc.test_str_refine_any", # test_jit + "TestNnapiBackend.test_prelu", # test_jit + "TestFreezing.test_freeze_module_with_fork2", # test_jit + "TestPeephole.test_peephole_dict_len_no_optimization_overlapping_keys", # test_jit + "TestPeephole.test_peephole_with_non_output_writes", # test_jit + "TestCustomOperators.test_script_graph_contains_custom_op", # test_jit + "TestTorchbind.test_torchbind_getattr", # test_jit + "TestRecursiveScript.test_inner_traced_module", # test_jit + "TestAliasAnalysis.test_recursive_calls", # test_jit + "TestNnapiBackend.test_pointwise_unary", # test_jit + "TestDeviceAnalysis.test_device_apply", # test_jit + "TestList.test_mutable_list_function_inline", # test_jit + "TestList.test_comprehensions_two_comps", # test_jit + "TestNnapiBackend.test_seblock_mul", # test_jit + "TestTorchbind.test_torchbind_return_instance", # test_jit + "TestRemoveMutation.test_if_output", # test_jit + "TestModels.test_time_sequence_prediction", # test_jit + "TestRemoveMutation.test_list_indexing_removal", # test_jit + "TestTypesAndAnnotation.test_types_as_values", # test_jit + "TestAwait.test_await_multiout_save", # test_jit + "TestHash.test_hash_device", # test_jit + "TestPeephole.test_peephole_dict_len", # test_jit + "TestSlice.test_slice_dynamic_index", # test_jit + "TestGenerator.test_default_generator", # test_jit + "TestMisc.test_parse_ir_single_element_tensor_negative", # test_jit + "TestTyping.test_list_unification", # test_jit + "TestList.test_del", # test_jit + "TestAwait.test_script", # test_jit + "TestNnapiBackend.test_avg_pool2d", # test_jit + "TestIsinstance.test_list_tensor_type_true", # test_jit + "TestDtypeAnalysis.test_conv_no_mixed_args", # test_jit + "TestBackends.test_errors", # test_jit + "TestMisc.test_parse_ir_annotate", # test_jit + "TestTorchbind.test_torchbind_pickle_serialization", # test_jit + "TestList.test_copy_list_immutable", # test_jit + "TestAsync.test_async_grad_guard_with_grad", # test_jit + "TestUnion.test_union_branching_does_not_autoinfer_undeclared_union", # test_jit + "TestNnapiBackend.test_slice", # test_jit + "TestWarn.test_warn", # test_jit + "TestRemoveMutation.test_special_mapped_op", # test_jit + "TestWarn.test_warn_multiple_calls_same_func_diff_stack", # test_jit + "TestSymbolicShapeAnalysis.test_convolution_backward", # test_jit + "TestUnion.test_union_type_refinement_statically_false", # test_jit + "TestTorchbind.test_default_args", # test_jit + "TestUpgraders.test_aten_full_other_variants", # test_jit + "TestScriptDict.test_reference_semantics", # test_jit + "TestUnion.test_union_does_not_replace_existing_annotated_type_union", # test_jit + "TestTyping.test_dict_invalid_annotations", # test_jit + "TestWith.test_with_no_grad", # test_jit + "TestUnion.test_union_branching_does_not_widen_existing_inferred_type", # test_jit + "TestTorchbind.test_torchbind_return_tuple", # test_jit + "TestTorchbind.test_staticmethod", # test_jit + "TestUnion.test_union_variable_can_be_reassigned", # test_jit + "TestTorchbind.test_torchbind_def_property_readwrite", # test_jit + "TestTorchbind.test_torchbind_attr_exception", # test_jit + "TestFunctionalToInplaceActivation.test_no_functional_to_inplace", # test_jit + "TestTorchbind.test_torchbind_class_attr_recursive", # test_jit + "TestBuiltins.test_del", # test_jit + "TestNnapiBackend.test_mean", # test_jit + "TestNnapiBackend.test_reshape", # test_jit + "TestFrozenOptimizations.test_collapse_adjacent_conversions", # test_jit + "TestTorchbind.test_torchbind_python_deepcopy", # test_jit + "TestPythonBindings.test_aliasdb", # test_jit + "TestParametrization.test_scriptable", # test_jit + "TestMKLDNNReinplacing.test_always_alive_values", # test_jit + "TestAsync.test_async_script_multi_waits", # test_jit + "TestTyping.test_dict_type_refinement_annotation_value_mismatch", # test_jit + "TestScriptList.test_reference_semantics", # test_jit + "TestPeephole.test_peephole_arith", # test_jit + "TestPythonBuiltinOP.test_gather", # test_jit + "TestNnapiBackend.test_upsample_nearest2d", # test_jit + "TestList.test_copy_list_mutable", # test_jit + "TestWarn.test_warn_only_once", # test_jit + "TestPythonBuiltinOP.test_str_to_float", # test_jit + "TestIsinstance.test_optional", # test_jit + "TestCustomOperators.test_calling_scripted_custom_op", # test_jit + "TestUnion.test_union_T_None_is_equivalent_to_optional_T", # test_jit + "TestSlice.test_slice_tensor_multidim_with_dots", # test_jit + "TestNnapiBackend.test_multi_output", # test_jit + "TestSymbolicShapeAnalysis.test_squeeze_dims", # test_jit + "TestPeephole.test_peephole_int", # test_jit + "TestUnion.test_unions_of_a_single_argument_vanish", # test_jit + "TestTorchbind.test_profiler_custom_op", # test_jit + "TestTorchbind.test_torchbind_class_attribute", # test_jit + "TestUnion.test_check_union_annotation", # test_jit + "TestTypesAndAnnotation.test_optional_no_element_type_annotation", # test_jit + "TestList.test_comprehension_iterable", # test_jit + "TestUpgraders.test_aten_test_serialization", # test_jit + "TestPythonBuiltinOP.test_mul", # test_jit + "TestAwait.test_nowait", # test_jit + "TestBuiltins.test_del_multiple_operands", # test_jit + "TestTypesAndAnnotation.test_bad_types", # test_jit + "TestSymbolicShapeAnalysis.test_cross_entropy_loss", # test_jit + "TestRemoveMutation.test_aten_inplace", # test_jit + "TestWarn.test_warn_only_once_in_loop_func", # test_jit + "TestDataclasses.test_use_unregistered_dataclass_raises", # test_jit + "TestTorchbind.test_torchbind_optional_explicit_attr", # test_jit + "TestTorchbind.test_torchbind_pass_wrong_type", # test_jit + "TestList.test_list_variance", # test_jit + "TestMisc.test_subexpression_Dict_int_Future", # test_jit + "TestMisc.test_future_isinstance", # test_jit + "TestPythonBuiltinOP.test_slice", # test_jit + "TestPeephole.test_short_circuit_optimization", # test_jit + "TestPeephole.test_peephole_slice_optimization_not_applied_list_modified", # test_jit + "TestTyping.test_namedtuple_good_error", # test_jit + "TestMisc.test_subexpression_List_Future", # test_jit + "TestDtypeAnalysis.test_combined", # test_jit + "TestFunctionalBlocks.test_subgraph_creation", # test_jit + "TestList.test_extend_list_mutable", # test_jit + "TestPythonBindings.test_cu_get_functions", # test_jit + "TestLogging.test_trace_numeric_counter", # test_jit + "TestBatchMM.test_batch_mm_side_prohibited_mutation_common_side", # test_jit + "TestPeephole.test_peephole_dynamic", # test_jit + "TestTorchbind.test_torchbind_def_property_getter_setter", # test_jit + "TestSymbolicShapeAnalysis.test_size_and_sizes", # test_jit + "TestAsync.test_async_script", # test_jit + "TestAsync.test_async_parsing", # test_jit + "TestAwait.test_await_func_arg", # test_jit + "TestTyping.test_dict_type_refinement_annotation_key_mismatch", # test_jit + "TestNnapiBackend.test_softmax", # test_jit + "TestDataclasses.test__post_init__", # test_jit + "TestPeephole.test_normalized_is_op", # test_jit + "TestMisc.test_broadcasting_list", # test_jit + "TestIsinstance.test_optional_no_contained_type", # test_jit + "TestUnion.test_union_argument_order_is_ignored", # test_jit + "TestUnion.test_union_argument_order_is_ignored_container", # test_jit + "TestAutodiffSubgraphSlicing.test_chunk_constant_script_ad", # test_jit + "TestBackends.test_save_load", # test_jit + "TestIsinstance.test_list_tensor", # test_jit + "TestComplex.test_tensor_attributes", # test_jit + "TestRemoveMutation.test_lists_insert", # test_jit + "TestNnapiBackend.test_qlinear", # test_jit + "TestNnapiBackend.test_quantize", # test_jit + "TestNnapiBackend.test_unsqueeze", # test_jit + "TestTorchbind.test_lambda_as_constructor", # test_jit + "TestTyping.test_dict_comprehension_with_type_annotation", # test_jit + "TestAtenPow.test_aten_pow_zero_negative_exponent", # test_jit + "TestUnion.test_union_as_dict_key", # test_jit + "TestTyping.test_optional_refinement", # test_jit + "TestPeephole.test_peephole_type_refinements", # test_jit + "TestSlice.test_slice_kwarg", # test_jit + "TestStringFormatting.test_string_interpolation_with_too_many_arguments", # test_jit + "TestTorchbind.test_torchbind_getstate", # test_jit + "TestTyping.test_dict_comprehension_scope", # test_jit + "TestRemoveMutation.test_if_output_fail", # test_jit + "TestMisc.test_legacy_tensor_constructor", # test_jit + "TestBatchMM.test_batch_mm_prohibited_mutation_multiple_adds", # test_jit + "TestSlice.test_slice_tensor_multidim", # test_jit + "TestPeephole.test_peephole_slice_two_empty_args", # test_jit + "TestTyping.test_namedtuple_py2", # test_jit + "TestUnion.test_union_type_refinement_statically_true", # test_jit + "TestRecursiveScript.test_script_function_attribute", # test_jit + "TestPeephole.test_peephole", # test_jit + "TestAwait.test_await_python", # test_jit + "TestPythonBuiltinOP.test_triple", # test_jit + "TestTorchbind.test_torchbind_take_as_arg", # test_jit + "TestNnapiBackend.test_qadd", # test_jit + "TestTypesAndAnnotation.test_pep585_type", # test_jit + "TestNnapiBackend.test_detach", # test_jit + "TestAsync.test_async_script_multi_forks", # test_jit + "TestPythonBindings.test_invalidation", # test_jit + "TestTyping.test_for_tuple_unpack", # test_jit + "TestTorchbind.test_torchbind_deepcopy", # test_jit + "TestTorchbind.test_torchbind_instantiate_missing_class", # test_jit + "TestSymbolicShapeAnalysis.test_if_propagation", # test_jit + "TestPeephole.test_normalized_rsub", # test_jit + "TestPythonIr.test_param_strides", # test_jit + "TestComplex.test_complex_list_sum", # test_jit + "TestUnion.test_union_redundant_arguments_are_skipped_optional", # test_jit + "TestNnapiBackend.test_conv2d", # test_jit + "TestDtypeAnalysis.test_unary", # test_jit + "TestPeephole.test_peephole_dict_len_no_optimization_keys_might_overlap", # test_jit + "TestIsinstance.test_dict_no_contained_type", # test_jit + "TestList.test_extend_list_immutable", # test_jit + "TestFrozenOptimizations.test_conv_add_folding", # test_jit + "TestGenerator.test_generator_arg", # test_jit + "TestTensorBuiltins.test_method_on_number", # test_jit + "TestUnion.test_union_optional_of_union_is_flattened", # test_jit + "TestUnion.test_union_type_refinement_tuple_rhs_union", # test_jit + "TestList.test_no_element_type_annotation", # test_jit + "TestParametrization.test_traceable", # test_jit + "TestSymbolicShapeAnalysis.test_shape_analysis", # test_jit + "TestScriptProfile.test_script", # test_jit + "TestSymbolicShapeAnalysis.test_write", # test_jit + "TestPeephole.test_peephole_slice_optimization_not_applied_non_const_args", # test_jit + "TestNnapiBackend.test_cat", # test_jit + "TestList.test_mutable_list_pop_empty", # test_jit + "TestMisc.test_subexpression_Optional", # test_jit + "TestUnion.test_union_does_not_replace_existing_annotated_type", # test_jit + "TestTorchbind.test_torchbind_return_instance_from_method", # test_jit + "TestTyping.test_opt_opt_refinement", # test_jit + "TestIsinstance.test_tuple_tensor", # test_jit + "TestUpgraders.test_populated_test_upgrader_graph", # test_jit + "TestList.test_slice_index", # test_jit + "TestTyping.test_tuple_assignments", # test_jit + "TestAsync.test_async_python", # test_jit + "TestBatchMM.test_batch_mm_prohibited_mutation", # test_jit + "TestFreezing.test_freeze_module_with_fork_calling_module_method", # test_jit + "TestUnion.test_unions_of_unions_are_flattened", # test_jit + "TestTypeSharing.test_script_function_attribute_different", # test_jit + "TestTorchbind.test_torchbind_lambda_method", # test_jit + "TestTypesAndAnnotation.test_unimported_type_resolution", # test_jit + "TestUnion.test_union_redundant_arguments_are_skipped_container", # test_jit + "TestPythonBindings.test_cu_create_function", # test_jit + "TestTorchbind.test_torchbind_tracing", # test_jit + "TestWarn.test_warn_once_per_func_in_loop", # test_jit + "TestBackendsWithCompiler.test_errors", # test_jit + "TestSaveLoadForOpVersion.test_versioned_div_tensor_inplace", # test_jit + "TestList.test_to_list", # test_jit + "TestUpgraders.test_populated_upgrader_graph", # test_jit + "TestWarn.test_warn_multiple_calls_multiple_warnings", # test_jit + "TestLogging.test_counter_aggregation", # test_jit + "TestTorchbind.test_torchbind_take_instance_as_method_arg", # test_jit + "TestComplex.test_complex_parse", # test_jit + "TestTorchbind.test_torchbind_save_load", # test_jit + "TestPeephole.test_integer_refinement", # test_jit + "TestBatchMM.test_batch_mm_prohibited_mutation_if_node", # test_jit + "TestHash.test_hash_tensor", # test_jit + "TestAsync.test_trace_fork_wait_inline", # test_jit + "TestTensorBuiltins.test_tensor_item", # test_jit + "TestList.test_list_keyword", # test_jit + "TestTypesAndAnnotation.test_ignore_with_types", # test_jit + "TestPeephole.test_peephole_slice_one_empty_arg", # test_jit + "TestAsync.test_async_script_nested", # test_jit + "TestNnapiBackend.test_flatten", # test_jit + "TestAsync.test_future_subtyping", # test_jit + "TestTorchbind.test_torchbind_no_init", # test_jit + "TestModels.test_vae_quantized", # test_jit + "TestSymbolicShapeAnalysis.test_shared_shape_graph", # test_jit + "TestNnapiBackend.test_dequantize", # test_jit + "TestPeephole.test_peephole_optional_refine", # test_jit + "TestTorchbind.test_torchbind", # test_jit + "TestAwait.test_await_out_of_interpreter", # test_jit + "TestNnapiBackend.test_conv2d_transpose", # test_jit + "TestNnapiBackend.test_max_pool2d", # test_jit + "TestPeephole.test_peephole_list_ops", # test_jit + "TestTyping.test_optional_conversion", # test_jit + "TestNnapiBackend.test_linear", # test_jit + "TestPythonBuiltinOP.test_add", # test_jit + "TestIsinstance.test_tuple_no_contained_type", # test_jit + "TestTyping.test_bool_list_io", # test_jit + "TestPeephole.test_peephole_dict_getitem_no_optimization_dict_modified", # test_jit + "TestNnapiBackend.test_compile_spec_santiy", # test_jit + "TestDtypeAnalysis.test_custom_rules", # test_jit + "TestPeephole.test_peephole_len_list", # test_jit + "TestTyping.test_dict_in_not_in", # test_jit + "TestUnion.test_union_redundant_arguments_are_skipped_subtyping", # test_jit + "TestTensorMethods.test_getitem", # test_jit + "TestPeephole.test_peephole_dict_getitem_no_optimization_overlapping_keys", # test_jit + "TestDataclasses.test_comparators", # test_jit + "TestTyping.test_tuple_specialization", # test_jit + "TestModels.test_snli_quantized", # test_jit + "TestGenerator.test_script", # test_jit + "TestAsync.test_async_script_error", # test_jit + "TestUnion.test_union_with_collections", # test_jit + "TestList.test_list_index_not_existing", # test_jit + "TestStringFormatting.test_string_interpolation_with_exponent_placeholder_and_string_variable", # test_jit + "TestStringFormatting.test_string_interpolation_with_too_few_arguments", # test_jit + "TestMisc.test_unsafe_hacked_twin", # test_jit + "TestPeephole.test_peephole_dict_getitem_no_optimization_get_input_arg", # test_jit + "TestTyping.test_singleton_tuple_unpack", # test_jit + "TestUnion.test_union_with_scalar_values", # test_jit + "TestAwait.test_jit_trace", # test_jit + "TestBackendsWithCompiler.test_execution", # test_jit + "TestPeephole.test_normalized_isnot_op", # test_jit + "TestTyping.test_list_type_refinement_annotation_element_mismatch", # test_jit + "TestTorchbind.test_torchbind_def_property_just_getter", # test_jit + "TestNnapiBackend.test_tensor_input", # test_jit + "TestPythonBindings.test_graph_iterator_keepalive", # test_jit + "TestUnion.test_union_subclasses_larger_union", # test_jit + "TestPeephole.test_peephole_dict_getitem_simple", # test_jit + "TestBackends.test_execution", # test_jit + "TestPeephole.test_peephole_with_writes", # test_jit + "TestRecursiveScript.test_script_basic", # test_jit + "TestScriptProfile.test_section", # test_jit + "TestPeephole.test_peephole_add_zero", # test_jit + "TestAsync.test_trace_fork_wait", # test_jit + "TestAliasAnalysis.test_nested_list_construct_not_wildcard", # test_jit + "TestList.test_mutable_list_remove_not_existing", # test_jit + "TestMisc.test_parse_ir_single_element_tensor_positive", # test_jit + "TestNnapiBackend.test_log_softmax", # test_jit + "TestOpDecompositions.test_registered_decomposition", # test_jit + "TestStringFormatting.test_string_interpolation_with_percent_in_string", # test_jit + "TestBatchMM.test_batch_mm_permitted_mutation", # test_jit + "TestTorchbind.test_torchbind_tracing_nested", # test_jit + "TestNnapiBackend.test_hardtanh", # test_jit + "TestBatchMM.test_batch_mm_no_mutation", # test_jit + "TestIsinstance.test_type_refinement", # test_jit + "TestPeephole.test_peephole_slice_all_three_args", # test_jit + "TestTyping.test_tuple_keyword", # test_jit + "TestOpDecompositions.test_op_decomposition", # test_jit + "TestBatchMM.test_batch_mm_side_permitted_mutation", # test_jit + "TestNnapiBackend.test_pointwise_binary_const", # test_jit + "TestTypeSharing.test_script_function_attribute_same", # test_jit + "TestTypesAndAnnotation.test_type_annotate_py3", # test_jit + "TestPeephole.test_peephole_dict_getitem_no_optimization_unsupported_type", # test_jit + "TestMisc.test_subexpression_Tuple_int_int_Future", # test_jit + "TestMisc.test_subexpression_Future_annotate", # test_jit + "TestStringFormatting.test_string_interpolation_with_char_placeholder_and_true_string_variable", # test_jit + "TestSlice.test_tuple_slicing", # test_jit + "TestAwait.test_await_isinstance", # test_jit + "TestNnapiBackend.test_adaptive_avg_pool2d", # test_jit + "TestIsinstance.test_list_no_contained_type", # test_jit + "TestPeephole.test_peephole_no_output_aliasing", # test_jit + "TestStringFormatting.test_string_interpolation_with_unknown_format_specifier", # test_jit + "TestPeephole.test_refine_integer_values", # test_jit + "TestStringFormatting.test_string_interpolation_with_digit_placeholder_and_string_variable", # test_jit + "TestPeephole.test_peephole_dict_getitem_no_optimization_keys_might_overlap", # test_jit + "TestPeephole.test_peephole_list_len", # test_jit + "TestMisc.test_list_literal_infer", # test_jit + "TestIgnorableArgs.test_add_out_ignorable_args", # test_jit + "TestTensorBuiltins.test_scalar_to_num_conversions", # test_jit + "TestWarn.test_warn_once_per_func", # test_jit + "TestAsync.test_async_grad_guard_no_grad", # test_jit + "TestPythonBuiltinOP.test_random", # test_jit + "TestSymbolicShapeAnalysis.test_stitching_concat", # test_jit + "TestMisc.test_if_returning_any", # test_jit + "TestBatchMM.test_batch_mm_side_prohibited_mutation_uncommon_side", # test_jit + "TestList.test_tensor_list_index_not_existing", # test_jit + "TestMisc.test_script_many_decorators", # test_jit + "TestUnion.test_union_does_not_replace_existing_annotated_type_empty_container", # test_jit + "TestNnapiBackend.test_pointwise_binary", # test_jit + "TestTypesAndAnnotation.test_tuple_no_element_type_annotation", # test_jit + "TestFrozenOptimizations.test_conv_bn_folding", # test_jit.py } From 98a044d33e401a7636e26217583b537b12a58a70 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Fri, 19 Jan 2024 14:31:12 +0000 Subject: [PATCH 07/12] [CI] Build M1 conda binaries on M1 runners (#117801) As usual, almost no work on PyTorch side, all changes are on the builder end, namely: - https://github.com/pytorch/builder/commit/8b67d32929b950c4851066800f5ef57c7646994c - depend on `blas * mkl` only on x86 machines - https://github.com/pytorch/builder/commit/eb78393f1e4bd68134d87e4059b9b25194af7dbb - install arm64 conda when running on Apple Silicon - https://github.com/pytorch/builder/commit/0d3aea4ee08e00b76fc263ce58e4c10df9f58e44 - constrain llvmdev-9 to x86 machines only - https://github.com/pytorch/builder/commit/6c6a33b2712bdb4be4406a10f75e3a404541ccd7 - set correct DEVELOPER_DIR path TODO: - We should auto-detect this `DEVELOPER_DIR` via `xcode-select` Pull Request resolved: https://github.com/pytorch/pytorch/pull/117801 Approved by: https://github.com/atalman --- .github/scripts/generate_ci_workflows.py | 3 ++- .../generated-macos-arm64-binary-conda-nightly.yml | 12 +++++------- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/.github/scripts/generate_ci_workflows.py b/.github/scripts/generate_ci_workflows.py index 1075db4255ed0..a8f2a39a2d983 100755 --- a/.github/scripts/generate_ci_workflows.py +++ b/.github/scripts/generate_ci_workflows.py @@ -342,7 +342,8 @@ class OperatingSystem: BinaryBuildWorkflow( os=OperatingSystem.MACOS_ARM64, package_type="conda", - cross_compile_arm64=True, + cross_compile_arm64=False, + macos_runner="macos-13-xlarge", build_configs=generate_binary_build_matrix.generate_conda_matrix( OperatingSystem.MACOS_ARM64 ), diff --git a/.github/workflows/generated-macos-arm64-binary-conda-nightly.yml b/.github/workflows/generated-macos-arm64-binary-conda-nightly.yml index d078c0dd3a95d..46f5a7621d7af 100644 --- a/.github/workflows/generated-macos-arm64-binary-conda-nightly.yml +++ b/.github/workflows/generated-macos-arm64-binary-conda-nightly.yml @@ -27,8 +27,6 @@ env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} PR_NUMBER: ${{ github.event.pull_request.number }} SKIP_ALL_TESTS: 1 - CROSS_COMPILE_ARM64: 1 - concurrency: group: macos-arm64-binary-conda-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }} cancel-in-progress: true @@ -36,7 +34,7 @@ concurrency: jobs: conda-py3_8-cpu-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: macos-12-xl + runs-on: macos-13-xlarge timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -154,7 +152,7 @@ jobs: uses: ./.github/workflows/_binary-upload.yml conda-py3_9-cpu-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: macos-12-xl + runs-on: macos-13-xlarge timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -272,7 +270,7 @@ jobs: uses: ./.github/workflows/_binary-upload.yml conda-py3_10-cpu-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: macos-12-xl + runs-on: macos-13-xlarge timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -390,7 +388,7 @@ jobs: uses: ./.github/workflows/_binary-upload.yml conda-py3_11-cpu-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: macos-12-xl + runs-on: macos-13-xlarge timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -508,7 +506,7 @@ jobs: uses: ./.github/workflows/_binary-upload.yml conda-py3_12-cpu-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: macos-12-xl + runs-on: macos-13-xlarge timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch From 03b12e56c758431df6f95075ce3a0113ccaeb3f9 Mon Sep 17 00:00:00 2001 From: Qingpeng Li <43924785+qingpeng9802@users.noreply.github.com> Date: Fri, 19 Jan 2024 14:56:39 +0000 Subject: [PATCH 08/12] accelerate `binary_cross_entropy_with_logits` by using `log_sigmoid` operator (#115539) When I was reimplementing BCEwithLogits, I found that `log_sigmoid` operator could accelerate the function. Simple benchmark on AMD 3600 CPU Ubuntu 22.04: |avg time (ms)|with `pos_weight`|no `pos_weight`| |-|-|-| |original|1986|1658| |this PR|1295|995| faster 35-40%. This is probably benefited by the `log_sigmoid` vectorization code. CUDA benchmark was not obtained, but I believe CUDA can be also benefited by reduecing kernel launches as https://github.com/pytorch/pytorch/pull/11054#issuecomment-442233714 and https://github.com/pytorch/pytorch/pull/78267#issue-1248398454 mentioned. The simple benchmark cpp file: [demo.txt](https://github.com/pytorch/pytorch/files/13635355/demo.txt) Pull Request resolved: https://github.com/pytorch/pytorch/pull/115539 Approved by: https://github.com/lezcano --- aten/src/ATen/native/Loss.cpp | 26 +++++------ test/profiler/test_memory_profiler.py | 62 +++++++++++++-------------- torch/_decomp/decompositions.py | 11 +---- 3 files changed, 46 insertions(+), 53 deletions(-) diff --git a/aten/src/ATen/native/Loss.cpp b/aten/src/ATen/native/Loss.cpp index 0eafdf27648d2..1eedceba1479e 100644 --- a/aten/src/ATen/native/Loss.cpp +++ b/aten/src/ATen/native/Loss.cpp @@ -30,6 +30,7 @@ #include #include #include +#include #include #include #include @@ -358,21 +359,20 @@ Tensor binary_cross_entropy_with_logits(const Tensor& input, const Tensor& targe c10::MaybeOwned pos_weight_maybe_owned = at::borrow_from_optional_tensor(pos_weight_opt); const Tensor& pos_weight = *pos_weight_maybe_owned; - Tensor loss; - auto max_val = (-input).clamp_min_(0); - if (pos_weight.defined()) { - // pos_weight need to be broadcasted, thus mul(target) is not inplace. - auto log_weight = (pos_weight - 1).mul(target).add_(1); - loss = (1 - target).mul_(input).add_(log_weight.mul_(((-max_val).exp_().add_((-input - max_val).exp_())).log_().add_(max_val))); - } else { - loss = (1 - target).mul_(input).add_(max_val).add_((-max_val).exp_().add_((-input -max_val).exp_()).log_()); - } + Tensor loss; + if (pos_weight.defined()) { + // pos_weight need to be broadcasted, thus mul(target) is not inplace. + auto log_weight = (pos_weight - 1).mul(target).add_(1); + loss = (1 - target).mul_(input).sub_(log_weight.mul_(at::log_sigmoid(input))); + } else { + loss = (1 - target).mul_(input).sub_(at::log_sigmoid(input)); + } - if (weight.defined()) { - loss.mul_(weight); - } + if (weight.defined()) { + loss.mul_(weight); + } - return apply_loss_reduction(loss, reduction); + return apply_loss_reduction(loss, reduction); } Tensor poisson_nll_loss(const Tensor& input, const Tensor& target, const bool log_input, const bool full, const double eps, const int64_t reduction) diff --git a/test/profiler/test_memory_profiler.py b/test/profiler/test_memory_profiler.py index f9348f8aaa0fc..a4927305cf1d4 100644 --- a/test/profiler/test_memory_profiler.py +++ b/test/profiler/test_memory_profiler.py @@ -1147,26 +1147,26 @@ def step_fn(mark_region): aten::mul.Tensor 1 (INPUT), 3 (INPUT) -> 4 (INPUT) aten::mul.Tensor 1 (INPUT), 5 (INPUT) -> 6 (INPUT) aten::cat 4 (INPUT), 6 (INPUT) -> 7 (INPUT) - aten::binary_cross_entropy_with_logits 7 (INPUT), 2 (INPUT) -> 13 (INPUT) + aten::binary_cross_entropy_with_logits 7 (INPUT), 2 (INPUT) -> 11 (INPUT) -- Backward --------------------------------------------------------------------------------------------- - aten::ones_like 13 (INPUT) -> 16 (INPUT) - aten::sigmoid 7 (INPUT) -> 17 (TEMPORARY) - aten::sub.Tensor 17 (TEMPORARY), 2 (INPUT) -> 18 (TEMPORARY) - aten::mul.Tensor 18 (TEMPORARY), 16 (INPUT) -> 19 (AUTOGRAD_DETAIL) - aten::div_.Scalar 19 (AUTOGRAD_DETAIL) -> 19 (AUTOGRAD_DETAIL) - aten::slice.Tensor 19 (AUTOGRAD_DETAIL) -> 19 (AUTOGRAD_DETAIL) - aten::slice.Tensor 19 (AUTOGRAD_DETAIL) -> 19 (AUTOGRAD_DETAIL) - aten::mul.Tensor 19 (AUTOGRAD_DETAIL), 1 (INPUT) -> 22 (AUTOGRAD_DETAIL) + aten::ones_like 11 (INPUT) -> 14 (INPUT) + aten::sigmoid 7 (INPUT) -> 15 (TEMPORARY) + aten::sub.Tensor 15 (TEMPORARY), 2 (INPUT) -> 16 (TEMPORARY) + aten::mul.Tensor 16 (TEMPORARY), 14 (INPUT) -> 17 (AUTOGRAD_DETAIL) + aten::div_.Scalar 17 (AUTOGRAD_DETAIL) -> 17 (AUTOGRAD_DETAIL) + aten::slice.Tensor 17 (AUTOGRAD_DETAIL) -> 17 (AUTOGRAD_DETAIL) + aten::slice.Tensor 17 (AUTOGRAD_DETAIL) -> 17 (AUTOGRAD_DETAIL) + aten::mul.Tensor 17 (AUTOGRAD_DETAIL), 1 (INPUT) -> 20 (AUTOGRAD_DETAIL) + aten::sum.dim_IntList 20 (AUTOGRAD_DETAIL) -> 21 (GRADIENT) + aten::view 21 (GRADIENT) -> 21 (GRADIENT) + aten::detach 21 (GRADIENT) -> 21 (GRADIENT) + aten::detach 21 (GRADIENT) -> ??? + aten::mul.Tensor 17 (AUTOGRAD_DETAIL), 1 (INPUT) -> 22 (AUTOGRAD_DETAIL) aten::sum.dim_IntList 22 (AUTOGRAD_DETAIL) -> 23 (GRADIENT) aten::view 23 (GRADIENT) -> 23 (GRADIENT) aten::detach 23 (GRADIENT) -> 23 (GRADIENT) - aten::detach 23 (GRADIENT) -> ??? - aten::mul.Tensor 19 (AUTOGRAD_DETAIL), 1 (INPUT) -> 24 (AUTOGRAD_DETAIL) - aten::sum.dim_IntList 24 (AUTOGRAD_DETAIL) -> 25 (GRADIENT) - aten::view 25 (GRADIENT) -> 25 (GRADIENT) - aten::detach 25 (GRADIENT) -> 25 (GRADIENT) - aten::detach 25 (GRADIENT) -> ???""", + aten::detach 23 (GRADIENT) -> ???""", ) def test_categories_e2e_simple_fwd_bwd_step(self) -> None: @@ -1199,30 +1199,30 @@ def step_fn(mark_region): aten::mul.Tensor 1 (INPUT), 3 (PARAMETER) -> 4 (ACTIVATION) aten::mul.Tensor 1 (INPUT), 5 (PARAMETER) -> 6 (ACTIVATION) aten::cat 4 (ACTIVATION), 6 (ACTIVATION) -> 7 (ACTIVATION) - aten::binary_cross_entropy_with_logits 7 (ACTIVATION), 2 (INPUT) -> 13 (ACTIVATION) + aten::binary_cross_entropy_with_logits 7 (ACTIVATION), 2 (INPUT) -> 11 (ACTIVATION) -- Backward --------------------------------------------------------------------------------------------- - aten::ones_like 13 (ACTIVATION) -> 16 (ACTIVATION) - aten::sigmoid 7 (ACTIVATION) -> 17 (TEMPORARY) - aten::sub.Tensor 17 (TEMPORARY), 2 (INPUT) -> 18 (TEMPORARY) - aten::mul.Tensor 18 (TEMPORARY), 16 (ACTIVATION) -> 19 (AUTOGRAD_DETAIL) - aten::div_.Scalar 19 (AUTOGRAD_DETAIL) -> 19 (AUTOGRAD_DETAIL) - aten::slice.Tensor 19 (AUTOGRAD_DETAIL) -> 19 (AUTOGRAD_DETAIL) - aten::slice.Tensor 19 (AUTOGRAD_DETAIL) -> 19 (AUTOGRAD_DETAIL) - aten::mul.Tensor 19 (AUTOGRAD_DETAIL), 1 (INPUT) -> 22 (AUTOGRAD_DETAIL) + aten::ones_like 11 (ACTIVATION) -> 14 (ACTIVATION) + aten::sigmoid 7 (ACTIVATION) -> 15 (TEMPORARY) + aten::sub.Tensor 15 (TEMPORARY), 2 (INPUT) -> 16 (TEMPORARY) + aten::mul.Tensor 16 (TEMPORARY), 14 (ACTIVATION) -> 17 (AUTOGRAD_DETAIL) + aten::div_.Scalar 17 (AUTOGRAD_DETAIL) -> 17 (AUTOGRAD_DETAIL) + aten::slice.Tensor 17 (AUTOGRAD_DETAIL) -> 17 (AUTOGRAD_DETAIL) + aten::slice.Tensor 17 (AUTOGRAD_DETAIL) -> 17 (AUTOGRAD_DETAIL) + aten::mul.Tensor 17 (AUTOGRAD_DETAIL), 1 (INPUT) -> 20 (AUTOGRAD_DETAIL) + aten::sum.dim_IntList 20 (AUTOGRAD_DETAIL) -> 21 (GRADIENT) + aten::view 21 (GRADIENT) -> 21 (GRADIENT) + aten::detach 21 (GRADIENT) -> 21 (GRADIENT) + aten::detach 21 (GRADIENT) -> 21 (GRADIENT) + aten::mul.Tensor 17 (AUTOGRAD_DETAIL), 1 (INPUT) -> 22 (AUTOGRAD_DETAIL) aten::sum.dim_IntList 22 (AUTOGRAD_DETAIL) -> 23 (GRADIENT) aten::view 23 (GRADIENT) -> 23 (GRADIENT) aten::detach 23 (GRADIENT) -> 23 (GRADIENT) aten::detach 23 (GRADIENT) -> 23 (GRADIENT) - aten::mul.Tensor 19 (AUTOGRAD_DETAIL), 1 (INPUT) -> 24 (AUTOGRAD_DETAIL) - aten::sum.dim_IntList 24 (AUTOGRAD_DETAIL) -> 25 (GRADIENT) - aten::view 25 (GRADIENT) -> 25 (GRADIENT) - aten::detach 25 (GRADIENT) -> 25 (GRADIENT) - aten::detach 25 (GRADIENT) -> 25 (GRADIENT) -- Optimizer -------------------------------------------------------------------------------------------- - aten::add_.Tensor 3 (PARAMETER), 25 (GRADIENT) -> 3 (PARAMETER) - aten::add_.Tensor 5 (PARAMETER), 23 (GRADIENT) -> 5 (PARAMETER)""", + aten::add_.Tensor 3 (PARAMETER), 23 (GRADIENT) -> 3 (PARAMETER) + aten::add_.Tensor 5 (PARAMETER), 21 (GRADIENT) -> 5 (PARAMETER)""", ) def test_categories_e2e_simple_module_fwd(self) -> None: diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index 0e2e4ee6e2883..078ccde41bc64 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -3821,18 +3821,11 @@ def mv(self, vec): def binary_cross_entropy_with_logits( self, target, weight=None, pos_weight=None, reduction=Reduction.MEAN.value ): - max_val = (-self).clamp_min(0) if pos_weight is not None: log_weight = (pos_weight - 1) * target + 1 - loss = (1 - target) * self + log_weight * ( - ((-max_val).exp() + (-self - max_val).exp()).log() + max_val - ) + loss = (1 - target) * self - (log_weight * F.logsigmoid(self)) else: - loss = ( - (1 - target) * self - + max_val - + ((-max_val).exp() + (-self - max_val).exp()).log() - ) + loss = (1 - target) * self - F.logsigmoid(self) if weight is not None: loss = loss * weight From de257183007b218ce95aa6c9c084161f7d70bc77 Mon Sep 17 00:00:00 2001 From: atalman Date: Fri, 19 Jan 2024 15:01:46 +0000 Subject: [PATCH 09/12] [release] Docker Release build trigger on rc for testing (#117849) Enable triggering the Docker Release builds on RC. Use test channel in this case. Hence following logic is applied: 1. On RC trigger use test channel and upload to pytorch-test : https://github.com/orgs/pytorch/packages/container/package/pytorch-test 2. On Final RC use prod channel and upload to pytorch : https://github.com/orgs/pytorch/packages/container/package/pytorch 3. Nightly: https://github.com/orgs/pytorch/packages/container/package/pytorch-nightly Pull Request resolved: https://github.com/pytorch/pytorch/pull/117849 Approved by: https://github.com/malfet --- .github/workflows/docker-release.yml | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/.github/workflows/docker-release.yml b/.github/workflows/docker-release.yml index 51af9af70883f..1d566cb21099d 100644 --- a/.github/workflows/docker-release.yml +++ b/.github/workflows/docker-release.yml @@ -11,8 +11,10 @@ on: branches: - nightly tags: - # We want to run this build on final release tag + # Final Release tags look like: v1.11.0 - v[0-9]+.[0-9]+.[0-9]+ + # Release candidate tags look like: v1.11.0-rc1 + - v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+ - ciflow/nightly/* concurrency: @@ -101,6 +103,16 @@ jobs: echo "${RUNNER_TEMP}/bin" >> "${GITHUB_PATH}" # Generate PyTorch version to use echo "PYTORCH_VERSION=$(python3 .github/scripts/generate_pytorch_version.py --no-build-suffix)" >> "${GITHUB_ENV}" + - name: Setup test specific variables + if: ${{ startsWith(github.event.ref, 'refs/tags/v') }} + run: | + if [[ ${{ github.event.ref }} =~ ^refs/tags/v[0-9]+\.[0-9]+\.[0-9]+-rc[0-9]+$ ]]; then + { + echo "DOCKER_IMAGE=pytorch-test"; + echo "INSTALL_CHANNEL=pytorch-test"; + echo "TRITON_VERSION=$(cut -f 1 .ci/docker/triton_version.txt)"; + } >> "${GITHUB_ENV}" + fi - name: Setup nightly specific variables if: ${{ github.event.ref == 'refs/heads/nightly' || startsWith(github.event.ref, 'refs/tags/ciflow/nightly/') }} run: | From 6c5c2121b1f8e4fe1ce7822d9dd1e47e2059e1fe Mon Sep 17 00:00:00 2001 From: Catherine Lee Date: Fri, 19 Jan 2024 16:45:35 +0000 Subject: [PATCH 10/12] Run some OOMing tests serially (#117759) They were disabled due to being flaky due to OOMs but got renamed. Seeing if running serially helps I kind of want to keep this test disabled since the rest of the file is probably fine... Issues in question: #113132 #113136 #113140 Pull Request resolved: https://github.com/pytorch/pytorch/pull/117759 Approved by: https://github.com/malfet, https://github.com/huydhn --- test/run_test.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/run_test.py b/test/run_test.py index 64e1496a4a292..b53a23652407c 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -361,6 +361,9 @@ def skip_test_p(name: str) -> bool: "test_native_mha", # OOM "test_module_hooks", # OOM "inductor/test_max_autotune", # Testing, probably revert later + "inductor/test_torchinductor", # OOM on test_large_block_sizes + "inductor/test_torchinductor_dynamic_shapes", # OOM on test_large_block_sizes + "inductor/test_torchinductor_codegen_dynamic_shapes", # OOM on test_large_block_sizes ] # A subset of onnx tests that cannot run in parallel due to high memory usage. ONNX_SERIAL_LIST = [ From 249a2261139dc06d4a365262d84432e1727f34dc Mon Sep 17 00:00:00 2001 From: angelayi Date: Fri, 19 Jan 2024 17:13:39 +0000 Subject: [PATCH 11/12] [export] Error on not pytree-flattened nodes (#117598) Attempts to make the input/output mismatch error better by first checking if the inputs/outputs are able to be pytree flattened into supporting types (tensors, symints, ...). So if user passes in some datastructure which does not have a pytree flatten registration, this will error with the message "It looks like one of the inputs is with type CustomType is not supported or pytree flatten-able.... please register a pytree flatten/unflatten function using the pytree.register_pytree_node API". The check inside of produce_matching should now only error if something unexpected happens (dynamo accidentally adds an input or removes an output), and should be considered an internal error. Pull Request resolved: https://github.com/pytorch/pytorch/pull/117598 Approved by: https://github.com/avikchaudhuri, https://github.com/BowenBao --- test/dynamo/test_export.py | 61 +++++++---- test/onnx/test_fx_to_onnx_with_onnxruntime.py | 3 +- torch/_dynamo/eval_frame.py | 102 ++++++++++-------- torch/_dynamo/exc.py | 1 + 4 files changed, 105 insertions(+), 62 deletions(-) diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py index eb68db84e4318..3c3cb3736cf24 100644 --- a/test/dynamo/test_export.py +++ b/test/dynamo/test_export.py @@ -2254,33 +2254,55 @@ def f(t): return t.x + t.y with self.assertRaisesRegex( - AssertionError, - "graph-captured input #1, of type .*Tensor.*, " - "is not among original inputs of types: .*Tensors", + UserError, + "It looks like one of the inputs with type .*Tensors.* " + "is not supported or pytree-flattenable", ): - torch._dynamo.export( - f, Tensors(x=torch.randn(10), y=torch.randn(10)), aten_graph=False + torch._dynamo.export(f, aten_graph=False)( + Tensors(x=torch.randn(10), y=torch.randn(10)) ) def f(x, y): return Tensors(x=x.sin(), y=y.cos()) with self.assertRaisesRegex( - AssertionError, - "original output #1 is .*Tensors.*, " - "but only the following types are supported", + UserError, + "It looks like one of the outputs with type .*Tensors.* " + "is not supported or pytree-flattenable", ): - torch._dynamo.export(f, torch.randn(10), torch.randn(10), aten_graph=False) + torch._dynamo.export(f, aten_graph=False)(torch.randn(10), torch.randn(10)) + + def test_empty(self): + def f(x): + return x + + exported = torch._dynamo.export(f)(torch.randn(3, 3)) + out_graph = exported[0] + inp = torch.randn(3, 3) + self.assertTrue(torch._dynamo.utils.same(inp, out_graph(inp))) + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.a = torch.ones(3, 3) + + def forward(self): + return self.a + + exported = torch._dynamo.export(M())() + out_graph = exported[0] + self.assertTrue(torch._dynamo.utils.same(torch.ones(3, 3), out_graph())) def test_none_out(self): def f(x, y): _ = x + y with self.assertRaisesRegex( - AssertionError, - "original output #1 is None, but only the following types are supported", + UserError, + "It looks like one of the outputs with type .*None.* " + "is not supported or pytree-flattenable", ): - torch._dynamo.export(f, torch.randn(10), torch.randn(10), aten_graph=False) + torch._dynamo.export(f, aten_graph=False)(torch.randn(10), torch.randn(10)) def test_primitive_constant_output(self): class Foo(torch.nn.Module): @@ -2292,8 +2314,9 @@ def forward(self, x): foo = Foo() with self.assertRaisesRegex( - AssertionError, - "original output #2 is 5, but only the following types are supported", + UserError, + "It looks like one of the outputs with type .*int.* " + "is not supported or pytree-flattenable", ): torch.export.export(foo, (torch.tensor(3),)) @@ -2305,8 +2328,9 @@ def forward(self, x, y): # new behavior with self.assertRaisesRegex( - AssertionError, - "original output #2 is 5, but only the following types are supported", + UserError, + "It looks like one of the outputs with type .*int.* " + "is not supported or pytree-flattenable", ): torch.export.export(bar, (torch.tensor(3), 5)) @@ -2317,8 +2341,9 @@ def forward(self, x, y): qux = Qux() with self.assertRaisesRegex( - AssertionError, - "original output #2 is 4, but only the following types are supported", + UserError, + "It looks like one of the outputs with type .*int.* " + "is not supported or pytree-flattenable", ): torch.export.export(qux, (torch.tensor(3), 5)) diff --git a/test/onnx/test_fx_to_onnx_with_onnxruntime.py b/test/onnx/test_fx_to_onnx_with_onnxruntime.py index 0e43bfcb82deb..5aab27bc5c70c 100644 --- a/test/onnx/test_fx_to_onnx_with_onnxruntime.py +++ b/test/onnx/test_fx_to_onnx_with_onnxruntime.py @@ -95,8 +95,7 @@ def func(x): self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(func, (tensor_x,)) @pytorch_test_common.xfail( - error_message="graph-captured input #2, of type , " - "is not among original inputs of types: ()", + error_message="Unexpectedly found a in the inputs.", reason="https://github.com/pytorch/pytorch/issues/96379", ) def test_func_with_args_and_tensor_kwargs(self): diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index aca6be7753694..6c3671ffeb10b 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -938,60 +938,78 @@ def rewrite_signature( ): orig_args, orig_kwargs = pytree.tree_unflatten(flat_args, in_spec) - supported_types = (torch.Tensor, torch.SymInt, torch.SymFloat, torch.SymBool) - - def is_supported_type(val): - return isinstance(val, supported_types) - - def produce_matching(sources, candidates): - source_types = " or ".join( - [ - desc - + " of types: (" - + ", ".join([str(type(val)) for val in vals]) - + ")" - for desc, vals in sources.items() - ] - ) - source_vals = [val for vals in sources.values() for val in vals] + constant_types = [ + int, + str, + bool, + float, + torch.memory_format, + torch.device, + torch.dtype, + torch.layout, + ] + + def check_user_input_output(flat_values, error_type): + supported_types = [ + torch.Tensor, + torch.SymInt, + torch.SymFloat, + torch.SymBool, + torch._C.ScriptObject, + ] + if error_type == UserErrorType.INVALID_INPUT: + supported_types.extend(constant_types) + + def is_supported_type(val): + return isinstance(val, tuple(supported_types)) + + value_type = "input" if error_type == UserErrorType.INVALID_INPUT else "output" + # We only check that the outputs are not None. Inputs can be None. + for v in flat_values: + if not is_supported_type(v): + if error_type == UserErrorType.INVALID_INPUT and v is None: + continue + + raise UserError( + error_type, + f"It looks like one of the {value_type}s with type `{type(v)}` " + "is not supported or pytree-flattenable. \n" + f"Exported graphs {value_type}s can only contain the " + f"following supported types: {supported_types}. \n" + "If you are using a custom class object, " + "please register a pytree_flatten/unflatten function " + "using `torch.utils._pytree.register_pytree_node` or " + "`torch.export.register_dataclass`.", + ) + + check_user_input_output(flat_args, UserErrorType.INVALID_INPUT) + flat_results_traced, out_spec_traced = pytree.tree_flatten(dynamo_traced_result) + check_user_input_output(flat_results_traced, UserErrorType.INVALID_OUTPUT) + + def produce_matching(debug_type, sources, candidates): matched_elements_positions = [] dict_of_source_vals = {} - for i, val in enumerate(source_vals): + for i, val in enumerate(sources): dict_of_source_vals[id(val)] = i - for candidate_desc, candidate_vals in candidates.items(): - for i, val in enumerate(candidate_vals): - if is_supported_type(val): - if id(val) in dict_of_source_vals: - matched_elements_positions.append(dict_of_source_vals[id(val)]) - else: - raise AssertionError( - f"{candidate_desc} #{i+1}, of type {type(val)}, is not among {source_types}" - 'Set TORCH_LOGS="+export" for more information.' - ) - else: - raise AssertionError( - f"{candidate_desc} #{i+1} is {val}, but only " - f"the following types are supported: {supported_types}" - 'Set TORCH_LOGS="+export" for more information.' - ) + for i, val in enumerate(candidates): + if id(val) not in dict_of_source_vals: + raise AssertionError( + f"Unexpectedly found a {type(val)} in the {debug_type}.\n" + 'Please file an issue along with a paste of the logs from TORCH_LOGS="+export"' + ) + + matched_elements_positions.append(dict_of_source_vals[id(val)]) return matched_elements_positions matched_input_elements_positions = produce_matching( - sources={"original inputs": flat_args}, - candidates={"graph-captured input": graph_captured_input}, + "inputs", flat_args, graph_captured_input ) - flat_results_traced, out_spec_traced = pytree.tree_flatten(dynamo_traced_result) - assert graph_captured_output is not None matched_output_elements_positions = produce_matching( - sources={ - "graph-captured outputs": list(graph_captured_output), - "original inputs": flat_args, - }, - candidates={"original output": flat_results_traced}, + "outputs", list(graph_captured_output) + flat_args, flat_results_traced ) new_graph = FlattenInputOutputSignature( diff --git a/torch/_dynamo/exc.py b/torch/_dynamo/exc.py index c47cce8fa04d2..b664bb6e2b528 100644 --- a/torch/_dynamo/exc.py +++ b/torch/_dynamo/exc.py @@ -133,6 +133,7 @@ class UserErrorType(Enum): CONSTRAINT_VIOLATION = auto() DYNAMIC_DIM = auto() INVALID_INPUT = auto() + INVALID_OUTPUT = auto() class UserError(Unsupported): From f316c35a341fb6ab231511545c331856d11704f9 Mon Sep 17 00:00:00 2001 From: "Tugsbayasgalan (Tugsuu) Manlaibaatar" Date: Fri, 19 Jan 2024 17:16:45 +0000 Subject: [PATCH 12/12] [export] Support preserving submodule callling convention in non-strict export (#117796) Summary: Title Test Plan: CI Reviewed By: zhxchen17 Differential Revision: D52889236 Pull Request resolved: https://github.com/pytorch/pytorch/pull/117796 Approved by: https://github.com/angelayi --- test/export/test_unflatten.py | 75 ++++++++++++++++++----------------- torch/_export/wrappers.py | 4 +- torch/export/_trace.py | 74 +++++++++++++++++++++++++++------- 3 files changed, 101 insertions(+), 52 deletions(-) diff --git a/test/export/test_unflatten.py b/test/export/test_unflatten.py index 4eea5dc9e51a6..9d709751a7f98 100644 --- a/test/export/test_unflatten.py +++ b/test/export/test_unflatten.py @@ -30,7 +30,7 @@ from torch.export import Constraint, Dim, export from torch.fx.experimental.proxy_tensor import make_fx from torch.testing import FileCheck -from torch.testing._internal.common_utils import run_tests, TestCase +from torch.testing._internal.common_utils import skipIfTorchDynamo, run_tests, TestCase from torch.utils._pytree import ( LeafSpec, tree_flatten, @@ -199,6 +199,7 @@ def forward(self, x): id(getattr(unflattened_module.sub_net, "2")), ) + @skipIfTorchDynamo("Non strict mode is not meant to run with dynamo") def test_unflatten_preserve_signature(self): class NestedChild(torch.nn.Module): def forward(self, zx, y): @@ -234,41 +235,43 @@ def forward(self, x, y): orig_eager = MyModule() inps = torch.rand(2, 3), torch.rand(2, 3) - export_module = export( - orig_eager, - inps, - {}, - preserve_module_call_signature=("foo.nested",), - ) - unflattened = unflatten(export_module) - self.compare_outputs(export_module, unflattened, inps) - unflattened.foo.nested = NestedChild() - self.compare_outputs(export_module, unflattened, inps) - - # Test tree spec mismatched input - orig_outs = export_module(*inps) - new_inps = *inps, torch.rand(2, 3) - with self.assertRaisesRegex( - TypeError, - "There is no flat args adapter sepcified. Are you sure you are calling this with the right arguments?", - ): - unflattened(new_inps) - - # With flat args adapter - class KeepTwoFlatArgsAdapter(FlatArgsAdapter): - def adapt( - self, - target_spec: TreeSpec, - input_spec: TreeSpec, - input_args: List[Any], - ) -> List[Any]: - while len(input_args) > 2: - input_args.pop(-1) - return input_args - - unflattened = unflatten(export_module, KeepTwoFlatArgsAdapter()) - new_outs = unflattened(*new_inps) - self.assertTrue(torch.allclose(orig_outs, new_outs)) + for strict in [True, False]: + export_module = export( + orig_eager, + inps, + {}, + preserve_module_call_signature=("foo.nested",), + strict=strict + ) + unflattened = unflatten(export_module) + self.compare_outputs(export_module, unflattened, inps) + unflattened.foo.nested = NestedChild() + self.compare_outputs(export_module, unflattened, inps) + + # Test tree spec mismatched input + orig_outs = export_module(*inps) + new_inps = *inps, torch.rand(2, 3) + with self.assertRaisesRegex( + TypeError, + "There is no flat args adapter sepcified. Are you sure you are calling this with the right arguments?", + ): + unflattened(new_inps) + + # With flat args adapter + class KeepTwoFlatArgsAdapter(FlatArgsAdapter): + def adapt( + self, + target_spec: TreeSpec, + input_spec: TreeSpec, + input_args: List[Any], + ) -> List[Any]: + while len(input_args) > 2: + input_args.pop(-1) + return input_args + + unflattened = unflatten(export_module, KeepTwoFlatArgsAdapter()) + new_outs = unflattened(*new_inps) + self.assertTrue(torch.allclose(orig_outs, new_outs)) def test_unflatten_param_list_dict(self): class Mod(torch.nn.Module): diff --git a/torch/_export/wrappers.py b/torch/_export/wrappers.py index b791550d5ec8c..51f0c64ba1840 100644 --- a/torch/_export/wrappers.py +++ b/torch/_export/wrappers.py @@ -61,7 +61,9 @@ def _wrap_submodule(mod, path, module_call_specs): submodule = getattr(submodule, name) def update_module_call_signatures(path, in_spec, out_spec): - assert path not in module_call_specs + if path in module_call_specs: + assert module_call_specs[path]["in_spec"] == in_spec + assert module_call_specs[path]["out_spec"] == out_spec module_call_specs[path] = {"in_spec": in_spec, "out_spec": out_spec} assert "forward" not in submodule.__dict__ diff --git a/torch/export/_trace.py b/torch/export/_trace.py index cd1d00ee514eb..0784b0519c17c 100644 --- a/torch/export/_trace.py +++ b/torch/export/_trace.py @@ -517,12 +517,24 @@ def _export( constraints = constraints or [] kwargs = kwargs or {} + flat_args, orig_in_spec = pytree.tree_flatten((args, kwargs)) + if not strict: assert isinstance(f, torch.nn.Module) - assert len(preserve_module_call_signature) == 0 assert len(kwargs) == 0, "keyword arguments NYI" out_spec = None + module_call_specs: Dict[str, Dict[str, pytree.TreeSpec]] = {} + + def strip_root(x): + if isinstance(x, str) and x.startswith("_export_root"): + stripped = x[len("_export_root") :] + return stripped[1:] if stripped.startswith(".") else stripped + return x + + def fixup_key(x): + return "L__self__" + strip_root(x) + def _tuplify_outputs(aot_export): def _aot_export_non_strict(mod, args, **kwargs): class Wrapper(torch.nn.Module): @@ -537,16 +549,16 @@ def forward(self, *args, **kwargs): ) return tuple(flat_outs) - gm, sig = aot_export(Wrapper(mod), args, **kwargs) - - def strip_root(x): - if isinstance(x, str) and x.startswith("_export_root"): - stripped = x[len("_export_root") :] - return stripped[1:] if stripped.startswith(".") else stripped - return x - - def fixup_key(x): - return "L__self__" + strip_root(x) + wrapped_mod = Wrapper(mod) + # Patch export_root to the signatures so that wrapper module correctly populates the + # in/out spec + new_preserved_call_signatures = [ + "_export_root." + i for i in preserve_module_call_signature + ] + with _wrap_submodules( + wrapped_mod, new_preserved_call_signatures, module_call_specs + ): + gm, sig = aot_export(wrapped_mod, args, **kwargs) sig.parameters = pytree.tree_map(strip_root, sig.parameters) sig.buffers = pytree.tree_map(strip_root, sig.buffers) @@ -585,9 +597,39 @@ def fixup_key(x): fake_mode, src_equalities, original_signature, ep_non_strict.gm ) assert out_spec is not None + + gm = ep_non_strict.gm + + module_call_signatures = { + strip_root(fqn): ModuleCallSignature(inputs=[], outputs=[], **specs) + for fqn, specs in module_call_specs.items() + } + + if len(preserve_module_call_signature) > 0: + for node in gm.graph.nodes: + if node.target == torch.ops.higher_order._export_tracepoint: + if "path" in node.kwargs: + path = strip_root(node.kwargs["path"]) + with gm.graph.inserting_before(node): + new_node = gm.graph.create_node( + "call_function", + torch.ops.higher_order._export_tracepoint, + args=node.args, + kwargs={ + "path": path, + "kind": node.kwargs["kind"], + }, + ) + node.replace_all_uses_with(new_node) + gm.graph.erase_node(node) + + res = CollectTracepointsPass(module_call_signatures, ep_non_strict.sig)(gm) + assert res is not None + gm = res.graph_module + return ExportedProgram( - root=ep_non_strict.gm, - graph=ep_non_strict.gm.graph, + root=gm, + graph=gm.graph, graph_signature=ep_non_strict.sig, state_dict=_get_params_buffers(f), range_constraints=range_constraints, @@ -595,9 +637,12 @@ def fixup_key(x): ModuleCallEntry( "", ModuleCallSignature( - [], [], pytree.tree_flatten((args, {}))[1], out_spec + inputs=[], outputs=[], in_spec=orig_in_spec, out_spec=out_spec ), ) + ] + + [ + ModuleCallEntry(fqn, sig) for fqn, sig in module_call_signatures.items() ], example_inputs=(args, kwargs), constants=ep_non_strict.constants, @@ -768,7 +813,6 @@ def _aot_export_strict(gm_torch_level: torch.fx.GraphModule, args, **kwargs): ), len(export_graph_signature.input_specs), ) - flat_args, orig_in_spec = pytree.tree_flatten((args, kwargs)) range_constraints = _process_constraints( gm, num_lifted,