-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Partition modules #98628
Partition modules #98628
Changes from all commits
f830924
a9e4401
7cd1d24
757fc32
107f359
2dd528a
77e7b79
815230f
7434821
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
# Owner(s): ["module: fx"] | ||
|
||
import os | ||
import sys | ||
import unittest | ||
|
||
import torch | ||
|
||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) | ||
sys.path.append(pytorch_test_dir) | ||
from torch._dynamo.eval_frame import is_dynamo_supported | ||
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions, check_subgraphs_connected | ||
from torch.testing._internal.jit_utils import JitTestCase | ||
|
||
class TestSourceMatcher(JitTestCase): | ||
@unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported") | ||
def test_module_partitioner_linear_relu_linear(self): | ||
class M(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.linear1 = torch.nn.Linear(3, 3) | ||
self.relu = torch.nn.ReLU() | ||
self.linear2 = torch.nn.Linear(3, 5) | ||
|
||
def forward(self, x): | ||
x = self.linear1(x) | ||
x = self.linear1(x) | ||
x = self.relu(x) | ||
x = self.linear2(x) | ||
return x | ||
|
||
inputs = (torch.randn(3, 3),) | ||
gm, _ = torch._dynamo.export(M(), *inputs, aten_graph=True) | ||
gm.graph.eliminate_dead_code() | ||
|
||
module_partitions = get_source_partitions(gm.graph, [torch.nn.Linear, torch.nn.ReLU]) | ||
|
||
self.assertEqual(len(module_partitions), 2) | ||
self.assertEqual(len(module_partitions[torch.nn.Linear]), 3) | ||
self.assertEqual(len(module_partitions[torch.nn.ReLU]), 1) | ||
|
||
self.assertFalse(check_subgraphs_connected(module_partitions[torch.nn.Linear][0], module_partitions[torch.nn.ReLU][0])) | ||
self.assertTrue(check_subgraphs_connected(module_partitions[torch.nn.Linear][1], module_partitions[torch.nn.ReLU][0])) | ||
self.assertFalse(check_subgraphs_connected(module_partitions[torch.nn.Linear][2], module_partitions[torch.nn.ReLU][0])) | ||
|
||
@unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported") | ||
def test_module_partitioner_conv_relu_maxpool(self): | ||
class M(torch.nn.Module): | ||
def __init__(self, constant_tensor: torch.Tensor) -> None: | ||
super().__init__() | ||
self.constant_tensor = constant_tensor | ||
self.conv1 = torch.nn.Conv2d( | ||
in_channels=3, out_channels=16, kernel_size=3, padding=1 | ||
) | ||
self.conv2 = torch.nn.Conv2d( | ||
in_channels=16, out_channels=16, kernel_size=3, padding=1 | ||
) | ||
self.conv3 = torch.nn.Conv2d( | ||
in_channels=16, out_channels=16, kernel_size=3, padding=1 | ||
) | ||
self.relu = torch.nn.ReLU() | ||
self.maxpool = torch.nn.MaxPool2d(kernel_size=3) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
a = self.conv1(x) | ||
b = self.conv2(a) | ||
c = a + self.constant_tensor | ||
z = self.conv3(b + c) | ||
return self.maxpool(self.relu(z)) | ||
|
||
inputs = (torch.randn(1, 3, 256, 256),) | ||
gm, _ = torch._dynamo.export(M(torch.ones(1, 16, 256, 256)), *inputs, aten_graph=True) | ||
gm.graph.eliminate_dead_code() | ||
|
||
module_partitions = get_source_partitions(gm.graph, [torch.nn.Conv2d, torch.nn.ReLU, torch.nn.MaxPool2d]) | ||
|
||
self.assertEqual(len(module_partitions), 3) | ||
self.assertEqual(len(module_partitions[torch.nn.Conv2d]), 3) | ||
self.assertEqual(len(module_partitions[torch.nn.ReLU]), 1) | ||
self.assertEqual(len(module_partitions[torch.nn.MaxPool2d]), 1) | ||
|
||
self.assertFalse(check_subgraphs_connected(module_partitions[torch.nn.Conv2d][0], module_partitions[torch.nn.ReLU][0])) | ||
self.assertFalse(check_subgraphs_connected(module_partitions[torch.nn.Conv2d][1], module_partitions[torch.nn.ReLU][0])) | ||
self.assertTrue(check_subgraphs_connected(module_partitions[torch.nn.Conv2d][2], module_partitions[torch.nn.ReLU][0])) | ||
self.assertFalse(check_subgraphs_connected(module_partitions[torch.nn.MaxPool2d][0], module_partitions[torch.nn.ReLU][0])) | ||
self.assertTrue(check_subgraphs_connected(module_partitions[torch.nn.ReLU][0], module_partitions[torch.nn.MaxPool2d][0])) | ||
|
||
@unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported") | ||
def test_module_partitioner_functional_conv_relu_conv(self): | ||
class FunctionalConv2d(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.stride = (1, 1) | ||
self.padding = (0, 0) | ||
self.dilation = (1, 1) | ||
self.groups = 1 | ||
|
||
def forward(self, x, weight, bias): | ||
return torch.nn.functional.conv2d(x, weight, bias, self.stride, self.padding, self.dilation, self.groups) | ||
|
||
class M(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.conv1 = FunctionalConv2d() | ||
self.conv2 = FunctionalConv2d() | ||
|
||
def forward(self, x, weight, bias): | ||
x = self.conv1(x, weight, bias) | ||
x = torch.nn.functional.relu(x) | ||
x = self.conv2(x, weight, bias) | ||
return x | ||
|
||
inputs = (torch.randn(1, 3, 5, 5), torch.rand(3, 3, 3, 3), torch.rand(3)) | ||
gm, _ = torch._dynamo.export(M(), *inputs, aten_graph=True) | ||
gm.graph.eliminate_dead_code() | ||
|
||
module_partitions = get_source_partitions(gm.graph, [torch.nn.functional.conv2d]) | ||
|
||
self.assertEqual(len(module_partitions), 1) | ||
self.assertEqual(len(module_partitions[torch.nn.functional.conv2d]), 2) | ||
|
||
@unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported") | ||
def test_module_partitioner_functional_linear_relu_linear(self): | ||
class M(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, x, weight, bias): | ||
x = torch.nn.functional.linear(x, weight, bias) | ||
x = torch.nn.functional.linear(x, weight, bias) | ||
x = torch.nn.functional.relu(x) | ||
x = torch.nn.functional.linear(x, weight, bias) | ||
x = torch.nn.functional.linear(x, weight, bias) | ||
x = torch.nn.functional.relu(x) | ||
return x | ||
|
||
inputs = (torch.randn(1, 5), torch.rand((5, 5)), torch.zeros(5)) | ||
gm, _ = torch._dynamo.export(M(), *inputs, aten_graph=True) | ||
gm.graph.eliminate_dead_code() | ||
|
||
module_partitions = get_source_partitions(gm.graph, [torch.nn.functional.linear, torch.nn.functional.relu]) | ||
|
||
self.assertEqual(len(module_partitions), 2) | ||
self.assertEqual(len(module_partitions[torch.nn.functional.linear]), 4) | ||
self.assertEqual(len(module_partitions[torch.nn.functional.relu]), 2) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -295,10 +295,10 @@ def get_fused_kernel_name(node_schedule): | |
sources = [] | ||
for origin in all_origins: | ||
if origin.op == "call_function" and "source_fn" in origin.meta: | ||
if isinstance(origin.meta["source_fn"], str): | ||
sources.append(origin.meta["source_fn"]) | ||
if isinstance(origin.meta["source_fn"][1], str): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what ws this change for? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I modified the "source_fn" metadata to additionally return a unique qualifying name for each function that is called so that if there are 2 modules that are called one after the other then we can distinguish between the two. This change is just to make inductor compatible. |
||
sources.append(origin.meta["source_fn"][1]) | ||
else: | ||
sources.append(origin.meta["source_fn"].__name__) | ||
sources.append(origin.meta["source_fn"][1].__name__) | ||
sources = sorted(set(sources)) | ||
elif config.triton.descriptive_names == "inductor_node": | ||
sources = [ | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
from dataclasses import dataclass, field | ||
from torch.fx.graph import Graph | ||
from torch.fx.node import Node | ||
from torch.fx._compatibility import compatibility | ||
from typing import Dict, List, Any, Type | ||
import logging | ||
import os | ||
|
||
|
||
__all__ = ['get_source_partitions', 'check_subgraphs_connected'] | ||
|
||
# Set`PYTORCH_MATCHER_LOGLEVEL=INFO` to see debug logs | ||
def _init_logger(): | ||
logger = logging.getLogger(__name__) | ||
|
||
level = os.environ.get('PYTORCH_MATCHER_LOGLEVEL', 'WARNING').upper() | ||
logger.setLevel(level) | ||
console = logging.StreamHandler() | ||
formatter = logging.Formatter("%(filename)s > %(message)s") | ||
console.setFormatter(formatter) | ||
console.setLevel(level) | ||
# add the handlers to the logger | ||
logger.addHandler(console) | ||
logger.propagate = False | ||
return logger | ||
|
||
logger = _init_logger() | ||
|
||
|
||
@compatibility(is_backward_compatible=False) | ||
@dataclass | ||
class SourcePartition(): | ||
# Nodes in a particular partition | ||
nodes: List[Node] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any reasoning for having the partition be a list of node rather than the partitioned graph itself? We can derive nodes from the graph, and having the graph can help preserve the partitions structure There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can generate a graph from the list. Do you want There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is there a simple api to convert a List[Node] --> Graph? If not, in the case I might want to use something like subgraph_rewriter after to replace these partition modules, using a graph as the pattern to replace rather than a list of nodes would be easier There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, there exists |
||
|
||
# The source these nodes decomposed from | ||
source: Any | ||
|
||
# Nodes in the graph that are needed as inputs to the partition | ||
input_nodes: List[Node] = field(default_factory=list) | ||
|
||
# Nodes in the partition that are being used by nodes outside of the | ||
# partition | ||
output_nodes: List[Node] = field(default_factory=list) | ||
|
||
# Parameters that are being used | ||
params: List[str] = field(default_factory=list) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why was this list of strings instead of List[Node]? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wasn't sure how you wanted the parameters formatted so I just returned a list of the attributes of the parameters. But I can fix this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ideally we could have |
||
|
||
|
||
@compatibility(is_backward_compatible=False) | ||
def get_source_partitions( | ||
graph: Graph, | ||
wanted_sources: List[Any] | ||
) -> Dict[Any, List[SourcePartition]]: | ||
""" | ||
Args: | ||
graph: The graph we want to partition | ||
wanted_sources: List of sources of nodes that were decomposed from this | ||
source. This can be a function (ex. torch.nn.functional.linear) or a | ||
leaf module type (ex. torch.nn.Linear). | ||
|
||
Returns: | ||
Dictionary mapping sources that were given to a list of SourcePartitions | ||
that correspond to the list of nodes that were decomposed from the given | ||
source. | ||
""" | ||
modules: Dict[Type, Dict[str, List[Node]]] = {} | ||
|
||
for node in graph.nodes: | ||
# The metadata source_fn should contain a tuple of a unique name for the | ||
# source, and the source function if the node is decomposed from a | ||
# function, or the type of module if the node is decomposed from a leaf | ||
# module | ||
|
||
if (source_fn := node.meta.get("source_fn", None)) is None: | ||
continue | ||
|
||
if source_fn[1] not in wanted_sources: | ||
continue | ||
|
||
diff_modules = modules.setdefault(source_fn[1], {}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Angela, so why are we using source_fn and not nn_module_stack. Main difference I see is that source_fn trackes the leaf module/function whereas nn_module_stack tracked the entire module hierarchy. I think it is useful to use nn_module_stack so that a node can belong to multiple partitions but ecah partition it belongs to must have a strict parent->child relation. Reason why this might be useful is that when modules like LSTM or attentention get decomposed you still can get nodes that belong to higher level module like Attention. And if I were to quantize entire attention module, I can. |
||
partition = diff_modules.setdefault(source_fn[0], []) | ||
partition.append(node) | ||
|
||
def make_partition(nodes: List[Node], module_type: Type) -> SourcePartition: | ||
input_nodes = set() | ||
output_nodes = set() | ||
params = set() | ||
for node in nodes: | ||
for arg in node.args: | ||
if isinstance(arg, Node) and arg not in nodes: | ||
input_nodes.add(arg) | ||
|
||
if node.op == "get_attr": | ||
params.add(node.target) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For the reasons I mentioned above, like for module here, https://fburl.com/owodcrrr, if we have named parameters than it is easier to access. ALthough I dont know what happens to constants. If they are "burnt" in then it will be harder to figure out what are their "names". |
||
|
||
for user in node.users.keys(): | ||
if user not in nodes: | ||
output_nodes.add(node) | ||
|
||
return SourcePartition( | ||
nodes, | ||
module_type, | ||
list(input_nodes), | ||
list(output_nodes), | ||
list(params), # type: ignore[arg-type] | ||
) | ||
|
||
ret: Dict[Type[Any], List[SourcePartition]] = {} | ||
for k, v in modules.items(): | ||
ret[k] = [make_partition(partition, k) for partition in v.values()] | ||
|
||
return ret | ||
|
||
|
||
@compatibility(is_backward_compatible=False) | ||
def check_subgraphs_connected(subgraph1: SourcePartition, subgraph2: SourcePartition) -> bool: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is somewhat loose in that, two graphs maybe overlapping, right? I would expect you to check the output nodes being the input nodes to the second partition? That might be stricter? |
||
""" | ||
Given two subgraphs A and B (in the form of a list of nodes), checks if | ||
A has nodes connecting to at least one node in B -- aka there exists a node | ||
in B that uses a node in A (not the other way around). | ||
""" | ||
|
||
for node in reversed(subgraph1.nodes): | ||
for user in node.users.keys(): | ||
if user in subgraph2.nodes: | ||
return True | ||
return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is
rv.node.name
the qualified name?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, it's the unique name of the node in the fx graph, so it will help us handle the case where if there are 2 linear module calls side by side in the graph.