Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "Revert D24024606: [FX] Shape propagation example" #45637

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
27 changes: 27 additions & 0 deletions test/test_fx.py
Expand Up @@ -7,6 +7,7 @@
from pathlib import Path
from torch.fx import symbolic_trace, Proxy, Node, GraphModule, Tracer, Graph
from torch.fx.experimental import GraphManipulation
from torch.fx.experimental import shape_prop

from torch.fx.proxy import TraceError

Expand Down Expand Up @@ -733,5 +734,31 @@ def test_wrong_topo(self):
with self.assertRaisesRegex(RuntimeError, 'was used before it has been defined'):
graph.lint()

def test_example_shape_prop(self):
class TestCase(torch.nn.Module):
def __init__(self):
super().__init__()
self.attr = torch.randn(3, 4)
self.submod = torch.nn.Linear(4, 4)

def forward(self, x):
return torch.neg(self.submod(x.relu() + self.attr))
tc = TestCase()
tc_traced = symbolic_trace(tc)
ref_out = tc_traced(torch.rand(3, 4))

# Make sure we're testing all opcodes
opcodes = set()
for node in tc_traced.graph.nodes:
opcodes.add(node.op)
self.assertEqual(opcodes, set(['placeholder', 'get_attr', 'call_function', 'call_method', 'call_module']))

# Test shape propogation and make sure results match actual
shape_prop.ShapeProp(tc_traced).propagate(torch.rand(3, 4))
self.assertEqual(tc_traced.graph.result.shape, ref_out.shape)




if __name__ == '__main__':
run_tests()
49 changes: 49 additions & 0 deletions torch/fx/experimental/shape_prop.py
@@ -0,0 +1,49 @@
import torch
import torch.fx
from torch.fx.node import Node

from typing import Dict

class ShapeProp:
def __init__(self, mod):
self.mod = mod
self.graph = mod.graph
self.modules = dict(self.mod.named_modules())

def propagate(self, *args):
args_iter = iter(args)
env : Dict[str, Node] = {}

def load_arg(a):
return torch.fx.graph.map_arg(a, lambda n: env[n.name])

def fetch_attr(target : str):
target_atoms = target.split('.')
attr_itr = self.mod
for i, atom in enumerate(target_atoms):
if not hasattr(attr_itr, atom):
raise RuntimeError(f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}")
attr_itr = getattr(attr_itr, atom)
return attr_itr

for node in self.graph.nodes:
if node.op == 'placeholder':
result = next(args_iter)
elif node.op == 'get_attr':
result = fetch_attr(node.target)
elif node.op == 'call_function':
result = node.target(*load_arg(node.args), **load_arg(node.kwargs))
elif node.op == 'call_method':
self_obj, *args = load_arg(node.args)
kwargs = load_arg(node.kwargs)
result = getattr(self_obj, node.target)(*args, **kwargs)
elif node.op == 'call_module':
result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs))

if isinstance(result, torch.Tensor):
node.shape = result.shape
node.dtype = result.dtype

env[node.name] = result

return load_arg(self.graph.result)