diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index 984c5fc1d853e..3b3d1f022d6e7 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -3109,6 +3109,25 @@ def fn(x, y): y = torch.randn(3, 4) self.assertEqual(opt_fn(x, y), fn(x, y)) + def test_methodcaller(self): + for name, args, kwargs in ( + ("size", (), {}), + ("size", (0,), {}), + ("add", (torch.randn(3, 4),), {}), + ("add", (torch.randn(3, 4),), {"alpha": 2.0}), + ): + with self.subTest(name=name, args=args, kwargs=kwargs): + + def fn(x, y): + caller = operator.methodcaller(name, *args, **kwargs) + return caller(x), caller(y) + + opt_fn = torch.compile(fullgraph=True)(fn) + + x = torch.randn(3, 4) + y = torch.randn(3, 4) + self.assertEqual(opt_fn(x, y), fn(x, y)) + def gen_random_range_args(self): args_count = random.randint(1, 3) args = [random.randint(-10, 10) for _ in range(args_count)] diff --git a/torch/_dynamo/polyfills/operator.py b/torch/_dynamo/polyfills/operator.py index bf84895bdd013..297e837d45a0e 100644 --- a/torch/_dynamo/polyfills/operator.py +++ b/torch/_dynamo/polyfills/operator.py @@ -12,7 +12,7 @@ # Most unary and binary operators are handled by BuiltinVariable (e.g., `pos`, `add`) -__all__ = ["attrgetter", "itemgetter"] +__all__ = ["attrgetter", "itemgetter", "methodcaller"] _T = TypeVar("_T") @@ -95,3 +95,15 @@ def getter(obj: Any) -> tuple[Any, ...]: # type: ignore[misc] return tuple(obj[item] for item in items) return getter + + +# Reference: https://docs.python.org/3/library/operator.html#operator.methodcaller +@substitute_in_graph(operator.methodcaller, is_embedded_type=True) # type: ignore[arg-type] +def methodcaller(name: str, /, *args: Any, **kwargs: Any) -> Callable[[Any], Any]: + if not isinstance(name, str): + raise TypeError("method name must be a string") + + def caller(obj: Any) -> Any: + return getattr(obj, name)(*args, **kwargs) + + return caller