diff --git a/.ci/docker/centos-rocm/Dockerfile b/.ci/docker/centos-rocm/Dockerfile index bcf028812a887..6cb82a1f770c5 100644 --- a/.ci/docker/centos-rocm/Dockerfile +++ b/.ci/docker/centos-rocm/Dockerfile @@ -77,6 +77,9 @@ RUN rm install_rocm.sh COPY ./common/install_rocm_magma.sh install_rocm_magma.sh RUN bash ./install_rocm_magma.sh RUN rm install_rocm_magma.sh +COPY ./common/install_amdsmi.sh install_amdsmi.sh +RUN bash ./install_amdsmi.sh +RUN rm install_amdsmi.sh ENV PATH /opt/rocm/bin:$PATH ENV PATH /opt/rocm/hcc/bin:$PATH ENV PATH /opt/rocm/hip/bin:$PATH diff --git a/.ci/docker/common/install_amdsmi.sh b/.ci/docker/common/install_amdsmi.sh new file mode 100644 index 0000000000000..c16c262f0e61f --- /dev/null +++ b/.ci/docker/common/install_amdsmi.sh @@ -0,0 +1,5 @@ +#!/bin/bash + +set -ex + +cd /opt/rocm/share/amd_smi && pip install . diff --git a/.ci/docker/common/install_rocm.sh b/.ci/docker/common/install_rocm.sh index 5659b487f8380..6b746d2f92b48 100644 --- a/.ci/docker/common/install_rocm.sh +++ b/.ci/docker/common/install_rocm.sh @@ -39,7 +39,8 @@ install_ubuntu() { rocm-libs \ rccl \ rocprofiler-dev \ - roctracer-dev + roctracer-dev \ + amd-smi-lib if [[ $(ver $ROCM_VERSION) -ge $(ver 6.1) ]]; then DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated rocm-llvm-dev @@ -106,7 +107,8 @@ install_centos() { rocm-libs \ rccl \ rocprofiler-dev \ - roctracer-dev + roctracer-dev \ + amd-smi-lib # precompiled miopen kernels; search for all unversioned packages # if search fails it will abort this script; use true to avoid case where search fails diff --git a/.ci/docker/ubuntu-rocm/Dockerfile b/.ci/docker/ubuntu-rocm/Dockerfile index 9964f5c3fa91b..cc43d9ec24142 100644 --- a/.ci/docker/ubuntu-rocm/Dockerfile +++ b/.ci/docker/ubuntu-rocm/Dockerfile @@ -78,6 +78,11 @@ ENV MAGMA_HOME /opt/rocm/magma ENV LANG C.UTF-8 ENV LC_ALL C.UTF-8 +# Install amdsmi +COPY ./common/install_amdsmi.sh install_amdsmi.sh +RUN bash ./install_amdsmi.sh +RUN rm install_amdsmi.sh + # (optional) Install non-default CMake version ARG CMAKE_VERSION COPY ./common/install_cmake.sh install_cmake.sh diff --git a/.ci/pytorch/build.sh b/.ci/pytorch/build.sh index 4aa5dc39d0f5f..46f91f71283ff 100755 --- a/.ci/pytorch/build.sh +++ b/.ci/pytorch/build.sh @@ -289,6 +289,9 @@ else fi WERROR=1 python setup.py bdist_wheel else + if [[ "$BUILD_ENVIRONMENT" == *xla* ]]; then + source .ci/pytorch/install_cache_xla.sh + fi python setup.py bdist_wheel fi pip_install_whl "$(echo dist/*.whl)" diff --git a/.ci/pytorch/install_cache_xla.sh b/.ci/pytorch/install_cache_xla.sh new file mode 100755 index 0000000000000..bfc2da177f6ed --- /dev/null +++ b/.ci/pytorch/install_cache_xla.sh @@ -0,0 +1,37 @@ +#!/bin/bash + +# Script for installing sccache on the xla build job, which uses xla's docker +# image and doesn't have sccache installed on it. This is mostly copied from +# .ci/docker/install_cache.sh. Changes are: removing checks that will always +# return the same thing, ex checks for for rocm, CUDA, and changing the path +# where sccache is installed, and not changing /etc/environment. + +set -ex + +install_binary() { + echo "Downloading sccache binary from S3 repo" + curl --retry 3 https://s3.amazonaws.com/ossci-linux/sccache -o /tmp/cache/bin/sccache +} + +mkdir -p /tmp/cache/bin +mkdir -p /tmp/cache/lib +export PATH="/tmp/cache/bin:$PATH" + +install_binary +chmod a+x /tmp/cache/bin/sccache + +function write_sccache_stub() { + # Unset LD_PRELOAD for ps because of asan + ps issues + # https://gcc.gnu.org/bugzilla/show_bug.cgi?id=90589 + # shellcheck disable=SC2086 + # shellcheck disable=SC2059 + printf "#!/bin/sh\nif [ \$(env -u LD_PRELOAD ps -p \$PPID -o comm=) != sccache ]; then\n exec sccache $(which $1) \"\$@\"\nelse\n exec $(which $1) \"\$@\"\nfi" > "/tmp/cache/bin/$1" + chmod a+x "/tmp/cache/bin/$1" +} + +write_sccache_stub cc +write_sccache_stub c++ +write_sccache_stub gcc +write_sccache_stub g++ +write_sccache_stub clang +write_sccache_stub clang++ diff --git a/BUILD.bazel b/BUILD.bazel index 831d64b44c2f6..3f7e6327452c0 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -772,7 +772,7 @@ cc_library( [ "torch/*.h", "torch/csrc/**/*.h", - "torch/csrc/distributed/c10d/**/*.hpp", + "torch/csrc/distributed/c10d/*.hpp", "torch/lib/libshm/*.h", ], exclude = [ diff --git a/aten/src/ATen/autocast_mode.cpp b/aten/src/ATen/autocast_mode.cpp index 2d01bdeca500b..f0c73cde2dda3 100644 --- a/aten/src/ATen/autocast_mode.cpp +++ b/aten/src/ATen/autocast_mode.cpp @@ -35,7 +35,11 @@ namespace { // directly against incoming TensorImpl*s. using weakref_type = c10::weak_intrusive_ptr; using val_type = std::tuple; -ska::flat_hash_map cached_casts; + +static ska::flat_hash_map& get_cached_casts() { + static ska::flat_hash_map cached_casts; + return cached_casts; +} std::mutex cached_casts_mutex; @@ -82,7 +86,7 @@ thread_local bool cache_enabled = true; void clear_cache() { const std::lock_guard lock(cached_casts_mutex); - cached_casts.clear(); + get_cached_casts().clear(); } int increment_nesting() { @@ -124,12 +128,12 @@ Tensor cached_cast(at::ScalarType to_type, const Tensor& arg, DeviceType device_ if (can_try_cache) { const std::lock_guard lock(cached_casts_mutex); - auto it = cached_casts.find(arg.unsafeGetTensorImpl()); - if (it != cached_casts.end()) { + auto it = get_cached_casts().find(arg.unsafeGetTensorImpl()); + if (it != get_cached_casts().end()) { return std::get<1>(it->second); } else { auto casted_arg = arg.to(to_type); - cached_casts.emplace(arg.unsafeGetTensorImpl(), val_type{weakref_type(arg.getIntrusivePtr()), casted_arg}); + get_cached_casts().emplace(arg.unsafeGetTensorImpl(), val_type{weakref_type(arg.getIntrusivePtr()), casted_arg}); return casted_arg; } } else { diff --git a/aten/src/ATen/cuda/CUDAGraph.cpp b/aten/src/ATen/cuda/CUDAGraph.cpp index 01d3d513c4ebb..e93a8561b2ced 100644 --- a/aten/src/ATen/cuda/CUDAGraph.cpp +++ b/aten/src/ATen/cuda/CUDAGraph.cpp @@ -268,7 +268,7 @@ void CUDAGraph::debug_dump(const std::string& debug_path) { TORCH_WARN("DEBUG: calling debug_dump()"); if (has_graph_) { TORCH_WARN("DEBUG: calling cudaGraphDebugDotPrint() with ", debug_path); - C10_CUDA_CHECK_WARN(cudaGraphDebugDotPrint(graph_, debug_path.c_str(), 1<<10)); // most verbose output + C10_CUDA_CHECK_WARN(cudaGraphDebugDotPrint(graph_, debug_path.c_str(), cudaGraphDebugDotFlagsVerbose)); // most verbose output AT_CUDA_CHECK(cudaGraphDestroy(graph_)); } } else { diff --git a/aten/src/ATen/native/cpu/BlasKernel.cpp b/aten/src/ATen/native/cpu/BlasKernel.cpp index 1cc53da3584ea..587809ea57c8d 100644 --- a/aten/src/ATen/native/cpu/BlasKernel.cpp +++ b/aten/src/ATen/native/cpu/BlasKernel.cpp @@ -308,7 +308,18 @@ void gemm_notrans_( } -static float compute_dot(const float16_t *a, const float16_t *b, int64_t l) { +inline float32x4_t load_as_float32x4(const Half* ptr) { + return vcvt_f32_f16(vld1_f16(reinterpret_cast(ptr))); +} + +inline float32x4_t load_as_float32x4(const BFloat16* ptr) { + int32x4_t shift = vdupq_n_s32(16); + uint32x4_t as_int = vmovl_u16(vld1_u16(reinterpret_cast(ptr))); + return vreinterpretq_f32_u32(vshlq_u32(as_int, shift)); +} + +template +static float compute_dot(const T* a, const T* b, int64_t l) { if ((l&3) != 0) { return sum(l, [&](int64_t i) -> float { return float(a[i]) * float(b[i]); @@ -316,8 +327,8 @@ static float compute_dot(const float16_t *a, const float16_t *b, int64_t l) { } float32x4_t rcv = vdupq_n_f32(0); for (int64_t idx = 0; idx < l; idx += 4) { - float32x4_t aVec = vcvt_f32_f16(vld1_f16(a + idx)); - float32x4_t bVec = vcvt_f32_f16(vld1_f16(b + idx)); + float32x4_t aVec = load_as_float32x4(a + idx); + float32x4_t bVec = load_as_float32x4(b + idx); rcv = vaddq_f32(rcv, vmulq_f32(aVec, bVec)); } auto sum = vpaddq_f32(rcv, rcv); @@ -343,7 +354,35 @@ void gemm_transa_( for (const auto i : c10::irange(begin, end)) { const auto *b_ = b; for (const auto j : c10::irange(n)) { - const auto dot = compute_dot(reinterpret_cast(a_), reinterpret_cast(b_), k); + const auto dot = compute_dot(a_, b_, k); + b_ += ldb; + if (beta == 0) { + c[j*ldc+i] = alpha*dot; + } else { + c[j*ldc+i] = beta*c[j*ldc+i]+alpha*dot; + } + } + a_ += lda; + } + }); +} + +template <> +void gemm_transa_( + TransposeType transa, + int64_t m, int64_t n, int64_t k, + float alpha, + const at::BFloat16 *a, int64_t lda, + const at::BFloat16 *b, int64_t ldb, + float beta, + at::BFloat16 *c, int64_t ldc) { + // c = alpha * (a.T @ b) + beta * c + parallel_for(0, m, 1, [&](int64_t begin, int64_t end) { + const auto *a_ = a + begin * lda; + for (const auto i : c10::irange(begin, end)) { + const auto *b_ = b; + for (const auto j : c10::irange(n)) { + const auto dot = compute_dot(a_, b_, k); b_ += ldb; if (beta == 0) { c[j*ldc+i] = alpha*dot; diff --git a/aten/src/ATen/native/mps/operations/TensorCompare.mm b/aten/src/ATen/native/mps/operations/TensorCompare.mm index f378af1326a73..4da5c302214d1 100644 --- a/aten/src/ATen/native/mps/operations/TensorCompare.mm +++ b/aten/src/ATen/native/mps/operations/TensorCompare.mm @@ -276,8 +276,6 @@ static void isin_Tensor_Tensor_out_mps(const Tensor& elements, bool invert, const Tensor& out, string op_name) { - TORCH_CHECK(is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS), - "isin_Tensor_Tensor_out supported on MPS from MacOs_14_0 onwards"); if (elements.numel() == 0) { return; } @@ -295,6 +293,10 @@ static void isin_Tensor_Tensor_out_mps(const Tensor& elements, TORCH_CHECK(elements.is_mps() && test_elements.is_mps()); TORCH_CHECK(elements.dtype() == test_elements.dtype()); + TORCH_CHECK( + !(!is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS) && !supportedFloatingType(elements.scalar_type())), + "isin_Tensor_Tensor_out only works on floating types on MPS for pre MacOS_14_0. Received dtype: ", + elements.scalar_type()); @autoreleasepool { string key = diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index 6ea7a31a39150..f40f40396992e 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -3485,6 +3485,18 @@ def get_example_inputs(self): action="store_true", help="Measure speedup with TorchInductor", ) + group.add_argument( + "--quantization", + choices=[ + "int8dynamic", + "int8weightonly", + "int4weightonly", + "autoquant", + "noquant", + ], + default=None, + help="Measure speedup of torchao quantization with TorchInductor baseline", + ) group.add_argument( "--export", action="store_true", @@ -3679,6 +3691,9 @@ def run(runner, args, original_dir=None): if args.inductor: assert args.backend is None args.backend = "inductor" + if args.quantization: + assert args.backend is None + args.backend = "torchao" if args.dynamic_batch_only: args.dynamic_shapes = True torch._dynamo.config.assume_static_by_default = True @@ -3957,6 +3972,20 @@ def run(runner, args, original_dir=None): # AOTInductor doesn't support control flow yet runner.skip_models.update(runner.skip_models_due_to_control_flow) + elif args.backend == "torchao": + assert "cuda" in args.devices, "Quantization requires CUDA device." + assert args.bfloat16, "Quantization requires dtype bfloat16." + from .torchao import setup_baseline, torchao_optimize_ctx + + setup_baseline() + baseline_ctx = functools.partial( + torch.compile, + backend="inductor", + fullgraph=args.nopython, + mode=args.inductor_compile_mode, + ) + runner.model_iter_fn = baseline_ctx(runner.model_iter_fn) + optimize_ctx = torchao_optimize_ctx(args.quantization) else: optimize_ctx = torch._dynamo.optimize(args.backend, nopython=args.nopython) experiment = speedup_experiment diff --git a/benchmarks/dynamo/microbenchmarks/analyze_templates.py b/benchmarks/dynamo/microbenchmarks/analyze_templates.py new file mode 100644 index 0000000000000..65fa547123a4b --- /dev/null +++ b/benchmarks/dynamo/microbenchmarks/analyze_templates.py @@ -0,0 +1,219 @@ +""" +This script uses linear programming to analyze outputs of triton mm config tuning. +To generate output that can be fed into this script set the env varTORCHINDUCTOR_MM_LOGGING_FILE. + +That file can be fed into this script to generate the minimizes total, weighted matmul time as a function of allowed templates. +""" +import json + +import click +import pulp + + +def parse_log_file(file_path): + with open(file_path) as f: + logs = json.load(f) + + occurrence_count = {} + benchmark_logs = {} + + # Parse the logs + for entry in logs: + if "invoke" in entry: + shape = entry["invoke"] + if shape not in occurrence_count: + occurrence_count[shape] = 0 + occurrence_count[shape] += 1 + else: + for shape, timings in entry.items(): + if shape not in benchmark_logs: + benchmark_logs[shape] = [] + benchmark_logs[shape].extend(timings) + + return occurrence_count, benchmark_logs + + +def optimize_templates(N, occurrence_count, benchmark_logs, verbose=False): + # Set of all possible Triton templates keyed by their attributes + triton_templates = set() + for timings in benchmark_logs.values(): + for timing in timings: + if timing["type"] == "triton": + triton_templates.add( + ( + timing["BLOCK_M"], + timing["BLOCK_N"], + timing["BLOCK_K"], + timing["num_stages"], + timing["num_warps"], + ) + ) + + # Print the initial data + if verbose: + print("Occurrence Count:", occurrence_count) + print("Triton Templates:", triton_templates) + + # Create a dictionary to store template selection variables + template_vars = { + template: pulp.LpVariable(f"Template_{template}", 0, 1, pulp.LpBinary) + for template in triton_templates + } + + # Variables to select specific timing option for each shape + selection_vars = { + (shape, "cublas"): pulp.LpVariable( + f"Select_{shape}_cublas", 0, 1, pulp.LpBinary + ) + for shape in occurrence_count + } + for shape in occurrence_count: + for template in triton_templates: + selection_vars[(shape, template)] = pulp.LpVariable( + f"Select_{shape}_{template}", 0, 1, pulp.LpBinary + ) + + # Variables for the total time for each shape + min_time_vars = pulp.LpVariable.dicts( + "MinTime", occurrence_count.keys(), 0, None, pulp.LpContinuous + ) + + # Define the problem + prob = pulp.LpProblem("MatrixMultiplicationOptimization", pulp.LpMinimize) + + # Objective: Minimize the weighted total time + prob += pulp.lpSum( + [occurrence_count[shape] * min_time_vars[shape] for shape in occurrence_count] + ) + + # Constraints to select exactly N templates + prob += pulp.lpSum([template_vars[template] for template in triton_templates]) == N + + # Store triton options per shape for debugging + triton_options_per_shape = {} + + # Constraints for the total time for each shape + for shape in occurrence_count: + # Get cuBLAS time + cublas_times = [ + timing["time"] + for timing in benchmark_logs[shape] + if timing["type"] == "cublas" + ] + min_cublas_time = min(cublas_times) + + # Collect Triton options + triton_options = [] + for template in triton_templates: + triton_times = [ + timing["time"] + for timing in benchmark_logs[shape] + if timing["type"] == "triton" + and ( + timing["BLOCK_M"], + timing["BLOCK_N"], + timing["BLOCK_K"], + timing["num_stages"], + timing["num_warps"], + ) + == template + ] + if triton_times: + min_triton_time = min(triton_times) + triton_options.append((min_triton_time, template)) + + # Save triton options for debugging + triton_options_per_shape[shape] = triton_options + + # Ensure exactly one timing option is selected for each shape + prob += ( + pulp.lpSum( + [selection_vars[(shape, "cublas")]] + + [ + selection_vars[(shape, template)] + for triton_time, template in triton_options + ] + ) + == 1 + ) + + # Ensure min_time_vars[shape] matches the selected timing option + prob += min_time_vars[shape] == ( + selection_vars[(shape, "cublas")] * min_cublas_time + + pulp.lpSum( + [ + selection_vars[(shape, template)] * triton_time + for triton_time, template in triton_options + ] + ) + ) + + # Ensure Triton templates can only be selected if they are included in the N allowed templates + for triton_time, template in triton_options: + prob += selection_vars[(shape, template)] <= template_vars[template] + + # Print the constraints + if verbose: + print("Constraints:") + for constraint in prob.constraints.values(): + print(constraint) + + # Solve the problem with suppressed output + prob.solve(pulp.PULP_CBC_CMD(msg=False)) + + # Output the selected templates and their configurations + selected_templates = [ + template + for template in triton_templates + if pulp.value(template_vars[template]) == 1 + ] + total_time = sum( + pulp.value(min_time_vars[shape]) * occurrence_count[shape] + for shape in occurrence_count + ) + + # Print the values of the decision variables after solving + if verbose: + print("Decision Variable Values:") + for var in prob.variables(): + print(f"{var.name} = {var.varValue}") + + # # Debugging information + if verbose: + for shape in occurrence_count: + print(f"Shape: {shape}") + print(f" Min Time: {pulp.value(min_time_vars[shape])}") + print(f" Occurrences: {occurrence_count[shape]}") + print( + f" Min CuBLAS Time: {min_cublas_time} Selected: {pulp.value(selection_vars[(shape, 'cublas')])}" + ) + for triton_time, template in triton_options_per_shape[shape]: + print( + f" Triton Template: {template} Time: {triton_time} Selected: {pulp.value(selection_vars[(shape, template)])}" + ) + + return selected_templates, total_time + + +# Main code to parse the log file and optimize templates +@click.command() +@click.argument("filename") +@click.option("--min-templates", default=0, help="Minimum number of templates.") +@click.option("--max-templates", default=10, help="Maximum number of templates.") +@click.option("--verbose", is_flag=True, help="Enable verbose output.") +def main(filename, min_templates, max_templates, verbose): + occurrence_count, benchmark_logs = parse_log_file(filename) + times = [] + for N in range(min_templates, max_templates + 1): + selected_templates, total_time = optimize_templates( + N, occurrence_count, benchmark_logs, verbose + ) + print(f"N = {N}") + print(f"Selected Templates: {selected_templates}") + print(f"Total Weighted Time: {total_time}") + times.append(total_time) + print(times) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/dynamo/torchao.py b/benchmarks/dynamo/torchao.py new file mode 100644 index 0000000000000..29e7d55d76ce1 --- /dev/null +++ b/benchmarks/dynamo/torchao.py @@ -0,0 +1,54 @@ +from typing import Any, Callable + +import torch + + +def setup_baseline(): + torch._dynamo.epilogue_fusion = False + torch._dynamo.config.automatic_dynamic_shapes = False + torch._dynamo.config.force_parameter_static_shapes = False + torch._dynamo.config.cache_size_limit = 10000 + torch._inductor.config.force_fuse_int_mm_with_mul = True + torch._inductor.config.use_mixed_mm = True + + +def torchao_optimize_ctx(quantization: str): + import torchao + from torchao.quantization import ( + change_linear_weights_to_int4_woqtensors, + change_linear_weights_to_int8_dqtensors, + change_linear_weights_to_int8_woqtensors, + ) + + def inner(model_iter_fn: Callable): + def _torchao_apply(module: torch.nn.Module, example_inputs: Any): + if getattr(module, "_quantized", None) is None: + if quantization == "int8dynamic": + change_linear_weights_to_int8_dqtensors(module) + elif quantization == "int8weightonly": + change_linear_weights_to_int8_woqtensors(module) + elif quantization == "int4weightonly": + change_linear_weights_to_int4_woqtensors(module) + elif quantization == "autoquant": + torchao.autoquant(module, error_on_unseen=False) + if isinstance(example_inputs, dict): + module(**example_inputs) + else: + module(*example_inputs) + from torchao.quantization.autoquant import AUTOQUANT_CACHE + + assert ( + len(AUTOQUANT_CACHE) > 0 + ), f"Err: found no autoquantizable layers in model {type(module)}, stopping autoquantization" + elif quantization == "noquant": + pass + else: + raise AssertionError( + f"Unsupposed quantization mode {quantization}." + ) + setattr(module, "_quantized", True) # noqa: B010 + model_iter_fn(module, example_inputs) + + return _torchao_apply + + return inner diff --git a/benchmarks/dynamo/torchbench.py b/benchmarks/dynamo/torchbench.py index 3086bddc4bb5b..2a9437e08b698 100755 --- a/benchmarks/dynamo/torchbench.py +++ b/benchmarks/dynamo/torchbench.py @@ -423,13 +423,19 @@ def compute_loss(self, pred): def forward_pass(self, mod, inputs, collect_outputs=True): with self.autocast(**self.autocast_arg): - return mod(*inputs) + if isinstance(inputs, dict): + return mod(**inputs) + else: + return mod(*inputs) def forward_and_backward_pass(self, mod, inputs, collect_outputs=True): cloned_inputs = clone_inputs(inputs) self.optimizer_zero_grad(mod) with self.autocast(**self.autocast_arg): - pred = mod(*cloned_inputs) + if isinstance(clone_inputs, dict): + pred = mod(**cloned_inputs) + else: + pred = mod(*cloned_inputs) loss = self.compute_loss(pred) self.grad_scaler.scale(loss).backward() self.optimizer_step() diff --git a/build_variables.bzl b/build_variables.bzl index 152324a4d90cb..3f16f9b847c1c 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -487,7 +487,6 @@ libtorch_core_sources = sorted( # These files are the only ones that are supported on Windows. libtorch_distributed_base_sources = [ "torch/csrc/distributed/c10d/Backend.cpp", - "torch/csrc/distributed/c10d/control_collectives/StoreCollectives.cpp", "torch/csrc/distributed/c10d/FileStore.cpp", "torch/csrc/distributed/c10d/Functional.cpp", "torch/csrc/distributed/c10d/GlooDeviceFactory.cpp", diff --git a/docs/source/fx.experimental.rst b/docs/source/fx.experimental.rst index 6abfb89971cd9..d6885eb41ca07 100644 --- a/docs/source/fx.experimental.rst +++ b/docs/source/fx.experimental.rst @@ -30,6 +30,7 @@ torch.fx.experimental.symbolic_shapes CallMethodKey PropagateUnbackedSymInts DivideByKey + InnerTensorKey hint_int is_concrete_int diff --git a/test/allowlist_for_publicAPI.json b/test/allowlist_for_publicAPI.json index ac77325188ee0..c3d3fe2f00ec8 100644 --- a/test/allowlist_for_publicAPI.json +++ b/test/allowlist_for_publicAPI.json @@ -2027,6 +2027,7 @@ "uninteresting_files", "CallMethodKey", "DivideByKey", + "InnerTensorKey", "PropagateUnbackedSymInts", "ShapeEnvSettings", "log_lru_cache_stats", @@ -2752,4 +2753,4 @@ "torch.utils.hipify.hipify_python": [ "TrieNode" ] -} \ No newline at end of file +} diff --git a/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp b/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp index aef97daae2e47..d83a9494112c6 100644 --- a/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp +++ b/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp @@ -69,7 +69,7 @@ class ProcessGroupNCCLSimulateErrors : public c10d::ProcessGroupNCCL { const std::vector& outputs = {}, bool record = false) override { return c10::make_intrusive( - device, simulateError_, rank, opType, seq_); + device, simulateError_, rank, opType, seqCollective_); } size_t getNCCLCommCacheSize() { @@ -131,7 +131,7 @@ class ProcessGroupNCCLTimedOutErrors : public ProcessGroupNCCLSimulateErrors { const std::vector& outputs = {}, bool record = false) override { return c10::make_intrusive( - device, setTimedoutError_, rank, opType, seq_); + device, setTimedoutError_, rank, opType, seqCollective_); } void setTimedoutError() { diff --git a/test/distributed/_composable/test_replicate_with_compiler.py b/test/distributed/_composable/test_replicate_with_compiler.py index 354f99dabd739..d4960681b27ed 100644 --- a/test/distributed/_composable/test_replicate_with_compiler.py +++ b/test/distributed/_composable/test_replicate_with_compiler.py @@ -279,12 +279,16 @@ def bwd(loss): self.assertEqual(counters["inductor"]["ddp_buckets"], 3) return code - def test_bucketing_coalesced_op(self): - torch._inductor.config._fuse_ddp_communication_passes = [ + @torch._inductor.config.patch( + _fuse_ddp_communication_passes=[ "fuse_ddp_with_coalesced_op", "schedule_comm_wait", ] - + ) + # todo: This pass mucks things up since Inductor thinks its inference + # and can apply this. Should turn off these passes in compiled autograd + @torch._inductor.config.patch(reorder_for_locality=False) + def test_bucketing_coalesced_op(self): # Gradient is None code = self._test_bucketing() self.assertEqual(counters["inductor"]["ddp_buckets"], 3) @@ -311,12 +315,16 @@ def test_bucketing_coalesced_op(self): fc.run(code) - def test_bucketing_concat_op(self): - torch._inductor.config._fuse_ddp_communication_passes = [ + @torch._inductor.config.patch( + _fuse_ddp_communication_passes=[ "fuse_ddp_with_concat_op", "schedule_comm_wait", ] - + ) + # todo: This pass mucks things up since Inductor thinks its inference + # and can apply this. Should turn off these passes in compiled autograd + @torch._inductor.config.patch(reorder_for_locality=False) + def test_bucketing_concat_op(self): # Gradient is None code = self._test_bucketing() self.assertEqual(counters["inductor"]["ddp_buckets"], 3) diff --git a/test/distributed/pipelining/model_registry.py b/test/distributed/pipelining/model_registry.py index f88bebd3a5598..ca811de3d75db 100644 --- a/test/distributed/pipelining/model_registry.py +++ b/test/distributed/pipelining/model_registry.py @@ -17,9 +17,8 @@ def __init__(self, d_hid: int = default_dhid): self.lin0 = torch.nn.Linear(d_hid, d_hid) self.lin1 = torch.nn.Linear(d_hid, d_hid) - def forward(self, x, y=torch.zeros(default_batch_size, default_dhid)): + def forward(self, x): x = torch.mm(x, self.mm_param0) - x = x + y x = torch.relu(x) # try passing a value that doesn't require_grad across skip boundaries a_constant = self.cval.clone() @@ -32,6 +31,29 @@ def forward(self, x, y=torch.zeros(default_batch_size, default_dhid)): return x +class ModelWithKwargs(torch.nn.Module): + default_dhid = 512 + default_batch_size = 256 + + def __init__(self, d_hid: int = default_dhid): + super().__init__() + self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) + self.mm_param1 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) + self.lin0 = torch.nn.Linear(d_hid, d_hid) + self.lin1 = torch.nn.Linear(d_hid, d_hid) + + def forward(self, x, y=torch.zeros(default_batch_size, default_dhid)): + x = torch.mm(x, self.mm_param0) + x = x + y + x = self.lin0(x) + x = torch.relu(x) + pipe_split() + x = torch.mm(x, self.mm_param1) + x = self.lin1(x) + x = torch.relu(x) + return x + + # MLP Layer class MLPModule(torch.nn.Module): def __init__(self, d_hid): diff --git a/test/distributed/pipelining/test_stage_backward.py b/test/distributed/pipelining/test_backward.py similarity index 100% rename from test/distributed/pipelining/test_stage_backward.py rename to test/distributed/pipelining/test_backward.py diff --git a/test/distributed/pipelining/test_chunkspec.py b/test/distributed/pipelining/test_chunkspec.py index 050a7b11a21bc..1b104e59ec779 100644 --- a/test/distributed/pipelining/test_chunkspec.py +++ b/test/distributed/pipelining/test_chunkspec.py @@ -16,7 +16,7 @@ torch.manual_seed(0) -class ExampleCode(torch.nn.Module): +class ModelWithKwargs(torch.nn.Module): def __init__(self): super().__init__() self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) @@ -44,7 +44,7 @@ def forward(self, x, y, z=torch.zeros(batch_size, d_hid)): class ChunkSpecTests(TestCase): def test_chunk_spec(self): - mod = ExampleCode() + mod = ModelWithKwargs() x = torch.randn(batch_size, d_hid) y = torch.randn(batch_size, d_hid) diff --git a/test/distributed/pipelining/test_schedule.py b/test/distributed/pipelining/test_schedule.py index 8357f3b66108d..c1fb6b075f766 100644 --- a/test/distributed/pipelining/test_schedule.py +++ b/test/distributed/pipelining/test_schedule.py @@ -8,7 +8,7 @@ import torch import torch.distributed as dist -from model_registry import ExampleCode, MultiMLP +from model_registry import ModelWithKwargs, MultiMLP from torch.distributed.pipelining import ( pipeline, PipelineStage, @@ -50,60 +50,11 @@ def setUpClass(cls): dev_id = cls.rank % torch.cuda.device_count() cls.device = torch.device(f"cuda:{dev_id}") - @requires_nccl() - @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") - def test_ec_forward(self): - # Setting this flag for numerical stability - torch.distributed.pipelining.microbatch._debug_mask_minibatches = True - - mod = ExampleCode(d_hid) - mod.to(self.device) - - x = torch.randn(batch_size, d_hid, device=self.device) - y = torch.randn(batch_size, d_hid, device=self.device) - - pipe = pipeline( - mod, - chunks, - example_args=(x,), - example_kwargs={"y": y}, - ) - - stage = PipelineStage( - pipe, - self.rank, - device=self.device, - ) - - # Attach to a schedule - schedule = ScheduleGPipe(stage, chunks) - - # Run - if self.rank == 0: - schedule.step(x, y=y) - else: - out = schedule.step() - - dist.barrier() - - # Last rank checks result - if self.rank == self.world_size - 1: - ref_out = mod(x, y=y) - torch.testing.assert_close(out, ref_out) - - # Test qualname mapping - submod_keys = stage.submod.state_dict().keys() - # Confirm keys are consistent with original model - old_keys = mod.state_dict().keys() - assert all(k in old_keys for k in submod_keys) - # Reset this flag - torch.distributed.pipelining.microbatch._debug_mask_minibatches = False - @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") @parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B]) def test_ec_backward(self, ScheduleClass): - mod = ExampleCode(d_hid) + mod = ModelWithKwargs(d_hid) mod.to(self.device) x = torch.randn(batch_size, d_hid, device=self.device) diff --git a/test/distributed/pipelining/test_stage.py b/test/distributed/pipelining/test_stage.py new file mode 100644 index 0000000000000..20f40ea5fa298 --- /dev/null +++ b/test/distributed/pipelining/test_stage.py @@ -0,0 +1,198 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# Owner(s): ["oncall: distributed"] +import os +import sys +import tempfile + +import torch +import torch.distributed as dist + +from model_registry import ExampleCode, ModelWithKwargs, MultiMLP +from torch.distributed.pipelining import ( + ManualPipelineStage, + pipeline, + PipelineStage, + ScheduleGPipe, +) +from torch.testing._internal.common_cuda import TEST_MULTIGPU +from torch.testing._internal.common_distributed import ( + MultiProcContinousTest, + requires_nccl, +) +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, + skip_but_pass_in_sandcastle_if, +) + + +d_hid = 512 +batch_size = 256 +chunks = 4 + +torch.manual_seed(0) + + +class StageTest(MultiProcContinousTest): + @classmethod + def backend_str(cls) -> str: + # Testing with NCCL backend + return "nccl" + + @classmethod + def setUpClass(cls): + """ + Class-scope test fixture. Run once for entire test class, before any test starts. + Set up the device. + """ + super().setUpClass() + dev_id = cls.rank % torch.cuda.device_count() + cls.device = torch.device(f"cuda:{dev_id}") + + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @parametrize("ModelClass", [ExampleCode, MultiMLP]) + def test_tracer(self, ModelClass): + mod = ModelClass(d_hid) + mod.to(self.device) + + x = torch.randn(batch_size, d_hid, device=self.device) + + pipe = pipeline( + mod, + chunks, + example_args=(x,), + ) + + stage = PipelineStage( + pipe, + self.rank, + device=self.device, + ) + + # Attach to a schedule + schedule = ScheduleGPipe(stage, chunks) + + # Run + if self.rank == 0: + schedule.step(x) + else: + out = schedule.step() + + # Last rank checks result + if self.rank == self.world_size - 1: + ref_out = mod(x) + torch.testing.assert_close(out, ref_out, atol=1e-3, rtol=5e-2) + + # Test qualname mapping + submod_keys = stage.submod.state_dict().keys() + # Confirm keys are consistent with original model + old_keys = mod.state_dict().keys() + assert all(k in old_keys for k in submod_keys) + + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @parametrize("ModelClass", [ModelWithKwargs]) + def test_tracer_kwargs(self, ModelClass): + mod = ModelClass(d_hid) + mod.to(self.device) + + x = torch.randn(batch_size, d_hid, device=self.device) + y = torch.randn(batch_size, d_hid, device=self.device) + + pipe = pipeline( + mod, + chunks, + example_args=(x,), + example_kwargs={"y": y}, + ) + + stage = PipelineStage( + pipe, + self.rank, + device=self.device, + ) + + # Attach to a schedule + schedule = ScheduleGPipe(stage, chunks) + + # Run + if self.rank == 0: + schedule.step(x, y=y) + else: + out = schedule.step() + + # Last rank checks result + if self.rank == self.world_size - 1: + ref_out = mod(x, y=y) + torch.testing.assert_close(out, ref_out, atol=1e-3, rtol=5e-2) + + # Test qualname mapping + submod_keys = stage.submod.state_dict().keys() + # Confirm keys are consistent with original model + old_keys = mod.state_dict().keys() + assert all(k in old_keys for k in submod_keys) + + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + def test_manual(self): + full_mod = MultiMLP(d_hid).to(self.device) + stage_mod = full_mod.get_submodule(f"mlp{self.rank}") + stage_mod.to(self.device) + + x = torch.randn(batch_size, d_hid, device=self.device) + + stage = ManualPipelineStage( + stage_mod, + self.rank, + self.world_size, + self.device, + chunks, + input_args=x.chunk(chunks)[0], + ) + + # Attach to a schedule + schedule = ScheduleGPipe(stage, chunks) + + # Run + if self.rank == 0: + schedule.step(x) + else: + out = schedule.step() + + # Last rank checks result + if self.rank == self.world_size - 1: + ref_out = full_mod(x) + torch.testing.assert_close(out, ref_out) + + +instantiate_parametrized_tests(StageTest) + +if __name__ == "__main__": + # Check if GPU and NCCL are available + if not ( + dist.is_available() + and dist.is_nccl_available() + and torch.cuda.device_count() > 1 + ): + print( + "c10d NCCL not available or not enough GPUs, skipping tests", + file=sys.stderr, + ) + sys.exit(0) + + rank = int(os.getenv("RANK", -1)) + world_size = int(os.getenv("WORLD_SIZE", 2)) + + if rank != -1: + # Launched with torchrun or other multi-proc launchers. Directly run the test. + StageTest.run_rank(rank, world_size) + else: + # Launched as a single process. Spawn subprocess to run the tests. + # Also need a rendezvous file for `init_process_group` purpose. + rdvz_file = tempfile.NamedTemporaryFile(delete=False).name + torch.multiprocessing.spawn( + StageTest.run_rank, + nprocs=world_size, + args=(world_size, rdvz_file), + ) diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index 5a958acdbdd74..e71bfb52b2254 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -3524,7 +3524,7 @@ def test_short(self, timing_enabled): t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace()) ver = t["version"] - self.assertEqual(ver, "1.5") + self.assertEqual(ver, "2.0") pg_config = t["pg_config"] self.assertEqual(len(pg_config), 1) default_pg_info = pg_config["0"] @@ -3548,7 +3548,7 @@ def test_short(self, timing_enabled): self.assertIn("test_c10d_nccl.py", str(last["frames"])) self.assertEqual(last["input_sizes"], ((3, 4),)) self.assertEqual(last["output_sizes"], ((3, 4),)) - self.assertEqual(last["seq_id"], 2) + self.assertEqual(last["collective_seq_id"], 2) now = datetime.now() event_created_time = datetime.fromtimestamp( last["time_created_ns"] / 1000000000 @@ -3629,7 +3629,7 @@ def test_long(self): self.assertIn("test_c10d_nccl.py", str(last["frames"])) self.assertEqual(last["input_sizes"], ((3, 4),)) self.assertEqual(last["output_sizes"], ((3, 4),)) - self.assertEqual(last["seq_id"] - first["seq_id"], 9) + self.assertEqual(last["collective_seq_id"] - first["collective_seq_id"], 9) @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") @@ -3659,10 +3659,10 @@ def test_trace_while_active(self, timing_enabled): t = t["entries"] self.assertEqual(t[-1]["profiling_name"], "nccl:all_reduce") if self.rank == 0: - self.assertEqual(t[-1]["seq_id"], 1) + self.assertEqual(t[-1]["collective_seq_id"], 1) self.assertEqual(t[-1]["state"], "completed") else: - self.assertEqual(t[-1]["seq_id"], 2) + self.assertEqual(t[-1]["collective_seq_id"], 2) self.assertEqual( t[-1]["state"], self.started_or_scheduled(timing_enabled) ) @@ -3704,10 +3704,10 @@ def gather_trace(): t = t["entries"] self.assertEqual(t[-1]["profiling_name"], "nccl:all_reduce") if self.rank == 0: - self.assertEqual(t[-1]["seq_id"], 1) + self.assertEqual(t[-1]["collective_seq_id"], 1) self.assertEqual(t[-1]["state"], "completed") else: - self.assertEqual(t[-1]["seq_id"], 2) + self.assertEqual(t[-1]["collective_seq_id"], 2) self.assertEqual( t[-1]["state"], self.started_or_scheduled(timing_enabled) ) @@ -3799,7 +3799,9 @@ def test_batched_send_recv(self, op_sizes_per_coalesce, timing_enabled): self.assertEqual( t["entries"][p2p_op_idx]["profiling_name"], profiling_name ) - self.assertEqual(t["entries"][p2p_op_idx]["seq_id"], expected_seq) + self.assertEqual( + t["entries"][p2p_op_idx]["collective_seq_id"], expected_seq + ) self.assertEqual(t["entries"][p2p_op_idx]["op_id"], expected_op_id) expected_op_id += 1 self.assertEqual(t["entries"][p2p_op_idx]["input_sizes"], [input_sizes]) @@ -3819,7 +3821,9 @@ def test_batched_send_recv(self, op_sizes_per_coalesce, timing_enabled): self.assertEqual( t["entries"][coalesced_op]["profiling_name"], "nccl:coalesced" ) - self.assertEqual(t["entries"][coalesced_op]["seq_id"], expected_seq) + self.assertEqual( + t["entries"][coalesced_op]["collective_seq_id"], expected_seq + ) expected_seq += 1 self.assertEqual(t["entries"][coalesced_op]["state"], "completed") self.assertEqual(t["entries"][coalesced_op]["input_sizes"], []) @@ -3875,7 +3879,7 @@ def test_individual_send_recv(self, op_sizes, timing_enabled): input_sizes = op_sizes[seq % ops_per_repeat] profiling_name = "nccl:recv 0<-1" if self.rank == 0 else "nccl:send 1->0" self.assertEqual(t["entries"][seq]["profiling_name"], profiling_name) - self.assertEqual(t["entries"][seq]["seq_id"], expected_seq) + self.assertEqual(t["entries"][seq]["p2p_seq_id"], expected_seq) expected_seq += 1 self.assertEqual(t["entries"][seq]["op_id"], expected_op_id) expected_op_id += 1 @@ -3935,7 +3939,7 @@ def test_coalescing_manager_collective(self, timing_enabled): self.assertEqual( t["entries"][0]["profiling_name"], "nccl:reduce_scatter_tensor_coalesced" ) - self.assertEqual(t["entries"][0]["seq_id"], 1) + self.assertEqual(t["entries"][0]["collective_seq_id"], 1) self.assertEqual(t["entries"][0]["input_sizes"], [[2, 2], [2, 2]]) self.assertEqual( t["entries"][0]["output_sizes"], @@ -4003,9 +4007,9 @@ def test_timeout_dumps(self, timing_enabled): t = pickle.load(f) t = t["entries"] self.assertEqual(len(t), 2) - self.assertEqual(t[0]["seq_id"], 1) + self.assertEqual(t[0]["collective_seq_id"], 1) self.assertEqual(t[0]["state"], "completed") - self.assertEqual(t[1]["seq_id"], 2) + self.assertEqual(t[1]["collective_seq_id"], 2) self.assertEqual( t[1]["state"], self.started_or_scheduled(timing_enabled) ) @@ -4066,7 +4070,7 @@ def test_timeout_dumps_on_stuck_ranks(self): t = pickle.load(f) t = t["entries"] self.assertEqual(len(t), 1) - self.assertEqual(t[0]["seq_id"], 1) + self.assertEqual(t[0]["collective_seq_id"], 1) self.assertEqual(t[0]["state"], "completed") return diff --git a/test/distributed/test_control_collectives.py b/test/distributed/test_control_collectives.py deleted file mode 100644 index fb0067f2dd2e9..0000000000000 --- a/test/distributed/test_control_collectives.py +++ /dev/null @@ -1,189 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -from datetime import timedelta -from multiprocessing.pool import ThreadPool - -import torch -import torch.distributed as dist -from torch.testing._internal.common_utils import run_tests, TestCase - - -class TestCollectives(TestCase): - def test_barrier(self) -> None: - store = dist.HashStore() - - world_size = 2 - - def f(rank: int) -> None: - collectives = dist._StoreCollectives(store, rank, world_size) - collectives.barrier("foo", timedelta(seconds=10), True) - - with ThreadPool(world_size) as pool: - pool.map(f, range(world_size)) - - def test_broadcast(self) -> None: - store = dist.HashStore() - - world_size = 4 - timeout = timedelta(seconds=10) - - def f(rank: int) -> None: - collectives = dist._StoreCollectives(store, rank, world_size) - if rank == 2: - collectives.broadcast_send("foo", b"data", timeout) - else: - out = collectives.broadcast_recv("foo", timeout) - self.assertEqual(out, b"data") - - with ThreadPool(world_size) as pool: - pool.map(f, range(world_size)) - - def test_gather(self) -> None: - store = dist.HashStore() - - world_size = 4 - timeout = timedelta(seconds=10) - - def f(rank: int) -> None: - collectives = dist._StoreCollectives(store, rank, world_size) - if rank == 2: - out = collectives.gather_recv("foo", str(rank), timeout) - self.assertEqual(out, [b"0", b"1", b"2", b"3"]) - else: - collectives.gather_send("foo", str(rank), timeout) - - with ThreadPool(world_size) as pool: - pool.map(f, range(world_size)) - - def test_scatter(self) -> None: - store = dist.HashStore() - - world_size = 4 - timeout = timedelta(seconds=10) - - def f(rank: int) -> None: - collectives = dist._StoreCollectives(store, rank, world_size) - if rank == 2: - out = collectives.scatter_send( - "foo", [str(i) for i in range(world_size)], timeout - ) - else: - out = collectives.scatter_recv("foo", timeout) - self.assertEqual(out, str(rank).encode()) - - with ThreadPool(world_size) as pool: - pool.map(f, range(world_size)) - - def test_all_sum(self) -> None: - store = dist.HashStore() - - world_size = 4 - timeout = timedelta(seconds=10) - - def f(rank: int) -> None: - collectives = dist._StoreCollectives(store, rank, world_size) - out = collectives.all_sum("foo", rank, timeout) - self.assertEqual(out, sum(range(world_size))) - - with ThreadPool(world_size) as pool: - pool.map(f, range(world_size)) - - def test_broadcast_timeout(self) -> None: - store = dist.HashStore() - - world_size = 4 - timeout = timedelta(milliseconds=1) - collectives = dist._StoreCollectives(store, 1, world_size) - with self.assertRaisesRegex(Exception, "Wait timeout"): - collectives.broadcast_recv("foo", timeout) - - def test_gather_timeout(self) -> None: - store = dist.HashStore() - - world_size = 4 - timeout = timedelta(milliseconds=1) - collectives = dist._StoreCollectives(store, 1, world_size) - with self.assertRaisesRegex( - Exception, "gather failed -- missing ranks: 0, 2, 3" - ): - collectives.gather_recv("foo", "data", timeout) - - def test_scatter_timeout(self) -> None: - store = dist.HashStore() - - world_size = 4 - timeout = timedelta(milliseconds=1) - collectives = dist._StoreCollectives(store, 1, world_size) - with self.assertRaisesRegex(Exception, "Wait timeout"): - collectives.scatter_recv("foo", timeout) - - def test_all_gather_timeout(self) -> None: - store = dist.HashStore() - - world_size = 4 - timeout = timedelta(milliseconds=1) - collectives = dist._StoreCollectives(store, 1, world_size) - with self.assertRaisesRegex( - Exception, "all_gather failed -- missing ranks: 0, 2, 3" - ): - collectives.all_gather("foo", "data", timeout) - - def test_barrier_timeout(self) -> None: - store = dist.HashStore() - - world_size = 4 - timeout = timedelta(milliseconds=1) - collectives = dist._StoreCollectives(store, 1, world_size) - with self.assertRaisesRegex( - Exception, "barrier failed -- missing ranks: 0, 2, 3" - ): - collectives.barrier("foo", timeout, True) - - def test_all_sum_timeout(self) -> None: - store = dist.HashStore() - - world_size = 4 - timeout = timedelta(milliseconds=1) - collectives = dist._StoreCollectives(store, 1, world_size) - with self.assertRaisesRegex( - Exception, "barrier failed -- missing ranks: 0, 2, 3" - ): - collectives.all_sum("foo", 1, timeout) - - def test_unique(self) -> None: - store = dist.HashStore() - - collectives = dist._StoreCollectives(store, 1, 1) - collectives.broadcast_send("foo", "bar") - - with self.assertRaisesRegex(Exception, "Key foo has already been used"): - collectives.broadcast_send("foo", "bar") - - with self.assertRaisesRegex(Exception, "Key foo has already been used"): - collectives.broadcast_recv("foo") - - with self.assertRaisesRegex(Exception, "Key foo has already been used"): - collectives.gather_send("foo", "bar") - - with self.assertRaisesRegex(Exception, "Key foo has already been used"): - collectives.gather_recv("foo", "asdf") - - with self.assertRaisesRegex(Exception, "Key foo has already been used"): - collectives.scatter_send("foo", ["asdf"]) - - with self.assertRaisesRegex(Exception, "Key foo has already been used"): - collectives.scatter_recv("foo") - - with self.assertRaisesRegex(Exception, "Key foo has already been used"): - collectives.all_gather("foo", "bar") - - with self.assertRaisesRegex(Exception, "Key foo has already been used"): - collectives.all_sum("foo", 2) - - -if __name__ == "__main__": - assert ( - not torch.cuda._initialized - ), "test_distributed must not have initialized CUDA context on main process" - - run_tests() diff --git a/test/dynamo/test_comptime.py b/test/dynamo/test_comptime.py index 6d874a005047b..a14c889a3bce7 100644 --- a/test/dynamo/test_comptime.py +++ b/test/dynamo/test_comptime.py @@ -189,6 +189,43 @@ def _(ctx): """, ) + # Just make sure it doesn't crash + def test_print_direct(self): + cnt = torch._dynamo.testing.CompileCounter() + + @torch._dynamo.optimize(cnt) + def f(x, z): + y = x * 2 + lambda: z + comptime.print(z) + return y + 3 + + f(torch.randn(2), torch.randn(2)) + + # Just make sure it doesn't crash + def test_get_local_closure_variable(self): + global SELF + SELF = self + cnt = torch._dynamo.testing.CompileCounter() + + @torch._dynamo.optimize(cnt) + def f(x): + z = 3 + + def g(): + @comptime + def _(ctx): + r = ctx.get_local("z") + SELF.assertEqual(repr(r), "3") + + comptime.print(z) + return 2 + + y = x * g() + return y + 3 + + f(torch.randn(2)) + def test_print_bt(self): global FILE FILE = StringIO() diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 5d7f780457d09..f07021c315585 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -1358,6 +1358,21 @@ def f(x): self.assertRaises(torch._dynamo.exc.UserError, lambda: f(torch.tensor([3]))) + def test_assert(self): + @torch.compile + def fn1(x): + assert x.shape != x.shape + + with self.assertRaises(AssertionError): + a = torch.randn(10) + fn1(a) + + def fn2(x): + assert x.shape == x.shape + return x.abs() + + torch._dynamo.testing.standard_test(self, fn=fn2, nargs=1, expected_ops=1) + def test_config_obj(self): class Cfg: def __init__(self): diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 96bf924e09990..85b95370db240 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -4709,8 +4709,7 @@ def forward(self, primals_1, primals_2): _foreach_copy = torch.ops.aten._foreach_copy.default([primals_1], [primals_2]); primals_1 = primals_2 = None getitem = _foreach_copy[0]; _foreach_copy = None mm = torch.ops.aten.mm.default(getitem, getitem) - t_1 = torch.ops.aten.t.default(getitem); getitem = None - return [mm, t_1]""", + return [mm, getitem]""", ) self.assertEqual(out_ref, out_test) diff --git a/test/dynamo_expected_failures/TestScript.test_unspecialized_any_binding b/test/dynamo_expected_failures/TestScript.test_unspecialized_any_binding deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestSortAndSelectCPU.test_sort_overflow_cpu_int16 b/test/dynamo_expected_failures/TestSortAndSelectCPU.test_sort_overflow_cpu_int16 deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestPythonRegistration.test_alias_analysis b/test/dynamo_skips/TestNNDeviceTypeCPU.test_conv_empty_input_cpu_complex128 similarity index 100% rename from test/dynamo_expected_failures/TestPythonRegistration.test_alias_analysis rename to test/dynamo_skips/TestNNDeviceTypeCPU.test_conv_empty_input_cpu_complex128 diff --git a/test/export/test_export.py b/test/export/test_export.py index 406e1f55dd804..6de95a1b6b406 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -148,6 +148,7 @@ class Inp: NON_STRICT_SUFFIX = "_non_strict" RETRACEABILITY_SUFFIX = "_retraceability" +PREDISPATCH_SUFFIX = "_pre_dispatch" def is_non_strict_test(test_name): @@ -3279,6 +3280,159 @@ def dynamify_inp(x): with self.assertRaisesRegex(RuntimeError, "shape\[0\] to be >= 3, but got 2"): ep.module()(*test_inp) + def test_nested_module(self): + class M1(torch.nn.Module): + def forward(self, x): + return x + x + + class M2(torch.nn.Module): + def forward(self, x): + m = M1() + return m(x) * x + + inps = (torch.randn(3, 3),) + ep = export(M2(), inps) + self.assertTrue(torch.allclose(ep.module()(*inps), M2()(*inps))) + + add_nodes = [ + node + for node in ep.graph.nodes + if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor + ] + self.assertEqual(len(add_nodes), 1) + add_node = add_nodes[0] + self.assertEqual(len(add_node.meta["nn_module_stack"]), 1) + self.assertTrue("M2" in list(add_node.meta["nn_module_stack"].values())[0][1]) + + self.assertExpectedInline( + str(ep.graph).strip(), + """\ +graph(): + %x : [num_users=2] = placeholder[target=x] + %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %x), kwargs = {}) + %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %x), kwargs = {}) + return (mul,)""", + ) + + unflattened = unflatten(ep) + self.assertTrue(torch.allclose(unflattened(*inps), M2()(*inps))) + + def test_nested_module_with_init_buffer(self): + class M1(torch.nn.Module): + def __init__(self): + super().__init__() + self.b = torch.ones(3, 3) + + def forward(self, x): + return x + self.b + + class M2(torch.nn.Module): + def forward(self, x): + m = M1() + return m(x) * x + + inps = (torch.randn(3, 3),) + ep = export(M2(), inps) + self.assertTrue(torch.allclose(ep.module()(*inps), M2()(*inps))) + + self.assertEqual(len(ep.state_dict), 0) + self.assertEqual(len(ep.constants), 0) + + self.assertExpectedInline( + str(ep.graph).strip(), + """\ +graph(): + %x : [num_users=2] = placeholder[target=x] + %ones : [num_users=1] = call_function[target=torch.ops.aten.ones.default](args = ([3, 3],), kwargs = {device: cpu, pin_memory: False}) + %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %ones), kwargs = {}) + %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %x), kwargs = {}) + return (mul,)""", + ) + + unflattened = unflatten(ep) + self.assertTrue(torch.allclose(unflattened(*inps), M2()(*inps))) + + @testing.expectedFailureRetraceability # Retracing tensor constants results in buffers + def test_nested_module_with_constant_buffer(self): + class M1(torch.nn.Module): + def __init__(self): + super().__init__() + self.b = torch.tensor(5) + + def forward(self, x): + return x + self.b + + class M2(torch.nn.Module): + def forward(self, x): + m = M1() + return m(x) * x + + inps = (torch.randn(3, 3),) + ep = export(M2(), inps) + self.assertTrue(torch.allclose(ep.module()(*inps), M2()(*inps))) + + self.assertEqual(len(ep.state_dict), 0) + self.assertEqual(len(ep.constants), 1) + + self.assertExpectedInline( + str(ep.graph).strip(), + """\ +graph(): + %c_lifted_tensor_0 : [num_users=1] = placeholder[target=c_lifted_tensor_0] + %x : [num_users=2] = placeholder[target=x] + %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%c_lifted_tensor_0,), kwargs = {}) + %detach : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%lift_fresh_copy,), kwargs = {}) + %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %detach), kwargs = {}) + %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %x), kwargs = {}) + return (mul,)""", + ) + + unflattened = unflatten(ep) + self.assertTrue(torch.allclose(unflattened(*inps), M2()(*inps))) + + def test_nested_module_with_parameter(self): + class M1(torch.nn.Module): + def __init__(self): + super().__init__() + self.a = torch.nn.Parameter(torch.ones(3, 3)) + self.b = torch.nn.Parameter(torch.tensor(5.0)) + + def forward(self, x): + return x + self.a * self.b + + class M2(torch.nn.Module): + def forward(self, x): + m = M1() + return m(x) * x + + inps = (torch.randn(3, 3),) + # Strict export segfaults (Issue #128109) + ep = torch.export.export(M2(), inps, strict=False) + self.assertTrue(torch.allclose(ep.module()(*inps), M2()(*inps))) + + self.assertEqual(len(ep.state_dict), 0) + self.assertEqual(len(ep.constants), 1) + + self.assertExpectedInline( + str(ep.graph).strip(), + """\ +graph(): + %c_lifted_tensor_0 : [num_users=1] = placeholder[target=c_lifted_tensor_0] + %x : [num_users=2] = placeholder[target=x] + %ones : [num_users=1] = call_function[target=torch.ops.aten.ones.default](args = ([3, 3],), kwargs = {device: cpu, pin_memory: False}) + %detach : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%ones,), kwargs = {}) + %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%c_lifted_tensor_0,), kwargs = {}) + %detach_1 : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%lift_fresh_copy,), kwargs = {}) + %detach_2 : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%detach_1,), kwargs = {}) + %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%detach, %detach_2), kwargs = {}) + %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %mul), kwargs = {}) + %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %x), kwargs = {}) + return (mul_1,)""", + ) + + unflattened = unflatten(ep) + self.assertTrue(torch.allclose(unflattened(*inps), M2()(*inps))) + def test_lazy_module_kwargs(self): class LazyModule(torch.nn.modules.lazy.LazyModuleMixin, torch.nn.Module): def initialize_parameters(self, *args, **kwargs): diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 49afa50b78ac2..5545cab3c0788 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -257,20 +257,24 @@ def forward(self, x): example_inputs = (torch.randn(32, 64, device=self.device),) self.check_model(Model(), example_inputs) - def test_large(self): + def test_large_weight(self): class Model(torch.nn.Module): def __init__(self): super().__init__() - self.linear = torch.nn.Linear(512, 250112) + self.linear = torch.nn.Linear(2048, 262144) def forward(self, x, y): return x + self.linear(y) example_inputs = ( - torch.randn(1, 250112, device=self.device), - torch.randn(1, 512, device=self.device), + torch.randn(1, 262144, device=self.device), + torch.randn(1, 2048, device=self.device), ) - self.check_model(Model(), example_inputs) + + # We only test compilation since we often get OOM running in CI. + model = Model() + model = model.to(self.device) + AOTIRunnerUtil.compile(model, example_inputs) def test_large_mmaped_weights(self): class Model(torch.nn.Module): @@ -1208,8 +1212,9 @@ def forward(self, x): return self.foo + x example_inputs = (torch.rand(4, 4, device=self.device),) - torch._export.aot_compile(Model(self.device), example_inputs) - self.check_model(Model(self.device), example_inputs) + with config.patch({"aot_inductor.allow_buffer_mutation": True}): + torch._export.aot_compile(Model(self.device), example_inputs) + self.check_model(Model(self.device), example_inputs) def test_non_tensor_input(self): def fn(a, b, alpha=1.0): @@ -1241,8 +1246,9 @@ def forward(self, x): self.foo[5] = self.bar[0] return x + self.bar, x * self.foo - example_inputs = (torch.randn(10, device=self.device),) - self.check_model(Model(self.device), example_inputs) + with config.patch({"aot_inductor.allow_buffer_mutation": True}): + example_inputs = (torch.randn(10, device=self.device),) + self.check_model(Model(self.device), example_inputs) def test_buffer_mutation_3(self): class KVCache(torch.nn.Module): @@ -1282,7 +1288,8 @@ def forward(self, inp_pos, k, v): torch.randn(1, 6, 1, 48, device=self.device), torch.randn(1, 6, 1, 48, device=self.device), ) - self.check_model(Model(self.device), example_inputs) + with config.patch({"aot_inductor.allow_buffer_mutation": True}): + self.check_model(Model(self.device), example_inputs) @requires_multigpu() def test_replicate_on_devices(self): @@ -2872,8 +2879,10 @@ def fail_non_abi_compatible_cuda(is_skip=False): "test_duplicate_constant_folding": fail_with_and_without_stack_allocation( is_skip=True ), - "test_dup_unbacked_sym_decl": fail_with_and_without_stack_allocation(), - "test_dup_unbacked_sym_decl_with_refinement": fail_with_and_without_stack_allocation(), + "test_dup_unbacked_sym_decl": fail_minimal_arrayref_interface(is_skip=True), + "test_dup_unbacked_sym_decl_with_refinement": fail_minimal_arrayref_interface( + is_skip=True + ), "test_dynamic_cat": fail_minimal_arrayref_interface(), # https://github.com/pytorch/pytorch/issues/122978 "test_dynamic_scalar": fail_stack_allocation(is_skip=True), @@ -2950,8 +2959,6 @@ def fail_non_abi_compatible_cuda(is_skip=False): CUDA_TEST_FAILURES = { # test_failures, xfail by default, set is_skip=True to skip - "test_dup_unbacked_sym_decl": fail_abi_compatible_cuda(), - "test_dup_unbacked_sym_decl_with_refinement": fail_abi_compatible_cuda(), "test_large_grid": fail_cuda(), "test_normal_functional": fail_abi_compatible_cuda(), # There is a double-free issue which will be fixed in another PR @@ -2970,12 +2977,10 @@ def fail_non_abi_compatible_cuda(is_skip=False): if TEST_WITH_ROCM: CUDA_TEST_FAILURES.update( { - "test_dup_unbacked_sym_decl": fail_cuda(is_skip=True), - "test_dup_unbacked_sym_decl_with_refinement": fail_cuda(is_skip=True), "test_addmm_multiple_dynamic": fail_cuda(is_skip=True), "test_bmm_multiple_dynamic": fail_cuda(is_skip=True), "test_convolution": fail_cuda(is_skip=True), - "test_large": fail_cuda(is_skip=True), + "test_large_weight": fail_cuda(is_skip=True), "test_large_mmaped_weights": fail_cuda(is_skip=True), "test_missing_cubin": fail_cuda(is_skip=True), "test_multi_device": fail_cuda(is_skip=True), @@ -3020,7 +3025,7 @@ def fail_non_abi_compatible_cuda(is_skip=False): "test_constant_folding": fail_minimal_arrayref_interface(is_skip=True), "test_convolution": fail_minimal_arrayref_interface(is_skip=True), "test_empty_graph": fail_minimal_arrayref_interface(is_skip=True), - "test_large": fail_minimal_arrayref_interface(is_skip=True), + "test_large_weight": fail_minimal_arrayref_interface(is_skip=True), "test_large_mmaped_weights": fail_minimal_arrayref_interface(is_skip=True), "test_misc_1": fail_minimal_arrayref_interface(is_skip=True), "test_missing_output": fail_minimal_arrayref_interface(is_skip=True), diff --git a/test/inductor/test_compiled_optimizers.py b/test/inductor/test_compiled_optimizers.py index 7100837e9b92f..b2d0ed91809f9 100644 --- a/test/inductor/test_compiled_optimizers.py +++ b/test/inductor/test_compiled_optimizers.py @@ -136,8 +136,6 @@ class KernelCounts(NamedTuple): "test_sgd_momentum_foreach_cuda": 5, "test_sgd_weight_decay_maximize_cuda": 4, "test_sgd_weight_decay_maximize_cpu": 4, - "test_sgd_weight_decay_cpu": 4, - "test_sgd_weight_decay_cuda": 4, "test_sgd_momentum_weight_decay_foreach_cuda": 2, "test_sgd_momentum_nesterov_weight_decay_foreach_cuda": 2, "test_sgd_cuda": 4, diff --git a/test/inductor/test_control_flow.py b/test/inductor/test_control_flow.py index 47a5980b6d79c..833693dab934a 100644 --- a/test/inductor/test_control_flow.py +++ b/test/inductor/test_control_flow.py @@ -2,6 +2,7 @@ import itertools import torch +import torch._dynamo.testing from torch._inductor.test_case import TestCase from torch.testing._internal.common_utils import ( diff --git a/test/inductor/test_cuda_repro.py b/test/inductor/test_cuda_repro.py index f303330bc1140..da3869c5a3acc 100644 --- a/test/inductor/test_cuda_repro.py +++ b/test/inductor/test_cuda_repro.py @@ -1135,6 +1135,32 @@ def fn(arg207_1, arg208_1, convert_element_type_40, expand, full, mul_3): fn(*args) torch.cuda.synchronize() # shake out Triton Error [CUDA]: misaligned address + def test_non_commutative_scan_op(self): + from torch._higher_order_ops.associative_scan import associative_scan + + a = torch.randn(1024, 8192, dtype=torch.float64, device="cuda") + b = torch.randn(1024, 8192, dtype=torch.float64, device="cuda") + + def baseline(v, u): + A = [] + A.append(b[:, 0]) + for i in range(1, v.shape[1]): + A.append(a[:, i] * A[i - 1] + b[:, i]) + return torch.stack(A, dim=1) + + def combine_fn(i, j): + ia, ib = i + ja, jb = j + return ia * ja, ib * ja + jb + + @torch.compile + def compiled_scan(a, b): + return associative_scan(combine_fn, (a, b), dim=-1)[1] + + out1 = baseline(a, b) + out2 = compiled_scan(a, b) + self.assertEqual(out1, out2) + if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index f3a9026a3c805..245e8f16ab641 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -1,7 +1,7 @@ # Owner(s): ["module: inductor"] +# flake8: noqa: B950 import functools -import unittest from collections import namedtuple from typing import Callable, Optional @@ -12,6 +12,7 @@ from torch._dynamo.testing import CompileCounterWithBackend, normalize_gm from torch._higher_order_ops.flex_attention import flex_attention as flex_attention_hop +from torch._inductor import metrics from torch._inductor.test_case import TestCase as InductorTestCase from torch._inductor.utils import run_and_get_code from torch.nn.attention._flex_attention import ( @@ -499,9 +500,8 @@ def score_mod_func(score, b, h, q, kv): ) query, key, value = make_tensor(), make_tensor(), make_tensor() # floor_div is not decomposed in decompostion_table is empty - gm = make_fx(_flex_attention, decomposition_table={})( - query, key, value, score_mod_func - ) + flex_attention = functools.partial(_flex_attention, score_mod=score_mod_func) + gm = make_fx(flex_attention, decomposition_table={})(query, key, value) self.assertExpectedInline( gm.sdpa_score0.code.strip(), """\ @@ -513,8 +513,8 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): ) # floor_div is decomposed for core_aten_decompositions - gm = make_fx(_flex_attention, decomposition_table=core_aten_decompositions())( - query, key, value, score_mod_func + gm = make_fx(flex_attention, decomposition_table=core_aten_decompositions())( + query, key, value ) self.assertExpectedInline( gm.sdpa_score0.code.strip(), @@ -528,7 +528,6 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): @supported_platform @common_utils.parametrize("dtype", test_dtypes_fast) - @unittest.skip("Silu decomp failing for full in backwards") def test_silu_on_score(self, dtype): def silu_score(score, b, h, q, kv): return torch.nn.functional.silu(score) @@ -645,6 +644,50 @@ def f(q, k1, k2, k3, v1, v2, v3): out2 = torch.compile(f)(query, *keys, *values) self.assertTrue((out - out2).abs().mean() < 1e-2) + @supported_platform + def test_inputs_are_realized(self): + def f(q, k, v): + x = torch.randn(1024, device="cuda") + x = x * 2 + + def func(qk, b, h, q, kv): + return qk + x[q] + + return _flex_attention(q.sin(), k, v, score_mod=func).cos() + + q, k, v = ( + torch.randn(1, 8, 1024, 64, device="cuda", requires_grad=True) + for _ in range(3) + ) + ref = f(q, k, v) + out = torch.compile(f)(q, k, v) + self.assertTrue((ref - out).abs().mean() < 1e-2) + gradOut = torch.randn_like(q) + + ref_grads = torch.autograd.grad(ref, (q, k, v), gradOut) + out_grads = torch.autograd.grad(out, (q, k, v), gradOut) + for ref, out in zip(ref_grads, out_grads): + self.assertTrue((ref - out).abs().mean() < 1e-2) + + @supported_platform + def test_epilogue_fused(self): + @torch.compile + def f(q, k, v): + out = _flex_attention(q, k, v) + return out.cos() + + q, k, v = (torch.randn(1, 8, 1024, 64, device="cuda") for _ in range(3)) + metrics.reset() + f(q, k, v) + accessed_bytes = 1 * 8 * 1024 * 64 * torch.float32.itemsize + num_accesses = 4 # q, k, v reads, one output. + # TODO: Get rid of this fudge factor + # We need this fudge factor for now, since + # 1. For some reason we materialize the output of the attention unnecessarily (it's related to the mutation somehow) + # 2. We also write the extraneous logsumexp + num_accesses += 2 + self.assertLess(metrics.num_bytes_accessed, accessed_bytes * num_accesses) + @supported_platform @skip("Triton bug ") # https://github.com/pytorch/pytorch/issues/124571 @common_utils.parametrize("dtype", test_dtypes) @@ -919,7 +962,13 @@ def debug_compile_fx_inner(graph, example_inputs, *args, **kwargs): joint_graph, """\ class GraphModule(torch.nn.Module): - def forward(self, primals_1: "f64[2, 2, 8, 4]", primals_2: "f64[2, 2, 8, 4]", primals_3: "f64[2, 2, 8, 4]", alias_3: "f64[2, 2, 8, 4]", alias_5: "f32[2, 2, 8]", tangents_1: "f64[2, 2, 8, 4]"): + def forward(self, primals_1: "f64[2, 2, 8, 4]", primals_2: "f64[2, 2, 8, 4]", primals_3: "f64[2, 2, 8, 4]", getitem: "f64[2, 2, 8, 4]", getitem_1: "f32[2, 2, 8]", tangents_1: "f64[2, 2, 8, 4]"): + alias: "f64[2, 2, 8, 4]" = torch.ops.aten.alias.default(getitem); getitem = None + alias_2: "f64[2, 2, 8, 4]" = torch.ops.aten.alias.default(alias); alias = None + alias_3: "f64[2, 2, 8, 4]" = torch.ops.aten.alias.default(alias_2); alias_2 = None + alias_1: "f32[2, 2, 8]" = torch.ops.aten.alias.default(getitem_1); getitem_1 = None + alias_4: "f32[2, 2, 8]" = torch.ops.aten.alias.default(alias_1); alias_1 = None + alias_5: "f32[2, 2, 8]" = torch.ops.aten.alias.default(alias_4); alias_4 = None fw_graph = self.fw_graph joint_graph = self.joint_graph flex_attention_backward = torch.ops.higher_order.flex_attention_backward(primals_1, primals_2, primals_3, alias_3, alias_5, tangents_1, fw_graph, joint_graph); primals_1 = primals_2 = primals_3 = alias_3 = alias_5 = tangents_1 = fw_graph = joint_graph = None diff --git a/test/inductor/test_group_batch_fusion.py b/test/inductor/test_group_batch_fusion.py index 6dd2ff51219d7..b203a0f63e8b1 100644 --- a/test/inductor/test_group_batch_fusion.py +++ b/test/inductor/test_group_batch_fusion.py @@ -5,6 +5,7 @@ import torch import torch._inductor +import torch._inductor.fx_passes.group_batch_fusion from torch._dynamo.utils import counters, optimus_scuba_log from torch._inductor.test_case import run_tests, TestCase from torch.testing._internal.inductor_utils import HAS_CUDA diff --git a/test/inductor/test_max_autotune.py b/test/inductor/test_max_autotune.py index c5f0afa118f87..1859ca391e02a 100644 --- a/test/inductor/test_max_autotune.py +++ b/test/inductor/test_max_autotune.py @@ -463,6 +463,27 @@ def fn(a, b, c): self.assertEqual(fn(*inputs), fn_c(*inputs), atol=1e-2, rtol=1e-2) self.assertEqual(counters["inductor"]["select_algorithm_precompile"], 0) + @skipIfRocm + @fresh_inductor_cache() + @config.patch(max_autotune=True, max_fusion_size=2) + def test_jit_fusion_matches_aot_fusion(self): + # In this example, AOTInductor's JIT-compile will fuse(buf1, buf2) due + # to proximity, we want to make sure AOT-compile pass does the same. + # AOT could do fuse(buf2, buf4) instead if buf3 was pushed to the end + # of the V.graph.buffers list because fuse(buf2, buf4) would have a + # better proximity score than fuse(buf1, buf2). This scenario is possible + # since finalizing MultiTemplateBuffers needs to replace buffers. + def fn(x, number): + buf0 = x + x + buf1 = number.item() + buf2 = x * x + buf3 = x @ x # MultiTemplateBuffer + buf4 = x**2 + return buf0, buf1, buf2, buf3, buf4 + + inputs = (torch.rand([256, 256], device="cuda"), torch.tensor(3, device="cuda")) + torch._export.aot_compile(fn, args=inputs) + @config.patch(autotune_local_cache=False, autotune_remote_cache=False) def test_precompilations(self): def fn(a, b, c): diff --git a/test/inductor/test_mkldnn_pattern_matcher.py b/test/inductor/test_mkldnn_pattern_matcher.py index 756de35df84cf..9c39f1c140018 100644 --- a/test/inductor/test_mkldnn_pattern_matcher.py +++ b/test/inductor/test_mkldnn_pattern_matcher.py @@ -1688,10 +1688,13 @@ def matcher_check_fn(): to_bf16_after_binary = 2 * (add_fn == add_fn_list[2] and fq_x2) self.assertEqual( counters["inductor"]["qlinear_binary_matcher_nodes"], - 5 + 2 * use_relu + to_bf16_after_binary, + (4 if is_dynamic else 5) + 2 * use_relu + to_bf16_after_binary, ) - for is_qat in [False, True]: + is_qat_list = [False, True] + is_dynamic_list = [False, True] + cases = itertools.product(is_qat_list, is_dynamic_list) + for is_qat, is_dynamic in cases: self._test_common( mod, (v,), @@ -1699,6 +1702,7 @@ def matcher_check_fn(): check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float, matcher_check_fn=matcher_check_fn, is_qat=is_qat, + is_dynamic=is_dynamic, ) @skipIfNoDynamoSupport diff --git a/test/inductor/test_perf.py b/test/inductor/test_perf.py index 09e913350e143..5e1af26f4bfac 100644 --- a/test/inductor/test_perf.py +++ b/test/inductor/test_perf.py @@ -8,9 +8,9 @@ import torch._inductor.config as config import torch.autograd from torch._inductor import metrics -from torch._inductor.compile_fx import compile_fx, count_bytes_inner +from torch._inductor.compile_fx import compile_fx, compile_fx_inner from torch._inductor.test_case import TestCase as InductorTestCase -from torch.testing._internal.common_utils import IS_WINDOWS, skipIfRocm +from torch.testing._internal.common_utils import skipIfRocm ######################## # Explanation of Tests # @@ -36,21 +36,12 @@ aten = torch.ops.aten -def count_bytes_inductor(gm, example_inputs): - return compile_fx(gm, example_inputs, inner_compile=count_bytes_inner) +def compile_but_use_eager(gm, example_inputs): + def inner_compile(gm, *args, **kwargs): + compile_fx_inner(gm, *args, **kwargs) + return gm - -# We don't support torch.compile() on Windows -if not IS_WINDOWS: - - @torch._dynamo.optimize(count_bytes_inductor) - def f(x): - return torch.cat([x, x.cos()]) - -else: - - def f(x): - return torch.cat([x, x.cos()]) + return compile_fx(gm, example_inputs, inner_compile=inner_compile) def count_numel(f, *args): @@ -58,7 +49,7 @@ def count_numel(f, *args): Assumes all inputs are fp32 """ metrics.reset() - torch._dynamo.optimize(count_bytes_inductor)(f)(*args) + torch.compile(f, backend=compile_but_use_eager)(*args) print(metrics.nodes_num_elem) return str(metrics.num_bytes_accessed // 4) @@ -69,7 +60,7 @@ def count_numel_train(f, *args): """ metrics.reset() - f = torch._dynamo.optimize(count_bytes_inductor)(f) + f = torch.compile(f, backend=compile_but_use_eager) out = f(*args) res = 0 for o in out: diff --git a/test/inductor/test_snode_runtime.py b/test/inductor/test_snode_runtime.py index b62c219f85e81..0d9ed849e0d50 100644 --- a/test/inductor/test_snode_runtime.py +++ b/test/inductor/test_snode_runtime.py @@ -8,7 +8,7 @@ from torch._inductor import metrics from torch._inductor.comm_analysis import estimate_nccl_collective_runtime -from torch._inductor.compile_fx import compile_fx, count_bytes_inner +from torch._inductor.compile_fx import compile_fx, compile_fx_inner from torch._inductor.test_case import TestCase as InductorTestCase from torch._inductor.utils import is_collective from torch.testing._internal.inductor_utils import HAS_CUDA @@ -18,8 +18,12 @@ _c10d = torch.ops._c10d_functional -def count_bytes_inductor(gm, example_inputs): - return compile_fx(gm, example_inputs, inner_compile=count_bytes_inner) +def compile_but_use_eager(gm, example_inputs): + def inner_compile(gm, *args, **kwargs): + compile_fx_inner(gm, *args, **kwargs) + return gm + + return compile_fx(gm, example_inputs, inner_compile=inner_compile) def calculate_runtime(f, *args) -> float: @@ -27,7 +31,7 @@ def calculate_runtime(f, *args) -> float: Assumes all inputs are fp32 """ metrics.reset() - torch._dynamo.optimize(count_bytes_inductor)(f)(*args) + torch.compile(f, backend=compile_but_use_eager)(*args) print(metrics.node_runtimes) ret = 0.0 @@ -187,7 +191,7 @@ def _verify_runtime_estimation(self, fn, inps): ) try: metrics.reset() - torch._dynamo.optimize(count_bytes_inductor)(fn)(*inps) + torch.compile(fn)(*inps) found_collective = False for snode, runtime in metrics.node_runtimes: if not is_collective(snode.node): diff --git a/test/quantization/pt2e/test_x86inductor_quantizer.py b/test/quantization/pt2e/test_x86inductor_quantizer.py index 218b30bd9e33f..3202900a28624 100644 --- a/test/quantization/pt2e/test_x86inductor_quantizer.py +++ b/test/quantization/pt2e/test_x86inductor_quantizer.py @@ -1198,84 +1198,63 @@ def test_linear(self): node_list, ) - @skipIfNoX86 - def test_linear_unary(self): + def _test_linear_unary_helper( + self, + post_op_module, + post_op_aten, + post_op_aten_inplace, + post_op_algo_list=None, + is_qat=False, + is_dynamic=False, + ): """ Test pattern of linear with unary post ops (e.g. relu) with X86InductorQuantizer. """ use_bias_list = [True, False] inplace_list = [True, False] - postop_list = [nn.ReLU, nn.LeakyReLU] # only test two to save time - cases = itertools.product(use_bias_list, inplace_list, postop_list) - post_op_map = { - nn.ReLU: [torch.ops.aten.relu_.default, torch.ops.aten.relu.default], - nn.LeakyReLU: [ - torch.ops.aten.leaky_relu_.default, - torch.ops.aten.leaky_relu.default, - ], - } + if post_op_algo_list is None: + post_op_algo_list = [None] + cases = itertools.product(use_bias_list, inplace_list, post_op_algo_list) with override_quantized_engine("x86"), torch.no_grad(): - for use_bias, inplace, postop in cases: + for use_bias, inplace, post_op_algo in cases: + if inplace and post_op_aten_inplace is None: + continue m = TestHelperModules.LinearUnaryModule( - use_bias=use_bias, postop=postop, inplace_postop=inplace + use_bias=use_bias, + postop=post_op_module, + inplace_postop=inplace, + post_op_algo=post_op_algo, ).eval() example_inputs = (torch.randn(2, 4),) quantizer = X86InductorQuantizer().set_global( - xiq.get_default_x86_inductor_quantization_config() + xiq.get_default_x86_inductor_quantization_config( + is_qat=is_qat, + is_dynamic=is_dynamic, + ) ) - node_occurrence = { - # one for input and weight of the conv, one for output for the relu - torch.ops.quantized_decomposed.quantize_per_tensor.default: 1, - torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1, - # quantize_per_channel for weights are const propagated - torch.ops.quantized_decomposed.quantize_per_channel.default: 0, - torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, - } - node_list = [ - torch.ops.quantized_decomposed.quantize_per_tensor.default, - torch.ops.quantized_decomposed.dequantize_per_tensor.default, - torch.ops.aten.linear.default, - post_op_map[postop][0 if inplace else 1], - ] - self._test_quantizer( - m, - example_inputs, - quantizer, - node_occurrence, - node_list, + quantize_per_tensor_op = ( + torch.ops.quantized_decomposed.quantize_per_tensor.tensor + if is_dynamic + else torch.ops.quantized_decomposed.quantize_per_tensor.default ) - - @skipIfNoX86 - def test_linear_unary_gelu(self): - """ - Test pattern of linear with unary post ops (e.g. gelu) with X86InductorQuantizer. - """ - use_bias_list = [True, False] - postop = nn.GELU - post_op_algorithm = ["none", "tanh"] - cases = itertools.product(use_bias_list, post_op_algorithm) - with override_quantized_engine("x86"), torch.no_grad(): - for use_bias, post_op_algo in cases: - m = TestHelperModules.LinearUnaryModule( - use_bias=use_bias, postop=postop, post_op_algo=post_op_algo - ).eval() - example_inputs = (torch.randn(2, 4),) - quantizer = X86InductorQuantizer().set_global( - xiq.get_default_x86_inductor_quantization_config() + dequantize_per_tensor_op = ( + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor + if is_dynamic + else torch.ops.quantized_decomposed.dequantize_per_tensor.default ) node_occurrence = { - # one for input and weight of the conv, one for output for the gelu - torch.ops.quantized_decomposed.quantize_per_tensor.default: 1, - torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1, + # one for input of the linear + quantize_per_tensor_op: 1, + dequantize_per_tensor_op: 1, # quantize_per_channel for weights are const propagated torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, } node_list = [ - torch.ops.quantized_decomposed.quantize_per_tensor.default, - torch.ops.quantized_decomposed.dequantize_per_tensor.default, + quantize_per_tensor_op, + dequantize_per_tensor_op, torch.ops.aten.linear.default, - torch.ops.aten.gelu.default, + post_op_aten_inplace if inplace else post_op_aten, ] self._test_quantizer( m, @@ -1283,8 +1262,71 @@ def test_linear_unary_gelu(self): quantizer, node_occurrence, node_list, + is_qat=is_qat, ) + @skipIfNoX86 + def test_linear_unary(self): + aten = torch.ops.aten + self._test_linear_unary_helper(nn.ReLU, aten.relu.default, aten.relu_.default) + self._test_linear_unary_helper( + nn.LeakyReLU, aten.leaky_relu.default, aten.leaky_relu_.default + ) + self._test_linear_unary_helper( + nn.GELU, aten.gelu.default, None, ["none", "tanh"] + ) + + @skipIfNoX86 + def test_linear_unary_qat(self): + aten = torch.ops.aten + self._test_linear_unary_helper( + nn.ReLU, aten.relu.default, aten.relu_.default, is_qat=True + ) + self._test_linear_unary_helper( + nn.LeakyReLU, aten.leaky_relu.default, aten.leaky_relu_.default, is_qat=True + ) + self._test_linear_unary_helper( + nn.GELU, aten.gelu.default, None, ["none", "tanh"], is_qat=True + ) + + @skipIfNoX86 + def test_linear_unary_dynamic(self): + aten = torch.ops.aten + self._test_linear_unary_helper( + nn.ReLU, aten.relu.default, aten.relu_.default, is_dynamic=True + ) + self._test_linear_unary_helper( + nn.LeakyReLU, + aten.leaky_relu.default, + aten.leaky_relu_.default, + is_dynamic=True, + ) + self._test_linear_unary_helper( + nn.GELU, aten.gelu.default, None, ["none", "tanh"], is_dynamic=True + ) + + @skipIfNoX86 + def test_linear_unary_dynamic_qat(self): + aten = torch.ops.aten + self._test_linear_unary_helper( + nn.ReLU, aten.relu.default, aten.relu_.default, is_qat=True, is_dynamic=True + ) + self._test_linear_unary_helper( + nn.LeakyReLU, + aten.leaky_relu.default, + aten.leaky_relu_.default, + is_qat=True, + is_dynamic=True, + ) + self._test_linear_unary_helper( + nn.GELU, + aten.gelu.default, + None, + ["none", "tanh"], + is_qat=True, + is_dynamic=True, + ) + def _check_annotation_stat(self, gm, expected_stat_dict): # Check expected annotation statistics to ensure the annotation is correct @@ -1302,8 +1344,7 @@ def _check_annotation(node): for op_stat in expected_stat_dict.values(): assert all(v == 0 for v in op_stat.values()) - @skipIfNoX86 - def test_linear_binary(self): + def _test_linear_binary_helper(self, is_qat=False, is_dynamic=False): """ Test pattern of linear with binary post ops (such as add) with X86InductorQuantizer. Currently, only add as binary post op is supported. @@ -1313,7 +1354,20 @@ def test_linear_binary(self): inplace_add_list = [False] example_inputs = (torch.randn(2, 16),) quantizer = X86InductorQuantizer().set_global( - xiq.get_default_x86_inductor_quantization_config() + xiq.get_default_x86_inductor_quantization_config( + is_qat=is_qat, + is_dynamic=is_dynamic, + ) + ) + quantize_per_tensor_op = ( + torch.ops.quantized_decomposed.quantize_per_tensor.tensor + if is_dynamic + else torch.ops.quantized_decomposed.quantize_per_tensor.default + ) + dequantize_per_tensor_op = ( + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor + if is_dynamic + else torch.ops.quantized_decomposed.dequantize_per_tensor.default ) cases = itertools.product(linear_pos_list, inplace_add_list) with override_quantized_engine("x86"), torch.no_grad(): @@ -1325,26 +1379,28 @@ def test_linear_binary(self): node_occurrence = { # Only one 1 q-dq for input of the linear # No q-dq for extra input node of add - torch.ops.quantized_decomposed.quantize_per_tensor.default: 1, - torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1, + quantize_per_tensor_op: 1, + dequantize_per_tensor_op: 1, # quantize_per_channel for weights are const propagated torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, } else: + # convert_pt2e disables duplicate dequant for dynamic quant + num_dequant = 1 if is_dynamic else 2 node_occurrence = { # One quantize_per_tensor for both linear nodes (shared) # Two dequantize_per_tensor for two linear nodes # No q-dq for extra input node of add - torch.ops.quantized_decomposed.quantize_per_tensor.default: 1, - torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, + quantize_per_tensor_op: 1, + dequantize_per_tensor_op: num_dequant, # quantize_per_channel for weights are const propagated torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, } node_list = [ - torch.ops.quantized_decomposed.quantize_per_tensor.default, - torch.ops.quantized_decomposed.dequantize_per_tensor.default, + quantize_per_tensor_op, + dequantize_per_tensor_op, torch.ops.aten.linear.default, torch.ops.aten.add_.Tensor if inplace_add @@ -1356,6 +1412,7 @@ def test_linear_binary(self): quantizer, node_occurrence, node_list, + is_qat=is_qat, )[-1] # One linear and add are fused. The other linear is quantized alone if present aten = torch.ops.aten @@ -1369,6 +1426,22 @@ def test_linear_binary(self): } self._check_annotation_stat(fq_m, expected_annotation_stat) + @skipIfNoX86 + def test_linear_binary(self): + self._test_linear_binary_helper() + + @skipIfNoX86 + def test_linear_binary_qat(self): + self._test_linear_binary_helper(is_qat=True) + + @skipIfNoX86 + def test_linear_binary_dynamic(self): + self._test_linear_binary_helper(is_dynamic=True) + + @skipIfNoX86 + def test_linear_binary_dynamic_qat(self): + self._test_linear_binary_helper(is_qat=True, is_dynamic=True) + @skipIfNoX86 def test_linear_binary2(self): """ @@ -1379,28 +1452,43 @@ def test_linear_binary2(self): Since linear_1 has 2 users, we should annotate linear_2 for binary fusion instead of linear_1 """ example_inputs = (torch.randn(2, 16),) - quantizer = X86InductorQuantizer().set_global( - xiq.get_default_x86_inductor_quantization_config() - ) # TODO test for inplace add after refactoring of capture_pre_autograd_graph inplace_add_list = [False] + is_qat_list = [False, True] + is_dynamic_list = [False, True] + cases = itertools.product(inplace_add_list, is_qat_list, is_dynamic_list) with override_quantized_engine("x86"), torch.no_grad(): - for inplace_add in inplace_add_list: + for inplace_add, is_qat, is_dynamic in cases: + quantizer = X86InductorQuantizer().set_global( + xiq.get_default_x86_inductor_quantization_config( + is_qat=is_qat, is_dynamic=is_dynamic + ) + ) m = TestHelperModules.LinearAddModule2(inplace_add=inplace_add).eval() + quantize_per_tensor_op = ( + torch.ops.quantized_decomposed.quantize_per_tensor.tensor + if is_dynamic + else torch.ops.quantized_decomposed.quantize_per_tensor.default + ) + dequantize_per_tensor_op = ( + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor + if is_dynamic + else torch.ops.quantized_decomposed.dequantize_per_tensor.default + ) # Two q-dq nodes for inputs of linear nodes # No q-dq for extra input node of add node_occurrence = { - torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, - torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, + quantize_per_tensor_op: 2, + dequantize_per_tensor_op: 2, # quantize_per_channel for weights are const propagated torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, } node_list = [ - torch.ops.quantized_decomposed.quantize_per_tensor.default, - torch.ops.quantized_decomposed.dequantize_per_tensor.default, + quantize_per_tensor_op, + dequantize_per_tensor_op, torch.ops.aten.linear.default, - torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_channel.default, torch.ops.aten.add_.Tensor if inplace_add else torch.ops.aten.add.Tensor, @@ -1425,7 +1513,7 @@ def test_linear_binary2(self): self._check_annotation_stat(fq_m, expected_annotation_stat) @skipIfNoX86 - def test_linear_binary_unary(self): + def _test_linear_binary_unary_helper(self, is_qat=False, is_dynamic=False): """ Test pattern of linear with binary + unary post ops (such as add + relu) with X86InductorQuantizer. Currently, only add as binary post op and relu as unary post op are supported. @@ -1437,7 +1525,20 @@ def test_linear_binary_unary(self): inplace_relu_list = [False] example_inputs = (torch.randn(2, 16),) quantizer = X86InductorQuantizer().set_global( - xiq.get_default_x86_inductor_quantization_config() + xiq.get_default_x86_inductor_quantization_config( + is_qat=is_qat, + is_dynamic=is_dynamic, + ) + ) + quantize_per_tensor_op = ( + torch.ops.quantized_decomposed.quantize_per_tensor.tensor + if is_dynamic + else torch.ops.quantized_decomposed.quantize_per_tensor.default + ) + dequantize_per_tensor_op = ( + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor + if is_dynamic + else torch.ops.quantized_decomposed.dequantize_per_tensor.default ) cases = itertools.product(linear_pos_list, inplace_add_list, inplace_relu_list) with override_quantized_engine("x86"), torch.no_grad(): @@ -1451,26 +1552,28 @@ def test_linear_binary_unary(self): node_occurrence = { # Only one q-dq node for input of the linear # No q-dq node for extra input node of add - torch.ops.quantized_decomposed.quantize_per_tensor.default: 1, - torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1, + quantize_per_tensor_op: 1, + dequantize_per_tensor_op: 1, # note: quantize op for weights are const propagated torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, } else: + # convert_pt2e disables duplicate dequant for dynamic quant + num_dequant = 1 if is_dynamic else 2 node_occurrence = { # One quantize_per_tensor for both linear nodes (shared) # Two dequantize_per_tensor for two linear nodes # No q-dq for extra input node of add - torch.ops.quantized_decomposed.quantize_per_tensor.default: 1, - torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, + quantize_per_tensor_op: 1, + dequantize_per_tensor_op: num_dequant, # note: quantize op for weights are const propagated torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, } node_list = [ - torch.ops.quantized_decomposed.quantize_per_tensor.default, - torch.ops.quantized_decomposed.dequantize_per_tensor.default, + quantize_per_tensor_op, + dequantize_per_tensor_op, torch.ops.aten.linear.default, torch.ops.aten.add_.Tensor if inplace_add @@ -1498,57 +1601,91 @@ def test_linear_binary_unary(self): } self._check_annotation_stat(fq_m, expected_annotation_stat) + @skipIfNoX86 + def test_linear_binary_unary(self): + self._test_linear_binary_unary_helper() + + @skipIfNoX86 + def test_linear_binary_unary_qat(self): + self._test_linear_binary_unary_helper(is_qat=True) + + @skipIfNoX86 + def test_linear_binary_unary_dynamic(self): + self._test_linear_binary_unary_helper(is_dynamic=True) + + @skipIfNoX86 + def test_linear_binary_unary_dynamic_qat(self): + self._test_linear_binary_unary_helper(is_qat=True, is_dynamic=True) + @skipIfNoX86 def test_linear_binary_unary_serials(self): """ Test pattern of 2 following up linear add relu with X86InductorQuantizer. """ + is_qat_list = [False, True] + is_dynamic_list = [False, True] + cases = itertools.product(is_qat_list, is_dynamic_list) with override_quantized_engine("x86"), torch.no_grad(): - m = TestHelperModules.SerialsLinearAddReLUModule().eval() - example_inputs = (torch.randn(2, 16),) - quantizer = X86InductorQuantizer().set_global( - xiq.get_default_x86_inductor_quantization_config() - ) - node_occurrence = { - # quantize_per_tensor: 1 for linear_1, 1 for linear_2/3 (shared), 1 for linear_4 - # dequantize_per_tensor: 1 for each linear - # No q-dq for extra input node of add - torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, - torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4, - # quantize_per_channel for weights are const propagated - torch.ops.quantized_decomposed.quantize_per_channel.default: 0, - torch.ops.quantized_decomposed.dequantize_per_channel.default: 4, - } - node_list = [ - torch.ops.quantized_decomposed.quantize_per_tensor.default, - torch.ops.quantized_decomposed.dequantize_per_tensor.default, - torch.ops.aten.linear.default, - torch.ops.quantized_decomposed.quantize_per_tensor.default, - torch.ops.quantized_decomposed.dequantize_per_tensor.default, - torch.ops.aten.linear.default, - torch.ops.aten.linear.default, - torch.ops.aten.add.Tensor, - torch.ops.aten.relu.default, - ] - fq_m = self._test_quantizer( - m, - example_inputs, - quantizer, - node_occurrence, - node_list, - )[-1] - # Two linear nodes are quantized alone - # The other two are fused with add and relu - aten = torch.ops.aten - expected_annotation_stat = { - aten.linear.default: { - "annotated": 4, - "is_quant_out": 2, - }, - aten.add.Tensor: {"annotated": 2, "is_quant_out": 0}, - aten.relu.default: {"annotated": 2, "is_quant_out": 2}, - } - self._check_annotation_stat(fq_m, expected_annotation_stat) + for is_qat, is_dynamic in cases: + m = TestHelperModules.SerialsLinearAddReLUModule().eval() + example_inputs = (torch.randn(2, 16),) + quantizer = X86InductorQuantizer().set_global( + xiq.get_default_x86_inductor_quantization_config( + is_qat=is_qat, + is_dynamic=is_dynamic, + ) + ) + quantize_per_tensor_op = ( + torch.ops.quantized_decomposed.quantize_per_tensor.tensor + if is_dynamic + else torch.ops.quantized_decomposed.quantize_per_tensor.default + ) + dequantize_per_tensor_op = ( + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor + if is_dynamic + else torch.ops.quantized_decomposed.dequantize_per_tensor.default + ) + # convert_pt2e disables duplicate dequant for dynamic quant + num_dequant = 3 if is_dynamic else 4 + node_occurrence = { + # quantize_per_tensor: 1 for linear_1, 1 for linear_2/3 (shared), 1 for linear_4 + # dequantize_per_tensor: 1 for each linear + # No q-dq for extra input node of add + quantize_per_tensor_op: 3, + dequantize_per_tensor_op: num_dequant, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 4, + } + node_list = [ + quantize_per_tensor_op, + dequantize_per_tensor_op, + torch.ops.aten.linear.default, + torch.ops.quantized_decomposed.dequantize_per_channel.default, + torch.ops.aten.linear.default, + torch.ops.aten.linear.default, + torch.ops.aten.add.Tensor, + torch.ops.aten.relu.default, + ] + fq_m = self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + )[-1] + # Two linear nodes are quantized alone + # The other two are fused with add and relu + aten = torch.ops.aten + expected_annotation_stat = { + aten.linear.default: { + "annotated": 4, + "is_quant_out": 2, + }, + aten.add.Tensor: {"annotated": 2, "is_quant_out": 0}, + aten.relu.default: {"annotated": 2, "is_quant_out": 2}, + } + self._check_annotation_stat(fq_m, expected_annotation_stat) @skipIfTorchDynamo("very slow") @skipIfNoX86 diff --git a/test/test_cuda.py b/test/test_cuda.py index cc3e2380f2664..c1b990381bbe8 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -37,11 +37,7 @@ instantiate_device_type_tests, onlyCUDA, ) -from torch.testing._internal.common_optimizers import ( - _get_optim_inputs_including_global_cliquey_kwargs, - optim_db, - optims, -) +from torch.testing._internal.common_optimizers import optim_db, optims from torch.testing._internal.common_utils import ( freeze_rng_state, gcIfJetson, @@ -3204,6 +3200,111 @@ def _test_graphed_optimizer( for p_control, p_graphed in zip(params_control, params_graphed): self.assertEqual(p_control, p_graphed) + @unittest.skipIf( + not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" + ) + def test_graph_optims(self): + # Needs generalization if we want to extend this test to non-Adam-like optimizers. + cases = ( + [ + ( + optimizer_ctor, + { + "lr": 0.1, + "betas": (0.8, 0.7), + "foreach": foreach, + "decoupled_weight_decay": decoupled_weight_decay, + "weight_decay": weight_decay, + }, + ) + for optimizer_ctor, foreach, decoupled_weight_decay, weight_decay in product( + ( + torch.optim.NAdam, + torch.optim.RAdam, + ), + ( + False, + True, + ), + ( + False, + True, + ), + ( + 0.0, + 0.1, + ), + ) + ] + + [ + ( + torch.optim.Rprop, + {"lr": 0.1, "foreach": foreach, "maximize": maximize}, + ) + for foreach, maximize in product( + ( + False, + True, + ), + ( + False, + True, + ), + ) + ] + + [ + ( + optimizer_ctor, + { + "lr": 0.1, + "betas": (0.8, 0.7), + "foreach": foreach, + "amsgrad": amsgrad, + }, + ) + for optimizer_ctor, foreach, amsgrad in product( + (torch.optim.Adam, torch.optim.AdamW), + (False, True), + (False, True), + ) + ] + + [ + ( + optimizer_ctor, + {"lr": 0.1, "betas": (0.8, 0.7), "fused": True, "amsgrad": amsgrad}, + ) + for optimizer_ctor, amsgrad in product( + (torch.optim.Adam, torch.optim.AdamW), (False, True) + ) + ] + + [ + ( + optimizer_ctor, + { + "lr": 0.1, + "foreach": foreach, + "maximize": maximize, + "weight_decay": weight_decay, + }, + ) + for optimizer_ctor, foreach, maximize, weight_decay in product( + ( + torch.optim.Adamax, + torch.optim.ASGD, + torch.optim.Adadelta, + torch.optim.RMSprop, + ), + (False, True), + (False, True), + (0, 0.1), + ) + ] + ) + + for optimizer_ctor, kwargs in cases: + with self.subTest(optimizer_ctor=optimizer_ctor, kwargs=kwargs): + self._test_graphed_optimizer(3, 2, optimizer_ctor, kwargs) + @unittest.skipIf( not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" ) @@ -3275,6 +3376,123 @@ def test_graph_optims_with_explicitly_capturable_param_groups(self): self.assertEqual(ref_p1, param1) self.assertEqual(ref_p2, param2) + @unittest.skipIf( + not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" + ) + def test_graph_scaling_fused_optimizers(self): + cases = [ + ( + optimizer_ctor, + {"lr": 0.1, "betas": (0.8, 0.7), "fused": True, "amsgrad": amsgrad}, + ) + for optimizer_ctor, amsgrad in product( + (torch.optim.Adam, torch.optim.AdamW), (False, True) + ) + ] + list( + product( + (torch.optim.SGD,), + [ + { + "lr": 0.1, + "momentum": 0.0, + "dampening": d, + "weight_decay": w, + "nesterov": n, + "fused": True, + } + for d, w, n in product((0.0, 0.5), (0.0, 0.5), (False,)) + ] + + [ + { + "lr": 0.1, + "momentum": 0.5, + "dampening": d, + "weight_decay": w, + "nesterov": n, + "fused": True, + } + for d, w, n in product((0.0,), (0.0, 0.5), (True, False)) + ], + ) + ) + + steps_warmup = 3 + steps_train = 2 + + for OptClass, kwargs in cases: + has_capturable_arg = OptClass in (torch.optim.Adam, torch.optim.AdamW) + for actually_do_graphs in (True, False) if has_capturable_arg else (True,): + params = [torch.randn((i + 5, i + 5), device="cuda") for i in range(2)] + params_control = [p.clone().requires_grad_() for p in params] + params_graphed = [p.clone().requires_grad_() for p in params] + + # `GradScaler` in-place updates gradients thus it's necessary to duplicate gradients. + grads = [ + [torch.randn_like(p) for p in params] + for _ in range(steps_warmup + steps_train) + ] + with torch.no_grad(): + grads_control = [[g.clone() for g in gs] for gs in grads] + grads_graphed = [[g.clone() for g in gs] for gs in grads] + + # Gradient Scaler + scaler_for_control = torch.cuda.amp.GradScaler(init_scale=128.0) + with torch.no_grad(): + scaler_for_control._lazy_init_scale_growth_tracker( + torch.device("cuda") + ) + + scaler_for_graphed = torch.cuda.amp.GradScaler() + scaler_for_graphed.load_state_dict(scaler_for_control.state_dict()) + with torch.no_grad(): + scaler_for_graphed._lazy_init_scale_growth_tracker( + torch.device("cuda") + ) + + # Control (capturable=False) + if has_capturable_arg: + kwargs["capturable"] = False + opt = OptClass(params_control, **kwargs) + + for i in range(steps_warmup + steps_train): + for j, p in enumerate(params_control): + p.grad = grads_control[i][j] + scaler_for_control.step(opt) + scaler_for_control.update() + + # capturable=True + if has_capturable_arg: + kwargs["capturable"] = True + opt = OptClass(params_graphed, **kwargs) + + for i in range(steps_warmup): + for j, p in enumerate(params_graphed): + p.grad = grads_graphed[i][j] + scaler_for_graphed.step(opt) + scaler_for_graphed.update() + + if actually_do_graphs: + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + scaler_for_graphed.step(opt) + scaler_for_graphed.update() + + for i in range(steps_train): + if actually_do_graphs: + for j, p in enumerate(params_graphed): + p.grad.copy_(grads_graphed[i + steps_warmup][j]) + g.replay() + else: + # Passing capturable=True to the constructor and running without graphs should still be + # numerically correct, even if it's not ideal for performance. + for j, p in enumerate(params_graphed): + p.grad = grads_graphed[i + steps_warmup][j] + scaler_for_graphed.step(opt) + scaler_for_graphed.update() + + for p_control, p_graphed in zip(params_control, params_graphed): + self.assertEqual(p_control, p_graphed) + @unittest.skipIf( not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" ) @@ -4033,7 +4251,10 @@ def free(): @unittest.skipIf(TEST_PYNVML, "pynvml is not available") def test_nvml_get_handler(self): - self.assertTrue(torch.cuda._get_pynvml_handler() is not None) + if not torch.version.hip: + self.assertTrue(torch.cuda._get_pynvml_handler() is not None) + else: + self.assertTrue(torch.cuda._get_amdsmi_handler() is not None) @unittest.skipIf(TEST_PYNVML, "pynvml is not available") def test_temperature(self): @@ -4480,175 +4701,10 @@ def test_no_triton_on_import(self): self.assertEqual(rc, "False", "Triton was imported when importing torch!") -@torch.testing._internal.common_utils.markDynamoStrictTest class TestCudaOptims(TestCase): # These tests will be instantiate with instantiate_device_type_tests # to apply the new OptimizerInfo structure. - @onlyCUDA - @unittest.skipIf( - not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >=5.3 required for graphs" - ) - @optims( - [optim for optim in optim_db if optim.has_capturable_arg], - dtypes=[torch.float32], - ) - def test_graph_optims(self, device, dtype, optim_info): - optim_cls = optim_info.optim_cls - all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( - device, dtype, optim_info, skip=("differentiable",) - ) - - steps_warmup = 3 - steps_train = 2 - - for optim_input in all_optim_inputs: - kwargs = optim_input.kwargs - - # lr as a Tensor is not supported when capturable=False and foreach=True for torch.optim.adam - # and torch.optim.adamw - kwargs["lr"] = 0.1 - - for actually_do_graphs in (True, False): - params = [ - torch.randn((i + 5, i + 5), device=device) for i in range(2) - ] + [torch.randn((), device=device)] - params_control = [p.clone().requires_grad_() for p in params] - params_graphed = [p.clone().requires_grad_() for p in params] - - grads = [ - [torch.randn_like(p) for p in params] - for _ in range(steps_warmup + steps_train) - ] - - # Control (capturable=False) - kwargs["capturable"] = False - - opt = optim_cls(params_control, **kwargs) - for i in range(steps_warmup + steps_train): - for j, p in enumerate(params_control): - p.grad = grads[i][j] - opt.step() - - # capturable=True - kwargs["capturable"] = True - opt = optim_cls(params_graphed, **kwargs) - - for i in range(steps_warmup): - for j, p in enumerate(params_graphed): - p.grad = grads[i][j] - opt.step() - - if actually_do_graphs: - g = torch.cuda.CUDAGraph() - with torch.cuda.graph(g): - opt.step() - - for i in range(steps_train): - if actually_do_graphs: - for j, p in enumerate(params_graphed): - p.grad.copy_(grads[i + steps_warmup][j]) - g.replay() - else: - # Passing capturable=True to the constructor and running without graphs should still be - # numerically correct, even if it's not ideal for performance. - for j, p in enumerate(params_graphed): - p.grad = grads[i + steps_warmup][j] - opt.step() - - for p_control, p_graphed in zip(params_control, params_graphed): - self.assertEqual(p_control, p_graphed) - - @onlyCUDA - @unittest.skipIf( - not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" - ) - @optims( - [optim for optim in optim_db if "fused" in optim.supported_impls], - dtypes=[torch.float32], - ) - def test_graph_scaling_fused_optimizers(self, device, dtype, optim_info): - optim_cls = optim_info.optim_cls - - steps_warmup = 3 - steps_train = 2 - - optim_inputs = optim_info.optim_inputs_func(device=device) - - for optim_input in optim_inputs: - kwargs = optim_input.kwargs - kwargs["fused"] = True - - for actually_do_graphs in ( - (True, False) if optim_info.has_capturable_arg else (True,) - ): - params = [torch.randn((i + 5, i + 5), device=device) for i in range(2)] - params_control = [p.clone().requires_grad_() for p in params] - params_graphed = [p.clone().requires_grad_() for p in params] - - # `GradScaler` in-place updates gradients thus it's necessary to duplicate gradients. - grads = [ - [torch.randn_like(p) for p in params] - for _ in range(steps_warmup + steps_train) - ] - with torch.no_grad(): - grads_control = [[g.clone() for g in gs] for gs in grads] - grads_graphed = [[g.clone() for g in gs] for gs in grads] - - # Gradient Scaler - scaler_for_control = torch.cuda.amp.GradScaler(init_scale=128.0) - with torch.no_grad(): - scaler_for_control._lazy_init_scale_growth_tracker(device) - - scaler_for_graphed = torch.cuda.amp.GradScaler() - scaler_for_graphed.load_state_dict(scaler_for_control.state_dict()) - with torch.no_grad(): - scaler_for_graphed._lazy_init_scale_growth_tracker(device) - - # Control (capturable=False) - if optim_info.has_capturable_arg: - kwargs["capturable"] = False - opt = optim_cls(params_control, **kwargs) - - for i in range(steps_warmup + steps_train): - for j, p in enumerate(params_control): - p.grad = grads_control[i][j] - scaler_for_control.step(opt) - scaler_for_control.update() - - # capturable=True - if optim_info.has_capturable_arg: - kwargs["capturable"] = True - opt = optim_cls(params_graphed, **kwargs) - - for i in range(steps_warmup): - for j, p in enumerate(params_graphed): - p.grad = grads_graphed[i][j] - scaler_for_graphed.step(opt) - scaler_for_graphed.update() - - if actually_do_graphs: - g = torch.cuda.CUDAGraph() - with torch.cuda.graph(g): - scaler_for_graphed.step(opt) - scaler_for_graphed.update() - - for i in range(steps_train): - if actually_do_graphs: - for j, p in enumerate(params_graphed): - p.grad.copy_(grads_graphed[i + steps_warmup][j]) - g.replay() - else: - # Passing capturable=True to the constructor and running without graphs should still be - # numerically correct, even if it's not ideal for performance. - for j, p in enumerate(params_graphed): - p.grad = grads_graphed[i + steps_warmup][j] - scaler_for_graphed.step(opt) - scaler_for_graphed.update() - - for p_control, p_graphed in zip(params_control, params_graphed): - self.assertEqual(p_control, p_graphed) - @onlyCUDA @unittest.skipIf( not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" diff --git a/test/test_mps.py b/test/test_mps.py index 24c4e2d45e48e..cbf8874e1c220 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -61,10 +61,17 @@ ) ) +def xfailIf(condition): + def wrapper(func): + if condition: + return unittest.expectedFailure(func) + else: + return func + return wrapper + def xfailIfMacOS14_4Plus(func): return unittest.expectedFailure(func) if product_version > 14.3 else func # noqa: F821 - def mps_ops_grad_modifier(ops): XFAILLIST_GRAD = { @@ -901,9 +908,9 @@ def mps_ops_modifier(ops): 'fft.rfft2': None, 'fft.rfftn': None, 'stft': None, - # Error in TestConsistencyCPU.test_output_match_isin_cpu_int32, + # Error in TestConsistencyCPU.test_output_match_isin_cpu fails for integers, # not reproducible in later OS. Added assert to op if used in < 14.0 - 'isin': None, + 'isin': [torch.int64, torch.int32, torch.int16, torch.uint8, torch.int8], }) UNDEFINED_XFAILLIST = { @@ -8218,7 +8225,6 @@ def helper(dtype): [helper(dtype) for dtype in [torch.float32, torch.float16, torch.int32, torch.int16, torch.uint8, torch.int8, torch.bool]] - @unittest.skipIf(product_version < 14.0, "Skipped on MacOS < 14.0") def test_isin(self): def helper(dtype): shapes = [([2, 5], [3, 5, 2]), ([10, 3, 5], [20, 1, 3]), @@ -8237,15 +8243,19 @@ def helper(dtype): B_mps = B.clone().detach().to('mps') cpu_ref = torch.isin(A, B, invert=inverted) - if dtype is torch.float16: + if dtype in [torch.float16, torch.bfloat16]: cpu_ref.type(dtype) mps_out = torch.isin(A_mps, B_mps, invert=inverted) self.assertEqual(mps_out, cpu_ref) - [helper(dtype) for dtype in [torch.float32, torch.float16, torch.int32, torch.int16, torch.uint8, torch.int8]] + dtypes = [torch.float32, torch.float16, torch.bfloat16, torch.int32, torch.int16, torch.uint8, torch.int8] + if product_version < 14.0: + # Int types expected to fail on MacOS < 14.0 + dtypes = [torch.float32, torch.float16, torch.bfloat16] + + [helper(dtype) for dtype in dtypes] - @unittest.skipIf(product_version < 14.0, "Skipped on MacOS < 14.0") def test_isin_asserts(self): A = torch.randn(size=[1, 4], device='mps', dtype=torch.float32) B = torch.randn(size=[1, 4], device='mps', dtype=torch.float16) diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index 842033d005f09..597180129f727 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -4246,7 +4246,7 @@ def in_proj(input_packed, qkv_linear=qkv_linear): # Low Precision Math Reference out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math( q, k, v)[0].transpose(-2, -3) - output_ref_atol, output_ref_rtol = get_tolerances(out, out_lp_ref) + output_ref_atol, output_ref_rtol = get_tolerances(out, out_lp_ref, fudge_factor=2) self.assertEqual(out, out_component, atol=output_ref_atol, rtol=output_ref_rtol) diff --git a/tools/stats/monitor.py b/tools/stats/monitor.py index cd43fe5a8e7f1..e5c56f49de3b2 100644 --- a/tools/stats/monitor.py +++ b/tools/stats/monitor.py @@ -2,28 +2,10 @@ import datetime import json import signal -import sys import time from typing import Any, Dict, List import psutil # type: ignore[import] -import pynvml # type: ignore[import] - -# ROCm does not currently have the rocm_smi module installed to a pythonic location. -# Must import from ROCm installation path. -# Cannot use the high-level rocm_smi cmdline module due to its use of exit(). -# Must use the lower-level ctypes wrappers exposed through rsmiBindings. -sys.path.append("/opt/rocm/libexec/rocm_smi") -try: - from ctypes import byref, c_uint32, c_uint64 - - from rsmiBindings import ( # type: ignore[import] - rocmsmi, - rsmi_process_info_t, - rsmi_status_t, - ) -except ImportError as e: - pass def get_processes_running_python_tests() -> List[Any]: @@ -76,78 +58,42 @@ def get_per_process_gpu_info(handle: Any) -> List[Dict[str, Any]]: return per_process_info -def rocm_ret_ok(ret: int) -> Any: - return ret == rsmi_status_t.RSMI_STATUS_SUCCESS - - -def rocm_list_devices() -> List[int]: - num = c_uint32(0) - ret = rocmsmi.rsmi_num_monitor_devices(byref(num)) - if rocm_ret_ok(ret): - return list(range(num.value)) - return [] - - -def rocm_get_mem_use(device: int) -> float: - memoryUse = c_uint64() - memoryTot = c_uint64() - - ret = rocmsmi.rsmi_dev_memory_usage_get(device, 0, byref(memoryUse)) - if rocm_ret_ok(ret): - ret = rocmsmi.rsmi_dev_memory_total_get(device, 0, byref(memoryTot)) - if rocm_ret_ok(ret): - return float(memoryUse.value) / float(memoryTot.value) - return 0.0 - - -def rocm_get_gpu_use(device: int) -> float: - percent = c_uint32() - ret = rocmsmi.rsmi_dev_busy_percent_get(device, byref(percent)) - if rocm_ret_ok(ret): - return float(percent.value) - return 0.0 - - -def rocm_get_pid_list() -> List[Any]: - num_items = c_uint32() - ret = rocmsmi.rsmi_compute_process_info_get(None, byref(num_items)) - if rocm_ret_ok(ret): - buff_sz = num_items.value + 10 - procs = (rsmi_process_info_t * buff_sz)() - procList = [] - ret = rocmsmi.rsmi_compute_process_info_get(byref(procs), byref(num_items)) - for i in range(num_items.value): - procList.append(procs[i].process_id) - return procList - return [] - - -def rocm_get_per_process_gpu_info() -> List[Dict[str, Any]]: +def rocm_get_per_process_gpu_info(handle: Any) -> List[Dict[str, Any]]: + processes = amdsmi.amdsmi_get_gpu_process_list(handle) per_process_info = [] - for pid in rocm_get_pid_list(): - proc = rsmi_process_info_t() - ret = rocmsmi.rsmi_compute_process_info_by_pid_get(int(pid), byref(proc)) - if rocm_ret_ok(ret): - info = {"pid": pid, "gpu_memory": proc.vram_usage} - per_process_info.append(info) + for p in processes: + proc_info = amdsmi.amdsmi_get_gpu_process_info(handle, p) + info = { + "pid": proc_info["pid"], + "gpu_memory": proc_info["memory_usage"]["vram_mem"], + } + per_process_info.append(info) return per_process_info if __name__ == "__main__": handle = None try: - pynvml.nvmlInit() - handle = pynvml.nvmlDeviceGetHandleByIndex(0) - except pynvml.NVMLError: + import pynvml # type: ignore[import] + + try: + pynvml.nvmlInit() + handle = pynvml.nvmlDeviceGetHandleByIndex(0) + except pynvml.NVMLError: + pass + except ModuleNotFoundError: # no pynvml avaliable, probably because not cuda pass - - rsmi_handles = [] try: - ret = rocmsmi.rsmi_init(0) - rsmi_handles = rocm_list_devices() - except Exception: - # no rocmsmi available, probably because not rocm + import amdsmi # type: ignore[import] + + try: + amdsmi.amdsmi_init() + amdsmi_handle = amdsmi.amdsmi_get_processor_handles()[0] + except amdsmi.AmdSmiException: + pass + except ModuleNotFoundError: + # no amdsmi is available pass kill_now = False @@ -171,17 +117,16 @@ def exit_gracefully(*args: Any) -> None: gpu_utilization = pynvml.nvmlDeviceGetUtilizationRates(handle) stats["total_gpu_utilization"] = gpu_utilization.gpu stats["total_gpu_mem_utilization"] = gpu_utilization.memory - if rsmi_handles: - stats["per_process_gpu_info"] = rocm_get_per_process_gpu_info() - # There are 1 to 4 GPUs in use; these values may sum > 1.0. - gpu_utilization = 0.0 - gpu_memory = 0.0 - for dev in rsmi_handles: - gpu_utilization += rocm_get_gpu_use(dev) - gpu_memory += rocm_get_mem_use(dev) - stats["total_gpu_utilization"] = gpu_utilization - stats["total_gpu_mem_utilization"] = gpu_memory - + if amdsmi_handle is not None: + stats["per_process_gpu_info"] = rocm_get_per_process_gpu_info( + amdsmi_handle + ) + stats["total_gpu_utilization"] = amdsmi.amdsmi_get_gpu_activity( + amdsmi_handle + )["gfx_activity"] + stats["total_gpu_mem_utilization"] = amdsmi.amdsmi_get_gpu_activity( + amdsmi_handle + )["umc_activity"] except Exception as e: stats = { "time": datetime.datetime.utcnow().isoformat("T") + "Z", diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index 74a73a3ddaa46..28d790e3d6903 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -210,20 +210,6 @@ class PrefixStore(Store): @property def underlying_store(self) -> Store: ... -class _ControlCollectives: - def barrier(self, key: str, timeout: timedelta, blocking: bool) -> None: ... - def broadcast_send(self, key: str, data: str, timeout: timedelta) -> None: ... - def broadcast_recv(self, key: str, timeout: timedelta) -> str: ... - def gather_send(self, key: str, data: str, timeout: timedelta) -> None: ... - def gather_recv(self, key: str, timeout: timedelta) -> str: ... - def scatter_send(self, key: str, data: str, timeout: timedelta) -> None: ... - def scatter_recv(self, key: str, timeout: timedelta) -> str: ... - def all_gather(self, key: str, data: str, timeout: timedelta) -> str: ... - def all_sum(self, key: str, data: str, timeout: timedelta) -> int: ... - -class _StoreCollectives(_ControlCollectives): - def __init__(self, store: Store, rank: int, world_size: int) -> None: ... - class _DistributedBackendOptions: def __init__(self): ... @property diff --git a/torch/_dynamo/comptime.py b/torch/_dynamo/comptime.py index 23000c464fdbb..80880588b54e3 100644 --- a/torch/_dynamo/comptime.py +++ b/torch/_dynamo/comptime.py @@ -14,7 +14,9 @@ from torch.fx.experimental.symbolic_shapes import free_symbols from .exc import unimplemented +from .variables import NewCellVariable from .variables.constant import ConstantVariable +from .variables.misc import ClosureVariable from .variables.tensor import SymNodeVariable @@ -146,7 +148,20 @@ def get_local(self, name: str, *, stacklevel=0) -> ComptimeVar: Retrieve the compile-time known information about a local. """ tx = self.__get_tx(stacklevel) - return ComptimeVar(tx.symbolic_locals[name]) + + # This is analogous to LOAD_DEREF + if hasattr(tx, "closure_cells") and name in tx.closure_cells: + cell = tx.closure_cells[name] + if isinstance(cell, ClosureVariable): + return ComptimeVar(tx.output.root_tx.symbolic_locals[cell.name]) + else: + return ComptimeVar(tx.output.side_effects.load_cell(cell)) + else: + r = tx.symbolic_locals[name] + if isinstance(r, NewCellVariable): + return ComptimeVar(tx.output.side_effects.load_cell(r)) + else: + return ComptimeVar(r) def graph_break(self, msg="ComptimeContext.graph_break"): """ diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index d5c24a67d9e25..6dcb84fab8fc1 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -1,4 +1,3 @@ -import base64 import collections import cProfile import dis @@ -350,20 +349,13 @@ def profile_wrapper(*args, **kwargs): ps.sort_stats(pstats.SortKey.TIME).print_stats(20) ps.sort_stats(pstats.SortKey.CUMULATIVE).print_stats(20) - maybe_upload_prof_stats_to_manifold(str(profile_path)) # fb-only - - torch._logging.trace_structured( - "artifact", - lambda: { - "name": "dynamo_cprofile_prof", - "type": "prof", - "encoding": "base64", - }, - payload_fn=lambda: base64.encodebytes( - open(profile_path, "rb").read() - ).decode("ascii"), - ) - + if manifold_link := maybe_upload_prof_stats_to_manifold( + str(profile_path) + ): # fb-only + torch._logging.trace_structured( + "link", + lambda: {"name": "cprofile_manifold_url", "url": manifold_link}, + ) return retval return profile_wrapper diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 093809703405f..864e53777941e 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -343,9 +343,11 @@ def inner(self: "InstructionTranslatorBase", inst: Instruction): ): error_msg: VariableTracker = self.pop() # Skip over things like `assert True` - if value.is_python_constant() and bool(value.as_python_constant()): - self.jump(inst) - return + if value.is_python_constant(): + if bool(value.as_python_constant()): + return self.jump(inst) + else: + jump_graph_break(self, inst, value) # TODO maybe should respect DtoH sync intention of users later?? # Manually insert torch._assert_async instead of python assert and jump over diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 8a2c12ee4e84e..b43f447737128 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -1302,9 +1302,6 @@ "torch._C.parse_schema", "torch._C.parse_type_comment", "torch._C.read_vitals", - "torch._C.set_flush_denormal", - "torch._C.set_num_interop_threads", - "torch._C.set_num_threads", "torch._C.set_vital", "torch._C.unify_type_list", "torch._C.vitals_enabled", @@ -2430,7 +2427,10 @@ "torch.cpu.synchronize", "torch.cuda._check_capability", "torch.cuda._check_cubins", + "torch.cuda._device_count_amdsmi", "torch.cuda._device_count_nvml", + "torch.cuda._get_amdsmi_handler", + "torch.cuda._get_amdsmi_device_index", "torch.cuda._get_device", "torch.cuda._get_generator", "torch.cuda._get_nvml_device_index", @@ -2461,7 +2461,9 @@ "torch.cuda._memory_viz.trace", "torch.cuda._nvml_based_avail", "torch.cuda._parse_visible_devices", + "torch.cuda._raw_device_count_amdsmi", "torch.cuda._raw_device_count_nvml", + "torch.cuda._raw_device_uuid_amdsmi", "torch.cuda._raw_device_uuid_nvml", "torch.cuda._register_triton_kernels", "torch.cuda._set_rng_state_offset", diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 3e9495b3c7ca8..02f4f8f47b279 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -218,7 +218,8 @@ def call_function( # TODO: support an expression form as well assert not kwargs - assert len(args) == 1 + # Second argument is runtime lambda, ignored + assert len(args) <= 2 fn = args[0] if isinstance(fn, UserFunctionVariable): fn.get_function()(ComptimeContext(tx)) diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py index 0956ee7e367c4..b1a9502bcf3dd 100644 --- a/torch/_functorch/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -533,12 +533,13 @@ def reordering_to_mimic_autograd_engine(gm: fx.GraphModule) -> fx.GraphModule: # Populate depth for the nodes. Depth is the distance from the inputs. depths = {} - output_node = next(iter(gm.graph.find_nodes(op="output"))) for node in gm.graph.nodes: if node.op == "placeholder": depths[node] = 0 else: - depths[node] = max([depths[arg] for arg in node.all_input_nodes], default=0) + depths[node] = ( + max((depths[arg] for arg in node.all_input_nodes), default=0) + 1 + ) def insert_node_in_graph(node): if node in env: @@ -802,6 +803,8 @@ def should_ban_recomputation(node): return False if node.target == operator.getitem: return False + if op_types.is_view(node): + return False if node.target in [aten.lift_fresh_copy.default, aten.lift_fresh.default]: return False # NB: "recompute" == 0 means that must save this node. @@ -854,6 +857,14 @@ def is_materialized(node): def get_node_weight(node) -> float: mem_sz = _size_of(node) + if op_types.is_view(node): + # We never choose to save views, since views are free to recompute. + # It makes it a bit simpler to analyze + # NB: If they're not free to recompute (e.g. nested tensors)... I + # think we should modify checks for view_ops to `is_view` and check + # that. Basically, with nested tensors, `aten.view` is not a "view + # op". + return math.inf if isinstance(node.meta["val"], py_sym_types): # We never want to save symfloats diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 70b4671431115..6efa7e0db572c 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -773,11 +773,23 @@ def iterate_over_candidates() -> Generator[CompiledFxGraph, None, None]: subdir = FxGraphCache._get_tmp_dir_for_key(key) if os.path.exists(subdir): for path in sorted(os.listdir(subdir)): - with open(os.path.join(subdir, path), "rb") as f: - yield pickle.load(f) + try: + with open(os.path.join(subdir, path), "rb") as f: + yield pickle.load(f) + except Exception: + log.warning( + "fx graph cache unable to load compiled graph", + exc_info=True, + ) + if remote_cache: - if (data := remote_cache.get(key)) is not None: - yield pickle.loads(data) + try: + if (data := remote_cache.get(key)) is not None: + yield pickle.loads(data) + except Exception: + log.warning( + "fx graph cache unable to load compiled graph", exc_info=True + ) # Iterate over any entries in the subdir for this key and evaluate # their guards to determine whether there's a hit. @@ -890,32 +902,39 @@ def _save_graph( try: content = pickle.dumps(disk_compiled_graph) - except Exception as e: - log.debug("fx graph cache unable to serialize compiled graph: %s", e) + except Exception: + log.warning( + "fx graph cache unable to serialize compiled graph", exc_info=True + ) counters["inductor"]["fxgraph_cache_pickle_error"] += 1 return - if local: - subdir = FxGraphCache._get_tmp_dir_for_key(key) - if not os.path.exists(subdir): - os.makedirs(subdir, exist_ok=True) - - # Use a hash of the serialized CompiledFxGraph to get a unique file - # name. The specific name doesn't matter since a lookup involves - # iterating over all entries in the parent subdir. - path = os.path.join(subdir, sha256_hash(content)) - write_atomic(path, content, make_dirs=True) - - if remote_cache: - cache_data = ( - { - "data": content, - "time_taken_ms": time_taken_ns // 1000000, # Convert from NS to MS - } - if config.is_fbcode() - else content - ) - remote_cache.put(key, cache_data) + try: + if local: + subdir = FxGraphCache._get_tmp_dir_for_key(key) + if not os.path.exists(subdir): + os.makedirs(subdir, exist_ok=True) + + # Use a hash of the serialized CompiledFxGraph to get a unique file + # name. The specific name doesn't matter since a lookup involves + # iterating over all entries in the parent subdir. + path = os.path.join(subdir, sha256_hash(content)) + write_atomic(path, content, make_dirs=True) + + if remote_cache: + cache_data = ( + { + "data": content, + "time_taken_ms": time_taken_ns + // 1000000, # Convert from NS to MS + } + if config.is_fbcode() + else content + ) + remote_cache.put(key, cache_data) + except Exception: + log.warning("fx graph unable to write to cache", exc_info=True) + counters["inductor"]["fxgraph_cache_write_error"] += 1 @staticmethod def _check_can_cache(gm: torch.fx.GraphModule): @@ -1924,12 +1943,19 @@ def _compile_consts_linux(consts: bytes) -> str: run_command_and_check(cmd) log.debug("aot constant binary command: %s", cmd) - # .data section is between .text and .bss. When the size of .data is large, - # during the linking, the relocation of .text against .bss may overflow. - # Rename it to .ldata so that it won't be in between the .text and .bss section + if config.aot_inductor.allow_buffer_mutation: + # .data section is between .text and .bss. When the size of .data is large, + # during the linking, the relocation of .text against .bss may overflow. + # Rename it to .ldata so that it won't be in between the .text and .bss section + rename_data = " .data=.ldata" + else: + # if no buffer mutation is needed, we could instead set the data region + # as read-only (i.e. .lrodata) which could accomodate larger size of data + # to be linked. + rename_data = " .data=.lrodata,alloc,load,readonly,data,contents" cmd = ( f"{objcopy_command} --rename-section" - " .data=.ldata" + f"{rename_data}" " --set-section-alignment .data=64" # following the gAlignment of CPU in c10/core/alignment.h f" {consts_o} {consts_o}" ) diff --git a/torch/_inductor/codegen/cpp_gemm_template.py b/torch/_inductor/codegen/cpp_gemm_template.py index c664ba7fae45d..e0a4c0993549c 100644 --- a/torch/_inductor/codegen/cpp_gemm_template.py +++ b/torch/_inductor/codegen/cpp_gemm_template.py @@ -39,7 +39,8 @@ const int64_t M = {{kernel.size(GemmOut, 0)}}; const int64_t M0_blocks = (M + M0 - 1) / M0; {%- if num_threads > 1 %} - const auto [Mt_blocks, Nt_blocks, Kt_blocks] = mm_get_thread_blocking(M, N, K, M0, N0, K0, num_threads); + int64_t Mt_blocks, Nt_blocks, Kt_blocks; + mm_get_thread_blocking(num_threads, M, N, K, M0, N0, K0, Mt_blocks, Nt_blocks, Kt_blocks); {%- else %} const auto Mt_blocks = M0_blocks; const auto Nt_blocks = N0_blocks; @@ -362,7 +363,9 @@ def render( # type: ignore[override] if epilogue_nodes: Y = cast(ir.Buffer, epilogue_nodes[-1]) assert Y.get_name() in V.kernel.inplace_update_buffers - if Y.get_stride() == list(reversed(template_buffer.get_stride())): + if Y.get_size() == list( + reversed(template_buffer.get_size()) + ) and Y.get_stride() == list(reversed(template_buffer.get_stride())): Y_is_transposed = True micro_gemm = create_micro_gemm( diff --git a/torch/_inductor/codegen/cpp_prefix.h b/torch/_inductor/codegen/cpp_prefix.h index fe083308988b3..6898a8a52112e 100644 --- a/torch/_inductor/codegen/cpp_prefix.h +++ b/torch/_inductor/codegen/cpp_prefix.h @@ -311,14 +311,17 @@ atomic_add(volatile T *addr, T offset) { atomic_addr->fetch_add(offset, std::memory_order_relaxed); } -std::tuple mm_get_thread_blocking( +void mm_get_thread_blocking( + int num_threads, int64_t M, int64_t N, int64_t K, int64_t M0, int64_t N0, int64_t K0, - int num_threads) { + int64_t& Mt, + int64_t& Nt, + int64_t& Kt) { auto get_factors = [](int64_t number) { int count = 0; for (int64_t i = std::sqrt(number); i > 0; --i) { @@ -359,27 +362,30 @@ std::tuple mm_get_thread_blocking( int64_t factor = factors[i]; if (n_blocks % factor == 0 && m_blocks % (num_threads / factor) == 0) { - return get_blocking( + std::tie(Mt, Nt, Kt) = get_blocking( num_threads, factor, m_blocks, n_blocks, k_blocks); + return; } } for (int i = 0; i < count; ++i) { int64_t factor = factors[i]; if (n_blocks % factor == 0) { - return get_blocking( + std::tie(Mt, Nt, Kt) = get_blocking( num_threads, factor, m_blocks, n_blocks, k_blocks); + return; } int64_t cofactor = num_threads / factor; if (m_blocks % cofactor == 0) { - return get_blocking( + std::tie(Mt, Nt, Kt) = get_blocking( num_threads, factor, m_blocks, n_blocks, k_blocks); + return; } } assert(false && "Should not reach here."); // Dummy return to avoid compiler warning - return std::make_tuple(0, 0, 0); + return; } inline void mm_get_thread_blocks( diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index 9595f1da6f957..38df2331315ed 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -46,6 +46,10 @@ def __init__(self): self.supports_intermediate_hooks = False self.outputs_need_copy = set() self.kernel_callsite_id = count() + self.var_array_id = ( + count() + ) # for different types of local array variable declarations + self.declared_var_array_vars = set() self.int_array_id = count() # for int array local variable declarations self.declared_int_array_vars = set() self.tmp_tensor_id = count() # for tmp tensor local variable declarations @@ -1511,6 +1515,43 @@ def codegen_int_array_var( writer.writeline(f"const {ctype} {var}[] = {int_array};") return var + @functools.lru_cache(None) + def codegen_var_array( + self, + var_array: str, + writer=None, + known_statically=False, + graph=None, # for per-graph caching + type_hint=None, # ['int64_t', 'tensor', 'bool'] + ): + # Because the memory planning is done in two passes (see the implementation + # of self.generate), the writeline behavior is different in the two passes. + # As a result, the emitted int array declarations may appear in a later + # position of the generated code, so the second pass codegen should not + # reuse int array declarations generated in the first pass + if writer is None: + # The first pass codegen uses `self` as the writer + writer = self + if not type_hint or type_hint in ["bool", "int64_t"]: + return self.codegen_int_array_var( + var_array, + writer, + known_statically, + graph, + is_bool=type_hint == "bool", + ) + + var = f"var_array_{next(self.var_array_id)}" + assert type_hint == "tensor" + ctype = "AtenTensorHandle*" + if var not in self.declared_var_array_vars: + self.declared_var_array_vars.add(var) + if known_statically: + writer.writeline(f"static constexpr {ctype} {var}[] = {var_array};") + else: + writer.writeline(f"const {ctype} {var}[] = {var_array};") + return var + def make_buffer_allocation(self, buffer): return self.make_allocation( buffer.get_name(), @@ -2243,7 +2284,7 @@ def generate_reset_kernel_saved_flags(self): def generate_save_uncompiled_kernels(self): pass - def val_to_cpp_arg_str(self, type_, val) -> str: + def val_to_cpp_arg_str(self, val, type_) -> str: if config.abi_compatible and isinstance(type_, torch.OptionalType): if val is None: return "0" # nullptr is not available in C @@ -2280,9 +2321,9 @@ def val_to_cpp_arg_str(self, type_, val) -> str: self.writeline(f"AtenTensorHandle {var_name} = {base_handle}.get();") return f"&{var_name}" - return self.val_to_arg_str(val) + return self.val_to_arg_str(val, type_) - def val_to_arg_str(self, val) -> str: + def val_to_arg_str(self, val, type_=None) -> str: if val is None: # When None is passed as an argument, it represents an optional that does not contain a value. if config.abi_compatible: @@ -2317,14 +2358,28 @@ def val_to_arg_str(self, val) -> str: if config.abi_compatible: assert len(val) > 0, "Empty array is not supported in C" static = self.is_statically_known_list_of_ints(val) + type_hint = "bool" if isinstance(val[0], bool) else "int64_t" + if ( + type_ is not None + and isinstance(type_, torch._C.ListType) + and isinstance(type_.getElementType(), torch._C.OptionalType) + and isinstance( + type_.getElementType().getElementType(), torch._C.TensorType + ) + ): + type_hint = "tensor" + tmp_arg_list = "" + for x in val: + tmp_arg_list += f"&{x}_handle, " + result = f"{{{tmp_arg_list}}}" # Need to pass the array length because we can't use std::vector - int_var_array = self.codegen_int_array_var( + var_array = self.codegen_var_array( result, known_statically=static, graph=self.get_codegened_graph(), - is_bool=isinstance(val[0], bool), + type_hint=type_hint, ) - return f"{int_var_array}, {len(val)}" + return f"{var_array}, {len(val)}" else: return result else: diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 183d28605b87a..eddb7bdcdcd14 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -13,7 +13,6 @@ import torch import torch._logging -import torch.utils._pytree as pytree from torch._dynamo.utils import preserve_rng_state from torch._inductor.runtime.hints import AutotuneHint, DeviceProperties @@ -1650,13 +1649,22 @@ def cse_multiple(line, n, masks): ) if not self.persistent_reduction: - partial_reduce_vars = pytree.tree_map( - self.reduction_resize, - cse_multiple( - f"tl.reduce(({csv(broadcasted_values)}), {dim}, {combine_helper_fn})", - len(values), - None, - ), + + def sum_fn(a, b): + return [ops.add(ai, bi) for ai, bi in zip(a, b)] + + sum_helper_fn = self._lift_helper(sum_fn, len(values)) + pre_reduce_vars = ", ".join( + f"{scan_var} * (rbase == (RBLOCK - 1))" + for scan_var in partial_scan_vars + ) + # tl.reduce doesn't work for non-commutative operators, so instead + # of repeating the scan op as a reduction, we use sum to select the + # last scan value + partial_reduce_vars = cse_multiple( + f"tl.reduce(({pre_reduce_vars}), -1, {sum_helper_fn}, keep_dims=True)", + len(values), + masks, ) accs_next = combine_fn(tuple(accumulators), partial_reduce_vars) full_scan_vars = combine_fn(tuple(accumulators), partial_scan_vars) diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 5d9a11de149e0..ff2fe1ec87cc0 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -1412,10 +1412,10 @@ def writelines(self, lines): def enter_context(self, ctx): self.lines.append(LineContext(ctx)) - def val_to_cpp_arg_str(self, type_, val) -> str: + def val_to_cpp_arg_str(self, val, type_) -> str: raise NotImplementedError - def val_to_arg_str(self, s): + def val_to_arg_str(self, s, type_=None): from torch.utils._triton import dtype_to_string, has_triton_package if has_triton_package(): diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index bdbfef2eee28f..5ad2b40894189 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -338,31 +338,6 @@ def maybe_disable_comprehensive_padding(example_inputs: List[torch.Tensor]): return contextlib.nullcontext() -@DebugContext.wrap -def count_bytes_inner( - gm: torch.fx.GraphModule, - example_inputs: List[torch.Tensor], - num_fixed: int = 0, - **kwargs, -): - shape_env = _shape_env_from_inputs(example_inputs) - fake_mode = fake_tensor_prop(gm, example_inputs) - - with V.set_fake_mode(fake_mode): - _recursive_post_grad_passes(gm, False) - - graph = GraphLowering(gm, shape_env=shape_env, num_static_inputs=num_fixed) - with V.set_graph_handler(graph), V.set_real_inputs( - example_inputs - ), maybe_disable_comprehensive_padding(example_inputs): - graph.run(*example_inputs) - num_bytes, nodes_num_elem, node_runtimes = graph.count_bytes() - metrics.num_bytes_accessed += num_bytes - metrics.nodes_num_elem += nodes_num_elem - metrics.node_runtimes += node_runtimes - return make_boxed_func(gm.forward) - - def fake_tensor_prop( gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor], @@ -795,6 +770,7 @@ def fx_codegen_and_compile( const_code=const_code, const_module=const_graph, ) + metrics_helper = metrics.CachedMetricsHelper() with V.set_graph_handler(graph): graph.run(*example_inputs) output_strides: List[Optional[Tuple[int, ...]]] = [] @@ -814,8 +790,11 @@ def fx_codegen_and_compile( else: output_strides.append(None) - metrics_helper = metrics.CachedMetricsHelper() compiled_fn = graph.compile_to_fn() + num_bytes, nodes_num_elem, node_runtimes = graph.count_bytes() + metrics.num_bytes_accessed += num_bytes + metrics.node_runtimes += node_runtimes + metrics.nodes_num_elem += nodes_num_elem if ( cudagraphs diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 6c8edbc0f6faf..729f30a82cf77 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -745,6 +745,9 @@ class aot_inductor: # rather than embedded into the data section. Needed to support 1B+ parameter models force_mmap_weights: bool = False + # flag to allow buffer mutation. This would remove the read-only property from buffers. + allow_buffer_mutation: bool = False + class cuda: # CUDA arch to use for CUDA template kernel compilation. diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index bfb7b8dea7ebb..0adf356f6262a 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -1656,14 +1656,10 @@ def codegen_subgraph(self, parent_graph): self.scheduler.codegen() def count_bytes(self): - from .scheduler import Scheduler - - scheduler = Scheduler(self.buffers) - total_bytes = 0 node_counts = [] node_runtimes = [] - for node in scheduler.nodes: + for node in self.scheduler.nodes: num_bytes = node.get_read_write_buffers_sizes() total_bytes += num_bytes node_counts.append((node, num_bytes // 4)) diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 8dc87c35836d8..8377fafab7b25 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -4412,7 +4412,7 @@ def codegen_args(self): type_ = self.arg_properties[i].get("type") args.append( V.graph.wrapper_code.val_to_cpp_arg_str( # type: ignore[arg-type] - type_, x + x, type_ ) ) else: @@ -4447,7 +4447,7 @@ def codegen_kwargs(self, skip_out=False): ) kwargs.append( V.graph.wrapper_code.val_to_cpp_arg_str( # type: ignore[arg-type] - type_, v + v, type_ ) ) else: @@ -5415,7 +5415,7 @@ def __repr__(self): if V.graph.cpp_wrapper and isinstance(self.op_overload, torch._ops.OpOverload): args = self.fill_non_provided_args(args, kwargs) args = [ - V.graph.wrapper_code.val_to_cpp_arg_str(param.real_type, x) + V.graph.wrapper_code.val_to_cpp_arg_str(x, param.real_type) for param, x in zip(self.op_overload._schema.arguments, args) ] else: diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index 32dff9d46668c..ddddbed11c829 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -115,9 +115,10 @@ def build_subgraph_buffer( # already created TensorBoxes as args from torch.utils._pytree import tree_map - env[node] = lowerings[node.target]( - *tree_map(lambda x: env[x] if x in env else x, node.args) + args, kwargs = tree_map( + lambda x: env[x] if x in env else x, (node.args, node.kwargs) ) + env[node] = lowerings[node.target](*args, **kwargs) elif node.op == "output": # For the output node we need to create a ComputedBuffer # which represents the actual score modification @@ -367,6 +368,8 @@ def _get_default_config_bwd(query) -> Tuple[int, int, int, int]: @register_lowering(torch.ops.higher_order.flex_attention, type_promotion_kind=None) def flex_attention(*args, **kwargs): query, key, value, subgraph, *other_buffers = args + for buf in [query, key, value]: + buf.realize() placeholder_inps = [ create_placeholder(name, dtype, query.get_device()) for name, dtype in [ @@ -640,6 +643,8 @@ def flex_attention_backward(*args, **kwargs): joint_graph, *other_buffers, ) = args + for buf in [query, key, value, grad_out]: + buf.realize() device = query.get_device() dtype = query.get_dtype() diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 07899fe2ccd09..adf8c542d33e6 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -5919,7 +5919,7 @@ def wrapped_combine_fn(lhs, rhs): kwargs["dtypes"] = tuple(x.get_dtype() for x in input) kwargs["inner_fns"] = tuple(x.make_loader() for x in input) result = ir.Scan.create(**kwargs, combine_fn=wrapped_combine_fn) - if result is None: + if result[0] is None: raise RuntimeError("Unable to generate code for associative_scan op") return result diff --git a/torch/_inductor/metrics.py b/torch/_inductor/metrics.py index 3a1b83045f4a0..76f15243c5ba1 100644 --- a/torch/_inductor/metrics.py +++ b/torch/_inductor/metrics.py @@ -1,6 +1,7 @@ from __future__ import annotations import csv +import dataclasses import inspect import os import re @@ -78,6 +79,11 @@ class CachedMetricsDeltas: generated_cpp_vec_kernel_count: int ir_nodes_pre_fusion: int cpp_to_dtype_count: int + num_bytes_accessed: int + + +def get_metric_fields(): + return [field.name for field in dataclasses.fields(CachedMetricsDeltas)] class CachedMetricsHelper: @@ -88,40 +94,21 @@ class CachedMetricsHelper: """ def __init__(self): - global generated_kernel_count - global generated_cpp_vec_kernel_count - global ir_nodes_pre_fusion - global cpp_to_dtype_count - - self.generated_kernel_count = generated_kernel_count - self.generated_cpp_vec_kernel_count = generated_cpp_vec_kernel_count - self.ir_nodes_pre_fusion = ir_nodes_pre_fusion - self.cpp_to_dtype_count = cpp_to_dtype_count + self.cached_metrics = {} + for metric in get_metric_fields(): + self.cached_metrics[metric] = globals()[metric] def get_deltas(self) -> CachedMetricsDeltas: - global generated_kernel_count - global generated_cpp_vec_kernel_count - global ir_nodes_pre_fusion - global cpp_to_dtype_count - - return CachedMetricsDeltas( - generated_kernel_count - self.generated_kernel_count, - generated_cpp_vec_kernel_count - self.generated_cpp_vec_kernel_count, - ir_nodes_pre_fusion - self.ir_nodes_pre_fusion, - cpp_to_dtype_count - self.cpp_to_dtype_count, - ) + delta_metrics = {} + for metric in get_metric_fields(): + delta_metrics[metric] = globals()[metric] - self.cached_metrics[metric] + + return CachedMetricsDeltas(**delta_metrics) @staticmethod def apply_deltas(delta: CachedMetricsDeltas): - global generated_kernel_count - global generated_cpp_vec_kernel_count - global ir_nodes_pre_fusion - global cpp_to_dtype_count - - generated_kernel_count += delta.generated_kernel_count - generated_cpp_vec_kernel_count += delta.generated_cpp_vec_kernel_count - ir_nodes_pre_fusion += delta.ir_nodes_pre_fusion - cpp_to_dtype_count += delta.cpp_to_dtype_count + for metric in get_metric_fields(): + globals()[metric] += getattr(delta, metric) REGISTERED_METRIC_TABLES: Dict[str, MetricTable] = {} diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 456e0c50567d5..edaa944722700 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -28,6 +28,7 @@ import torch from torch._dynamo.utils import counters, dynamo_timed from torch._inductor.metrics import get_metric_table, is_metric_table_enabled +from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols from torch.utils._sympy.symbol import free_symbol_is_type, SymT from torch.utils._triton import has_triton @@ -505,12 +506,16 @@ def get_read_write_buffers_sizes(self) -> int: if isinstance(self, ExternKernelSchedulerNode) and isinstance( self.node, MultiOutput ): + # todo: Calculate this - it's kinda annoying. return 0 + def try_size_hint(s): + return V.graph.sizevars.size_hint(s, fallback=0) + if isinstance(self, SchedulerNode): - node_numel = V.graph.sizevars.size_hint( + node_numel = try_size_hint( sympy_product(self.get_ranges()[0]) - * sympy_product(self.get_ranges()[1]) + * sympy_product(self.get_ranges()[1]), ) else: node_numel = int(1e9) @@ -545,16 +550,24 @@ def is_materialized(buf, snodes): continue def get_buf_elems(buf): - return V.graph.sizevars.size_hint(sympy_product(buf.get_size())) - - # Kind of a lazy way to get the MultiOutput nodes corresponding to - # a MultiOutputLayout - if isinstance(buf.layout, MultiOutputLayout): - users = self.scheduler.name_to_node[buf.get_name()].users - buf_elems = sum(get_buf_elems(user.node.node) for user in users) - else: - buf_elems = get_buf_elems(buf) + # Kind of a lazy way to get the MultiOutput nodes corresponding to + # a MultiOutputLayout + if isinstance(buf.layout, MultiOutputLayout): + users = self.scheduler.name_to_node[buf.get_name()].users + tot = 0 + for user in users: + if isinstance(user.node.node, MultiOutput): + tot += get_buf_elems(user.node.node) + else: + # Buf is a MultiOutputLayout but not all of its + # users are MultiOutputs... + # TODO: Figure out what's going on + return 0 + return tot + else: + return try_size_hint(sympy_product(buf.get_size())) + buf_elems = get_buf_elems(buf) node_bytes += min(buf_elems, buf_accessed_elems) * get_dtype_size( buf.get_dtype() ) @@ -580,13 +593,20 @@ def get_estimated_runtime(self) -> float: layout = self.node.get_layout() dtype = self.node.get_dtype() - if not is_gpu(layout.device.type): + if layout.device is not None and not is_gpu(layout.device.type): # default to no reordering based on runtime return 0 # Collective kernels if is_collective(self.node): - return estimate_nccl_collective_runtime(self.node) + try: + return estimate_nccl_collective_runtime(self.node) + except ValueError as e: + # We don't know how to estimate runtime for this collective, + # falling back to 0 + log.info(e) + return 0 + elif is_wait(self.node): # ir.Wait is only used for collective ops. # The time needed for the collective op is already estimated and considered @@ -611,7 +631,14 @@ def get_estimated_runtime(self) -> float: from torch._subclasses.fake_tensor import FakeTensorMode from torch.utils.flop_counter import FlopCounterMode - assert self.node.fx_node is not None + if any( + len(free_unbacked_symbols(n.get_numel())) > 0 + for n in self.node.inputs + ): + # Tensor has unbacked symints, we don't know how to estimate + # runtime for that today + return 0 + with FakeTensorMode() as fake_mode, FlopCounterMode( display=False ) as flop_counter_mode, V.set_current_node( @@ -619,7 +646,6 @@ def get_estimated_runtime(self) -> float: ), V.set_fake_mode( fake_mode ): - assert V.current_node is not None from .ir import ir_node_to_tensor fake_inputs = [ @@ -1752,7 +1778,9 @@ def replace_buffer(orig_node: ir.MultiTemplateBuffer, new_node: ir.Buffer): del V.graph.name_to_buffer[replaced_name] new_node.name = orig_name - V.graph.buffers.remove(orig_node) + orig = V.graph.buffers.index(orig_node) + V.graph.buffers.remove(new_node) + V.graph.buffers[orig] = new_node V.graph.name_to_buffer[orig_name] = new_node for i, node in enumerate(self.nodes): diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 6e923883ca68d..c3bbe917e1471 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -2,6 +2,7 @@ import functools import inspect import itertools +import json import logging import math @@ -17,6 +18,7 @@ from unittest.mock import patch import sympy +from filelock import FileLock import torch from torch._dynamo.testing import rand_strided @@ -912,6 +914,34 @@ def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType } +@functools.lru_cache(None) +def get_mm_log_filename() -> Optional[str]: + mm_file_name = os.environ.get("TORCHINDUCTOR_MM_LOGGING_FILE", None) + if not mm_file_name: + return None + + if "json" not in mm_file_name: + mm_file_name = f"{mm_file_name}.json" + + return mm_file_name + + +def append_to_log(filename, data): + lock_file = filename.replace(".json", ".lock") + lock = FileLock(lock_file) + with lock: + try: + with open(filename) as f: + log_data = json.load(f) + except (FileNotFoundError, json.JSONDecodeError): + log_data = [] + + log_data.append(data) + + with open(filename, "w") as f: + json.dump(log_data, f, indent=4) + + class DataProcessorChoiceCallerWrapper: def __init__(self, wrapped, preprocessor, postprocessor): self._wrapped = wrapped @@ -1048,6 +1078,11 @@ def __call__( # TODO(nmacchioni): remove once CI tests are fixed choices = [choice for choice in choices if choice is not None] + if mm_file_name := get_mm_log_filename(): + M, K = input_nodes[-2].get_size()[:2] + N = input_nodes[-1].get_size()[-1] + append_to_log(mm_file_name, {"invoke": str((M, K, N))}) + if len(choices) == 0: raise NoValidChoicesError( "No choices to select, please consider adding ATEN into max_autotune_gemm_backends " @@ -1431,9 +1466,48 @@ def log_results( for n in input_nodes ] ) + n = None if log.getEffectiveLevel() == logging.DEBUG else 10 top_k = sorted(timings, key=timings.__getitem__)[:n] best = top_k[0] + + def get_choice_info(choice): + if isinstance(choice, torch._inductor.select_algorithm.ExternKernelCaller): + return {"type": "cublas", "time": timings[choice]} + + assert isinstance( + choice, torch._inductor.select_algorithm.TritonTemplateCaller + ) + + info = choice.info_dict() + tile = info["tile_shape"] + + tile_vals = eval(tile) # type: ignore[arg-type] + BLOCK_M = tile_vals[0] + BLOCK_K = tile_vals[1] + BLOCK_N = tile_vals[2] + + return { + "type": "triton", + "time": timings[choice], + "BLOCK_M": BLOCK_M, + "BLOCK_K": BLOCK_K, + "BLOCK_N": BLOCK_N, + "num_stages": info["num_stages"], + "num_warps": info["num_warps"], + } + + mm_filename = get_mm_log_filename() + if mm_filename and "mm" in name: + M, K = input_nodes[-2].get_size()[:2] + N = input_nodes[-1].get_size()[-1] + + out_dict = { + str((M, K, N)): [get_choice_info(choice) for choice in timings.keys()] + } + + append_to_log(mm_filename, out_dict) + best_time = timings[best] sys.stderr.write(f"AUTOTUNE {name}({sizes})\n") for choice in top_k: diff --git a/torch/_utils_internal.py b/torch/_utils_internal.py index a61adb6a826b6..ac07f588107a2 100644 --- a/torch/_utils_internal.py +++ b/torch/_utils_internal.py @@ -3,7 +3,7 @@ import os import sys import tempfile -from typing import Any, Dict +from typing import Any, Dict, Optional import torch @@ -195,6 +195,6 @@ def max_clock_rate(): REQUIRES_SET_PYTHON_MODULE = False -def maybe_upload_prof_stats_to_manifold(profile_path: str) -> None: +def maybe_upload_prof_stats_to_manifold(profile_path: str) -> Optional[str]: print("Uploading profile stats (fb-only otherwise no-op)") - pass + return None diff --git a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py index 4cc05e46c6a70..ecb9a14c0a4c6 100644 --- a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py @@ -776,10 +776,8 @@ def _annotate_conv2d_fusion_pattern(self, model: torch.fx.GraphModule): def _annotate_linear_fusion_pattern(self, model: torch.fx.GraphModule): if config := self._get_aten_operator_qconfig(torch.ops.aten.linear.default): - if config.input_activation and not config.input_activation.is_dynamic: - # Weiwen: Dynamic Quant of linear unary will be supported in next step - self._annotate_linear_binary_unary(model, config) - self._annotate_linear_unary(model, config) + self._annotate_linear_binary_unary(model, config) + self._annotate_linear_unary(model, config) self._annotate_linear(model, config) def _annotate_matmul(self, model: torch.fx.GraphModule): diff --git a/torch/csrc/distributed/c10d/HashStore.hpp b/torch/csrc/distributed/c10d/HashStore.hpp index 3697d62301ba3..1453c0a72808a 100644 --- a/torch/csrc/distributed/c10d/HashStore.hpp +++ b/torch/csrc/distributed/c10d/HashStore.hpp @@ -22,7 +22,7 @@ class TORCH_API HashStore : public Store { std::vector get(const std::string& key) override; void wait(const std::vector& keys) override { - wait(keys, timeout_); + wait(keys, Store::kDefaultTimeout); } void wait( diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 7586058475ff1..2319db06db643 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -931,7 +931,7 @@ void ProcessGroupNCCL::setSequenceNumberForGroup() { } // NCCL just starts sequence numbers at 0. uint64_t ProcessGroupNCCL::getSequenceNumberForGroup() { - return seq_; + return seqCollective_; } void ProcessGroupNCCL::registerOnCompletionHook( @@ -2246,7 +2246,7 @@ c10::intrusive_ptr ProcessGroupNCCL::initWork( device, rank, opType, - seq_, + seqCollective_, profilingTitle, profilingTitle != nullptr ? std::optional>(inputs) : c10::nullopt, @@ -2254,6 +2254,7 @@ c10::intrusive_ptr ProcessGroupNCCL::initWork( enableTiming_.load(), dist_debug_level_); if (record) { + bool isP2P = isP2POp(opType); // Ideally record every work that we enqueue, rather than every work we // create. // - at the time of this PR we do not currently enqueue every created work @@ -2270,13 +2271,15 @@ c10::intrusive_ptr ProcessGroupNCCL::initWork( r->trace_id_ = NCCLTraceBuffer::get()->record( uid_, std::make_tuple(pg_name_, pg_desc_), - seq_, + seqCollective_, + seqP2P_, op_id_, profilingTitle ? profilingTitle : "", inputs, outputs, r->ncclStartEvent_.get(), - r->ncclEndEvent_.get()); + r->ncclEndEvent_.get(), + isP2P); } return r; } @@ -2328,10 +2331,6 @@ ProcessGroupNCCL::Options::Options(bool is_high_priority_stream) static constexpr int CoalActive = 0x01, CoalColl = 0x02, CoalP2P = 0x04; void ProcessGroupNCCL::startCoalescing() { - coalescedDevice_.set_index(-1); - coalescedComm_ = nullptr; - coalescing_state_ |= CoalActive; - groupStart(); // Other collective ops bump seq_ before creating a work. Thus, if coalesced // ops bump seq_ only after initing a work they will collide with (reuse) the // seq_ of the last non-coalesced collective. Previously, seq_ was bumped @@ -2340,10 +2339,19 @@ void ProcessGroupNCCL::startCoalescing() { // same seq_ for those ops and its 'endCoalescing' op. Hence we bump during // start, which has one minor downside- we burn a seq_ if someone ever does a // 'start' and 'end' coalescing region without doing an operation inbetween. - seq_++; - // Don't bump op_id_ here, becuase startCoalescing isn't a logical operation. + // Don't bump op_id_ here, because startCoalescing isn't a logical operation. // Bump it for each logical op inside the coalescing group. + if (coalescing_state_ & CoalP2P) { + seqP2P_++; + } else { + seqCollective_++; + } + + coalescedDevice_.set_index(-1); + coalescedComm_ = nullptr; + coalescing_state_ |= CoalActive; + groupStart(); } // `optype` is for specifying a composite optype, such as ALLGATHER and @@ -2441,7 +2449,7 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( errorIfCapturingNonCapturableNCCL(capture_status); // Bump collective counter - seq_++; + seqCollective_++; op_id_++; auto device = getDevice(input); @@ -2596,9 +2604,10 @@ c10::intrusive_ptr ProcessGroupNCCL::collectiveCoalesced( errorIfCapturingNonCapturableNCCL(capture_status); // Bump collective counter - seq_++; + seqCollective_++; + // For coalescingManager collectives, there is no individual c++ call per - // collective so there is no flight record and we increment seq_ and op_id_ + // collective so there is no flight record and we increment seq*_ and op_id_ // together. Compare this to startCoalesing/endCoalescing flow where we // increment seq_ once per group and increment op_id_ once per indvidual // operation within the group @@ -2826,9 +2835,9 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( p2pTargetRank = isSendRecvSelf ? 0 : 1 - p2pRank; if (!coalescing_state_) { - // Bump sequence number. Don't do so if it's a batch P2P, it will be - // bumped in `endCoalescing`. - seq_++; + // Bump P2P sequence number. Don't do so if it's a batch P2P, it will be + // bumped in `startCoalescing`. + seqP2P_++; } } @@ -2869,13 +2878,15 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( auto trace_id = NCCLTraceBuffer::get()->record( uid_, std::make_tuple(pg_name_, pg_desc_), - seq_, + seqCollective_, + seqP2P_, op_id_, profilingTitle, {tensor}, {tensor}, nullptr, - nullptr); + nullptr, + /*isP2P=*/true); // TODO(whc) if we want to make the per-p2p-op flightrecorder entries get // their timings/states updated by proxy when the Work obj representing the // coalesce group gets its update, we could accumulate these trace_ids @@ -2894,19 +2905,21 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( // output, not sure what work->outputs_ = std::make_shared>(); work->outputs_->push_back(tensor); - // TODO(whc) becuase we don't pass output {tensor} to initWork, we tell + // TODO(whc) because we don't pass output {tensor} to initWork, we tell // initWork to not record, and then we manually call record passing all the // information it wants. work->trace_id_ = NCCLTraceBuffer::get()->record( uid_, std::make_tuple(pg_name_, pg_desc_), - seq_, + seqCollective_, + seqP2P_, op_id_, profilingTitle, {tensor}, {tensor}, work->ncclStartEvent_.get(), - work->ncclEndEvent_.get()); + work->ncclEndEvent_.get(), + /*isP2P=*/true); } // is gpuGuard needed for the if block below, or can i swap them diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index 07f3730b1338b..995ae003a1cf0 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -1055,13 +1055,16 @@ class TORCH_API ProcessGroupNCCL : public Backend { // Counting for the sequential number of NCCL collective call. // (specifically, how many actual kernels we launched, which differs from // op_id_ when coalescing is enabled) - uint64_t seq_{0}; + uint64_t seqCollective_{0}; + + // Counting for the sequential number of NCCL P2P calls. + uint64_t seqP2P_{0}; // Incrementing counter for logical operations (collective or p2p) issued on // the ProcessGroup uint64_t op_id_{0}; - // the sequential number of the last colletive enqueued into workMetaList_ + // the sequential number of the last collective enqueued into workMetaList_ // This is useful for indentifying a rank that has not join a collective // initialized to be -1 to indicate no collective has been enqueued int64_t lastEnqueuedSeq_{-1}; @@ -1069,10 +1072,10 @@ class TORCH_API ProcessGroupNCCL : public Backend { // the name of the last collective enqueued into workMetaList_ std::string lastEnqueuedWorkName_; - // the sequential number of the last colletive started as the kernal + // the sequential number of the last collective started as the kernel int64_t lastStartedSeq_{-1}; - // the name of the last collective started as the kernal + // the name of the last collective started as the kernel std::string lastStartedWorkName_; // the sequential number of the last colletive completed marked by diff --git a/torch/csrc/distributed/c10d/Store.hpp b/torch/csrc/distributed/c10d/Store.hpp index 993284fa7cc56..af715ba98a794 100644 --- a/torch/csrc/distributed/c10d/Store.hpp +++ b/torch/csrc/distributed/c10d/Store.hpp @@ -97,33 +97,4 @@ class TORCH_API Store : public torch::CustomClassHolder { std::chrono::milliseconds timeout_; }; -/* -StoreTimeoutGuard is a RAII guard that will set the store timeout and restore it -when it returns. -*/ -class StoreTimeoutGuard { - public: - explicit StoreTimeoutGuard( - Store& store, - const std::chrono::milliseconds& timeout) - : store_(store) { - oldTimeout_ = store.getTimeout(); - store.setTimeout(timeout); - } - - ~StoreTimeoutGuard() { - store_.setTimeout(oldTimeout_); - } - - /* Disabling copy and move semantics */ - StoreTimeoutGuard(const StoreTimeoutGuard&) = delete; - StoreTimeoutGuard& operator=(const StoreTimeoutGuard&) = delete; - StoreTimeoutGuard(StoreTimeoutGuard&&) = delete; - StoreTimeoutGuard& operator=(StoreTimeoutGuard&&) = delete; - - private: - Store& store_; - std::chrono::milliseconds oldTimeout_; -}; - } // namespace c10d diff --git a/torch/csrc/distributed/c10d/TraceUtils.h b/torch/csrc/distributed/c10d/TraceUtils.h index 181f2208160b7..5b2fcc45c8f3f 100644 --- a/torch/csrc/distributed/c10d/TraceUtils.h +++ b/torch/csrc/distributed/c10d/TraceUtils.h @@ -2,17 +2,23 @@ #include #include +#include #include #include +#include + +#include #include #include #include #include +#include #include #include #include + namespace c10d { static c10::IValue entries_key = "entries"; @@ -20,12 +26,14 @@ static c10::IValue nccl_comm_key = "nccl_comm_state"; static c10::IValue version_key = "version"; // Update whenever changing contents or formatting of the dump // (minor when adding fields, major when changing existing fields) -static c10::IValue version_val = "1.5"; +static c10::IValue version_val = "2.0"; static c10::IValue pg_config_key = "pg_config"; static c10::IValue record_id_key = "record_id"; static c10::IValue pg_id_key = "pg_id"; static c10::IValue pg_name_key = "process_group"; -static c10::IValue seq_id_key = "seq_id"; +static c10::IValue collective_seq_id_key = "collective_seq_id"; +static c10::IValue p2p_seq_id_key = "p2p_seq_id"; +static c10::IValue is_p2p_key = "is_p2p"; static c10::IValue op_id_key = "op_id"; static c10::IValue profiling_name_key = "profiling_name"; static c10::IValue input_sizes_key = "input_sizes"; @@ -428,11 +436,14 @@ struct NCCLTraceBuffer { size_t pg_id_; std::tuple pg_name_; // - // Both seq_id_ and op_id_ are per_pg incrementing counters - // seq_id refers to actual kernel launches (e.g. 1 per coalesced group) - // op_id refers to logical operations (e.g. one per op inside coalesced - // group) - size_t seq_id_; + // collective_seq_id and p2p_seq_id refer to actual kernel launches (e.g. 1 + // per coalesced group). + // collective_seq_id only increments for true collective operations (over + // all ranks in the group). p2p_seq_id only increments over non-collective + // operations in the group. op_id refers to logical operations (e.g. one per + // op inside coalesced group) + size_t collective_seq_id_; + size_t p2p_seq_id_; size_t op_id_; std::string profiling_name_; @@ -445,6 +456,10 @@ struct NCCLTraceBuffer { // timestamp when the entry was created, likely close to the time the work // was 'enqueued'- not necessarily started c10::time_t time_created_; + + // Is this a P2P event? + bool isP2P_; + std::optional duration_; // timestamp when our CPU threads discovered that the kernel started. @@ -479,13 +494,15 @@ struct NCCLTraceBuffer { std::optional record( size_t pg_id, const std::tuple& pg_name, - size_t seq_id, + size_t collective_seq_id, + size_t p2p_seq_id, size_t op_id, std::string profiling_name, const std::vector& inputs, const std::vector& outputs, Event* start, - Event* end) { + Event* end, + bool isP2P) { if (!enabled_) { return c10::nullopt; } @@ -497,13 +514,15 @@ struct NCCLTraceBuffer { id_, pg_id, pg_name, - seq_id, + collective_seq_id, + p2p_seq_id, op_id, std::move(profiling_name), std::move(traceback), std::move(start), std::move(end), - c10::getTime()}; + c10::getTime(), + isP2P}; for (const auto& input : inputs) { c10::IntArrayRef sizes = input.sizes(); @@ -656,7 +675,8 @@ struct NCCLTraceBuffer { dict.insert(record_id_key, int64_t(e.id_)); dict.insert(pg_id_key, int64_t(e.pg_id_)); dict.insert(pg_name_key, e.pg_name_); - dict.insert(seq_id_key, int64_t(e.seq_id_)); + dict.insert(collective_seq_id_key, int64_t(e.collective_seq_id_)); + dict.insert(p2p_seq_id_key, int64_t(e.p2p_seq_id_)); dict.insert(op_id_key, int64_t(e.op_id_)); dict.insert(profiling_name_key, e.profiling_name_); dict.insert(time_created_key, int64_t(e.time_created_)); @@ -699,6 +719,7 @@ struct NCCLTraceBuffer { ? int64_t(*e.time_discovered_completed_) : c10::IValue()); dict.insert(retired_key, e.retired_); + dict.insert(is_p2p_key, e.isP2P_); auto frames = new_list(); for (int64_t frame : tb) { diff --git a/torch/csrc/distributed/c10d/control_collectives/ControlCollectives.hpp b/torch/csrc/distributed/c10d/control_collectives/ControlCollectives.hpp deleted file mode 100644 index b98f9a71fb024..0000000000000 --- a/torch/csrc/distributed/c10d/control_collectives/ControlCollectives.hpp +++ /dev/null @@ -1,59 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include - -#include -#include - -namespace c10d { - -using namespace std::chrono_literals; - -class TORCH_API ControlCollectives : public torch::CustomClassHolder { - public: - virtual void barrier( - const std::string& key, - std::chrono::milliseconds timeout = 5min, - bool block = true) = 0; - - virtual void broadcastSend( - const std::string& key, - const std::vector& data, - std::chrono::milliseconds timeout = 5min) = 0; - virtual std::vector broadcastRecv( - const std::string& key, - std::chrono::milliseconds timeout = 5min) = 0; - - virtual void gatherSend( - const std::string& key, - const std::vector& data, - std::chrono::milliseconds timeout = 5min) = 0; - virtual std::vector> gatherRecv( - const std::string& key, - const std::vector& data, - std::chrono::milliseconds timeout = 5min) = 0; - - virtual std::vector scatterSend( - const std::string& key, - const std::vector>& data, - std::chrono::milliseconds timeout = 5min) = 0; - virtual std::vector scatterRecv( - const std::string& key, - std::chrono::milliseconds timeout = 5min) = 0; - - virtual std::vector> allGather( - const std::string& key, - const std::vector& data, - std::chrono::milliseconds timeout = 5min) = 0; - - virtual int64_t allSum( - const std::string& key, - int64_t data, - std::chrono::milliseconds timeout = 5min) = 0; -}; - -} // namespace c10d diff --git a/torch/csrc/distributed/c10d/control_collectives/StoreCollectives.cpp b/torch/csrc/distributed/c10d/control_collectives/StoreCollectives.cpp deleted file mode 100644 index 995899441d461..0000000000000 --- a/torch/csrc/distributed/c10d/control_collectives/StoreCollectives.cpp +++ /dev/null @@ -1,222 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include - -namespace { -std::string getRankKey(const std::string& key, int rank) { - return fmt::format("{}/{}", key, rank); -} -} // namespace - -namespace c10d { - -StoreCollectives::StoreCollectives( - c10::intrusive_ptr<::c10d::Store> store, - int rank, - int worldSize) - : store_(std::move(store)), rank_(rank), worldSize_(worldSize) {} - -void StoreCollectives::barrier( - const std::string& key, - std::chrono::milliseconds timeout, - bool blocking) { - enforceUnique(key); - StoreTimeoutGuard g{*store_, timeout}; - - auto num_members_key = fmt::format("{}/num_members", key); - auto last_members_key = fmt::format("{}/last_members", key); - - auto idx = store_->add(num_members_key, 1); - store_->set(getRankKey(key, rank_), "joined"); - - if (idx == worldSize_) { - store_->set(last_members_key, ""); - } else if (blocking) { - try { - store_->wait({last_members_key}); - } catch (const std::exception& e) { - std::string msg = "barrier failed -- missing ranks: "; - for (int i = 0; i < worldSize_; i++) { - if (i == rank_) { - continue; - } - auto rank_key = getRankKey(key, i); - if (!store_->check({rank_key})) { - msg += fmt::format("{}, ", i); - } - } - throw std::runtime_error(msg + e.what()); - } - } -} - -void StoreCollectives::broadcastSend( - const std::string& key, - const std::vector& data, - std::chrono::milliseconds timeout) { - enforceUnique(key); - StoreTimeoutGuard g{*store_, timeout}; - - store_->set(key, data); -} - -std::vector StoreCollectives::broadcastRecv( - const std::string& key, - std::chrono::milliseconds timeout) { - enforceUnique(key); - StoreTimeoutGuard g{*store_, timeout}; - - return store_->get(key); -} - -void StoreCollectives::gatherSend( - const std::string& key, - const std::vector& data, - std::chrono::milliseconds timeout) { - enforceUnique(key); - StoreTimeoutGuard g{*store_, timeout}; - - auto rank_key = getRankKey(key, rank_); - store_->set(rank_key, data); -} - -std::vector> StoreCollectives::gatherRecv( - const std::string& key, - const std::vector& data, - std::chrono::milliseconds timeout) { - enforceUnique(key); - StoreTimeoutGuard g{*store_, timeout}; - - std::vector keys; - keys.reserve(worldSize_); - - for (int i = 0; i < worldSize_; i++) { - if (i == rank_) { - continue; - } - auto rank_key = getRankKey(key, i); - keys.emplace_back(rank_key); - } - - std::vector> results; - results.reserve(worldSize_); - - try { - results = store_->multiGet(keys); - } catch (const std::exception& e) { - std::string msg = "gather failed -- missing ranks: "; - for (int i = 0; i < worldSize_; i++) { - if (i == rank_) { - continue; - } - auto rank_key = getRankKey(key, i); - if (!store_->check({rank_key})) { - msg += fmt::format("{}, ", i); - } - } - throw std::runtime_error(msg + e.what()); - } - - // insert local data - results.insert(results.begin() + rank_, data); - return results; -} - -std::vector StoreCollectives::scatterSend( - const std::string& key, - const std::vector>& data, - std::chrono::milliseconds timeout) { - enforceUnique(key); - StoreTimeoutGuard g{*store_, timeout}; - - std::vector keys; - keys.reserve(worldSize_); - for (int i = 0; i < worldSize_; i++) { - if (i == rank_) { - continue; - } - auto rank_key = getRankKey(key, i); - keys.emplace_back(rank_key); - } - auto local = data.at(rank_); - - std::vector> toSend{data}; - - toSend.erase(toSend.begin() + rank_); - - store_->multiSet(keys, toSend); - - return local; -} - -std::vector StoreCollectives::scatterRecv( - const std::string& key, - std::chrono::milliseconds timeout) { - enforceUnique(key); - StoreTimeoutGuard g{*store_, timeout}; - - auto rank_key = getRankKey(key, rank_); - return store_->get(rank_key); -} - -std::vector> StoreCollectives::allGather( - const std::string& key, - const std::vector& data, - std::chrono::milliseconds timeout) { - enforceUnique(key); - StoreTimeoutGuard g{*store_, timeout}; - - auto localKey = getRankKey(key, rank_); - store_->set(localKey, data); - - std::vector keys; - keys.reserve(worldSize_); - - for (int i = 0; i < worldSize_; i++) { - auto rank_key = getRankKey(key, i); - keys.emplace_back(rank_key); - } - - try { - return store_->multiGet(keys); - } catch (const std::exception& e) { - std::string msg = "all_gather failed -- missing ranks: "; - for (int i = 0; i < worldSize_; i++) { - if (i == rank_) { - continue; - } - auto rank_key = getRankKey(key, i); - if (!store_->check({rank_key})) { - msg += fmt::format("{}, ", i); - } - } - throw std::runtime_error(msg + e.what()); - } -} - -int64_t StoreCollectives::allSum( - const std::string& key, - int64_t value, - std::chrono::milliseconds timeout) { - enforceUnique(key); - StoreTimeoutGuard g{*store_, timeout}; - - store_->add(key, value); - - barrier(key + "/barrier", timeout); - - return store_->add(key, 0); -} - -void StoreCollectives::enforceUnique(const std::string& key) { - auto it = seenKeys_.find(key); - TORCH_INTERNAL_ASSERT( - it == seenKeys_.end(), "Key ", key, " has already been used."); - seenKeys_.emplace(key); -} - -} // namespace c10d diff --git a/torch/csrc/distributed/c10d/control_collectives/StoreCollectives.hpp b/torch/csrc/distributed/c10d/control_collectives/StoreCollectives.hpp deleted file mode 100644 index 7d3eb5038565e..0000000000000 --- a/torch/csrc/distributed/c10d/control_collectives/StoreCollectives.hpp +++ /dev/null @@ -1,68 +0,0 @@ -#pragma once - -#include -#include -#include -#include - -namespace c10d { - -class TORCH_API StoreCollectives : public ControlCollectives { - public: - explicit StoreCollectives( - c10::intrusive_ptr store, - int rank, - int worldSize); - - void barrier( - const std::string& key, - std::chrono::milliseconds timeout = 5min, - bool block = true) override; - - void broadcastSend( - const std::string& key, - const std::vector& data, - std::chrono::milliseconds timeout = 5min) override; - std::vector broadcastRecv( - const std::string& key, - std::chrono::milliseconds timeout = 5min) override; - - void gatherSend( - const std::string& key, - const std::vector& data, - std::chrono::milliseconds timeout = 5min) override; - std::vector> gatherRecv( - const std::string& key, - const std::vector& data, - std::chrono::milliseconds timeout = 5min) override; - - std::vector scatterSend( - const std::string& key, - const std::vector>& data, - std::chrono::milliseconds timeout = 5min) override; - std::vector scatterRecv( - const std::string& key, - std::chrono::milliseconds timeout = 5min) override; - - std::vector> allGather( - const std::string& key, - const std::vector& data, - std::chrono::milliseconds timeout = 5min) override; - - int64_t allSum( - const std::string& key, - int64_t data, - std::chrono::milliseconds timeout = 5min) override; - - private: - void enforceUnique(const std::string& key); - - private: - c10::intrusive_ptr store_; - int rank_; - int worldSize_; - - c10::FastSet seenKeys_{}; -}; - -} // namespace c10d diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 505b64e2a6976..483becbce0094 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -6,9 +6,6 @@ #include #include #include -#include -#include -#include #ifndef _WIN32 #include #include @@ -139,34 +136,6 @@ namespace torch::distributed::c10d { namespace { -py::bytes toPyBytes(const std::vector& data) { - return py::bytes(reinterpret_cast(data.data()), data.size()); -} - -std::vector toPyBytes( - const std::vector>& data) { - std::vector out; - out.reserve(data.size()); - for (const std::vector& data_ : data) { - out.emplace_back(reinterpret_cast(data_.data()), data_.size()); - } - return out; -} - -std::vector toVec8(const std::string& data) { - std::vector out{data.begin(), data.end()}; - return out; -} - -std::vector> toVec8(const std::vector& data) { - std::vector> out; - out.reserve(data.size()); - for (auto& data_ : data) { - out.emplace_back(toVec8(data_)); - } - return out; -} - template using shared_ptr_class_ = py::class_>; @@ -197,7 +166,8 @@ class PythonStore : public ::c10d::Store { pybind11::get_overload(static_cast(this), "set"); TORCH_INTERNAL_ASSERT(fn, "Not implemented."); // Call function with a py::bytes object for the value. - fn(key, toPyBytes(value)); + fn(key, + py::bytes(reinterpret_cast(value.data()), value.size())); } // Note: this function manually calls the Python-side overload @@ -214,7 +184,7 @@ class PythonStore : public ::c10d::Store { // std::vector. There is no API for directly accessing // the contents of the py::bytes object. std::string str = pybind11::cast(fn(key)); - return toVec8(str); + return std::vector(str.begin(), str.end()); } // Note: this function manually calls the Python-side overload @@ -234,8 +204,14 @@ class PythonStore : public ::c10d::Store { // std::vector. There is no API for directly accessing // the contents of the py::bytes object. std::string str = pybind11::cast( - fn(key, toPyBytes(expectedValue), toPyBytes(desiredValue))); - return toVec8(str); + fn(key, + py::bytes( + reinterpret_cast(expectedValue.data()), + expectedValue.size()), + py::bytes( + reinterpret_cast(desiredValue.data()), + desiredValue.size()))); + return std::vector(str.begin(), str.end()); } int64_t add(const std::string& key, int64_t value) override { @@ -277,7 +253,8 @@ class PythonStore : public ::c10d::Store { return Store::append(key, value); } // Call function with a py::bytes object for the value. - fn(key, toPyBytes(value)); + fn(key, + py::bytes(reinterpret_cast(value.data()), value.size())); } std::vector> multiGet( @@ -310,7 +287,14 @@ class PythonStore : public ::c10d::Store { return Store::multiSet(keys, values); } - fn(keys, toPyBytes(values)); + std::vector bytes; + bytes.reserve(values.size()); + for (auto& value : values) { + bytes.emplace_back( + reinterpret_cast(value.data()), value.size()); + } + + fn(keys, bytes); } bool hasExtendedApi() const override { @@ -989,7 +973,10 @@ and :class:`~torch.distributed.HashStore`). "set", [](::c10d::Store& store, const std::string& key, - const std::string& value) { store.set(key, toVec8(value)); }, + const std::string& value) { + std::vector value_(value.begin(), value.end()); + store.set(key, value_); + }, py::call_guard(), R"( Inserts the key-value pair into the store based on the supplied ``key`` and @@ -1014,9 +1001,14 @@ Example:: const std::string& key, const std::string& expected_value, const std::string& desired_value) -> py::bytes { - auto value = store.compareSet( - key, toVec8(expected_value), toVec8(desired_value)); - return toPyBytes(value); + std::vector expectedValue_( + expected_value.begin(), expected_value.end()); + std::vector desiredValue_( + desired_value.begin(), desired_value.end()); + auto value = + store.compareSet(key, expectedValue_, desiredValue_); + return py::bytes( + reinterpret_cast(value.data()), value.size()); }, py::call_guard(), R"( @@ -1048,7 +1040,8 @@ Example:: py::gil_scoped_release guard; return store.get(key); }(); - return toPyBytes(value); + return py::bytes( + reinterpret_cast(value.data()), value.size()); }, R"( Retrieves the value associated with the given ``key`` in the store. If ``key`` is not @@ -1247,7 +1240,8 @@ Example:: [](::c10d::Store& store, const std::string& key, const std::string& value) { - store.append(key, toVec8(value)); + std::vector value_(value.begin(), value.end()); + store.append(key, value_); }, py::call_guard(), R"( @@ -1274,7 +1268,14 @@ Example:: py::gil_scoped_release guard; return store.multiGet(keys); }(); - return toPyBytes(values); + std::vector res; + for (auto& value : values) { + auto bytes = py::bytes( + reinterpret_cast(value.data()), + value.size()); + res.push_back(bytes); + } + return res; }, R"( Retrieve all values in ``keys``. If any key in ``keys`` is not @@ -1297,7 +1298,12 @@ Example:: [](::c10d::Store& store, const std::vector& keys, const std::vector& values) { - store.multiSet(keys, toVec8(values)); + std::vector> vals; + vals.reserve(values.size()); + for (auto& value : values) { + vals.emplace_back(value.begin(), value.end()); + } + store.multiSet(keys, vals); }, py::call_guard(), R"( @@ -1481,212 +1487,6 @@ that adds a prefix to each key inserted to the store. &::c10d::PrefixStore::getUnderlyingNonPrefixStore, R"(Recursively to get the store before layers of wrapping with PrefixStore.)"); - using namespace std::chrono_literals; - - auto collectives = - py::class_< - ::c10d::ControlCollectives, - c10::intrusive_ptr<::c10d::ControlCollectives>>( - module, - "_ControlCollectives", - R"( -Base class for all ControlCollectives implementations. -)") - .def( - "barrier", - &::c10d::ControlCollectives::barrier, - py::arg("key"), - py::arg("timeout") = 5min, - py::arg("block") = true, - py::call_guard(), - R"( -Blocks until all workers have entered this function. - -Arguments: - key (str): The unique key used to identify this operation. - timeout (duration): The timeout for this operation. - block (bool): whether to block this working waiting on the results of the barrier. -)") - .def( - "all_sum", - &::c10d::ControlCollectives::allSum, - py::arg("key"), - py::arg("data"), - py::arg("timeout") = 5min, - py::call_guard(), - R"( -Computes a sum across all workers and returns the final value. - -Arguments: - key (str): The unique key used to identify this operation. - data (int): The data to sum. - timeout (duration): The timeout for this operation. -)") - .def( - "broadcast_send", - [](::c10d::ControlCollectives& collectives, - const std::string& key, - const std::string& data, - std::chrono::milliseconds timeout = 5min) { - collectives.broadcastSend(key, toVec8(data), timeout); - }, - py::arg("key"), - py::arg("data"), - py::arg("timeout") = 5min, - py::call_guard(), - R"( -Sends data to all other workers. Must be only called from one worker. - -Arguments: - key (str): The unique key used to identify this operation. - data (str): The data to send. - timeout (duration): The timeout for this operation. -)") - .def( - "broadcast_recv", - [](::c10d::ControlCollectives& collectives, - const std::string& key, - std::chrono::milliseconds timeout = 5min) { - auto out = [&]() { - py::gil_scoped_release guard; - return collectives.broadcastRecv(key, timeout); - }(); - return toPyBytes(out); - }, - py::arg("key"), - py::arg("timeout") = 5min, - R"( -Receives data broadcasted from 1 worker. - -Arguments: - key (str): The unique key used to identify this operation. - timeout (duration): The timeout for this operation. -)") - .def( - "gather_send", - [](::c10d::ControlCollectives& collectives, - const std::string& key, - const std::string& data, - std::chrono::milliseconds timeout = 5min) { - collectives.gatherSend(key, toVec8(data), timeout); - }, - py::arg("key"), - py::arg("data"), - py::arg("timeout") = 5min, - py::call_guard(), - R"( -Sends data to one other worker. - -Arguments: - key (str): The unique key used to identify this operation. - data (str): The data to send. - timeout (duration): The timeout for this operation. -)") - .def( - "gather_recv", - [](::c10d::ControlCollectives& collectives, - const std::string& key, - const std::string& data, - std::chrono::milliseconds timeout = 5min) { - auto out = [&]() { - py::gil_scoped_release guard; - return collectives.gatherRecv(key, toVec8(data), timeout); - }(); - return toPyBytes(out); - }, - py::arg("key"), - py::arg("data"), - py::arg("timeout") = 5min, - R"( -Receives data broadcasted from all workers. Must only be called by one worker. - -Arguments: - key (str): The unique key used to identify this operation. - timeout (duration): The timeout for this operation. -)") - - .def( - "scatter_send", - [](::c10d::ControlCollectives& collectives, - const std::string& key, - const std::vector& data, - std::chrono::milliseconds timeout = 5min) { - auto out = [&]() { - py::gil_scoped_release guard; - return collectives.scatterSend(key, toVec8(data), timeout); - }(); - return toPyBytes(out); - }, - py::arg("key"), - py::arg("data"), - py::arg("timeout") = 5min, - R"( -Sends rank specific data to all other workers. - -Arguments: - key (str): The unique key used to identify this operation. - data (str): The data to send. - timeout (duration): The timeout for this operation. -)") - .def( - "scatter_recv", - [](::c10d::ControlCollectives& collectives, - const std::string& key, - std::chrono::milliseconds timeout = 5min) { - auto out = [&]() { - py::gil_scoped_release guard; - return collectives.scatterRecv(key, timeout); - }(); - return toPyBytes(out); - }, - py::arg("key"), - py::arg("timeout") = 5min, - R"( -Receives rank specific data from one worker. - -Arguments: - key (str): The unique key used to identify this operation. - timeout (duration): The timeout for this operation. -)") - - .def( - "all_gather", - [](::c10d::ControlCollectives& collectives, - const std::string& key, - const std::string& data, - std::chrono::milliseconds timeout = 5min) { - auto out = [&]() { - py::gil_scoped_release guard; - return collectives.allGather(key, toVec8(data), timeout); - }(); - return toPyBytes(out); - }, - py::arg("key"), - py::arg("data"), - py::arg("timeout") = 5min, - R"( -Sends data to all workers and receives data from all other workers. - -Arguments: - key (str): The unique key used to identify this operation. - data (str): The data to send. - timeout (duration): The timeout for this operation. -)"); - - intrusive_ptr_class_<::c10d::StoreCollectives>( - module, - "_StoreCollectives", - collectives, - R"( -An implementation of ControlCollectives that uses the provided store as the underlying -communication mechanism. - )") - .def( - py::init, int, int>(), - py::arg("store"), - py::arg("rank"), - py::arg("world_size")); - auto processGroup = py::class_< ::c10d::ProcessGroup, diff --git a/torch/csrc/jit/mobile/import.cpp b/torch/csrc/jit/mobile/import.cpp index 96ff6c88779d9..da7b87bae6110 100644 --- a/torch/csrc/jit/mobile/import.cpp +++ b/torch/csrc/jit/mobile/import.cpp @@ -651,6 +651,10 @@ mobile::Module _load_for_mobile( std::optional device, ExtraFilesMap& extra_files, uint64_t module_load_options) { + auto observer = torch::observerConfig().getModuleObserver(); + if (observer) { + extra_files.insert(std::make_pair("model_path", filename)); + } auto format = getFileFormat(filename); if (format == FileFormat::FlatbufferFileFormat) { diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py index 1cc88e8adc578..8c19788d1055d 100644 --- a/torch/cuda/__init__.py +++ b/torch/cuda/__init__.py @@ -53,9 +53,18 @@ _HAS_PYNVML = False _PYNVML_ERR = None try: - import pynvml # type: ignore[import] + try: + import pynvml # type: ignore[import] + + _HAS_PYNVML = True + except ModuleNotFoundError: + pass + try: + import amdsmi # type: ignore[import] - _HAS_PYNVML = True + _HAS_PYNVML = True + except ModuleNotFoundError: + pass except ImportError as err: _PYNVML_ERR = err # sometimes a lib is installed but the import fails for some other reason, so we log the error for later @@ -571,7 +580,9 @@ def set_stream(stream: Stream): def _parse_visible_devices() -> Union[List[int], List[str]]: r"""Parse CUDA_VISIBLE_DEVICES environment variable.""" - var = os.getenv("CUDA_VISIBLE_DEVICES") + var = os.getenv( + "CUDA_VISIBLE_DEVICES" if not torch.version.hip else "HIP_VISIBLE_DEVICES" + ) if var is None: return list(range(64)) @@ -617,6 +628,16 @@ def parse_list_with_prefix(lst: str, prefix: str) -> List[str]: return rc +def _raw_device_count_amdsmi() -> int: + try: + amdsmi.amdsmi_init() + except amdsmi.AmdSmiException as e: + warnings.warn(f"Can't initialize amdsmi - Error code: {e.err_code}") + return -1 + socket_handles = amdsmi.amdsmi_get_processor_handles() + return len(socket_handles) + + def _raw_device_count_nvml() -> int: r"""Return number of devices as reported by NVML or negative value if NVML discovery/initialization failed.""" from ctypes import byref, c_int, CDLL @@ -635,6 +656,36 @@ def _raw_device_count_nvml() -> int: return dev_count.value +def _raw_device_uuid_amdsmi() -> Optional[List[str]]: + from ctypes import byref, c_int, c_void_p, CDLL, create_string_buffer + + try: + amdsmi.amdsmi_init() + except amdsmi.AmdSmiException: + warnings.warn("Can't initialize amdsmi") + return None + try: + socket_handles = amdsmi.amdsmi_get_processor_handles() + dev_count = len(socket_handles) + except amdsmi.AmdSmiException: + warnings.warn("Can't get amdsmi device count") + return None + uuids: List[str] = [] + for idx in range(dev_count): + try: + handler = amdsmi.amdsmi_get_processor_handles()[idx] + except amdsmi.AmdSmiException: + warnings.warn("Cannot get amd device handler") + return None + try: + uuid = amdsmi.amdsmi_get_gpu_device_uuid(handler) + except amdsmi.AmdSmiException: + warnings.warn("Cannot get uuid for amd device") + return None + uuids.append(str(uuid)) + return uuids + + def _raw_device_uuid_nvml() -> Optional[List[str]]: r"""Return list of device UUID as reported by NVML or None if NVM discovery/initialization failed.""" from ctypes import byref, c_int, c_void_p, CDLL, create_string_buffer @@ -694,6 +745,28 @@ def uuid_to_orinal(candidate: str, uuids: List[str]) -> int: return rc +def _device_count_amdsmi() -> int: + visible_devices = _parse_visible_devices() + if not visible_devices: + return 0 + try: + if type(visible_devices[0]) is str: + return -1 + else: + raw_cnt = _raw_device_count_amdsmi() + if raw_cnt <= 0: + return raw_cnt + # Trim the list up to a maximum available device + for idx, val in enumerate(visible_devices): + if cast(int, val) >= raw_cnt: + return idx + except OSError: + return -1 + except AttributeError: + return -1 + return len(visible_devices) + + def _device_count_nvml() -> int: r"""Return number of devices as reported by NVML taking CUDA_VISIBLE_DEVICES into account. @@ -758,7 +831,7 @@ def device_count() -> int: if _cached_device_count is not None: return _cached_device_count # bypass _device_count_nvml() if rocm (not supported) - nvml_count = -1 if torch.version.hip else _device_count_nvml() + nvml_count = _device_count_amdsmi() if torch.version.hip else _device_count_nvml() r = torch._C._cuda_getDeviceCount() if nvml_count < 0 else nvml_count # NB: Do not cache the device count prior to CUDA initialization, because # the number of devices can change due to changes to CUDA_VISIBLE_DEVICES @@ -916,6 +989,68 @@ def _get_pynvml_handler(device: Optional[Union[Device, int]] = None): return handle +def _get_amdsmi_handler(device: Optional[Union[Device, int]] = None): + if not _HAS_PYNVML: + raise ModuleNotFoundError( + "amdsmi does not seem to be installed or it can't be imported." + ) from _PYNVML_ERR + try: + amdsmi.amdsmi_init() + except amdsmi.AmdSmiException as e: + raise RuntimeError( + "amdsmi driver can't be loaded, requires >=ROCm5.6 installation" + ) from e + device = _get_amdsmi_device_index(device) + handle = amdsmi.amdsmi_get_processor_handles()[device] + return handle + + +def _get_amdsmi_device_index(device: Optional[Union[int, Device]]) -> int: + r"""Return the amdsmi index of the device, taking HIP_VISIBLE_DEVICES into account.""" + idx = _get_device_index(device, optional=True) + visible_devices = _parse_visible_devices() + if type(visible_devices[0]) is str: + raise RuntimeError("HIP_VISIBLE_DEVICES should be indices and not strings") + idx_map = dict(enumerate(cast(List[int], visible_devices))) + if idx not in idx_map: + raise RuntimeError( + f"device {idx} is not visible (HIP_VISIBLE_DEVICES={visible_devices})" + ) + return idx_map[idx] + + +def _get_amdsmi_memory_usage(device: Optional[Union[Device, int]] = None) -> int: + handle = _get_amdsmi_handler() + device = _get_amdsmi_device_index(device) + return amdsmi.amdsmi_get_gpu_vram_usage(handle)["vram_used"] + + +def _get_amdsmi_utilization(device: Optional[Union[Device, int]] = None) -> int: + handle = _get_amdsmi_handler() + device = _get_amdsmi_device_index(device) + handle = amdsmi.amdsmi_get_processor_handles()[device] + return amdsmi.amdsmi_get_gpu_activity(handle)["gfx_activity"] + + +def _get_amdsmi_temperature(device: Optional[Union[Device, int]] = None) -> int: + handle = _get_amdsmi_handler(device) + return amdsmi.amdsmi_get_temp_metric( + handle, + amdsmi.AmdSmiTemperatureType.JUNCTION, + amdsmi.AmdSmiTemperatureMetric.CURRENT, + ) + + +def _get_amdsmi_power_draw(device: Optional[Union[Device, int]] = None) -> int: + handle = _get_amdsmi_handler(device) + return amdsmi.amdsmi_get_power_info(handle)["average_socket_power"] + + +def _get_amdsmi_clock_rate(device: Optional[Union[Device, int]] = None) -> int: + handle = _get_amdsmi_handler(device) + return amdsmi.amdsmi_get_clock_info(handle, amdsmi.AmdSmiClkType.GFX)["cur_clk"] + + def memory_usage(device: Optional[Union[Device, int]] = None) -> int: r"""Return the percent of time over the past sample period during which global (device) memory was being read or written as given by `nvidia-smi`. @@ -928,11 +1063,13 @@ def memory_usage(device: Optional[Union[Device, int]] = None) -> int: Warning: Each sample period may be between 1 second and 1/6 second, depending on the product being queried. """ - handle = _get_pynvml_handler() - - device = _get_nvml_device_index(device) - handle = pynvml.nvmlDeviceGetHandleByIndex(device) - return pynvml.nvmlDeviceGetUtilizationRates(handle).memory + if not torch.version.hip: + handle = _get_pynvml_handler() + device = _get_nvml_device_index(device) + handle = pynvml.nvmlDeviceGetHandleByIndex(device) + return pynvml.nvmlDeviceGetUtilizationRates(handle).memory + else: + return _get_amdsmi_memory_usage(device) def utilization(device: Optional[Union[Device, int]] = None) -> int: @@ -947,10 +1084,13 @@ def utilization(device: Optional[Union[Device, int]] = None) -> int: Warning: Each sample period may be between 1 second and 1/6 second, depending on the product being queried. """ - handle = _get_pynvml_handler(device) - device = _get_nvml_device_index(device) - handle = pynvml.nvmlDeviceGetHandleByIndex(device) - return pynvml.nvmlDeviceGetUtilizationRates(handle).gpu + if not torch.version.hip: + handle = _get_pynvml_handler(device) + device = _get_nvml_device_index(device) + handle = pynvml.nvmlDeviceGetHandleByIndex(device) + return pynvml.nvmlDeviceGetUtilizationRates(handle).gpu + else: + return _get_amdsmi_utilization(device) def temperature(device: Optional[Union[Device, int]] = None) -> int: @@ -966,9 +1106,12 @@ def temperature(device: Optional[Union[Device, int]] = None) -> int: Warning: Each sample period may be between 1 second and 1/6 second, depending on the product being queried. """ - handle = _get_pynvml_handler(device) - # 0 refers to the temperature sensor for the GPU die. - return pynvml.nvmlDeviceGetTemperature(handle, 0) + if not torch.version.hip: + handle = _get_pynvml_handler(device) + # 0 refers to the temperature sensor for the GPU die. + return pynvml.nvmlDeviceGetTemperature(handle, 0) + else: + return _get_amdsmi_temperature(device) def power_draw(device: Optional[Union[Device, int]] = None) -> int: @@ -983,8 +1126,11 @@ def power_draw(device: Optional[Union[Device, int]] = None) -> int: Warning: Each sample period may be between 1 second and 1/6 second, depending on the product being queried. """ - handle = _get_pynvml_handler(device) - return pynvml.nvmlDeviceGetPowerUsage(handle) + if not torch.version.hip: + handle = _get_pynvml_handler(device) + return pynvml.nvmlDeviceGetPowerUsage(handle) + else: + return _get_amdsmi_power_draw(device) def clock_rate(device: Optional[Union[Device, int]] = None) -> int: @@ -998,8 +1144,11 @@ def clock_rate(device: Optional[Union[Device, int]] = None) -> int: Warning: Each sample period may be between 1 second and 1/6 second, depending on the product being queried. """ - handle = _get_pynvml_handler(device) - return pynvml.nvmlDeviceGetClockInfo(handle, 1) + if not torch.version.hip: + handle = _get_pynvml_handler(device) + return pynvml.nvmlDeviceGetClockInfo(handle, 1) + else: + return _get_amdsmi_clock_rate(device) def _get_device(device: Union[int, str, torch.device]) -> torch.device: diff --git a/torch/cuda/memory.py b/torch/cuda/memory.py index 8453842ef14a2..8a5110b10c98b 100644 --- a/torch/cuda/memory.py +++ b/torch/cuda/memory.py @@ -15,7 +15,13 @@ from torch.types import Device from .._utils import _dummy_type -from . import _get_device_index, _get_nvml_device_index, _lazy_init, is_initialized +from . import ( + _get_amdsmi_device_index, + _get_device_index, + _get_nvml_device_index, + _lazy_init, + is_initialized, +) from ._memory_viz import memory as _memory, segments as _segments @@ -609,26 +615,48 @@ def list_gpu_processes(device: Union[Device, int] = None) -> str: printout for the current device, given by :func:`~torch.cuda.current_device`, if :attr:`device` is ``None`` (default). """ - try: - import pynvml # type: ignore[import] - except ModuleNotFoundError: - return "pynvml module not found, please install pynvml" - from pynvml import NVMLError_DriverNotLoaded + if not torch.version.hip: + try: + import pynvml # type: ignore[import] + except ModuleNotFoundError: + return "pynvml module not found, please install pynvml" + from pynvml import NVMLError_DriverNotLoaded + + try: + pynvml.nvmlInit() + except NVMLError_DriverNotLoaded: + return "cuda driver can't be loaded, is cuda enabled?" + + device = _get_nvml_device_index(device) + handle = pynvml.nvmlDeviceGetHandleByIndex(device) + procs = pynvml.nvmlDeviceGetComputeRunningProcesses(handle) + else: + try: + import amdsmi # type: ignore[import] + except ModuleNotFoundError: + return "amdsmi module not found, please install amdsmi" + try: + amdsmi.amdsmi_init() # type: ignore[attr-defined] + except amdsmi.AmdSmiException: # type: ignore[attr-defined] + return "amdsmi driver can't be loaded, is ROCm installed?" + + device = _get_amdsmi_device_index(device) + handle = amdsmi.amdsmi_get_processor_handles()[device] # type: ignore[attr-defined] + procs = amdsmi.amdsmi_get_gpu_process_list(handle) # type: ignore[attr-defined] - try: - pynvml.nvmlInit() - except NVMLError_DriverNotLoaded: - return "cuda driver can't be loaded, is cuda enabled?" - device = _get_nvml_device_index(device) - handle = pynvml.nvmlDeviceGetHandleByIndex(device) - procs = pynvml.nvmlDeviceGetComputeRunningProcesses(handle) lines = [] lines.append(f"GPU:{device}") if len(procs) == 0: lines.append("no processes are running") for p in procs: - mem = p.usedGpuMemory / (1024 * 1024) - lines.append(f"process {p.pid:>10d} uses {mem:>12.3f} MB GPU memory") + if not torch.version.hip: + mem = p.usedGpuMemory / (1024 * 1024) + pid = p.pid + else: + proc_info = amdsmi.amdsmi_get_gpu_process_info(handle, p) # type: ignore[possibly-undefined] + mem = proc_info["memory_usage"]["vram_mem"] / (1024 * 1024) + pid = proc_info["pid"] + lines.append(f"process {pid:>10d} uses {mem:>12.3f} MB GPU memory") return "\n".join(lines) diff --git a/torch/distributed/__init__.py b/torch/distributed/__init__.py index 3e7dce97b54c9..eb7a690fa9589 100644 --- a/torch/distributed/__init__.py +++ b/torch/distributed/__init__.py @@ -54,8 +54,6 @@ def is_available() -> bool: set_debug_level, set_debug_level_from_env, _make_nccl_premul_sum, - _ControlCollectives, - _StoreCollectives, ) class _DistributedPdb(pdb.Pdb): diff --git a/torch/distributed/pipelining/__init__.py b/torch/distributed/pipelining/__init__.py index 8ea9923dd44d1..45352d3da1b90 100644 --- a/torch/distributed/pipelining/__init__.py +++ b/torch/distributed/pipelining/__init__.py @@ -8,7 +8,7 @@ pipeline, SplitPoint, ) -from ._PipelineStage import PipelineStage +from ._PipelineStage import ManualPipelineStage, PipelineStage from .PipelineSchedule import ( Schedule1F1B, ScheduleGPipe, @@ -24,6 +24,7 @@ "pipeline", "ArgsChunkSpec", "KwargsChunkSpec", + "ManualPipelineStage", "PipelineStage", "Schedule1F1B", "ScheduleGPipe", diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index 9976c4e9beca2..cdd9995bbcd36 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -29,6 +29,7 @@ from torch.utils._traceback import CapturedTraceback import logging from torch._library.fake_class_registry import FakeScriptObject +import warnings from torch.overrides import TorchFunctionMode @@ -921,6 +922,10 @@ def disable_autocast_cache(): torch.set_autocast_cache_enabled(old_value) +class _ModuleNotInstalledAsSubmoduleError(NameError): + pass + + class _ModuleStackTracer(PythonKeyTracer): r"""Customized version of PythonKeyTracer that retains module stack information in node.meta["nn_module_stack"]. @@ -998,7 +1003,10 @@ def path_of_module(self, mod: torch.nn.Module) -> str: if isinstance(mod, self.proxy_type): return self.proxy_paths[mod] - return Tracer.path_of_module(self, mod) + try: + return Tracer.path_of_module(self, mod) + except NameError as e: + raise _ModuleNotInstalledAsSubmoduleError from e def getattr(self, attr, attr_val, parameter_proxy_cache): if not isinstance(attr_val, torch.nn.Module) or isinstance(attr_val, torch.fx.GraphModule): @@ -1070,7 +1078,17 @@ def call_module(self, m, forward, args, kwargs): # use cases don't need to work with HOO. if isinstance(m, (OptimizedModule, GraphModule)): return forward(*args, **kwargs) - return Tracer.call_module(self, m, forward, args, kwargs) + + try: + return Tracer.call_module(self, m, forward, args, kwargs) + except _ModuleNotInstalledAsSubmoduleError as e: + warnings.warn( + f"Unable to find the path of the module {m}. " + "This might be because the module was not properly registered " + "as a submodule, which is not good practice. We will trace " + "through the module without recording stack information." + ) + return forward(*args, **kwargs) def is_leaf_module(self, m, module_qualified_name): diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index be1be24137f88..ca6e5957e20e6 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -489,6 +489,18 @@ def get(self, o: Any) -> Any: return getattr(o, self.name)() +@dataclass(frozen=True) +class InnerTensorKey: + inner_name: str + + def __str__(self) -> str: + return f".{self.inner_name}" + + def get(self, o: Any) -> Any: + """Get the inner tensor attribute""" + return getattr(o, self.inner_name) + + @dataclass(frozen=True) class DivideByKey: divisor: int @@ -538,6 +550,14 @@ def free_unbacked_symbols_with_path( real=real[i] if real is not None else None ) ) + elif is_traceable_wrapper_subclass(a): + # TODO: Determine if this is correct + attrs, _ = a.__tensor_flatten__() + for attr in attrs: + sub = getattr(a, attr) + r.update( + free_unbacked_symbols_with_path(sub, path + (InnerTensorKey(attr),)) + ) elif isinstance(a, torch.Tensor): r.update( free_unbacked_symbols_with_path( @@ -4397,6 +4417,9 @@ def _set_replacement(self, a: "sympy.Symbol", tgt: "sympy.Expr", msg: str) -> No Use this instead of `self.replacements[a] = tgt`. """ + if tgt == self.replacements.get(a, None): + return + # Precondition: a == tgt assert isinstance(a, sympy.Symbol) @@ -4487,14 +4510,24 @@ def issubset(x, y): "[%s not subset of %s (size-oblivious conditions)]", a, tgt, msg, tgt_bound_so, src_bound_so) return - if config.print_specializations and isinstance(tgt, (sympy.Integer, sympy.Float)): - # specializing to a constant, which is likely unexpected + if isinstance(tgt, (sympy.Integer, sympy.Float)): + # specializing to a constant, which is likely unexpected (unless + # you specified dynamic=True) + + user_tb = TracingContext.extract_stack() + trace_structured( + "symbolic_shape_specialization", + metadata_fn=lambda: { + "symbol": repr(a), + "sources": [s.name() for s in self.var_to_sources.get(a, [])], + "value": repr(tgt), + "reason": msg, + "stack": structured.from_traceback(CapturedTraceback.extract(skip=1).summary()), + "user_stack": structured.from_traceback(user_tb) if user_tb else None, + } + ) - # NOTE(avik): It is possible that we try logging the same specialization multiple times, e.g., - # when adding a to self.replacements, and again when simplifying an expression containing a. - # Thus to avoid duplication, checking whether a is in self.replacements isn't enough; if it is, - # it must not already map to `tgt`. Fortunately this check is cheap because `tgt` is a constant. - if a not in self.replacements or tgt != self.replacements[a]: + if config.print_specializations: self.log.warning("Specializing %s to %s", self.var_to_sources[a][0].name(), tgt) self.log.debug("SPECIALIZATION", stack_info=True) log.info("set_replacement %s = %s (%s) %s", a, tgt, msg, tgt_bound) diff --git a/torch/fx/passes/runtime_assert.py b/torch/fx/passes/runtime_assert.py index 0d45defe8a48c..843f5f37e1dab 100644 --- a/torch/fx/passes/runtime_assert.py +++ b/torch/fx/passes/runtime_assert.py @@ -57,6 +57,7 @@ def insert_deferred_runtime_asserts( ConvertIntKey, DivideByKey, free_symbols, + InnerTensorKey, ) from torch.utils._sympy.interp import sympy_interp from torch.utils._sympy.reference import PythonReferenceAnalysis @@ -225,6 +226,13 @@ def go(node, keypath): ), keypath[1:], ) + elif isinstance(keypath[0], InnerTensorKey): + return go( + graph.call_function( + getattr, (node, keypath[0].inner_name) + ), + keypath[1:], + ) else: raise AssertionError(f"unrecognized keypath {keypath}") diff --git a/torch/nn/attention/_flex_attention.py b/torch/nn/attention/_flex_attention.py index c56374fcbc40d..bd999ec39118d 100644 --- a/torch/nn/attention/_flex_attention.py +++ b/torch/nn/attention/_flex_attention.py @@ -28,11 +28,21 @@ def inner(score, b, h, m, n): ] +def _identity( + score: torch.Tensor, + batch: torch.Tensor, + head: torch.Tensor, + token_q: torch.Tensor, + token_kv: torch.Tensor, +) -> torch.Tensor: + return score + + def _flex_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - score_mod: _score_mod_signature, + score_mod: _score_mod_signature = _identity, ) -> torch.Tensor: r"""This function implements scaled dot product attention with an arbitrary attention score modification function. @@ -63,7 +73,7 @@ def score_mod( query (Tensor): Query tensor; shape :math:`(B, H, L, E)`. key (Tensor): Key tensor; shape :math:`(B, H, S, E)`. value (Tensor): Value tensor; shape :math:`(B, H, S, Ev)`. - score_mod (Callable): Function to modify attention scores + score_mod (Callable): Function to modify attention scores. By default no score_mod is applied. Returns: output (Tensor): Attention output; shape :math:`(B, H, L, Ev)`. @@ -114,16 +124,6 @@ def score_mod( """Some common used score_mod functions for flex_attention in PyTorch.""" -def _identity( - score: torch.Tensor, - batch: torch.Tensor, - head: torch.Tensor, - token_q: torch.Tensor, - token_kv: torch.Tensor, -) -> torch.Tensor: - return score - - def _causal( score: torch.Tensor, batch: torch.Tensor, diff --git a/torch/testing/_internal/common_optimizers.py b/torch/testing/_internal/common_optimizers.py index 5abacf2df1d61..c81efb093cd8b 100644 --- a/torch/testing/_internal/common_optimizers.py +++ b/torch/testing/_internal/common_optimizers.py @@ -123,8 +123,6 @@ def __init__( supported_impls: Tuple[str] = ("foreach", "differentiable"), # the optim supports passing in sparse gradients as well as dense grads supports_sparse: bool = False, - # the optimizer constructor supports passing in capturable as a kwarg - has_capturable_arg: bool = False, # the optim only supports one config: sparse grads w/ dense params, see SparseAdam only_supports_sparse_grads: bool = False, # Tuple of (optimizer kwargs, schedulers_constructors) specifically for sparse tests, @@ -149,7 +147,6 @@ def __init__( self.scheduler_inputs = scheduler_inputs self.supported_impls = supported_impls self.supports_sparse = supports_sparse - self.has_capturable_arg = has_capturable_arg self.metadata_for_sparse = metadata_for_sparse self.only_supports_sparse_grads = only_supports_sparse_grads self.supports_complex = supports_complex @@ -314,11 +311,10 @@ def optim_inputs_func_adadelta(device, dtype=None): OptimizerInput( params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay" ), - OptimizerInput(params=None, kwargs={"maximize": True}, desc="maximize"), OptimizerInput( params=None, kwargs={"weight_decay": 0.1, "maximize": True}, - desc="maximize, weight_decay", + desc="maximize", ), OptimizerInput( params=None, kwargs={"rho": 0.95, "weight_decay": 0.9}, desc="rho" @@ -530,15 +526,10 @@ def optim_inputs_func_adamax(device, dtype=None): OptimizerInput( params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay" ), - OptimizerInput( - params=None, - kwargs={"maximize": True}, - desc="maximize", - ), OptimizerInput( params=None, kwargs={"weight_decay": 0.1, "maximize": True}, - desc="maximize, weight_decay", + desc="maximize", ), ] + (cuda_supported_configs if "cuda" in str(device) else []) @@ -690,22 +681,16 @@ def optim_inputs_func_nadam(device, dtype=None): kwargs={"momentum_decay": 6e-3}, desc="non-zero momentum_decay", ), - OptimizerInput( - params=None, - kwargs={ - "weight_decay": 0.1, - }, - desc="weight_decay", - ), OptimizerInput( params=None, kwargs={"weight_decay": 0.1, "momentum_decay": 6e-3}, - desc="weight_decay, momentum_decay", + desc="weight_decay", ), OptimizerInput( params=None, kwargs={ "weight_decay": 0.1, + "momentum_decay": 6e-3, "decoupled_weight_decay": True, }, desc="decoupled_weight_decay", @@ -833,26 +818,11 @@ def optim_inputs_func_rmsprop(device, dtype=None): OptimizerInput( params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay" ), - OptimizerInput( - params=None, - kwargs={ - "maximize": True, - }, - desc="maximize", - ), OptimizerInput( params=None, kwargs={"weight_decay": 0.1, "centered": True}, desc="centered", ), - OptimizerInput( - params=None, - kwargs={ - "maximize": True, - "weight_decay": 0.1, - }, - desc="maximize, weight_decay", - ), OptimizerInput( params=None, kwargs={"weight_decay": 0.1, "centered": True, "momentum": 0.1}, @@ -866,7 +836,7 @@ def optim_inputs_func_rmsprop(device, dtype=None): "momentum": 0.1, "maximize": True, }, - desc="maximize, centered, weight_decay, w/ momentum", + desc="maximize", ), ] + (cuda_supported_configs if "cuda" in str(device) else []) @@ -937,15 +907,7 @@ def optim_inputs_func_sgd(device, dtype=None): OptimizerInput( params=None, kwargs={"lr": torch.tensor(0.001)}, desc="tensor lr" ), - OptimizerInput( - params=None, kwargs={"weight_decay": 0.5}, desc="non-zero weight_decay" - ), OptimizerInput(params=None, kwargs={"momentum": 0.9}, desc="momentum"), - OptimizerInput( - params=None, - kwargs={"weight_decay": 0.1, "maximize": True}, - desc="maximize", - ), OptimizerInput( params=None, kwargs={"momentum": 0.9, "dampening": 0.5}, @@ -954,13 +916,18 @@ def optim_inputs_func_sgd(device, dtype=None): OptimizerInput( params=None, kwargs={"momentum": 0.9, "weight_decay": 0.1}, - desc="weight_decay w/ momentum", + desc="non-zero weight_decay", ), OptimizerInput( params=None, kwargs={"momentum": 0.9, "nesterov": True, "weight_decay": 0.1}, desc="nesterov", ), + OptimizerInput( + params=None, + kwargs={"weight_decay": 0.1, "maximize": True}, + desc="maximize", + ), ] @@ -1130,7 +1097,6 @@ def _get_optim_inputs_including_global_cliquey_kwargs( optim_inputs_func=optim_inputs_func_adadelta, optim_error_inputs_func=optim_error_inputs_func_adadelta, supported_impls=("foreach", "differentiable"), - has_capturable_arg=True, skips=( DecorateInfo( skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"), @@ -1266,7 +1232,6 @@ def _get_optim_inputs_including_global_cliquey_kwargs( optim_error_inputs_func=optim_error_inputs_func_adam, supported_impls=("foreach", "differentiable", "fused"), supports_fused_on=("cpu", "cuda"), - has_capturable_arg=True, decorators=( # Expected floating point error between fused and compiled forloop DecorateInfo( @@ -1333,7 +1298,6 @@ def _get_optim_inputs_including_global_cliquey_kwargs( optim_inputs_func=optim_inputs_func_adamax, optim_error_inputs_func=optim_error_inputs_func_adamax, supported_impls=("foreach", "differentiable"), - has_capturable_arg=True, skips=( DecorateInfo( skipIfMps, # addcdiv doesn't work for non-contiguous, see #118115 @@ -1384,7 +1348,6 @@ def _get_optim_inputs_including_global_cliquey_kwargs( optim_error_inputs_func=optim_error_inputs_func_adamw, supported_impls=("foreach", "differentiable", "fused"), supports_fused_on=("cpu", "cuda"), - has_capturable_arg=True, decorators=( # Expected error between compiled forloop and fused optimizers DecorateInfo( @@ -1451,7 +1414,6 @@ def _get_optim_inputs_including_global_cliquey_kwargs( optim_inputs_func=optim_inputs_func_asgd, optim_error_inputs_func=optim_error_inputs_func_asgd, supported_impls=("foreach", "differentiable"), - has_capturable_arg=True, skips=( DecorateInfo( skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"), @@ -1544,7 +1506,6 @@ def _get_optim_inputs_including_global_cliquey_kwargs( optim_inputs_func=optim_inputs_func_nadam, optim_error_inputs_func=optim_error_inputs_func_nadam, supported_impls=("foreach", "differentiable"), - has_capturable_arg=True, skips=( DecorateInfo( skipIfMps, # addcdiv doesn't work for non-contiguous, see #118115 @@ -1600,7 +1561,6 @@ def _get_optim_inputs_including_global_cliquey_kwargs( optim_inputs_func=optim_inputs_func_radam, optim_error_inputs_func=optim_error_inputs_func_radam, supported_impls=("foreach", "differentiable"), - has_capturable_arg=True, skips=( DecorateInfo( skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"), @@ -1646,7 +1606,6 @@ def _get_optim_inputs_including_global_cliquey_kwargs( optim_inputs_func=optim_inputs_func_rmsprop, optim_error_inputs_func=optim_error_inputs_func_rmsprop, supported_impls=("foreach", "differentiable"), - has_capturable_arg=True, skips=( DecorateInfo( skipIfMps, # addcdiv doesn't work for non-contiguous, see #118115 @@ -1696,7 +1655,6 @@ def _get_optim_inputs_including_global_cliquey_kwargs( optim_inputs_func=optim_inputs_func_rprop, optim_error_inputs_func=optim_error_inputs_func_rprop, supported_impls=("foreach", "differentiable"), - has_capturable_arg=True, skips=( DecorateInfo( skipIfMps, # Rprop doesn't update for non-contiguous, see #118117 diff --git a/torch/testing/_internal/inductor_utils.py b/torch/testing/_internal/inductor_utils.py index 0f6209a01c3f1..e8db1e394b96f 100644 --- a/torch/testing/_internal/inductor_utils.py +++ b/torch/testing/_internal/inductor_utils.py @@ -12,8 +12,6 @@ LazyVal, IS_FBCODE, ) -from torch._dynamo.backends.registry import register_backend -from torch._inductor.compile_fx import compile_fx, count_bytes_inner from torch.testing._internal.common_utils import TestCase def test_cpu(): @@ -48,10 +46,6 @@ def test_cpu(): GPU_TYPE = "cuda" if len(tmp_gpus) == 0 else tmp_gpus.pop() del tmp_gpus -@register_backend -def count_bytes_inductor(gm, example_inputs): - return compile_fx(gm, example_inputs, inner_compile=count_bytes_inner) - def _check_has_dynamic_shape( self: TestCase, code, diff --git a/torch/utils/data/sampler.py b/torch/utils/data/sampler.py index 0f100f1858419..4c4c967ef9a9a 100644 --- a/torch/utils/data/sampler.py +++ b/torch/utils/data/sampler.py @@ -19,8 +19,8 @@ class Sampler(Generic[T_co]): r"""Base class for all Samplers. Every Sampler subclass has to provide an :meth:`__iter__` method, providing a - way to iterate over indices or lists of indices (batches) of dataset elements, and a :meth:`__len__` method - that returns the length of the returned iterators. + way to iterate over indices or lists of indices (batches) of dataset elements, + and may provide a :meth:`__len__` method that returns the length of the returned iterators. Args: data_source (Dataset): This argument is not used and will be removed in 2.2.0. diff --git a/torch/utils/hipify/constants.py b/torch/utils/hipify/constants.py index fb56e7a77a3ed..a9053b261ad44 100644 --- a/torch/utils/hipify/constants.py +++ b/torch/utils/hipify/constants.py @@ -2,7 +2,7 @@ The constants defined here are used to annotate the mapping tuples in cuda_to_hip_mappings.py. They are based on -https://github.com/ROCm-Developer-Tools/HIP/blob/master/hipify-clang/src/Statistics.h +https://github.com/ROCm/HIPIFY/blob/master/src/Statistics.h and fall in three categories: 1) type of mapping, 2) API of mapping, 3) unsupported mapping. """ diff --git a/torch/utils/hipify/cuda_to_hip_mappings.py b/torch/utils/hipify/cuda_to_hip_mappings.py index 3c84e1bff4c9d..976e12e42d336 100644 --- a/torch/utils/hipify/cuda_to_hip_mappings.py +++ b/torch/utils/hipify/cuda_to_hip_mappings.py @@ -4163,6 +4163,7 @@ ("cudaGraphLaunch", ("hipGraphLaunch", CONV_TYPE, API_RUNTIME)), ("cudaGraphGetNodes", ("hipGraphGetNodes", CONV_TYPE, API_RUNTIME)), ("cudaGraphDebugDotPrint", ("hipGraphDebugDotPrint", CONV_TYPE, API_RUNTIME)), + ("cudaGraphDebugDotFlagsVerbose", ("hipGraphDebugDotFlagsVerbose", CONV_NUMERIC_LITERAL, API_RUNTIME)), ("cudaGraphRetainUserObject", ("hipGraphRetainUserObject", CONV_TYPE, API_RUNTIME)), ("cudaGraphUserObjectMove", ("hipGraphUserObjectMove", CONV_TYPE, API_RUNTIME)), ("cudaUserObject_t", ("hipUserObject_t", CONV_TYPE, API_RUNTIME)),