Skip to content

Commit

Permalink
[fx][split] make sure we copy node.meta over during split
Browse files Browse the repository at this point in the history
Summary: Previously when we create placeholder nodes for sub graph modules, we didn't copy node.meta over.

Test Plan: CI

Differential Revision: D48330866

fbshipit-source-id: 7f35314b11395d37bb4c1ff44bf1cdaaf4618732
  • Loading branch information
842974287 authored and facebook-github-bot committed Aug 15, 2023
1 parent e1ee10e commit 2ce61f4
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions torch/fx/passes/split_utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import copy
from dataclasses import dataclass, field
from typing import List, Optional, Dict
from typing import Dict, List, Optional

import torch.fx
from torch.fx._compatibility import compatibility
from torch.fx.graph import map_arg
from torch.fx.passes.utils import HolderModule, lift_subgraph_as_module

from .tools_common import NodeList
from torch.fx._compatibility import compatibility
from torch.fx.passes.utils import lift_subgraph_as_module, HolderModule

__all__ = ['getattr_recursive', 'setattr_recursive', 'Component', 'split_by_tags']
__all__ = ["getattr_recursive", "setattr_recursive", "Component", "split_by_tags"]

@compatibility(is_backward_compatible=False)
def getattr_recursive(obj, name):
Expand Down Expand Up @@ -205,14 +207,14 @@ def remap_func(x):
# as a placeholder in current component's graph.
if x not in comp.orig_inputs:
comp.orig_inputs.append(x)
placeholder = comp.graph.placeholder(x.name, type_expr=x.type)
placeholder.meta = copy.copy(x.meta)
comp.input_placeholders.append(
comp.graph.placeholder(x.name, type_expr=x.type)
placeholder
)
used_in_main[x] = None

return comp.input_placeholders[
next(i for i, y in enumerate(comp.orig_inputs) if x is y)
]
return comp.input_placeholders[comp.orig_inputs.index(x)]

n = comp.graph.node_copy(node, remap_func)
n.tag = node.tag # type: ignore[attr-defined]
Expand Down

0 comments on commit 2ce61f4

Please sign in to comment.