@@ -32,12 +32,7 @@ def get_estimation_cache() -> Any:
3232@lru_cache (maxsize = 1 )
3333def clear_collective_cache_once () -> None :
3434 """Clear all collective benchmarks once per process (lru_cache ensures one-time)."""
35- from torch ._inductor import config
36-
37- if not config .aten_distributed_optimizations .benchmark_collectives :
38- return
39-
40- get_estimation_cache ().set_value ("collective_benchmarking" , value = None )
35+ get_estimation_cache ().set_value ("collective_benchmarking" , value = {})
4136
4237
4338def get_cached_runtime (key : str ) -> Optional [float ]:
@@ -55,11 +50,6 @@ def set_cached_runtime(key: str, value: float) -> None:
5550 get_estimation_cache ().set_value ("collective_benchmarking" , value = cache )
5651
5752
58- # ============================================================================
59- # Utilities
60- # ============================================================================
61-
62-
6353def get_hint (x : int | torch .SymInt ) -> Optional [int ]:
6454 if isinstance (x , int ):
6555 return x
@@ -90,7 +80,6 @@ def _benchmark_collective_with_cuda_events_impl(
9080 n : torch .fx .Node ,
9181 args : tuple [Any , ...],
9282 kwargs : dict [str , Any ],
93- benchmark_tensor : torch .Tensor ,
9483 nruns : int ,
9584) -> float | None :
9685 """
@@ -99,16 +88,9 @@ def _benchmark_collective_with_cuda_events_impl(
9988 """
10089 import torch .distributed as c10d
10190
102- # Replace tensors in args/kwargs with benchmark_tensor
103- bench_args , bench_kwargs = torch .utils ._pytree .tree_map_only (
104- torch .Tensor ,
105- lambda t : benchmark_tensor ,
106- (args , kwargs ),
107- )
108-
10991 # Warmup: call collective once and wait
11092 torch .cuda .synchronize ()
111- result = n .target (* bench_args , ** bench_kwargs ) # type: ignore[operator]
93+ result = n .target (* args , ** kwargs ) # type: ignore[operator]
11294 torch .ops ._c10d_functional .wait_tensor (result )
11395
11496 # Benchmark with CUDA events
@@ -121,7 +103,7 @@ def _benchmark_collective_with_cuda_events_impl(
121103 end_evt = torch .cuda .Event (enable_timing = True )
122104
123105 start_evt .record ()
124- result = n .target (* bench_args , ** bench_kwargs ) # type: ignore[operator]
106+ result = n .target (* args , ** kwargs ) # type: ignore[operator]
125107 torch .ops ._c10d_functional .wait_tensor (result )
126108 end_evt .record ()
127109 end_evt .synchronize ()
@@ -172,11 +154,12 @@ def extract_tensor_info(t: torch.Tensor) -> torch.Tensor:
172154 actual_bytes = total_elems * t .dtype .itemsize
173155 actual_dtype = t .dtype
174156 actual_device = t .device
175- return t
157+ else :
158+ raise RuntimeError (f"should only be one input tensor to collective { n } " )
176159
177160 torch .utils ._pytree .tree_map_only (torch .Tensor , extract_tensor_info , (args , kwargs ))
178161
179- if actual_bytes is None or actual_device is None or actual_bytes is None :
162+ if actual_bytes is None or actual_device is None or actual_dtype is None :
180163 return None , ""
181164
182165 # Find power-of-2 BYTE bounds
@@ -203,9 +186,16 @@ def benchmark_bytes(
203186 # Create empty tensor for benchmarking
204187 benchmark_tensor = torch .empty (num_elements , dtype = dtype , device = actual_device )
205188
189+ # Replace all tensors in args/kwargs with benchmark_tensor
190+ bench_args , bench_kwargs = torch .utils ._pytree .tree_map_only (
191+ torch .Tensor ,
192+ lambda t : benchmark_tensor ,
193+ (args , kwargs ),
194+ )
195+
206196 # Benchmark using CUDA events
207197 runtime = _benchmark_collective_with_cuda_events_impl (
208- n , args , kwargs , benchmark_tensor , nruns
198+ n , bench_args , bench_kwargs , nruns
209199 )
210200
211201 if runtime is None :
@@ -216,7 +206,7 @@ def benchmark_bytes(
216206 return runtime , key
217207
218208 # If exact power-of-2 bytes, just benchmark it
219- if actual_bytes == lower_pow2_bytes or actual_bytes == upper_pow2_bytes :
209+ if actual_bytes in ( lower_pow2_bytes , upper_pow2_bytes ) :
220210 return benchmark_bytes (actual_bytes , actual_dtype )
221211
222212 # Otherwise, benchmark bounds and interpolate
0 commit comments