Skip to content

Commit

Permalink
[minimizer]skip mode for minimizer (pytorch#109399)
Browse files Browse the repository at this point in the history
Summary:

- skip known issue nodes in minimizer and check the whole graph

Reviewed By: siyan-lin

Differential Revision: D48990707
  • Loading branch information
zejunh authored and facebook-github-bot committed Sep 15, 2023
1 parent ec8b58f commit 831c80d
Showing 1 changed file with 59 additions and 1 deletion.
60 changes: 59 additions & 1 deletion torch/fx/passes/net_min_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,61 @@ def _accumulate_traverse(self, nodes: NodeList) -> NodeSet:

return culprits

def _skip_traverse_impl(self, all_nodes: NodeList, start_idx: int, end_idx: int) -> NodeSet:
"""
Skip certain nodes in graph based on settings
"""
culprits: NodeSet = set()
nodes: NodeList = all_nodes[start_idx:end_idx]

report: List[str] = []
self.reports.append(report)
self.iteration += 1
report.append(f" Nodes block {self.iteration}.")
report.append(
f"From node index {start_idx} to {end_idx-1}. "
f"Size of the interested node list is {len(nodes)}"
)

cur_nodes: NodeSet = set(nodes)

for node in nodes:
if node in self.fusions:
cur_nodes.update(self.fusions[node])

try:
split_module, submod_name = self._build_submodule(cur_nodes)
self._run_and_compare(split_module, submod_name, [])
except (FxNetMinimizerResultMismatchError):
culprits.update(cur_nodes)
report.append(f"Found culprit from numeric error: {cur_nodes}")
self.print_report(report)
return culprits
except (FxNetMinimizerRunFuncError):
culprits.update(cur_nodes)
report.append(f"Found culprit from run error: {node}")
self.print_report(report)
return culprits


def _skip_traverse(self, all_nodes: NodeList, skip_nodes: List)->NodeSet:
"""
Skip certain nodes in graph based on settings
"""
start_idx = 0
num_nodes = len(all_nodes)
idx = 0
while idx < num_nodes:
node = all_nodes[idx]
if (node.name in skip_nodes): # skip the node
if idx > start_idx:
self._skip_traverse_impl(all_nodes, start_idx, idx)
start_idx = idx + 1
elif idx == num_nodes - 1 and start_idx <= idx: # last node
self._skip_traverse_impl(all_nodes, start_idx, idx + 1)
idx += 1


def _collect_nodes(self, start: Optional[str], end: Optional[str]) -> NodeList:
"""
Collect nodes in the model that between nodes with name of `start` and `end`.
Expand Down Expand Up @@ -583,7 +638,7 @@ def print_reports(self):
self.print_report(report)

def minimize(
self, start: Optional[str] = None, end: Optional[str] = None
self, start: Optional[str] = None, end: Optional[str] = None, skip_nodes: Optional[List] = None,
) -> NodeSet:
"""
Minimizing the model from node with name `start` to node with name `end` base
Expand Down Expand Up @@ -615,4 +670,7 @@ def minimize(
if self.settings.traverse_method == "accumulate":
return self._accumulate_traverse(nodes)

if(self.settings.traverse_method == "skip"):
return self._skip_traverse(nodes, skip_nodes)

raise RuntimeError(f"Unknown traverse method {self.settings.traverse_method}!")

0 comments on commit 831c80d

Please sign in to comment.