From b3c225629e0f283c85b18975933b7860ba4b88c3 Mon Sep 17 00:00:00 2001 From: Horace He Date: Thu, 17 Mar 2022 01:54:01 +0000 Subject: [PATCH 1/7] added proxy tensor to core --- torch/fx/_proxy_tensor.py | 201 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 201 insertions(+) create mode 100644 torch/fx/_proxy_tensor.py diff --git a/torch/fx/_proxy_tensor.py b/torch/fx/_proxy_tensor.py new file mode 100644 index 000000000000..d56d1059cdda --- /dev/null +++ b/torch/fx/_proxy_tensor.py @@ -0,0 +1,201 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import functools +from typing import Any, Dict, Optional, Tuple, Callable, Union +import torch +from torch._C import _disabled_torch_function_impl +import torch.utils._pytree as pytree +from torch.fx import Tracer, GraphModule +import torch.fx as fx +from torch.fx.passes.shape_prop import _extract_tensor_metadata +from contextlib import contextmanager + +aten = torch.ops.aten + +CURRENT_DECOMPOSITION_TABLE = {} + + +@contextmanager +def no_dispatch(): + guard = torch._C._DisableTorchDispatch() + try: + yield + finally: + del guard + + +@contextmanager +def pythonkey_decompose(decomposition_table): + global CURRENT_DECOMPOSITION_TABLE + CURRENT_DECOMPOSITION_TABLE = decomposition_table + try: + yield CURRENT_DECOMPOSITION_TABLE + finally: + CURRENT_DECOMPOSITION_TABLE = {} + + +class ProxyTensor(torch.Tensor): + elem: torch.Tensor + + __slots__ = ['elem', 'proxy'] + + @staticmethod + def __new__(cls, elem, proxy): + # Wrapping something in ProxyTensor implicitly detaches + # gradients. If something required grad, we will collect it as if it + # were a leaf. A consequence of detaching in this way is you + # need to maintain a parameter cache when translating tensors + # into ProxyTensor, so you don't create multiple copies of + # a gradient (they are aliased, but they would count as independent + # leaves). An alternate strategy would be to avoid implicitly + # detaching and instead "catch" gradients as they exit the + # ProxyTensor boundary. + # assert not elem.requires_grad or not torch.is_grad_enabled() + + r = torch.Tensor._make_subclass(cls, elem, elem.requires_grad) + r.proxy = proxy + proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(r) + return r + + def __repr__(self): + # This is a bit goofy but whatever. Should fix up _tensor_str.py to + # work on subclasses when it calls tolist + return f"ProxyTensor({torch.Tensor._make_subclass(torch.Tensor, self)})" + + __torch_function__ = _disabled_torch_function_impl + + @classmethod + def __torch_dispatch__(cls, func_overload, types, args=(), kwargs=None): + func = func_overload.overloadpacket + if func_overload in CURRENT_DECOMPOSITION_TABLE: + return CURRENT_DECOMPOSITION_TABLE[func_overload](*args, **kwargs) + # Commenting this out for now since it causes some spurious failures (such as error checking) + # if func == aten._local_scalar_dense: + # raise RuntimeError("It appears that you're trying to get value out of a tracing tensor - erroring out! " + # "It's likely that this is caused by data-dependent control flow or similar.") + + def unwrap_proxy(e): + return e.proxy if isinstance(e, ProxyTensor) else e + + proxy_args = pytree.tree_map(unwrap_proxy, args) + proxy_kwargs = pytree.tree_map(unwrap_proxy, kwargs) + + proxy_out = func(*proxy_args, **proxy_kwargs) + + # Kind of a hacky way to test if an op is in-place or not + if func.__name__[-1] == "_" and func.__name__[0] != "_": + args[0].proxy = proxy_out + proxy_out.node.meta['tensor_meta'] = _extract_tensor_metadata(args[0]) + + with no_dispatch(): + real_out = func_overload(*args, **kwargs) + + def wrap_with_proxy(e, proxy): + # Some ops (like native_batch_norm_backward) return undefined tensors that get + # converted into None in python. + # As the function signature expects tensors, if we directly return these None + # tensors back to C++, we'll error. + if e is None: + e = torch.empty(()) + if type(e) == torch.Tensor: + return ProxyTensor(e, proxy) + else: + return e + if isinstance(real_out, tuple): + return tuple([wrap_with_proxy(e, proxy_out[idx]) for idx, e in enumerate(real_out)]) + elif isinstance(real_out, list): + return list([wrap_with_proxy(e, proxy_out[idx]) for idx, e in enumerate(real_out)]) + elif isinstance(real_out, torch.Tensor): + return wrap_with_proxy(real_out, proxy_out) + else: + return real_out + + +class PythonKeyTracer(Tracer): + def __init__(self): + super().__init__() + + def call_module( + self, m: torch.nn.Module, forward: Callable[..., Any], args: Tuple[Any, ...], kwargs: Dict[str, Any] + ) -> Any: + return forward(*args, **kwargs) + + def _module_getattr(self, attr, attr_val, parameter_proxy_cache): + if isinstance(attr_val, torch.nn.Parameter): + for n, p in self.root.named_parameters(): + if attr_val is p: + if n not in parameter_proxy_cache: + proxy = self.create_proxy('get_attr', n, (), {}) + parameter_proxy_cache[n] = ProxyTensor(attr_val, proxy) + return parameter_proxy_cache[n] + return attr_val + return attr_val + + # We need to do this so that parameters entering the `make_fx` context have + # a reference to them (and also have requires_grad set on them correctly + # I'm not actually sure if this is the right thing to do ... + def create_arg(self, a: Any): + if isinstance(a, torch.nn.Parameter): + for n, p in self.root.named_parameters(): + if a is p: + return self.create_node('get_attr', n, (), {}) + qualname: Optional[str] = None + + if not qualname: + i = 0 + while True: + qualname = f'_param_constant{i}' + if not hasattr(self.root, qualname): + break + i += 1 + setattr(self.root, qualname, a) + + return self.create_node('get_attr', qualname, (), {}) + return super().create_arg(a) + + +def pythonkey_trace( + root: Union[torch.nn.Module, Callable], concrete_args: Optional[Dict[str, Any]] = None +) -> GraphModule: + tracer = PythonKeyTracer() + graph = tracer.trace(root, concrete_args) + name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ + return GraphModule(tracer.root, graph, name) + + +def wrap_key(f, inps): + flat_inps, _ = pytree.tree_flatten(inps) + + @functools.wraps(f) + def wrapped(*args): + flat_args, args_spec = pytree.tree_flatten(args) + assert(len(flat_args) == len(flat_inps)) + for idx, arg in enumerate(flat_args): + if isinstance(flat_inps[idx], torch.Tensor): + flat_args[idx] = ProxyTensor(flat_inps[idx], arg) + else: + flat_args[idx] = flat_inps[idx] + + tree_args = pytree.tree_unflatten(flat_args, args_spec) + out = f(*tree_args) + flat_outs, out_spec = pytree.tree_flatten(out) + for idx in range(len(flat_outs)): + if isinstance(flat_outs[idx], torch.Tensor) and isinstance(flat_outs[idx], ProxyTensor): + flat_outs[idx] = flat_outs[idx].proxy + return pytree.tree_unflatten(flat_outs, out_spec) + + return wrapped + + +def make_fx(f, decomposition_table={}): + @functools.wraps(f) + def wrapped(*args): + phs = pytree.tree_map(lambda x: fx.PH, args) + with pythonkey_decompose(decomposition_table): + t = pythonkey_trace(wrap_key(f, args), concrete_args=tuple(phs)) + return t + + return wrapped From 00ec0635c1c88485cfa362d4ea6dd336b4d51b19 Mon Sep 17 00:00:00 2001 From: Horace He Date: Sat, 26 Mar 2022 02:26:52 +0000 Subject: [PATCH 2/7] resolved some monir comments --- torch/fx/_proxy_tensor.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/torch/fx/_proxy_tensor.py b/torch/fx/_proxy_tensor.py index d56d1059cdda..3d51af7f85b3 100644 --- a/torch/fx/_proxy_tensor.py +++ b/torch/fx/_proxy_tensor.py @@ -28,13 +28,14 @@ def no_dispatch(): @contextmanager -def pythonkey_decompose(decomposition_table): +def decompose(decomposition_table): global CURRENT_DECOMPOSITION_TABLE + old_decomposition_table = CURRENT_DECOMPOSITION_TABLE CURRENT_DECOMPOSITION_TABLE = decomposition_table try: yield CURRENT_DECOMPOSITION_TABLE finally: - CURRENT_DECOMPOSITION_TABLE = {} + CURRENT_DECOMPOSITION_TABLE = old_decomposition_table class ProxyTensor(torch.Tensor): @@ -73,9 +74,9 @@ def __torch_dispatch__(cls, func_overload, types, args=(), kwargs=None): if func_overload in CURRENT_DECOMPOSITION_TABLE: return CURRENT_DECOMPOSITION_TABLE[func_overload](*args, **kwargs) # Commenting this out for now since it causes some spurious failures (such as error checking) - # if func == aten._local_scalar_dense: - # raise RuntimeError("It appears that you're trying to get value out of a tracing tensor - erroring out! " - # "It's likely that this is caused by data-dependent control flow or similar.") + if func == aten._local_scalar_dense: + raise RuntimeError("It appears that you're trying to get value out of a tracing tensor - erroring out! " + "It's likely that this is caused by data-dependent control flow or similar.") def unwrap_proxy(e): return e.proxy if isinstance(e, ProxyTensor) else e @@ -105,7 +106,7 @@ def wrap_with_proxy(e, proxy): else: return e if isinstance(real_out, tuple): - return tuple([wrap_with_proxy(e, proxy_out[idx]) for idx, e in enumerate(real_out)]) + return tuple(wrap_with_proxy(e, proxy_out[idx]) for idx, e in enumerate(real_out)) elif isinstance(real_out, list): return list([wrap_with_proxy(e, proxy_out[idx]) for idx, e in enumerate(real_out)]) elif isinstance(real_out, torch.Tensor): @@ -157,7 +158,7 @@ def create_arg(self, a: Any): return super().create_arg(a) -def pythonkey_trace( +def dispatch_trace( root: Union[torch.nn.Module, Callable], concrete_args: Optional[Dict[str, Any]] = None ) -> GraphModule: tracer = PythonKeyTracer() @@ -194,8 +195,8 @@ def make_fx(f, decomposition_table={}): @functools.wraps(f) def wrapped(*args): phs = pytree.tree_map(lambda x: fx.PH, args) - with pythonkey_decompose(decomposition_table): - t = pythonkey_trace(wrap_key(f, args), concrete_args=tuple(phs)) + with decompose(decomposition_table): + t = dispatch_trace(wrap_key(f, args), concrete_args=tuple(phs)) return t return wrapped From a8ef585c5b3e6858c7c542c837b0e9d0fe7167dc Mon Sep 17 00:00:00 2001 From: Horace He Date: Sat, 23 Apr 2022 01:09:15 +0000 Subject: [PATCH 3/7] Add a test --- test/test_fx_experimental.py | 9 +++++++++ .../proxy_tensor.py} | 15 ++++++++++----- 2 files changed, 19 insertions(+), 5 deletions(-) rename torch/fx/{_proxy_tensor.py => experimental/proxy_tensor.py} (94%) diff --git a/test/test_fx_experimental.py b/test/test_fx_experimental.py index 81e6faf81734..18c12a99f6e1 100644 --- a/test/test_fx_experimental.py +++ b/test/test_fx_experimental.py @@ -27,6 +27,7 @@ from torch.fx.experimental.rewriter import RewritingTracer from torch.fx.experimental.schema_type_annotation import AnnotateTypesWithSchema from torch.fx.experimental.meta_tracer import MetaTracer +from torch.fx.experimental.proxy_tensor import make_fx from torch.fx.graph_module import GraphModule from torch.fx.node import Node from torch.fx.operator_schemas import ( @@ -689,6 +690,14 @@ def forward(self, x): gm = torch.fx.GraphModule(mttm, graph) torch.testing.assert_close(gm(x), mttm(x)) + def test_proxy_tensor(self): + def f(x): + val = x.cos().cos().sum() + return torch.autograd.grad(val, x) + + traced_graph = make_fx(f)(torch.randn(3, requires_grad=True)) + inp = torch.randn(3, requires_grad=True) + torch.testing.assert_close(traced_graph(inp), f(inp)) def test_call_to_assert_with_msg(self): class M(torch.nn.Module): diff --git a/torch/fx/_proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py similarity index 94% rename from torch/fx/_proxy_tensor.py rename to torch/fx/experimental/proxy_tensor.py index 3d51af7f85b3..8094e488ae66 100644 --- a/torch/fx/_proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -56,15 +56,20 @@ def __new__(cls, elem, proxy): # ProxyTensor boundary. # assert not elem.requires_grad or not torch.is_grad_enabled() - r = torch.Tensor._make_subclass(cls, elem, elem.requires_grad) + # Hack to deal with super().__new__ not working for sparse tensors + if elem.is_sparse: + proxy.node.meta['tensor_meta'] = {} + r = torch.Tensor._make_subclass(cls, elem, elem.requires_grad) + else: + r = super().__new__(cls, elem) + proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(r) r.proxy = proxy - proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(r) + return r def __repr__(self): - # This is a bit goofy but whatever. Should fix up _tensor_str.py to - # work on subclasses when it calls tolist - return f"ProxyTensor({torch.Tensor._make_subclass(torch.Tensor, self)})" + with no_dispatch(): + return f"ProxyTensor({self.as_subclass(torch.Tensor)})" __torch_function__ = _disabled_torch_function_impl From b672087665eec2f9ec80a35ebfd2f82c00f0accc Mon Sep 17 00:00:00 2001 From: Horace He Date: Wed, 27 Apr 2022 08:19:30 +0000 Subject: [PATCH 4/7] responded to comments --- torch/fx/experimental/proxy_tensor.py | 44 ++++++--------------------- 1 file changed, 10 insertions(+), 34 deletions(-) diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index 8094e488ae66..78140538a0c3 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -41,21 +41,10 @@ def decompose(decomposition_table): class ProxyTensor(torch.Tensor): elem: torch.Tensor - __slots__ = ['elem', 'proxy'] + __slots__ = ['proxy'] @staticmethod def __new__(cls, elem, proxy): - # Wrapping something in ProxyTensor implicitly detaches - # gradients. If something required grad, we will collect it as if it - # were a leaf. A consequence of detaching in this way is you - # need to maintain a parameter cache when translating tensors - # into ProxyTensor, so you don't create multiple copies of - # a gradient (they are aliased, but they would count as independent - # leaves). An alternate strategy would be to avoid implicitly - # detaching and instead "catch" gradients as they exit the - # ProxyTensor boundary. - # assert not elem.requires_grad or not torch.is_grad_enabled() - # Hack to deal with super().__new__ not working for sparse tensors if elem.is_sparse: proxy.node.meta['tensor_meta'] = {} @@ -69,7 +58,7 @@ def __new__(cls, elem, proxy): def __repr__(self): with no_dispatch(): - return f"ProxyTensor({self.as_subclass(torch.Tensor)})" + return f"ProxyTensor({self.as_subclass(torch.Tensor)}, proxy={self.proxy})" __torch_function__ = _disabled_torch_function_impl @@ -79,7 +68,7 @@ def __torch_dispatch__(cls, func_overload, types, args=(), kwargs=None): if func_overload in CURRENT_DECOMPOSITION_TABLE: return CURRENT_DECOMPOSITION_TABLE[func_overload](*args, **kwargs) # Commenting this out for now since it causes some spurious failures (such as error checking) - if func == aten._local_scalar_dense: + if func_overload == aten._local_scalar_dense.default: raise RuntimeError("It appears that you're trying to get value out of a tracing tensor - erroring out! " "It's likely that this is caused by data-dependent control flow or similar.") @@ -100,16 +89,14 @@ def unwrap_proxy(e): real_out = func_overload(*args, **kwargs) def wrap_with_proxy(e, proxy): - # Some ops (like native_batch_norm_backward) return undefined tensors that get - # converted into None in python. - # As the function signature expects tensors, if we directly return these None - # tensors back to C++, we'll error. - if e is None: - e = torch.empty(()) if type(e) == torch.Tensor: return ProxyTensor(e, proxy) else: return e + + # Unfortunately, tree_map cannot directly be used here. As the resulting + # object may be a proxy that represents a tuple, we may need to + # explicitly unwrap the proxy by simulating the flattening operations. if isinstance(real_out, tuple): return tuple(wrap_with_proxy(e, proxy_out[idx]) for idx, e in enumerate(real_out)) elif isinstance(real_out, list): @@ -124,25 +111,14 @@ class PythonKeyTracer(Tracer): def __init__(self): super().__init__() + # In general, we don't want to make modules leaves. In priniple, users of + # this tracer might want to override this in order to turn a couple specific + # modules into leaves in the tracd graph. def call_module( self, m: torch.nn.Module, forward: Callable[..., Any], args: Tuple[Any, ...], kwargs: Dict[str, Any] ) -> Any: return forward(*args, **kwargs) - def _module_getattr(self, attr, attr_val, parameter_proxy_cache): - if isinstance(attr_val, torch.nn.Parameter): - for n, p in self.root.named_parameters(): - if attr_val is p: - if n not in parameter_proxy_cache: - proxy = self.create_proxy('get_attr', n, (), {}) - parameter_proxy_cache[n] = ProxyTensor(attr_val, proxy) - return parameter_proxy_cache[n] - return attr_val - return attr_val - - # We need to do this so that parameters entering the `make_fx` context have - # a reference to them (and also have requires_grad set on them correctly - # I'm not actually sure if this is the right thing to do ... def create_arg(self, a: Any): if isinstance(a, torch.nn.Parameter): for n, p in self.root.named_parameters(): From 51affbf433254a100a2bc0f83c618ebd45ebe480 Mon Sep 17 00:00:00 2001 From: Horace He Date: Tue, 3 May 2022 07:32:41 +0000 Subject: [PATCH 5/7] fix minor comments --- torch/fx/experimental/proxy_tensor.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index 78140538a0c3..894c4c97660e 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -13,6 +13,7 @@ from torch.fx.passes.shape_prop import _extract_tensor_metadata from contextlib import contextmanager +__all__ = ["ProxyTensor", "PythonKeyTracer", "dispatch_trace", "make_fx"] aten = torch.ops.aten CURRENT_DECOMPOSITION_TABLE = {} @@ -67,7 +68,6 @@ def __torch_dispatch__(cls, func_overload, types, args=(), kwargs=None): func = func_overload.overloadpacket if func_overload in CURRENT_DECOMPOSITION_TABLE: return CURRENT_DECOMPOSITION_TABLE[func_overload](*args, **kwargs) - # Commenting this out for now since it causes some spurious failures (such as error checking) if func_overload == aten._local_scalar_dense.default: raise RuntimeError("It appears that you're trying to get value out of a tracing tensor - erroring out! " "It's likely that this is caused by data-dependent control flow or similar.") @@ -111,9 +111,9 @@ class PythonKeyTracer(Tracer): def __init__(self): super().__init__() - # In general, we don't want to make modules leaves. In priniple, users of + # In general, we don't want to make modules leaves. In principle, users of # this tracer might want to override this in order to turn a couple specific - # modules into leaves in the tracd graph. + # modules into leaves in the traced graph. def call_module( self, m: torch.nn.Module, forward: Callable[..., Any], args: Tuple[Any, ...], kwargs: Dict[str, Any] ) -> Any: From 7fa4527d4fd35bd632a2e88f8ad1cc190ac748b6 Mon Sep 17 00:00:00 2001 From: Horace He Date: Tue, 3 May 2022 09:02:54 +0000 Subject: [PATCH 6/7] fix some other stuff --- torch/fx/experimental/proxy_tensor.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index 894c4c97660e..f4687eb0aeea 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -16,12 +16,12 @@ __all__ = ["ProxyTensor", "PythonKeyTracer", "dispatch_trace", "make_fx"] aten = torch.ops.aten -CURRENT_DECOMPOSITION_TABLE = {} +CURRENT_DECOMPOSITION_TABLE: Dict[torch._ops.OpOverload, Callable] = {} @contextmanager def no_dispatch(): - guard = torch._C._DisableTorchDispatch() + guard = torch._C._DisableTorchDispatch() # type: ignore[attr-defined] try: yield finally: @@ -40,9 +40,7 @@ def decompose(decomposition_table): class ProxyTensor(torch.Tensor): - elem: torch.Tensor - - __slots__ = ['proxy'] + proxy: fx.Proxy @staticmethod def __new__(cls, elem, proxy): @@ -51,15 +49,15 @@ def __new__(cls, elem, proxy): proxy.node.meta['tensor_meta'] = {} r = torch.Tensor._make_subclass(cls, elem, elem.requires_grad) else: - r = super().__new__(cls, elem) + r = super().__new__(cls, elem) # type: ignore[call-arg] proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(r) - r.proxy = proxy + r.proxy = proxy # type: ignore[attr-defined] return r def __repr__(self): with no_dispatch(): - return f"ProxyTensor({self.as_subclass(torch.Tensor)}, proxy={self.proxy})" + return f"ProxyTensor({self.as_subclass(torch.Tensor)}, proxy={self.proxy})" # type: ignore[arg-type] __torch_function__ = _disabled_torch_function_impl @@ -140,7 +138,7 @@ def create_arg(self, a: Any): def dispatch_trace( - root: Union[torch.nn.Module, Callable], concrete_args: Optional[Dict[str, Any]] = None + root: Union[torch.nn.Module, Callable], concrete_args: Optional[Tuple[Any, ...]] = None ) -> GraphModule: tracer = PythonKeyTracer() graph = tracer.trace(root, concrete_args) @@ -172,10 +170,12 @@ def wrapped(*args): return wrapped -def make_fx(f, decomposition_table={}): +def make_fx(f, decomposition_table=None): + if decomposition_table is None: + decomposition_table = {} @functools.wraps(f) def wrapped(*args): - phs = pytree.tree_map(lambda x: fx.PH, args) + phs = pytree.tree_map(lambda x: fx.PH, args) #type: ignore[attr-defined] with decompose(decomposition_table): t = dispatch_trace(wrap_key(f, args), concrete_args=tuple(phs)) return t From c09f17c3c340d5963241513b41fd837186da3347 Mon Sep 17 00:00:00 2001 From: Horace He Date: Tue, 3 May 2022 17:36:42 +0000 Subject: [PATCH 7/7] fix flake issues --- torch/fx/experimental/proxy_tensor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index f4687eb0aeea..ea6eceec8b1b 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -173,9 +173,10 @@ def wrapped(*args): def make_fx(f, decomposition_table=None): if decomposition_table is None: decomposition_table = {} + @functools.wraps(f) def wrapped(*args): - phs = pytree.tree_map(lambda x: fx.PH, args) #type: ignore[attr-defined] + phs = pytree.tree_map(lambda x: fx.PH, args) # type: ignore[attr-defined] with decompose(decomposition_table): t = dispatch_trace(wrap_key(f, args), concrete_args=tuple(phs)) return t