-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Description
🚀 The feature, motivation and pitch
I hit the following AssertionError
when trying to unflatten a submodule created by fx split_module
.
[rank0]: File "torch/distributed/pipelining/_IR.py", line 750, in _from_traced
[rank0]: new_submod = _outline_submodules(submodule.graph)
[rank0]: File "torch/distributed/pipelining/_unflatten.py", line 27, in _outline_submodules
[rank0]: ).run_outer()
[rank0]: File "torch/export/unflatten.py", line 1373, in run_outer
[rank0]: self.run_from(node_idx)
[rank0]: File "torch/export/unflatten.py", line 1389, in run_from
[rank0]: assert node.op != "placeholder", f"""
[rank0]: AssertionError:
[rank0]: node idx 1: gate_weight
[rank0]: node idx 2: submod_1 <---
[rank0]: node idx 3: lifted_tensor_0
[rank0]: node idx 4: lifted_tensor_1
Root Cause
(i) It seems that the unflatten function assumes that placeholders are the first-n nodes of the graph. It thus pre-processes the placeholder nodes until a first non-placeholder one, and starts run_from
from there. In the run_from
function, it asserts that no node is a placeholder
.
(ii) However, when split_module
creates submodules, it may not necessarily put all placeholder nodes as first-n nodes. In this above case, submod_1
is a get_attr
node surrounded by in a bunch of placeholders
.
(iii) Here, submod_1
is a HOP created by torch.export, before graph split.
Alternatives
If there is a way to lint the split graph so that all the first-n nodes are placeholders, please let me know.
Additional context
Am calling unflatten on subgraph for use by pipeline parallel. The above only occurs when there is a HOP in the subgraph.
cc @chauhang @penguinwu @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4 @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci