From 609f76f27a11b12fbf04ed51538fb9d60e135c00 Mon Sep 17 00:00:00 2001 From: James Reed Date: Mon, 1 Feb 2021 11:34:54 -0800 Subject: [PATCH] [WIP][FX] Add Interpreter and Transformer (#50420) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/50420 Test Plan: Imported from OSS Reviewed By: zdevito Differential Revision: D25880330 Pulled By: jamesr66a fbshipit-source-id: 27d34888e36e39924821fed891d79f969237a104 --- docs/source/fx.rst | 6 + test/test_fx.py | 143 +++++++++- torch/fx/__init__.py | 1 + torch/fx/__init__.pyi | 1 + torch/fx/experimental/shape_prop.py | 54 +--- torch/fx/interpreter.py | 388 ++++++++++++++++++++++++++++ 6 files changed, 548 insertions(+), 45 deletions(-) create mode 100644 torch/fx/interpreter.py diff --git a/docs/source/fx.rst b/docs/source/fx.rst index 4830b76954ad..b2ec1d7e254a 100644 --- a/docs/source/fx.rst +++ b/docs/source/fx.rst @@ -322,3 +322,9 @@ API Reference :members: .. autoclass:: torch.fx.Proxy + +.. autoclass:: torch.fx.Interpreter + :members: + +.. autoclass:: torch.fx.Transformer + :members: diff --git a/test/test_fx.py b/test/test_fx.py index f380259f5371..4e6dae1c288e 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -13,7 +13,8 @@ from math import sqrt from pathlib import Path from torch.multiprocessing import Process -from torch.fx import symbolic_trace, Proxy, Node, GraphModule, Tracer, Graph, wrap +from torch.fx import symbolic_trace, Proxy, Node, GraphModule, Interpreter, Tracer, Transformer, Graph, wrap +from torch.fx.node import Target from torch.fx.experimental import shape_prop from torch.fx.immutable_collections import immutable_dict, immutable_list from copy import deepcopy @@ -957,6 +958,146 @@ def forward(self, x): # Test shape propogation and make sure results match actual self.assertEqual(output_shape, ref_out.shape) + def test_interpreter(self): + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.param = torch.nn.Parameter(torch.rand(3, 4)) + self.linear = torch.nn.Linear(4, 5) + + def forward(self, x): + return self.linear(x + self.param).clamp(min=0.0, max=1.0) + + m = MyModule() + gm = torch.fx.symbolic_trace(m) + + interpreter = Interpreter(gm) + input = torch.randn(3, 4) + self.assertEqual(interpreter.run(input), gm(input)) + self.assertEqual(interpreter.run(input), m(input)) + + def test_interpreter_run_node_override(self): + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.param = torch.nn.Parameter(torch.rand(3, 4)) + self.linear = torch.nn.Linear(4, 5) + + def forward(self, x): + return self.linear(x + self.param).clamp(min=0.0, max=1.0) + + m = MyModule() + gm = torch.fx.symbolic_trace(m) + + class RunNodeInterpreter(Interpreter): + def __init__(self, module): + super().__init__(module) + + def run_node(self, n : Node) -> Any: + result = super().run_node(n) + n.cached_value = result + return result + + input = torch.randn(3, 4) + RunNodeInterpreter(gm).run(input) + for node in gm.graph.nodes: + assert hasattr(node, 'cached_value') + + def test_interpreter_onthefly_swap(self): + + def fn(x): + return torch.sigmoid(x).neg() + + gm = torch.fx.symbolic_trace(fn) + + class NegSigmSwapInterpreter(Interpreter): + def call_function(self, target : Target, args : Tuple, kwargs : Dict) -> Any: + if target == torch.sigmoid: + return torch.neg(*args, **kwargs) + return super().call_function(n) + + def call_method(self, target : Target, args : Tuple, kwargs : Dict) -> Any: + if target == 'neg': + call_self, *args_tail = args + return call_self.sigmoid(*args_tail, **kwargs) + return super().call_method(n) + + input = torch.randn(3, 4) + result = NegSigmSwapInterpreter(gm).run(input) + self.assertEqual(result, torch.neg(input).sigmoid()) + + def test_interpreter_partial_eval(self): + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.param = torch.nn.Parameter(torch.rand(3, 4)) + self.linear = torch.nn.Linear(4, 5) + + def forward(self, x): + return self.linear(x + self.param).clamp(min=0.0, max=1.0) + + gm = torch.fx.symbolic_trace(MyModule()) + interp = Interpreter(gm) + env = {} + for node in gm.graph.nodes: + if node.op == 'call_module' and node.target == 'linear': + env[node] = torch.arange(0, 12, 1).reshape(3, 4) - 6.0 + break + assert len(env) == 1 + x = torch.randn(3, 4) + result = interp.run(x, initial_env=env) + self.assertEqual(result, (torch.arange(0, 12, 1).reshape(3, 4) - 6.0).clamp(0.0, 1.0)) + + def test_interpreter_star_args(self): + def with_star_args(x, *args): + return x + args[0] + + gm = torch.fx.symbolic_trace(with_star_args) + interp = Interpreter(gm) + result = interp.run(torch.ones(3, 4), torch.ones(3, 4), torch.rand(3, 4)) + self.assertEqual(result, torch.ones(3, 4) * 2.0) + + def test_transformer_noop(self): + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.param = torch.nn.Parameter(torch.rand(3, 4)) + self.linear = torch.nn.Linear(4, 5) + + def forward(self, x): + return self.linear(x + self.param).clamp(min=0.0, max=1.0) + + m = MyModule() + gm = torch.fx.symbolic_trace(m) + + new_gm = Transformer(gm).transform() + + input = torch.randn(3, 4) + self.assertEqual(new_gm(input), gm(input)) + + def test_transformer_op_swap(self): + + def fn(x): + return torch.sigmoid(x).neg() + + gm = torch.fx.symbolic_trace(fn) + + class NegSigmSwapXformer(Transformer): + def call_function(self, target : Target, args : Tuple, kwargs : Dict) -> Any: + if target == torch.sigmoid: + return torch.neg(*args, **kwargs) + return super().call_function(n) + + def call_method(self, target : Target, args : Tuple, kwargs : Dict) -> Any: + if target == 'neg': + call_self, *args_tail = args + return call_self.sigmoid(*args_tail, **kwargs) + return super().call_method(n) + + transformed = NegSigmSwapXformer(gm).transform() + input = torch.randn(3, 4) + self.assertEqual(transformed(input), torch.neg(input).sigmoid()) + def test_fn_type_annotations(self): class Foo(torch.nn.Module): def forward(self, p : Pair, z : torch.Tensor, i : int) -> Dict[str, torch.Tensor]: diff --git a/torch/fx/__init__.py b/torch/fx/__init__.py index df87e6514fdb..e5fe2e2eefee 100644 --- a/torch/fx/__init__.py +++ b/torch/fx/__init__.py @@ -82,3 +82,4 @@ def forward(self, x): from .graph import Graph from .node import Node, map_arg from .proxy import Proxy +from .interpreter import Interpreter as Interpreter, Transformer as Transformer diff --git a/torch/fx/__init__.pyi b/torch/fx/__init__.pyi index 9939f25f4c14..3143095c9379 100644 --- a/torch/fx/__init__.pyi +++ b/torch/fx/__init__.pyi @@ -3,3 +3,4 @@ from .graph_module import GraphModule as GraphModule from .node import Node as Node, map_arg as map_arg from .proxy import Proxy as Proxy from .symbolic_trace import Tracer as Tracer, symbolic_trace as symbolic_trace, wrap as wrap +from .interpreter import Interpreter as Interpreter, Transformer as Transformer diff --git a/torch/fx/experimental/shape_prop.py b/torch/fx/experimental/shape_prop.py index 52264796c7d4..a583f776ee7c 100644 --- a/torch/fx/experimental/shape_prop.py +++ b/torch/fx/experimental/shape_prop.py @@ -1,51 +1,17 @@ import torch import torch.fx from torch.fx.node import Node +from typing import Any -from typing import Dict +class ShapeProp(torch.fx.Interpreter): + def run_node(self, n : Node) -> Any: + result = super().run_node(n) -class ShapeProp: - def __init__(self, mod): - self.mod = mod - self.graph = mod.graph - self.modules = dict(self.mod.named_modules()) + if isinstance(result, torch.Tensor): + n.shape = result.shape # type: ignore + n.dtype = result.dtype # type: ignore - def propagate(self, *args): - args_iter = iter(args) - env : Dict[str, Node] = {} - - def load_arg(a): - return torch.fx.node.map_arg(a, lambda n: env[n.name]) - - def fetch_attr(target : str): - target_atoms = target.split('.') - attr_itr = self.mod - for i, atom in enumerate(target_atoms): - if not hasattr(attr_itr, atom): - raise RuntimeError(f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}") - attr_itr = getattr(attr_itr, atom) - return attr_itr - - for node in self.graph.nodes: - if node.op == 'placeholder': - result = next(args_iter) - elif node.op == 'get_attr': - result = fetch_attr(node.target) - elif node.op == 'call_function': - result = node.target(*load_arg(node.args), **load_arg(node.kwargs)) - elif node.op == 'call_method': - self_obj, *args = load_arg(node.args) - kwargs = load_arg(node.kwargs) - result = getattr(self_obj, node.target)(*args, **kwargs) - elif node.op == 'call_module': - result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs)) - elif node.op == 'output': - return load_arg(node.args[0]) + return result - if isinstance(result, torch.Tensor): - node.shape = result.shape - node.dtype = result.dtype - - env[node.name] = result - - return None + def propagate(self, *args): + return super().run(*args) diff --git a/torch/fx/interpreter.py b/torch/fx/interpreter.py new file mode 100644 index 000000000000..06c893560766 --- /dev/null +++ b/torch/fx/interpreter.py @@ -0,0 +1,388 @@ +from .graph_module import GraphModule +from .graph import Graph +from .node import Argument, Node, Target, map_arg +from .proxy import Proxy +from .symbolic_trace import Tracer +from typing import Any, Dict, Iterator, Optional, Tuple + +class Interpreter: + """ + An Interpreter executes an FX graph Node-by-Node. This pattern + can be useful for many things, including writing code + transformations as well as analysis passes. + + Methods in the Interpreter class can be overridden to customize + the behavior of execution. The map of overrideable methods + in terms of call hierarchy:: + + run() + +-- run_node + +-- placeholder() + +-- get_attr() + +-- call_function() + +-- call_method() + +-- call_module() + +-- output() + + Example: + + Suppose we want to swap all instances of ``torch.neg`` with + ``torch.sigmoid`` and vice versa (including their ``Tensor`` + method equivalents). We could subclass Interpreter like so:: + + class NegSigmSwapInterpreter(Interpreter): + def call_function(self, target : Target, + args : Tuple, kwargs : Dict) -> Any: + if target == torch.sigmoid: + return torch.neg(*args, **kwargs) + return super().call_function(n) + + def call_method(self, target : Target, + args : Tuple, kwargs : Dict) -> Any: + if target == 'neg': + call_self, *args_tail = args + return call_self.sigmoid(*args_tail, **kwargs) + return super().call_method(n) + + def fn(x): + return torch.sigmoid(x).neg() + + gm = torch.fx.symbolic_trace(fn) + input = torch.randn(3, 4) + result = NegSigmSwapInterpreter(gm).run(input) + torch.testing.assert_allclose(result, torch.neg(input).sigmoid()) + + Args: + module (GraphModule): The module to be executed + """ + def __init__(self, module : GraphModule): + assert isinstance(module, GraphModule) + self.module = module + self.submodules = dict(self.module.named_modules()) + self.env : Dict[Node, Any] = {} + + def run(self, *args, initial_env : Optional[Dict[Node, Any]] = None) -> Any: + """ + Run `module` via interpretation and return the result. + + Args: + *args: The arguments to the Module to run, in positional order + initial_env (Optional[Dict[Node, Any]]): An optional starting environment for execution. + This is a dict mapping `Node` to any value. This can be used, for example, to + pre-populate results for certain `Nodes` so as to do only partial evaluation within + the interpreter. + + Returns: + Any: The value returned from executing the Module + """ + self.env = initial_env if initial_env else {} + + # Positional function args are consumed left-to-right by + # `placeholder` nodes. Use an iterator to keep track of + # position and extract those values. + self.args_iter : Iterator[Any] = iter(args) + + for node in self.module.graph.nodes: + if node in self.env: + # Short circuit if we have this value. This could + # be used, for example, for partial evaluation + # where the caller has pre-populated `env` with + # values for a subset of the program. + continue + + self.env[node] = self.run_node(node) + + if node.op == 'output': + output_val = self.env[node] + return output_val + + def run_node(self, n : Node) -> Any: + """ + Run a specific node ``n`` and return the result. + Calls into placeholder, get_attr, call_function, + call_method, call_module, or output depending + on ``node.op`` + + Args: + n (Node): The Node to execute + + Returns: + Any: The result of executing ``n`` + """ + args, kwargs = self.fetch_args_kwargs_from_env(n) + assert isinstance(args, tuple) + assert isinstance(kwargs, dict) + return getattr(self, n.op)(n.target, args, kwargs) + + # Main Node running APIs + + def placeholder(self, target : 'Target', args : Tuple[Any], kwargs : Dict[str, Any]) -> Any: + """ + Execute a ``placeholder`` node. Note that this is stateful: + ``Interpreter`` maintains an internal iterator over + arguments passed to ``run`` and this method returns + next() on that iterator. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Returns: + Any: The argument value that was retrieved. + """ + assert isinstance(target, str) + if target.startswith('*'): + # For a starred parameter e.g. `*args`, retrieve all + # remaining values from the args list. + return list(self.args_iter) + else: + return next(self.args_iter) + + def get_attr(self, target : 'Target', args : Tuple[Any], kwargs : Dict[str, Any]) -> Any: + """ + Execute a ``get_attr`` node. Will retrieve an attribute + value from the ``Module`` hierarchy of ``self.module``. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return: + Any: The value of the attribute that was retrieved + """ + assert isinstance(target, str) + return self.fetch_attr(target) + + def call_function(self, target : 'Target', args : Tuple[Any], kwargs : Dict[str, Any]) -> Any: + """ + Execute a ``call_function`` node and return the result. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return + Any: The value returned by the function invocation + """ + assert not isinstance(target, str) + + # Execute the function and return the result + return target(*args, **kwargs) + + def call_method(self, target : 'Target', args : Tuple[Any], kwargs : Dict[str, Any]) -> Any: + """ + Execute a ``call_method`` node and return the result. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return + Any: The value returned by the method invocation + """ + # args[0] is the `self` object for this method call + self_obj, *args_tail = args # type: ignore + + # Execute the method and return the result + assert isinstance(target, str) + return getattr(self_obj, target)(*args_tail, **kwargs) + + def call_module(self, target : 'Target', args : Tuple[Any], kwargs : Dict[str, Any]) -> Any: + """ + Execute a ``call_module`` node and return the result. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return + Any: The value returned by the module invocation + """ + # Retrieve executed args and kwargs values from the environment + + # Execute the method and return the result + assert isinstance(target, str) + submod = self.fetch_attr(target) + + return submod(*args, **kwargs) + + def output(self, target : 'Target', args : Tuple[Any], kwargs : Dict[str, Any]) -> Any: + """ + Execute an ``output`` node. This really just retrieves + the value referenced by the ``output`` node and returns it. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return: + Any: The return value referenced by the output node + """ + return args[0] + + # Helper methods + + def fetch_attr(self, target : str): + """ + Fetch an attribute from the ``Module`` hierarchy of ``self.module``. + + Args: + target (str): The fully-qualfiied name of the attribute to fetch + + Return: + Any: The value of the attribute. + """ + target_atoms = target.split('.') + attr_itr = self.module + for i, atom in enumerate(target_atoms): + if not hasattr(attr_itr, atom): + raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}") + attr_itr = getattr(attr_itr, atom) + return attr_itr + + def fetch_args_kwargs_from_env(self, n : Node) -> Tuple[Tuple, Dict]: + """ + Fetch the concrete values of ``args`` and ``kwargs`` of node ``n`` + from the current execution environment. + + Args: + n (Node): The node for which ``args`` and ``kwargs`` should be fetched. + + Return: + Tuple[Tuple, Dict]: ``args`` and ``kwargs`` with concrete values for ``n``. + """ + args = self.map_nodes_to_values(n.args, n) + assert isinstance(args, tuple) + kwargs = self.map_nodes_to_values(n.kwargs, n) + assert isinstance(kwargs, dict) + return args, kwargs + + def map_nodes_to_values(self, args : Argument, n : Node) -> Argument: + """ + Recursively descend through ``args`` and look up the concrete value + for each ``Node`` in the current execution environment. + + Args: + args (Argument): Data structure within which to look up concrete values + + n (Node): Node to which ``args`` belongs. This is only used for error reporting. + """ + def load_arg(n_arg : Node) -> Any: + if n_arg not in self.env: + raise RuntimeError(f'Node {n} referenced nonexistent value {n_arg}! Run Graph.lint() ' + f'to diagnose such issues') + return self.env[n_arg] + return map_arg(args, load_arg) + +class Transformer(Interpreter): + """ + ``Transformer`` is a special type of interpreter that produces a + new ``Module``. It exposes a ``transform()`` method that returns + the transformed ``Module``. ``Transformer`` does not require + arguments to run, as ``Interpreter`` does. ``Transformer`` works + entirely symbolically. + + Example: + + Suppose we want to swap all instances of ``torch.neg`` with + ``torch.sigmoid`` and vice versa (including their ``Tensor`` + method equivalents). We could subclass ``Transformer`` like so:: + + class NegSigmSwapXformer(Transformer): + def call_function(self, target : 'Target', args : Tuple[Any], kwargs : Dict[str, Any]) -> Any: + if target == torch.sigmoid: + return torch.neg(*args, **kwargs) + return super().call_function(n) + + def call_method(self, target : 'Target', args : Tuple[Any], kwargs : Dict[str, Any]) -> Any: + if target == 'neg': + call_self, *args_tail = args + return call_self.sigmoid(*args_tail, **kwargs) + return super().call_method(n) + + def fn(x): + return torch.sigmoid(x).neg() + + gm = torch.fx.symbolic_trace(fn) + + transformed : torch.nn.Module = NegSigmSwapXformer(gm).transform() + input = torch.randn(3, 4) + torch.testing.assert_allclose(transformed(input), torch.neg(input).sigmoid()) + + Args: + module (GraphModule): The ``Module`` to be transformed. + """ + def __init__(self, module): + super().__init__(module) + self.new_graph = Graph() + + class TransformerTracer(Tracer): + def __init__(self, graph: Graph): + super().__init__() + self.graph = graph + + def is_leaf_module(self, _, __) -> bool: + return True + self.tracer = TransformerTracer(self.new_graph) + self.tracer.root = module + + def placeholder(self, target : 'Target', args : Tuple[Any], kwargs : Dict[str, Any]) -> Proxy: + """ + Execute a ``placeholder`` node. In ``Transformer``, this is + overridden to insert a new ``placeholder`` into the output + graph. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + """ + assert isinstance(target, str) + return Proxy(self.new_graph.placeholder(target), self.tracer) + + def get_attr(self, target : 'Target', args : Tuple[Any], kwargs : Dict[str, Any]) -> Proxy: + """ + Execute a ``get_attr`` node. In ``Transformer``, this is + overridden to insert a new ``get_attr`` node into the output + graph. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + """ + assert isinstance(target, str) + return Proxy(self.new_graph.get_attr(target), self.tracer) + + def transform(self) -> GraphModule: + """ + Transform ``self.module`` and return the transformed + ``GraphModule``. + """ + result = super().run() + if result is not None: + assert isinstance(result, Proxy) + self.new_graph.output(result.node) + return GraphModule(self.module, self.new_graph)