diff --git a/.ci/docker/build.sh b/.ci/docker/build.sh index 73e3f09394b72..8786471a7bddf 100755 --- a/.ci/docker/build.sh +++ b/.ci/docker/build.sh @@ -149,6 +149,21 @@ case "$image" in TRITON=yes INDUCTOR_BENCHMARKS=yes ;; + pytorch-linux-focal-cuda12.1-cudnn8-py3.12-gcc9-inductor-benchmarks) + CUDA_VERSION=12.1.1 + CUDNN_VERSION=8 + ANACONDA_PYTHON_VERSION=3.12 + GCC_VERSION=9 + PROTOBUF=yes + DB=yes + VISION=yes + KATEX=yes + UCX_COMMIT=${_UCX_COMMIT} + UCC_COMMIT=${_UCC_COMMIT} + CONDA_CMAKE=yes + TRITON=yes + INDUCTOR_BENCHMARKS=yes + ;; pytorch-linux-focal-cuda11.8-cudnn8-py3-gcc9) CUDA_VERSION=11.8.0 CUDNN_VERSION=8 diff --git a/.ci/docker/common/install_acl.sh b/.ci/docker/common/install_acl.sh index f5e5ce92af4af..8a6dc4d1c79c6 100644 --- a/.ci/docker/common/install_acl.sh +++ b/.ci/docker/common/install_acl.sh @@ -1,6 +1,6 @@ set -euo pipefail -readonly version=v23.08 +readonly version=v24.04 readonly src_host=https://review.mlplatform.org/ml readonly src_repo=ComputeLibrary diff --git a/.github/workflows/docker-builds.yml b/.github/workflows/docker-builds.yml index 9f0dfe973dc9f..bb356dce5da9b 100644 --- a/.github/workflows/docker-builds.yml +++ b/.github/workflows/docker-builds.yml @@ -42,6 +42,7 @@ jobs: pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9-inductor-benchmarks, pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9, pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9-inductor-benchmarks, + pytorch-linux-focal-cuda12.1-cudnn8-py3.12-gcc9-inductor-benchmarks, pytorch-linux-focal-cuda11.8-cudnn8-py3-gcc9, pytorch-linux-focal-py3.8-clang10, pytorch-linux-focal-py3.11-clang10, diff --git a/.github/workflows/inductor.yml b/.github/workflows/inductor.yml index c00630b5e8b66..3d1c3a539686c 100644 --- a/.github/workflows/inductor.yml +++ b/.github/workflows/inductor.yml @@ -107,6 +107,27 @@ jobs: secrets: HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} + linux-focal-cuda12_1-py3_12-gcc9-inductor-build: + name: cuda12.1-py3.12-gcc9-sm86 + uses: ./.github/workflows/_linux-build.yml + with: + build-environment: linux-focal-cuda12.1-py3.12-gcc9-sm86 + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3.12-gcc9-inductor-benchmarks + cuda-arch-list: '8.6' + test-matrix: | + { include: [ + { config: "inductor", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, + ]} + + linux-focal-cuda12_1-py3_12-gcc9-inductor-test: + name: cuda12.1-py3.12-gcc9-sm86 + uses: ./.github/workflows/_linux-test.yml + needs: linux-focal-cuda12_1-py3_12-gcc9-inductor-build + with: + build-environment: linux-focal-cuda12.1-py3.12-gcc9-sm86 + docker-image: ${{ needs.linux-focal-cuda12_1-py3_12-gcc9-inductor-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-focal-cuda12_1-py3_12-gcc9-inductor-build.outputs.test-matrix }} + linux-jammy-cpu-py3_8-gcc11-inductor-build: name: linux-jammy-cpu-py3.8-gcc11-inductor uses: ./.github/workflows/_linux-build.yml diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml index 6c2bd277166c8..00813edd3d913 100644 --- a/.github/workflows/trunk.yml +++ b/.github/workflows/trunk.yml @@ -194,6 +194,7 @@ jobs: { include: [ { config: "default", shard: 1, num_shards: 2, runner: "linux.rocm.gpu" }, { config: "default", shard: 2, num_shards: 2, runner: "linux.rocm.gpu" }, + { config: "distributed", shard: 1, num_shards: 1, runner: "linux.rocm.gpu" }, ]} linux-focal-rocm6_1-py3_8-test: @@ -209,4 +210,4 @@ jobs: build-environment: linux-focal-rocm6.1-py3.8 docker-image: ${{ needs.linux-focal-rocm6_1-py3_8-build.outputs.docker-image }} test-matrix: ${{ needs.linux-focal-rocm6_1-py3_8-build.outputs.test-matrix }} - tests-to-include: "test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs test_autograd inductor/test_torchinductor" + tests-to-include: "test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs test_autograd inductor/test_torchinductor distributed/test_c10d_common distributed/test_c10d_nccl" \ No newline at end of file diff --git a/.github/workflows/upload_test_stats_intermediate.yml b/.github/workflows/upload_test_stats_intermediate.yml new file mode 100644 index 0000000000000..14b65f6a75ef3 --- /dev/null +++ b/.github/workflows/upload_test_stats_intermediate.yml @@ -0,0 +1,43 @@ +name: Upload test stats intermediate + +on: + workflow_dispatch: + inputs: + workflow_id: + description: workflow_id of the run + required: true + workflow_run_attempt: + description: workflow_run_attempt of the run + required: true + +jobs: + intermediate_upload_test_stats: + name: Intermediate upload test stats for ${{ inputs.workflow_id }} + runs-on: ubuntu-22.04 + environment: upload-stats + steps: + - name: Checkout PyTorch + uses: pytorch/pytorch/.github/actions/checkout-pytorch@main + with: + fetch-depth: 1 + submodules: false + + - uses: actions/setup-python@v4 + with: + python-version: '3.11' + cache: pip + + - run: | + pip3 install requests==2.26 rockset==1.0.3 boto3==1.19.12 + + - name: Upload test stats + env: + AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} + AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + WORKFLOW_RUN_ID: ${{ inputs.workflow_id }} + WORKFLOW_RUN_ATTEMPT: ${{ inputs.workflow_run_attempt }} + run: | + python3 -m tools.stats.upload_test_stats_intermediate \ + --workflow-run-id "${WORKFLOW_RUN_ID}" \ + --workflow-run-attempt "${WORKFLOW_RUN_ATTEMPT}" \ diff --git a/.gitmodules b/.gitmodules index c9b84a3701674..bd62cb8280ea9 100644 --- a/.gitmodules +++ b/.gitmodules @@ -2,10 +2,6 @@ ignore = dirty path = third_party/pybind11 url = https://github.com/pybind/pybind11.git -[submodule "third_party/cub"] - ignore = dirty - path = third_party/cub - url = https://github.com/NVlabs/cub.git [submodule "third_party/eigen"] ignore = dirty path = third_party/eigen @@ -50,10 +46,6 @@ ignore = dirty path = third_party/psimd url = https://github.com/Maratyszcza/psimd.git -[submodule "third_party/zstd"] - ignore = dirty - path = third_party/zstd - url = https://github.com/facebook/zstd.git [submodule "third_party/cpuinfo"] ignore = dirty path = third_party/cpuinfo @@ -66,10 +58,6 @@ ignore = dirty path = third_party/onnx url = https://github.com/onnx/onnx.git -[submodule "third_party/onnx-tensorrt"] - ignore = dirty - path = third_party/onnx-tensorrt - url = https://github.com/onnx/onnx-tensorrt [submodule "third_party/sleef"] ignore = dirty path = third_party/sleef @@ -152,3 +140,7 @@ [submodule "third_party/opentelemetry-cpp"] path = third_party/opentelemetry-cpp url = https://github.com/open-telemetry/opentelemetry-cpp.git +[submodule "third_party/cpp-httplib"] + path = third_party/cpp-httplib + url = https://github.com/yhirose/cpp-httplib.git + branch = v0.15.3 diff --git a/.lintrunner.toml b/.lintrunner.toml index 8dfc1554041a1..50eb09984fec7 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -1929,8 +1929,6 @@ exclude_patterns = [ 'torch/utils/_mode_utils.py', 'torch/utils/_python_dispatch.py', 'torch/utils/_stats.py', - 'torch/utils/_sympy/__init__.py', - 'torch/utils/_sympy/functions.py', 'torch/utils/_traceback.py', 'torch/utils/_zip.py', 'torch/utils/backcompat/__init__.py', diff --git a/BUILD.bazel b/BUILD.bazel index 3f7e6327452c0..831d64b44c2f6 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -772,7 +772,7 @@ cc_library( [ "torch/*.h", "torch/csrc/**/*.h", - "torch/csrc/distributed/c10d/*.hpp", + "torch/csrc/distributed/c10d/**/*.hpp", "torch/lib/libshm/*.h", ], exclude = [ diff --git a/CMakeLists.txt b/CMakeLists.txt index f7561d606cbdb..3c6320e68d390 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -265,7 +265,6 @@ option(USE_PYTORCH_QNNPACK "Use ATen/QNNPACK (quantized 8-bit operators)" ON) option(USE_SNPE "Use Qualcomm's SNPE library" OFF) option(USE_SYSTEM_EIGEN_INSTALL "Use system Eigen instead of the one under third_party" OFF) -option(USE_TENSORRT "Using Nvidia TensorRT library" OFF) cmake_dependent_option( USE_VALGRIND "Use Valgrind. Only available on Linux." ON "LINUX" OFF) @@ -279,11 +278,13 @@ endif() option(USE_SLEEF_FOR_ARM_VEC256 "Use sleef for arm" OFF) option(USE_SOURCE_DEBUG_ON_MOBILE "Enable" ON) option(USE_LITE_INTERPRETER_PROFILER "Enable" ON) +cmake_dependent_option( + USE_LITE_AOTI "Include AOTI sources" OFF + "BUILD_LITE_INTERPRETER" OFF) option(USE_VULKAN_FP16_INFERENCE "Vulkan - Use fp16 inference" OFF) option(USE_VULKAN_RELAXED_PRECISION "Vulkan - Use relaxed precision math in the kernels (mediump)" OFF) # option USE_XNNPACK: try to enable xnnpack by default. option(USE_XNNPACK "Use XNNPACK" ON) -option(USE_ZSTD "Use ZSTD" OFF) option(USE_ROCM_KERNEL_ASSERT "Use Kernel Assert for ROCm" OFF) # Ensure that an ITT build is the default for x86 CPUs cmake_dependent_option( @@ -413,7 +414,6 @@ option(USE_SYSTEM_FXDIV "Use system-provided fxdiv." OFF) option(USE_SYSTEM_BENCHMARK "Use system-provided google benchmark." OFF) option(USE_SYSTEM_ONNX "Use system-provided onnx." OFF) option(USE_SYSTEM_XNNPACK "Use system-provided xnnpack." OFF) -option(USE_SYSTEM_ZSTD "Use system-provided zstd." OFF) option(USE_GOLD_LINKER "Use ld.gold to link" OFF) if(USE_SYSTEM_LIBS) set(USE_SYSTEM_CPUINFO ON) @@ -435,9 +435,6 @@ if(USE_SYSTEM_LIBS) if(USE_TBB) set(USE_SYSTEM_TBB ON) endif() - if(USE_ZSTD) - set(USE_SYSTEM_ZSTD ON) - endif() endif() # Used when building Caffe2 through setup.py diff --git a/WORKSPACE b/WORKSPACE index 8eabea571a571..f7e6043322131 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -355,9 +355,4 @@ local_repository( path = "third_party/onnx/third_party/benchmark", ) -local_repository( - name = "unused_onnx_tensorrt_benchmark", - path = "third_party/onnx-tensorrt/third_party/onnx/third_party/benchmark", -) - ### Unused repos end diff --git a/aten/src/ATen/core/ivalue.cpp b/aten/src/ATen/core/ivalue.cpp index 6c505f8b656cf..3086fa18add6f 100644 --- a/aten/src/ATen/core/ivalue.cpp +++ b/aten/src/ATen/core/ivalue.cpp @@ -887,12 +887,12 @@ c10::intrusive_ptr ivalue::Object::create( } IValue IValue::deepcopy(std::optional device) const { - IValue::HashAliasedIValueMap memo; + IValue::HashIdentityIValueMap memo; return deepcopy(memo, device); } IValue IValue::deepcopy( - IValue::HashAliasedIValueMap& memo, + IValue::HashIdentityIValueMap& memo, std::optional device) const { if (memo.count(*this)) { return memo.at(*this); @@ -1028,12 +1028,12 @@ c10::intrusive_ptr ivalue::Object::copy_to_weak_compilation_ref( c10::intrusive_ptr ivalue::Object::deepcopy( std::optional device) const { - IValue::HashAliasedIValueMap memo; + IValue::HashIdentityIValueMap memo; return deepcopy(memo, device); } c10::intrusive_ptr ivalue::Object::deepcopy( - IValue::HashAliasedIValueMap& memo, + IValue::HashIdentityIValueMap& memo, std::optional device) const { auto cu = type_.cu_; auto object = ivalue::Object::create(WeakOrStrongTypePtr(type_.cu_, type_.type_), type()->numAttributes()); diff --git a/aten/src/ATen/core/ivalue.h b/aten/src/ATen/core/ivalue.h index 7715ffbe3c31d..922b10b8efeb5 100644 --- a/aten/src/ATen/core/ivalue.h +++ b/aten/src/ATen/core/ivalue.h @@ -1117,6 +1117,23 @@ struct TORCH_API IValue final { using HashAliasedIValueMap = std::unordered_map; + struct HashIdentityIValue { + size_t operator()(const IValue& val) const { + return val.payload.u.as_int; + } + }; + + struct CompIdentityIValues { + bool operator()(const IValue& lhs, const IValue& rhs) const { + return lhs.is(rhs); + } + }; + + using HashIdentityIValues = + std::unordered_set; + using HashIdentityIValueMap = + std::unordered_map; + // Chechs if this and rhs has a subvalues in common. // [t1,t2] and [t2, t3] returns true. bool overlaps(const IValue& rhs) const; @@ -1130,7 +1147,7 @@ struct TORCH_API IValue final { void visit(const std::function& visitor) const; IValue deepcopy(std::optional device = c10::nullopt) const; IValue deepcopy( - HashAliasedIValueMap& memo, + HashIdentityIValueMap& memo, std::optional device = c10::nullopt) const; private: diff --git a/aten/src/ATen/core/ivalue_inl.h b/aten/src/ATen/core/ivalue_inl.h index b1124c12cfb34..b99229f2759c4 100644 --- a/aten/src/ATen/core/ivalue_inl.h +++ b/aten/src/ATen/core/ivalue_inl.h @@ -1589,7 +1589,7 @@ struct C10_EXPORT ivalue::Object final : c10::intrusive_ptr_target { std::optional device = c10::nullopt) const; c10::intrusive_ptr deepcopy( - IValue::HashAliasedIValueMap& memo, + IValue::HashIdentityIValueMap& memo, std::optional device = c10::nullopt) const; bool is_weak_compilation_ref() const { diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index 2502456e285b9..ce991a9bcad4e 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -1422,10 +1422,13 @@ void scaled_gemm( computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, mat1_scale_ptr); computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, mat2_scale_ptr); computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr); +#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 60200) + // Amax support in ROCm as of 6.2 + if (isFloat8Type(result_dtype)) { + computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, amax_ptr); + } +#endif #ifndef USE_ROCM -if (isFloat8Type(result_dtype)) { - computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, amax_ptr); -} computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_FAST_ACCUM, fastAccuMode); #endif CuBlasLtMatrixLayout Adesc(ScalarTypeToCudaDataType(mat1_dtype), m, k, mat1_ld, transa == 't'); diff --git a/aten/src/ATen/native/BlasKernel.cpp b/aten/src/ATen/native/BlasKernel.cpp index 48a077814880b..af34ae5c582ae 100644 --- a/aten/src/ATen/native/BlasKernel.cpp +++ b/aten/src/ATen/native/BlasKernel.cpp @@ -215,6 +215,87 @@ static inline float16_t reduce(float16x8_t x) { return reduce(vadd_f16(vget_low_f16(x), vget_high_f16(x))); } +/* + * The below reduce overload and + * fp16_gemv_trans_fp16_arith_by_dot_products function is adapted from + * llama.cpp's ggml_vec_dot_f16 and surrounding utility functions, so + * here is the required copyright notice: + * + * MIT License + * + * Copyright (c) 2023-2024 The ggml authors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#define F16_ELEMENTS_PER_ITERATION 32 +#define F16_ELEMENTS_PER_REGISTER 8 +#define F16_REGISTERS_PER_ITERATION (F16_ELEMENTS_PER_ITERATION / F16_ELEMENTS_PER_REGISTER) +static inline double reduce(float16x8_t x[F16_REGISTERS_PER_ITERATION]) { + int offset = F16_REGISTERS_PER_ITERATION / 2; + for (int i = 0; i < offset; ++i) { + x[i] = vaddq_f16(x[i], x[offset + i]); + } + offset /= 2; + for (int i = 0; i < offset; ++i) { + x[i] = vaddq_f16(x[i], x[offset + i]); + } + offset /= 2; + for (int i = 0; i < offset; ++i) { + x[i] = vaddq_f16(x[i], x[offset + i]); + } + const float32x4_t t0 = vcvt_f32_f16(vget_low_f16(x[0])); + const float32x4_t t1 = vcvt_f32_f16(vget_high_f16(x[0])); + return (double)vaddvq_f32(vaddq_f32(t0, t1)); + +} + +static inline float16x8_t f16_fma(float16x8_t a, float16x8_t b, float16x8_t c) { +#ifdef __ARM_FEATURE_FMA + return vfmaq_f16(a, b, c); +#else + return vaddq_f16(a, vmulq_f16(b, c)); +#endif +} + +// Rather than unrolling to process multiple rows (transposed columns) +// of matrix A at once as done in fp16_gemv_trans_fp16_arith, unroll +// along an individual dot product. +static void fp16_gemv_trans_fp16_arith_by_dot_products(const int m, const int n, const float16_t* a, const int lda, const float16_t *x, float16_t* y, int incy) { + parallel_for(0, n, 1, [&](int begin, int end) { + for (int i = begin; i < end; ++i) { + float16x8_t sum[F16_REGISTERS_PER_ITERATION] = {vdupq_n_f16(0)}; + float16x8_t ax[F16_REGISTERS_PER_ITERATION]; + float16x8_t ay[F16_REGISTERS_PER_ITERATION]; + + for (int j = 0; j < m; j += F16_ELEMENTS_PER_ITERATION) { + for (int k = 0; k < F16_REGISTERS_PER_ITERATION; ++k) { + ax[k] = vld1q_f16(x + j + k * F16_ELEMENTS_PER_REGISTER); + ay[k] = vld1q_f16(a + lda * i + j + k * F16_ELEMENTS_PER_REGISTER); + sum[k] = f16_fma(sum[k], ax[k], ay[k]); + } + } + // TODO: add a tail fixup so we don't have to have such a + // restrictive gate to enter this path. + y[i * incy] = reduce(sum); + } + }); +} static void fp16_gemv_trans_fp16_arith(const int m, const int n, const float16_t* a, const int lda, const float16_t *x, float16_t* y, int incy) { parallel_for(0, n / 4, 1, [&](int begin, int end) { @@ -230,13 +311,13 @@ static void fp16_gemv_trans_fp16_arith(const int m, const int n, const float16_t for (auto j = 0; j < m; j += 8) { float16x8_t xVec = vld1q_f16(x + j); float16x8_t a0Vec = vld1q_f16(row0 + j); - sum0Vec = vaddq_f16(sum0Vec, vmulq_f16(a0Vec, xVec)); + sum0Vec = f16_fma(sum0Vec, a0Vec, xVec); float16x8_t a1Vec = vld1q_f16(row1 + j); - sum1Vec = vaddq_f16(sum1Vec, vmulq_f16(a1Vec, xVec)); + sum1Vec = f16_fma(sum1Vec, a1Vec, xVec); float16x8_t a2Vec = vld1q_f16(row2 + j); - sum2Vec = vaddq_f16(sum2Vec, vmulq_f16(a2Vec, xVec)); + sum2Vec = f16_fma(sum2Vec, a2Vec, xVec); float16x8_t a3Vec = vld1q_f16(row3 + j); - sum3Vec = vaddq_f16(sum3Vec, vmulq_f16(a3Vec, xVec)); + sum3Vec = f16_fma(sum3Vec, a3Vec, xVec); } y[(i + 0) * incy] = reduce(sum0Vec); y[(i + 1) * incy] = reduce(sum1Vec); @@ -245,6 +326,7 @@ static void fp16_gemv_trans_fp16_arith(const int m, const int n, const float16_t } }); } + #endif static inline float reduce(float32x4_t x) { @@ -252,6 +334,14 @@ static inline float reduce(float32x4_t x) { return vgetq_lane_f32(vpaddq_f32(sum, sum), 0); } +static inline float32x4_t f32_fma(float32x4_t a, float32x4_t b, float32x4_t c) { +#ifdef __ARM_FEATURE_FMA + return vfmaq_f32(a, b, c); +#else + return vaddq_f32(a, vmulq_f32(b, c)); +#endif +} + static void fp16_gemv_trans_fp32_arith(const int m, const int n, const float16_t* a, const int lda, const float16_t *x, float16_t* y, int incy) { parallel_for(0, n / 4, 1, [&](int begin, int end) { for (auto i = begin * 4 ; i < end * 4; i += 4) { @@ -266,13 +356,13 @@ static void fp16_gemv_trans_fp32_arith(const int m, const int n, const float16_t for (auto j = 0; j < m; j += 4) { float32x4_t xVec = vcvt_f32_f16(vld1_f16(x + j)); float32x4_t a0Vec = vcvt_f32_f16(vld1_f16(row0 + j)); - sum0Vec = vaddq_f32(sum0Vec, vmulq_f32(a0Vec, xVec)); + sum0Vec = f32_fma(sum0Vec, a0Vec, xVec); float32x4_t a1Vec = vcvt_f32_f16(vld1_f16(row1 + j)); - sum1Vec = vaddq_f32(sum1Vec, vmulq_f32(a1Vec, xVec)); + sum1Vec = f32_fma(sum1Vec, a1Vec, xVec); float32x4_t a2Vec = vcvt_f32_f16(vld1_f16(row2 + j)); - sum2Vec = vaddq_f32(sum2Vec, vmulq_f32(a2Vec, xVec)); + sum2Vec = f32_fma(sum2Vec, a2Vec, xVec); float32x4_t a3Vec = vcvt_f32_f16(vld1_f16(row3 + j)); - sum3Vec = vaddq_f32(sum3Vec, vmulq_f32(a3Vec, xVec)); + sum3Vec = f32_fma(sum3Vec, a3Vec, xVec); } y[(i + 0) * incy] = reduce(sum0Vec); y[(i + 1) * incy] = reduce(sum1Vec); @@ -295,11 +385,16 @@ void fp16_gemv_trans( const int incy) { if (incx == 1 && alpha == 1.0 && beta == 0.0 && m % 4 == 0 && n % 4 == 0) { #ifdef __ARM_FEATURE_FP16_SCALAR_ARITHMETIC - return at::globalContext().allowFP16ReductionCPU() && m % 8 == 0 ? fp16_gemv_trans_fp16_arith(m, n, a, lda, x, y, incy) - : fp16_gemv_trans_fp32_arith(m, n, a, lda, x, y, incy); -#else - return fp16_gemv_trans_fp32_arith(m, n, a, lda, x, y, incy); + if (at::globalContext().allowFP16ReductionCPU()) { + if (m % 32 == 0 && n % 32 == 0) { + return fp16_gemv_trans_fp16_arith_by_dot_products(m, n, a, lda, x, y, incy); + } + if (m % 8 == 0) { + return fp16_gemv_trans_fp16_arith(m, n, a, lda, x, y, incy); + } + } #endif + return fp16_gemv_trans_fp32_arith(m, n, a, lda, x, y, incy); } for (const auto i : c10::irange(n)) { float sum = 0; diff --git a/aten/src/ATen/native/ConvolutionMM2d.cpp b/aten/src/ATen/native/ConvolutionMM2d.cpp index 686948584c728..10ab4a70f0914 100644 --- a/aten/src/ATen/native/ConvolutionMM2d.cpp +++ b/aten/src/ATen/native/ConvolutionMM2d.cpp @@ -543,6 +543,11 @@ Tensor& slow_conv2d_forward_out_cpu( IntArrayRef padding, Tensor& output) { // See [Note: hacky wrapper removal for optional tensor] + + TORCH_CHECK(kernel_size.size() == 2, "2D kernel_size expected"); + TORCH_CHECK(stride.size() == 2, "2D stride expected"); + TORCH_CHECK(padding.size() == 2, "2D padding expected"); + c10::MaybeOwned bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt); const Tensor& bias = *bias_maybe_owned; diff --git a/aten/src/ATen/native/TypeProperties.cpp b/aten/src/ATen/native/TypeProperties.cpp index 7091e4f78aef9..4afc7619c2ebd 100644 --- a/aten/src/ATen/native/TypeProperties.cpp +++ b/aten/src/ATen/native/TypeProperties.cpp @@ -191,8 +191,8 @@ ScalarType result_type(const Scalar& scalar1, const Scalar& scalar2) { return result_type(state); } -bool can_cast(const at::ScalarType from, const at::ScalarType to) { - return at::canCast(from, to); +bool can_cast(const at::ScalarType from_, const at::ScalarType to) { + return at::canCast(from_, to); } ScalarType promote_types(ScalarType type1, ScalarType type2) { diff --git a/aten/src/ATen/native/cpu/int4mm_kernel.cpp b/aten/src/ATen/native/cpu/int4mm_kernel.cpp index acb4b927f23f5..2ffef25a10ff4 100644 --- a/aten/src/ATen/native/cpu/int4mm_kernel.cpp +++ b/aten/src/ATen/native/cpu/int4mm_kernel.cpp @@ -341,12 +341,46 @@ inline void tinygemm_kernel( #if !defined(C10_MOBILE) && defined(__aarch64__) #include -template -inline void tinygemm_kernel( - const Half* RESTRICT A, + +inline float32x4x2_t load_as_float32x4x2(const Half* ptr) { + float16x4x2_t f16_val = vld2_f16(reinterpret_cast(ptr)); + auto val_low = vcvt_f32_f16(f16_val.val[0]); + auto val_high = vcvt_f32_f16(f16_val.val[1]); + return {val_low, val_high}; +} + +inline void store_float32x4(Half* ptr, float32x4_t val) { + vst1_f16(reinterpret_cast(ptr), vcvt_f16_f32(val)); +} + +inline float32x4x2_t load_as_float32x4x2(const BFloat16* ptr) { + int32x4_t shift = vdupq_n_s32(16); + uint16x4x2_t u16_val = vld2_u16(reinterpret_cast(ptr)); + uint32x4_t int_low = vmovl_u16(u16_val.val[0]); + uint32x4_t int_high = vmovl_u16(u16_val.val[1]); + return {vreinterpretq_f32_u32(vshlq_u32(int_low, shift)), vreinterpretq_f32_u32(vshlq_u32(int_high, shift))}; +} + +inline void store_float32x4(BFloat16* ptr, float32x4_t val) { + int32x4_t shift = vdupq_n_s32(-16); + uint32x4_t uint32_val = vshlq_u32(vreinterpretq_u32_f32(val), shift); + vst1_u16(reinterpret_cast(ptr), vmovn_u32(uint32_val)); +} + +inline float32x4x2_t load_as_float32x4x2(const float* ptr) { + return vld2q_f32(ptr); +} + +inline void store_float32x4(float* ptr, float32x4_t val) { + vst1q_f32(ptr, val); +} + +template +inline void tinygemm_kernel_( + const T* RESTRICT A, const uint8_t* RESTRICT B, - const Half* RESTRICT ScaleAndZeros, - Half* RESTRICT C, + const T* RESTRICT ScaleAndZeros, + T* RESTRICT C, int lda, int ldb, int ldc, @@ -368,9 +402,9 @@ inline void tinygemm_kernel( if (is_block_start(k, BLOCK_K)) { int kb = k / BLOCK_K; c10::ForcedUnroll<4>{}([&](auto i) { - auto scales_and_zeros = vld2_f16(reinterpret_cast(ScaleAndZeros + kb * ldc * 2 + n * 2 + i * 8)); - scales[i] = vcvt_f32_f16(scales_and_zeros.val[0]); - zeros[i] = vcvt_f32_f16(scales_and_zeros.val[1]); + auto scales_and_zeros = load_as_float32x4x2(ScaleAndZeros + kb * ldc * 2 + n * 2 + i * 8); + scales[i] = scales_and_zeros.val[0]; + zeros[i] = scales_and_zeros.val[1]; }); } c10::ForcedUnroll<4>{}([&](auto i) { @@ -383,11 +417,53 @@ inline void tinygemm_kernel( }); } c10::ForcedUnroll<4>{}([&](auto i) { - vst1_f16(reinterpret_cast(C + m * ldc + n + i * 4), vcvt_f16_f32(c_val[i])); + store_float32x4(C + m * ldc + n + i * 4, c_val[i]); }); } } } + +template +inline void tinygemm_kernel( + const Half* RESTRICT A, + const uint8_t* RESTRICT B, + const Half* RESTRICT ScaleAndZeros, + Half* RESTRICT C, + int lda, + int ldb, + int ldc, + int K, + int BLOCK_K) { + tinygemm_kernel_(A, B, ScaleAndZeros, C, lda, ldb, ldc, K, BLOCK_K); +} + +template +inline void tinygemm_kernel( + const BFloat16* RESTRICT A, + const uint8_t* RESTRICT B, + const BFloat16* RESTRICT ScaleAndZeros, + BFloat16* RESTRICT C, + int lda, + int ldb, + int ldc, + int K, + int BLOCK_K) { + tinygemm_kernel_(A, B, ScaleAndZeros, C, lda, ldb, ldc, K, BLOCK_K); +} + +template +inline void tinygemm_kernel( + const float* RESTRICT A, + const uint8_t* RESTRICT B, + const float* RESTRICT ScaleAndZeros, + float* RESTRICT C, + int lda, + int ldb, + int ldc, + int K, + int BLOCK_K) { + tinygemm_kernel_(A, B, ScaleAndZeros, C, lda, ldb, ldc, K, BLOCK_K); +} #endif template diff --git a/aten/src/ATen/native/cpu/int8mm_kernel.cpp b/aten/src/ATen/native/cpu/int8mm_kernel.cpp index bd266030b2566..d61a1933afc73 100644 --- a/aten/src/ATen/native/cpu/int8mm_kernel.cpp +++ b/aten/src/ATen/native/cpu/int8mm_kernel.cpp @@ -250,10 +250,18 @@ inline void tinygemm_kernel_( }); } +#if __OPTIMIZE__ float32x4_t scale_val = load_as_float32x4(scales); c10::ForcedUnroll{}([&](auto i) { C[m * ldc + i] = reduce(c_val[i]) * vgetq_lane_f32(scale_val, i); }); +#else + // Workaround GCCs inability to infer lane index at compile time + // See https://github.com/pytorch/pytorch/issues/126283 + c10::ForcedUnroll{}([&](auto i) { + C[m * ldc + i] = reduce(c_val[i]) * float(scales[i]); + }); +#endif } } diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp index c0ed650cf0219..84c59a4fd0d71 100644 --- a/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp @@ -32,6 +32,7 @@ #include #include #include +#include #include #include #endif @@ -988,6 +989,11 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, else #endif { +#if defined(USE_ROCM) && ROCM_VERSION >= 60200 + // hipBlasLT requires scaleD to be set to something in order to use AMAX + auto dummy_options = TensorOptions().dtype(kFloat).device(kCUDA); + auto dummy_scale = at::ones(1, dummy_options); +#endif at::cuda::blas::scaled_gemm( args.transa, args.transb, @@ -1005,15 +1011,19 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, bias ? bias->data_ptr(): nullptr, bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_, args.result->data_ptr(), +#if defined(USE_ROCM) && ROCM_VERSION >= 60200 + scale_result ? scale_result->data_ptr() : dummy_scale.data_ptr(), +#else scale_result ? scale_result->data_ptr() : nullptr, +#endif args.result_ld, out_dtype_, amax.data_ptr(), use_fast_accum); } -#if defined(USE_ROCM) - // rocm's hipblaslt does not yet support amax, so calculate separately +#if defined(USE_ROCM) && ROCM_VERSION >= 60000 && ROCM_VERSION < 60200 + // ROCm's hipBLASLt does not support amax before 6.2, so calculate separately amax = at::max(at::abs(out.to(kFloat))); #endif diff --git a/aten/src/ATen/native/cuda/FusedSgdKernel.cu b/aten/src/ATen/native/cuda/FusedSgdKernel.cu index 61da02ce0b888..e644b1048c9b6 100644 --- a/aten/src/ATen/native/cuda/FusedSgdKernel.cu +++ b/aten/src/ATen/native/cuda/FusedSgdKernel.cu @@ -86,12 +86,8 @@ struct FusedSgdMathFunctor { init_args(args, tl, chunk_idx, chunk_size, tensor_loc)}; const auto n = tl.numel_for_tensor[tensor_loc] - chunk_idx * chunk_size; -#ifndef USE_ROCM const auto use_faster_load_store = (n % kILP == 0) && (chunk_size % kILP == 0) && all_aligned; -#else - const auto use_faster_load_store{false}; -#endif if (use_faster_load_store) { for (auto i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 8cf229c69c238..10d8b1ad79cad 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -7714,7 +7714,7 @@ - func: result_type.Scalar_Scalar(Scalar scalar1, Scalar scalar2) -> ScalarType -- func: can_cast(ScalarType from, ScalarType to) -> bool +- func: can_cast(ScalarType from_, ScalarType to) -> bool variants: function - func: promote_types(ScalarType type1, ScalarType type2) -> ScalarType diff --git a/aten/src/ATen/test/vulkan_api_test.cpp b/aten/src/ATen/test/vulkan_api_test.cpp index 687691a370bf4..e41d3d3d6abef 100644 --- a/aten/src/ATen/test/vulkan_api_test.cpp +++ b/aten/src/ATen/test/vulkan_api_test.cpp @@ -788,7 +788,7 @@ TEST_F(VulkanAPITest, avg_pool2d) { ASSERT_TRUE(check); } -TEST_F(VulkanAPITest, batch_norm_invalid_inputs) { +TEST_F(VulkanAPITest, DISABLED_batch_norm_invalid_inputs) { c10::InferenceMode mode; // Act: Vulkan batchnorm only supports evaluation mode diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv index 5b5646e854875..20fb340690ac9 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv @@ -150,7 +150,7 @@ hf_Bert_large,pass,0 -hf_BigBird,pass,0 +hf_BigBird,pass,46 diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_training.csv index 0dd9ce3482f4a..5131c2e9ade4b 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_training.csv @@ -98,7 +98,7 @@ hf_Bert_large,pass,6 -hf_BigBird,pass,6 +hf_BigBird,pass, 52 diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_torchbench_inference.csv index 3e0af614a38c4..40382a4f277ce 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_torchbench_inference.csv @@ -138,7 +138,7 @@ hf_Bert_large,pass,0 -hf_BigBird,fail_accuracy,0 +hf_BigBird,fail_to_run,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv index 07bbe765f1616..431a91d106696 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv @@ -150,7 +150,7 @@ hf_Bert_large,pass,0 -hf_BigBird,fail_to_run,0 +hf_BigBird,pass,46 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_training.csv index 80035c453fbf0..1e1a4be4149e8 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_training.csv @@ -98,7 +98,7 @@ hf_Bert_large,pass,6 -hf_BigBird,fail_to_run,3 +hf_BigBird,pass,52 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv index 07bbe765f1616..f652e5ffa91a6 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv @@ -150,7 +150,7 @@ hf_Bert_large,pass,0 -hf_BigBird,fail_to_run,0 +hf_BigBird,fail_accuracy,46 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_training.csv index eb1195caa9a14..ee58808c0bb03 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_training.csv @@ -98,7 +98,7 @@ hf_Bert_large,pass,6 -hf_BigBird,fail_to_run,3 +hf_BigBird,pass,52 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv index 5b5646e854875..20fb340690ac9 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv @@ -150,7 +150,7 @@ hf_Bert_large,pass,0 -hf_BigBird,pass,0 +hf_BigBird,pass,46 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_training.csv index 0dd9ce3482f4a..cfc5244266440 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_training.csv @@ -98,7 +98,7 @@ hf_Bert_large,pass,6 -hf_BigBird,pass,6 +hf_BigBird,pass,52 diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv index 4ced1b19f2455..108bc6543aa92 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv @@ -150,7 +150,7 @@ hf_Bert_large,pass,0 -hf_BigBird,fail_accuracy,0 +hf_BigBird,fail_accuracy,46 diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv index 0dd9ce3482f4a..cfc5244266440 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv @@ -98,7 +98,7 @@ hf_Bert_large,pass,6 -hf_BigBird,pass,6 +hf_BigBird,pass,52 diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index 096dbc48ec7da..6ea7a31a39150 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -354,6 +354,24 @@ def deterministic_torch_manual_seed(*args, **kwargs): torch.manual_seed = deterministic_torch_manual_seed +def empty_gpu_cache(device): + """ + Explicitly empty gpu cache to avoid OOM in subsequent run. + """ + + if device not in ["cuda", "xpu"]: + log.warning( + "Trying to call the empty_gpu_cache for device: %s, which is not in list [cuda, xpu]", + device, + ) + return + + if device == "cuda": + torch.cuda.empty_cache() + elif device == "xpu": + torch.xpu.empty_cache() + + def synchronize(): pass @@ -1234,7 +1252,7 @@ def wrapper(self, *args, **kwargs) -> Any: ) time.sleep(wait) else: - raise RuntimeError( # noqa: TRY200 + raise RuntimeError( # noqa: B904 f"Failed to load model '{args}' with following error(s): {str(e)}." ) @@ -2278,7 +2296,7 @@ def decay_batch_exp(self, batch_size, factor=0.5, divisor=2): def batch_size_finder(self, device, model_name, initial_batch_size=1024): batch_size = initial_batch_size while batch_size >= 1: - torch.cuda.empty_cache() + empty_gpu_cache(current_device) try: device, name, model, example_inputs, _ = self.load_model( device, @@ -2468,7 +2486,7 @@ def record_status(accuracy_status, dynamo_start_stats): fp64_outputs = None finally: del model_fp64, inputs_fp64 - torch.cuda.empty_cache() + empty_gpu_cache(current_device) tolerance, cos_similarity = self.get_tolerance_and_cosine_flag( self.args.training, current_device, name @@ -2497,7 +2515,7 @@ def record_status(accuracy_status, dynamo_start_stats): return record_status(accuracy_status, dynamo_start_stats=start_stats) finally: del model_copy - torch.cuda.empty_cache() + empty_gpu_cache(current_device) # Rerun native pytorch reset_rng_state() @@ -2518,7 +2536,7 @@ def record_status(accuracy_status, dynamo_start_stats): return record_status(accuracy_status, dynamo_start_stats=start_stats) finally: del model_copy - torch.cuda.empty_cache() + empty_gpu_cache(current_device) # Two eager runs should have exactly same result is_same = True @@ -2719,7 +2737,7 @@ def warmup(fn, model, example_inputs, mode, niters=5): try: if current_device == "cuda": torch.cuda.reset_peak_memory_stats() - torch.cuda.empty_cache() + empty_gpu_cache(current_device) t0 = time.perf_counter() for _ in range(niters): fn(model, example_inputs) @@ -2949,7 +2967,7 @@ def run_one_model( name, model, example_inputs, optimize_ctx, experiment, tag ) print(status) - torch.cuda.empty_cache() + empty_gpu_cache(current_device) self.maybe_preserve_compile_debug(name, status) diff --git a/benchmarks/dynamo/huggingface.py b/benchmarks/dynamo/huggingface.py index d6014706479e3..a998d10bf33c2 100755 --- a/benchmarks/dynamo/huggingface.py +++ b/benchmarks/dynamo/huggingface.py @@ -15,6 +15,10 @@ log = logging.getLogger(__name__) +# Enable FX graph caching +if "TORCHINDUCTOR_FX_GRAPH_CACHE" not in os.environ: + torch._inductor.config.fx_graph_cache = True + def pip_install(package): subprocess.check_call([sys.executable, "-m", "pip", "install", package]) diff --git a/benchmarks/dynamo/timm_models.py b/benchmarks/dynamo/timm_models.py index ed5132001827a..1d291e8d1d75b 100755 --- a/benchmarks/dynamo/timm_models.py +++ b/benchmarks/dynamo/timm_models.py @@ -13,6 +13,10 @@ from torch._dynamo.testing import collect_results, reduce_to_scalar_loss from torch._dynamo.utils import clone_inputs +# Enable FX graph caching +if "TORCHINDUCTOR_FX_GRAPH_CACHE" not in os.environ: + torch._inductor.config.fx_graph_cache = True + def pip_install(package): subprocess.check_call([sys.executable, "-m", "pip", "install", package]) diff --git a/benchmarks/transformer/score_mod.py b/benchmarks/transformer/score_mod.py index 2c5f41502f7ea..57088c45f8a06 100644 --- a/benchmarks/transformer/score_mod.py +++ b/benchmarks/transformer/score_mod.py @@ -3,7 +3,7 @@ from collections import defaultdict from dataclasses import asdict, dataclass from functools import partial -from typing import Callable, List +from typing import Callable, List, Optional, Tuple import numpy as np import torch @@ -29,28 +29,32 @@ def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) -> @dataclass(frozen=True) class ExperimentConfig: - batch_size: int - num_heads: int - q_seq_len: int - k_seq_len: int - head_dim: int + shape: Tuple[int] score_mod: Callable dtype: torch.dtype + calculate_bwd_time: bool + + def __post_init__(self): + assert len(self.shape) == 4, "Shape must be of length 4" def asdict(self): - return asdict(self) + # Convert the dataclass instance to a dictionary + d = asdict(self) + # Remove the 'calculate_bwd_time' key + d.pop("calculate_bwd_time", None) + return d @dataclass(frozen=True) -class ExperimentResults: +class Times: eager_time: float compiled_time: float - def get_entries(self) -> List: - return [ - f"{self.eager_time:2f}", - f"{self.compiled_time:2f}", - ] + +@dataclass(frozen=True) +class ExperimentResults: + fwd_times: Times + bwd_times: Optional[Times] @dataclass(frozen=True) @@ -58,29 +62,31 @@ class Experiment: config: ExperimentConfig results: ExperimentResults - def get_entries(self) -> List: - return self.config.get_entries() + self.results.get_entries() - def asdict(self): - dict1 = asdict(self.config) + dict1 = self.config.asdict() dict2 = asdict(self.results) return {**dict1, **dict2} def generate_inputs( - batch_size, - num_heads, - q_sequence_length, - kv_sequence_length, - head_dim, - dtype, - device, + batch_size: int, + num_heads: int, + q_sequence_length: int, + kv_sequence_length: int, + head_dim: int, + dtype: torch.dtype, + device: torch.device, + requires_grad: bool, ): q_shape = (batch_size, q_sequence_length, num_heads * head_dim) kv_shape = (batch_size, kv_sequence_length, num_heads * head_dim) - make_q = partial(torch.rand, q_shape, device=device, dtype=dtype) - make_kv = partial(torch.rand, kv_shape, device=device, dtype=dtype) + make_q = partial( + torch.rand, q_shape, device=device, dtype=dtype, requires_grad=requires_grad + ) + make_kv = partial( + torch.rand, kv_shape, device=device, dtype=dtype, requires_grad=requires_grad + ) query = ( make_q() .view(batch_size, q_sequence_length, num_heads, head_dim) @@ -101,14 +107,16 @@ def generate_inputs( def run_single_experiment(config: ExperimentConfig, dynamic=False) -> ExperimentResults: device = torch.device("cuda") + batch_size, num_heads, q_seq_len, head_dim = config.shape query, key, value = generate_inputs( - config.batch_size, - config.num_heads, - config.q_seq_len, - config.k_seq_len, - config.head_dim, + batch_size, + num_heads, + q_seq_len, + q_seq_len, + head_dim, config.dtype, device, + requires_grad=config.calculate_bwd_time, ) def eager_sdpa(query, key, value, _): @@ -125,23 +133,47 @@ def eager_sdpa(query, key, value, _): compiled_sdpa, query, key, value, score_mod ) - return ExperimentResults( - eager_time=forward_eager_time, - compiled_time=forward_compiled_time, - ) + if config.calculate_bwd_time: + out_eager = eager_sdpa(query, key, value, score_mod) + dOut = torch.randn_like(out_eager) + backward_eager_time = benchmark_torch_function_in_microseconds( + out_eager.backward, dOut, retain_graph=True + ) + + out_compile = compiled_sdpa(query, key, value, score_mod) + dOut = torch.randn_like(out_eager) + backward_compile_time = benchmark_torch_function_in_microseconds( + out_compile.backward, dOut, retain_graph=True + ) + + return ExperimentResults( + fwd_times=Times(forward_eager_time, forward_compiled_time), + bwd_times=Times(backward_eager_time, backward_compile_time), + ) + else: + return ExperimentResults( + fwd_times=Times(forward_eager_time, forward_compiled_time), + bwd_times=None, + ) -def calculate_speedup(results: ExperimentResults) -> float: - return results.eager_time / results.compiled_time +def calculate_speedup(results: ExperimentResults, type: str) -> float: + if type == "fwd": + return results.fwd_times.eager_time / results.fwd_times.compiled_time + elif type == "bwd": + assert results.bwd_times is not None + return results.bwd_times.eager_time / results.bwd_times.compiled_time + else: + raise ValueError(f"Invalid type {type}") def get_func_name(func): return func.__name__.split(".")[-1].split(" at ")[0] -def get_average_speedups(results: List[Experiment]): +def get_average_speedups(results: List[Experiment], type: str): # Calculate speedups - speedups = [calculate_speedup(r.results) for r in results] + speedups = [calculate_speedup(r.results, type) for r in results] # Find indices of max and min speedups max_speedup_index = np.argmax(speedups) @@ -177,20 +209,39 @@ def print_results(results: List[Experiment]): table_data = defaultdict(list) for experiment in results: for key, value in experiment.asdict().items(): - if key == "eager_time" or key == "compiled_time": - value = float(value) - table_data[key].append(value) + if key == "fwd_times": + for name, time in value.items(): + table_data[f"fwd_{name}"].append(float(time)) + elif key == "bwd_times": + if experiment.config.calculate_bwd_time: + for name, time in value.items(): + table_data[f"bwd_{name}"].append(float(time)) + else: + table_data[key].append(value) # Calculate speedups - speedups = [calculate_speedup(r.results) for r in results] - table_data["speedup"] = speedups + fwd_speedups = [calculate_speedup(r.results, type="fwd") for r in results] + table_data["fwd_speedup"] = fwd_speedups + if results[0].config.calculate_bwd_time: + bwd_speedups = [calculate_speedup(r.results, type="bwd") for r in results] + table_data["bwd_speedup"] = bwd_speedups table_data["score_mod"] = [get_func_name(func) for func in table_data["score_mod"]] print(tabulate(table_data, headers="keys", tablefmt="github", floatfmt=".3f")) - average_data = get_average_speedups(results) + print("\n") + print("FWD Speedups".center(125, "=")) + print("\n") + average_data = get_average_speedups(results, type="fwd") print(tabulate(average_data, headers="keys", tablefmt="github", floatfmt=".3f")) + if results[0].config.calculate_bwd_time: + print("\n") + print("BWD Speedups".center(125, "=")) + print("\n") + average_data = get_average_speedups(results, type="bwd") + print(tabulate(average_data, headers="keys", tablefmt="github", floatfmt=".3f")) + def generate_score_mods() -> List[Callable]: def noop(score, b, h, m, n): @@ -208,8 +259,8 @@ def head_bias(score, b, h, m, n): return [noop, causal_mask, relative_bias, head_bias] -def generate_experiment_configs() -> List[ExperimentConfig]: - batch_sizes = [1, 8, 16] +def generate_experiment_configs(calculate_bwd: bool) -> List[ExperimentConfig]: + batch_sizes = [2, 8, 16] num_heads = [16] q_kv_seq_lens = [(512, 512), (1024, 1024), (4096, 4096)] head_dims = [64, 128, 256] @@ -228,41 +279,49 @@ def generate_experiment_configs() -> List[ExperimentConfig]: ) in itertools.product( batch_sizes, num_heads, q_kv_seq_lens, head_dims, score_mods, dtypes ): + assert q_seq_len == kv_seq_len, "Only equal length inputs supported for now." all_configs.append( ExperimentConfig( - batch_size=bsz, - num_heads=n_heads, - q_seq_len=q_seq_len, - k_seq_len=kv_seq_len, - head_dim=head_dim, + shape=(bsz, n_heads, q_seq_len, head_dim), score_mod=score_mod, dtype=dtype, + calculate_bwd_time=calculate_bwd, ) ) return all_configs -def main(dynamic=False): +def main(dynamic: bool, calculate_bwd: bool): seed = 123 np.random.seed(seed) torch.manual_seed(seed) results = [] - for config in tqdm(generate_experiment_configs()): + for config in tqdm(generate_experiment_configs(calculate_bwd)): results.append( Experiment(config, run_single_experiment(config, dynamic=dynamic)) ) + for config in tqdm(generate_experiment_configs(calculate_bwd)): + results.append(Experiment(config, run_single_experiment(config))) print_results(results) if __name__ == "__main__": - parser = argparse.ArgumentParser() + # Set up the argument parser + parser = argparse.ArgumentParser( + description="Run sweep over sizes and score mods for flex attention" + ) parser.add_argument( "--dynamic", action="store_true", help="Runs a dynamic shapes version of compiled flex attention.", ) + parser.add_argument( + "--calculate-bwd", action="store_true", help="Calculate backward pass times" + ) + # Parse arguments args = parser.parse_args() - main(args.dynamic) + + main(args.dynamic, args.calculate_bwd) diff --git a/build_variables.bzl b/build_variables.bzl index 3f16f9b847c1c..152324a4d90cb 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -487,6 +487,7 @@ libtorch_core_sources = sorted( # These files are the only ones that are supported on Windows. libtorch_distributed_base_sources = [ "torch/csrc/distributed/c10d/Backend.cpp", + "torch/csrc/distributed/c10d/control_collectives/StoreCollectives.cpp", "torch/csrc/distributed/c10d/FileStore.cpp", "torch/csrc/distributed/c10d/Functional.cpp", "torch/csrc/distributed/c10d/GlooDeviceFactory.cpp", diff --git a/c10/util/Float8_e4m3fn.h b/c10/util/Float8_e4m3fn.h index 8e05e2e43bb01..e7a59e343c1f1 100644 --- a/c10/util/Float8_e4m3fn.h +++ b/c10/util/Float8_e4m3fn.h @@ -19,7 +19,7 @@ #include #include -#if defined(__cplusplus) && (__cplusplus >= 201103L) +#if defined(__cplusplus) #include #include #elif !defined(__OPENCL_VERSION__) diff --git a/c10/util/Float8_e4m3fnuz.h b/c10/util/Float8_e4m3fnuz.h index 86ece9ebdadbb..cf73b322e8993 100644 --- a/c10/util/Float8_e4m3fnuz.h +++ b/c10/util/Float8_e4m3fnuz.h @@ -22,7 +22,7 @@ #include #include -#if defined(__cplusplus) && (__cplusplus >= 201103L) +#if defined(__cplusplus) #include #elif !defined(__OPENCL_VERSION__) #include diff --git a/c10/util/Float8_e5m2fnuz.h b/c10/util/Float8_e5m2fnuz.h index f63773914c112..145464e2cfff6 100644 --- a/c10/util/Float8_e5m2fnuz.h +++ b/c10/util/Float8_e5m2fnuz.h @@ -21,7 +21,7 @@ #include #include -#if defined(__cplusplus) && (__cplusplus >= 201103L) +#if defined(__cplusplus) #include #elif !defined(__OPENCL_VERSION__) #include diff --git a/c10/util/Half.h b/c10/util/Half.h index 3d5a38cb365c7..af3435941e48b 100644 --- a/c10/util/Half.h +++ b/c10/util/Half.h @@ -16,7 +16,7 @@ #include #include -#if defined(__cplusplus) && (__cplusplus >= 201103L) +#if defined(__cplusplus) #include #elif !defined(__OPENCL_VERSION__) #include diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index bd2588b5aef35..369bb9b106a0d 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -457,6 +457,9 @@ if(BUILD_LITE_INTERPRETER) append_filelist("libtorch_lite_cmake_sources" LIBTORCH_CMAKE_SRCS) list(APPEND LIBTORCH_CMAKE_SRCS ${LITE_EAGER_SYMOBLICATION_SRCS}) list(APPEND LIBTORCH_CMAKE_SRCS ${LITE_PROFILER_SRCS}) + if(USE_LITE_AOTI) + append_filelist("inductor_core_resources" LIBTORCH_CMAKE_SRCS) + endif() set(CMAKE_POSITION_INDEPENDENT_CODE TRUE) else() append_filelist("libtorch_cmake_sources" LIBTORCH_CMAKE_SRCS) diff --git a/caffe2/core/macros.h.in b/caffe2/core/macros.h.in index e1d8b89325ecc..2497effd8637d 100644 --- a/caffe2/core/macros.h.in +++ b/caffe2/core/macros.h.in @@ -25,7 +25,6 @@ #cmakedefine USE_MKLDNN #cmakedefine CAFFE2_USE_NVTX #cmakedefine CAFFE2_USE_ITT -#cmakedefine CAFFE2_USE_TRT #ifndef EIGEN_MPL2_ONLY #cmakedefine EIGEN_MPL2_ONLY @@ -67,7 +66,6 @@ {"USE_MKLDNN", "${USE_MKLDNN}"}, \ {"USE_NVTX", "${CAFFE2_USE_NVTX}"}, \ {"USE_ITT", "${CAFFE2_USE_ITT}"}, \ - {"USE_TRT", "${CAFFE2_USE_TRT}"}, \ {"USE_ROCM_KERNEL_ASSERT", "${USE_ROCM_KERNEL_ASSERT}"}, \ {"USE_CUSPARSELT", "${USE_CUSPARSELT}"}, \ } diff --git a/cmake/Caffe2Config.cmake.in b/cmake/Caffe2Config.cmake.in index 30e53c5fc7528..c23b3990aff8a 100644 --- a/cmake/Caffe2Config.cmake.in +++ b/cmake/Caffe2Config.cmake.in @@ -79,7 +79,6 @@ if(@USE_CUDA@) # If Caffe2 was compiled with the libraries below, they must # be found again when including the Caffe2 target. set(CAFFE2_USE_CUDA @USE_CUDA@) - set(CAFFE2_USE_TENSORRT @USE_TENSORRT@) # Add current directory to module path so we pick up FindCUDAToolkit.cmake set(old_CMAKE_MODULE_PATH "${CMAKE_MODULE_PATH}") @@ -93,12 +92,6 @@ if(@USE_CUDA@) "libraries. Please set the proper CUDA prefixes and / or install " "CUDA.") endif() - if(@CAFFE2_USE_TENSORRT@ AND NOT CAFFE2_USE_TENSORRT) - message(FATAL_ERROR - "Your installed Caffe2 version uses TensorRT but I cannot find the TensorRT " - "libraries. Please set the proper TensorRT prefixes and / or install " - "TensorRT.") - endif() endif() if(@USE_XPU@) diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index a9a3aab8c5107..a7e38ee73bcce 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -40,7 +40,6 @@ if(USE_CUDA) set(CAFFE2_USE_CUDNN ${USE_CUDNN}) set(CAFFE2_USE_CUSPARSELT ${USE_CUSPARSELT}) set(CAFFE2_USE_NVRTC ${USE_NVRTC}) - set(CAFFE2_USE_TENSORRT ${USE_TENSORRT}) include(${CMAKE_CURRENT_LIST_DIR}/public/cuda.cmake) if(CAFFE2_USE_CUDA) # A helper variable recording the list of Caffe2 dependent libraries @@ -63,11 +62,6 @@ if(USE_CUDA) else() caffe2_update_option(USE_CUSPARSELT OFF) endif() - if(CAFFE2_USE_TENSORRT) - list(APPEND Caffe2_PUBLIC_CUDA_DEPENDENCY_LIBS caffe2::tensorrt) - else() - caffe2_update_option(USE_TENSORRT OFF) - endif() find_program(SCCACHE_EXECUTABLE sccache) if(SCCACHE_EXECUTABLE) # Using RSP/--options-file renders output noncacheable by sccache @@ -84,12 +78,10 @@ if(USE_CUDA) caffe2_update_option(USE_CUDNN OFF) caffe2_update_option(USE_CUSPARSELT OFF) caffe2_update_option(USE_NVRTC OFF) - caffe2_update_option(USE_TENSORRT OFF) set(CAFFE2_USE_CUDA OFF) set(CAFFE2_USE_CUDNN OFF) set(CAFFE2_USE_CUSPARSELT OFF) set(CAFFE2_USE_NVRTC OFF) - set(CAFFE2_USE_TENSORRT OFF) endif() endif() @@ -1300,11 +1292,10 @@ endif() # ---[ CUB if(USE_CUDA) find_package(CUB) - if(CUB_FOUND) - include_directories(SYSTEM ${CUB_INCLUDE_DIRS}) - else() - include_directories(SYSTEM ${CMAKE_CURRENT_LIST_DIR}/../third_party/cub) + if(NOT CUB_FOUND) + message(FATAL_ERROR "Cannot find CUB.") endif() + include_directories(SYSTEM ${CUB_INCLUDE_DIRS}) endif() if(USE_DISTRIBUTED AND USE_TENSORPIPE) @@ -1426,25 +1417,6 @@ if(USE_NNAPI AND NOT ANDROID) caffe2_update_option(USE_NNAPI OFF) endif() -if(USE_ZSTD) - if(USE_SYSTEM_ZSTD) - find_package(zstd REQUIRED) - if(TARGET zstd::libzstd_shared) - set(ZSTD_TARGET zstd::libzstd_shared) - else() - set(ZSTD_TARGET zstd::libzstd_static) - endif() - list(APPEND Caffe2_DEPENDENCY_LIBS ${ZSTD_TARGET}) - get_property(ZSTD_INCLUDE_DIR TARGET ${ZSTD_TARGET} PROPERTY INTERFACE_INCLUDE_DIRECTORIES) - include_directories(SYSTEM ${ZSTD_INCLUDE_DIR}) - else() - list(APPEND Caffe2_DEPENDENCY_LIBS libzstd_static) - include_directories(SYSTEM ${CMAKE_CURRENT_LIST_DIR}/../third_party/zstd/lib) - add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/../third_party/zstd/build/cmake) - set_property(TARGET libzstd_static PROPERTY POSITION_INDEPENDENT_CODE ON) - endif() -endif() - # ---[ Onnx if(CAFFE2_CMAKE_BUILDING_WITH_MAIN_REPO AND NOT INTERN_DISABLE_ONNX) if(EXISTS "${CAFFE2_CUSTOM_PROTOC_EXECUTABLE}") @@ -1511,27 +1483,6 @@ if(CAFFE2_CMAKE_BUILDING_WITH_MAIN_REPO AND NOT INTERN_DISABLE_ONNX) set(BUILD_SHARED_LIBS ${TEMP_BUILD_SHARED_LIBS}) endif() -# --[ TensorRT integration with onnx-trt -function(add_onnx_tensorrt_subdir) - # We pass the paths we found to onnx tensorrt. - set(CUDNN_INCLUDE_DIR "${CUDNN_INCLUDE_PATH}") - set(CUDNN_LIBRARY "${CUDNN_LIBRARY_PATH}") - set(CMAKE_VERSION_ORIG "{CMAKE_VERSION}") - # TODO: this WAR is for https://github.com/pytorch/pytorch/issues/18524 - set(CMAKE_VERSION "3.9.0") - add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/../third_party/onnx-tensorrt EXCLUDE_FROM_ALL) - set(CMAKE_VERSION "{CMAKE_VERSION_ORIG}") -endfunction() -if(CAFFE2_CMAKE_BUILDING_WITH_MAIN_REPO) - if(USE_TENSORRT) - add_onnx_tensorrt_subdir() - include_directories("${CMAKE_CURRENT_LIST_DIR}/../third_party/onnx-tensorrt") - caffe2_interface_library(nvonnxparser_static onnx_trt_library) - list(APPEND Caffe2_DEPENDENCY_WHOLE_LINK_LIBS onnx_trt_library) - set(CAFFE2_USE_TRT 1) - endif() -endif() - # --[ ATen checks set(USE_LAPACK 0) diff --git a/defs.bzl b/defs.bzl index 6c32f5f9c8b48..d2978f3bfb973 100644 --- a/defs.bzl +++ b/defs.bzl @@ -1,7 +1,7 @@ def get_blas_gomp_arch_deps(): return [ ("x86_64", [ - "third-party//IntelComposerXE:{}".format(native.read_config("fbcode", "mkl_lp64", "mkl_lp64_omp")), + "fbsource//third-party/mkl:{}".format(native.read_config("fbcode", "mkl_lp64", "mkl_lp64_omp")), ]), ("aarch64", [ "third-party//OpenBLAS:OpenBLAS", diff --git a/docs/source/community/design.rst b/docs/source/community/design.rst index 73ed7e1447b8f..16b1500afcdd1 100644 --- a/docs/source/community/design.rst +++ b/docs/source/community/design.rst @@ -119,7 +119,7 @@ This principle began as **Python First**: PyTorch is not a Python binding into a monolithic C++ framework. It is built to be deeply integrated into Python. You can use it naturally like you would use `NumPy `__, - `SciPy `__, `scikit-learn <(https://scikit-learn.org/>`__, + `SciPy `__, `scikit-learn `__, or other Python libraries. You can write your new neural network layers in Python itself, using your favorite libraries and use packages such as `Cython `__ and diff --git a/docs/source/notes/serialization.rst b/docs/source/notes/serialization.rst index 09fd9e858b87e..225486cdedac9 100644 --- a/docs/source/notes/serialization.rst +++ b/docs/source/notes/serialization.rst @@ -394,3 +394,6 @@ The following utility functions are related to serialization: .. autofunction:: set_default_load_endianness .. autofunction:: get_default_mmap_options .. autofunction:: set_default_mmap_options +.. autofunction:: add_safe_globals +.. autofunction:: clear_safe_globals +.. autofunction:: get_safe_globals diff --git a/functorch/compile/__init__.py b/functorch/compile/__init__.py index 96b853cd2e27e..e7548a5ff6b91 100644 --- a/functorch/compile/__init__.py +++ b/functorch/compile/__init__.py @@ -25,7 +25,6 @@ from torch._functorch.partitioners import ( default_partition, draw_graph, - draw_joint_graph, min_cut_rematerialization_partition, ) from torch._functorch.python_key import pythonkey_decompose diff --git a/pyproject.toml b/pyproject.toml index 3ff4b94447f9c..07f0750820978 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,7 +51,6 @@ ignore = [ "B019", "B023", "B028", # No explicit `stacklevel` keyword argument found - "B904", # Migrate from TRY200 "E402", "C408", # C408 ignored because we like the dict keyword argument syntax "E501", # E501 is not flexible enough, we're using B950 instead @@ -90,6 +89,7 @@ ignore = [ ] select = [ "B", + "B904", # Re-raised error without specifying the cause via the from keyword "C4", "G", "E", @@ -133,7 +133,6 @@ select = [ "RUF017", "RUF018", # no assignment in assert "TRY002", # ban vanilla raise (todo fix NOQAs) - "TRY200", # TODO: migrate from deprecated alias "TRY302", "TRY401", # verbose-log-message "UP", diff --git a/setup.py b/setup.py index 84f3d48c958e8..93245d971be8b 100644 --- a/setup.py +++ b/setup.py @@ -151,9 +151,6 @@ # USE_REDIS # Whether to use Redis for distributed workflows (Linux only) # -# USE_ZSTD -# Enables use of ZSTD, if the libraries are found -# # USE_ROCM_KERNEL_ASSERT=1 # Enable kernel assert in ROCm platform # diff --git a/test/cpp/api/CMakeLists.txt b/test/cpp/api/CMakeLists.txt index 42b67d8cb25c2..b0e296ad23095 100644 --- a/test/cpp/api/CMakeLists.txt +++ b/test/cpp/api/CMakeLists.txt @@ -10,6 +10,7 @@ set(TORCH_API_TEST_SOURCES ${TORCH_API_TEST_DIR}/functional.cpp ${TORCH_API_TEST_DIR}/init.cpp ${TORCH_API_TEST_DIR}/integration.cpp + ${TORCH_API_TEST_DIR}/ivalue.cpp ${TORCH_API_TEST_DIR}/jit.cpp ${TORCH_API_TEST_DIR}/memory.cpp ${TORCH_API_TEST_DIR}/meta_tensor.cpp diff --git a/test/cpp/api/ivalue.cpp b/test/cpp/api/ivalue.cpp new file mode 100644 index 0000000000000..fa8dcc25cd4d4 --- /dev/null +++ b/test/cpp/api/ivalue.cpp @@ -0,0 +1,63 @@ +#include + +#include + +#include +#include +#include + +#include + +#include + +#include +#include +#include +#include +#include + +using namespace torch::test; +using namespace torch::nn; +using namespace torch::optim; + +TEST(IValueTest, DeepcopyTensors) { + torch::Tensor t0 = torch::randn({2, 3}); + torch::Tensor t1 = torch::randn({3, 4}); + torch::Tensor t2 = t0.detach(); + torch::Tensor t3 = t0; + torch::Tensor t4 = t1.as_strided({2, 3}, {3, 1}, 2); + std::vector tensor_vector = {t0, t1, t2, t3, t4}; + c10::List tensor_list(tensor_vector); + torch::IValue tensor_list_ivalue(tensor_list); + + c10::IValue::CompIdentityIValues ivalue_compare; + + // Make sure our setup configuration is correct + ASSERT_TRUE(ivalue_compare(tensor_list[0].get(), tensor_list[3].get())); + ASSERT_FALSE(ivalue_compare(tensor_list[0].get(), tensor_list[1].get())); + ASSERT_FALSE(ivalue_compare(tensor_list[0].get(), tensor_list[2].get())); + ASSERT_FALSE(ivalue_compare(tensor_list[1].get(), tensor_list[4].get())); + ASSERT_TRUE(tensor_list[0].get().isAliasOf(tensor_list[2].get())); + + c10::IValue copied_ivalue = tensor_list_ivalue.deepcopy(); + c10::List copied_list = copied_ivalue.toList(); + + // Make sure our setup configuration is correct + ASSERT_TRUE(ivalue_compare(copied_list[0].get(), copied_list[3].get())); + ASSERT_FALSE(ivalue_compare(copied_list[0].get(), copied_list[1].get())); + ASSERT_FALSE(ivalue_compare(copied_list[0].get(), copied_list[2].get())); + ASSERT_FALSE(ivalue_compare(copied_list[1].get(), copied_list[4].get())); + // NOTE: this is actually incorrect. Ideally, these _should_ be aliases. + ASSERT_FALSE(copied_list[0].get().isAliasOf(copied_list[2].get())); + + ASSERT_TRUE(copied_list[0].get().toTensor().allclose( + tensor_list[0].get().toTensor())); + ASSERT_TRUE(copied_list[1].get().toTensor().allclose( + tensor_list[1].get().toTensor())); + ASSERT_TRUE(copied_list[2].get().toTensor().allclose( + tensor_list[2].get().toTensor())); + ASSERT_TRUE(copied_list[3].get().toTensor().allclose( + tensor_list[3].get().toTensor())); + ASSERT_TRUE(copied_list[4].get().toTensor().allclose( + tensor_list[4].get().toTensor())); +} diff --git a/test/distributed/_composable/fsdp/test_fully_shard_clip_grad_norm_.py b/test/distributed/_composable/fsdp/test_fully_shard_clip_grad_norm_.py index 4c22ea347156e..9139b62f13673 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_clip_grad_norm_.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_clip_grad_norm_.py @@ -67,13 +67,13 @@ def _test_clip_grad_norm( ) comm_mode = CommDebugMode() with comm_mode: + # foreach is default to turn on so we don't need to specify it. total_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), max_norm=max_norm, norm_type=norm_type, - foreach=True, ) - self.assertEqual(ref_total_norm, total_norm) + self.assertEqual(ref_total_norm, total_norm.full_tensor()) # Expect one all-reduce per mesh dim for partial -> replicate expected_all_reduces = len(total_norm.placements) self.assertEqual( diff --git a/test/distributed/_composable/fsdp/test_fully_shard_comm.py b/test/distributed/_composable/fsdp/test_fully_shard_comm.py index 115c1f93227c6..283b8ab2b944b 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_comm.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_comm.py @@ -244,7 +244,7 @@ def _test_reduce_scatter( group = fsdp_param_group.mesh_info.shard_process_group self.assertEqual(group.size(), self.world_size) all_reduce_stream = torch.cuda.Stream() - view_out_event = foreach_reduce( + post_reduce_event, _ = foreach_reduce( fsdp_params, unsharded_grads, group, @@ -254,8 +254,10 @@ def _test_reduce_scatter( device=self.device, all_reduce_group=None, all_reduce_stream=all_reduce_stream, + all_reduce_grads=True, + partial_reduce_output=None, ) - torch.cuda.current_stream().wait_event(view_out_event) + torch.cuda.current_stream().wait_event(post_reduce_event) # Check reduce-scatter correctness predivide_factor, postdivide_factor = _get_gradient_divide_factors( diff --git a/test/distributed/_composable/fsdp/test_fully_shard_init.py b/test/distributed/_composable/fsdp/test_fully_shard_init.py index 3dfaab80dbe1f..73e078c0b2f2d 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_init.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_init.py @@ -24,6 +24,10 @@ Shard, ) from torch.distributed.device_mesh import init_device_mesh +from torch.distributed.fsdp._init_utils import ( + _init_inter_node_process_group, + _init_intra_node_process_group, +) from torch.distributed.tensor.parallel import ( ColwiseParallel, parallelize_module, @@ -672,7 +676,7 @@ def world_size(self) -> int: return 4 @unittest.skipIf(not TEST_CUDA, "no cuda") - def test_process_group_init(self): + def test_1d_process_group_init(self): assert self.world_size == 4, f"{self.world_size}" # For convenience, use device mesh's infra to construct the DP PG # (in practice, the trainer would do it manually via `new_group()`) @@ -684,11 +688,10 @@ def test_process_group_init(self): dp_pg = ref_dp_mesh.get_group(0) # Check the `from_group()` API for correctness - dp_mesh = DeviceMesh.from_group(dp_pg, "cuda") - # We only compare the mesh tensors instead of the DeviceMesh objects - # since mesh_dim_names attributes and parent mesh are different. + dp_mesh = DeviceMesh.from_group(dp_pg, "cuda", mesh_dim_names=("dp",)) + # Only compare the mesh tensors, not `DeviceMesh` objects themselves, + # since the ref has a parent mesh, while the `from_group` one does not self.assertEqual(dp_mesh.mesh, ref_dp_mesh.mesh) - # self.assertFalse(hasattr(dp_mesh, "_coordinate_on_dim")) self.assertEqual(dp_mesh._coordinate_on_dim, ref_dp_mesh._coordinate_on_dim) self.assertEqual(dp_mesh._dim_group_infos, ref_dp_mesh._dim_group_infos) @@ -722,7 +725,9 @@ def test_process_group_init(self): loss.backward() self.assertEqual(loss, ref_loss) for param, ref_param in zip(model.parameters(), ref_model.parameters()): - # we cannot directly compare param and ref_param because their parent mesh is different. + # Cannot compare `DTensor`s directly since their meshes are not + # equal due to the ref parameter's mesh having a parent mesh while + # the other's mesh does not self.assertEqual(param.to_local(), ref_param.to_local()) self.assertEqual(param.device_mesh.mesh, ref_param.device_mesh.mesh) self.assertEqual(param.grad.to_local(), ref_param.grad.to_local()) @@ -730,6 +735,83 @@ def test_process_group_init(self): param.grad.device_mesh.mesh, ref_param.grad.device_mesh.mesh ) + @unittest.skipIf(not TEST_CUDA, "no cuda") + def test_2d_process_group_init(self): + shard_mesh_dim_size = 2 + assert ( + self.world_size % shard_mesh_dim_size == 0 + ), f"Expects {self.world_size} to be divisible by {shard_mesh_dim_size}" + replicate_mesh_dim_size = self.world_size // shard_mesh_dim_size + mesh_dim_names = ("replicate", "shard") + ref_mesh = init_device_mesh( + "cuda", + (replicate_mesh_dim_size, shard_mesh_dim_size), + mesh_dim_names=mesh_dim_names, + ) + + # Use the global PG as the parent group (in practice, this could be a + # subgroup of the global PG) + dp_group = dist.distributed_c10d._get_default_group() + dp_shard_group = _init_intra_node_process_group(shard_mesh_dim_size) + dp_replicate_group = _init_inter_node_process_group( + dp_group, replicate_mesh_dim_size + ) + mesh_tensor = torch.tensor( + dist.get_process_group_ranks(dp_group), dtype=torch.int + ).view(replicate_mesh_dim_size, shard_mesh_dim_size) + + # Check the `from_group()` API for correctness + mesh = DeviceMesh.from_group( + [dp_replicate_group, dp_shard_group], + "cuda", + mesh_dim_names=mesh_dim_names, + mesh=mesh_tensor, + ) + self.assertEqual(mesh.mesh, ref_mesh.mesh) + self.assertEqual(mesh._coordinate_on_dim, ref_mesh._coordinate_on_dim) + for (tag, ranks, group_name), (ref_tag, ref_ranks, ref_group_name) in zip( + mesh._dim_group_infos, ref_mesh._dim_group_infos + ): + # Since we manually constructed new subgroups, the test and ref + # groups are not the same + self.assertEqual(ranks, ref_ranks) + for mesh_dim_name in mesh_dim_names: + child_mesh = mesh[mesh_dim_name] + ref_child_mesh = ref_mesh[mesh_dim_name] + self.assertEqual(child_mesh, ref_child_mesh) + child_ranks = dist.distributed_c10d.get_process_group_ranks( + child_mesh.get_group() + ) + ref_child_ranks = dist.distributed_c10d.get_process_group_ranks( + ref_child_mesh.get_group() + ) + self.assertEqual(child_ranks, ref_child_ranks) + + # Check HSDP forward/backward parity + torch.manual_seed(42) + mlp_dim = 8 + ref_model = MLP(mlp_dim) + for param in ref_model.parameters(): + dist.broadcast(param.detach(), src=0) + model = copy.deepcopy(ref_model) + + # Parallelize the test model with the ref mesh + for module in (ref_model.in_proj, ref_model.out_proj, ref_model): + fully_shard(module, mesh=ref_mesh) + # Parallelize the test model with the new mesh from the PG + for module in (model.in_proj, model.out_proj, model): + fully_shard(module, mesh=mesh) + + inp = torch.randn((4, mlp_dim), device="cuda") + ref_loss = ref_model(inp).sum() + ref_loss.backward() + loss = model(inp).sum() + loss.backward() + self.assertEqual(loss, ref_loss) + for param, ref_param in zip(model.parameters(), ref_model.parameters()): + self.assertEqual(param, ref_param) + self.assertEqual(param.grad, ref_param.grad) + class TestFullyShardHSDPBroadcast(FSDPTestMultiThread): @property diff --git a/test/distributed/_composable/fsdp/test_fully_shard_training.py b/test/distributed/_composable/fsdp/test_fully_shard_training.py index eec060d3004cc..392596549d771 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_training.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_training.py @@ -672,6 +672,12 @@ def test_gradient_accumulation(self): "mode": ["all", "root_only", "some_mlps"], "reshard_after_backward": [False, True], "offload_policy": [OffloadPolicy(), CPUOffloadPolicy()], + # For HSDP only: + # `True`: reduce-scatter only (no all-reduce) each microbatch + # until the last microbatch + # `False`: neither reduce-scatter nor all-reduce each + # microbatch until the last microbatch + "reduce_scatter_only": [False, True], }, self._test_gradient_accumulation, ) @@ -683,15 +689,20 @@ def _test_gradient_accumulation( mode: str, reshard_after_backward: bool, offload_policy: OffloadPolicy, + reduce_scatter_only: bool, # for HSDP ): if ( - not reshard_after_backward - and (reshard_after_forward is not False or mode == "some_mlps") - ) or ( - isinstance(offload_policy, CPUOffloadPolicy) - and reshard_after_forward is not True + ( + not reshard_after_backward + and (reshard_after_forward is not False or mode == "some_mlps") + ) + or ( + isinstance(offload_policy, CPUOffloadPolicy) + and reshard_after_forward is not True + ) + or (mesh.ndim != 2 and reduce_scatter_only) ): - return # skip since not common + return # skip since not common or applicable torch.manual_seed(42) batch_size, lin_dim, num_mlps, num_microbatches = (2, 32, 3, 3) @@ -713,29 +724,35 @@ def _test_gradient_accumulation( ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2) optim = torch.optim.Adam(model.parameters(), lr=1e-2) + def set_grad_sync_flag( + module: nn.Module, is_last_microbatch: bool, recurse: bool = True + ): + if reduce_scatter_only: + module.set_requires_all_reduce(is_last_microbatch, recurse=recurse) + else: + module.set_requires_gradient_sync(is_last_microbatch, recurse=recurse) + + def set_backward_flags(_model: nn.Module, is_last_microbatch: bool): + if mode == "all": + set_grad_sync_flag(_model, is_last_microbatch) + if not reshard_after_backward: + _model.set_reshard_after_backward(is_last_microbatch) + elif mode == "some_mlps": + for mlp in model[1 : 1 + num_mlps_to_disable_reduce_scatter]: + set_grad_sync_flag(mlp, is_last_microbatch) + if not reshard_after_backward: + mlp.set_reshard_after_backward(is_last_microbatch) + elif mode == "root_only": + set_grad_sync_flag(model, is_last_microbatch, recurse=False) + if not reshard_after_backward: + model.set_reshard_after_backward(is_last_microbatch, recurse=False) + torch.manual_seed(42 + self.rank + 1) for iter_idx in range(5): with CommDebugMode() as comm_mode: for microbatch_idx in range(num_microbatches): is_last_microbatch = microbatch_idx == num_microbatches - 1 - if mode == "all": - model.set_requires_gradient_sync(is_last_microbatch) - if not reshard_after_backward: - model.set_reshard_after_backward(is_last_microbatch) - elif mode == "some_mlps": - for mlp in model[1 : 1 + num_mlps_to_disable_reduce_scatter]: - mlp.set_requires_gradient_sync(is_last_microbatch) - if not reshard_after_backward: - mlp.set_reshard_after_backward(is_last_microbatch) - elif mode == "root_only": - model.set_requires_gradient_sync( - is_last_microbatch, recurse=False - ) - if not reshard_after_backward: - model.set_reshard_after_backward( - is_last_microbatch, recurse=False - ) - + set_backward_flags(model, is_last_microbatch) inp = torch.randn(batch_size, lin_dim, device="cuda") losses: List[torch.Tensor] = [] for _model in (ref_model, model): @@ -760,10 +777,15 @@ def _test_gradient_accumulation( elif mode == "root_only": # Expect additional reduce-scatters for all MLPs expected_reduce_scatter_count += (num_mlps) * (num_microbatches - 1) - self.assertEqual(reduce_scatter_count, expected_reduce_scatter_count) expected_all_reduce_count = ( expected_reduce_scatter_count if mesh.ndim == 2 else 0 ) + if reduce_scatter_only: + # Specially for HSDP if only reduce-scattering but not + # all-reducing until the last microbatch, expect one + # reduce-scatter per MLP plus for the root per microbatch + expected_reduce_scatter_count = (num_mlps + 1) * num_microbatches + self.assertEqual(reduce_scatter_count, expected_reduce_scatter_count) self.assertEqual(all_reduce_count, expected_all_reduce_count) # Expect one all-gather per MLP plus one for the root's linear in diff --git a/test/distributed/_tensor/test_view_ops.py b/test/distributed/_tensor/test_view_ops.py index 429e62588651a..2ea89e34789bf 100644 --- a/test/distributed/_tensor/test_view_ops.py +++ b/test/distributed/_tensor/test_view_ops.py @@ -11,9 +11,9 @@ from torch.distributed._tensor.debug import CommDebugMode from torch.distributed._tensor.ops.view_ops import ( Broadcast, + dim_maps, Flatten, InputDim, - ops, Repeat, Singleton, Split, @@ -130,8 +130,8 @@ def world_size(self) -> int: return 6 def call_dt_test(self, op, args, kwargs, device_mesh: DeviceMesh): - spec = ops[op] - rules = spec.dim_map(*args, **kwargs) + dim_map = dim_maps[op] + rules = dim_map(*args, **kwargs) outputs = op(*args, **kwargs) flat_args = pytree.arg_tree_leaves(*args) in_shape = flat_args[0].shape @@ -163,7 +163,6 @@ def call_dt_test(self, op, args, kwargs, device_mesh: DeviceMesh): ) for in_shard in all_sharding_choices: - # print(f' |--- {in_shard}') in_dt = distribute_tensor(args[0], device_mesh, in_shard) comm_mode = CommDebugMode() @@ -180,7 +179,7 @@ def call_dt_test(self, op, args, kwargs, device_mesh: DeviceMesh): self.assertEqual(outputs, full_out) def dimmap_test(self, op, args, expected_rule_output): - rules = ops[op].dim_map(*args) + rules = dim_maps[op](*args) self.assertEqual(rules, expected_rule_output) self.call_dt_test(op, args, {}, self.device_mesh) @@ -229,7 +228,7 @@ def test_view_ops(self): ) with self.assertRaises(AssertionError): - ops[torch.broadcast_to].dim_map(randn(24, 36), (1, 2, 4)) + dim_maps[torch.broadcast_to](randn(24, 36), (1, 2, 4)) self.dimmap_test( torch.broadcast_to, @@ -495,14 +494,14 @@ def test_complex_view_ops(self): InputDim(0), Flatten((InputDim(1), InputDim(2))), ) - view_as_complex_rule = ops[torch.view_as_complex].dim_map(inp) + view_as_complex_rule = dim_maps[torch.view_as_complex](inp) self.assertEqual(view_as_complex_rule, expected_view_as_complex_rule) expected_view_as_real_rule = ( InputDim(0), Split(InputDim(1), (13, 2), 0), Split(InputDim(1), (13, 2), 1), ) - view_as_real_rule = ops[torch.view_as_real].dim_map(intermediate) + view_as_real_rule = dim_maps[torch.view_as_real](intermediate) self.assertEqual(view_as_real_rule, expected_view_as_real_rule) # test sharded computation correctness diff --git a/test/distributed/pipelining/test_unflatten.py b/test/distributed/pipelining/test_unflatten.py index 9c388d279cdfd..37eaf599e4d81 100644 --- a/test/distributed/pipelining/test_unflatten.py +++ b/test/distributed/pipelining/test_unflatten.py @@ -45,7 +45,6 @@ def test_unflatten(self): constant = torch.ones(1, 16, 256, 256) mod = M() - print("Original model:\n", mod) pipe = pipeline( mod, @@ -58,21 +57,19 @@ def test_unflatten(self): orig_state_dict = mod.state_dict() # Check qualnames - print("\nParameters of each stage:") for stage_idx in range(pipe.num_stages): - print(f"\nStage {stage_idx}:") stage_mod = pipe.get_stage_module(stage_idx) for param_name, param in stage_mod.named_parameters(): assert ( param_name in orig_state_dict ), f"{param_name} not in original state dict" - print(f"{param_name}: {param.size()}") + print("Param qualname test passed") # Check equivalence ref = mod(x, constant) out = pipe(x, constant)[0] torch.testing.assert_close(out, ref) - print(f"\nEquivalence test passed {torch.sum(out)} ref {torch.sum(ref)}") + print(f"Equivalence test passed {torch.sum(out)} ref {torch.sum(ref)}") if __name__ == "__main__": diff --git a/test/distributed/test_c10d_functional_native.py b/test/distributed/test_c10d_functional_native.py index 54030d1f1d42b..775d3f9cc03df 100644 --- a/test/distributed/test_c10d_functional_native.py +++ b/test/distributed/test_c10d_functional_native.py @@ -195,6 +195,18 @@ def test_all_gather_into_tensor_single(self) -> None: assert torch.allclose(output, expect) assert output.eq(expect).all() + # Test out-variant of all_gather_into_tensor + output = torch.empty(expect.shape, device=self.device) + output = torch.ops._c10d_functional.all_gather_into_tensor_out( + input, + self.world_size, + "default", + out=output, + ) + output = torch.ops._c10d_functional.wait_tensor(output) + assert torch.allclose(output, expect) + assert output.eq(expect).all() + # Test Python API and AsyncCollectiveTensor output = all_gather_tensor( input, diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index a0629f054ae02..5a958acdbdd74 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -2577,6 +2577,27 @@ def test_all_reduce_coalesced_nccl(self): ), ) + @requires_nccl() + @skip_if_lt_x_gpu(2) + def test_all_reduce_coalesced_nccl_float8_errors(self): + store = c10d.FileStore(self.file_name, self.world_size) + c10d.init_process_group( + backend="nccl", store=store, rank=self.rank, world_size=self.world_size + ) + process_group = c10d.distributed_c10d._get_default_group() + device = torch.device("cuda:%d" % self.rank) + tensors = [ + torch.full( + (60 + i,), self.rank + 1 + i, device=device, dtype=torch.float + ).to(torch.float8_e4m3fn) + for i in range(5) + ] + with self.assertRaisesRegex( + RuntimeError, + "Float8 dtypes are not currenlty supported for NCCL reductions", + ): + torch.distributed.all_reduce_coalesced(tensors, group=process_group) + @requires_nccl() @skip_if_lt_x_gpu(2) def test_all_reduce_coalesced_manager_nccl(self): @@ -2940,6 +2961,56 @@ def test_reduce_scatter_tensor_coalesced(self): dist.reduce_scatter_tensor(output_tensors[i], input_tensors[i]) self.assertEqual(output_tensors, input_tensors[self.rank] * self.world_size) + @requires_nccl() + @skip_if_lt_x_gpu(2) + def test_reduce_scatter_base_k_float8_errors(self): + store = dist.FileStore(self.file_name, self.world_size) + dist.init_process_group( + "nccl", + world_size=self.world_size, + rank=self.rank, + store=store, + ) + output_tensor = ( + torch.zeros(2, dtype=torch.float32).to(torch.float8_e4m3fn).to(self.rank) + ) + input_tensors = ( + torch.arange(self.world_size * 2, dtype=torch.float32) + .to(torch.float8_e4m3fn) + .to(self.rank) + ) + input_tensors = torch.reshape(input_tensors, (self.world_size, 2)) + with self.assertRaisesRegex( + RuntimeError, + "Float8 dtypes are not currenlty supported for NCCL reductions", + ): + dist.reduce_scatter_tensor(output_tensor, input_tensors) + + @requires_nccl() + @skip_if_lt_x_gpu(2) + def test_reduce_scatter_tensor_coalesced_float8_errors(self): + store = dist.FileStore(self.file_name, self.world_size) + dist.init_process_group( + "nccl", + world_size=self.world_size, + rank=self.rank, + store=store, + ) + output_tensors = torch.zeros(2, 2).to(torch.float8_e5m2).to(self.rank) + input_tensors = [ + torch.ones(2, 2).to(torch.float8_e5m2).to(self.rank) + for _ in range(self.world_size) + ] + + with self.assertRaisesRegex( + RuntimeError, + "Float8 dtypes are not currenlty supported for NCCL reductions", + ): + with dist._coalescing_manager(): + for i in range(self.world_size): + dist.reduce_scatter_tensor(output_tensors[i], input_tensors[i]) + self.assertEqual(output_tensors, input_tensors[self.rank]) + class SetDeviceMethod(Enum): TORCH_CUDA_SET = auto() # torch.cuda.set_device @@ -2980,6 +3051,28 @@ def test_allgather_base(self): dist.all_gather_into_tensor(output_tensor, tensor) self.assertEqual(output_tensor, tensor) + @requires_nccl() + @skip_if_lt_x_gpu(1) + @parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) + def test_allgather_float8(self, float8_dtype): + store = dist.FileStore(self.file_name, self.world_size) + dist.init_process_group( + "nccl", + world_size=self.world_size, + rank=self.rank, + store=store, + ) + device = "cuda" + tensor = torch.ones(10, 16, device=torch.device(device)).to(float8_dtype) + output_tensor = torch.zeros(10, 16, device=torch.device(device)).to( + float8_dtype + ) + dist.all_gather_into_tensor(output_tensor, tensor) + self.assertEqual(output_tensor.view(torch.float32), tensor.view(torch.float32)) + + +instantiate_parametrized_tests(NcclProcessGroupWithDispatchedCollectivesTests) + class LargeCommTest(test_c10d_common.AbstractLargeCommTest, MultiProcessTestCase): def setUp(self): diff --git a/test/distributed/test_control_collectives.py b/test/distributed/test_control_collectives.py new file mode 100644 index 0000000000000..fb0067f2dd2e9 --- /dev/null +++ b/test/distributed/test_control_collectives.py @@ -0,0 +1,189 @@ +# Owner(s): ["oncall: distributed"] + +from datetime import timedelta +from multiprocessing.pool import ThreadPool + +import torch +import torch.distributed as dist +from torch.testing._internal.common_utils import run_tests, TestCase + + +class TestCollectives(TestCase): + def test_barrier(self) -> None: + store = dist.HashStore() + + world_size = 2 + + def f(rank: int) -> None: + collectives = dist._StoreCollectives(store, rank, world_size) + collectives.barrier("foo", timedelta(seconds=10), True) + + with ThreadPool(world_size) as pool: + pool.map(f, range(world_size)) + + def test_broadcast(self) -> None: + store = dist.HashStore() + + world_size = 4 + timeout = timedelta(seconds=10) + + def f(rank: int) -> None: + collectives = dist._StoreCollectives(store, rank, world_size) + if rank == 2: + collectives.broadcast_send("foo", b"data", timeout) + else: + out = collectives.broadcast_recv("foo", timeout) + self.assertEqual(out, b"data") + + with ThreadPool(world_size) as pool: + pool.map(f, range(world_size)) + + def test_gather(self) -> None: + store = dist.HashStore() + + world_size = 4 + timeout = timedelta(seconds=10) + + def f(rank: int) -> None: + collectives = dist._StoreCollectives(store, rank, world_size) + if rank == 2: + out = collectives.gather_recv("foo", str(rank), timeout) + self.assertEqual(out, [b"0", b"1", b"2", b"3"]) + else: + collectives.gather_send("foo", str(rank), timeout) + + with ThreadPool(world_size) as pool: + pool.map(f, range(world_size)) + + def test_scatter(self) -> None: + store = dist.HashStore() + + world_size = 4 + timeout = timedelta(seconds=10) + + def f(rank: int) -> None: + collectives = dist._StoreCollectives(store, rank, world_size) + if rank == 2: + out = collectives.scatter_send( + "foo", [str(i) for i in range(world_size)], timeout + ) + else: + out = collectives.scatter_recv("foo", timeout) + self.assertEqual(out, str(rank).encode()) + + with ThreadPool(world_size) as pool: + pool.map(f, range(world_size)) + + def test_all_sum(self) -> None: + store = dist.HashStore() + + world_size = 4 + timeout = timedelta(seconds=10) + + def f(rank: int) -> None: + collectives = dist._StoreCollectives(store, rank, world_size) + out = collectives.all_sum("foo", rank, timeout) + self.assertEqual(out, sum(range(world_size))) + + with ThreadPool(world_size) as pool: + pool.map(f, range(world_size)) + + def test_broadcast_timeout(self) -> None: + store = dist.HashStore() + + world_size = 4 + timeout = timedelta(milliseconds=1) + collectives = dist._StoreCollectives(store, 1, world_size) + with self.assertRaisesRegex(Exception, "Wait timeout"): + collectives.broadcast_recv("foo", timeout) + + def test_gather_timeout(self) -> None: + store = dist.HashStore() + + world_size = 4 + timeout = timedelta(milliseconds=1) + collectives = dist._StoreCollectives(store, 1, world_size) + with self.assertRaisesRegex( + Exception, "gather failed -- missing ranks: 0, 2, 3" + ): + collectives.gather_recv("foo", "data", timeout) + + def test_scatter_timeout(self) -> None: + store = dist.HashStore() + + world_size = 4 + timeout = timedelta(milliseconds=1) + collectives = dist._StoreCollectives(store, 1, world_size) + with self.assertRaisesRegex(Exception, "Wait timeout"): + collectives.scatter_recv("foo", timeout) + + def test_all_gather_timeout(self) -> None: + store = dist.HashStore() + + world_size = 4 + timeout = timedelta(milliseconds=1) + collectives = dist._StoreCollectives(store, 1, world_size) + with self.assertRaisesRegex( + Exception, "all_gather failed -- missing ranks: 0, 2, 3" + ): + collectives.all_gather("foo", "data", timeout) + + def test_barrier_timeout(self) -> None: + store = dist.HashStore() + + world_size = 4 + timeout = timedelta(milliseconds=1) + collectives = dist._StoreCollectives(store, 1, world_size) + with self.assertRaisesRegex( + Exception, "barrier failed -- missing ranks: 0, 2, 3" + ): + collectives.barrier("foo", timeout, True) + + def test_all_sum_timeout(self) -> None: + store = dist.HashStore() + + world_size = 4 + timeout = timedelta(milliseconds=1) + collectives = dist._StoreCollectives(store, 1, world_size) + with self.assertRaisesRegex( + Exception, "barrier failed -- missing ranks: 0, 2, 3" + ): + collectives.all_sum("foo", 1, timeout) + + def test_unique(self) -> None: + store = dist.HashStore() + + collectives = dist._StoreCollectives(store, 1, 1) + collectives.broadcast_send("foo", "bar") + + with self.assertRaisesRegex(Exception, "Key foo has already been used"): + collectives.broadcast_send("foo", "bar") + + with self.assertRaisesRegex(Exception, "Key foo has already been used"): + collectives.broadcast_recv("foo") + + with self.assertRaisesRegex(Exception, "Key foo has already been used"): + collectives.gather_send("foo", "bar") + + with self.assertRaisesRegex(Exception, "Key foo has already been used"): + collectives.gather_recv("foo", "asdf") + + with self.assertRaisesRegex(Exception, "Key foo has already been used"): + collectives.scatter_send("foo", ["asdf"]) + + with self.assertRaisesRegex(Exception, "Key foo has already been used"): + collectives.scatter_recv("foo") + + with self.assertRaisesRegex(Exception, "Key foo has already been used"): + collectives.all_gather("foo", "bar") + + with self.assertRaisesRegex(Exception, "Key foo has already been used"): + collectives.all_sum("foo", 2) + + +if __name__ == "__main__": + assert ( + not torch.cuda._initialized + ), "test_distributed must not have initialized CUDA context on main process" + + run_tests() diff --git a/test/distributed/test_device_mesh.py b/test/distributed/test_device_mesh.py index e6c1e27e23ce3..8f70ee2f0b7d8 100644 --- a/test/distributed/test_device_mesh.py +++ b/test/distributed/test_device_mesh.py @@ -168,7 +168,7 @@ def test_fake_pg_device_mesh(self): self.assertEqual(global_tensor.shape, (self.world_size * 2, 8)) @with_comms - def test_from_group(self): + def test_from_group_with_global_pg(self): # Simple test: check `from_group` for a global PG vs. directly # initializing via `init_device_mesh` global_pg = _get_default_group() @@ -180,6 +180,23 @@ def test_from_group(self): ref_global_mesh._coordinate_on_dim, global_mesh._coordinate_on_dim ) + @with_comms + def test_from_group_with_invalid_mesh(self): + global_pg = _get_default_group() + global_pg_size = global_pg.size() + assert global_pg_size == 4, "Test assumes global world size of 4" + invalid_mesh = [[0, 1], [2, 3]] # 2D mesh when we need 1D + regex = r"Invalid mesh \[\[0, 1\], \[2, 3\]\] for ProcessGroup with ranks \[0, 1, 2, 3\]" + with self.assertRaisesRegex(ValueError, regex): + DeviceMesh.from_group(global_pg, "cuda", invalid_mesh) + + device_mesh = init_device_mesh(self.device_type, (2, 2)) + groups = device_mesh.get_group() + invalid_mesh = (0, 1, 2, 3) # 1D mesh when we need 2D + regex = r"Expects mesh with ndim equal to number of ProcessGroups but got mesh \[0, 1, 2, 3\] and 2 ProcessGroups" + with self.assertRaisesRegex(ValueError, regex): + DeviceMesh.from_group(groups, self.device_type, invalid_mesh) + def test_raises_invalid_device_type(self): with self.assertRaisesRegex( RuntimeError, @@ -280,8 +297,8 @@ def test_device_mesh_parent_child_hash(self): ep_mesh_1 = DeviceMesh(self.device_type, mesh_group_1) ep_mesh_2 = DeviceMesh(self.device_type, mesh_group_2) ep_mesh = ep_mesh_1 if self.rank < self.world_size // 2 else ep_mesh_2 - # # ep_mesh is considered different from mesh_2d["TP"] - # # since mesh_2d["TP"] has a parent mesh while ep_mesh does not. + # ep_mesh is considered different from mesh_2d["TP"] + # since mesh_2d["TP"] has a parent mesh while ep_mesh does not. self.assertEqual(mesh_2d["TP"]._flatten_mesh_list, ep_mesh._flatten_mesh_list) self.assertEqual(mesh_2d["TP"].mesh.shape, ep_mesh.mesh.shape) self.assertEqual(mesh_2d["TP"].device_type, ep_mesh.device_type) @@ -307,6 +324,47 @@ def test_device_mesh_parent_child_hash(self): self.assertEqual(hash(ep_mesh), hash(another_mesh)) self.assertEqual(ep_mesh, another_mesh) + @with_comms + def test_from_group_with_mesh_shape(self): + """Tests ``from_group`` when passing ``mesh_shape`` as 2D.""" + # Consider two different logical views of the same mesh: + # - (4, 2) ("dp", "tp") mesh + # - (2, 2, 2) ("dp_replicate", "dp_shard", "tp") mesh + mesh_shape = (2, 2, 2) + mesh_dim_names = ("dp_replicate", "dp_shard", "tp") + ref_mesh = init_device_mesh( + self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names + ) + + dp_shard_group = ref_mesh["dp_shard"].get_group() + dp_replicate_group = ref_mesh["dp_replicate"].get_group() + + dp_mesh = DeviceMesh.from_group( + [dp_replicate_group, dp_shard_group], + self.device_type, + mesh=ref_mesh.mesh[:, :, ref_mesh.get_local_rank(2)], + mesh_dim_names=mesh_dim_names[:2], + ) + + ref_mesh_dp_dim_group_infos = ref_mesh._dim_group_infos[:2] + for (_, ref_ranks, _), (_, ranks, _) in zip( + ref_mesh_dp_dim_group_infos, dp_mesh._dim_group_infos + ): + self.assertEqual(ref_ranks, ranks) + # Cannot check directly for mesh equality since parent meshes are not + # the same since the ref's parent mesh is 3D + self.assertEqual(dp_mesh["dp_replicate"].mesh, ref_mesh["dp_replicate"].mesh) + for (_, ref_ranks, _), (_, ranks, _) in zip( + dp_mesh["dp_replicate"]._dim_group_infos, + ref_mesh["dp_replicate"]._dim_group_infos, + ): + self.assertEqual(ref_ranks, ranks) + self.assertEqual(dp_mesh["dp_shard"].mesh, ref_mesh["dp_shard"].mesh) + for (_, ref_ranks, _), (_, ranks, _) in zip( + dp_mesh["dp_shard"]._dim_group_infos, ref_mesh["dp_shard"]._dim_group_infos + ): + self.assertEqual(ref_ranks, ranks) + class InitDeviceMeshTest(DTensorTestBase): @property diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index fdb23e3f590f3..472e9c56bae63 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -1164,6 +1164,32 @@ def test_tuple_contains(a, b): return a + b return a - b + @unittest.skipIf( + sys.version_info < (3, 9), + "SET_UPDATE was added at Python 3.9", + ) + @make_test + def test_set_update_bytecode(x): + # This produces bytecode SET_UPDATE since python 3.9 + var = {"apple", "banana", "cherry"} + if isinstance(var, set): + return x + 1 + else: + return x - 1 + + @unittest.skipIf( + sys.version_info < (3, 9), + "SET_UPDATE was added at Python 3.9", + ) + @make_test + def test_set_update_list_with_duplicated_items(x): + list1 = ["apple", "banana", "apple"] + list2 = ["orange", "banana"] + if len({*list1, *list2}) == 3: + return x + 1 + else: + return x - 1 + @make_test def test_set_contains(a, b): vals = set(["a", "b", "c"]) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index d28b67f3aa940..96bf924e09990 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -33,6 +33,7 @@ from torch import nn from torch._dynamo.debug_utils import same_two_models from torch._dynamo.testing import CompileCounter, rand_strided, same +from torch._inductor.utils import fresh_inductor_cache from torch.nn import functional as F from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION @@ -400,7 +401,7 @@ def _iter_ex(self, resolve: bool) -> Iterator[Any]: try: return ListConfig.ListIterator(self, resolve) except Exception: - raise AssertionError + raise AssertionError from None def __init__(self): self._content = [ @@ -4971,6 +4972,22 @@ def fn(x): opt_fn = torch.compile(fn, backend="eager") opt_fn(np.ones([3, 3])) + def test_issue126128(self): + def fn(): + x = torch.randn(1, 10) + y = torch.randn(10, 1) + return torch.mm(x, y).sum() + + def fn2(): + x = torch.randn(10, 100) + y = torch.randn(100, 10) + return torch.mm(x, y).sum() + + with fresh_inductor_cache(): + torch.compile(fn)() + + torch.compile(fn2)() + instantiate_parametrized_tests(ReproTests) diff --git a/test/dynamo/test_unspec.py b/test/dynamo/test_unspec.py index 9b49f5ff8bb6a..cb47a0b728a37 100644 --- a/test/dynamo/test_unspec.py +++ b/test/dynamo/test_unspec.py @@ -394,6 +394,36 @@ def f4(v): self.assertEqual(f3(r), optimize(f3)(r)) self.assertEqual(f4(r), optimize(f4)(r)) + def test_to_tensor(self): + def f1(): + a = np.random.uniform(low=-1, high=1, size=(20, 1)) + return torch.tensor([a, a, a, a], dtype=torch.float64, device="cpu") + + def f2(): + a = torch.tensor([[[123]]]) + return torch.tensor([a, a]) + + def f3(): + a = torch.tensor(123) + return torch.tensor([a, a]) + + def f4(): + a = torch.tensor(123) + b = torch.tensor([[[456]]]) + return torch.tensor([a, b]) + + def f5(): + a = np.array([1, 2]) + return torch.tensor([a, a]) + + optimize = torch.compile(backend="aot_eager", fullgraph=True) + + self.assertEqual(f1().shape, optimize(f1)().shape) + self.assertEqual(f2(), optimize(f2)()) + self.assertEqual(f3(), optimize(f3)()) + self.assertEqual(f4(), optimize(f4)()) + self.assertEqual(f5(), optimize(f5)()) + def test_sym_int_conversion(self): def f(x): y = x.size(0) diff --git a/test/dynamo_expected_failures/TestExamplesCorrectnessCPU.test_maml_regression_mechanism_make_functional_cpu b/test/dynamo_expected_failures/TestExamplesCorrectnessCPU.test_maml_regression_mechanism_make_functional_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestExamplesCorrectnessCUDA.test_maml_regression_mechanism_functional_call_cuda b/test/dynamo_expected_failures/TestExamplesCorrectnessCUDA.test_maml_regression_mechanism_functional_call_cuda deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestExamplesCorrectnessCUDA.test_maml_regression_mechanism_make_functional_cuda b/test/dynamo_expected_failures/TestExamplesCorrectnessCUDA.test_maml_regression_mechanism_make_functional_cuda deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestContentStoreCPU.test_repeated_hash_cpu b/test/dynamo_expected_failures/TestQuantizePT2E.test_multi_users_without_output_observer similarity index 100% rename from test/dynamo_expected_failures/TestContentStoreCPU.test_repeated_hash_cpu rename to test/dynamo_expected_failures/TestQuantizePT2E.test_multi_users_without_output_observer diff --git a/test/dynamo_expected_failures/TestExamplesCorrectnessCPU.test_maml_regression_mechanism_functional_call_cpu b/test/dynamo_expected_failures/TestSubclassSerialization.test_allowlist_for_weights_only similarity index 100% rename from test/dynamo_expected_failures/TestExamplesCorrectnessCPU.test_maml_regression_mechanism_functional_call_cpu rename to test/dynamo_expected_failures/TestSubclassSerialization.test_allowlist_for_weights_only diff --git a/test/export/test_experimental.py b/test/export/test_experimental.py index 2d7e88bfc111d..b343dbff27a79 100644 --- a/test/export/test_experimental.py +++ b/test/export/test_experimental.py @@ -8,6 +8,7 @@ from torch._export.wrappers import _mark_strict_experimental from torch._functorch.aot_autograd import aot_export_module +from torch.export._trace import _convert_ts_to_export_experimental from torch.testing import FileCheck @@ -106,6 +107,36 @@ def forward(self, x): ): ep = torch.export.export(M(), inp, strict=False) + def test_torchscript_module_export(self): + class M(torch.nn.Module): + def forward(self, x): + return x.cos() + x.sin() + + model_to_trace = M() + inps = (torch.randn(4, 4),) + traced_module_by_torchscript = torch.jit.trace(M(), example_inputs=inps) + + exported_module = _convert_ts_to_export_experimental( + traced_module_by_torchscript, inps + ) + + self.assertTrue(torch.allclose(exported_module(*inps), model_to_trace(*inps))) + + def test_torchscript_module_export_single_input(self): + class M(torch.nn.Module): + def forward(self, x): + return x.cos() + x.sin() + + model_to_trace = M() + inps = torch.randn(4, 4) + traced_module_by_torchscript = torch.jit.trace(M(), example_inputs=inps) + + exported_module = _convert_ts_to_export_experimental( + traced_module_by_torchscript, inps + ) + + self.assertTrue(torch.allclose(exported_module(inps), model_to_trace(inps))) + if __name__ == "__main__": run_tests() diff --git a/test/export/test_export.py b/test/export/test_export.py index cec463fa3dc0e..406e1f55dd804 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -1522,7 +1522,6 @@ def forward(self, arg1, arg2, *args, kw1, kw2, **kwargs): self._test_export_same_as_eager(kw_func, args, kwargs) @testing.expectedFailureSerDer # we don't save placeholder metadata - @testing.expectedFailureSerDerPreDispatch @testing.expectedFailureNonStrict def test_linear_conv(self): class MyLinear(torch.nn.Module): @@ -2902,7 +2901,6 @@ def forward(self, xs, y): ) @testing.expectedFailureSerDer # We don't preserve metadata on graph module - @testing.expectedFailureSerDerPreDispatch @testing.expectedFailureNonStrict def test_retrace_graph_level_meta_preservation(self): class Foo(torch.nn.Module): @@ -3692,7 +3690,6 @@ def forward(self, q, k, v): self.assertEqual(ep.module()(*inputs), m(*inputs)) @testing.expectedFailureSerDer # symfloat nyi - @testing.expectedFailureSerDerPreDispatch # symfloat nyi def test_sym_sqrt(self): import math diff --git a/test/export/test_export_predispatch.py b/test/export/test_export_predispatch.py deleted file mode 100644 index 2075cba58ca67..0000000000000 --- a/test/export/test_export_predispatch.py +++ /dev/null @@ -1,50 +0,0 @@ -# Owner(s): ["oncall: export"] - -try: - from . import test_export, testing -except ImportError: - import test_export - import testing -from torch.export._trace import _export - -test_classes = {} - - -def mocked_predispatch_export(*args, **kwargs): - # If user already specified strict, don't make it non-strict - ep = _export(*args, **kwargs, pre_dispatch=True) - return ep.run_decompositions() - - -def make_dynamic_cls(cls): - suffix = "_pre_dispatch" - - cls_prefix = "PreDispatchExport" - - test_class = testing.make_test_cls_with_mocked_export( - cls, - cls_prefix, - suffix, - mocked_predispatch_export, - xfail_prop="_expected_failure_pre_dispatch", - ) - - test_classes[test_class.__name__] = test_class - # REMOVING THIS LINE WILL STOP TESTS FROM RUNNING - globals()[test_class.__name__] = test_class - test_class.__module__ = __name__ - return test_class - - -tests = [ - test_export.TestDynamismExpression, - test_export.TestExport, -] -for test in tests: - make_dynamic_cls(test) -del test - -if __name__ == "__main__": - from torch._dynamo.test_case import run_tests - - run_tests() diff --git a/test/export/test_serdes.py b/test/export/test_serdes.py index 253c6db818198..bd11cd7f83662 100644 --- a/test/export/test_serdes.py +++ b/test/export/test_serdes.py @@ -9,7 +9,6 @@ import testing from torch.export import export, load, save -from torch.export._trace import _export test_classes = {} @@ -23,21 +22,10 @@ def mocked_serder_export(*args, **kwargs): return loaded_ep -def mocked_serder_export_pre_dispatch(*args, **kwargs): - ep = _export(*args, **kwargs, pre_dispatch=True) - buffer = io.BytesIO() - save(ep, buffer) - buffer.seek(0) - loaded_ep = load(buffer) - return loaded_ep - - def make_dynamic_cls(cls): suffix = "_serdes" - suffix_pre_dispatch = "_serdes_pre_dispatch" cls_prefix = "SerDesExport" - cls_prefix_pre_dispatch = "SerDesExportPreDispatch" test_class = testing.make_test_cls_with_mocked_export( cls, @@ -47,21 +35,10 @@ def make_dynamic_cls(cls): xfail_prop="_expected_failure_serdes", ) - test_class_pre_dispatch = testing.make_test_cls_with_mocked_export( - cls, - cls_prefix_pre_dispatch, - suffix_pre_dispatch, - mocked_serder_export_pre_dispatch, - xfail_prop="_expected_failure_serdes_pre_dispatch", - ) - test_classes[test_class.__name__] = test_class - test_classes[test_class_pre_dispatch.__name__] = test_class_pre_dispatch # REMOVING THIS LINE WILL STOP TESTS FROM RUNNING globals()[test_class.__name__] = test_class - globals()[test_class_pre_dispatch.__name__] = test_class_pre_dispatch test_class.__module__ = __name__ - test_class_pre_dispatch.__module__ = __name__ tests = [ diff --git a/test/export/test_tools.py b/test/export/test_tools.py new file mode 100644 index 0000000000000..b8ab7616fd679 --- /dev/null +++ b/test/export/test_tools.py @@ -0,0 +1,67 @@ +# Owner(s): ["oncall: export"] + +import torch +from torch._dynamo.test_case import TestCase +from torch._export.tools import report_exportability + +from torch.testing._internal.common_utils import run_tests + +torch.library.define( + "testlib::op_missing_meta", + "(Tensor(a!) x, Tensor(b!) z) -> Tensor", + tags=torch.Tag.pt2_compliant_tag, +) + + +@torch.library.impl("testlib::op_missing_meta", "cpu") +@torch._dynamo.disable +def op_missing_meta(x, z): + x.add_(5) + z.add_(5) + return x + z + + +class TestExportTools(TestCase): + def test_report_exportability_basic(self): + class Module(torch.nn.Module): + def forward(self, x, y): + return x[0] + y + + f = Module() + inp = ([torch.ones(1, 3)], torch.ones(1, 3)) + + report = report_exportability(f, inp) + self.assertTrue(len(report) == 1) + self.assertTrue(report[""] is None) + + def test_report_exportability_with_issues(self): + class Unsupported(torch.nn.Module): + def forward(self, x): + return torch.ops.testlib.op_missing_meta(x, x.cos()) + + class Supported(torch.nn.Module): + def forward(self, x): + return x.sin() + + class Module(torch.nn.Module): + def __init__(self): + super().__init__() + self.unsupported = Unsupported() + self.supported = Supported() + + def forward(self, x): + y = torch.nonzero(x) + return self.unsupported(y) + self.supported(y) + + f = Module() + inp = (torch.ones(4, 4),) + + report = report_exportability(f, inp, strict=False, pre_dispatch=True) + + self.assertTrue(report[""] is not None) + self.assertTrue(report["unsupported"] is not None) + self.assertTrue(report["supported"] is None) + + +if __name__ == "__main__": + run_tests() diff --git a/test/forward_backward_compatibility/check_forward_backward_compatibility.py b/test/forward_backward_compatibility/check_forward_backward_compatibility.py index 285e410a79edc..81b85a4fe42f9 100644 --- a/test/forward_backward_compatibility/check_forward_backward_compatibility.py +++ b/test/forward_backward_compatibility/check_forward_backward_compatibility.py @@ -140,6 +140,8 @@ ("onednn::qconv2d_pointwise", datetime.date(2024, 12, 31)), ("onednn::qconv3d_pointwise", datetime.date(2024, 12, 31)), ("onednn::qconv2d_pointwise.binary", datetime.date(2024, 12, 31)), + # BC-breaking change in can_cast signature: 'from' -> 'from_' + ("aten::can_cast", datetime.date(2024, 5, 31)), ] ALLOW_LIST_COMPILED = [ diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index ffa71a7e905b5..cfbd96e7368d9 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -4835,70 +4835,6 @@ def f(a, b, c, d): self.assertEqual(get_num_ins_outs(fw_graph), (4, 2)) self.assertEqual(get_num_ins_outs(bw_graph), (2, 4)) - @unittest.skipIf(not USE_NETWORKX, "networkx not available") - def test_min_cut_partitioner_recomputable_ops(self): - def f(x): - return x * x * x - - recomputable_ops = [] - partition_fn = partial( - min_cut_rematerialization_partition, recomputable_ops=recomputable_ops - ) - - fw_graph, bw_graph = get_fw_bw_graph( - f, [torch.randn(3, requires_grad=True)], partition_fn - ) - # Expected forward graph: - # opcode name target args kwargs - # ------------- --------- --------------- -------------------------- -------- - # placeholder primals_1 primals_1 () {} - # call_function mul aten.mul.Tensor (primals_1, primals_1) {} - # call_function mul_1 aten.mul.Tensor (mul, primals_1) {} - # output output output ([mul_1, primals_1, mul],) {} - self.assertEqual(get_num_ins_outs(fw_graph), (1, 3)) - # Expected backward graph: - # opcode name target args kwargs - # ------------- ---------- --------------- ----------------------- -------- - # placeholder primals_1 primals_1 () {} - # placeholder mul mul () {} - # placeholder tangents_1 tangents_1 () {} - # call_function mul_2 aten.mul.Tensor (tangents_1, mul) {} - # call_function mul_3 aten.mul.Tensor (tangents_1, primals_1) {} - # call_function mul_4 aten.mul.Tensor (mul_3, primals_1) {} - # call_function add aten.add.Tensor (mul_2, mul_4) {} - # call_function add_1 aten.add.Tensor (add, mul_4) {} - # output output output ([add_1],) {} - self.assertEqual(get_num_ins_outs(bw_graph), (3, 1)) - - recomputable_ops = [torch.ops.aten.mul] - partition_fn = partial( - min_cut_rematerialization_partition, recomputable_ops=recomputable_ops - ) - fw_graph, bw_graph = get_fw_bw_graph( - f, [torch.randn(3, requires_grad=True)], partition_fn - ) - # Expected forward graph: - # opcode name target args kwargs - # ------------- --------- --------------- ---------------------- -------- - # placeholder primals_1 primals_1 () {} - # call_function mul aten.mul.Tensor (primals_1, primals_1) {} - # call_function mul_1 aten.mul.Tensor (mul, primals_1) {} - # output output output ([mul_1, primals_1],) {} - self.assertEqual(get_num_ins_outs(fw_graph), (1, 2)) - # Expected backward graph: - # opcode name target args kwargs - # ------------- ---------- --------------- ----------------------- -------- - # placeholder primals_1 primals_1 () {} - # placeholder tangents_1 tangents_1 () {} - # call_function mul aten.mul.Tensor (primals_1, primals_1) {} # RECOMPUTED - # call_function mul_2 aten.mul.Tensor (tangents_1, mul) {} - # call_function mul_3 aten.mul.Tensor (tangents_1, primals_1) {} - # call_function mul_4 aten.mul.Tensor (mul_3, primals_1) {} - # call_function add aten.add.Tensor (mul_2, mul_4) {} - # call_function add_1 aten.add.Tensor (add, mul_4) {} - # output output output ([add_1],) {} - self.assertEqual(get_num_ins_outs(bw_graph), (2, 1)) - def test_contiguous(self): # The test simulates the condition where transpose followed by view # happens in the backward pass. diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py index 074d075fc848c..87299d796f6c7 100644 --- a/test/inductor/test_compiled_autograd.py +++ b/test/inductor/test_compiled_autograd.py @@ -1,5 +1,6 @@ # Owner(s): ["module: inductor"] import functools +import logging import re import sys import unittest @@ -51,6 +52,14 @@ def hook3(gI, gO): class TestCompiledAutograd(TestCase): + def setUp(self) -> None: + super().setUp() + compiled_autograd.reset() + + def tearDown(self) -> None: + super().tearDown() + compiled_autograd.reset() + def check_output_and_recompiles( self, fn, count=1, compiler_fn=compiler_fn, compile_fn=False ): @@ -322,6 +331,7 @@ def bytecode_hook(code, out_code): handle.remove() def test_inputs_aliasing_bytecode_stack_restore(self): + logging.getLogger().setLevel(logging.WARNING) from torch.testing._internal.logging_tensor import LoggingTensor # Create a graph that allows inputs stealing @@ -752,6 +762,52 @@ def backward(ctx, gO_1, gO_2, gO_3): self.check_output_and_recompiles(fn, count=2) + @unittest.skipIf(not HAS_CUDA, "requires cuda") + def test_logging_tensor_flaky(self) -> None: + # when you first run some test using triton and then run test_inputs_aliasing_bytecode_stack_restore + # resulting in: + # - pytest: `TypeError: unsupported operand type(s) for +: 'Tensor' and 'LoggingTensor'` + # - python: `TypeError: not all arguments converted during string formatting` + + # 1. some triton involving test + def fn(): + def _fn(x): + return x + + x = torch.arange( + 1, 10, requires_grad=True, dtype=torch.float16, device="cuda" + ) + out = _fn(x) + loss = out.sum() + loss.backward() + + with compiled_autograd.enable(compiler_fn): + fn() + + logging.getLogger().setLevel( + logging.WARNING + ) # triton setup overwrote it to INFO + # 2. test_inputs_aliasing_bytecode_stack_restore + from torch.testing._internal.logging_tensor import LoggingTensor + + def forward(inputs): + add = inputs[0] + 1 + add_1 = add + inputs[1] + out = add_1.cpu() + return (out,) + + gm = torch.fx.symbolic_trace(forward) + print(gm.print_readable()) + torch._dynamo.utils.set_locals_to_steal(gm, ["inputs"]) + compiled_fn = torch.compile(gm) + + inputs = [ + torch.ones(1000000, dtype=torch.float32), + LoggingTensor(torch.ones(1)), + ] + + compiled_fn(inputs) + @unittest.skipIf(not HAS_CUDA, "requires cuda") def test_custom_fn_output_metadata(self): def my_compiler_fn(gm): @@ -1477,64 +1533,62 @@ def fn(): ) def test_verbose_logs_cpp(self): - script = """ -import torch + torch._logging.set_logs(compiled_autograd_verbose=True) -def compiler_fn(gm): - return torch.compile(gm, backend="eager") + def fn(): + model = torch.nn.Sequential( + torch.nn.Linear(4, 4), + torch.nn.ReLU(), + torch.nn.Linear(4, 4), + torch.nn.ReLU(), + ) + for i in [10, 11, 12]: + model.zero_grad() + x = torch.randn([i, 4]) + result = model(x).sum() + result.backward() + yield model[0].weight.grad + yield model[0].bias.grad + yield model[2].weight.grad + yield model[2].bias.grad -def main(): - torch._logging.set_logs(compiled_autograd_verbose=True) - model = torch.nn.Sequential( - torch.nn.Linear(4, 4), - torch.nn.ReLU(), - torch.nn.Linear(4, 4), - torch.nn.ReLU(), - ) + logs, ctx = logs_to_string( + torch._dynamo.compiled_autograd.__name__, "compiled_autograd_verbose" + ) + with ctx(): + self.check_output_and_recompiles(fn, count=2) - for i in range(10, 100): - x = torch.randn([i, 4]) - result = model(x).sum() - with torch._dynamo.compiled_autograd.enable(compiler_fn): - result.backward() + patterns1 = [ + r".*Cache miss due to new autograd node: torch::autograd::GraphRoot \(NodeCall 0\) with key size (\d+), " + r"previous key sizes=\[\]\n", + ] -main() -""" - stdout, _ = self.run_process_no_exception(script) - stdout = stdout.decode("utf-8") - - patterns = [ - r"\[python_compiled_autograd.cpp\] Creating cache entry for SumBackward0, with key of size (\d+)\n", - r"\[python_compiled_autograd.cpp\] Creating cache entry for ReluBackward0, with key of size (\d+)\n", - r"\[python_compiled_autograd.cpp\] Creating cache entry for AddmmBackward0, with key of size (\d+)\n", - r"\[python_compiled_autograd.cpp\] Creating cache entry for TBackward0, with key of size (\d+)\n", - r"\[python_compiled_autograd.cpp\] Creating cache entry for torch::autograd::AccumulateGrad, with key of size (\d+)\n", - r"\[python_compiled_autograd.cpp\] Creating cache entry for ReluBackward0, with key of size (\d+)\n", - r"\[python_compiled_autograd.cpp\] Creating cache entry for AddmmBackward0, with key of size (\d+)\n", - r"\[python_compiled_autograd.cpp\] Creating cache entry for TBackward0, with key of size (\d+)\n", - r"\[python_compiled_autograd.cpp\] Creating cache entry for torch::autograd::AccumulateGrad, with key of size (\d+)\n", - r"\[python_compiled_autograd.cpp\] Creating cache entry for torch::autograd::AccumulateGrad, with key of size (\d+)\n", - r"\[python_compiled_autograd.cpp\] Creating cache entry for torch::autograd::AccumulateGrad, with key of size (\d+)\n", - r"\[python_compiled_autograd.cpp\] Creating cache entry for torch::autograd::AccumulateGrad, with key of size (\d+)\n", - r"\[python_compiled_autograd.cpp\] Creating cache entry for ReluBackward0, with key of size (\d+)\n", - r"\[python_compiled_autograd.cpp\] Creating cache entry for AddmmBackward0, with key of size (\d+)\n", - r"\[python_compiled_autograd.cpp\] Creating cache entry for TBackward0, with key of size (\d+)\n", - r"\[python_compiled_autograd.cpp\] Creating cache entry for torch::autograd::AccumulateGrad, with key of size (\d+)\n", - r"\[python_compiled_autograd.cpp\] Creating cache entry for torch::autograd::AccumulateGrad, with key of size (\d+)\n", - r"\[python_compiled_autograd.cpp\] Creating cache entry for torch::autograd::AccumulateGrad, with key of size (\d+)\n", - r"\[python_compiled_autograd.cpp\] cache miss: marking sizes\[(\d+)\] as dynamic\n", - r"\[python_compiled_autograd.cpp\] cache miss: marking sizes\[(\d+)\] as dynamic\n", - r"\[python_compiled_autograd.cpp\] cache miss: marking sizes\[(\d+)\] as dynamic\n", - r"\[python_compiled_autograd.cpp\] cache miss: marking sizes\[(\d+)\] as dynamic\n", - r"\[python_compiled_autograd.cpp\] cache miss: marking sizes\[(\d+)\] as dynamic\n", - r"\[python_compiled_autograd.cpp\] cache miss: marking sizes\[(\d+)\] as dynamic\n", - r"\[python_compiled_autograd.cpp\] cache miss: marking sizes\[(\d+)\] as dynamic\n", + # recompile + patterns2 = [ + r".*Cache miss due to changed shapes: marking size idx (\d+) of torch::autograd::GraphRoot \(NodeCall 0\) as dynamic\n", + r".*Cache miss due to changed shapes: marking size idx (\d+) of SumBackward0 \(NodeCall 1\) as dynamic\n", + r".*Cache miss due to changed shapes: marking size idx (\d+) of SumBackward0 \(NodeCall 1\) as dynamic\n", + r".*Cache miss due to changed shapes: marking size idx (\d+) of ReluBackward0 \(NodeCall 2\) as dynamic\n", + r".*Cache miss due to changed shapes: marking size idx (\d+) of AddmmBackward0 \(NodeCall 3\) as dynamic\n", + r".*Cache miss due to changed shapes: marking size idx (\d+) of torch::autograd::AccumulateGrad " + r"\(NodeCall 5\) as dynamic\n", + r".*Cache miss due to changed shapes: marking size idx (\d+) of ReluBackward0 \(NodeCall 6\) as dynamic\n", ] - pattern = r"".join(patterns) - matches = re.findall(pattern, stdout) - self.assertEqual(len(matches), 1) - self.assertEqual(len(matches[0]), len(patterns)) + all_logs = logs.getvalue() + + pattern1 = r"".join(patterns1) + matches1 = re.findall(pattern1, all_logs) + self.assertEqual(len(matches1), 1) + assert isinstance( + matches1[0], str + ) # for a single match: matches1=['match'], for multiple matches: matches1=[('match1', 'match2')]... + self.assertEqual(len(matches1), len(patterns1)) + + pattern2 = r"".join(patterns2) + matches2 = re.findall(pattern2, all_logs) + self.assertEqual(len(matches2), 1) + self.assertEqual(len(matches2[0]), len(patterns2)) def test_snapshot_verbose_logs_flag(self): def fn(): diff --git a/test/inductor/test_compiled_optimizers.py b/test/inductor/test_compiled_optimizers.py index b2d0ed91809f9..7100837e9b92f 100644 --- a/test/inductor/test_compiled_optimizers.py +++ b/test/inductor/test_compiled_optimizers.py @@ -136,6 +136,8 @@ class KernelCounts(NamedTuple): "test_sgd_momentum_foreach_cuda": 5, "test_sgd_weight_decay_maximize_cuda": 4, "test_sgd_weight_decay_maximize_cpu": 4, + "test_sgd_weight_decay_cpu": 4, + "test_sgd_weight_decay_cuda": 4, "test_sgd_momentum_weight_decay_foreach_cuda": 2, "test_sgd_momentum_nesterov_weight_decay_foreach_cuda": 2, "test_sgd_cuda": 4, diff --git a/test/inductor/test_coordinate_descent_tuner.py b/test/inductor/test_coordinate_descent_tuner.py index 70618c06e9ec6..fdd3abb143927 100644 --- a/test/inductor/test_coordinate_descent_tuner.py +++ b/test/inductor/test_coordinate_descent_tuner.py @@ -16,7 +16,7 @@ except ImportError: if __name__ == "__main__": sys.exit(0) - raise unittest.SkipTest("requires triton") # noqa: TRY200 + raise unittest.SkipTest("requires triton") # noqa: B904 from torch._inductor import config from torch._inductor.runtime.coordinate_descent_tuner import CoordescTuner diff --git a/test/inductor/test_cpu_cpp_wrapper.py b/test/inductor/test_cpu_cpp_wrapper.py index b8fdbc49bd387..0888f3ad47a10 100644 --- a/test/inductor/test_cpu_cpp_wrapper.py +++ b/test/inductor/test_cpu_cpp_wrapper.py @@ -71,7 +71,6 @@ class DynamicShapesCppWrapperCpuTests(InductorTestCase): if config.abi_compatible: xfail_list = [ - "test_bernoulli1_cpu", # cpp fallback op naming issue "test_conv2d_binary_inplace_fusion_failed_cpu", "test_conv2d_binary_inplace_fusion_pass_cpu", "test_dynamic_qlinear_cpu", @@ -297,6 +296,24 @@ class BaseTest(NamedTuple): test_mkldnn_pattern_matcher.TestPatternMatcher(), condition=torch.backends.mkldnn.is_available(), ), + BaseTest( + "test_qlinear_gelu", + "cpu", + test_mkldnn_pattern_matcher.TestPatternMatcher(), + condition=torch.backends.mkldnn.is_available(), + ), + BaseTest( + "test_qlinear_add", + "cpu", + test_mkldnn_pattern_matcher.TestPatternMatcher(), + condition=torch.backends.mkldnn.is_available(), + ), + BaseTest( + "test_qlinear_add_relu", + "cpu", + test_mkldnn_pattern_matcher.TestPatternMatcher(), + condition=torch.backends.mkldnn.is_available(), + ), BaseTest( "test_qlinear_dequant_promotion", "cpu", diff --git a/test/inductor/test_cpu_select_algorithm.py b/test/inductor/test_cpu_select_algorithm.py deleted file mode 100644 index 505ae2f69a4ed..0000000000000 --- a/test/inductor/test_cpu_select_algorithm.py +++ /dev/null @@ -1,265 +0,0 @@ -# Owner(s): ["oncall: cpu inductor"] -import functools - -import sys -import unittest -from unittest.mock import patch - -import torch -import torch._dynamo.config -import torch._dynamo.config as dynamo_config -import torch._inductor.config as inductor_config -import torch._inductor.select_algorithm as select_algorithm -from torch._dynamo.utils import counters -from torch._inductor.test_case import run_tests, TestCase -from torch.testing._internal.common_device_type import ( - dtypes, - instantiate_device_type_tests, -) - -from torch.testing._internal.common_utils import IS_MACOS, parametrize, TEST_MKL - -try: - try: - from . import test_torchinductor - except ImportError: - import test_torchinductor -except unittest.SkipTest: - if __name__ == "__main__": - sys.exit(0) - raise - -check_model = test_torchinductor.check_model - -aten = torch.ops.aten - - -def patches(fn): - def skip_cache(self, choices, name, key, benchmark): - if benchmark is None: - return {} - timings = benchmark(choices) - for choice, timing in timings.items(): - if isinstance(choice, select_algorithm.ExternKernelCaller): - # we intentionally make ATEN kernel slower to cover the cases - # where template kernels are always chosen with fusions applied - # and correctness checks at runtime. - timings[choice] = timing * 1000 - return timings - - for patcher in [ - dynamo_config.patch(verbose=True), - inductor_config.patch( - debug=True, - max_autotune=True, - epilogue_fusion=True, - max_autotune_gemm_backends="CPP,ATEN", - ), - patch.object(select_algorithm, "VERIFY", dict(atol=1e-4, rtol=1e-4)), - patch.object(select_algorithm.AlgorithmSelectorCache, "lookup", skip_cache), - ]: - fn = patcher(fn) - - @functools.wraps(fn) - def wrapped(*args, **kwargs): - counters.clear() - torch.manual_seed(12345) - return fn(*args, **kwargs) - - return wrapped - - -class TestSelectAlgorithm(TestCase): - common = check_model - - @inductor_config.patch({"freezing": True}) - @patches - @torch.no_grad - @unittest.skipIf(not TEST_MKL, "Test requires MKL") - @parametrize("batch_size", (1, 2, 1000)) - @parametrize("in_features", (1, 1000)) - @parametrize("out_features", (1, 1024)) - @parametrize("bias", (True, False)) - @parametrize("input_3d", (True, False)) - @dtypes(torch.float, torch.bfloat16, torch.half) - def test_linear_static_shapes( - self, batch_size, in_features, out_features, bias, input_3d, dtype - ): - class M(torch.nn.Module): - def __init__(self, bias): - super().__init__() - self.linear = torch.nn.Linear(in_features, out_features, bias) - - def forward(self, x): - return self.linear(x) - - counters.clear() - mod = M(bias=bias).to(dtype=dtype).eval() - B = (2, batch_size) if input_3d else (batch_size,) - v = torch.randn(*B, in_features).to(dtype=dtype) - # For bfloat16 and half, we have to relax the tolerance - # due to the difference associave orders in different - # kernel implementations - atol, rtol = 1e-4, 1e-4 - if dtype == torch.half or dtype == torch.bfloat16: - atol, rtol = 1e-2, 1e-2 - with patch.object(select_algorithm, "VERIFY", dict(atol=atol, rtol=rtol)): - self.common(mod, (v,), atol=atol, rtol=rtol) - self.assertEqual( - counters["inductor"]["select_algorithm_autotune"], - 1 if out_features != 1 else 0, - ) - - @inductor_config.patch({"freezing": True}) - @patches - @torch.no_grad - @unittest.skipIf(not TEST_MKL, "Test requires MKL") - @parametrize("bias", (True, False)) - @dtypes(torch.float) - def test_linear_input_transpose(self, bias, dtype): - batch_size = 384 - in_features = 196 - out_features = 384 - - class M(torch.nn.Module): - def __init__(self, bias): - super().__init__() - self.linear = torch.nn.Linear(in_features, out_features, bias) - - @torch.compile - def forward(self, x): - return self.linear(x) - - counters.clear() - mod = M(bias=bias).to(dtype=dtype).eval() - v = torch.randn(in_features, batch_size).to(dtype=dtype) - self.common(mod, (v.transpose(0, 1),)) - # TODO(jgong5): support transposed input - self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 0) - - @inductor_config.patch({"freezing": True}) - @patches - @torch.no_grad - @unittest.skipIf(not TEST_MKL, "Test requires MKL") - @parametrize("bias", (True, False)) - @parametrize( - "epilogue", - ( - "relu", - "gelu", - "silu", - "sigmoid", - "tanh", - "hardswish", - "hardsigmoid", - "leaky_relu", - "hardtanh", - "add", - "sub", - "mul", - "div", - ), - ) - @dtypes(torch.float) - def test_linear_with_pointwise(self, bias, epilogue, dtype): - batch_size = 384 - in_features = 196 - out_features = 384 - - class M(torch.nn.Module): - def __init__(self, bias, epilogue, other): - super().__init__() - self.linear = torch.nn.Linear(in_features, out_features, bias) - if epilogue == "relu": - self.epilogue = torch.nn.ReLU() - elif epilogue == "gelu": - self.epilogue = torch.nn.GELU() - elif epilogue == "silu": - self.epilogue = torch.nn.SiLU() - elif epilogue == "sigmoid": - self.epilogue = torch.nn.Sigmoid() - elif epilogue == "tanh": - self.epilogue = torch.nn.Tanh() - elif epilogue == "hardswish": - self.epilogue = torch.nn.Hardswish() - elif epilogue == "hardsigmoid": - self.epilogue = torch.nn.Hardsigmoid() - elif epilogue == "leaky_relu": - self.epilogue = torch.nn.LeakyReLU() - elif epilogue == "hardtanh": - self.epilogue = torch.nn.Hardtanh() - elif epilogue == "add": - self.epilogue = lambda x: x + other - elif epilogue == "sub": - self.epilogue = lambda x: x - other - elif epilogue == "mul": - self.epilogue = lambda x: x * other - elif epilogue == "div": - self.epilogue = lambda x: x / other - - def forward(self, x): - return self.epilogue(self.linear(x)) - - counters.clear() - v = torch.randn(batch_size, in_features).to(dtype=dtype) - u = torch.randn(batch_size, out_features).to(dtype=dtype) - mod = M(bias=bias, epilogue=epilogue, other=u).to(dtype=dtype).eval() - self.common(mod, (v,)) - self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) - self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1) - - @inductor_config.patch({"freezing": True}) - @patches - @torch.no_grad - @unittest.skipIf(not TEST_MKL, "Test requires MKL") - @parametrize("bias", (True, False)) - @dtypes(torch.float) - def test_linear_with_transpose(self, bias, dtype): - batch_size = 384 - in_features = 196 - out_features = 128 - - class M(torch.nn.Module): - def __init__(self, bias): - super().__init__() - self.linear = torch.nn.Linear(in_features, out_features, bias) - - def forward(self, x, y): - return self.linear(x).transpose(0, 1) + y - - counters.clear() - mod = M(bias=bias).to(dtype=dtype).eval() - v = torch.randn(batch_size, in_features).to(dtype=dtype) - u = torch.randn(out_features, batch_size).to(dtype=dtype) - self.common(mod, (v, u)) - self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) - self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1) - - -@dynamo_config.patch({"dynamic_shapes": True, "assume_static_by_default": False}) -class _DynamicShapesTestBase(TestCase): - pass - - -class TestSelectAlgorithmDynamicShapes(_DynamicShapesTestBase): - common = check_model - test_linear_dynamic_shapes = TestSelectAlgorithm.test_linear_static_shapes - test_linear_with_pointwise_dynamic_shapes = ( - TestSelectAlgorithm.test_linear_with_pointwise - ) - test_linear_with_transpose_dynamic_shapes = ( - TestSelectAlgorithm.test_linear_with_transpose - ) - - -instantiate_device_type_tests(TestSelectAlgorithm, globals(), only_for="cpu") -instantiate_device_type_tests( - TestSelectAlgorithmDynamicShapes, globals(), only_for="cpu" -) - - -if __name__ == "__main__": - from torch.testing._internal.inductor_utils import HAS_CPU - - if HAS_CPU and not IS_MACOS: - run_tests() diff --git a/test/inductor/test_cuda_cpp_wrapper.py b/test/inductor/test_cuda_cpp_wrapper.py index 5bbe588d3a84e..5cb8af9db165a 100644 --- a/test/inductor/test_cuda_cpp_wrapper.py +++ b/test/inductor/test_cuda_cpp_wrapper.py @@ -97,7 +97,6 @@ class DynamicShapesCudaWrapperCudaTests(InductorTestCase): if config.abi_compatible: xfail_list = [ - "test_bernoulli1_cuda", # cpp fallback op naming issue "test_profiler_mark_wrapper_call_cuda", "test_scaled_dot_product_attention_cuda_dynamic_shapes", ] diff --git a/test/inductor/test_cuda_repro.py b/test/inductor/test_cuda_repro.py index db02d19310097..f303330bc1140 100644 --- a/test/inductor/test_cuda_repro.py +++ b/test/inductor/test_cuda_repro.py @@ -32,7 +32,7 @@ import triton from triton import language as tl except ImportError: - raise unittest.SkipTest("requires triton") # noqa: TRY200 + raise unittest.SkipTest("requires triton") # noqa: B904 try: from . import test_torchinductor diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index 9df905d2ad547..f3a9026a3c805 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -1,8 +1,9 @@ # Owner(s): ["module: inductor"] import functools +import unittest from collections import namedtuple -from typing import Callable +from typing import Callable, Optional from unittest import expectedFailure, skip, skipUnless from unittest.mock import patch @@ -58,14 +59,8 @@ def create_attention(score_mod): # --------- Useful score mod functions for testing --------- - -test_score_mods = [ - _identity, - _causal, - _rel_bias, - _rel_causal, - _generate_alibi_bias(8), -] +def _inverse_causal(score, b, h, m, n): + return torch.where(m <= n, score, float("-inf")) def _times_two(score, b, h, m, n): @@ -79,13 +74,11 @@ def _squared(score, b, h, m, n): def _head_offset(dtype: torch.dtype): - """Captured Buffer - Note: this builds a score_mod with index of a type - """ + """Captured Buffer""" head_offset = torch.rand(H, device="cuda", dtype=dtype) def score_mod(score, b, h, m, n): - return score * index(head_offset, [h]) + return score * head_offset[h] return score_mod @@ -103,20 +96,19 @@ def _trig2(score, b, h, m, n): return z -def _buffer_reduced(dtype: torch.dtype): - """Reduction in captured buffer""" - batch_offsets = torch.rand(B, 8, device="cuda", dtype=dtype) - - def score_mod(score, b, h, m, n): - batch_vals = index(batch_offsets, [b]) - return score + batch_vals.sum() - - return score_mod - +test_score_mods = [ + _identity, + _times_two, + _squared, + _causal, + _inverse_causal, + _rel_bias, + _rel_causal, + _generate_alibi_bias(8), +] captured_buffers_map = { "_head_offset": _head_offset, - "_buffer_reduced": _buffer_reduced, } B = 4 @@ -125,18 +117,35 @@ def score_mod(score, b, h, m, n): D = 64 -class TestTemplatedSDPA(InductorTestCase): - def _check_equal(self, golden_out, ref_out, compiled_out, dtype): +def query_key_value_clones( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + dtype: torch.dtype = None, +): + """Clones the query, key, and value tensors and moves them to the specified dtype.""" + if dtype is None: + dtype = query.dtype + query_ref = query.clone().detach().to(dtype).requires_grad_(query.requires_grad) + key_ref = key.clone().detach().to(dtype).requires_grad_(key.requires_grad) + value_ref = value.clone().detach().to(dtype).requires_grad_(value.requires_grad) + return query_ref, key_ref, value_ref + + +class TestFlexAttention(InductorTestCase): + def _check_equal( + self, + golden_out: torch.Tensor, + ref_out: torch.Tensor, + compiled_out: torch.Tensor, + fudge_factor: float, + tensor_name: Optional[str] = None, + ): compiled_error = (golden_out - compiled_out).abs().mean() ref_error = (golden_out - ref_out).abs().mean() - # Note, it seems like we really are less accurate than the float32 - # computation, likely due to the online softmax - if dtype == torch.float32: - fudge_factor = 10.0 - else: - fudge_factor = 1.1 if compiled_error > ref_error * fudge_factor: - msg = f"Compiled error {compiled_error} is greater than ref error {ref_error} by more than {fudge_factor}X." + name = tensor_name if tensor_name is not None else "" + msg = f"{name} Compiled error {compiled_error} is greater than ref error {ref_error} by more than {fudge_factor}X." self.assertTrue(False, msg) def run_test( @@ -150,15 +159,45 @@ def run_test( ): sdpa_partial = create_attention(score_mod) compiled_sdpa = torch.compile(sdpa_partial) - q = torch.randn((B, H, S, D), dtype=dtype, device="cuda") - k = torch.randn((B, H, S, D), dtype=dtype, device="cuda") - v = torch.randn((B, H, S, D), dtype=dtype, device="cuda") - golden_out = sdpa_partial( - q.to(torch.float64), k.to(torch.float64), v.to(torch.float64) - ) - ref_out = sdpa_partial(q, k, v) + q = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) + k = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) + v = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) + q_ref, k_ref, v_ref = query_key_value_clones(q, k, v) + q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64) + golden_out = sdpa_partial(q_gold, k_gold, v_gold) + ref_out = sdpa_partial(q_ref, k_ref, v_ref) compiled_out = compiled_sdpa(q, k, v) - self._check_equal(golden_out, ref_out, compiled_out, dtype) + + backward_grad = torch.randn((B, H, S, D), dtype=dtype, device="cuda") + + golden_out.backward(backward_grad.to(torch.float64)) + ref_out.backward(backward_grad) + compiled_out.backward(backward_grad) + + with torch.no_grad(): + # Note, it seems like we really are less accurate than the float32 + # computation, likely due to the online softmax + if dtype == torch.float32: + fudge_factor = 10.0 + else: + fudge_factor = 1.1 + + # Checkout output + self._check_equal(golden_out, ref_out, compiled_out, fudge_factor, "Out") + + # Check gradients + q_fudge_factor = 2.5 * fudge_factor + self._check_equal( + q_gold.grad, q_ref.grad, q.grad, q_fudge_factor, "Grad_Query" + ) + k_fudge_factor = 4 * fudge_factor + self._check_equal( + k_gold.grad, k_ref.grad, k.grad, k_fudge_factor, "Grad_Key" + ) + v_fudge_factor = 8 * fudge_factor + self._check_equal( + v_gold.grad, v_ref.grad, v.grad, v_fudge_factor, "Grad_Value" + ) def run_dynamic_test( self, @@ -196,12 +235,20 @@ def run_dynamic_test( # Compiling with dynamic shape in the first batch. compiled_sdpa = torch.compile(sdpa_partial, dynamic=True) compiled_out1 = compiled_sdpa(q1, k1, v1) - self._check_equal(golden_out1, ref_out1, compiled_out1, dtype) + + # Note, it seems like we really are less accurate than the float32 + # computation, likely due to the online softmax + if dtype == torch.float32: + fudge_factor = 10.0 + else: + fudge_factor = 1.1 + + self._check_equal(golden_out1, ref_out1, compiled_out1, fudge_factor) self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1) # No re-compilation, use the compiled dynamic shape version. compiled_out2 = compiled_sdpa(q2, k2, v2) - self._check_equal(golden_out2, ref_out2, compiled_out2, dtype) + self._check_equal(golden_out2, ref_out2, compiled_out2, fudge_factor) self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1) def run_automatic_dynamic_test( @@ -251,20 +298,28 @@ def run_automatic_dynamic_test( # 2, the second batch is compiled with dynamic shape # 3, no re-compilation in the third batch torch._dynamo.reset() + + # Note, it seems like we really are less accurate than the float32 + # computation, likely due to the online softmax + if dtype == torch.float32: + fudge_factor = 10.0 + else: + fudge_factor = 1.1 + # The first batch. compiled_sdpa = torch.compile(sdpa_partial) compiled_out1 = compiled_sdpa(q1, k1, v1) - self._check_equal(golden_out1, ref_out1, compiled_out1, dtype) + self._check_equal(golden_out1, ref_out1, compiled_out1, fudge_factor) self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1) # The second batch (automatic dynamic). compiled_out2 = compiled_sdpa(q2, k2, v2) - self._check_equal(golden_out2, ref_out2, compiled_out2, dtype) + self._check_equal(golden_out2, ref_out2, compiled_out2, fudge_factor) self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 2) # The third batch (no re-compilation). compiled_out3 = compiled_sdpa(q3, k3, v3) - self._check_equal(golden_out3, ref_out3, compiled_out3, dtype) + self._check_equal(golden_out3, ref_out3, compiled_out3, fudge_factor) self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 2) @supported_platform @@ -318,6 +373,21 @@ def score_mod(score, b, h, m, n): self.run_test(score_mod, dtype) + @supported_platform + @common_utils.parametrize("dtype", test_dtypes) + def test_captured_buffers_all_dims(self, dtype: torch.dtype): + head_scale = torch.randn(H, device="cuda") + batch_scale = torch.randn(B, device="cuda") + tok_scale = torch.randn(S, device="cuda") + + def all_bias(score, batch, head, token_q, token_kv): + score = score + tok_scale[token_q] + score = score + batch_scale[batch] + score = score + head_scale[head] + return score + + self.run_test(all_bias, dtype) + @supported_platform @common_utils.parametrize("dtype", test_dtypes_fast) def test_seq_masking(self, dtype): @@ -422,7 +492,7 @@ def score_mod_func(score, b, h, q, kv): make_tensor = functools.partial( torch.randn, - (2, 2, 8, 4), + (2, 2, 128, 4), device="cuda", dtype=torch.float64, requires_grad=True, @@ -458,6 +528,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): @supported_platform @common_utils.parametrize("dtype", test_dtypes_fast) + @unittest.skip("Silu decomp failing for full in backwards") def test_silu_on_score(self, dtype): def silu_score(score, b, h, q, kv): return torch.nn.functional.silu(score) @@ -597,23 +668,6 @@ def njt_score_mod(qk, b, h, q, kv): self.run_test(causal_njt, dtype) - @supported_platform - def test_backwards_fails(self): - make_tensor = functools.partial( - torch.randn, - (B, H, S, D), - dtype=torch.float32, - device="cuda", - requires_grad=True, - ) - q, k, v = make_tensor(), make_tensor(), make_tensor() - func = torch.compile(_flex_attention, backend="inductor", fullgraph=True) - with self.assertRaisesRegex( - AssertionError, "flex_attention_backward is not an OpOverload" - ): - out = func(q, k, v, _identity) - out.backward(torch.ones_like(out)) - @supported_platform def test_mixed_dtypes_fails(self): query = torch.randn((1, 1, 1024, 64), dtype=torch.float32, device="cuda") @@ -641,6 +695,7 @@ def score_mod(score, b, h, m, n): self.run_test(score_mod) @supported_platform + @skip("TODO: Figure out why this is erroring") @patch.object(torch._inductor.config, "max_autotune", True) def test_max_autotune_with_captured(self): head_scale = torch.randn(H, device="cuda") @@ -776,7 +831,7 @@ def test_aot_eager_gradcheck(self, score_mod): ) @supported_platform - @common_utils.parametrize("score_mod_name", ["_head_offset", "_buffer_reduced"]) + @common_utils.parametrize("score_mod_name", ["_head_offset"]) @common_utils.parametrize("mode", ["eager", "aot_eager"]) def test_captured_score_mod_aot_eager_gradcheck( self, score_mod_name: str, mode: str @@ -864,13 +919,10 @@ def debug_compile_fx_inner(graph, example_inputs, *args, **kwargs): joint_graph, """\ class GraphModule(torch.nn.Module): - def forward(self, primals_1: "f64[2, 2, 8, 4]", primals_2: "f64[2, 2, 8, 4]", primals_3: "f64[2, 2, 8, 4]", """ - + """alias_5: "f64[2, 2, 8, 4]", alias_7: "f32[2, 2, 8]", tangents_1: "f64[2, 2, 8, 4]"): + def forward(self, primals_1: "f64[2, 2, 8, 4]", primals_2: "f64[2, 2, 8, 4]", primals_3: "f64[2, 2, 8, 4]", alias_3: "f64[2, 2, 8, 4]", alias_5: "f32[2, 2, 8]", tangents_1: "f64[2, 2, 8, 4]"): fw_graph = self.fw_graph joint_graph = self.joint_graph - flex_attention_backward = torch.ops.higher_order.flex_attention_backward(primals_1, primals_2, """ - + """primals_3, alias_5, alias_7, tangents_1, fw_graph, joint_graph); primals_1 = primals_2 = primals_3 = alias_5 """ - + """= alias_7 = tangents_1 = fw_graph = joint_graph = None + flex_attention_backward = torch.ops.higher_order.flex_attention_backward(primals_1, primals_2, primals_3, alias_3, alias_5, tangents_1, fw_graph, joint_graph); primals_1 = primals_2 = primals_3 = alias_3 = alias_5 = tangents_1 = fw_graph = joint_graph = None getitem_2: "f64[2, 2, 8, 4]" = flex_attention_backward[0] getitem_3: "f64[2, 2, 8, 4]" = flex_attention_backward[1] getitem_4: "f64[2, 2, 8, 4]" = flex_attention_backward[2]; flex_attention_backward = None @@ -888,11 +940,11 @@ def forward(self, arg0_1: "f64[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i3 mul_2: "f64[]" = torch.ops.aten.mul.Tensor(arg5_1, arg0_1); arg5_1 = arg0_1 = None add: "f64[]" = torch.ops.aten.add.Tensor(mul_2, mul_1); mul_2 = mul_1 = None return [add, None, None, None, None] -""", +""", # noqa: B950 ) -common_utils.instantiate_parametrized_tests(TestTemplatedSDPA) +common_utils.instantiate_parametrized_tests(TestFlexAttention) if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/test/inductor/test_memory_planning.py b/test/inductor/test_memory_planning.py index 1bd546e5b4dfb..1ec1dd9f89e95 100644 --- a/test/inductor/test_memory_planning.py +++ b/test/inductor/test_memory_planning.py @@ -2,6 +2,8 @@ import sys +import unittest + from torch.testing._internal.common_utils import IS_CI, IS_WINDOWS, skipIfRocm from torch.testing._internal.inductor_utils import HAS_CUDA @@ -13,14 +15,12 @@ sys.exit(0) raise unittest.SkipTest("requires sympy/functorch/filelock") # noqa: F821 -import unittest - import torch -from test_torchinductor import run_and_get_cpp_code from torch._C import FileCheck from torch._dynamo.utils import same from torch._inductor import config from torch._inductor.test_case import run_tests, TestCase +from torch._inductor.utils import run_and_get_cpp_code from torch.export import Dim from torch.utils._triton import has_triton diff --git a/test/inductor/test_mkldnn_pattern_matcher.py b/test/inductor/test_mkldnn_pattern_matcher.py index cbf9dd89c506b..756de35df84cf 100644 --- a/test/inductor/test_mkldnn_pattern_matcher.py +++ b/test/inductor/test_mkldnn_pattern_matcher.py @@ -1589,6 +1589,144 @@ def test_qlinear_gelu_int8_mixed_bf16(self): (torch.randn((2, 4)),), gelu, int8_mixed_bf16=True ) + def _qlinear_add_cpu_test_helper(self, use_relu=False, int8_mixed_bf16=False): + r""" + This testcase will quantize two consecutive Linear->Add(->relu) patterns as: + X + / \ + linear(X) linear(X) + \ / + Add + | + Optional(relu) + / \ + linear(X) linear(X) + \ / + Add + | + Optional(relu) + | + Y + """ + + def fake_quant(x): + # to produce a float32 result as extra input + qlib = torch.ops.quantized_decomposed + x = qlib.quantize_per_tensor.default(x, 0.0166785, 42, 0, 255, torch.uint8) + x = qlib.dequantize_per_tensor.default( + x, 0.0166785, 42, 0, 255, torch.uint8 + ) + return x + + class M(torch.nn.Module): + def __init__( + self, + add_fn, + use_relu, + fake_quant_before_extra_input, + ): + super().__init__() + self.linear1 = torch.nn.Linear(4, 4) + self.linear2 = torch.nn.Linear(4, 4) + self.add_fn = add_fn + self.relu = torch.nn.ReLU() + self.linear3 = torch.nn.Linear(4, 4) + self.linear4 = torch.nn.Linear(4, 4) + self.add_fn2 = add_fn + self.relu2 = torch.nn.ReLU() + self.use_relu = use_relu + self.fake_quant_before_extra_input = fake_quant_before_extra_input + + def forward(self, x): + x1 = self.linear1(x) + x2 = self.linear2(x) + if self.fake_quant_before_extra_input: + x2 = fake_quant(x2) + tmp = self.add_fn(x1, x2) + if self.use_relu: + tmp = self.relu(tmp) + tmp1 = self.linear3(tmp) + tmp2 = self.linear4(tmp) + if self.fake_quant_before_extra_input: + tmp2 = fake_quant(tmp2) + res = self.add_fn2(tmp1, tmp2) + if self.use_relu: + res = self.relu2(res) + return res + + add_fn_list = [ + lambda x, y: x + y, + lambda x, y: y + x, + lambda x, y: x.add_(y), + lambda x, y: y.add_(x), + ] + fake_quant_x2_list = [False, True] if int8_mixed_bf16 else [False] + cases = itertools.product(add_fn_list, fake_quant_x2_list) + for add_fn, fq_x2 in cases: + mod = M(add_fn, use_relu, fq_x2).eval() + v = torch.randn((4, 4), dtype=torch.float32, requires_grad=False).add(1) + + def matcher_check_fn(): + # 1. Dequant-linear pattern matched in quantization weight prepack * 4 + self.assertEqual( + counters["inductor"]["qlinear_weight_prepack_matcher_count"], 4 + ) + # pattern = [dequant_per_tensor, (convert_dtype), dequant_per_channel, (convert_dtype), permute, addmm] + nodes_per_match = 6 if int8_mixed_bf16 else 4 + self.assertEqual( + counters["inductor"]["qlinear_weight_prepack_matcher_nodes"], + 4 * nodes_per_match, + ) + # 2. Qlinear Binary Unary fusion in post-grad fusion pass * 2 + self.assertEqual( + counters["inductor"]["qlinear_binary_matcher_count"], 2 + ) + # Two linear-binary patterns are matched + # matched patter1 = [qlinear, add, (convert dtype), (relu), quantize_per_tensor] + # matched patter2 = [qlinear, add, (convert dtype), (relu)] + # If add_fn is x.add_(y), x is bf16 and y is fp32, there is a to_bf16 node after binary + to_bf16_after_binary = 2 * (add_fn == add_fn_list[2] and fq_x2) + self.assertEqual( + counters["inductor"]["qlinear_binary_matcher_nodes"], + 5 + 2 * use_relu + to_bf16_after_binary, + ) + + for is_qat in [False, True]: + self._test_common( + mod, + (v,), + check_quantization=True, + check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float, + matcher_check_fn=matcher_check_fn, + is_qat=is_qat, + ) + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfRocm + def test_qlinear_add_cpu(self): + self._qlinear_add_cpu_test_helper() + + @skipIfNoDynamoSupport + @skipIfNoONEDNNBF16 + @skipIfNoONEDNN + @skipIfRocm + def test_qlinear_add_int8_mixed_bf16(self): + self._qlinear_add_cpu_test_helper(int8_mixed_bf16=True) + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfRocm + def test_qlinear_add_relu_cpu(self): + self._qlinear_add_cpu_test_helper(use_relu=True) + + @skipIfNoDynamoSupport + @skipIfNoONEDNNBF16 + @skipIfNoONEDNN + @skipIfRocm + def test_qlinear_add_relu_int8_mixed_bf16(self): + self._qlinear_add_cpu_test_helper(use_relu=True, int8_mixed_bf16=True) + def _qlinear_dequant_promotion_cpu_test_helper( self, inputs, diff --git a/test/inductor/test_pad_mm.py b/test/inductor/test_pad_mm.py index b16e5e5d62edf..bb37368f95676 100644 --- a/test/inductor/test_pad_mm.py +++ b/test/inductor/test_pad_mm.py @@ -169,6 +169,16 @@ def forward(self, a, b): res2, (code,) = run_and_get_code(compiled_fn, a, b) self.assertEqual(res1, res2) + @inductor_config.patch(force_shape_pad=True) + def test_zero_dim(self): + def addmm(x, a, b): + return torch.addmm(x, a, b) + + x = torch.randn(100).cuda() + a = torch.randn(0, 10).cuda() + b = torch.randn(10, 100).cuda() + self.assertEqual(torch.compile(addmm)(x, a, b), addmm(x, a, b)) + @inductor_config.patch(max_autotune=True, max_autotune_gemm_backends="TRITON") def test_pad_bmm_dyn_b(self): B = 10 diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 3a7b66d660658..1201e68f277e6 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -36,6 +36,8 @@ expectedFailureCodegenDynamic, rand_strided, same, + skipIfPy312, + xfailIfPy312, ) from torch._inductor.codegen.common import DataTypePropagation, OptimizationContext from torch._inductor.fx_passes import pad_mm @@ -46,6 +48,7 @@ aoti_eager_cache_dir, load_aoti_eager_cache, run_and_get_code, + run_and_get_cpp_code, run_and_get_triton_code, ) from torch._inductor.virtualized import V @@ -342,29 +345,6 @@ def clone_preserve_strides(x, device=None): return out -def run_and_get_cpp_code(fn, *args, **kwargs): - # We use the patch context manager instead of using it as a decorator. - # In this way, we can ensure that the attribute is patched and unpatched correctly - # even if this run_and_get_cpp_code function is called multiple times. - with patch.object(config, "debug", True): - torch._dynamo.reset() - import io - import logging - - log_capture_string = io.StringIO() - ch = logging.StreamHandler(log_capture_string) - from torch._inductor.graph import output_code_log - - output_code_log.addHandler(ch) - prev_level = output_code_log.level - output_code_log.setLevel(logging.DEBUG) - result = fn(*args, **kwargs) - s = log_capture_string.getvalue() - output_code_log.setLevel(prev_level) - output_code_log.removeHandler(ch) - return result, s - - def check_model( self: TestCase, model, @@ -861,6 +841,86 @@ def fn(a): self.assertTrue(kernel_lib_path in kernel_libs_abs_path) + @skipCUDAIf(not SM80OrLater, "Requires sm80") + def test_eager_aoti_with_scalar(self): + namespace_name = "aten" + op_name = "add" + op_overload_name = "Tensor" + op_name_with_overload = f"{op_name}.{op_overload_name}" + + dispatch_key = "CPU" + device = torch.device("cpu") + if self.device.lower() == "cuda": + dispatch_key = "CUDA" + device = torch.device("cuda") + + # Test the difference between scalar tensor and scalar + a = torch.scalar_tensor(1.0, device=device) + b = torch.scalar_tensor(2.0, device=device) + + kernel_lib_path = aoti_compile_with_persistent_cache( + namespace_name, + op_name_with_overload, + a.device.type, + False, + torch.ops.aten.add, + args=(a, b), + kwargs={"alpha": 3.0}, + ) + self.assertTrue(Path(kernel_lib_path).exists()) + device_kernel_cache = aoti_eager_cache_dir(namespace_name, device.type) + kernel_conf = device_kernel_cache / f"{op_name_with_overload}.json" + self.assertTrue(kernel_conf.exists()) + json_data = load_aoti_eager_cache( + namespace_name, op_name_with_overload, a.device.type + ) + op_info = json_data[0] + self.assertTrue(isinstance(op_info, dict)) + self.assertTrue("meta_info" in op_info) + self.assertTrue(len(op_info["meta_info"]) == 3) + self.assertTrue(op_info["meta_info"][0]["sizes"] == []) + self.assertTrue(op_info["meta_info"][0]["strides"] == []) + # Scalar Tensor + self.assertTrue("scalar_value" not in op_info["meta_info"][0]) + self.assertTrue(op_info["meta_info"][1]["sizes"] == []) + self.assertTrue(op_info["meta_info"][1]["strides"] == []) + # Scalar Tensor + self.assertTrue("scalar_value" not in op_info["meta_info"][1]) + self.assertTrue(op_info["meta_info"][2]["sizes"] == []) + self.assertTrue(op_info["meta_info"][2]["strides"] == []) + # Scalar + self.assertTrue("scalar_value" in op_info["meta_info"][2]) + + with _scoped_library("aten", "IMPL") as torch_compile_op_lib_impl: + a = torch.randn(128, device=device) + b = torch.randn(128, device=device) + + scalar_values = [1.0, 2.0, 3.0] + ref_values = [] + for scalar_value in scalar_values: + ref_values.append(torch.add(a, b, alpha=scalar_value)) + + qualified_op_name = f"{namespace_name}::{op_name}" + _, overload_names = torch._C._jit_get_operation(qualified_op_name) + for overload_name in overload_names: + try: + reg_op_name = qualified_op_name + schema = torch._C._get_schema(reg_op_name, overload_name) + if schema.overload_name: + reg_op_name = f"{reg_op_name}.{schema.overload_name}" + torch_compile_op_lib_impl._impl_with_aoti_compile( # noqa: F821 + reg_op_name, dispatch_key + ) + except Exception as e: + continue + + res_values = [] + for scalar_value in scalar_values: + res_values.append(torch.add(a, b, alpha=scalar_value)) + + self.assertEqual(len(ref_values), len(res_values)) + self.assertEqual(ref_values, res_values) + @skipCUDAIf(not SM80OrLater, "Requires sm80") def test_torch_compile_override_registration(self): dynamic = False @@ -2743,6 +2803,7 @@ def fn(a, b): check_lowp=False, ) + @skipIfPy312 # segfaults @config.patch(force_mixed_mm=True) def test_mixed_mm(self): def fn(a, b): @@ -2757,6 +2818,7 @@ def fn(a, b): check_lowp=True, ) + @skipIfPy312 # segfaults @config.patch(force_mixed_mm=True) def test_mixed_mm2(self): def fn(a, b, scale, bias): @@ -9448,6 +9510,7 @@ def fn(inp, offsets): self.common(fn, (inp, offsets), check_lowp=False) + @xfailIfPy312 @requires_gpu() @config.patch(assume_aligned_inputs=False) def test_config_option_dont_assume_alignment(self): diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index fdb9a8c37a47c..9bd873ac747b3 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -438,6 +438,8 @@ def wrapper_noop_set_seed(op, *args, **kwargs): "mH", "rsub", "triu", + "cummax", + "cummin", } diff --git a/test/inductor/test_triton_heuristics.py b/test/inductor/test_triton_heuristics.py index 6375512cc1289..d8c74c0a3841a 100644 --- a/test/inductor/test_triton_heuristics.py +++ b/test/inductor/test_triton_heuristics.py @@ -14,7 +14,7 @@ except ImportError: if __name__ == "__main__": sys.exit(0) - raise unittest.SkipTest("requires triton") # noqa: TRY200 + raise unittest.SkipTest("requires triton") # noqa: B904 from torch._inductor import config from torch._inductor.runtime.hints import TRITON_MAX_BLOCK diff --git a/test/inductor/test_unbacked_symints.py b/test/inductor/test_unbacked_symints.py index 43d1307fcfa58..60ce45317238e 100644 --- a/test/inductor/test_unbacked_symints.py +++ b/test/inductor/test_unbacked_symints.py @@ -188,6 +188,32 @@ def fn(x, w, a, b): expected = fn(*example_inputs) torch.testing.assert_close(actual, expected) + @skipCUDAIf(not HAS_CUDA, "requires cuda") + @dynamo_config.patch({"capture_scalar_outputs": True}) + def test_vertical_pointwise_reduction_fusion(self, device): + # Tests fusing a pointwise & reduction op with unbacked numel/rnumel. + def fn(x, y, repeats): + u0 = repeats.item() + unbacked = y.expand(u0, *y.shape) # [u0, 1, 16] + + # Note: We add x to both pointwise and reduction. Otherwise, the + # scheduler will refuse to fuse ops whose only common buffer has + # unbacked symints. + pointwise = unbacked + x + reduction = torch.sum(pointwise + x) + return pointwise, reduction + + example_inputs = ( + torch.randn(32, 16).cuda(), + torch.randn(1, 16).cuda(), + torch.tensor(32).cuda(), + ) + + actual = torch.compile(fn, fullgraph=True)(*example_inputs) + expected = fn(*example_inputs) + torch.testing.assert_close(actual, expected) + self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) + instantiate_device_type_tests( TestUnbackedSymints, globals(), only_for=(GPU_TYPE, "cpu") diff --git a/test/quantization/core/experimental/test_adaround_eager.py b/test/quantization/core/experimental/test_adaround_eager.py new file mode 100644 index 0000000000000..33a16f21bd0ff --- /dev/null +++ b/test/quantization/core/experimental/test_adaround_eager.py @@ -0,0 +1,118 @@ +# Owner(s): ["oncall: speech_infra"] + +import copy + +import torch +import torch.nn as nn +from torch.ao.quantization.experimental.adaround_optimization import ( + AdaptiveRoundingOptimizer, +) + +from torch.nn import functional as F +from torch.quantization.observer import MinMaxObserver +from torch.testing._internal.common_quantization import QuantizationTestCase + + +def forward_wrapper(fetcher): + def forward(module, input, output): + fetcher.append(input[0].detach()) + fetcher.append(output.detach()) + + return forward + + +class TestAdaround(QuantizationTestCase): + def feedforawrd_callback( + self, + model, + data, + ) -> None: + model(data) + + def run_adaround(self, model, img_data): + adaround_optimizer = AdaptiveRoundingOptimizer( + model, + self.feedforawrd_callback, + forward_wrapper, + img_data, + max_iter=100, + batch_size=10, + ) + adarounded_model = adaround_optimizer.run_adaround() + return adarounded_model + + def get_fake_quant(self, model): + hard_fake_quant_model = copy.deepcopy(model) + for _, module in hard_fake_quant_model.named_modules(): + if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)): + weight_observer = MinMaxObserver( + quant_min=-128, + quant_max=127, + dtype=torch.qint8, + qscheme=torch.per_tensor_symmetric, + ) + weight_observer(module.weight) + scale, zero_point = weight_observer.calculate_qparams() + fake_quant_module = torch.fake_quantize_per_tensor_affine( + module.weight, + scale=scale, + zero_point=zero_point, + quant_min=-128, + quant_max=127, + ) + module.weight.data.copy_(fake_quant_module) + return hard_fake_quant_model + + def test_linear_chain(self): + class LinearChain(nn.Module): + def __init__(self): + super().__init__() + self.linear1 = nn.Linear(3, 4) + self.linear2 = nn.Linear(4, 5) + self.linear3 = nn.Linear(5, 6) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + x = self.linear3(x) + return x + + float_model = LinearChain() + img_data = [torch.rand(10, 3, dtype=torch.float) for _ in range(50)] + adarounded_model = self.run_adaround(float_model, img_data) + fq_model = self.get_fake_quant(float_model) + rand_input = torch.rand(10, 3) + with torch.no_grad(): + ada_out = adarounded_model(rand_input) + fq_out = fq_model(rand_input) + float_out = float_model(rand_input) + ada_loss = F.mse_loss(ada_out, float_out) + fq_loss = F.mse_loss(fq_out, float_out) + self.assertTrue(ada_loss.item() < fq_loss.item()) + + def test_conv_chain(self): + class ConvChain(nn.Module): + def __init__(self): + super().__init__() + self.conv2d1 = nn.Conv2d(3, 4, 5, 5) + self.conv2d2 = nn.Conv2d(4, 5, 5, 5) + self.conv2d3 = nn.Conv2d(5, 6, 5, 5) + + def forward(self, x): + x = self.conv2d1(x) + x = self.conv2d2(x) + x = self.conv2d3(x) + return x + + float_model = ConvChain() + img_data = [torch.rand(10, 3, 125, 125, dtype=torch.float) for _ in range(50)] + adarounded_model = self.run_adaround(float_model, img_data) + fq_model = self.get_fake_quant(float_model) + rand_input = torch.rand(10, 3, 256, 256) + with torch.no_grad(): + ada_out = adarounded_model(rand_input) + fq_out = fq_model(rand_input) + float_out = float_model(rand_input) + ada_loss = F.mse_loss(ada_out, float_out) + fq_loss = F.mse_loss(fq_out, float_out) + self.assertTrue(ada_loss.item() < fq_loss.item()) diff --git a/test/quantization/pt2e/test_quantize_pt2e.py b/test/quantization/pt2e/test_quantize_pt2e.py index b96e1ff12ac3d..75cf3c4445716 100644 --- a/test/quantization/pt2e/test_quantize_pt2e.py +++ b/test/quantization/pt2e/test_quantize_pt2e.py @@ -2278,5 +2278,46 @@ def validate(self, model: torch.fx.GraphModule) -> None: node_list, ) + def test_multi_users_without_output_observer(self): + """ + Test the case in which a node is used by multiple users, + and had its output observer removed. + """ + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 3, 3) + + def forward(self, x): + x = self.conv(x) + return x, x + 1 + + example_inputs = (torch.randn(1, 3, 5, 5),) + m = M() + m = capture_pre_autograd_graph(m, example_inputs) + quantizer = XNNPACKQuantizer().set_global( + get_symmetric_quantization_config(), + ) + m = prepare_pt2e(m, quantizer) + m(*example_inputs) + + # Remove output observer + observer_to_remove = None + for n in m.graph.nodes: + if n.op == "output": + observer_to_remove = n.args[0][0] + assert observer_to_remove.op == "call_module" + assert observer_to_remove.target.startswith("activation_post_process_") + break + assert observer_to_remove is not None + observer_to_remove.replace_all_uses_with(observer_to_remove.args[0]) + m.graph.erase_node(observer_to_remove) + m.recompile() + + # Convert should succeed + m = convert_pt2e(m) + m(*example_inputs) + instantiate_parametrized_tests(TestQuantizePT2E) diff --git a/test/run_test.py b/test/run_test.py index 5b24a00789964..71ab08199f7a7 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -181,6 +181,7 @@ def __contains__(self, item): "test_jit_legacy", "test_cuda_nvml_based_avail", "test_jit_cuda_fuser", + "distributed/_tensor/test_attention", ] XPU_BLOCKLIST = [ @@ -239,7 +240,8 @@ def __contains__(self, item): "test_native_mha", # OOM "test_module_hooks", # OOM "inductor/test_max_autotune", - "inductor/test_cutlass_backend", # slow due to many nvcc compilation steps + "inductor/test_cutlass_backend", # slow due to many nvcc compilation steps, + "inductor/test_flex_attention", # OOM ] # A subset of onnx tests that cannot run in parallel due to high memory usage. ONNX_SERIAL_LIST = [ @@ -406,7 +408,7 @@ def run_test( stepcurrent_key = f"{test_file}_{test_module.shard}_{os.urandom(8).hex()}" if options.verbose: - unittest_args.append(f'-{"v"*options.verbose}') # in case of pytest + unittest_args.append(f'-{"v" * options.verbose}') # in case of pytest if test_file in RUN_PARALLEL_BLOCKLIST: unittest_args = [ diff --git a/test/test_autograd.py b/test/test_autograd.py index e20e8b18ebae7..3ae37e18e7a3f 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -2612,7 +2612,7 @@ def coro_no_grad(n=10): except UnrecoverableException: self.assertFalse(torch.is_grad_enabled()) - raise SecondaryException + raise SecondaryException from None @torch.enable_grad() def coro_enable_grad(n=10): @@ -2624,7 +2624,7 @@ def coro_enable_grad(n=10): except UnrecoverableException: self.assertTrue(torch.is_grad_enabled()) - raise SecondaryException + raise SecondaryException from None with torch.enable_grad(): coro = coro_no_grad() diff --git a/test/test_cuda.py b/test/test_cuda.py index 93e08eff4df6d..cc3e2380f2664 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -37,7 +37,11 @@ instantiate_device_type_tests, onlyCUDA, ) -from torch.testing._internal.common_optimizers import optim_db, optims +from torch.testing._internal.common_optimizers import ( + _get_optim_inputs_including_global_cliquey_kwargs, + optim_db, + optims, +) from torch.testing._internal.common_utils import ( freeze_rng_state, gcIfJetson, @@ -3200,111 +3204,6 @@ def _test_graphed_optimizer( for p_control, p_graphed in zip(params_control, params_graphed): self.assertEqual(p_control, p_graphed) - @unittest.skipIf( - not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" - ) - def test_graph_optims(self): - # Needs generalization if we want to extend this test to non-Adam-like optimizers. - cases = ( - [ - ( - optimizer_ctor, - { - "lr": 0.1, - "betas": (0.8, 0.7), - "foreach": foreach, - "decoupled_weight_decay": decoupled_weight_decay, - "weight_decay": weight_decay, - }, - ) - for optimizer_ctor, foreach, decoupled_weight_decay, weight_decay in product( - ( - torch.optim.NAdam, - torch.optim.RAdam, - ), - ( - False, - True, - ), - ( - False, - True, - ), - ( - 0.0, - 0.1, - ), - ) - ] - + [ - ( - torch.optim.Rprop, - {"lr": 0.1, "foreach": foreach, "maximize": maximize}, - ) - for foreach, maximize in product( - ( - False, - True, - ), - ( - False, - True, - ), - ) - ] - + [ - ( - optimizer_ctor, - { - "lr": 0.1, - "betas": (0.8, 0.7), - "foreach": foreach, - "amsgrad": amsgrad, - }, - ) - for optimizer_ctor, foreach, amsgrad in product( - (torch.optim.Adam, torch.optim.AdamW), - (False, True), - (False, True), - ) - ] - + [ - ( - optimizer_ctor, - {"lr": 0.1, "betas": (0.8, 0.7), "fused": True, "amsgrad": amsgrad}, - ) - for optimizer_ctor, amsgrad in product( - (torch.optim.Adam, torch.optim.AdamW), (False, True) - ) - ] - + [ - ( - optimizer_ctor, - { - "lr": 0.1, - "foreach": foreach, - "maximize": maximize, - "weight_decay": weight_decay, - }, - ) - for optimizer_ctor, foreach, maximize, weight_decay in product( - ( - torch.optim.Adamax, - torch.optim.ASGD, - torch.optim.Adadelta, - torch.optim.RMSprop, - ), - (False, True), - (False, True), - (0, 0.1), - ) - ] - ) - - for optimizer_ctor, kwargs in cases: - with self.subTest(optimizer_ctor=optimizer_ctor, kwargs=kwargs): - self._test_graphed_optimizer(3, 2, optimizer_ctor, kwargs) - @unittest.skipIf( not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" ) @@ -3376,123 +3275,6 @@ def test_graph_optims_with_explicitly_capturable_param_groups(self): self.assertEqual(ref_p1, param1) self.assertEqual(ref_p2, param2) - @unittest.skipIf( - not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" - ) - def test_graph_scaling_fused_optimizers(self): - cases = [ - ( - optimizer_ctor, - {"lr": 0.1, "betas": (0.8, 0.7), "fused": True, "amsgrad": amsgrad}, - ) - for optimizer_ctor, amsgrad in product( - (torch.optim.Adam, torch.optim.AdamW), (False, True) - ) - ] + list( - product( - (torch.optim.SGD,), - [ - { - "lr": 0.1, - "momentum": 0.0, - "dampening": d, - "weight_decay": w, - "nesterov": n, - "fused": True, - } - for d, w, n in product((0.0, 0.5), (0.0, 0.5), (False,)) - ] - + [ - { - "lr": 0.1, - "momentum": 0.5, - "dampening": d, - "weight_decay": w, - "nesterov": n, - "fused": True, - } - for d, w, n in product((0.0,), (0.0, 0.5), (True, False)) - ], - ) - ) - - steps_warmup = 3 - steps_train = 2 - - for OptClass, kwargs in cases: - has_capturable_arg = OptClass in (torch.optim.Adam, torch.optim.AdamW) - for actually_do_graphs in (True, False) if has_capturable_arg else (True,): - params = [torch.randn((i + 5, i + 5), device="cuda") for i in range(2)] - params_control = [p.clone().requires_grad_() for p in params] - params_graphed = [p.clone().requires_grad_() for p in params] - - # `GradScaler` in-place updates gradients thus it's necessary to duplicate gradients. - grads = [ - [torch.randn_like(p) for p in params] - for _ in range(steps_warmup + steps_train) - ] - with torch.no_grad(): - grads_control = [[g.clone() for g in gs] for gs in grads] - grads_graphed = [[g.clone() for g in gs] for gs in grads] - - # Gradient Scaler - scaler_for_control = torch.cuda.amp.GradScaler(init_scale=128.0) - with torch.no_grad(): - scaler_for_control._lazy_init_scale_growth_tracker( - torch.device("cuda") - ) - - scaler_for_graphed = torch.cuda.amp.GradScaler() - scaler_for_graphed.load_state_dict(scaler_for_control.state_dict()) - with torch.no_grad(): - scaler_for_graphed._lazy_init_scale_growth_tracker( - torch.device("cuda") - ) - - # Control (capturable=False) - if has_capturable_arg: - kwargs["capturable"] = False - opt = OptClass(params_control, **kwargs) - - for i in range(steps_warmup + steps_train): - for j, p in enumerate(params_control): - p.grad = grads_control[i][j] - scaler_for_control.step(opt) - scaler_for_control.update() - - # capturable=True - if has_capturable_arg: - kwargs["capturable"] = True - opt = OptClass(params_graphed, **kwargs) - - for i in range(steps_warmup): - for j, p in enumerate(params_graphed): - p.grad = grads_graphed[i][j] - scaler_for_graphed.step(opt) - scaler_for_graphed.update() - - if actually_do_graphs: - g = torch.cuda.CUDAGraph() - with torch.cuda.graph(g): - scaler_for_graphed.step(opt) - scaler_for_graphed.update() - - for i in range(steps_train): - if actually_do_graphs: - for j, p in enumerate(params_graphed): - p.grad.copy_(grads_graphed[i + steps_warmup][j]) - g.replay() - else: - # Passing capturable=True to the constructor and running without graphs should still be - # numerically correct, even if it's not ideal for performance. - for j, p in enumerate(params_graphed): - p.grad = grads_graphed[i + steps_warmup][j] - scaler_for_graphed.step(opt) - scaler_for_graphed.update() - - for p_control, p_graphed in zip(params_control, params_graphed): - self.assertEqual(p_control, p_graphed) - @unittest.skipIf( not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" ) @@ -4698,10 +4480,175 @@ def test_no_triton_on_import(self): self.assertEqual(rc, "False", "Triton was imported when importing torch!") +@torch.testing._internal.common_utils.markDynamoStrictTest class TestCudaOptims(TestCase): # These tests will be instantiate with instantiate_device_type_tests # to apply the new OptimizerInfo structure. + @onlyCUDA + @unittest.skipIf( + not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >=5.3 required for graphs" + ) + @optims( + [optim for optim in optim_db if optim.has_capturable_arg], + dtypes=[torch.float32], + ) + def test_graph_optims(self, device, dtype, optim_info): + optim_cls = optim_info.optim_cls + all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( + device, dtype, optim_info, skip=("differentiable",) + ) + + steps_warmup = 3 + steps_train = 2 + + for optim_input in all_optim_inputs: + kwargs = optim_input.kwargs + + # lr as a Tensor is not supported when capturable=False and foreach=True for torch.optim.adam + # and torch.optim.adamw + kwargs["lr"] = 0.1 + + for actually_do_graphs in (True, False): + params = [ + torch.randn((i + 5, i + 5), device=device) for i in range(2) + ] + [torch.randn((), device=device)] + params_control = [p.clone().requires_grad_() for p in params] + params_graphed = [p.clone().requires_grad_() for p in params] + + grads = [ + [torch.randn_like(p) for p in params] + for _ in range(steps_warmup + steps_train) + ] + + # Control (capturable=False) + kwargs["capturable"] = False + + opt = optim_cls(params_control, **kwargs) + for i in range(steps_warmup + steps_train): + for j, p in enumerate(params_control): + p.grad = grads[i][j] + opt.step() + + # capturable=True + kwargs["capturable"] = True + opt = optim_cls(params_graphed, **kwargs) + + for i in range(steps_warmup): + for j, p in enumerate(params_graphed): + p.grad = grads[i][j] + opt.step() + + if actually_do_graphs: + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + opt.step() + + for i in range(steps_train): + if actually_do_graphs: + for j, p in enumerate(params_graphed): + p.grad.copy_(grads[i + steps_warmup][j]) + g.replay() + else: + # Passing capturable=True to the constructor and running without graphs should still be + # numerically correct, even if it's not ideal for performance. + for j, p in enumerate(params_graphed): + p.grad = grads[i + steps_warmup][j] + opt.step() + + for p_control, p_graphed in zip(params_control, params_graphed): + self.assertEqual(p_control, p_graphed) + + @onlyCUDA + @unittest.skipIf( + not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" + ) + @optims( + [optim for optim in optim_db if "fused" in optim.supported_impls], + dtypes=[torch.float32], + ) + def test_graph_scaling_fused_optimizers(self, device, dtype, optim_info): + optim_cls = optim_info.optim_cls + + steps_warmup = 3 + steps_train = 2 + + optim_inputs = optim_info.optim_inputs_func(device=device) + + for optim_input in optim_inputs: + kwargs = optim_input.kwargs + kwargs["fused"] = True + + for actually_do_graphs in ( + (True, False) if optim_info.has_capturable_arg else (True,) + ): + params = [torch.randn((i + 5, i + 5), device=device) for i in range(2)] + params_control = [p.clone().requires_grad_() for p in params] + params_graphed = [p.clone().requires_grad_() for p in params] + + # `GradScaler` in-place updates gradients thus it's necessary to duplicate gradients. + grads = [ + [torch.randn_like(p) for p in params] + for _ in range(steps_warmup + steps_train) + ] + with torch.no_grad(): + grads_control = [[g.clone() for g in gs] for gs in grads] + grads_graphed = [[g.clone() for g in gs] for gs in grads] + + # Gradient Scaler + scaler_for_control = torch.cuda.amp.GradScaler(init_scale=128.0) + with torch.no_grad(): + scaler_for_control._lazy_init_scale_growth_tracker(device) + + scaler_for_graphed = torch.cuda.amp.GradScaler() + scaler_for_graphed.load_state_dict(scaler_for_control.state_dict()) + with torch.no_grad(): + scaler_for_graphed._lazy_init_scale_growth_tracker(device) + + # Control (capturable=False) + if optim_info.has_capturable_arg: + kwargs["capturable"] = False + opt = optim_cls(params_control, **kwargs) + + for i in range(steps_warmup + steps_train): + for j, p in enumerate(params_control): + p.grad = grads_control[i][j] + scaler_for_control.step(opt) + scaler_for_control.update() + + # capturable=True + if optim_info.has_capturable_arg: + kwargs["capturable"] = True + opt = optim_cls(params_graphed, **kwargs) + + for i in range(steps_warmup): + for j, p in enumerate(params_graphed): + p.grad = grads_graphed[i][j] + scaler_for_graphed.step(opt) + scaler_for_graphed.update() + + if actually_do_graphs: + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + scaler_for_graphed.step(opt) + scaler_for_graphed.update() + + for i in range(steps_train): + if actually_do_graphs: + for j, p in enumerate(params_graphed): + p.grad.copy_(grads_graphed[i + steps_warmup][j]) + g.replay() + else: + # Passing capturable=True to the constructor and running without graphs should still be + # numerically correct, even if it's not ideal for performance. + for j, p in enumerate(params_graphed): + p.grad = grads_graphed[i + steps_warmup][j] + scaler_for_graphed.step(opt) + scaler_for_graphed.update() + + for p_control, p_graphed in zip(params_control, params_graphed): + self.assertEqual(p_control, p_graphed) + @onlyCUDA @unittest.skipIf( not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" diff --git a/test/test_fx.py b/test/test_fx.py index eadcd750aeded..a58abb906d89c 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -4078,7 +4078,7 @@ def test_function_back_compat(self): f"unintended, please revert it. If it was intended, check with the FX " \ f"team to ensure that the proper deprecation protocols have been followed " \ f"and subsequently --accept the change." - raise AssertionError(msg) # noqa: TRY200 + raise AssertionError(msg) # noqa: B904 def test_class_member_back_compat(self): """ diff --git a/test/test_nn.py b/test/test_nn.py index 008354ad721eb..76bc614f025d9 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -8205,6 +8205,16 @@ def help(input, conv, memory_format): weight = torch.empty([1, 0, 1], dtype=dtype, device=device) torch._C._nn.slow_conv3d(inp, weight, 1) + with self.assertRaisesRegex(RuntimeError, re.escape("2D kernel_size expected")): + torch._C._nn.thnn_conv2d(torch.rand([1, 1, 1, 1]), kernel_size=[], padding=[1, 1], stride=[1, 1], + weight=torch.rand([1, 1])) + with self.assertRaisesRegex(RuntimeError, re.escape("2D stride expected")): + torch._C._nn.thnn_conv2d(torch.rand([1, 1, 1, 1]), kernel_size=[1, 1], padding=[1, 1], stride=[], + weight=torch.rand([1, 1])) + with self.assertRaisesRegex(RuntimeError, re.escape("2D padding expected")): + torch._C._nn.thnn_conv2d(torch.rand([1, 1, 1, 1]), kernel_size=[1, 1], padding=[], stride=[1, 1], + weight=torch.rand([1, 1])) + def test_InstanceNorm1d_general(self, device): b = random.randint(3, 5) c = random.randint(3, 5) diff --git a/test/test_optim.py b/test/test_optim.py index 717e892246722..7fa612e89da01 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -604,8 +604,16 @@ def _compare_between(self, inputs, models, optimizers, assert_eq_kwargs=None, as for input, model, optimizer in zip(inputs, models, optimizers): optimizer.zero_grad() + if i == 3: + # Freeze a layer to test if the step of this layer in 'fused' or 'foreach' + # is same as the step in 'forloop'. + model[2].requires_grad_(False) + if i == 5: + # Unfreeze the layer after 2 iters. + model[2].requires_grad_(True) + # Test that step behaves as expected (a no-op) when grads are set to None - if i != 3: + if i != 2: output = model(input) loss = output.sum() loss.backward() diff --git a/test/test_serialization.py b/test/test_serialization.py index 49f8880885ec4..1be1b06ab7863 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -15,8 +15,10 @@ import shutil import pathlib import platform +from collections import OrderedDict from copy import deepcopy from itertools import product +from types import ModuleType from torch._utils_internal import get_file_path_2 from torch._utils import _rebuild_tensor @@ -27,9 +29,10 @@ from torch.testing._internal.common_utils import ( IS_FILESYSTEM_UTF8_ENCODING, TemporaryDirectoryName, TestCase, IS_FBCODE, IS_WINDOWS, TEST_DILL, run_tests, download_file, BytesIOContext, TemporaryFileName, - parametrize, instantiate_parametrized_tests, AlwaysWarnTypedStorageRemoval, serialTest) + parametrize, instantiate_parametrized_tests, AlwaysWarnTypedStorageRemoval, serialTest, skipIfTorchDynamo) from torch.testing._internal.common_device_type import instantiate_device_type_tests from torch.testing._internal.common_dtype import all_types_and_complex_and +from torch.testing._internal.two_tensor import TwoTensor # noqa: F401 if not IS_WINDOWS: from mmap import MAP_SHARED, MAP_PRIVATE @@ -493,6 +496,15 @@ def test_serialization_map_location(self): def map_location(storage, loc): return storage + def generate_map_locations(device_type): + return [ + {'cuda:0': device_type + ':0'}, + device_type, + device_type + ':0', + torch.device(device_type), + torch.device(device_type, 0) + ] + def load_bytes(): with open(test_file_path, 'rb') as f: return io.BytesIO(f.read()) @@ -504,34 +516,39 @@ def load_bytes(): 'cpu', torch.device('cpu'), ] - gpu_0_map_locations = [ - {'cuda:0': 'cuda:0'}, - 'cuda', - 'cuda:0', - torch.device('cuda'), - torch.device('cuda', 0) - ] + gpu_0_map_locations = generate_map_locations('cuda') gpu_last_map_locations = [ f'cuda:{torch.cuda.device_count() - 1}', ] + xpu_0_map_locations = generate_map_locations('xpu') + xpu_last_map_locations = [ + f'xpu:{torch.xpu.device_count() - 1}', + ] - def check_map_locations(map_locations, tensor_class, intended_device): + def check_map_locations(map_locations, dtype, intended_device): for fileobject_lambda in fileobject_lambdas: for map_location in map_locations: tensor = torch.load(fileobject_lambda(), map_location=map_location) self.assertEqual(tensor.device, intended_device) - self.assertIsInstance(tensor, tensor_class) - self.assertEqual(tensor, tensor_class([[1.0, 2.0], [3.0, 4.0]])) + self.assertEqual(tensor.dtype, dtype) + self.assertEqual(tensor, torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=dtype, device=intended_device)) - check_map_locations(cpu_map_locations, torch.FloatTensor, torch.device('cpu')) + check_map_locations(cpu_map_locations, torch.float, torch.device('cpu')) if torch.cuda.is_available(): - check_map_locations(gpu_0_map_locations, torch.cuda.FloatTensor, torch.device('cuda', 0)) + check_map_locations(gpu_0_map_locations, torch.float, torch.device('cuda', 0)) check_map_locations( gpu_last_map_locations, - torch.cuda.FloatTensor, + torch.float, torch.device('cuda', torch.cuda.device_count() - 1) ) + if torch.xpu.is_available(): + check_map_locations(xpu_0_map_locations, torch.float, torch.device('xpu', 0)) + check_map_locations( + xpu_last_map_locations, + torch.float, + torch.device('xpu', torch.xpu.device_count() - 1) + ) @unittest.skipIf(torch.cuda.is_available(), "Testing torch.load on CPU-only machine") def test_load_nonexistent_device(self): @@ -1024,7 +1041,7 @@ def __reduce__(self): self.assertIsNone(torch.load(f, weights_only=False)) f.seek(0) # Safe load should assert - with self.assertRaisesRegex(pickle.UnpicklingError, "Unsupported class"): + with self.assertRaisesRegex(pickle.UnpicklingError, "Unsupported global: GLOBAL __builtin__.print"): torch.load(f, weights_only=True) @parametrize('weights_only', (False, True)) @@ -4094,6 +4111,23 @@ def __setstate__(self, state): class TestEmptySubclass(torch.Tensor): ... +# ONLY use SubclassSpoof subclasses for the subclass spoof tests since we modify them +# Cannot define locally in test or pickle will fail. +class TestEmptySubclassSpoof(TestEmptySubclass): + ... + +class TestWrapperSubclassSpoof(TestWrapperSubclass): + ... + +class RebuildFromTypeV2Spoof(torch.Tensor): + def __new__(cls, elem, naughty, **kwargs): + if naughty: + raise RuntimeError("naughty") + return super().__new__(cls, elem) + + def __reduce_ex__(self, protocol): + return (torch._tensor._rebuild_from_type_v2, (RebuildFromTypeV2Spoof, torch.Tensor, (True,), {})) + class TestSubclassSerialization(TestCase): def test_tensor_subclass_wrapper_serialization(self): @@ -4173,6 +4207,203 @@ def test_empty_class_serialization(self): f.seek(0) tensor2 = torch.load(f) + def _create_bad_func(self, name): + def bad_func(self, *args, **kwargs): + raise RuntimeError(f"running {name}") + return bad_func + + @parametrize("wrapper", (True, False)) + def test_tensor_subclass_method_spoofing(self, wrapper): + ''' + This tests seeks to do the following: + - determine which methods of a tensor subclass might be called during unpickling (weights_only=False) + we consider these methods "risky" for weights_only + - ensure that we ban overriding this group of methods on a tensor subclass by default (weights_only=True) + - ensure that tensor subclass that doesn't override any of these can be unpickled (weights_only=True) + + We achieve this by overriding all methods of a tensor subclass to raise a RuntimeError + when called. We then try to unpickle a tensor subclass with weights_only=False and ensure that + only the RuntimeErrors that we expect are thrown. + + We then load with weights_only and ensure that weights_only will fail unless all the risky methods + are not overriden by resetting the risky methods to the non-overriden version in a loop and calling load. + The final weights_only load call when all the risky methods are no longer overriden. + ''' + subclass = TestWrapperSubclassSpoof if wrapper else TestEmptySubclassSpoof + t = subclass(torch.randn(2, 3)) + # To trigger setattr for the non-wrapper case + if not wrapper: + t.foo = 'bar' + inp = {'weight': t} + + with TemporaryFileName() as f: + torch.save(inp, f) + loaded = torch.load(f, weights_only=True) + self.assertEqual(loaded['weight'], inp['weight']) + + restore_methods = dict() + methods = [func for func in dir(subclass) if callable(getattr(subclass, func))] + for method in methods: + if method != "__class__": + restore_methods[method] = getattr(subclass, method) + setattr(subclass, method, self._create_bad_func(method)) + # These additional methods might be called during getattr or setattr + # but are not in methods above (not defined on tensor base class) + subclass.__get__ = self._create_bad_func("__get__") + subclass.__set__ = self._create_bad_func("__set__") + subclass.__getattr__ = self._create_bad_func("__getattr__") + restore_methods["__get__"] = None + restore_methods["__getattr__"] = None + restore_methods["__set__"] = None + + try: + # Check that weights_only=False load raises the RuntimeErrors we expect + with self.assertRaisesRegex(RuntimeError, "running __getattribute__"): + torch.load(f, weights_only=False) + subclass.__getattribute__ = restore_methods['__getattribute__'] + with self.assertRaisesRegex(RuntimeError, "running __setstate__"): + torch.load(f, weights_only=False) + subclass.__setstate__ = restore_methods['__setstate__'] + with self.assertRaisesRegex(RuntimeError, "running __setattr__"): + torch.load(f, weights_only=False) + subclass.__setattr__ = restore_methods['__setattr__'] + # should finally work + torch.load(f, weights_only=False) + + # Check that weights_only=True catches that risky methods are overriden + subclass.__setstate__ = self._create_bad_func("__setstate__") + subclass.__getattribute__ = self._create_bad_func("__getattribute__") + subclass.__setattr__ = self._create_bad_func("__setattr__") + with self.assertRaisesRegex(pickle.UnpicklingError, + "methods: __getattribute__=True __getattr__=True __get__=True " + "__setattr__=True __set__=True __setstate__=True"): + torch.load(f, weights_only=True) + risky_methods = ['__get__', '__set__', '__getattr__', '__setattr__', '__getattribute__', '__setstate__'] + for i, meth in enumerate(risky_methods): + setattr(subclass, meth, restore_methods[meth]) + if i != len(risky_methods) - 1: + # When the given methods are not all back to default, load should still throw + # but reflect which methods are no longer overriden + with self.assertRaisesRegex(pickle.UnpicklingError, f"{meth}=False"): + torch.load(f, weights_only=True) + else: + # When the given methods are all back to default, weights_only load should finally work + loaded = torch.load(f, weights_only=True) + finally: + for method, func in restore_methods.items(): + setattr(subclass, method, func) + a = subclass(torch.randn(2, 3)) + + @skipIfTorchDynamo("name 'SYNTHETIC_LOCAL' is not defined") + def test_safe_globals_for_weights_only(self): + ''' + Tests import semantic for tensor subclass and the {add/get/clear}_safe_globals APIs + ''' + # Needed to prevent UnboundLocalError: local variable 'TwoTensor' referenced before assignment + global TwoTensor + t = TwoTensor(torch.randn(2, 3), torch.randn(2, 3)) + p = torch.nn.Parameter(t) + sd = OrderedDict([('t', t), ('p', p)]) + + with tempfile.NamedTemporaryFile() as f: + torch.save(sd, f) + # unimport TwoTensor + try: + del sys.modules['torch.testing._internal.two_tensor'] + + # Loading tensor subclass with weights_only=True should fail + # if tensor subclass has not been imported + with self.assertRaisesRegex(pickle.UnpicklingError, + "expect `torch.testing._internal.two_tensor` to be present in `sys.modules`"): + f.seek(0) + sd = torch.load(f, weights_only=True) + + # Loading tensor subclass with weights_only=True should work + # if target methods are not overriden and user has imported the subclass + from torch.testing._internal.two_tensor import TwoTensor + f.seek(0) + sd = torch.load(f, weights_only=True) + self.assertEqual(sd['t'], t) + self.assertEqual(sd['p'], p) + + # Loading tensor subclass with weights_only=True should fail + # if __setstate__ is overriden + f.seek(0) + restore_setstate = TwoTensor.__setstate__ + try: + TwoTensor.__setstate__ = lambda self, state: self.__dict__.update(state) + with self.assertRaisesRegex(pickle.UnpicklingError, "__setstate__=True"): + torch.load(f, weights_only=True) + + # Loading tensor subclass with overriden __setstate__ with weights_only=True should work + # if the class is marked safe + f.seek(0) + torch.serialization.add_safe_globals([TwoTensor]) + self.assertTrue(torch.serialization.get_safe_globals() == [TwoTensor]) + sd = torch.load(f, weights_only=True) + self.assertEqual(sd['t'], t) + self.assertEqual(sd['p'], p) + + # Should fail again when safe globals are cleared + torch.serialization.clear_safe_globals() + f.seek(0) + with self.assertRaisesRegex(pickle.UnpicklingError, "__setstate__=True"): + torch.load(f, weights_only=True) + finally: + TwoTensor.__setstate__ = restore_setstate + finally: + from torch.testing._internal.two_tensor import TwoTensor + + + def test_tensor_subclass_parent_module_method_spoofing(self): + ''' + Tests that weights_only load does not call any methods of the parent module + that contains the tensor subclass. + + We achieve this by overriding all methods of a module we add to sys.modules to raise a RuntimeError + when called. We then try to unpickle a tensor subclass with weights_only=True and ensure that + no RuntimeErrors are thrown. + ''' + # Simulates user doing `import spoof_mod` where `spoof_mod` contains `TestEmptySubclass` + class SpoofModule(ModuleType): + pass + + spoof_mod = SpoofModule('bla') + spoof_mod.TestEmptySubclass = TestEmptySubclass + inp = {'weight': TestEmptySubclass(torch.randn(2, 3))} + TestEmptySubclass.__module__ = 'spoof_mod' + sys.modules['spoof_mod'] = spoof_mod + + try: + with TemporaryFileName() as f: + torch.save(inp, f) + torch.load(f, weights_only=True) + restore_methods = dict() + methods = [func for func in dir(SpoofModule) if callable(getattr(SpoofModule, func))] + for method in methods: + if method != "__class__": + restore_methods[method] = getattr(SpoofModule, method) + setattr(SpoofModule, method, self._create_bad_func(method)) + SpoofModule.__get__ = self._create_bad_func("__get__") + SpoofModule.__getattr__ = self._create_bad_func("__getattr__") + loaded = torch.load(f, weights_only=True) + self.assertEqual(loaded['weight'], inp['weight']) + finally: + TestEmptySubclass.__module__ = __name__ + del sys.modules['spoof_mod'] + + def test_rebuild_from_type_v2_spoof(self): + t = RebuildFromTypeV2Spoof(torch.randn(2, 3), False) + inp = {'weight': t} + + with TemporaryFileName() as f: + torch.save(inp, f) + # subclass will be pushed onto unpickler's stack as a string + # and only gets converted to the type if it is argument 1 to _rebuild_from_type_v2 + with self.assertRaisesRegex(TypeError, "'str' object is not callable"): + loaded = torch.load(f, weights_only=True) + + instantiate_device_type_tests(TestBothSerialization, globals()) instantiate_parametrized_tests(TestSubclassSerialization) diff --git a/test/test_torch.py b/test/test_torch.py index 81da78f9a8820..c8cff93bd1bf6 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -29,6 +29,8 @@ from functools import partial from torch import multiprocessing as mp from torch.testing import make_tensor +from torch.testing._internal.common_optimizers import ( + optim_db, optims, _get_optim_inputs_including_global_cliquey_kwargs) from torch.testing._internal.common_utils import ( # type: ignore[attr-defined] TEST_WITH_TORCHINDUCTOR, TEST_WITH_ROCM, run_tests, IS_JETSON, @@ -5877,8 +5879,13 @@ def _run_scaling_case(self, device, run, unskipped, skipped, atol=1e-7, optimize self.assertEqual(c, s, atol=atol, rtol=1e-05) - # Compares no scaling + no autocasting against scaling + autocasting. - def _grad_scaling_autocast_test(self, *, device="cuda", atol=1e-3, optimizer_ctor=torch.optim.SGD, optimizer_kwargs=None): + @onlyNativeDeviceTypes + @parametrize("foreach, fused", [(None, None), (True, None), (None, True)]) + @optims( + [optim for optim in optim_db if optim.optim_cls in [torch.optim.AdamW, torch.optim.Adam, torch.optim.SGD]], + dtypes=[torch.float32] + ) + def test_grad_scaling_autocast(self, device, dtype, optim_info, foreach, fused): try_pickle = False def run(device, data, model, optimizer, scaler, loss_fn, skip_iter, try_scaling_api): @@ -5902,6 +5909,9 @@ def run(device, data, model, optimizer, scaler, loss_fn, skip_iter, try_scaling_ optimizer.step() return scaler + optimizer_ctor = optim_info.optim_cls + + # Compares no scaling + no autocasting against scaling + autocasting. # NOTE(mkozuki): With current way of testing, `torch.optim.Adam` is failing in spite of `foreach` and `fused`. # Giving some flexibility to this test might help. context = contextlib.nullcontext @@ -5911,71 +5921,51 @@ def run(device, data, model, optimizer, scaler, loss_fn, skip_iter, try_scaling_ with context(): # sets atol=1e-3 because we're comparing pure fp32 arithmetic vs a mixture of fp16 and fp32 self._run_scaling_case( - device, run, unskipped=3, skipped=1, atol=atol, optimizer_ctor=optimizer_ctor, optimizer_kwargs=optimizer_kwargs, + device, run, unskipped=3, skipped=1, atol=1e-3, + optimizer_ctor=optimizer_ctor, optimizer_kwargs={"foreach": foreach, "fused": fused}, ) # this will be picked up by try_pickle within run(): try_pickle = True self._run_scaling_case( - device, run, unskipped=3, skipped=1, atol=atol, optimizer_ctor=optimizer_ctor, optimizer_kwargs=optimizer_kwargs, + device, run, unskipped=3, skipped=1, atol=1e-3, + optimizer_ctor=optimizer_ctor, optimizer_kwargs={"foreach": foreach, "fused": fused}, ) - @onlyNativeDeviceTypes - def test_grad_scaling_autocast(self, device): - device = torch.device(device) - for optimizer_ctor in (torch.optim.SGD, torch.optim.Adam, torch.optim.AdamW): - self._grad_scaling_autocast_test(device=device.type, optimizer_ctor=optimizer_ctor) - - @onlyNativeDeviceTypes - def test_grad_scaling_autocast_foreach(self, device): - device = torch.device(device) - for optimizer_ctor in (torch.optim.SGD, torch.optim.Adam, torch.optim.AdamW): - self._grad_scaling_autocast_test(device=device.type, optimizer_ctor=optimizer_ctor, optimizer_kwargs={"foreach": True}) - - @onlyNativeDeviceTypes - def test_grad_scaling_autocast_fused(self, device): - device = torch.device(device) - for optimizer_ctor in (torch.optim.Adam, torch.optim.AdamW): - self._grad_scaling_autocast_test(device=device.type, optimizer_ctor=optimizer_ctor, optimizer_kwargs={"fused": True}) - # Make sure that the parameters become nonsense when scaled gradients are finite # but they get invalidated before `optimizer.step`, after `GradScaler.unscale_` @onlyNativeDeviceTypes - def test_params_invalidated_with_grads_invalidated_between_unscale_and_step(self, device): - device = torch.device(device) - for optimizer_ctor, optimizer_kwargs in product( - (torch.optim.Adam, torch.optim.AdamW), - ( - {"foreach": False, "fused": False}, - {"foreach": True, "fused": False}, - {"foreach": False, "fused": True}, - ), - ): - with self.subTest(optimizer=optimizer_ctor, optimizer_kwargs=optimizer_kwargs): - self._test_grads_invalidated_between_unscale_and_step(device.type, optimizer_ctor, optimizer_kwargs) - - def _test_grads_invalidated_between_unscale_and_step(self, device, optimizer_ctor, optimizer_kwargs): - model, _, optimizer, _, data, loss_fn, _ = _create_scaling_case( - device, optimizer_ctor=optimizer_ctor, optimizer_kwargs=optimizer_kwargs, - ) - scaler = torch.GradScaler(device=device, init_scale=128.0) + @optims( + [optim for optim in optim_db if optim.optim_cls in [torch.optim.AdamW, torch.optim.Adam, torch.optim.SGD]], + dtypes=[torch.float32] + ) + def test_params_invalidated_with_grads_invalidated_between_unscale_and_step(self, device, dtype, optim_info): + optimizer_ctor = optim_info.optim_cls + all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( + device, dtype, optim_info, skip=("differentiable",)) + + for optim_input in all_optim_inputs: + model, _, optimizer, _, data, loss_fn, _ = _create_scaling_case( + device, optimizer_ctor=optimizer_ctor, optimizer_kwargs=optim_input.kwargs, + ) + scaler = torch.GradScaler(device=device, init_scale=128.0) - for input, target in data: - optimizer.zero_grad() - with torch.autocast(device_type=device, dtype=torch.half): - output = model(input) - loss = loss_fn(output, target) - scaler.scale(loss).backward() - scaler.unscale_(optimizer) + for input, target in data: + optimizer.zero_grad() + with torch.autocast(device_type=device, dtype=torch.half): + output = model(input) + loss = loss_fn(output, target) + scaler.scale(loss).backward() + scaler.unscale_(optimizer) - # deliberately break grads - for j, param in enumerate(model.parameters()): - param.grad.copy_(torch.inf if j % 2 else torch.nan) + # deliberately break grads + for j, param in enumerate(model.parameters()): + param.grad.copy_(torch.inf if j % 2 else torch.nan) - scaler.step(optimizer) - scaler.update() + scaler.step(optimizer) + scaler.update() - self.assertTrue(all((p.isnan().any() or p.isinf().any()) for p in model.parameters())) + self.assertTrue(all((p.isnan().any() or p.isinf().any()) for p in model.parameters())) @onlyNativeDeviceTypes def test_grad_scale_will_not_overflow(self, device): diff --git a/test/test_xpu.py b/test/test_xpu.py index 74cc891a9e624..a3838f1d5a05d 100644 --- a/test/test_xpu.py +++ b/test/test_xpu.py @@ -1,6 +1,7 @@ # Owner(s): ["module: intel"] import sys +import tempfile import unittest import torch @@ -270,6 +271,40 @@ def convert_boolean_tensors(x): self.assertEqual(expect, actual) + def test_serialization_array_with_storage(self): + x = torch.randn(5, 5).xpu() + y = torch.zeros(2, 5, dtype=torch.int, device="xpu") + q = [x, y, x, y.storage()] + with tempfile.NamedTemporaryFile() as f: + torch.save(q, f) + f.seek(0) + q_copy = torch.load(f) + self.assertEqual(q_copy, q, atol=0, rtol=0) + q_copy[0].fill_(5) + self.assertEqual(q_copy[0], q_copy[2], atol=0, rtol=0) + self.assertEqual(q_copy[0].dtype, torch.float) + self.assertEqual(q_copy[1].dtype, torch.int) + self.assertEqual(q_copy[2].dtype, torch.float) + self.assertTrue(isinstance(q_copy[3], torch.storage.TypedStorage)) + self.assertTrue(isinstance(q_copy[3]._untyped_storage, torch.UntypedStorage)) + q_copy[1].fill_(10) + y.fill_(10) + self.assertEqual(q_copy[3], y.storage()) + + def test_serialization_array_with_empty(self): + x = [ + torch.randn(4, 4).xpu(), + torch.tensor([], dtype=torch.float, device=torch.device("xpu")), + ] + with tempfile.NamedTemporaryFile() as f: + torch.save(x, f) + f.seek(0) + x_copy = torch.load(f) + for original, copy in zip(x, x_copy): + self.assertEqual(copy, original) + self.assertIs(type(copy), type(original)) + self.assertEqual(copy.get_device(), original.get_device()) + instantiate_device_type_tests(TestXpu, globals(), only_for="xpu") diff --git a/test/torch_np/numpy_tests/core/test_indexing.py b/test/torch_np/numpy_tests/core/test_indexing.py index f0d9023ddbff0..77844e77a6e0b 100644 --- a/test/torch_np/numpy_tests/core/test_indexing.py +++ b/test/torch_np/numpy_tests/core/test_indexing.py @@ -740,7 +740,7 @@ def _get_multi_index(self, arr, indices): try: indx = np.array(indx, dtype=np.intp) except ValueError: - raise IndexError + raise IndexError from None in_indices[i] = indx elif indx.dtype.kind != "b" and indx.dtype.kind != "i": raise IndexError( @@ -902,7 +902,7 @@ def _get_multi_index(self, arr, indices): arr = arr.reshape(arr.shape[:ax] + mi.shape + arr.shape[ax + 1 :]) except ValueError: # too many dimensions, probably - raise IndexError + raise IndexError from None ax += mi.ndim continue diff --git a/test/torch_np/numpy_tests/core/test_multiarray.py b/test/torch_np/numpy_tests/core/test_multiarray.py index 38e2df73d5b86..bf9aab8ebcee2 100644 --- a/test/torch_np/numpy_tests/core/test_multiarray.py +++ b/test/torch_np/numpy_tests/core/test_multiarray.py @@ -409,7 +409,7 @@ def make_array(size, offset, strides): try: r = np.ndarray([size], dtype=int, buffer=x, offset=offset * x.itemsize) except Exception as e: - raise RuntimeError(e) # noqa: TRY200 + raise RuntimeError(e) # noqa: B904 r.strides = strides = strides * x.itemsize return r @@ -6304,7 +6304,7 @@ def test_flat_element_deletion(self): except TypeError: pass except Exception: - raise AssertionError + raise AssertionError from None class TestConversion(TestCase): diff --git a/test/torch_np/numpy_tests/core/test_scalar_methods.py b/test/torch_np/numpy_tests/core/test_scalar_methods.py index addc550ed3379..2e763c6636a84 100644 --- a/test/torch_np/numpy_tests/core/test_scalar_methods.py +++ b/test/torch_np/numpy_tests/core/test_scalar_methods.py @@ -132,7 +132,7 @@ def test_roundtrip(self, ftype, frac_vals, exp_vals): df = np.longdouble(d) except (OverflowError, RuntimeWarning): # the values may not fit in any float type - raise SkipTest("longdouble too small on this platform") # noqa: TRY200 + raise SkipTest("longdouble too small on this platform") # noqa: B904 assert_equal(nf / df, f, f"{n}/{d}") diff --git a/test/torch_np/numpy_tests/lib/test_function_base.py b/test/torch_np/numpy_tests/lib/test_function_base.py index fa1168840635c..d0eda87b0108a 100644 --- a/test/torch_np/numpy_tests/lib/test_function_base.py +++ b/test/torch_np/numpy_tests/lib/test_function_base.py @@ -1435,7 +1435,7 @@ def test_keywords_no_func_code(self): try: vectorize(random.randrange) # Should succeed except Exception: - raise AssertionError # noqa: TRY200 + raise AssertionError # noqa: B904 def test_keywords2_ticket_2100(self): # Test kwarg support: enhancement ticket 2100 diff --git a/test/torch_np/numpy_tests/linalg/test_linalg.py b/test/torch_np/numpy_tests/linalg/test_linalg.py index 616c7b95f5c90..3a5c21745e246 100644 --- a/test/torch_np/numpy_tests/linalg/test_linalg.py +++ b/test/torch_np/numpy_tests/linalg/test_linalg.py @@ -1958,7 +1958,7 @@ def test_xerbla_override(self): pid = os.fork() except (OSError, AttributeError): # fork failed, or not running on POSIX - raise SkipTest("Not POSIX or fork failed.") # noqa: TRY200 + raise SkipTest("Not POSIX or fork failed.") # noqa: B904 if pid == 0: # child; close i/o file handles diff --git a/third_party/cpp-httplib b/third_party/cpp-httplib new file mode 160000 index 0000000000000..3b6597bba913d --- /dev/null +++ b/third_party/cpp-httplib @@ -0,0 +1 @@ +Subproject commit 3b6597bba913d51161383657829b7e644e59c006 diff --git a/third_party/cub b/third_party/cub deleted file mode 160000 index d106ddb991a56..0000000000000 --- a/third_party/cub +++ /dev/null @@ -1 +0,0 @@ -Subproject commit d106ddb991a56c3df1b6d51b2409e36ba8181ce4 diff --git a/third_party/ideep b/third_party/ideep index 8a6cc4e09dc50..55ca0191687aa 160000 --- a/third_party/ideep +++ b/third_party/ideep @@ -1 +1 @@ -Subproject commit 8a6cc4e09dc509f04f83c085e38786b1fb44e14d +Subproject commit 55ca0191687aaf19aca5cdb7881c791e3bea442b diff --git a/third_party/mkl-dnn.BUILD b/third_party/mkl-dnn.BUILD index dac4f9e3e8cf8..9a688a52b1cf6 100644 --- a/third_party/mkl-dnn.BUILD +++ b/third_party/mkl-dnn.BUILD @@ -63,9 +63,9 @@ template_rule( out = "include/oneapi/dnnl/dnnl_version.h", substitutions = { "@DNNL_VERSION_MAJOR@": "3", - "@DNNL_VERSION_MINOR@": "3", - "@DNNL_VERSION_PATCH@": "6", - "@DNNL_VERSION_HASH@": "86e6af5974177e513fd3fee58425e1063e7f1361", + "@DNNL_VERSION_MINOR@": "4", + "@DNNL_VERSION_PATCH@": "2", + "@DNNL_VERSION_HASH@": "1137e04ec0b5251ca2b4400a4fd3c667ce843d67", }, ) diff --git a/third_party/onnx-tensorrt b/third_party/onnx-tensorrt deleted file mode 160000 index c153211418a7c..0000000000000 --- a/third_party/onnx-tensorrt +++ /dev/null @@ -1 +0,0 @@ -Subproject commit c153211418a7c57ce071d9ce2a41f8d1c85a878f diff --git a/third_party/zstd b/third_party/zstd deleted file mode 160000 index aec56a52fbab2..0000000000000 --- a/third_party/zstd +++ /dev/null @@ -1 +0,0 @@ -Subproject commit aec56a52fbab207fc639a1937d1e708a282edca8 diff --git a/tools/pyi/gen_pyi.py b/tools/pyi/gen_pyi.py index d5f6837cba01e..a7eb81341eb54 100644 --- a/tools/pyi/gen_pyi.py +++ b/tools/pyi/gen_pyi.py @@ -1130,6 +1130,18 @@ def replace_special_case(hint: str) -> str: ) ) ], + "xpu": [ + "def xpu({}) -> Tensor: ...".format( + ", ".join( + [ + "self", + "device: Optional[Union[_device, _int, str]] = None", + "non_blocking: _bool = False", + "memory_format: torch.memory_format = torch.preserve_format", + ] + ) + ) + ], "cpu": [ "def cpu(self, memory_format: torch.memory_format = torch.preserve_format) -> Tensor: ..." ], diff --git a/tools/stats/upload_test_stats_intermediate.py b/tools/stats/upload_test_stats_intermediate.py new file mode 100644 index 0000000000000..77cab472367bb --- /dev/null +++ b/tools/stats/upload_test_stats_intermediate.py @@ -0,0 +1,29 @@ +import argparse +import sys + +from tools.stats.test_dashboard import upload_additional_info +from tools.stats.upload_test_stats import get_tests + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Upload test stats to Rockset") + parser.add_argument( + "--workflow-run-id", + required=True, + help="id of the workflow to get artifacts from", + ) + parser.add_argument( + "--workflow-run-attempt", + type=int, + required=True, + help="which retry of the workflow this is", + ) + args = parser.parse_args() + + print(f"Workflow id is: {args.workflow_run_id}") + + test_cases = get_tests(args.workflow_run_id, args.workflow_run_attempt) + + # Flush stdout so that any errors in Rockset upload show up last in the logs. + sys.stdout.flush() + + upload_additional_info(args.workflow_run_id, args.workflow_run_attempt, test_cases) diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index ac70396c468e2..0599da2117fbb 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -1196,6 +1196,7 @@ def _has_storage(x: Tensor) -> _bool: ... def _construct_storage_from_data_pointer(data_ptr: _int, device: torch.device, size: _int) -> Storage: ... def _should_allow_numbers_as_tensors(func_name: str) -> _bool: ... def _group_tensors_by_device_and_dtype(nested_tensorlists: List[List[Optional[Tensor]]], with_indices: _bool = False) -> Dict[Tuple[torch.device, str], Tuple[List[List[Optional[Tensor]]], List[_int]]]: ... +def _check_tp_alloc_is_default(cls: Type) -> _bool: ... # NB: There is no Capsule type in typing, see # https://code.activestate.com/lists/python-dev/139675/ diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index 28d790e3d6903..74a73a3ddaa46 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -210,6 +210,20 @@ class PrefixStore(Store): @property def underlying_store(self) -> Store: ... +class _ControlCollectives: + def barrier(self, key: str, timeout: timedelta, blocking: bool) -> None: ... + def broadcast_send(self, key: str, data: str, timeout: timedelta) -> None: ... + def broadcast_recv(self, key: str, timeout: timedelta) -> str: ... + def gather_send(self, key: str, data: str, timeout: timedelta) -> None: ... + def gather_recv(self, key: str, timeout: timedelta) -> str: ... + def scatter_send(self, key: str, data: str, timeout: timedelta) -> None: ... + def scatter_recv(self, key: str, timeout: timedelta) -> str: ... + def all_gather(self, key: str, data: str, timeout: timedelta) -> str: ... + def all_sum(self, key: str, data: str, timeout: timedelta) -> int: ... + +class _StoreCollectives(_ControlCollectives): + def __init__(self, store: Store, rank: int, world_size: int) -> None: ... + class _DistributedBackendOptions: def __init__(self): ... @property diff --git a/torch/_C/_dynamo/compiled_autograd.pyi b/torch/_C/_dynamo/compiled_autograd.pyi index 8ec4fbbdae8c3..d2067a5839210 100644 --- a/torch/_C/_dynamo/compiled_autograd.pyi +++ b/torch/_C/_dynamo/compiled_autograd.pyi @@ -7,4 +7,4 @@ def set_autograd_compiler( ) -> Optional[Callable[[], AutogradCompilerInstance]]: ... def clear_cache() -> None: ... def is_cache_empty() -> bool: ... -def set_verbose_logging(enable: bool) -> bool: ... +def set_verbose_logger(fn: Optional[Callable[[str], None]]) -> bool: ... diff --git a/torch/_custom_ops.py b/torch/_custom_ops.py index c13b0aaf339ad..c09a8ae68543f 100644 --- a/torch/_custom_ops.py +++ b/torch/_custom_ops.py @@ -250,7 +250,7 @@ def impl_abstract(qualname, *, func=None): """ import torch.library - return torch.library.impl_abstract(qualname, func, _stacklevel=2) + return torch.library.register_fake(qualname, func, _stacklevel=2) def impl_save_for_backward(qualname, *, func=None): diff --git a/torch/_dynamo/compiled_autograd.py b/torch/_dynamo/compiled_autograd.py index e8e61042d4746..7a87a2c7d575c 100644 --- a/torch/_dynamo/compiled_autograd.py +++ b/torch/_dynamo/compiled_autograd.py @@ -37,6 +37,10 @@ def snapshot_verbose_logging_enabled(): ) +def cpp_verbose_log_fn(msg: str) -> None: + verbose_log.debug(msg) + + def maybe_clone(x): if x is not None: return clone_preserve_strides(x) @@ -292,9 +296,8 @@ def enable(compiler_fn): prior = torch._C._dynamo.compiled_autograd.set_autograd_compiler( functools.partial(AutogradCompilerInstance, compiler_fn) ) - torch._C._dynamo.compiled_autograd.set_verbose_logging( - snapshot_verbose_logging_enabled() - ) + if snapshot_verbose_logging_enabled(): + torch._C._dynamo.compiled_autograd.set_verbose_logger(cpp_verbose_log_fn) global compiled_autograd_enabled, compiled_autograd_enabled_count compiled_autograd_enabled = True compiled_autograd_enabled_count += 1 @@ -319,3 +322,11 @@ def disable(): if prior: compiled_autograd_enabled = True torch._C._dynamo.compiled_autograd.set_autograd_compiler(prior) + + +# return to starting state of a new process +def reset() -> None: + compiled_autograd_enable = False + assert compiled_autograd_enabled_count == 0 + torch._C._dynamo.compiled_autograd.set_autograd_compiler(None) + torch._C._dynamo.compiled_autograd.set_verbose_logger(None) diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index c1733280e31f3..d5c24a67d9e25 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -312,7 +312,7 @@ def profile_wrapper(*args, **kwargs): retval = prof.runcall(func, *args, **kwargs) profile_latency = time.time() - start_ts prof.disable() - log.info( + log.warning( "### Cprofile for %s trace id [%s] took %.3f seconds ###", func.__name__, trace_id, @@ -322,7 +322,7 @@ def profile_wrapper(*args, **kwargs): try: prof.dump_stats(profile_path) except PermissionError: - log.info("Cannot write to %s", str(profile_path)) + log.warning("Cannot write to %s", str(profile_path)) svg_path = profile_path.with_suffix(".svg") try: gprof2dot_process = subprocess.Popen( @@ -341,9 +341,9 @@ def profile_wrapper(*args, **kwargs): ["dot", "-Tsvg", "-o", str(svg_path)], stdin=gprof2dot_process.stdout, ) - log.info("Generated SVG from profile at %s", str(svg_path)) + log.warning("Generated SVG from profile at %s", str(svg_path)) except FileNotFoundError: - log.info( + log.warning( "Failed to generate SVG from profile -- dumping stats instead." "Try installing gprof2dot and dot for a better visualization" ) diff --git a/torch/_dynamo/debug_utils.py b/torch/_dynamo/debug_utils.py index 67dd492fe8512..4b4b37a34da94 100644 --- a/torch/_dynamo/debug_utils.py +++ b/torch/_dynamo/debug_utils.py @@ -360,7 +360,7 @@ def same_two_models( fp64_ref = run_fwd_maybe_bwd(fp64_model, fp64_examples, only_fwd) except Exception: if require_fp64: - raise RuntimeError("Could not generate fp64 outputs") # noqa: TRY200 + raise RuntimeError("Could not generate fp64 outputs") # noqa: B904 log.warning("Could not generate fp64 outputs") try: diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 391bdfcf02020..db35c0f631e8c 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -1387,7 +1387,7 @@ def graph_with_interpreter(*args): )(*example_fake_inputs) except CondOpArgsMismatchError as e: # Wrap the internal error to the user-facing error - raise UserError( # noqa: TRY200 + raise UserError( # noqa: B904 UserErrorType.DYNAMIC_CONTROL_FLOW, str(e), case_name="cond_operands", diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index d6fb3e2145b73..093809703405f 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -1504,6 +1504,14 @@ def SET_ADD(self, inst): assert obj.mutable_local return obj.call_method(self, "add", [v], {}) + def SET_UPDATE(self, inst): + v = self.pop() + assert inst.argval > 0 + obj = self.stack[-inst.arg] + assert isinstance(obj, SetVariable) + assert obj.mutable_local + obj.call_method(self, "update", [v], {}) + def LIST_APPEND(self, inst): v = self.pop() assert inst.argval > 0 @@ -2494,7 +2502,7 @@ def inline_call_( sub_locals, closure_cells = func.bind_args(parent, args, kwargs) except TypeError as e: # Wrap the general TypeError during bind_args() to the internal ArgsMismatchError with detailed info - raise ArgsMismatchError( # noqa: TRY200 + raise ArgsMismatchError( # noqa: B904 "{reason}.\n func = {func}, args = {args}, kwargs = {kwargs}".format( reason=str(e), func=f"'{func.get_name()}' {func.get_filename()}:{func.get_code().co_firstlineno}", diff --git a/torch/_dynamo/testing.py b/torch/_dynamo/testing.py index b4c022e8d8c24..9e9abe84228b1 100644 --- a/torch/_dynamo/testing.py +++ b/torch/_dynamo/testing.py @@ -349,6 +349,12 @@ def xfailIfPy312(fn): return fn +def skipIfPy312(fn): + if sys.version_info >= (3, 12): + return unittest.skip(fn) + return fn + + # Controls tests generated in test/inductor/test_torchinductor_dynamic_shapes.py # and test/dynamo/test_dynamic_shapes.py def expectedFailureDynamic(fn): diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 21f3da2f61e6f..fcfbde1a6a799 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -1751,7 +1751,7 @@ def get_fake_value(node, tx, allow_non_graph_fake=False): elif isinstance( cause, torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode ): - raise UserError( # noqa: TRY200 + raise UserError( # noqa: B904 UserErrorType.CONSTRAINT_VIOLATION, "Tried to use data-dependent value in the subsequent computation. " "This can happen when we encounter unbounded dynamic value that is unknown during tracing time. " diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 8f9ab01088a70..41b9fbd836ae1 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -2379,6 +2379,8 @@ def create(tx, value) -> VariableTracker: return PlacementVariable(value) elif DeviceMeshVariable.is_device_mesh(value): return DeviceMeshVariable(value) + elif isinstance(value, re.Pattern): + return RegexPatternVariable(value) unimplemented( f"Unexpected type in sourceless builder {value_type.__module__}.{value_type.__qualname__}" ) @@ -2399,6 +2401,7 @@ def make_type_handlers(): ) handlers[dict] = lambda tx, value: ConstDictVariable( {create(tx, k): create(tx, v) for k, v in value.items()}, + type(value), mutable_local=MutableLocal(), ) handlers[list] = lambda tx, value: ListVariable( @@ -2410,6 +2413,7 @@ def make_type_handlers(): handlers[torch.Size] = lambda tx, value: SizeVariable( [create(tx, x) for x in value] ) + handlers[collections.OrderedDict] = handlers[dict] handlers[immutable_dict] = handlers[dict] handlers[immutable_list] = handlers[list] handlers[types.ModuleType] = lambda tx, value: PythonModuleVariable(value) diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index c8eabc2c88799..0724a80621f76 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -407,6 +407,8 @@ def call_method( args: List[VariableTracker], kwargs: Dict[str, VariableTracker], ) -> "VariableTracker": + from . import ListVariable, TupleVariable + # We foward the calls to the dictionary model if name == "add": assert not kwargs @@ -426,6 +428,24 @@ def call_method( return variables.UserFunctionVariable( polyfill.set_isdisjoint ).call_function(tx, [self, args[0]], {}) + elif ( + name == "update" + and len(args) == 1 + and isinstance( + args[0], + ( + SetVariable, + ListVariable, + TupleVariable, + ), + ) + and self.mutable_local + ): + if isinstance(args[0], (ListVariable, TupleVariable)): + arg = SetVariable(args[0].unpack_var_sequence(tx)) + else: + arg = args[0] + return super().call_method(tx, "update", (arg,), kwargs) return super().call_method(tx, name, args, kwargs) def getitem_const(self, arg: VariableTracker): diff --git a/torch/_dynamo/variables/sdpa.py b/torch/_dynamo/variables/sdpa.py index 0a6af76690df6..c5b0d9f586c8a 100644 --- a/torch/_dynamo/variables/sdpa.py +++ b/torch/_dynamo/variables/sdpa.py @@ -65,9 +65,9 @@ def var_getattr(self, tx, name: str) -> VariableTracker: getattr_static(torch._C._SDPAParams, name) except AttributeError: # Using raise from is too verbose here - raise Unsupported( # noqa: TRY200 + raise Unsupported( f"Unsupported torch._C._SDPAParams attribute {name}" - ) + ) from None proxy = GetAttrVariable.create_getattr_proxy(self.as_proxy(), name) if self.source is not None: diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index 361527050b163..a1adbcf614bca 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -1019,7 +1019,7 @@ def evaluate_expr(self, output_graph=None): try: return guard_scalar(self.sym_num) except GuardOnDataDependentSymNode as e: - raise UserError( # noqa: TRY200 + raise UserError( # noqa: B904 UserErrorType.ANTI_PATTERN, f"Consider annotating your code using torch._check*(). {str(e)}", case_name="constrain_as_size_example", diff --git a/torch/_export/__init__.py b/torch/_export/__init__.py index 05aee5f28f66c..105a7ee2594bc 100644 --- a/torch/_export/__init__.py +++ b/torch/_export/__init__.py @@ -149,7 +149,10 @@ def capture_pre_autograd_graph( kwargs = {} if export_api_rollout_check(): - log.warning("Using torch.export._trace._export") + @lru_cache + def print_export_warning(): + log.warning("Using torch.export._trace._export") + print_export_warning() module = torch.export._trace._export(f, args, kwargs, dynamic_shapes=dynamic_shapes, pre_dispatch=True).module() else: log_export_usage(event="export.private_api", flags={"capture_pre_autograd_graph"}) diff --git a/torch/_export/non_strict_utils.py b/torch/_export/non_strict_utils.py index 42c20aa555000..aff3d444c960b 100644 --- a/torch/_export/non_strict_utils.py +++ b/torch/_export/non_strict_utils.py @@ -110,7 +110,9 @@ def make_fake_params_buffers( return faked_params_buffers # type: ignore[return-value] -def make_fake_inputs(nn_module, args, kwargs, dynamic_shapes): +def make_fake_inputs( + nn_module, args, kwargs, dynamic_shapes, _is_torch_jit_trace=False +): """ Given an nn module, example inputs, and constraints, return a new fake mode, fake inputs created in that mode whose dynamic shape dimensions are constrained @@ -127,7 +129,7 @@ def make_fake_inputs(nn_module, args, kwargs, dynamic_shapes): # - [post-tracing] guards.py processes input shape equalities. constraints = torch.export.dynamic_shapes._process_dynamic_shapes( - nn_module, args, kwargs, dynamic_shapes + nn_module, args, kwargs, dynamic_shapes, _is_torch_jit_trace=_is_torch_jit_trace ) constraints = constraints or [] t_constraints: Dict[int, Dict[int, Constraint]] = defaultdict(dict) @@ -136,13 +138,6 @@ def make_fake_inputs(nn_module, args, kwargs, dynamic_shapes): if constraint.shared is not None: t_constraints[constraint.shared.t_id][constraint.shared.dim] = constraint - code = nn_module.forward.__code__ - co_fields = { - "co_name": code.co_name, - "co_filename": code.co_filename, - "co_firstlineno": code.co_firstlineno, - } - context = torch._guards.TracingContext.try_get() if context is not None: # This occurs when we are exporting within dynamo. There already exists @@ -153,11 +148,22 @@ def make_fake_inputs(nn_module, args, kwargs, dynamic_shapes): len(constraints) == 0 ), "Found constraints when tracing with a toplevel tracing context." fake_mode = context.fake_mode - else: + elif not _is_torch_jit_trace: + code = nn_module.forward.__code__ + co_fields = { + "co_name": code.co_name, + "co_filename": code.co_filename, + "co_firstlineno": code.co_firstlineno, + } fake_mode = FakeTensorMode( shape_env=ShapeEnv(tracked_fakes=[], co_fields=co_fields), allow_non_fake_inputs=True, ) + else: + fake_mode = FakeTensorMode( + shape_env=ShapeEnv(tracked_fakes=[]), + allow_non_fake_inputs=True, + ) if fake_mode.shape_env is None or fake_mode.shape_env.tracked_fakes is None: raise ValueError( "Detected fake_mode does not have a shape_env with tracked fakes. " @@ -166,7 +172,11 @@ def make_fake_inputs(nn_module, args, kwargs, dynamic_shapes): ) with fake_mode: - original_signature = inspect.signature(nn_module.forward) + # FIXME(ycao) ScriptMethod doesn't have signature, I am using an empty one to unblock + if not _is_torch_jit_trace: + original_signature = inspect.signature(nn_module.forward) + else: + original_signature = None sources: Dict[Tuple[int, int], List[Source]] = defaultdict(list) fake_args, fake_kwargs = tree_map_with_path( lambda kp, val: fakify(fake_mode, kp, val, t_constraints, sources), @@ -215,6 +225,7 @@ def produce_guards_and_solve_constraints( equalities_inputs: EqualityConstraint, original_signature: inspect.Signature, _disable_forced_specializations: Optional[bool] = False, + _is_torch_jit_trace=False, ): """ Given a fake mode, sources pairs corresponding to equal dynamic shape dimensions, @@ -259,9 +270,13 @@ def produce_guards_and_solve_constraints( ) dim_constraints.remove_redundant_dynamic_results() forced_specializations = dim_constraints.forced_specializations() - msg = dim_constraints.prettify_results( - original_signature, constraint_violation_error, forced_specializations - ) + if not _is_torch_jit_trace: + msg = dim_constraints.prettify_results( + original_signature, constraint_violation_error, forced_specializations + ) + else: + # FIXME(ycao): This is a hack to get around missing signature from ScriptMethod + msg = "dummy constraint violation message" if constraint_violation_error: constraint_violation_error.args = (constraint_violation_error.args[0] + msg,) elif forced_specializations: diff --git a/torch/_export/tools.py b/torch/_export/tools.py new file mode 100644 index 0000000000000..d76392993bd20 --- /dev/null +++ b/torch/_export/tools.py @@ -0,0 +1,139 @@ +import logging +import warnings +from typing import Any, Dict, Iterable, Optional, Tuple + +import torch +import torch.export +import torch.export._trace +from torch._utils_internal import log_export_usage + +log = logging.getLogger(__name__) + +__all__ = ["report_exportability"] + + +def _generate_inputs_for_submodules( + model: torch.nn.Module, + target_submodules: Iterable[str], + args: Tuple[Any, ...], + kwargs: Optional[Dict[str, Any]] = None, +) -> Dict[str, Tuple[Any, Any]]: + """ + Generate inputs for targeting submdoules in the given model. Note that if two submodules refer to the same obj, this + function doesn't work. + + Args: + model: root model. + inputs: inputs to the root model. + target_submodules: submodules that we want to generate inputs for. + + Returns: + A dict that maps from submodule name to its inputs. + """ + kwargs = kwargs or {} + + handles = [] + results = {} + submodule_to_names = {mod: name for name, mod in model.named_modules()} + + def pre_forward(module, module_args, module_kwargs): + results[submodule_to_names[module]] = (module_args, module_kwargs) + + try: + for name, mod in model.named_modules(): + if name in target_submodules: + handles.append( + mod.register_forward_pre_hook(pre_forward, with_kwargs=True) + ) + model(*args, **kwargs) + except Exception as e: + warnings.warn( + f"Failed to generate submodule inputs because of the following error:\n{e}" + ) + finally: + for h in handles: + h.remove() + return results + + +def report_exportability( + mod: torch.nn.Module, + args: Tuple[Any, ...], + kwargs: Optional[Dict[str, Any]] = None, + *, + strict: bool = True, + pre_dispatch: bool = False, +) -> Dict[str, Optional[Exception]]: + """ + Report exportability issues for a module in one-shot. + + Args: + mod: root module. + args: args to the root module. + kwargs: kwargs to the root module. + Returns: + A dict that maps from submodule name to the exception that was raised when trying to export it. + `None` means the module is exportable without issue. + Sample output: + { + '': UnsupportedOperatorException(func=), + 'submod_1': UnsupportedOperatorException(func=), + 'submod_2': None + } + """ + + log_export_usage(event="export.report_exportability") + + kwargs = kwargs or {} + + all_submod_names = [name for name, _ in mod.named_modules() if name != ""] + submod_inputs = _generate_inputs_for_submodules(mod, all_submod_names, args, kwargs) + + report: Dict[str, Optional[Exception]] = {} + + def try_export(module, module_name, args, kwargs): + nonlocal submod_inputs, report, strict, pre_dispatch + + if args is not None or kwargs is not None: + try: + torch.export._trace._export( + module, + args, + kwargs, + strict=strict, + pre_dispatch=pre_dispatch, + ) + report[module_name] = None + log.info("Successfully exported `%s`", module_name) + return + except Exception as e: + short_msg = repr(e).split("\n")[0] + log.warning( + "Failed exporting `%s` with exception: %s", module_name, short_msg + ) + report[module_name] = e + + for name, submod in module.named_children(): + sub_module_name = name if module_name == "" else f"{module_name}.{name}" + + submod_args, submod_kwargs = submod_inputs.get( + sub_module_name, (None, None) + ) + + try_export(submod, sub_module_name, submod_args, submod_kwargs) + + return + + try_export(mod, "", args, kwargs) + + unique_issues = set() + for exception in report.values(): + if exception is not None: + key = repr(exception).split("\\n")[0] + unique_issues.add(key) + + log.warning("Found %d export issues:", len(unique_issues)) + for issue in unique_issues: + log.warning(issue) + + return report diff --git a/torch/_export/utils.py b/torch/_export/utils.py index 59648ccadab25..19fc4e9bdc4d4 100644 --- a/torch/_export/utils.py +++ b/torch/_export/utils.py @@ -118,7 +118,7 @@ def get_keystr(key_path: KeyPath) -> str: sympy.Eq(node_dim.node.expr, arg_dim), symbol ) if solution is None: - raise RuntimeError( # noqa: TRY200 + raise RuntimeError( # noqa: B904 f"Expected input {node.name}.shape[{j}] = {arg_dim} to be " f"of the form {node_dim.node.expr}, where {symbol} is an integer" ) diff --git a/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py b/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py index 320a899e6b646..0b6e02da80d21 100644 --- a/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py @@ -83,26 +83,6 @@ def _force_contiguous(x): return x -def _compute_output_meta_with_inductor_strides(fw_module, fwd_output_strides): - out = [n.meta["val"] for n in (list(fw_module.graph.nodes)[-1].args[0])] - # will only be set for inductor - if not fwd_output_strides: - return out - - from torch.fx.experimental.symbolic_shapes import statically_known_true - - for i in range(len(out)): - if not isinstance(out[i], Tensor): - continue - if all( - statically_known_true(s1 == s2) - for s1, s2 in zip(out[i].stride(), fwd_output_strides[i]) - ): - continue - out[i] = out[i].as_strided(out[i].shape, fwd_output_strides[i]) - return out - - # See Note [Tangents must be contiguous, Part 2] def coerce_runtime_tangent(x, metadata_tensor): if not isinstance(x, torch.Tensor): diff --git a/torch/_functorch/_aot_autograd/runtime_wrappers.py b/torch/_functorch/_aot_autograd/runtime_wrappers.py index a1fb2980ed1d4..c1b9a3b29f2e2 100644 --- a/torch/_functorch/_aot_autograd/runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/runtime_wrappers.py @@ -486,15 +486,18 @@ def _compute_output_meta_with_inductor_strides(self): fwd_output_strides = self.fwd_output_strides if not fwd_output_strides: return out - with TracingContext.get().fake_mode.shape_env.suppress_guards(): - for i in range(len(out)): - if not isinstance(out[i], Tensor): - continue - if all( - s1 == s2 for s1, s2 in zip(out[i].stride(), fwd_output_strides[i]) - ): - continue - out[i] = out[i].as_strided(out[i].shape, fwd_output_strides[i]) + + from torch.fx.experimental.symbolic_shapes import statically_known_true + + for i in range(len(out)): + if not isinstance(out[i], Tensor): + continue + if all( + statically_known_true(s1 == s2) + for s1, s2 in zip(out[i].stride(), fwd_output_strides[i]) + ): + continue + out[i] = out[i].as_strided(out[i].shape, fwd_output_strides[i]) return out # To be called post compile diff --git a/torch/_functorch/compile_utils.py b/torch/_functorch/compile_utils.py index ffa37e59f04df..c9c750835a9f1 100644 --- a/torch/_functorch/compile_utils.py +++ b/torch/_functorch/compile_utils.py @@ -1,6 +1,8 @@ # mypy: ignore-errors +from typing import Callable + import torch import torch.fx as fx from torch.utils import _pytree as pytree @@ -9,7 +11,7 @@ aten = torch.ops.aten -def get_aten_target(node): +def get_aten_target(node: fx.Node) -> Callable: if hasattr(node.target, "overloadpacket"): return node.target.overloadpacket return node.target diff --git a/torch/_functorch/config.py b/torch/_functorch/config.py index 5749477c6e98c..c559951f38094 100644 --- a/torch/_functorch/config.py +++ b/torch/_functorch/config.py @@ -136,6 +136,10 @@ # of tensors in question. fake_tensor_propagate_real_tensors = False +# Controls the default graph output format used by draw_graph +# Supported formats are defined here https://graphviz.org/docs/outputs/ +torch_compile_graph_format = os.environ.get("TORCH_COMPILE_GRAPH_FORMAT", "svg") + if TYPE_CHECKING: from torch.utils._config_typing import * # noqa: F401, F403 diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py index ba549e5bd6e20..0956ee7e367c4 100644 --- a/torch/_functorch/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -1,5 +1,3 @@ -# mypy: ignore-errors - import copy import functools import heapq @@ -9,7 +7,10 @@ import operator import os from collections import defaultdict -from typing import List, Optional, Set, Tuple, TYPE_CHECKING, Union +from dataclasses import dataclass, replace +from typing import Callable, Dict, List, Optional, Set, Tuple, Union + +import sympy import torch import torch._inductor.inductor_prims @@ -28,19 +29,84 @@ from . import config from .compile_utils import fx_graph_cse, get_aten_target -if TYPE_CHECKING: - import sympy - AOT_PARTITIONER_DEBUG = config.debug_partitioner log = logging.getLogger(__name__) +aten = torch.ops.aten +prims = torch.ops.prims + + +@dataclass +class OpTypes: + """Class for keeping track of different operator categories""" + + fusible_ops: Set[Callable] + compute_intensive_ops: Set[Callable] + random_ops: Set[Callable] + view_ops: Set[Callable] + recomputable_ops: Set[Callable] + + def is_fusible(self, node: fx.Node): + return get_aten_target(node) in self.fusible_ops + + def is_compute_intensive(self, node: fx.Node): + return get_aten_target(node) in self.compute_intensive_ops + + def is_random(self, node: fx.Node): + return get_aten_target(node) in self.random_ops + + def is_view(self, node: fx.Node): + return get_aten_target(node) in self.view_ops + + def is_recomputable(self, node: fx.Node): + return get_aten_target(node) in self.recomputable_ops + + +@dataclass +class NodeInfo: + # Be careful about iterating over these explicitly, as their order may not + # be deterministic + inputs: List[fx.Node] + _required_fw_nodes: Set[fx.Node] + required_bw_nodes: Set[fx.Node] + unclaimed_nodes: Set[fx.Node] + fw_order: Dict[fx.Node, int] + + @functools.cached_property + def required_fw_nodes(self) -> List[fx.Node]: + return sorted( + (n for n in self._required_fw_nodes), key=lambda n: self.fw_order[n] + ) + + def is_required_fw(self, n: fx.Node) -> bool: + return n in self._required_fw_nodes -def must_recompute(node): + def is_required_bw(self, n: fx.Node) -> bool: + return n in self.required_bw_nodes + + def is_unclaimed(self, n: fx.Node) -> bool: + return n in self.unclaimed_nodes + + def get_fw_order(self, n: fx.Node) -> int: + assert n in self._required_fw_nodes, f"Node {n} not in fw nodes!" + return self.fw_order[n] + + +@dataclass +class MinCutOptions: + ban_if_used_far_apart: bool + ban_if_long_fusible_chains: bool + ban_if_materialized_backward: bool + ban_if_not_in_allowlist: bool + ban_if_reduction: bool + + +def must_recompute(node: fx.Node) -> bool: return node.meta.get("recompute", False) -def has_recomputable_ops(fx_g): +def has_recomputable_ops(fx_g: fx.GraphModule) -> bool: found = False for node in fx_g.graph.nodes: if must_recompute(node): @@ -48,7 +114,7 @@ def has_recomputable_ops(fx_g): return False -def has_recomputable_rng_ops(fx_g): +def has_recomputable_rng_ops(fx_g: fx.GraphModule) -> bool: for node in fx_g.graph.nodes: if ( must_recompute(node) @@ -59,7 +125,7 @@ def has_recomputable_rng_ops(fx_g): return False -def sym_node_size(node): +def sym_node_size(node: fx.Node) -> int: if isinstance(node.meta["val"], (torch.SymInt, torch.SymBool)): return 1 assert isinstance(node.meta["val"], torch.SymFloat) @@ -74,7 +140,9 @@ def __repr__(self): InvalidNode = InvalidNodeBase() -def _extract_graph_with_inputs_outputs(joint_graph, inputs, outputs): +def _extract_graph_with_inputs_outputs( + joint_graph: fx.Graph, inputs: List[fx.Node], outputs: List[fx.Node] +) -> fx.Graph: """ Given a graph, extracts out a subgraph that takes the specified nodes as inputs and returns the specified outputs. @@ -136,36 +204,38 @@ def _extract_graph_with_inputs_outputs(joint_graph, inputs, outputs): return new_graph -def _is_primal(node): +def _is_primal(node: fx.Node) -> bool: return ( node.op == "placeholder" - and "tangents" not in node.target + and "tangents" not in str(node.target) and not _is_bwd_seed_offset(node) and not _is_fwd_seed_offset(node) ) -def _is_tangent(node): - return node.op == "placeholder" and "tangents" in node.target +def _is_tangent(node: fx.Node) -> bool: + return node.op == "placeholder" and "tangents" in str(node.target) -def _is_bwd_seed_offset(node): +def _is_bwd_seed_offset(node: fx.Node) -> bool: return node.op == "placeholder" and ( - "bwd_seed" in node.target or "bwd_base_offset" in node.target + "bwd_seed" in str(node.target) or "bwd_base_offset" in str(node.target) ) -def _is_fwd_seed_offset(node): +def _is_fwd_seed_offset(node: fx.Node) -> bool: return node.op == "placeholder" and ( - "fwd_seed" in node.target or "fwd_base_offset" in node.target + "fwd_seed" in str(node.target) or "fwd_base_offset" in str(node.target) ) -def _is_backward_state(node): +def _is_backward_state(node: fx.Node) -> bool: return node.op == "placeholder" and isinstance(node.meta.get("val"), BackwardState) -def _extract_fwd_bwd_outputs(joint_module: fx.GraphModule, *, num_fwd_outputs): +def _extract_fwd_bwd_outputs( + joint_module: fx.GraphModule, *, num_fwd_outputs +) -> Tuple[List[fx.Node], List[fx.Node]]: outputs = pytree.arg_tree_leaves( *(node.args for node in joint_module.graph.find_nodes(op="output")) ) @@ -174,7 +244,7 @@ def _extract_fwd_bwd_outputs(joint_module: fx.GraphModule, *, num_fwd_outputs): return fwd_outputs, bwd_outputs -def _remove_by_name(saved_values, name): +def _remove_by_name(saved_values: List[fx.Node], name: str): for saved_value in saved_values: if saved_value.name == name: saved_values.remove(saved_value) @@ -182,8 +252,12 @@ def _remove_by_name(saved_values, name): def _extract_fwd_bwd_modules( - joint_module: fx.GraphModule, saved_values, saved_sym_nodes, *, num_fwd_outputs -): + joint_module: fx.GraphModule, + saved_values: List[fx.Node], + saved_sym_nodes: List[fx.Node], + *, + num_fwd_outputs: int, +) -> Tuple[fx.GraphModule, fx.GraphModule]: fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs( joint_module, num_fwd_outputs=num_fwd_outputs ) @@ -359,14 +433,10 @@ def default_partition( ) -def _prod(x): - s = 1 - for i in x: - s *= i - return s +INT_INF = int(1e6) -def _tensor_nbytes(numel, dtype): +def _tensor_nbytes(numel: int, dtype) -> int: return numel * dtype.itemsize @@ -374,10 +444,7 @@ def _size_of(node: fx.Node) -> int: if "val" in node.meta: val = node.meta["val"] if isinstance(val, py_sym_types): - if isinstance(val, torch.SymInt): - return 1 - else: - return 999999 + return 1 # NB: The fallback values here are meaningless, maybe we should respect # torch._inductor.config.unbacked_symint_fallback (but this is a # layering violation) @@ -391,28 +458,18 @@ def _size_of(node: fx.Node) -> int: return _tensor_nbytes(hint_int(val.numel(), fallback=4098), val.dtype) raise RuntimeError(f"Unknown metadata type {type(val)}") - - # Only needed since we don't always trace with fake tensors. - if "tensor_meta" in node.meta: - metadata = node.meta["tensor_meta"] - # TODO: What is to_size_hint suppose to be? - numel = _prod(map(to_size_hint, metadata.shape)) # noqa: F821 - dtype = metadata.dtype - else: - return 0 - - return _tensor_nbytes(numel, dtype) + raise RuntimeError("We should always have `val` metadata on the nodes") # Used for some investigative purposes -def _count_ops(graph): +def _count_ops(graph: fx.Graph): from collections import defaultdict - cnt = defaultdict(int) + cnt: Dict[str, int] = defaultdict(int) for node in graph.nodes: if node.op == "call_function": cnt[node.target.__name__] += 1 - print(sorted(cnt.items(), key=operator.itemgetter(1), reverse=True)) + print(sorted(cnt.items(), key=lambda x: x[1], reverse=True)) @functools.lru_cache(None) @@ -433,14 +490,14 @@ def pointwise_ops(): return ops -def sort_depths(args, depth_map): +def sort_depths(args, depth_map: Dict[fx.Node, int]) -> List[Tuple[fx.Node, int]]: arg_depths = { arg: depth_map[arg] for arg in args if isinstance(arg, torch.fx.node.Node) } - return sorted(arg_depths.items(), key=operator.itemgetter(1), reverse=True) + return sorted(arg_depths.items(), key=lambda x: x[1], reverse=True) -def reordering_to_mimic_autograd_engine(gm): +def reordering_to_mimic_autograd_engine(gm: fx.GraphModule) -> fx.GraphModule: """ This pass finds the first bwd node in the graph (by looking at users of tangents) and then reorders the graph by walking from this node to all the @@ -464,7 +521,7 @@ def reordering_to_mimic_autograd_engine(gm): """ new_graph = fx.Graph() - env = {} + env: Dict[fx.Node, fx.Node] = {} # Add new placeholder nodes in the order specified by the inputs for node in gm.graph.find_nodes(op="placeholder"): @@ -517,7 +574,12 @@ def insert_node_in_graph(node): return new_gm -def functionalize_rng_ops(joint_module, fw_module, bw_module, num_sym_nodes): +def functionalize_rng_ops( + joint_module: fx.GraphModule, + fw_module: fx.GraphModule, + bw_module: fx.GraphModule, + num_sym_nodes: int, +) -> Tuple[fx.GraphModule, fx.GraphModule]: # During user-driven activation checkpointing, we have to ensure that a rng # op in fwd yields the same output as the recomputed rng op in the bwd. To # do this, we use functionalize wrappers to wrap the random ops and share @@ -591,11 +653,15 @@ def get_sample_rng_state(device): run_and_save_rng = torch._prims.rng_prims.run_and_save_rng_state run_with_rng_state = torch._prims.rng_prims.run_with_rng_state - + bw_tangent_start_node = None for node in bw_module.graph.find_nodes(op="placeholder"): if "tangent" in node.name: bw_tangent_start_node = node break + if bw_tangent_start_node is None: + raise RuntimeError( + "Couldn't find tangent node in graph inputs. This is unexpected, please file a bug if you see this" + ) fw_rng_state_outputs = [] for base_node, node_pair in recomputable_rng_ops_map.items(): @@ -665,7 +731,7 @@ def get_sample_rng_state(device): return fw_module, bw_module -def cleanup_recompute_tags(joint_module): +def cleanup_recompute_tags(joint_module: fx.GraphModule) -> fx.GraphModule: """ If there are two consecutive checkpointed blocks with no operator in between, we would still want to stash the tensor at the boundary of @@ -683,332 +749,50 @@ def cleanup_recompute_tags(joint_module): return joint_module -def min_cut_rematerialization_partition( - joint_module: fx.GraphModule, - _joint_inputs, - compiler="inductor", - recomputable_ops=None, - *, - num_fwd_outputs, -) -> Tuple[fx.GraphModule, fx.GraphModule]: - """ - Partitions the joint graph such that the backward recomputes the forward. - Recomputing helps in trading off memory bandwidth with computation. - - To create the fwd and bwd graph, we copy the joint graph, manually set the - outputs to just original forward or backward outputs. And then we run the - resulting graphs through dead code elimination. - - .. warning:: - This API is experimental and likely to change. - - Args: - joint_module(fx.GraphModule): The joint forward and backward graph. This - is the result of AOT Autograd tracing. - _joint_inputs: The inputs to the joint graph. This is unused. - compiler: This option determines the default set of recomputable ops. - Currently, there are two options: ``nvfuser`` and ``inductor``. - recomputable_ops: This is an optional set of recomputable ops. If this - is not None, then this set of ops will be used instead of the - default set of ops. - num_fwd_outputs: The number of outputs from the forward graph. - - Returns: - Returns the generated forward and backward Fx graph modules. - """ - try: - import networkx as nx - except ImportError as e: - raise RuntimeError( - "Need networkx installed to perform smart recomputation " "heuristics" - ) from e - - joint_module.graph.eliminate_dead_code() - joint_module.recompile() - - fx_g = joint_module.graph - - # add the CSE pass - if config.cse: - cse_graph = fx_graph_cse(fx_g) - joint_module.graph = cse_graph - joint_graph = joint_module.graph - - graph_has_recomputable_ops = has_recomputable_ops(joint_module) - graph_has_recomputable_rng_ops = has_recomputable_rng_ops(joint_module) - if graph_has_recomputable_ops: - joint_module = cleanup_recompute_tags(joint_module) - - name_to_node = {} - for node in joint_module.graph.nodes: - name_to_node[node.name] = node - - def classify_nodes(joint_module): - required_bw_nodes = set() - for node in joint_module.graph.nodes: - if node.op == "placeholder" and "tangents" in node.target: - required_bw_nodes.add(node) - if node in required_bw_nodes: - required_bw_nodes.update(node.users) - - primal_inputs = list(filter(_is_primal, joint_module.graph.nodes)) - fwd_seed_offset_inputs = list( - filter(_is_fwd_seed_offset, joint_module.graph.nodes) - ) - inputs = primal_inputs + fwd_seed_offset_inputs - fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs( - joint_module, num_fwd_outputs=num_fwd_outputs - ) - required_bw_nodes.update( - o for o in bwd_outputs if o is not None and o.op != "output" - ) - forward_only_graph = _extract_graph_with_inputs_outputs( - joint_module.graph, inputs, fwd_outputs - ) - required_fw_nodes = { - name_to_node[node.name] - for node in forward_only_graph.nodes - if node.op != "output" - } - unclaimed_nodes = { - node - for node in joint_module.graph.nodes - if node not in required_fw_nodes and node not in required_bw_nodes - } - return ( - fwd_outputs, - required_fw_nodes, - required_bw_nodes, - unclaimed_nodes, - inputs, - ) - - ( - orig_fw_outputs, - required_fw_nodes, - required_bw_nodes, - unclaimed_nodes, - inputs, - ) = classify_nodes(joint_module) - - # networkx blows up on graphs with no required backward nodes - # Since there's nothing to partition anyway, and the default partitioner can "handle" - # this case, send our graph over to the default partitioner. - if len(required_bw_nodes) == 0: - return default_partition( - joint_module, _joint_inputs, num_fwd_outputs=num_fwd_outputs - ) - - def is_fusible(a, b): - # We can perform "memory fusion" into a cat, but cat cannot be a - # producer to a fusion - if get_aten_target(b) == aten.cat: - return True - return get_aten_target(a) in fusible_ops and get_aten_target(b) in fusible_ops - - fw_order = 0 - for node in joint_module.graph.nodes: - if node in required_fw_nodes: - node.fw_order = fw_order - fw_order += 1 - - for node in reversed(joint_module.graph.nodes): - if node not in required_fw_nodes: - node.dist_from_bw = 0 - else: - node.dist_from_bw = int(1e9) - for user in node.users: - node.dist_from_bw = min(node.dist_from_bw, user.dist_from_bw + 1) - - aten = torch.ops.aten - prims = torch.ops.prims - - # compiler == "nvfuser" is the default set of recomputable ops - default_recomputable_ops = [ - aten.add, - aten.sub, - aten.div, - aten.atan2, - aten.mul, - aten.max, - aten.min, - aten.pow, - aten.remainder, - aten.fmod, - aten.__and__, - aten.__or__, - aten.__xor__, - aten.__lshift__, - aten.__rshift__, - aten.eq, - aten.ne, - aten.ge, - aten.gt, - aten.le, - aten.lt, - aten.abs, - aten.bitwise_not, - aten.ceil, - aten.floor, - aten.frac, - aten.neg, - aten.relu, - aten.round, - aten.silu, - aten.trunc, - aten.log, - aten.log10, - aten.log1p, - aten.log2, - aten.lgamma, - aten.exp, - aten.expm1, - aten.erf, - aten.erfc, - aten.cos, - aten.acos, - aten.cosh, - aten.sin, - aten.asin, - aten.sinh, - aten.tan, - aten.atan, - aten.tanh, - aten.atanh, - aten.sqrt, - aten.rsqrt, - aten.reciprocal, - aten.sigmoid, - aten.softplus, - aten.threshold, - aten.threshold_backward, - aten.clamp, - aten.where, - aten.lerp, - aten.addcmul, - aten.gelu, - aten.gelu_backward, - aten.sum, - aten.mean, - aten._grad_sum_to_size, - aten.sum_to_size, - aten.amax, - aten.to, - aten.type_as, - operator.getitem, - aten.squeeze, - aten.unsqueeze, - aten.rsub, - aten._to_copy, - ] # noqa: E501,B950 - view_ops = [aten.squeeze, aten.unsqueeze, aten.alias] - if compiler == "inductor": - default_recomputable_ops += [ - prims.div, - prims.convert_element_type, - aten.clone, - aten._to_copy, - aten.full_like, - prims.var, - prims.sum, - aten.var, - aten.std, - prims.broadcast_in_dim, - aten.select, - aten._unsafe_view, - aten.view, - aten.expand, - aten.slice, - aten.reshape, - aten.broadcast_tensors, - aten.scalar_tensor, - aten.ones, - aten.new_zeros, - aten.lift_fresh_copy, - aten.arange, - aten.triu, - aten.var_mean, - aten.isinf, - aten.any, - aten.full, - aten.as_strided, - aten.zeros, - aten.argmax, - aten.maximum, - prims.iota, - prims._low_memory_max_pool2d_offsets_to_indices, - ] # noqa: E501,B950 - view_ops += [ - aten.view, - aten.slice, - aten.t, - prims.broadcast_in_dim, - aten.expand, - aten.as_strided, - aten.permute, - ] - # Natalia said that we should allow recomputing indexing :) - default_recomputable_ops += [aten.index, aten.gather] - default_recomputable_ops += view_ops - - default_recomputable_ops += pointwise_ops() - - default_recomputable_ops += [ - aten.zeros_like, - ] - - default_recomputable_ops += [method_to_operator(m) for m in magic_methods] - recomputable_ops = ( - set(recomputable_ops) - if recomputable_ops is not None - else set(default_recomputable_ops) - ) - - random_ops = [aten.native_dropout, aten.rand_like, aten.randn_like] - compute_intensive_ops = [ - aten.mm, - aten.convolution, - aten.convolution_backward, - aten.bmm, - aten.addmm, - aten._scaled_dot_product_flash_attention, - aten._scaled_dot_product_efficient_attention, - aten.upsample_bilinear2d, - ] # noqa: E501,B950 +def get_saved_values( + joint_graph: fx.Graph, + node_info: NodeInfo, + min_cut_options: MinCutOptions, + dont_ban=None, +): + if dont_ban is None: + dont_ban = set() + op_types = get_default_op_list() - fusible_ops = recomputable_ops | set(random_ops) if AOT_PARTITIONER_DEBUG: joint_module_ops = { str(node.target._overloadpacket) - for node in joint_module.graph.nodes + for node in joint_graph.nodes if node.op == "call_function" and hasattr(node.target, "_overloadpacket") } - ops_ignored = joint_module_ops - {str(i) for i in recomputable_ops} + ops_ignored = joint_module_ops - {str(i) for i in op_types.recomputable_ops} print("Ops banned from rematerialization: ", ops_ignored) print() - BAN_IF_USED_FAR_APART = config.ban_recompute_used_far_apart - BAN_IF_LONG_FUSIBLE_CHAINS = config.ban_recompute_long_fusible_chains - BAN_IF_MATERIALIZED_BACKWARDS = config.ban_recompute_materialized_backward - BAN_IF_NOT_IN_ALLOWLIST = config.ban_recompute_not_in_allowlist - BAN_IF_REDUCTION = config.ban_recompute_reductions + def is_fusible(a, b): + # We can perform "memory fusion" into a cat, but cat cannot be a + # producer to a fusion + if get_aten_target(b) == aten.cat: + return True + return op_types.is_fusible(a) and op_types.is_fusible(b) - if config.aggressive_recomputation: - BAN_IF_MATERIALIZED_BACKWARDS = False - BAN_IF_USED_FAR_APART = False - BAN_IF_LONG_FUSIBLE_CHAINS = False - BAN_IF_NOT_IN_ALLOWLIST = False + try: + import networkx as nx + except ImportError as e: + raise RuntimeError( + "Need networkx installed to perform smart recomputation " "heuristics" + ) from e def is_materialized_backwards(node): - if get_aten_target(node) in view_ops: + if op_types.is_view(node): return False cur_nodes = {node} while len(cur_nodes) > 0: cur = cur_nodes.pop() for user in cur.users: - if user not in required_fw_nodes and not is_fusible(cur, user): + if not node_info.is_required_fw(user) and not is_fusible(cur, user): return True - if get_aten_target(user) in view_ops: + if op_types.is_view(user): cur_nodes.add(user) return False @@ -1020,17 +804,15 @@ def should_ban_recomputation(node): return False if node.target in [aten.lift_fresh_copy.default, aten.lift_fresh.default]: return False - # NB: "recompute" == 0 means that must save this node. if node.meta.get("recompute", None) == 0: return True - if BAN_IF_NOT_IN_ALLOWLIST: - if get_aten_target(node) not in recomputable_ops: + if min_cut_options.ban_if_not_in_allowlist: + if not op_types.is_recomputable(node): return True else: - ignored_ops = random_ops + compute_intensive_ops - if get_aten_target(node) in ignored_ops: + if op_types.is_random(node) or op_types.is_compute_intensive(node): return True # If a node *must* be materialized in the backwards pass, then we @@ -1038,7 +820,9 @@ def should_ban_recomputation(node): # general, the assumption we make is that recomputing a node in the # backwards pass is "free". However, if a node must be materialized # in the backwards pass, then recomputing it is never free. - if is_materialized_backwards(node) and BAN_IF_MATERIALIZED_BACKWARDS: + if min_cut_options.ban_if_materialized_backward and is_materialized_backwards( + node + ): log.info("materialized backwards: %s %s", node, tuple(node.users)) return True @@ -1046,16 +830,15 @@ def should_ban_recomputation(node): # modification appears to have made this heuristic a lot less critical # for performance. # NB: As of PR #121692, this hack no longer seems necessary. - if not graph_has_recomputable_ops: - if compiler == "inductor" and node.dist_from_bw > config.max_dist_from_bw: - return True + if node.dist_from_bw < 1000 and node.dist_from_bw > config.max_dist_from_bw: + return True # If the output of an op is 4x smaller (arbitrary choice), # then we don't allow recomputation. The idea here is that for # things like reductions, saving the output of the reduction is very # cheap/small, and it makes sure we don't do things like recompute # normalizations in the backwards. - if BAN_IF_REDUCTION: + if min_cut_options.ban_if_reduction: input_tensors_size = sum( _size_of(i) for i in node.args if isinstance(i, fx.Node) ) @@ -1069,9 +852,14 @@ def is_materialized(node): return not all(is_fusible(node, user) for user in node.users) - def get_node_weight(node) -> int: + def get_node_weight(node) -> float: mem_sz = _size_of(node) + if isinstance(node.meta["val"], py_sym_types): + # We never want to save symfloats + if not isinstance(node.meta["val"], torch.SymInt): + return INT_INF + # Heuristic to bias towards nodes closer to the backwards pass # Complete guess about current value mem_sz = int(mem_sz * (1.1 ** max(min(node.dist_from_bw, 100), 1))) @@ -1084,6 +872,11 @@ def get_node_weight(node) -> int: banned_nodes = set() def ban_recomputation_if_allowed(node): + if op_types.is_view(node): + return False + if node in dont_ban: + return False + # breakpoint() # This bans recomputation of the node unless we've been forced not to by # user annotation # NB: "recompute" > 0 means that user annotation has asked us to @@ -1106,8 +899,8 @@ def ban_recomputation_if_allowed(node): if node.op == "output": continue - if node in required_bw_nodes: - if node not in inputs: + if node in node_info.required_bw_nodes: + if node not in node_info.inputs: nx_graph.add_edge(node.name + "_in", "sink", capacity=math.inf) continue # If someone saves a input for backward as-is and backward @@ -1126,7 +919,7 @@ def ban_recomputation_if_allowed(node): # If a node can't be recomputed (too expensive or involves randomness), # we prevent it from being recomputed by adding an inf edge to the source # We only need to ban nodes in the fw pass, as those are the only ones that would be recomputed. - if node in required_fw_nodes and should_ban_recomputation(node): + if node_info.is_required_fw(node) and should_ban_recomputation(node): ban_recomputation_if_allowed(node) # Checks if a node is actually a tuple. Can be simplified to just an isinstance check if we always use faketensors. @@ -1135,12 +928,13 @@ def ban_recomputation_if_allowed(node): ) or ("val" in node.meta and not isinstance(node.meta["val"], torch.Tensor)) if is_sym_node(node): - weight = sym_node_size(node) + weight = float(sym_node_size(node)) elif is_non_tensor_node: - weight = 0 if isinstance(node.meta.get("val"), BackwardState) else math.inf + weight = ( + 0.0 if isinstance(node.meta.get("val"), BackwardState) else math.inf + ) else: weight = get_node_weight(node) - # Creates the weights on the "node" edge nx_graph.add_edge(node.name + "_in", node.name + "_out", capacity=weight) for user in node.users: @@ -1168,35 +962,40 @@ def find_first_unfusible(start_nodes: List[fx.Node], max_range: int) -> int: Finds the first unfusible node in the chain of nodes starting from `start_nodes` and returns its position. """ - sorted_nodes = [] + sorted_nodes: List[Tuple[int, fx.Node, bool]] = [] for n in start_nodes: - heapq.heappush(sorted_nodes, (n.fw_order, n, True)) + heapq.heappush(sorted_nodes, (node_info.get_fw_order(n), n, True)) while len(sorted_nodes) > 0: _, node, node_is_fusible = heapq.heappop(sorted_nodes) if not node_is_fusible: - return node.fw_order + return node_info.get_fw_order(node) for user in node.users: - if user in required_fw_nodes: - if user.fw_order > max_range: + if node_info.is_required_fw(user): + if node_info.get_fw_order(user) > max_range: continue heapq.heappush( - sorted_nodes, (user.fw_order, user, is_fusible(node, user)) + sorted_nodes, + (node_info.get_fw_order(user), user, is_fusible(node, user)), ) return max_range - if BAN_IF_USED_FAR_APART: - for used_node in required_fw_nodes: + if min_cut_options.ban_if_used_far_apart: + for used_node in node_info.required_fw_nodes: orders = [ - user.fw_order for user in used_node.users if user in required_fw_nodes + node_info.get_fw_order(user) + for user in used_node.users + if node_info.is_required_fw(user) + ] + fw_users = [ + user for user in used_node.users if node_info.is_required_fw(user) ] - fw_users = [user for user in used_node.users if user in required_fw_nodes] if len(orders) > 0: first_unfusible_use = find_first_unfusible(fw_users, max(orders)) for user in tuple(used_node.users): if ( - user in required_fw_nodes - and user.fw_order > first_unfusible_use + node_info.is_required_fw(user) + and node_info.get_fw_order(user) > first_unfusible_use and is_fusible(used_node, user) ): if user in banned_nodes: @@ -1204,10 +1003,10 @@ def find_first_unfusible(start_nodes: List[fx.Node], max_range: int) -> int: log.info( "used above/below fusible %s:(%s) -> %s -> %s:(%s)", used_node, - used_node.fw_order, + node_info.get_fw_order(used_node), first_unfusible_use, user, - user.fw_order, + node_info.get_fw_order(user), ) ban_recomputation_if_allowed(user) @@ -1222,47 +1021,51 @@ def find_first_unfusible(start_nodes: List[fx.Node], max_range: int) -> int: # Some models it improves perf on are cait_m36_384, mixer_b16_224, poolformer_m36 - if BAN_IF_LONG_FUSIBLE_CHAINS: + if min_cut_options.ban_if_long_fusible_chains: visited = set() for start_node in joint_graph.nodes: - if start_node not in required_fw_nodes: + if not node_info.is_required_fw(start_node): continue - fusible = [(start_node.fw_order, start_node)] - start_order = start_node.fw_order + fusible = [(node_info.get_fw_order(start_node), start_node)] + start_order = node_info.get_fw_order(start_node) while len(fusible) > 0: _, cur = heapq.heappop(fusible) if cur in visited: continue visited.add(cur) # 100 is arbitrary choice to try and prevent degenerate cases - if cur.fw_order > start_order + 100 and len(fusible) == 0: + if ( + node_info.get_fw_order(cur) > start_order + 100 + and len(fusible) == 0 + ): log.info( "too long %s %s %s %s", cur, start_node, - cur.fw_order, - start_node.fw_order, + node_info.get_fw_order(cur), + node_info.get_fw_order(start_node), ) ban_recomputation_if_allowed(cur) break for user in cur.users: if ( - user in required_fw_nodes + node_info.is_required_fw(user) and is_fusible(cur, user) and user not in banned_nodes ): - heapq.heappush(fusible, (user.fw_order, user)) + heapq.heappush(fusible, (node_info.get_fw_order(user), user)) try: cut_value, partition = nx.minimum_cut(nx_graph, "source", "sink") except Exception: print("Failed to compute min-cut on following graph:") print("\n".join(nx.readwrite.edgelist.generate_edgelist(nx_graph))) + visualize_min_cut_graph(nx_graph) raise reachable, non_reachable = partition - cutset = set() + cutset: Set[Tuple[str, str]] = set() for u, nbrs in ((n, nx_graph[n]) for n in reachable): cutset.update((u, v) for v in nbrs if v in non_reachable) @@ -1272,14 +1075,347 @@ def find_first_unfusible(start_nodes: List[fx.Node], max_range: int) -> int: node_name = node_in[:-3] cut_nodes.add(node_name) + name_to_node = get_name_to_node(joint_graph) # To make this stuff deterministic - node_idx = {node: idx for idx, node in enumerate(joint_module.graph.nodes)} + node_idx = {node: idx for idx, node in enumerate(joint_graph.nodes)} saved_values = sorted( (name_to_node[node] for node in cut_nodes), key=lambda x: node_idx[x] ) + return saved_values, banned_nodes + + +def visualize_min_cut_graph(nx_graph): + import networkx as nx + import pydot + + dot_format = nx.nx_pydot.to_pydot(nx_graph).to_string() + dot_graph = pydot.graph_from_dot_data(dot_format)[0] + for edge in dot_graph.get_edges(): + weight = nx_graph[edge.get_source()][edge.get_destination()]["capacity"] + # Set edge label to weight + edge.set_label(str(weight)) + # Color edges with weight 'inf' as red + if weight == float("inf"): + edge.set_color("red") + print("Visualizing the failed graph to min_cut_failed.svg") + dot_graph.write_svg("min_cut_failed.svg") + + +def get_default_op_list() -> OpTypes: + default_recomputable_ops: List[Callable] = [ + aten.add, + aten.sub, + aten.div, + aten.atan2, + aten.mul, + aten.max, + aten.min, + aten.pow, + aten.remainder, + aten.fmod, + aten.__and__, + aten.__or__, + aten.__xor__, + aten.__lshift__, + aten.__rshift__, + aten.eq, + aten.ne, + aten.ge, + aten.gt, + aten.le, + aten.lt, + aten.abs, + aten.bitwise_not, + aten.ceil, + aten.floor, + aten.frac, + aten.neg, + aten.relu, + aten.round, + aten.silu, + aten.trunc, + aten.log, + aten.log10, + aten.log1p, + aten.log2, + aten.lgamma, + aten.exp, + aten.expm1, + aten.erf, + aten.erfc, + aten.cos, + aten.acos, + aten.cosh, + aten.sin, + aten.asin, + aten.sinh, + aten.tan, + aten.atan, + aten.tanh, + aten.atanh, + aten.sqrt, + aten.rsqrt, + aten.reciprocal, + aten.sigmoid, + aten.softplus, + aten.threshold, + aten.threshold_backward, + aten.clamp, + aten.where, + aten.lerp, + aten.addcmul, + aten.gelu, + aten.gelu_backward, + aten.sum, + aten.mean, + aten._grad_sum_to_size, + aten.sum_to_size, + aten.amax, + aten.to, + aten.type_as, + operator.getitem, + aten.squeeze, + aten.unsqueeze, + aten.rsub, + aten._to_copy, + ] # noqa: E501,B950 + recomputable_view_ops = [aten.squeeze, aten.unsqueeze, aten.alias] + recomputable_view_ops += [ + aten.view, + aten.slice, + aten.t, + prims.broadcast_in_dim, + aten.expand, + aten.as_strided, + aten.permute, + ] + view_ops = recomputable_view_ops + default_recomputable_ops += [ + prims.div, + prims.convert_element_type, + aten.clone, + aten._to_copy, + aten.full_like, + prims.var, + prims.sum, + aten.var, + aten.std, + prims.broadcast_in_dim, + aten.select, + aten._unsafe_view, + aten.view, + aten.expand, + aten.slice, + aten.reshape, + aten.broadcast_tensors, + aten.scalar_tensor, + aten.ones, + aten.new_zeros, + aten.lift_fresh_copy, + aten.arange, + aten.triu, + aten.var_mean, + aten.isinf, + aten.any, + aten.full, + aten.as_strided, + aten.zeros, + aten.argmax, + aten.maximum, + prims.iota, + prims._low_memory_max_pool2d_offsets_to_indices, + ] # noqa: E501,B950 + # Natalia said that we should allow recomputing indexing :) + default_recomputable_ops += [aten.index, aten.gather] + default_recomputable_ops += view_ops + + default_recomputable_ops += pointwise_ops() + + default_recomputable_ops += [ + aten.zeros_like, + ] + + default_recomputable_ops += [method_to_operator(m) for m in magic_methods] + recomputable_ops = set(default_recomputable_ops) + + random_ops = [aten.native_dropout, aten.rand_like, aten.randn_like] + compute_intensive_ops = [ + aten.mm, + aten.convolution, + aten.convolution_backward, + aten.bmm, + aten.addmm, + aten._scaled_dot_product_flash_attention, + aten._scaled_dot_product_efficient_attention, + aten.upsample_bilinear2d, + ] # noqa: E501,B950 + + fusible_ops = recomputable_ops | set(random_ops) + return OpTypes( + set(fusible_ops), + set(compute_intensive_ops), + set(random_ops), + set(view_ops), + set(recomputable_ops), + ) + + +def get_name_to_node(graph: fx.Graph): + name_to_node = {} + for node in graph.nodes: + name_to_node[node.name] = node + return name_to_node + + +def choose_saved_values_set( + joint_graph: fx.Graph, node_info: NodeInfo, memory_budget=1 +) -> List[fx.Node]: + min_cut_options = MinCutOptions( + ban_if_used_far_apart=config.ban_recompute_used_far_apart, + ban_if_long_fusible_chains=config.ban_recompute_long_fusible_chains, + ban_if_materialized_backward=config.ban_recompute_materialized_backward, + ban_if_not_in_allowlist=config.ban_recompute_not_in_allowlist, + ban_if_reduction=config.ban_recompute_reductions, + ) + + if config.aggressive_recomputation: + min_cut_options = replace( + min_cut_options, + ban_if_used_far_apart=False, + ban_if_long_fusible_chains=False, + ban_if_materialized_backward=False, + ban_if_not_in_allowlist=False, + ) + + if memory_budget == 0: + return node_info.inputs + + runtime_optimized_saved_values, _ = get_saved_values( + joint_graph, + node_info, + min_cut_options, + ) + return runtime_optimized_saved_values + + +def min_cut_rematerialization_partition( + joint_module: fx.GraphModule, + _joint_inputs, + compiler="inductor", + *, + num_fwd_outputs, +) -> Tuple[fx.GraphModule, fx.GraphModule]: + """ + Partitions the joint graph such that the backward recomputes the forward. + Recomputing helps in trading off memory bandwidth with computation. + + To create the fwd and bwd graph, we copy the joint graph, manually set the + outputs to just original forward or backward outputs. And then we run the + resulting graphs through dead code elimination. + + .. warning:: + This API is experimental and likely to change. + + Args: + joint_module(fx.GraphModule): The joint forward and backward graph. This + is the result of AOT Autograd tracing. + _joint_inputs: The inputs to the joint graph. This is unused. + compiler: This option determines the default set of recomputable ops. + Currently, there are two options: ``nvfuser`` and ``inductor``. + recomputable_ops: This is an optional set of recomputable ops. If this + is not None, then this set of ops will be used instead of the + default set of ops. + num_fwd_outputs: The number of outputs from the forward graph. + + Returns: + Returns the generated forward and backward Fx graph modules. + """ + + joint_module.graph.eliminate_dead_code() + joint_module.recompile() + + fx_g = joint_module.graph + + # add the CSE pass + if config.cse: + cse_graph = fx_graph_cse(fx_g) + joint_module.graph = cse_graph + joint_graph = joint_module.graph + + graph_has_recomputable_ops = has_recomputable_ops(joint_module) + graph_has_recomputable_rng_ops = has_recomputable_rng_ops(joint_module) + if graph_has_recomputable_ops: + joint_module = cleanup_recompute_tags(joint_module) + + def classify_nodes(joint_module): + name_to_node = get_name_to_node(joint_module.graph) + required_bw_nodes = set() + for node in joint_module.graph.nodes: + if node.op == "placeholder" and "tangents" in node.target: + required_bw_nodes.add(node) + if node in required_bw_nodes: + for user in node.users: + required_bw_nodes.add(user) + + primal_inputs = list(filter(_is_primal, joint_module.graph.nodes)) + fwd_seed_offset_inputs = list( + filter(_is_fwd_seed_offset, joint_module.graph.nodes) + ) + inputs = primal_inputs + fwd_seed_offset_inputs + fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs( + joint_module, num_fwd_outputs=num_fwd_outputs + ) + required_bw_nodes.update( + o for o in bwd_outputs if o is not None and o.op != "output" + ) + forward_only_graph = _extract_graph_with_inputs_outputs( + joint_module.graph, inputs, fwd_outputs + ) + required_fw_nodes: Set[fx.Node] = { + name_to_node[node.name] + for node in forward_only_graph.nodes + if node.op != "output" + } + unclaimed_nodes = { + node + for node in joint_module.graph.nodes + if node not in required_fw_nodes and node not in required_bw_nodes + } + fw_cnt = 0 + fw_order = {} + for node in joint_module.graph.nodes: + if node in required_fw_nodes: + fw_order[node] = fw_cnt + fw_cnt += 1 + return NodeInfo( + inputs, required_fw_nodes, required_bw_nodes, unclaimed_nodes, fw_order + ) + + node_info = classify_nodes(joint_module) + + # networkx blows up on graphs with no required backward nodes + # Since there's nothing to partition anyway, and the default partitioner can "handle" + # this case, send our graph over to the default partitioner. + if len(node_info.required_bw_nodes) == 0: + return default_partition( + joint_module, _joint_inputs, num_fwd_outputs=num_fwd_outputs + ) + + for node in reversed(joint_module.graph.nodes): + if node.op == "output": + node.dist_from_bw = int(1e9) + elif not node_info.is_required_fw(node): + node.dist_from_bw = 0 + else: + node.dist_from_bw = int(1e9) + for user in node.users: + node.dist_from_bw = min(node.dist_from_bw, user.dist_from_bw + 1) + + saved_values = choose_saved_values_set(joint_graph, node_info, memory_budget=1) # save_for_backward on tensors and stashes symints in autograd .ctx saved_sym_nodes = list(filter(is_sym_node, saved_values)) saved_values = list(filter(lambda n: not is_sym_node(n), saved_values)) + # NB: saved_sym_nodes will be mutated to reflect the actual saved symbols fw_module, bw_module = _extract_fwd_bwd_modules( joint_module, @@ -1312,7 +1448,7 @@ def find_first_unfusible(start_nodes: List[fx.Node], max_range: int) -> int: } remat_nodes = fw_module_nodes & bw_module_nodes - counts = defaultdict(int) + counts: Dict[str, int] = defaultdict(int) for node in fw_module.graph.nodes: if node.name in remat_nodes and hasattr(node.target, "_overloadpacket"): counts[str(node.target._overloadpacket)] += 1 @@ -1321,7 +1457,7 @@ def find_first_unfusible(start_nodes: List[fx.Node], max_range: int) -> int: ) print( "Count of Ops Rematerialized: ", - sorted(counts.items(), key=operator.itemgetter(1), reverse=True), + sorted(counts.items(), key=lambda x: x[1], reverse=True), ) return fw_module, bw_module @@ -1331,7 +1467,7 @@ def draw_graph( fname: str, figname: str = "fx_graph", clear_meta: bool = True, - prog: Union[str, List[str]] = None, + prog: Optional[Union[str, List[str]]] = None, parse_stack_trace: bool = False, dot_graph_shape: Optional[str] = None, ) -> None: @@ -1342,7 +1478,7 @@ def draw_graph( node.meta = {} base, ext = os.path.splitext(fname) if not ext: - ext = ".svg" + ext = "." + config.torch_compile_graph_format print(f"Writing FX graph to file: {base}{ext}") g = graph_drawer.FxGraphDrawer( traced, @@ -1357,13 +1493,3 @@ def draw_graph( write_method(fname) else: write_method(fname, prog=prog) - - -def draw_joint_graph( - graph: torch.fx.GraphModule, - joint_inputs, - file_name: str = "full_graph.png", - dot_graph_shape: Optional[str] = None, -): - draw_graph(graph, file_name, dot_graph_shape=dot_graph_shape) - return default_partition(graph, joint_inputs) diff --git a/torch/_higher_order_ops/flex_attention.py b/torch/_higher_order_ops/flex_attention.py index b5e1385da346b..f4586a0a57b0c 100644 --- a/torch/_higher_order_ops/flex_attention.py +++ b/torch/_higher_order_ops/flex_attention.py @@ -406,17 +406,20 @@ def flex_attention_autograd( score_mod: Callable, *other_buffers: Tuple[torch.Tensor, ...], ) -> Tuple[torch.Tensor, torch.Tensor]: - input_requires_grad = any(t.requires_grad for t in (query, key, value)) - if torch.is_grad_enabled() and input_requires_grad: - example_vals = [ - torch.zeros((), dtype=query.dtype, requires_grad=input_requires_grad) - ] + [torch.zeros((), dtype=torch.int) for _ in range(4)] - fw_graph, bw_graph = create_fw_bw_graph(score_mod, example_vals, other_buffers) - else: - fw_graph, bw_graph = score_mod, None - out, logsumexp = FlexAttentionAutogradOp.apply( - query, key, value, fw_graph, bw_graph, *other_buffers - ) + with TransformGetItemToIndex(): + input_requires_grad = any(t.requires_grad for t in (query, key, value)) + if torch.is_grad_enabled() and input_requires_grad: + example_vals = [ + torch.zeros((), dtype=query.dtype, requires_grad=input_requires_grad) + ] + [torch.zeros((), dtype=torch.int) for _ in range(4)] + fw_graph, bw_graph = create_fw_bw_graph( + score_mod, example_vals, other_buffers + ) + else: + fw_graph, bw_graph = score_mod, None + out, logsumexp = FlexAttentionAutogradOp.apply( + query, key, value, fw_graph, bw_graph, *other_buffers + ) return out, logsumexp @@ -449,9 +452,10 @@ def sdpa_dense_backward( score_mod = torch.vmap(score_mod, in_dims=(0, None, 0, None, None) + in_dim_buffers) score_mod = torch.vmap(score_mod, in_dims=(0, 0, None, None, None) + in_dim_buffers) - post_mod_scores = score_mod(scores, b, h, m, n, *other_buffers).to( - working_precision - ) + with TransformGetItemToIndex(): + post_mod_scores = score_mod(scores, b, h, m, n, *other_buffers).to( + working_precision + ) softmax_scores = torch.exp(post_mod_scores - logsumexp.unsqueeze(-1)) @@ -485,9 +489,10 @@ def sdpa_dense_backward( in_dims=(0, 0, None, None, None, 0) + in_dim_buffers, out_dims=out_dims, ) - grad_scores, *_ = joint_score_mod( - scores, b, h, m, n, grad_score_mod, *other_buffers - ) + with TransformGetItemToIndex(): + grad_scores, *_ = joint_score_mod( + scores, b, h, m, n, grad_score_mod, *other_buffers + ) grad_scores = grad_scores.to(query.dtype) grad_query = grad_scores @ key @@ -524,8 +529,9 @@ def trace_flex_attention_backward( torch.zeros((), dtype=query.dtype, requires_grad=query.requires_grad) ] + [torch.zeros((), dtype=torch.int) for _ in range(4)] bw_example_vals = fw_example_vals + [torch.zeros((), dtype=query.dtype)] - fw_graph = make_fx(fw_graph)(*fw_example_vals, *other_buffers) - joint_graph = make_fx(joint_graph)(*bw_example_vals, *other_buffers) + with TransformGetItemToIndex(): + fw_graph = reenter_make_fx(fw_graph)(*fw_example_vals, *other_buffers) + joint_graph = reenter_make_fx(joint_graph)(*bw_example_vals, *other_buffers) proxy_mode.tracer.root.register_module("fw_graph", fw_graph) proxy_mode.tracer.root.register_module("joint_graph", joint_graph) node_args = ( diff --git a/torch/_inductor/__init__.py b/torch/_inductor/__init__.py index 0516fc55e074a..0d7cd8cece49e 100644 --- a/torch/_inductor/__init__.py +++ b/torch/_inductor/__init__.py @@ -81,7 +81,7 @@ def aot_compile( ) if in_spec is not None and received_spec != in_spec: - raise ValueError( # noqa: TRY200 + raise ValueError( # noqa: B904 "Trying to flatten user inputs with exported input tree spec: \n" f"{in_spec}\n" "but actually got inputs with tree spec of: \n" diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 0850b4a94bdc4..70b4671431115 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -1961,6 +1961,14 @@ def _compile_consts_linux(consts: bytes) -> str: return consts_o def _compile_consts_darwin(consts: bytes) -> str: + if config.aot_inductor.debug_dump_consts_bin: + _, _binary_constants_path = write( + consts, + "bin", + specified_dir=specified_output_path, + ) + log.debug("binary constants path: %s", _binary_constants_path) + is_large_consts = len(consts) > 1024 consts_asm = "\t.section\t__DATA,__data\n" consts_asm += "\t.globl\t__binary_constants_bin_start\n" diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index a6ac6234ab083..8641f89a7d3a3 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -601,6 +601,8 @@ class OverridesData: ) +# NB: if you add a new special function, don't forget to update +# torch._inductor.ops_handler too pointwise_overrides_data: Dict[str, OverridesData] = dict( airy_ai=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, @@ -1508,7 +1510,7 @@ def _bound_variable(name, *args, **kwargs): return ValueRanges.unknown() fx_node = V.interpreter.current_node - if fx_node.target == name and self.node_to_bounds is not None: + if fx_node.target == name: assert isinstance(self.node_to_bounds, dict) return self.node_to_bounds.get(fx_node, ValueRanges.unknown()) elif config.compute_all_bounds and hasattr(ValueRangeAnalysis, name): diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 7a026b9b3c6df..a0beddbf9bd39 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -8,7 +8,7 @@ import sys from copy import copy, deepcopy from enum import Enum -from typing import Any, cast, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Dict, List, Optional, Set, Tuple, Union import sympy @@ -17,10 +17,9 @@ from torch._inductor import dependencies from torch._prims_common import is_float_dtype from torch.utils import _pytree as pytree -from torch.utils._sympy.functions import FloorDiv, ModularIndexing +from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges -from ..._dynamo.utils import counters from .. import codecache, config, ir, metrics from ..codegen.wrapper import WrapperCodeGen @@ -46,7 +45,7 @@ sympy_subs, ) -from ..virtualized import NullKernelHandler, ops, OpsValue, V +from ..virtualized import ops, OpsValue, V from .common import ( BracesBuffer, CppWrapperKernelArgs, @@ -1503,6 +1502,7 @@ def __init__(self, args, num_threads): self.local_reduction_init = IndentedBuffer() self.local_reduction_stores = IndentedBuffer() self.is_reduction = False + self.non_parallel_reduction_prefix = IndentedBuffer() self.reduction_cse = CSE(self.newvar_prefix, self.suffix, name_prefix="tmp_acc") self.preloads = IndentedBuffer() self.poststores = IndentedBuffer() @@ -1517,6 +1517,7 @@ def _gen_parallel_reduction_buffers( dtype, reduction_combine_fn=reduction_combine, reduction_init_fn=reduction_init, + welford_weight_reciprocal_vec_fn=None, ): if config.cpp.dynamic_threads and not self.parallel_reduction_prefix: self.parallel_reduction_prefix.writeline( @@ -1553,6 +1554,15 @@ def _gen_parallel_reduction_buffers( "}", ], ) + if ( + reduction_type == "welford_reduce" + and welford_weight_reciprocal_vec_fn + and hasattr(self, "weight_recp_vec_range") + and "vec" in f"{acc_type}" + ): + self.local_reduction_init.writeline( + welford_weight_reciprocal_vec_fn(dtype, num_threads) + ) def get_reduction_var_pattern(self, line: str): return re.search("tmp_acc[0-9]+", line) @@ -1881,6 +1891,8 @@ def get_reduction_code_buffer(loops, buffer="prefix"): prefix = kernel.reduction_prefix if loop.parallel: prefix = prefix + kernel.parallel_reduction_prefix + else: + prefix = prefix + kernel.non_parallel_reduction_prefix return prefix def gen_loops(loops: List[LoopLevel], in_reduction=False): @@ -2319,9 +2331,25 @@ def reduction(self, dtype, src_dtype, reduction_type, value): self.reduction_prefix.writeline( f"{acc_type_vec} {acc_vec} = {self.reduction_init_vec(reduction_type, dtype)};" ) - self.stores.writeline( - f"{acc_vec} = {self.reduction_combine_vec(reduction_type, acc_vec, value)};" + # save the reciprocal of weights for welford reduce if using static shape + reduction_size = functools.reduce( + lambda x, y: x * y, self.ranges[self.reduction_depth :] ) + if reduction_type == "welford_reduce": + reduction_factor = ( + self.tiling_factor if self.tiling_idx >= self.reduction_depth else 1 + ) + self.weight_recp_vec_range = FloorDiv(reduction_size, reduction_factor) + self.non_parallel_reduction_prefix.writeline( + self.welford_weight_reciprocal_vec(dtype, None) + ) + self.stores.writeline( + f"{acc_vec} = {self.reduction_combine_vec(reduction_type, acc_vec, value, True)};" + ) + else: + self.stores.writeline( + f"{acc_vec} = {self.reduction_combine_vec(reduction_type, acc_vec, value)};" + ) self._gen_parallel_reduction_buffers( acc, acc_type, @@ -2335,6 +2363,7 @@ def reduction(self, dtype, src_dtype, reduction_type, value): dtype, reduction_combine_fn=self.reduction_combine_vec, reduction_init_fn=self.reduction_init_vec, + welford_weight_reciprocal_vec_fn=self.welford_weight_reciprocal_vec, ) tmpvar: Union[str, CSEVariable] if self.tiling_idx >= self.reduction_depth: @@ -2436,7 +2465,18 @@ def reduction_acc_type_vec(self, reduction_type, dtype): return vec_type - def reduction_combine_vec(self, reduction_type, var, next_value): + def welford_weight_reciprocal_vec(self, dtype, num_threads=None): + vec_num_range_thread = ( + CeilDiv(self.weight_recp_vec_range, num_threads) + if num_threads + else self.weight_recp_vec_range + ) + vec_num_range_thread_expr = cexpr_index(vec_num_range_thread) + return f"static WeightRecp<{self._get_vec_type(dtype)}> weight_recps({vec_num_range_thread_expr});" + + def reduction_combine_vec( + self, reduction_type, var, next_value, use_weight_recps=False + ): if reduction_type == "max": return f"at::vec::maximum({var}, {next_value})" elif reduction_type == "min": @@ -2448,7 +2488,10 @@ def reduction_combine_vec(self, reduction_type, var, next_value): elif reduction_type == "xor_sum": return f"{var} ^ {next_value}" elif reduction_type == "welford_reduce": - return f"welford_combine({var}, {next_value})" + if use_weight_recps: + return f"welford_combine({var}, {next_value}, &weight_recps)" + else: + return f"welford_combine({var}, {next_value})" elif reduction_type == "welford_combine": if isinstance(next_value, tuple): # When reading a value from Inductor IR we have a tuple of variable names @@ -2748,8 +2791,9 @@ def store_reduction(self, name, index, value): return self.simd_vec def __exit__(self, exc_type, exc_val, exc_tb): + assert self._orig_wrapper_code is not None # Restore the wrapper_code - V.graph.wrapper_code = self._orig_wrapper_code # type: ignore[assignment] + V.graph.wrapper_code = self._orig_wrapper_code self.exit_stack.__exit__(exc_type, exc_val, exc_tb) def __enter__(self): @@ -3147,11 +3191,27 @@ def is_memory_copy_scheduler_node(node: SchedulerNode): body: ir.LoopBody = node._body _legalize_lowp_fp(body) - def codegen_functions(self, fn_list, var_sizes_list, vec_dtype=torch.float): - # TODO(jgong5): remove vec_dtype arg with alternative tiling factors for various dtypes - assert len(fn_list) == len(var_sizes_list) + def codegen_nodes(self, nodes: List[SchedulerNode]): + # Legalize BF16 node by adding to_dtype explicitly + self.legalize_lowp_fp_dtype(nodes) + self.data_type_propagation(nodes) + + assert len(nodes) >= 1 + first_node = nodes[0] + vec_dtype = ( + first_node._lowp_fp_type # type: ignore[attr-defined] + if all( + hasattr(_node, "_lowp_fp_type") + and _node._lowp_fp_type == first_node._lowp_fp_type # type: ignore[attr-defined] + for _node in nodes + ) + else torch.float + ) + kernel_group = self.kernel_group - group, reduction_group = max(var_sizes_list, key=lambda sizes: len(sizes[1])) + _, (group, reduction_group) = max( + nodes, key=lambda x: int(x.is_reduction()) + ).group self.set_ranges(group, reduction_group) @@ -3167,22 +3227,22 @@ def codegen_kernel(cls, *args): def run(kernel): vars, reduction_vars = kernel.set_ranges(group, reduction_group) in_suffix = False - for fn, var_sizes in zip(fn_list, var_sizes_list): - if var_sizes in [ + for node in nodes: + if node.group[1] in [ (group, reduction_group), (group + reduction_group, ()), ]: assert not in_suffix - fn(vars, reduction_vars) + node.run(vars, reduction_vars) else: in_suffix = True - assert var_sizes == ( + assert node.group[1] == ( group, (), - ), f"unexpected group: {var_sizes} != {group}, {reduction_group}" + ), f"unexpected group: {node.group[1]} != {group}, {reduction_group}" # we can fuse in some extra pointwise into the suffix with kernel.write_to_suffix(): - fn(vars, ()) + node.run(vars, ()) scalar_kernel = codegen_kernel(CppKernel) V.graph.removed_buffers |= scalar_kernel.removed_buffers @@ -3194,8 +3254,8 @@ def run(kernel): def select_tiling_indices(tiling_factor): all_index = [] - for fn, var_sizes in zip(fn_list, var_sizes_list): - rw = dependencies.extract_read_writes(fn, *var_sizes) + for node in nodes: + rw = dependencies.extract_read_writes(node._body, *node._sizes) all_index += [dep.index for dep in itertools.chain(rw.reads, rw.writes)] contig_vars = set() contig_vars_list = [] @@ -3309,41 +3369,6 @@ def select_tiling(dtype: torch.dtype = torch.float): inner_main_loop.set_kernel(tile2d_kernel) inner_tail_loop.set_kernel(vec_kernel) - def codegen_loop_bodies(self, loop_bodies, var_sizes_list): - # TODO(jgong5): support lowp legalization - for body in loop_bodies: - DataTypePropagation.propagate_loopbody(body) - self.codegen_functions(loop_bodies, var_sizes_list) - - def codegen_nodes(self, nodes: List[SchedulerNode]): - # Legalize BF16 node by adding to_dtype explicitly - self.legalize_lowp_fp_dtype(nodes) - self.data_type_propagation(nodes) - - assert len(nodes) >= 1 - first_node = nodes[0] - vec_dtype = ( - first_node._lowp_fp_type # type: ignore[attr-defined] - if all( - hasattr(_node, "_lowp_fp_type") - and _node._lowp_fp_type == first_node._lowp_fp_type # type: ignore[attr-defined] - for _node in nodes - ) - else torch.float - ) - - def fn(node, *index_vars): - node.decide_inplace_update() - node.mark_run() - if isinstance(V.kernel, NullKernelHandler): - return node._body(*index_vars) - else: - return node.codegen(index_vars) - - fn_list = [functools.partial(fn, node) for node in nodes] - var_sizes_list = [node.group[1] for node in nodes] - self.codegen_functions(fn_list, var_sizes_list, vec_dtype) - def codegen_loops(self, code, worksharing): self.codegen_loops_impl(self.loop_nest, code, worksharing) @@ -3408,9 +3433,6 @@ def reset_kernel_group(self): def fuse(self, node1, node2): if node1.is_foreach() or node2.is_foreach(): return ForeachKernelSchedulerNode.fuse(node1, node2) - elif node1.is_template(): - assert not node2.is_template() - return FusedSchedulerNode.fuse(node1, node2) else: if ( self._why_fuse_nodes(node1, node2) @@ -3543,8 +3565,6 @@ def _can_fuse_horizontal_impl(self, node1, node2): return self._why_fuse_nodes(node1, node2) is not None def can_fuse_horizontal(self, node1, node2): - if node1.is_template() or node2.is_template(): - return False if ( len(node1.get_nodes()) + len(node2.get_nodes()) > config.cpp.max_horizontal_fusion_size @@ -3609,9 +3629,7 @@ def _get_outer_loop_fusion_depth(self, node1, node2): def can_fuse_vertical_outer_loop(self, node1, node2): return ( - not node1.is_template() - and not node2.is_template() - and node1.get_names() & node2.ancestors + node1.get_names() & node2.ancestors and not ( self._can_fuse_horizontal_impl(node1, node2) and not node1.is_reduction() @@ -3627,11 +3645,6 @@ def get_fusion_pair_priority(self, node1, node2): return 0 def can_fuse_vertical(self, node1, node2): - if node2.is_template(): - # TODO(jgong5): support pre-op fusion with template - return False - if node1.is_template(): - return not node2.is_reduction() return ( self._can_fuse_horizontal_impl(node1, node2) and not node1.is_reduction() ) or self.can_fuse_vertical_outer_loop(node1, node2) @@ -3688,43 +3701,6 @@ def codegen_node( if args_num > CppScheduling.MAX_FUSED_KERNEL_ARGS_NUM: self._set_flush_status(True) - def is_cpp_template(self, node: BaseSchedulerNode) -> bool: - return isinstance(node, SchedulerNode) and isinstance( - node.node, ir.CppTemplateBuffer - ) - - def codegen_template( - self, template_node: BaseSchedulerNode, epilogue_nodes: List[SchedulerNode] - ): - """ - Codegen a CPP template, possibly with fused epilogues - """ - counters["inductor"]["cpp_epilogue_fusion_counter"] += len(epilogue_nodes) - assert self.is_cpp_template( - template_node - ), "Template node passed to CppScheduler.codegen_template must be a SchedulerNode that wraps a CppTemplateBuffer" - template_node = cast(SchedulerNode, template_node) - _, (_, rnumel) = template_node.group - assert rnumel == () - ctb: ir.CppTemplateBuffer = cast(ir.CppTemplateBuffer, template_node.node) - epilogue_ir_nodes: List[ir.Buffer] = [n.node for n in epilogue_nodes] - assert all( - isinstance(n, ir.ComputedBuffer) for n in epilogue_ir_nodes - ), "Epilogue nodes must all be instances of ir.ComputedBuffer" - kernel, render = ctb.make_kernel_render(ctb, epilogue_nodes=epilogue_ir_nodes) - with kernel: - for node in [template_node, *epilogue_nodes]: - node.decide_inplace_update() - node.mark_run() - src_code = render() - - with V.set_kernel_handler(kernel): - node_schedule = [template_node, *epilogue_nodes] - kernel_name = self.define_kernel(src_code, node_schedule, kernel.args) - kernel.call_kernel(kernel_name, ctb) - V.graph.removed_buffers |= kernel.removed_buffers - self.scheduler.free_buffers() - def _get_scheduled_num_args(self): return self.kernel_group.get_num_args() @@ -3734,7 +3710,7 @@ def ready_to_flush(self): def codegen_sync(self): pass - def define_kernel(self, src_code, nodes, kernel_args=None): + def define_kernel(self, src_code, nodes): wrapper = V.graph.wrapper_code fused_name = ( get_fused_kernel_name(nodes, config.cpp.descriptive_names) @@ -3750,8 +3726,7 @@ def define_kernel(self, src_code, nodes, kernel_args=None): src_code = src_code.replace("#pragma CMT", "//") compile_wrapper = IndentedBuffer() - args = self.kernel_group.args if kernel_args is None else kernel_args - _, _, arg_types = args.cpp_argdefs() + _, _, arg_types = self.kernel_group.args.cpp_argdefs() if not V.graph.cpp_wrapper: compile_wrapper.writeline(f"async_compile.cpp_pybinding({arg_types!r}, '''") compile_wrapper.splice(src_code, strip=True) diff --git a/torch/_inductor/codegen/cpp_gemm_template.py b/torch/_inductor/codegen/cpp_gemm_template.py deleted file mode 100644 index 4d2a640515f5c..0000000000000 --- a/torch/_inductor/codegen/cpp_gemm_template.py +++ /dev/null @@ -1,436 +0,0 @@ -from typing import cast, List, Optional - -import torch -import torch.utils -from .. import ir, lowering as L - -from ..kernel.mm_common import mm_args -from ..select_algorithm import DataProcessorTemplateWrapper -from ..utils import cache_on_self, has_free_symbols, parallel_num_threads -from ..virtualized import V -from .cpp_micro_gemm import create_micro_gemm -from .cpp_template import CppTemplate - -from .cpp_template_kernel import CppTemplateKernel -from .cpp_utils import GemmBlocking - -GEMM_TEMPLATE = r""" -{{template.header().getvalue()}} - -{{micro_gemm.codegen_define(kernel)}} - -extern "C" -{{kernel.def_kernel(inputs={"X": X, "W": W, "inp": inp}, outputs={"Y": Y})}} -{ - {{kernel.maybe_codegen_profile()}} - constexpr int64_t num_threads = {{num_threads}}; - constexpr int64_t N = {{kernel.size(GemmOut, 1)}}; - constexpr int64_t K = {{kernel.size(X, 1)}}; - constexpr int64_t M0 = {{micro_gemm.register_blocking.block_m}}; - constexpr int64_t N0 = {{micro_gemm.register_blocking.block_n}}; - constexpr int64_t K0 = {{micro_gemm.register_blocking.block_k}}; - constexpr int64_t N0_blocks = (N + N0 - 1) / N0; - constexpr int64_t K0_blocks = (K + K0 - 1) / K0; - - static_assert(N % N0 == 0, "N dimension must be multiple of N0"); - - // TODO(jgong5): improve cache blocking with CPU info (Mc, Kc) - {%- if is_dynamic_M %} - const int64_t M = {{kernel.size(GemmOut, 0)}}; - const int64_t M0_blocks = (M + M0 - 1) / M0; - {%- if num_threads > 1 %} - const auto [Mt_blocks, Nt_blocks, Kt_blocks] = mm_get_thread_blocking(M, N, K, M0, N0, K0, num_threads); - {%- else %} - const auto Mt_blocks = M0_blocks; - const auto Nt_blocks = N0_blocks; - const auto Kt_blocks = K0_blocks; - {%- endif %} - const int64_t Mc_blocks = Mt_blocks; - const int64_t Kc_blocks = Kt_blocks; - {%- else %} - constexpr int64_t M = {{kernel.size(GemmOut, 0)}}; - constexpr int64_t M0_blocks = (M + M0 - 1) / M0; - constexpr int64_t Mt_blocks = {{template.thread_blocking().block_m}}; - constexpr int64_t Nt_blocks = {{template.thread_blocking().block_n}}; - constexpr int64_t Kt_blocks = {{template.thread_blocking().block_k}}; - constexpr int64_t Mc_blocks = {{template.cache_blocking().block_m}}; - constexpr int64_t Kc_blocks = {{template.cache_blocking().block_k}}; - {%- endif %} - - // TODO(jgong5): support k-slicing - {{kernel.assert_function}}(Kt_blocks == K0_blocks, "Do not support k slicing yet."); - // make sure all partitions are assigned - {{kernel.assert_function}}( - Mt_blocks * Nt_blocks * Kt_blocks * {{num_threads}} >= M0_blocks * N0_blocks * K0_blocks, - "Not all partitions are assigned." - ); - - {%- if num_threads > 1 %} - #pragma omp parallel num_threads({{num_threads}}) - { - int tid = omp_get_thread_num(); - int64_t m_block_start, m_block_end, n_block_start, n_block_end, k_block_start, k_block_end; - mm_get_thread_blocks( - tid, M0_blocks, N0_blocks, K0_blocks, Mt_blocks, Nt_blocks, Kt_blocks, - m_block_start, m_block_end, n_block_start, n_block_end, k_block_start, k_block_end); - {%- else %} - { - int64_t m_block_start = 0; - int64_t m_block_end = M0_blocks; - int64_t n_block_start = 0; - int64_t n_block_end = N0_blocks; - int64_t k_block_start = 0; - int64_t k_block_end = K0_blocks; - {%- endif %} - for (int64_t mc = m_block_start; mc < m_block_end; mc += Mc_blocks) { - const int64_t m_start = mc * M0; - const int64_t m_end = std::min((mc + Mc_blocks) * M0, M); - const int64_t m_size = m_end - m_start; - for (int64_t nc = n_block_start; nc < n_block_end; ++nc) { - const int64_t n_start = nc * N0; - const int64_t n_size = N0; - {%- if use_local_acc %} - {{ kernel.define_buffer("acc_local_buf", ["m_end - m_start", "N0"]) }} - {%- set acc = kernel.local_buffers["acc_local_buf"] %} - {%- else %} - {%- set acc = kernel.slice_nd(GemmOut, [("m_start", "m_end"), ("n_start", "n_start + N0")]) %} - {%- endif %} - {%- if inp is not none and beta != 0 %} - for (int64_t m = 0; m < m_size; ++m) { - #pragma omp simd - for (int64_t n = 0; n < n_size; ++n) { - {{kernel.index(acc, ["m", "n"])}} = {{beta}} * {{kernel.index(inp, ["m + m_start", "n + n_start"])}}; - } - } - {%- endif %} - for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) { - int64_t k_start = kc * K0; - int64_t k_end = std::min((kc + Kc_blocks) * K0, K); - {%- set tile_X = kernel.slice_nd(X, [("m_start", "m_end"), ("k_start", "k_end")]) %} - {%- set tile_W_3d = kernel.slice_nd(W, [("nc", "nc + 1"), ("k_start", "k_end"), ()]) %} - {%- set tile_W = kernel.view(tile_W_3d, ["k_end - k_start", micro_gemm.register_blocking.block_n]) %} - {%- if inp is not none and beta != 0 %} - {{ micro_gemm.codegen_call(kernel, tile_X, tile_W, acc, accum=True)|indent(20, false) }} - {%- else %} - if (kc == k_block_start) { - {{ micro_gemm.codegen_call(kernel, tile_X, tile_W, acc, accum=False)|indent(24, false) }} - } else { - {{ micro_gemm.codegen_call(kernel, tile_X, tile_W, acc, accum=True)|indent(24, false) }} - } - {%- endif %} - } - {%- if reindexer is not none %} - {%- set Y_maybe_transposed = kernel.permute(Y, reindexer([0,1])) %} - {%- else %} - {%- set Y_maybe_transposed = Y %} - {%- endif %} - {%- set tile_Y = kernel.slice_nd(Y_maybe_transposed, [("m_start", "m_end"), ("n_start", "n_start + N0")]) %} - {{ kernel.store_output( - tile_Y, acc, epilogue_nodes, offsets=("m_start", "n_start"), reindexer=reindexer - )|indent(16, false) - }} - } - } - } -} -""" - - -class CppPackedGemmTemplate(CppTemplate): - def __init__( - self, - input_nodes, - layout: ir.Layout, - num_threads: int, - register_blocking: GemmBlocking, - beta=1, - alpha=1, - ): - assert layout.dtype in [torch.float, torch.bfloat16, torch.half] - super().__init__("packed_gemm", input_nodes, layout) - self.beta = beta - self.alpha = alpha - self.num_threads = num_threads - self.register_blocking = register_blocking - m, n = layout.size - _, k = input_nodes[0].get_size() - self.m, self.n, self.k = m, n, k - self.is_dynamic_M = has_free_symbols((m,)) - - @cache_on_self - def thread_blocking(self) -> GemmBlocking: - # TODO(jgong5): allow tuning various blocking options - def get_factors(number): - factors = [] - # priorize more evenly divided factors - for i in range(int(number**0.5), 0, -1): - if number % i == 0: - factors.append(number // i) - factors.append(i) - return factors - - def get_blocking(num_threads, factor, m_blocks, n_blocks, k_blocks): - thread_block_n = (n_blocks + factor - 1) // factor - cofactor = num_threads // factor - thread_block_m = (m_blocks + cofactor - 1) // cofactor - return GemmBlocking(thread_block_m, thread_block_n, k_blocks) - - assert ( - not self.is_dynamic_M - ), "Unable to determine thread blocking for dynamic M." - register_blocking = self.register_blocking - m_blocks = (self.m + register_blocking.block_m - 1) // register_blocking.block_m - n_blocks = (self.n + register_blocking.block_n - 1) // register_blocking.block_n - k_blocks = (self.k + register_blocking.block_k - 1) // register_blocking.block_k - factors = get_factors(self.num_threads) - assert len(factors) > 0 - for factor in factors: - if n_blocks % factor == 0 and m_blocks % (self.num_threads // factor) == 0: - return get_blocking( - self.num_threads, factor, m_blocks, n_blocks, k_blocks - ) - for factor in factors: - if n_blocks % factor == 0: - return get_blocking( - self.num_threads, factor, m_blocks, n_blocks, k_blocks - ) - cofactor = self.num_threads // factor - if m_blocks % cofactor == 0: - return get_blocking( - self.num_threads, factor, m_blocks, n_blocks, k_blocks - ) - raise AssertionError("Should not reach here.") - - @cache_on_self - def cache_blocking(self) -> GemmBlocking: - # TODO(jgong5): improve cache blocking with CPU info - assert ( - not self.is_dynamic_M - ), "Unable to determine cache blocking for dynamic M." - thread_blocking = self.thread_blocking() - return GemmBlocking(thread_blocking.block_m, 1, thread_blocking.block_k) - - @staticmethod - def add_choices( - choices, - layout, - input_nodes, - beta=1, - alpha=1, - trans_w=False, - input_indices=None, - ): - if input_indices is None: - input_indices = list(range(len(input_nodes))) - - def reorder_and_filter(inputs, layout_or_out): - if len(input_indices) == 2: - x_idx = input_indices[0] - w_idx = input_indices[1] - return [inputs[x_idx], inputs[w_idx]], layout_or_out - else: - assert ( - len(input_indices) == 3 - ), "Cpp Packed GEMM template requires 2 or 3 input nodes." - # assume the input order is [inp, x, w] and we reorder it to [x, w, inp] - inp_idx = input_indices[0] - x_idx = input_indices[1] - w_idx = input_indices[2] - return [inputs[x_idx], inputs[w_idx], inputs[inp_idx]], layout_or_out - - def maybe_to_dense(inputs, layout_or_out): - new_inputs = list(inputs) - if isinstance(inputs[1], torch.Tensor): - W = inputs[1] - new_inputs[1] = W.to_dense() if W.is_mkldnn else W - return new_inputs, layout_or_out - - def normalize_shapes(inputs, layout_or_out): - if not trans_w: - return inputs, layout_or_out - - new_inputs = list(inputs) - X = inputs[0] - W = inputs[1] - B = inputs[2] if len(inputs) > 2 else None - if isinstance(W, ir.IRNode): - if trans_w: - if not isinstance(W, ir.TensorBox): - W = ir.TensorBox(W) - W = L.permute(W, [1, 0]) - else: - if trans_w: - assert isinstance(W, torch.Tensor) - W = W.transpose(0, 1) - if B is not None: - if isinstance(B, ir.IRNode): - if not isinstance(B, ir.TensorBox): - B = ir.TensorBox(B) - B = L.expand(B, (X.get_size()[0], B.get_size()[-1])) - else: - assert isinstance(B, torch.Tensor) - B = B.expand(X.shape[0], B.shape[-1]) - new_inputs[1] = W - if B is not None: - new_inputs[2] = B - return new_inputs, layout_or_out - - # TODO(jgong5): decide proper number of threads per problem size - num_threads = parallel_num_threads() - new_inputs, _ = normalize_shapes( - *maybe_to_dense(*reorder_and_filter(input_nodes, layout)) - ) - m, n, k, *_ = mm_args(new_inputs[0], new_inputs[1]) - micro_gemm = create_micro_gemm( - "micro_gemm", - m, - n, - k, - input_dtype=layout.dtype, - output_dtype=torch.float, - alpha=alpha, - num_threads=num_threads, - ) - assert micro_gemm is not None - _, block_n, _ = micro_gemm.register_blocking - - def pack_weight(inputs, layout_or_out): - W = inputs[1] - new_inputs = list(inputs) - if isinstance(W, ir.IRNode): - if not isinstance(W, ir.TensorBox): - W = ir.TensorBox(W) - k, n = W.get_size() - assert ( - n % block_n == 0 - ), f"The last dimension of W must be a multiple of {block_n}." - blocked_w = L.permute( - L.view(W, (k, n // block_n, block_n)), - [1, 0, 2], - ) - blocked_w = ir.ExternKernel.realize_input(blocked_w) - blocked_w = ir.ExternKernel.require_contiguous(blocked_w) - if isinstance(blocked_w, ir.ReinterpretView): - # normalize stride to be "contiguous_strides" per size - # this avoids the problems in L.view during template codegen - assert isinstance(blocked_w.layout, ir.FixedLayout) - blocked_w.layout = ir.FixedLayout( - blocked_w.layout.device, - blocked_w.layout.dtype, - blocked_w.layout.size, - ir.FlexibleLayout.contiguous_strides(blocked_w.layout.size), - blocked_w.layout.offset, - ) - else: - k, n = list(W.shape) - blocked_w = ( - W.reshape(k, n // block_n, block_n).transpose(0, 1).contiguous() - ) - # normalize stride to be "contiguous_strides" per size - # this avoids the problems in L.view during template codegen - new_stride = [1] - for sz in reversed(blocked_w.shape[1:]): - new_stride.insert(0, new_stride[0] * sz) - blocked_w = blocked_w.as_strided(blocked_w.shape, new_stride) - new_inputs[1] = blocked_w - return new_inputs, layout_or_out - - def preprocessor(inputs, layout): - return pack_weight( - *normalize_shapes(*maybe_to_dense(*reorder_and_filter(inputs, layout))) - ) - - def postprocessor(output): - if isinstance(output, ir.TensorBox): - # prepack the weight as input to the template buffer - # TODO(jgong5): prune the unused constants in V.graph - # Should we implement it with constant folding in the scheduler instead? - template_buffer = ir.InputsKernel.unwrap_storage_for_input(output) - assert isinstance(template_buffer, ir.CppTemplateBuffer) - new_input_nodes, _ = reorder_and_filter(input_nodes, layout) - W_node = new_input_nodes[1] - assert W_node.get_name() in V.graph.constants - W = V.graph.constants[W_node.get_name()] - new_input_nodes[1] = W - new_input_nodes, _ = pack_weight( - *normalize_shapes(*maybe_to_dense(new_input_nodes, layout)) - ) - W_packed = new_input_nodes[1] - W_packed_constant = V.graph.add_tensor_constant(W_packed) - template_buffer.inputs[1] = ir.InputsKernel.unwrap_storage_for_input( - W_packed_constant - ) - return output - - template = DataProcessorTemplateWrapper( - CppPackedGemmTemplate, - preprocessor, - postprocessor, - input_nodes=input_nodes, - layout=layout, - num_threads=num_threads, - register_blocking=micro_gemm.register_blocking, - beta=beta, - alpha=alpha, - ) - template.maybe_append_choice(choices) - return template - - def render( # type: ignore[override] - self, - kernel: CppTemplateKernel, - template_buffer_node: Optional[ir.CppTemplateBuffer] = None, - epilogue_nodes: Optional[List[ir.IRNode]] = None, - **kwargs, - ) -> str: - assert len(self.input_nodes) >= 2 - - X, W = self.input_nodes[0], self.input_nodes[1] - inp = self.input_nodes[2] if len(self.input_nodes) > 2 else None - Y = self.output_node - - if template_buffer_node is not None: - # Use the updated prepacked weight buffer - W = template_buffer_node.inputs[1] - Y = template_buffer_node - - template_buffer = Y - Y_is_transposed = False - use_local_acc = self.layout.dtype != torch.float - if epilogue_nodes: - Y = cast(ir.Buffer, epilogue_nodes[-1]) - assert Y.get_name() in V.kernel.inplace_update_buffers - if Y.get_stride() == list(reversed(template_buffer.get_stride())): - Y_is_transposed = True - - micro_gemm = create_micro_gemm( - f"{kernel.kernel_name}_micro_gemm", - self.m, - self.n, - self.k, - input_dtype=self.layout.dtype, - output_dtype=torch.float, - alpha=self.alpha, - num_threads=self.num_threads, - ) - assert micro_gemm is not None - assert self.register_blocking == micro_gemm.register_blocking - - options = dict( - X=X, - W=W, - inp=inp, - Y=Y, - GemmOut=template_buffer, - beta=self.beta, - alpha=self.alpha, - num_threads=self.num_threads, - micro_gemm=micro_gemm, - is_dynamic_M=self.is_dynamic_M, - template=self, - kernel=kernel, - epilogue_nodes=epilogue_nodes, - reindexer=(lambda x: list(reversed(x))) if Y_is_transposed else None, - use_local_acc=use_local_acc, - ) - return self._template_from_string(GEMM_TEMPLATE).render(**options) diff --git a/torch/_inductor/codegen/cpp_micro_gemm.py b/torch/_inductor/codegen/cpp_micro_gemm.py deleted file mode 100644 index 375da4ec12581..0000000000000 --- a/torch/_inductor/codegen/cpp_micro_gemm.py +++ /dev/null @@ -1,447 +0,0 @@ -from collections import namedtuple -from typing import Dict, List, Optional, Type - -import sympy - -import torch - -from .. import ir -from ..codecache import pick_vec_isa, VecAVX2, VecAVX512 -from ..utils import IndentedBuffer, parallel_num_threads -from ..virtualized import V -from .common import KernelTemplate -from .cpp_template_kernel import CppTemplateKernel -from .cpp_utils import DTYPE_TO_CPP, GemmBlocking, value_to_cpp - - -class CppMicroGemm: - """ - A class that codegens a kernel that computes small-sized matrix multiplication. - - A micro GEMM kernel is responsible for register blocking, instruction selection, - and other CPU architecture-specific optimizations. - - The subclasses need to override `codegen_define` to define the kernel function - that is called by the code generated by `codegen_call`. - """ - - # TODO(jgong5): support constant shapes and lds as template args. - DECLARE_KERNEL = r""" -template -inline void {{kernel_name}}( - const {{input_t}}* __restrict__ A, - const {{input_t}}* __restrict__ B, - {{output_t}}* __restrict__ C, - int64_t M, - int64_t N, - int64_t K, - int64_t lda, - int64_t ldb, - int64_t ldc -) -""" - - def __init__( - self, - name, - input_dtype, - output_dtype, - compute_dtype, - register_blocking, - alpha=1, - ): - self.name = name - self.input_dtype = input_dtype - self.output_dtype = output_dtype - self.compute_dtype = compute_dtype - self.register_blocking = register_blocking - self.alpha = alpha - - def get_common_options(self): - return { - "torch": torch, - "kernel_name": self.name, - "input_dtype": self.input_dtype, - "output_dtype": self.output_dtype, - "compute_dtype": self.compute_dtype, - "input_t": DTYPE_TO_CPP[self.input_dtype], - "output_t": DTYPE_TO_CPP[self.output_dtype], - "compute_t": DTYPE_TO_CPP[self.compute_dtype], - "alpha": self.alpha, - } - - def get_kernel_declaration(self): - options = self.get_common_options() - return KernelTemplate._template_from_string(self.DECLARE_KERNEL).render(options) - - def codegen_define(self, kernel: CppTemplateKernel) -> str: - raise NotImplementedError - - def codegen_call( - self, - kernel: CppTemplateKernel, - A: ir.Buffer, - B: ir.Buffer, - C: ir.Buffer, - accum: bool, - ) -> str: - """ - Generate the code for calling the templated kernel that computes - `C += alpha * A @ B` if `accum` is True, or `C = alpha * A @ B` otherwise. - """ - A_ptr = f"&({kernel.index(A, [0, 0])})" - B_ptr = f"&({kernel.index(B, [0, 0])})" - C_ptr = f"&({kernel.index(C, [0, 0])})" - M = kernel.size(C, 0) - N = kernel.size(C, 1) - K = kernel.size(A, 1) - lda = kernel.stride(A, 0) - ldb = kernel.stride(B, 0) - ldc = kernel.stride(C, 0) - res = IndentedBuffer() - res.writeline(f"{self.name}<{value_to_cpp(accum, 'bool')}>(") - with res.indent(): - res.writeline(f"{A_ptr},") - res.writeline(f"{B_ptr},") - res.writeline(f"{C_ptr},") - res.writeline(f"{M},") - res.writeline(f"{N},") - res.writeline(f"{K},") - res.writeline(f"{lda},") - res.writeline(f"{ldb},") - res.writeline(f"{ldc}") - res.writeline(");") - return res.getvalue() - - -CppMicroGemmConfig = namedtuple( - "CppMicroGemmConfig", - [ - "input_dtype", - "output_dtype", - "compute_dtype", - "vec_isa_cls", - "register_blocking", - ], -) - -micro_gemm_configs: Dict[Type[CppMicroGemm], List[CppMicroGemmConfig]] = {} - - -def register_micro_gemm(*configs): - def inner(cls): - assert ( - cls not in micro_gemm_configs - ), f"Duplicate micro_gemm registration for {cls}" - assert len(configs) > 0, f"No micro_gemm configs provided for {cls}" - micro_gemm_configs[cls] = list(configs) - return cls - - return inner - - -def generate_gemm_config( - vec_isa_cls, - register_blockings, - input_dtype=torch.float, - output_dtype=None, - compute_dtype=None, -): - if output_dtype is None: - output_dtype = input_dtype - if compute_dtype is None: - compute_dtype = output_dtype - return [ - CppMicroGemmConfig( - input_dtype, - output_dtype, - compute_dtype, - vec_isa_cls, - GemmBlocking(*blocking), - ) - for blocking in register_blockings - ] - - -class CppMicroGemmRef(CppMicroGemm): - """ - A reference implementation of the CppMicroGemm class with naive C++ code. - It is used for correctness debugging. - """ - - TEMPLATE_ENTRY = r""" -{{declare_kernel}} { - for (int64_t m = 0; m < M; ++m) { - for (int64_t n = 0; n < N; ++n) { - {{compute_t}} result = accum ? C[m * ldc + n] : 0; - for (int64_t k = 0; k < K; ++k) { - result += ({{compute_t}})A[m * lda + k] * ({{compute_t}})B[k * ldb + n] * {{alpha}}; - } - C[m * ldc + n] = result; - } - } -} -""" - - def __init__(self, name, input_dtype, output_dtype, compute_dtype, alpha): - super().__init__( - name, input_dtype, output_dtype, compute_dtype, GemmBlocking(1, 1, 1), alpha - ) - - def codegen_define(self, kernel: CppTemplateKernel) -> str: - options = { - "declare_kernel": self.get_kernel_declaration(), - **self.get_common_options(), - } - return KernelTemplate._template_from_string(self.TEMPLATE_ENTRY).render(options) - - -@register_micro_gemm( - *generate_gemm_config( - VecAVX512, [(8, 48, 1), (8, 32, 1), (16, 16, 1)], input_dtype=torch.float - ), - *generate_gemm_config( - VecAVX512, - [(8, 48, 1), (8, 32, 1), (16, 16, 1)], - input_dtype=torch.bfloat16, - output_dtype=torch.float, - ), - *generate_gemm_config( - VecAVX512, - [(8, 48, 1), (8, 32, 1), (16, 16, 1)], - input_dtype=torch.half, - output_dtype=torch.float, - ), - *generate_gemm_config( - VecAVX2, [(4, 24, 1), (4, 16, 1), (8, 8, 1)], input_dtype=torch.float - ), - *generate_gemm_config( - VecAVX2, - [(4, 24, 1), (4, 16, 1), (8, 8, 1)], - input_dtype=torch.bfloat16, - output_dtype=torch.float, - ), - *generate_gemm_config( - VecAVX2, - [(4, 24, 1), (4, 16, 1), (8, 8, 1)], - input_dtype=torch.half, - output_dtype=torch.float, - ), -) -class CppMicroGemmFP32Vec(CppMicroGemm): - """ - This class generates the code for micro gemm using fp32 vec instructions for compute. - It supports input types of torch.float, torch.bfloat16, and torch.half with fp32 output. - """ - - TEMPLATE_ENTRY = r""" -{{declare_kernel}} { - TORCH_CHECK(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}"); - TORCH_CHECK(K % {{block_k}} == 0, "K dimension must be multiple of {{block_k}}"); - // TODO(jgong5): loop unroll for M and N - for (int64_t m = 0; m < M; m += {{block_m}}) { - int64_t block_m = std::min(M - m, {{block_m}}); - for (int64_t n = 0; n < N; n += {{block_n}}) { - if (block_m == {{block_m}}) { - {{kernel_name}}_kernel<{{block_m}}, {{block_n}}, accum>( - A + m * lda, - B + n, - C + m * ldc + n, - K, - lda, - ldb, - ldc - ); - } else { - switch (block_m) { - {%- for b in range(block_m - 1, 0, -1) %} - case {{b}}: - {{kernel_name}}_kernel<{{b}}, {{block_n}}, accum>( - A + m * lda, - B + n, - C + m * ldc + n, - K, - lda, - ldb, - ldc - ); - break; - {%- endfor %} - default: - {{kernel.assert_function}}(false, "Unsupported block_m: ", block_m); - } - } - } - } -} -""" - - TEMPLATE_KERNEL = r""" -template -inline void {{kernel_name}}_kernel( - const {{input_t}}* __restrict__ A, - const {{input_t}}* __restrict__ B, - {{output_t}}* __restrict__ C, - int64_t K, - int64_t lda, - int64_t ldb, - int64_t ldc -) { - using Vectorized = at::vec::Vectorized<{{compute_t}}>; - using VectorizedIn = at::vec::Vectorized<{{input_t}}>; - constexpr auto VLEN = Vectorized::size(); - constexpr auto ROWS = BLOCK_M; - constexpr auto COLS = BLOCK_N / VLEN; - - Vectorized va; - at::vec::VectorizedN<{{compute_t}}, COLS> vb; - at::vec::VectorizedN<{{compute_t}}, ROWS*COLS> vc; - - auto loadc = [&](auto i) { - if constexpr (accum) { - constexpr int row = i / COLS; - constexpr int col = i % COLS; - vc[i] = Vectorized::loadu(C + row * ldc + col * VLEN); - } else { - vc[i] = Vectorized(0.0f); - } - }; - c10::ForcedUnroll{}(loadc); - - auto compute = [&, COLS](auto i, int k) { - constexpr int row = i / COLS; - constexpr int col = i % COLS; - - if constexpr (col == 0) { - {%- if alpha != 1 %} - va = Vectorized(static_cast<{{compute_t}}>(A[row * lda + k]) * {{alpha}}); - {%- else %} - va = Vectorized(static_cast<{{compute_t}}>(A[row * lda + k])); - {%- endif %} - } - - if constexpr (row == 0) { - {%- if input_dtype == torch.bfloat16 or input_dtype == torch.float16 %} - auto b = VectorizedIn::loadu(B + k * ldb + col * VLEN, VLEN); - vb[col] = at::vec::convert<{{compute_t}}>(b); - {%- else %} - vb[col] = Vectorized::loadu(B + k * ldb + col * VLEN); - {%- endif %} - } - - constexpr int idx = row * COLS + col; - vc[idx] = at::vec::fmadd(va, vb[col], vc[idx]); - }; - - {{kernel.unroll_pragma(4)}} - for (int k = 0; k < K; ++k) { - c10::ForcedUnroll{}(compute, k); - } - - // store to C - auto storec = [&](auto i) { - constexpr int row = i / COLS; - constexpr int col = i % COLS; - vc[i].store(C + row * ldc + col * VLEN); - }; - c10::ForcedUnroll{}(storec); -} -""" - - def codegen_define(self, kernel: CppTemplateKernel) -> str: - options = { - "declare_kernel": self.get_kernel_declaration(), - "kernel": kernel, - "block_m": self.register_blocking.block_m, - "block_n": self.register_blocking.block_n, - "block_k": self.register_blocking.block_k, - **self.get_common_options(), - } - result = KernelTemplate._template_from_string(self.TEMPLATE_KERNEL).render( - options - ) - result += KernelTemplate._template_from_string(self.TEMPLATE_ENTRY).render( - options - ) - return result - - -def create_micro_gemm( - name, - m, - n, - k, - input_dtype, - output_dtype=None, - compute_dtype=None, - alpha=1, - num_threads=-1, - use_ref=False, -) -> Optional[CppMicroGemm]: - def create_from_config(cls, config: CppMicroGemmConfig): - return cls( - name, - config.input_dtype, - config.output_dtype, - config.compute_dtype, - config.register_blocking, - alpha, - ) - - assert isinstance(n, int) or n.is_number, n - assert isinstance(k, int) or k.is_number, k - m = V.graph.sizevars.size_hint(m, fallback=1) if isinstance(m, sympy.Expr) else m - assert isinstance(m, int), m - if output_dtype is None: - output_dtype = input_dtype - if compute_dtype is None: - compute_dtype = output_dtype - if num_threads < 0: - num_threads = parallel_num_threads() - vec_isa = pick_vec_isa() - matched_configs = [] - for cls, configs in micro_gemm_configs.items(): - for config in configs: - if not isinstance(vec_isa, config.vec_isa_cls): - continue - if ( - config.input_dtype == input_dtype - and config.output_dtype == output_dtype - and config.compute_dtype == compute_dtype - ): - block_m, block_n, block_k = config.register_blocking - # TODO(jgong5): support n % n_block_size != 0 - if n % block_n != 0: - continue - # Criteria on the ranking of configurations - # 1. Dividable by block sizes (block_m, block_k) - # 2. Number of mxn blocks is large enough to occupy all the threads - # 3. Register blocks are larger - dividable_score = 0 - if k % block_k == 0: - dividable_score += 1 - if m % block_m == 0: - dividable_score += 1 - occupancy_score = 0 - n_blocks = n // block_n - total_mxn_blocks = n // block_n * ((m + block_m - 1) // block_m) - if n_blocks >= num_threads: - occupancy_score += 1 - if total_mxn_blocks >= num_threads: - occupancy_score += 1 - matched_configs.append( - ( - (dividable_score, occupancy_score, block_m * block_n * block_k), - cls, - config, - ) - ) - if len(matched_configs) == 0: - if use_ref: - return CppMicroGemmRef( - name, input_dtype, output_dtype, compute_dtype, alpha - ) - else: - return None - # TODO(jgong5): allow autotuning on choices of configs - return create_from_config(*max(matched_configs, key=lambda x: x[0])[1:]) diff --git a/torch/_inductor/codegen/cpp_prefix.h b/torch/_inductor/codegen/cpp_prefix.h index 45f874fc4d269..7e3483ca99948 100644 --- a/torch/_inductor/codegen/cpp_prefix.h +++ b/torch/_inductor/codegen/cpp_prefix.h @@ -5,7 +5,6 @@ #include #include #include -#include #include // WARNING: be extra careful when including more ATen/c10 header files here! @@ -46,7 +45,7 @@ template struct Welford { T mean = T(0); T m2 = T(0); - T weight = T(0); + int64_t index = 0; }; @@ -59,41 +58,57 @@ struct IsVecType>: std::true_type {}; #endif template -Welford welford_combine(const Welford &a, const Welford &b) { - if constexpr (!IsVecType::value) { - if (a.weight == 0) { - return b; - } - if (b.weight == 0) { - return a; +struct WeightRecp { + using scalar_t = typename T::value_type; + int64_t N; + std::vector weight_recps; + WeightRecp(int64_t N) : N(N) { + weight_recps.reserve(N); + for (const auto i : c10::irange(N)) { + weight_recps.push_back( + scalar_t(static_cast(1) / static_cast(i + 1))); } } - auto delta = b.mean - a.mean; - auto new_weight = a.weight + b.weight; - auto wb_over_w = b.weight / new_weight; - if constexpr (IsVecType::value) { - // Guard against division by zero - wb_over_w = T::blendv(wb_over_w, T(0), new_weight == T(0)); +}; + +template +Welford welford_combine(const Welford &a, const Welford &b) { + if (a.index == 0) { + return b; } + if (b.index == 0) { + return a; + } + auto delta = b.mean - a.mean; + auto new_index = a.index + b.index; + auto wb_over_w = T(b.index) / T(new_index); auto result = Welford{ a.mean + delta * wb_over_w, - a.m2 + b.m2 + delta * delta * a.weight * wb_over_w, - new_weight + a.m2 + b.m2 + delta * delta * T(a.index) * wb_over_w, + new_index, }; return result; } template -Welford welford_combine(const Welford &acc, T data) { +Welford welford_combine(const Welford &acc, T data, const WeightRecp* w=nullptr) { // Add a single data point + int64_t index = acc.index + 1; auto delta = data - acc.mean; - auto new_weight = acc.weight + T(1); - auto new_mean = acc.mean + delta / new_weight; + T new_mean; + if constexpr (!IsVecType::value) { + new_mean = acc.mean + delta / T(index); + } else { + new_mean = acc.mean + + ((w == nullptr || acc.index >= w->weight_recps.size()) + ? delta / T(index) + : delta * T(w->weight_recps[acc.index])); + } auto new_delta = data - new_mean; auto result = Welford{ new_mean, acc.m2 + delta * new_delta, - new_weight + index }; return result; } @@ -178,10 +193,11 @@ template Welford welford_vec_reduce_all(Welford> acc) { using Vec = at::vec::Vectorized; for (size_t n = 1; n < Vec::size(); n *= 2) { + auto index = acc.index; auto shuffled = Welford{ vec_shuffle_down(acc.mean, n), vec_shuffle_down(acc.m2, n), - vec_shuffle_down(acc.weight, n) + index, }; acc = welford_combine(acc, shuffled); } @@ -194,8 +210,7 @@ Welford welford_vec_reduce_all(Welford> acc.m2.store(array); result.m2 = array[0]; - acc.weight.store(array); - result.weight = array[0]; + result.index = acc.index; return result; } @@ -294,100 +309,3 @@ atomic_add(volatile T *addr, T offset) { std::atomic *atomic_addr = (std::atomic *)addr; atomic_addr->fetch_add(offset, std::memory_order_relaxed); } - -std::tuple mm_get_thread_blocking( - int64_t M, - int64_t N, - int64_t K, - int64_t M0, - int64_t N0, - int64_t K0, - int num_threads) { - auto get_factors = [](int64_t number) { - int count = 0; - for (int64_t i = std::sqrt(number); i > 0; --i) { - if (number % i == 0) { - count += 2; - } - } - auto factors = std::make_unique(count); - int index = 0; - for (int64_t i = std::sqrt(number); i > 0; --i) { - if (number % i == 0) { - factors[index++] = number / i; - factors[index++] = i; - } - } - return std::make_tuple(std::move(factors), count); - }; - - auto get_blocking = [](int64_t num_threads, - int64_t factor, - int64_t m_blocks, - int64_t n_blocks, - int64_t k_blocks) { - int64_t thread_block_n = (n_blocks + factor - 1) / factor; - int64_t cofactor = num_threads / factor; - int64_t thread_block_m = (m_blocks + cofactor - 1) / cofactor; - return std::make_tuple(thread_block_m, thread_block_n, k_blocks); - }; - - int64_t m_blocks = (M + M0 - 1) / M0; - int64_t n_blocks = (N + N0 - 1) / N0; - int64_t k_blocks = (K + K0 - 1) / K0; - - auto [factors, count] = get_factors(num_threads); - assert(count > 0); - - for (int i = 0; i < count; ++i) { - int64_t factor = factors[i]; - if (n_blocks % factor == 0 && - m_blocks % (num_threads / factor) == 0) { - return get_blocking( - num_threads, factor, m_blocks, n_blocks, k_blocks); - } - } - - for (int i = 0; i < count; ++i) { - int64_t factor = factors[i]; - if (n_blocks % factor == 0) { - return get_blocking( - num_threads, factor, m_blocks, n_blocks, k_blocks); - } - int64_t cofactor = num_threads / factor; - if (m_blocks % cofactor == 0) { - return get_blocking( - num_threads, factor, m_blocks, n_blocks, k_blocks); - } - } - - assert(false && "Should not reach here."); - // Dummy return to avoid compiler warning - return std::make_tuple(0, 0, 0); -} - -inline void mm_get_thread_blocks( - int thread_id, - int64_t M_blocks, - int64_t N_blocks, - int64_t K_blocks, - int64_t Mt_blocks, - int64_t Nt_blocks, - int64_t Kt_blocks, - int64_t& m_block_start, - int64_t& m_block_end, - int64_t& n_block_start, - int64_t& n_block_end, - int64_t& k_block_start, - int64_t& k_block_end) { - int64_t num_Kt = (K_blocks + Kt_blocks - 1) / Kt_blocks; - k_block_start = (thread_id % num_Kt) * Kt_blocks; - k_block_end = std::min(k_block_start + Kt_blocks, K_blocks); - thread_id /= num_Kt; - int64_t num_Nt = (N_blocks + Nt_blocks - 1) / Nt_blocks; - n_block_start = (thread_id % num_Nt) * Nt_blocks; - n_block_end = std::min(n_block_start + Nt_blocks, N_blocks); - thread_id /= num_Nt; - m_block_start = std::min(thread_id * Mt_blocks, M_blocks); - m_block_end = std::min(m_block_start + Mt_blocks, M_blocks); -} diff --git a/torch/_inductor/codegen/cpp_template.py b/torch/_inductor/codegen/cpp_template.py deleted file mode 100644 index 222d6a2e57ba0..0000000000000 --- a/torch/_inductor/codegen/cpp_template.py +++ /dev/null @@ -1,116 +0,0 @@ -import functools -import itertools -import logging - -import sys -from typing import List, Optional -from unittest.mock import patch - -import sympy - -from .. import codecache, config, ir -from ..autotune_process import CppBenchmarkRequest, TensorMeta -from ..utils import IndentedBuffer, Placeholder, unique -from ..virtualized import V -from .common import KernelTemplate -from .cpp_template_kernel import CppTemplateCaller, CppTemplateKernel - -log = logging.getLogger(__name__) - - -class CppTemplate(KernelTemplate): - index_counter = itertools.count() - - def __init__( - self, - name: str, - input_nodes, - layout: ir.Layout, - ): - super().__init__(name) - self.input_nodes = input_nodes - self.output_node: ir.Buffer = ir.Buffer("buf_out", layout) - self.layout = layout - - def generate(self, **kwargs): - kernel_name = f"cpp_{self.name}" - with patch.object( - V.graph, "get_dtype", self._fake_get_dtype(self.output_node) - ), CppTemplateKernel( - kernel_name=kernel_name, - ) as kernel: - code = kernel.render(self, **kwargs) - _, call_args, _ = kernel.args.python_argdefs() - log.debug("Generated Code:\n%s", code) - log.debug( - "Args: cpp_argdefs: %s, python_argdefs: %s", - kernel.args.cpp_argdefs(), - kernel.args.python_argdefs(), - ) - - expected_args = list( - unique(input_node.get_name() for input_node in self.input_nodes) - ) - expected_args.extend([self.output_node.get_name()]) - assert list(call_args)[: len(expected_args)] == expected_args, ( - call_args, - expected_args, - ) - extra_args = V.graph.sizevars.size_hints( - map(sympy.expand, call_args[len(expected_args) :]) - ) - - kernel_hash_name = f"cpp_{self.name}_{next(self.index_counter)}" - - # Create the BenchmarkRequest for CPP - bmreq = CppBenchmarkRequest( - kernel_name=kernel_name, - input_tensor_meta=TensorMeta.from_irnodes(self.input_nodes), - output_tensor_meta=TensorMeta.from_irnodes(self.output_node), - extra_args=extra_args, - source_code=code, - ) - - def make_kernel_render( - template_node: ir.CppTemplateBuffer, - epilogue_nodes: Optional[List[ir.IRNode]] = None, - ): - kernel = CppTemplateKernel( - kernel_name=str(Placeholder.KERNEL_NAME), - ) - render = functools.partial( - kernel.render, - self, - template_buffer_node=template_node, - epilogue_nodes=epilogue_nodes, - **kwargs, - ) - return kernel, render - - return CppTemplateCaller( - kernel_hash_name, - self.name, - self.input_nodes, - self.output_node.get_layout(), - make_kernel_render, - bmreq, - self, - ) - - def header(self) -> IndentedBuffer: - res = IndentedBuffer() - res.writeline(codecache.cpp_prefix()) - res.splice( - """ - #include "c10/util/Unroll.h" - """ - ) - enable_kernel_profile = ( - config.cpp.enable_kernel_profile and sys.platform == "linux" - ) - if enable_kernel_profile: - res.writelines(["#include "]) - return res - - def render(self, **kwargs) -> str: - raise NotImplementedError diff --git a/torch/_inductor/codegen/cpp_template_kernel.py b/torch/_inductor/codegen/cpp_template_kernel.py deleted file mode 100644 index deff54e10eb99..0000000000000 --- a/torch/_inductor/codegen/cpp_template_kernel.py +++ /dev/null @@ -1,329 +0,0 @@ -import itertools -from typing import Any, Callable, Dict, List, Optional, Tuple, Union - -import sympy -from sympy.parsing.sympy_parser import parse_expr - -import torch -from .. import codecache, config, ir, lowering as L - -from ..autotune_process import CppBenchmarkRequest -from ..select_algorithm import PartialRender -from ..utils import sympy_index_symbol -from ..virtualized import V -from .common import Kernel, OpOverrides -from .cpp import CppKernelProxy, KernelGroup -from .cpp_utils import cexpr_index, DTYPE_TO_CPP, LocalBufferScope - - -def parse_expr_with_index_symbols(expr): - if isinstance(expr, sympy.Expr): - return expr - elif isinstance(expr, (list, tuple)): - return [parse_expr_with_index_symbols(e) for e in expr] - else: - expr = parse_expr(str(expr)) - int_symbols = {sym: sympy_index_symbol(sym.name) for sym in expr.free_symbols} - return expr.subs(int_symbols) - - -def wrap_with_tensorbox(node) -> ir.TensorBox: - return ( - ir.TensorBox.create(node) if isinstance(node, ir.Buffer) else ir.TensorBox(node) - ) - - -class CppTemplateKernel(Kernel): - overrides = OpOverrides - - def __init__(self, kernel_name): - super().__init__() - self.kernel_name = kernel_name - self.render_hooks = {} - self.local_buffers = {} - - def render(self, template, **kwargs): - return PartialRender( - template.render(kernel=self, **kwargs), self.render_hooks - ).finalize() - - def def_kernel( - self, - inputs: Dict[str, ir.Buffer], - outputs: Dict[str, ir.Buffer], - ) -> str: - for name, inp in inputs.items(): - if inp is not None: - self.args.input_buffers[inp.get_name()] = name - for name, out in outputs.items(): - if out.get_name() not in self.args.inplace_buffers: - self.args.output_buffers[out.get_name()] = name - unique_sizevars = { - s - for input in inputs.values() - if input is not None - for sym in itertools.chain(input.get_size(), input.get_stride()) - if isinstance(sym, sympy.Expr) - for s in sym.free_symbols - } - unique_sizevars |= { - s - for output in outputs.values() - for sym in itertools.chain(output.get_size(), output.get_stride()) - if isinstance(sym, sympy.Expr) - for s in sym.free_symbols - } - sizevars = sorted(unique_sizevars, key=str) - for sizevar in sizevars: - self.args.sizevars[sizevar] = f"k{sizevar}" - - def hook(): - cpp_argdefs, _, _ = self.args.cpp_argdefs() - return f"void {self.kernel_name}({', '.join(cpp_argdefs)})" - - placeholder = "" - assert placeholder not in self.render_hooks - self.render_hooks[placeholder] = hook - return placeholder - - def call_kernel(self, name: str, node: ir.CppTemplateBuffer): - wrapper = V.graph.wrapper_code - _, call_args, arg_types = self.args.cpp_argdefs() - wrapper.generate_kernel_call(name, call_args, cuda=False, arg_types=arg_types) - - def dtype(self, node: ir.Buffer) -> str: - return DTYPE_TO_CPP[node.get_dtype()] - - def acc_dtype(self, node: ir.Buffer) -> str: - if node.get_dtype() in [torch.float32, torch.bfloat16, torch.half]: - return "float" - else: - raise NotImplementedError(f"Unsupported dtype: {node.get_dtype()}") - - def size(self, node: ir.Buffer, dim: int) -> str: - return cexpr_index(self.rename_indexing(node.get_size()[dim])) - - def stride(self, node: ir.Buffer, dim: int) -> str: - return cexpr_index(self.rename_indexing(node.get_stride()[dim])) - - def index(self, node: ir.Buffer, indices: List[Any]) -> str: - indexer = node.make_indexer() - index = indexer(parse_expr_with_index_symbols(indices)) - index = self.rename_indexing(index) - outer_name = node.get_name() - inner_name = ( - outer_name - if outer_name in self.local_buffers - else self.args.input(node.get_name()) - ) - return f"{inner_name}[{cexpr_index(index)}]" - - def slice_nd(self, node, ranges: List[Tuple[Any, Any]]) -> ir.ReinterpretView: - """ - Slice the given node with a list of ranges (start and end) corresponding to its dims. - The dim is not sliced if the corresponding range is empty. - """ - assert len(ranges) == len(node.get_size()) - sliced = wrap_with_tensorbox(node) - for dim, _range in enumerate(ranges): - if len(_range) == 0: - continue - assert len(_range) == 2 - start, end = parse_expr_with_index_symbols(_range) - sliced = L.slice_(sliced, dim, start, end, clamp=False) - assert isinstance(sliced.data, ir.ReinterpretView) - return sliced.data - - def view(self, node, sizes: List[Any]) -> ir.View: - node = wrap_with_tensorbox(node) - sizes = parse_expr_with_index_symbols(sizes) - return L.view(node, sizes).data - - def permute(self, node, dims): - node = wrap_with_tensorbox(node) - permuted = L.permute(node, dims).data - assert isinstance(permuted, ir.ReinterpretView) - return permuted - - @property - def assert_function(self) -> str: - if V.graph.aot_mode: - return "AOTI_TORCH_CHECK" - else: - return "TORCH_CHECK" - - def maybe_codegen_profile(self) -> str: - if config.cpp.enable_kernel_profile: - graph_id = V.graph.graph_id - prefix = "graph_" + str(graph_id) + "_" if graph_id is not None else "" - return f'RECORD_FUNCTION("{prefix}{self.kernel_name}", c10::ArrayRef({{}}));' - else: - return "" - - def unroll_pragma(self, unroll): - if codecache.is_gcc(): - return f"#pragma GCC unroll {unroll}" - else: - return f"#pragma unroll {unroll}" - - def define_buffer(self, name, sizes: List[Any], dtype=torch.float) -> str: - """Define kernel local buffer""" - sizes = parse_expr_with_index_symbols(sizes) - buf = ir.Buffer(name, ir.FixedLayout(torch.device("cpu"), dtype, sizes)) - self.local_buffers[name] = buf - ctype = f"{DTYPE_TO_CPP[dtype]}" - numel = f"{cexpr_index(buf.get_numel())}" - return f"auto _{name} = std::make_unique<{ctype}[]>({numel}); auto {name} = _{name}.get();" - - def store_pointwise_nodes( - self, - dst: ir.Buffer, - nodes: List[ir.IRNode], - offsets: Optional[List[sympy.Expr]] = None, - reindexer: Optional[Callable[[List[Any]], List[Any]]] = None, - ) -> str: - var_sizes = (tuple(dst.get_size()), ()) - var_ranges = {sympy.Symbol(f"z{i}"): sz for i, sz in enumerate(var_sizes[0])} - if not offsets: - offsets = [sympy.Integer(0)] * len(var_sizes[0]) - assert len(offsets) == len(var_sizes[0]) - output_index = dst.get_layout().make_indexer()(var_ranges.keys()) - kernel_group = KernelGroup() - kernel_group.args = self.args - cpp_kernel_proxy = CppKernelProxy(kernel_group) - bodies = [] - var_sizes_list = [] - for i, node in enumerate(nodes): - output_name = node.get_name() if i < len(nodes) - 1 else dst.get_name() - node = node.data if isinstance(node, ir.ComputedBuffer) else node - assert isinstance(node, ir.Pointwise), node - - def fn(*args): - assert len(args) == 2 - assert len(args[0]) == len(var_sizes[0]) - assert len(args[1]) == 0 - new_args = [arg + offset for arg, offset in zip(args[0], offsets)] # type: ignore[arg-type] - if reindexer is not None: - new_args = reindexer(new_args) - V.ops.store( - output_name, - output_index, - node.make_loader()(new_args).value, - ) - - body = ir.LoopBody(fn, (list(var_ranges.keys()), ()), var_ranges) - bodies.append(body) - var_sizes_list.append(var_sizes) - - cpp_kernel_proxy.codegen_loop_bodies(bodies, var_sizes_list) - kernel_group.finalize_kernel(cpp_kernel_proxy, []) - return kernel_group.loops_code.getvalue() - - def store_output( - self, - dst: ir.Buffer, - src: ir.Buffer, - epilogue_nodes: Optional[List[ir.IRNode]] = None, - offsets: Optional[List[Any]] = None, - reindexer: Optional[Callable[[List[Any]], List[Any]]] = None, - ): - """ - Store the `src` buffer to the `dst` buffer. The size of `src` and `dst` should match. - If `epilogue_nodes` is provided, the `src` buffer is firstly computed with the epilogues - before stored to `dst`. The `epilogues_nodes` are all pointwise. - - Notes: - 1. `src` and `dst` buffer could be the same buffer in which case we are doing in-place compute - and stores. In case `epilogue_nodes` are not provided, we do nothing. - 2. The `epilogue_nodes`, if exist, have computations on `src` before storing to `dst` but since - they come form the original Inductor IR, they might need to be adjusted before working with - `src` and `dst` as outlined below: - a) `src` or `dst` buffer could be a sub-slice of the ranges the `epilogue_nodes`work on. - In this case, the `offsets` could be provided to adjust the indices passed to - `epilogue_nodes` during codegen and the data ranges are also configured according to - the sizes of `src` and `dst`. - b) `dst` might be indexed in a different way as the `epilogue_nodes`, hence a `reindexer` is - needed on the indices to `epilogue_nodes` to match the indexing of `dst`. - """ - assert dst.get_size() == src.get_size() - if offsets: - offsets = parse_expr_with_index_symbols(offsets) - if epilogue_nodes: - return self.store_pointwise_nodes(dst, epilogue_nodes, offsets, reindexer) - else: - if dst.get_name() != src.get_name(): - # src is local - copy = L.copy(dst, src).data.data - with LocalBufferScope(self) as scope: - scope.add_local_buffer(src) - return self.store_pointwise_nodes(dst, [copy]) - else: - assert dst.layout == src.layout - return "" - - -class CppTemplateCaller(ir.ChoiceCaller): - """ - CppTemplateCaller - - This class represents a caller for CPP template kernels. It is a subclass of ir.ChoiceCaller. - Attributes: - name (str): The name of the caller. - category (str): The category of the caller. - bmreq (CppBenchmarkRequest): The benchmark request for the caller. - template_buffer (ir.CppTemplateBuffer): The template buffer for the caller. - """ - - def __init__( - self, - name: str, - category: str, - input_nodes: List[ir.Buffer], - layout: ir.Layout, - make_kernel_render: Callable[ - [ir.CppTemplateBuffer, Optional[List[ir.IRNode]]], str - ], - bmreq: CppBenchmarkRequest, - template: "CppTemplate", # type: ignore[name-defined] # noqa: F821 - info_kwargs: Optional[ - Dict[str, Union[ir.PrimitiveInfoType, List[ir.PrimitiveInfoType]]] - ] = None, - ): - super().__init__(name, input_nodes, layout) - self.category = category - self.make_kernel_render = make_kernel_render - self.bmreq = bmreq - self.template = template - self.info_kwargs = info_kwargs - - def precompile(self) -> None: - assert self.bmreq is not None - self.bmreq.precompile() - - def benchmark(self, *args, out) -> float: - assert self.bmreq is not None - return self.bmreq.benchmark(*args, output_tensor=out) - - def hash_key(self) -> str: - return "-".join( - [ - self.category, - self.bmreq.hash_key, - ] - ) - - def info_dict( - self, - ) -> Dict[str, Union[ir.PrimitiveInfoType, List[ir.PrimitiveInfoType]]]: - return {"backend": "CPP", "op_type": "unknown"} - - def output_node(self) -> ir.TensorBox: - return ir.TensorBox.create( - ir.CppTemplateBuffer( - layout=self.layout, - inputs=self.input_nodes, - make_kernel_render=self.make_kernel_render, - template=self.template, - choice=self, - ) - ) diff --git a/torch/_inductor/codegen/cpp_utils.py b/torch/_inductor/codegen/cpp_utils.py index 54904f33f20bc..7e6f06b9e507e 100644 --- a/torch/_inductor/codegen/cpp_utils.py +++ b/torch/_inductor/codegen/cpp_utils.py @@ -1,15 +1,8 @@ -import contextlib import math -from collections import namedtuple -from typing import Dict -from unittest.mock import patch - import torch -from .. import ir -from ..virtualized import V -from .common import ExprPrinter, Kernel +from .common import ExprPrinter DTYPE_TO_CPP = { torch.float32: "float", @@ -62,8 +55,6 @@ INDEX_TYPE = "long" -GemmBlocking = namedtuple("GemmBlocking", ["block_m", "block_n", "block_k"]) - class CppPrinter(ExprPrinter): def _print_Integer(self, expr): @@ -241,58 +232,3 @@ def value_to_cpp(value, cpp_type): return f"std::numeric_limits<{cpp_type}>::quiet_NaN()" else: return f"static_cast<{cpp_type}>({repr(value)})" - - -class LocalBufferScope: - """ - This class creates a context that helps to generate code involving Inductor IR with - function local buffers. These buffers are constructed during the codegen process and - are used to store intermediate results such as local accumulators. We do not want to - add them to `V.graph` since they are not global and we do not want to add them as - function arguments either. So we patch the codegen processes under this scope to support - these buffers without exposure to the outside world. - """ - - def __init__(self, kernel: Kernel): - self.kernel = kernel - self.exit_stack = contextlib.ExitStack() - self.local_buffers: Dict[str, ir.Buffer] = {} - - def __enter__(self): - self.exit_stack.__enter__() - original_get_dtype = V.graph.get_dtype - - def get_dtype(name): - if name in self.local_buffers: - return self.local_buffers[name].get_dtype() - return original_get_dtype(name) - - self.exit_stack.enter_context(patch.object(V.graph, "get_dtype", get_dtype)) - - original_input = self.kernel.args.input - - def input(name): - if name in self.local_buffers: - return name - return original_input(name) - - self.exit_stack.enter_context(patch.object(self.kernel.args, "input", input)) - - original_output = self.kernel.args.output - - def output(name): - if name in self.local_buffers: - return name - return original_output(name) - - self.exit_stack.enter_context(patch.object(self.kernel.args, "output", output)) - - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.local_buffers.clear() - self.exit_stack.__exit__(exc_type, exc_val, exc_tb) - - def add_local_buffer(self, buffer: ir.Buffer): - assert buffer.get_name() not in self.local_buffers - self.local_buffers[buffer.get_name()] = buffer diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index 6ce230714632a..9595f1da6f957 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -552,10 +552,8 @@ def write_wrapper_decl(self): ), "Fails to get the dtype of the sympy.Expr" cpp_dtype = DTYPE_TO_CPP[dtype] if config.abi_compatible: - self.prefix.writeline(f"{cpp_dtype} {input_key};") - dtype_str = str(dtype).split(".")[-1] - self.prefix.writeline( - f"aoti_torch_item_{dtype_str}(inputs[{idx}], &{input_key});" + self.codegen_tensor_item( + dtype, f"inputs[{idx}]", input_key, self.prefix ) else: self.prefix.writeline( @@ -890,6 +888,19 @@ def codegen_scalar_to_tensor(self, output: str): ) return name + def codegen_tensor_item( + self, dtype: torch.dtype, tensor: str, scalar: str, indented_buffer=None + ): + assert ( + config.abi_compatible + ), "codegen_tensor_item is only used for the ABI-compatible mode" + dtype_str = str(dtype).split(".")[-1] + writer = indented_buffer or self + writer.writeline(f"{DTYPE_TO_CPP[dtype]} {scalar};") + writer.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_item_{dtype_str}({tensor}, &{scalar}));" + ) + @cache_on_self def get_output_refs(self): return [ @@ -1120,7 +1131,7 @@ def g(args): ) def get_c_shim_func_name(self, kernel): - if not config.abi_compatible: + if not config.abi_compatible or kernel.startswith("aoti_torch_"): return kernel assert "::" in kernel, "Cpp kernel name: " + kernel + " does not contain '::'" @@ -1376,10 +1387,9 @@ def codegen_shape_tuple(self, shape: Tuple[Expr, ...]) -> str: def codegen_dynamic_scalar(self, node): (data,) = (t.codegen_reference() for t in node.inputs) if config.abi_compatible: - dtype = node.inputs[0].get_dtype() - dtype_str = str(dtype).split(".")[-1] - self.writeline(f"{DTYPE_TO_CPP[dtype]} {node.sym}_raw;") - self.writeline(f"aoti_torch_item_{dtype_str}({data}, &{node.sym}_raw);") + self.codegen_tensor_item( + node.inputs[0].get_dtype(), data, f"{node.sym}_raw" + ) else: convert_type = DTYPE_TO_ATEN[node.inputs[0].get_dtype()].replace( "at::k", "to" @@ -1763,12 +1773,13 @@ def codegen_conditional(self, conditional): outer_outputs.append(out.get_name()) if not isinstance(conditional.predicate, ir.ShapeAsConstantBuffer): - predicate = f"{conditional.predicate.get_name()}_scalar" - self.writeline(f"bool {predicate};") # in ABI-compatible mode, we need to use the ABI shim function # to extract a C++ bool from the unrelying scalar bool Tensor - self.writeline( - f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_item_bool({conditional.predicate.codegen_reference()}, &{predicate}));" + predicate = f"{conditional.predicate.get_name()}_scalar" + self.codegen_tensor_item( + torch.bool, + conditional.predicate.codegen_reference(), + predicate, ) else: # the predicate is not a Tensor: SymBool or Python bool @@ -1847,12 +1858,7 @@ def codegen_while_loop(self, while_loop): if config.abi_compatible: cond_result = f"{cond_result_name}_scalar" - self.writeline(f"bool {cond_result};") - # in ABI-compatible mode, we need to use the ABI shim function - # to extract a C++ bool from the unrelying scalar bool Tensor - self.writeline( - f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_item_bool({cond_result_name}, &{cond_result}));" - ) + self.codegen_tensor_item(torch.bool, cond_result_name, cond_result) else: cond_result = f"{cond_result_name}.item()" self.writeline(f"if (!{cond_result}) break;") diff --git a/torch/_inductor/codegen/simd.py b/torch/_inductor/codegen/simd.py new file mode 100644 index 0000000000000..2a002fa3677f6 --- /dev/null +++ b/torch/_inductor/codegen/simd.py @@ -0,0 +1,1917 @@ +from __future__ import annotations + +import collections +import contextlib +import dataclasses +import functools +import itertools +import logging +import math +import operator +from typing import ( + Any, + Callable, + Counter, + DefaultDict, + Dict, + Iterable, + List, + Optional, + Set, + Tuple, + TYPE_CHECKING, + Union, +) + +import sympy + +import torch +import torch._logging + +from torch.utils._sympy.functions import FloorDiv, ModularIndexing +from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT +from ..._dynamo.utils import counters +from .. import config, ir, scheduler +from ..codecache import code_hash + +from ..dependencies import Dep, MemoryDep, StarDep, WeakDep +from ..ir import TritonTemplateBuffer +from ..optimize_indexing import indexing_dtype_strength_reduction +from ..runtime.hints import ReductionHint, TRITON_MAX_BLOCK +from ..runtime.runtime_utils import get_max_y_grid, green_text, yellow_text +from ..scheduler import BaseSchedulerNode, BaseScheduling, WhyNoFuse +from ..utils import ( + get_dtype_size, + IndentedBuffer, + Placeholder, + sympy_dot, + sympy_index_symbol, + sympy_product, + sympy_subs, + unique, +) +from ..virtualized import V +from .common import CSEVariable, index_prevent_reordering, Kernel, PythonPrinter +from .multi_kernel import MultiKernel + +if TYPE_CHECKING: + pass + +log = logging.getLogger(__name__) +perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints") +schedule_log = torch._logging.getArtifactLogger(__name__, "schedule") +fusion_log = torch._logging.getArtifactLogger(__name__, "fusion") + + +pexpr = PythonPrinter().doprint + + +@dataclasses.dataclass +class IndexingOptions: + index_str: str + mask_vars: Set[sympy.Symbol] + mask_str: str + expand_str: Optional[str] + _has_rindex: bool + index: sympy.Expr + + def has_mask(self): + return bool(self.mask_vars) + + def has_rindex(self): + return self._has_rindex + + def has_tmpmask(self): + return "tmp" in self.mask_str + + def has_rmask(self): + return "rmask" in self.mask_str + + +@dataclasses.dataclass +class IterationRanges: + """ + Each range tree represents multiple sets of iteration indexing + in a single tiled dimension in the output kernel. + + If you have two loops ranges one (4, 3, 2) and another (4, 6), + then the range tree will be: + 4 (i0) + 3 (i1) 6 (i3) + 2 (i2) + Where i0 is shared between both loops, but then the split into + different indexing vars. All loop ranges must iterate over + the same number of elements. + """ + + def __init__( + self, + name: str, + var_list: List[sympy.Symbol], + var_ranges: Dict[sympy.Symbol, sympy.Expr], + numel: sympy.Expr, + prefix: str, + *, + kernel: SIMDKernel, + divisor=sympy.Integer(1), + length=sympy.Integer(1), + root: IterationRangesRoot, + ): + super().__init__() + self.name = name + self.var_list = var_list + self.var_ranges = var_ranges + self.numel = numel + self.prefix = prefix + self.divisor = divisor + self.length = length + self.kernel = kernel + self.root = root + + def symbol(self): + return sympy_index_symbol(self.name) + + +class IterationRangesRoot(IterationRanges): + def __init__( + self, + name: str, + numel: sympy.Expr, + # TODO: this is probably SymTy.INDEX and SymTy.RINDEX + prefix: str, + index: int, + kernel: SIMDKernel, + pid_cache=None, + *, + is_loop: bool, + tensor_dim: Optional[int], + grid_dim: Optional[int], + has_zdim: bool, + ): + if pid_cache is None: + pid_cache = {} + super().__init__( + name=name, + var_list=[], + var_ranges={}, + numel=numel, + prefix=prefix, + kernel=kernel, + root=self, + ) + self.index = index + # Store all the nodes in one flat list + self.nodes: Dict[sympy.Expr, IterationRangesEntry] = {} + # This is for re-ordering program ID in triton mm template + # pid_cache["tl.program_id(0)"] = pid_m + self.pid_cache: Dict[str, str] = pid_cache + + # True if the dimension is implemented as a single program looping over + # the full dimension (currently only used for non-persistent reduction) + assert not is_loop or (prefix == "r" and grid_dim is None) + self.is_loop = is_loop + # Index of corresponding dimension on triton tensors + self.tensor_dim = tensor_dim + # Index of corresponding dimension in the triton grid + self.grid_dim = grid_dim + self.has_zdim = has_zdim + + def __repr__(self): + return f"IterationRangesRoot({self.name!r}, {self.numel}, ...)" + + def cache_clear(self): + for node in self.nodes.values(): + node.cache_clear() + + def lookup(self, divisor, length): + """ + Lookup a given RangeTreeEntry, creating it if needed + """ + if V.graph.sizevars.statically_known_equals(divisor * length, self.numel): + expr = FloorDiv(sympy_index_symbol(f"{self.prefix}index"), divisor) + else: + expr = ModularIndexing( + sympy_index_symbol(f"{self.prefix}index"), divisor, length + ) + + if expr not in self.nodes: + node = IterationRangesEntry( + f"{self.prefix}{next(V.kernel.iter_vars_count)}", + divisor, + length, + expr, + self, + ) + V.kernel.range_tree_nodes[node.symbol()] = node + self.var_list.append(node.symbol()) + self.var_ranges[node.symbol()] = length + self.nodes[expr] = node + return self.nodes[expr] + + def construct_entries(self, lengths: List[sympy.Expr]): + divisor = sympy.Integer(1) + itervars = [] + for length in reversed(lengths): + itervars.append(self.lookup(divisor, length)) + divisor = divisor * length + return list(reversed(itervars)) + + def construct(self, lengths: List[sympy.Expr]): + return [e.symbol() for e in self.construct_entries(lengths)] + + def vars_and_sizes(self, index: sympy.Expr): + """Figure out vars from this tree used in index""" + nodes = [V.kernel.range_tree_nodes.get(s) for s in index.free_symbols] + nodes = [n for n in nodes if n and n.prefix == self.prefix] + nodes.sort(key=lambda x: V.graph.sizevars.size_hint(x.divisor)) + divisor = sympy.Integer(1) + index_vars = [] + sizes = [] + + def add(node): + nonlocal divisor + index_vars.append(node.symbol()) + sizes.append(node.length) + divisor = divisor * node.length + + for node in nodes: + if not V.graph.sizevars.statically_known_equals(node.divisor, divisor): + # fill in unused index var + add(self.lookup(divisor, FloorDiv(node.divisor, divisor))) + divisor = node.divisor + add(node) + if not V.graph.sizevars.statically_known_equals(self.numel, divisor): + # fill in unused index var + add(self.lookup(divisor, FloorDiv(self.numel, divisor))) + + return list(reversed(index_vars)), list(reversed(sizes)) + + def ranges_code(self): + assert self.tensor_dim is not None + size = self.kernel.indexing_size_str(self.tensor_dim) + index_dtype = self.kernel.index_dtype + convert = f".to({index_dtype})" if index_dtype != "tl.int32" else "" + return f"tl.arange(0, {self.prefix.upper()}BLOCK){size}{convert}" + + def scalar_code(self, value): + index_dtype = self.kernel.index_dtype + ndim = self.kernel.triton_tensor_ndim() + size = [1] * ndim + return f"tl.full({size}, {value}, {index_dtype})" + + def get_pid(self): + assert self.grid_dim is not None + key = f"tl.program_id({self.grid_dim})" + # y_grid has a limit, so express it in terms of y and z in case of overflow. + # z grid is only exercised when max_tiles == 3 (off by default). + if ( + self.grid_dim == 1 + and not self.has_zdim + and not (isinstance(self.numel, int) and self.numel <= get_max_y_grid()) + ): + key = f"{key} * (tl.program_id({self.grid_dim + 1}) + 1)" + pid = self.pid_cache.get(key, key) + if self.kernel.index_dtype != "tl.int32": + return f"{pid}.to({self.kernel.index_dtype})" + return pid + + def codegen_header(self, code): + x = self.prefix + if self.is_loop: + code.writeline(f"{self.name} = {x}offset + {x}base") + elif self.grid_dim is None: + # no need to "{x}offset = " + code.writeline(f"{self.name} = {self.ranges_code()}") + code.writeline(f"{x}offset = 0") + else: + if self.tensor_dim is not None: + line = f"{x}offset + {self.ranges_code()}" + else: + line = self.scalar_code(f"{x}offset") + code.writelines( + [ + f"{x}offset = {self.get_pid()} * {x.upper()}BLOCK", + f"{self.name} = {line}", + ] + ) + code.writeline(f"{x}mask = {self.name} < {x}numel") + + +class IterationRangesEntry(IterationRanges): + def __init__( + self, + name: str, + divisor: sympy.Expr, + length: sympy.Expr, + expr: sympy.Expr, + parent: IterationRanges, + ): + super().__init__( + name=name, + numel=parent.numel / length, + var_list=parent.var_list, + var_ranges=parent.var_ranges, + prefix=parent.prefix, + divisor=divisor, + length=length, + kernel=parent.kernel, + root=parent.root, + ) + self.parent = parent + self.codegen = functools.lru_cache(None)(self._codegen) + self.expr = expr + + def __repr__(self): + return f"IterationRangesEntry({self.name}, {self.divisor}, {self.length}, {self.expr}, {self.var_ranges})" + + def set_name(self, name): + self.codegen = lambda: name # type: ignore[assignment] + self.codegen.cache_clear = lambda: None # type: ignore[method-assign] + self.name = name + + def cache_clear(self): + self.codegen.cache_clear() + + def _codegen(self): + V.kernel.codegen_iteration_ranges_entry(self) + return self.name + + def precomputed_args(self): + # for dynamic shapes, find parts of indexing expressions that have to be precomputed + precomputed_args: List[sympy.Expr] = [] + if isinstance(self.expr, sympy.Symbol): + return precomputed_args + assert isinstance(self.expr, (FloorDiv, ModularIndexing)), type(self.expr) + for arg in self.expr.args[1:]: + if not isinstance(arg, (sympy.Integer, sympy.Symbol)): + symbols = arg.free_symbols + if len(symbols) > 0 and all( + symbol_is_type(s, SymT.SIZE) for s in symbols + ): + precomputed_args.append(arg) + return precomputed_args + + def __hash__(self): + return hash(self.name) + + def __eq__(self, other): + return self.name == other.name + + +def triton_constant(value): + if value == float("inf"): + return 'float("inf")' + elif value == float("-inf"): + return 'float("-inf")' + elif math.isnan(value): + return 'float("nan")' + return repr(value) + + +class SIMDKernel(Kernel): + """ + Common base class for Triton/Halide codegen which both use flattened indexing rather than loop nests. + """ + + sexpr = pexpr + kexpr: Callable[[sympy.Expr], str] + allow_block_ptr = False + + def __init__( + self, + *groups, + index_dtype: str, + mutations: Optional[Set[str]] = None, + pid_cache=None, + reduction_hint=ReductionHint.DEFAULT, + disable_persistent_reduction=False, + ): + if pid_cache is None: + pid_cache = {} + super().__init__() + self.body = IndentedBuffer() + self.indexing_code = IndentedBuffer() + self.numels = [V.graph.sizevars.simplify(s) for s in groups] + self.mutations: Set[str] = mutations if mutations is not None else set() + self.range_trees: List[IterationRangesRoot] = [] + self.range_tree_nodes: Dict[sympy.Symbol, IterationRangesEntry] = {} + self.iter_vars_count = itertools.count() + self.inside_reduction = self.numels[-1] != 1 + self.reduction_hint = reduction_hint + self.index_dtype: str = index_dtype + self.last_usage: Set[str] = set() + self.buf_accesses: DefaultDict[str, List[Dep]] = collections.defaultdict(list) + self.persistent_reduction: bool = ( + not disable_persistent_reduction + ) and self.should_use_persistent_reduction() + self.no_x_dim = self.want_no_x_dim() + self.code_hash = None + + # define this in a closure to make cache local to object + @functools.lru_cache(None) + def simplify_indexing(index: sympy.Expr): + index = V.graph.sizevars.simplify_with_ranges(index, self.var_ranges()) + for tree in self.range_trees: + index = self.combine_contiguous_dims(index, tree) + return index + + self.simplify_indexing = simplify_indexing + self.initialize_range_tree(pid_cache) + + def want_no_x_dim(self): + return False + + def initialize_range_tree(self, pid_cache): + no_r_dim = not self.inside_reduction or self.numels[-1] == 1 + + prefixes = "zyxr" + active_prefixes = prefixes[-len(self.numels) :] + + grid_dims = "xyz" + if self.no_x_dim: + tensor_dims = "r" + elif no_r_dim: + tensor_dims = "xyz" + else: + tensor_dims = "xyzr" + + tensor_dims = "".join(p for p in tensor_dims if p in active_prefixes) + + for i, prefix in enumerate(active_prefixes): + is_reduction = prefix == "r" + tensor_dim = tensor_dims.find(prefix) if prefix in tensor_dims else None + grid_dim = None if is_reduction else grid_dims.find(prefix) + index = i if grid_dim is None else grid_dim + self.range_trees.append( + IterationRangesRoot( + f"{prefix}index", + self.numels[i], + prefix, + index, + self, + pid_cache=pid_cache, + is_loop=is_reduction and not self.persistent_reduction, + tensor_dim=tensor_dim, + grid_dim=grid_dim, + has_zdim="z" in active_prefixes, + ) + ) + + def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable): + prior = self.inside_reduction + self.inside_reduction = False + try: + return self.store(name, index, value) + finally: + self.inside_reduction = prior + + def should_use_persistent_reduction(self) -> bool: + return False # defined in subclass + + def var_ranges(self): + return dict( + itertools.chain.from_iterable( + tree.var_ranges.items() for tree in self.range_trees + ) + ) + + def triton_tensor_ndim(self): + return sum(int(tree.tensor_dim is not None) for tree in self.range_trees) + + def indexing_size_str(self, i): + sizes = ["None"] * self.triton_tensor_ndim() + sizes[i] = ":" + return f"[{', '.join(sizes)}]" + + def dense_size_list(self) -> List[str]: + sizes = ["1"] * self.triton_tensor_ndim() + for tree in self.range_trees: + if tree.tensor_dim is None: + continue + + if tree.prefix != "r" or self.inside_reduction: + sizes[tree.tensor_dim] = f"{tree.prefix.upper()}BLOCK" + return sizes + + def dense_size_str(self): + sizes = self.dense_size_list() + return f"[{', '.join(sizes)}]" + + def combine_contiguous_dims(self, index: sympy.Expr, tree: IterationRangesRoot): + """ + More aggressive simplification to merge contiguous dims + """ + if isinstance(index, (sympy.Integer, sympy.Symbol)): + return index + index_vars, sizes = tree.vars_and_sizes(index) + if len(sizes) <= 1: + return index + new_sizes, reindex, prune = V.graph.sizevars._simplify_loops( + index_vars, sizes, index_prevent_reordering([index], index_vars, sizes) + ) + if new_sizes == sizes: + return index + new_index_vars = tree.construct(new_sizes) + new_index = sympy_subs(index, dict(zip(index_vars, reindex(new_index_vars)))) + return new_index + + def set_last_usage(self, nodes): + if not self.inside_reduction or self.persistent_reduction: + return + self.last_usage = set( + itertools.chain.from_iterable( + n.last_usage for n in nodes if n is not EnableReduction + ) + ) + + def disable_reduction(self): + should_flush = self.range_trees[-1].is_loop + + @contextlib.contextmanager + def ctx(): + if self.numels[-1] == 1: + assert not self.inside_reduction + yield + return + if should_flush: + # calling codegen_body() will flush all the pending buffers + # and write out a reduction loop + self.codegen_body() + self.inside_reduction = False + try: + yield + if should_flush: + # flush out any code before opening the next loop + self.codegen_body() + finally: + self.inside_reduction = True + + return ctx() + + def set_ranges(self, *lengths): + assert len(lengths) == len(self.range_trees) + return [ + ranges.construct(length) + for length, ranges in zip(lengths, self.range_trees) + ] + + @staticmethod + def _split_iteration_ranges( + groups: Iterable[sympy.Expr], lengths: List[List[sympy.Expr]] + ): + sv = V.graph.sizevars + new_ranges: List[List[sympy.Expr]] = [[] for _ in groups] + remaining = [sv.simplify(g) for g in groups] + var_count = itertools.count() + + def add_range(i, expr): + expr = sv.simplify(expr) + if not sv.statically_known_multiple_of(remaining[i], expr): + raise CantSplit + # guard on the last item out + remaining[i] = FloorDiv(remaining[i], expr) + new_ranges[i].append(expr) + return next(var_count) + + def make_combined(size, idx1, idx2): + def getter(flat_vars): + return size * flat_vars[idx1] + flat_vars[idx2] + + return getter + + return_getters_groups = [] + current_group = 0 + for length_group in lengths: + return_getters = [] + for size in length_group: + if sv.statically_known_equals(size, 1): # type: ignore[arg-type] + return_getters.append(lambda _: sympy.Integer(0)) + continue + + while current_group < len(remaining) and sv.statically_known_equals( + remaining[current_group], 1 # type: ignore[arg-type] + ): + # scroll to next group with remaining elements + current_group += 1 + + if current_group + 1 < len(remaining) and sv.statically_known_gt( + size, remaining[current_group] + ): + # need to break size in two + if not sv.statically_known_multiple_of( + size, remaining[current_group] + ): + raise CantSplit + size1 = remaining[current_group] + size2 = FloorDiv(size, remaining[current_group]) + return_getters.append( + make_combined( + size2, + add_range(current_group, size1), + add_range(current_group + 1, size2), + ) + ) + else: + return_getters.append( + operator.itemgetter(add_range(current_group, size)) + ) + return_getters_groups.append(return_getters) + + assert all( + V.graph.sizevars.size_hint(s) == 1 for s in remaining + ), f"failed to set ranges {remaining} {lengths}" + + return new_ranges, return_getters_groups + + @classmethod + def is_compatible( + cls, groups: Iterable[sympy.Expr], lengths: List[List[sympy.Expr]] + ): + try: + cls._split_iteration_ranges(groups, lengths) + return True + except CantSplit: + return False + + def split_and_set_ranges(self, lengths: List[List[sympy.Expr]]): + """ + We may want to fuse `for i0 in s0*s1` into a tiled kernel with groups (s0, s1). + + To do this we need to split up the iteration space of i0 into something like: + for i1 in s0: + for i2 in s1: + i0 = i1*s1 + i2 + .... + + This function matches and resplits lengths to the groups of + this kernel to enable tiled + non-tiled fusions. + """ + groups = [rt.numel for rt in self.range_trees] + if not self.inside_reduction: + groups[-1] = sympy.Integer(1) + + if len(lengths) == len(self.range_trees) and all( + V.graph.sizevars.simplify(sympy_product(x) - g) == 0 + for x, g in zip(lengths, groups) + ): + return self.set_ranges(*lengths) + + new_ranges, return_getters_groups = self._split_iteration_ranges( + groups, lengths + ) + itervars = list(itertools.chain.from_iterable(self.set_ranges(*new_ranges))) + return [[fn(itervars) for fn in fns] for fns in return_getters_groups] + + def is_indirect_indexing(self, index: sympy.Expr): + # tmpX means indirect indexing + return free_symbol_is_type(index, SymT.TMP) + + def is_broadcasted(self, index: sympy.Expr): + # Note. This may not be correct when there is indirect indexing + if self.is_indirect_indexing(index): + return False + + index_numels = [1] * len(self.numels) + for symbol in index.free_symbols: + if symbol not in self.range_tree_nodes: + # Non-iterated variables, e.g. strides + continue + entry = self.range_tree_nodes[symbol] # type: ignore[index] + assert isinstance(entry.parent, IterationRangesRoot) + index_numels[entry.parent.index] *= entry.length + + # If the index variables only iterate over a subset of the kernel + # numels, then it must be broadcasted. + simplify = V.graph.sizevars.simplify + return any( + simplify(idx_range) != simplify(iter_range) # type: ignore[arg-type] + for idx_range, iter_range in zip(index_numels, self.numels) + ) + + def index_to_str(self, index: sympy.Expr) -> str: + """ + Convert an index expr to a string that can be used in triton code. + e.g. a sympy expression "s2" may actually appear as "ks1" in the triton kernel. + + Index expressions often need to be passed in as arguments to the triton kernel. + Rename_indexing and codegen_indexing keep track of the needed indices and add + new parameters to the function signature. + """ + if isinstance(index, list): + return f"[{', '.join(map(self.index_to_str, index))}]" + return self.kexpr( # type: ignore[call-arg] + self.rename_indexing(self.codegen_indexing(index)) + ) + + def indexing( + self, + index: sympy.Expr, + *, + copy_shape=None, + dense_indexing=False, + override_mask=None, + block_ptr=False, + ): + """ + Compute the index and mask to pass to tl.load() or tl.store() + """ + index = self.simplify_indexing(index) + index = sympy_subs(index, V.graph.sizevars.precomputed_replacements) + # if simple replacements didn't get rid of floor/ceil, try full subs + if len(index.atoms(sympy.floor)) or len(index.atoms(sympy.ceiling)): + index = index.subs(V.graph.sizevars.precomputed_replacements) + # last resort, if no range vars are in the expr, hoist it + # TODO instead of trying to blindly find complicated exprs, we should hoist the + # inputs/outputs sizes and strides, but at the time indexing is generated + # kernel inputs and outputs are not set yet, we'd need a deeper refactor + # to do it this way + + if len(index.atoms(sympy.ceiling)): + for a in index.atoms(sympy.ceiling): + # for nested exprs, atoms yields top level first (?) + # so if everything goes fine, lower level replacements will come up empty + symbols = a.free_symbols + if len(symbols) > 0 and all( + symbol_is_type(s, (SymT.SIZE, SymT.PRECOMPUTED_SIZE)) + for s in symbols + ): + replacements = {a: V.graph.sizevars.lookup_precomputed_size(a)} + index = sympy_subs(index, replacements) + + index = self.simplify_indexing(index) + index_vars = index.free_symbols + has_rindex = False + + mask_vars: Set[str] = set() + for var in index_vars: + assert isinstance(var, sympy.Symbol) + has_rindex = has_rindex or symbol_is_type(var, SymT.RINDEX) + if override_mask: + pass + elif symbol_is_type(var, SymT.TMP): + # indirect indexing + cse_var = self.cse.varname_map[var.name] + mask_vars.update(cse_var.mask_vars) + elif symbol_is_type( + var, + ( + SymT.UNBACKED_INT, + SymT.SIZE, + SymT.PRECOMPUTED_SIZE, + SymT.INDEX, + SymT.FLOAT, + SymT.UNBACKED_FLOAT, + ), + ): + pass + else: + # var is one of xN, yN or rN + assert symbol_is_type( + var, (SymT.RINDEX, SymT.XBLOCK, SymT.YBLOCK) + ), var.name + mask_vars.add(f"{var.name[0]}mask") + + need_dense = ( + config.triton.dense_indexing + or dense_indexing + or self._load_mask is not None + ) and index != 0 + + have_dense = True + have_loop_vars = False + dense_mask_vars = set() + + for tree in self.active_range_trees(): + if index_vars.intersection(tree.var_list): + have_loop_vars = True + else: + have_dense = False + dense_mask_vars.add(f"{tree.prefix}mask") + + if ( + block_ptr + and self.allow_block_ptr + and config.triton.use_block_ptr + and not override_mask + and not self._load_mask + and len(mask_vars - dense_mask_vars) == 0 + and not self.is_indirect_indexing(index) + and have_loop_vars + # workaround https://github.com/openai/triton/issues/2821 + and self.index_dtype == "tl.int32" + ): + index_relative_to_xyr_index = sympy_subs( + index, {v: t.expr for v, t in self.range_tree_nodes.items()} + ) + range_trees = self.active_range_trees(reorder=True) + symbols = [t.symbol() for t in range_trees] + strides = [sympy.Wild(f"stride_{s}", exclude=symbols) for s in symbols] + offset = sympy.Wild("_offset", exclude=symbols) + m = index_relative_to_xyr_index.match(sympy_dot(symbols, strides) + offset) + # TODO(jansel): it is sometimes possible to do higher dimensional block_ptrs with + # a tl.reshape the correct block. We will miss these cases today. + if m: + self.filter_masks(mask_vars) + from .triton import BlockPtrOptions + + return BlockPtrOptions.create( + [m[s] for s in strides], + m[offset], + range_trees, + mask_vars, # type: ignore[arg-type] + ) + + expand_str = None + index_str = self.index_to_str(index) + if isinstance(index, sympy.Integer): + expand_str = f"{copy_shape}.shape" if copy_shape else self.dense_size_str() + index_str = f"tl.full({expand_str}, {index_str}, tl.int32)" + return IndexingOptions( + index_str, set(), "None", expand_str, has_rindex, index + ) + + if need_dense and not have_dense: + expand_str = f"{copy_shape}.shape" if copy_shape else self.dense_size_str() + index_str = f"tl.broadcast_to({index_str}, {expand_str})" + mask_vars = dense_mask_vars + elif not have_loop_vars and copy_shape: + index_str = f"tl.broadcast_to({index_str}, {copy_shape}.shape)" + mask_vars = dense_mask_vars + + if override_mask: + mask_vars = {override_mask} + + if self._load_mask: + mask_vars.add(self._load_mask) + + self.filter_masks(mask_vars) + + mask_str = " & ".join(sorted(map(str, mask_vars))) if mask_vars else "None" + return IndexingOptions(index_str, mask_vars, mask_str, expand_str, has_rindex, index) # type: ignore[arg-type] + + def active_range_trees(self, reorder=False): + trees = [ + t for t in self.range_trees if t.prefix != "r" or self.inside_reduction + ] + if reorder and len(trees) > 1: + count = sum(t.prefix in "xyz" for t in trees) + assert "".join(t.prefix for t in trees[:count]) == "zyx"[-count:], [ + t.prefix for t in trees[:count] + ] + trees[:count] = reversed(trees[:count]) + return trees + + def filter_masks(self, mask_vars): + for tree in self.range_trees: + # Masks are superfluous if we only have one element + if V.graph.sizevars.statically_known_equals(tree.numel, 1): # type: ignore[arg-type] + mask_vars.discard(f"{tree.prefix}mask") + continue + # Masks are superfluous if numel is a multiple of BLOCK + # (We use the fact that BLOCK is required by triton to be a power of 2) + if tree.prefix.upper() not in TRITON_MAX_BLOCK: + continue + max_block = TRITON_MAX_BLOCK[tree.prefix.upper()] + # Optional optimization: if block divides numel exactly, we will + # never need to do a masked load to handle stragglers at the end. + # It's faster to avoid masking at all. But it is sound to always + # mask. + if V.graph.sizevars.statically_known_multiple_of(tree.numel, max_block): # type: ignore[arg-type] + mask_vars.discard(f"{tree.prefix}mask") + + def codegen_indexing(self, expr: sympy.Expr): + expr = V.graph.sizevars.simplify_with_ranges(expr, self.var_ranges()) + for sym in sorted(expr.free_symbols, key=str): + if sym in self.range_tree_nodes: + # if indexing expression is complicated, we precompute it on the host side + # and send the result as a kernel argument + replacements = {} + for ps in self.range_tree_nodes[sym].precomputed_args(): # type: ignore[index] + replacements[ps] = V.graph.sizevars.lookup_precomputed_size(ps) + if len(replacements) > 0: + self.range_tree_nodes[sym].expr = sympy_subs( # type: ignore[index] + self.range_tree_nodes[sym].expr, replacements # type: ignore[index] + ) + self.range_tree_nodes[sym].codegen() # type: ignore[index] + return expr + + @contextlib.contextmanager + def mask_loads(self, mask): + """Context manager to add an additional mask to tl.load/store""" + prior = self._load_mask + if prior: + mask = self.cse.generate(self.compute, f"{mask} & {prior}") + + self._load_mask = mask + try: + # TODO(jansel): do we need a reshape here? + yield mask + finally: + self._load_mask = prior + + def load_mask(self, var): + mask = "" + mask_vars = set(var.mask_vars) + if self._load_mask: + mask_vars.add(self._load_mask) + + if mask_vars: + mask = ( + f"{next(iter(mask_vars))}" + if len(mask_vars) == 1 + # sorted for deterministic order + else f"({' & '.join(sorted(map(str, mask_vars)))})" + ) + return mask + + def get_strides_of_load(self, index: sympy.Expr): + """ + This gets the stride of the index for each of the tiling variables + (technically, it does it at index 0) + + For example, if + xindex = x0 + 512*x1 + 1024*r0 + x0 = (xindex//512) + x1 = (xindex % 512) + r0 = rindex // 1024 + + this function would return + {xindex: 512, rindex: 1024} + """ + index_to_tile_indexes = {k: v.expr for k, v in self.range_tree_nodes.items()} + index_in_tile_vars = sympy_subs(index, index_to_tile_indexes) # type: ignore[arg-type] + strides = {} + for range_tree in self.range_trees: + s = sympy_index_symbol(range_tree.name) + strides[s] = sympy_subs(index_in_tile_vars, {s: 1}) - sympy_subs( + index_in_tile_vars, {s: 0} + ) + return strides + + @staticmethod + def _map_tuple_or_scalar(fn, value): + if isinstance(value, tuple): + return tuple(map(fn, value)) + return fn(value) + + def estimate_kernel_num_bytes(self): + """ + Try the best to estimate the total size (in bytes) of the + kernel's inputs and outputs, which is used for estimating the memory + throughput of this kernel. This information is used for checking how + far we are from the peak memory bandwidth. It's important that + we want to avoid overestimating the sizes of the inputs and outputs, + because it can wrongfully give us a very large memory traffic value, + which may be even larger than the theoretical bandwidth and thus + become very misleading. This is particularly problematic for cases + where we slice some inputs. In those cases, we should only count + the size of the "slices" instead of the original inputs, because + only the slices contribute to the real memory traffic. + """ + nbytes = [] + ninplace_args = len(unique(self.args.inplace_buffers.values())) + _, call_args, _ = self.args.python_argdefs() + + # For pointwise and reduction kernels, this is the upper-bound numels + # for the output buffer. + # FIXME: This is not exactly right for cases like below: + # def foo(tensor0, tensor1): + # x0 = narrow(tensor0) + # return cat(x0, tensor1) + # For this example, we will end up overestimate the size for the + # slice s0. Potentially, we could have precise inputs information + # if we maintained the original inputs of the Pointwise kernel created + # for the "cat". However, I think it might be a bit overwhelming that + # we add such complexity only for handling some particular cases for + # benchmarking. + out_numel = V.graph.sizevars.size_hint(sympy_product(self.numels)) + for i, arg in enumerate(call_args): + # "buf" may be narrowed. In this case, the number of memory accesses + # should be estimated based on the reinterpreted layout. + # On the other hand, buf may be broadcasted. In this case, + # counting the size of the underline storage would give us + # a better estimation in terms of memory accesses. + if arg not in self.buf_accesses: + nbytes.append(0) + continue + arg_numel = V.graph.get_numel(arg) + buf_size = V.graph.sizevars.size_hint(arg_numel) + if buf_size > out_numel: + # This arg points to a buf that has been sliced. + # We need to count each individual slice to have + # a better estimation. + indices: Set[Any] = set() + no_index_dep_count = 0 + for dep in self.buf_accesses[arg]: + if isinstance(dep, (StarDep, WeakDep)): + indices.add(f"no_index_dep_{no_index_dep_count}") + no_index_dep_count += 1 + else: + indices.add(dep.index) + numel = len(indices) * out_numel + else: + numel = buf_size + dtype = V.graph.get_dtype(arg) + dtype_size = get_dtype_size(dtype) + nbytes.append(numel * dtype_size * (1 + int(i < ninplace_args))) + return sum(nbytes) + + def warn_mix_layout(self, kernel_name): + """ + Print message if the kernel have mixed layout inputs. + Only care about 4D tensor for now. + """ + if ( + len(self.args.input_buffers) == 1 + and len(self.args.output_buffers) == 1 + and len(self.args.inplace_buffers) == 0 + ): + # even if input buffer and output buffer have different layout, + # this can be a layout conversion kernel. No need to warn for + # the mix layouts. + return + + argdefs, call_args, signature = self.args.python_argdefs() + uniform_stride_order = None + for arg_name in call_args: + buf = V.graph.get_buffer(arg_name) + if buf and len(buf.layout.size) == 4: + # ignore the tensor if only 1 dimension is non-zero + if len([x for x in buf.layout.size if x == 1]) == 3: + continue + stride_order = ir.get_stride_order(buf.layout.stride) + if uniform_stride_order is None: + uniform_stride_order = stride_order + elif uniform_stride_order != stride_order: + msg = yellow_text( + f"Expected stride order {uniform_stride_order}, but found stride order" + + f" {stride_order} for kernel {kernel_name}" + ) + log.warning(msg) + + stride_order_list = [ + ir.get_stride_order(V.graph.get_buffer(name).layout.stride) + if V.graph.get_buffer(name) + else None + for name in call_args + ] + size_list = [ + V.graph.get_buffer(name).layout.size + if V.graph.get_buffer(name) + else None + for name in call_args + ] + source_list = [ + "GraphInput" + if name in V.graph.graph_inputs + else "IntermediateBuffer" + if name in V.graph.name_to_buffer + else None + for name in call_args + ] + + msg = yellow_text( + f" param names {argdefs}\n buf names {call_args}\n strides {stride_order_list}" + + f"\n sizes {size_list}\n sources {source_list}\n" + ) + log.warning(msg) + return + msg = green_text( + f"All the inputs for the triton kernel {kernel_name} have uniform layout" + ) + log.warning(msg) + + def codegen_kernel(self): + raise NotImplementedError + + def codegen_body(self): + pass + + def codegen_iteration_ranges_entry(self, entry: IterationRangesEntry): + raise NotImplementedError + + +class SIMDScheduling(BaseScheduling): + kernel_type = SIMDKernel # override in subclass + int32_type = "torch.int32" + int64_type = "torch.int64" + + def __init__(self, scheduler): + self.scheduler = scheduler + + def group_fn(self, sizes): + return tuple(V.graph.sizevars.simplify(sympy_product(s)) for s in sizes) + + def can_fuse(self, node1, node2): + """ + Hook called by Scheduler to determine if the Triton backend + can fuse node1 and node2. These nodes might already be + FusedSchedulerNodes. + """ + if isinstance(node1, scheduler.ForeachKernelSchedulerNode) or isinstance( + node2, scheduler.ForeachKernelSchedulerNode + ): + return scheduler.ForeachKernelSchedulerNode.can_fuse(node1, node2) + + _, (numel1, rnumel1) = node1.group + _, (numel2, rnumel2) = node2.group + why = WhyNoFuse(node1, node2) + + if node1.is_split_scan() and not node2.is_split_scan(): + if node2.is_reduction(): + why("Split scan cannot fuse with reductions") + elif node2.is_split_scan() and not node1.is_split_scan(): + if node1.is_reduction(): + why("Split scan cannot fuse with reductions") + + if node1.is_reduction() and node2.is_reduction(): + reduction_can_fuse = numel1 == numel2 and rnumel1 == rnumel2 + if not reduction_can_fuse: + why( + "numel/rnumel mismatch (reduce) (%s, %s), (%s, %s)", + numel1, + numel2, + rnumel1, + rnumel2, + ) + return reduction_can_fuse + + if not node1.is_reduction() and not node2.is_reduction(): + if not (numel1 == numel2 and rnumel1 == rnumel2): + why( + "numel/rnumel mismatch (non-reduce) (%s, %s), (%s, %s)", + numel1, + numel2, + rnumel1, + rnumel2, + ) + return False + + if node1.is_template(): + # Only allow fusion for TritonTemplates for now. + # Fusion for CUDATemplates are not supported. + is_triton_template = isinstance(node1.node, TritonTemplateBuffer) + if not is_triton_template: + why("node1 is not TritonTemplateBuffer") + return is_triton_template + + # check for a bad combined tiling + tiling1 = self.select_tiling(node1.get_nodes(), numel1, rnumel1) + tiling2 = self.select_tiling(node2.get_nodes(), numel1, rnumel1) + tiling3 = self.select_tiling( + node1.get_nodes() + node2.get_nodes(), numel1, rnumel1 + ) + if config.triton.tiling_prevents_pointwise_fusion: + cond = True + if len(tiling1) > 2: + if len(tiling2) > 2: + cond = tiling1 == tiling2 == tiling3 + else: + cond = tiling1 == tiling3 + elif len(tiling2) > 2: + cond = tiling2 == tiling3 + if not cond: + why( + "tiling mismatch (%s, %s, %s)", + tiling1, + tiling2, + tiling3, + ) + return False + + return True + + if not node1.is_reduction() and node2.is_reduction(): + assert rnumel1 == 1 and rnumel2 != 1 + if numel1 == numel2 * rnumel2: + if not all( + SIMDKernel.is_compatible((numel2, rnumel2), n.get_ranges()) + for n in node1.get_nodes() + ): + why("nodes numel/rnumel incompatibility") + return False + if ( + config.triton.tiling_prevents_reduction_fusion + and not node1.is_template() + ): + is_reduction_tiling_valid = self.select_tiling( + node1.get_nodes(), numel1 + ) in ( + (numel1, 1), + (numel2, rnumel2, 1), + ) + if not is_reduction_tiling_valid: + why("invalid tiling for reduction") + return is_reduction_tiling_valid + return True + + if numel1 != numel2: + why("nodes numel incompatibility") + return numel1 == numel2 + + assert node1.is_reduction() and not node2.is_reduction() + # swap args to hit the case above + return self.can_fuse_horizontal(node2, node1) + + can_fuse_vertical = can_fuse + can_fuse_horizontal = can_fuse + + def generate_node_schedule(self, nodes, numel, rnumel): + node_schedule: List[Any] = [] + current_loop_writes: Set[str] = set() + + # Writes with a reduced shape, meaning they are only present once the + # reduction loop has ended + current_loop_reduced_writes = set() + current_loop_has_writes = False + done = set() + + def fits_in_main_body(n): + _, (node_numel, node_rnumel) = n.group + return (node_numel == numel and node_rnumel == rnumel) or ( + node_numel == numel * rnumel and node_rnumel == 1 + ) + + def fits_outside_reduction(n): + _, (node_numel, node_rnumel) = n.group + return node_numel == numel and node_rnumel == 1 and rnumel != 1 + + def schedule_node_in_loop(n): + nonlocal current_loop_has_writes + done.add(n) + node_schedule.append(n) + current_loop_has_writes = True + # A scan is modelled as a reduction in the scheduler but has a + # full sized output that can be used inside the loop body + if ( + n.is_reduction() + and isinstance(n, scheduler.SchedulerNode) + and isinstance(n.node, ir.ComputedBuffer) + and not isinstance(n.node.data, ir.Scan) + ): + current_loop_reduced_writes.add(n.get_name()) + + @contextlib.contextmanager + def end_current_reduction_loop(): + nonlocal current_loop_has_writes + if current_loop_has_writes: + # flush out any other runnable nodes to reduce number of loops + for other_node in nodes[index + 1 :]: + if ( + node not in done + and fits_in_main_body(other_node) + and not (current_loop_reduced_writes & other_node.ancestors) + ): + schedule_node_in_loop(node) + + if node_schedule and node_schedule[-1] is EnableReduction: + node_schedule.pop() + else: + node_schedule.append(DisableReduction) + yield + node_schedule.append(EnableReduction) + current_loop_reduced_writes.clear() + current_loop_has_writes = False + + for index, node in enumerate(nodes): + if node in done: + continue + done.add(node) + + def requires_closing_previous_reduction(node, node_schedule): + if rnumel == 1: + return False + if not current_loop_reduced_writes & node.ancestors: + return False + assert node_schedule and not isinstance( + node_schedule[-1], (EnableReduction, DisableReduction) + ) + return bool(current_loop_reduced_writes) + + if fits_in_main_body(node): + if requires_closing_previous_reduction(node, node_schedule): + with end_current_reduction_loop(): + pass # need to start a new reduction loop + + schedule_node_in_loop(node) + elif fits_outside_reduction(node): + with end_current_reduction_loop(): + node_schedule.append(node) + else: + raise NotImplementedError( + f"unexpected group: ({numel}, {rnumel}) != {node.group[1]}" + ) + + return node_schedule + + def codegen_node( + self, node: Union[scheduler.FusedSchedulerNode, scheduler.SchedulerNode] + ): + """ + Given a set of pre-fused nodes, generate a Triton kernel. + """ + + nodes: List[scheduler.SchedulerNode] = node.get_nodes() # type: ignore[assignment] + + _, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group + + node_schedule = self.generate_node_schedule(nodes, numel, rnumel) + buf_accesses = collections.defaultdict(list) + for node in nodes: + for access in node.read_writes.reads | node.read_writes.writes: + buf_accesses[access.name].append(access) + + schedule_log.debug("Schedule:\n %s", node_schedule) + + return self.codegen_node_schedule(node_schedule, buf_accesses, numel, rnumel) + + @staticmethod + def reduction_hint(node): + assert node.is_reduction() + if all( + dep.is_contiguous() + for dep in itertools.chain(node.read_writes.reads, node.read_writes.writes) + ): + return ReductionHint.INNER + else: + return node.node.data.reduction_hint + + @staticmethod + def can_use_32bit_indexing( + numel: sympy.Expr, buffers: Iterable[Union[ir.Buffer, ir.TensorBox]] + ) -> bool: + int_max = torch.iinfo(torch.int32).max + size_hint = V.graph.sizevars.size_hint + has_hint = V.graph.sizevars.shape_env.has_hint + + def within_32bit(e): + # Allow for unhinted e as long as we can still statically prove + # (e.g., via ValueRanges) that it is still in bounds + if V.graph.sizevars.is_expr_static_and_true(e <= int_max): + return True + # Otherwise, the hint MUST exist and be in range + return has_hint(e) and size_hint(e) <= int_max + + if not within_32bit(numel): + return False + + # Any use of a MultiOutputLayout will create a buffer with a + # Layout whose sizes are accounted for + buf_sizes = [ + buf.get_layout().storage_size() + for buf in buffers + if not isinstance(buf.get_layout(), ir.MultiOutputLayout) + ] + + if not all(within_32bit(size) for size in buf_sizes): + return False + + # Only install guards for 32-bit indexing as there is no correctness + # issue with using 64-bit for everything + V.graph.sizevars.guard_leq(numel, int_max) # type: ignore[arg-type] + for size in buf_sizes: + V.graph.sizevars.guard_leq(size, int_max) # type: ignore[arg-type] + return True + + @classmethod + def select_index_dtype(cls, node_schedule, numel, reduction_numel): + # Gather all used buffer names + buffer_names = set() + for node in node_schedule: + if not isinstance(node, scheduler.BaseSchedulerNode): + continue + + buffer_names.update(node.get_names()) + buffer_names.update(node.used_buffer_names()) + + # Get buffers objects + + def _get_buffer(name: str) -> Union[ir.Buffer, ir.TensorBox]: + buf = V.graph.get_buffer(name) + if buf is None: + raise RuntimeError(f"Failed to find buffer matching name {name}") + return buf + + buffers = [V.graph.get_buffer(name) for name in buffer_names] + + # In theory we can separately check xnumel and rnumel are <= int_max + # but some indexers do use the full linear index so we need to be + # conservative here. + total_numel = numel * reduction_numel + + if SIMDScheduling.can_use_32bit_indexing(total_numel, buffers): + return cls.int32_type + return cls.int64_type + + def has_non_contiguous_pw_in_reduction_kernel(self, node_schedule, numel, rnumel): + pointwise_nodes = list( + filter( + lambda n: n not in (EnableReduction, DisableReduction) + and not n.is_reduction() + and n.group[1][0] == numel * rnumel, + node_schedule, + ) + ) + for node in pointwise_nodes: + # An index can be an integer when loading a random seed. + if not all( + not isinstance(dep, MemoryDep) + or dep.is_contiguous() + or isinstance(dep.index, (sympy.Integer, int)) + or dep.stride1_for_last_dim() + for dep in itertools.chain( + node.read_writes.reads, node.read_writes.writes + ) + ): + return True + return False + + def get_kernel_args(self, node_schedule, numel, reduction_numel): + reductions = list( + filter( + lambda n: n not in (EnableReduction, DisableReduction) + and n.is_reduction(), + node_schedule, + ) + ) + if len(reductions) > 0: + hints = [self.reduction_hint(n) for n in reductions] + if hints.count(hints[0]) == len(hints): + reduction_hint_val = hints[0] + else: + reduction_hint_val = ReductionHint.DEFAULT + + if ( + reduction_hint_val == ReductionHint.INNER + and self.has_non_contiguous_pw_in_reduction_kernel( + node_schedule, numel, reduction_numel + ) + ): + reduction_hint_val = ReductionHint.DEFAULT + else: + reduction_hint_val = ReductionHint.DEFAULT + + mutations = set() + for node in node_schedule: + if hasattr(node, "get_mutations"): + mutations.update(node.get_mutations()) + + index_dtype = self.select_index_dtype(node_schedule, numel, reduction_numel) + + return reduction_hint_val, mutations, index_dtype + + def codegen_node_schedule( + self, node_schedule, buf_accesses, numel, reduction_numel + ): + from torch._inductor.codegen.triton_split_scan import TritonSplitScanKernel + + tiled_groups = self.select_tiling(node_schedule, numel, reduction_numel) + ( + reduction_hint_val, + mutations, + index_dtype, + ) = self.get_kernel_args(node_schedule, numel, reduction_numel) + + is_split_scan = any( + isinstance(node, BaseSchedulerNode) and node.is_split_scan() + for node in node_schedule + ) + kernel_type = TritonSplitScanKernel if is_split_scan else self.kernel_type + kernel_args = tiled_groups + kernel_kwargs = { + "reduction_hint": reduction_hint_val, + "mutations": mutations, + "index_dtype": index_dtype, + } + kernel = kernel_type( + *kernel_args, + **kernel_kwargs, + ) + kernel.buf_accesses = buf_accesses + + self.codegen_node_schedule_with_kernel(node_schedule, kernel) + + with V.set_kernel_handler(kernel): + src_code = kernel.codegen_kernel() + + kernel_name = self.define_kernel(src_code, node_schedule, kernel) + log.debug("Generating kernel code with kernel_name: %s", kernel_name) + kernel.kernel_name = kernel_name + kernel.code_hash = code_hash(src_code) + + if kernel.persistent_reduction and config.triton.multi_kernel: + kernel2 = self.kernel_type( + *kernel_args, + **kernel_kwargs, + disable_persistent_reduction=True, + ) + self.codegen_node_schedule_with_kernel(node_schedule, kernel2) + with V.set_kernel_handler(kernel2): + src_code2 = kernel2.codegen_kernel() + kernel_name2 = self.define_kernel(src_code2, node_schedule, kernel) + kernel2.kernel_name = kernel_name2 + kernel2.code_hash = code_hash(src_code2) + + final_kernel = MultiKernel([kernel, kernel2]) + else: + final_kernel = kernel # type: ignore[assignment] + + with V.set_kernel_handler(final_kernel): + for node in node_schedule: + if node not in (EnableReduction, DisableReduction): + node.mark_run() + + self.codegen_comment(node_schedule) + final_kernel.call_kernel(final_kernel.kernel_name) + if config.nan_asserts: + final_kernel.codegen_nan_check() + if config.warn_mix_layout: + final_kernel.warn_mix_layout(kernel_name) + + V.graph.removed_buffers |= final_kernel.removed_buffers + V.graph.inplaced_to_remove |= final_kernel.inplaced_to_remove + + if ( + V.graph.wrapper_code.supports_intermediate_hooks + and config.generate_intermediate_hooks + ): + # Not every node in the schedule will actually be live on output; + # we can't check dead buffers. + live_outs = kernel.args.live_output_buffers() + for node in node_schedule: + if not isinstance(node, scheduler.BaseSchedulerNode): + continue + name = node.get_name() + if name not in live_outs: + continue + origin_node = node.node.get_origin_node() + if origin_node is not None: + counters["inductor"]["intermediate_hooks"] += 1 + V.graph.wrapper_code.writeline( + f"run_intermediate_hooks({origin_node.name!r}, {name})" + ) + + self.scheduler.free_buffers() + + def codegen_node_schedule_with_kernel(self, node_schedule, kernel): + def current_reduction_nodes(nodes): + return itertools.takewhile(lambda n: n is not DisableReduction, nodes) + + with kernel: + stack = contextlib.ExitStack() + kernel.set_last_usage(current_reduction_nodes(node_schedule)) + + for node in node_schedule: + if node not in (EnableReduction, DisableReduction): + node.decide_inplace_update() + for i, node in enumerate(node_schedule): + if node is DisableReduction: + stack.enter_context(kernel.disable_reduction()) + elif node is EnableReduction: + stack.close() + kernel.set_last_usage(current_reduction_nodes(node_schedule[i:])) + else: + # TODO - use split ranges ? + indexing_dtype_strength_reduction(node._body) + index_vars = kernel.split_and_set_ranges(node.get_ranges()) + node.codegen(index_vars) + + def codegen_template( + self, template_node, epilogue_nodes, only_gen_src_code=False + ) -> Optional[str]: + """ + Codegen a triton template + + If `only_gen_src_code` the src code will be returned instead of codegen'd into the wrapper + """ + _, (numel, rnumel) = template_node.group + assert rnumel == 1 + kernel, render = template_node.node.make_kernel_render(template_node.node) + with kernel: + if not only_gen_src_code: + for node in [template_node, *epilogue_nodes]: + node.mark_run() + partial_code = render() + for node in epilogue_nodes: + node.codegen(kernel.split_and_set_ranges(node.get_ranges())) + + # finalize must be called after adding epilogue above + with V.set_kernel_handler(kernel): + # TODO: Maybe unify CUDATemplateKernel to also use PartialRender for flexible epilogue fusion. + src_code = ( + partial_code + if isinstance(partial_code, str) + else partial_code.finalize() + ) + node_schedule = [template_node, *epilogue_nodes] + + if config.benchmark_kernel: + num_gb = kernel.estimate_kernel_num_bytes() / 1e9 + grid_args = V.graph.sizevars.size_hints(kernel.call_sizes) + assert kernel.meta is not None, "meta is None" + grid = kernel.grid_fn(*grid_args, kernel.meta) + src_code = ( + f"{kernel.imports_for_benchmark_kernel()}\n" + f"{src_code}\n" + f"{kernel.codegen_kernel_benchmark(num_gb, grid).getvalue()}" + ) + + if only_gen_src_code: + return src_code + + kernel_name = self.define_kernel(src_code, node_schedule, kernel) + + self.codegen_comment(node_schedule) + kernel.call_kernel(kernel_name, template_node.node) + V.graph.removed_buffers |= kernel.removed_buffers + V.graph.inplaced_to_remove |= kernel.inplaced_to_remove + self.scheduler.free_buffers() + return None + + def codegen_sync(self): + V.graph.wrapper_code.writeline(V.graph.device_ops.synchronize()) + + def codegen_foreach(self, foreach_node): + from .triton_foreach import ForeachKernel + + for partitions_with_metadata in ForeachKernel.horizontal_partition( + foreach_node.get_subkernel_nodes(), self + ): + kernel = ForeachKernel() + for nodes, tiled_groups, numel, rnumel in partitions_with_metadata: + node_schedule = self.generate_node_schedule(nodes, numel, rnumel) + ( + reduction_hint_val, + mutations, + index_dtype, + ) = self.get_kernel_args(node_schedule, numel, rnumel) + + subkernel = kernel.create_sub_kernel( + *tiled_groups, + reduction_hint=reduction_hint_val, + mutations=mutations, + index_dtype=index_dtype, + ) + + self.codegen_node_schedule_with_kernel( + node_schedule, + subkernel, + ) + + with V.set_kernel_handler(subkernel): + for node in node_schedule: + if node not in (EnableReduction, DisableReduction): + node.mark_run() + V.graph.removed_buffers |= subkernel.removed_buffers + V.graph.inplaced_to_remove |= subkernel.inplaced_to_remove + + src_code = kernel.codegen_kernel() + kernel_name = self.define_kernel(src_code, [foreach_node], kernel) + self.codegen_comment([foreach_node]) + kernel.call_kernel(V.graph.wrapper_code, kernel_name) + + self.scheduler.free_buffers() + + @staticmethod + @functools.lru_cache(32) + def candidate_tilings(node): + ranges, reduction_ranges = node.get_ranges() + if len(ranges) <= 1: + return () + + rw = node.pointwise_read_writes() + assert len(rw.range_vars) == len(ranges) + + # isinstance(dep, MemoryDep): this filters out StarDeps. StarDeps refer to reads + # that need to access the entire tensor; they don't contribute read indexing + # information (and practically, they don't have dep.index so they can't be used + # for stride_hints below + dep_sources = [rw.reads, rw.writes] + assert all( + isinstance(dep, (MemoryDep, StarDep)) + for dep in itertools.chain.from_iterable(dep_sources) + ) + deps = [ + dep + for dep in itertools.chain.from_iterable(dep_sources) + if dep.name not in V.graph.removed_buffers and isinstance(dep, MemoryDep) + ] + write_names = {dep.name for dep in rw.writes} + + tilings: List[CandidateTiling] = [] + + for dep in deps: + strides = V.graph.sizevars.stride_hints(dep.index, rw.range_vars) + assert len(strides) == len(ranges) + try: + split = strides.index(1) + 1 + if split == len(ranges): + continue + if all(s == 0 for s in strides[split:]): + # if this is a broadcasted tensor and all dimensions after split are broadcast, + # this is not a real split + continue + + except ValueError: + continue + tiled_groups = ( + V.graph.sizevars.simplify(sympy_product(ranges[:split])), + V.graph.sizevars.simplify(sympy_product(ranges[split:])), + ) + # score by number of elements + score = V.graph.sizevars.size_hint( + sympy_product( + size for size, stride in zip(ranges, strides) if stride != 0 + ) + ) + if dep.name in write_names: + # ngimel said contiguous writes is more important than reads + score *= 2 + if CandidateTiling.is_good_size(tiled_groups[0]): + score *= 2 + if CandidateTiling.is_good_size(tiled_groups[1]): + score *= 2 + + if ( + V.graph.sizevars.size_hint( + score - sympy_product(itertools.chain(ranges, reduction_ranges)) + ) + >= 0 + ): + tilings.append(CandidateTiling(tiled_groups, score, dep.name)) + return tilings + + @classmethod + def select_tiling(cls, node_schedule, numel, reduction_numel=sympy.Integer(1)): + """ + Heuristics to decide how to tile kernels. + Currently, we tile based on stride-1 dimensions. + + Returns: + `(tile1, tile2, reduction_numel)` s.t. `tile1 * tile2 == numel` + + """ + if reduction_numel != 1 or config.triton.max_tiles <= 1: + # TODO(jansel): should we tile reductions? + # do perf hint here if stride-1 dim is not being reduced + if perf_hint_log.level <= logging.WARNING: + for node in EnableReduction.filter(node_schedule): + if len(cls.candidate_tilings(node)) > 0: + perf_hint_log.info("reduction over non-contiguous dims") + break + return (numel, reduction_numel) + + seen_names = set() + candidate_tiles: Counter[Any] = collections.Counter() + for node in EnableReduction.filter(node_schedule): + for tiling in cls.candidate_tilings(node): + if tiling.name in seen_names: + continue + seen_names.add(tiling.name) + candidate_tiles[tiling.tiling] += tiling.score + + ranked_tilings = [tiling for tiling, score in candidate_tiles.most_common()] + + if config.triton.max_tiles >= 3: + # Consider adding a third dimension of tiling, but only + # when a1 is a multiple of b1; otherwise, you have a lot + # of stragglers which is annoying to generate code for. + # + # NB: More than three max tiles is not enabled by default. + + # Add one 3D tiling choice + for i in range(1, len(ranked_tilings)): + a0, a1 = ranked_tilings[0] + b0, b1 = ranked_tilings[i] + if V.graph.sizevars.size_hint(a1 - b1) == 0: + continue + if V.graph.sizevars.size_hint(a1 - b1) < 0: + # swap so a0 is bigger + a0, a1 = ranked_tilings[i] + b0, b1 = ranked_tilings[0] + assert V.graph.sizevars.size_hint(a1 - b1) > 0 + if V.graph.sizevars.statically_known_multiple_of(a1, b1): + tiling = (a0, FloorDiv(a1, b1), b1) + ranked_tilings = [tiling] + ranked_tilings + break # only 1 choice for now + + if len(ranked_tilings) > 1: + perf_hint_log.info("possibly bad tiling: %s", ranked_tilings) + + for tiled_groups in ranked_tilings: + new_groups = (*tiled_groups, reduction_numel) + if all( + SIMDKernel.is_compatible(new_groups, node.get_ranges()) + for node in node_schedule + if isinstance(node, scheduler.SchedulerNode) + ): + return new_groups + + return (numel, reduction_numel) + + def flush(self): + pass + + def ready_to_flush(self) -> bool: + return False + + def generate_kernel_code_from_nodes(self, nodes, benchmark_kernel=False): + @dataclasses.dataclass + class LastUsageHolder: + n: Any + last_usage: Any + + def __del__(self): + self.n.last_usage = self.last_usage + + last_usage_holders = [LastUsageHolder(n, n.last_usage) for n in nodes] + + # empty last_usage. May cause more aggressive 'evict_last'. Should be fine. + for n in nodes: + n.last_usage = set() + + if not nodes[0].is_template(): + _, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group + node_schedule = self.generate_node_schedule(nodes, numel, rnumel) + + tiled_groups = self.select_tiling(node_schedule, numel, rnumel) + reduction_hint_val, mutations, index_dtype = self.get_kernel_args( + node_schedule, numel, rnumel + ) + + kernel = self.kernel_type( + *tiled_groups, + reduction_hint=reduction_hint_val, + mutations=mutations, + index_dtype=index_dtype, + ) + + self.codegen_node_schedule_with_kernel(node_schedule, kernel) + with config.patch( + "benchmark_kernel", benchmark_kernel + ), V.set_kernel_handler(kernel): + src_code = kernel.codegen_kernel() + else: + template_node = nodes[0] + epilogue_nodes = nodes[1:] + + with config.patch("benchmark_kernel", benchmark_kernel): + src_code = self.codegen_template( + template_node, epilogue_nodes, only_gen_src_code=True + ) + + src_code = src_code.replace(str(Placeholder.KERNEL_NAME), "triton_") + return src_code + + def codegen_comment(self, node_schedule): + pass + + def define_kernel(self, src_code, node_schedule, kernel): + raise NotImplementedError + + +@dataclasses.dataclass +class CandidateTiling: + tiling: Tuple[sympy.Expr, sympy.Expr] + score: int # higher is better + name: Optional[str] = None + + @staticmethod + def is_good_size(s): + """Somewhat arbitrary heuristic used to boost scores for some sizes""" + s = V.graph.sizevars.size_hint(s) + return s >= 32 and (s % 32 == 0) + + +class DisableReduction: + """ + Marker to invoke `kernel.disable_reduction()`. This closes a + reduction loop and allows for pointwise ops to occur on the output + of a reduction. + """ + + +class EnableReduction: + """ + Marker to end a DisableReduction block. + """ + + @staticmethod + def filter(node_schedule): + """ + Get the nodes from node_schedule skipping those in a + DisableReduction block. + """ + disabled = False + for node in node_schedule: + if node in (EnableReduction, DisableReduction): + # Don't tile stuff outside the main reduction loop + disabled = node is DisableReduction + elif disabled: + pass + else: + yield node + + +class CantSplit(Exception): + pass diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index e015cd7dbf6ad..183d28605b87a 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -1,30 +1,13 @@ from __future__ import annotations -import collections -import contextlib import dataclasses import functools import itertools import logging -import math -import operator import os import textwrap from functools import lru_cache -from typing import ( - Any, - Callable, - cast, - Counter, - DefaultDict, - Dict, - Iterable, - List, - Optional, - Set, - Tuple, - Union, -) +from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple, Union import sympy @@ -33,42 +16,24 @@ import torch.utils._pytree as pytree from torch._dynamo.utils import preserve_rng_state -from torch._inductor.metrics import is_metric_table_enabled, log_kernel_metadata from torch._inductor.runtime.hints import AutotuneHint, DeviceProperties from torch._prims_common import is_integer_dtype -from torch.utils._sympy.functions import FloorDiv, ModularIndexing -from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT -from torch.utils._sympy.value_ranges import ValueRanges from torch.utils._triton import has_triton_package +from ...utils._sympy.value_ranges import ValueRanges -from ..._dynamo.utils import counters -from .. import config, ir, scheduler +from .. import config, ir from ..codecache import code_hash, get_path, PyCodeCache -from ..dependencies import Dep, MemoryDep, StarDep, WeakDep -from ..ir import IRNode, TritonTemplateBuffer -from ..optimize_indexing import indexing_dtype_strength_reduction +from ..ir import IRNode +from ..metrics import is_metric_table_enabled, log_kernel_metadata from ..runtime.hints import ReductionHint, TRITON_MAX_BLOCK -from ..runtime.runtime_utils import ( - do_bench_gpu, - get_max_y_grid, - green_text, - next_power_of_2, - yellow_text, -) -from ..scheduler import BaseSchedulerNode, BaseScheduling, WhyNoFuse +from ..runtime.runtime_utils import do_bench_gpu, next_power_of_2 from ..utils import ( cache_on_self, get_bounds_index_expr, - get_dtype_size, get_fused_kernel_name, get_kernel_metadata, is_welford_reduction, Placeholder, - sympy_dot, - sympy_index_symbol, - sympy_product, - sympy_subs, - unique, ) from ..virtualized import _ops as ops, OpsHandler, ReductionType, StoreMode, V from ..wrapper_benchmark import get_kernel_category_by_source_code @@ -77,17 +42,21 @@ CSEVariable, DeferredLine, IndentedBuffer, - index_prevent_reordering, - Kernel, OpOverrides, PythonPrinter, SizeArg, TensorArg, ) -from .multi_kernel import MultiKernel +from .simd import ( + IndexingOptions, + IterationRangesEntry, + pexpr, + SIMDKernel, + SIMDScheduling, + triton_constant, +) from .triton_utils import config_of, signature_of, signature_to_meta - log = logging.getLogger(__name__) perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints") schedule_log = torch._logging.getArtifactLogger(__name__, "schedule") @@ -133,27 +102,6 @@ def gen_common_triton_imports(): return imports.getvalue() -@dataclasses.dataclass -class IndexingOptions: - index_str: str - mask_vars: Set[sympy.Symbol] - mask_str: str - expand_str: Optional[str] - _has_rindex: bool - - def has_mask(self): - return bool(self.mask_vars) - - def has_rindex(self): - return self._has_rindex - - def has_tmpmask(self): - return "tmp" in self.mask_str - - def has_rmask(self): - return "rmask" in self.mask_str - - @dataclasses.dataclass class BlockPtrOptions: constant_offset: sympy.Expr @@ -416,7 +364,6 @@ def _print_RoundDecimal(self, expr): texpr = TritonPrinter().doprint -pexpr = PythonPrinter().doprint def triton_compute_type(dtype): @@ -455,16 +402,6 @@ def triton_acc_type(dtype): return triton_compute_type(dtype) -def triton_constant(value): - if value == float("inf"): - return 'float("inf")' - elif value == float("-inf"): - return 'float("-inf")' - elif math.isnan(value): - return 'float("nan")' - return repr(value) - - class TritonCSEVariable(CSEVariable): def __init__(self, name, bounds: ValueRanges[Any]): super().__init__(name, bounds) @@ -487,9 +424,6 @@ def update_on_args(self, name, args, kwargs): # those reads should subsequently be masked, self.mask_vars.update({f"{arg.name[0]}mask"}) - def __repr__(self): - return f"TritonCSEVariable(name={self.name})" - class TritonOverrides(OpOverrides): """Map element-wise ops to Triton""" @@ -965,283 +899,6 @@ def _typecheck_TritonKernelOverrides(h: TritonKernelOverrides) -> OpsHandler[str return h -@dataclasses.dataclass -class IterationRanges: - """ - Each range tree represents multiple sets of iteration indexing - in a single tiled dimension in the output kernel. - - If you have two loops ranges one (4, 3, 2) and another (4, 6), - then the range tree will be: - 4 (i0) - 3 (i1) 6 (i3) - 2 (i2) - Where i0 is shared between both loops, but then the split into - different indexing vars. All loop ranges must iterate over - the same number of elements. - """ - - def __init__( - self, - name: str, - var_list: List[sympy.Symbol], - var_ranges: Dict[sympy.Symbol, sympy.Expr], - numel: sympy.Expr, - prefix: str, - *, - kernel: TritonKernel, - divisor=sympy.Integer(1), - length=sympy.Integer(1), - root: IterationRangesRoot, - ): - super().__init__() - self.name = name - self.var_list = var_list - self.var_ranges = var_ranges - self.numel = numel - self.prefix = prefix - self.divisor = divisor - self.length = length - self.kernel = kernel - self.root = root - - def symbol(self): - return sympy_index_symbol(self.name) - - -class IterationRangesRoot(IterationRanges): - def __init__( - self, - name: str, - numel: sympy.Expr, - # TODO: this is probably SymTy.INDEX and SymTy.RINDEX - prefix: str, - index: int, - kernel: TritonKernel, - pid_cache=None, - *, - is_loop: bool, - tensor_dim: Optional[int], - grid_dim: Optional[int], - has_zdim: bool, - ): - if pid_cache is None: - pid_cache = {} - super().__init__( - name=name, - var_list=[], - var_ranges={}, - numel=numel, - prefix=prefix, - kernel=kernel, - root=self, - ) - self.index = index - # Store all the nodes in one flat list - self.nodes: Dict[sympy.Expr, IterationRangesEntry] = {} - # This is for re-ordering program ID in triton mm template - # pid_cache["tl.program_id(0)"] = pid_m - self.pid_cache: Dict[str, str] = pid_cache - - # True if the dimension is implemented as a single program looping over - # the full dimension (currently only used for non-persistent reduction) - assert not is_loop or (prefix == "r" and grid_dim is None) - self.is_loop = is_loop - # Index of corresponding dimension on triton tensors - self.tensor_dim = tensor_dim - # Index of corresponding dimension in the triton grid - self.grid_dim = grid_dim - self.has_zdim = has_zdim - - def __repr__(self): - return f"IterationRangesRoot({self.name!r}, {self.numel}, ...)" - - def cache_clear(self): - for node in self.nodes.values(): - node.cache_clear() - - def lookup(self, divisor, length): - """ - Lookup a given RangeTreeEntry, creating it if needed - """ - if V.graph.sizevars.statically_known_equals(divisor * length, self.numel): - expr = FloorDiv(sympy_index_symbol(f"{self.prefix}index"), divisor) - else: - expr = ModularIndexing( - sympy_index_symbol(f"{self.prefix}index"), divisor, length - ) - - if expr not in self.nodes: - node = IterationRangesEntry( - f"{self.prefix}{next(V.kernel.iter_vars_count)}", - divisor, - length, - expr, - self, - ) - V.kernel.range_tree_nodes[node.symbol()] = node - self.var_list.append(node.symbol()) - self.var_ranges[node.symbol()] = length - self.nodes[expr] = node - return self.nodes[expr] - - def construct_entries(self, lengths: List[sympy.Expr]): - divisor = sympy.Integer(1) - itervars = [] - for length in reversed(lengths): - itervars.append(self.lookup(divisor, length)) - divisor = divisor * length - return list(reversed(itervars)) - - def construct(self, lengths: List[sympy.Expr]): - return [e.symbol() for e in self.construct_entries(lengths)] - - def vars_and_sizes(self, index: sympy.Expr): - """Figure out vars from this tree used in index""" - nodes = [V.kernel.range_tree_nodes.get(s) for s in index.free_symbols] - nodes = [n for n in nodes if n and n.prefix == self.prefix] - nodes.sort(key=lambda x: V.graph.sizevars.size_hint(x.divisor)) - divisor = sympy.Integer(1) - index_vars = [] - sizes = [] - - def add(node): - nonlocal divisor - index_vars.append(node.symbol()) - sizes.append(node.length) - divisor = divisor * node.length - - for node in nodes: - if not V.graph.sizevars.statically_known_equals(node.divisor, divisor): - # fill in unused index var - add(self.lookup(divisor, FloorDiv(node.divisor, divisor))) - divisor = node.divisor - add(node) - if not V.graph.sizevars.statically_known_equals(self.numel, divisor): - # fill in unused index var - add(self.lookup(divisor, FloorDiv(self.numel, divisor))) - - return list(reversed(index_vars)), list(reversed(sizes)) - - def ranges_code(self): - assert self.tensor_dim is not None - size = self.kernel.indexing_size_str(self.tensor_dim) - index_dtype = self.kernel.index_dtype - convert = f".to({index_dtype})" if index_dtype != "tl.int32" else "" - return f"tl.arange(0, {self.prefix.upper()}BLOCK){size}{convert}" - - def scalar_code(self, value): - index_dtype = self.kernel.index_dtype - ndim = self.kernel.triton_tensor_ndim() - size = [1] * ndim - return f"tl.full({size}, {value}, {index_dtype})" - - def get_pid(self): - assert self.grid_dim is not None - key = f"tl.program_id({self.grid_dim})" - # y_grid has a limit, so express it in terms of y and z in case of overflow. - # z grid is only exercised when max_tiles == 3 (off by default). - if ( - self.grid_dim == 1 - and not self.has_zdim - and not (isinstance(self.numel, int) and self.numel <= get_max_y_grid()) - ): - key = f"{key} * (tl.program_id({self.grid_dim + 1}) + 1)" - pid = self.pid_cache.get(key, key) - if self.kernel.index_dtype != "tl.int32": - return f"{pid}.to({self.kernel.index_dtype})" - return pid - - def codegen_header(self, code): - x = self.prefix - if self.is_loop: - code.writeline(f"{self.name} = {x}offset + {x}base") - elif self.grid_dim is None: - # no need to "{x}offset = " - code.writeline(f"{self.name} = {self.ranges_code()}") - code.writeline(f"{x}offset = 0") - else: - if self.tensor_dim is not None: - line = f"{x}offset + {self.ranges_code()}" - else: - line = self.scalar_code(f"{x}offset") - code.writelines( - [ - f"{x}offset = {self.get_pid()} * {x.upper()}BLOCK", - f"{self.name} = {line}", - ] - ) - code.writeline(f"{x}mask = {self.name} < {x}numel") - - -class IterationRangesEntry(IterationRanges): - def __init__( - self, - name: str, - divisor: sympy.Expr, - length: sympy.Expr, - expr: sympy.Expr, - parent: IterationRanges, - ): - super().__init__( - name=name, - numel=parent.numel / length, - var_list=parent.var_list, - var_ranges=parent.var_ranges, - prefix=parent.prefix, - divisor=divisor, - length=length, - kernel=parent.kernel, - root=parent.root, - ) - self.parent = parent - self.codegen = functools.lru_cache(None)(self._codegen) - self.expr = expr - - def __repr__(self): - return f"IterationRangesEntry({self.name}, {self.divisor}, {self.length}, {self.expr}, {self.var_ranges})" - - def set_name(self, name): - self.codegen = lambda: name # type: ignore[assignment] - self.codegen.cache_clear = lambda: None # type: ignore[method-assign] - self.name = name - - def cache_clear(self): - self.codegen.cache_clear() - - def writeline(self, line): - if self.root.is_loop: - V.kernel.indexing_code.writeline(line) - else: - # lift non-reduction stores outside loop - V.kernel.body.writeline(line) - - def _codegen(self): - self.writeline(f"{self.name} = " + texpr(V.kernel.rename_indexing(self.expr))) - return self.name - - def precomputed_args(self): - # for dynamic shapes, find parts of indexing expressions that have to be precomputed - precomputed_args: List[sympy.Expr] = [] - if isinstance(self.expr, sympy.Symbol): - return precomputed_args - assert isinstance(self.expr, (FloorDiv, ModularIndexing)), type(self.expr) - for arg in self.expr.args[1:]: - if not isinstance(arg, (sympy.Integer, sympy.Symbol)): - symbols = arg.free_symbols - if len(symbols) > 0 and all( - symbol_is_type(s, SymT.SIZE) for s in symbols - ): - precomputed_args.append(arg) - return precomputed_args - - def __hash__(self): - return hash(self.name) - - def __eq__(self, other): - return self.name == other.name - - class HelperFunctions: """An ordered set of helper functions.""" @@ -1281,11 +938,11 @@ def __getitem__(self, idx): return self.finalized_helpers[idx] -class TritonKernel(Kernel): +class TritonKernel(SIMDKernel): overrides = TritonKernelOverrides # type: ignore[assignment] - sexpr = pexpr - helper_functions: HelperFunctions + kexpr: Callable[[sympy.Expr], str] = texpr + allow_block_ptr = True def __init__( self, @@ -1297,54 +954,35 @@ def __init__( min_elem_per_thread=0, disable_persistent_reduction=False, ): - if pid_cache is None: - pid_cache = {} - super().__init__() - self.numels = [V.graph.sizevars.simplify(s) for s in groups] - self.mutations: Set[str] = mutations if mutations is not None else set() - self.range_trees: List[IterationRangesRoot] = [] - self.range_tree_nodes: Dict[sympy.Symbol, IterationRangesEntry] = {} - self.iter_vars_count = itertools.count() - self.inside_reduction = self.numels[-1] != 1 - self.body = IndentedBuffer() - self.indexing_code = IndentedBuffer() + super().__init__( + *groups, + index_dtype=index_dtype, + mutations=mutations, + reduction_hint=reduction_hint, + pid_cache=pid_cache, + disable_persistent_reduction=disable_persistent_reduction, + ) self.suffix: IndentedBuffer = IndentedBuffer() # type: ignore[assignment] self.outside_loop_vars: Set[Any] = set() - self.reduction_hint = reduction_hint - self.index_dtype: str = index_dtype self.min_elem_per_thread = min_elem_per_thread - self.last_usage: Set[str] = set() self.block_ptr_id = itertools.count() - # buffer accesses in the kernel - self.buf_accesses: DefaultDict[str, List[Dep]] = collections.defaultdict(list) - - self.persistent_reduction: bool = ( - not disable_persistent_reduction - ) and self.should_use_persistent_reduction() - self.no_x_dim = ( - self.reduction_hint == ReductionHint.INNER - and self.persistent_reduction - and len(self.numels) == 2 - and self.numels[-1] >= 256 - ) - self.initialize_range_tree(pid_cache) - self.helper_functions = HelperFunctions() # A set of autotuning hints to pass as part of triton_meta self.autotune_hints: Set[AutotuneHint] = set() + self.triton_meta: Optional[Dict[str, object]] = None - # define this in a closure to make cache local to object - @functools.lru_cache(None) - def simplify_indexing(index: sympy.Expr): - index = V.graph.sizevars.simplify_with_ranges(index, self.var_ranges()) - for tree in self.range_trees: - index = self.combine_contiguous_dims(index, tree) - return index + self.codegen_range_tree() - self.simplify_indexing = simplify_indexing - self.code_hash = None - self.triton_meta: Optional[Dict[str, object]] = None + def codegen_range_tree(self): + for tree in self.range_trees: + # reduction indexing goes inside a loop + if not tree.is_loop: + tree.codegen_header(self.body) + if self.inside_reduction and self.range_trees[-1].is_loop: + # workaround for this issue: + # https://gist.github.com/jansel/6527126f781559095c5531f98a4235a7 + self.body.writeline(f"rbase = {self.range_trees[-1].ranges_code()}") def need_numel_args(self): r""" @@ -1384,507 +1022,21 @@ def should_use_persistent_reduction(self) -> bool: V.graph.sizevars.guard_leq(self.numels[-1], next_power_of_2(hint)) # type: ignore[arg-type] return True - def set_last_usage(self, nodes): - if not self.inside_reduction or self.persistent_reduction: - return - self.last_usage = set( - itertools.chain.from_iterable( - n.last_usage for n in nodes if n is not EnableReduction - ) - ) - - def initialize_range_tree(self, pid_cache): - no_r_dim = not self.inside_reduction or self.numels[-1] == 1 - - prefixes = "zyxr" - active_prefixes = prefixes[-len(self.numels) :] - - grid_dims = "xyz" - if self.no_x_dim: - tensor_dims = "r" - elif no_r_dim: - tensor_dims = "xyz" - else: - tensor_dims = "xyzr" - - tensor_dims = "".join(p for p in tensor_dims if p in active_prefixes) - - for i, prefix in enumerate(active_prefixes): - is_reduction = prefix == "r" - tensor_dim = tensor_dims.find(prefix) if prefix in tensor_dims else None - grid_dim = None if is_reduction else grid_dims.find(prefix) - index = i if grid_dim is None else grid_dim - self.range_trees.append( - IterationRangesRoot( - f"{prefix}index", - self.numels[i], - prefix, - index, - self, - pid_cache=pid_cache, - is_loop=is_reduction and not self.persistent_reduction, - tensor_dim=tensor_dim, - grid_dim=grid_dim, - has_zdim="z" in active_prefixes, - ) - ) - for tree in self.range_trees: - # reduction indexing goes inside a loop - if not tree.is_loop: - tree.codegen_header(self.body) - if self.inside_reduction and self.range_trees[-1].is_loop: - # workaround for this issue: - # https://gist.github.com/jansel/6527126f781559095c5531f98a4235a7 - self.body.writeline(f"rbase = {self.range_trees[-1].ranges_code()}") - - def disable_reduction(self): - should_flush = self.range_trees[-1].is_loop - - @contextlib.contextmanager - def ctx(): - if self.numels[-1] == 1: - assert not self.inside_reduction - yield - return - if should_flush: - # calling codegen_body() will flush all the pending buffers - # and write out a reduction loop - self.codegen_body() - self.inside_reduction = False - try: - yield - if should_flush: - # flush out any code before opening the next loop - self.codegen_body() - finally: - self.inside_reduction = True - - return ctx() - - def set_ranges(self, *lengths): - assert len(lengths) == len(self.range_trees) - return [ - ranges.construct(length) - for length, ranges in zip(lengths, self.range_trees) - ] - - @staticmethod - def _split_iteration_ranges( - groups: Iterable[sympy.Expr], lengths: List[List[sympy.Expr]] - ): - sv = V.graph.sizevars - new_ranges: List[List[sympy.Expr]] = [[] for _ in groups] - remaining = [sv.simplify(g) for g in groups] - var_count = itertools.count() - - def add_range(i, expr): - expr = sv.simplify(expr) - if not sv.statically_known_multiple_of(remaining[i], expr): - raise CantSplit - # guard on the last item out - remaining[i] = FloorDiv(remaining[i], expr) - new_ranges[i].append(expr) - return next(var_count) - - def make_combined(size, idx1, idx2): - def getter(flat_vars): - return size * flat_vars[idx1] + flat_vars[idx2] - - return getter - - return_getters_groups = [] - current_group = 0 - for length_group in lengths: - return_getters = [] - for size in length_group: - if sv.statically_known_equals(size, 1): # type: ignore[arg-type] - return_getters.append(lambda _: sympy.Integer(0)) - continue - - while ( - current_group < len(remaining) - and sv.size_hint(remaining[current_group]) == 1 - ): - # scroll to next group with remaining elements - current_group += 1 - - if sv.size_hint(size) > sv.size_hint(remaining[current_group]): - # need to break size in two - if not sv.statically_known_multiple_of( - size, remaining[current_group] - ): - raise CantSplit - size1 = remaining[current_group] - size2 = FloorDiv(size, remaining[current_group]) - return_getters.append( - make_combined( - size2, - add_range(current_group, size1), - add_range(current_group + 1, size2), - ) - ) - else: - return_getters.append( - operator.itemgetter(add_range(current_group, size)) - ) - return_getters_groups.append(return_getters) - - assert all( - V.graph.sizevars.size_hint(s) == 1 for s in remaining - ), f"failed to set ranges {remaining} {lengths}" - - return new_ranges, return_getters_groups - - @classmethod - def is_compatible( - cls, groups: Iterable[sympy.Expr], lengths: List[List[sympy.Expr]] - ): - try: - cls._split_iteration_ranges(groups, lengths) - return True - except CantSplit: - return False - - def split_and_set_ranges(self, lengths: List[List[sympy.Expr]]): - """ - We may want to fuse `for i0 in s0*s1` into a tiled kernel with groups (s0, s1). - - To do this we need to split up the iteration space of i0 into something like: - for i1 in s0: - for i2 in s1: - i0 = i1*s1 + i2 - .... - - This function matches and resplits lengths to the groups of - this kernel to enable tiled + non-tiled fusions. - """ - groups = [rt.numel for rt in self.range_trees] - if not self.inside_reduction: - groups[-1] = sympy.Integer(1) - - if len(lengths) == len(self.range_trees) and all( - V.graph.sizevars.simplify(sympy_product(x) - g) == 0 - for x, g in zip(lengths, groups) - ): - return self.set_ranges(*lengths) - - new_ranges, return_getters_groups = self._split_iteration_ranges( - groups, lengths - ) - itervars = list(itertools.chain.from_iterable(self.set_ranges(*new_ranges))) - return [[fn(itervars) for fn in fns] for fns in return_getters_groups] - - def is_indirect_indexing(self, index: sympy.Expr): - # tmpX means indirect indexing - return free_symbol_is_type(index, SymT.TMP) - - def is_broadcasted(self, index: sympy.Expr): - # Note. This may not be correct when there is indirect indexing - if self.is_indirect_indexing(index): - return False - - index_numels = [1] * len(self.numels) - for symbol in index.free_symbols: - if symbol not in self.range_tree_nodes: - # Non-iterated variables, e.g. strides - continue - entry = self.range_tree_nodes[symbol] # type: ignore[index] - assert isinstance(entry.parent, IterationRangesRoot) - index_numels[entry.parent.index] *= entry.length - - # If the index variables only iterate over a subset of the kernel - # numels, then it must be broadcasted. - simplify = V.graph.sizevars.simplify - return any( - simplify(idx_range) != simplify(iter_range) # type: ignore[arg-type] - for idx_range, iter_range in zip(index_numels, self.numels) - ) - - def combine_contiguous_dims(self, index: sympy.Expr, tree: IterationRangesRoot): - """ - More aggressive simplification to merge contiguous dims - """ - if isinstance(index, (sympy.Integer, sympy.Symbol)): - return index - index_vars, sizes = tree.vars_and_sizes(index) - if len(sizes) <= 1: - return index - new_sizes, reindex, prune = V.graph.sizevars._simplify_loops( - index_vars, sizes, index_prevent_reordering([index], index_vars, sizes) - ) - if new_sizes == sizes: - return index - new_index_vars = tree.construct(new_sizes) - new_index = sympy_subs(index, dict(zip(index_vars, reindex(new_index_vars)))) - return new_index - - def index_to_str(self, index: sympy.Expr) -> str: - """ - Convert an index expr to a string that can be used in triton code. - e.g. a sympy expression "s2" may actually appear as "ks1" in the triton kernel. - - Index expressions often need to be passed in as arguments to the triton kernel. - Rename_indexing and codegen_indexing keep track of the needed indices and add - new parameters to the function signature. - """ - if isinstance(index, list): - return f"[{', '.join(map(self.index_to_str, index))}]" - return texpr(self.rename_indexing(self.codegen_indexing(index))) - - def indexing( - self, - index: sympy.Expr, - *, - copy_shape=None, - dense_indexing=False, - override_mask=None, - block_ptr=False, - ) -> Union[IndexingOptions, BlockPtrOptions]: - """ - Compute the index and mask to pass to tl.load() or tl.store() - """ - index = self.simplify_indexing(index) - index = sympy_subs(index, V.graph.sizevars.precomputed_replacements) - # if simple replacements didn't get rid of floor/ceil, try full subs - if len(index.atoms(sympy.floor)) or len(index.atoms(sympy.ceiling)): - index = index.subs(V.graph.sizevars.precomputed_replacements) - # last resort, if no range vars are in the expr, hoist it - # TODO instead of trying to blindly find complicated exprs, we should hoist the - # inputs/outputs sizes and strides, but at the time indexing is generated - # kernel inputs and outputs are not set yet, we'd need a deeper refactor - # to do it this way - - if len(index.atoms(sympy.ceiling)): - for a in index.atoms(sympy.ceiling): - # for nested exprs, atoms yields top level first (?) - # so if everything goes fine, lower level replacements will come up empty - symbols = a.free_symbols - if len(symbols) > 0 and all( - symbol_is_type(s, (SymT.SIZE, SymT.PRECOMPUTED_SIZE)) - for s in symbols - ): - replacements = {a: V.graph.sizevars.lookup_precomputed_size(a)} - index = sympy_subs(index, replacements) - - index = self.simplify_indexing(index) - index_vars = index.free_symbols - has_rindex = False - - mask_vars: Set[str] = set() - for var in index_vars: - assert isinstance(var, sympy.Symbol) - has_rindex = has_rindex or symbol_is_type(var, SymT.RINDEX) - if override_mask: - pass - elif symbol_is_type(var, SymT.TMP): - # indirect indexing - cse_var = self.cse.varname_map[var.name] - mask_vars.update(cse_var.mask_vars) - elif symbol_is_type( - var, - ( - SymT.UNBACKED_INT, - SymT.SIZE, - SymT.PRECOMPUTED_SIZE, - SymT.INDEX, - SymT.FLOAT, - SymT.UNBACKED_FLOAT, - ), - ): - pass - else: - # var is one of xN, yN or rN - assert symbol_is_type( - var, (SymT.RINDEX, SymT.XBLOCK, SymT.YBLOCK) - ), var.name - mask_vars.add(f"{var.name[0]}mask") - - need_dense = ( - config.triton.dense_indexing - or dense_indexing - or self._load_mask is not None - ) and index != 0 - - have_dense = True - have_loop_vars = False - dense_mask_vars = set() - - for tree in self.active_range_trees(): - if index_vars.intersection(tree.var_list): - have_loop_vars = True - else: - have_dense = False - dense_mask_vars.add(f"{tree.prefix}mask") - - if ( - block_ptr - and config.triton.use_block_ptr - and not override_mask - and not self._load_mask - and len(mask_vars - dense_mask_vars) == 0 - and not self.is_indirect_indexing(index) - and have_loop_vars - # workaround https://github.com/openai/triton/issues/2821 - and self.index_dtype == "tl.int32" - ): - index_relative_to_xyr_index = sympy_subs( - index, {v: t.expr for v, t in self.range_tree_nodes.items()} - ) - range_trees = self.active_range_trees(reorder=True) - symbols = [t.symbol() for t in range_trees] - strides = [sympy.Wild(f"stride_{s}", exclude=symbols) for s in symbols] - offset = sympy.Wild("_offset", exclude=symbols) - m = index_relative_to_xyr_index.match(sympy_dot(symbols, strides) + offset) - # TODO(jansel): it is sometimes possible to do higher dimensional block_ptrs with - # a tl.reshape the correct block. We will miss these cases today. - if m: - self.filter_masks(mask_vars) - return BlockPtrOptions.create( - [m[s] for s in strides], - m[offset], - range_trees, - mask_vars, # type: ignore[arg-type] - ) - - expand_str = None - index_str = self.index_to_str(index) - if isinstance(index, sympy.Integer): - expand_str = f"{copy_shape}.shape" if copy_shape else self.dense_size_str() - index_str = f"tl.full({expand_str}, {index_str}, tl.int32)" - return IndexingOptions(index_str, set(), "None", expand_str, has_rindex) - - if need_dense and not have_dense: - expand_str = f"{copy_shape}.shape" if copy_shape else self.dense_size_str() - index_str = f"tl.broadcast_to({index_str}, {expand_str})" - mask_vars = dense_mask_vars - elif not have_loop_vars and copy_shape: - index_str = f"tl.broadcast_to({index_str}, {copy_shape}.shape)" - mask_vars = dense_mask_vars - - if override_mask: - mask_vars = {override_mask} - - if self._load_mask: - mask_vars.add(self._load_mask) - - self.filter_masks(mask_vars) - - mask_str = " & ".join(sorted(map(str, mask_vars))) if mask_vars else "None" - return IndexingOptions(index_str, mask_vars, mask_str, expand_str, has_rindex) # type: ignore[arg-type] - - def active_range_trees(self, reorder=False): - trees = [ - t for t in self.range_trees if t.prefix != "r" or self.inside_reduction - ] - if reorder and len(trees) > 1: - count = sum(t.prefix in "xyz" for t in trees) - assert "".join(t.prefix for t in trees[:count]) == "zyx"[-count:], [ - t.prefix for t in trees[:count] - ] - trees[:count] = reversed(trees[:count]) - return trees - - def filter_masks(self, mask_vars): - for tree in self.range_trees: - # Masks are superfluous if we only have one element - if V.graph.sizevars.statically_known_equals(tree.numel, 1): # type: ignore[arg-type] - mask_vars.discard(f"{tree.prefix}mask") - continue - # Masks are superfluous if numel is a multiple of BLOCK - # (We use the fact that BLOCK is required by triton to be a power of 2) - if tree.prefix.upper() not in TRITON_MAX_BLOCK: - continue - max_block = TRITON_MAX_BLOCK[tree.prefix.upper()] - # Optional optimization: if block divides numel exactly, we will - # never need to do a masked load to handle stragglers at the end. - # It's faster to avoid masking at all. But it is sound to always - # mask. - if V.graph.sizevars.statically_known_multiple_of(tree.numel, max_block): # type: ignore[arg-type] - mask_vars.discard(f"{tree.prefix}mask") - - def var_ranges(self): - return dict( - itertools.chain.from_iterable( - tree.var_ranges.items() for tree in self.range_trees - ) + def want_no_x_dim(self): + return ( + self.reduction_hint == ReductionHint.INNER + and self.persistent_reduction + and len(self.numels) == 2 + and self.numels[-1] >= 256 ) - def codegen_indexing(self, expr: sympy.Expr): - expr = V.graph.sizevars.simplify_with_ranges(expr, self.var_ranges()) - for sym in sorted(expr.free_symbols, key=str): - if sym in self.range_tree_nodes: - # if indexing expression is complicated, we precompute it on the host side - # and send the result as a kernel argument - replacements = {} - for ps in self.range_tree_nodes[sym].precomputed_args(): # type: ignore[index] - replacements[ps] = V.graph.sizevars.lookup_precomputed_size(ps) - if len(replacements) > 0: - self.range_tree_nodes[sym].expr = sympy_subs( # type: ignore[index] - self.range_tree_nodes[sym].expr, replacements # type: ignore[index] - ) - self.range_tree_nodes[sym].codegen() # type: ignore[index] - return expr - - @contextlib.contextmanager - def mask_loads(self, mask): - """Context manager to add an additional mask to tl.load/store""" - prior = self._load_mask - if prior: - mask = self.cse.generate(self.compute, f"{mask} & {prior}") - - self._load_mask = mask - try: - # TODO(jansel): do we need a reshape here? - yield mask - finally: - self._load_mask = prior - def generate_assert(self, check): return torch.version.hip is None and super().generate_assert(check) - def load_mask(self, var): - mask = "" - mask_vars = set(var.mask_vars) - if self._load_mask: - mask_vars.add(self._load_mask) - - if mask_vars: - mask = ( - f"{next(iter(mask_vars))}" - if len(mask_vars) == 1 - # sorted for deterministic order - else f"({' & '.join(sorted(map(str, mask_vars)))})" - ) - return mask - @property def assert_function(self) -> str: return "tl.device_assert" - def get_strides_of_load(self, index: sympy.Expr): - """ - This gets the stride of the index for each of the tiling variables - (technically, it does it at index 0) - - For example, if - xindex = x0 + 512*x1 + 1024*r0 - x0 = (xindex//512) - x1 = (xindex % 512) - r0 = rindex // 1024 - - this function would return - {xindex: 512, rindex: 1024} - """ - index_to_tile_indexes = {k: v.expr for k, v in self.range_tree_nodes.items()} - index_in_tile_vars = sympy_subs(index, index_to_tile_indexes) # type: ignore[arg-type] - strides = {} - for range_tree in self.range_trees: - s = sympy_index_symbol(range_tree.name) - strides[s] = sympy_subs(index_in_tile_vars, {s: 1}) - sympy_subs( - index_in_tile_vars, {s: 0} - ) - return strides - def codegen_block_ptr( self, name: str, var: str, indexing: BlockPtrOptions, other="" ) -> Tuple[str, Optional[DeferredLine], str]: @@ -2125,12 +1277,6 @@ def reduction_resize(self, value): sizes[-1] = "None" return f"{value}[{', '.join(sizes)}]" - @staticmethod - def _map_tuple_or_scalar(fn, value): - if isinstance(value, tuple): - return tuple(map(fn, value)) - return fn(value) - def reduction( self, dtype: torch.dtype, @@ -2684,68 +1830,6 @@ def imports_for_benchmark_kernel(self): ) ) - def estimate_kernel_num_bytes(self): - """ - Try the best to estimate the total size (in bytes) of the - kernel's inputs and outputs, which is used for estimating the memory - throughput of this kernel. This information is used for checking how - far we are from the peak memory bandwidth. It's important that - we want to avoid overestimating the sizes of the inputs and outputs, - because it can wrongfully give us a very large memory traffic value, - which may be even larger than the theoretical bandwidth and thus - become very misleading. This is particularly problematic for cases - where we slice some inputs. In those cases, we should only count - the size of the "slices" instead of the original inputs, because - only the slices contribute to the real memory traffic. - """ - nbytes = [] - ninplace_args = len(unique(self.args.inplace_buffers.values())) - _, call_args, _ = self.args.python_argdefs() - - # For pointwise and reduction kernels, this is the upper-bound numels - # for the output buffer. - # FIXME: This is not exactly right for cases like below: - # def foo(tensor0, tensor1): - # x0 = narrow(tensor0) - # return cat(x0, tensor1) - # For this example, we will end up overestimate the size for the - # slice s0. Potentially, we could have precise inputs information - # if we maintained the original inputs of the Pointwise kernel created - # for the "cat". However, I think it might be a bit overwhelming that - # we add such complexity only for handling some particular cases for - # benchmarking. - out_numel = V.graph.sizevars.size_hint(sympy_product(self.numels)) - for i, arg in enumerate(call_args): - # "buf" may be narrowed. In this case, the number of memory accesses - # should be estimated based on the reinterpreted layout. - # On the other hand, buf may be broadcasted. In this case, - # counting the size of the underline storage would give us - # a better estimation in terms of memory accesses. - if arg not in self.buf_accesses: - nbytes.append(0) - continue - arg_numel = V.graph.get_numel(arg) - buf_size = V.graph.sizevars.size_hint(arg_numel) - if buf_size > out_numel: - # This arg points to a buf that has been sliced. - # We need to count each individual slice to have - # a better estimation. - indices: Set[Any] = set() - no_index_dep_count = 0 - for dep in self.buf_accesses[arg]: - if isinstance(dep, (StarDep, WeakDep)): - indices.add(f"no_index_dep_{no_index_dep_count}") - no_index_dep_count += 1 - else: - indices.add(dep.index) - numel = len(indices) * out_numel - else: - numel = buf_size - dtype = V.graph.get_dtype(arg) - dtype_size = get_dtype_size(dtype) - nbytes.append(numel * dtype_size * (1 + int(i < ninplace_args))) - return sum(nbytes) - def _get_heuristic(self): if self.persistent_reduction: assert self.inside_reduction @@ -2991,28 +2075,6 @@ def KERNEL_NAME(in_ptr0, in_ptr1, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexp if tree.prefix == "x" and self.no_x_dim: code.writeline("XBLOCK: tl.constexpr = 1") - def triton_tensor_ndim(self): - return sum(int(tree.tensor_dim is not None) for tree in self.range_trees) - - def indexing_size_str(self, i): - sizes = ["None"] * self.triton_tensor_ndim() - sizes[i] = ":" - return f"[{', '.join(sizes)}]" - - def dense_size_list(self) -> List[str]: - sizes = ["1"] * self.triton_tensor_ndim() - for tree in self.range_trees: - if tree.tensor_dim is None: - continue - - if tree.prefix != "r" or self.inside_reduction: - sizes[tree.tensor_dim] = f"{tree.prefix.upper()}BLOCK" - return sizes - - def dense_size_str(self): - sizes = self.dense_size_list() - return f"[{', '.join(sizes)}]" - def _get_grid_fn(self): return "grid" @@ -3076,439 +2138,22 @@ def codegen_nan_check(self): line = f"assert not {arg}.isinf().any().item()" wrapper.writeline(line) - def warn_mix_layout(self, kernel_name): - """ - Print message if the kernel have mixed layout inputs. - Only care about 4D tensor for now. - """ - if ( - len(self.args.input_buffers) == 1 - and len(self.args.output_buffers) == 1 - and len(self.args.inplace_buffers) == 0 - ): - # even if input buffer and output buffer have different layout, - # this can be a layout conversion kernel. No need to warn for - # the mix layouts. - return - - argdefs, call_args, signature = self.args.python_argdefs() - uniform_stride_order = None - for arg_name in call_args: - buf = V.graph.get_buffer(arg_name) - if buf and len(buf.layout.size) == 4: - # ignore the tensor if only 1 dimension is non-zero - if len([x for x in buf.layout.size if x == 1]) == 3: - continue - stride_order = ir.get_stride_order(buf.layout.stride) - if uniform_stride_order is None: - uniform_stride_order = stride_order - elif uniform_stride_order != stride_order: - msg = yellow_text( - f"Expected stride order {uniform_stride_order}, but found stride order" - + f" {stride_order} for kernel {kernel_name}" - ) - log.warning(msg) - - stride_order_list = [ - ir.get_stride_order(V.graph.get_buffer(name).layout.stride) - if V.graph.get_buffer(name) - else None - for name in call_args - ] - size_list = [ - V.graph.get_buffer(name).layout.size - if V.graph.get_buffer(name) - else None - for name in call_args - ] - source_list = [ - "GraphInput" - if name in V.graph.graph_inputs - else "IntermediateBuffer" - if name in V.graph.name_to_buffer - else None - for name in call_args - ] - - msg = yellow_text( - f" param names {argdefs}\n buf names {call_args}\n strides {stride_order_list}" - + f"\n sizes {size_list}\n sources {source_list}\n" - ) - log.warning(msg) - return - msg = green_text( - f"All the inputs for the triton kernel {kernel_name} have uniform layout" - ) - log.warning(msg) - def create_cse_var(self, *args, **kwargs): return TritonCSEVariable(*args, **kwargs) - -class TritonScheduling(BaseScheduling): - def __init__(self, scheduler): - self.scheduler = scheduler - - def group_fn(self, sizes): - return tuple(V.graph.sizevars.simplify(sympy_product(s)) for s in sizes) - - def can_fuse(self, node1, node2): - """ - Hook called by Scheduler to determine if the Triton backend - can fuse node1 and node2. These nodes might already be - FusedSchedulerNodes. - """ - if isinstance(node1, scheduler.ForeachKernelSchedulerNode) or isinstance( - node2, scheduler.ForeachKernelSchedulerNode - ): - return scheduler.ForeachKernelSchedulerNode.can_fuse(node1, node2) - - _, (numel1, rnumel1) = node1.group - _, (numel2, rnumel2) = node2.group - why = WhyNoFuse(node1, node2) - - if node1.is_split_scan() and not node2.is_split_scan(): - if node2.is_reduction(): - why("Split scan cannot fuse with reductions") - elif node2.is_split_scan() and not node1.is_split_scan(): - if node1.is_reduction(): - why("Split scan cannot fuse with reductions") - - if node1.is_reduction() and node2.is_reduction(): - reduction_can_fuse = numel1 == numel2 and rnumel1 == rnumel2 - if not reduction_can_fuse: - why( - "numel/rnumel mismatch (reduce) (%s, %s), (%s, %s)", - numel1, - numel2, - rnumel1, - rnumel2, - ) - return reduction_can_fuse - - if not node1.is_reduction() and not node2.is_reduction(): - if not (numel1 == numel2 and rnumel1 == rnumel2): - why( - "numel/rnumel mismatch (non-reduce) (%s, %s), (%s, %s)", - numel1, - numel2, - rnumel1, - rnumel2, - ) - return False - - if node1.is_template(): - # Only allow fusion for TritonTemplates for now. - # Fusion for CUDATemplates are not supported. - is_triton_template = isinstance(node1.node, TritonTemplateBuffer) - if not is_triton_template: - why("node1 is not TritonTemplateBuffer") - return is_triton_template - - # check for a bad combined tiling - tiling1 = self.select_tiling(node1.get_nodes(), numel1, rnumel1) - tiling2 = self.select_tiling(node2.get_nodes(), numel1, rnumel1) - tiling3 = self.select_tiling( - node1.get_nodes() + node2.get_nodes(), numel1, rnumel1 - ) - if config.triton.tiling_prevents_pointwise_fusion: - cond = True - if len(tiling1) > 2: - if len(tiling2) > 2: - cond = tiling1 == tiling2 == tiling3 - else: - cond = tiling1 == tiling3 - elif len(tiling2) > 2: - cond = tiling2 == tiling3 - if not cond: - why( - "tiling mismatch (%s, %s, %s)", - tiling1, - tiling2, - tiling3, - ) - return False - - return True - - if not node1.is_reduction() and node2.is_reduction(): - assert rnumel1 == 1 and rnumel2 != 1 - if numel1 == numel2 * rnumel2: - if not all( - TritonKernel.is_compatible((numel2, rnumel2), n.get_ranges()) - for n in node1.get_nodes() - ): - why("nodes numel/rnumel incompatibility") - return False - if ( - config.triton.tiling_prevents_reduction_fusion - and not node1.is_template() - ): - is_reduction_tiling_valid = self.select_tiling( - node1.get_nodes(), numel1 - ) in ( - (numel1, 1), - (numel2, rnumel2, 1), - ) - if not is_reduction_tiling_valid: - why("invalid tiling for reduction") - return is_reduction_tiling_valid - return True - - if numel1 != numel2: - why("nodes numel incompatibility") - return numel1 == numel2 - - assert node1.is_reduction() and not node2.is_reduction() - # swap args to hit the case above - return self.can_fuse_horizontal(node2, node1) - - can_fuse_vertical = can_fuse - can_fuse_horizontal = can_fuse - - def generate_node_schedule(self, nodes, numel, rnumel): - node_schedule: List[Any] = [] - current_loop_writes: Set[str] = set() - - # Writes with a reduced shape, meaning they are only present once the - # reduction loop has ended - current_loop_reduced_writes = set() - current_loop_has_writes = False - done = set() - - def fits_in_main_body(n): - _, (node_numel, node_rnumel) = n.group - return (node_numel == numel and node_rnumel == rnumel) or ( - node_numel == numel * rnumel and node_rnumel == 1 - ) - - def fits_outside_reduction(n): - _, (node_numel, node_rnumel) = n.group - return node_numel == numel and node_rnumel == 1 and rnumel != 1 - - def schedule_node_in_loop(n): - nonlocal current_loop_has_writes - done.add(n) - node_schedule.append(n) - current_loop_has_writes = True - # A scan is modelled as a reduction in the scheduler but has a - # full sized output that can be used inside the loop body - if ( - n.is_reduction() - and isinstance(n, scheduler.SchedulerNode) - and isinstance(n.node, ir.ComputedBuffer) - and not isinstance(n.node.data, ir.Scan) - ): - current_loop_reduced_writes.add(n.get_name()) - - @contextlib.contextmanager - def end_current_reduction_loop(): - nonlocal current_loop_has_writes - if current_loop_has_writes: - # flush out any other runnable nodes to reduce number of loops - for other_node in nodes[index + 1 :]: - if ( - node not in done - and fits_in_main_body(other_node) - and not (current_loop_reduced_writes & other_node.ancestors) - ): - schedule_node_in_loop(node) - - if node_schedule and node_schedule[-1] is EnableReduction: - node_schedule.pop() - else: - node_schedule.append(DisableReduction) - yield - node_schedule.append(EnableReduction) - current_loop_reduced_writes.clear() - current_loop_has_writes = False - - for index, node in enumerate(nodes): - if node in done: - continue - done.add(node) - - def requires_closing_previous_reduction(node, node_schedule): - if rnumel == 1: - return False - if not current_loop_reduced_writes & node.ancestors: - return False - assert node_schedule and not isinstance( - node_schedule[-1], (EnableReduction, DisableReduction) - ) - return bool(current_loop_reduced_writes) - - if fits_in_main_body(node): - if requires_closing_previous_reduction(node, node_schedule): - with end_current_reduction_loop(): - pass # need to start a new reduction loop - - schedule_node_in_loop(node) - elif fits_outside_reduction(node): - with end_current_reduction_loop(): - node_schedule.append(node) - else: - raise NotImplementedError( - f"unexpected group: ({numel}, {rnumel}) != {node.group[1]}" - ) - - return node_schedule - - def codegen_node( - self, node: Union[scheduler.FusedSchedulerNode, scheduler.SchedulerNode] - ): - """ - Given a set of pre-fused nodes, generate a Triton kernel. - """ - - nodes: List[scheduler.SchedulerNode] = node.get_nodes() # type: ignore[assignment] - - _, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group - - node_schedule = self.generate_node_schedule(nodes, numel, rnumel) - buf_accesses = collections.defaultdict(list) - for node in nodes: - for access in node.read_writes.reads | node.read_writes.writes: - buf_accesses[access.name].append(access) - - schedule_log.debug("Schedule:\n %s", node_schedule) - - return self.codegen_node_schedule(node_schedule, buf_accesses, numel, rnumel) - - @staticmethod - def reduction_hint(node): - assert node.is_reduction() - if all( - dep.is_contiguous() - for dep in itertools.chain(node.read_writes.reads, node.read_writes.writes) - ): - return ReductionHint.INNER + def codegen_iteration_ranges_entry(self, entry: IterationRangesEntry): + line = f"{entry.name} = {self.kexpr(self.rename_indexing(entry.expr))}" + if entry.root.is_loop: + self.indexing_code.writeline(line) else: - return node.node.data.reduction_hint - - @staticmethod - def can_use_32bit_indexing( - numel: sympy.Expr, buffers: Iterable[Union[ir.Buffer, ir.TensorBox]] - ) -> bool: - int_max = torch.iinfo(torch.int32).max - size_hint = V.graph.sizevars.size_hint - has_hint = V.graph.sizevars.shape_env.has_hint - - def within_32bit(e): - # Allow for unhinted e as long as we can still statically prove - # (e.g., via ValueRanges) that it is still in bounds - if V.graph.sizevars.is_expr_static_and_true(e <= int_max): - return True - # Otherwise, the hint MUST exist and be in range - return has_hint(e) and size_hint(e) <= int_max - - if not within_32bit(numel): - return False - - # Any use of a MultiOutputLayout will create a buffer with a - # Layout whose sizes are accounted for - buf_sizes = [ - buf.get_layout().storage_size() - for buf in buffers - if not isinstance(buf.get_layout(), ir.MultiOutputLayout) - ] - - if not all(within_32bit(size) for size in buf_sizes): - return False - - # Only install guards for 32-bit indexing as there is no correctness - # issue with using 64-bit for everything - V.graph.sizevars.guard_leq(numel, int_max) # type: ignore[arg-type] - for size in buf_sizes: - V.graph.sizevars.guard_leq(size, int_max) # type: ignore[arg-type] - return True - - @staticmethod - def select_index_dtype(node_schedule, numel, reduction_numel): - # Gather all used buffer names - buffer_names = set() - for node in node_schedule: - if not isinstance(node, scheduler.BaseSchedulerNode): - continue - - buffer_names.update(node.get_names()) - buffer_names.update(node.used_buffer_names()) - - # Get buffers objects - - def _get_buffer(name: str) -> Union[ir.Buffer, ir.TensorBox]: - buf = V.graph.get_buffer(name) - if buf is None: - raise RuntimeError(f"Failed to find buffer matching name {name}") - return buf - - buffers = [V.graph.get_buffer(name) for name in buffer_names] - - # In theory we can separately check xnumel and rnumel are <= int_max - # but some indexers do use the full linear index so we need to be - # conservative here. - total_numel = numel * reduction_numel - - if TritonScheduling.can_use_32bit_indexing(total_numel, buffers): - return "tl.int32" - return "tl.int64" - - def has_non_contiguous_pw_in_reduction_kernel(self, node_schedule, numel, rnumel): - pointwise_nodes = list( - filter( - lambda n: n not in (EnableReduction, DisableReduction) - and not n.is_reduction() - and n.group[1][0] == numel * rnumel, - node_schedule, - ) - ) - for node in pointwise_nodes: - # An index can be an integer when loading a random seed. - if not all( - not isinstance(dep, MemoryDep) - or dep.is_contiguous() - or isinstance(dep.index, (sympy.Integer, int)) - or dep.stride1_for_last_dim() - for dep in itertools.chain( - node.read_writes.reads, node.read_writes.writes - ) - ): - return True - return False - - def get_kernel_args(self, node_schedule, numel, reduction_numel): - reductions = list( - filter( - lambda n: n not in (EnableReduction, DisableReduction) - and n.is_reduction(), - node_schedule, - ) - ) - if len(reductions) > 0: - hints = [self.reduction_hint(n) for n in reductions] - if hints.count(hints[0]) == len(hints): - reduction_hint_val = hints[0] - else: - reduction_hint_val = ReductionHint.DEFAULT - - if ( - reduction_hint_val == ReductionHint.INNER - and self.has_non_contiguous_pw_in_reduction_kernel( - node_schedule, numel, reduction_numel - ) - ): - reduction_hint_val = ReductionHint.DEFAULT - else: - reduction_hint_val = ReductionHint.DEFAULT - - mutations = set() - for node in node_schedule: - if hasattr(node, "get_mutations"): - mutations.update(node.get_mutations()) + # lift non-reduction stores outside loop + self.body.writeline(line) - index_dtype = self.select_index_dtype(node_schedule, numel, reduction_numel) - return reduction_hint_val, mutations, index_dtype +class TritonScheduling(SIMDScheduling): + int32_type = "tl.int32" + int64_type = "tl.int64" + kernel_type = TritonKernel def codegen_comment(self, node_schedule): wrapper = V.graph.wrapper_code @@ -3536,123 +2181,7 @@ def codegen_comment(self, node_schedule): f"{wrapper.comment} Fused node name list: {', '.join(node_names)}" ) - def codegen_node_schedule( - self, node_schedule, buf_accesses, numel, reduction_numel - ): - from torch._inductor.codegen.triton_split_scan import TritonSplitScanKernel - - tiled_groups = self.select_tiling(node_schedule, numel, reduction_numel) - ( - reduction_hint_val, - mutations, - index_dtype, - ) = self.get_kernel_args(node_schedule, numel, reduction_numel) - - is_split_scan = any( - isinstance(node, BaseSchedulerNode) and node.is_split_scan() - for node in node_schedule - ) - kernel_type = TritonSplitScanKernel if is_split_scan else TritonKernel - kernel_args = tiled_groups - kernel_kwargs = { - "reduction_hint": reduction_hint_val, - "mutations": mutations, - "index_dtype": index_dtype, - } - kernel = kernel_type( - *kernel_args, - **kernel_kwargs, - ) - kernel.buf_accesses = buf_accesses - - self.codegen_node_schedule_with_kernel(node_schedule, kernel) - - with V.set_kernel_handler(kernel): - src_code = kernel.codegen_kernel() - - kernel_name = self.define_kernel(src_code, node_schedule) - log.debug("Generating kernel code with kernel_name: %s", kernel_name) - kernel.kernel_name = kernel_name - kernel.code_hash = code_hash(src_code) - - if kernel.persistent_reduction and config.triton.multi_kernel: - kernel2 = TritonKernel( - *kernel_args, - **kernel_kwargs, - disable_persistent_reduction=True, - ) - self.codegen_node_schedule_with_kernel(node_schedule, kernel2) - with V.set_kernel_handler(kernel2): - src_code2 = kernel2.codegen_kernel() - kernel_name2 = self.define_kernel(src_code2, node_schedule) - kernel2.kernel_name = kernel_name2 - kernel2.code_hash = code_hash(src_code2) - - final_kernel = MultiKernel([kernel, kernel2]) - else: - final_kernel = kernel # type: ignore[assignment] - - with V.set_kernel_handler(final_kernel): - for node in node_schedule: - if node not in (EnableReduction, DisableReduction): - node.mark_run() - - self.codegen_comment(node_schedule) - final_kernel.call_kernel(final_kernel.kernel_name) - if config.nan_asserts: - final_kernel.codegen_nan_check() - if config.warn_mix_layout: - final_kernel.warn_mix_layout(kernel_name) - - V.graph.removed_buffers |= final_kernel.removed_buffers - V.graph.inplaced_to_remove |= final_kernel.inplaced_to_remove - - if ( - V.graph.wrapper_code.supports_intermediate_hooks - and config.generate_intermediate_hooks - ): - # Not every node in the schedule will actually be live on output; - # we can't check dead buffers. - live_outs = kernel.args.live_output_buffers() - for node in node_schedule: - if not isinstance(node, scheduler.BaseSchedulerNode): - continue - name = node.get_name() - if name not in live_outs: - continue - origin_node = node.node.get_origin_node() - if origin_node is not None: - counters["inductor"]["intermediate_hooks"] += 1 - V.graph.wrapper_code.writeline( - f"run_intermediate_hooks({origin_node.name!r}, {name})" - ) - - self.scheduler.free_buffers() - - def codegen_node_schedule_with_kernel(self, node_schedule, kernel): - def current_reduction_nodes(nodes): - return itertools.takewhile(lambda n: n is not DisableReduction, nodes) - - with kernel: - stack = contextlib.ExitStack() - kernel.set_last_usage(current_reduction_nodes(node_schedule)) - - for node in node_schedule: - if node not in (EnableReduction, DisableReduction): - node.decide_inplace_update() - for i, node in enumerate(node_schedule): - if node is DisableReduction: - stack.enter_context(kernel.disable_reduction()) - elif node is EnableReduction: - stack.close() - kernel.set_last_usage(current_reduction_nodes(node_schedule[i:])) - else: - # TODO - use split ranges ? - indexing_dtype_strength_reduction(node._body) - index_vars = kernel.split_and_set_ranges(node.get_ranges()) - node.codegen(index_vars) - - def define_kernel(self, src_code, node_schedule): + def define_kernel(self, src_code, node_schedule, kernel): wrapper = V.graph.wrapper_code if src_code in wrapper.src_to_kernel: kernel_name = wrapper.src_to_kernel[src_code] @@ -3704,293 +2233,6 @@ def define_kernel(self, src_code, node_schedule): return kernel_name - def codegen_template( - self, template_node, epilogue_nodes, only_gen_src_code=False - ) -> Optional[str]: - """ - Codegen a triton template - - If `only_gen_src_code` the src code will be returned instead of codegen'd into the wrapper - """ - _, (numel, rnumel) = template_node.group - assert rnumel == 1 - kernel, render = template_node.node.make_kernel_render(template_node.node) - with kernel: - if not only_gen_src_code: - for node in [template_node, *epilogue_nodes]: - node.mark_run() - partial_code = render() - for node in epilogue_nodes: - node.codegen(kernel.split_and_set_ranges(node.get_ranges())) - - # finalize must be called after adding epilogue above - with V.set_kernel_handler(kernel): - # TODO: Maybe unify CUDATemplateKernel to also use PartialRender for flexible epilogue fusion. - src_code = ( - partial_code - if isinstance(partial_code, str) - else partial_code.finalize() - ) - node_schedule = [template_node, *epilogue_nodes] - - if config.benchmark_kernel: - num_gb = kernel.estimate_kernel_num_bytes() / 1e9 - grid_args = V.graph.sizevars.size_hints(kernel.call_sizes) - assert kernel.meta is not None, "meta is None" - grid = kernel.grid_fn(*grid_args, kernel.meta) - src_code = ( - f"{kernel.imports_for_benchmark_kernel()}\n" - f"{src_code}\n" - f"{kernel.codegen_kernel_benchmark(num_gb, grid).getvalue()}" - ) - - if only_gen_src_code: - return src_code - - kernel_name = self.define_kernel(src_code, node_schedule) - - self.codegen_comment(node_schedule) - kernel.call_kernel(kernel_name, template_node.node) - V.graph.removed_buffers |= kernel.removed_buffers - V.graph.inplaced_to_remove |= kernel.inplaced_to_remove - self.scheduler.free_buffers() - return None - - def codegen_sync(self): - V.graph.wrapper_code.writeline(V.graph.device_ops.synchronize()) - - def codegen_foreach(self, foreach_node): - from .triton_foreach import ForeachKernel - - for partitions_with_metadata in ForeachKernel.horizontal_partition( - foreach_node.get_subkernel_nodes(), self - ): - kernel = ForeachKernel() - for nodes, tiled_groups, numel, rnumel in partitions_with_metadata: - node_schedule = self.generate_node_schedule(nodes, numel, rnumel) - ( - reduction_hint_val, - mutations, - index_dtype, - ) = self.get_kernel_args(node_schedule, numel, rnumel) - - subkernel = kernel.create_sub_kernel( - *tiled_groups, - reduction_hint=reduction_hint_val, - mutations=mutations, - index_dtype=index_dtype, - ) - - self.codegen_node_schedule_with_kernel( - node_schedule, - subkernel, - ) - - with V.set_kernel_handler(subkernel): - for node in node_schedule: - if node not in (EnableReduction, DisableReduction): - node.mark_run() - V.graph.removed_buffers |= subkernel.removed_buffers - V.graph.inplaced_to_remove |= subkernel.inplaced_to_remove - - src_code = kernel.codegen_kernel() - kernel_name = self.define_kernel(src_code, [foreach_node]) - self.codegen_comment([foreach_node]) - kernel.call_kernel(V.graph.wrapper_code, kernel_name) - - self.scheduler.free_buffers() - - @staticmethod - @functools.lru_cache(32) - def candidate_tilings(node): - ranges, reduction_ranges = node.get_ranges() - if len(ranges) <= 1: - return () - - rw = node.pointwise_read_writes() - assert len(rw.range_vars) == len(ranges) - - # isinstance(dep, MemoryDep): this filters out StarDeps. StarDeps refer to reads - # that need to access the entire tensor; they don't contribute read indexing - # information (and practically, they don't have dep.index so they can't be used - # for stride_hints below - dep_sources = [rw.reads, rw.writes] - assert all( - isinstance(dep, (MemoryDep, StarDep)) - for dep in itertools.chain.from_iterable(dep_sources) - ) - deps = [ - dep - for dep in itertools.chain.from_iterable(dep_sources) - if dep.name not in V.graph.removed_buffers and isinstance(dep, MemoryDep) - ] - write_names = {dep.name for dep in rw.writes} - - tilings: List[CandidateTiling] = [] - - for dep in deps: - strides = V.graph.sizevars.stride_hints(dep.index, rw.range_vars) - assert len(strides) == len(ranges) - try: - split = strides.index(1) + 1 - if split == len(ranges): - continue - if all(s == 0 for s in strides[split:]): - # if this is a broadcasted tensor and all dimensions after split are broadcast, - # this is not a real split - continue - - except ValueError: - continue - tiled_groups = ( - V.graph.sizevars.simplify(sympy_product(ranges[:split])), - V.graph.sizevars.simplify(sympy_product(ranges[split:])), - ) - # score by number of elements - score = V.graph.sizevars.size_hint( - sympy_product( - size for size, stride in zip(ranges, strides) if stride != 0 - ) - ) - if dep.name in write_names: - # ngimel said contiguous writes is more important than reads - score *= 2 - if CandidateTiling.is_good_size(tiled_groups[0]): - score *= 2 - if CandidateTiling.is_good_size(tiled_groups[1]): - score *= 2 - - if ( - V.graph.sizevars.size_hint( - score - sympy_product(itertools.chain(ranges, reduction_ranges)) - ) - >= 0 - ): - tilings.append(CandidateTiling(tiled_groups, score, dep.name)) - return tilings - - @classmethod - def select_tiling(cls, node_schedule, numel, reduction_numel=sympy.Integer(1)): - """ - Heuristics to decide how to tile kernels. - Currently, we tile based on stride-1 dimensions. - - Returns: - `(tile1, tile2, reduction_numel)` s.t. `tile1 * tile2 == numel` - - """ - if reduction_numel != 1 or config.triton.max_tiles <= 1: - # TODO(jansel): should we tile reductions? - # do perf hint here if stride-1 dim is not being reduced - if perf_hint_log.level <= logging.WARNING: - for node in EnableReduction.filter(node_schedule): - if len(cls.candidate_tilings(node)) > 0: - perf_hint_log.info("reduction over non-contiguous dims") - break - return (numel, reduction_numel) - - seen_names = set() - candidate_tiles: Counter[Any] = collections.Counter() - for node in EnableReduction.filter(node_schedule): - for tiling in cls.candidate_tilings(node): - if tiling.name in seen_names: - continue - seen_names.add(tiling.name) - candidate_tiles[tiling.tiling] += tiling.score - - ranked_tilings = [tiling for tiling, score in candidate_tiles.most_common()] - - if config.triton.max_tiles >= 3: - # Consider adding a third dimension of tiling, but only - # when a1 is a multiple of b1; otherwise, you have a lot - # of stragglers which is annoying to generate code for. - # - # NB: More than three max tiles is not enabled by default. - - # Add one 3D tiling choice - for i in range(1, len(ranked_tilings)): - a0, a1 = ranked_tilings[0] - b0, b1 = ranked_tilings[i] - if V.graph.sizevars.size_hint(a1 - b1) == 0: - continue - if V.graph.sizevars.size_hint(a1 - b1) < 0: - # swap so a0 is bigger - a0, a1 = ranked_tilings[i] - b0, b1 = ranked_tilings[0] - assert V.graph.sizevars.size_hint(a1 - b1) > 0 - if V.graph.sizevars.statically_known_multiple_of(a1, b1): - tiling = (a0, FloorDiv(a1, b1), b1) - ranked_tilings = [tiling] + ranked_tilings - break # only 1 choice for now - - if len(ranked_tilings) > 1: - perf_hint_log.info("possibly bad tiling: %s", ranked_tilings) - - for tiled_groups in ranked_tilings: - new_groups = (*tiled_groups, reduction_numel) - if all( - TritonKernel.is_compatible(new_groups, node.get_ranges()) - for node in node_schedule - if isinstance(node, scheduler.SchedulerNode) - ): - return new_groups - - return (numel, reduction_numel) - - def flush(self): - pass - - def ready_to_flush(self) -> bool: - return False - - def generate_kernel_code_from_nodes(self, nodes, benchmark_kernel=False): - @dataclasses.dataclass - class LastUsageHolder: - n: Any - last_usage: Any - - def __del__(self): - self.n.last_usage = self.last_usage - - last_usage_holders = [LastUsageHolder(n, n.last_usage) for n in nodes] - - # empty last_usage. May cause more aggressive 'evict_last'. Should be fine. - for n in nodes: - n.last_usage = set() - - if not nodes[0].is_template(): - _, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group - node_schedule = self.generate_node_schedule(nodes, numel, rnumel) - - tiled_groups = self.select_tiling(node_schedule, numel, rnumel) - reduction_hint_val, mutations, index_dtype = self.get_kernel_args( - node_schedule, numel, rnumel - ) - - kernel = TritonKernel( - *tiled_groups, - reduction_hint=reduction_hint_val, - mutations=mutations, - index_dtype=index_dtype, - ) - - self.codegen_node_schedule_with_kernel(node_schedule, kernel) - with config.patch( - "benchmark_kernel", benchmark_kernel - ), V.set_kernel_handler(kernel): - src_code = kernel.codegen_kernel() - else: - template_node = nodes[0] - epilogue_nodes = nodes[1:] - - with config.patch("benchmark_kernel", benchmark_kernel): - src_code = self.codegen_template( - template_node, epilogue_nodes, only_gen_src_code=True - ) - - src_code = src_code.replace(str(Placeholder.KERNEL_NAME), "triton_") - return src_code - @preserve_rng_state() def benchmark_fused_nodes(self, nodes): src_code = self.generate_kernel_code_from_nodes(nodes, benchmark_kernel=True) @@ -4061,50 +2303,3 @@ def store_cache(): ) store_cache() return ms, mod.__file__ - - -@dataclasses.dataclass -class CandidateTiling: - tiling: Tuple[sympy.Expr, sympy.Expr] - score: int # higher is better - name: Optional[str] = None - - @staticmethod - def is_good_size(s): - """Somewhat arbitrary heuristic used to boost scores for some sizes""" - s = V.graph.sizevars.size_hint(s) - return s >= 32 and (s % 32 == 0) - - -class DisableReduction: - """ - Marker to invoke `kernel.disable_reduction()`. This closes a - reduction loop and allows for pointwise ops to occur on the output - of a reduction. - """ - - -class EnableReduction: - """ - Marker to end a DisableReduction block. - """ - - @staticmethod - def filter(node_schedule): - """ - Get the nodes from node_schedule skipping those in a - DisableReduction block. - """ - disabled = False - for node in node_schedule: - if node in (EnableReduction, DisableReduction): - # Don't tile stuff outside the main reduction loop - disabled = node is DisableReduction - elif disabled: - pass - else: - yield node - - -class CantSplit(Exception): - pass diff --git a/torch/_inductor/codegen/triton_split_scan.py b/torch/_inductor/codegen/triton_split_scan.py index 8df904946e4aa..2a8e0142fbd4c 100644 --- a/torch/_inductor/codegen/triton_split_scan.py +++ b/torch/_inductor/codegen/triton_split_scan.py @@ -4,12 +4,9 @@ import torch._inductor.runtime.hints from torch._inductor import config +from torch._inductor.codegen.simd import IterationRangesRoot -from torch._inductor.codegen.triton import ( - IterationRangesRoot, - triton_compute_type, - TritonKernel, -) +from torch._inductor.codegen.triton import triton_compute_type, TritonKernel from torch._prims_common import prod diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 79af641514bd6..db8a6d9ae3b63 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -232,14 +232,21 @@ def is_fbcode(): force_same_precision = ( True if is_fbcode() else os.environ.get("TORCHINDUCTOR_FORCE_SAME_PRECISION") == "1" ) + # Specify candidate backends for gemm autotune. -# Possible choices are combinations of: ATen, Triton, CUTLASS, CPP. +# Possible choices are combinations of: ATen, Triton, CUTLASS. # ATen: default Pytorch ATen kernels. # Triton: Triton templates defined in torch inductor. # CUTLASS: Cutlass templates and kernels. -# CPP: CPP templates and kernels for CPU. max_autotune_gemm_backends = os.environ.get( - "TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS", "ATEN,TRITON,CPP" + "TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS", "ATEN,TRITON" +).upper() + +# Specify the size of the search space for GEMM autotuning. +# DEFAULT - balance between compile time overhead and performance +# EXHAUSTIVE - maximize performance +max_autotune_gemm_search_space = os.environ.get( + "TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_SEARCH_SPACE", "DEFAULT" ).upper() # the value used as a fallback for the unbacked SymInts @@ -720,6 +727,10 @@ class aot_inductor: debug_compile = os.environ.get("AOT_INDUCTOR_DEBUG_COMPILE", "0") == "1" + debug_dump_consts_bin: bool = ( + os.environ.get("AOT_INDUCTOR_DEBUG_DUMP_CONSTS_BIN", "0") == "1" + ) + # Serialized tree spec for flattening inputs serialized_in_spec = "" diff --git a/torch/_inductor/fx_passes/pad_mm.py b/torch/_inductor/fx_passes/pad_mm.py index df282629e2ce7..43f7e009af83b 100644 --- a/torch/_inductor/fx_passes/pad_mm.py +++ b/torch/_inductor/fx_passes/pad_mm.py @@ -1,4 +1,5 @@ import functools +import itertools import operator from typing import List, Optional, Union @@ -325,6 +326,17 @@ def should_pad_bench( if m_padded_length == k_padded_length == n_padded_length == 0: return False + def realize_symbols(ds): + return [d if isinstance(d, int) else d.node.hint for d in ds] + + if any( + dim == 0 + for dim in itertools.chain( + realize_symbols(mat1.shape), realize_symbols(mat2.shape) + ) + ): + return False + if torch._inductor.config.force_shape_pad: return True @@ -342,9 +354,6 @@ def should_pad_bench( if cached_pad is not None: return cached_pad - def realize_symbols(ds): - return [d if isinstance(d, int) else d.node.hint for d in ds] - def realize_tensor(t): if isinstance(t, FakeTensor): size_hints = realize_symbols(t.size()) diff --git a/torch/_inductor/fx_passes/quantization.py b/torch/_inductor/fx_passes/quantization.py index 0d4fc3b429332..4476a9ccd512d 100644 --- a/torch/_inductor/fx_passes/quantization.py +++ b/torch/_inductor/fx_passes/quantization.py @@ -224,17 +224,26 @@ def generate_pattern_with_binary( binary_post_op, computation_call, extra_input_pattern, - int8_mixed_bf16_with_inplace_add=False, + dtype_convert=False, + swap_inputs=False, ): - binary_pattern = CallFunction( - binary_post_op, - computation_call, - extra_input_pattern, + binary_pattern = ( + CallFunction( + binary_post_op, + extra_input_pattern, + computation_call, + ) + if swap_inputs + else CallFunction( + binary_post_op, + computation_call, + extra_input_pattern, + ) ) return _may_generate_pattern_with_dtype_convert( binary_pattern, KeywordArg("convert_dtype_after_inplace_add"), - int8_mixed_bf16_with_inplace_add, + dtype_convert, ) @@ -435,10 +444,109 @@ def qlinear(match: Match, *args, **kwargs): return qlinear -def _is_valid_quantized_conv_binary_optimization_pattern(): - # Check if it's a valid Conv Binary Pattern: - # * qconv2d_pointwise should only has one users - # * Extra input of binary node comes from dequant pattern +def _register_quantized_linear_binary_lowering( + pattern, + pass_number, + computation_op, + binary_unary_attr, +): + @register_lowering_pattern( + pattern, + extra_check=_is_valid_qlinear_binary_optimization_pattern(), + pass_number=pass_number, + ) + def qlinear_binary(match: Match, *args, **kwargs): + output_dtype = _get_pattern_output_dtype(match) + # Activation QParams + x, x_scale, x_zp = ( + kwargs["x"], + kwargs["x_scale"], + kwargs["x_zp"], + ) + x2 = ( + kwargs["accum"] + if binary_unary_attr.binary_op_name == "sum" + else kwargs["other"] + ) + x2_scale = 1.0 + x2_zp = 0 + # Weight QParams + packed_weight, w_scale, w_zp = ( + kwargs["packed_weight"], + kwargs["w_scale"], + kwargs["w_zp"], + ) + # bias + b = kwargs["b"] if "b" in kwargs else None + # Output QParams + o_inv_scale = kwargs["o_inv_scale"] if output_dtype is None else 1.0 + o_zero_point = kwargs["o_zp"] if output_dtype is None else 0 + + x2.realize() + from .mkldnn_fusion import _can_be_inplace + + if binary_unary_attr.binary_op_name == "sum": + assert _can_be_inplace( + x2 + ), "QLinear Binary Inplace Fusion requires accum is not an alias or mutation." + + # if the binary post op is sum but output dtype is not the same as accum, + # use accum's dtype as output dtype + out_dtype = output_dtype + if ( + output_dtype + and binary_unary_attr.binary_op_name == "sum" + and output_dtype != x2.dtype + ): + out_dtype = x2.dtype + + computation_args = ( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + b, + o_inv_scale, + o_zero_point, + out_dtype, + x2, + x2_scale, + x2_zp, + binary_unary_attr.binary_op_name, + binary_unary_attr.alpha, + binary_unary_attr.unary_op_name, + binary_unary_attr.scalars_attr, + binary_unary_attr.algorithm_attr, + ) + counters["inductor"]["qlinear_binary_matcher_count"] += 1 + counters["inductor"]["qlinear_binary_matcher_nodes"] += len(match.nodes) + return L[computation_op](*computation_args) + + return qlinear_binary + + +def _is_valid_qconv_binary_optimization_pattern(): + return _is_valid_quantized_op_binary_optimization_pattern( + torch.ops.onednn.qconv2d_pointwise + ) + + +def _is_valid_qlinear_binary_optimization_pattern(): + return _is_valid_quantized_op_binary_optimization_pattern( + torch.ops.onednn.qlinear_pointwise, + # we don't insert q-dq for extra input due to accuracy issues + extra_input_from_dequant=False, + ) + + +def _is_valid_quantized_op_binary_optimization_pattern( + qop, extra_input_from_dequant=True +): + # Check if it's a valid Binary Pattern for qconv2d and qlinear: + # * qop_pointwise should only has one users + # * If extra_input_from_dequant is True, extra input of binary node should come from dequant pattern # * the two inputs of binary node should have attribute "meta" and should be tensors # * the two inputs of binary node should have the same shape # * All users of the extra input in this pattern should be @@ -446,8 +554,8 @@ def _is_valid_quantized_conv_binary_optimization_pattern(): # connected to the compute node. def fn(match): output_dtype = _get_pattern_output_dtype(match) - compute_node = filter_nodes(match.nodes, torch.ops.onednn.qconv2d_pointwise)[0] - # qconv2d_pointwise should only have one user + compute_node = filter_nodes(match.nodes, qop)[0] + # qop_pointwise should only have one user if len(compute_node.users) != 1: return False binary_node_inputs = next(iter(compute_node.users)).args @@ -460,9 +568,12 @@ def fn(match): break assert extra_input_of_binary_node is not None # Extra input of binary node comes from dequant pattern - if (not isinstance(extra_input_of_binary_node, torch.fx.Node)) or ( - extra_input_of_binary_node.target - != quantized_decomposed.dequantize_per_tensor.default + if extra_input_from_dequant and ( + (not isinstance(extra_input_of_binary_node, torch.fx.Node)) + or ( + extra_input_of_binary_node.target + != quantized_decomposed.dequantize_per_tensor.default + ) ): return False @@ -489,9 +600,13 @@ def fn(match): from .mkldnn_fusion import _get_remaining_users extra_input_of_pattern = ( - match.kwargs["accum"] - if output_dtype is None - else match.kwargs["accum_after_dequant"] + match.kwargs["other"] + if "other" in match.kwargs + else ( + match.kwargs["accum"] + if output_dtype is None or (not extra_input_from_dequant) + else match.kwargs["accum_after_dequant"] + ) ) if ( len( @@ -517,7 +632,7 @@ def _register_quantized_conv_binary_lowering( ): @register_lowering_pattern( pattern, - extra_check=_is_valid_quantized_conv_binary_optimization_pattern(), + extra_check=_is_valid_qconv_binary_optimization_pattern(), pass_number=pass_number, ) def qconv_binary(match: Match, *args, **kwargs): @@ -884,6 +999,228 @@ def __init__( binary_unary_attr, # binary_unary_attr ) + # QLinear + r""" + Supported linear-binary(-unary) patterns + + linear(X) extra input + \ / + Add + | + Optional(relu) + | + Y + + 1. int8-mixed-fp32 + +---+---------------+-----------+------------------------------+---------+ + | # | Add type | Quant out | Pattern | Post op | + +---+---------------+-----------+------------------------------+---------+ + | 1 | In-/out-place | Yes | linear + fp32 -> (relu) -> q | add | + +---+---------------+-----------+------------------------------+---------+ + | 2 | In-/out-place | No | linear + fp32 -> (relu) | sum | + +---+---------------+-----------+------------------------------+---------+ + + 2. int8-mixed-bf16 + +---+----------+---------------+-----------+-----------------------------------------+---------+ + | # | X2 dtype | Add type | Quant out | Pattern | Post op | + +---+----------+---------------+-----------+-----------------------------------------+---------+ + | 1 | BF16 | In-/out-place | Yes | linear + bf16 -> (relu) -> q | add | + +---+----------+---------------+-----------+-----------------------------------------+---------+ + | 2 | BF16 | In-/out-place | No | linear + bf16 -> (relu) | sum | + +---+----------+---------------+-----------+-----------------------------------------+---------+ + | 3 | FP32 | Out-place | Yes | linear + fp32 -> (relu) -> q | add | + | | | In-place right| | | | + +---+----------+---------------+-----------+-----------------------------------------+---------+ + | 4 | FP32 | Out-place | No | linear + fp32 -> (relu) | sum | + | | | In-place right| | | | + +---+----------+---------------+-----------+-----------------------------------------+---------+ + | 5 | FP32 | In-place left | Yes | linear + fp32 -> to_bf16 -> (relu) -> q | add | + +---+----------+---------------+-----------+-----------------------------------------+---------+ + | 6 | FP32 | In-place left | No | linear + fp32 -> to_bf16 -> (relu) | add | + +---+----------+---------------+-----------+-----------------------------------------+---------+ + + Note + (1) The positions of linear and the extra input can be swapped. + (2) we don't insert q-dq before the extra input of linear-add by recipe. But if q-dq is found at the + extra input, we don't match that pattern because we cannot match all these patterns in 3 passes. + """ + for x_scale_zp_are_tensors in (False, True): + qlinear_binary_op = ( + torch.ops.onednn.qlinear_pointwise.binary_tensor + if x_scale_zp_are_tensors + else torch.ops.onednn.qlinear_pointwise.binary + ) + unary_postop_list = ["none", "relu"] + unary_postop_dict = { + "none": None, + "relu": aten.relu.default, + } + convert_dtype_after_binary_list = [False, True] + + # Priority 1 to match: QLinear Binary or Binary-Unary pattern with int8 output + # Covers case (1) of int8-mixed-fp32 and case (1)(3)(5) of int8-mixed-bf16, + # totally 3 patterns (2 are identical) + swap_binary_inputs_list = [False, True] + int8_mixed_bf16_list = [False, True] + combinations = itertools.product( + unary_postop_list, + int8_mixed_bf16_list, + swap_binary_inputs_list, + convert_dtype_after_binary_list, + ) + qlinear_binary_replace_patterns = {} + for unary_op, int8_mixed_bf16, swap_inputs, cvt_dtype_binary in combinations: + if not int8_mixed_bf16 and cvt_dtype_binary: + # No convert node after binary node if dtypes are all fp32 + continue + qlinear_binary_replace_patterns.update( + { + BinaryUnaryAttr( + "add", 1.0, unary_op, [], "" + ): generate_pattern_with_output_quant( + generate_pattern_with_unary( + generate_pattern_with_binary( + aten.add.Tensor, + get_qlinear_pt2e_pattern(x_scale_zp_are_tensors), + KeywordArg("other"), + # If fp32 extra input is inplace added to bf16 linear output, + # a to_bf16 node is inserted after binary + dtype_convert=cvt_dtype_binary, + swap_inputs=swap_inputs, + ), + unary_postop_dict[unary_op], + ), + ) + } + ) + for binary_unary_attr, patterns in qlinear_binary_replace_patterns.items(): + _register_quantized_linear_binary_lowering( + patterns, + 0, # pass_number + qlinear_binary_op, # computation_op + binary_unary_attr, # binary_unary_attr + ) + + # Priority 2.1 to match: QLinear Binary-Unary pattern with fp32/bfloat16 output + # Covers case (2) of int8-mixed-fp32 and case (2)(4) of int8-mixed-bf16, + # totally 2 patterns (2 are identical) + binary_replace_float_out_patterns = {} + for swap_binary_inputs in swap_binary_inputs_list: + binary_replace_float_out_patterns.update( + { + BinaryUnaryAttr( + "sum", 1.0, "relu", [], "" + ): generate_pattern_with_unary( + generate_pattern_with_binary( + aten.add.Tensor, + get_qlinear_pt2e_pattern(x_scale_zp_are_tensors), + KeywordArg("accum"), + dtype_convert=False, + swap_inputs=swap_binary_inputs, + ), + aten.relu.default, + ), + } + ) + for ( + binary_unary_attr, + patterns, + ) in binary_replace_float_out_patterns.items(): + _register_quantized_linear_binary_lowering( + patterns, + 1, # pass_number + qlinear_binary_op, # computation_op + binary_unary_attr, + ) + # Priority 2.2 to match: QLinear Binary-Unary pattern with fp32/bfloat16 output + # Covers case (6) of int8-mixed-bf16 + binary_replace_float_out_patterns = {} + for swap_binary_inputs in swap_binary_inputs_list: + binary_replace_float_out_patterns.update( + { + BinaryUnaryAttr( + "add", 1.0, "relu", [], "" + ): generate_pattern_with_unary( + generate_pattern_with_binary( + aten.add.Tensor, + get_qlinear_pt2e_pattern(x_scale_zp_are_tensors), + KeywordArg("other"), + dtype_convert=True, + swap_inputs=swap_binary_inputs, + ), + aten.relu.default, + ), + } + ) + for ( + binary_unary_attr, + patterns, + ) in binary_replace_float_out_patterns.items(): + _register_quantized_linear_binary_lowering( + patterns, + 1, # pass_number + qlinear_binary_op, # computation_op + binary_unary_attr, + ) + + # Priority 3.1: QLinear Binary pattern with fp32/bfloat16 output + # Covers case (2) of int8-mixed-fp32 and case (2)(4) of int8-mixed-bf16, + # totally 2 patterns (2 are identical) + binary_replace_float_out_patterns = {} + for swap_binary_inputs in swap_binary_inputs_list: + binary_replace_float_out_patterns.update( + { + BinaryUnaryAttr( + "sum", 1.0, "none", [], "" + ): generate_pattern_with_binary( + aten.add.Tensor, + get_qlinear_pt2e_pattern(x_scale_zp_are_tensors), + KeywordArg("accum"), + dtype_convert=False, + swap_inputs=swap_binary_inputs, + ), + } + ) + for ( + binary_unary_attr, + patterns, + ) in binary_replace_float_out_patterns.items(): + _register_quantized_linear_binary_lowering( + patterns, + 2, # pass_number + qlinear_binary_op, # computation_op + # Output dtype should be the same as accum's dtype but we don't know + # its dtype. So, leave it to be determined in the lowering function + binary_unary_attr, + ) + # Priority 3.2: QLinear Binary pattern with fp32/bfloat16 output + # Covers (6) of int8-mixed-bf16 + binary_replace_float_out_patterns = {} + for swap_binary_inputs in swap_binary_inputs_list: + binary_replace_float_out_patterns.update( + { + BinaryUnaryAttr( + "add", 1.0, "none", [], "" + ): generate_pattern_with_binary( + aten.add.Tensor, + get_qlinear_pt2e_pattern(x_scale_zp_are_tensors), + KeywordArg("other"), + dtype_convert=True, + swap_inputs=swap_binary_inputs, + ), + } + ) + for ( + binary_unary_attr, + patterns, + ) in binary_replace_float_out_patterns.items(): + _register_quantized_linear_binary_lowering( + patterns, + 2, # pass_number + qlinear_binary_op, # computation_op + binary_unary_attr, + ) + def _is_valid_quantized_maxpool2d_optimization_pattern(): def fn(match): diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 672509b20a568..bfb7b8dea7ebb 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -1303,6 +1303,8 @@ def debug(msg): torch.ops.aten.mkldnn_rnn_layer.default, torch.ops.onednn.qlinear_pointwise.default, torch.ops.onednn.qlinear_pointwise.tensor, + torch.ops.onednn.qlinear_pointwise.binary, + torch.ops.onednn.qlinear_pointwise.binary_tensor, ] need_fixed_channels_last_layout += [ torch.ops.mkldnn._convolution_pointwise.default, diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index bd092791eb7c2..689877ba69281 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -3595,7 +3595,10 @@ def __init__( self.mutated_inputs = mutated_inputs if mutated_inputs is not None: # Ensure that the mutated inputs are only allowed for certain nodes - allowed_set = {torch.ops.higher_order.flex_attention} + allowed_set = { + torch.ops.higher_order.flex_attention, + torch.ops.higher_order.flex_attention_backward, + } current_node = V.graph.current_node.target assert ( current_node in allowed_set @@ -3718,13 +3721,6 @@ def get_workspace_size(self): return self.workspace_size if self.workspace_size is not None else 0 -class CppTemplateBuffer(TemplateBuffer): - def __init__(self, layout, inputs, make_kernel_render, template, choice): - super().__init__(layout, inputs, make_kernel_render) - self.template = template - self.choice = choice - - @dataclasses.dataclass class InputsKernel(Buffer): inputs: List[Buffer] @@ -3930,6 +3926,21 @@ def should_allocate(self): return True +def get_aten_cpp_kernel_name(kernel): + # Calling with the default kernel name can lead to ambiguous behavior like the following example. + # repeat_interleave(const at::Tensor & repeats, c10::optional output_size=c10::nullopt) + # repeat_interleave(const at::Tensor & self, int64_t repeats, + # c10::optional dim=c10::nullopt, c10::optional output_size=c10::nullopt) + if not isinstance(kernel, torch._ops.OpOverload) or kernel.namespace != "aten": + return None + opname = ( + kernel.__name__.split(".")[0] + if kernel._overloadname == "default" + else kernel.__name__.replace(".", "_") + ) + return f"at::_ops::{opname}::call" + + @dataclasses.dataclass class ExternKernel(InputsKernel): constant_args: Tuple[Any, ...] = () @@ -3973,7 +3984,8 @@ def __init__( self.kwargs = kwargs if kwargs else {} self.output_view = output_view self.python_kernel_name = python_kernel_name - self.cpp_kernel_name = cpp_kernel_name + # If cpp_kernel_name is None, we will try to construct it from op_overload + self.cpp_kernel_name = cpp_kernel_name or get_aten_cpp_kernel_name(op_overload) self.ordered_kwargs_for_cpp_kernel = ordered_kwargs_for_cpp_kernel self.op_overload = op_overload self.collect_arg_kwarg_properties() @@ -4016,6 +4028,40 @@ def collect_arg_kwarg_properties(self): else {} ) + def fill_non_provided_args(self, args, kwargs, convert_val_to_str=False): + # Previously, we want to maintain forward-compatibility by skipping + # default args in the serialized artifacts in fbcode. However, + # some of our shim interfaces require default values being set. + # Discussed with Sherlock offline and we decided to allow serializing + # default args into the C++ wrapper code for now. We will refine this + # part if we see real FC requirement. More details related to FC + # can be found at: + # https://docs.google.com/document/d/1FzWm-sHYwmRi3x_g036kOxd99KaYquUsA-L5JwOn8ys/edit?usp=sharing + assert isinstance(args, (list, tuple)) + if isinstance(args, tuple): + args = list(args) + assert self.arg_properties, "ExternKernel.arg_properties should not be empty" + + n_args = len(args) + n_pos_args = len(self.arg_properties) + # For cpp wrapper, if some positional args are not provided, we need to check + # if they're in the kwargs or use their default value + if n_args < n_pos_args: + log.debug( + "%s has %d unprovided positional arguments. " + "Will check if they are in the keyword arguments or will use default values.", + self.op_overload, + n_pos_args - n_args, + ) + for i in range(n_args, n_pos_args): + arg_name = self.arg_properties[i]["name"] + args.append( + kwargs[arg_name] + if arg_name in kwargs + else self.arg_properties[i]["default_value"] + ) + return args + def decide_layout(self): if isinstance(self.layout, FlexibleLayout): self.apply_constraint() @@ -4030,7 +4076,15 @@ def codegen(self, wrapper): raise NotImplementedError def get_kernel_name(self): - return self.cpp_kernel_name if V.graph.cpp_wrapper else self.python_kernel_name + return ( + ( + V.graph.wrapper_code.get_c_shim_func_name(self.cpp_kernel_name) # type: ignore[attr-defined] + if config.abi_compatible + else self.cpp_kernel_name + ) + if V.graph.cpp_wrapper + else self.python_kernel_name + ) @staticmethod def copy_input(x): @@ -4726,9 +4780,17 @@ class InplaceBernoulliFallback(ExternKernel): def codegen(self, wrapper): (x,) = (t.codegen_reference() for t in self.inputs) - wrapper.writeline( - f"{self.get_kernel_name()}({x}, {', '.join(map(repr, self.constant_args))}){wrapper.ending}" - ) + + if V.graph.cpp_wrapper and config.abi_compatible: + # Inductor doesn't really support aten Generator, so the Generator kwarg is always NULL here, + # which needs to be explicitly generated for cpp wrapper + wrapper.writeline( + f"{self.get_kernel_name()}({x}, {', '.join(map(repr, self.constant_args))}, NULL){wrapper.ending}" + ) + else: + wrapper.writeline( + f"{self.get_kernel_name()}({x}, {', '.join(map(repr, self.constant_args))}){wrapper.ending}" + ) def should_allocate(self): return False @@ -4739,20 +4801,19 @@ def get_mutation_names(self): def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]: return set() - def __init__(self, x, *constant_args): + def __init__(self, op_overload, x, *constant_args): super().__init__( None, NoneLayout(x.get_device()), # type: ignore[arg-type] self.unwrap_storage([x]), constant_args, + op_overload=op_overload, ) self.name = V.graph.register_buffer(self) self.python_kernel_name = "aten.bernoulli_" - self.cpp_kernel_name = ( - "aoti_torch_bernoulli_" - if config.abi_compatible - else "at::native::bernoulli_" - ) + if not config.abi_compatible: + # TODO: this should be simplified once we switch to ABI-compatible only + self.cpp_kernel_name = "at::native::bernoulli_" mark_node_as_mutating(self, x) @@ -5128,25 +5189,7 @@ class ExternKernelNode: } -def get_aten_cpp_kernel_name(kernel): - # Calling with the default kernel name can lead to ambiguous behavior like the following example. - # repeat_interleave(const at::Tensor & repeats, c10::optional output_size=c10::nullopt) - # repeat_interleave(const at::Tensor & self, int64_t repeats, - # c10::optional dim=c10::nullopt, c10::optional output_size=c10::nullopt) - assert ( - isinstance(kernel, torch._ops.OpOverload) and kernel.namespace == "aten" - ), "Invalid aten kernel" - opname = ( - kernel.__name__.split(".")[0] - if kernel._overloadname == "default" - else kernel.__name__.replace(".", "_") - ) - return f"at::_ops::{opname}::call" - - class FallbackKernel(ExternKernelAlloc): - args_default_value: List[Dict[str, Any]] - def __init__( self, layout, @@ -5158,12 +5201,23 @@ def __init__( *, unbacked_bindings=None, ): + if ( + kernel == aten.mul.Tensor + and len(tensor_args) == 1 + and len(nontensor_args) == 1 + ): + # When aten.mul.Tensor's second arg is constant, cpp wrapper expects + # to call mul_Scalar. A more proper fix is to do it in decomposition. + # See https://github.com/pytorch/pytorch/issues/123478 + kernel = aten.mul.Scalar + super().__init__( layout, tuple(tensor_args), tuple(nontensor_args), op_overload=kernel, ) + # We need output buffers for generating kernel arguments in the # abi-compatible mode, where we retrieve outputs by pass each individual # output through the abi-compatible interface. @@ -5179,7 +5233,6 @@ def __init__( ), ), f"Fails to create FallbackKernel for {kernel}: {type(kernel)} not supported" self.op_overload = kernel - self.unflatten_args = unflatten_args self.kwargs = {} if kwargs is None else kwargs V.graph.warn_fallback(self.python_kernel_name) @@ -5341,41 +5394,6 @@ def is_not_write(arg): self.cpp_kernel_key = f"{self.cpp_kernel_name.replace('::', '_')}_{self.cpp_kernel_overload_name}" # type: ignore[union-attr] self.cpp_op_schema = get_cpp_op_schema(kernel) - self.init_args_default_value(kernel._schema) - - def init_args_default_value(self, schema): - self.args_default_value = [ - { - "name": x.name, - "type": x.real_type, - "value": x.default_value, - } - for x in schema.arguments - if not x.kwarg_only - ] - - def get_pos_arg_value(self, pos, kwargs): - # positional args may be provided in kwargs - pos_arg_name = self.args_default_value[pos]["name"] - if pos_arg_name in kwargs: - log.debug( - "Found argument %s with value %s from kwargs", - pos_arg_name, - kwargs[pos_arg_name], - ) - return kwargs[pos_arg_name] - - assert hasattr( - self, "args_default_value" - ), "self.args_default_value has to be provided" - assert pos < len( - self.args_default_value - ), f"expected the index {pos} to be smaller than len(self.args_default_value): {len(self.args_default_value)}" - arg_default_value = self.args_default_value[pos]["value"] - log.debug( - "Use default value %s for argument %s", arg_default_value, pos_arg_name - ) - return arg_default_value def codegen_args(self): @dataclasses.dataclass @@ -5388,6 +5406,7 @@ def __repr__(self): tensor_args = [Shim(x.codegen_reference()) for x in self.inputs] args, kwargs = self.unflatten_args(tensor_args, self.constant_args) if V.graph.cpp_wrapper and isinstance(self.op_overload, torch._ops.OpOverload): + args = self.fill_non_provided_args(args, kwargs) args = [ V.graph.wrapper_code.val_to_cpp_arg_str(param.real_type, x) for param, x in zip(self.op_overload._schema.arguments, args) @@ -5395,17 +5414,6 @@ def __repr__(self): else: args = [V.graph.wrapper_code.val_to_arg_str(x) for x in args] - # Previously, we want to maintain forward-compatibility by skipping - # default args in the serialized artifacts in fbcode. However, - # some of our shim interfaces require default values being set. - # Discussed with Sherlock offline and we decided to allow serializing - # default args into the C++ wrapper code for now. We will refine this - # part if we see real FC requirement. More details related to FC - # can be found at: - # https://docs.google.com/document/d/1FzWm-sHYwmRi3x_g036kOxd99KaYquUsA-L5JwOn8ys/edit?usp=sharing - if V.graph.cpp_wrapper and hasattr(self, "args_default_value"): - self.fill_non_provided_args(args, kwargs, convert_val_to_str=True) - # let self.codegen_kwargs handle kwargs self.kwargs.update(kwargs) return args @@ -5441,30 +5449,6 @@ def get_mutation_names(self): assert len(self.mutation_names) <= 1 return self.mutation_names - def fill_non_provided_args(self, args, kwargs, convert_val_to_str=False): - assert isinstance(args, (list, tuple)) - if isinstance(args, tuple): - args = list(args) - assert hasattr(self, "args_default_value") - n_args = len(args) - n_pos_args = len(self.args_default_value) - # For cpp wrapper, if some positional args are not provided, we need to check - # if they're in the kwargs or use their default value - if n_args < n_pos_args: - log.debug( - "%s has %d unprovided positional arguments. " - "Will check if they are in the keyword arguments or will use default values.", - self.op_overload, - n_pos_args - n_args, - ) - pos_args = [ - self.get_pos_arg_value(i, kwargs) for i in range(n_args, n_pos_args) - ] - if convert_val_to_str: - pos_args = [V.graph.wrapper_code.val_to_arg_str(x) for x in pos_args] - args.extend(pos_args) - return args - # ProxyExecutor Design Note # We export the ExternFallbackNodes (for custom ops) into a serialized file # and run it with a host side proxy executor to address the ABI problem @@ -5539,15 +5523,6 @@ def codegen(self, wrapper): if kernel.namespace == "aten": # type: ignore[union-attr] # Aten Fallback Ops assert isinstance(kernel, torch._ops.OpOverload) - - if ( - kernel == aten.mul.Tensor - and len(self.inputs) == 1 - and len(self.constant_args) == 1 - ): - # When aten.mul.Tensor's second arg is constant, cpp wrapper expects to call mul_Scalar - kernel = aten.mul.Scalar - if V.graph.cpp_wrapper: if ( config.is_fbcode() @@ -5562,10 +5537,6 @@ def codegen(self, wrapper): ) self.use_runtime_dispatch = True self.set_cpp_kernel(kernel) - else: - self.cpp_kernel_name = get_aten_cpp_kernel_name(kernel) - schema = kernel._schema # type: ignore[union-attr] - self.init_args_default_value(schema) else: self.python_kernel_name = str(kernel) elif kernel.namespace == "_quantized": # type: ignore[union-attr] @@ -6280,7 +6251,7 @@ def codegen(self, wrapper): ) @classmethod - def create(cls, x, packed_w, orig_w, B, batch_size): + def create(cls, x, packed_w, orig_w, batch_size): x = cls.require_stride1(cls.realize_input(x)) orig_w = cls.require_stride1(cls.realize_input(orig_w)) *m, _ = x.get_size() @@ -6288,11 +6259,7 @@ def create(cls, x, packed_w, orig_w, B, batch_size): output_size = list(m) + [oc] output_stride = make_contiguous_strides_for(output_size) inputs = [x, packed_w, orig_w] - constant_args = [batch_size] - if B is not None: - inputs += [B] - else: - constant_args.insert(0, None) + constant_args = [None, batch_size] return MKLPackedLinear( layout=FixedLayout( @@ -6339,7 +6306,7 @@ def codegen(self, wrapper): ) @classmethod - def create(cls, x, w, B, attr, scalars, algorithm): + def create(cls, x, w, b, attr, scalars, algorithm): x = cls.require_contiguous(cls.realize_input(x)) w = cls.require_contiguous(cls.realize_input(w)) @@ -6347,9 +6314,9 @@ def create(cls, x, w, B, attr, scalars, algorithm): oc, ic = w.get_size() inputs = [x, w] constant_args = [attr, scalars if scalars else [-1], algorithm] - if B is not None: - B = cls.require_contiguous(cls.realize_input(B)) - inputs.append(B) + if b is not None: + b = cls.require_contiguous(cls.realize_input(b)) + inputs.append(b) else: constant_args.insert(0, None) @@ -7191,6 +7158,232 @@ def create( ) +class QLinearPointwiseBinaryPT2E(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + has_bias=True, + x_scale_zp_are_tensors=False, + ): + """ + if bias is not None + - inputs = [x, w, b, weight_scale, weight_zp, x2] + - const_args is: [x_scale, x_zp, o_inv_scale, o_zp, + fp32_output, binary_attr, aplha, unary_attr, unary_scalars, unary_algorithm] + else + - inputs = [x, w, weight_scale, weight_zp, x2] + - const_args is: [bias, x_scale, x_zp, o_inv_scale, o_zp, + fp32_output, binary_attr, aplha, unary_attr, unary_scalars, unary_algorithm] + """ + self.has_bias = has_bias + self.x_scale_zp_are_tensors = x_scale_zp_are_tensors + super().__init__( + layout, + inputs, + constant_args, + None, + python_kernel_name=( + "torch.ops.onednn.qlinear_pointwise.binary_tensor" + if x_scale_zp_are_tensors + else "torch.ops.onednn.qlinear_pointwise.binary" + ), + cpp_kernel_name="onednn::qlinear_pointwise", + ) + self.cpp_kernel_overload_name = ( + "binary_tensor" if x_scale_zp_are_tensors else "binary" + ) + self.cpp_kernel_key = "qlinear_pointwise_binary" + x_scale_type_str, x_zp_type_str = ( + ("at::Tensor", "at::Tensor") + if x_scale_zp_are_tensors + else ("double", "int64_t") + ) + self.cpp_op_schema = f""" + at::Tensor( + at::Tensor act, + {x_scale_type_str} act_scale, + {x_zp_type_str} act_zero_point, + at::Tensor weight, + at::Tensor weight_scales, + at::Tensor weight_zero_points, + c10::optional bias, + double inv_output_scale, + int64_t output_zero_point, + c10::optional output_dtype, + c10::optional other, + double other_scale, + int64_t other_zero_point, + c10::string_view binary_post_op, + double binary_alpha, + c10::string_view unary_post_op, + torch::List> unary_post_op_args, + c10::string_view unary_post_op_algorithm)""" + + def codegen(self, wrapper): + # Parser the inputs and constant + args = [x.codegen_reference() for x in self.inputs] + const_args = [] + const_args.extend(self.codegen_const_args()) + + x = args[0] + packed_weight = args[1] + bias = args[2] if self.has_bias else const_args[0] + w_scale, w_zp, other = args[-3], args[-2], args[-1] + if self.x_scale_zp_are_tensors: + assert len(args) >= 5 + x_scale, x_zp = args[-5], args[-4] + ( + o_inv_scale, + o_zp, + output_dtype, + other_scale, + other_zp, + binary_attr, + alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ) = const_args[-10:] + else: + assert len(const_args) >= 8 + ( + x_scale, + x_zp, + o_inv_scale, + o_zp, + output_dtype, + other_scale, + other_zp, + binary_attr, + alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ) = const_args[-12:] + + codegen_args = ( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + bias, + o_inv_scale, + o_zp, + output_dtype, + other, + other_scale, + other_zp, + binary_attr, + alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ) + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.python_kernel_name, + self.cpp_kernel_name, + codegen_args, + self.cpp_op_schema, + self.cpp_kernel_key, + self.cpp_kernel_overload_name, + ) + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + + @classmethod + def create( + cls, + x: "TensorBox", + x_scale: float, + x_zp: int, + weight: "TensorBox", # packed_weight + w_scale: "TensorBox", + w_zp: "TensorBox", + bias: "TensorBox", + o_inv_scale: float, + output_zero_point: int, + output_dtype, + other: "TensorBox", + other_scale, + other_zp, + binary_attr, + alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ): + ( + inputs, + constant_args, + kernel_layout, + req_stride_order, + ) = _prepare_linear_fusion_create( + cls, + x, + weight, + bias, + ) + + if isinstance(x_scale, TensorBox) and isinstance(x_zp, TensorBox): + x_scale.realize() + x_zp.realize() + inputs = inputs + [x_scale, x_zp] + x_scale_zp_are_tensors = True + else: + assert isinstance(x_scale, float) and isinstance(x_zp, int) + constant_args = constant_args + [x_scale, x_zp] + x_scale_zp_are_tensors = False + w_scale.realize() + w_zp.realize() + inputs = inputs + [w_scale, w_zp] + if binary_attr == "sum": + other = cls.require_stride_order(other, req_stride_order) + inputs.append(other) + constant_args = constant_args + [ + o_inv_scale, + output_zero_point, + output_dtype, + other_scale, + other_zp, + binary_attr, + alpha, + unary_attr, + may_convert_to_optional(unary_scalars), + unary_algorithm, + ] + + if binary_attr == "sum": + packed = QLinearPointwiseBinaryPT2E( + layout=NoneLayout(other.get_device()), + inputs=inputs, + constant_args=constant_args, + has_bias=(bias is not None), + x_scale_zp_are_tensors=x_scale_zp_are_tensors, + ) + mark_node_as_mutating(packed, other) + # Return other since it has been inplace changed. + return packed.inputs[-1] + + if output_dtype is not None: + assert output_dtype in [torch.float32, torch.bfloat16] + # in _prepare_linear_fusion_create, we use x.dtype (uint8) to create kernel_layout + # if we set fp32_output, the output buf should be dtype float32 instead of uint8. + kernel_layout.dtype = output_dtype + + return QLinearPointwiseBinaryPT2E( + layout=kernel_layout, + inputs=inputs, + constant_args=constant_args, + has_bias=(bias is not None), + x_scale_zp_are_tensors=x_scale_zp_are_tensors, + ) + + @dataclasses.dataclass class MutableBox(IRNode): """ diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index a780d3709cb0c..32dff9d46668c 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -1,17 +1,39 @@ """ Triton Implementation of the flex_attention Kernel""" + import logging -from typing import Any, List +import math +from enum import auto, Enum +from typing import Any, List, Tuple import torch +from torch._prims_common import make_contiguous_strides_for from .. import config -from ..lowering import empty_strided, lowerings, register_lowering +from ..ir import ( + ComputedBuffer, + FixedLayout, + FlexibleLayout, + InputBuffer, + IRNode, + StorageBox, + Subgraph, + TensorBox, +) +from ..lowering import empty_strided, full, lowerings, register_lowering from ..select_algorithm import autotune_select_algorithm, TritonTemplate log = logging.getLogger(__name__) aten = torch.ops.aten -def sdpa_grid(batch_size, num_heads, num_queries, d_model, meta): +class SubgraphType(Enum): + """The type of subgraph for which we want to generate an output buffer.""" + + FWD = auto() # Forward pass + JOINT_FWD = auto() # The recompute step fo the of the bwds kernel + JOINT_BWD = auto() # The bwd pass of the joint + + +def flex_attention_grid(batch_size, num_heads, num_queries, d_model, meta): """How is this kernel parallelized? We create a grid of (batch_size * num_heads, ceil_div(n_queries, query_block_size), 1) Each block is responsible for iterating over blocks of keys and values calculating @@ -22,9 +44,117 @@ def sdpa_grid(batch_size, num_heads, num_queries, d_model, meta): return (triton.cdiv(num_queries, meta["BLOCK_M"]), batch_size * num_heads, 1) -sdpa_template = TritonTemplate( - name="sdpa", - grid=sdpa_grid, +def create_placeholder( + name: str, dtype: torch.dtype, device: torch.device +) -> TensorBox: + """Creates a placeholder input buffers for producing subgraph_output.""" + input_buffer = InputBuffer(name, FixedLayout(device, dtype, [1], [1])) + return TensorBox.create(input_buffer) + + +def index_to_other_buffers(cnt: int, graph_type: SubgraphType) -> int: + """This function needs to be aware of the signatures for flex_attention_forward + and flex_attention_backward. If new args are added, or the signature changes + be sure to update the indexing math + + Args: + cnt (int): The current index of the placeholder node + is_joint_graph (bool): Whether or not this subgraph represents the joint graph + """ + # Current fwd_args = [query, key, value, score_mod, *other_buffers] + # For fwd_graphs we have 5 dummy values this when the first lifted args + # is seen cnt = 5 and the start of the index_buffers is at args[4] + # thus we subtract 1 from the current cnt + if graph_type == SubgraphType.FWD: + return cnt - 1 + + # Current bwd_args = [q, k, v, out, lse, grad_out, fw_graph, joint_graph, *other_buffers] + # We have 5 dummy values but the start of other_buffers is at index 8 + if graph_type == SubgraphType.JOINT_FWD: + return cnt + 3 + + # Same bwd args but now with 6 dummy values while other_buffers still start at 8 + if graph_type == SubgraphType.JOINT_BWD: + return cnt + 2 + + +def build_subgraph_buffer( + args: Tuple[IRNode], + placeholder_inps: List[TensorBox], + subgraph: Subgraph, + graph_type: SubgraphType, +) -> ComputedBuffer: + """This function's goal is to take in the required args and produce the subgraph buffer + The subgraph buffer is a ComputedBuffer that will be inlined into the triton template + + Args: + args: The args that were passed into the flex_attention kernel + placeholder_inps: The list of scalar inputs, these were created on the fly through `create_placeholder` + subgraph: The Subgraph ir for which to produce the output node + graph_type: The type of subgraph for which we want to produce the output node, see enum above for details + """ + cnt = 0 + env = {} + for node in subgraph.graph_module.graph.nodes: + # There are two classes of placeholder inpts that we need + # to handle differently. For the first n_scalar_inps inputs + # we expect that these placeholders were generated by the make_fx call + # in the flex Attention HOP. So we need to create a new placeholder + # TensorBox for each of these inputs. For the rest of the inputs we + # expect that these are lifted inputs that fill up the '*other_buffers' + # tuple and already have corresponding TensorBoxes passed in as args. + if node.op == "placeholder": + is_lifted_input = cnt >= len(placeholder_inps) + lifted_input_index = index_to_other_buffers(cnt, graph_type) + env[node] = ( + args[lifted_input_index] if is_lifted_input else placeholder_inps[cnt] + ) + cnt += 1 + elif node.op == "call_function": + # For call_function we use the default lowerings and pass in the + # already created TensorBoxes as args + from torch.utils._pytree import tree_map + + env[node] = lowerings[node.target]( + *tree_map(lambda x: env[x] if x in env else x, node.args) + ) + elif node.op == "output": + # For the output node we need to create a ComputedBuffer + # which represents the actual score modification + # The joint_graph's output should be of the form[grad_score, None, None, None, None] + # This is because only the 'score' requires grad and the other outputs are + # the non-differentiable index scalars + if graph_type == SubgraphType.FWD or graph_type == SubgraphType.JOINT_FWD: + output_node = node.args[0] + else: + output_node = node.args[0][0] + output_buffer = env[output_node] + assert isinstance(output_buffer, TensorBox), ( + "The output node for flex attention's subgraph must be a TensorBox, but got: ", + type(output_buffer), + ) + assert isinstance(output_buffer.data, StorageBox), ( + "The output node for the flex attention subgraph must be a StorageBox, but got: ", + type(output_buffer), + ) + # Create the ComputedBuffer directly that will be inlined into the modification block + subgraph_buffer = ComputedBuffer( + name=None, + layout=FlexibleLayout( + device=output_buffer.data.get_device(), + dtype=output_buffer.data.get_dtype(), + size=output_buffer.data.get_size(), + ), + data=output_buffer.data.data, # type: ignore[arg-type] + ) + return subgraph_buffer + + raise ValueError("TemplatedAttention was passed a subgraph with no output node!") + + +flex_attention_template = TritonTemplate( + name="flex_attention", + grid=flex_attention_grid, source=r""" {{def_kernel("Q", "K", "V", "LSE")}} # Sub notation for this kernel: @@ -118,6 +248,7 @@ def sdpa_grid(batch_size, num_heads, num_queries, d_model, meta): m = offs_m[:, None] n = start_n + offs_n[None, :] {{ modification( + subgraph_number=0, score="qk", b="off_hz // H", h="off_hz % H", @@ -192,7 +323,7 @@ def sdpa_grid(batch_size, num_heads, num_queries, d_model, meta): } -def _get_default_config(query): +def _get_default_config_fwd(query) -> Tuple[int, int, int, int]: dtype = query.get_dtype() head_dim = query.get_size()[-1] default_config = None @@ -218,143 +349,394 @@ def _get_default_config(query): return default_config +def _get_default_config_bwd(query) -> Tuple[int, int, int, int]: + head_dim = query.get_size()[-1] + dtype = query.get_dtype() + + if head_dim <= 256 and torch.cuda.get_device_capability() >= (9, 0): # H100 + if dtype == torch.float32: + return (64, 64, 4, 1) + return (128, 128, 4, 3) + elif head_dim <= 256 and torch.cuda.get_device_capability() >= (8, 0): # A100 + return (32, 32, 4, 1) + else: # modest hardware or extremely large head_dim + return (32, 32, 4, 1) + + # TODO: We probably also need a layout constraint? @register_lowering(torch.ops.higher_order.flex_attention, type_promotion_kind=None) def flex_attention(*args, **kwargs): - from torch._prims_common import make_contiguous_strides_for - from ..ir import ( - ComputedBuffer, - FixedLayout, - FlexibleLayout, - InputBuffer, - StorageBox, - TensorBox, - ) - query, key, value, subgraph, *other_buffers = args + placeholder_inps = [ + create_placeholder(name, dtype, query.get_device()) + for name, dtype in [ + ("score", query.get_dtype()), + ("b", torch.int32), + ("h", torch.int32), + ("m", torch.int32), + ("n", torch.int32), + ] + ] + subgraph_buffer = build_subgraph_buffer( + args, placeholder_inps, subgraph, graph_type=SubgraphType.FWD + ) + layout = FixedLayout( + query.get_device(), + query.get_dtype(), + query.get_size(), + make_contiguous_strides_for(query.get_size()), + ) + # see NOTE:[TritonTemplates with multiple outputs] + logsumexp_shape = query.get_size()[:-1] # [B, H, M] + logsumexp = empty_strided( + logsumexp_shape, + None, + dtype=torch.float32, # The logsumexp is always stored in fp32 regardless of the input dtype + device=query.get_device(), + ) + choices: List[Any] = [] + configs: List[Tuple[int, int, int, int]] = [] + configs.append(_get_default_config_fwd(query)) + if config.max_autotune: + configs += [ + (128, 64, 4, 3), + (128, 128, 4, 3), + (128, 128, 8, 2), + (64, 128, 4, 3), + (64, 64, 4, 3), + ] - def create_placeholder(name: str, dtype: torch.dtype) -> InputBuffer: - return TensorBox.create( - InputBuffer( - name, - FixedLayout( - query.get_device(), - dtype, - [ - 1, - ], - [ - 1, - ], - ), - ) + # Note, we don't need to pass in the captured buffers explicitly + # because they're implicitly added by the score_mod function + # We do need to explicitly pass it in for autotuning though. + for BLOCK_M, BLOCK_N, num_warps, num_stages in configs: + flex_attention_template.maybe_append_choice( + choices=choices, + input_nodes=[query, key, value, logsumexp], + layout=layout, + subgraphs=[ + subgraph_buffer, + ], + mutated_inputs=[ + logsumexp, + ], + num_stages=num_stages, + num_warps=num_warps, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_DMODEL=query.get_size()[-1], + # For now, we always assume the "sound" option + SCORE_MOD_IS_LINEAR=False, + ROWS_GUARANTEED_SAFE=False, + OUTPUT_LOGSUMEXP=True, ) + inputs_for_autotuning = [query, key, value, logsumexp] + list(other_buffers) + return ( + autotune_select_algorithm( + "flex_attention", choices, inputs_for_autotuning, layout + ), + logsumexp, + ) - scalar_inps = ["score", "b", "h", "m", "n"] - env = {} - cnt = 0 - placeholder_inps = [ - create_placeholder(name, dtype) + +# ---------------------------- Backward HOP Implementation ---------------------------- + + +def flex_attention_backward_grid(batch_size, num_heads, num_key_value, d_model, meta): + """How is this kernel parallelized? + Currently this is only parallelizing over batch * num_heads, but we can, and want to + parallelize over ceil_div(num_key_value, key_value_block_size). To do this will either require + atomic updates to some grad values or to have a two pass kernel design. + """ + return (batch_size * num_heads, 1, 1) + + +flex_attention_backward_template = TritonTemplate( + name="flex_attention_backward", + grid=flex_attention_backward_grid, + source=r""" +{{def_kernel("Q", "K", "V", "OUT", "LSE", "DELTA", "DO", "DQ", "DV")}} + # Sub notation for this kernel: + # Q: Query, K: Key, V: Value + # OUT: Forward output, LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT* DO, axis=1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values, D: Model dimension + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # (Modifiable) Config options: + # BLOCK_M + # BLOCK_N + # SCORE_MOD_IS_LINEAR: Is the score modifier linear? If so, we can lift the + # change of base out of the loop + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + + # Define Q Strides + stride_qz = {{stride("Q", 0)}} + stride_qh = {{stride("Q", 1)}} + stride_qm = {{stride("Q", 2)}} + stride_qk = {{stride("Q", 3)}} + # Define K Strides + stride_kz = {{stride("K", 0)}} + stride_kh = {{stride("K", 1)}} + stride_kn = {{stride("K", 2)}} + stride_kk = {{stride("K", 3)}} + # Define V Strides + stride_vz = {{stride("V", 0)}} + stride_vh = {{stride("V", 1)}} + stride_vn = {{stride("V", 2)}} + stride_vk = {{stride("V", 3)}} + + Z = {{size("Q", 0)}} + H = {{size("Q", 1)}} + N_CTX = {{size("Q", 2)}} + + qk_scale = 1.0 + MATMUL_PRECISION = Q.dtype.element_ty + + off_hz = tl.program_id(0) + off_z = off_hz // H # batch idx + off_h = off_hz % H # head idx + + # offset pointers for batch/head + Q += off_z * stride_qz + off_h * stride_qh + K += off_z * stride_kz + off_h * stride_kh + V += off_z * stride_vz + off_h * stride_vh + + # Asserting contiguous for now... + DO += off_z * stride_qz + off_h * stride_qh + DQ += off_z * stride_qz + off_h * stride_qh + DV += off_z * stride_vz + off_h * stride_vh + + # TODO I think that this should be N_CTX/BLOCK_N blocks + for start_n in range(0, NUM_Q_BLOCKS): + # We are not doing the causal optimization yet allowing us to start further down the + # kv column + offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_m = tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, BLOCK_DMODEL) + + # initialize pointers to value-like data + q_ptrs = Q + (offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk) + k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) + v_ptrs = V + (offs_n[:, None] * stride_vn + offs_k[None, :] * stride_vk) + do_ptrs = DO + (offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk) + dq_ptrs = DQ + (offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk) + + # pointer to row-wise quantities in value-like data + D_ptrs = DELTA + off_hz * N_CTX + l_ptrs = LSE + off_hz * N_CTX + + # initialize dv and dk + dv = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) + dk = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) + + # Key and Value stay in SRAM throughout + k = tl.load(k_ptrs) + v = tl.load(v_ptrs) + + for start_m in range(0, NUM_Q_BLOCKS * BLOCK_M, BLOCK_M): + offs_m_curr = start_m + offs_m + + # load q, k, v, do on-chip + q = tl.load(q_ptrs) + + if SCORE_MOD_IS_LINEAR: + qk_scale *= 1.44269504 + q = (q * qk_scale).to(MATMUL_PRECISION) + + # -- compute qk --- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = tl.dot(q, tl.trans(k.to(MATMUL_PRECISION)), acc=qk) + pre_mod_scores = qk + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = offs_m_curr[:, None] + n = offs_n[None, :] + {{ modification( + subgraph_number=0, + score="qk", + b="off_z", + h="off_h", + m="m", + n="n", + out="qk" + ) | indent_except_first(3) }} + # TODO: In the case that score_mod is linear, this can be LICMed + if not SCORE_MOD_IS_LINEAR: + qk *= 1.44269504 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + l_i = tl.load(l_ptrs + offs_m_curr) + p = tl.math.exp2(qk - l_i[:, None]) + + # compute dv + do = tl.load(do_ptrs) + dv += tl.dot(tl.trans(p.to(MATMUL_PRECISION)), do) + + # compute dp = dot(v, do) + Di = tl.load(D_ptrs + offs_m_curr) # [BLOCKM, 1] + + # compute ds = p * (dp - delta[:, None]) + dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None] + dp += tl.dot(do, tl.trans(v)) + ds = p * dp + + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + {{ modification( + subgraph_number=1, + score="pre_mod_scores", + b="off_z", + h="off_h", + m="m", + n="n", + out="ds" + ) | indent_except_first(3) }} + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # compute dk = dot(ds.T, q) + dk += tl.dot(tl.trans(ds.to(MATMUL_PRECISION)), q) + # compute dq + dq = tl.load(dq_ptrs) + dq += tl.dot(ds.to(MATMUL_PRECISION), k) + + # Store grad_query + tl.store(dq_ptrs, dq) + + # increment pointers + dq_ptrs += BLOCK_M * stride_qm + q_ptrs += BLOCK_M * stride_qm + do_ptrs += BLOCK_M * stride_qm + + # write-back + index_n = offs_n[:, None] + index_k = offs_k[None, :] + + # Store grad_key and grad_value + dv_ptrs = DV + (index_n * stride_vn + index_k * stride_vk) + tl.store(dv_ptrs, dv) + + # TODO generalize and add proper mask support + mask = (index_n != -1) & (index_k != -1) + {{store_output(("off_z", "off_h", "index_n", "index_k"), "dk", "mask", indent_width=8)}} + + """, +) + + +# TODO: We probably also need a layout constraint? +@register_lowering( + torch.ops.higher_order.flex_attention_backward, type_promotion_kind=None +) +def flex_attention_backward(*args, **kwargs): + ( + query, + key, + value, + out, + logsumexp, + grad_out, + fw_graph, + joint_graph, + *other_buffers, + ) = args + + device = query.get_device() + dtype = query.get_dtype() + + fwd_placeholder_inps = [ + create_placeholder(name, dtype, device) for name, dtype in [ - ("score", query.get_dtype()), + ("score", dtype), ("b", torch.int32), ("h", torch.int32), ("m", torch.int32), ("n", torch.int32), ] ] - for node in subgraph.graph_module.graph.nodes: - # There are two classes of placeholder inpts that we need - # to handle differently. For the first n_scalar_inps inputs - # we expect that these placeholders were generated by the make_fx call - # in the flex Attention HOP. So we need to create a new placeholder - # TensorBox for each of these inputs. For the rest of the inputs we - # expect that these are lifted inputs that fill up the '*other_buffers' - # tuple and already have corresponding TensorBoxes passed in as args. - if node.op == "placeholder": - is_lifted_input = cnt >= len(scalar_inps) - env[node] = args[cnt - 1] if is_lifted_input else placeholder_inps[cnt] - cnt += 1 - elif node.op == "call_function": - # For call_function we use the defulat lowerings and pass in the - # already created TensorBoxes as args - from torch.utils._pytree import tree_map + fw_subgraph_buffer = build_subgraph_buffer( + args, fwd_placeholder_inps, fw_graph, graph_type=SubgraphType.JOINT_FWD + ) - env[node] = lowerings[node.target]( - *tree_map(lambda x: env[x] if x in env else x, node.args) - ) - elif node.op == "output": - # For the output node we need to create a ComputedBuffer - # which represents the actual score modification + joint_placeholder_inps = fwd_placeholder_inps + [ + create_placeholder("out", dtype, device) + ] + joint_subgraph_buffer = build_subgraph_buffer( + args, joint_placeholder_inps, joint_graph, graph_type=SubgraphType.JOINT_BWD + ) - output_buffer = env[node.args[0]] - assert isinstance(output_buffer.data, StorageBox), ( - "The output node for the flex attention subgraph must be a StorageBox, but got: ", - type(output_buffer), - ) - # Create the ComputedBuffer directly that will be inlined into the modification block - subgraph_buffer = ComputedBuffer( - name=None, - layout=FlexibleLayout( - device=output_buffer.data.get_device(), - dtype=output_buffer.data.get_dtype(), - size=output_buffer.data.get_size(), - ), - data=output_buffer.data.data, # type: ignore[arg-type] - ) + layout_k = FixedLayout( + key.get_device(), + key.get_dtype(), + key.get_size(), + make_contiguous_strides_for(key.get_size()), + ) - layout = FixedLayout( - output_buffer.get_device(), - query.get_dtype(), - query.get_size(), - make_contiguous_strides_for(query.get_size()), - ) - # see NOTE:[TritonTemplates with multiple outputs] - logsumexp_shape = query.get_size()[:-1] # [B, H, M] - logsumexp = empty_strided( - logsumexp_shape, - None, - dtype=torch.float32, # The logsumexp is always stored in fp32 regardless of the input dtype - device=output_buffer.get_device(), - ) - choices: List[Any] = [] - configs: List[Any] = [] - configs.append(_get_default_config(query)) - if config.max_autotune: - configs += [ - (128, 64, 4, 3), - (128, 128, 4, 3), - (128, 128, 8, 2), - (64, 128, 4, 3), - (64, 64, 4, 3), - ] - # Note, we don't need to pass in the captured buffers explicitly - # because they're implicitly added by the score_mod function - # We do need to explicitly pass it in for autotuning though. - for BLOCK_M, BLOCK_N, num_warps, num_stages in configs: - sdpa_template.maybe_append_choice( - choices=choices, - input_nodes=[query, key, value, logsumexp], - layout=layout, - subgraphs=subgraph_buffer, - mutated_inputs=[ - logsumexp, - ], - num_stages=num_stages, - num_warps=num_warps, - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, - BLOCK_DMODEL=query.get_size()[-1], - # For now, we always assume the "sound" option - SCORE_MOD_IS_LINEAR=False, - ROWS_GUARANTEED_SAFE=False, - OUTPUT_LOGSUMEXP=True, - ) - inputs_for_autotuning = [query, key, value, logsumexp] + list(other_buffers) - return ( - autotune_select_algorithm( - "sdpa", choices, inputs_for_autotuning, layout - ), + # Create delta which will is needed for the bwd's kernel + mul_delta = lowerings[aten.mul](out, grad_out) + delta = lowerings[aten.sum](mul_delta, axis=-1) + + # see NOTE:[TritonTemplates with multiple outputs] + grad_query = full( + query.get_size(), 0.0, dtype=dtype, device=device + ) # torch.zeros equivalent + grad_query.realize() + grad_value = empty_strided(value.get_size(), None, dtype=dtype, device=device) + + choices: List[Any] = [] + configs: List[Tuple[int, int, int, int]] = [] + configs.append(_get_default_config_bwd(query)) + if config.max_autotune: + configs += [ + (128, 128, 4, 3), + (128, 128, 8, 1), + (64, 64, 4, 3), + (64, 64, 8, 1), + ] + + for BLOCK_M, BLOCK_N, num_warps, num_stages in configs: + flex_attention_backward_template.maybe_append_choice( + choices=choices, + input_nodes=[ + query, + key, + value, + out, logsumexp, - ) - raise ValueError("TemplatedAttention was passed a subgraph with no output node!") + delta, + grad_out, + grad_query, + grad_value, + ], + layout=layout_k, # We use store_output only for grad_key + subgraphs=[fw_subgraph_buffer, joint_subgraph_buffer], + mutated_inputs=[grad_query, grad_value], + num_stages=num_stages, + num_warps=num_warps, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_DMODEL=query.get_size()[-1], + NUM_Q_BLOCKS=math.ceil(query.get_size()[-2] / BLOCK_M), + # For now, we always assume the "sound" option + SCORE_MOD_IS_LINEAR=False, + ) + inputs_for_autotuning = [ + query, + key, + value, + out, + logsumexp, + delta, + grad_out, + grad_query, + grad_value, + ] + list(other_buffers) + + grad_key = autotune_select_algorithm( + "flex_attention_backward", choices, inputs_for_autotuning, layout_k + ) + return ( + grad_query, + grad_key, + grad_value, + ) diff --git a/torch/_inductor/kernel/mm.py b/torch/_inductor/kernel/mm.py index fa14b4406de69..593da39d2bf63 100644 --- a/torch/_inductor/kernel/mm.py +++ b/torch/_inductor/kernel/mm.py @@ -3,7 +3,6 @@ from typing import Any, Dict, List, Optional import torch -from torch._inductor.codegen.cpp_gemm_template import CppPackedGemmTemplate from torch._inductor.virtualized import V from .. import config as inductor_config from ..codegen.cuda.gemm_template import CUTLASSGemmTemplate @@ -18,7 +17,6 @@ ) from ..utils import ( use_aten_gemm_kernels, - use_cpp_packed_gemm_template, use_cutlass_template, use_max_autotune, use_triton_template, @@ -158,13 +156,6 @@ def tuned_mm(mat1, mat2, *, layout=None): if static_shape and is_nonzero and use_cutlass_template(layout, m, n, k): CUTLASSGemmTemplate.add_cutlass_gemm_choices(choices, layout, [mat1, mat2]) - if use_cpp_packed_gemm_template(layout, mat1, mat2): - CppPackedGemmTemplate.add_choices( - choices, - layout, - [mat1, mat2], - ) - if len(choices) == 0 and not use_aten_gemm_kernels(): log.warning("No choices for GEMM, using ATen backend as fallback") choices.append(aten_mm.bind((mat1, mat2), aten_layout)) @@ -320,15 +311,6 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): beta=beta, ) - if use_cpp_packed_gemm_template(layout, mat1, mat2): - CppPackedGemmTemplate.add_choices( - choices, - layout, - [inp_expanded, mat1, mat2], - alpha=alpha, - beta=beta, - ) - add_aten_fallback = False if len(choices) == 0: log.warning("No choices for GEMM, using ATen backend as fallback") diff --git a/torch/_inductor/kernel/mm_common.py b/torch/_inductor/kernel/mm_common.py index 26d08183b0e55..76511e19a49da 100644 --- a/torch/_inductor/kernel/mm_common.py +++ b/torch/_inductor/kernel/mm_common.py @@ -1,4 +1,5 @@ import functools +import itertools import logging from typing import cast, List, Tuple @@ -113,39 +114,50 @@ def filtered_configs( # List of dictionaries to store the kernel configs. Configs that evaluate to true -# will be utilised on the target platform -mm_kernel_configs = [ - # "BLOCK_M", "BLOCK_N", "BLOCK_K", "num_stages", "num_warps" - {"config": (16, 32, 16, 3, 2), "cond": True}, - {"config": (16, 32, 32, 4, 2), "cond": True}, - {"config": (16, 32, 32, 5, 2), "cond": True}, - {"config": (32, 32, 16, 1, 2), "cond": True}, - {"config": (32, 32, 128, 2, 4), "cond": torch.version.hip is None}, - {"config": (32, 64, 32, 5, 8), "cond": True}, - {"config": (64, 32, 32, 5, 8), "cond": True}, - {"config": (64, 32, 128, 5, 4), "cond": True}, - {"config": (64, 64, 16, 2, 4), "cond": True}, - {"config": (64, 64, 32, 2, 4), "cond": True}, - {"config": (64, 64, 64, 3, 8), "cond": True}, - {"config": (64, 64, 128, 3, 4), "cond": True}, - {"config": (64, 64, 128, 5, 4), "cond": True}, - {"config": (64, 128, 32, 3, 4), "cond": True}, - {"config": (64, 128, 32, 4, 8), "cond": True}, - {"config": (64, 128, 64, 4, 4), "cond": True}, - {"config": (64, 128, 128, 4, 4), "cond": True}, - {"config": (128, 64, 32, 2, 2), "cond": True}, - {"config": (128, 64, 32, 3, 4), "cond": True}, - {"config": (128, 64, 32, 4, 8), "cond": True}, - {"config": (128, 64, 64, 3, 8), "cond": True}, - {"config": (128, 64, 128, 4, 8), "cond": True}, - {"config": (128, 128, 32, 2, 8), "cond": True}, - {"config": (128, 128, 32, 3, 4), "cond": True}, - {"config": (128, 128, 32, 4, 4), "cond": True}, - {"config": (128, 128, 64, 3, 4), "cond": True}, - {"config": (128, 128, 64, 3, 8), "cond": True}, - {"config": (128, 128, 64, 5, 4), "cond": True}, - {"config": (128, 128, 64, 5, 8), "cond": True}, -] +# will be utilised on the target platform. The configs are as follows: +# (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps) +mm_kernel_configs = ( + [ + {"config": (16, 32, 16, 3, 2), "cond": True}, + {"config": (16, 32, 32, 4, 2), "cond": True}, + {"config": (16, 32, 32, 5, 2), "cond": True}, + {"config": (32, 32, 16, 1, 2), "cond": True}, + {"config": (32, 32, 128, 2, 4), "cond": torch.version.hip is None}, + {"config": (32, 64, 32, 5, 8), "cond": True}, + {"config": (64, 32, 32, 5, 8), "cond": True}, + {"config": (64, 32, 128, 5, 4), "cond": True}, + {"config": (64, 64, 16, 2, 4), "cond": True}, + {"config": (64, 64, 32, 2, 4), "cond": True}, + {"config": (64, 64, 64, 3, 8), "cond": True}, + {"config": (64, 64, 128, 3, 4), "cond": True}, + {"config": (64, 64, 128, 5, 4), "cond": True}, + {"config": (64, 128, 32, 3, 4), "cond": True}, + {"config": (64, 128, 32, 4, 8), "cond": True}, + {"config": (64, 128, 64, 4, 4), "cond": True}, + {"config": (64, 128, 128, 4, 4), "cond": True}, + {"config": (128, 64, 32, 2, 2), "cond": True}, + {"config": (128, 64, 32, 3, 4), "cond": True}, + {"config": (128, 64, 32, 4, 8), "cond": True}, + {"config": (128, 64, 64, 3, 8), "cond": True}, + {"config": (128, 64, 128, 4, 8), "cond": True}, + {"config": (128, 128, 32, 2, 8), "cond": True}, + {"config": (128, 128, 32, 3, 4), "cond": True}, + {"config": (128, 128, 32, 4, 4), "cond": True}, + {"config": (128, 128, 64, 3, 4), "cond": True}, + {"config": (128, 128, 64, 3, 8), "cond": True}, + {"config": (128, 128, 64, 5, 4), "cond": True}, + {"config": (128, 128, 64, 5, 8), "cond": True}, + ] + if inductor_config.max_autotune_gemm_search_space != "EXHAUSTIVE" + else [ + {"config": (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps), "cond": True} + for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product( + [16, 32, 64, 128, 256], repeat=3 + ) + for num_stages in [1, 2, 3, 4, 5] + for num_warps in [2, 4, 8] + ] +) int8_mm_kernel_configs = [ {"config": (64, 64, 32, 2, 4), "cond": True}, diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 389ff16e39025..07899fe2ccd09 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -1788,7 +1788,12 @@ def bernoulli_(x, *args): "cpu" ), "this should be handled in decomps unless config.fallback_random or the device is CPU" x.realize() - ir.InplaceBernoulliFallback(x, *args) + op_overload = ( + aten.bernoulli_.float + if len(args) == 0 or isinstance(args[0], float) + else aten.bernoulli_.Tensor + ) + ir.InplaceBernoulliFallback(op_overload, x, *args) return x @@ -4554,7 +4559,7 @@ def fn(idx): factor = ops.index_expr(hend - hstart, torch.int32) divide_factors.append(factor) divide_factor = functools.reduce(ops.mul, divide_factors) - return ops.div(fn_sum(idx, x_loader), divide_factor) + return ops.truediv(fn_sum(idx, x_loader), divide_factor) rv = Pointwise.create( device=x.get_device(), @@ -5318,7 +5323,7 @@ def log_add_exp_helper(a_tuple, b_tuple): def cummax(x, axis=None): if len(x.get_size()) == 0: assert axis in [0, -1] - return clone(x), torch.empty_like(x, dtype=torch.int64) + return clone(x), empty_like(x, dtype=torch.int64) dtype = x.get_dtype() combine_fn = ir.get_reduction_combine_fn( @@ -5348,7 +5353,7 @@ def cummax(x, axis=None): def cummin(x, axis=None): if len(x.get_size()) == 0: assert axis in [0, -1] - return clone(x), torch.empty_like(x, dtype=torch.int64) + return clone(x), empty_like(x, dtype=torch.int64) dtype = x.get_dtype() combine_fn = ir.get_reduction_combine_fn( diff --git a/torch/_inductor/mkldnn_lowerings.py b/torch/_inductor/mkldnn_lowerings.py index 075a9b8b709e5..5a12a5c090bf8 100644 --- a/torch/_inductor/mkldnn_lowerings.py +++ b/torch/_inductor/mkldnn_lowerings.py @@ -1,37 +1,14 @@ -from typing import List, Optional +from typing import List import torch import torch.utils._pytree as pytree -from torch._inductor.kernel.mm_common import mm_args from . import ir -from .codegen.cpp_gemm_template import CppPackedGemmTemplate from .ir import TensorBox -from .lowering import ( - add, - add_needs_realized_inputs, - aten, - permute, - register_lowering, - to_dtype, - view, -) -from .select_algorithm import ( - autotune_select_algorithm, - ChoiceCaller, - ExternKernelChoice, -) -from .utils import use_aten_gemm_kernels, use_cpp_packed_gemm_template, use_max_autotune -from .virtualized import V +from .lowering import add, add_needs_realized_inputs, aten, register_lowering, to_dtype def register_onednn_fusion_ops(): if torch._C._has_mkldnn: - aten_mkldnn_linear_unary = ExternKernelChoice( - torch.ops.mkldnn._linear_pointwise, - "mkldnn::_linear_pointwise", - has_out_variant=False, - kernel_creator=ir.LinearUnary.create, - ) cpu_needs_realized_inputs = [ torch.ops.mkldnn._convolution_pointwise, torch.ops.mkldnn._convolution_pointwise_, @@ -139,75 +116,11 @@ def convolution_binary_inplace( @register_lowering(torch.ops.mkldnn._linear_pointwise) def linear_unary( - x: TensorBox, - w: TensorBox, - b: TensorBox, - attr, - scalars, - algorithm, - layout=None, + x: TensorBox, w: TensorBox, b: TensorBox, attr, scalars, algorithm ): - x_size = x.get_size() - if len(x_size) > 2: - # GEMM template needs 2D input, normalize input shape here - x = view(x, [-1, x_size[-1]]) - choices: List[ChoiceCaller] = [] - if len(choices) == 0 or use_aten_gemm_kernels(): - choices.append( - aten_mkldnn_linear_unary.bind( - (x, w), - layout, - B=None, - attr=attr, - scalars=scalars, - algorithm=algorithm, - ) - if b is None - else aten_mkldnn_linear_unary.bind( - (x, w, b), - layout, - attr=attr, - scalars=scalars, - algorithm=algorithm, - ) - ) - if use_max_autotune(): - transposed_w = permute(w, [1, 0]) - *_, layout, x, transposed_w = mm_args(x, transposed_w, layout=layout) - # TODO(jgong5): support epilogue fusion - if ( - use_cpp_packed_gemm_template(layout, x, transposed_w) - and attr == "none" - ): - if b is None: - CppPackedGemmTemplate.add_choices( - choices, - layout, - [x, w], - trans_w=True, - ) - else: - CppPackedGemmTemplate.add_choices( - choices, - layout, - [x, w, b], - trans_w=True, - input_indices=[2, 0, 1], - ) - assert w.get_name() in V.graph.constants - input_gen_fns = { - 1: lambda x: V.graph.constants[x.get_name()], - } - result = autotune_select_algorithm( - "linear_unary", - choices, - [x, w] if b is None else [x, w, b], - layout, - input_gen_fns=input_gen_fns, + return TensorBox.create( + ir.LinearUnary.create(x, w, b, attr, scalars, algorithm) ) - if len(x_size) > 2: - result = view(result, (*x_size[:-1], result.get_size()[-1])) - return result @register_lowering(torch.ops.mkldnn._linear_pointwise.binary) def linear_binary(x: TensorBox, y: TensorBox, w: TensorBox, b: TensorBox, attr): @@ -425,13 +338,71 @@ def qlinear_unary( ) ) - if torch._C.has_mkl: - aten_mkl_linear = ExternKernelChoice( - torch.ops.mkl._mkl_linear, - "mkl::_mkl_linear", - has_out_variant=False, - kernel_creator=ir.MKLPackedLinear.create, + @register_lowering( + torch.ops.onednn.qlinear_pointwise.binary, type_promotion_kind=None + ) + @register_lowering( + torch.ops.onednn.qlinear_pointwise.binary_tensor, type_promotion_kind=None + ) + def qlinear_binary( + x: TensorBox, + x_scale, + x_zp, + packed_weight: TensorBox, + w_scale: TensorBox, + w_zp: TensorBox, + bias: TensorBox, + o_inv_scale, + o_zero_point, + output_dtype, + x2: TensorBox, + x2_scale, + x2_zp, + binary_attr, + alpha, + unary_attr, + unary_scalars, + unary_algorithmm, + ): + if binary_attr == "sum": + if output_dtype in [ + torch.float32, + torch.bfloat16, + ] and x2.get_dtype() in [torch.float32, torch.bfloat16]: + if x2.get_dtype() != output_dtype: + # For int8-mixed-bf16 quantization and inplace add, + # there is case when accum dtype is float32 but output dtype is bfloat16. + # Since the accum will be inplaced changed with post op sum, + # we will do accum dtype convertion here. + x2 = to_dtype(x2, output_dtype) + else: + assert ( + x2.get_dtype() == output_dtype + ), "dtype of accum for qlinear post op sum should be the same as output" + return TensorBox.create( + ir.QLinearPointwiseBinaryPT2E.create( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + bias, + o_inv_scale, + o_zero_point, + output_dtype, + x2, + x2_scale, + x2_zp, + binary_attr, + alpha, + unary_attr, + unary_scalars, + unary_algorithmm, + ) ) + + if torch._C.has_mkl: cpu_needs_realized_inputs.append(torch.ops.mkl._mkl_linear) @register_lowering(torch.ops.mkl._mkl_linear) @@ -439,47 +410,11 @@ def mkl_packed_linear( x: TensorBox, packed_w: TensorBox, orig_w: TensorBox, - b: Optional[TensorBox], + b: TensorBox, batch_size, - *, - layout=None, ): - choices: List[ChoiceCaller] = [] - if use_max_autotune(): - transposed_w = permute(orig_w, [1, 0]) - *_, layout, x, transposed_w = mm_args( - x, transposed_w, layout=layout - ) - if use_cpp_packed_gemm_template(layout, x, transposed_w): - CppPackedGemmTemplate.add_choices( - choices, - layout, - [x, packed_w, orig_w], - trans_w=True, - input_indices=[0, 2], - ) - - if len(choices) == 0 or use_aten_gemm_kernels(): - choices.append( - aten_mkl_linear.bind( - (x, packed_w, orig_w), layout, B=None, batch_size=batch_size - ) - ) - - assert packed_w.get_name() in V.graph.constants - assert orig_w.get_name() in V.graph.constants - # packed_w is a mkldnn tensor which we can't generate directly - # so we use the weights from the original tensor in autotune. - input_gen_fns = { - 1: lambda x: V.graph.constants[x.get_name()], - 2: lambda x: V.graph.constants[x.get_name()], - } - result: TensorBox = autotune_select_algorithm( - "packed_linear", - choices, - [x, packed_w, orig_w], - layout, - input_gen_fns=input_gen_fns, + result = TensorBox.create( + ir.MKLPackedLinear.create(x, packed_w, orig_w, batch_size) ) if b is not None: result = add(result, b) diff --git a/torch/_inductor/ops_handler.py b/torch/_inductor/ops_handler.py index 88f9d406c2e18..71395c71c9b6a 100644 --- a/torch/_inductor/ops_handler.py +++ b/torch/_inductor/ops_handler.py @@ -17,7 +17,6 @@ import torch import torch.utils._pytree as pytree -from torch.fx.graph import inplace_methods, magic_methods from .utils import IndentedBuffer, reduction_num_outputs, sympy_index_symbol, sympy_str T = TypeVar("T") @@ -146,6 +145,12 @@ def to_dtype_bitcast(self, x: T, dtype: torch.dtype, src_dtype: torch.dtype) -> """ ... + def identity(self, x: T) -> T: + """ + Returns x as is. This is used to trigger CSE. + """ + ... + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # These operations are only available in a "kernel" context. Check # torch._inductor.codegen.common.CSEProxy for their typical implementation @@ -408,9 +413,6 @@ def to_int(self, x0: T) -> T: def trunc(self, x0: T) -> T: ... - def truncdiv(self, x0: T, x1: T) -> T: - ... - def ceil(self, x0: T) -> T: ... @@ -447,28 +449,195 @@ def sub(self, x0: T, x1: T) -> T: def mul(self, x0: T, x1: T) -> T: ... - def floordiv(self, x0: T, x1: T) -> T: + def pow(self, x0: T, x1: T) -> T: ... - def truediv(self, x0: T, x1: T) -> T: + def and_(self, x0: T, x1: T) -> T: ... - def div(self, x0: T, x1: T) -> T: + def or_(self, x0: T, x1: T) -> T: ... - def mod(self, x0: T, x1: T) -> T: + def xor(self, x0: T, x1: T) -> T: ... - def pow(self, x0: T, x1: T) -> T: + # These are metaprogrammed by MockHandler._init_cls + def lshift(self, x0: T, x1: T) -> T: ... - def and_(self, x0: T, x1: T) -> T: + def rshift(self, x0: T, x1: T) -> T: ... - def or_(self, x0: T, x1: T) -> T: + def getitem(self, x0: T, x1: T) -> T: + # TODO: this is probably just illegal lol ... - def xor(self, x0: T, x1: T) -> T: + def matmul(self, x0: T, x1: T) -> T: + # TODO: this is probably just illegal lol + ... + + def invert(self, x0: T) -> T: + ... + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # These are "special" operators. These only exist if the target + # language actually supports the operator. Keep this in sync with + # pointwise_overrides_data. + + def airy_ai(self, x: T) -> T: + ... + + def bessel_j0(self, x: T) -> T: + ... + + def bessel_j1(self, x: T) -> T: + ... + + def bessel_y0(self, x: T) -> T: + ... + + def bessel_y1(self, x: T) -> T: + ... + + def digamma(self, x: T) -> T: + ... + + def erfcx(self, x: T) -> T: + ... + + def fma(self, x: T, y: T, z: T) -> T: + ... + + def igamma(self, x: T, y: T) -> T: + ... + + def igammac(self, x: T, y: T) -> T: + ... + + def gammainc(self, x: T, y: T) -> T: + ... + + def gammaincc(self, x: T, y: T) -> T: + ... + + def i0(self, x: T) -> T: + ... + + def i0e(self, x: T) -> T: + ... + + def i1(self, x: T) -> T: + ... + + def i1e(self, x: T) -> T: + ... + + def log_ndtr(self, x: T) -> T: + ... + + def modified_bessel_i0(self, x: T) -> T: + ... + + def modified_bessel_i1(self, x: T) -> T: + ... + + def modified_bessel_k0(self, x: T) -> T: + ... + + def modified_bessel_k1(self, x: T) -> T: + ... + + def ndtr(self, x: T) -> T: + ... + + def ndtri(self, x: T) -> T: + ... + + def polygamma(self, x: T, y: T) -> T: + ... + + def scaled_modified_bessel_k0(self, x: T) -> T: + ... + + def scaled_modified_bessel_k1(self, x: T) -> T: + ... + + def spherical_bessel_j0(self, x: T) -> T: + ... + + def zeta(self, x: T, y: T) -> T: + ... + + def chebyshev_polynomial_t(self, x: T, y: T) -> T: + ... + + def chebyshev_polynomial_u(self, x: T, y: T) -> T: + ... + + def chebyshev_polynomial_v(self, x: T, y: T) -> T: + ... + + def chebyshev_polynomial_w(self, x: T, y: T) -> T: + ... + + def legendre_polynomial_p(self, x: T, y: T) -> T: + ... + + def shifted_chebyshev_polynomial_t(self, x: T, y: T) -> T: + ... + + def shifted_chebyshev_polynomial_u(self, x: T, y: T) -> T: + ... + + def shifted_chebyshev_polynomial_v(self, x: T, y: T) -> T: + ... + + def shifted_chebyshev_polynomial_w(self, x: T, y: T) -> T: + ... + + def hermite_polynomial_h(self, x: T, y: T) -> T: + ... + + def hermite_polynomial_he(self, x: T, y: T) -> T: + ... + + def laguerre_polynomial_l(self, x: T, y: T) -> T: + ... + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # These operators are a bit special, because they are conventionally + # natively supported in both Python and C, but the semantics differ so + # care must be taken + + def truncdiv(self, x0: T, x1: T) -> T: + """C-style trunc division between integers only. Computes the true + division of two numbers and rounds the result to zero. + """ + ... + + def floordiv(self, x0: T, x1: T) -> T: + """Python-style floor division between integers only. Computes the + true division of two numbers and floors the result. + """ + ... + + def truediv(self, x0: T, x1: T) -> T: + """True division between floats. Integer inputs are NOT valid: to do + Python style (int, int) -> float division, promote the inputs to float + first.""" + ... + + def div(self, x0: T, x1: T) -> T: + """TODO: to be removed. This renders as / no matter what the backend is + which is incoherent.""" + ... + + def mod(self, x0: T, x1: T) -> T: + """C-style modulus, take sign from LHS (x0).""" + ... + + def remainder(self, x0: T, x1: T) -> T: + """Python-style modulus, take sign from RHS (x1).""" ... # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -578,9 +747,27 @@ def inner(*args): return inner - for name, format_string in itertools.chain( - magic_methods.items(), inplace_methods.items() - ): + for name, format_string in { + "add": "{} + {}", + "sub": "{} - {}", + "mul": "{} * {}", + "floordiv": "{} // {}", + "truediv": "{} / {}", + "mod": "{} % {}", # careful, depending on target semantics varies + "pow": "{} ** {}", + "lshift": "{} << {}", + "rshift": "{} >> {}", + "and_": "{} & {}", + "or_": "{} | {}", + "xor": "{} ^ {}", + "eq": "{} == {}", + "ne": "{} != {}", + "lt": "{} < {}", + "gt": "{} > {}", + "le": "{} <= {}", + "ge": "{} >= {}", + "neg": "-{}", + }.items(): setattr(cls, name, make_handler(format_string)) diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index f98333875f5c8..456e0c50567d5 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -336,7 +336,7 @@ def decide_inplace_update(self): isinstance(self, (SchedulerNode,)) and config.inplace_buffers and ( - not isinstance(V.kernel, torch._inductor.codegen.triton.TritonKernel) + not isinstance(V.kernel, torch._inductor.codegen.simd.SIMDKernel) or getattr(V.kernel, "mutations", None) is not None ) ): @@ -390,7 +390,7 @@ def decide_inplace_update(self): ) # mutations not tracked in cpp kernels if isinstance( - V.kernel, torch._inductor.codegen.triton.TritonKernel + V.kernel, torch._inductor.codegen.simd.SIMDKernel ): V.kernel.mutations.add(input_node.get_name()) V.kernel.mutations.add(self.get_name()) diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 39d8334fe7d50..5cb10e1820cf9 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -6,6 +6,7 @@ import math import operator +import os import sys import textwrap import time @@ -103,7 +104,7 @@ def __init__( prefix_args=0, suffix_args=0, epilogue_fn=identity, - subgraphs=None, + subgraphs: Optional[List[ir.ComputedBuffer]] = None, *, index_dtype, ): @@ -114,7 +115,7 @@ def __init__( ) self.input_nodes = input_nodes self.output_node = output_node - self.named_input_nodes = {} + self.named_input_nodes = {} # type: ignore[var-annotated] self.defines = defines self.kernel_name = kernel_name self.template_mask = None @@ -128,10 +129,10 @@ def __init__( self.prefix_args = prefix_args self.suffix_args = suffix_args self.epilogue_fn = epilogue_fn - self.render_hooks = dict() + self.render_hooks = dict() # type: ignore[var-annotated] self.triton_meta: Optional[Dict[str, object]] = None - # For Templated Attention - self.subgraphs = subgraphs + # For Templated Attention this can be a list of ir.Subgraph + self.subgraphs: Optional[List[ir.ComputedBuffer]] = subgraphs def need_numel_args(self): return False @@ -271,19 +272,28 @@ def stride(self, name, index): val = self.named_input_nodes[name].get_stride()[index] return texpr(self.rename_indexing(val)) - def modification(self, **fixed_inputs) -> str: - """This function generates the code body to populate - a 'modification' placeholder within a template + def modification(self, subgraph_number: int, **fixed_inputs) -> str: + """This creates a modification function for a subgraph. + To use this inside a template, the first argument should specify which subgraph to codegen for - TODO come up with standardized way to modify templates, with - potential multiple modifications + Args: + subgraph_number (int): The index of the subgraph in self.subgraphs """ + assert isinstance(subgraph_number, int) + assert isinstance(self.subgraphs, list) + assert subgraph_number < len( + self.subgraphs + ), f"Invalid subgraph number provided to create_modification, {subgraph_number} must be < {len(self.subgraphs)}" + + subgraph = self.subgraphs[subgraph_number] def add_input(name): return self.args.input(name) + name = f"PlaceholderSubstitution_{subgraph_number}" + class PlaceholderSubstitution(V.WrapperHandler): # type: ignore[name-defined] - self.name = "PlaceholderSubstitution" + self.name = name def load(self, name: str, index: sympy.Expr): if name not in fixed_inputs: @@ -297,15 +307,14 @@ def load(self, name: str, index: sympy.Expr): def indirect_indexing(self, index_var, size, check): return sympy_index_symbol(str(index_var)) - # if self.modification_cache is None: with V.set_ops_handler(PlaceholderSubstitution(V.ops)): assert isinstance( - self.subgraphs, ir.ComputedBuffer - ), "Expected the subgraph to be a ComputedBuffer" - if isinstance(self.subgraphs.data, ir.InputBuffer): - out = self.subgraphs.data.make_loader()((1,)) + subgraph, ir.ComputedBuffer + ), f"Expected the subgraph to be a ComputedBuffer, got {type(subgraph)}" + if isinstance(subgraph.data, ir.InputBuffer): + out = subgraph.data.make_loader()((1,)) else: - out = self.subgraphs.data.inner_fn((1,)) + out = subgraph.data.inner_fn((1,)) self.codegen_body() self.body.writeline(f"{fixed_inputs['out']} = {out.value}") @@ -320,11 +329,18 @@ def store_output( indices: Union[List[Any], Tuple[Any]], val: str, mask: Optional[str] = None, + indent_width: int = 4, ): - """ - Hook called from template code to store the final output - (if the buffer hasn't been optimized away), then append any - epilogue fusions. + """Stores the final output and appends any epilogue fusions if the buffer hasn't been optimized away. + + Args: + indices (Union[List, Tuple]): The index for each dimension of the output. The dot product of + these indices and output strides must match `val`. + val (str): The value to store. + mask (Optional[str]): An optional mask to use for the store operation. If provided, this mask + will be applied to the store. + indent_width (int): The number of spaces to use for indentation. This is used when the call to + store_output is indented in the kernel definition. """ assert isinstance(indices, (list, tuple)) assert isinstance(val, str) @@ -348,7 +364,7 @@ def store_output( self.range_trees[0].lookup(sympy.Integer(1), sympy_product(lengths)).set_name( "xindex" ) - self.template_mask = mask + self.template_mask = mask # type: ignore[assignment] self.template_indices = indices output_index = self.output_node.get_layout().make_indexer()(index_symbols) output_index = self.rename_indexing(output_index) @@ -373,7 +389,7 @@ def store_output( def hook(): # more stuff might have been added since the codegen_body above self.codegen_body() - return textwrap.indent(self.body.getvalue(), " ").strip() + return textwrap.indent(self.body.getvalue(), " " * indent_width).strip() assert "" not in self.render_hooks self.render_hooks[""] = hook @@ -438,11 +454,8 @@ def indexing( block_ptr=block_ptr, ) - def initialize_range_tree(self, pid_cache): - super().initialize_range_tree(pid_cache) - # ignore default codegen - self.body.clear() - self.indexing_code.clear() + def codegen_range_tree(self): + pass # ignore default codegen def call_kernel(self, name: str, node: Optional[ir.IRNode] = None): wrapper = V.graph.wrapper_code @@ -696,19 +709,17 @@ def __init__( has_out_variant=True, op_overload=None, use_fallback_kernel=False, - kernel_creator=None, ): super().__init__() name = name or kernel.__name__ assert callable(kernel) - assert not hasattr(extern_kernels, name), f"duplicate extern kernel: {name}" + assert not hasattr(extern_kernels, name), "duplicate extern kernel" self.name = name self.cpp_kernel_name = cpp_kernel self.has_out_variant = has_out_variant setattr(extern_kernels, name, kernel) self.op_overload = op_overload self.use_fallback_kernel = use_fallback_kernel - self.kernel_creator = kernel_creator def to_callable(self): return getattr(extern_kernels, self.name) @@ -875,8 +886,6 @@ def output_node(self): inner = ir.FallbackKernel.create( self.choice.op_overload, *self.input_nodes, **self.kwargs ) - elif self.choice.kernel_creator is not None: - inner = self.choice.kernel_creator(*self.input_nodes, **self.kwargs) else: cls = ir.ExternKernelOut if self.has_out_variant else ir.ExternKernelAlloc inner = cls( @@ -899,86 +908,6 @@ def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType } -class DataProcessorChoiceCallerWrapper: - def __init__(self, wrapped, preprocessor, postprocessor): - self._wrapped = wrapped - if preprocessor is not None: - self._preprocessor = preprocessor - else: - self._preprocessor = lambda x, y: (x, y) - if postprocessor is not None: - self._postprocessor = postprocessor - else: - self._postprocessor = lambda x: x - - def __getattr__(self, name): - return getattr(self._wrapped, name) - - def benchmark(self, *args, out) -> float: - new_args, new_out = self._preprocessor(args, out) - result = self._wrapped.benchmark(*new_args, out=new_out) - new_out = self._postprocessor(new_out) - if out is not new_out: - out.copy_(new_out) - return result - - def output_node(self) -> ir.TensorBox: - result = self._wrapped.output_node() - return self._postprocessor(result) - - def __repr__(self) -> str: - return f"DataProcessorChoiceCallerWrapper({self._wrapped})" - - -class DataProcessorTemplateWrapper: - """ - A wrapper class for a kernel template. - - This class together with `DataProcessorChoiceCallerWrapper` provides a convenient way to - preprocess and postprocess data before and after using the wrapped template. A typical - usage is to reorder or filter the input nodes in order to match the expected input of other - kernel choices like a ATen kernel. A more complicated usage is to prepack the weights. - See the example from :mod:`cpp_gemm_template` for more details. - """ - - def __init__( - self, - wrapped_template_cls, - preprocessor, - postprocessor, - **kwargs, - ): - if preprocessor is not None: - self._preprocessor = preprocessor - else: - self._preprocessor = lambda x, y: (x, y) - if postprocessor is not None: - self._postprocessor = postprocessor - else: - self._postprocessor = lambda x: x - assert "input_nodes" in kwargs - assert "layout" in kwargs - kwargs["input_nodes"], kwargs["layout"] = preprocessor( - kwargs["input_nodes"], kwargs["layout"] - ) - self._wrapped = wrapped_template_cls(**kwargs) - - def __getattr__(self, name): - return getattr(self._wrapped, name) - - def maybe_append_choice(self, choices, **kwargs): - return type(self._wrapped).maybe_append_choice(self, choices, **kwargs) - - def generate(self, **kwargs): - choice_caller = self._wrapped.generate(**kwargs) - return DataProcessorChoiceCallerWrapper( - choice_caller, self._preprocessor, self._postprocessor - ) - - def __repr__(self) -> str: - return f"DataProcessorTemplateWrapper({self._wrapped})" - - class ErrorFromChoice(RuntimeError): def __init__(self, msg, choice: ChoiceCaller, inputs_str): msg += f"\nFrom choice {choice}\n{inputs_str}" @@ -990,6 +919,13 @@ class NoValidChoicesError(RuntimeError): pass +@functools.lru_cache(None) +def get_env_num_workers() -> Optional[int]: + if "TORCHINDUCTOR_COMPILE_THREADS" in os.environ: + return int(os.environ["TORCHINDUCTOR_COMPILE_THREADS"]) + return None + + class AlgorithmSelectorCache(PersistentCache): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -1019,8 +955,7 @@ def __call__( # Templates selected with input_gen_fns require specific input data to avoid IMA # Passing custom input gen fns to benchmark_fusion NYI, so skip deferred template selection - # TODO(jgong5): support multi-template on CPU - if input_gen_fns is not None or layout.device.type == "cpu": + if input_gen_fns is not None: return_multi_template = False # TODO - assert that we have not mutating kernels here @@ -1055,11 +990,10 @@ def no_op(*args, **kwargs): or precompilation_timeout_seconds <= 0 ): return no_op - num_workers = min( - config.compile_threads, - torch.get_num_threads(), - len(choices), - ) + + env_workers = get_env_num_workers() + num_workers = env_workers if env_workers is not None else (len(choices)) + if num_workers <= 0: return no_op @@ -1151,7 +1085,7 @@ def wait_on_futures(): else: raise e except ImportError: - raise e + raise e from None executor.shutdown(wait=True) @@ -1258,9 +1192,7 @@ def get_inputs(): } example_inputs = list(unique_example_inputs.values()) example_inputs_extern = [ - unique_example_inputs[input_node.get_name()] - if unique_example_inputs[input_node.get_name()].is_mkldnn - else torch.as_strided( + torch.as_strided( unique_example_inputs[input_node.get_name()], V.graph.sizevars.size_hints( input_node.get_size(), @@ -1349,7 +1281,7 @@ def benchmark_in_current_process(choices): ) timing = float("inf") except AssertionError as e: - raise AssertionError( # noqa: TRY200 + raise AssertionError( # noqa: B904 f"Incorrect result from choice {choice}\n\n{e}" ) except Exception as e: @@ -1362,7 +1294,7 @@ def benchmark_in_current_process(choices): else: raise e except ImportError: - raise e + raise e from None timings[choice] = timing @@ -1421,7 +1353,7 @@ def log_results( result = timings[choice] if result: sys.stderr.write( - f" {choice.name} {result:.4f} ms {best_time/result:.1%}\n" + f" {choice.name} {result:.4f} ms {best_time / result:.1%}\n" ) else: sys.stderr.write( diff --git a/torch/_inductor/sizevars.py b/torch/_inductor/sizevars.py index 65a9cb8379078..b6288b34fafa1 100644 --- a/torch/_inductor/sizevars.py +++ b/torch/_inductor/sizevars.py @@ -310,6 +310,14 @@ def statically_known_leq(self, left: Expr, right: Expr) -> bool: expr = left <= right return self.is_expr_static_and_true(expr) + # See Note - [On Statically Known] + def statically_known_geq(self, left: Expr, right: Expr) -> bool: + """ + Returns a bool indicating if it is sound to optimize as if left is greater than or equal to right. + """ + expr = left >= right + return self.is_expr_static_and_true(expr) + # See Note - [On Statically Known] def statically_known_lt(self, left: Expr, right: Expr) -> bool: """ @@ -318,6 +326,14 @@ def statically_known_lt(self, left: Expr, right: Expr) -> bool: expr = left < right return self.is_expr_static_and_true(expr) + # See Note - [On Statically Known] + def statically_known_gt(self, left: Expr, right: Expr) -> bool: + """ + Returns a bool indicating if it is sound to optimize as if left is greater than right. + """ + expr = left > right + return self.is_expr_static_and_true(expr) + # See Note - [On Statically Known] def statically_known_multiple_of(self, numerator: Expr, denominator: Expr) -> bool: """ diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index bcf586862a716..59baad51885e0 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -339,7 +339,7 @@ def print_performance( ): timings = torch.tensor([timed(fn, args, times, device) for _ in range(repeat)]) took = torch.median(timings) / times - print(f"{took/baseline:.6f}") + print(f"{took / baseline:.6f}") return took @@ -726,6 +726,8 @@ def fresh_inductor_cache(cache_entries=None): except Exception: log.warning("on error, temporary cache dir kept at %s", inductor_cache_dir) raise + finally: + clear_inductor_caches() def argsort(seq) -> List[int]: @@ -985,48 +987,6 @@ def use_cutlass_template(layout, m, n, k): return res -def _use_template_for_cpu(layout): - return use_max_autotune() and layout.device.type == "cpu" - - -def use_cpp_packed_gemm_template(layout, mat1, mat2): - from . import ir - from .codegen.cpp_micro_gemm import create_micro_gemm - from .kernel.mm_common import mm_args - - if not _use_template_for_cpu(layout) or not _use_autotune_backend("CPP"): - return False - - if not config.cpp.weight_prepack: - return False - - layout_dtypes = [torch.float32, torch.bfloat16, torch.half] - m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2) - # TODO(jgong5): support dynamic shapes for n or k - if has_free_symbols((n, k)): - return False - if isinstance(mat2, ir.BaseView): - mat2 = mat2.unwrap_view() - micro_gemm = create_micro_gemm( - "micro_gemm", - m, - n, - k, - input_dtype=layout.dtype, - output_dtype=torch.float, - num_threads=parallel_num_threads(), - ) - # TODO(jgong5): support n % n_block_size != 0 - return ( - layout.dtype in layout_dtypes - and micro_gemm is not None - and n % micro_gemm.register_blocking[1] == 0 - and mat1.get_stride()[-1] == 1 # TODO(jgong5): support transposed input - and isinstance(mat2, ir.StorageBox) - and mat2.is_module_buffer() - ) - - def use_aten_gemm_kernels(): return not use_max_autotune() or _use_autotune_backend("ATEN") @@ -1496,7 +1456,7 @@ def dump_node_schedule(node_schedule): An API that can be used in pdb to dump a node_schedule. Right mainly dump the read/write dependencies but can add more as needed. """ - from torch._inductor.codegen.triton import DisableReduction, EnableReduction + from torch._inductor.codegen.simd import DisableReduction, EnableReduction from torch._inductor.scheduler import SchedulerNode print(f"Node schedule with {len(node_schedule)} nodes") @@ -1618,16 +1578,23 @@ def aoti_compile_with_persistent_cache( """ Compile the given function with persistent cache for AOTI eager mode. """ - flattened_inputs = pytree.arg_tree_leaves(*args, **kwargs) - assert all( - isinstance(input, torch.Tensor) for input in flattened_inputs - ), "Only support tensor for now" assert not dynamic, "Only support static shape for now" + type_to_torch_dtype = {int: torch.int32, float: torch.float, bool: torch.bool} + supported_scalar_types = tuple(type_to_torch_dtype.keys()) + flattened_inputs = pytree.arg_tree_leaves(*args, **kwargs) + if not all( + isinstance(input, (supported_scalar_types, torch.Tensor)) + for input in flattened_inputs + ): + raise NotImplementedError("Only support tensor, int, float, bool for now") persistent_cache = aoti_eager_cache_dir(ns, device_type) - persistent_cache.mkdir(parents=True, exist_ok=True) + if not persistent_cache.exists(): + persistent_cache.mkdir(parents=True) + persistent_cache_lib = persistent_cache / "lib" - persistent_cache_lib.mkdir(parents=True, exist_ok=True) + if not persistent_cache_lib.exists(): + persistent_cache_lib.mkdir() with mock.patch.dict( os.environ, @@ -1649,18 +1616,30 @@ def aoti_compile_with_persistent_cache( ) kernel_metadata_items = [] - for input_tensor in flattened_inputs: + for input in flattened_inputs: # TODO(Eikan): To add dynamic support metadata: Dict[str, Any] = {} metadata["is_dynamic"] = dynamic - metadata["device_type"] = f"{input_tensor.device.type}" - if is_cpu_device([input_tensor]): - metadata["device_index"] = -1 + + if isinstance(input, torch.Tensor): + metadata["device_type"] = f"{input.device.type}" + if is_cpu_device([input]): + metadata["device_index"] = -1 + else: + metadata["device_index"] = input.device.index + metadata["dtype"] = f"{input.dtype}" + metadata["sizes"] = list(input.size()) + metadata["strides"] = list(input.stride()) else: - metadata["device_index"] = input_tensor.device.index - metadata["dtype"] = f"{input_tensor.dtype}" - metadata["sizes"] = list(input_tensor.size()) - metadata["strides"] = list(input_tensor.stride()) + assert isinstance(input, supported_scalar_types) + # Scalar tensor + metadata["device_type"] = device_type + metadata["device_index"] = -1 if device_type == "cpu" else 0 + metadata["dtype"] = f"{type_to_torch_dtype[type(input)]}" + metadata["sizes"] = [] + metadata["strides"] = [] + metadata["scalar_value"] = input + kernel_metadata_items.append(metadata) kernel_meta_info: Dict[str, Any] = {} @@ -1696,3 +1675,26 @@ def aoti_compile_with_persistent_cache( return kernel_lib_path except Exception as e: return "" + + +def run_and_get_cpp_code(fn, *args, **kwargs): + # We use the patch context manager instead of using it as a decorator. + # In this way, we can ensure that the attribute is patched and unpatched correctly + # even if this run_and_get_cpp_code function is called multiple times. + with unittest.mock.patch.object(config, "debug", True): + torch._dynamo.reset() + import io + import logging + + log_capture_string = io.StringIO() + ch = logging.StreamHandler(log_capture_string) + from torch._inductor.graph import output_code_log + + output_code_log.addHandler(ch) + prev_level = output_code_log.level + output_code_log.setLevel(logging.DEBUG) + result = fn(*args, **kwargs) + s = log_capture_string.getvalue() + output_code_log.setLevel(prev_level) + output_code_log.removeHandler(ch) + return result, s diff --git a/torch/_library/fake_class_registry.py b/torch/_library/fake_class_registry.py index 32a9aa8c87110..d77989cd829b9 100644 --- a/torch/_library/fake_class_registry.py +++ b/torch/_library/fake_class_registry.py @@ -134,8 +134,10 @@ def register_fake_class(qualname, fake_class: Optional[HasStaticMethodFromReal] returns an instance of the fake class. All tensors in the fake object should also be properly fakified with to_fake_tensor() in from_real. + Examples: # For a custom class Foo defined in test_custom_class_registration.cpp: + TORCH_LIBRARY(_TorchScriptTesting, m) { m.class_("_TensorQueue") .def(torch::init()) @@ -144,6 +146,7 @@ def register_fake_class(qualname, fake_class: Optional[HasStaticMethodFromReal] .def("top", &TensorQueue::top) .def("size", &TensorQueue::size) .def("clone_queue", &TensorQueue::clone_queue) + .def("__obj_flatten__", &TensorQueue::__obj_flatten__) .def_pickle( // __getstate__ [](const c10::intrusive_ptr& self) @@ -166,8 +169,7 @@ def __init__(self, queue): @classmethod def __obj_unflatten__(cls, flattened_ctx): - ctx = {flattened_ctx[0]: flattened_ctx[1]} - return cls(**ctx) + return cls(**dict(ctx)) def push(self, x): self.queue.append(x) @@ -178,6 +180,11 @@ def pop(self): def size(self): return len(self.queue) + In this example, the original TensorQeue need to addd a __obj_flatten__ method + to the class TensorQueue and the flattend result is passed into FakeTensorQueue's + __obj_unflatten__ as inputs to create a fake class. This protocol allows pytorch to look + at the contents of the script object and properly handle them in the subsystems + like dynamo, aot_aotugrad or more. """ def inner(fake_class: HasStaticMethodFromReal): diff --git a/torch/_logging/_registrations.py b/torch/_logging/_registrations.py index 0530c12df3047..10463b864f440 100644 --- a/torch/_logging/_registrations.py +++ b/torch/_logging/_registrations.py @@ -26,12 +26,13 @@ register_log("torch", "torch") register_log("distributed", DISTRIBUTED) register_log( - "dist_c10d", ["torch.distributed.distributed_c10d", "torch.distributed.rendezvous"] + "c10d", ["torch.distributed.distributed_c10d", "torch.distributed.rendezvous"] ) register_log( - "dist_ddp", ["torch.nn.parallel.distributed", "torch._dynamo.backends.distributed"] + "ddp", ["torch.nn.parallel.distributed", "torch._dynamo.backends.distributed"] ) -register_log("dist_fsdp", ["torch.distributed.fsdp"]) +register_log("pp", ["torch.distributed.pipelining"]) +register_log("fsdp", ["torch.distributed.fsdp"]) register_log("onnx", "torch.onnx") register_log("export", ["torch._dynamo", "torch.export", *DYNAMIC]) diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index a4edd839c651e..93e45bfb1d845 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -19,6 +19,7 @@ corresponding_real_dtype, elementwise_dtypes, ELEMENTWISE_TYPE_PROMOTION_KIND, + FloatLike, IntLike, make_contiguous_strides_for, Number, @@ -3286,6 +3287,15 @@ def _meta_foreach_inplace(*args, _scalar_op=None, **kwargs): return +@register_meta([aten._foreach_pow_.Scalar]) +def meta__foreach_pow__scalar(self, exponent): + torch._check( + isinstance(exponent, FloatLike), + lambda: f"exponent must be a float but got {type(exponent)}", + ) + return + + @register_meta([aten._foreach_pow.ScalarAndTensor]) def meta__foreach_pow_scalar_and_tensor(self, exponent): # Only foreach_pow has a ScalarAndTensor method and needs special diff --git a/torch/_numpy/_util.py b/torch/_numpy/_util.py index ff219d930731c..477d3d44671ad 100644 --- a/torch/_numpy/_util.py +++ b/torch/_numpy/_util.py @@ -178,7 +178,7 @@ def _try_convert_to_tensor(obj): tensor = torch.as_tensor(obj) except Exception as e: mesg = f"failed to convert {obj} to ndarray. \nInternal error is: {str(e)}." - raise NotImplementedError(mesg) # noqa: TRY200 + raise NotImplementedError(mesg) # noqa: B904 return tensor diff --git a/torch/_numpy/linalg.py b/torch/_numpy/linalg.py index 2232419db1b2e..093851142dbca 100644 --- a/torch/_numpy/linalg.py +++ b/torch/_numpy/linalg.py @@ -38,7 +38,7 @@ def wrapped(*args, **kwds): try: return func(*args, **kwds) except torch._C._LinAlgError as e: - raise LinAlgError(*e.args) # noqa: TRY200 + raise LinAlgError(*e.args) # noqa: B904 return wrapped diff --git a/torch/_numpy/testing/utils.py b/torch/_numpy/testing/utils.py index cd3d3407f582e..f757860e12183 100644 --- a/torch/_numpy/testing/utils.py +++ b/torch/_numpy/testing/utils.py @@ -247,7 +247,7 @@ def assert_equal(actual, desired, err_msg="", verbose=True): assert_equal(actualr, desiredr) assert_equal(actuali, desiredi) except AssertionError: - raise AssertionError(msg) # noqa: TRY200 + raise AssertionError(msg) # noqa: B904 # isscalar test to check cases such as [np.nan] != np.nan if isscalar(desired) != isscalar(actual): @@ -279,7 +279,7 @@ def assert_equal(actual, desired, err_msg="", verbose=True): except (DeprecationWarning, FutureWarning) as e: # this handles the case when the two types are not even comparable if "elementwise == comparison" in e.args[0]: - raise AssertionError(msg) # noqa: TRY200 + raise AssertionError(msg) # noqa: B904 else: raise @@ -426,7 +426,7 @@ def _build_err_msg(): assert_almost_equal(actualr, desiredr, decimal=decimal) assert_almost_equal(actuali, desiredi, decimal=decimal) except AssertionError: - raise AssertionError(_build_err_msg()) # noqa: TRY200 + raise AssertionError(_build_err_msg()) # noqa: B904 if isinstance(actual, (ndarray, tuple, list)) or isinstance( desired, (ndarray, tuple, list) @@ -726,7 +726,7 @@ def func_assert_same_pos(x, y, func=isnan, hasval="nan"): names=("x", "y"), precision=precision, ) - raise ValueError(msg) # noqa: TRY200 + raise ValueError(msg) # noqa: B904 def assert_array_equal(x, y, err_msg="", verbose=True, *, strict=False): @@ -2272,7 +2272,7 @@ def check_free_memory(free_bytes): try: mem_free = _parse_size(env_value) except ValueError as exc: - raise ValueError( # noqa: TRY200 + raise ValueError( # noqa: B904 f"Invalid environment variable {env_var}: {exc}" ) diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index ac6a60d0078c7..68675c7517360 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -3830,7 +3830,7 @@ def _check_stack_inputs(tensors: TensorSequenceType) -> None: entry_shape = tensors[0].shape for i in range(1, len(tensors)): assert tensors[i].shape == entry_shape, ( - f"stack expects each tensor to be equal size, but got {entry_shape} at entry 0" + f"stack expects each tensor to be equal size, but got {entry_shape} at entry 0 " f"and {tensors[i].shape} at entry {i}" ) @@ -6298,7 +6298,7 @@ def _compute_sizes(seq, scalar_type): try: handle = seq[0] except Exception: - raise ValueError( # noqa: TRY200 + raise ValueError( # noqa: B904 f"could not determine the shape of object type '{type(seq).__name__}'" ) seq = handle @@ -6358,12 +6358,24 @@ def _infer_scalar_type(obj): # Analogous to recursive_store # xref: recursive_store in torch/csrc/utils/tensor_new.cpp -def _recursive_build(scalarType: torch.dtype, obj: TensorOrNumberLikeType): - if isinstance(obj, Tensor) and obj.ndim <= 1: +def _recursive_build( + scalarType: torch.dtype, obj: Union[TensorOrNumberLikeType, TensorSequenceType] +): + if isinstance(obj, Tensor) and obj.numel() == 1: return obj.detach().to(dtype=scalarType, device="cpu", copy=True).view(()) + elif isinstance(obj, Tensor): + # It is invalid to call ".tensor([...])" with a non-scalar tensor in eager mode + # >>> torch.tensor([torch.randn(2)]) + # ValueError: only one element tensors can be converted to Python scalars + # + # But it is possible with a NumPy array + # >>> torch.tensor([np.random.uniform(size=(2,))]).shape + # torch.Size([1, 2]) + return obj.detach().to(dtype=scalarType, device="cpu", copy=True) elif isinstance(obj, Number): return torch.scalar_tensor(obj, dtype=scalarType) + # seq can be a list of tensors seq = obj return torch.stack([_recursive_build(scalarType, item) for item in seq]) diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index ba8e899dc9437..6d22f9dcf9845 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -2195,13 +2195,13 @@ def merge_dicts(*dicts): add_docstr( torch.can_cast, r""" -can_cast(from, to) -> bool +can_cast(from_, to) -> bool Determines if a type conversion is allowed under PyTorch casting rules described in the type promotion :ref:`documentation `. Args: - from (dtype): The original :class:`torch.dtype`. + from\_ (dtype): The original :class:`torch.dtype`. to (dtype): The target :class:`torch.dtype`. Example:: diff --git a/torch/_utils.py b/torch/_utils.py index 2e48fe9a1a9de..1bb726252dee4 100644 --- a/torch/_utils.py +++ b/torch/_utils.py @@ -52,71 +52,40 @@ def _type(self, dtype=None, non_blocking=False, **kwargs): return dtype(self.size()).copy_(self, non_blocking) -def _hpu(self, device=None, non_blocking=False, **kwargs): - """Returns a copy of this object in HPU memory. +def _to(self, device, non_blocking=False): + """Returns a copy of this object in device memory. - If this object is already in HPU memory and on the correct device, then - no copy is performed and the original object is returned. + If this object is already on the correct device, then no copy is performed + and the original object is returned. Args: - device (int): The destination HPU id. Defaults to the current device. + device (int): The destination device. non_blocking (bool): If ``True`` and the source is in pinned memory, the copy will be asynchronous with respect to the host. Otherwise, the argument has no effect. - **kwargs: For compatibility, may contain the key ``async`` in place of - the ``non_blocking`` argument. """ - non_blocking = _get_async_or_non_blocking("hpu", non_blocking, kwargs) - hpu = getattr(torch, "hpu", None) - assert hpu is not None, "HPU device module is not loaded" - if self.is_hpu: - if device is None: - device = hpu.current_device() - if self.get_device() == device: - return self - else: - if device is None: - device = -1 - with hpu.device(device): - assert not self.is_sparse, "sparse storage is not supported for HPU tensors" - untyped_storage = torch.UntypedStorage(self.size(), device=torch.device("hpu")) - untyped_storage.copy_(self, non_blocking) - return untyped_storage - - -def _cuda(self, device=None, non_blocking=False, **kwargs): - """Returns a copy of this object in CUDA memory. - - If this object is already in CUDA memory and on the correct device, then - no copy is performed and the original object is returned. + if self.device == device: + return self - Args: - device (int): The destination GPU id. Defaults to the current device. - non_blocking (bool): If ``True`` and the source is in pinned memory, - the copy will be asynchronous with respect to the host. Otherwise, - the argument has no effect. - **kwargs: For compatibility, may contain the key ``async`` in place of - the ``non_blocking`` argument. - """ - non_blocking = _get_async_or_non_blocking("cuda", non_blocking, kwargs) - if self.is_cuda: - if device is None: - device = torch.cuda.current_device() - if self.get_device() == device: - return self - else: - if device is None: - device = -1 - with torch.cuda.device(device): - if self.is_sparse: - new_type = getattr(torch.cuda.sparse, self.__class__.__name__) - indices = torch.Tensor._indices(self).cuda(device, non_blocking) - values = torch.Tensor._values(self).cuda(device, non_blocking) + device_module = getattr(torch, device.type, None) + assert ( + device_module is not None + ), f"{device.type.upper()} device module is not loaded" + with device_module.device(device): + if self.is_sparse and hasattr(device_module, "sparse"): + new_type = getattr(device_module.sparse, self.__class__.__name__) + indices = getattr(torch.Tensor._indices(self), device.type)( + device, non_blocking + ) + values = getattr(torch.Tensor._values(self), device.type)( + device, non_blocking + ) return new_type(indices, values, self.size()) else: - untyped_storage = torch.UntypedStorage( - self.size(), device=torch.device("cuda") - ) + assert ( + not self.is_sparse + ), f"sparse storage is not supported for {device.type.upper()} tensors" + untyped_storage = torch.UntypedStorage(self.size(), device=device) untyped_storage.copy_(self, non_blocking) return untyped_storage diff --git a/torch/_weights_only_unpickler.py b/torch/_weights_only_unpickler.py index 44dd8223862ae..6c9f3b61ae8ba 100644 --- a/torch/_weights_only_unpickler.py +++ b/torch/_weights_only_unpickler.py @@ -9,6 +9,10 @@ # - `torch.nn.Parameter` # - `collections.Counter` # - `collections.OrderedDict` +# Additionally, users can use an allowlist for adding classes they have deemed as safe using +# `_add_safe_globals()` (`torch.serialization.add_safe_globals`) +# `_clear_safe_globals()` (`torch.serialization.clear_safe_globals`) +# `_get_safe_globals()` (`torch.serialization.get_safe_globals`) # Based of https://github.com/python/cpython/blob/main/Lib/pickle.py # Expected to be useful for loading PyTorch model weights @@ -19,6 +23,7 @@ import functools as _functools from collections import Counter, OrderedDict +from inspect import getattr_static from pickle import ( APPEND, APPENDS, @@ -59,11 +64,57 @@ UnpicklingError, ) from struct import unpack -from sys import maxsize -from typing import Any, Dict, List +from sys import maxsize, modules +from typing import Any, Dict, List, Type import torch +_marked_safe_globals_list: List[Any] = [] + + +def _add_safe_globals(safe_globals: List[Any]): + global _marked_safe_globals_list + _marked_safe_globals_list += safe_globals + + +def _get_safe_globals() -> List[Any]: + global _marked_safe_globals_list + return _marked_safe_globals_list + + +def _clear_safe_globals(): + global _marked_safe_globals_list + _marked_safe_globals_list = [] + + +# Separate from _get_allowed_globals because of the lru_cache on _get_allowed_globals +# For example if user had a script like +# torch.load(file_a) +# torch.serialization._add_safe_globals([torch.foo]) +# torch.load(file_b) +# the dynamic additions to safe_globals would not be picked up by +# _get_allowed_globals due to the lru_cache +def _get_user_allowed_globals(): + rc: Dict[str, Any] = {} + for f in _marked_safe_globals_list: + rc[f"{f.__module__}.{f.__name__}"] = f + return rc + + +def _tensor_rebuild_functions(): + return { + torch._utils._rebuild_parameter, + torch._utils._rebuild_parameter_with_state, + torch._utils._rebuild_qtensor, + torch._utils._rebuild_tensor, + torch._utils._rebuild_tensor_v2, + torch._utils._rebuild_tensor_v3, + torch._utils._rebuild_sparse_tensor, + torch._utils._rebuild_meta_tensor_no_storage, + torch._utils._rebuild_nested_tensor, + torch._utils._rebuild_wrapper_subclass, + } + # Unpickling machinery @_functools.lru_cache(maxsize=1) @@ -75,6 +126,7 @@ def _get_allowed_globals(): "torch.serialization._get_layout": torch.serialization._get_layout, "torch.Size": torch.Size, "torch.Tensor": torch.Tensor, + "torch.device": torch.device, } # dtype for t in torch.storage._dtype_to_storage_type_map().keys(): @@ -103,17 +155,7 @@ def _get_allowed_globals(): ]: rc[str(qt)] = qt # Rebuild functions - for f in [ - torch._utils._rebuild_parameter, - torch._utils._rebuild_parameter_with_state, - torch._utils._rebuild_qtensor, - torch._utils._rebuild_tensor, - torch._utils._rebuild_tensor_v2, - torch._utils._rebuild_tensor_v3, - torch._utils._rebuild_sparse_tensor, - torch._utils._rebuild_meta_tensor_no_storage, - torch._utils._rebuild_nested_tensor, - ]: + for f in _tensor_rebuild_functions(): rc[f"torch._utils.{f.__name__}"] = f # Handles Tensor Subclasses, Tensor's with attributes. @@ -128,6 +170,11 @@ def __init__(self, file, *, encoding: str = "bytes"): self.readline = file.readline self.read = file.read self.memo: Dict[int, Any] = {} + # tensor subclass types found from GLOBAL instructions that have passed the criteria + # to be allowed as the second argument to `torch._tensor._rebuild_from_type_v2` + # This enables rebuilding of tensor subclasses defined outside the `torch` package. + # See [Note: Criteria for allowing out-of-core tensor subclasses] for details on the criteria. + self.tensor_subclasses_found: Dict[str, Type] = {} def load(self): """Read a pickled object representation from the open file. @@ -151,8 +198,124 @@ def load(self): full_path = f"{module}.{name}" if full_path in _get_allowed_globals(): self.append(_get_allowed_globals()[full_path]) + elif full_path in _get_user_allowed_globals(): + self.append(_get_user_allowed_globals()[full_path]) else: - raise RuntimeError(f"Unsupported class {full_path}") + # The logic in this branch handles user-defined tensor subclasses. + # We can automatically allow and raise and error for anything that is not provably safe. + # [Note: Criteria for allowing out-of-core tensor subclasses] + # GLOBAL '.' instructions will get the class and + # push the string (not the actual type) while adding the type to the dictionary keyed + # by the string onto the unpickler's stack if they satisfy the following conditions: + # (1) The that defines them is in `sys.modules` + # (we will use getattr_static to access it to ensure no code execution) + # (2) They inherit from `torch.Tensor` + # (2) The class is not overriding any of the `torch.Tensor` methods listed here: + # `__getattr__`, `__get__`, `__getattribute__`, `__setstate__`, `__set__`, + # and `tp_alloc` + # The methods that we ban overriding were selected in a test-driven manner + # by overriding every callable method on a tensor subclass and determinining + # which might get called during unpickling. + # When executing REDUCE, the string will be appropriately converted back to the type only + # for `torch._tensor._rebuild_from_type_v2` as other use of the class could use methods + # we didn't audit. + if module == "__builtin__": + raise RuntimeError( + f"Unsupported global: GLOBAL {full_path} was not an allowed global by default. " + "Please use `torch.serialization.add_safe_globals` to allowlist this global " + "if you trust this class/function." + ) + elif module not in modules: + # TODO: add a link here to a doc that explains to users what we mean by trust + raise RuntimeError( + f"Found GLOBAL `{full_path}` instruction in the pickle file but `{full_path}` was " + f"not in the pre-defined list of allowed globals that are considered safe by the " + "weights_only unpickler for rebuilding state_dicts. This is the expected behavior if " + f"`{full_path}` is a class or function that is not in the list of allowed globals " + f"If `{full_path}` is NOT a tensor subclass, you might consider" + "`torch.serialization.add_safe_globals` if it is appropriate. However, if it is a " + "user-defined tensor subclass not defined in the `torch` package, this error might arise " + f"as we expect `{module}` to be present in `sys.modules` (i.e. it " + "must be imported in the current environment), but this was not the case. " + f"If you intend to unpickle a tensor subclass `{full_path}` please import `{name}` from " + f"`{module}`. Note that having this imported will *only* allow the type `{full_path}` to " + "be passed as the second argument to `torch._tensor._rebuild_from_type_v2`, which should " + "enable the tensor subclass to be unpickled without any arbitrary code execution as long " + # If the user imports and these are overridden the next error will prompt them to use + # torch.serialization.add_safe_globals. + "a sa pre-defined list of methods called when unpickling are not overridden. In " + "particular, the methods are `__getattr__`, `__get__`, `__getattribute__`, `__setstate__`, " + "`__set__`, as well as the implementation of `tp_alloc`." + ) + else: + try: + class_type = getattr_static(modules[module], name) + except AttributeError as e: + raise AttributeError( + "For safety during weights_only loading, we use inspect.getattr_state to " + f"get {name} from {module}, if {module} implements the descriptor protocol, " + "__getattr__ or __getattribute__ these will not be called." + ) from e + # None of the objects here contain any data from the pickle so this is safe + if isinstance(class_type, type) and issubclass( + class_type, torch.Tensor + ): + # getattr is called by the getattr call in `_rebuild_from_type_v2` + custom_get_attribute = ( + class_type.__getattribute__ + is not torch.Tensor.__getattribute__ + ) + custom_get = ( + getattr_static(class_type, "__get__", None) is not None + ) + custom_get_attr = ( + getattr_static(class_type, "__getattr__", None) + is not None + ) + # Tensor.__setstate__ might be called in `_rebuild_from_type_v2` + custom_set_state = ( + class_type.__setstate__ is not torch.Tensor.__setstate__ + ) + # setattr is called in `torch._utils._set_obj_state` + custom_set_attr = ( + class_type.__setattr__ is not object.__setattr__ + ) + custom_set = ( + getattr_static(class_type, "__set__", None) is not None + ) + # tp_alloc is called by `Tensor._rebuild_wrapper_subclass` and `Tensor.as_subclass` + has_custom_tp_alloc = ( + not torch._C._check_tp_alloc_is_default(class_type) + ) + custom_methods = { + "__getattribute__": custom_get_attribute, + "__getattr__": custom_get_attr, + "__get__": custom_get, + "__setattr__": custom_set_attr, + "__set__": custom_set, + "__setstate__": custom_set_state, + "tp_alloc": has_custom_tp_alloc, + } + if any(custom_methods.values()): + error = "" + for k, v in custom_methods.items(): + error += f" {k}={v}" + raise RuntimeError( + f"Trying to unpickle tensor subclass `{full_path}` that has defined a custom " + f"version for one of these methods:{error}. Please check whether you trust these " + "methods and allowlist the subclass with `torch.serialization.add_safe_globals` if so." + ) + # push the string full_path onto the stack (in REBUILD, there is special logic to + # access this from tensor_subclasses_found for rebuild_from_type_v2) + self.tensor_subclasses_found[full_path] = class_type + self.append(full_path) + else: + raise RuntimeError( + f"Unsupported global: GLOBAL {full_path} was not an allowed global by default. " + "Please use `torch.serialization.add_safe_globals` to allowlist this global " + "if you trust this class/function." + ) + elif key[0] == NEWOBJ[0]: args = self.stack.pop() cls = self.stack.pop() @@ -162,10 +325,33 @@ def load(self): elif key[0] == REDUCE[0]: args = self.stack.pop() func = self.stack[-1] - if func not in _get_allowed_globals().values(): + if ( + func not in _get_allowed_globals().values() + and func not in _get_user_allowed_globals().values() + ): raise RuntimeError( f"Trying to call reduce for unrecognized function {func}" ) + # Special handling for tensor subclass type found in GLOBAL that is pushed + # onto stack as str to prevent it from being used anywhere except the + # second arg of _rebuild_from_type_v2 and within argument tuple for _rebuild_wrapper_subclass + # _rebuild_from_type_v2 is called with args (func, type, func_args, state) + # where both type and, when func is rebuild_wrapper_subclass, func_args[0] could be the subclass type + # Since we pushed these subclass types onto the stack as strings, convert them to the actual + # type here. + if func is torch._tensor._rebuild_from_type_v2 and type(args[1]) is str: + args_after = args[2:] + if ( + args[0] is torch._utils._rebuild_wrapper_subclass + and type(args[2][0]) is str + ): + new_arg_tuple = ( + self.tensor_subclasses_found[args[2][0]], + ) + args[2][1:] + args_after = (new_arg_tuple,) + args[3:] + args = ( + args[:1] + (self.tensor_subclasses_found[args[1]],) + args_after + ) self.stack[-1] = func(*args) elif key[0] == BUILD[0]: state = self.stack.pop() diff --git a/torch/ao/quantization/experimental/adaround_fake_quantize.py b/torch/ao/quantization/experimental/adaround_fake_quantize.py new file mode 100644 index 0000000000000..4d988bbb25bb2 --- /dev/null +++ b/torch/ao/quantization/experimental/adaround_fake_quantize.py @@ -0,0 +1,148 @@ +from typing import Tuple + +import torch +from torch.ao.quantization.fake_quantize import _is_symmetric_quant +from torch.ao.quantization.utils import is_per_tensor +from torch.quantization import FakeQuantize +from torch.quantization.observer import MinMaxObserver + + +class AdaroundFakeQuantizer(FakeQuantize): + """ + This is a FakeQuantizer that enables an adaptive rounding fake quantizer. + Adaround is a technique to adaptively round weights, derived from the paper https://arxiv.org/pdf/2004.10568.pdf + For HTP compatibility, we are targeting to use symmetric quantization + """ + + scale: torch.Tensor + zero_point: torch.Tensor + V: torch.nn.Parameter + + # pyre-fixme[3]: Return type must be annotated. + def __init__( + self, + observer=MinMaxObserver, + qscheme=torch.per_tensor_symmetric, # not used, but needed for fakequant + quant_min: int = -128, + quant_max: int = 127, + ch_axis: int = 0, + # pyre-fixme[2]: Parameter must be annotated. + **observer_kwargs, + ): + super().__init__( + observer=observer, + qscheme=qscheme, + quant_min=quant_min, + quant_max=quant_max, + is_dynamic=False, + **observer_kwargs, + ) + # Populate quant_min/quant_max to observer_kwargs if valid + if quant_min is not None and quant_max is not None: + assert ( + quant_min <= quant_max + ), "quant_min must be less than or equal to quant_max" + # pyre-fixme[4]: Attribute must be annotated. + self.qscheme = qscheme + self.is_per_tensor: bool = is_per_tensor(qscheme) + self.is_symmetric: bool = _is_symmetric_quant(qscheme) + assert self.is_symmetric, "Only symmetric quantization is supported" + self.ch_axis: int = ch_axis + + self.scale = torch.tensor([], requires_grad=False) + self.zero_point = torch.tensor([], requires_grad=False) + self.V = torch.nn.Parameter(torch.tensor([]), requires_grad=True) + # Fixed Stretch parameters + self.zeta: torch.Tensor = torch.tensor(1.1, requires_grad=False) + self.gamma: torch.Tensor = torch.tensor(-0.1, requires_grad=False) + self.sigmoid = torch.nn.Sigmoid() + self.use_soft_rounding = True + + @torch.jit.export + def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]: + return self.scale, self.zero_point + + @torch.jit.export + def extra_repr(self) -> str: + return ( + f"fake_quant_enabled={self.fake_quant_enabled}, observer_enabled={self.observer_enabled}, " + f"quant_min={self.activation_post_process.quant_min}, quant_max={self.activation_post_process.quant_max}, " + f"dtype={self.dtype}, qscheme={self.qscheme}, ch_axis={self.ch_axis}, " + f"scale={self.scale}, zero_point={self.zero_point}, (self.V >= 0).int().sum()={(self.V >= 0).int().sum()}" + ) + + def enable_weight_fake_quant(self) -> None: + self.fake_quant_enabled[0] = 1 + + def get_rectified_sigmoid_func(self) -> torch.Tensor: + if self.use_soft_rounding: + return torch.clamp( + self.sigmoid(self.V) * (self.zeta - self.gamma) + self.gamma, + min=0, + max=1, + ) + else: + # This will dump a binary solution + return (self.V >= 0).int() + + @torch.jit.ignore + def update_scale( + self, X: torch.Tensor, _scale: torch.Tensor, _zero_point: torch.Tensor + ) -> None: + if self.scale.numel() == 0: + self.scale.data = _scale.to(X.device) + self.zero_point = _zero_point.to(X.device) + else: + self.scale.data = _scale + if not self.is_symmetric: + self.zero_point = _zero_point + else: + self.zero_point = torch.zeros_like(_zero_point) + for i in range(X.dim()): + if i == self.ch_axis: + continue + self.zero_point = self.zero_point.unsqueeze(i) + X_q = X / self.scale + X_q_floor = torch.floor(X_q) + residual = X_q - X_q_floor # [0,1) + assert torch.all( + torch.ge(residual, 0) + ), "residual should be non-negative [0, 1)" + V_init = -torch.log((self.zeta - self.gamma) / (residual - self.gamma) - 1) + self.V.data = V_init + + def forward(self, X: torch.Tensor) -> torch.Tensor: + if self.observer_enabled[0] == 1: + X_detached = X.detach() + self.activation_post_process(X_detached) + _scale, _zero_point = self.activation_post_process.calculate_qparams() + _scale, _zero_point = _scale.to(self.scale.device), _zero_point.to( + self.zero_point.device + ) + dims = list(range(X.dim())) + if not self.is_per_tensor: + dims.remove(self.ch_axis) + if not self.is_per_tensor: + for i in range(X.dim()): + if i == self.ch_axis: + continue + _scale = _scale.unsqueeze(i) + _zero_point = _zero_point.unsqueeze(i) + self.update_scale(X_detached, _scale, _zero_point) + + if self.fake_quant_enabled[0] == 1: + # Perform soft quantization + # See the equation (23) in Adaround paper + h_v = self.get_rectified_sigmoid_func() + X_q = X / self.scale + # Straight-Through Estimator for floor function + X_q_floor = torch.floor(X_q) + self.zero_point + # Regardless of rounding, gradient should be able to flow back to self.V from X_q_dq. + # With adaround, we don't train weight, but train V only. + X_q_dq = ( + torch.clamp(X_q_floor + h_v, min=self.quant_min, max=self.quant_max) + - self.zero_point + ) * self.scale + return X_q_dq + else: + return X diff --git a/torch/ao/quantization/experimental/adaround_loss.py b/torch/ao/quantization/experimental/adaround_loss.py new file mode 100644 index 0000000000000..8080d72cc6da2 --- /dev/null +++ b/torch/ao/quantization/experimental/adaround_loss.py @@ -0,0 +1,96 @@ +from typing import Tuple + +import numpy as np +import torch +from torch.nn import functional as F + +ADAROUND_ZETA: float = 1.1 +ADAROUND_GAMMA: float = -0.1 + + +class AdaptiveRoundingLoss(torch.nn.Module): + """ + Adaptive Rounding Loss functions described in https://arxiv.org/pdf/2004.10568.pdf + rounding regularization is eq [24] + reconstruction loss is eq [25] except regularization term + """ + + def __init__( + self, + max_iter: int, + warm_start: float = 0.2, + beta_range: Tuple[int, int] = (20, 2), + reg_param: float = 0.001, + ) -> None: + super().__init__() + self.max_iter = max_iter + self.warm_start = warm_start + self.beta_range = beta_range + self.reg_param = reg_param + + def rounding_regularization( + self, + V: torch.Tensor, + curr_iter: int, + ) -> torch.Tensor: + """ + Major logics copied from official Adaround Implementation. + Apply rounding regularization to the input tensor V. + """ + assert ( + curr_iter < self.max_iter + ), "Current iteration strictly les sthan max iteration" + if curr_iter < self.warm_start * self.max_iter: + return torch.tensor(0.0) + else: + start_beta, end_beta = self.beta_range + warm_start_end_iter = self.warm_start * self.max_iter + + # compute relative iteration of current iteration + rel_iter = (curr_iter - warm_start_end_iter) / ( + self.max_iter - warm_start_end_iter + ) + beta = end_beta + 0.5 * (start_beta - end_beta) * ( + 1 + np.cos(rel_iter * np.pi) + ) + + # A rectified sigmoid for soft-quantization as formualted [23] in https://arxiv.org/pdf/2004.10568.pdf + h_alpha = torch.clamp( + torch.sigmoid(V) * (ADAROUND_ZETA - ADAROUND_GAMMA) + ADAROUND_GAMMA, + min=0, + max=1, + ) + + # Apply rounding regularization + # This regularization term helps out term to converge into binary solution either 0 or 1 at the end of optimization. + inner_term = torch.add(2 * h_alpha, -1).abs().pow(beta) + regularization_term = torch.add(1, -inner_term).sum() + return regularization_term * self.reg_param + + def reconstruction_loss( + self, + soft_quantized_output: torch.Tensor, + original_output: torch.Tensor, + ) -> torch.Tensor: + """ + Compute the reconstruction loss between the soft quantized output and the original output. + """ + return F.mse_loss( + soft_quantized_output, original_output, reduction="none" + ).mean() + + def forward( + self, + soft_quantized_output: torch.Tensor, + original_output: torch.Tensor, + V: torch.Tensor, + curr_iter: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Compute the asymmetric reconstruction formulation as eq [25] + """ + regularization_term = self.rounding_regularization(V, curr_iter) + reconstruction_term = self.reconstruction_loss( + soft_quantized_output, original_output + ) + return regularization_term, reconstruction_term diff --git a/torch/ao/quantization/experimental/adaround_optimization.py b/torch/ao/quantization/experimental/adaround_optimization.py new file mode 100644 index 0000000000000..7304f885a6f36 --- /dev/null +++ b/torch/ao/quantization/experimental/adaround_optimization.py @@ -0,0 +1,238 @@ +import copy +import logging +from typing import Any, Callable, List, Optional, Tuple, Type, Union + +import torch +from torch.ao.quantization.experimental.adaround_fake_quantize import ( + AdaroundFakeQuantizer, +) +from torch.ao.quantization.experimental.adaround_loss import AdaptiveRoundingLoss +from torch.ao.quantization.observer import MinMaxObserver +from torch.nn import functional as F +from torch.nn.parallel import DataParallel +from torch.utils.data import DataLoader, TensorDataset + +logger: logging.Logger = logging.getLogger(__name__) + + +class AdaptiveRoundingOptimizer: + def __init__( + self, + model: Union[torch.nn.Module, torch.nn.DataParallel], + callback: Callable[[torch.nn.Module, List[Any]], None], + forward_hook_wrapper: Callable[[List[torch.Tensor]], Callable], + data: List[Any], + observer: Type[torch.ao.quantization.observer.ObserverBase] = MinMaxObserver, + max_iter=10000, + dtype: torch.dtype = torch.qint8, + quant_min=-128, + quant_max=127, + qscheme: torch.qscheme = torch.per_tensor_symmetric, + batch_size: int = 256, + ): + self.model = model + self.q_model = copy.deepcopy(self.model) + self.device = torch.device("cuda") if torch.cuda.is_available() else None + self.callback = callback + self.forward_hook_wrapper = forward_hook_wrapper + # TODO rather than having a data as list type or, we better pass *iterator* instead of list + self.data = data + self.batch_size = min(batch_size, len(data)) + self.max_iter = max_iter + self.adaptive_round_loss_fn = AdaptiveRoundingLoss( + max_iter=self.max_iter, warm_start=0.2 + ) + self.dtype = dtype + self.observer = observer + self.quant_min = quant_min + self.quant_max = quant_max + self.qscheme = qscheme + + def run_adaround(self) -> torch.nn.Module: + layer_list: List[Tuple[str, torch.nn.Module, torch.nn.Module]] = [] + for (name, module), q_module in zip( + self.model.named_modules(), self.q_model.modules() + ): + if isinstance(module, (torch.nn.Conv1d, torch.nn.Linear)): + # Knowing activation ahead-of-time would be helpful for asymmetric formulation + # But this is challenging in eager mode, but graph module. + layer_list.append((name, module, q_module)) + logger.info(f"Total number of layers : {len(layer_list)}") # noqa: G004 + + for name, module, q_module in layer_list: + logger.info( + f"Kick start adaptive rounding on {name} module {module}" # noqa: G004 + ) + self.optimize_adaptive_rounding( + module, + q_module, + None, + ) + + return ( + self.q_model.module + if isinstance(self.q_model, DataParallel) + else self.q_model + ) + + def get_data_inp_out( + self, module: torch.nn.Module, q_module: torch.nn.Module, data: List[Any] + ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]: + fp_out: List[torch.Tensor] = [] + q_input: List[torch.Tensor] = [] + fp_input: List[torch.Tensor] = [] + fp32_fetcher: List[torch.Tensor] = [] + quant_fetcher: List[torch.Tensor] = [] + handler1 = module.register_forward_hook(self.forward_hook_wrapper(fp32_fetcher)) + handler2 = q_module.register_forward_hook( + self.forward_hook_wrapper(quant_fetcher) + ) + for data_ in data: + with torch.no_grad(): + self.callback(self.model, data_) + self.callback(self.q_model, data_) + fp32_output = fp32_fetcher[1] + quant_input = quant_fetcher[0] + fp_out.append(fp32_output) + q_input.append(quant_input) + fp_input.append(fp32_fetcher[0]) + handler1.remove() + handler2.remove() + return q_input, fp_out, fp_input + + @torch.no_grad() + def feed_forward(self, x, weight, module): + if isinstance(module, torch.nn.Conv1d): + out = torch.nn.functional.conv1d( + x, + weight, + stride=module.stride, + padding=module.padding, + dilation=module.dilation, + groups=module.groups, + ) + elif isinstance(module, torch.nn.Linear): + out = torch.nn.functional.linear( + x, + weight, + bias=module.bias, + ) + else: + raise NotImplementedError + return out + + def _compute_and_display_local_losses( + self, + ada_quantizer: AdaroundFakeQuantizer, + q_module: torch.nn.Module, + q_inp: torch.Tensor, + fp_out: torch.Tensor, + ): + with torch.no_grad(): + ada_quantizer.use_soft_rounding = False + q_w_hard_round = ada_quantizer(q_module.weight) + out_hard_quant = self.feed_forward(q_inp, q_w_hard_round, q_module) + ada_quantizer.use_soft_rounding = True + q_w_soft_round = ada_quantizer(q_module.weight) + out_soft_quant = self.feed_forward(q_inp, q_w_soft_round, q_module) + soft_quant_loss = F.mse_loss(out_soft_quant, fp_out) + hard_quant_loss = F.mse_loss(out_hard_quant, fp_out) + logger.info( + f"soft quant loss: {soft_quant_loss.item()} hard quant loss: {hard_quant_loss.item()}" # noqa: G004 + ) + + def optimize_adaptive_rounding( + self, + module: torch.nn.Module, + q_module: torch.nn.Module, + activation: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, + ) -> None: + ada_quantizer = AdaroundFakeQuantizer( + dtype=self.dtype, + observer=self.observer, + qscheme=self.qscheme, + quant_min=self.quant_min, + quant_max=self.quant_max, + reduce_range=False, + ) + ada_quantizer.enable_observer() + ada_quantizer(q_module.weight) + ada_quantizer.disable_observer() + ada_quantizer.enable_fake_quant() + optimizer = torch.optim.Adam([ada_quantizer.V]) + inp, out, fp_in = self.get_data_inp_out(module, q_module, self.data) + + logger.info("==================== Before adaround ====================") + test_in, test_out, fp_test_in = self.get_data_inp_out( + module, q_module, self.data[0] + ) + + assert ( + torch.abs(test_out[0] - module(fp_test_in[0])).sum().item() == 0 + ), "In-placed activation is detected, please do not use activation in-placed" + # Stack the tensors in each list into a single tensor + # Assuming inp and out are your lists of tensors + inp_tensor = torch.vstack(inp) + out_tensor = torch.vstack(out) + dataset = TensorDataset(inp_tensor, out_tensor) + dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True) + + self._compute_and_display_local_losses( + ada_quantizer, q_module, test_in[0], test_out[0] + ) + global_idx = 0 + one_iter = len(out) // self.batch_size + for iteration in range(self.max_iter // one_iter): + reconstruction_loss = regularization_loss = torch.tensor(0) + for q_inp, fp_out in dataloader: + optimizer.zero_grad() + q_weight = ada_quantizer(q_module.weight) + if isinstance(module, torch.nn.Conv1d): + q_out = torch.nn.functional.conv1d( + q_inp, + q_weight, + stride=q_module.stride, + padding=q_module.padding, + dilation=q_module.dilation, + groups=q_module.groups, + ) + elif isinstance(q_module, torch.nn.Linear): + q_out = torch.nn.functional.linear( + q_inp, + q_weight, + bias=q_module.bias, + ) + else: + raise NotImplementedError + regularization_loss, reconstruction_loss = self.adaptive_round_loss_fn( + fp_out, + q_out, + ada_quantizer.V, + curr_iter=global_idx, + ) + loss = regularization_loss + reconstruction_loss + loss.backward() + optimizer.step() + global_idx += 1 + if global_idx >= self.max_iter: + break + if global_idx >= self.max_iter: + break + if iteration % 30 == 0: + logger.info( + f"glob iter {global_idx} regularization_loss {regularization_loss.item()} " # noqa: G004 + f"reconstruction_loss {reconstruction_loss.item()}" # noqa: G004 + ) + logger.info("==================== After adaround ====================") + self._compute_and_display_local_losses( + ada_quantizer, q_module, test_in[0], test_out[0] + ) + + ada_quantizer.use_soft_rounding = True + ada_quantizer.V.requires_grad = False + ada_quantizer = ada_quantizer.eval() + q_weight = ada_quantizer(q_module.weight) + # At the end of optimization, we need to copy the adarounded weight back to the original module + q_module.weight.data.copy_(q_weight) + # Eager mode requires observer to be set as "weight_fake_quant" to be parsed + q_module.weight_fake_quant = ada_quantizer.activation_post_process diff --git a/torch/ao/quantization/fx/_lower_to_native_backend.py b/torch/ao/quantization/fx/_lower_to_native_backend.py index 728506037b558..049f4e3135d9b 100644 --- a/torch/ao/quantization/fx/_lower_to_native_backend.py +++ b/torch/ao/quantization/fx/_lower_to_native_backend.py @@ -926,8 +926,14 @@ def _lower_dynamic_weighted_ref_functional( # Linear prepack args: (quantized weights[, bias]) # Conv prepack args: (quantized weights[, bias, stride, padding, dilation, groups]) prepack_args = [quantized_weight] + remaining_func_args + prepack_kwargs = {} if func_node.target == F.linear: prepack_op = get_linear_prepack_op_for_dtype(weight_dtype) + kwargs = func_node.kwargs.copy() + if 'bias' in kwargs: + prepack_kwargs['B'] = kwargs['bias'] + del kwargs['bias'] + func_node.kwargs = kwargs elif func_node.target in CONV_FUNCTIONAL_OPS: prepack_op = get_qconv_prepack_op(func_node.target) # For conv1d, the stride, padding, and dilation args may be ints, @@ -939,7 +945,7 @@ def _lower_dynamic_weighted_ref_functional( else: raise ValueError(f"Lowering is not supported for op '{func_node.target}'") with model.graph.inserting_before(func_node): - packed_weight = model.graph.create_node("call_function", prepack_op, tuple(prepack_args), {}) + packed_weight = model.graph.create_node("call_function", prepack_op, tuple(prepack_args), prepack_kwargs) # Step 3: Replace reference pattern with the corresponding quantized op func_node.target = q_relu_func if relu_node is not None else q_func diff --git a/torch/ao/quantization/pt2e/port_metadata_pass.py b/torch/ao/quantization/pt2e/port_metadata_pass.py index c47e820735787..5ea1f939a3b69 100644 --- a/torch/ao/quantization/pt2e/port_metadata_pass.py +++ b/torch/ao/quantization/pt2e/port_metadata_pass.py @@ -136,7 +136,7 @@ def _port_metadata_for_output_quant_nodes( node_users = _filter_sym_size_users(node) if len(node_users) != 1: - raise InternalError(f"Expecting {node} to have single user") + logger.warning(f"Expecting {node} to have single user") # noqa: G004 q_node = node_users.pop() if q_node.op != "call_function" or q_node.target not in _QUANTIZE_OPS: logger.warning( diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 3be764220e0de..9ff9131435f4c 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -422,6 +422,19 @@ PyObject* THPModule_swap_tensor_impl(PyObject* _unused, PyObject* args) { END_HANDLE_TH_ERRORS } +PyObject* THPModule_check_tp_alloc_is_default( + PyObject* _unused, + PyObject* cls) { + HANDLE_TH_ERRORS + TORCH_CHECK_TYPE( + PyType_Check(cls), + "cls must be a type (got ", + Py_TYPE(cls)->tp_name, + ")"); + return PyBool_FromLong(Py_TYPE(cls)->tp_alloc == PyType_GenericAlloc); + END_HANDLE_TH_ERRORS +} + PyObject* THPModule_addDocStr(PyObject* _unused, PyObject* args) { // adds a __doc__ string to a function, similar to numpy's arr_add_docstring static std::vector all_docs; @@ -1268,6 +1281,10 @@ static PyMethodDef TorchMethods[] = { // NOLINT {"_autograd_init", THPAutograd_initExtension, METH_NOARGS, nullptr}, {"_add_docstr", THPModule_addDocStr, METH_VARARGS, nullptr}, {"_swap_tensor_impl", THPModule_swap_tensor_impl, METH_VARARGS, nullptr}, + {"_check_tp_alloc_is_default", + THPModule_check_tp_alloc_is_default, + METH_O, + nullptr}, {"_init_names", THPModule_initNames, METH_O, nullptr}, {"_has_distributed", THPModule_hasDistributed, METH_NOARGS, nullptr}, {"_set_default_tensor_type", diff --git a/torch/csrc/distributed/c10d/Functional.cpp b/torch/csrc/distributed/c10d/Functional.cpp index d392c0213b84a..9d525f0d56400 100644 --- a/torch/csrc/distributed/c10d/Functional.cpp +++ b/torch/csrc/distributed/c10d/Functional.cpp @@ -196,6 +196,19 @@ at::Tensor all_gather_into_tensor( inputs, group_size, std::move(group_name))[0]; } +at::Tensor& all_gather_into_tensor_out( + at::Tensor& input, + int64_t group_size, + std::string group_name, + at::Tensor& output) { + c10d::AllgatherOptions opts; + + auto group = c10d::resolve_process_group(group_name); + auto work = group->_allgather_base(output, input, opts); + c10d::RankLocal::get().register_work(output, work); + return output; +} + at::Tensor allocate_reduce_scatter_output( const at::Tensor& input, const int64_t group_size) { @@ -321,6 +334,13 @@ TORCH_LIBRARY(_c10d_functional, m) { c10::DispatchKey::CompositeExplicitAutograd, ::all_reduce_coalesced_), {at::Tag::pt2_compliant_tag}); + m.def( + "all_gather_into_tensor_out(Tensor input, int group_size, str group_name, *, Tensor(a!) out) -> Tensor(a!)", + torch::dispatch( + c10::DispatchKey::CompositeExplicitAutograd, + ::all_gather_into_tensor_out), + {at::Tag::pt2_compliant_tag}); + m.def( "all_gather_into_tensor(Tensor input, int group_size, str group_name) -> Tensor", torch::dispatch( diff --git a/torch/csrc/distributed/c10d/HashStore.hpp b/torch/csrc/distributed/c10d/HashStore.hpp index 1453c0a72808a..3697d62301ba3 100644 --- a/torch/csrc/distributed/c10d/HashStore.hpp +++ b/torch/csrc/distributed/c10d/HashStore.hpp @@ -22,7 +22,7 @@ class TORCH_API HashStore : public Store { std::vector get(const std::string& key) override; void wait(const std::vector& keys) override { - wait(keys, Store::kDefaultTimeout); + wait(keys, timeout_); } void wait( diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index a609da1654b95..7586058475ff1 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -1,4 +1,3 @@ - #ifdef USE_C10D_NCCL #include @@ -64,6 +63,10 @@ std::map ncclDataType = { {at::kLong, ncclInt64}, {at::kHalf, ncclHalf}, {at::kBool, ncclUint8}, + {at::kFloat8_e5m2, ncclUint8}, + {at::kFloat8_e4m3fn, ncclUint8}, + {at::kFloat8_e4m3fnuz, ncclUint8}, + {at::kFloat8_e5m2fnuz, ncclUint8}, #if HAS_NCCL_BF16_DATATYPE {at::kBFloat16, ncclBfloat16}, #endif @@ -1567,6 +1570,8 @@ void ProcessGroupNCCL::watchdogHandler() { data.strings["last_enqueued_work_name"] = lastEnqueuedWorkName_; data.strings["last_started_work_name"] = lastStartedWorkName_; data.strings["last_completed_work_name"] = lastCompletedWorkName_; + data.strings["pg_name"] = pg_name_; + data.strings["pg_desc"] = pg_desc_; logger->log(data); lastStatusUpdateTime = std::chrono::steady_clock::now(); } @@ -3037,6 +3042,9 @@ c10::intrusive_ptr ProcessGroupNCCL::allreduce_sparse( const AllreduceOptions& opts) { TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); auto tensor = tensors.back(); + TORCH_CHECK( + !isFloat8Type(tensor.scalar_type()), + "Float8 dtypes are not currenlty supported for NCCL reductions"); #ifdef IS_NCCLX tensor = tensor.coalesce(); at::Tensor outputTensor = @@ -3151,7 +3159,9 @@ c10::intrusive_ptr ProcessGroupNCCL::allreduce( return c10::make_intrusive(); } } - + TORCH_CHECK( + !isFloat8Type(tensor.scalar_type()), + "Float8 dtypes are not currenlty supported for NCCL reductions"); // @lint-ignore CLANGTIDY RECORD_PARAM_COMMS_DATA( static_cast( @@ -3178,6 +3188,9 @@ c10::intrusive_ptr ProcessGroupNCCL::allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts) { auto total_numel = check_gpu_tensors_same_device(tensors); + TORCH_CHECK( + !isFloat8Type(tensors.back().scalar_type()), + "Float8 dtypes are not currenlty supported for NCCL reductions"); // @lint-ignore CLANGTIDY RECORD_PARAM_COMMS_DATA( @@ -3550,6 +3563,9 @@ c10::intrusive_ptr ProcessGroupNCCL::reduce_scatter( check_gpu_single_tensor(outputTensor); // @lint-ignore CLANGTIDY auto inputTensors_ = inputTensors.back(); + TORCH_CHECK( + !isFloat8Type(outputTensor.scalar_type()), + "Float8 dtypes are not currenlty supported for NCCL reductions"); RECORD_PARAM_COMMS_DATA( static_cast( @@ -3661,6 +3677,9 @@ c10::intrusive_ptr ProcessGroupNCCL::_reduce_scatter_base( // @lint-ignore CLANGTIDY const auto& tensor = outputTensor; + TORCH_CHECK( + !isFloat8Type(tensor.scalar_type()), + "Float8 dtypes are not currenlty supported for NCCL reductions"); RECORD_PARAM_COMMS_DATA( static_cast( this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective @@ -3721,6 +3740,9 @@ c10::intrusive_ptr ProcessGroupNCCL::reduce_scatter_tensor_coalesced( std::vector& outputs, std::vector& inputs, const ReduceScatterOptions& opts) { + TORCH_CHECK( + !isFloat8Type(inputs.back().scalar_type()), + "Float8 dtypes are not currenlty supported for NCCL reductions"); return collectiveCoalesced( inputs, outputs, diff --git a/torch/csrc/distributed/c10d/Store.hpp b/torch/csrc/distributed/c10d/Store.hpp index af715ba98a794..993284fa7cc56 100644 --- a/torch/csrc/distributed/c10d/Store.hpp +++ b/torch/csrc/distributed/c10d/Store.hpp @@ -97,4 +97,33 @@ class TORCH_API Store : public torch::CustomClassHolder { std::chrono::milliseconds timeout_; }; +/* +StoreTimeoutGuard is a RAII guard that will set the store timeout and restore it +when it returns. +*/ +class StoreTimeoutGuard { + public: + explicit StoreTimeoutGuard( + Store& store, + const std::chrono::milliseconds& timeout) + : store_(store) { + oldTimeout_ = store.getTimeout(); + store.setTimeout(timeout); + } + + ~StoreTimeoutGuard() { + store_.setTimeout(oldTimeout_); + } + + /* Disabling copy and move semantics */ + StoreTimeoutGuard(const StoreTimeoutGuard&) = delete; + StoreTimeoutGuard& operator=(const StoreTimeoutGuard&) = delete; + StoreTimeoutGuard(StoreTimeoutGuard&&) = delete; + StoreTimeoutGuard& operator=(StoreTimeoutGuard&&) = delete; + + private: + Store& store_; + std::chrono::milliseconds oldTimeout_; +}; + } // namespace c10d diff --git a/torch/csrc/distributed/c10d/control_collectives/ControlCollectives.hpp b/torch/csrc/distributed/c10d/control_collectives/ControlCollectives.hpp new file mode 100644 index 0000000000000..b98f9a71fb024 --- /dev/null +++ b/torch/csrc/distributed/c10d/control_collectives/ControlCollectives.hpp @@ -0,0 +1,59 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include +#include + +namespace c10d { + +using namespace std::chrono_literals; + +class TORCH_API ControlCollectives : public torch::CustomClassHolder { + public: + virtual void barrier( + const std::string& key, + std::chrono::milliseconds timeout = 5min, + bool block = true) = 0; + + virtual void broadcastSend( + const std::string& key, + const std::vector& data, + std::chrono::milliseconds timeout = 5min) = 0; + virtual std::vector broadcastRecv( + const std::string& key, + std::chrono::milliseconds timeout = 5min) = 0; + + virtual void gatherSend( + const std::string& key, + const std::vector& data, + std::chrono::milliseconds timeout = 5min) = 0; + virtual std::vector> gatherRecv( + const std::string& key, + const std::vector& data, + std::chrono::milliseconds timeout = 5min) = 0; + + virtual std::vector scatterSend( + const std::string& key, + const std::vector>& data, + std::chrono::milliseconds timeout = 5min) = 0; + virtual std::vector scatterRecv( + const std::string& key, + std::chrono::milliseconds timeout = 5min) = 0; + + virtual std::vector> allGather( + const std::string& key, + const std::vector& data, + std::chrono::milliseconds timeout = 5min) = 0; + + virtual int64_t allSum( + const std::string& key, + int64_t data, + std::chrono::milliseconds timeout = 5min) = 0; +}; + +} // namespace c10d diff --git a/torch/csrc/distributed/c10d/control_collectives/StoreCollectives.cpp b/torch/csrc/distributed/c10d/control_collectives/StoreCollectives.cpp new file mode 100644 index 0000000000000..995899441d461 --- /dev/null +++ b/torch/csrc/distributed/c10d/control_collectives/StoreCollectives.cpp @@ -0,0 +1,222 @@ +#include +#include +#include +#include +#include +#include +#include + +namespace { +std::string getRankKey(const std::string& key, int rank) { + return fmt::format("{}/{}", key, rank); +} +} // namespace + +namespace c10d { + +StoreCollectives::StoreCollectives( + c10::intrusive_ptr<::c10d::Store> store, + int rank, + int worldSize) + : store_(std::move(store)), rank_(rank), worldSize_(worldSize) {} + +void StoreCollectives::barrier( + const std::string& key, + std::chrono::milliseconds timeout, + bool blocking) { + enforceUnique(key); + StoreTimeoutGuard g{*store_, timeout}; + + auto num_members_key = fmt::format("{}/num_members", key); + auto last_members_key = fmt::format("{}/last_members", key); + + auto idx = store_->add(num_members_key, 1); + store_->set(getRankKey(key, rank_), "joined"); + + if (idx == worldSize_) { + store_->set(last_members_key, ""); + } else if (blocking) { + try { + store_->wait({last_members_key}); + } catch (const std::exception& e) { + std::string msg = "barrier failed -- missing ranks: "; + for (int i = 0; i < worldSize_; i++) { + if (i == rank_) { + continue; + } + auto rank_key = getRankKey(key, i); + if (!store_->check({rank_key})) { + msg += fmt::format("{}, ", i); + } + } + throw std::runtime_error(msg + e.what()); + } + } +} + +void StoreCollectives::broadcastSend( + const std::string& key, + const std::vector& data, + std::chrono::milliseconds timeout) { + enforceUnique(key); + StoreTimeoutGuard g{*store_, timeout}; + + store_->set(key, data); +} + +std::vector StoreCollectives::broadcastRecv( + const std::string& key, + std::chrono::milliseconds timeout) { + enforceUnique(key); + StoreTimeoutGuard g{*store_, timeout}; + + return store_->get(key); +} + +void StoreCollectives::gatherSend( + const std::string& key, + const std::vector& data, + std::chrono::milliseconds timeout) { + enforceUnique(key); + StoreTimeoutGuard g{*store_, timeout}; + + auto rank_key = getRankKey(key, rank_); + store_->set(rank_key, data); +} + +std::vector> StoreCollectives::gatherRecv( + const std::string& key, + const std::vector& data, + std::chrono::milliseconds timeout) { + enforceUnique(key); + StoreTimeoutGuard g{*store_, timeout}; + + std::vector keys; + keys.reserve(worldSize_); + + for (int i = 0; i < worldSize_; i++) { + if (i == rank_) { + continue; + } + auto rank_key = getRankKey(key, i); + keys.emplace_back(rank_key); + } + + std::vector> results; + results.reserve(worldSize_); + + try { + results = store_->multiGet(keys); + } catch (const std::exception& e) { + std::string msg = "gather failed -- missing ranks: "; + for (int i = 0; i < worldSize_; i++) { + if (i == rank_) { + continue; + } + auto rank_key = getRankKey(key, i); + if (!store_->check({rank_key})) { + msg += fmt::format("{}, ", i); + } + } + throw std::runtime_error(msg + e.what()); + } + + // insert local data + results.insert(results.begin() + rank_, data); + return results; +} + +std::vector StoreCollectives::scatterSend( + const std::string& key, + const std::vector>& data, + std::chrono::milliseconds timeout) { + enforceUnique(key); + StoreTimeoutGuard g{*store_, timeout}; + + std::vector keys; + keys.reserve(worldSize_); + for (int i = 0; i < worldSize_; i++) { + if (i == rank_) { + continue; + } + auto rank_key = getRankKey(key, i); + keys.emplace_back(rank_key); + } + auto local = data.at(rank_); + + std::vector> toSend{data}; + + toSend.erase(toSend.begin() + rank_); + + store_->multiSet(keys, toSend); + + return local; +} + +std::vector StoreCollectives::scatterRecv( + const std::string& key, + std::chrono::milliseconds timeout) { + enforceUnique(key); + StoreTimeoutGuard g{*store_, timeout}; + + auto rank_key = getRankKey(key, rank_); + return store_->get(rank_key); +} + +std::vector> StoreCollectives::allGather( + const std::string& key, + const std::vector& data, + std::chrono::milliseconds timeout) { + enforceUnique(key); + StoreTimeoutGuard g{*store_, timeout}; + + auto localKey = getRankKey(key, rank_); + store_->set(localKey, data); + + std::vector keys; + keys.reserve(worldSize_); + + for (int i = 0; i < worldSize_; i++) { + auto rank_key = getRankKey(key, i); + keys.emplace_back(rank_key); + } + + try { + return store_->multiGet(keys); + } catch (const std::exception& e) { + std::string msg = "all_gather failed -- missing ranks: "; + for (int i = 0; i < worldSize_; i++) { + if (i == rank_) { + continue; + } + auto rank_key = getRankKey(key, i); + if (!store_->check({rank_key})) { + msg += fmt::format("{}, ", i); + } + } + throw std::runtime_error(msg + e.what()); + } +} + +int64_t StoreCollectives::allSum( + const std::string& key, + int64_t value, + std::chrono::milliseconds timeout) { + enforceUnique(key); + StoreTimeoutGuard g{*store_, timeout}; + + store_->add(key, value); + + barrier(key + "/barrier", timeout); + + return store_->add(key, 0); +} + +void StoreCollectives::enforceUnique(const std::string& key) { + auto it = seenKeys_.find(key); + TORCH_INTERNAL_ASSERT( + it == seenKeys_.end(), "Key ", key, " has already been used."); + seenKeys_.emplace(key); +} + +} // namespace c10d diff --git a/torch/csrc/distributed/c10d/control_collectives/StoreCollectives.hpp b/torch/csrc/distributed/c10d/control_collectives/StoreCollectives.hpp new file mode 100644 index 0000000000000..7d3eb5038565e --- /dev/null +++ b/torch/csrc/distributed/c10d/control_collectives/StoreCollectives.hpp @@ -0,0 +1,68 @@ +#pragma once + +#include +#include +#include +#include + +namespace c10d { + +class TORCH_API StoreCollectives : public ControlCollectives { + public: + explicit StoreCollectives( + c10::intrusive_ptr store, + int rank, + int worldSize); + + void barrier( + const std::string& key, + std::chrono::milliseconds timeout = 5min, + bool block = true) override; + + void broadcastSend( + const std::string& key, + const std::vector& data, + std::chrono::milliseconds timeout = 5min) override; + std::vector broadcastRecv( + const std::string& key, + std::chrono::milliseconds timeout = 5min) override; + + void gatherSend( + const std::string& key, + const std::vector& data, + std::chrono::milliseconds timeout = 5min) override; + std::vector> gatherRecv( + const std::string& key, + const std::vector& data, + std::chrono::milliseconds timeout = 5min) override; + + std::vector scatterSend( + const std::string& key, + const std::vector>& data, + std::chrono::milliseconds timeout = 5min) override; + std::vector scatterRecv( + const std::string& key, + std::chrono::milliseconds timeout = 5min) override; + + std::vector> allGather( + const std::string& key, + const std::vector& data, + std::chrono::milliseconds timeout = 5min) override; + + int64_t allSum( + const std::string& key, + int64_t data, + std::chrono::milliseconds timeout = 5min) override; + + private: + void enforceUnique(const std::string& key); + + private: + c10::intrusive_ptr store_; + int rank_; + int worldSize_; + + c10::FastSet seenKeys_{}; +}; + +} // namespace c10d diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 483becbce0094..505b64e2a6976 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -6,6 +6,9 @@ #include #include #include +#include +#include +#include #ifndef _WIN32 #include #include @@ -136,6 +139,34 @@ namespace torch::distributed::c10d { namespace { +py::bytes toPyBytes(const std::vector& data) { + return py::bytes(reinterpret_cast(data.data()), data.size()); +} + +std::vector toPyBytes( + const std::vector>& data) { + std::vector out; + out.reserve(data.size()); + for (const std::vector& data_ : data) { + out.emplace_back(reinterpret_cast(data_.data()), data_.size()); + } + return out; +} + +std::vector toVec8(const std::string& data) { + std::vector out{data.begin(), data.end()}; + return out; +} + +std::vector> toVec8(const std::vector& data) { + std::vector> out; + out.reserve(data.size()); + for (auto& data_ : data) { + out.emplace_back(toVec8(data_)); + } + return out; +} + template using shared_ptr_class_ = py::class_>; @@ -166,8 +197,7 @@ class PythonStore : public ::c10d::Store { pybind11::get_overload(static_cast(this), "set"); TORCH_INTERNAL_ASSERT(fn, "Not implemented."); // Call function with a py::bytes object for the value. - fn(key, - py::bytes(reinterpret_cast(value.data()), value.size())); + fn(key, toPyBytes(value)); } // Note: this function manually calls the Python-side overload @@ -184,7 +214,7 @@ class PythonStore : public ::c10d::Store { // std::vector. There is no API for directly accessing // the contents of the py::bytes object. std::string str = pybind11::cast(fn(key)); - return std::vector(str.begin(), str.end()); + return toVec8(str); } // Note: this function manually calls the Python-side overload @@ -204,14 +234,8 @@ class PythonStore : public ::c10d::Store { // std::vector. There is no API for directly accessing // the contents of the py::bytes object. std::string str = pybind11::cast( - fn(key, - py::bytes( - reinterpret_cast(expectedValue.data()), - expectedValue.size()), - py::bytes( - reinterpret_cast(desiredValue.data()), - desiredValue.size()))); - return std::vector(str.begin(), str.end()); + fn(key, toPyBytes(expectedValue), toPyBytes(desiredValue))); + return toVec8(str); } int64_t add(const std::string& key, int64_t value) override { @@ -253,8 +277,7 @@ class PythonStore : public ::c10d::Store { return Store::append(key, value); } // Call function with a py::bytes object for the value. - fn(key, - py::bytes(reinterpret_cast(value.data()), value.size())); + fn(key, toPyBytes(value)); } std::vector> multiGet( @@ -287,14 +310,7 @@ class PythonStore : public ::c10d::Store { return Store::multiSet(keys, values); } - std::vector bytes; - bytes.reserve(values.size()); - for (auto& value : values) { - bytes.emplace_back( - reinterpret_cast(value.data()), value.size()); - } - - fn(keys, bytes); + fn(keys, toPyBytes(values)); } bool hasExtendedApi() const override { @@ -973,10 +989,7 @@ and :class:`~torch.distributed.HashStore`). "set", [](::c10d::Store& store, const std::string& key, - const std::string& value) { - std::vector value_(value.begin(), value.end()); - store.set(key, value_); - }, + const std::string& value) { store.set(key, toVec8(value)); }, py::call_guard(), R"( Inserts the key-value pair into the store based on the supplied ``key`` and @@ -1001,14 +1014,9 @@ Example:: const std::string& key, const std::string& expected_value, const std::string& desired_value) -> py::bytes { - std::vector expectedValue_( - expected_value.begin(), expected_value.end()); - std::vector desiredValue_( - desired_value.begin(), desired_value.end()); - auto value = - store.compareSet(key, expectedValue_, desiredValue_); - return py::bytes( - reinterpret_cast(value.data()), value.size()); + auto value = store.compareSet( + key, toVec8(expected_value), toVec8(desired_value)); + return toPyBytes(value); }, py::call_guard(), R"( @@ -1040,8 +1048,7 @@ Example:: py::gil_scoped_release guard; return store.get(key); }(); - return py::bytes( - reinterpret_cast(value.data()), value.size()); + return toPyBytes(value); }, R"( Retrieves the value associated with the given ``key`` in the store. If ``key`` is not @@ -1240,8 +1247,7 @@ Example:: [](::c10d::Store& store, const std::string& key, const std::string& value) { - std::vector value_(value.begin(), value.end()); - store.append(key, value_); + store.append(key, toVec8(value)); }, py::call_guard(), R"( @@ -1268,14 +1274,7 @@ Example:: py::gil_scoped_release guard; return store.multiGet(keys); }(); - std::vector res; - for (auto& value : values) { - auto bytes = py::bytes( - reinterpret_cast(value.data()), - value.size()); - res.push_back(bytes); - } - return res; + return toPyBytes(values); }, R"( Retrieve all values in ``keys``. If any key in ``keys`` is not @@ -1298,12 +1297,7 @@ Example:: [](::c10d::Store& store, const std::vector& keys, const std::vector& values) { - std::vector> vals; - vals.reserve(values.size()); - for (auto& value : values) { - vals.emplace_back(value.begin(), value.end()); - } - store.multiSet(keys, vals); + store.multiSet(keys, toVec8(values)); }, py::call_guard(), R"( @@ -1487,6 +1481,212 @@ that adds a prefix to each key inserted to the store. &::c10d::PrefixStore::getUnderlyingNonPrefixStore, R"(Recursively to get the store before layers of wrapping with PrefixStore.)"); + using namespace std::chrono_literals; + + auto collectives = + py::class_< + ::c10d::ControlCollectives, + c10::intrusive_ptr<::c10d::ControlCollectives>>( + module, + "_ControlCollectives", + R"( +Base class for all ControlCollectives implementations. +)") + .def( + "barrier", + &::c10d::ControlCollectives::barrier, + py::arg("key"), + py::arg("timeout") = 5min, + py::arg("block") = true, + py::call_guard(), + R"( +Blocks until all workers have entered this function. + +Arguments: + key (str): The unique key used to identify this operation. + timeout (duration): The timeout for this operation. + block (bool): whether to block this working waiting on the results of the barrier. +)") + .def( + "all_sum", + &::c10d::ControlCollectives::allSum, + py::arg("key"), + py::arg("data"), + py::arg("timeout") = 5min, + py::call_guard(), + R"( +Computes a sum across all workers and returns the final value. + +Arguments: + key (str): The unique key used to identify this operation. + data (int): The data to sum. + timeout (duration): The timeout for this operation. +)") + .def( + "broadcast_send", + [](::c10d::ControlCollectives& collectives, + const std::string& key, + const std::string& data, + std::chrono::milliseconds timeout = 5min) { + collectives.broadcastSend(key, toVec8(data), timeout); + }, + py::arg("key"), + py::arg("data"), + py::arg("timeout") = 5min, + py::call_guard(), + R"( +Sends data to all other workers. Must be only called from one worker. + +Arguments: + key (str): The unique key used to identify this operation. + data (str): The data to send. + timeout (duration): The timeout for this operation. +)") + .def( + "broadcast_recv", + [](::c10d::ControlCollectives& collectives, + const std::string& key, + std::chrono::milliseconds timeout = 5min) { + auto out = [&]() { + py::gil_scoped_release guard; + return collectives.broadcastRecv(key, timeout); + }(); + return toPyBytes(out); + }, + py::arg("key"), + py::arg("timeout") = 5min, + R"( +Receives data broadcasted from 1 worker. + +Arguments: + key (str): The unique key used to identify this operation. + timeout (duration): The timeout for this operation. +)") + .def( + "gather_send", + [](::c10d::ControlCollectives& collectives, + const std::string& key, + const std::string& data, + std::chrono::milliseconds timeout = 5min) { + collectives.gatherSend(key, toVec8(data), timeout); + }, + py::arg("key"), + py::arg("data"), + py::arg("timeout") = 5min, + py::call_guard(), + R"( +Sends data to one other worker. + +Arguments: + key (str): The unique key used to identify this operation. + data (str): The data to send. + timeout (duration): The timeout for this operation. +)") + .def( + "gather_recv", + [](::c10d::ControlCollectives& collectives, + const std::string& key, + const std::string& data, + std::chrono::milliseconds timeout = 5min) { + auto out = [&]() { + py::gil_scoped_release guard; + return collectives.gatherRecv(key, toVec8(data), timeout); + }(); + return toPyBytes(out); + }, + py::arg("key"), + py::arg("data"), + py::arg("timeout") = 5min, + R"( +Receives data broadcasted from all workers. Must only be called by one worker. + +Arguments: + key (str): The unique key used to identify this operation. + timeout (duration): The timeout for this operation. +)") + + .def( + "scatter_send", + [](::c10d::ControlCollectives& collectives, + const std::string& key, + const std::vector& data, + std::chrono::milliseconds timeout = 5min) { + auto out = [&]() { + py::gil_scoped_release guard; + return collectives.scatterSend(key, toVec8(data), timeout); + }(); + return toPyBytes(out); + }, + py::arg("key"), + py::arg("data"), + py::arg("timeout") = 5min, + R"( +Sends rank specific data to all other workers. + +Arguments: + key (str): The unique key used to identify this operation. + data (str): The data to send. + timeout (duration): The timeout for this operation. +)") + .def( + "scatter_recv", + [](::c10d::ControlCollectives& collectives, + const std::string& key, + std::chrono::milliseconds timeout = 5min) { + auto out = [&]() { + py::gil_scoped_release guard; + return collectives.scatterRecv(key, timeout); + }(); + return toPyBytes(out); + }, + py::arg("key"), + py::arg("timeout") = 5min, + R"( +Receives rank specific data from one worker. + +Arguments: + key (str): The unique key used to identify this operation. + timeout (duration): The timeout for this operation. +)") + + .def( + "all_gather", + [](::c10d::ControlCollectives& collectives, + const std::string& key, + const std::string& data, + std::chrono::milliseconds timeout = 5min) { + auto out = [&]() { + py::gil_scoped_release guard; + return collectives.allGather(key, toVec8(data), timeout); + }(); + return toPyBytes(out); + }, + py::arg("key"), + py::arg("data"), + py::arg("timeout") = 5min, + R"( +Sends data to all workers and receives data from all other workers. + +Arguments: + key (str): The unique key used to identify this operation. + data (str): The data to send. + timeout (duration): The timeout for this operation. +)"); + + intrusive_ptr_class_<::c10d::StoreCollectives>( + module, + "_StoreCollectives", + collectives, + R"( +An implementation of ControlCollectives that uses the provided store as the underlying +communication mechanism. + )") + .def( + py::init, int, int>(), + py::arg("store"), + py::arg("rank"), + py::arg("world_size")); + auto processGroup = py::class_< ::c10d::ProcessGroup, diff --git a/torch/csrc/dynamo/python_compiled_autograd.cpp b/torch/csrc/dynamo/python_compiled_autograd.cpp index fb27b39b28e6a..3a79a7bc63721 100644 --- a/torch/csrc/dynamo/python_compiled_autograd.cpp +++ b/torch/csrc/dynamo/python_compiled_autograd.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include /* @@ -50,14 +51,6 @@ at trace time. namespace torch::dynamo::autograd { using c10::SymInt; -// snapshot of python verbose logging toggle -static bool is_verbose_logging_enabled; -static constexpr std::string_view VLOG_PREFIX = - "[python_compiled_autograd.cpp] "; -std::ostream& vcout() { - return std::cout << VLOG_PREFIX; -} - static PyObject* wrap_int_list(const std::vector& inputs) { PyObject* pyinput = PyTuple_New(static_cast(inputs.size())); for (const auto i : c10::irange(inputs.size())) { @@ -91,6 +84,82 @@ static void check(bool result) { check(nullptr); } +// snapshot of python verbose logging toggle +static PyObject* python_verbose_logger = nullptr; +struct VerboseLogger { + static std::optional maybe_create() { + if (python_verbose_logger == nullptr) { + return std::nullopt; + } + return VerboseLogger(); + } + + void verbose_log_fn(std::string_view msg) const { + TORCH_CHECK(python_verbose_logger != nullptr); + check(PyObject_CallFunction(python_verbose_logger, "s", msg.data())); + } + + void log_node_check( + const Node& fn, + size_t size_inputs_num, + std::unordered_set cached_keys, + const CacheKey& key, + size_t node_idx) { + std::string node_name = + fn.name() + " (NodeCall " + std::to_string(node_idx) + ")"; + + cumulative_sizes_per_node[size_inputs_num] = node_name; + + if (!logged_node_miss && cached_keys.find(key) == cached_keys.end()) { + _log_node_miss(typeid(fn), cached_keys, key, node_name); + logged_node_miss = true; + } + } + + void _log_node_miss( + const std::type_info& node_type, + std::unordered_set cached_keys, + const CacheKey& key, + const std::string& node_name) const { + std::ostringstream oss; + oss << "Cache miss due to new autograd node: " << node_name + << " with key size " << std::to_string(key.key_size) + << ", previous key sizes=["; + + for (auto it = cached_keys.begin(); it != cached_keys.end(); it++) { + if (it->node_type != node_type) { + continue; + } + oss << it->key_size; + if (std::next(it) != cached_keys.end()) { + oss << ","; + } + } + oss << "]"; + verbose_log_fn(oss.str()); + } + + void log_dynamic_shapes_check(size_t size_idx) const { + if (cumulative_sizes_per_node.empty()) { + return; + } + + auto it = cumulative_sizes_per_node.lower_bound(size_idx); + TORCH_CHECK(it != cumulative_sizes_per_node.end()); + size_t start_idx = + it == cumulative_sizes_per_node.begin() ? 0 : std::prev(it)->first; + verbose_log_fn( + "Cache miss due to changed shapes: marking size idx " + + std::to_string(size_idx - start_idx) + " of " + it->second + + " as dynamic"); + } + + // track which size index belongs to which node + std::map cumulative_sizes_per_node; + // only log cache miss due to node key once + bool logged_node_miss = false; +}; + struct CacheNode { // A node in the shadow graph, we follow next edges until we reach the end of // the graph @@ -135,7 +204,9 @@ struct CacheNode { CacheNode& operator=(const CacheNode&) = delete; CacheNode& operator=(CacheNode&&) = delete; - bool check_dynamic_sizes(AutogradCompilerCall& call) { + bool check_dynamic_sizes( + AutogradCompilerCall& call, + const std::optional& vlogger) { /* We start off by assuming everything is static, then we mark things as dynamic when we see them change. This function: @@ -161,9 +232,8 @@ struct CacheNode { if (changed_value) { if (!was_dynamic) { cache_hit = false; - if (is_verbose_logging_enabled) { - vcout() << "cache miss: marking sizes[" << i << "] as dynamic" - << std::endl; + if (vlogger.has_value()) { + vlogger->log_dynamic_shapes_check(i); } } expected = SizeInput(SizeInput::DYNAMIC, data[i].value); @@ -257,11 +327,18 @@ static PyObject* is_cache_empty(PyObject* dummy, PyObject* args) { END_HANDLE_TH_ERRORS; } -static PyObject* set_verbose_logging(PyObject* dummy, PyObject* args) { +static PyObject* set_verbose_logger(PyObject* dummy, PyObject* args) { HANDLE_TH_ERRORS; - if (!PyArg_ParseTuple(args, "p", &is_verbose_logging_enabled)) { + PyObject* logger = nullptr; + if (!PyArg_ParseTuple(args, "O", &logger)) { Py_RETURN_FALSE; } + + if (logger == Py_None) { + python_verbose_logger = nullptr; + } else { + python_verbose_logger = logger; + } Py_RETURN_TRUE; END_HANDLE_TH_ERRORS; } @@ -271,7 +348,7 @@ static PyMethodDef _methods[] = { {"set_autograd_compiler", set_autograd_compiler, METH_VARARGS, nullptr}, {"clear_cache", clear_cache, METH_NOARGS, nullptr}, {"is_cache_empty", is_cache_empty, METH_NOARGS, nullptr}, - {"set_verbose_logging", set_verbose_logging, METH_VARARGS, nullptr}, + {"set_verbose_logger", set_verbose_logger, METH_VARARGS, nullptr}, {nullptr, nullptr, 0, nullptr}}; static struct PyModuleDef _module = { @@ -353,6 +430,8 @@ CacheNode* _compiled_autograd_impl( calls.reserve( check_exec_info ? graph_task.exec_info_.size() : dependencies.size() + 1); + int i = 0; + std::optional vlogger = VerboseLogger::maybe_create(); while (!worklist.empty()) { std::shared_ptr fn = std::move(worklist.back()); worklist.pop_back(); @@ -367,10 +446,17 @@ CacheNode* _compiled_autograd_impl( node_args.collect(call.node->next_edges()); } CacheKey key = node_args.key(); - if (is_verbose_logging_enabled && - cache->lookup(key, /*create=*/false) == nullptr) { - vcout() << "Creating cache entry for " << fn->name() - << ", with key of size " << key.key_size << std::endl; + if (vlogger.has_value()) { + std::unordered_set cached_keys; + for (const auto& [k, _] : cache->next) { + cached_keys.emplace(k); + } + vlogger->log_node_check( + *fn, + compiler_call.all_size_inputs.size(), + std::move(cached_keys), + key, + i); } cache = cache->lookup(key); } @@ -395,10 +481,11 @@ CacheNode* _compiled_autograd_impl( worklist.emplace_back(edge.function); } } + i++; } // TODO(jansel): some dynamic sizes seem to be ints not symints - if (!cache->check_dynamic_sizes(compiler_call)) { + if (!cache->check_dynamic_sizes(compiler_call, vlogger)) { // cache miss, need to capture FX graph ClosingTHPObjectPtr py_compiler( check(PyObject_CallNoArgs((the_autograd_compiler)))); @@ -454,7 +541,7 @@ CacheNode* _compiled_autograd_impl( inputs = THPVariable_UnpackList(pyinputs); } - if (is_verbose_logging_enabled) { + if (python_verbose_logger != nullptr) { std::string _node_name = call.node->name(); THPObjectPtr node_name(PyUnicode_FromString(_node_name.data())); TORCH_INTERNAL_ASSERT(node_name != nullptr); diff --git a/torch/csrc/inductor/aoti_eager/kernel_holder.cpp b/torch/csrc/inductor/aoti_eager/kernel_holder.cpp index 238050f501223..1ada9415ea12e 100644 --- a/torch/csrc/inductor/aoti_eager/kernel_holder.cpp +++ b/torch/csrc/inductor/aoti_eager/kernel_holder.cpp @@ -96,8 +96,13 @@ bool unpack_tensors( const std::vector& arguments, const torch::jit::Stack& stack, const c10::Device& device, - std::vector& inputs) { + std::vector& inputs, + bool with_scalar = false) { for (size_t idx = 0; idx < stack.size(); idx++) { + if (!with_scalar && stack[idx].isScalar()) { + continue; + } + if (!unpack_ivalue(arguments[idx], stack[idx], device, inputs)) { return false; } @@ -106,6 +111,40 @@ bool unpack_tensors( return true; } +std::vector get_tensor_parameter_index( + const std::vector& arguments, + const torch::jit::Stack& stack) { + std::vector tensor_parameter_index; + for (size_t idx = 0; idx < stack.size(); idx++) { + if (stack[idx].isScalar() || stack[idx].isTensor()) { + // scalar and tensor + tensor_parameter_index.push_back(idx); + } else if (stack[idx].isTensorList()) { + // tensor list + std::fill_n( + std::back_inserter(tensor_parameter_index), + stack[idx].toListRef().size(), + idx); + } else if (stack[idx].isOptionalTensorList()) { + // optional tensor list: std::vector> + for (const auto& item : stack[idx].toListRef()) { + if (item.toOptional().has_value()) { + tensor_parameter_index.push_back(idx); + } + } + } else if ( + *arguments[idx].real_type() == + *c10::getTypePtr>()) { + // optional tensor + if (stack[idx].toOptional().has_value()) { + tensor_parameter_index.push_back(idx); + } + } + } + + return tensor_parameter_index; +} + } // namespace AOTIPythonKernelHolder::AOTIPythonKernelHolder( @@ -149,14 +188,19 @@ bool AOTIPythonKernelHolder::cache_lookup( "Not implemented for operations that return a non-Tensor value."); std::vector inputs; - auto res = unpack_tensors(op.schema().arguments(), *stack, device_, inputs); + auto res = + unpack_tensors(op.schema().arguments(), *stack, device_, inputs, true); TORCH_CHECK_NOT_IMPLEMENTED( res && inputs.size() > 0, "Not implemented for operations that contain a parameter which is ", "not one of the following types: at::Tensor, at::TensorList, ", "std::optional, std::vector>."); - auto inputs_metadata = get_inputs_metadata(inputs); + auto tensor_parameter_index = + get_tensor_parameter_index(op.schema().arguments(), *stack); + TORCH_INTERNAL_ASSERT(tensor_parameter_index.size() == inputs.size()); + auto inputs_metadata = get_inputs_metadata( + inputs, op.schema().arguments(), tensor_parameter_index); auto aoti_kernel_state = aoti_kernel_cache_.find(inputs_metadata); if (aoti_kernel_state == aoti_kernel_cache_.end()) { return false; @@ -197,18 +241,49 @@ void AOTIPythonKernelHolder::cache_hit( } AOTIKernelMetadata AOTIPythonKernelHolder::get_inputs_metadata( - const std::vector& inputs) { + const std::vector& inputs, + const std::vector& inputs_argument, + const std::vector& inputs_argument_index) { AOTIKernelMetadata inputs_metadata; - for (const auto& input : inputs) { + for (size_t idx = 0; idx < inputs.size(); ++idx) { + auto input = inputs[idx]; + auto input_info = inputs_argument[inputs_argument_index[idx]]; + auto device = input.device(); if (device.is_cpu()) { // If the device is CPU, set the device index to -1. device = c10::Device(device.type(), -1); } + c10::Scalar scalar_value((double)1.0); + auto tensor_type = input.scalar_type(); + + bool is_scalar = input_info.type()->isSubtypeOf(*c10::NumberType::get()); + if (is_scalar) { + if (c10::isFloatingType(input.scalar_type())) { + auto scalar_numeric_value = input.item().toDouble(); + tensor_type = c10::ScalarType::Double; + scalar_value = c10::Scalar(scalar_numeric_value); + } else if (c10::isIntegralType(input.scalar_type(), false)) { + auto scalar_numeric_value = input.item().toUInt64(); + tensor_type = c10::ScalarType::UInt64; + scalar_value = c10::Scalar(scalar_numeric_value); + } else if (input.scalar_type() == c10::ScalarType::Bool) { + auto scalar_numeric_value = input.item().toBool(); + tensor_type = c10::ScalarType::Bool; + scalar_value = c10::Scalar(scalar_numeric_value); + } else { + TORCH_CHECK( + false, + "Unsupported scalar tensor type: ", + c10::toString(input.scalar_type())); + } + } + inputs_metadata.emplace_back( - false, // is symbloic - input.scalar_type(), + false, + tensor_type, + c10::IValue(scalar_value), device, input.sizes().vec(), input.strides().vec()); @@ -269,6 +344,7 @@ void AOTIPythonKernelHolder::init_aoti_kernel_cache() { reinterpret_cast(data_type_obj.ptr())->scalar_type; auto sizes = metadata["sizes"].cast>(); auto strides = metadata["strides"].cast>(); + bool is_scalar = metadata.contains("scalar_value"); std::vector> sym_optional_sizes; std::vector> sym_optional_strides; @@ -279,10 +355,34 @@ void AOTIPythonKernelHolder::init_aoti_kernel_cache() { sym_optional_strides.push_back(std::optional(stride)); } - // Now you can use these variables in your code + // If an input parameter is a scalar, its detailed value is cached. + // This is done to ensure correctness during subsequent checks. + c10::Scalar scalar_value((double)1.0); + if (is_scalar) { + if (c10::isFloatingType(data_type)) { + auto scalar_numeric_value = metadata["scalar_value"].cast(); + data_type = c10::ScalarType::Double; + scalar_value = c10::Scalar(scalar_numeric_value); + } else if (c10::isIntegralType(data_type, false)) { + auto scalar_numeric_value = metadata["scalar_value"].cast(); + data_type = c10::ScalarType::UInt64; + scalar_value = c10::Scalar(scalar_numeric_value); + } else if (data_type == c10::ScalarType::Bool) { + auto scalar_numeric_value = metadata["scalar_value"].cast(); + data_type = c10::ScalarType::Bool; + scalar_value = c10::Scalar(scalar_numeric_value); + } else { + TORCH_CHECK( + false, + "Unsupported scalar tensor type: ", + c10::toString(data_type)); + } + } + tensor_metadata_list.emplace_back( is_dynamic, data_type, + c10::IValue(scalar_value), c10::Device(c10::Device(device_type).type(), device_index), sizes, strides); diff --git a/torch/csrc/inductor/aoti_eager/kernel_holder.h b/torch/csrc/inductor/aoti_eager/kernel_holder.h index 9cbcc217d7c30..b67e4e7d4464e 100644 --- a/torch/csrc/inductor/aoti_eager/kernel_holder.h +++ b/torch/csrc/inductor/aoti_eager/kernel_holder.h @@ -3,6 +3,7 @@ #include #include +#include #include #include @@ -82,7 +83,10 @@ class AOTIPythonKernelHolder : public c10::OperatorKernel { void init_aoti_kernel_cache(); // Abstract the meta information of each tensor for the given operation. The // meta infomation will be used for cache lookup as the key. - AOTIKernelMetadata get_inputs_metadata(const std::vector&); + AOTIKernelMetadata get_inputs_metadata( + const std::vector& inputs, + const std::vector& inputs_argument, + const std::vector& inputs_argument_index); // Load the AOTIModelContainerRunner object from the given file path. std::shared_ptr load_aoti_model_runner( const std::string&); diff --git a/torch/csrc/inductor/aoti_eager/kernel_meta_info.cpp b/torch/csrc/inductor/aoti_eager/kernel_meta_info.cpp index e89c59142328f..a49fab21d671e 100644 --- a/torch/csrc/inductor/aoti_eager/kernel_meta_info.cpp +++ b/torch/csrc/inductor/aoti_eager/kernel_meta_info.cpp @@ -1,5 +1,6 @@ #if !defined(C10_MOBILE) && !defined(ANDROID) #include +#include namespace torch::inductor { @@ -17,6 +18,24 @@ TensorMetadata::TensorMetadata( std::vector strides) : is_symbolic_(is_symbolic), dtype_(dtype), + scalar_value_((float)1.0), + device_(device), + sizes_(sizes), + strides_(strides) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + !is_symbolic_, "Not support symbolic shape now"); +} + +TensorMetadata::TensorMetadata( + bool is_symbolic, + c10::ScalarType dtype, + c10::IValue scalar_value, + c10::Device device, + std::vector sizes, + std::vector strides) + : is_symbolic_(is_symbolic), + dtype_(dtype), + scalar_value_(scalar_value), device_(device), sizes_(sizes), strides_(strides) { @@ -29,15 +48,39 @@ bool TensorMetadata::operator==(const TensorMetadata& other) const { !is_symbolic_, "Not support symbolic shape now"); return this->is_symbolic_ == other.is_symbolic_ && this->dtype_ == other.dtype_ && + this->scalar_value_ == other.scalar_value_ && this->device_.type() == other.device_.type() && this->sizes_ == other.sizes_ && this->strides_ == other.strides_; } +std::ostream& operator<<( + std::ostream& stream, + const TensorMetadata& tensor_metadata) { + stream << "is_symbolic_: " << tensor_metadata.is_symbolic_ << std::endl; + stream << "dtype_: " << tensor_metadata.dtype_ << std::endl; + stream << "scalar_value_: " << tensor_metadata.scalar_value_.type()->str() + << "(" << tensor_metadata.scalar_value_ << ")" << std::endl; + stream << "device_: " << tensor_metadata.device_ << std::endl; + stream << "sizes_: "; + for (const auto& size : tensor_metadata.sizes_) { + stream << size << " "; + } + stream << std::endl; + stream << "strides_: "; + for (const auto& stride : tensor_metadata.strides_) { + stream << stride << " "; + } + stream << std::endl; + return stream; +} + size_t TensorMetadataHash::operator()( const TensorMetadata& tensor_metadata) const { auto hash = std::hash()(tensor_metadata.is_symbolic_); hash = c10::hash_combine( hash, std::hash()(tensor_metadata.dtype_)); + hash = + c10::hash_combine(hash, c10::IValue::hash(tensor_metadata.scalar_value_)); hash = c10::hash_combine( hash, std::hash()(tensor_metadata.device_.type())); diff --git a/torch/csrc/inductor/aoti_eager/kernel_meta_info.h b/torch/csrc/inductor/aoti_eager/kernel_meta_info.h index c7f8315d2707a..5c22e9b75f65b 100644 --- a/torch/csrc/inductor/aoti_eager/kernel_meta_info.h +++ b/torch/csrc/inductor/aoti_eager/kernel_meta_info.h @@ -33,6 +33,8 @@ struct TensorMetadata { bool is_symbolic_; // Dtype of a tensor(For scalar, we will wrap it as a scalar tensor) c10::ScalarType dtype_; + // Concrete scalar value. Serve for operations w/ scalar parameter + c10::IValue scalar_value_; // Device of a tensor. c10::Device device_; // Sizes of a tensor. Currently, we only support static shape and use int64_t @@ -49,6 +51,13 @@ struct TensorMetadata { c10::Device device, std::vector sizes, std::vector strides); + TensorMetadata( + bool is_symbolic, + c10::ScalarType dtype, + c10::IValue scalar_value, + c10::Device device, + std::vector sizes, + std::vector strides); bool operator==(const TensorMetadata& other) const; }; diff --git a/torch/csrc/inductor/aoti_torch/c/shim.h b/torch/csrc/inductor/aoti_torch/c/shim.h index b05c52c6a3876..6fa7df75c0566 100644 --- a/torch/csrc/inductor/aoti_torch/c/shim.h +++ b/torch/csrc/inductor/aoti_torch/c/shim.h @@ -72,6 +72,9 @@ extern "C" { struct AtenTensorOpaque; using AtenTensorHandle = AtenTensorOpaque*; +struct AtenGeneratorOpaque; +using AtenGeneratorHandle = AtenGeneratorOpaque*; + struct AOTIProxyExecutorOpaque; using AOTIProxyExecutorHandle = AOTIProxyExecutorOpaque*; diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h index 8058618f97486..2c7f05dd84cd9 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h @@ -47,6 +47,8 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_avg_pool2d(AtenTensorHandle self AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_avg_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_avg_pool3d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_avg_pool3d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_bernoulli__Tensor(AtenTensorHandle self, AtenTensorHandle p, AtenGeneratorHandle* generator); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_bernoulli__float(AtenTensorHandle self, double p, AtenGeneratorHandle* generator); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_bmm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_bucketize_Tensor(AtenTensorHandle self, AtenTensorHandle boundaries, int32_t out_int32, int32_t right, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_cat(const AtenTensorHandle* tensors, int64_t tensors_len_, int64_t dim, AtenTensorHandle* ret0); @@ -58,6 +60,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_cummax(AtenTensorHandle self, in AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_cummin(AtenTensorHandle self, int64_t dim, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_cumprod(AtenTensorHandle self, int64_t dim, int32_t* dtype, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_cumsum(AtenTensorHandle self, int64_t dim, int32_t* dtype, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_exponential(AtenTensorHandle self, double lambd, AtenGeneratorHandle* generator, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_fractional_max_pool2d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle random_samples, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_fractional_max_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle indices, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_fractional_max_pool3d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle random_samples, AtenTensorHandle* ret0, AtenTensorHandle* ret1); @@ -94,15 +97,18 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_pow_Scalar(double self, AtenTens AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_pow_Tensor_Scalar(AtenTensorHandle self, double exponent, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_pow_Tensor_Tensor(AtenTensorHandle self, AtenTensorHandle exponent, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_rand(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_rand_generator(const int64_t* size, int64_t size_len_, AtenGeneratorHandle* generator, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_randint(int64_t high, const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_randint_generator(int64_t high, const int64_t* size, int64_t size_len_, AtenGeneratorHandle* generator, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_randint_low_out(AtenTensorHandle out, int64_t low, int64_t high, const int64_t* size, int64_t size_len_); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_randn(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_randn_generator(const int64_t* size, int64_t size_len_, AtenGeneratorHandle* generator, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_randperm(int64_t n, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_repeat_interleave_Tensor(AtenTensorHandle repeats, int64_t* output_size, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_replication_pad1d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* padding, int64_t padding_len_, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_replication_pad2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* padding, int64_t padding_len_, AtenTensorHandle* ret0); -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_resize_(AtenTensorHandle self, const int64_t* size, int64_t size_len_, int32_t* memory_format, AtenTensorHandle* ret0); -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_resize_as_(AtenTensorHandle self, AtenTensorHandle the_template, int32_t* memory_format, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_resize_(AtenTensorHandle self, const int64_t* size, int64_t size_len_, int32_t* memory_format); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_resize_as_(AtenTensorHandle self, AtenTensorHandle the_template, int32_t* memory_format); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_scatter_src_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, AtenTensorHandle src); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_scatter_value_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, double value); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_scatter_reduce_two_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, AtenTensorHandle src, const char* reduce, int32_t include_self); @@ -113,6 +119,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_sort(AtenTensorHandle self, int6 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_sort_stable(AtenTensorHandle self, int32_t* stable, int64_t dim, int32_t descending, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_topk(AtenTensorHandle self, int64_t k, int64_t dim, int32_t largest, int32_t sorted, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_triangular_solve(AtenTensorHandle self, AtenTensorHandle A, int32_t upper, int32_t transpose, int32_t unitriangular, AtenTensorHandle* ret0, AtenTensorHandle* ret1); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_uniform(AtenTensorHandle self, double from, double to, AtenGeneratorHandle* generator, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_upsample_bicubic2d_backward(AtenTensorHandle grad_output, const int64_t* output_size, int64_t output_size_len_, const int64_t* input_size, int64_t input_size_len_, int32_t align_corners, double* scales_h, double* scales_w, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_upsample_linear1d_backward(AtenTensorHandle grad_output, const int64_t* output_size, int64_t output_size_len_, const int64_t* input_size, int64_t input_size_len_, int32_t align_corners, double* scales, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_upsample_trilinear3d_backward(AtenTensorHandle grad_output, const int64_t* output_size, int64_t output_size_len_, const int64_t* input_size, int64_t input_size_len_, int32_t align_corners, double* scales_d, double* scales_h, double* scales_w, AtenTensorHandle* ret0); diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h index df0099d37bed3..1dceac240e40a 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h @@ -55,6 +55,8 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_avg_pool2d(AtenTensorHandle sel AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_avg_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_avg_pool3d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_avg_pool3d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_bernoulli__Tensor(AtenTensorHandle self, AtenTensorHandle p, AtenGeneratorHandle* generator); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_bernoulli__float(AtenTensorHandle self, double p, AtenGeneratorHandle* generator); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_bmm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_bucketize_Tensor(AtenTensorHandle self, AtenTensorHandle boundaries, int32_t out_int32, int32_t right, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_cat(const AtenTensorHandle* tensors, int64_t tensors_len_, int64_t dim, AtenTensorHandle* ret0); @@ -66,6 +68,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_cummax(AtenTensorHandle self, i AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_cummin(AtenTensorHandle self, int64_t dim, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_cumprod(AtenTensorHandle self, int64_t dim, int32_t* dtype, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_cumsum(AtenTensorHandle self, int64_t dim, int32_t* dtype, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_exponential(AtenTensorHandle self, double lambd, AtenGeneratorHandle* generator, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_fractional_max_pool2d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle random_samples, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_fractional_max_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle indices, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_fractional_max_pool3d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle random_samples, AtenTensorHandle* ret0, AtenTensorHandle* ret1); @@ -101,15 +104,18 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_pow_Scalar(double self, AtenTen AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_pow_Tensor_Scalar(AtenTensorHandle self, double exponent, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_pow_Tensor_Tensor(AtenTensorHandle self, AtenTensorHandle exponent, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_rand(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_rand_generator(const int64_t* size, int64_t size_len_, AtenGeneratorHandle* generator, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_randint(int64_t high, const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_randint_generator(int64_t high, const int64_t* size, int64_t size_len_, AtenGeneratorHandle* generator, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_randint_low_out(AtenTensorHandle out, int64_t low, int64_t high, const int64_t* size, int64_t size_len_); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_randn(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_randn_generator(const int64_t* size, int64_t size_len_, AtenGeneratorHandle* generator, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_randperm(int64_t n, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_repeat_interleave_Tensor(AtenTensorHandle repeats, int64_t* output_size, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_replication_pad1d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* padding, int64_t padding_len_, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_replication_pad2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* padding, int64_t padding_len_, AtenTensorHandle* ret0); -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_resize_(AtenTensorHandle self, const int64_t* size, int64_t size_len_, int32_t* memory_format, AtenTensorHandle* ret0); -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_resize_as_(AtenTensorHandle self, AtenTensorHandle the_template, int32_t* memory_format, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_resize_(AtenTensorHandle self, const int64_t* size, int64_t size_len_, int32_t* memory_format); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_resize_as_(AtenTensorHandle self, AtenTensorHandle the_template, int32_t* memory_format); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_scatter_src_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, AtenTensorHandle src); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_scatter_value_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, double value); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_scatter_reduce_two_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, AtenTensorHandle src, const char* reduce, int32_t include_self); @@ -120,6 +126,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_sort(AtenTensorHandle self, int AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_sort_stable(AtenTensorHandle self, int32_t* stable, int64_t dim, int32_t descending, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_topk(AtenTensorHandle self, int64_t k, int64_t dim, int32_t largest, int32_t sorted, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_triangular_solve(AtenTensorHandle self, AtenTensorHandle A, int32_t upper, int32_t transpose, int32_t unitriangular, AtenTensorHandle* ret0, AtenTensorHandle* ret1); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_uniform(AtenTensorHandle self, double from, double to, AtenGeneratorHandle* generator, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_upsample_bicubic2d_backward(AtenTensorHandle grad_output, const int64_t* output_size, int64_t output_size_len_, const int64_t* input_size, int64_t input_size_len_, int32_t align_corners, double* scales_h, double* scales_w, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_upsample_linear1d_backward(AtenTensorHandle grad_output, const int64_t* output_size, int64_t output_size_len_, const int64_t* input_size, int64_t input_size_len_, int32_t align_corners, double* scales, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_upsample_trilinear3d_backward(AtenTensorHandle grad_output, const int64_t* output_size, int64_t output_size_len_, const int64_t* input_size, int64_t input_size_len_, int32_t align_corners, double* scales_d, double* scales_h, double* scales_w, AtenTensorHandle* ret0); diff --git a/torch/csrc/inductor/aoti_torch/utils.h b/torch/csrc/inductor/aoti_torch/utils.h index 0964479caabd8..44ca34b1c6e8d 100644 --- a/torch/csrc/inductor/aoti_torch/utils.h +++ b/torch/csrc/inductor/aoti_torch/utils.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include @@ -32,6 +33,16 @@ inline AtenTensorHandle tensor_pointer_to_tensor_handle(at::Tensor* tensor) { return reinterpret_cast(tensor); } +inline at::Generator* generator_handle_to_generator_pointer( + AtenGeneratorHandle handle) { + return reinterpret_cast(handle); +} + +inline AtenGeneratorHandle generator_pointer_to_generator_handle( + at::Generator* generator) { + return reinterpret_cast(generator); +} + inline AtenTensorHandle new_tensor_handle(at::Tensor&& tensor) { at::Tensor* new_tensor = new at::Tensor(std::move(tensor)); return tensor_pointer_to_tensor_handle(new_tensor); @@ -61,6 +72,13 @@ inline std::optional pointer_to_optional( : c10::nullopt; } +template <> +inline std::optional pointer_to_optional( + AtenGeneratorHandle* ptr) { + return ptr ? c10::make_optional(*generator_handle_to_generator_pointer(*ptr)) + : c10::nullopt; +} + inline std::optional pointer_to_optional_device( int32_t* device_type, int32_t device_index) { diff --git a/torch/csrc/jit/api/module.cpp b/torch/csrc/jit/api/module.cpp index 1b9932ed34d4d..45b99eb8e47aa 100644 --- a/torch/csrc/jit/api/module.cpp +++ b/torch/csrc/jit/api/module.cpp @@ -323,7 +323,7 @@ Module Module::deepcopy(std::optional device) const { Module Module::clone(bool inplace) const { std::unordered_map type_remap; - IValue::HashAliasedIValueMap memo; + IValue::HashIdentityIValueMap memo; const std::unordered_set ignored_methods; const std::unordered_set ignored_attributes; return clone_impl( @@ -335,7 +335,7 @@ Module Module::clone( const std::unordered_set& ignored_methods, const std::unordered_set& ignored_attributes) const { std::unordered_map type_remap; - IValue::HashAliasedIValueMap memo; + IValue::HashIdentityIValueMap memo; return clone_impl( type_remap, inplace, memo, ignored_methods, ignored_attributes); } @@ -343,7 +343,7 @@ Module Module::clone( Module Module::clone_impl( std::unordered_map& type_remap, bool inplace, - IValue::HashAliasedIValueMap memo, + IValue::HashIdentityIValueMap memo, const std::unordered_set& ignored_methods, const std::unordered_set& ignored_attributes) const { // Create a new _ivalue in the same compilation unit. diff --git a/torch/csrc/jit/api/module.h b/torch/csrc/jit/api/module.h index 0787210a4aefe..e779542e315fa 100644 --- a/torch/csrc/jit/api/module.h +++ b/torch/csrc/jit/api/module.h @@ -301,7 +301,7 @@ struct TORCH_API Module : public Object { Module clone_impl( std::unordered_map& type_remap, bool inplace, - IValue::HashAliasedIValueMap memo, + IValue::HashIdentityIValueMap memo, const std::unordered_set& ignored_methods, const std::unordered_set& ignored_attributes) const; diff --git a/torch/csrc/jit/passes/onnx/constant_map.cpp b/torch/csrc/jit/passes/onnx/constant_map.cpp index 8fd1bed0b7a1b..e249d0a83a64f 100644 --- a/torch/csrc/jit/passes/onnx/constant_map.cpp +++ b/torch/csrc/jit/passes/onnx/constant_map.cpp @@ -48,6 +48,14 @@ c10::optional ConstantValueMap::GetAllGraphInputsStatic() { return ConstantValueMap::getInstance().allGraphInputsStatic; } +void ConstantValueMap::SetAllGraphInputsReliableComputed(bool computed) { + ConstantValueMap::getInstance().allGraphInputsReliableComputed = computed; +} + +bool ConstantValueMap::GetAllGraphInputsReliableComputed() { + return ConstantValueMap::getInstance().allGraphInputsReliableComputed; +} + void ConstantValueMap::SetShape( const std::string& tensorName, const c10::SymbolicShape& shapeValue) { @@ -277,6 +285,7 @@ void ConstantValueMap::ClearMaps() { ConstantValueMap::getInstance().symbolDimMap.clear(); ConstantValueMap::getInstance().dimSymbolMap.clear(); ConstantValueMap::getInstance().allGraphInputsStatic = c10::nullopt; + ConstantValueMap::getInstance().allGraphInputsReliableComputed = false; } // For debug only. diff --git a/torch/csrc/jit/passes/onnx/constant_map.h b/torch/csrc/jit/passes/onnx/constant_map.h index 303d373eea56f..4261e45cc56c2 100644 --- a/torch/csrc/jit/passes/onnx/constant_map.h +++ b/torch/csrc/jit/passes/onnx/constant_map.h @@ -29,6 +29,9 @@ class ConstantValueMap { static void SetAllGraphInputsStatic(bool all_static); static c10::optional GetAllGraphInputsStatic(); + static void SetAllGraphInputsReliableComputed(bool computed); + static bool GetAllGraphInputsReliableComputed(); + static void SetShape( const std::string& tensorName, const c10::SymbolicShape& shapeValue); @@ -108,6 +111,8 @@ class ConstantValueMap { DimSymbolMap dimSymbolMap; // Stores if all graph-level inputs have static shape c10::optional allGraphInputsStatic; + // True if reliable has been computed for all graph inputs + bool allGraphInputsReliableComputed; }; } // namespace jit diff --git a/torch/csrc/jit/passes/onnx/shape_type_inference.cpp b/torch/csrc/jit/passes/onnx/shape_type_inference.cpp index dd79754f4c016..65d065adeb2b5 100644 --- a/torch/csrc/jit/passes/onnx/shape_type_inference.cpp +++ b/torch/csrc/jit/passes/onnx/shape_type_inference.cpp @@ -2035,11 +2035,17 @@ void UpdateReliable(Node* n) { } } +// Traverse the graph inputs and compute reliability (e.g., are shapes static). +// Since the inputs do not change during export, we save computation time by +// marking it as computed and subsequently skipping. void SetGraphInputTypeReliable(const Graph* g) { - for (auto graph_input : g->inputs()) { - if (!ConstantValueMap::HasTypeReliable(graph_input->debugName())) { - ConstantValueMap::SetTypeReliable(graph_input->debugName(), true); + if (!ConstantValueMap::GetAllGraphInputsReliableComputed()) { + for (auto graph_input : g->inputs()) { + if (!ConstantValueMap::HasTypeReliable(graph_input->debugName())) { + ConstantValueMap::SetTypeReliable(graph_input->debugName(), true); + } } + ConstantValueMap::SetAllGraphInputsReliableComputed(true); } } diff --git a/torch/csrc/jit/passes/quantization/insert_observers.cpp b/torch/csrc/jit/passes/quantization/insert_observers.cpp index e5df64f1929c7..de1cff1ba9d19 100644 --- a/torch/csrc/jit/passes/quantization/insert_observers.cpp +++ b/torch/csrc/jit/passes/quantization/insert_observers.cpp @@ -92,7 +92,7 @@ class ModuleCloneHelper { const ModuleQConfigMap& module_qconfig_map, bool inplace = false) { std::unordered_map type_remap; - IValue::HashAliasedIValueMap memo; + IValue::HashIdentityIValueMap memo; return clone_impl( module, module_qconfig_map, type_remap, inplace, std::move(memo)); } @@ -103,7 +103,7 @@ class ModuleCloneHelper { const ModuleQConfigMap& module_qconfig_map, std::unordered_map& type_remap, bool inplace, - IValue::HashAliasedIValueMap memo) { + IValue::HashIdentityIValueMap memo) { auto qconfig = module_qconfig_map.at(module._ivalue()); auto type = module.type(); // Create a new _ivalue in the same compilation unit. diff --git a/torch/csrc/jit/python/script_init.cpp b/torch/csrc/jit/python/script_init.cpp index 971b6c76ca47e..c46762a88615b 100644 --- a/torch/csrc/jit/python/script_init.cpp +++ b/torch/csrc/jit/python/script_init.cpp @@ -668,13 +668,13 @@ static constexpr std::array magic_method_names = { }; struct DeepCopyMemoTable { - std::shared_ptr map; + std::shared_ptr map; }; IValue pyIValueDeepcopy(const IValue& ivalue, const py::dict& memo) { if (!memo.contains(py::str("__torch_script_memo_table"))) { memo["__torch_script_memo_table"] = - DeepCopyMemoTable{std::make_shared()}; + DeepCopyMemoTable{std::make_shared()}; } auto& ivalue_memo = *py::cast(memo["__torch_script_memo_table"]).map; diff --git a/torch/csrc/profiler/standalone/execution_trace_observer.cpp b/torch/csrc/profiler/standalone/execution_trace_observer.cpp index d100d8090c074..9e8a995ec977d 100644 --- a/torch/csrc/profiler/standalone/execution_trace_observer.cpp +++ b/torch/csrc/profiler/standalone/execution_trace_observer.cpp @@ -28,8 +28,27 @@ #include #include +#ifdef USE_DISTRIBUTED +#include +#endif // USE_DISTRIBUTED + using namespace at; +// Collective property attributes +// https://github.com/pytorch/pytorch/issues/124674 +#ifdef USE_DISTRIBUTED +constexpr auto kETCommsName = "collective_name"; +constexpr auto kETInMsgNelems = "in_msg_nelems"; +constexpr auto kETOutMsgNelems = "out_msg_nelems"; +constexpr auto kETInSplit = "in_split_size"; +constexpr auto kETOutSplit = "out_split_size"; +constexpr auto kETGlobalRankStart = "global_rank_start"; +constexpr auto kETGlobalRankStride = "global_rank_stride"; +constexpr auto kETGroupSize = "pg_size"; +constexpr auto kETProcessGroupName = "pg_name"; +constexpr auto kETProcessGroupDesc = "pg_desc"; +#endif // USE_DISTRIBUTED + namespace torch { namespace profiler { namespace impl { @@ -258,6 +277,19 @@ static std::ofstream openOutputFile(const std::string& name) { return stream; } +static inline std::string getAttrJson( + const std::string& name, + const std::string& type, + const std::string& value) { + // note name and type are not quoted but value should be if it is a string. + return fmt::format( + R"JSON( + {{"name": "{}", "type": "{}", "value": {}}})JSON", + name, + type, + value); +} + static void writeJsonNode( std::ofstream& out, const std::string& name, @@ -277,14 +309,15 @@ static void writeJsonNode( const std::string& output_types = "[]", const std::string& operator_schema = "", const std::string& kernel_backend = "", - const std::string& kernel_file = "") { + const std::string& kernel_file = "", + const std::string& additiona_attrs = "") { out << fmt::format( R"JSON( {{ "id": {}, "name": "{}", "ctrl_deps": {}, "inputs": {{"values": {}, "shapes": {}, "types": {}}}, "outputs": {{"values": {}, "shapes": {}, "types": {}}}, - "attrs": [{{"name": "rf_id", "type": "uint64", "value": {}}},{{"name": "fw_parent", "type": "uint64", "value": {}}},{{"name": "seq_id", "type": "int64", "value": {}}},{{"name": "scope", "type": "uint64", "value": {}}},{{"name": "tid", "type": "uint64", "value": {}}},{{"name": "fw_tid", "type": "uint64", "value": {}}},{{"name": "op_schema", "type": "string", "value": "{}"}},{{"name": "kernel_backend", "type": "string", "value": "{}"}},{{"name": "kernel_file", "type": "string", "value": "{}"}}] + "attrs": [{{"name": "rf_id", "type": "uint64", "value": {}}},{{"name": "fw_parent", "type": "uint64", "value": {}}},{{"name": "seq_id", "type": "int64", "value": {}}},{{"name": "scope", "type": "uint64", "value": {}}},{{"name": "tid", "type": "uint64", "value": {}}},{{"name": "fw_tid", "type": "uint64", "value": {}}},{{"name": "op_schema", "type": "string", "value": "{}"}},{{"name": "kernel_backend", "type": "string", "value": "{}"}},{{"name": "kernel_file", "type": "string", "value": "{}"}}{}] }})JSON", id, name, @@ -303,7 +336,8 @@ static void writeJsonNode( fw_tid, operator_schema, kernel_backend, - kernel_file); + kernel_file, + additiona_attrs); } inline std::string timeString(const std::time_t timepoint) { @@ -332,7 +366,7 @@ static bool initExecutionTraceStart(ExecutionTraceObserver& ob) { ob.out << fmt::format( R"JSON({{ - "schema": "1.0.4-chakra.0.0.4", "pid": {}, "time": "{}", "start_ts": {}, + "schema": "1.1.0-chakra.0.0.4", "pid": {}, "time": "{}", "start_ts": {}, "nodes": [)JSON", ob.pid, ob.record_time, @@ -486,6 +520,56 @@ inline void handleKernelBackendInfo( } } +// Additional attributes for commounication collectives +inline std::string getCommsNodeAttrs(const RecordFunction& fn) { + std::vector attrs; + +#ifdef USE_DISTRIBUTED + // We rely on paramcommsdebug object that is available in thread local info + auto debugInfo = dynamic_cast( + c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::PARAM_COMMS_INFO)); + if (debugInfo == nullptr) { + LOG(WARNING) << "ParamCommsDebugInfo not available for function: " + << fn.name(); + return ", " + getAttrJson("debug", "string", "\"missing comms info\""); + } + + // get NcclMeta from record function, this used ParamCommsDebugInfo above + auto meta = saveNcclMeta(fn, false /*truncate*/); + + auto addAttr = + [&](const char* commsMetaName, const char* etMetaName, const char* type) { + auto it = meta.find(commsMetaName); + if (it != meta.end()) { + attrs.push_back(getAttrJson(etMetaName, type, it->second)); + } + }; + + addAttr(kCommsName, kETCommsName, "string"); + addAttr(kDtype, kDtype, "string"); + + addAttr(kInMsgNelems, kETInMsgNelems, "uint64"); + addAttr(kOutMsgNelems, kETOutMsgNelems, "uint64"); + + // following two metadata are lists. + addAttr(kInSplit, kETInSplit, "string"); + addAttr(kOutSplit, kETOutSplit, "string"); + + addAttr(kGlobalRankStart, kETGlobalRankStart, "uint64"); + addAttr(kGlobalRankStride, kETGlobalRankStride, "uint64"); + + // pg_name is a string. + addAttr(kProcessGroupName, kETProcessGroupName, "string"); + addAttr(kProcessGroupDesc, kETProcessGroupDesc, "string"); + + addAttr(kGroupSize, kETGroupSize, "uint64"); + +#endif // USE_DISTRIBUTED + + // XXX consider using as string stream? + return attrs.size() == 0 ? "" : fmt::format(", {}", fmt::join(attrs, ", ")); +} + static void recordOperatorStart( ExecutionTraceObserver& ob, FunctionCallContext& fc, @@ -645,6 +729,9 @@ static void onFunctionExit(const RecordFunction& fn, ObserverContext* ctx_ptr) { op_schema_str = json_str_escape(c10::toString(op_schema.value())); } + const std::string additiona_attrs = + fn.isNcclMeta() ? getCommsNodeAttrs(fn) : ""; + writeJsonNode( ob->out, fc.name, @@ -664,7 +751,8 @@ static void onFunctionExit(const RecordFunction& fn, ObserverContext* ctx_ptr) { vectorToString(output_types), op_schema_str, fc.kernel_backend, - fc.kernel_file); + fc.kernel_file, + additiona_attrs); ob->out << ","; } catch (const std::exception& e) { LOG(WARNING) << "Exception in execution trace observer: [" << fc.name diff --git a/torch/csrc/profiler/util.cpp b/torch/csrc/profiler/util.cpp index f301596fca813..21e16a7e7eaee 100644 --- a/torch/csrc/profiler/util.cpp +++ b/torch/csrc/profiler/util.cpp @@ -334,25 +334,22 @@ std::vector inputTypes(const at::RecordFunction& fn) { // ---------------------------------------------------------------------------- // -- NCCL Metadata ----------------------------------------------------------- // ---------------------------------------------------------------------------- -#ifdef USE_DISTRIBUTED -static constexpr auto kCommsName = "Collective name"; -static constexpr auto kDtype = "dtype"; -static constexpr auto kInMsgNelems = "In msg nelems"; -static constexpr auto kOutMsgNelems = "Out msg nelems"; -static constexpr auto kInSplit = "In split size"; -static constexpr auto kOutSplit = "Out split size"; -static constexpr auto kGlobalRankStart = "Global rank start"; -static constexpr auto kGlobalRankStride = "Global rank stride"; -static constexpr auto kGroupSize = "Group size"; -static constexpr auto kProcessGroupName = "Process Group Name"; -static constexpr auto kProcessGroupDesc = "Process Group Description"; -static constexpr auto kGroupRanks = "Process Group Ranks"; static constexpr int32_t kTruncatLength = 30; -#endif // USE_DISTRIBUTED + +template +inline std::string format_list(ListLikeType list, bool truncate) { + if (truncate && list.size() > kTruncatLength) { + return fmt::format( + "\"[{}, ...]\"", + fmt::join(list.begin(), list.begin() + kTruncatLength, ", ")); + } + return fmt::format("\"[{}]\"", fmt::join(list.begin(), list.end(), ", ")); +} std::unordered_map saveNcclMeta( - const at::RecordFunction& fn) { + const at::RecordFunction& fn, + bool truncate) { std::unordered_map map; #ifdef USE_DISTRIBUTED auto debugInfo = dynamic_cast( @@ -369,34 +366,13 @@ std::unordered_map saveNcclMeta( kDtype, fmt::format("\"{}\"", c10::toString(debugInfo->getDType()))); map.emplace(kInMsgNelems, std::to_string(debugInfo->getInMessageNelems())); map.emplace(kOutMsgNelems, std::to_string(debugInfo->getOutMessageNelems())); + auto& inSplitSizes = debugInfo->getInputSplitSizes(); - if (!inSplitSizes.empty() && inSplitSizes.size() <= kTruncatLength) { - map.emplace( - kInSplit, fmt::format("\"[{}]\"", fmt::join(inSplitSizes, ", "))); - } else if (inSplitSizes.size() > kTruncatLength) { - map.emplace( - kInSplit, - fmt::format( - "\"[{}, ...]\"", - fmt::join( - inSplitSizes.begin(), - inSplitSizes.begin() + kTruncatLength, - ", "))); - } + map.emplace(kInSplit, format_list(inSplitSizes, truncate)); + auto& outSplitSizes = debugInfo->getOutputSplitSizes(); - if (!outSplitSizes.empty() && outSplitSizes.size() <= kTruncatLength) { - map.emplace( - kOutSplit, fmt::format("\"[{}]\"", fmt::join(outSplitSizes, ", "))); - } else if (outSplitSizes.size() > kTruncatLength) { - map.emplace( - kOutSplit, - fmt::format( - "\"[{}, ...]\"", - fmt::join( - outSplitSizes.begin(), - outSplitSizes.begin() + kTruncatLength, - ", "))); - } + map.emplace(kOutSplit, format_list(outSplitSizes, truncate)); + auto globalRankStart = debugInfo->getGlobalRankStart(); if (globalRankStart >= 0) { map.emplace(kGlobalRankStart, std::to_string(globalRankStart)); @@ -415,20 +391,7 @@ std::unordered_map saveNcclMeta( map.emplace(kProcessGroupDesc, fmt::format("\"{}\"", group_desc)); } auto& groupRanks = debugInfo->getGroupRanks(); - if (!groupRanks.empty() && groupRanks.size() <= kTruncatLength) { - map.emplace( - kGroupRanks, fmt::format("\"[{}]\"", fmt::join(groupRanks, ", "))); - } else if (groupRanks.size() > kTruncatLength) { - map.emplace( - kGroupRanks, - fmt::format( - "\"[{}, ..., {}]\"", - fmt::join( - groupRanks.begin(), - groupRanks.begin() + kTruncatLength - 1, - ", "), - groupRanks.back())); - } + map.emplace(kGroupRanks, format_list(groupRanks, truncate)); #endif // USE_DISTRIBUTED return map; } diff --git a/torch/csrc/profiler/util.h b/torch/csrc/profiler/util.h index c8216c93f41c5..3c995b49e602b 100644 --- a/torch/csrc/profiler/util.h +++ b/torch/csrc/profiler/util.h @@ -100,7 +100,7 @@ TORCH_API std::vector inputTypes(const at::RecordFunction& fn); std::unordered_map TORCH_API saveExtraArgs(const at::RecordFunction& fn); std::unordered_map TORCH_API -saveNcclMeta(const at::RecordFunction& fn); +saveNcclMeta(const at::RecordFunction& fn, bool truncate = true); uint64_t TORCH_API computeFlops( const std::string& op_name, @@ -157,6 +157,21 @@ struct HashCombine { } }; +#ifdef USE_DISTRIBUTED +constexpr auto kCommsName = "Collective name"; +constexpr auto kDtype = "dtype"; +constexpr auto kInMsgNelems = "In msg nelems"; +constexpr auto kOutMsgNelems = "Out msg nelems"; +constexpr auto kInSplit = "In split size"; +constexpr auto kOutSplit = "Out split size"; +constexpr auto kGlobalRankStart = "Global rank start"; +constexpr auto kGlobalRankStride = "Global rank stride"; +constexpr auto kGroupSize = "Group size"; +constexpr auto kProcessGroupName = "Process Group Name"; +constexpr auto kProcessGroupDesc = "Process Group Description"; +constexpr auto kGroupRanks = "Process Group Ranks"; +#endif // USE_DISTRIBUTED + } // namespace impl } // namespace profiler } // namespace torch diff --git a/torch/distributed/__init__.py b/torch/distributed/__init__.py index eb7a690fa9589..3e7dce97b54c9 100644 --- a/torch/distributed/__init__.py +++ b/torch/distributed/__init__.py @@ -54,6 +54,8 @@ def is_available() -> bool: set_debug_level, set_debug_level_from_env, _make_nccl_premul_sum, + _ControlCollectives, + _StoreCollectives, ) class _DistributedPdb(pdb.Pdb): diff --git a/torch/distributed/_composable/fsdp/_fsdp_collectives.py b/torch/distributed/_composable/fsdp/_fsdp_collectives.py index f279703151591..b7264cb34d6dd 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_collectives.py +++ b/torch/distributed/_composable/fsdp/_fsdp_collectives.py @@ -125,9 +125,11 @@ def foreach_reduce( orig_dtype: torch.dtype, reduce_dtype: Optional[torch.dtype], device: torch.device, - all_reduce_group: Optional[dist.ProcessGroup], + all_reduce_group: Optional[dist.ProcessGroup], # not `None` iff HSDP all_reduce_stream: torch.cuda.Stream, -) -> torch.cuda.Event: + all_reduce_grads: bool, + partial_reduce_output: Optional[torch.Tensor], # only used for HSDP +) -> Tuple[torch.cuda.Event, Optional[torch.Tensor]]: """ ``unsharded_grads`` owns the references to the gradients computed by autograd, so clearing the list frees the gradients. @@ -163,36 +165,43 @@ def foreach_reduce( # computed in the default stream current_stream.wait_stream(reduce_scatter_stream) unsharded_grads.clear() - post_reduce_output = reduce_scatter_input.new_empty( - (reduce_scatter_output_numel,) - ) + reduce_output = reduce_scatter_input.new_empty((reduce_scatter_output_numel,)) _div_if_needed(reduce_scatter_input, predivide_factor) dist.reduce_scatter_tensor( - output=post_reduce_output, + output=reduce_output, input=reduce_scatter_input, group=reduce_scatter_group, op=ReduceOp.AVG if predivide_factor is None else ReduceOp.SUM, ) - view_out_stream = reduce_scatter_stream - if all_reduce_group is not None: - view_out_stream = all_reduce_stream - all_reduce_stream.wait_stream(reduce_scatter_stream) - with torch.cuda.stream(all_reduce_stream): - dist.all_reduce( - post_reduce_output, - group=all_reduce_group, - op=ReduceOp.AVG if predivide_factor is None else ReduceOp.SUM, - ) - with torch.cuda.stream(view_out_stream): - _div_if_needed(post_reduce_output, postdivide_factor) - post_reduce_output = _to_dtype_if_needed(post_reduce_output, orig_dtype) - # - View out and accumulate + post_reduce_stream = reduce_scatter_stream + if all_reduce_group is not None: # HSDP + # Accumulations must run in the reduce-scatter stream + if not all_reduce_grads: + if partial_reduce_output is not None: + partial_reduce_output += reduce_output + else: + partial_reduce_output = reduce_output + return post_reduce_stream.record_event(), partial_reduce_output + if partial_reduce_output is not None: + reduce_output += partial_reduce_output + post_reduce_stream = all_reduce_stream + all_reduce_stream.wait_stream(reduce_scatter_stream) + with torch.cuda.stream(all_reduce_stream): + dist.all_reduce( + reduce_output, + group=all_reduce_group, + op=ReduceOp.AVG if predivide_factor is None else ReduceOp.SUM, + ) + with torch.cuda.stream(post_reduce_stream): + _div_if_needed(reduce_output, postdivide_factor) + reduce_output = _to_dtype_if_needed(reduce_output, orig_dtype) + # View out and accumulate sharded gradients flat_grad_offset = 0 # [0, reduce_scatter_output_numel - 1] for padded_unsharded_size, fsdp_param in zip( padded_unsharded_sizes, fsdp_params ): new_sharded_grad = torch.as_strided( - post_reduce_output, + reduce_output, size=fsdp_param.sharded_size, stride=fsdp_param.contiguous_sharded_stride, storage_offset=flat_grad_offset, @@ -220,12 +229,12 @@ def foreach_reduce( fsdp_param.sharded_param.grad = new_sharded_dtensor_grad padded_sharded_numel = padded_unsharded_size.numel() // world_size flat_grad_offset += padded_sharded_numel - post_reduce_view_out_event = view_out_stream.record_event() + post_reduce_event = post_reduce_stream.record_event() # The RS output is allocated in the RS stream and used in the default # stream (for optimizer). To ensure its memory is not reused for later # RSs, we do not need extra synchronization since the sharded parameters # hold refs through the end of backward. - return post_reduce_view_out_event + return post_reduce_event, None def foreach_reduce_scatter_copy_in( diff --git a/torch/distributed/_composable/fsdp/_fsdp_common.py b/torch/distributed/_composable/fsdp/_fsdp_common.py index 94b0249177697..1395e3487847f 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_common.py +++ b/torch/distributed/_composable/fsdp/_fsdp_common.py @@ -117,20 +117,29 @@ def _from_local_no_grad( global_stride: Tuple[int, ...], ) -> DTensor: """ - This method is similar to ``DTensor.from_local()`` except it avoids some - CPU overhead by avoiding default args and not being differentiable. + This method is similar to ``DTensor.from_local()`` except that in eager mode + it avoids some CPU overhead by avoiding default args and not being differentiable. """ - return DTensor( - # Use the local tensor directly instead of constructing a new tensor - # variable, e.g. with `view_as()`, since this is not differentiable - local_tensor, - device_mesh, - placements, - shape=global_size, - dtype=local_tensor.dtype, - requires_grad=local_tensor.requires_grad, - stride=global_stride, - ) + if not torch._dynamo.compiled_autograd.compiled_autograd_enabled: + return DTensor( + # Use the local tensor directly instead of constructing a new tensor + # variable, e.g. with `view_as()`, since this is not differentiable + local_tensor, + device_mesh, + placements, + shape=global_size, + dtype=local_tensor.dtype, + requires_grad=local_tensor.requires_grad, + stride=global_stride, + ) + else: + return DTensor.from_local( + local_tensor, + device_mesh, + placements, + shape=global_size, + stride=global_stride, + ) def _to_dtype_if_needed( diff --git a/torch/distributed/_composable/fsdp/_fsdp_param_group.py b/torch/distributed/_composable/fsdp/_fsdp_param_group.py index 9e9813102db3a..ea2307222ce14 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_param_group.py +++ b/torch/distributed/_composable/fsdp/_fsdp_param_group.py @@ -138,11 +138,15 @@ def __init__( # Holds the reduce-scatter/all-reduce view-out CUDA event that marks the end of # the group's post-backward (e.g. reduce-scatter, all-reduce and div), which # should be waited on at the end of backward - self._post_reduce_view_out_event: Optional[torch.cuda.Event] = None + self._post_reduce_event: Optional[torch.cuda.Event] = None # Holds the reshard-after-forward CUDA event when resharding to a # different world size, which should be waited on in the next unshard self._reshard_after_forward_event: Optional[torch.cuda.Event] = None + # Only for HSDP, if accumulating gradients without all-reduce, save the + # partial reduce output (only reduce-scattered but not all-reduced) + self._partial_reduce_output: Optional[torch.Tensor] = None + # Initialization # def _init_mp_dtypes(self) -> None: for fsdp_param in self.fsdp_params: @@ -273,6 +277,8 @@ def _record_post_forward(self) -> None: self._post_forward_indices.append(post_forward_index) def pre_backward(self, *unused: Any): + if self._training_state == TrainingState.PRE_BACKWARD: + return with torch.profiler.record_function("FSDP::pre_backward"): self._training_state = TrainingState.PRE_BACKWARD self.unshard() # no-op if prefetched @@ -311,7 +317,7 @@ def post_backward(self, *unused: Any): if len(fsdp_params_with_grad) == 0: return with torch.profiler.record_function("FSDP::post_backward_reduce"): - self._post_reduce_view_out_event = foreach_reduce( + self._post_reduce_event, self._partial_reduce_output = foreach_reduce( fsdp_params_with_grad, unsharded_grads, self._reduce_scatter_process_group, @@ -319,16 +325,16 @@ def post_backward(self, *unused: Any): self._orig_dtype, self._reduce_dtype, self.device, - self._all_reduce_process_group - if self._is_hsdp and self.all_reduce_grads - else None, + self._all_reduce_process_group if self._is_hsdp else None, self.comm_ctx.all_reduce_stream, + self.all_reduce_grads, + self._partial_reduce_output, ) def finalize_backward(self): - if self._post_reduce_view_out_event is not None: - torch.cuda.current_stream().wait_event(self._post_reduce_view_out_event) - self._post_reduce_view_out_event = None + if self._post_reduce_event is not None: + torch.cuda.current_stream().wait_event(self._post_reduce_event) + self._post_reduce_event = None for fsdp_param in self.fsdp_params: if fsdp_param.grad_offload_event is not None: fsdp_param.grad_offload_event.synchronize() diff --git a/torch/distributed/_composable/fsdp/_fsdp_state.py b/torch/distributed/_composable/fsdp/_fsdp_state.py index bab24c283063d..15a00e83f0863 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_state.py +++ b/torch/distributed/_composable/fsdp/_fsdp_state.py @@ -5,7 +5,6 @@ import torch import torch.nn as nn from torch.autograd import Variable -from torch.autograd.graph import register_multi_grad_hook from torch.distributed._composable_state import ( _get_module_state, _insert_module_state, @@ -201,11 +200,12 @@ def _post_forward(self, module: nn.Module, input: Any, output: Any) -> Any: ) return output - def _pre_backward(self, *unused: Any) -> None: + def _pre_backward(self, grad: torch.Tensor) -> torch.Tensor: self._training_state = TrainingState.PRE_BACKWARD self._register_root_post_backward_final_callback() if self._fsdp_param_group: - self._fsdp_param_group.pre_backward(*unused) + self._fsdp_param_group.pre_backward() + return grad def _root_post_backward_final_callback(self) -> None: with torch.profiler.record_function("FSDP::root_post_backward_callback"): @@ -235,7 +235,8 @@ def _register_pre_backward_hook(self, output: Any) -> Any: t for t in flat_outputs if (torch.is_tensor(t) and t.requires_grad) ) if tensors: - register_multi_grad_hook(tensors, self._pre_backward, mode="any") + for tensor in tensors: + tensor.register_hook(self._pre_backward) return output def _register_root_post_backward_final_callback(self): diff --git a/torch/distributed/_composable/fsdp/fully_shard.py b/torch/distributed/_composable/fsdp/fully_shard.py index a5204701731c4..981b82987462e 100644 --- a/torch/distributed/_composable/fsdp/fully_shard.py +++ b/torch/distributed/_composable/fsdp/fully_shard.py @@ -208,7 +208,7 @@ def set_is_last_backward(self, is_last_backward: bool) -> None: state._state_ctx.is_last_backward = is_last_backward def set_requires_gradient_sync( - self, requires_gradient_sync: bool, recurse: bool = True + self, requires_gradient_sync: bool, *, recurse: bool = True ) -> None: """ Sets if the module should sync gradients. This can be used to implement @@ -231,16 +231,13 @@ def set_requires_gradient_sync( fsdp_param_group.all_reduce_grads = requires_gradient_sync def set_requires_all_reduce( - self, requires_all_reduce: bool, recurse: bool = True + self, requires_all_reduce: bool, *, recurse: bool = True ) -> None: """ Sets if the module should all-reduce gradients. This can be used to implement gradient accumulation with only reduce-scatter but not all-reduce for HSDP. """ - # TODO: post_reduce_output += fsdp_param.sharded_param.grad - # after reduce-scatter and before all-reduce - raise NotImplementedError("requires_all_reduce is not yet supported in HSDP") self_module = cast(nn.Module, self) modules = list(self_module.modules()) if recurse else [self_module] for module in modules: @@ -250,7 +247,7 @@ def set_requires_all_reduce( fsdp_param_group.all_reduce_grads = requires_all_reduce def set_reshard_after_backward( - self, reshard_after_backward: bool, recurse: bool = True + self, reshard_after_backward: bool, *, recurse: bool = True ) -> None: """ Sets if the module should reshard parameters after backward. This can diff --git a/torch/distributed/_functional_collectives.py b/torch/distributed/_functional_collectives.py index b1250eddf0377..8d598713cf50d 100644 --- a/torch/distributed/_functional_collectives.py +++ b/torch/distributed/_functional_collectives.py @@ -894,6 +894,12 @@ def _all_to_all_single_meta( return input.new_empty(out_size) +def _all_gather_into_tensor_out_native_meta(input, group_size, group_name, *, out): + shape = list(input.size()) + shape[0] *= group_size + return input.new_empty(shape) + + def _all_gather_into_tensor_native_meta(input, group_size, group_name): shape = list(input.size()) shape[0] *= group_size @@ -932,6 +938,9 @@ def _reduce_scatter_tensor_coalesced_native_meta( lib_impl.impl("all_reduce_coalesced", _all_reduce_coalesced_meta, "Meta") lib_impl.impl("all_reduce_coalesced_", _all_reduce_coalesced__meta, "Meta") lib_impl.impl("wait_tensor", _wait_tensor_meta, "Meta") + lib_impl.impl( + "all_gather_into_tensor_out", _all_gather_into_tensor_out_native_meta, "Meta" + ) lib_impl.impl("all_gather_into_tensor", _all_gather_into_tensor_native_meta, "Meta") lib_impl.impl( "all_gather_into_tensor_coalesced", diff --git a/torch/distributed/_spmd/batch_dim_utils.py b/torch/distributed/_spmd/batch_dim_utils.py index afb9dd2e7d3b4..6d36b2e381187 100644 --- a/torch/distributed/_spmd/batch_dim_utils.py +++ b/torch/distributed/_spmd/batch_dim_utils.py @@ -9,11 +9,7 @@ from torch import Tensor from torch.distributed._tensor import DeviceMesh, Replicate, Shard -from torch.distributed._tensor.ops.view_ops import ( - DimSpec, - InputDim, - ops as view_op_rules, -) +from torch.distributed._tensor.ops.view_ops import dim_maps, DimSpec, InputDim from torch.distributed._tensor.placement_types import _Partial, DTensorSpec aten = torch.ops.aten @@ -80,12 +76,12 @@ def compute_batch_dim(self, node: fx.Node, full_reduction=False) -> int: return self.batch_dim_map[node] if node.target in self.dim_rule_map: - view_op_rule = view_op_rules[self.dim_rule_map[node.target]] # type: ignore[index] + dim_map = dim_maps[self.dim_rule_map[node.target]] # type: ignore[index] args_val = pytree.tree_map_only(fx.Node, lambda n: n.meta["val"], node.args) kwargs_val = pytree.tree_map_only( fx.Node, lambda n: n.meta["val"], node.kwargs ) - output_dim_rules = view_op_rule.dim_map(*args_val, **kwargs_val) + output_dim_rules = dim_map(*args_val, **kwargs_val) def collect_input_dim(cmd: DimSpec, input_dims: Set[int]): if isinstance(cmd, InputDim): diff --git a/torch/distributed/_tensor/__init__.py b/torch/distributed/_tensor/__init__.py index f7afe41e753c5..7c391c4821aa3 100644 --- a/torch/distributed/_tensor/__init__.py +++ b/torch/distributed/_tensor/__init__.py @@ -10,7 +10,12 @@ from torch.distributed._tensor.ops.utils import normalize_to_torch_size from torch.distributed._tensor.placement_types import Placement, Replicate, Shard from torch.distributed.device_mesh import _mesh_resources, DeviceMesh, init_device_mesh -from torch.optim.optimizer import _foreach_supported_types +from torch.optim.optimizer import ( + _foreach_supported_types as _optim_foreach_supported_types, +) +from torch.utils._foreach_utils import ( + _foreach_supported_types as _util_foreach_supported_types, +) # All public APIs from dtensor package @@ -25,10 +30,13 @@ ] -# Append DTensor to the list of supported types for foreach implementation of optimizer -# so that we will try to use foreach over the for-loop implementation on CUDA. -if DTensor not in _foreach_supported_types: - _foreach_supported_types.append(DTensor) +# Append DTensor to the list of supported types for foreach implementation for optimizer +# and clip_grad_norm_ so that we will try to use foreach over the for-loop implementation on CUDA. +if DTensor not in _optim_foreach_supported_types: + _optim_foreach_supported_types.append(DTensor) + +if DTensor not in _util_foreach_supported_types: + _util_foreach_supported_types.append(DTensor) def _dtensor_init_helper( diff --git a/torch/distributed/_tensor/op_schema.py b/torch/distributed/_tensor/op_schema.py index 7d5bd691395be..4918bffec6213 100644 --- a/torch/distributed/_tensor/op_schema.py +++ b/torch/distributed/_tensor/op_schema.py @@ -161,6 +161,14 @@ def output_ndim(self): def output_shape(self): return self.strategies[0].output_spec.shape + @property + def ndim(self): + return self.output_ndim + + @property + def shape(self): + return self.output_shape + class TupleStrategy(StrategyType): """ diff --git a/torch/distributed/_tensor/ops/view_ops.py b/torch/distributed/_tensor/ops/view_ops.py index be72cc9509f58..598c973170f4a 100644 --- a/torch/distributed/_tensor/ops/view_ops.py +++ b/torch/distributed/_tensor/ops/view_ops.py @@ -16,23 +16,24 @@ import torch from torch import Tensor -from torch._subclasses.fake_tensor import unset_fake_temporarily -from torch.distributed._tensor._utils import compute_local_shape from torch.distributed._tensor.api import Shard from torch.distributed._tensor.op_schema import ( OpSchema, - OutputSharding, + OpStrategy, + PlacementStrategy, RuntimeSchemaInfo, + StrategyType, ) from torch.distributed._tensor.ops.utils import ( + generate_redistribute_costs, normalize_dim, normalize_dims, prod, - register_prop_rule, + register_op_strategy, ) from torch.distributed._tensor.placement_types import DTensorSpec, Placement, Replicate -from torch.fx.experimental.proxy_tensor import disable_proxy_modes_tracing +from torch.distributed.device_mesh import DeviceMesh aten = torch.ops.aten @@ -454,68 +455,41 @@ def dim_reduction( ) -@dataclass -class Op: - dim_map: Callable[..., DimMap] - shape_argnum: Optional[int] = None - - -ops: Dict[Callable[..., torch.Tensor], Op] = { - torch.atleast_1d: Op(dim_map=lambda x: dim_pad_left(x.ndim, 1)), - torch.atleast_2d: Op(dim_map=lambda x: dim_pad_left(x.ndim, 2)), - torch.atleast_3d: Op(dim_map=lambda x: dim_atleast_3d(x.ndim)), - torch.broadcast_to: Op( - dim_map=lambda input, shape: expand(input.shape, shape), shape_argnum=1 - ), - Tensor.expand: Op( - dim_map=lambda self, *sizes: expand(self.shape, normalize_sizes(sizes)), - shape_argnum=1, - ), - torch.flatten: Op(dim_map=lambda tensor: dim_flatten(tensor.ndim)), - torch.movedim: Op( - dim_map=lambda input, source, destination: dim_movedim( - input.ndim, source, destination - ) - ), - torch.permute: Op( - dim_map=lambda input, dims: tuple( - InputDim(i) for i in normalize_dims(dims, input.ndim) - ) - ), - torch.ravel: Op(dim_map=lambda tensor: dim_flatten(tensor.ndim)), - Tensor.repeat: Op(dim_map=lambda self, *sizes: dim_repeat(self.ndim, sizes)), - torch.reshape: Op( - dim_map=lambda input, shape: view_groups(input.shape, shape), - shape_argnum=1, +dim_maps: Dict[Callable[..., torch.Tensor], Callable[..., DimMap]] = { + torch.atleast_1d: lambda x: dim_pad_left(x.ndim, 1), + torch.atleast_2d: lambda x: dim_pad_left(x.ndim, 2), + torch.atleast_3d: lambda x: dim_atleast_3d(x.ndim), + torch.broadcast_to: lambda input, shape: expand(input.shape, shape), + Tensor.expand: lambda self, *sizes: expand(self.shape, normalize_sizes(sizes)), + torch.flatten: lambda tensor: dim_flatten(tensor.ndim), + torch.movedim: lambda input, source, destination: dim_movedim( + input.ndim, source, destination ), - torch.squeeze: Op(dim_map=lambda input, dim=None: dim_squeeze(input.shape, dim)), - torch.tile: Op(dim_map=lambda input, dims: dim_tile(input.ndim, dims)), - torch.transpose: Op( - dim_map=lambda input, dim0, dim1: dim_transpose(input.ndim, dim0, dim1) + torch.permute: lambda input, dims: tuple( + InputDim(i) for i in normalize_dims(dims, input.ndim) ), - torch.unsqueeze: Op(dim_map=lambda input, dim: dim_unsqueeze(input.ndim, dim)), - Tensor.view: Op( - dim_map=lambda input, *shape: view_groups(input.shape, shape), - shape_argnum=1, - ), - torch.view_as_complex: Op( - dim_map=lambda input: dim_flatten(input.ndim, input.ndim - 2) - ), - torch.view_as_real: Op(dim_map=lambda input: dim_view_as_real(input.shape)), + torch.ravel: lambda tensor: dim_flatten(tensor.ndim), + Tensor.repeat: lambda self, *sizes: dim_repeat(self.ndim, sizes), + torch.reshape: lambda input, shape: view_groups(input.shape, shape), + torch.squeeze: lambda input, dim=None: dim_squeeze(input.shape, dim), + torch.tile: lambda input, dims: dim_tile(input.ndim, dims), + torch.transpose: lambda input, dim0, dim1: dim_transpose(input.ndim, dim0, dim1), + torch.unsqueeze: lambda input, dim: dim_unsqueeze(input.ndim, dim), + Tensor.view: lambda input, *shape: view_groups(input.shape, shape), + torch.view_as_complex: lambda input: dim_flatten(input.ndim, input.ndim - 2), + torch.view_as_real: lambda input: dim_view_as_real(input.shape), } def propagate_shape_and_sharding( - in_shard: Sequence[Placement], + input_src_placements: Sequence[Placement], local_in_shape: Shape, rule: DimMap, mesh_sizes: Shape, -) -> Tuple[Shape, Optional[Sequence[Placement]], torch.Tensor]: +) -> Tuple[Sequence[Placement], Sequence[Placement]]: """ - Determine output sharding and tensor shape based on given global tensor shape and input sharding. - - Takes as input the global shape of the tensor, and the input sharding, - and produce corresponding output sharding and shape of the output tensor. + Determine input target sharding and output sharding based on + given global tensor shape and input source sharding. Sharding propagation follows mapped dimensions: - An output dimension that maps directly to an input dimension is sharded equally @@ -524,16 +498,13 @@ def propagate_shape_and_sharding( - An output dimension that is a split of the input dimension can only be sharded if the leftmost split size is divisible by the mesh dimension """ - assert len(in_shard) == len(mesh_sizes) - sharded_in_dims: Set[int] = {s.dim for s in in_shard if isinstance(s, Shard)} + assert len(input_src_placements) == len(mesh_sizes) # for each input dim, for each mesh dim, provides a list of possible shardable dimensions - shardable_dims: torch.Tensor = torch.ones( - (len(local_in_shape), len(mesh_sizes)), dtype=torch.bool - ) + mesh_ndim = len(mesh_sizes) + shardable_dims: Dict[int, List[bool]] = {} # in case an input dimension disappears (e.g. collapsing, reduction) # we cannot shard in that dimension (we need a replication fall-back rule) - seen_input_dims: Set[int] = set() def collect_used_inputs(cmd: DimSpec) -> None: @@ -545,28 +516,19 @@ def collect_used_inputs(cmd: DimSpec) -> None: for cmd in rule: collect_used_inputs(cmd) for dim in range(len(local_in_shape)): - shardable_dims[dim, :] = dim in seen_input_dims + shardable_dims[dim] = [dim in seen_input_dims] * mesh_ndim - def get_dim_size(cmd: DimSpec) -> Tuple[int, Optional[InputDim]]: + def get_in_dim_to_shard(cmd: DimSpec) -> Optional[InputDim]: if isinstance(cmd, InputDim): - seen_input_dims.add(cmd.input_dim) - return ( - local_in_shape[cmd.input_dim], - cmd if cmd.input_dim in sharded_in_dims else None, - ) + return cmd elif isinstance(cmd, Flatten): for dim in cmd.input_dims[1:]: if isinstance(dim, InputDim): - shardable_dims[dim.input_dim, :] = False + shardable_dims[dim.input_dim] = [False] * mesh_ndim dim0 = cmd.input_dims[0] - return ( - prod(get_dim_size(a)[0] for a in cmd.input_dims), - dim0 - if isinstance(dim0, InputDim) and dim0.input_dim in sharded_in_dims - else None, - ) + return dim0 if isinstance(dim0, InputDim) else None elif isinstance(cmd, Split): - _, in_dim = get_dim_size(cmd.input_dim) + in_dim = get_in_dim_to_shard(cmd.input_dim) out_size = cmd.group_shape[cmd.split_id] if cmd.split_id == 0 and in_dim is not None: # we need to check that the input dimension is divisible @@ -579,14 +541,13 @@ def get_dim_size(cmd: DimSpec) -> Tuple[int, Optional[InputDim]]: # but we will allow it if that's the input and it's compatible # 1. is this dimension shardable on each individual mesh dim? - for mesh_dim, mesh_dim_size in enumerate(mesh_sizes): - shardable_dims[in_dim.input_dim, mesh_dim] = ( - out_size % mesh_dim_size == 0 - ) + shardable_dims[in_dim.input_dim] = [ + out_size % mesh_dim_size == 0 for mesh_dim_size in mesh_sizes + ] # 2. here we special case things like [Shard(0), Shard(0)] submesh_size = 1 - for size, shard in zip(mesh_sizes, in_shard): + for size, shard in zip(mesh_sizes, input_src_placements): if isinstance(shard, Shard) and shard.dim == in_dim: submesh_size *= size assert ( @@ -594,158 +555,113 @@ def get_dim_size(cmd: DimSpec) -> Tuple[int, Optional[InputDim]]: ), f"Resulting dimension size {out_size} is not divisible by its mesh dimension {submesh_size}." # we will only shard our first component of the split - return out_size, in_dim if cmd.split_id == 0 else None - elif isinstance(cmd, Singleton): - return 1, None - elif isinstance(cmd, Broadcast): - return cmd.dim_size, None - elif isinstance(cmd, NewDim): - return cmd.size, None + return in_dim if cmd.split_id == 0 else None elif isinstance(cmd, Repeat): - size, in_dim = get_dim_size(cmd.input_dim) + in_dim = get_in_dim_to_shard(cmd.input_dim) if in_dim is not None: - shardable_dims[in_dim.input_dim, :] = False - return size * cmd.times, None + shardable_dims[in_dim.input_dim] = [False] * mesh_ndim + return None else: - raise RuntimeError(f"cmd not found: {cmd}, in rule: {rule}") + return None - dim_map = {} - out_shape = [] + # for each output dim, find the corresponding input dim in terms of sharding prop + shard_dim_map = {} for dim, cmd in enumerate(rule): - out_size, in_dim = get_dim_size(cmd) - out_shape.append(out_size) + in_dim = get_in_dim_to_shard(cmd) if in_dim is not None: - dim_map[in_dim.input_dim] = dim + shard_dim_map[in_dim.input_dim] = dim - needs_reshard = any( - isinstance(placement, Shard) and not shardable_dims[placement.dim][mesh_dim] - for mesh_dim, placement in enumerate(in_shard) - ) - - output_placements = ( - None - if needs_reshard - else [Shard(dim_map[s.dim]) if isinstance(s, Shard) else s for s in in_shard] - ) + input_tgt_placements = [ + Replicate() + if isinstance(p, Shard) and not shardable_dims[p.dim][mesh_dim] + else p + for mesh_dim, p in enumerate(input_src_placements) + ] + output_placements = [ + Shard(shard_dim_map[p.dim]) if isinstance(p, Shard) else p + for p in input_tgt_placements + ] - return (tuple(out_shape), output_placements, shardable_dims) + return input_tgt_placements, output_placements -def register_prop_rule_map( +def register_op_strategy_map( aten_op_overload: torch._ops.OpOverload, local_op_name: Callable[..., torch.Tensor], schema_info: Optional[RuntimeSchemaInfo] = None, ) -> None: - spec: Op = ops[local_op_name] - - @register_prop_rule(aten_op_overload, schema_info=schema_info) - def reshape_prop(op_schema: OpSchema) -> OutputSharding: - rules = spec.dim_map(*op_schema.args_schema, **op_schema.kwargs_schema) - input_dtensor_spec = cast(DTensorSpec, op_schema.args_schema[0]) - mesh = input_dtensor_spec.mesh - - assert isinstance( - input_dtensor_spec, DTensorSpec - ), "Expected first input to be a DTensorSpec" - global_in_shape = input_dtensor_spec.shape + dim_map: Callable[..., DimMap] = dim_maps[local_op_name] + + @register_op_strategy(aten_op_overload, schema_info=schema_info) + def reshape_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: + rules = dim_map(*op_schema.args_schema, **op_schema.kwargs_schema) + input_strategy = cast(OpStrategy, op_schema.args_schema[0]) + global_in_shape = input_strategy.output_shape assert global_in_shape is not None, "Shape required." - with disable_proxy_modes_tracing(), unset_fake_temporarily(): - ( - global_out_shape, - shard_out, - shardable_dims, - ) = propagate_shape_and_sharding( - input_dtensor_spec.placements, + output_strategy = OpStrategy([]) + for input_placement_strategy in input_strategy.strategies: + input_src_spec = input_placement_strategy.output_spec + + input_tgt_placements, output_placements = propagate_shape_and_sharding( + input_src_spec.placements, tuple(global_in_shape), rules, mesh.shape, ) - if shard_out is not None: - # no reshard needed - output_dtensor_spec = DTensorSpec(mesh=mesh, placements=tuple(shard_out)) - - # We only need the local shape to lower the call into the local op - args = op_schema.args_schema - shape_argnum = spec.shape_argnum - if shape_argnum is not None: - # compute the local shape from the global shape, then return - # a resharding even if we don't really reshard, the only reason - # for this type of resharding is to lower the global shape to - # local shape - local_out_shape = compute_local_shape( - list(global_out_shape), mesh, shard_out - ) - - suggested_schema = OpSchema( - op=op_schema.op, - args_schema=args[:shape_argnum] - + (tuple(local_out_shape),) - + args[shape_argnum + 1 :], - kwargs_schema=op_schema.kwargs_schema, - ) - return OutputSharding( - output_spec=output_dtensor_spec, - redistribute_schema=suggested_schema, - needs_redistribute=True, - ) - - return OutputSharding(output_spec=output_dtensor_spec) - - else: # TODO: optimize this. we shouldn't simply blindly replicate # unshardable dims ... # FIXME: this can be wrong for situations where we have # [Shard(0), Shard(0)] - suggested_placements = [ - p - if not isinstance(p, Shard) or shardable_dims[p.dim][mesh_dim] - else Replicate() - for mesh_dim, p in enumerate(input_dtensor_spec.placements) + input_tgt_spec = DTensorSpec( + placements=tuple(input_tgt_placements), + mesh=input_src_spec.mesh, + tensor_meta=input_src_spec.tensor_meta, + ) + redistribute_costs = [ + generate_redistribute_costs(input_strategy, input_tgt_spec) ] - return OutputSharding( - output_spec=None, - redistribute_schema=OpSchema( - op=op_schema.op, - args_schema=( - DTensorSpec( - placements=tuple(suggested_placements), - mesh=input_dtensor_spec.mesh, - tensor_meta=input_dtensor_spec.tensor_meta, - ), - ) - + op_schema.args_schema[1:], - kwargs_schema=op_schema.kwargs_schema, - ), + + output_spec = DTensorSpec(mesh=mesh, placements=tuple(output_placements)) + output_strategy.strategies.append( + PlacementStrategy( + output_specs=output_spec, + input_specs=(input_tgt_spec,), + redistribute_cost=redistribute_costs, + ) ) + return output_strategy -register_prop_rule_map(aten.squeeze.default, torch.squeeze) -register_prop_rule_map( + +register_op_strategy_map(aten.squeeze.default, torch.squeeze) +register_op_strategy_map( aten.squeeze.dim, torch.squeeze, schema_info=RuntimeSchemaInfo(1) ) -register_prop_rule_map(aten.view.default, Tensor.view, schema_info=RuntimeSchemaInfo(1)) -register_prop_rule_map( +register_op_strategy_map( + aten.view.default, Tensor.view, schema_info=RuntimeSchemaInfo(1) +) +register_op_strategy_map( aten.reshape.default, torch.reshape, schema_info=RuntimeSchemaInfo(1) ) -register_prop_rule_map( +register_op_strategy_map( aten._unsafe_view.default, Tensor.view, schema_info=RuntimeSchemaInfo(1) ) -register_prop_rule_map( +register_op_strategy_map( aten.unsqueeze.default, torch.unsqueeze, schema_info=RuntimeSchemaInfo(1) ) -register_prop_rule_map( +register_op_strategy_map( aten.expand.default, Tensor.expand, schema_info=RuntimeSchemaInfo(1) ) -register_prop_rule_map( +register_op_strategy_map( aten.permute.default, torch.permute, schema_info=RuntimeSchemaInfo(1) ) -register_prop_rule_map( +register_op_strategy_map( aten.repeat.default, Tensor.repeat, schema_info=RuntimeSchemaInfo(1) ) -register_prop_rule_map( +register_op_strategy_map( aten.transpose.int, torch.transpose, schema_info=RuntimeSchemaInfo(1) ) -register_prop_rule_map(aten.view_as_complex.default, torch.view_as_complex) -register_prop_rule_map(aten.view_as_real.default, torch.view_as_real) +register_op_strategy_map(aten.view_as_complex.default, torch.view_as_complex) +register_op_strategy_map(aten.view_as_real.default, torch.view_as_real) diff --git a/torch/distributed/_tensor/sharding_prop.py b/torch/distributed/_tensor/sharding_prop.py index 9acf6aa0c9195..d173a91a771c0 100644 --- a/torch/distributed/_tensor/sharding_prop.py +++ b/torch/distributed/_tensor/sharding_prop.py @@ -45,15 +45,21 @@ def __init__(self) -> None: # op map to save static argnum to decide to reuse sharding prop cache or re-run sharding prop self.op_to_schema_info: Dict[OpOverload, RuntimeSchemaInfo] = {} self.propagate_op_sharding = lru_cache(None)(self.propagate_op_sharding_non_cached) # type: ignore[method-assign] - # op map to save indices of size (and stride) args which may need to be modified in sharding prop - self.op_to_size_and_stride_idx: Dict[ + # op map to save indices of shape (and stride) args which may need to be modified in sharding prop + self.op_to_shape_and_stride_idx: Dict[ OpOverload, Union[int, Tuple[int, int]] ] = { + # new factory ops aten.new_empty.default: 1, aten.new_full.default: 1, aten.new_ones.default: 1, aten.new_zeros.default: 1, aten.new_empty_strided.default: (1, 2), + # view ops + aten.expand.default: 1, + aten.reshape.default: 1, + aten.view.default: 1, + aten._unsafe_view.default: 1, } def register_sharding_prop_rule( @@ -260,16 +266,19 @@ def spec_to_strategy(spec: object) -> object: ) suggestion_schema._inplace_rewrap_schema_suggestion(op_schema) - # size and stride args need to be modified for new factory ops, potentially - if op_schema.op in self.op_to_size_and_stride_idx: + # shape and stride args need to be modified for + # view ops and new factory ops, potentially + if op_schema.op in self.op_to_shape_and_stride_idx: assert isinstance(output_strategy.output_spec, DTensorSpec) # It happens when the output has the same shape as the input # and the input placements are not all Replicate(). if output_strategy.output_spec.is_sharded(): - needs_redistribute = True - suggestion_schema = self._adjust_size_and_stride_args( - op_schema, output_strategy.output_spec, mesh + schema = suggestion_schema or op_schema + assert isinstance(out_tensor_meta, TensorMeta) + suggestion_schema = self._adjust_shape_and_stride_args( + out_tensor_meta, schema, output_strategy.output_spec, mesh ) + needs_redistribute = True # construct output spec for the op if op_schema.return_type_tuple_tensor_like(): @@ -442,29 +451,31 @@ def _select_strategy(self, strategy: OpStrategy) -> PlacementStrategy: # for eager execution, we just select the one with the minimal redistribute cost return strategy.strategies[strategy_costs.index(min(strategy_costs))] - def _adjust_size_and_stride_args( - self, op_schema: OpSchema, spec: DTensorSpec, mesh: DeviceMesh + def _adjust_shape_and_stride_args( + self, + out_tensor_meta: TensorMeta, + schema: OpSchema, + spec: DTensorSpec, + mesh: DeviceMesh, ) -> OpSchema: - size_stride_idx = self.op_to_size_and_stride_idx[op_schema.op] - if isinstance(size_stride_idx, tuple): - size_idx, stride_idx = size_stride_idx + shape_stride_idx = self.op_to_shape_and_stride_idx[schema.op] + if isinstance(shape_stride_idx, tuple): + shape_idx, stride_idx = shape_stride_idx else: - size_idx = size_stride_idx + shape_idx = shape_stride_idx stride_idx = None - expected_input_schema = list(op_schema.args_schema) - size = cast(list, expected_input_schema[size_idx]) - # # adjust size to be the same as that of the _local_tensor - # # of the DTensor input arg at index 0, which is inferred - expected_input_schema[size_idx] = compute_local_shape( - size, mesh, spec.placements + expected_input_schema = list(schema.args_schema) + # adjust shape to be the same as that of the _local_tensor + # of the DTensor input arg at index 0, which is inferred + expected_input_schema[shape_idx] = compute_local_shape( + out_tensor_meta.shape, mesh, spec.placements ) # adjust the stride arg for aten.new_empty_strided.default if stride_idx: - stride = cast(list, expected_input_schema[stride_idx]) expected_input_schema[stride_idx] = compute_local_stride( - stride, mesh, spec.placements + out_tensor_meta.stride, mesh, spec.placements ) - return OpSchema(op_schema.op, tuple(expected_input_schema), {}) + return OpSchema(schema.op, tuple(expected_input_schema), schema.kwargs_schema) diff --git a/torch/distributed/checkpoint/state_dict.py b/torch/distributed/checkpoint/state_dict.py index e0a4d8886fc7d..e7072d6230126 100644 --- a/torch/distributed/checkpoint/state_dict.py +++ b/torch/distributed/checkpoint/state_dict.py @@ -135,7 +135,6 @@ class StateDictOptions: - ``strict``: the ``strict`` option when ``set_state_dict`` calls model.load_state_dict(). - The default value is False. - ``broadcast_from_rank0``: when the option is True, rank0 should receive a full state_dict and will broadcast the tensors in the state_dict/ diff --git a/torch/distributed/device_mesh.py b/torch/distributed/device_mesh.py index 05dc7710215de..c0981a549c6b8 100644 --- a/torch/distributed/device_mesh.py +++ b/torch/distributed/device_mesh.py @@ -209,7 +209,7 @@ def __init__( self.mesh = ( mesh.detach().to(dtype=torch.int) if isinstance(mesh, torch.Tensor) - else torch.tensor(mesh, dtype=torch.int) + else torch.tensor(mesh, device="cpu", dtype=torch.int) ) self.mesh_dim_names = tuple(mesh_dim_names) if mesh_dim_names else None @@ -451,21 +451,67 @@ def get_group( return dim_groups @staticmethod - def from_group(group: ProcessGroup, device_type: str) -> "DeviceMesh": + def from_group( + group: Union[ProcessGroup, List[ProcessGroup]], + device_type: str, + mesh: Optional[Union[torch.Tensor, "ArrayLike"]] = None, + *, + mesh_dim_names: Optional[Tuple[str, ...]] = None, + ) -> "DeviceMesh": """ Contstructs a :class:`DeviceMesh` with ``device_type`` from an existing :class:`ProcessGroup`. - The constructed device mesh is assumed to be 1D. + The constructed device mesh has number of dimensions equal to the + number of groups passed. If more than one group is passed, then the + ``mesh`` argument is required. """ - # Manually define `_dim_group_infos` instead of relying on the - # normal logic since we already have the PG - group_ranks = get_process_group_ranks(group) - mesh = DeviceMesh(device_type, group_ranks, _init_backend=False) - mesh._dim_group_infos = [ - (_get_group_tag(group), group_ranks, group.group_name) + if isinstance(group, ProcessGroup): + group_ranks = get_process_group_ranks(group) + if ( + isinstance(mesh, torch.Tensor) and mesh.tolist() != group_ranks + ) or (mesh is not None and mesh != group_ranks): + raise ValueError( + f"Invalid mesh {str(mesh)} for ProcessGroup with ranks {group_ranks}" + ) + mesh = torch.tensor(group_ranks, device="cpu", dtype=torch.int) + device_mesh = DeviceMesh( + device_type, + mesh, + mesh_dim_names=mesh_dim_names, + _init_backend=False, + ) + device_mesh._dim_group_infos = [ + (_get_group_tag(group), group_ranks, group.group_name) + ] + return device_mesh + groups = list(group) + if len(groups) == 0: + raise ValueError("Expects at least one ProcessGroup to be passed") + if mesh is None: + raise ValueError("Must pass mesh if passing multiple ProcessGroups") + mesh = ( + mesh.detach().to(dtype=torch.int, device="cpu") + if isinstance(mesh, torch.Tensor) + else torch.tensor(mesh, device="cpu", dtype=torch.int) + ) + if mesh.ndim != len(groups): + raise ValueError( + "Expects mesh with ndim equal to number of ProcessGroups but got " + f"mesh {mesh.tolist()} and {len(groups)} ProcessGroups" + ) + device_mesh = DeviceMesh( + device_type, mesh, mesh_dim_names=mesh_dim_names, _init_backend=False + ) + device_mesh._dim_group_infos = [ + ( + _get_group_tag(group), + get_process_group_ranks(group), + group.group_name, + ) + for group in groups ] - return mesh + return device_mesh def size(self, mesh_dim: Optional[int] = None) -> int: return self.mesh.numel() if mesh_dim is None else self.mesh.size(mesh_dim) diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 9e86551514228..70283cada9287 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -408,6 +408,21 @@ def __new__(cls, op: Callable, tensor: torch.Tensor, peer: int, _check_single_tensor(tensor, "tensor") return object.__new__(cls) + def __repr__(self): + my_group_rank = get_rank(self.group) + peer_group_rank = get_group_rank(self.group, self.peer) if self.group else self.peer + op_name = self.op.__name__ + group_name = self.group.group_name if self.group else "default_pg" + if "send" in op_name: + s = my_group_rank + d = peer_group_rank + elif "recv" in op_name: + s = peer_group_rank + d = my_group_rank + else: + return super().__repr__() + + return f"P2POp({op_name} pg={group_name}, s={s}, d={d}, {self.tensor.shape}, {self.tensor.dtype})" class _CollOp: """ @@ -737,7 +752,7 @@ def _store_based_barrier(rank, store, group_name, rendezvous_count, timeout, log ) if timedelta(seconds=(time.time() - start)) > timeout: - raise DistStoreError( # noqa: TRY200 + raise DistStoreError( # noqa: B904 "Timed out initializing process group in store based barrier on " f"rank {rank}, for key: {store_key} (world_size={world_size}, " f"num_workers_joined={worker_count}, timeout={timeout} error={e})" diff --git a/torch/distributed/elastic/multiprocessing/api.py b/torch/distributed/elastic/multiprocessing/api.py index 4c584dc32e700..eb0b110f25ee9 100644 --- a/torch/distributed/elastic/multiprocessing/api.py +++ b/torch/distributed/elastic/multiprocessing/api.py @@ -246,7 +246,7 @@ def __init__( if not log_dir: log_dir = tempfile.mkdtemp(prefix="torchelastic_") elif not os.path.exists(log_dir): - os.makedirs(log_dir) + os.makedirs(log_dir, exist_ok=True) else: if os.path.isfile(log_dir): raise NotADirectoryError(f"log_dir: {log_dir} is a file") diff --git a/torch/distributed/fsdp/_common_utils.py b/torch/distributed/fsdp/_common_utils.py index 7d9394ef1fbda..c1d77bf410b59 100644 --- a/torch/distributed/fsdp/_common_utils.py +++ b/torch/distributed/fsdp/_common_utils.py @@ -421,29 +421,14 @@ def f(module: torch.nn.Module, prefix: str, tree_level: int, *args, **kwargs): # ``named_children`` + `named_parameter(recurse=False)``. # This hack is a must to make the traversal work. # TODO: Remove this hack once DMP + FSDP is not supported. + # It turns out that recursive wrapping may trigger this as + # well. if ( submodule_name == "_fsdp_wrapped_module" or submodule_name == "_dmp_wrapped_module" ): - if ( - not torch.distributed._functional_collectives.is_torchdynamo_compiling() - ): - # TODO(voz): Don't graph break on this - warnings.warn( - "An unexpected prefix is detected. This case " - " should only happen when using DMP with FSDP. " - f"prefix = {prefix}, " - f"submodule_name = {submodule_name}" - ) new_prefix = prefix elif submodule_name == "module": - warnings.warn( - "An unexpected prefix is detected. This case " - " should only happen when DDP wraps the outer " - " modules while FSDP wraps the inner ones." - f"prefix = {prefix}, " - f"submodule_name = {submodule_name}" - ) new_prefix = prefix f(submodule, new_prefix, new_tree_level, *args, **kwargs) diff --git a/torch/distributed/fsdp/_debug_utils.py b/torch/distributed/fsdp/_debug_utils.py index 4ed76476e56b6..a41a817724e57 100644 --- a/torch/distributed/fsdp/_debug_utils.py +++ b/torch/distributed/fsdp/_debug_utils.py @@ -57,7 +57,7 @@ def dump_and_reset(cls, msg: str) -> None: # This cannot be combined with DETAIL distributed log # as the profiling will be very incorrect. if dist.get_rank() == 0 and dist.get_debug_level() == dist.DebugLevel.INFO: - logger.warning("%s %s", msg, cls.results) + logger.info("%s %s", msg, cls.results) cls.reset() diff --git a/torch/distributed/fsdp/_optim_utils.py b/torch/distributed/fsdp/_optim_utils.py index 163cde70b3f97..b066f930ebaf5 100644 --- a/torch/distributed/fsdp/_optim_utils.py +++ b/torch/distributed/fsdp/_optim_utils.py @@ -1511,7 +1511,7 @@ def _allgather_orig_param_states( """ fsdp_state = fsdp_param_info.state if fsdp_state.rank == 0 and dist.get_debug_level() == dist.DebugLevel.DETAIL: - logger.warning( + logger.info( "Memory Summary before calling to _allgather_orig_param_states %s", fsdp_state._device_handle.memory_summary(), ) diff --git a/torch/distributed/pipelining/PipelineSchedule.py b/torch/distributed/pipelining/PipelineSchedule.py index bf4116688bff9..c8a256299bd52 100644 --- a/torch/distributed/pipelining/PipelineSchedule.py +++ b/torch/distributed/pipelining/PipelineSchedule.py @@ -411,34 +411,71 @@ def _step_microbatches( fwd_sends_to_wait: List[dist.Work] = [] bwd_sends_to_wait: List[dist.Work] = [] + def is_forward_step(i): + assert i >= 0, i + return i < self._n_microbatches + + def is_backward_step(i): + assert i < total_steps, i + return i >= warmup_steps and self._has_backward + + def is_1f1b_step(i): + return is_forward_step(i) and is_backward_step(i) + + def is_warmup_step(i): + return is_forward_step(i) and not is_backward_step(i) + + def is_cooldown_step(i): + return not is_forward_step(i) and is_backward_step(i) + + def should_coalesce_fwd_send_bwd_recv(fwd_send_i): + return ( + is_1f1b_step(fwd_send_i) + or (is_warmup_step(fwd_send_i) and is_cooldown_step(fwd_send_i + 1)) + or ( + fwd_send_i >= 1 + and is_warmup_step(fwd_send_i - 1) + and is_cooldown_step(fwd_send_i) + ) + ) + + def should_coalesce_bwd_send_fwd_recv(bwd_send_i): + # The backward send to prev stage should be coalesced with the fwd recv from the previous stage + return bwd_send_i >= warmup_steps and is_1f1b_step(bwd_send_i + 1) + # bwd chunk counter bwd_mb_index = 0 self._stage._configure_data_parallel_mode(last_backward=False) for i in range(total_steps): - if i < self._n_microbatches: - # forward + if is_forward_step(i): with record_function(f"Forward {i}"): ops = self._stage.get_fwd_recv_ops() + if should_coalesce_bwd_send_fwd_recv(i - 1): + ops.extend(self._stage.get_bwd_send_ops()) + works = sorted_batch_isend_irecv(ops) for work in works.values(): work.wait() output = self._stage.forward_one_chunk(arg_mbs[i], kwarg_mbs[i]) # type: ignore[index] - ops = self._stage.get_fwd_send_ops() - works = sorted_batch_isend_irecv(ops) - fwd_sends_to_wait.extend(works.values()) + if not should_coalesce_fwd_send_bwd_recv(i): + ops = self._stage.get_fwd_send_ops() + works = sorted_batch_isend_irecv(ops) + fwd_sends_to_wait.extend(works.values()) self._maybe_compute_loss(self._stage, output, target_mbs, i) - if i >= warmup_steps and self._has_backward: + if is_backward_step(i): self._stage._configure_data_parallel_mode( last_backward=(i == total_steps - 1) ) - - # backward with record_function(f"Backward {bwd_mb_index}"): ops = self._stage.get_bwd_recv_ops() + + if should_coalesce_fwd_send_bwd_recv(i): + ops.extend(self._stage.get_fwd_send_ops()) + works = sorted_batch_isend_irecv(ops) for work in works.values(): work.wait() @@ -446,9 +483,12 @@ def _step_microbatches( loss = self._maybe_get_loss(self._stage, bwd_mb_index) self._stage.backward_one_chunk(loss=loss) - ops = self._stage.get_bwd_send_ops() - works = sorted_batch_isend_irecv(ops) - bwd_sends_to_wait.extend(works.values()) + if not should_coalesce_bwd_send_fwd_recv(i): + # see Note: coalesced bwd-send/fwd-recv + ops = self._stage.get_bwd_send_ops() + works = sorted_batch_isend_irecv(ops) + bwd_sends_to_wait.extend(works.values()) + bwd_mb_index += 1 # Wait for all forward sends to finish diff --git a/torch/distributed/pipelining/_IR.py b/torch/distributed/pipelining/_IR.py index 799ec6d5e0d17..204a60a340228 100644 --- a/torch/distributed/pipelining/_IR.py +++ b/torch/distributed/pipelining/_IR.py @@ -11,12 +11,13 @@ import torch import torch.fx as fx from torch.export import ExportedProgram +from torch.export.unflatten import _assign_attr, _AttrKind, _sink_params from torch.fx.node import map_aggregate from torch.fx.passes.split_module import split_module from ._backward import _null_coalesce_accumulate, stage_backward from ._debug import PIPPY_VERBOSITY -from ._unflatten import _assign_attr, _AttrKind, _outline_submodules, _sink_params +from ._unflatten import _outline_submodules from ._utils import QualnameMapMixin from .microbatch import split_args_kwargs_into_chunks, TensorChunkSpec @@ -303,7 +304,7 @@ def _pipe_split(): return None -@torch.library.impl_abstract("pippy::_pipe_split") # type: ignore[no-redef] +@torch.library.register_fake("pippy::_pipe_split") # type: ignore[no-redef] def _pipe_split(): # noqa: F811 return None @@ -869,8 +870,8 @@ def move_param_to_callee( # After moving the params to their corresponding hierarchies, we also # need to move the `get_attr` nodes from the root of the graph to those # hierarchies. - inputs_to_state: Dict[str, str] = { - attr.name: attr.target for attr in attr_nodes + inputs_to_state: Dict[str, List[str]] = { + attr.name: [attr.target] for attr in attr_nodes } # This is done by (1) `_sind_params` at each submodule; for name, submod in split.named_children(): @@ -1281,7 +1282,7 @@ def annotate_split_points(mod: torch.nn.Module, spec: Dict[str, SplitPoint]): except AttributeError as e: raise AttributeError( f'Specified target {qualname} referenced nonexistent module {".".join(atoms[:i+1])}' - ) + ) from e mod_to_wrap = getattr(predecessor_module, atoms[-1]) mod_to_wrap._orig_forward = mod_to_wrap.forward diff --git a/torch/distributed/pipelining/_unflatten.py b/torch/distributed/pipelining/_unflatten.py index 684fcfbc1d6d7..27241d17874c4 100644 --- a/torch/distributed/pipelining/_unflatten.py +++ b/torch/distributed/pipelining/_unflatten.py @@ -1,453 +1,8 @@ # Copyright (c) Meta Platforms, Inc. and affiliates -# This file is a copy of private utilities in pytorch/torch/export/unflatten.py -# pylint: skip-file - -import copy -import operator -from enum import Enum -from typing import cast, Dict, List, Optional, Union +from typing import Dict import torch -import torch.fx._pytree as fx_pytree -import torch.utils._pytree as pytree -from torch.export.exported_program import ( - ConstantArgument, - ModuleCallSignature, - SymIntArgument, - TensorArgument, -) -from torch.export.unflatten import InterpreterModule - - -class _AttrKind(Enum): - PARAMETER = "parameter" - BUFFER = "buffer" - CONSTANT = "constant" - - -# Assign attribute 'from_obj' to the qualified name 'target' on 'to_module -# This installs empty Modules where none exist yet if they are subpaths of target -def _assign_attr( - from_obj: Union[torch.Tensor, torch.ScriptObject], - to_module: torch.nn.Module, - target: str, - attr_kind: _AttrKind, - persistent: bool = True, -): - *prefix, field = target.split(".") - for item in prefix: - t = getattr(to_module, item, None) - - if t is None: - t = torch.nn.Module() - setattr(to_module, item, t) - to_module = t - - if attr_kind == _AttrKind.PARAMETER: - assert isinstance(from_obj, torch.nn.Parameter) - to_module.register_parameter(field, from_obj) - elif attr_kind == _AttrKind.BUFFER: - assert isinstance(from_obj, torch.Tensor) - to_module.register_buffer(field, from_obj, persistent=persistent) - elif attr_kind == _AttrKind.CONSTANT: - assert isinstance(from_obj, (torch.Tensor, torch.ScriptObject)) - setattr(to_module, field, from_obj) - - -def _is_prefix(candidate, target): - """Check whether `candidate` is a prefix of `target`.""" - return len(candidate) < len(target) and target[: len(candidate)] == candidate - - -def _compute_accessor(parent_fqn: str, child_fqn: str) -> str: - if parent_fqn == "": - # Handle the root module correctly. - return child_fqn - - parent_split = parent_fqn.split(".") - child_split = child_fqn.split(".") - - assert ( - child_split[: len(parent_split)] == parent_split - ), f"Child module '{child_fqn}' is not a descendant of parent module '{parent_fqn}'" - return ".".join(child_split[len(parent_split) :]) - - -def _verify_graph_equivalence(x: torch.nn.Module, y: torch.nn.Module): - def graph_dump(graph: torch.fx.Graph) -> str: - ret = [] - nodes_idx: Dict[int, int] = {} - - def arg_dump(arg) -> str: - if isinstance(arg, torch.fx.Node): - return "%" + str(nodes_idx[id(arg)]) - return str(arg) - - for i, node in enumerate(graph.nodes): - args_dump = [str(arg) for arg in pytree.tree_map(arg_dump, node.args)] - args_dump += [ - f"{key}={value}" - for key, value in pytree.tree_map(arg_dump, node.kwargs).items() - ] - target = node.target if node.op == "call_function" else "" - ret.append(f"{i}: {node.op}[{target}]({', '.join(args_dump)})") - nodes_idx[id(node)] = i - return "\n".join(ret) - - assert graph_dump(x.graph) == graph_dump(y.graph) - - -def _add_spec(gm: torch.nn.Module, spec) -> str: - i = 0 - while hasattr(gm, f"_spec_{i}"): - i += 1 - name = f"_spec_{i}" - setattr(gm, name, spec) - return name - - -def _generate_flatten(gm: torch.nn.Module, node, spec) -> torch.fx.Node: - name = _add_spec(gm, spec) - spec_node = gm.graph.get_attr(name) - return gm.graph.call_function(fx_pytree.tree_flatten_spec, (node, spec_node)) - - -def _generate_unflatten(gm: torch.nn.Module, nodes, spec) -> torch.fx.Node: - name = _add_spec(gm, spec) - spec_node = gm.graph.get_attr(name) - return gm.graph.call_function(pytree.tree_unflatten, (nodes, spec_node)) - - -def _add_submodule(mod: torch.nn.Module, target: str, module_to_add: torch.nn.Module): - *prefix, field = target.split(".") - - for item in prefix: - submod = getattr(mod, item, None) - - if submod is None: - submod = torch.nn.Module() - setattr(mod, item, submod) - - if not isinstance(submod, torch.nn.Module): - return False - - mod = submod - - mod.add_module(field, module_to_add) - - -class _ModuleFrame: - def __init__( - self, - flat_graph, - nodes, - seen_nodes, - seen_modules, - parent, - module_stack, - module_id, - module_call_graph: Optional[Dict[str, ModuleCallSignature]] = None, - module: Optional[torch.nn.Module] = None, - ): - self.flat_graph = flat_graph - self.nodes = nodes - self.seen_nodes = seen_nodes - self.seen_modules = seen_modules - self.parent = parent - self.module_stack = module_stack - self.module_id = module_id - - self.module_call_graph = module_call_graph - self.verbose = False - - self.fqn = self.module_stack[-1] - if module is not None: - self.module = module - else: - self.module = InterpreterModule(torch.fx.Graph()) - if self.module_id in self.seen_modules: - self.cached_graph_module = self.seen_modules[self.module_id] - else: - self.cached_graph_module = None - self.seen_modules[self.module_id] = self.module - - self.graph = self.module.graph - - # Mapping of nodes in the flat graph to nodes in this graph. - self.node_map: Dict[torch.fx.Node, torch.fx.Node] = {} - self.node_to_placeholder = {} - - self.parent_call_module: Optional[torch.fx.Node] = None - if parent is not None: - accessor = _compute_accessor(parent.fqn, self.fqn) - _add_submodule( - parent.module, - accessor, - self.module - if self.cached_graph_module is None - else self.cached_graph_module, - ) - self.parent_call_module = parent.graph.call_module(accessor) - - signature = self.get_signature() - - if signature is not None and self.parent is not None: - assert signature.in_spec.num_children == 2 - args_spec = signature.in_spec.children_specs[0] - kwargs_spec = signature.in_spec.children_specs[1] - assert args_spec.context is None - assert kwargs_spec.context is not None - - with self.graph.inserting_after(None): - arg_nodes = [] - for idx in range(args_spec.num_children): - arg_nodes.append(self.graph.placeholder(f"_positional_arg_{idx}")) - kwarg_nodes = {} - for name in kwargs_spec.context: - kwarg_nodes[name] = self.graph.placeholder(name) - flat_args = _generate_flatten( - self.module, - (tuple(arg_nodes), kwarg_nodes), - signature.in_spec, - ) - for idx, arg in enumerate(signature.inputs): - flat_arg_node = self.graph.create_node( - op="call_function", - target=operator.getitem, - args=(flat_args, idx), - name=arg.name - if not isinstance(arg, ConstantArgument) - else f"_constant_{idx}", - ) - if isinstance(arg, ConstantArgument): - continue - flat_arg_node.meta = copy.copy(self.seen_nodes[arg.name].meta) - self.node_to_placeholder[self.seen_nodes[arg.name]] = flat_arg_node - - with self.parent.graph.inserting_before(self.parent_call_module): - input_nodes: List[Optional[torch.fx.Node]] = [] - for input in signature.inputs: - if isinstance(input, ConstantArgument) and input.value is None: - input_nodes.append(None) - else: - assert isinstance(input, (TensorArgument, SymIntArgument)) - input_nodes.append( - self.parent.remap_input(self.seen_nodes[input.name]) - ) - - inputs_node = _generate_unflatten( - self.parent.module, - input_nodes, - signature.in_spec, - ) - - args_node = self.parent.graph.call_function( - operator.getitem, (inputs_node, 0) - ) - kwargs_node = self.parent.graph.call_function( - operator.getitem, (inputs_node, 1) - ) - arg_nodes = [ - self.parent.graph.call_function(operator.getitem, (args_node, i)) - for i in range(args_spec.num_children) - ] - kwarg_nodes = { - k: self.parent.graph.call_function( - operator.getitem, (kwargs_node, k) - ) - for k in kwargs_spec.context - } - assert self.parent_call_module is not None - self.parent_call_module.args = tuple(arg_nodes) - self.parent_call_module.kwargs = kwarg_nodes - - def add_placeholder(self, x): - assert x.graph is self.flat_graph - # x is not in subgraph, create a new placeholder for subgraph - with self.graph.inserting_before(None): - placeholder_node = self.graph.placeholder(x.name, type_expr=x.type) - # copy all meta fields, even if some fields might be irrelvant for - # the placeholder node - placeholder_node.meta = copy.copy(x.meta) - self.node_to_placeholder[x] = placeholder_node - - def remap_input(self, x): - assert x.graph is self.flat_graph - if x in self.node_map: - return self.node_map[x] - if x not in self.node_to_placeholder: - self.add_placeholder(x) - if self.parent_call_module is not None: - # Important to *prepend* the output to match how we are - # inserting placeholder nodes. - self.parent_call_module.insert_arg(0, self.parent.remap_input(x)) - return self.node_to_placeholder[x] - - def get_signature(self): - if self.module_call_graph is not None: - return self.module_call_graph.get(self.fqn) - return None - - def finalize_outputs(self): - orig_outputs = [] - signature = self.get_signature() - - if signature is not None and self.parent is not None: - for output in signature.outputs: - if isinstance(output, (TensorArgument, SymIntArgument)): - orig_outputs.append(self.seen_nodes[output.name]) - else: - raise RuntimeError( - f"Unsupported data type for output node: {output}" - ) - - tree_out_node = _generate_unflatten( - self.module, - tuple( - self.node_map[self.seen_nodes[output.name]] - for output in orig_outputs - ), - signature.out_spec, - ) - parent_out: Optional[torch.fx.Node] = _generate_flatten( - self.parent.module, self.parent_call_module, signature.out_spec - ) - graph_outputs: Union[torch.fx.Node, List[torch.fx.Node]] = tree_out_node - else: - graph_outputs = [] - # Iterate through nodes we have copied into self.graph. - for orig_node in self.node_map.keys(): - for user_node in orig_node.users: - if user_node.name not in self.seen_nodes: - # external user node, need to expose as an output - orig_outputs.append(orig_node) - graph_outputs.append(self.node_map[orig_node]) - break - - parent_out = self.parent_call_module - if len(graph_outputs) == 1: - graph_outputs = graph_outputs[0] - - assert isinstance(graph_outputs, (list, torch.fx.Node)) - - self.graph.output(graph_outputs) - - # Rewrite outputs in parent module - if parent_out is None: - return - - parent_out.meta["val"] = ( - graph_outputs.meta.get("val") - if isinstance(graph_outputs, torch.fx.Node) - else [o.meta.get("val") for o in graph_outputs] - ) - - if len(orig_outputs) == 1 and signature is None: - self.parent.node_map[orig_outputs[0]] = parent_out - else: - for i, orig_output in enumerate(orig_outputs): - # Use Proxy to record getitem access. - proxy_out = torch.fx.Proxy(parent_out)[i].node # type: ignore[index] - proxy_out.meta["val"] = orig_output.meta.get("val") - self.parent.node_map[orig_output] = proxy_out - - if self.cached_graph_module is not None: - _verify_graph_equivalence(self.cached_graph_module, self.module) - - def copy_node(self, node): - self.print("copying", node.format_node()) - self.node_map[node] = self.graph.node_copy(node, self.remap_input) - self.seen_nodes[node.name] = node - - def run_outer(self): - i = 0 - for node in self.flat_graph.nodes: - self.print(i, node.meta.get("nn_module_stack"), node.format_node()) - i += 1 - - # Copy all graph inputs - node_idx: int = 0 - node = self.nodes[node_idx] - while node.op == "placeholder": - self.copy_node(node) - node_idx += 1 - node = self.nodes[node_idx] - - self.run_from(node_idx) - - # Copy graph outputs - for node in self.flat_graph.nodes: - if node.op == "output": - self.copy_node(node) - - def print(self, *args, **kwargs): - if self.verbose: - print(*args, **kwargs) - - def run_from(self, node_idx): - module_idx = 0 - # Walk through the graph, building up a new graph with the right submodules - while node_idx < len(self.nodes): - node = self.nodes[node_idx] - assert node.op != "placeholder" - - self.print() - self.print("STEP", node_idx, node.format_node()) - self.print(self.module_stack) - if node.op == "output": - if len(self.module_stack) == 1: - # We want the output node of the original graph to be handled - # specially by the outermost stack frame (in run_outer). So - # skip finalization here. - return node_idx - - # We've reached the end of the graph. Wrap up all the existing stack frames. - self.finalize_outputs() - return node_idx - - node_module_stack = ( - [path for path, ty in node.meta["nn_module_stack"].values()] - if "nn_module_stack" in node.meta - else self.module_stack - ) - if node_module_stack[: len(self.module_stack)] != self.module_stack: - # This means that the current module is done executing and the - # current node is the beginning of a new module. - # - # In this case, we should finalize this module and return without - # incrementing the node counter. - self.finalize_outputs() - self.print("outlining", self.fqn) - self.print(self.graph) - return node_idx - - assert node_module_stack is not None - - if _is_prefix(self.module_stack, node_module_stack): - # This means that the current node represents the execution of a new - # module. - next_module = node_module_stack[len(self.module_stack)] - self.print("Creating new stack frame for", next_module) - # Run a nested version of module outliner from the current node - # counter. Once it is complete, continue from that point. - node_idx = _ModuleFrame( - self.flat_graph, - self.nodes, - self.seen_nodes, - self.seen_modules, - self, - self.module_stack + [next_module], - list(node.meta["nn_module_stack"].keys())[len(self.module_stack)], - self.module_call_graph, - ).run_from(node_idx) - module_idx += 1 - continue - - # The only remaining possibility is that we are in the right stack - # frame. Copy the node into this frame's graph and increment the node counter. - assert node_module_stack == self.module_stack - self.copy_node(node) - node_idx += 1 +from torch.export.unflatten import _ModuleFrame def _outline_submodules(orig_graph: torch.fx.Graph): @@ -463,80 +18,9 @@ def _outline_submodules(orig_graph: torch.fx.Graph): None, [""], "", + {}, module=new_module, ).run_outer() new_module.graph.lint() new_module.recompile() return new_module - - -def _sink_params( - module: torch.nn.Module, - inputs_to_state: Dict[str, str], - scope: List[str], -): - """Sink params, buffers, and constants from graph inputs into get_attr nodes. - - Exported modules are purely functional, so they pass their parameters and - buffers in as inputs to the graph. - - To replicate eager's semantics, we need to get them from the module state - via get_attr instead. - - module: GraphModule, potentially containining nested submodules. - inputs_to_state: mapping graph input names to the corresponding key in the state_dict. - scope: tracks where we are in the module hierarchy, so that we can emit the - right `getattr(self, "foo.bar")` calls, etc. - """ - # We need to use _modules here instead of named_children(), because we - # explicitly want duplicate modules to show up in the traversal. - for name, submodule in module._modules.items(): - _sink_params(cast(torch.nn.Module, submodule), inputs_to_state, scope + [name]) - - if not hasattr(module, "graph"): - # Not all modules have graphs defined, if they are empty modules with no operations (like ParameterList) - return - - graph = module.graph - inputs = list(filter(lambda n: n.op == "placeholder", graph.nodes)) - the_last_input = inputs[-1] - - # Also remove from call_module nodes - call_module_nodes = filter(lambda n: n.op == "call_module", graph.nodes) - for node in call_module_nodes: - node.args = tuple(filter(lambda n: n.name not in inputs_to_state, node.args)) - - for node in inputs: - if node.name not in inputs_to_state: - continue - - if len(node.users) > 0: - state_name = inputs_to_state[node.name].split(".") - # If there's a mismatch beteewn scope name and state name, then there must be multuple scopes - # pointing to the same state name, meaning some modules are shared. In such case, we can simply - # skip updating the current node because another later iteration will take care of this input - # node when the unique match between scope and state name occurs. - # To make sure this always happen, we should enforce the invariant that no placeholder node - # in the unflattened graph appears in inputs_to_state dict, which means all the extra input - # nodes have been handled. - if state_name[: len(scope)] != scope: - continue - attr_path = state_name[len(scope) :] - state_attr = _recursive_getattr(module, attr_path) - assert isinstance(state_attr, (torch.Tensor, torch.ScriptObject)) - - # Make sure the newly created get_attr node is placed after the last placeholder node - with graph.inserting_after(the_last_input): - new_node = graph.create_node("get_attr", ".".join(attr_path)) - - node.replace_all_uses_with(new_node, propagate_meta=True) - graph.erase_node(node) - if isinstance(module, InterpreterModule): - module.finalize() - - -def _recursive_getattr(obj, attr_path): - for attr in attr_path: - obj = getattr(obj, attr) - - return obj diff --git a/torch/export/_trace.py b/torch/export/_trace.py index 72ae886a392c5..c85a82c8c4c58 100644 --- a/torch/export/_trace.py +++ b/torch/export/_trace.py @@ -452,9 +452,9 @@ def _export_to_torch_ir( **kwargs, ) except (ConstraintViolationError, ValueRangeError) as e: - raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e)) # noqa: TRY200 + raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e)) # noqa: B904 except GuardOnDataDependentSymNode as e: - raise UserError( # noqa: TRY200 + raise UserError( # noqa: B904 UserErrorType.ANTI_PATTERN, f"Consider annotating your code using torch._check*(). {str(e)}", case_name="constrain_as_size_example", @@ -468,7 +468,7 @@ def _export_to_torch_ir( return gm_torch_level -def _export_non_strict( +def _export_to_aten_ir( mod: torch.nn.Module, fake_args, fake_kwargs, @@ -478,6 +478,7 @@ def _export_non_strict( transform=lambda x: x, # TODO(zhxchen17) Revisit if this is needed later. pre_dispatch=False, should_insert_runtime_assertion=False, + _is_torch_jit_trace=False, ): # [NOTE] If the user is exporting under training mode, we want to detect if there is any # state change in the autograd global state and error. If the user is exporting under inference @@ -632,19 +633,21 @@ def make_argument_spec(i, node) -> ArgumentSpec: constants = rewrite_script_object_meta(gm) constants.update(lift_constants_pass(gm, export_graph_signature, constant_attrs)) - # prettify names for placeholder nodes - placeholder_naming_pass( - gm, - export_graph_signature, - mod, - fake_args, - fake_kwargs, - fake_params_buffers, - constants, - ) + # FIXME: Skipping this because traced modules do not have signature yet + if not _is_torch_jit_trace: + # prettify names for placeholder nodes + placeholder_naming_pass( + gm, + export_graph_signature, + mod, + fake_args, + fake_kwargs, + fake_params_buffers, + constants, + ) @dataclasses.dataclass - class _ExportedProgramNonStrict: + class _ExportedArtifact: gm: torch.fx.GraphModule sig: ExportGraphSignature constants: Dict[ @@ -656,7 +659,7 @@ class _ExportedProgramNonStrict: ], ] - return _ExportedProgramNonStrict( + return _ExportedArtifact( gm, export_graph_signature, constants, @@ -889,6 +892,48 @@ def wrapper(*args, **kwargs): return wrapper +def _convert_ts_to_export_experimental(traced_callable, args, kwargs=None): + torch._C._jit_set_texpr_fuser_enabled(False) + + def process_trace_inputs_for_export(example_inputs, example_kwarg_inputs): + if not isinstance(example_inputs, tuple): + example_inputs = (example_inputs,) + + if example_kwarg_inputs is None: + example_kwarg_inputs = {} + return example_inputs, example_kwarg_inputs + + class _WrapperModule(torch.nn.Module): + def __init__(self, f): + super().__init__() + self.f = f + + def forward(self, *args, **kwargs): + return self.f(*args, **kwargs) + + from torch.jit._trace import TopLevelTracedModule + + export_args, export_kwargs = process_trace_inputs_for_export(args, kwargs) + + if isinstance(traced_callable, TopLevelTracedModule): + return _export( + traced_callable, + export_args, + export_kwargs, + strict=False, + _is_torch_jit_trace=True, + ).module() + + else: + return _export( + _WrapperModule(traced_callable), + export_args, + export_kwargs, + strict=False, + _is_torch_jit_trace=True, + ).module() + + @_log_export_wrapper @_disable_prexisiting_fake_mode def _export( @@ -901,6 +946,7 @@ def _export( preserve_module_call_signature: Tuple[str, ...] = (), pre_dispatch: bool = False, _disable_forced_specializations: Optional[bool] = False, + _is_torch_jit_trace: bool = False, ) -> ExportedProgram: """ Traces either an nn.Module's forward function or just a callable with PyTorch @@ -969,7 +1015,10 @@ def _export( flat_args, orig_in_spec = pytree.tree_flatten((args, kwargs)) original_state_dict = mod.state_dict(keep_vars=True) - forward_arg_names = _get_forward_arg_names(mod, args, kwargs) + if not _is_torch_jit_trace: + forward_arg_names = _get_forward_arg_names(mod, args, kwargs) + else: + forward_arg_names = None if not strict: out_spec = None @@ -1048,7 +1097,9 @@ def forward(self, *args, **kwargs): fake_kwargs, equalities_inputs, original_signature, - ) = make_fake_inputs(mod, args, kwargs, dynamic_shapes) + ) = make_fake_inputs( + mod, args, kwargs, dynamic_shapes, _is_torch_jit_trace=_is_torch_jit_trace + ) fake_params_buffers = make_fake_params_buffers( fake_mode, _get_params_buffers(mod) @@ -1062,7 +1113,7 @@ def forward(self, *args, **kwargs): new_fake_constant_attrs, map_fake_to_real, ): - ep_non_strict = _export_non_strict( + aten_export_artifact = _export_to_aten_ir( patched_mod, new_fake_args, new_fake_kwargs, @@ -1071,16 +1122,17 @@ def forward(self, *args, **kwargs): pre_dispatch=pre_dispatch, transform=_tuplify_outputs, should_insert_runtime_assertion=not strict, + _is_torch_jit_trace=_is_torch_jit_trace, ) - # ep_non_strict.constants contains only fake script objects, we need to map them back - ep_non_strict.constants = { + # aten_export_artifact.constants contains only fake script objects, we need to map them back + aten_export_artifact.constants = { fqn: map_fake_to_real[obj] if isinstance(obj, FakeScriptObject) else obj - for fqn, obj in ep_non_strict.constants.items() + for fqn, obj in aten_export_artifact.constants.items() } - ep_non_strict.gm.meta["inline_constraints"] = { + aten_export_artifact.gm.meta["inline_constraints"] = { k: v for k, v in fake_mode.shape_env.var_to_range.items() if free_unbacked_symbols(k) @@ -1088,25 +1140,26 @@ def forward(self, *args, **kwargs): num_lifted = len( [ spec - for spec in ep_non_strict.sig.input_specs + for spec in aten_export_artifact.sig.input_specs if spec.kind != InputKind.USER_INPUT ] ) try: produce_guards_and_solve_constraints( fake_mode, - ep_non_strict.gm, + aten_export_artifact.gm, equalities_inputs, original_signature, _disable_forced_specializations=_disable_forced_specializations, + _is_torch_jit_trace=_is_torch_jit_trace, ) except (ConstraintViolationError, ValueRangeError) as e: - raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e)) # noqa: TRY200 + raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e)) # noqa: B904 combined_args = _combine_args(mod, args, kwargs) range_constraints = make_constraints( fake_mode, - ep_non_strict.gm, + aten_export_artifact.gm, combined_args, dynamic_shapes, num_lifted, @@ -1114,7 +1167,7 @@ def forward(self, *args, **kwargs): assert out_spec is not None - gm = ep_non_strict.gm + gm = aten_export_artifact.gm gm.meta["forward_arg_names"] = forward_arg_names module_call_signatures = { @@ -1141,25 +1194,30 @@ def forward(self, *args, **kwargs): node.replace_all_uses_with(new_node) gm.graph.erase_node(node) - res = CollectTracepointsPass(module_call_signatures, ep_non_strict.sig)(gm) + res = CollectTracepointsPass( + module_call_signatures, aten_export_artifact.sig + )(gm) assert res is not None gm = res.graph_module - _rewrite_non_persistent_buffers(mod, ep_non_strict.sig, ep_non_strict.constants) + _rewrite_non_persistent_buffers( + mod, aten_export_artifact.sig, aten_export_artifact.constants + ) _verify_nn_module_stack(gm) _verify_stack_trace(gm) - _verify_placeholder_names(gm, ep_non_strict.sig) + if not _is_torch_jit_trace: + _verify_placeholder_names(gm, aten_export_artifact.sig) exported_program = ExportedProgram( root=gm, graph=gm.graph, - graph_signature=ep_non_strict.sig, + graph_signature=aten_export_artifact.sig, state_dict=original_state_dict, range_constraints=range_constraints, module_call_graph=_make_module_call_graph( _EXPORT_MODULE_HIERARCHY, orig_in_spec, out_spec, module_call_signatures ), example_inputs=(args, kwargs), - constants=ep_non_strict.constants, + constants=aten_export_artifact.constants, ) return exported_program @@ -1256,7 +1314,7 @@ def forward(self, *args, **kwargs): # NOTE: graph module expects only positional args constant_attrs = _gather_constant_attrs(mod) - ep_non_strict = _export_non_strict( + aten_export_artifact = _export_to_aten_ir( gm_torch_level, _convert_to_positional_args(orig_arg_names, fake_args, fake_kwargs), {}, @@ -1266,9 +1324,9 @@ def forward(self, *args, **kwargs): should_insert_runtime_assertion=not strict, ) - gm = ep_non_strict.gm - export_graph_signature = ep_non_strict.sig - constants = ep_non_strict.constants + gm = aten_export_artifact.gm + export_graph_signature = aten_export_artifact.sig + constants = aten_export_artifact.constants # Don't copy over nn_module_stack, stack_trace metadata for params/buffers nodes for metadata in params_buffers_to_node_meta.values(): @@ -1324,15 +1382,17 @@ def forward(self, *args, **kwargs): _rewrite_dynamo_tensor_constants( orig_mod_buffers=set(mod.buffers()), traced_mod_buffers=dict(gm_torch_level.named_buffers()), - graph_signature=ep_non_strict.sig, - constants=ep_non_strict.constants, + graph_signature=aten_export_artifact.sig, + constants=aten_export_artifact.constants, ) # 2. Restore FQN of param/buffers param_buffer_table: Dict[str, str] = _get_param_buffer_mapping(mod, gm_torch_level) _replace_param_buffer_names(param_buffer_table, export_graph_signature) # 3. Remove non-persistent buffers from the graph signature - _rewrite_non_persistent_buffers(mod, ep_non_strict.sig, ep_non_strict.constants) + _rewrite_non_persistent_buffers( + mod, aten_export_artifact.sig, aten_export_artifact.constants + ) # 4. Rewrite constants to have the same FQN as the original module. _remap_constants(constant_attrs, export_graph_signature, constants) diff --git a/torch/export/_unlift.py b/torch/export/_unlift.py index 52ce64e4dcadc..2fdb7916eeebf 100644 --- a/torch/export/_unlift.py +++ b/torch/export/_unlift.py @@ -22,7 +22,7 @@ def _check_input_constraints_pre_hook(self, *args, **kwargs): flat_args_with_path, received_spec = pytree.tree_flatten_with_path(args) if received_spec != self._in_spec: - raise ValueError( # noqa: TRY200 + raise ValueError( # noqa: B904 "Trying to flatten user inputs with exported input tree spec: \n" f"{self._in_spec}\n" "but actually got inputs with tree spec of: \n" diff --git a/torch/export/dynamic_shapes.py b/torch/export/dynamic_shapes.py index eba8333323445..a4ed16e975b80 100644 --- a/torch/export/dynamic_shapes.py +++ b/torch/export/dynamic_shapes.py @@ -602,18 +602,20 @@ def f(t, *dynamic_shapes): return tree_map(f, tree, *dynamic_shapes, is_leaf=is_leaf) -def _combine_args(f, args, kwargs): +def _combine_args(f, args, kwargs, _is_torch_jit_trace=False): # combine args and kwargs following the signature of f, as it happens # in the body of f when called with *args, **kwargs if isinstance(f, ExportedProgram): f = f.module() - signature = ( - inspect.signature(f.forward) - if isinstance(f, torch.nn.Module) - else inspect.signature(f) - ) - kwargs = kwargs if kwargs is not None else {} - return signature.bind(*args, **kwargs).arguments + if not _is_torch_jit_trace: + signature = ( + inspect.signature(f.forward) + if isinstance(f, torch.nn.Module) + else inspect.signature(f) + ) + kwargs = kwargs if kwargs is not None else {} + return signature.bind(*args, **kwargs).arguments + return args class ShapesCollection: @@ -692,6 +694,7 @@ def _process_dynamic_shapes( args: Tuple[Any, ...], kwargs: Optional[Dict[str, Any]] = None, dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None, + _is_torch_jit_trace=False, ) -> Optional[List[Constraint]]: from torch._dynamo.exc import UserError, UserErrorType @@ -720,7 +723,7 @@ def root_value(): if solution is not None: return int(solution[1]) # type: ignore[call-overload] else: - raise UserError( # noqa: TRY200 + raise UserError( # noqa: B904 UserErrorType.CONSTRAINT_VIOLATION, f"Expected shape[{i}] = {tensor.shape[i]} of input Tensor to be " f"of the form {expr}, where {symbol} is an integer", @@ -858,7 +861,9 @@ def assoc_shape(t, dynamic_shape): _tree_map(assoc_shape, combined_args, dynamic_shapes) - combined_args = _combine_args(f, args, kwargs) + combined_args = _combine_args( + f, args, kwargs, _is_torch_jit_trace=_is_torch_jit_trace + ) if not isinstance(dynamic_shapes, dict): assert isinstance(dynamic_shapes, (tuple, list)) combined_args = type(dynamic_shapes)(combined_args.values()) # type: ignore[assignment, misc] diff --git a/torch/export/exported_program.py b/torch/export/exported_program.py index f5635038c4e26..ffb3467055b31 100644 --- a/torch/export/exported_program.py +++ b/torch/export/exported_program.py @@ -448,7 +448,7 @@ def _postprocess_graph_module_outputs(self, res, orig_args, orig_kwargs): res = pytree.tree_unflatten(res, self.call_spec.out_spec) except Exception: _, received_spec = pytree.tree_flatten(res) - raise error.InternalError( # noqa: TRY200 + raise error.InternalError( # noqa: B904 "Trying to flatten user outputs with exported output tree spec: \n" f"{self.call_spec.out_spec}\n" "but actually got outputs with tree spec of: \n" diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py index 95a9f568443e9..fa44b63067865 100644 --- a/torch/fx/graph_module.py +++ b/torch/fx/graph_module.py @@ -310,7 +310,7 @@ def __call__(self, obj, *args, **kwargs): _WrappedCall._generate_error_message(topmost_framesummary), file=sys.stderr, ) - raise e.with_traceback(None) # noqa: TRY200 + raise e.with_traceback(None) # noqa: B904 else: raise e diff --git a/torch/fx/passes/infra/partitioner.py b/torch/fx/passes/infra/partitioner.py index 7b36918928d38..3952bb6525171 100644 --- a/torch/fx/passes/infra/partitioner.py +++ b/torch/fx/passes/infra/partitioner.py @@ -262,10 +262,14 @@ def _update_partition_map(node: Node, id: int): return [partition for partition in partitions_by_id.values() if partition.size() > 0] - def fuse_partitions(self, partitions: List[Partition]) -> GraphModule: + def fuse_partitions(self, partitions: List[Partition], prefix: str = "fused_") -> GraphModule: logger.debug("Fusing partitions...") # fuse_by_partitions expects partitions in List[List[Node]]: [ [node0, node1], [node2, node3] ] - return fuse_by_partitions(self.graph_module, [list(partition.nodes) for partition in partitions]) + return fuse_by_partitions( + self.graph_module, + [list(partition.nodes) for partition in partitions], + prefix=prefix, + ) # remove non-compute-ops that sits at the boundary of a partition. def remove_bookend_non_compute_ops(self, partitions: List[Partition]): @@ -323,7 +327,7 @@ def is_transparent_output_node(node: Node, partition: Set[Node], removed_nodes: if len(remove_node) != 0: partition.nodes = partition.nodes - remove_node - def partition_and_fuse(self) -> GraphModule: + def partition_and_fuse(self, prefix: str = "fused_") -> GraphModule: partitions = self.propose_partitions() - fused_gm = self.fuse_partitions(partitions) + fused_gm = self.fuse_partitions(partitions, prefix=prefix) return fused_gm diff --git a/torch/fx/passes/net_min_base.py b/torch/fx/passes/net_min_base.py index 9d24162500ac9..6d050c78f7540 100644 --- a/torch/fx/passes/net_min_base.py +++ b/torch/fx/passes/net_min_base.py @@ -373,7 +373,7 @@ def _run_and_compare( self._store_outputs(a_result, b_result, submodule) except Exception as e: report.append(f"Exception raised when running {submod_name}: {e}") - raise FxNetMinimizerRunFuncError( # noqa: TRY200 + raise FxNetMinimizerRunFuncError( # noqa: B904 f"Exception raised when running {submod_name}: {e}" ) diff --git a/torch/fx/passes/utils/fuser_utils.py b/torch/fx/passes/utils/fuser_utils.py index 8976690ed73a1..3423ea3dad5a4 100644 --- a/torch/fx/passes/utils/fuser_utils.py +++ b/torch/fx/passes/utils/fuser_utils.py @@ -218,11 +218,11 @@ def erase_nodes(gm: GraphModule, nodes: NodeList): @compatibility(is_backward_compatible=False) -def fuse_by_partitions(gm: GraphModule, partitions: List[NodeList]) -> GraphModule: +def fuse_by_partitions(gm: GraphModule, partitions: List[NodeList], prefix: str = "fused_") -> GraphModule: for partition_id, nodes in enumerate(partitions): sorted_nodes = topo_sort(nodes) - submodule_name = "fused_" + str(partition_id) + submodule_name = prefix + str(partition_id) sub_gm, orig_inputs, orig_outputs = fuse_as_graphmodule(gm, sorted_nodes, submodule_name) insert_subgm(gm, sub_gm, orig_inputs, orig_outputs) diff --git a/torch/nn/attention/_flex_attention.py b/torch/nn/attention/_flex_attention.py index 6f158c20db6de..c56374fcbc40d 100644 --- a/torch/nn/attention/_flex_attention.py +++ b/torch/nn/attention/_flex_attention.py @@ -96,6 +96,8 @@ def score_mod( raise ValueError( "NYI: The target sequence length (L) of the query tensor must match the source sequence length (S) of the key tensor." ) + if query.size(-2) % 128 != 0: + raise ValueError("NYI: S and L must be a multiple of 128") if not torch._dynamo.is_dynamo_supported(): raise RuntimeError("flex_attention requires dynamo support.") @@ -149,7 +151,7 @@ def _rel_causal( token_q: torch.Tensor, token_kv: torch.Tensor, ) -> torch.Tensor: - return torch.where(token_q <= token_kv, score + (token_q - token_kv), float("-inf")) + return torch.where(token_q >= token_kv, score + (token_q - token_kv), float("-inf")) def _generate_alibi_bias(num_heads: int): diff --git a/torch/nn/utils/_named_member_accessor.py b/torch/nn/utils/_named_member_accessor.py index 3a82b2b426aa0..e46318b0d3acb 100644 --- a/torch/nn/utils/_named_member_accessor.py +++ b/torch/nn/utils/_named_member_accessor.py @@ -147,7 +147,7 @@ def get_submodule(self, name: str) -> "torch.nn.Module": f"{module._get_name()} has no attribute `{attr}`" ) from ex if not isinstance(submodule, torch.nn.Module): - raise TypeError( # noqa: TRY200 + raise TypeError( # noqa: B904 f"submodule `{name}`: {submodule} is not an instance of torch.nn.Module" ) self.memo[name] = submodule diff --git a/torch/onnx/_internal/fx/decomposition_skip.py b/torch/onnx/_internal/fx/decomposition_skip.py index 425e8604468b8..7fb971a3307a5 100644 --- a/torch/onnx/_internal/fx/decomposition_skip.py +++ b/torch/onnx/_internal/fx/decomposition_skip.py @@ -71,7 +71,7 @@ def register_custom_op(cls): new_op_qualname = f"{_NEW_OP_NAMESPACE}::{cls.new_op_name}" torch.library.define(new_op_qualname, cls.new_op_schema) torch.library.impl(new_op_qualname, "default", cls.replacement) - torch.library.impl_abstract(new_op_qualname, cls.abstract) + torch.library.register_fake(new_op_qualname, cls.abstract) @classmethod def replacement(cls, *args, **kwargs): diff --git a/torch/optim/asgd.py b/torch/optim/asgd.py index a87aadc81803c..f53f8b427e9f9 100644 --- a/torch/optim/asgd.py +++ b/torch/optim/asgd.py @@ -22,13 +22,6 @@ __all__ = ["ASGD", "asgd"] -def _to_tensor(x, device=None): - if not isinstance(x, torch.Tensor): - return torch.tensor(x, device=device) - - return x - - class ASGD(Optimizer): def __init__( self, @@ -264,9 +257,9 @@ def _single_tensor_asgd( mu.copy_(1 / torch.maximum(step_t - t0, torch.ones_like(step_t))) else: step = _get_value(step_t) - new_eta = _to_tensor(lr / ((1 + lambd * lr * step) ** alpha)) + new_eta = torch.as_tensor(lr / ((1 + lambd * lr * step) ** alpha)) eta.copy_(new_eta) - new_mu = _to_tensor(1 / max(1, step - t0)) + new_mu = torch.as_tensor(1 / max(1, step - t0)) mu.copy_(new_mu) @@ -381,27 +374,23 @@ def _multi_tensor_asgd( torch._foreach_copy_(grouped_mus, new_mus) del new_mus - # update eta = lr / (1 + lambd * lr * step^alpha) - new_etas = torch._foreach_pow(grouped_state_steps, alpha) - torch._foreach_mul_(new_etas, lambd) + # update eta = lr / ((1 + lambd * lr * step)^alpha) + new_etas = torch._foreach_mul(grouped_state_steps, lambd) torch._foreach_mul_(new_etas, lr) torch._foreach_add_(new_etas, 1) + torch._foreach_pow_(new_etas, alpha) torch._foreach_reciprocal_(new_etas) torch._foreach_mul_(new_etas, lr) torch._foreach_copy_(grouped_etas, new_etas) else: - step = grouped_state_steps[0].item() - new_etas = [] - new_mus = [] - - for i in range(len(grouped_mus)): - new_eta = _to_tensor( - lr / (1 + lambd * lr * step**alpha), device=device - ) - new_etas.append(new_eta) - new_mu = _to_tensor(1 / max(1, step - t0), device=device) - new_mus.append(new_mu) - + new_etas = [ + torch.as_tensor(lr / ((1 + lambd * lr * step) ** alpha), device=device) + for step in grouped_state_steps + ] + new_mus = [ + torch.as_tensor(1 / max(1, _get_value(step) - t0), device=device) + for step in grouped_state_steps + ] torch._foreach_copy_(grouped_etas, new_etas) torch._foreach_copy_(grouped_mus, new_mus) diff --git a/torch/serialization.py b/torch/serialization.py index df839408ee776..a7703b9964d0d 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -1,4 +1,5 @@ import difflib +import functools import os import io import shutil @@ -31,7 +32,7 @@ STORAGE_KEY_SEPARATOR = ',' FILE_LIKE: TypeAlias = Union[str, os.PathLike, BinaryIO, IO[bytes]] -MAP_LOCATION: TypeAlias = Optional[Union[Callable[[torch.Tensor, str], torch.Tensor], torch.device, str, Dict[str, str]]] +MAP_LOCATION: TypeAlias = Optional[Union[Callable[[Storage, str], Storage], torch.device, str, Dict[str, str]]] STORAGE: TypeAlias = Union[Storage, torch.storage.TypedStorage, torch.UntypedStorage] IS_WINDOWS = sys.platform == "win32" @@ -58,6 +59,9 @@ 'LoadEndianness', 'get_default_load_endianness', 'set_default_load_endianness', + 'clear_safe_globals', + 'get_safe_globals', + 'add_safe_globals', ] @@ -147,6 +151,27 @@ def set_default_mmap_options(flags: int): f"expected mmap.MAP_PRIVATE or mmap.MAP_SHARED, but got {flags}") _default_mmap_options = flags +def clear_safe_globals() -> None: + ''' + Clears the list of globals that are safe for ``weights_only`` load. + ''' + _weights_only_unpickler._clear_safe_globals() + +def get_safe_globals() -> List[Any]: + ''' + Returns the list of user-added globals that are safe for ``weights_only`` load. + ''' + return _weights_only_unpickler._get_safe_globals() + +def add_safe_globals(safe_globals: List[Any]) -> None: + ''' + Marks the given globals as safe for ``weights_only`` load. + + Args: + safe_globals (List[Any]): list of globals to mark as safe + ''' + _weights_only_unpickler._add_safe_globals(safe_globals) + def _is_zipfile(f) -> bool: # This is a stricter implementation than zipfile.is_zipfile(). # zipfile.is_zipfile() is True if the magic number appears anywhere in the @@ -252,14 +277,6 @@ def _cpu_tag(obj): return 'cpu' -def _cuda_tag(obj): - if obj.device.type == 'cuda': - return 'cuda:' + str(obj.device.index) - -def _hpu_tag(obj): - if obj.device.type == 'hpu': - return 'hpu:' + str(obj.device.index) - def _mps_tag(obj): if obj.device.type == 'mps': return 'mps' @@ -270,8 +287,9 @@ def _meta_tag(obj): return 'meta' -def _privateuse1_tag(obj): - backend_name = torch._C._get_privateuse1_backend_name() +def _backend_tag(backend_name, obj): + if backend_name == 'privateuse1': + backend_name = torch._C._get_privateuse1_backend_name() if obj.device.type == backend_name: if obj.device.index is None: return backend_name @@ -284,66 +302,6 @@ def _cpu_deserialize(obj, location): return obj -def validate_cuda_device(location): - device = torch.cuda._utils._get_device_index(location, True) - - if not torch.cuda.is_available(): - raise RuntimeError('Attempting to deserialize object on a CUDA ' - 'device but torch.cuda.is_available() is False. ' - 'If you are running on a CPU-only machine, ' - 'please use torch.load with map_location=torch.device(\'cpu\') ' - 'to map your storages to the CPU.') - device_count = torch.cuda.device_count() - if device >= device_count: - raise RuntimeError('Attempting to deserialize object on CUDA device ' - f'{device} but torch.cuda.device_count() is {device_count}. Please use ' - 'torch.load with map_location to map your storages ' - 'to an existing device.') - return device - - -def _cuda_deserialize(obj, location): - if location.startswith('cuda'): - device = validate_cuda_device(location) - if getattr(obj, "_torch_load_uninitialized", False): - with torch.cuda.device(device): - return torch.UntypedStorage(obj.nbytes(), device=torch.device(location)) - else: - return obj.cuda(device) - - -def validate_hpu_device(location): - hpu = getattr(torch, "hpu", None) - assert hpu is not None, "HPU device module is not loaded" - device = hpu._utils._get_device_index(location, optional=True) - - if not hpu.is_available(): - raise RuntimeError('Attempting to deserialize object on a HPU ' - 'device but torch.hpu.is_available() is False. ' - 'If you are running on a CPU-only machine, ' - 'please use torch.load with map_location=torch.device(\'cpu\') ' - 'to map your storages to the CPU.') - device_count = hpu.device_count() - if device >= device_count: - raise RuntimeError('Attempting to deserialize object on HPU device ' - f'{device} but torch.hpu.device_count() is {device_count}. Please use ' - 'torch.load with map_location to map your storages ' - 'to an existing device.') - return device - - -def _hpu_deserialize(obj, location): - if location.startswith('hpu'): - hpu = getattr(torch, "hpu", None) - assert hpu is not None, "HPU device module is not loaded" - device = validate_hpu_device(location) - if getattr(obj, "_torch_load_uninitialized", False): - with hpu.device(device): - return torch.UntypedStorage(obj.nbytes(), device=torch.device(location)) - else: - return obj.hpu(device) - - def _mps_deserialize(obj, location): if location.startswith('mps'): return obj.mps() @@ -354,18 +312,18 @@ def _meta_deserialize(obj, location): return torch.UntypedStorage(obj.nbytes(), device='meta') -def _validate_privateuse1_device(location, backend_name): +def _validate_device(location, backend_name): ''' - Check whether the device index of privateuse1 is valid + Check whether the device index of specified backend is valid - Register a device_module of privateuse1 by torch._register_device_module. - Implement the following methods in device_module like cuda: - device_module._utils._get_device_index(location, True), + In case of privateuse1 backend, your must first register a device_module for + privateuse1 using torch._register_device_module. Implement the following + methods in device_module like cuda: device_module._utils._get_device_index(location, True), device_module.device_count(). Args: location: string of device - backend_name: the name of privateuse1, which can be renamed + backend_name: the backend name or the name of privateuse1, which can be renamed Returns: device_index: int @@ -378,6 +336,7 @@ def _validate_privateuse1_device(location, backend_name): device_module = getattr(torch, backend_name) if hasattr(device_module, '_utils') and hasattr(device_module._utils, '_get_device_index'): device_index = device_module._utils._get_device_index(location, True) + device = torch.device(backend_name, device_index) else: device = torch.device(location) device_index = device.index if device.index else 0 @@ -394,29 +353,32 @@ def _validate_privateuse1_device(location, backend_name): f'{device_index} but torch.{backend_name}.device_count() is {device_count}. ' 'Please use torch.load with map_location to map your storages ' 'to an existing device.') - return device_index + return device + + +def validate_cuda_device(location): + return _validate_device(location, 'cuda').index + +def validate_hpu_device(location): + return _validate_device(location, 'hpu').index -def _privateuse1_deserialize(obj, location): - backend_name = torch._C._get_privateuse1_backend_name() + +def _deserialize(backend_name, obj, location): + if backend_name == 'privateuse1': + backend_name = torch._C._get_privateuse1_backend_name() if location.startswith(backend_name): - if not hasattr(obj, backend_name): - raise RuntimeError(f'Attempting to load the storages to the {backend_name.upper()} device ' - f'but torch.storage._StorageBase.{backend_name}() or ' - f'torch.storage.TypedStorage.{backend_name}() is not generated. ' - 'Please use torch.utils.generate_methods_for_privateuse1_backend ' - f'to generate storage.{backend_name}() method first.') - device_index = _validate_privateuse1_device(location, backend_name) - return getattr(obj, backend_name)(device_index) + device = _validate_device(location, backend_name) + return obj.to(device=device) register_package(10, _cpu_tag, _cpu_deserialize) -register_package(20, _cuda_tag, _cuda_deserialize) +register_package(20, functools.partial(_backend_tag, 'cuda'), functools.partial(_deserialize, 'cuda')) register_package(21, _mps_tag, _mps_deserialize) register_package(22, _meta_tag, _meta_deserialize) -register_package(23, _privateuse1_tag, _privateuse1_deserialize) -register_package(24, _hpu_tag, _hpu_deserialize) - +register_package(23, functools.partial(_backend_tag, 'privateuse1'), functools.partial(_deserialize, 'privateuse1')) +register_package(24, functools.partial(_backend_tag, 'hpu'), functools.partial(_deserialize, 'hpu')) +register_package(25, functools.partial(_backend_tag, 'xpu'), functools.partial(_deserialize, 'xpu')) def location_tag(storage: Union[Storage, torch.storage.TypedStorage, torch.UntypedStorage]): for _, tagger, _ in _package_registry: @@ -1014,7 +976,9 @@ def load( UNSAFE_MESSAGE = ( "Weights only load failed. Re-running `torch.load` with `weights_only` set to `False`" " will likely succeed, but it can result in arbitrary code execution." - "Do it only if you get the file from a trusted source. WeightsUnpickler error: " + " Do it only if you get the file from a trusted source. Alternatively, to load" + " with `weights_only` please check the recommended steps in the following error message." + " WeightsUnpickler error: " ) # Add ability to force safe only weight loads via environment variable if os.getenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD", "0").lower() in ['1', 'y', 'yes', 'true']: diff --git a/torch/storage.py b/torch/storage.py index 306dd99a93add..32070783f4940 100644 --- a/torch/storage.py +++ b/torch/storage.py @@ -1,7 +1,7 @@ import io import torch -from ._utils import _type, _cuda, _hpu +from ._utils import _type, _to from torch.types import Storage from typing import cast, Any, Dict as _Dict, Optional as _Optional, TypeVar, Type, Union import copy @@ -38,8 +38,37 @@ def size(self) -> int: return self.nbytes() def type(self, dtype: _Optional[str] = None, non_blocking: bool = False) -> T: ... # type: ignore[empty-body, misc, type-var] # noqa: E704 - def cuda(self, device=None, non_blocking=False, **kwargs) -> T: ... # type: ignore[empty-body, misc, type-var] # noqa: E704 - def hpu(self, device=None, non_blocking=False, **kwargs) -> T: ... # type: ignore[empty-body, misc, type-var] # noqa: E704 + + def cuda(self, device=None, non_blocking=False) -> T: # type: ignore[type-var] # noqa: E704 + """Returns a copy of this object in CUDA memory. + + If this object is already in CUDA memory and on the correct device, then + no copy is performed and the original object is returned. + + Args: + device (int): The destination GPU id. Defaults to the current device. + non_blocking (bool): If ``True`` and the source is in pinned memory, + the copy will be asynchronous with respect to the host. Otherwise, + the argument has no effect. + """ + device2 = torch.device('cuda', device) if device else torch.device('cuda') + return self.to(device=device2, non_blocking=non_blocking) + + def hpu(self, device=None, non_blocking=False) -> T: # type: ignore[type-var] # noqa: E704 + """Returns a copy of this object in HPU memory. + + If this object is already in HPU memory and on the correct device, then + no copy is performed and the original object is returned. + + Args: + device (int): The destination HPU id. Defaults to the current device. + non_blocking (bool): If ``True`` and the source is in pinned memory, + the copy will be asynchronous with respect to the host. Otherwise, + the argument has no effect. + """ + device2 = torch.device('hpu', device) if device else torch.device('hpu') + return self.to(device=device2, non_blocking=non_blocking) + def element_size(self) -> int: ... # type: ignore[empty-body, type-var] # noqa: E704 def get_device(self) -> int: @@ -153,6 +182,9 @@ def _to(self, dtype): storage = storage.clone() return storage + def to(self, *, device: torch.device, non_blocking: bool = False) -> T: # type: ignore[type-var] # noqa: E704 + return _to(self, device, non_blocking) + def double(self): """Casts this storage to double type.""" return self._to(torch.double) @@ -382,8 +414,6 @@ def _load_from_bytes(b): _StorageBase.type = _type # type: ignore[assignment] -_StorageBase.cuda = _cuda # type: ignore[assignment] -_StorageBase.hpu = _hpu # type: ignore[assignment] @lru_cache(maxsize=None) @@ -812,20 +842,27 @@ def type(self, dtype: _Optional[str] = None, non_blocking: bool = False) -> Unio else: return self._untyped_storage.type(dtype, non_blocking) - def cuda(self, device=None, non_blocking=False, **kwargs) -> T: # type: ignore[misc, type-var] + def cuda(self, device=None, non_blocking=False) -> T: # type: ignore[misc, type-var] _warn_typed_storage_removal() if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]: raise RuntimeError("Cannot create CUDA storage with quantized dtype") - cuda_storage: torch.UntypedStorage = self._untyped_storage.cuda(device, non_blocking, **kwargs) + cuda_storage: torch.UntypedStorage = self._untyped_storage.cuda(device, non_blocking) return self._new_wrapped_storage(cuda_storage) - def hpu(self, device=None, non_blocking=False, **kwargs) -> T: # type: ignore[misc, type-var] + def hpu(self, device=None, non_blocking=False) -> T: # type: ignore[misc, type-var] _warn_typed_storage_removal() if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]: raise RuntimeError("Cannot create HPU storage with quantized dtype") - hpu_storage: torch.UntypedStorage = self._untyped_storage.hpu(device, non_blocking, **kwargs) + hpu_storage: torch.UntypedStorage = self._untyped_storage.hpu(device, non_blocking) return self._new_wrapped_storage(hpu_storage) + def to(self, *, device: torch.device, non_blocking: bool = False) -> T: # type: ignore[type-var] + _warn_typed_storage_removal() + if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]: + raise RuntimeError(f"Cannot create {device.type.upper()} storage with quantized dtype") + to_storage: torch.UntypedStorage = self._untyped_storage.to(device=device, non_blocking=non_blocking) + return self._new_wrapped_storage(to_storage) + def element_size(self): _warn_typed_storage_removal() return self._element_size() @@ -1209,8 +1246,9 @@ def _get_legacy_storage_class(self): return None TypedStorage.type.__doc__ = _type.__doc__ -TypedStorage.cuda.__doc__ = _cuda.__doc__ -TypedStorage.hpu.__doc__ = _hpu.__doc__ +TypedStorage.cuda.__doc__ = _StorageBase.cuda.__doc__ +TypedStorage.hpu.__doc__ = _StorageBase.hpu.__doc__ +TypedStorage.to.__doc__ = _to.__doc__ class _LegacyStorageMeta(type): dtype: torch.dtype diff --git a/torch/testing/_internal/common_optimizers.py b/torch/testing/_internal/common_optimizers.py index 5a66923373f74..5abacf2df1d61 100644 --- a/torch/testing/_internal/common_optimizers.py +++ b/torch/testing/_internal/common_optimizers.py @@ -123,6 +123,8 @@ def __init__( supported_impls: Tuple[str] = ("foreach", "differentiable"), # the optim supports passing in sparse gradients as well as dense grads supports_sparse: bool = False, + # the optimizer constructor supports passing in capturable as a kwarg + has_capturable_arg: bool = False, # the optim only supports one config: sparse grads w/ dense params, see SparseAdam only_supports_sparse_grads: bool = False, # Tuple of (optimizer kwargs, schedulers_constructors) specifically for sparse tests, @@ -147,6 +149,7 @@ def __init__( self.scheduler_inputs = scheduler_inputs self.supported_impls = supported_impls self.supports_sparse = supports_sparse + self.has_capturable_arg = has_capturable_arg self.metadata_for_sparse = metadata_for_sparse self.only_supports_sparse_grads = only_supports_sparse_grads self.supports_complex = supports_complex @@ -311,10 +314,11 @@ def optim_inputs_func_adadelta(device, dtype=None): OptimizerInput( params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay" ), + OptimizerInput(params=None, kwargs={"maximize": True}, desc="maximize"), OptimizerInput( params=None, kwargs={"weight_decay": 0.1, "maximize": True}, - desc="maximize", + desc="maximize, weight_decay", ), OptimizerInput( params=None, kwargs={"rho": 0.95, "weight_decay": 0.9}, desc="rho" @@ -528,9 +532,14 @@ def optim_inputs_func_adamax(device, dtype=None): ), OptimizerInput( params=None, - kwargs={"weight_decay": 0.1, "maximize": True}, + kwargs={"maximize": True}, desc="maximize", ), + OptimizerInput( + params=None, + kwargs={"weight_decay": 0.1, "maximize": True}, + desc="maximize, weight_decay", + ), ] + (cuda_supported_configs if "cuda" in str(device) else []) @@ -590,6 +599,7 @@ def optim_inputs_func_asgd(device, dtype=None): ] return [ OptimizerInput(params=None, kwargs={}, desc="default"), + OptimizerInput(params=None, kwargs={"lambd": 0.1}, desc="non-default lambd"), OptimizerInput(params=None, kwargs={"lr": 0.02}, desc="non-default lr"), OptimizerInput(params=None, kwargs={"t0": 100}, desc="t0"), OptimizerInput(params=None, kwargs={"maximize": True}, desc="maximize"), @@ -682,14 +692,20 @@ def optim_inputs_func_nadam(device, dtype=None): ), OptimizerInput( params=None, - kwargs={"weight_decay": 0.1, "momentum_decay": 6e-3}, + kwargs={ + "weight_decay": 0.1, + }, desc="weight_decay", ), + OptimizerInput( + params=None, + kwargs={"weight_decay": 0.1, "momentum_decay": 6e-3}, + desc="weight_decay, momentum_decay", + ), OptimizerInput( params=None, kwargs={ "weight_decay": 0.1, - "momentum_decay": 6e-3, "decoupled_weight_decay": True, }, desc="decoupled_weight_decay", @@ -817,11 +833,26 @@ def optim_inputs_func_rmsprop(device, dtype=None): OptimizerInput( params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay" ), + OptimizerInput( + params=None, + kwargs={ + "maximize": True, + }, + desc="maximize", + ), OptimizerInput( params=None, kwargs={"weight_decay": 0.1, "centered": True}, desc="centered", ), + OptimizerInput( + params=None, + kwargs={ + "maximize": True, + "weight_decay": 0.1, + }, + desc="maximize, weight_decay", + ), OptimizerInput( params=None, kwargs={"weight_decay": 0.1, "centered": True, "momentum": 0.1}, @@ -835,7 +866,7 @@ def optim_inputs_func_rmsprop(device, dtype=None): "momentum": 0.1, "maximize": True, }, - desc="maximize", + desc="maximize, centered, weight_decay, w/ momentum", ), ] + (cuda_supported_configs if "cuda" in str(device) else []) @@ -906,7 +937,15 @@ def optim_inputs_func_sgd(device, dtype=None): OptimizerInput( params=None, kwargs={"lr": torch.tensor(0.001)}, desc="tensor lr" ), + OptimizerInput( + params=None, kwargs={"weight_decay": 0.5}, desc="non-zero weight_decay" + ), OptimizerInput(params=None, kwargs={"momentum": 0.9}, desc="momentum"), + OptimizerInput( + params=None, + kwargs={"weight_decay": 0.1, "maximize": True}, + desc="maximize", + ), OptimizerInput( params=None, kwargs={"momentum": 0.9, "dampening": 0.5}, @@ -915,18 +954,13 @@ def optim_inputs_func_sgd(device, dtype=None): OptimizerInput( params=None, kwargs={"momentum": 0.9, "weight_decay": 0.1}, - desc="non-zero weight_decay", + desc="weight_decay w/ momentum", ), OptimizerInput( params=None, kwargs={"momentum": 0.9, "nesterov": True, "weight_decay": 0.1}, desc="nesterov", ), - OptimizerInput( - params=None, - kwargs={"weight_decay": 0.1, "maximize": True}, - desc="maximize", - ), ] @@ -1096,6 +1130,7 @@ def _get_optim_inputs_including_global_cliquey_kwargs( optim_inputs_func=optim_inputs_func_adadelta, optim_error_inputs_func=optim_error_inputs_func_adadelta, supported_impls=("foreach", "differentiable"), + has_capturable_arg=True, skips=( DecorateInfo( skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"), @@ -1231,6 +1266,7 @@ def _get_optim_inputs_including_global_cliquey_kwargs( optim_error_inputs_func=optim_error_inputs_func_adam, supported_impls=("foreach", "differentiable", "fused"), supports_fused_on=("cpu", "cuda"), + has_capturable_arg=True, decorators=( # Expected floating point error between fused and compiled forloop DecorateInfo( @@ -1297,6 +1333,7 @@ def _get_optim_inputs_including_global_cliquey_kwargs( optim_inputs_func=optim_inputs_func_adamax, optim_error_inputs_func=optim_error_inputs_func_adamax, supported_impls=("foreach", "differentiable"), + has_capturable_arg=True, skips=( DecorateInfo( skipIfMps, # addcdiv doesn't work for non-contiguous, see #118115 @@ -1347,6 +1384,7 @@ def _get_optim_inputs_including_global_cliquey_kwargs( optim_error_inputs_func=optim_error_inputs_func_adamw, supported_impls=("foreach", "differentiable", "fused"), supports_fused_on=("cpu", "cuda"), + has_capturable_arg=True, decorators=( # Expected error between compiled forloop and fused optimizers DecorateInfo( @@ -1413,6 +1451,7 @@ def _get_optim_inputs_including_global_cliquey_kwargs( optim_inputs_func=optim_inputs_func_asgd, optim_error_inputs_func=optim_error_inputs_func_asgd, supported_impls=("foreach", "differentiable"), + has_capturable_arg=True, skips=( DecorateInfo( skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"), @@ -1450,6 +1489,13 @@ def _get_optim_inputs_including_global_cliquey_kwargs( "TestOptimRenewed", "test_defaults_changed_to_foreach", ), + DecorateInfo( + unittest.skip( + "ASGD internally changes the weights even with zero grad" + ), + "TestOptimRenewed", + "test_step_is_noop_for_zero_grads", + ), ), ), OptimizerInfo( @@ -1498,6 +1544,7 @@ def _get_optim_inputs_including_global_cliquey_kwargs( optim_inputs_func=optim_inputs_func_nadam, optim_error_inputs_func=optim_error_inputs_func_nadam, supported_impls=("foreach", "differentiable"), + has_capturable_arg=True, skips=( DecorateInfo( skipIfMps, # addcdiv doesn't work for non-contiguous, see #118115 @@ -1553,6 +1600,7 @@ def _get_optim_inputs_including_global_cliquey_kwargs( optim_inputs_func=optim_inputs_func_radam, optim_error_inputs_func=optim_error_inputs_func_radam, supported_impls=("foreach", "differentiable"), + has_capturable_arg=True, skips=( DecorateInfo( skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"), @@ -1598,6 +1646,7 @@ def _get_optim_inputs_including_global_cliquey_kwargs( optim_inputs_func=optim_inputs_func_rmsprop, optim_error_inputs_func=optim_error_inputs_func_rmsprop, supported_impls=("foreach", "differentiable"), + has_capturable_arg=True, skips=( DecorateInfo( skipIfMps, # addcdiv doesn't work for non-contiguous, see #118115 @@ -1647,6 +1696,7 @@ def _get_optim_inputs_including_global_cliquey_kwargs( optim_inputs_func=optim_inputs_func_rprop, optim_error_inputs_func=optim_error_inputs_func_rprop, supported_impls=("foreach", "differentiable"), + has_capturable_arg=True, skips=( DecorateInfo( skipIfMps, # Rprop doesn't update for non-contiguous, see #118117 diff --git a/torch/testing/_internal/custom_op_db.py b/torch/testing/_internal/custom_op_db.py index 3177fb9c8bb5d..ee170cc360586 100644 --- a/torch/testing/_internal/custom_op_db.py +++ b/torch/testing/_internal/custom_op_db.py @@ -458,7 +458,7 @@ def source1_fake(x): lib.define("source2(Tensor x) -> Tensor") -@torch.library.impl_abstract("_torch_testing::source2", lib=lib) +@torch.library.register_fake("_torch_testing::source2", lib=lib) def _(x): return x.clone() @@ -467,7 +467,7 @@ def _(x): def source3_fake(x): return x.clone() -torch.library.impl_abstract("_torch_testing::source3", source3_fake, lib=lib) +torch.library.register_fake("_torch_testing::source3", source3_fake, lib=lib) @torch.library.custom_op("_torch_testing::source4", mutates_args=()) diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index 1805134130936..b9873b9950fae 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -48,6 +48,10 @@ _verify_param_shape_across_processes, _sync_module_states, ) +from torch.profiler import ( + ExecutionTraceObserver, + ProfilerActivity, +) from torch.nn.parallel import DistributedDataParallel from torch.nn.parallel.distributed import _dump_DDP_relevant_env_vars, _MixedPrecision @@ -6867,7 +6871,20 @@ def test_ddp_grad_div_uneven_inputs(self): net.zero_grad() torch.cuda.synchronize(device=self.rank) - def _test_ddp_profiling(self, profiler_ctx): + def _test_ddp_profiling(self, profiler_ctx, profiler_ctx2=None): + """Runs DDP based model training and captures profiles. + This test will do two profiler runs. + 1. An inital basic run to check if profiler events are correctly captured. + 2. A second profiling pass after running some iterations of DDP, to check robustness of thread local state. + + args + profiler_ctx : Profiler context manager for pass 1 + profiler_ctx2 : Profiler context manager for pass 2. + This can be left out as None, in which case a deepcopy + of profiler_ctx is used. + Returns: + prof: Instantiated profiler object that can be used for post analysis. + """ batch = 3 dim = 10 num_iters = 6 @@ -6878,7 +6895,8 @@ def _test_ddp_profiling(self, profiler_ctx): model.cuda(self.rank), device_ids=[self.rank], ) - profiler_ctx_copy = copy.deepcopy(profiler_ctx) + if profiler_ctx2 is None: + profiler_ctx2 = copy.deepcopy(profiler_ctx) with profiler_ctx as prof: for i in range(num_iters): @@ -6913,7 +6931,7 @@ def _test_ddp_profiling(self, profiler_ctx): loss = net(inp).sum() loss.backward() # Now enable the profiler. - with profiler_ctx_copy as prof: + with profiler_ctx2 as prof: loss = net(inp).sum() loss.backward() @@ -6971,6 +6989,90 @@ def test_ddp_profiling_torch_profiler(self): self.assertEqual(a1["Out msg nelems"], 1, msg=f"{a1}") self.assertEqual(a1["dtype"], "Int", msg=f"{a1}") + def _validate_execution_trace_nccl(self, et_file: str) -> None: + """Torch profiler includes nccl metadata in an inserted operator called "record_param_comms" + We test for basic fields in theese nodes in the Execution Trace. + """ + with open(et_file) as f: + et = json.load(f) + + nccl_meta_nodes = [n for n in et["nodes"] if n["name"] == "record_param_comms"] + self.assertEqual(len(nccl_meta_nodes), 3) + per_coll_meta = defaultdict(list) + + # Sanity check NCCL metadata nodes + for n in nccl_meta_nodes: + attrs_list = n.get("attrs", []) + self.assertGreater(len(attrs_list), 0) + attrs = {a["name"]: a["value"] for a in attrs_list} + + collname = attrs.get("collective_name", "") + self.assertNotEqual(collname, "") + self.assertNotEqual(attrs.get("dtype", ""), "") + + per_coll_meta[collname].append(attrs) + if collname in {"wait"}: + continue + + self.assertEqual(attrs["pg_name"], "0") # yes this is a string + self.assertEqual(attrs["pg_desc"], "default_pg") + self.assertEqual(attrs["pg_size"], 2) + + self.assertGreaterEqual(attrs.get("in_msg_nelems", -1), 0) + self.assertGreaterEqual(attrs.get("out_msg_nelems", -1), 0) + self.assertTrue("in_split_size" in attrs.keys()) + self.assertTrue("out_split_size" in attrs.keys()) + self.assertEqual(attrs.get("global_rank_start", -1), 0) + self.assertEqual(attrs.get("global_rank_stride", -1), 1) + + # print(per_coll_meta) + self.assertEqual(len(per_coll_meta["allreduce"]), 2) + self.assertEqual(len(per_coll_meta["wait"]), 1) + + # check allreduce message sizes + a0 = per_coll_meta["allreduce"][0] + self.assertEqual(a0["out_msg_nelems"], 100, msg=f"{a0}") + self.assertEqual(a0["dtype"], "Float", msg=f"{a0}") + a1 = per_coll_meta["allreduce"][1] + self.assertEqual(a1["out_msg_nelems"], 1, msg=f"{a1}") + self.assertEqual(a1["dtype"], "Int", msg=f"{a1}") + + + @require_backend_is_available(DistTestCases.backend_feature["gpu"]) + @skip_if_lt_x_gpu(2) + @skip_but_pass_in_sandcastle_if(IS_FBCODE, "Kineto in fbcode code causes hang") + @skip_but_pass_in_sandcastle_if( + IS_MACOS or IS_WINDOWS, + "torch.profiler not enabled for mac/windows: https://github.com/pytorch/pytorch/pull/56124", + ) + @unittest.skipIf(BACKEND != "nccl", "Tests nccl metadata primarily.") + def test_ddp_profiling_execution_trace(self): + self.assertEqual(dist.get_backend(), "nccl") + # Create a temp file to save execution trace data + fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False) + fp.close() + et_file = fp.name + + et = ExecutionTraceObserver().register_callback(et_file) + + # first profiler context need not have ET + torch_profiler_ctx1 = torch.profiler.profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + ) + # collect ET in second profiler pass + torch_profiler_ctx2 = torch.profiler.profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + execution_trace_observer=et + ) + prof = self._test_ddp_profiling( + profiler_ctx=torch_profiler_ctx1, + profiler_ctx2=torch_profiler_ctx2, + ) + + print(f"Execution trace saved at {fp.name}") + self._validate_execution_trace_nccl(et_file) + + @skip_if_lt_x_gpu(2) @skip_but_pass_in_sandcastle_if( BACKEND not in DistTestCases.backend_feature["ddp"], diff --git a/torch/testing/_internal/hop_db.py b/torch/testing/_internal/hop_db.py index 4772fb42a9631..1602c1ef65625 100644 --- a/torch/testing/_internal/hop_db.py +++ b/torch/testing/_internal/hop_db.py @@ -82,7 +82,7 @@ def foo_impl_cuda(x, z): return x, z, x + z -@torch.library.impl_abstract("testlib::mutating_custom_op") +@torch.library.register_fake("testlib::mutating_custom_op") def foo_impl_abstract(x, z): return x, z, x + z @@ -118,9 +118,9 @@ def score_mod(score, b, h, m, n): return score + h yield SampleInput( - make_arg(2, 2, 64, 8, low=0.1, high=2), - make_arg(2, 2, 64, 8, low=0.1, high=2), - make_arg(2, 2, 64, 8, low=0.1, high=2), + make_arg(2, 2, 128, 8, low=0.1, high=2), + make_arg(2, 2, 128, 8, low=0.1, high=2), + make_arg(2, 2, 128, 8, low=0.1, high=2), score_mod, ) diff --git a/torch/testing/_internal/logging_tensor.py b/torch/testing/_internal/logging_tensor.py index 5ddd537474404..8b7faf45b3c3c 100644 --- a/torch/testing/_internal/logging_tensor.py +++ b/torch/testing/_internal/logging_tensor.py @@ -11,6 +11,7 @@ import functools from torch._C._profiler import gather_traceback, symbolize_tracebacks +logger = logging.getLogger("LoggingTensor") _dtype_abbrs = { torch.bfloat16: "bf16", @@ -135,8 +136,8 @@ def emit(self, record): if self.tracebacks_list is not None: self.tracebacks_list.append(record.traceback) -def log_input(name: str, var: object): - logging.getLogger("LoggingTensor").info("input", (name,), {}, var) # noqa: PLE1205 +def log_input(name: str, var: object) -> None: + logger.info("input", (name,), {}, var) # noqa: PLE1205 class GatherTraceback(logging.Filter): def __init__(self, python=True, script=True, cpp=False): @@ -151,7 +152,6 @@ def filter(self, record): @contextlib.contextmanager def capture_logs(is_mode=False, python_tb=False, script_tb=False, cpp_tb=False) -> Iterator[List[str]]: collect_traceback = python_tb or script_tb or cpp_tb - logger = logging.getLogger("LoggingTensor") log_list: List[str] = [] tracebacks_list: List[str] = [] handler = LoggingTensorHandler( diff --git a/torch/testing/_internal/opinfo/definitions/sparse.py b/torch/testing/_internal/opinfo/definitions/sparse.py index e6f0ad0e6f514..3e1f816d9f73f 100644 --- a/torch/testing/_internal/opinfo/definitions/sparse.py +++ b/torch/testing/_internal/opinfo/definitions/sparse.py @@ -25,7 +25,7 @@ def _check_fail(sample): except sample.error_type: pass except Exception as msg: - raise AssertionError( # noqa: TRY200 + raise AssertionError( # noqa: B904 f"{op_info.name} on {sample.sample_input=} expected exception " f"{sample.error_type}: {sample.error_regex}, got {type(msg).__name__}: {msg}" ) @@ -39,7 +39,7 @@ def _check_success(sample): try: op_info(sample.input, *sample.args, **sample.kwargs) except Exception as msg: - raise AssertionError( # noqa: TRY200 + raise AssertionError( # noqa: B904 f"{op_info.name} on {sample=} expected to succeed " f", got {type(msg).__name__}: {msg}" ) diff --git a/torch/utils/_foreach_utils.py b/torch/utils/_foreach_utils.py index 840df0432b53a..6f8a9b5b7e237 100644 --- a/torch/utils/_foreach_utils.py +++ b/torch/utils/_foreach_utils.py @@ -15,6 +15,8 @@ def _get_fused_kernels_supported_devices() -> List[str]: TensorListList: TypeAlias = List[List[Optional[Tensor]]] Indices: TypeAlias = List[int] +_foreach_supported_types = [torch.Tensor] + # This util function splits tensors into groups by device and dtype, which is useful before sending # tensors off to a foreach implementation, which requires tensors to be on one device and dtype. @@ -44,4 +46,4 @@ def _device_has_foreach_support(device: torch.device) -> bool: def _has_foreach_support(tensors: List[Tensor], device: torch.device) -> bool: - return _device_has_foreach_support(device) and all(t is None or type(t) == torch.Tensor for t in tensors) + return _device_has_foreach_support(device) and all(t is None or type(t) in _foreach_supported_types for t in tensors) diff --git a/torch/utils/_sympy/functions.py b/torch/utils/_sympy/functions.py index 427333b07c16d..e8c4a57d84c80 100644 --- a/torch/utils/_sympy/functions.py +++ b/torch/utils/_sympy/functions.py @@ -1,11 +1,21 @@ +import math + import sympy from sympy import S from sympy.core.logic import fuzzy_and, fuzzy_not, fuzzy_or -import math __all__ = [ - "FloorDiv", "ModularIndexing", "CleanDiv", "CeilDiv", "Pow", "TrueDiv", - "LShift", "RShift", "IsNonOverlappingAndDenseIndicator", "Round", "RoundDecimal", + "FloorDiv", + "ModularIndexing", + "CleanDiv", + "CeilDiv", + "Pow", + "TrueDiv", + "LShift", + "RShift", + "IsNonOverlappingAndDenseIndicator", + "Round", + "RoundDecimal", ] @@ -21,6 +31,7 @@ class FloorDiv(sympy.Function): 1. We can use divisibility guards to simplify FloorDiv(a, b) to a / b. 2. Printing out the expression is nicer (compared to say, representing a//b as (a - a % b) / b) """ + nargs = (2,) precedence = 50 # precedence of mul # noqa: F811 @@ -53,11 +64,14 @@ def _eval_is_integer(self): @classmethod def eval(cls, base, divisor): def check_supported_type(x): - if (x.is_integer is False and x.is_real is False and x.is_complex) or x.is_Boolean: + if ( + x.is_integer is False and x.is_real is False and x.is_complex + ) or x.is_Boolean: raise TypeError( f"unsupported operand type(s) for //: " f"'{type(base).__name__}' and '{type(divisor).__name__}'" - f", expected integer or real") + f", expected integer or real" + ) check_supported_type(base) check_supported_type(divisor) @@ -77,7 +91,9 @@ def check_supported_type(x): return sympy.Mul(base, -1) if isinstance(base, sympy.Integer) and isinstance(divisor, sympy.Integer): return base // divisor - if isinstance(base, (sympy.Integer, sympy.Float)) and isinstance(divisor, (sympy.Integer, sympy.Float)): + if isinstance(base, (sympy.Integer, sympy.Float)) and isinstance( + divisor, (sympy.Integer, sympy.Float) + ): return sympy.floor(base / divisor) if isinstance(base, FloorDiv): return FloorDiv(base.args[0], base.args[1] * divisor) @@ -125,7 +141,9 @@ def eval(cls, base, divisor, modulus): gcd = sympy.gcd(base, divisor) if gcd != 1: return ModularIndexing( - sympy.simplify(base / gcd), sympy.simplify(divisor / gcd), modulus + sympy.simplify(base / gcd), + sympy.simplify(divisor / gcd), + modulus, ) except sympy.PolynomialError: pass # https://github.com/pytorch/pytorch/issues/108276 @@ -178,6 +196,7 @@ def eval(cls, c, p, q): elif c == sympy.false: return q + class Mod(sympy.Function): """ We maintain this so that we avoid SymPy correctness issues, such as: @@ -263,16 +282,17 @@ class LShift(sympy.Function): @classmethod def eval(cls, base, shift): if shift < 0: - raise ValueError('negative shift count') - return base * 2 ** shift + raise ValueError("negative shift count") + return base * 2**shift class RShift(sympy.Function): @classmethod def eval(cls, base, shift): if shift < 0: - raise ValueError('negative shift count') - return base // 2 ** shift + raise ValueError("negative shift count") + return base // 2**shift + # Overloaded to be compatible with regular Python. # https://github.com/pytorch/pytorch/issues/90900 @@ -284,7 +304,8 @@ def eval(cls, base, exp): elif base.is_zero and exp < 0: raise ZeroDivisionError(f"{base} cannot be raised to a negative power") else: - return base ** exp + return base**exp + # Overloaded to be compatible with regular Python. # https://github.com/pytorch/pytorch/issues/90900 @@ -317,13 +338,14 @@ def eval(cls, *args): # in dim 0. if all(isinstance(a, sympy.Integer) for a in args): # sym_node imported in torch.__init__. Local import to avoid an import cycle - from torch.fx.experimental.symbolic_shapes import eval_is_non_overlapping_and_dense + from torch.fx.experimental.symbolic_shapes import ( + eval_is_non_overlapping_and_dense, + ) size_args = args[0:dim] stride_args = args[dim:] return eval_is_non_overlapping_and_dense( - [int(a) for a in size_args], - [int(a) for a in stride_args] + [int(a) for a in size_args], [int(a) for a in stride_args] ) return None @@ -361,7 +383,11 @@ def eval(cls, number, ndigits): if number.is_integer and ndigits >= 0: return number elif isinstance(number, sympy.Number) and isinstance(ndigits, sympy.Integer): - value_type, output_type = (int, sympy.Integer) if isinstance(number, sympy.Integer) else (float, sympy.Float) + value_type, output_type = ( + (int, sympy.Integer) + if isinstance(number, sympy.Integer) + else (float, sympy.Float) + ) return output_type(round(value_type(number), int(ndigits))) @@ -401,6 +427,7 @@ def eval(cls, a): return OpaqueUnaryFn + # Keep in sync with math_op_names in torch/fx/experimental/sym_node.py OpaqueUnaryFn_sqrt = make_opaque_unary_fn("sqrt") OpaqueUnaryFn_cos = make_opaque_unary_fn("cos") diff --git a/torch/utils/_sympy/value_ranges.py b/torch/utils/_sympy/value_ranges.py index eae126b1b4dcd..504fe757d4f2c 100644 --- a/torch/utils/_sympy/value_ranges.py +++ b/torch/utils/_sympy/value_ranges.py @@ -137,8 +137,8 @@ def __init__(self, lower: AllIn, upper: AllIn) -> None: try: if not sympy_generic_le(lower, upper): raise ValueRangeError(f"Invalid ranges [{lower}:{upper}]") - except TypeError: - raise TypeError(f"Could not compare {lower} <= {upper}") # noqa: TRY200 + except TypeError as e: + raise TypeError(f"Could not compare {lower} <= {upper}") from e # Because this is a frozen class object.__setattr__(self, "lower", lower) object.__setattr__(self, "upper", upper) diff --git a/torch/utils/_traceback.py b/torch/utils/_traceback.py index fa73b9f41cd66..9f4d04c55105f 100644 --- a/torch/utils/_traceback.py +++ b/torch/utils/_traceback.py @@ -128,7 +128,7 @@ def report_compile_source_on_error(): tb.tb_next = tb_next tb_next = tb - raise exc.with_traceback(tb_next) # noqa: TRY200 + raise exc.with_traceback(tb_next) # noqa: B904 def shorten_filename(fn, *, base=None): """Shorten a source filepath, with the assumption that torch/ subdirectories don't need to be shown to user.""" diff --git a/torch/utils/tensorboard/writer.py b/torch/utils/tensorboard/writer.py index cd281bc0d3fca..c646ce0c0c110 100644 --- a/torch/utils/tensorboard/writer.py +++ b/torch/utils/tensorboard/writer.py @@ -916,7 +916,7 @@ def add_embedding( "warning: Embedding dir exists, did you set global_step for add_embedding()?" ) else: - raise FileExistsError( + raise NotADirectoryError( f"Path: `{save_path}` exists, but is a file. Cannot proceed." ) else: diff --git a/torchgen/aoti/fallback_ops.py b/torchgen/aoti/fallback_ops.py index f77527a156beb..4a300c3cc3010 100644 --- a/torchgen/aoti/fallback_ops.py +++ b/torchgen/aoti/fallback_ops.py @@ -25,6 +25,8 @@ "aten.avg_pool2d.default", "aten.avg_pool3d_backward.default", "aten.avg_pool3d.default", + "aten.bernoulli_.float", + "aten.bernoulli_.Tensor", "aten.bmm.out", "aten.bucketize.Tensor", "aten.cat.default", diff --git a/torchgen/gen_aoti_c_shim.py b/torchgen/gen_aoti_c_shim.py index 1f99e3a9f3fae..f123bc879cd34 100644 --- a/torchgen/gen_aoti_c_shim.py +++ b/torchgen/gen_aoti_c_shim.py @@ -34,6 +34,7 @@ BaseTy.Layout: "int32_t", # Represent enum as int BaseTy.MemoryFormat: "int32_t", # Represent enum as int BaseTy.ScalarType: "int32_t", # Represent enum as int + BaseTy.Generator: "AtenGeneratorHandle", } base_type_to_aten_type = { @@ -48,6 +49,7 @@ BaseTy.Layout: "c10::Layout", BaseTy.MemoryFormat: "c10::MemoryFormat", BaseTy.ScalarType: "c10::ScalarType", + BaseTy.Generator: "at::Generator", } base_type_to_callsite_expr = { @@ -62,6 +64,7 @@ BaseTy.Layout: "static_cast", BaseTy.MemoryFormat: "static_cast", BaseTy.ScalarType: "static_cast", + BaseTy.Generator: "*generator_handle_to_generator_pointer", } @@ -89,7 +92,7 @@ def convert_arg_type_and_name(typ: Type, name: str) -> Tuple[List[str], List[str ], ) else: - # TODO: BaseTy.Dimname, BaseTy.Generator, etc. + # TODO: BaseTy.Dimname, etc. raise NotImplementedError(f"TODO: add support for arg type {repr(typ)}") elif isinstance(typ, OptionalType): c_types, names, aten_types, callsite_exprs = convert_arg_type_and_name( @@ -246,18 +249,18 @@ def gen_declaration_and_definition( return declaration_definition_cache[(func_name, device, backend_call)] if schema.is_out_fn(): - # out_variant has out arguments in the front, and it's ok to ignore return value + # out_variant has out arguments in the front, and it's ok to ignore return values # because C shim functions only return AOTITorchError - # Somehow at::native out-variant functions have out arguments in the back args, callsite_exprs = gen_arguments( - [*schema.arguments.flat_non_out, *schema.arguments.out] - if "at::native" in backend_call - else [*schema.arguments.out, *schema.arguments.flat_non_out], + [*schema.arguments.out, *schema.arguments.flat_non_out] ) ret_assignments: List[str] = [] else: args, callsite_exprs = gen_arguments(schema.arguments.flat_all) - ret_declarations, ret_assignments = gen_returns(schema) + # ignore return values for inplace ops + ret_declarations, ret_assignments = ( + ([], []) if schema.name.name.inplace else gen_returns(schema) + ) args.extend(ret_declarations) declaration = f"AOTITorchError aoti_torch_{device}_{func_name}({', '.join(args)})"