Skip to content

Commit ba5515e

Browse files
committed
Update
[ghstack-poisoned]
1 parent e68d185 commit ba5515e

File tree

2 files changed

+23
-25
lines changed

2 files changed

+23
-25
lines changed

test/distributed/test_aten_comm_compute_reordering.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def apply_reordering_and_get_graph(graph, out_li) -> None:
5454
"max_compute_pre_fetch",
5555
"custom_runtime_estimation",
5656
"insert_overlap_deps",
57+
"collective_estimator",
5758
)
5859
for key in config_keys:
5960
if (val := getattr(dist_opts, key)) is not None:
@@ -963,6 +964,11 @@ def func(a):
963964
self.backend(device_type),
964965
fake_pg=not at_least_x_gpu(2),
965966
):
967+
# Clear any stale cache from previous tests
968+
from torch._inductor.fx_passes.node_runtime_estimation import clear_collective_cache_once
969+
clear_collective_cache_once.cache_clear() # Reset the lru_cache
970+
clear_collective_cache_once() # Actually clear the cache
971+
966972
inputs = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank
967973

968974
with torch._inductor.config.patch(patches):
@@ -974,6 +980,8 @@ def func(a):
974980
"mm"
975981
).run(aten_graph_str)
976982

983+
# Test passes if compilation succeeded with benchmarking enabled
984+
# Cache verification is tricky due to multiprocess test setup
977985
correct = func(inputs)
978986
self.assertTrue(same(out, correct))
979987

torch/_inductor/fx_passes/node_runtime_estimation.py

Lines changed: 15 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,7 @@ def get_estimation_cache() -> Any:
3232
@lru_cache(maxsize=1)
3333
def clear_collective_cache_once() -> None:
3434
"""Clear all collective benchmarks once per process (lru_cache ensures one-time)."""
35-
from torch._inductor import config
36-
37-
if not config.aten_distributed_optimizations.benchmark_collectives:
38-
return
39-
40-
get_estimation_cache().set_value("collective_benchmarking", value=None)
35+
get_estimation_cache().set_value("collective_benchmarking", value={})
4136

4237

4338
def get_cached_runtime(key: str) -> Optional[float]:
@@ -55,11 +50,6 @@ def set_cached_runtime(key: str, value: float) -> None:
5550
get_estimation_cache().set_value("collective_benchmarking", value=cache)
5651

5752

58-
# ============================================================================
59-
# Utilities
60-
# ============================================================================
61-
62-
6353
def get_hint(x: int | torch.SymInt) -> Optional[int]:
6454
if isinstance(x, int):
6555
return x
@@ -90,7 +80,6 @@ def _benchmark_collective_with_cuda_events_impl(
9080
n: torch.fx.Node,
9181
args: tuple[Any, ...],
9282
kwargs: dict[str, Any],
93-
benchmark_tensor: torch.Tensor,
9483
nruns: int,
9584
) -> float | None:
9685
"""
@@ -99,16 +88,9 @@ def _benchmark_collective_with_cuda_events_impl(
9988
"""
10089
import torch.distributed as c10d
10190

102-
# Replace tensors in args/kwargs with benchmark_tensor
103-
bench_args, bench_kwargs = torch.utils._pytree.tree_map_only(
104-
torch.Tensor,
105-
lambda t: benchmark_tensor,
106-
(args, kwargs),
107-
)
108-
10991
# Warmup: call collective once and wait
11092
torch.cuda.synchronize()
111-
result = n.target(*bench_args, **bench_kwargs) # type: ignore[operator]
93+
result = n.target(*args, **kwargs) # type: ignore[operator]
11294
torch.ops._c10d_functional.wait_tensor(result)
11395

11496
# Benchmark with CUDA events
@@ -121,7 +103,7 @@ def _benchmark_collective_with_cuda_events_impl(
121103
end_evt = torch.cuda.Event(enable_timing=True)
122104

123105
start_evt.record()
124-
result = n.target(*bench_args, **bench_kwargs) # type: ignore[operator]
106+
result = n.target(*args, **kwargs) # type: ignore[operator]
125107
torch.ops._c10d_functional.wait_tensor(result)
126108
end_evt.record()
127109
end_evt.synchronize()
@@ -172,11 +154,12 @@ def extract_tensor_info(t: torch.Tensor) -> torch.Tensor:
172154
actual_bytes = total_elems * t.dtype.itemsize
173155
actual_dtype = t.dtype
174156
actual_device = t.device
175-
return t
157+
else:
158+
raise RuntimeError(f"should only be one input tensor to collective {n}")
176159

177160
torch.utils._pytree.tree_map_only(torch.Tensor, extract_tensor_info, (args, kwargs))
178161

179-
if actual_bytes is None or actual_device is None or actual_bytes is None:
162+
if actual_bytes is None or actual_device is None or actual_dtype is None:
180163
return None, ""
181164

182165
# Find power-of-2 BYTE bounds
@@ -203,9 +186,16 @@ def benchmark_bytes(
203186
# Create empty tensor for benchmarking
204187
benchmark_tensor = torch.empty(num_elements, dtype=dtype, device=actual_device)
205188

189+
# Replace all tensors in args/kwargs with benchmark_tensor
190+
bench_args, bench_kwargs = torch.utils._pytree.tree_map_only(
191+
torch.Tensor,
192+
lambda t: benchmark_tensor,
193+
(args, kwargs),
194+
)
195+
206196
# Benchmark using CUDA events
207197
runtime = _benchmark_collective_with_cuda_events_impl(
208-
n, args, kwargs, benchmark_tensor, nruns
198+
n, bench_args, bench_kwargs, nruns
209199
)
210200

211201
if runtime is None:
@@ -216,7 +206,7 @@ def benchmark_bytes(
216206
return runtime, key
217207

218208
# If exact power-of-2 bytes, just benchmark it
219-
if actual_bytes == lower_pow2_bytes or actual_bytes == upper_pow2_bytes:
209+
if actual_bytes in (lower_pow2_bytes, upper_pow2_bytes):
220210
return benchmark_bytes(actual_bytes, actual_dtype)
221211

222212
# Otherwise, benchmark bounds and interpolate

0 commit comments

Comments
 (0)