66from collections import Counter , defaultdict
77from collections .abc import Iterable
88from dataclasses import dataclass
9- from typing import Any , Callable
9+ from typing import Any , Callable , Literal
1010
1111import torch
1212import torch .fx as fx
@@ -63,10 +63,12 @@ def estimate_collective_time(
6363
6464 # Use benchmarking if configured
6565 from torch ._inductor import config
66+
6667 if config .aten_distributed_optimizations .benchmark_collectives :
6768 from torch ._inductor .fx_passes .node_runtime_estimation import (
6869 benchmark_collective_with_cuda_events ,
6970 )
71+
7072 # Use cache during estimation
7173 runtime , _ = benchmark_collective_with_cuda_events (n , nruns = 2 )
7274 if runtime is not None :
@@ -258,6 +260,7 @@ def __init__(
258260 compute_overlap_multipler : float ,
259261 max_coll_distance : int ,
260262 custom_runtime_estimation : Callable [[fx .Node ], float | None ] | None ,
263+ collective_estimator : Literal ["analytical" , "benchmark" ],
261264 ):
262265 self .gm = gm
263266 self .graph = gm .graph
@@ -268,6 +271,7 @@ def __init__(
268271 self .collective_bucketing = collective_bucketing
269272 self .insert_overlap_deps = insert_overlap_deps
270273 self .max_compute_pre_fetch = max_compute_pre_fetch
274+ self .collective_estimator = collective_estimator
271275
272276 # Build structures
273277 stable_topological_sort (self .graph )
@@ -370,11 +374,56 @@ def _calculate_compute_node_domination_index(self) -> dict[fx.Node, int]:
370374
371375 return domination_index
372376
377+ def _log_collective_benchmarks (
378+ self ,
379+ collective_nodes : list [fx .Node ],
380+ collective_keys : list [str ],
381+ benchmarked_medians : list [float ],
382+ world_size : int ,
383+ ) -> None :
384+ """Log collective benchmarks with analytical comparisons for tlparse."""
385+ collective_benchmarks = {}
386+ for key , benchmarked_ms , coll_node in zip (
387+ collective_keys , benchmarked_medians , collective_nodes
388+ ):
389+ # NCCL estimator (deterministic, no need to align)
390+ nccl_ms = torch ._inductor .comm_analysis .estimate_nccl_collective_runtime_from_fx_node (
391+ coll_node , None , use_nccl_estimator = True
392+ )
393+
394+ # Inductor analytical (deterministic, no need to align)
395+ inductor_ms = torch ._inductor .comm_analysis .estimate_nccl_collective_runtime_from_fx_node (
396+ coll_node , None , use_nccl_estimator = False
397+ )
398+
399+ collective_benchmarks [key ] = {
400+ "benchmarked_ms" : benchmarked_ms ,
401+ "analytical_nccl_ms" : nccl_ms ,
402+ "analytical_inductor_ms" : inductor_ms ,
403+ }
404+
405+ # Emit tlparse artifact
406+ from torch ._logging import trace_structured
407+
408+ trace_structured (
409+ "artifact" ,
410+ metadata_fn = lambda : {
411+ "name" : "node_runtime_estimation" ,
412+ "encoding" : "json" ,
413+ },
414+ payload_fn = lambda : {
415+ "world_size" : world_size ,
416+ "collective_benchmarks" : collective_benchmarks ,
417+ },
418+ )
419+
373420 def _align_compute_nodes_runtime_estimations_across_all_distributed_ranks (
374421 self ,
375422 ) -> None :
376423 """Align runtime estimations across ranks (compute + collectives)."""
377- log .info ("Overlap scheduling: Aligning runtime estimations across all distributed ranks" )
424+ log .info (
425+ "Overlap scheduling: Aligning runtime estimations across all distributed ranks"
426+ )
378427
379428 # Benchmark compute nodes
380429 runtime_estimations_keys : list [str | None ] = []
@@ -387,9 +436,9 @@ def _align_compute_nodes_runtime_estimations_across_all_distributed_ranks(
387436 runtime_estimations_keys .append (key )
388437 compute_key_count += 1
389438
390- # Benchmark collectives if enabled
391- from torch . _inductor import config
392- if config . aten_distributed_optimizations . benchmark_collectives :
439+ # Benchmark collectives if enabled (only CUDA events - others are deterministic)
440+ collective_nodes : list [ fx . Node ] = []
441+ if self . collective_estimator == "benchmark" :
393442 from torch ._inductor .fx_passes .node_runtime_estimation import (
394443 benchmark_collective_with_cuda_events ,
395444 clear_collective_cache_once ,
@@ -398,73 +447,72 @@ def _align_compute_nodes_runtime_estimations_across_all_distributed_ranks(
398447 # Clear stale cache once per process
399448 clear_collective_cache_once ()
400449
401- collective_nodes = [info .start_node for info in self .collective_info .values ()]
450+ collective_nodes = [
451+ info .start_node for info in self .collective_info .values ()
452+ ]
453+
454+ # Benchmark CUDA events (non-deterministic, needs alignment)
402455 for n in collective_nodes :
403- val , key = benchmark_collective_with_cuda_events (n , nruns = 2 )
404- # Skip if benchmarking failed (None)
405- if val is not None :
406- runtime_estimations .append (val )
407- runtime_estimations_keys .append (key )
456+ cuda_val , cuda_key = benchmark_collective_with_cuda_events (n , nruns = 2 )
457+ if cuda_val is not None :
458+ runtime_estimations .append (cuda_val )
459+ runtime_estimations_keys .append (cuda_key )
408460
409- # All gather and compute median
461+ # Single all_gather and compute medians
410462 import torch .distributed as dist
411463 from torch ._subclasses .fake_tensor import unset_fake_temporarily
412464 from torch .distributed .distributed_c10d import _get_default_group
413465
414466 world_size = dist .get_world_size ()
415467 pg = _get_default_group ()
468+
416469 with unset_fake_temporarily ():
417- gathered_runtime_estimations : list [list [float ]] = [[] for _ in range (world_size )]
418- dist .all_gather_object (gathered_runtime_estimations , runtime_estimations , pg )
470+ gathered_runtime_estimations : list [list [float ]] = [
471+ [] for _ in range (world_size )
472+ ]
473+ dist .all_gather_object (
474+ gathered_runtime_estimations , runtime_estimations , pg
475+ )
419476 median_runtime_estimations = torch .median (
420477 torch .tensor (gathered_runtime_estimations ), dim = 0
421478 ).values .tolist ()
422479
423- # Cache medians (compute vs collective use different caches)
424- collective_benchmarks = {}
425- for idx , (key , median_runtime_estimation ) in enumerate (zip (runtime_estimations_keys , median_runtime_estimations )):
480+ # Cache medians
481+ collective_keys = []
482+ collective_medians = []
483+ for idx , (key , median_runtime_estimation ) in enumerate (
484+ zip (runtime_estimations_keys , median_runtime_estimations )
485+ ):
426486 if key is None :
427487 continue
428488 if idx < compute_key_count :
429489 # Compute node
430490 set_cached_node_time (key , median_runtime_estimation )
431491 else :
432- # Collective node
433- from torch ._inductor .fx_passes .node_runtime_estimation import set_cached_runtime
434- set_cached_runtime (key , median_runtime_estimation )
435-
436- # Get analytical estimates for comparison
437- collective_node = collective_nodes [idx - compute_key_count ]
438-
439- # Inductor analytical model (bandwidth formulas)
440- analytical_inductor = torch ._inductor .comm_analysis .estimate_nccl_collective_runtime_from_fx_node (
441- collective_node , None , use_nccl_estimator = False
492+ # Collective CUDA event benchmark
493+ from torch ._inductor .fx_passes .node_runtime_estimation import (
494+ set_cached_runtime ,
442495 )
443496
444- # NCCL's built-in estimator
445- analytical_nccl = torch ._inductor .comm_analysis .estimate_nccl_collective_runtime_from_fx_node (
446- collective_node , None , use_nccl_estimator = True
447- )
497+ set_cached_runtime (key , median_runtime_estimation )
448498
449- collective_benchmarks [key ] = {
450- "benchmarked_ms" : median_runtime_estimation ,
451- "analytical_inductor_ms" : analytical_inductor ,
452- "analytical_nccl_ms" : analytical_nccl ,
453- }
454-
455- # Emit tlparse artifact with collective benchmarks
456- if collective_benchmarks :
457- from torch ._logging import trace_structured
458- trace_structured (
459- "artifact" ,
460- metadata_fn = lambda : {
461- "name" : "node_runtime_estimation" ,
462- "encoding" : "json" ,
463- },
464- payload_fn = lambda : {
465- "world_size" : world_size ,
466- "collective_benchmarks" : collective_benchmarks ,
467- },
499+ # Update CollectiveInfo with aligned benchmark
500+ coll_idx = idx - compute_key_count
501+ coll_node = collective_nodes [coll_idx ]
502+ info = self .collective_info [coll_node ]
503+ info .estimated_time_ms = median_runtime_estimation
504+ info .exposed_time_ms = median_runtime_estimation
505+
506+ collective_keys .append (key )
507+ collective_medians .append (median_runtime_estimation )
508+
509+ # Log benchmarks with analytical comparisons
510+ if collective_keys :
511+ self ._log_collective_benchmarks (
512+ collective_nodes [: len (collective_keys )],
513+ collective_keys ,
514+ collective_medians ,
515+ world_size ,
468516 )
469517
470518 log .info ("Overlap scheduling: Runtime estimations aligned" )
@@ -968,6 +1016,7 @@ def schedule_overlap_bucketing(
9681016 compute_overlap_multipler : float = 1.0 ,
9691017 max_coll_distance : int = 1000 ,
9701018 custom_runtime_estimation : Callable [[fx .Node ], float | None ] | None = None ,
1019+ collective_estimator : Literal ["analytical" , "benchmark" ] = "analytical" ,
9711020) -> torch .fx .GraphModule :
9721021 """Schedule nodes to maximize compute-collective overlap.
9731022
@@ -984,6 +1033,8 @@ def schedule_overlap_bucketing(
9841033 max_coll_distance: Maximum node distance for overlap or bucketing. Mostly intended to reduce compile time.
9851034 custom_runtime_estimation: Custom runtime estimation function that estimates runtime in ms for an fx node.
9861035 If None, uses default estimations. This is currently limited to collectives and compute nodes.
1036+ collective_estimator: Method for estimating collective runtime. "analytical" uses bandwidth formulas,
1037+ "benchmark" uses CUDA events with power-of-2 rounding and interpolation.
9871038 """
9881039
9891040 return OverlapScheduler (
@@ -995,4 +1046,5 @@ def schedule_overlap_bucketing(
9951046 custom_runtime_estimation = custom_runtime_estimation ,
9961047 collective_bucketing = collective_bucketing ,
9971048 insert_overlap_deps = insert_overlap_deps ,
1049+ collective_estimator = collective_estimator ,
9981050 ).run ()
0 commit comments