Skip to content

Commit

Permalink
[WIP][FX] Add Interpreter and Transformer
Browse files Browse the repository at this point in the history
ghstack-source-id: 9d200c9512549670d0c78e7ba86bb38cf3f79b4a
Pull Request resolved: #50420
  • Loading branch information
James Reed committed Jan 29, 2021
1 parent dbfaf96 commit 964ab86
Show file tree
Hide file tree
Showing 6 changed files with 528 additions and 45 deletions.
6 changes: 6 additions & 0 deletions docs/source/fx.rst
Expand Up @@ -322,3 +322,9 @@ API Reference
:members:

.. autoclass:: torch.fx.Proxy

.. autoclass:: torch.fx.Interpreter
:members:

.. autoclass:: torch.fx.Transformer
:members:
143 changes: 142 additions & 1 deletion test/test_fx.py
Expand Up @@ -13,7 +13,8 @@
from math import sqrt
from pathlib import Path
from torch.multiprocessing import Process
from torch.fx import symbolic_trace, Proxy, Node, GraphModule, Tracer, Graph, wrap
from torch.fx import symbolic_trace, Proxy, Node, GraphModule, Interpreter, Tracer, Transformer, Graph, wrap
from torch.fx.node import Target
from torch.fx.experimental import shape_prop
from torch.fx.immutable_collections import immutable_dict, immutable_list
from copy import deepcopy
Expand Down Expand Up @@ -957,6 +958,146 @@ def forward(self, x):
# Test shape propogation and make sure results match actual
self.assertEqual(output_shape, ref_out.shape)

def test_interpreter(self):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4))
self.linear = torch.nn.Linear(4, 5)

def forward(self, x):
return self.linear(x + self.param).clamp(min=0.0, max=1.0)

m = MyModule()
gm = torch.fx.symbolic_trace(m)

interpreter = Interpreter(gm)
input = torch.randn(3, 4)
self.assertEqual(interpreter.run(input), gm(input))
self.assertEqual(interpreter.run(input), m(input))

def test_interpreter_run_node_override(self):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4))
self.linear = torch.nn.Linear(4, 5)

def forward(self, x):
return self.linear(x + self.param).clamp(min=0.0, max=1.0)

m = MyModule()
gm = torch.fx.symbolic_trace(m)

class RunNodeInterpreter(Interpreter):
def __init__(self, module):
super().__init__(module)

def run_node(self, n : Node) -> Any:
result = super().run_node(n)
n.cached_value = result
return result

input = torch.randn(3, 4)
RunNodeInterpreter(gm).run(input)
for node in gm.graph.nodes:
assert hasattr(node, 'cached_value')

def test_interpreter_onthefly_swap(self):

def fn(x):
return torch.sigmoid(x).neg()

gm = torch.fx.symbolic_trace(fn)

class NegSigmSwapInterpreter(Interpreter):
def call_function(self, target : Target, args : Tuple, kwargs : Dict) -> Any:
if target == torch.sigmoid:
return torch.neg(*args, **kwargs)
return super().call_function(n)

def call_method(self, target : Target, args : Tuple, kwargs : Dict) -> Any:
if target == 'neg':
call_self, *args_tail = args
return call_self.sigmoid(*args_tail, **kwargs)
return super().call_method(n)

input = torch.randn(3, 4)
result = NegSigmSwapInterpreter(gm).run(input)
self.assertEqual(result, torch.neg(input).sigmoid())

def test_interpreter_partial_eval(self):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4))
self.linear = torch.nn.Linear(4, 5)

def forward(self, x):
return self.linear(x + self.param).clamp(min=0.0, max=1.0)

gm = torch.fx.symbolic_trace(MyModule())
interp = Interpreter(gm)
env = {}
for node in gm.graph.nodes:
if node.op == 'call_module' and node.target == 'linear':
env[node] = torch.arange(0, 12, 1).reshape(3, 4) - 6.0
break
assert len(env) == 1
x = torch.randn(3, 4)
result = interp.run(x, initial_env=env)
self.assertEqual(result, (torch.arange(0, 12, 1).reshape(3, 4) - 6.0).clamp(0.0, 1.0))

def test_interpreter_star_args(self):
def with_star_args(x, *args):
return x + args[0]

gm = torch.fx.symbolic_trace(with_star_args)
interp = Interpreter(gm)
result = interp.run(torch.ones(3, 4), torch.ones(3, 4), torch.rand(3, 4))
self.assertEqual(result, torch.ones(3, 4) * 2.0)

def test_transformer_noop(self):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4))
self.linear = torch.nn.Linear(4, 5)

def forward(self, x):
return self.linear(x + self.param).clamp(min=0.0, max=1.0)

m = MyModule()
gm = torch.fx.symbolic_trace(m)

new_gm = Transformer(gm).transform()

input = torch.randn(3, 4)
self.assertEqual(new_gm(input), gm(input))

def test_transformer_op_swap(self):

def fn(x):
return torch.sigmoid(x).neg()

gm = torch.fx.symbolic_trace(fn)

class NegSigmSwapXformer(Transformer):
def call_function(self, target : Target, args : Tuple, kwargs : Dict) -> Any:
if target == torch.sigmoid:
return torch.neg(*args, **kwargs)
return super().call_function(n)

def call_method(self, target : Target, args : Tuple, kwargs : Dict) -> Any:
if target == 'neg':
call_self, *args_tail = args
return call_self.sigmoid(*args_tail, **kwargs)
return super().call_method(n)

transformed = NegSigmSwapXformer(gm).transform()
input = torch.randn(3, 4)
self.assertEqual(transformed(input), torch.neg(input).sigmoid())

def test_fn_type_annotations(self):
class Foo(torch.nn.Module):
def forward(self, p : Pair, z : torch.Tensor, i : int) -> Dict[str, torch.Tensor]:
Expand Down
1 change: 1 addition & 0 deletions torch/fx/__init__.py
Expand Up @@ -82,3 +82,4 @@ def forward(self, x):
from .graph import Graph
from .node import Node, map_arg
from .proxy import Proxy
from .interpreter import Interpreter as Interpreter, Transformer as Transformer
1 change: 1 addition & 0 deletions torch/fx/__init__.pyi
Expand Up @@ -3,3 +3,4 @@ from .graph_module import GraphModule as GraphModule
from .node import Node as Node, map_arg as map_arg
from .proxy import Proxy as Proxy
from .symbolic_trace import Tracer as Tracer, symbolic_trace as symbolic_trace, wrap as wrap
from .interpreter import Interpreter as Interpreter, Transformer as Transformer
54 changes: 10 additions & 44 deletions torch/fx/experimental/shape_prop.py
@@ -1,51 +1,17 @@
import torch
import torch.fx
from torch.fx.node import Node
from typing import Any

from typing import Dict
class ShapeProp(torch.fx.Interpreter):
def run_node(self, n : Node) -> Any:
result = super().run_node(n)

class ShapeProp:
def __init__(self, mod):
self.mod = mod
self.graph = mod.graph
self.modules = dict(self.mod.named_modules())
if isinstance(result, torch.Tensor):
n.shape = result.shape # type: ignore
n.dtype = result.dtype # type: ignore

def propagate(self, *args):
args_iter = iter(args)
env : Dict[str, Node] = {}

def load_arg(a):
return torch.fx.node.map_arg(a, lambda n: env[n.name])

def fetch_attr(target : str):
target_atoms = target.split('.')
attr_itr = self.mod
for i, atom in enumerate(target_atoms):
if not hasattr(attr_itr, atom):
raise RuntimeError(f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}")
attr_itr = getattr(attr_itr, atom)
return attr_itr

for node in self.graph.nodes:
if node.op == 'placeholder':
result = next(args_iter)
elif node.op == 'get_attr':
result = fetch_attr(node.target)
elif node.op == 'call_function':
result = node.target(*load_arg(node.args), **load_arg(node.kwargs))
elif node.op == 'call_method':
self_obj, *args = load_arg(node.args)
kwargs = load_arg(node.kwargs)
result = getattr(self_obj, node.target)(*args, **kwargs)
elif node.op == 'call_module':
result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs))
elif node.op == 'output':
return load_arg(node.args[0])
return result

if isinstance(result, torch.Tensor):
node.shape = result.shape
node.dtype = result.dtype

env[node.name] = result

return None
def propagate(self, *args):
return super().run(*args)

0 comments on commit 964ab86

Please sign in to comment.