Skip to content

Commit e68d185

Browse files
committed
Update
[ghstack-poisoned]
1 parent 0eea956 commit e68d185

File tree

5 files changed

+122
-59
lines changed

5 files changed

+122
-59
lines changed

test/distributed/test_aten_comm_compute_reordering.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -954,7 +954,7 @@ def func(a):
954954

955955
patches = {
956956
**get_patches(),
957-
"aten_distributed_optimizations.benchmark_collectives": True,
957+
"aten_distributed_optimizations.collective_estimator": "benchmark",
958958
}
959959

960960
with _dynamo_dist_per_rank_init(
@@ -970,7 +970,9 @@ def func(a):
970970
out, aten_graph_str = run_and_get_aten_graph(compiled, inputs)
971971

972972
# Verify wait_tensor is sinked (scheduling worked)
973-
FileCheck().check("all_reduce").check("mm").check("wait_tensor").check("mm").run(aten_graph_str)
973+
FileCheck().check("all_reduce").check("mm").check("wait_tensor").check(
974+
"mm"
975+
).run(aten_graph_str)
974976

975977
correct = func(inputs)
976978
self.assertTrue(same(out, correct))

torch/_inductor/config.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -910,9 +910,10 @@ class aten_distributed_optimizations:
910910
None
911911
)
912912

913-
# Benchmark collectives using CUDA events instead of analytical model
914-
# When enabled, uses power-of-2 rounding, interpolation, and distributed sync
915-
benchmark_collectives: bool = False
913+
# Method for estimating collective runtime
914+
# "analytical": Use bandwidth formulas (default)
915+
# "benchmark": Use CUDA events with power-of-2 rounding and interpolation
916+
collective_estimator: Literal["analytical", "benchmark"] = "analytical"
916917

917918

918919
def parallel_compile_enabled_internally() -> bool:

torch/_inductor/fx_passes/node_runtime_estimation.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010
from typing import Any, Optional
1111

1212
import torch
13-
from torch._dynamo.utils import dynamo_timed
14-
from torch._logging import getArtifactLogger, trace_structured
13+
from torch._logging import getArtifactLogger
1514

1615

1716
# Setup logger for artifact logging
@@ -22,9 +21,11 @@
2221
# Cache (following overlap_scheduling.py)
2322
# ============================================================================
2423

24+
2525
@functools.cache
2626
def get_estimation_cache() -> Any:
2727
from torch._inductor.codecache import LocalCache
28+
2829
return LocalCache()
2930

3031

@@ -58,6 +59,7 @@ def set_cached_runtime(key: str, value: float) -> None:
5859
# Utilities
5960
# ============================================================================
6061

62+
6163
def get_hint(x: int | torch.SymInt) -> Optional[int]:
6264
if isinstance(x, int):
6365
return x
@@ -83,6 +85,7 @@ def can_benchmark_collective() -> bool:
8385
# Collective Benchmarking
8486
# ============================================================================
8587

88+
8689
def _benchmark_collective_with_cuda_events_impl(
8790
n: torch.fx.Node,
8891
args: tuple[Any, ...],
@@ -178,10 +181,14 @@ def extract_tensor_info(t: torch.Tensor) -> torch.Tensor:
178181

179182
# Find power-of-2 BYTE bounds
180183
upper_pow2_bytes = next_power_of_2(actual_bytes)
181-
lower_pow2_bytes = upper_pow2_bytes if upper_pow2_bytes == actual_bytes else upper_pow2_bytes // 2
184+
lower_pow2_bytes = (
185+
upper_pow2_bytes if upper_pow2_bytes == actual_bytes else upper_pow2_bytes // 2
186+
)
182187

183188
# Helper to benchmark a specific power-of-2 byte size
184-
def benchmark_bytes(bytes_pow2: int, dtype: torch.dtype) -> tuple[float | None, str]:
189+
def benchmark_bytes(
190+
bytes_pow2: int, dtype: torch.dtype
191+
) -> tuple[float | None, str]:
185192
# Cache key by BYTES (dtype-agnostic)
186193
key = f"{n.target}: ({bytes_pow2} bytes)"
187194

torch/_inductor/fx_passes/overlap_scheduling.py

Lines changed: 102 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from collections import Counter, defaultdict
77
from collections.abc import Iterable
88
from dataclasses import dataclass
9-
from typing import Any, Callable
9+
from typing import Any, Callable, Literal
1010

1111
import torch
1212
import torch.fx as fx
@@ -63,10 +63,12 @@ def estimate_collective_time(
6363

6464
# Use benchmarking if configured
6565
from torch._inductor import config
66+
6667
if config.aten_distributed_optimizations.benchmark_collectives:
6768
from torch._inductor.fx_passes.node_runtime_estimation import (
6869
benchmark_collective_with_cuda_events,
6970
)
71+
7072
# Use cache during estimation
7173
runtime, _ = benchmark_collective_with_cuda_events(n, nruns=2)
7274
if runtime is not None:
@@ -258,6 +260,7 @@ def __init__(
258260
compute_overlap_multipler: float,
259261
max_coll_distance: int,
260262
custom_runtime_estimation: Callable[[fx.Node], float | None] | None,
263+
collective_estimator: Literal["analytical", "benchmark"],
261264
):
262265
self.gm = gm
263266
self.graph = gm.graph
@@ -268,6 +271,7 @@ def __init__(
268271
self.collective_bucketing = collective_bucketing
269272
self.insert_overlap_deps = insert_overlap_deps
270273
self.max_compute_pre_fetch = max_compute_pre_fetch
274+
self.collective_estimator = collective_estimator
271275

272276
# Build structures
273277
stable_topological_sort(self.graph)
@@ -370,11 +374,56 @@ def _calculate_compute_node_domination_index(self) -> dict[fx.Node, int]:
370374

371375
return domination_index
372376

377+
def _log_collective_benchmarks(
378+
self,
379+
collective_nodes: list[fx.Node],
380+
collective_keys: list[str],
381+
benchmarked_medians: list[float],
382+
world_size: int,
383+
) -> None:
384+
"""Log collective benchmarks with analytical comparisons for tlparse."""
385+
collective_benchmarks = {}
386+
for key, benchmarked_ms, coll_node in zip(
387+
collective_keys, benchmarked_medians, collective_nodes
388+
):
389+
# NCCL estimator (deterministic, no need to align)
390+
nccl_ms = torch._inductor.comm_analysis.estimate_nccl_collective_runtime_from_fx_node(
391+
coll_node, None, use_nccl_estimator=True
392+
)
393+
394+
# Inductor analytical (deterministic, no need to align)
395+
inductor_ms = torch._inductor.comm_analysis.estimate_nccl_collective_runtime_from_fx_node(
396+
coll_node, None, use_nccl_estimator=False
397+
)
398+
399+
collective_benchmarks[key] = {
400+
"benchmarked_ms": benchmarked_ms,
401+
"analytical_nccl_ms": nccl_ms,
402+
"analytical_inductor_ms": inductor_ms,
403+
}
404+
405+
# Emit tlparse artifact
406+
from torch._logging import trace_structured
407+
408+
trace_structured(
409+
"artifact",
410+
metadata_fn=lambda: {
411+
"name": "node_runtime_estimation",
412+
"encoding": "json",
413+
},
414+
payload_fn=lambda: {
415+
"world_size": world_size,
416+
"collective_benchmarks": collective_benchmarks,
417+
},
418+
)
419+
373420
def _align_compute_nodes_runtime_estimations_across_all_distributed_ranks(
374421
self,
375422
) -> None:
376423
"""Align runtime estimations across ranks (compute + collectives)."""
377-
log.info("Overlap scheduling: Aligning runtime estimations across all distributed ranks")
424+
log.info(
425+
"Overlap scheduling: Aligning runtime estimations across all distributed ranks"
426+
)
378427

379428
# Benchmark compute nodes
380429
runtime_estimations_keys: list[str | None] = []
@@ -387,9 +436,9 @@ def _align_compute_nodes_runtime_estimations_across_all_distributed_ranks(
387436
runtime_estimations_keys.append(key)
388437
compute_key_count += 1
389438

390-
# Benchmark collectives if enabled
391-
from torch._inductor import config
392-
if config.aten_distributed_optimizations.benchmark_collectives:
439+
# Benchmark collectives if enabled (only CUDA events - others are deterministic)
440+
collective_nodes: list[fx.Node] = []
441+
if self.collective_estimator == "benchmark":
393442
from torch._inductor.fx_passes.node_runtime_estimation import (
394443
benchmark_collective_with_cuda_events,
395444
clear_collective_cache_once,
@@ -398,73 +447,72 @@ def _align_compute_nodes_runtime_estimations_across_all_distributed_ranks(
398447
# Clear stale cache once per process
399448
clear_collective_cache_once()
400449

401-
collective_nodes = [info.start_node for info in self.collective_info.values()]
450+
collective_nodes = [
451+
info.start_node for info in self.collective_info.values()
452+
]
453+
454+
# Benchmark CUDA events (non-deterministic, needs alignment)
402455
for n in collective_nodes:
403-
val, key = benchmark_collective_with_cuda_events(n, nruns=2)
404-
# Skip if benchmarking failed (None)
405-
if val is not None:
406-
runtime_estimations.append(val)
407-
runtime_estimations_keys.append(key)
456+
cuda_val, cuda_key = benchmark_collective_with_cuda_events(n, nruns=2)
457+
if cuda_val is not None:
458+
runtime_estimations.append(cuda_val)
459+
runtime_estimations_keys.append(cuda_key)
408460

409-
# All gather and compute median
461+
# Single all_gather and compute medians
410462
import torch.distributed as dist
411463
from torch._subclasses.fake_tensor import unset_fake_temporarily
412464
from torch.distributed.distributed_c10d import _get_default_group
413465

414466
world_size = dist.get_world_size()
415467
pg = _get_default_group()
468+
416469
with unset_fake_temporarily():
417-
gathered_runtime_estimations: list[list[float]] = [[] for _ in range(world_size)]
418-
dist.all_gather_object(gathered_runtime_estimations, runtime_estimations, pg)
470+
gathered_runtime_estimations: list[list[float]] = [
471+
[] for _ in range(world_size)
472+
]
473+
dist.all_gather_object(
474+
gathered_runtime_estimations, runtime_estimations, pg
475+
)
419476
median_runtime_estimations = torch.median(
420477
torch.tensor(gathered_runtime_estimations), dim=0
421478
).values.tolist()
422479

423-
# Cache medians (compute vs collective use different caches)
424-
collective_benchmarks = {}
425-
for idx, (key, median_runtime_estimation) in enumerate(zip(runtime_estimations_keys, median_runtime_estimations)):
480+
# Cache medians
481+
collective_keys = []
482+
collective_medians = []
483+
for idx, (key, median_runtime_estimation) in enumerate(
484+
zip(runtime_estimations_keys, median_runtime_estimations)
485+
):
426486
if key is None:
427487
continue
428488
if idx < compute_key_count:
429489
# Compute node
430490
set_cached_node_time(key, median_runtime_estimation)
431491
else:
432-
# Collective node
433-
from torch._inductor.fx_passes.node_runtime_estimation import set_cached_runtime
434-
set_cached_runtime(key, median_runtime_estimation)
435-
436-
# Get analytical estimates for comparison
437-
collective_node = collective_nodes[idx - compute_key_count]
438-
439-
# Inductor analytical model (bandwidth formulas)
440-
analytical_inductor = torch._inductor.comm_analysis.estimate_nccl_collective_runtime_from_fx_node(
441-
collective_node, None, use_nccl_estimator=False
492+
# Collective CUDA event benchmark
493+
from torch._inductor.fx_passes.node_runtime_estimation import (
494+
set_cached_runtime,
442495
)
443496

444-
# NCCL's built-in estimator
445-
analytical_nccl = torch._inductor.comm_analysis.estimate_nccl_collective_runtime_from_fx_node(
446-
collective_node, None, use_nccl_estimator=True
447-
)
497+
set_cached_runtime(key, median_runtime_estimation)
448498

449-
collective_benchmarks[key] = {
450-
"benchmarked_ms": median_runtime_estimation,
451-
"analytical_inductor_ms": analytical_inductor,
452-
"analytical_nccl_ms": analytical_nccl,
453-
}
454-
455-
# Emit tlparse artifact with collective benchmarks
456-
if collective_benchmarks:
457-
from torch._logging import trace_structured
458-
trace_structured(
459-
"artifact",
460-
metadata_fn=lambda: {
461-
"name": "node_runtime_estimation",
462-
"encoding": "json",
463-
},
464-
payload_fn=lambda: {
465-
"world_size": world_size,
466-
"collective_benchmarks": collective_benchmarks,
467-
},
499+
# Update CollectiveInfo with aligned benchmark
500+
coll_idx = idx - compute_key_count
501+
coll_node = collective_nodes[coll_idx]
502+
info = self.collective_info[coll_node]
503+
info.estimated_time_ms = median_runtime_estimation
504+
info.exposed_time_ms = median_runtime_estimation
505+
506+
collective_keys.append(key)
507+
collective_medians.append(median_runtime_estimation)
508+
509+
# Log benchmarks with analytical comparisons
510+
if collective_keys:
511+
self._log_collective_benchmarks(
512+
collective_nodes[: len(collective_keys)],
513+
collective_keys,
514+
collective_medians,
515+
world_size,
468516
)
469517

470518
log.info("Overlap scheduling: Runtime estimations aligned")
@@ -968,6 +1016,7 @@ def schedule_overlap_bucketing(
9681016
compute_overlap_multipler: float = 1.0,
9691017
max_coll_distance: int = 1000,
9701018
custom_runtime_estimation: Callable[[fx.Node], float | None] | None = None,
1019+
collective_estimator: Literal["analytical", "benchmark"] = "analytical",
9711020
) -> torch.fx.GraphModule:
9721021
"""Schedule nodes to maximize compute-collective overlap.
9731022
@@ -984,6 +1033,8 @@ def schedule_overlap_bucketing(
9841033
max_coll_distance: Maximum node distance for overlap or bucketing. Mostly intended to reduce compile time.
9851034
custom_runtime_estimation: Custom runtime estimation function that estimates runtime in ms for an fx node.
9861035
If None, uses default estimations. This is currently limited to collectives and compute nodes.
1036+
collective_estimator: Method for estimating collective runtime. "analytical" uses bandwidth formulas,
1037+
"benchmark" uses CUDA events with power-of-2 rounding and interpolation.
9871038
"""
9881039

9891040
return OverlapScheduler(
@@ -995,4 +1046,5 @@ def schedule_overlap_bucketing(
9951046
custom_runtime_estimation=custom_runtime_estimation,
9961047
collective_bucketing=collective_bucketing,
9971048
insert_overlap_deps=insert_overlap_deps,
1049+
collective_estimator=collective_estimator,
9981050
).run()

torch/_inductor/fx_passes/post_grad.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,7 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
289289
"max_compute_pre_fetch",
290290
"custom_runtime_estimation",
291291
"insert_overlap_deps",
292+
"collective_estimator",
292293
)
293294
for key in config_keys:
294295
if (val := getattr(dist_opts, key)) is not None:

0 commit comments

Comments
 (0)