Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Partition modules #98628

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
145 changes: 145 additions & 0 deletions test/fx/test_source_matcher_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# Owner(s): ["module: fx"]

import os
import sys
import unittest

import torch

pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch._dynamo.eval_frame import is_dynamo_supported
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions, check_subgraphs_connected
from torch.testing._internal.jit_utils import JitTestCase

class TestSourceMatcher(JitTestCase):
@unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
def test_module_partitioner_linear_relu_linear(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(3, 3)
self.relu = torch.nn.ReLU()
self.linear2 = torch.nn.Linear(3, 5)

def forward(self, x):
x = self.linear1(x)
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
return x

inputs = (torch.randn(3, 3),)
gm, _ = torch._dynamo.export(M(), *inputs, aten_graph=True)
gm.graph.eliminate_dead_code()

module_partitions = get_source_partitions(gm.graph, [torch.nn.Linear, torch.nn.ReLU])

self.assertEqual(len(module_partitions), 2)
self.assertEqual(len(module_partitions[torch.nn.Linear]), 3)
self.assertEqual(len(module_partitions[torch.nn.ReLU]), 1)

self.assertFalse(check_subgraphs_connected(module_partitions[torch.nn.Linear][0], module_partitions[torch.nn.ReLU][0]))
self.assertTrue(check_subgraphs_connected(module_partitions[torch.nn.Linear][1], module_partitions[torch.nn.ReLU][0]))
self.assertFalse(check_subgraphs_connected(module_partitions[torch.nn.Linear][2], module_partitions[torch.nn.ReLU][0]))

@unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
def test_module_partitioner_conv_relu_maxpool(self):
class M(torch.nn.Module):
def __init__(self, constant_tensor: torch.Tensor) -> None:
super().__init__()
self.constant_tensor = constant_tensor
self.conv1 = torch.nn.Conv2d(
in_channels=3, out_channels=16, kernel_size=3, padding=1
)
self.conv2 = torch.nn.Conv2d(
in_channels=16, out_channels=16, kernel_size=3, padding=1
)
self.conv3 = torch.nn.Conv2d(
in_channels=16, out_channels=16, kernel_size=3, padding=1
)
self.relu = torch.nn.ReLU()
self.maxpool = torch.nn.MaxPool2d(kernel_size=3)

def forward(self, x: torch.Tensor) -> torch.Tensor:
a = self.conv1(x)
b = self.conv2(a)
c = a + self.constant_tensor
z = self.conv3(b + c)
return self.maxpool(self.relu(z))

inputs = (torch.randn(1, 3, 256, 256),)
gm, _ = torch._dynamo.export(M(torch.ones(1, 16, 256, 256)), *inputs, aten_graph=True)
gm.graph.eliminate_dead_code()

module_partitions = get_source_partitions(gm.graph, [torch.nn.Conv2d, torch.nn.ReLU, torch.nn.MaxPool2d])

self.assertEqual(len(module_partitions), 3)
self.assertEqual(len(module_partitions[torch.nn.Conv2d]), 3)
self.assertEqual(len(module_partitions[torch.nn.ReLU]), 1)
self.assertEqual(len(module_partitions[torch.nn.MaxPool2d]), 1)

self.assertFalse(check_subgraphs_connected(module_partitions[torch.nn.Conv2d][0], module_partitions[torch.nn.ReLU][0]))
self.assertFalse(check_subgraphs_connected(module_partitions[torch.nn.Conv2d][1], module_partitions[torch.nn.ReLU][0]))
self.assertTrue(check_subgraphs_connected(module_partitions[torch.nn.Conv2d][2], module_partitions[torch.nn.ReLU][0]))
self.assertFalse(check_subgraphs_connected(module_partitions[torch.nn.MaxPool2d][0], module_partitions[torch.nn.ReLU][0]))
self.assertTrue(check_subgraphs_connected(module_partitions[torch.nn.ReLU][0], module_partitions[torch.nn.MaxPool2d][0]))

@unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
def test_module_partitioner_functional_conv_relu_conv(self):
class FunctionalConv2d(torch.nn.Module):
def __init__(self):
super().__init__()
self.stride = (1, 1)
self.padding = (0, 0)
self.dilation = (1, 1)
self.groups = 1

def forward(self, x, weight, bias):
return torch.nn.functional.conv2d(x, weight, bias, self.stride, self.padding, self.dilation, self.groups)

class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = FunctionalConv2d()
self.conv2 = FunctionalConv2d()

def forward(self, x, weight, bias):
x = self.conv1(x, weight, bias)
x = torch.nn.functional.relu(x)
x = self.conv2(x, weight, bias)
return x

inputs = (torch.randn(1, 3, 5, 5), torch.rand(3, 3, 3, 3), torch.rand(3))
gm, _ = torch._dynamo.export(M(), *inputs, aten_graph=True)
gm.graph.eliminate_dead_code()

module_partitions = get_source_partitions(gm.graph, [torch.nn.functional.conv2d])

self.assertEqual(len(module_partitions), 1)
self.assertEqual(len(module_partitions[torch.nn.functional.conv2d]), 2)

@unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
def test_module_partitioner_functional_linear_relu_linear(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, weight, bias):
x = torch.nn.functional.linear(x, weight, bias)
x = torch.nn.functional.linear(x, weight, bias)
x = torch.nn.functional.relu(x)
x = torch.nn.functional.linear(x, weight, bias)
x = torch.nn.functional.linear(x, weight, bias)
x = torch.nn.functional.relu(x)
return x

inputs = (torch.randn(1, 5), torch.rand((5, 5)), torch.zeros(5))
gm, _ = torch._dynamo.export(M(), *inputs, aten_graph=True)
gm.graph.eliminate_dead_code()

module_partitions = get_source_partitions(gm.graph, [torch.nn.functional.linear, torch.nn.functional.relu])

self.assertEqual(len(module_partitions), 2)
self.assertEqual(len(module_partitions[torch.nn.functional.linear]), 4)
self.assertEqual(len(module_partitions[torch.nn.functional.relu]), 2)
1 change: 1 addition & 0 deletions test/test_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from fx.test_common_passes import TestCommonPass # noqa: F401
from fx.test_cse_pass import TestCSEPass # noqa: F401
from fx.test_matcher_utils import TestMatcher # noqa: F401
from fx.test_source_matcher_utils import TestSourceMatcher # noqa: F401

from fx.test_gradual_type import AnnotationsTest # noqa: F401
from fx.test_gradual_type import TypeCheckerTest # noqa: F401
Expand Down
7 changes: 5 additions & 2 deletions torch/_dynamo/output_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1049,12 +1049,15 @@ def create_proxy(
rv.node.meta["nn_module_stack"] = nn_module_stack.copy()

if kind in {"call_function", "call_method"}:
rv.node.meta["source_fn"] = target
rv.node.meta["source_fn"] = (rv.node.name, target)
elif kind == "call_module":
if self.parent is not None:
unimplemented("Invoking an nn.Module inside HigherOrderOperator")
# For modules we store the class
rv.node.meta["source_fn"] = rv.node.meta["nn_module_stack"][target][1]
rv.node.meta["source_fn"] = (
rv.node.name,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is rv.node.name the qualified name?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, it's the unique name of the node in the fx graph, so it will help us handle the case where if there are 2 linear module calls side by side in the graph.

rv.node.meta["nn_module_stack"][target][1],
)

frame_summaries: List[traceback.FrameSummary] = []
while tx:
Expand Down
6 changes: 3 additions & 3 deletions torch/_inductor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,10 +295,10 @@ def get_fused_kernel_name(node_schedule):
sources = []
for origin in all_origins:
if origin.op == "call_function" and "source_fn" in origin.meta:
if isinstance(origin.meta["source_fn"], str):
sources.append(origin.meta["source_fn"])
if isinstance(origin.meta["source_fn"][1], str):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what ws this change for?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I modified the "source_fn" metadata to additionally return a unique qualifying name for each function that is called so that if there are 2 modules that are called one after the other then we can distinguish between the two. This change is just to make inductor compatible.

sources.append(origin.meta["source_fn"][1])
else:
sources.append(origin.meta["source_fn"].__name__)
sources.append(origin.meta["source_fn"][1].__name__)
sources = sorted(set(sources))
elif config.triton.descriptive_names == "inductor_node":
sources = [
Expand Down
128 changes: 128 additions & 0 deletions torch/fx/passes/utils/source_matcher_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
from dataclasses import dataclass, field
from torch.fx.graph import Graph
from torch.fx.node import Node
from torch.fx._compatibility import compatibility
from typing import Dict, List, Any, Type
import logging
import os


__all__ = ['get_source_partitions', 'check_subgraphs_connected']

# Set`PYTORCH_MATCHER_LOGLEVEL=INFO` to see debug logs
def _init_logger():
logger = logging.getLogger(__name__)

level = os.environ.get('PYTORCH_MATCHER_LOGLEVEL', 'WARNING').upper()
logger.setLevel(level)
console = logging.StreamHandler()
formatter = logging.Formatter("%(filename)s > %(message)s")
console.setFormatter(formatter)
console.setLevel(level)
# add the handlers to the logger
logger.addHandler(console)
logger.propagate = False
return logger

logger = _init_logger()


@compatibility(is_backward_compatible=False)
@dataclass
class SourcePartition():
# Nodes in a particular partition
nodes: List[Node]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reasoning for having the partition be a list of node rather than the partitioned graph itself? We can derive nodes from the graph, and having the graph can help preserve the partitions structure

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can generate a graph from the list. Do you want Graph istead of List[Node]? If so why?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a simple api to convert a List[Node] --> Graph? If not, in the case I might want to use something like subgraph_rewriter after to replace these partition modules, using a graph as the pattern to replace rather than a list of nodes would be easier

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, there exists fuse_as_graphmodule


# The source these nodes decomposed from
source: Any

# Nodes in the graph that are needed as inputs to the partition
input_nodes: List[Node] = field(default_factory=list)

# Nodes in the partition that are being used by nodes outside of the
# partition
output_nodes: List[Node] = field(default_factory=list)

# Parameters that are being used
params: List[str] = field(default_factory=list)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why was this list of strings instead of List[Node]?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't sure how you wanted the parameters formatted so I just returned a list of the attributes of the parameters. But I can fix this

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally we could have NamedParameters = tuple(str, Tensor) such that weight would correspond to weight tensor. But talking to Sherlock, I remember this was harder. For now this is fine.



@compatibility(is_backward_compatible=False)
def get_source_partitions(
graph: Graph,
wanted_sources: List[Any]
) -> Dict[Any, List[SourcePartition]]:
"""
Args:
graph: The graph we want to partition
wanted_sources: List of sources of nodes that were decomposed from this
source. This can be a function (ex. torch.nn.functional.linear) or a
leaf module type (ex. torch.nn.Linear).

Returns:
Dictionary mapping sources that were given to a list of SourcePartitions
that correspond to the list of nodes that were decomposed from the given
source.
"""
modules: Dict[Type, Dict[str, List[Node]]] = {}

for node in graph.nodes:
# The metadata source_fn should contain a tuple of a unique name for the
# source, and the source function if the node is decomposed from a
# function, or the type of module if the node is decomposed from a leaf
# module

if (source_fn := node.meta.get("source_fn", None)) is None:
continue

if source_fn[1] not in wanted_sources:
continue

diff_modules = modules.setdefault(source_fn[1], {})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Angela, so why are we using source_fn and not nn_module_stack. Main difference I see is that source_fn trackes the leaf module/function whereas nn_module_stack tracked the entire module hierarchy. I think it is useful to use nn_module_stack so that a node can belong to multiple partitions but ecah partition it belongs to must have a strict parent->child relation.

Reason why this might be useful is that when modules like LSTM or attentention get decomposed you still can get nodes that belong to higher level module like Attention. And if I were to quantize entire attention module, I can.

partition = diff_modules.setdefault(source_fn[0], [])
partition.append(node)

def make_partition(nodes: List[Node], module_type: Type) -> SourcePartition:
input_nodes = set()
output_nodes = set()
params = set()
for node in nodes:
for arg in node.args:
if isinstance(arg, Node) and arg not in nodes:
input_nodes.add(arg)

if node.op == "get_attr":
params.add(node.target)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the reasons I mentioned above, like for module here, https://fburl.com/owodcrrr, if we have named parameters than it is easier to access. ALthough I dont know what happens to constants. If they are "burnt" in then it will be harder to figure out what are their "names".


for user in node.users.keys():
if user not in nodes:
output_nodes.add(node)

return SourcePartition(
nodes,
module_type,
list(input_nodes),
list(output_nodes),
list(params), # type: ignore[arg-type]
)

ret: Dict[Type[Any], List[SourcePartition]] = {}
for k, v in modules.items():
ret[k] = [make_partition(partition, k) for partition in v.values()]

return ret


@compatibility(is_backward_compatible=False)
def check_subgraphs_connected(subgraph1: SourcePartition, subgraph2: SourcePartition) -> bool:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is somewhat loose in that, two graphs maybe overlapping, right? I would expect you to check the output nodes being the input nodes to the second partition? That might be stricter?

"""
Given two subgraphs A and B (in the form of a list of nodes), checks if
A has nodes connecting to at least one node in B -- aka there exists a node
in B that uses a node in A (not the other way around).
"""

for node in reversed(subgraph1.nodes):
for user in node.users.keys():
if user in subgraph2.nodes:
return True
return False