Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 123 additions & 46 deletions test/distributed/test_inductor_collectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -1621,13 +1621,29 @@ def test_reorder_peak_memory_bucketed(self):
Ensure the whole bucketed group including copy-ops get moved together rather than the copy ops preventing the
comm from moving due to data dependency.
"""
_mems = []
def _reset():
_mems.clear()

def foo():
_mems.append(torch.cuda.memory_allocated())
lib = torch.library.Library("_test", "FRAGMENT")
lib.define("foo() -> ()")
lib.impl("foo", foo, "BackendSelect")
from torch._higher_order_ops.effects import _EffectType, _register_effectful_op

_register_effectful_op(
torch.ops._test.foo.default,
_EffectType.ORDERED,
)

def func(x, w, ag_0, ag_1, ag_2, ag_3, *, tag, ranks, group_size):
# do some unrelated matmuls
y = torch.mm(x, w)

# cast the inputs
ag_0_cast = ag_0.to(torch.bfloat16)
torch.ops._test.foo()
ag_1_cast = ag_1.to(torch.bfloat16)

# allgather
Expand Down Expand Up @@ -1715,6 +1731,30 @@ def _reorder_communication_preserving_peak_memory(
) = _reorder_communication_preserving_peak_memory_internal(snodes)
return reordered_snodes

from torch._inductor.virtualized import V
from torch._inductor.memory import estimate_peak_memory, FreeableInputBuffer, get_freeable_input_buf
from torch.utils._ordered_set import OrderedSet
import copy
estimate_outs = []
nodes = []
def _estimate_peak_memory(
snodes: list[BaseSchedulerNode],
) -> list[BaseSchedulerNode]:
graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys())
graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names())
name_to_freeable_input_buf: dict[str, FreeableInputBuffer] = get_freeable_input_buf(
snodes, graph_inputs
)
peak_memory, curr_memory = estimate_peak_memory(
snodes, name_to_freeable_input_buf, graph_outputs
)
nonlocal estimate_outs
estimate_outs.append(peak_memory)
estimate_outs.append(curr_memory)
nonlocal nodes
nodes = copy.copy(snodes)
return snodes

with torch._inductor.config.patch(
{
"bucket_all_gathers_fx": "all",
Expand All @@ -1723,59 +1763,96 @@ def _reorder_communication_preserving_peak_memory(
"bucket_reduce_scatters_fx_bucket_size_determinator": lambda _: 2,
"reorder_for_compute_comm_overlap": True,
"reorder_for_compute_comm_overlap_passes": [
sink_waits_iterative,
_reorder_communication_preserving_peak_memory,
_estimate_peak_memory,
# sink_waits_iterative,
# _reorder_communication_preserving_peak_memory,
],
"allow_buffer_reuse": False,
}
):
compiled = torch.compile(func)
compiled = torch.compile(func, fullgraph=True)
code = run_and_get_triton_code(compiled, *inputs, **self.get_world_trs())

_reset()
compiled(*inputs, **self.get_world_trs())
print("--------------------------------------")
print(f"XXX FOO_MEMS {len(_mems)}:{_mems}")
_est_mems = [0] + estimate_outs[1]
print(f"XXX _EST_MEMS {len(_est_mems)}:{_est_mems}")
print("--------------------------------------")
print("NORMALIZED")
_mems_norm = [m - _mems[0] for m in _mems]
_est_mems_norm = [m - _est_mems[0] for m in _est_mems]
print("--------------------------------------")
print(f"XXX FOO_MEMS_NORMS {len(_mems_norm)}:{_mems_norm}")
print(f"XXX _EST_MEMS_NORMS {len(_est_mems_norm)}:{_est_mems_norm}")
print(f"XXX NODES_LEN:{len(nodes)}")
print("--------------------------------------")
peak_real = max(_mems_norm)
peak_est = max(_est_mems_norm)
print(f"XXX PEAK_REAL:{peak_real}")
print(f"XXX PEAK_EST :{peak_est}")
mem_prev = 0
emem_prev = 0
for i in range(len(_est_mems) - 1):
mem = _mems_norm[i]
mem_delta = mem - mem_prev
emem = _est_mems_norm[i]
emem_delta = emem - emem_prev
print(f"XXX {i:2} EST_MEM:{emem:7} EST_MEM_DELTA:{emem_delta:7}")
print(f"XXX {i:2} MEM:{mem:7} MEM_DELTA:{mem_delta:7}")
mem_prev = mem
emem_prev = emem
if i < len(nodes):
node = nodes[i]
print(f"XXX NODE[{i:2}]:{node.debug_str()} buf_names:{node.get_buffer_names()}")


# NOTE: The first return value should be the output of the first wait_tensor.
# We want to make sure no unnecessary copy is made.
(
FileCheck()
.check_count(
"torch.ops._c10d_functional.all_gather_into_tensor_out.default(",
count=2,
exactly=True,
)
.check(
"extern_kernels.mm",
)
.check(
"extern_kernels.addmm",
)
.run(code)
)
(
FileCheck()
.check_count(
"torch.ops._c10d_functional.reduce_scatter_tensor.default(",
count=2,
exactly=True,
)
.check(
"extern_kernels.mm",
)
.check(
"extern_kernels.addmm",
)
.run(code)
)
out = compiled(*inputs, **self.get_world_trs())
correct = func(*inputs, **self.get_world_trs())
assert same(out, correct), f"{out} va {correct}"
assert node_stats is not None
self.assertTrue(isinstance(node_stats, dict))
self.assertEqual(len(node_stats), 4)
it = iter(node_stats.values())
node_stat0 = next(it)
self.assertTrue(node_stat0.moves > 0)
self.assertTrue(node_stat0.limiting_factor == "None")
node_stat1 = next(it)
self.assertTrue(node_stat1.moves > 0)
self.assertTrue("collective ordering" in node_stat1.limiting_factor)
# (
# FileCheck()
# .check_count(
# "torch.ops._c10d_functional.all_gather_into_tensor_out.default(",
# count=2,
# exactly=True,
# )
# .check(
# "extern_kernels.mm",
# )
# .check(
# "extern_kernels.addmm",
# )
# .run(code)
# )
# (
# FileCheck()
# .check_count(
# "torch.ops._c10d_functional.reduce_scatter_tensor.default(",
# count=2,
# exactly=True,
# )
# .check(
# "extern_kernels.mm",
# )
# .check(
# "extern_kernels.addmm",
# )
# .run(code)
# )
# out = compiled(*inputs, **self.get_world_trs())
# correct = func(*inputs, **self.get_world_trs())
# assert same(out, correct), f"{out} va {correct}"
# assert node_stats is not None
# self.assertTrue(isinstance(node_stats, dict))
# self.assertEqual(len(node_stats), 4)
# it = iter(node_stats.values())
# node_stat0 = next(it)
# self.assertTrue(node_stat0.moves > 0)
# self.assertTrue(node_stat0.limiting_factor == "None")
# node_stat1 = next(it)
# self.assertTrue(node_stat1.moves > 0)
# self.assertTrue("collective ordering" in node_stat1.limiting_factor)

@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
def test_reorder_respects_wait_dep(self):
Expand Down
15 changes: 15 additions & 0 deletions torch/_inductor/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,7 @@ def _group_names(head, tail):
curr = snodes[-1]

processed_waits = OrderedSet() # type: ignore[var-annotated]
swap_i = 0
while _prev[curr] is not None:
if contains_wait(curr) and curr not in processed_waits:
processed_waits.add(curr)
Expand Down Expand Up @@ -761,6 +762,7 @@ def is_groupable(snode):

info.moves += 1
info.moves_info += f"+{candidate.get_name()}"
print(f"XXX SWAP {candidate.get_name()} vs {_group_names(group_head, group_tail)}")

# group_head_prev -0-> candidate -1-> group_head...group_tail -2-> candidate_next
mem_deltas = {}
Expand Down Expand Up @@ -791,6 +793,19 @@ def is_groupable(snode):
_prev_curr_memory + mem_deltas[n]
)

### DEBUG_PEAK
new_snodes = _group_nodes(_head, None)
new_peak_memory, new_curr_memory = estimate_peak_memory(
new_snodes, name_to_freeable_input_buf, graph_outputs
)
print(f"XXX SWAP {swap_i}")
print(f"XXX NEW_PEAK:{new_peak_memory} old_peak:{peak_memory}")
for i, node in enumerate(new_snodes):
print(f"XXX {node.get_name()} CURR:{new_curr_memory[i]} ITER_CURR:{_curr_memory[node]}")
swap_i += 1
if new_peak_memory > peak_memory:
assert False
### DEBUG_PEAK
candidate = _next[group_tail]
curr = _prev[curr] # type: ignore[assignment]

Expand Down
59 changes: 51 additions & 8 deletions torch/_inductor/fx_passes/bucketing.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,24 +460,67 @@ def merge_all_gather(
"pin_memory": False,
},
)
all_gather_copy_in = new_graph_call_function(
# INLINE AG_COPY_IN
all_gather_input = new_graph_call_function(
new_graph,
torch.ops.fsdp.all_gather_copy_in.default,
torch.ops.aten.slice.Tensor,
(
param_all_gather_inputs_flattened,
all_gather_output,
inp_split_sizes,
0,
all_gather_input_numel * rank,
all_gather_input_numel,
rank,
),
{},
)
all_gather_input = new_graph_call_function(
split_with_sizes = new_graph_call_function(
new_graph,
torch.ops.aten.split_with_sizes.default,
(
all_gather_input,
inp_split_sizes,
),
{},
)
splits = [
new_graph_call_function(
new_graph,
operator.getitem,
(
split_with_sizes,
i,
),
{},
)
for i in range(len(inp_split_sizes))
]
foreach_copy = new_graph_call_function(
new_graph,
operator.getitem,
(all_gather_copy_in, 0),
torch.ops.aten._foreach_copy_.default,
(
splits,
param_all_gather_inputs_flattened,
),
{},
)
# END
# all_gather_copy_in = new_graph_call_function(
# new_graph,
# torch.ops.fsdp.all_gather_copy_in.default,
# (
# param_all_gather_inputs_flattened,
# all_gather_output,
# inp_split_sizes,
# all_gather_input_numel,
# rank,
# ),
# {},
# )
# all_gather_input = new_graph_call_function(
# new_graph,
# operator.getitem,
# (all_gather_copy_in, 0),
# {},
# )
all_gather_into_tensor_out = new_graph_call_function(
new_graph,
torch.ops._c10d_functional.all_gather_into_tensor_out.default,
Expand Down
5 changes: 4 additions & 1 deletion torch/_inductor/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -6691,7 +6691,7 @@ def make_triton_fallback(op):
register_foreach_pointwise(aten._foreach_clamp_max.Scalar, minimum)
register_foreach_pointwise(aten._foreach_reciprocal, reciprocal)
register_foreach_pointwise(aten._foreach_sign, sign)
register_foreach_pointwise(aten._foreach_copy, copy)
foreach_copy = register_foreach_pointwise(aten._foreach_copy, copy)


# these are only encountered as outputs of the graph
Expand Down Expand Up @@ -6730,6 +6730,9 @@ def fn(*args, **kwargs):
register_foreach_inplace(
aten._foreach_div_.Scalar, aten._foreach_div.Scalar, foreach_div_scalar
)
register_foreach_inplace(
aten._foreach_copy_.default, aten._foreach_copy.default, foreach_copy
)


def register_inplace(aten_op, outplace_op):
Expand Down
7 changes: 5 additions & 2 deletions torch/_inductor/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,10 +340,13 @@ class BufferInfo:
memory = [0 for _ in range(len(nodes) + 1)]

# for each buffer, update memory when created and when freed
for buf_info in buf_info_list:
for i, buf_info in enumerate(buf_info_list):
print(f"XXX BUF_INFO[{i}]:{buf_info.buffer.get_name()} size_alloc:{buf_info.size_alloc} size_free:{buf_info.size_free}")
print(f"XXX start_step:{buf_info.start_step} end_step:{buf_info.end_step}")
memory[buf_info.start_step] += buf_info.size_alloc
memory[buf_info.end_step + 1] -= buf_info.size_free

for i, m in enumerate(memory):
print(f"XXX MEM_ALLOC_AFTER_NODE[{i}]={m}")
# get peak memory by compute the cumulative memories
max_memory = 0
cur_memory = 0
Expand Down
Loading
Loading